1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_
17
18 #include <cstdint>
19 #include <cstring>
20 #include <initializer_list>
21 #include <iterator>
22
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24
25 namespace tflite {
26
27 template <int N>
28 struct Dims {
29 int sizes[N];
30 int strides[N];
31 };
32
33 class RuntimeShape {
34 public:
35 // Shapes with dimensions up to 5 are stored directly in the structure, while
36 // larger shapes are separately allocated.
37 static constexpr int kMaxSmallSize = 5;
38
39 RuntimeShape& operator=(RuntimeShape const&) = delete;
40
RuntimeShape()41 RuntimeShape() : size_(0) {}
42
RuntimeShape(int dimensions_count)43 explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
44 if (dimensions_count > kMaxSmallSize) {
45 dims_pointer_ = new int32_t[dimensions_count];
46 }
47 }
48
RuntimeShape(int shape_size,int32_t value)49 RuntimeShape(int shape_size, int32_t value) : size_(0) {
50 Resize(shape_size);
51 for (int i = 0; i < shape_size; ++i) {
52 SetDim(i, value);
53 }
54 }
55
RuntimeShape(int dimensions_count,const int32_t * dims_data)56 RuntimeShape(int dimensions_count, const int32_t* dims_data) : size_(0) {
57 ReplaceWith(dimensions_count, dims_data);
58 }
59
RuntimeShape(const std::initializer_list<int> init_list)60 RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
61 BuildFrom(init_list);
62 }
63
64 // Avoid using this constructor. We should be able to delete it when C++17
65 // rolls out.
RuntimeShape(RuntimeShape const & other)66 RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
67 if (size_ > kMaxSmallSize) {
68 dims_pointer_ = new int32_t[size_];
69 }
70 std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * size_);
71 }
72
73 bool operator==(const RuntimeShape& comp) const {
74 return this->size_ == comp.size_ &&
75 std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32_t)) ==
76 0;
77 }
78
~RuntimeShape()79 ~RuntimeShape() {
80 if (size_ > kMaxSmallSize) {
81 delete[] dims_pointer_;
82 }
83 }
84
DimensionsCount()85 inline int32_t DimensionsCount() const { return size_; }
Dims(int i)86 inline int32_t Dims(int i) const {
87 TFLITE_DCHECK_GE(i, 0);
88 TFLITE_DCHECK_LT(i, size_);
89 return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
90 }
SetDim(int i,int32_t val)91 inline void SetDim(int i, int32_t val) {
92 TFLITE_DCHECK_GE(i, 0);
93 TFLITE_DCHECK_LT(i, size_);
94 if (size_ > kMaxSmallSize) {
95 dims_pointer_[i] = val;
96 } else {
97 dims_[i] = val;
98 }
99 }
100
DimsData()101 inline int32_t* DimsData() {
102 return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
103 }
DimsData()104 inline const int32_t* DimsData() const {
105 return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
106 }
107 // The caller must ensure that the shape is no bigger than 5-D.
DimsDataUpTo5D()108 inline const int32_t* DimsDataUpTo5D() const { return dims_; }
109
Resize(int dimensions_count)110 inline void Resize(int dimensions_count) {
111 if (size_ > kMaxSmallSize) {
112 delete[] dims_pointer_;
113 }
114 size_ = dimensions_count;
115 if (dimensions_count > kMaxSmallSize) {
116 dims_pointer_ = new int32_t[dimensions_count];
117 }
118 }
119
ReplaceWith(int dimensions_count,const int32_t * dims_data)120 inline void ReplaceWith(int dimensions_count, const int32_t* dims_data) {
121 Resize(dimensions_count);
122 int32_t* dst_dims = DimsData();
123 std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
124 }
125
126 template <typename T>
BuildFrom(const T & src_iterable)127 inline void BuildFrom(const T& src_iterable) {
128 const int dimensions_count =
129 std::distance(src_iterable.begin(), src_iterable.end());
130 Resize(dimensions_count);
131 int32_t* data = DimsData();
132 for (auto it : src_iterable) {
133 *data = it;
134 ++data;
135 }
136 }
137
138 // This will probably be factored out. Old code made substantial use of 4-D
139 // shapes, and so this function is used to extend smaller shapes. Note that
140 // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
141 // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
142 // inputs should already be 4-D, so this function should not be needed.
ExtendedShape(int new_shape_size,const RuntimeShape & shape)143 inline static RuntimeShape ExtendedShape(int new_shape_size,
144 const RuntimeShape& shape) {
145 return RuntimeShape(new_shape_size, shape, 1);
146 }
147
BuildFrom(const std::initializer_list<int> init_list)148 inline void BuildFrom(const std::initializer_list<int> init_list) {
149 BuildFrom<const std::initializer_list<int>>(init_list);
150 }
151
152 // Returns the total count of elements, that is the size when flattened into a
153 // vector.
FlatSize()154 inline int FlatSize() const {
155 int buffer_size = 1;
156 const int* dims_data = reinterpret_cast<const int*>(DimsData());
157 for (int i = 0; i < size_; i++) {
158 buffer_size *= dims_data[i];
159 }
160 return buffer_size;
161 }
162
163 bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
164
165 private:
166 // For use only by ExtendedShape(), written to guarantee (return-value) copy
167 // elision in C++17.
168 // This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size,const RuntimeShape & shape,int pad_value)169 RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
170 : size_(0) {
171 // If the following check fails, it is likely because a 4D-only kernel is
172 // being used with an array of larger dimension count.
173 TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
174 Resize(new_shape_size);
175 const int size_increase = new_shape_size - shape.DimensionsCount();
176 for (int i = 0; i < size_increase; ++i) {
177 SetDim(i, pad_value);
178 }
179 std::memcpy(DimsData() + size_increase, shape.DimsData(),
180 sizeof(int32_t) * shape.DimensionsCount());
181 }
182
183 int32_t size_;
184 union {
185 int32_t dims_[kMaxSmallSize];
186 int32_t* dims_pointer_;
187 };
188 };
189
190 // Converts inference-style shape to legacy tflite::Dims<4>.
ToRuntimeDims(const tflite::RuntimeShape & array_shape)191 inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
192 tflite::Dims<4> result;
193 const int dimensions_count = array_shape.DimensionsCount();
194 TFLITE_CHECK_LE(dimensions_count, 4);
195 int cum_prod = 1;
196 for (int i = 0; i < 4; i++) {
197 const int new_dim =
198 (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
199 result.sizes[i] = new_dim;
200 result.strides[i] = cum_prod;
201 cum_prod *= new_dim;
202 }
203 return result;
204 }
205
206 // TODO(b/80418076): Move to legacy ops file, update invocations.
DimsToShape(const tflite::Dims<4> & dims)207 inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
208 return RuntimeShape(
209 {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
210 }
211
212 // Since tensors with '0' in their shape are valid in TF, these offset functions
213 // allow that as long as the corresponding index is also 0. It is upto the
214 // calling ops to ensure that they perform verification checks on tensor shapes
215 // if they don't support a particular behavior.
216
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3)217 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
218 TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
219 const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
220 TFLITE_DCHECK((dims_data[0] == 0 && i0 == 0) ||
221 (i0 >= 0 && i0 < dims_data[0]));
222 TFLITE_DCHECK((dims_data[1] == 0 && i1 == 0) ||
223 (i1 >= 0 && i1 < dims_data[1]));
224 TFLITE_DCHECK((dims_data[2] == 0 && i2 == 0) ||
225 (i2 >= 0 && i2 < dims_data[2]));
226 TFLITE_DCHECK((dims_data[3] == 0 && i3 == 0) ||
227 (i3 >= 0 && i3 < dims_data[3]));
228 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
229 }
230
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3,int i4)231 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3,
232 int i4) {
233 TFLITE_DCHECK_EQ(shape.DimensionsCount(), 5);
234 const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
235 TFLITE_DCHECK((dims_data[0] == 0 && i0 == 0) ||
236 (i0 >= 0 && i0 < dims_data[0]));
237 TFLITE_DCHECK((dims_data[1] == 0 && i1 == 0) ||
238 (i1 >= 0 && i1 < dims_data[1]));
239 TFLITE_DCHECK((dims_data[2] == 0 && i2 == 0) ||
240 (i2 >= 0 && i2 < dims_data[2]));
241 TFLITE_DCHECK((dims_data[3] == 0 && i3 == 0) ||
242 (i3 >= 0 && i3 < dims_data[3]));
243 TFLITE_DCHECK((dims_data[4] == 0 && i4 == 0) ||
244 (i4 >= 0 && i4 < dims_data[4]));
245 return (((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3) *
246 dims_data[4] +
247 i4;
248 }
249
Offset(const RuntimeShape & shape,int * index)250 inline int Offset(const RuntimeShape& shape, int* index) {
251 return Offset(shape, index[0], index[1], index[2], index[3]);
252 }
253
254 } // namespace tflite
255
256 #endif // ENSORFLOW_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_
257