Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
75d832e
Neighborhood filter
ahijevyc Sep 7, 2024
b850bc4
Merge remote-tracking branch 'upstream/main' into neighborhood_filter
ahijevyc Sep 7, 2024
8ec0193
ruff recommendations
ahijevyc Sep 9, 2024
5605949
added Callable to Type checking
ahijevyc Sep 9, 2024
70ba961
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Sep 9, 2024
0c7bc1e
np.vstack().T faster than np.c
ahijevyc Sep 9, 2024
d6d8a33
Fix some comments
ahijevyc Sep 9, 2024
47b9cda
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Sep 17, 2024
f75db0d
Merge branch 'main' into ahijevyc/neighborhood_filter
aaronzedwick Oct 1, 2024
bddd2fa
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Oct 29, 2024
a0b6361
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Mar 17, 2025
6c59af7
missing imports
ahijevyc Mar 17, 2025
67d0c11
Merge branch 'main' into ahijevyc/neighborhood_filter
philipc2 Mar 18, 2025
8979745
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Mar 19, 2025
a8875cd
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc May 28, 2025
f4af498
Merge branch 'UXARRAY:main' into ahijevyc/neighborhood_filter
ahijevyc Aug 4, 2025
45407aa
Merge branch 'main' into ahijevyc/neighborhood_filter
erogluorhan Sep 3, 2025
9fa100c
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Jan 21, 2026
56e7821
Merge branch 'main' into ahijevyc/neighborhood_filter
erogluorhan Feb 26, 2026
14c37b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2026
a37b88f
Update dataset.py to address pre-commit errors
erogluorhan Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 115 additions & 3 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import numpy as np


from typing import TYPE_CHECKING, Optional, Union, Hashable, Literal
from typing import TYPE_CHECKING, Callable, Optional, Union, Hashable, Literal

from uxarray.constants import GRID_DIMS
from uxarray.formatting_html import array_repr

