feat: add API specification for returning the k largest elements#722
Open
kgryte wants to merge 8 commits intodata-apis:mainfrom
Open
feat: add API specification for returning the k largest elements#722kgryte wants to merge 8 commits intodata-apis:mainfrom
k largest elements#722kgryte wants to merge 8 commits intodata-apis:mainfrom
Conversation
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Update: This PR
resolves RFC: add topk and / or argpartition #629 by adding one new API to the Array API specification
top_k: returns a tuple whose first element is an array containing the topklargest (or smallest) values and whose second element is an array containing the indices of thosekvalues.The design decisions largely follow the discussion below. The specialized methods
top_k_indicesandtop_k_valueshave been dropped from the specification.This PR
resolves RFC: add topk and / or argpartition #629 by adding
3new APIs to the Array API specificationtop_k: returns a tuple whose first element is an array containing the topklargest (or smallest) values and whose second element is an array containing the indices of thosekvalues.top_k_indices: returns an array containing the indices of theklargest (or smallest) values.top_k_values: returns an array containing theklargest (or smallest) values.Prior Art
As illustrated in the API comparison, there is currently no consistent API across array libraries for returning the
klargest or smallest values.partitionandargpartition, but these return full arrays. WhenaxisisNone, NumPy operates on a flattened input array. To get the topkvalues, one must index and, if wanting sorted values, sort.partitionandargpartitionand follows NumPy; however, forargpartition, for implementation reasons, it performs a full sort.topkand switches "modes" (largest or smallest) based on whetherkis positive or negative.top_kwhichonly returns valuesalways returns both values and indices, as well as NumPy equivalentpartitionandargpartitionAPIs (however, JAX differs in how it handles NaNs). The function only supports searching along the last axis.topkwhich always returns both values and indices.top_kwhich always returns both values and indices and only supports searching along the last axis.Proposed APIs
This PR attempts to synthesize the common themes and best ideas for "top k" APIs as observed among array libraries and attempts to define APIs which adhere to specification precedent in order to promote consistent design and reduce cognitive load.
top_k
Returns a tuple containing the
klargest (or smallest) elements inx.Returns an array containing the indices of the
klargest (or smallest) elements inx.Returns an array containing the
klargest (or smallest) elements inx.Design Decision Rationale
axisisNonein order to matchmin,max,argmin, andargmax. In those APIs, whenaxisisNone(the default), the functions operate over a flattened array. Given thattop_k*may be considered a generalization of the mentioned APIs, ensuring consistency seemed preferable to requiring users to remember a separate set of rules fortop_k*.axisonly supportsintandNonein order to matchargminandargmax. Inminandmax, the specification supports specifying multiple axes. Support for multiple axes can be a future specification extension.top_kwas chosen overtopkdue to naming concerns discussed elsewhere (namelytop kvsto pk). Furthermore, "top k" follows ML conventions, as opposed tomaxk/max_kornlargest/nsmallestas found in other languages.unique. In that case, rather than support polymorphic return values (e.g., returning values, returning values and indices, return values and counts, etc), we chose define specific API which are monomorphic in their output. We innovated there, and the thinking that went into those design decisions seemed applicable here, where a user may want only values, indices, or both.unique_*naming convention, rather than thearg*naming convention, as there are three different return value situations: values, indices, and indices and values. Hence, using a suffix to describe what is returned as inunique_*seems reasonable and follows existing precedent in the specification.largestkeyword argument. This PR chooses to name the kwargmodein order to be more explicit (what doeslargest=Falsemean to the lay reader?) and follows precedent elsewhere in the specification (e.g.,linalg.qr) wheremodeis used to toggle between different operating modes.sortedkwarg in order to instruct the API to return sorted values (or indices corresponding to sorted values) because (a) the kwarg is not universally supported currently, (b) downstream users can, at least for values, explicitly callsort(except in Dask which doesn't currently support full sorting) after callingtop_kortop_k_values, and (c) can be addressed in a future specification extension. Additionally, if we supportsorted, we may also want to support astablekwarg as insortto allow ensuring that returned indices are consistent when provided the same input array.kexceeds the number of elements, as different behaviors seem acceptable (e.g., raising an exception or returningm < kvalues).Questions
kexceeds the number of elements?argminandargmax, the specification requires returning the index of first occurrence when a minimum/maximum value occurs multiple times. Given thattop_k*can be implemented as a partial sort, presumably we do not want specify a first occurrence restriction. Is this a reasonable assumption?minandmax?Nonebeing the default foraxis, where the default behavior is searching over a flattened array?Considerations
The APIs included in this PR have implications for the following array libraries:
unique_*, will need to be added to the main namespace.topkandargtopk.laxnamespace, this PR would introduce breaking changes, asJAX would need to return both values and indices, by default,and JAX would need to flatten by default rather than search along the last dimension. However, if implemented in itsnumpynamespace, these will simply be new APIs. In both scenarios, JAX will need to add support foraxisandmodebehavior.topk.top_k_valuesandtop_k_indices). If implemented in itsmathnamespace, this PR would introduce breaking changes as TensorFlow would need to flatten by default. However, if implemented in itsnumpynamespace, these will simply be new APIs. In both scenarios, TensorFlow will need to add support foraxisandmodebehavior.Related Links