• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
18 
19 #include <stddef.h>
20 #include <stdint.h>
21 
22 #include <array>
23 #include <functional>
24 #include <numeric>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 namespace tflite {
30 namespace gpu {
31 
32 enum class Axis {
33   UNKNOWN = 0,
34   CHANNELS = 1,
35   INPUT_CHANNELS = 2,
36   OUTPUT_CHANNELS = 3,
37   HEIGHT = 4,
38   WIDTH = 5,
39   BATCH = 6,
40   VALUE = 7,
41   DEPTH = 8,
42 };
43 
44 std::string ToString(Axis t);
45 
46 // Layout represents axis order.
47 enum class Layout {
48   UNKNOWN = 0,
49   SCALAR = 1,
50   LINEAR = 2,
51   HW = 3,
52   CHW = 4,
53   HWC = 5,
54   OIHW = 6,
55   OHWI = 7,
56   IHWO = 8,
57   IOHW = 9,
58   BHWC = 10,
59   HWDC = 11,
60   BHWDC = 12,
61   HWD = 13,
62   OHWDI = 14,
63 };
64 
65 std::string ToString(Layout l);
66 
67 // Returns number of axis for the fixed layout.
68 template <Layout T>
69 constexpr int Size();
70 
71 // Returns number of axis for the given layout.
72 int Size(Layout layout);
73 
74 // Returns Axis for the given index and fixed layout.
75 template <Layout T>
76 constexpr Axis GetAxis(int index);
77 
78 // Returns axis for the given layout and index.
79 Axis GetAxis(Layout layout, int32_t index);
80 
81 // Returns axis index for the given axis and fixed layout.
82 template <Layout T>
83 constexpr int GetAxisIndex(Axis axis);
84 
85 // Returns axis index for the given layout and axis.
86 int GetAxisIndex(Layout layout, Axis axis);
87 
88 // Checks if fixed layout has given axis
89 template <Layout T>
90 constexpr bool HasAxis(Axis axis);
91 
92 // Checks if given layout has given axis
93 bool HasAxis(Layout layout, Axis axis);
94 
95 // Stores Layout(axis set and order) and value for dimensions.
96 struct Shape {
ShapeShape97   Shape() : layout(Layout::UNKNOWN), dimensions() {}
98 
ShapeShape99   explicit Shape(Layout t) : layout(t), dimensions(Size(t)) {}
100 
ShapeShape101   Shape(Layout t, std::vector<int32_t> d)
102       : layout(t), dimensions(std::move(d)) {}
103 
104   bool operator==(const Shape& other) const {
105     return (layout == other.layout) && (dimensions == other.dimensions);
106   }
107 
108   bool operator!=(const Shape& other) const { return !operator==(other); }
109 
110   // All methods below are matching same methods defined in StrongShape to
111   // make sure generic algorithms work both ways.
112 
113   // Returns back a dimension or -1 if it is not found.
114   template <Axis D>
115   int32_t get() const;
116   int32_t get(Axis axis) const;
117 
118   template <Axis D>
119   bool set(int32_t t);
120   bool set(Axis axis, int32_t t);
121 
axisShape122   Axis axis(int index) const { return GetAxis(layout, index); }
123 
indexShape124   int index(Axis axis) const { return GetAxisIndex(layout, axis); }
125 
hasShape126   bool has(Axis axis) const { return HasAxis(layout, axis); }
127 
DimensionsProductShape128   int64_t DimensionsProduct() const {
129     return std::accumulate(dimensions.begin(), dimensions.end(), 1ll,
130                            std::multiplies<int64_t>());
131   }
132 
133   Layout layout = Layout::UNKNOWN;
134 
135   std::vector<int32_t> dimensions;
136 };
137 
138 std::string ToString(const Shape& s);
139 
140 // StrongShape provides convenient explicit access to dimensions stored in
141 // shape, e.g. StrongShape<Layout::HW> s; provides s.h and s.w accessors.
142 //
143 // There is a conversion possible both ways between Shape and StrongShape.
144 //
145 //   OIHW oihw;  // specific shape
146 //   Shape l = oihw.ToShape();
147 //
148 //   OHWI other;  // notice not the same but compatible shape.
149 //   if (!other.Adopt(l)) {
150 //     // error handling
151 //   }
152 //
153 // StrongShape supports the following set of operations:
154 //
155 //   // Returns number of axis in the shape class.
156 //   static constexpr int size();
157 //
158 //   // Returns Axis for the given index or Axis::UNKNOWN if index
159 //   // falls outside of the defined range in this shape.
160 //   static constexpr Axis axis(int index);
161 //
162 //   // Returns index for the given axis or -1 if axis is not defined in this
163 //   // shape.
164 //   static constexpr int index(Axis axis);
165 //
166 //   // Getters
167 //   int32_t get(int index) const;
168 //   int32_t get(Axis axis) const;
169 //   int32_t get<Axis>() const;
170 //
171 //   // Setters that return false if set was not successful.
172 //   bool set(int index, int32_t v);
173 //   bool set(Axis axis, int32_t v);
174 //   bool set<Axis>(int32_t v);
175 //
176 //   // Returns shape's layout.
177 //   static const Layout layout;
178 //
179 //   // Turns specific shape into generic shape.
180 //   Shape ToShape() const;
181 //
182 //   // Copies all dimensions from the given shape.
183 //   bool Adopt(const Shape&);
184 //
185 template <Layout L>
186 struct StrongShape;
187 
188 using Scalar = StrongShape<Layout::SCALAR>;
189 using Linear = StrongShape<Layout::LINEAR>;
190 using HW = StrongShape<Layout::HW>;
191 using HWD = StrongShape<Layout::HWD>;
192 
193 // Common tensor shape for CNN models working with images.
194 using CHW = StrongShape<Layout::CHW>;
195 using HWC = StrongShape<Layout::HWC>;
196 using HWDC = StrongShape<Layout::HWDC>;
197 using BHWC = StrongShape<Layout::BHWC>;
198 using BHWDC = StrongShape<Layout::BHWDC>;
199 
200 // Tensor shape used in convolution_2d weights.
201 using OIHW = StrongShape<Layout::OIHW>;
202 using OHWI = StrongShape<Layout::OHWI>;
203 using IHWO = StrongShape<Layout::IHWO>;
204 using IOHW = StrongShape<Layout::IOHW>;
205 
206 // Tensor shape used in convolution_3d weights.
207 using OHWDI = StrongShape<Layout::OHWDI>;
208 
209 // -----------------------------------------------------------------------------
210 // Everything below are internal implementation details.
211 // -----------------------------------------------------------------------------
212 
213 namespace internal_shape {
214 
215 template <Axis T>
216 struct AxisTraits;
217 
218 #define TFLITE_GPU_AXIS_TRAITS(AxisName, HolderName)    \
219   template <>                                           \
220   struct AxisTraits<Axis::AxisName> {                   \
221     struct Holder {                                     \
222       int32_t HolderName;                               \
223                                                         \
224      protected:                                         \
225       int32_t operator()() const { return HolderName; } \
226       void operator()(int32_t v) { HolderName = v; }    \
227     };                                                  \
228                                                         \
229     using dimension_holder_type = Holder;               \
230   }
231 
232 TFLITE_GPU_AXIS_TRAITS(CHANNELS, c);
233 TFLITE_GPU_AXIS_TRAITS(HEIGHT, h);
234 TFLITE_GPU_AXIS_TRAITS(WIDTH, w);
235 TFLITE_GPU_AXIS_TRAITS(INPUT_CHANNELS, i);
236 TFLITE_GPU_AXIS_TRAITS(OUTPUT_CHANNELS, o);
237 TFLITE_GPU_AXIS_TRAITS(BATCH, b);
238 TFLITE_GPU_AXIS_TRAITS(VALUE, v);
239 TFLITE_GPU_AXIS_TRAITS(DEPTH, d);
240 
241 #undef TFLITE_GPU_AXIS_TRAITS
242 
243 template <int N, Axis... As>
244 struct StrongShapeImpl;
245 
246 template <int N>
247 struct StrongShapeImpl<N> {
248   static constexpr int size() { return N; }
249 
250   static constexpr Axis axis(int) { return Axis::UNKNOWN; }
251 
252   static constexpr int index(Axis) { return -1; }
253 
254   static constexpr bool has(Axis) { return false; }
255 
256   int32_t get(Axis) const { return -1; }
257 
258   int32_t get(int) const { return -1; }
259 
260   template <Axis B>
261   int32_t get() const {
262     return -1;
263   }
264 
265   bool set(Axis, int32_t) { return false; }
266 
267   bool set(int, int32_t) { return false; }
268 
269   template <Axis B>
270   bool set(int32_t) {
271     return false;
272   }
273 };
274 
275 // Used to deduce number of axis, and to be a child of a proper holder to
276 // provide access to the dimension by name
277 template <int N, Axis A, Axis... As>
278 struct StrongShapeImpl<N, A, As...>
279     : public AxisTraits<A>::dimension_holder_type,
280       public StrongShapeImpl<N + 1, As...> {
281   using dimension_holder_type = typename AxisTraits<A>::dimension_holder_type;
282 
283   using rest_type = StrongShapeImpl<N + 1, As...>;
284 
285   StrongShapeImpl() : dimension_holder_type{0}, rest_type() {}
286 
287   template <typename... Ts>
288   explicit StrongShapeImpl(int32_t t, Ts... ts)
289       : dimension_holder_type{t}, rest_type(ts...) {}
290 
291   static constexpr Axis axis(int index) {
292     return index == N ? A : rest_type::axis(index);
293   }
294 
295   static constexpr int index(Axis axis) {
296     return axis == A ? N : rest_type::index(axis);
297   }
298 
299   static constexpr bool has(Axis axis) {
300     return axis == A ? true : rest_type::has(axis);
301   }
302 
303   int32_t get(Axis axis) const {
304     return axis == A ? dimension_holder_type::operator()()
305                      : rest_type::get(axis);
306   }
307 
308   template <Axis B>
309   int32_t get() const {
310     return B == A ? dimension_holder_type::operator()()
311                   : rest_type::template get<B>();
312   }
313 
314   int32_t get(int index) const {
315     return index == N ? dimension_holder_type::operator()()
316                       : rest_type::get(index);
317   }
318 
319   bool set(Axis axis, int32_t t) {
320     if (axis == A) {
321       dimension_holder_type::operator()(t);
322       return true;
323     }
324     return rest_type::set(axis, t);
325   }
326 
327   bool set(int index, int32_t t) {
328     if (index == N) {
329       dimension_holder_type::operator()(t);
330       return true;
331     }
332     return rest_type::set(index, t);
333   }
334 
335   template <Axis B>
336   bool set(int32_t t) {
337     if (A == B) {
338       dimension_holder_type::operator()(t);
339       return true;
340     }
341     return rest_type::template set<B>(t);
342   }
343 };
344 
345 template <Layout T>
346 struct LayoutTraits;
347 
348 #define TFLITE_GPU_LAYOUT_TRAITS(LayoutName, ...)              \
349   template <>                                                  \
350   struct LayoutTraits<Layout::LayoutName> {                    \
351     using strong_shape_type = StrongShapeImpl<0, __VA_ARGS__>; \
352   }
353 
354 TFLITE_GPU_LAYOUT_TRAITS(HW, Axis::HEIGHT, Axis::WIDTH);
355 TFLITE_GPU_LAYOUT_TRAITS(HWD, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH);
356 TFLITE_GPU_LAYOUT_TRAITS(OHWI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH,
357                          Axis::INPUT_CHANNELS);
358 TFLITE_GPU_LAYOUT_TRAITS(OIHW, Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS,
359                          Axis::HEIGHT, Axis::WIDTH);
360 TFLITE_GPU_LAYOUT_TRAITS(IOHW, Axis::INPUT_CHANNELS, Axis::OUTPUT_CHANNELS,
361                          Axis::HEIGHT, Axis::WIDTH);
362 TFLITE_GPU_LAYOUT_TRAITS(IHWO, Axis::INPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH,
363                          Axis::OUTPUT_CHANNELS);
364 TFLITE_GPU_LAYOUT_TRAITS(CHW, Axis::CHANNELS, Axis::HEIGHT, Axis::WIDTH);
365 TFLITE_GPU_LAYOUT_TRAITS(HWC, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS);
366 TFLITE_GPU_LAYOUT_TRAITS(HWDC, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH,
367                          Axis::CHANNELS);
368 TFLITE_GPU_LAYOUT_TRAITS(LINEAR, Axis::VALUE);
369 TFLITE_GPU_LAYOUT_TRAITS(SCALAR, Axis::VALUE);
370 TFLITE_GPU_LAYOUT_TRAITS(BHWC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH,
371                          Axis::CHANNELS);
372 TFLITE_GPU_LAYOUT_TRAITS(BHWDC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH,
373                          Axis::DEPTH, Axis::CHANNELS);
374 TFLITE_GPU_LAYOUT_TRAITS(OHWDI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT,
375                          Axis::WIDTH, Axis::DEPTH, Axis::INPUT_CHANNELS);
376 
377 #undef TFLITE_GPU_LAYOUT_TRAITS
378 
379 template <>
380 struct LayoutTraits<Layout::UNKNOWN> {
381   using strong_shape_type = StrongShapeImpl<0>;
382 };
383 
384 template <Axis A>
385 struct DimensionGetterFixedAxisFunc {
386   template <Layout T>
387   int32_t operator()() const {
388     constexpr int i = GetAxisIndex<T>(A);
389     return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1;
390   }
391   const Shape* l;
392 };
393 
394 struct DimensionGetterFunc {
395   template <Layout T>
396   int32_t operator()() const {
397     int i = GetAxisIndex<T>(axis);
398     return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1;
399   }
400   Axis axis;
401   const Shape* l;
402 };
403 
404 template <Axis A>
405 struct DimensionSetterFixedAxisFunc {
406   template <Layout T>
407   bool operator()() const {
408     constexpr int i = GetAxisIndex<T>(A);
409     if (i >= 0 && i < l->dimensions.size()) {
410       l->dimensions[i] = v;
411       return true;
412     }
413     return false;
414   }
415   Shape* l;
416   int32_t v;
417 };
418 
419 struct DimensionSetterFunc {
420   template <Layout T>
421   bool operator()() const {
422     int i = GetAxisIndex<T>(axis);
423     if (i >= 0 && i < l->dimensions.size()) {
424       l->dimensions[i] = v;
425       return true;
426     }
427     return false;
428   }
429   Axis axis;
430   Shape* l;
431   int32_t v;
432 };
433 
434 template <Layout L>
435 struct ToShapeFunc {
436   template <Layout T>
437   bool operator()() const {
438     for (int i = 0; i < StrongShape<L>::size(); ++i) {
439       int index = GetAxisIndex<T>(StrongShape<L>::axis(i));
440       if (index < 0) return false;
441       shape->set(i, l.dimensions[index]);
442     }
443     return true;
444   }
445 
446   StrongShape<L>* shape;
447   const Shape& l;
448 };
449 
450 }  // namespace internal_shape
451 
452 // template <Axis... As>
453 template <Layout L>
454 struct StrongShape : public internal_shape::LayoutTraits<L>::strong_shape_type {
455   using strong_shape_type =
456       typename internal_shape::LayoutTraits<L>::strong_shape_type;
457   StrongShape() = default;
458 
459   template <typename... Ts>
460   explicit StrongShape(Ts... t) : strong_shape_type(t...) {}
461 
462   constexpr static Layout layout = L;
463 
464   bool operator==(const StrongShape<L>& shape) const {
465     // TODO(akulik): implement better alternative.
466     return this->ToShape() == shape.ToShape();
467   }
468 
469   bool operator!=(const StrongShape<L>& shape) const {
470     // TODO(akulik): implement better alternative.
471     return this->ToShape() != shape.ToShape();
472   }
473   bool empty() const { return DimensionsProduct() == 0; }
474 
475   // Turns StrongShape into generic shape.
476   Shape ToShape() const {
477     std::vector<int32_t> dimensions(StrongShape::size());
478     for (int i = 0; i < StrongShape::size(); ++i) {
479       dimensions[i] = StrongShape::get(i);
480     }
481     return Shape(L, std::move(dimensions));
482   }
483 
484   // @return all dimensions multiplied
485   int64_t DimensionsProduct() const {
486     int64_t product = 1;
487     for (int i = 0; i < StrongShape::size(); ++i) {
488       product *= StrongShape::get(i);
489     }
490     return product;
491   }
492 
493   // Translates given coordinates of the layout into a linear index assuming
494   // dimensions are sorted in tensor access order e.g. if you access
495   // foobar[i][j][k] order of coordinates should be i,j,k.
496   int64_t LinearIndex(
497       const std::array<int32_t, StrongShape::size()>& coordinates) const {
498     int64_t index = coordinates[0];
499     for (int i = 1; i < StrongShape::size(); ++i) {
500       index = index * StrongShape::get(i) + coordinates[i];
501     }
502     return index;
503   }
504 
505   // Copies all dimensions from the given generic shape into specific shape.
506   // It requires shape to have all axis defined in the given
507   // StrongShape. For example:
508   //   - If this shape is OHWI but given shape is OIHW, Adopt will copy all
509   //     dimensions and return true.
510   //   - If this shape is OIHW but input shape is HW, Adopt will copy H and W
511   //     dimensions and return true, but if this shape is HW and given shape
512   //     OIHW, then Adopt will return false because not all axis are present in
513   //     the input shape.
514   //
515   // @return false if generic shape is not compatible.
516   bool Adopt(const Shape& shape) {
517     return DispatchByLayout(shape.layout,
518                             internal_shape::ToShapeFunc<L>{this, shape});
519   }
520 
521   // For all axis defined in a given shape copies values to this shape.
522   // Therefore, it is possible to copy dimensions from CHW to BCHW, but not
523   // the other way around.
524   //
525   // BCHW bchw;
526   // CHW chw;
527   // bchw.CopyAllGivenAxis(chw);  --> true
528   // chw.CopyAllGivenAxis(bchw);  --> false
529   //
530   // @return false if axis in source shape is not defined here, thus value
531   //         was not copied.
532   template <Layout B>
533   bool CopyAllGivenAxis(const StrongShape<B>& source) {
534     for (int i = 0; i < source.size(); ++i) {
535       if (!StrongShape::set(source.axis(i), source.get(i))) {
536         return false;
537       }
538     }
539     return true;
540   }
541 
542   // For all axis defined in this shape copies values from the given shape.
543   //
544   // BCHW bchw;
545   // CHW chw;
546   // bchw.CopyAllDefinedAxis(chw);  --> false
547   // chw.CopyAllDefinedAxis(bchw);  --> true
548   //
549   // @return false if given shape does not have axis defined here,
550   //         therefore a value was not copied.
551   template <Layout B>
552   bool CopyAllDefinedAxis(const StrongShape<B>& source) {
553     for (int i = 0; i < StrongShape::size(); ++i) {
554       int source_index = source.index(StrongShape::axis(i));
555       if (source_index < 0) {
556         return false;
557       }
558       StrongShape::set(i, source.get(source_index));  // always true
559     }
560     return true;
561   }
562 
563   // Copies values only for matching axis.
564   template <Layout B>
565   void CopyMatchingAxis(const StrongShape<B>& source) {
566     for (int i = 0; i < StrongShape::size(); ++i) {
567       StrongShape::set(source.axis(i), source.get(i));
568     }
569   }
570 
571   // AbslHash function for using in flat hash containers.
572   template <typename H>
573   friend H AbslHashValue(H hash_state, const StrongShape& strong_shape) {
574     for (size_t i = 0; i < strong_shape.size(); ++i) {
575       hash_state = H::combine(std::move(hash_state), strong_shape.get(i));
576     }
577     return hash_state;
578   }
579 };
580 
581 template <Layout T>
582 inline std::string ToString(const StrongShape<T>& s) {
583   return ToString(s.ToShape());
584 }
585 
586 template <Layout L>
587 constexpr Layout StrongShape<L>::layout;
588 
589 template <class F>
590 auto DispatchByLayout(Layout type, F f)
591     -> decltype(f.template operator()<Layout::UNKNOWN>()) {
592   switch (type) {
593     case Layout::HW:
594       return f.template operator()<Layout::HW>();
595     case Layout::HWD:
596       return f.template operator()<Layout::HWD>();
597     case Layout::HWC:
598       return f.template operator()<Layout::HWC>();
599     case Layout::HWDC:
600       return f.template operator()<Layout::HWDC>();
601     case Layout::CHW:
602       return f.template operator()<Layout::CHW>();
603     case Layout::OIHW:
604       return f.template operator()<Layout::OIHW>();
605     case Layout::IOHW:
606       return f.template operator()<Layout::IOHW>();
607     case Layout::OHWI:
608       return f.template operator()<Layout::OHWI>();
609     case Layout::IHWO:
610       return f.template operator()<Layout::IHWO>();
611     case Layout::LINEAR:
612       return f.template operator()<Layout::LINEAR>();
613     case Layout::SCALAR:
614       return f.template operator()<Layout::SCALAR>();
615     case Layout::BHWC:
616       return f.template operator()<Layout::BHWC>();
617     case Layout::BHWDC:
618       return f.template operator()<Layout::BHWDC>();
619     case Layout::OHWDI:
620       return f.template operator()<Layout::OHWDI>();
621     case Layout::UNKNOWN:
622       return f.template operator()<Layout::UNKNOWN>();
623   }
624 }
625 
626 template <Layout T>
627 constexpr int Size() {
628   return StrongShape<T>::size();
629 }
630 
631 template <Layout T>
632 constexpr Axis GetAxis(int index) {
633   return StrongShape<T>::axis(index);
634 }
635 
636 template <Layout T>
637 constexpr int GetAxisIndex(Axis axis) {
638   return StrongShape<T>::index(axis);
639 }
640 
641 template <Layout T>
642 constexpr bool HasAxis(Axis axis) {
643   return StrongShape<T>::has(axis);
644 }
645 
646 template <Axis D>
647 inline int32_t Shape::get() const {
648   return DispatchByLayout(
649       layout, internal_shape::DimensionGetterFixedAxisFunc<D>{this});
650 }
651 
652 inline int32_t Shape::get(Axis axis) const {
653   return DispatchByLayout(layout,
654                           internal_shape::DimensionGetterFunc{axis, this});
655 }
656 
657 template <Axis D>
658 inline bool Shape::set(int32_t t) {
659   return DispatchByLayout(
660       layout, internal_shape::DimensionSetterFixedAxisFunc<D>{this, t});
661 }
662 
663 inline bool Shape::set(Axis axis, int32_t t) {
664   return DispatchByLayout(layout,
665                           internal_shape::DimensionSetterFunc{axis, this, t});
666 }
667 
668 }  // namespace gpu
669 }  // namespace tflite
670 
671 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
672