Skip to content

feat: add API specification for returning the k largest elements#722

Open
kgryte wants to merge 8 commits intodata-apis:mainfrom
kgryte:feat/top_k
Open

feat: add API specification for returning the k largest elements#722
kgryte wants to merge 8 commits intodata-apis:mainfrom
kgryte:feat/top_k

Conversation

@kgryte
Copy link
Copy Markdown
Contributor

@kgryte kgryte commented Dec 14, 2023

Update: This PR

  • resolves RFC: add topk and / or argpartition #629 by adding one new API to the Array API specification

    • top_k: returns a tuple whose first element is an array containing the top k largest (or smallest) values and whose second element is an array containing the indices of those k values.

The design decisions largely follow the discussion below. The specialized methods top_k_indices and top_k_values have been dropped from the specification.


This PR

  • resolves RFC: add topk and / or argpartition #629 by adding 3 new APIs to the Array API specification

    • top_k: returns a tuple whose first element is an array containing the top k largest (or smallest) values and whose second element is an array containing the indices of those k values.
    • top_k_indices: returns an array containing the indices of the k largest (or smallest) values.
    • top_k_values: returns an array containing the k largest (or smallest) values.

Prior Art

As illustrated in the API comparison, there is currently no consistent API across array libraries for returning the k largest or smallest values.

  • NumPy has partition and argpartition, but these return full arrays. When axis is None, NumPy operates on a flattened input array. To get the top k values, one must index and, if wanting sorted values, sort.
  • CuPy has partition and argpartition and follows NumPy; however, for argpartition, for implementation reasons, it performs a full sort.
  • Dask has topk and switches "modes" (largest or smallest) based on whether k is positive or negative.
  • JAX has top_k which only returns values always returns both values and indices, as well as NumPy equivalent partition and argpartition APIs (however, JAX differs in how it handles NaNs). The function only supports searching along the last axis.
  • PyTorch has topk which always returns both values and indices.
  • TensorFlow has top_k which always returns both values and indices and only supports searching along the last axis.

Proposed APIs

This PR attempts to synthesize the common themes and best ideas for "top k" APIs as observed among array libraries and attempts to define APIs which adhere to specification precedent in order to promote consistent design and reduce cognitive load.

top_k

def top_k(
    x: array,
    k: int,
    /,
    *,
    axis: Optional[int] = None,
    mode: Literal["largest", "smallest"] = "largest",
) -> Tuple[array, array]

Returns a tuple containing the k largest (or smallest) elements in x.

def top_k_indices(
    x: array,
    k: int,
    /,
    *,
    axis: Optional[int] = None,
    mode: Literal["largest", "smallest"] = "largest",
) -> array

Returns an array containing the indices of the k largest (or smallest) elements in x.

def top_k_values(
    x: array,
    k: int,
    /,
    *,
    axis: Optional[int] = None,
    mode: Literal["largest", "smallest"] = "largest",
) -> array

Returns an array containing the k largest (or smallest) elements in x.

Design Decision Rationale

  • The default for axis is None in order to match min, max, argmin, and argmax. In those APIs, when axis is None (the default), the functions operate over a flattened array. Given that top_k* may be considered a generalization of the mentioned APIs, ensuring consistency seemed preferable to requiring users to remember a separate set of rules for top_k*.
  • axis only supports int and None in order to match argmin and argmax. In min and max, the specification supports specifying multiple axes. Support for multiple axes can be a future specification extension.
  • top_k was chosen over topk due to naming concerns discussed elsewhere (namely top k vs to pk). Furthermore, "top k" follows ML conventions, as opposed to maxk/max_k or nlargest/nsmallest as found in other languages.
  • The PR includes three separate APIs following the lead of unique. In that case, rather than support polymorphic return values (e.g., returning values, returning values and indices, return values and counts, etc), we chose define specific API which are monomorphic in their output. We innovated there, and the thinking that went into those design decisions seemed applicable here, where a user may want only values, indices, or both.
  • The PR follows the unique_* naming convention, rather than the arg* naming convention, as there are three different return value situations: values, indices, and indices and values. Hence, using a suffix to describe what is returned as in unique_* seems reasonable and follows existing precedent in the specification.
  • The APIs include a "mode" option to specify the type (largest or smallest) of values to return. Most existing array libraries supporting a "top k" API return only the largest values; however, PyTorch supports returning either the smallest or largest and does so via a largest keyword argument. This PR chooses to name the kwarg mode in order to be more explicit (what does largest=False mean to the lay reader?) and follows precedent elsewhere in the specification (e.g., linalg.qr) where mode is used to toggle between different operating modes.
  • The PR does not include a sorted kwarg in order to instruct the API to return sorted values (or indices corresponding to sorted values) because (a) the kwarg is not universally supported currently, (b) downstream users can, at least for values, explicitly call sort (except in Dask which doesn't currently support full sorting) after calling top_k or top_k_values, and (c) can be addressed in a future specification extension. Additionally, if we support sorted, we may also want to support a stable kwarg as in sort to allow ensuring that returned indices are consistent when provided the same input array.
  • Leaves unspecified what should happen when k exceeds the number of elements, as different behaviors seem acceptable (e.g., raising an exception or returning m < k values).

Questions

  • Should we be more strict in specifying what should happen when k exceeds the number of elements?
  • Should zero-dimensional arrays be supported?
  • In argmin and argmax, the specification requires returning the index of first occurrence when a minimum/maximum value occurs multiple times. Given that top_k* can be implemented as a partial sort, presumably we do not want specify a first occurrence restriction. Is this a reasonable assumption?
  • Should we defer adding support for specifying multiple axes until a future revision of the specification or should we go ahead and add now for parity with min and max?
  • Are we okay with None being the default for axis, where the default behavior is searching over a flattened array?

Considerations

The APIs included in this PR have implications for the following array libraries:

  • NumPy: these will be new APIs, and, similar to unique_*, will need to be added to the main namespace.
  • CuPy: same as NumPy.
  • Dask: will need to introduce new APIs; however, the new APIs can be implemented as lightweight wrappers around Dask's existing topk and argtopk.
  • JAX: were JAX to place the APIs in its lax namespace, this PR would introduce breaking changes, as JAX would need to return both values and indices, by default, and JAX would need to flatten by default rather than search along the last dimension. However, if implemented in its numpy namespace, these will simply be new APIs. In both scenarios, JAX will need to add support for axis and mode behavior.
  • PyTorch: these will be new APIs; however, the new APIs can be implemented as lightweight wrappers around PyTorch's existing topk.
  • TensorFlow: additional APIs (top_k_values and top_k_indices). If implemented in its math namespace, this PR would introduce breaking changes as TensorFlow would need to flatten by default. However, if implemented in its numpy namespace, these will simply be new APIs. In both scenarios, TensorFlow will need to add support for axis and mode behavior.

Related Links

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

API extension Adds new functions or objects to the API. Needs Review Pull request which needs review. status: Blocked Issue or pull request which is current blocked.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RFC: add topk and / or argpartition

7 participants