forked from flexflow/flexflow-train
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patharray_shape.cc
More file actions
75 lines (58 loc) · 1.8 KB
/
array_shape.cc
File metadata and controls
75 lines (58 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include "kernels/array_shape.h"
#include "utils/containers.h"
namespace FlexFlow {
ArrayShape::ArrayShape(size_t *_dims, size_t num_dims)
: dims(_dims, _dims + num_dims) {}
std::size_t ArrayShape::get_volume() const {
return num_elements();
}
std::size_t ArrayShape::get_dim() const {
return num_dims();
}
std::size_t ArrayShape::num_elements() const {
return product(this->dims);
}
std::size_t ArrayShape::num_dims() const {
return this->dims.size();
}
std::size_t ArrayShape::operator[](legion_dim_t idx) const {
return dims.at(idx);
}
std::size_t ArrayShape::at(legion_dim_t idx) const {
return dims.at(idx);
}
legion_dim_t ArrayShape::last_idx() const {
return legion_dim_t(dims.size() - 1);
}
legion_dim_t ArrayShape::neg_idx(int idx) const {
assert(idx < 0 && "Idx should be negative for negative indexing");
return legion_dim_t(dims.size() + idx);
}
optional<std::size_t> ArrayShape::at_maybe(std::size_t idx) const {
if (idx < dims.size()) {
return dims[legion_dim_t(idx)];
} else {
return {};
}
}
ArrayShape ArrayShape::reversed_dim_order() const {
std::vector<std::size_t> dims_reversed(dims.rbegin(), dims.rend());
return ArrayShape(dims_reversed);
}
ArrayShape ArrayShape::sub_shape(optional<legion_dim_t> start,
optional<legion_dim_t> end) {
size_t s = start.has_value() ? start.value().value() : 0;
size_t e = end.has_value() ? end.value().value() : dims.size();
std::vector<std::size_t> sub_dims(dims.begin() + s, dims.begin() + e);
return ArrayShape(sub_dims);
}
bool ArrayShape::operator==(ArrayShape const &other) const {
if (this->dims.size() != other.dims.size()) {
return false;
}
return this->dims == other.dims;
}
size_t get_volume(ArrayShape const &shape) {
return shape.get_volume();
}
} // namespace FlexFlow