from html import escape
Expand Down Expand Up @@ -1046,8 +1047,6 @@ def isel(self, ignore_grid=False, *args, **kwargs):
> uxda.subset(n_node=[1, 2, 3])
"""

from uxarray.constants import GRID_DIMS

if any(grid_dim in kwargs for grid_dim in GRID_DIMS) and not ignore_grid:
# slicing a grid-dimension through Grid object

Expand Down Expand Up @@ -1104,3 +1103,116 @@ def _slice_from_grid(self, sliced_grid):
dims=self.dims,
attrs=self.attrs,
)

def neighborhood_filter(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation looks great! May we move the bulk of the logic into the uxarray.grid.neighbors module and call that helper from here?

We can keep the data-mapping checks here, and anything related to constructing and returining the final data array but the bulk of the computations would go inside a helper in the module mentioned above.

Copy link
Copy Markdown
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to think about how to do that, but I am happy to defer to you.

self,
func: Callable = np.mean,
r: float = 1.0,
) -> UxDataArray:
"""Apply neighborhood filter
Parameters:
-----------
func: Callable, default=np.mean
Apply this function to neighborhood
r : float, default=1.
Radius of neighborhood. For spherical coordinates, the radius is in units of degrees,
and for cartesian coordinates, the radius is in meters.
Returns:
--------
destination_data : np.ndarray
Filtered data.
"""

if self._face_centered():
data_mapping = "face centers"
elif self._node_centered():
data_mapping = "nodes"
elif self._edge_centered():
data_mapping = "edge centers"
else:
raise ValueError(
"Data_mapping is not face, node, or edge. Could not define data_mapping."
)

# reconstruct because the cached tree could be built from
# face centers, edge centers or nodes.
tree = self.uxgrid.get_ball_tree(coordinates=data_mapping, reconstruct=True)
Comment on lines +1995 to +1996
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aaronzedwick

We should probably fix this logic in get_ball_tree(), since we shouldn't need to manually set reconstruct=False

        if self._ball_tree is None or reconstruct:
            self._ball_tree = BallTree(
                self,
                coordinates=coordinates,
                distance_metric=distance_metric,
                coordinate_system=coordinate_system,
                reconstruct=reconstruct,
            )
        else:
            if coordinates != self._ball_tree._coordinates:
                self._ball_tree.coordinates = coordinates

The coordinates != self._ball_tree._coordinates check should be included in the first if

Copy link
Copy Markdown
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. So, move the coordinates check to the if-clause like this?

                if (
                    self._ball_tree is None
                    or coordinates != self._ball_tree._coordinates
                    or reconstruct
                ):

                    self._ball_tree = BallTree(
                        self,
                        coordinates=coordinates,
                        distance_metric=distance_metric,
                        coordinate_system=coordinate_system,
                        reconstruct=reconstruct,
                    )

What if the coordinate_system is different? Would that also require a newly constructed tree?

Copy link
Copy Markdown
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever logic is fixed in Grid.get_ball_tree should also be applied to Grid.get_kdtree.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking coordinate system also (coordinate_system is not a hidden variable of _ball_tree; it has no underscore):

                if (
                    self._ball_tree is None
                    or coordinates != self._ball_tree._coordinates
                    or coordinate_system != self._ball_tree.coordinate_system
                    or reconstruct
                ):

                    self._ball_tree = BallTree(
                        self,
                        coordinates=coordinates,
                        distance_metric=distance_metric,
                        coordinate_system=coordinate_system,
                        reconstruct=reconstruct,
                    )


coordinate_system = tree.coordinate_system

if coordinate_system == "spherical":
if data_mapping == "nodes":
lon, lat = (
self.uxgrid.node_lon.values,
self.uxgrid.node_lat.values,
)
elif data_mapping == "face centers":
lon, lat = (
self.uxgrid.face_lon.values,
self.uxgrid.face_lat.values,
)
elif data_mapping == "edge centers":
lon, lat = (
self.uxgrid.edge_lon.values,
self.uxgrid.edge_lat.values,
)
else:
raise ValueError(
f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', "
f"but received: {data_mapping}"
)

dest_coords = np.c_[lon, lat]

elif coordinate_system == "cartesian":
if data_mapping == "nodes":
x, y, z = (
self.uxgrid.node_x.values,
self.uxgrid.node_y.values,
self.uxgrid.node_z.values,
)
elif data_mapping == "face centers":
x, y, z = (
self.uxgrid.face_x.values,
self.uxgrid.face_y.values,
self.uxgrid.face_z.values,
)
elif data_mapping == "edge centers":
x, y, z = (
self.uxgrid.edge_x.values,
self.uxgrid.edge_y.values,
self.uxgrid.edge_z.values,
)
else:
raise ValueError(
f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', "
f"but received: {data_mapping}"
)
Comment on lines +1997 to +2047
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use #974 's new remap.utils._remap_grid_parse instead of this code block.


dest_coords = np.c_[x, y, z]
Comment thread
ahijevyc marked this conversation as resolved.
Outdated

else:
raise ValueError(
f"Invalid coordinate_system. Expected either 'spherical' or 'cartesian', but received {coordinate_system}"
)

neighbor_indices = tree.query_radius(dest_coords, r=r)

destination_data = np.empty(self.data.shape)

# assert last dimension is a GRID dimension.
assert self.dims[-1] in GRID_DIMS, (
f"expected last dimension of uxDataArray {self.data.dims[-1]} "
f"to be one of {GRID_DIMS}"
)
# Apply function to indices on last axis.
for i, idx in enumerate(neighbor_indices):
if len(idx):
destination_data[..., i] = func(self.data[..., idx])

# construct data array for filtered variable
uxda_filter = self._copy()

uxda_filter.data = destination_data

return uxda_filter
37 changes: 36 additions & 1 deletion uxarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import sys

from typing import Optional, IO, Union
from typing import Callable, Optional, IO, Union

from uxarray.constants import GRID_DIMS
from uxarray.grid import Grid
from uxarray.core.dataarray import UxDataArray

Expand Down Expand Up @@ -338,6 +339,40 @@ def to_array(self) -> UxDataArray:
xarr = super().to_array()
return UxDataArray(xarr, uxgrid=self.uxgrid)

def neighborhood_filter(
self,
func: Callable = np.mean,
r: float = 1.0,
):
"""Neighborhood function implementation for ``UxDataset``.
Parameters
---------
func : Callable = np.mean
Apply this function to neighborhood
r : float, default=1.
Radius of neighborhood
"""
Comment thread
ahijevyc marked this conversation as resolved.

destination_uxds = self._copy()
# Loop through uxDataArrays in uxDataset
for var_name in self.data_vars:
uxda = self[var_name]

# Skip if uxDataArray has no GRID dimension.
grid_dims = [dim for dim in uxda.dims if dim in GRID_DIMS]
if len(grid_dims) == 0:
continue

# Put GRID dimension last for UxDataArray.neighborhood_filter.
remember_dim_order = uxda.dims
uxda = uxda.transpose(..., grid_dims[0])
# Filter uxDataArray.
uxda = uxda.neighborhood_filter(func, r)
# Restore old dimension order.
destination_uxds[var_name] = uxda.transpose(*remember_dim_order)

return destination_uxds

def nearest_neighbor_remap(
self,
destination_obj: Union[Grid, UxDataArray, UxDataset],
Expand Down