-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathxarray_util.py
More file actions
144 lines (128 loc) · 5.09 KB
/
xarray_util.py
File metadata and controls
144 lines (128 loc) · 5.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from collections import defaultdict
import xarray as xr
import numpy as np
def generic_subset(d, dim, slices):
""" slices is a dictionary of variable names to slice objects or boolean masks,
assumed to all apply to the same dimension dim within xarray dataset d"""
mask = None
for k, sl in slices.items():
# print(k, sl)
if hasattr(sl, 'start'):
this_mask = (d[k] >= sl.start) & (d[k] <= sl.stop)
else:
this_mask = sl
if mask is None:
mask = this_mask
else:
mask = mask & this_mask
# print(mask.sum())
return d[{dim:mask}]
def get_1d_dims(d):
"""
Find all dimensions in a dataset that are purely 1-dimensional,
i.e., those dimensions that are not part of a 2D or higher-D
variable.
arguments
d: xarray Dataset
returns
dims1d: a list of dimension names
"""
# Assume all dims coorespond to 1D vars
dims1d = list(d.dims.keys())
for varname, var in d.variables.items():
if len(var.dims) > 1:
for vardim in var.dims:
if vardim in dims1d:
dims1d.remove(str(vardim))
return dims1d
def gen_1d_datasets(d):
"""
Generate a sequence of datasets having only those variables
along each dimension that is only used for 1-dimensional variables.
arguments
d: xarray Dataset
returns
generator function yielding a sequence of single-dimension datasets
"""
dims1d = get_1d_dims(d)
# print(dims1d)
for dim in dims1d:
all_dims = list(d.dims.keys())
all_dims.remove(dim)
yield d.drop_dims(all_dims)
def get_1d_datasets(d):
"""
Generate a list of datasets having only those variables
along each dimension that is only used for 1-dimensional variables.
arguments
d: xarray Dataset
returns
a list of single-dimension datasets
"""
return [d1 for d1 in gen_1d_datasets(d, *args, **kwargs)]
def get_scalar_vars(d):
scalars = []
for varname, var in d.variables.items():
if len(var.dims) == 0:
scalars.append(varname)
return scalars
def concat_1d_dims(datasets, stack_scalars=None):
"""
For each xarray Dataset in datasets, concatenate (preserving the order of datasets)
all variables along dimensions that are only used for one-dimensional variables.
arguments
d: iterable of xarray Datasets
stack_scalars: create a new dimension named with this value
that aggregates all scalar variables and coordinates
returns
a new xarray Dataset with only the single-dimension variables
"""
# dictionary mapping dimension names to a list of all
# datasets having only that dimension
all_1d_datasets = defaultdict(list)
for d in datasets:
scalars = get_scalar_vars(d)
for d_1d_initial in gen_1d_datasets(d):
# Get rid of scalars
d_1d = d_1d_initial.drop(scalars)
dims = tuple(d_1d.dims.keys())
all_1d_datasets[dims[0]].append(d_1d)
if stack_scalars:
# restore scalars along new dimension stack_scalars
scalar_dataset = xr.Dataset()
for scalar_var in scalars:
# promote from scalar to an array with a dimension, and remove
# the coordinate info so that it's just a regular variable.
as_1d = d[scalar_var].expand_dims(stack_scalars).reset_coords(drop=True)
scalar_dataset[scalar_var] = as_1d # xr.DataArray(as_1d, dims=[stack_scalars])
all_1d_datasets[stack_scalars].append(scalar_dataset)
unified = xr.Dataset()
for dim in all_1d_datasets:
combined = xr.concat(all_1d_datasets[dim], dim, coords='minimal', data_vars='minimal')
unified.update(combined)
return unified
# datasets=[]
# for i, size in enumerate((4, 6)):
# a = xr.DataArray(10*i + np.arange(size), dims='x')
# b = xr.DataArray(10*i + np.arange(size/2), dims='y')
# c = xr.DataArray(20*i + np.arange(size*3), dims='t')
# d = xr.DataArray(11*i + np.arange(size*3), dims='t')
# T = xr.DataArray(10*i + np.arange(size)**2, dims='x')
# D = xr.DataArray(10*i + np.arange(size/2)**2, dims='y')
# z = xr.DataArray(10*i + np.arange(size*4)**2, dims='z')
# u = xr.DataArray(10*i + np.arange(size*5)**2, dims='u')
# v = xr.DataArray(12*i + np.arange(size*5)**2, dims='u')
# P = xr.DataArray(10*i + np.ones((size,int(size/2))), dims=['x', 'y'])
# Q = xr.DataArray(20*i + np.ones((size,int(size/2))), dims=['x', 'y'])
# d = xr.Dataset({'x':a,'y':b, 't':c, 'd':d, 'u':u, 'v':v, 'z':z, 'T':T, 'D':D, 'P':P, 'Q':Q})
# datasets.append(d)
# # datasets.append(d[{'x':slice(None, None), 'y':slice(0,0)}])
# for d in datasets: print(d,'\n')
# # xr.combine_by_coords(datasets, coords='all')
# # xr.combine_nested(datasets, coords='all', data_vars='all')
# # print(get_1d_dims(d))
# assert(get_1d_dims(d)==['t', 'u', 'z'])
# # for d1 in get_1d_datasets(d):
# # print(d1,'\n')
# combined = concat_1d_dims(datasets)
# print(combined)