1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ 18 19 #include <sys/types.h> 20 21 #include <algorithm> 22 #include <array> 23 #include <functional> 24 #include <numeric> 25 #include <string> 26 #include <utility> 27 #include <vector> 28 29 #include "absl/hash/hash.h" 30 31 namespace tflite { 32 namespace gpu { 33 34 enum class Axis { 35 UNKNOWN = 0, 36 CHANNELS = 1, 37 INPUT_CHANNELS = 2, 38 OUTPUT_CHANNELS = 3, 39 HEIGHT = 4, 40 WIDTH = 5, 41 BATCH = 6, 42 VALUE = 7, 43 DEPTH = 8, 44 }; 45 46 std::string ToString(Axis t); 47 48 // Layout represents axis order. 49 enum class Layout { 50 UNKNOWN = 0, 51 SCALAR = 1, 52 LINEAR = 2, 53 HW = 3, 54 CHW = 4, 55 HWC = 5, 56 OIHW = 6, 57 OHWI = 7, 58 IHWO = 8, 59 IOHW = 9, 60 BHWC = 10, 61 HWDC = 11, 62 BHWDC = 12, 63 HWD = 13, 64 OHWDI = 14, 65 }; 66 67 std::string ToString(Layout l); 68 69 // Returns number of axis for the fixed layout. 70 template <Layout T> 71 constexpr int Size(); 72 73 // Returns number of axis for the given layout. 74 int Size(Layout layout); 75 76 // Returns Axis for the given index and fixed layout. 77 template <Layout T> 78 constexpr Axis GetAxis(int index); 79 80 // Returns axis for the given layout and index. 81 Axis GetAxis(Layout layout, int32_t index); 82 83 // Returns axis index for the given axis and fixed layout. 84 template <Layout T> 85 constexpr int GetAxisIndex(Axis axis); 86 87 // Returns axis index for the given layout and axis. 88 int GetAxisIndex(Layout layout, Axis axis); 89 90 // Checks if fixed layout has given axis 91 template <Layout T> 92 constexpr bool HasAxis(Axis axis); 93 94 // Checks if given layout has given axis 95 bool HasAxis(Layout layout, Axis axis); 96 97 // Stores Layout(axis set and order) and value for dimensions. 98 struct Shape { ShapeShape99 Shape() : layout(Layout::UNKNOWN), dimensions() {} 100 ShapeShape101 explicit Shape(Layout t) : layout(t), dimensions(Size(t)) {} 102 ShapeShape103 Shape(Layout t, std::vector<int32_t> d) 104 : layout(t), dimensions(std::move(d)) {} 105 106 bool operator==(const Shape& other) const { 107 return (layout == other.layout) && (dimensions == other.dimensions); 108 } 109 110 bool operator!=(const Shape& other) const { return !operator==(other); } 111 112 // All methods below are matching same methods defined in StrongShape to 113 // make sure generic algorithms work both ways. 114 115 // Returns back a dimension or -1 if it is not found. 116 template <Axis D> 117 int32_t get() const; 118 int32_t get(Axis axis) const; 119 120 template <Axis D> 121 bool set(int32_t t); 122 bool set(Axis axis, int32_t t); 123 axisShape124 Axis axis(int index) const { return GetAxis(layout, index); } 125 indexShape126 int index(Axis axis) const { return GetAxisIndex(layout, axis); } 127 hasShape128 bool has(Axis axis) const { return HasAxis(layout, axis); } 129 DimensionsProductShape130 int64_t DimensionsProduct() const { 131 return std::accumulate(dimensions.begin(), dimensions.end(), 1ll, 132 std::multiplies<int64_t>()); 133 } 134 135 Layout layout = Layout::UNKNOWN; 136 137 std::vector<int32_t> dimensions; 138 }; 139 140 std::string ToString(const Shape& s); 141 142 // StrongShape provides convenient explicit access to dimensions stored in 143 // shape, e.g. StrongShape<Layout::HW> s; provides s.h and s.w accessors. 144 // 145 // There is a conversion possible both ways between Shape and StrongShape. 146 // 147 // OIHW oihw; // specific shape 148 // Shape l = oihw.ToShape(); 149 // 150 // OHWI other; // notice not the same but compatible shape. 151 // if (!other.Adopt(l)) { 152 // // error handling 153 // } 154 // 155 // StrongShape supports the following set of operations: 156 // 157 // // Returns number of axis in the shape class. 158 // static constexpr int size(); 159 // 160 // // Returns Axis for the given index or Axis::UNKNOWN if index 161 // // falls outside of the defined range in this shape. 162 // static constexpr Axis axis(int index); 163 // 164 // // Returns index for the given axis or -1 if axis is not defined in this 165 // // shape. 166 // static constexpr int index(Axis axis); 167 // 168 // // Getters 169 // int32_t get(int index) const; 170 // int32_t get(Axis axis) const; 171 // int32_t get<Axis>() const; 172 // 173 // // Setters that return false if set was not successful. 174 // bool set(int index, int32_t v); 175 // bool set(Axis axis, int32_t v); 176 // bool set<Axis>(int32_t v); 177 // 178 // // Returns shape's layout. 179 // static const Layout layout; 180 // 181 // // Turns specific shape into generic shape. 182 // Shape ToShape() const; 183 // 184 // // Copies all dimensions from the given shape. 185 // bool Adopt(const Shape&); 186 // 187 template <Layout L> 188 struct StrongShape; 189 190 using Scalar = StrongShape<Layout::SCALAR>; 191 using Linear = StrongShape<Layout::LINEAR>; 192 using HW = StrongShape<Layout::HW>; 193 using HWD = StrongShape<Layout::HWD>; 194 195 // Common tensor shape for CNN models working with images. 196 using CHW = StrongShape<Layout::CHW>; 197 using HWC = StrongShape<Layout::HWC>; 198 using HWDC = StrongShape<Layout::HWDC>; 199 using BHWC = StrongShape<Layout::BHWC>; 200 using BHWDC = StrongShape<Layout::BHWDC>; 201 202 // Tensor shape used in convolution_2d weights. 203 using OIHW = StrongShape<Layout::OIHW>; 204 using OHWI = StrongShape<Layout::OHWI>; 205 using IHWO = StrongShape<Layout::IHWO>; 206 using IOHW = StrongShape<Layout::IOHW>; 207 208 // Tensor shape used in convolution_3d weights. 209 using OHWDI = StrongShape<Layout::OHWDI>; 210 211 // ----------------------------------------------------------------------------- 212 // Everything below are internal implementation details. 213 // ----------------------------------------------------------------------------- 214 215 namespace internal_shape { 216 217 template <Axis T> 218 struct AxisTraits; 219 220 #define TFLITE_GPU_AXIS_TRAITS(AxisName, HolderName) \ 221 template <> \ 222 struct AxisTraits<Axis::AxisName> { \ 223 struct Holder { \ 224 int32_t HolderName; \ 225 \ 226 protected: \ 227 int32_t operator()() const { return HolderName; } \ 228 void operator()(int32_t v) { HolderName = v; } \ 229 }; \ 230 \ 231 using dimension_holder_type = Holder; \ 232 } 233 234 TFLITE_GPU_AXIS_TRAITS(CHANNELS, c); 235 TFLITE_GPU_AXIS_TRAITS(HEIGHT, h); 236 TFLITE_GPU_AXIS_TRAITS(WIDTH, w); 237 TFLITE_GPU_AXIS_TRAITS(INPUT_CHANNELS, i); 238 TFLITE_GPU_AXIS_TRAITS(OUTPUT_CHANNELS, o); 239 TFLITE_GPU_AXIS_TRAITS(BATCH, b); 240 TFLITE_GPU_AXIS_TRAITS(VALUE, v); 241 TFLITE_GPU_AXIS_TRAITS(DEPTH, d); 242 243 #undef TFLITE_GPU_AXIS_TRAITS 244 245 template <int N, Axis... As> 246 struct StrongShapeImpl; 247 248 template <int N> 249 struct StrongShapeImpl<N> { 250 static constexpr int size() { return N; } 251 252 static constexpr Axis axis(int) { return Axis::UNKNOWN; } 253 254 static constexpr int index(Axis) { return -1; } 255 256 static constexpr bool has(Axis) { return false; } 257 258 int32_t get(Axis) const { return -1; } 259 260 int32_t get(int) const { return -1; } 261 262 template <Axis B> 263 int32_t get() const { 264 return -1; 265 } 266 267 bool set(Axis, int32_t) { return false; } 268 269 bool set(int, int32_t) { return false; } 270 271 template <Axis B> 272 bool set(int32_t) { 273 return false; 274 } 275 }; 276 277 // Used to deduce number of axis, and to be a child of a proper holder to 278 // provide access to the dimension by name 279 template <int N, Axis A, Axis... As> 280 struct StrongShapeImpl<N, A, As...> 281 : public AxisTraits<A>::dimension_holder_type, 282 public StrongShapeImpl<N + 1, As...> { 283 using dimension_holder_type = typename AxisTraits<A>::dimension_holder_type; 284 285 using rest_type = StrongShapeImpl<N + 1, As...>; 286 287 StrongShapeImpl() : dimension_holder_type{0}, rest_type() {} 288 289 template <typename... Ts> 290 explicit StrongShapeImpl(int32_t t, Ts... ts) 291 : dimension_holder_type{t}, rest_type(ts...) {} 292 293 static constexpr Axis axis(int index) { 294 return index == N ? A : rest_type::axis(index); 295 } 296 297 static constexpr int index(Axis axis) { 298 return axis == A ? N : rest_type::index(axis); 299 } 300 301 static constexpr bool has(Axis axis) { 302 return axis == A ? true : rest_type::has(axis); 303 } 304 305 int32_t get(Axis axis) const { 306 return axis == A ? dimension_holder_type::operator()() 307 : rest_type::get(axis); 308 } 309 310 template <Axis B> 311 int32_t get() const { 312 return B == A ? dimension_holder_type::operator()() 313 : rest_type::template get<B>(); 314 } 315 316 int32_t get(int index) const { 317 return index == N ? dimension_holder_type::operator()() 318 : rest_type::get(index); 319 } 320 321 bool set(Axis axis, int32_t t) { 322 if (axis == A) { 323 dimension_holder_type::operator()(t); 324 return true; 325 } 326 return rest_type::set(axis, t); 327 } 328 329 bool set(int index, int32_t t) { 330 if (index == N) { 331 dimension_holder_type::operator()(t); 332 return true; 333 } 334 return rest_type::set(index, t); 335 } 336 337 template <Axis B> 338 bool set(int32_t t) { 339 if (A == B) { 340 dimension_holder_type::operator()(t); 341 return true; 342 } 343 return rest_type::template set<B>(t); 344 } 345 }; 346 347 template <Layout T> 348 struct LayoutTraits; 349 350 #define TFLITE_GPU_LAYOUT_TRAITS(LayoutName, ...) \ 351 template <> \ 352 struct LayoutTraits<Layout::LayoutName> { \ 353 using strong_shape_type = StrongShapeImpl<0, __VA_ARGS__>; \ 354 } 355 356 TFLITE_GPU_LAYOUT_TRAITS(HW, Axis::HEIGHT, Axis::WIDTH); 357 TFLITE_GPU_LAYOUT_TRAITS(HWD, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH); 358 TFLITE_GPU_LAYOUT_TRAITS(OHWI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH, 359 Axis::INPUT_CHANNELS); 360 TFLITE_GPU_LAYOUT_TRAITS(OIHW, Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS, 361 Axis::HEIGHT, Axis::WIDTH); 362 TFLITE_GPU_LAYOUT_TRAITS(IOHW, Axis::INPUT_CHANNELS, Axis::OUTPUT_CHANNELS, 363 Axis::HEIGHT, Axis::WIDTH); 364 TFLITE_GPU_LAYOUT_TRAITS(IHWO, Axis::INPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH, 365 Axis::OUTPUT_CHANNELS); 366 TFLITE_GPU_LAYOUT_TRAITS(CHW, Axis::CHANNELS, Axis::HEIGHT, Axis::WIDTH); 367 TFLITE_GPU_LAYOUT_TRAITS(HWC, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS); 368 TFLITE_GPU_LAYOUT_TRAITS(HWDC, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH, 369 Axis::CHANNELS); 370 TFLITE_GPU_LAYOUT_TRAITS(LINEAR, Axis::VALUE); 371 TFLITE_GPU_LAYOUT_TRAITS(SCALAR, Axis::VALUE); 372 TFLITE_GPU_LAYOUT_TRAITS(BHWC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, 373 Axis::CHANNELS); 374 TFLITE_GPU_LAYOUT_TRAITS(BHWDC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, 375 Axis::DEPTH, Axis::CHANNELS); 376 TFLITE_GPU_LAYOUT_TRAITS(OHWDI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, 377 Axis::WIDTH, Axis::DEPTH, Axis::INPUT_CHANNELS); 378 379 #undef TFLITE_GPU_LAYOUT_TRAITS 380 381 template <> 382 struct LayoutTraits<Layout::UNKNOWN> { 383 using strong_shape_type = StrongShapeImpl<0>; 384 }; 385 386 template <Axis A> 387 struct DimensionGetterFixedAxisFunc { 388 template <Layout T> 389 int32_t operator()() const { 390 constexpr int i = GetAxisIndex<T>(A); 391 return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1; 392 } 393 const Shape* l; 394 }; 395 396 struct DimensionGetterFunc { 397 template <Layout T> 398 int32_t operator()() const { 399 int i = GetAxisIndex<T>(axis); 400 return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1; 401 } 402 Axis axis; 403 const Shape* l; 404 }; 405 406 template <Axis A> 407 struct DimensionSetterFixedAxisFunc { 408 template <Layout T> 409 bool operator()() const { 410 constexpr int i = GetAxisIndex<T>(A); 411 if (i >= 0 && i < l->dimensions.size()) { 412 l->dimensions[i] = v; 413 return true; 414 } 415 return false; 416 } 417 Shape* l; 418 int32_t v; 419 }; 420 421 struct DimensionSetterFunc { 422 template <Layout T> 423 bool operator()() const { 424 int i = GetAxisIndex<T>(axis); 425 if (i >= 0 && i < l->dimensions.size()) { 426 l->dimensions[i] = v; 427 return true; 428 } 429 return false; 430 } 431 Axis axis; 432 Shape* l; 433 int32_t v; 434 }; 435 436 template <Layout L> 437 struct ToShapeFunc { 438 template <Layout T> 439 bool operator()() const { 440 for (int i = 0; i < StrongShape<L>::size(); ++i) { 441 int index = GetAxisIndex<T>(StrongShape<L>::axis(i)); 442 if (index < 0) return false; 443 shape->set(i, l.dimensions[index]); 444 } 445 return true; 446 } 447 448 StrongShape<L>* shape; 449 const Shape& l; 450 }; 451 452 } // namespace internal_shape 453 454 // template <Axis... As> 455 template <Layout L> 456 struct StrongShape : public internal_shape::LayoutTraits<L>::strong_shape_type { 457 using strong_shape_type = 458 typename internal_shape::LayoutTraits<L>::strong_shape_type; 459 StrongShape() = default; 460 461 template <typename... Ts> 462 explicit StrongShape(Ts... t) : strong_shape_type(t...) {} 463 464 constexpr static Layout layout = L; 465 466 bool operator==(const StrongShape<L>& shape) const { 467 // TODO(akulik): implement better alternative. 468 return this->ToShape() == shape.ToShape(); 469 } 470 471 bool operator!=(const StrongShape<L>& shape) const { 472 // TODO(akulik): implement better alternative. 473 return this->ToShape() != shape.ToShape(); 474 } 475 bool empty() const { return DimensionsProduct() == 0; } 476 477 // Turns StrongShape into generic shape. 478 Shape ToShape() const { 479 std::vector<int32_t> dimensions(StrongShape::size()); 480 for (int i = 0; i < StrongShape::size(); ++i) { 481 dimensions[i] = StrongShape::get(i); 482 } 483 return Shape(L, std::move(dimensions)); 484 } 485 486 // @return all dimensions multiplied 487 int64_t DimensionsProduct() const { 488 int64_t product = 1; 489 for (int i = 0; i < StrongShape::size(); ++i) { 490 product *= StrongShape::get(i); 491 } 492 return product; 493 } 494 495 // Translates given coordinates of the layout into a linear index assuming 496 // dimensions are sorted in tensor access order e.g. if you access 497 // foobar[i][j][k] order of coordinates should be i,j,k. 498 int64_t LinearIndex( 499 const std::array<int32_t, StrongShape::size()>& coordinates) const { 500 int64_t index = coordinates[0]; 501 for (int i = 1; i < StrongShape::size(); ++i) { 502 index = index * StrongShape::get(i) + coordinates[i]; 503 } 504 return index; 505 } 506 507 // Copies all dimensions from the given generic shape into specific shape. 508 // It requires shape to have all axis defined in the given 509 // StrongShape. For example: 510 // - If this shape is OHWI but given shape is OIHW, Adopt will copy all 511 // dimensions and return true. 512 // - If this shape is OIHW but input shape is HW, Adopt will copy H and W 513 // dimensions and return true, but if this shape is HW and given shape 514 // OIHW, then Adopt will return false because not all axis are present in 515 // the input shape. 516 // 517 // @return false if generic shape is not compatible. 518 bool Adopt(const Shape& shape) { 519 return DispatchByLayout(shape.layout, 520 internal_shape::ToShapeFunc<L>{this, shape}); 521 } 522 523 // For all axis defined in a given shape copies values to this shape. 524 // Therefore, it is possible to copy dimensions from CHW to BCHW, but not 525 // the other way around. 526 // 527 // BCHW bchw; 528 // CHW chw; 529 // bchw.CopyAllGivenAxis(chw); --> true 530 // chw.CopyAllGivenAxis(bchw); --> false 531 // 532 // @return false if axis in source shape is not defined here, thus value 533 // was not copied. 534 template <Layout B> 535 bool CopyAllGivenAxis(const StrongShape<B>& source) { 536 for (int i = 0; i < source.size(); ++i) { 537 if (!StrongShape::set(source.axis(i), source.get(i))) { 538 return false; 539 } 540 } 541 return true; 542 } 543 544 // For all axis defined in this shape copies values from the given shape. 545 // 546 // BCHW bchw; 547 // CHW chw; 548 // bchw.CopyAllDefinedAxis(chw); --> false 549 // chw.CopyAllDefinedAxis(bchw); --> true 550 // 551 // @return false if given shape does not have axis defined here, 552 // therefore a value was not copied. 553 template <Layout B> 554 bool CopyAllDefinedAxis(const StrongShape<B>& source) { 555 for (int i = 0; i < StrongShape::size(); ++i) { 556 int source_index = source.index(StrongShape::axis(i)); 557 if (source_index < 0) { 558 return false; 559 } 560 StrongShape::set(i, source.get(source_index)); // always true 561 } 562 return true; 563 } 564 565 // Copies values only for matching axis. 566 template <Layout B> 567 void CopyMatchingAxis(const StrongShape<B>& source) { 568 for (int i = 0; i < StrongShape::size(); ++i) { 569 StrongShape::set(source.axis(i), source.get(i)); 570 } 571 } 572 573 // AbslHash function for using in flat hash containers. 574 template <typename H> 575 friend H AbslHashValue(H hash_state, const StrongShape& strong_shape) { 576 for (size_t i = 0; i < strong_shape.size(); ++i) { 577 hash_state = H::combine(std::move(hash_state), strong_shape.get(i)); 578 } 579 return hash_state; 580 } 581 }; 582 583 template <Layout T> 584 inline std::string ToString(const StrongShape<T>& s) { 585 return ToString(s.ToShape()); 586 } 587 588 template <Layout L> 589 constexpr Layout StrongShape<L>::layout; 590 591 template <class F> 592 auto DispatchByLayout(Layout type, F f) 593 -> decltype(f.template operator()<Layout::UNKNOWN>()) { 594 switch (type) { 595 case Layout::HW: 596 return f.template operator()<Layout::HW>(); 597 case Layout::HWD: 598 return f.template operator()<Layout::HWD>(); 599 case Layout::HWC: 600 return f.template operator()<Layout::HWC>(); 601 case Layout::HWDC: 602 return f.template operator()<Layout::HWDC>(); 603 case Layout::CHW: 604 return f.template operator()<Layout::CHW>(); 605 case Layout::OIHW: 606 return f.template operator()<Layout::OIHW>(); 607 case Layout::IOHW: 608 return f.template operator()<Layout::IOHW>(); 609 case Layout::OHWI: 610 return f.template operator()<Layout::OHWI>(); 611 case Layout::IHWO: 612 return f.template operator()<Layout::IHWO>(); 613 case Layout::LINEAR: 614 return f.template operator()<Layout::LINEAR>(); 615 case Layout::SCALAR: 616 return f.template operator()<Layout::SCALAR>(); 617 case Layout::BHWC: 618 return f.template operator()<Layout::BHWC>(); 619 case Layout::BHWDC: 620 return f.template operator()<Layout::BHWDC>(); 621 case Layout::OHWDI: 622 return f.template operator()<Layout::OHWDI>(); 623 case Layout::UNKNOWN: 624 return f.template operator()<Layout::UNKNOWN>(); 625 } 626 } 627 628 template <Layout T> 629 constexpr int Size() { 630 return StrongShape<T>::size(); 631 } 632 633 template <Layout T> 634 constexpr Axis GetAxis(int index) { 635 return StrongShape<T>::axis(index); 636 } 637 638 template <Layout T> 639 constexpr int GetAxisIndex(Axis axis) { 640 return StrongShape<T>::index(axis); 641 } 642 643 template <Layout T> 644 constexpr bool HasAxis(Axis axis) { 645 return StrongShape<T>::has(axis); 646 } 647 648 template <Axis D> 649 inline int32_t Shape::get() const { 650 return DispatchByLayout( 651 layout, internal_shape::DimensionGetterFixedAxisFunc<D>{this}); 652 } 653 654 inline int32_t Shape::get(Axis axis) const { 655 return DispatchByLayout(layout, 656 internal_shape::DimensionGetterFunc{axis, this}); 657 } 658 659 template <Axis D> 660 inline bool Shape::set(int32_t t) { 661 return DispatchByLayout( 662 layout, internal_shape::DimensionSetterFixedAxisFunc<D>{this, t}); 663 } 664 665 inline bool Shape::set(Axis axis, int32_t t) { 666 return DispatchByLayout(layout, 667 internal_shape::DimensionSetterFunc{axis, this, t}); 668 } 669 670 } // namespace gpu 671 } // namespace tflite 672 673 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ 674