• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
18 
19 #include <sys/types.h>
20 
21 #include <algorithm>
22 #include <array>
23 #include <functional>
24 #include <numeric>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "absl/hash/hash.h"
30 
31 namespace tflite {
32 namespace gpu {
33 
34 enum class Axis {
35   UNKNOWN = 0,
36   CHANNELS = 1,
37   INPUT_CHANNELS = 2,
38   OUTPUT_CHANNELS = 3,
39   HEIGHT = 4,
40   WIDTH = 5,
41   BATCH = 6,
42   VALUE = 7,
43   DEPTH = 8,
44 };
45 
46 std::string ToString(Axis t);
47 
48 // Layout represents axis order.
49 enum class Layout {
50   UNKNOWN = 0,
51   SCALAR = 1,
52   LINEAR = 2,
53   HW = 3,
54   CHW = 4,
55   HWC = 5,
56   OIHW = 6,
57   OHWI = 7,
58   IHWO = 8,
59   IOHW = 9,
60   BHWC = 10,
61   HWDC = 11,
62   BHWDC = 12,
63   HWD = 13,
64   OHWDI = 14,
65 };
66 
67 std::string ToString(Layout l);
68 
69 // Returns number of axis for the fixed layout.
70 template <Layout T>
71 constexpr int Size();
72 
73 // Returns number of axis for the given layout.
74 int Size(Layout layout);
75 
76 // Returns Axis for the given index and fixed layout.
77 template <Layout T>
78 constexpr Axis GetAxis(int index);
79 
80 // Returns axis for the given layout and index.
81 Axis GetAxis(Layout layout, int32_t index);
82 
83 // Returns axis index for the given axis and fixed layout.
84 template <Layout T>
85 constexpr int GetAxisIndex(Axis axis);
86 
87 // Returns axis index for the given layout and axis.
88 int GetAxisIndex(Layout layout, Axis axis);
89 
90 // Checks if fixed layout has given axis
91 template <Layout T>
92 constexpr bool HasAxis(Axis axis);
93 
94 // Checks if given layout has given axis
95 bool HasAxis(Layout layout, Axis axis);
96 
97 // Stores Layout(axis set and order) and value for dimensions.
98 struct Shape {
ShapeShape99   Shape() : layout(Layout::UNKNOWN), dimensions() {}
100 
ShapeShape101   explicit Shape(Layout t) : layout(t), dimensions(Size(t)) {}
102 
ShapeShape103   Shape(Layout t, std::vector<int32_t> d)
104       : layout(t), dimensions(std::move(d)) {}
105 
106   bool operator==(const Shape& other) const {
107     return (layout == other.layout) && (dimensions == other.dimensions);
108   }
109 
110   bool operator!=(const Shape& other) const { return !operator==(other); }
111 
112   // All methods below are matching same methods defined in StrongShape to
113   // make sure generic algorithms work both ways.
114 
115   // Returns back a dimension or -1 if it is not found.
116   template <Axis D>
117   int32_t get() const;
118   int32_t get(Axis axis) const;
119 
120   template <Axis D>
121   bool set(int32_t t);
122   bool set(Axis axis, int32_t t);
123 
axisShape124   Axis axis(int index) const { return GetAxis(layout, index); }
125 
indexShape126   int index(Axis axis) const { return GetAxisIndex(layout, axis); }
127 
hasShape128   bool has(Axis axis) const { return HasAxis(layout, axis); }
129 
DimensionsProductShape130   int64_t DimensionsProduct() const {
131     return std::accumulate(dimensions.begin(), dimensions.end(), 1ll,
132                            std::multiplies<int64_t>());
133   }
134 
135   Layout layout = Layout::UNKNOWN;
136 
137   std::vector<int32_t> dimensions;
138 };
139 
140 std::string ToString(const Shape& s);
141 
142 // StrongShape provides convenient explicit access to dimensions stored in
143 // shape, e.g. StrongShape<Layout::HW> s; provides s.h and s.w accessors.
144 //
145 // There is a conversion possible both ways between Shape and StrongShape.
146 //
147 //   OIHW oihw;  // specific shape
148 //   Shape l = oihw.ToShape();
149 //
150 //   OHWI other;  // notice not the same but compatible shape.
151 //   if (!other.Adopt(l)) {
152 //     // error handling
153 //   }
154 //
155 // StrongShape supports the following set of operations:
156 //
157 //   // Returns number of axis in the shape class.
158 //   static constexpr int size();
159 //
160 //   // Returns Axis for the given index or Axis::UNKNOWN if index
161 //   // falls outside of the defined range in this shape.
162 //   static constexpr Axis axis(int index);
163 //
164 //   // Returns index for the given axis or -1 if axis is not defined in this
165 //   // shape.
166 //   static constexpr int index(Axis axis);
167 //
168 //   // Getters
169 //   int32_t get(int index) const;
170 //   int32_t get(Axis axis) const;
171 //   int32_t get<Axis>() const;
172 //
173 //   // Setters that return false if set was not successful.
174 //   bool set(int index, int32_t v);
175 //   bool set(Axis axis, int32_t v);
176 //   bool set<Axis>(int32_t v);
177 //
178 //   // Returns shape's layout.
179 //   static const Layout layout;
180 //
181 //   // Turns specific shape into generic shape.
182 //   Shape ToShape() const;
183 //
184 //   // Copies all dimensions from the given shape.
185 //   bool Adopt(const Shape&);
186 //
187 template <Layout L>
188 struct StrongShape;
189 
190 using Scalar = StrongShape<Layout::SCALAR>;
191 using Linear = StrongShape<Layout::LINEAR>;
192 using HW = StrongShape<Layout::HW>;
193 using HWD = StrongShape<Layout::HWD>;
194 
195 // Common tensor shape for CNN models working with images.
196 using CHW = StrongShape<Layout::CHW>;
197 using HWC = StrongShape<Layout::HWC>;
198 using HWDC = StrongShape<Layout::HWDC>;
199 using BHWC = StrongShape<Layout::BHWC>;
200 using BHWDC = StrongShape<Layout::BHWDC>;
201 
202 // Tensor shape used in convolution_2d weights.
203 using OIHW = StrongShape<Layout::OIHW>;
204 using OHWI = StrongShape<Layout::OHWI>;
205 using IHWO = StrongShape<Layout::IHWO>;
206 using IOHW = StrongShape<Layout::IOHW>;
207 
208 // Tensor shape used in convolution_3d weights.
209 using OHWDI = StrongShape<Layout::OHWDI>;
210 
211 // -----------------------------------------------------------------------------
212 // Everything below are internal implementation details.
213 // -----------------------------------------------------------------------------
214 
215 namespace internal_shape {
216 
217 template <Axis T>
218 struct AxisTraits;
219 
220 #define TFLITE_GPU_AXIS_TRAITS(AxisName, HolderName)    \
221   template <>                                           \
222   struct AxisTraits<Axis::AxisName> {                   \
223     struct Holder {                                     \
224       int32_t HolderName;                               \
225                                                         \
226      protected:                                         \
227       int32_t operator()() const { return HolderName; } \
228       void operator()(int32_t v) { HolderName = v; }    \
229     };                                                  \
230                                                         \
231     using dimension_holder_type = Holder;               \
232   }
233 
234 TFLITE_GPU_AXIS_TRAITS(CHANNELS, c);
235 TFLITE_GPU_AXIS_TRAITS(HEIGHT, h);
236 TFLITE_GPU_AXIS_TRAITS(WIDTH, w);
237 TFLITE_GPU_AXIS_TRAITS(INPUT_CHANNELS, i);
238 TFLITE_GPU_AXIS_TRAITS(OUTPUT_CHANNELS, o);
239 TFLITE_GPU_AXIS_TRAITS(BATCH, b);
240 TFLITE_GPU_AXIS_TRAITS(VALUE, v);
241 TFLITE_GPU_AXIS_TRAITS(DEPTH, d);
242 
243 #undef TFLITE_GPU_AXIS_TRAITS
244 
245 template <int N, Axis... As>
246 struct StrongShapeImpl;
247 
248 template <int N>
249 struct StrongShapeImpl<N> {
250   static constexpr int size() { return N; }
251 
252   static constexpr Axis axis(int) { return Axis::UNKNOWN; }
253 
254   static constexpr int index(Axis) { return -1; }
255 
256   static constexpr bool has(Axis) { return false; }
257 
258   int32_t get(Axis) const { return -1; }
259 
260   int32_t get(int) const { return -1; }
261 
262   template <Axis B>
263   int32_t get() const {
264     return -1;
265   }
266 
267   bool set(Axis, int32_t) { return false; }
268 
269   bool set(int, int32_t) { return false; }
270 
271   template <Axis B>
272   bool set(int32_t) {
273     return false;
274   }
275 };
276 
277 // Used to deduce number of axis, and to be a child of a proper holder to
278 // provide access to the dimension by name
279 template <int N, Axis A, Axis... As>
280 struct StrongShapeImpl<N, A, As...>
281     : public AxisTraits<A>::dimension_holder_type,
282       public StrongShapeImpl<N + 1, As...> {
283   using dimension_holder_type = typename AxisTraits<A>::dimension_holder_type;
284 
285   using rest_type = StrongShapeImpl<N + 1, As...>;
286 
287   StrongShapeImpl() : dimension_holder_type{0}, rest_type() {}
288 
289   template <typename... Ts>
290   explicit StrongShapeImpl(int32_t t, Ts... ts)
291       : dimension_holder_type{t}, rest_type(ts...) {}
292 
293   static constexpr Axis axis(int index) {
294     return index == N ? A : rest_type::axis(index);
295   }
296 
297   static constexpr int index(Axis axis) {
298     return axis == A ? N : rest_type::index(axis);
299   }
300 
301   static constexpr bool has(Axis axis) {
302     return axis == A ? true : rest_type::has(axis);
303   }
304 
305   int32_t get(Axis axis) const {
306     return axis == A ? dimension_holder_type::operator()()
307                      : rest_type::get(axis);
308   }
309 
310   template <Axis B>
311   int32_t get() const {
312     return B == A ? dimension_holder_type::operator()()
313                   : rest_type::template get<B>();
314   }
315 
316   int32_t get(int index) const {
317     return index == N ? dimension_holder_type::operator()()
318                       : rest_type::get(index);
319   }
320 
321   bool set(Axis axis, int32_t t) {
322     if (axis == A) {
323       dimension_holder_type::operator()(t);
324       return true;
325     }
326     return rest_type::set(axis, t);
327   }
328 
329   bool set(int index, int32_t t) {
330     if (index == N) {
331       dimension_holder_type::operator()(t);
332       return true;
333     }
334     return rest_type::set(index, t);
335   }
336 
337   template <Axis B>
338   bool set(int32_t t) {
339     if (A == B) {
340       dimension_holder_type::operator()(t);
341       return true;
342     }
343     return rest_type::template set<B>(t);
344   }
345 };
346 
347 template <Layout T>
348 struct LayoutTraits;
349 
350 #define TFLITE_GPU_LAYOUT_TRAITS(LayoutName, ...)              \
351   template <>                                                  \
352   struct LayoutTraits<Layout::LayoutName> {                    \
353     using strong_shape_type = StrongShapeImpl<0, __VA_ARGS__>; \
354   }
355 
356 TFLITE_GPU_LAYOUT_TRAITS(HW, Axis::HEIGHT, Axis::WIDTH);
357 TFLITE_GPU_LAYOUT_TRAITS(HWD, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH);
358 TFLITE_GPU_LAYOUT_TRAITS(OHWI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH,
359                          Axis::INPUT_CHANNELS);
360 TFLITE_GPU_LAYOUT_TRAITS(OIHW, Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS,
361                          Axis::HEIGHT, Axis::WIDTH);
362 TFLITE_GPU_LAYOUT_TRAITS(IOHW, Axis::INPUT_CHANNELS, Axis::OUTPUT_CHANNELS,
363                          Axis::HEIGHT, Axis::WIDTH);
364 TFLITE_GPU_LAYOUT_TRAITS(IHWO, Axis::INPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH,
365                          Axis::OUTPUT_CHANNELS);
366 TFLITE_GPU_LAYOUT_TRAITS(CHW, Axis::CHANNELS, Axis::HEIGHT, Axis::WIDTH);
367 TFLITE_GPU_LAYOUT_TRAITS(HWC, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS);
368 TFLITE_GPU_LAYOUT_TRAITS(HWDC, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH,
369                          Axis::CHANNELS);
370 TFLITE_GPU_LAYOUT_TRAITS(LINEAR, Axis::VALUE);
371 TFLITE_GPU_LAYOUT_TRAITS(SCALAR, Axis::VALUE);
372 TFLITE_GPU_LAYOUT_TRAITS(BHWC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH,
373                          Axis::CHANNELS);
374 TFLITE_GPU_LAYOUT_TRAITS(BHWDC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH,
375                          Axis::DEPTH, Axis::CHANNELS);
376 TFLITE_GPU_LAYOUT_TRAITS(OHWDI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT,
377                          Axis::WIDTH, Axis::DEPTH, Axis::INPUT_CHANNELS);
378 
379 #undef TFLITE_GPU_LAYOUT_TRAITS
380 
381 template <>
382 struct LayoutTraits<Layout::UNKNOWN> {
383   using strong_shape_type = StrongShapeImpl<0>;
384 };
385 
386 template <Axis A>
387 struct DimensionGetterFixedAxisFunc {
388   template <Layout T>
389   int32_t operator()() const {
390     constexpr int i = GetAxisIndex<T>(A);
391     return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1;
392   }
393   const Shape* l;
394 };
395 
396 struct DimensionGetterFunc {
397   template <Layout T>
398   int32_t operator()() const {
399     int i = GetAxisIndex<T>(axis);
400     return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1;
401   }
402   Axis axis;
403   const Shape* l;
404 };
405 
406 template <Axis A>
407 struct DimensionSetterFixedAxisFunc {
408   template <Layout T>
409   bool operator()() const {
410     constexpr int i = GetAxisIndex<T>(A);
411     if (i >= 0 && i < l->dimensions.size()) {
412       l->dimensions[i] = v;
413       return true;
414     }
415     return false;
416   }
417   Shape* l;
418   int32_t v;
419 };
420 
421 struct DimensionSetterFunc {
422   template <Layout T>
423   bool operator()() const {
424     int i = GetAxisIndex<T>(axis);
425     if (i >= 0 && i < l->dimensions.size()) {
426       l->dimensions[i] = v;
427       return true;
428     }
429     return false;
430   }
431   Axis axis;
432   Shape* l;
433   int32_t v;
434 };
435 
436 template <Layout L>
437 struct ToShapeFunc {
438   template <Layout T>
439   bool operator()() const {
440     for (int i = 0; i < StrongShape<L>::size(); ++i) {
441       int index = GetAxisIndex<T>(StrongShape<L>::axis(i));
442       if (index < 0) return false;
443       shape->set(i, l.dimensions[index]);
444     }
445     return true;
446   }
447 
448   StrongShape<L>* shape;
449   const Shape& l;
450 };
451 
452 }  // namespace internal_shape
453 
454 // template <Axis... As>
455 template <Layout L>
456 struct StrongShape : public internal_shape::LayoutTraits<L>::strong_shape_type {
457   using strong_shape_type =
458       typename internal_shape::LayoutTraits<L>::strong_shape_type;
459   StrongShape() = default;
460 
461   template <typename... Ts>
462   explicit StrongShape(Ts... t) : strong_shape_type(t...) {}
463 
464   constexpr static Layout layout = L;
465 
466   bool operator==(const StrongShape<L>& shape) const {
467     // TODO(akulik): implement better alternative.
468     return this->ToShape() == shape.ToShape();
469   }
470 
471   bool operator!=(const StrongShape<L>& shape) const {
472     // TODO(akulik): implement better alternative.
473     return this->ToShape() != shape.ToShape();
474   }
475   bool empty() const { return DimensionsProduct() == 0; }
476 
477   // Turns StrongShape into generic shape.
478   Shape ToShape() const {
479     std::vector<int32_t> dimensions(StrongShape::size());
480     for (int i = 0; i < StrongShape::size(); ++i) {
481       dimensions[i] = StrongShape::get(i);
482     }
483     return Shape(L, std::move(dimensions));
484   }
485 
486   // @return all dimensions multiplied
487   int64_t DimensionsProduct() const {
488     int64_t product = 1;
489     for (int i = 0; i < StrongShape::size(); ++i) {
490       product *= StrongShape::get(i);
491     }
492     return product;
493   }
494 
495   // Translates given coordinates of the layout into a linear index assuming
496   // dimensions are sorted in tensor access order e.g. if you access
497   // foobar[i][j][k] order of coordinates should be i,j,k.
498   int64_t LinearIndex(
499       const std::array<int32_t, StrongShape::size()>& coordinates) const {
500     int64_t index = coordinates[0];
501     for (int i = 1; i < StrongShape::size(); ++i) {
502       index = index * StrongShape::get(i) + coordinates[i];
503     }
504     return index;
505   }
506 
507   // Copies all dimensions from the given generic shape into specific shape.
508   // It requires shape to have all axis defined in the given
509   // StrongShape. For example:
510   //   - If this shape is OHWI but given shape is OIHW, Adopt will copy all
511   //     dimensions and return true.
512   //   - If this shape is OIHW but input shape is HW, Adopt will copy H and W
513   //     dimensions and return true, but if this shape is HW and given shape
514   //     OIHW, then Adopt will return false because not all axis are present in
515   //     the input shape.
516   //
517   // @return false if generic shape is not compatible.
518   bool Adopt(const Shape& shape) {
519     return DispatchByLayout(shape.layout,
520                             internal_shape::ToShapeFunc<L>{this, shape});
521   }
522 
523   // For all axis defined in a given shape copies values to this shape.
524   // Therefore, it is possible to copy dimensions from CHW to BCHW, but not
525   // the other way around.
526   //
527   // BCHW bchw;
528   // CHW chw;
529   // bchw.CopyAllGivenAxis(chw);  --> true
530   // chw.CopyAllGivenAxis(bchw);  --> false
531   //
532   // @return false if axis in source shape is not defined here, thus value
533   //         was not copied.
534   template <Layout B>
535   bool CopyAllGivenAxis(const StrongShape<B>& source) {
536     for (int i = 0; i < source.size(); ++i) {
537       if (!StrongShape::set(source.axis(i), source.get(i))) {
538         return false;
539       }
540     }
541     return true;
542   }
543 
544   // For all axis defined in this shape copies values from the given shape.
545   //
546   // BCHW bchw;
547   // CHW chw;
548   // bchw.CopyAllDefinedAxis(chw);  --> false
549   // chw.CopyAllDefinedAxis(bchw);  --> true
550   //
551   // @return false if given shape does not have axis defined here,
552   //         therefore a value was not copied.
553   template <Layout B>
554   bool CopyAllDefinedAxis(const StrongShape<B>& source) {
555     for (int i = 0; i < StrongShape::size(); ++i) {
556       int source_index = source.index(StrongShape::axis(i));
557       if (source_index < 0) {
558         return false;
559       }
560       StrongShape::set(i, source.get(source_index));  // always true
561     }
562     return true;
563   }
564 
565   // Copies values only for matching axis.
566   template <Layout B>
567   void CopyMatchingAxis(const StrongShape<B>& source) {
568     for (int i = 0; i < StrongShape::size(); ++i) {
569       StrongShape::set(source.axis(i), source.get(i));
570     }
571   }
572 
573   // AbslHash function for using in flat hash containers.
574   template <typename H>
575   friend H AbslHashValue(H hash_state, const StrongShape& strong_shape) {
576     for (size_t i = 0; i < strong_shape.size(); ++i) {
577       hash_state = H::combine(std::move(hash_state), strong_shape.get(i));
578     }
579     return hash_state;
580   }
581 };
582 
583 template <Layout T>
584 inline std::string ToString(const StrongShape<T>& s) {
585   return ToString(s.ToShape());
586 }
587 
588 template <Layout L>
589 constexpr Layout StrongShape<L>::layout;
590 
591 template <class F>
592 auto DispatchByLayout(Layout type, F f)
593     -> decltype(f.template operator()<Layout::UNKNOWN>()) {
594   switch (type) {
595     case Layout::HW:
596       return f.template operator()<Layout::HW>();
597     case Layout::HWD:
598       return f.template operator()<Layout::HWD>();
599     case Layout::HWC:
600       return f.template operator()<Layout::HWC>();
601     case Layout::HWDC:
602       return f.template operator()<Layout::HWDC>();
603     case Layout::CHW:
604       return f.template operator()<Layout::CHW>();
605     case Layout::OIHW:
606       return f.template operator()<Layout::OIHW>();
607     case Layout::IOHW:
608       return f.template operator()<Layout::IOHW>();
609     case Layout::OHWI:
610       return f.template operator()<Layout::OHWI>();
611     case Layout::IHWO:
612       return f.template operator()<Layout::IHWO>();
613     case Layout::LINEAR:
614       return f.template operator()<Layout::LINEAR>();
615     case Layout::SCALAR:
616       return f.template operator()<Layout::SCALAR>();
617     case Layout::BHWC:
618       return f.template operator()<Layout::BHWC>();
619     case Layout::BHWDC:
620       return f.template operator()<Layout::BHWDC>();
621     case Layout::OHWDI:
622       return f.template operator()<Layout::OHWDI>();
623     case Layout::UNKNOWN:
624       return f.template operator()<Layout::UNKNOWN>();
625   }
626 }
627 
628 template <Layout T>
629 constexpr int Size() {
630   return StrongShape<T>::size();
631 }
632 
633 template <Layout T>
634 constexpr Axis GetAxis(int index) {
635   return StrongShape<T>::axis(index);
636 }
637 
638 template <Layout T>
639 constexpr int GetAxisIndex(Axis axis) {
640   return StrongShape<T>::index(axis);
641 }
642 
643 template <Layout T>
644 constexpr bool HasAxis(Axis axis) {
645   return StrongShape<T>::has(axis);
646 }
647 
648 template <Axis D>
649 inline int32_t Shape::get() const {
650   return DispatchByLayout(
651       layout, internal_shape::DimensionGetterFixedAxisFunc<D>{this});
652 }
653 
654 inline int32_t Shape::get(Axis axis) const {
655   return DispatchByLayout(layout,
656                           internal_shape::DimensionGetterFunc{axis, this});
657 }
658 
659 template <Axis D>
660 inline bool Shape::set(int32_t t) {
661   return DispatchByLayout(
662       layout, internal_shape::DimensionSetterFixedAxisFunc<D>{this, t});
663 }
664 
665 inline bool Shape::set(Axis axis, int32_t t) {
666   return DispatchByLayout(layout,
667                           internal_shape::DimensionSetterFunc{axis, this, t});
668 }
669 
670 }  // namespace gpu
671 }  // namespace tflite
672 
673 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
674