• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_
17 
18 #include <cstdint>
19 #include <cstring>
20 #include <initializer_list>
21 #include <iterator>
22 
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 
25 namespace tflite {
26 
27 template <int N>
28 struct Dims {
29   int sizes[N];
30   int strides[N];
31 };
32 
33 class RuntimeShape {
34  public:
35   // Shapes with dimensions up to 5 are stored directly in the structure, while
36   // larger shapes are separately allocated.
37   static constexpr int kMaxSmallSize = 5;
38 
39   RuntimeShape& operator=(RuntimeShape const&) = delete;
40 
RuntimeShape()41   RuntimeShape() : size_(0) {}
42 
RuntimeShape(int dimensions_count)43   explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
44     if (dimensions_count > kMaxSmallSize) {
45       dims_pointer_ = new int32_t[dimensions_count];
46     }
47   }
48 
RuntimeShape(int shape_size,int32_t value)49   RuntimeShape(int shape_size, int32_t value) : size_(0) {
50     Resize(shape_size);
51     for (int i = 0; i < shape_size; ++i) {
52       SetDim(i, value);
53     }
54   }
55 
RuntimeShape(int dimensions_count,const int32_t * dims_data)56   RuntimeShape(int dimensions_count, const int32_t* dims_data) : size_(0) {
57     ReplaceWith(dimensions_count, dims_data);
58   }
59 
RuntimeShape(const std::initializer_list<int> init_list)60   RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
61     BuildFrom(init_list);
62   }
63 
64   // Avoid using this constructor.  We should be able to delete it when C++17
65   // rolls out.
RuntimeShape(RuntimeShape const & other)66   RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
67     if (size_ > kMaxSmallSize) {
68       dims_pointer_ = new int32_t[size_];
69     }
70     std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * size_);
71   }
72 
73   bool operator==(const RuntimeShape& comp) const {
74     return this->size_ == comp.size_ &&
75            std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32_t)) ==
76                0;
77   }
78 
~RuntimeShape()79   ~RuntimeShape() {
80     if (size_ > kMaxSmallSize) {
81       delete[] dims_pointer_;
82     }
83   }
84 
DimensionsCount()85   inline int32_t DimensionsCount() const { return size_; }
Dims(int i)86   inline int32_t Dims(int i) const {
87     TFLITE_DCHECK_GE(i, 0);
88     TFLITE_DCHECK_LT(i, size_);
89     return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
90   }
SetDim(int i,int32_t val)91   inline void SetDim(int i, int32_t val) {
92     TFLITE_DCHECK_GE(i, 0);
93     TFLITE_DCHECK_LT(i, size_);
94     if (size_ > kMaxSmallSize) {
95       dims_pointer_[i] = val;
96     } else {
97       dims_[i] = val;
98     }
99   }
100 
DimsData()101   inline int32_t* DimsData() {
102     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
103   }
DimsData()104   inline const int32_t* DimsData() const {
105     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
106   }
107   // The caller must ensure that the shape is no bigger than 5-D.
DimsDataUpTo5D()108   inline const int32_t* DimsDataUpTo5D() const { return dims_; }
109 
Resize(int dimensions_count)110   inline void Resize(int dimensions_count) {
111     if (size_ > kMaxSmallSize) {
112       delete[] dims_pointer_;
113     }
114     size_ = dimensions_count;
115     if (dimensions_count > kMaxSmallSize) {
116       dims_pointer_ = new int32_t[dimensions_count];
117     }
118   }
119 
ReplaceWith(int dimensions_count,const int32_t * dims_data)120   inline void ReplaceWith(int dimensions_count, const int32_t* dims_data) {
121     Resize(dimensions_count);
122     int32_t* dst_dims = DimsData();
123     std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
124   }
125 
126   template <typename T>
BuildFrom(const T & src_iterable)127   inline void BuildFrom(const T& src_iterable) {
128     const int dimensions_count =
129         std::distance(src_iterable.begin(), src_iterable.end());
130     Resize(dimensions_count);
131     int32_t* data = DimsData();
132     for (auto it : src_iterable) {
133       *data = it;
134       ++data;
135     }
136   }
137 
138   // This will probably be factored out. Old code made substantial use of 4-D
139   // shapes, and so this function is used to extend smaller shapes. Note that
140   // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
141   // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
142   // inputs should already be 4-D, so this function should not be needed.
ExtendedShape(int new_shape_size,const RuntimeShape & shape)143   inline static RuntimeShape ExtendedShape(int new_shape_size,
144                                            const RuntimeShape& shape) {
145     return RuntimeShape(new_shape_size, shape, 1);
146   }
147 
BuildFrom(const std::initializer_list<int> init_list)148   inline void BuildFrom(const std::initializer_list<int> init_list) {
149     BuildFrom<const std::initializer_list<int>>(init_list);
150   }
151 
152   // Returns the total count of elements, that is the size when flattened into a
153   // vector.
FlatSize()154   inline int FlatSize() const {
155     int buffer_size = 1;
156     const int* dims_data = reinterpret_cast<const int*>(DimsData());
157     for (int i = 0; i < size_; i++) {
158       buffer_size *= dims_data[i];
159     }
160     return buffer_size;
161   }
162 
163   bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
164 
165  private:
166   // For use only by ExtendedShape(), written to guarantee (return-value) copy
167   // elision in C++17.
168   // This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size,const RuntimeShape & shape,int pad_value)169   RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
170       : size_(0) {
171     // If the following check fails, it is likely because a 4D-only kernel is
172     // being used with an array of larger dimension count.
173     TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
174     Resize(new_shape_size);
175     const int size_increase = new_shape_size - shape.DimensionsCount();
176     for (int i = 0; i < size_increase; ++i) {
177       SetDim(i, pad_value);
178     }
179     std::memcpy(DimsData() + size_increase, shape.DimsData(),
180                 sizeof(int32_t) * shape.DimensionsCount());
181   }
182 
183   int32_t size_;
184   union {
185     int32_t dims_[kMaxSmallSize];
186     int32_t* dims_pointer_;
187   };
188 };
189 
190 // Converts inference-style shape to legacy tflite::Dims<4>.
ToRuntimeDims(const tflite::RuntimeShape & array_shape)191 inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
192   tflite::Dims<4> result;
193   const int dimensions_count = array_shape.DimensionsCount();
194   TFLITE_CHECK_LE(dimensions_count, 4);
195   int cum_prod = 1;
196   for (int i = 0; i < 4; i++) {
197     const int new_dim =
198         (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
199     result.sizes[i] = new_dim;
200     result.strides[i] = cum_prod;
201     cum_prod *= new_dim;
202   }
203   return result;
204 }
205 
206 // TODO(b/80418076): Move to legacy ops file, update invocations.
DimsToShape(const tflite::Dims<4> & dims)207 inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
208   return RuntimeShape(
209       {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
210 }
211 
212 // Since tensors with '0' in their shape are valid in TF, these offset functions
213 // allow that as long as the corresponding index is also 0. It is upto the
214 // calling ops to ensure that they perform verification checks on tensor shapes
215 // if they don't support a particular behavior.
216 
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3)217 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
218   TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
219   const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
220   TFLITE_DCHECK((dims_data[0] == 0 && i0 == 0) ||
221                 (i0 >= 0 && i0 < dims_data[0]));
222   TFLITE_DCHECK((dims_data[1] == 0 && i1 == 0) ||
223                 (i1 >= 0 && i1 < dims_data[1]));
224   TFLITE_DCHECK((dims_data[2] == 0 && i2 == 0) ||
225                 (i2 >= 0 && i2 < dims_data[2]));
226   TFLITE_DCHECK((dims_data[3] == 0 && i3 == 0) ||
227                 (i3 >= 0 && i3 < dims_data[3]));
228   return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
229 }
230 
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3,int i4)231 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3,
232                   int i4) {
233   TFLITE_DCHECK_EQ(shape.DimensionsCount(), 5);
234   const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
235   TFLITE_DCHECK((dims_data[0] == 0 && i0 == 0) ||
236                 (i0 >= 0 && i0 < dims_data[0]));
237   TFLITE_DCHECK((dims_data[1] == 0 && i1 == 0) ||
238                 (i1 >= 0 && i1 < dims_data[1]));
239   TFLITE_DCHECK((dims_data[2] == 0 && i2 == 0) ||
240                 (i2 >= 0 && i2 < dims_data[2]));
241   TFLITE_DCHECK((dims_data[3] == 0 && i3 == 0) ||
242                 (i3 >= 0 && i3 < dims_data[3]));
243   TFLITE_DCHECK((dims_data[4] == 0 && i4 == 0) ||
244                 (i4 >= 0 && i4 < dims_data[4]));
245   return (((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3) *
246              dims_data[4] +
247          i4;
248 }
249 
Offset(const RuntimeShape & shape,int * index)250 inline int Offset(const RuntimeShape& shape, int* index) {
251   return Offset(shape, index[0], index[1], index[2], index[3]);
252 }
253 
254 }  // namespace tflite
255 
256 #endif  // ENSORFLOW_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_
257