1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ 18 19 #include <stddef.h> 20 #include <stdint.h> 21 22 #include <array> 23 #include <functional> 24 #include <numeric> 25 #include <string> 26 #include <utility> 27 #include <vector> 28 29 namespace tflite { 30 namespace gpu { 31 32 enum class Axis { 33 UNKNOWN = 0, 34 CHANNELS = 1, 35 INPUT_CHANNELS = 2, 36 OUTPUT_CHANNELS = 3, 37 HEIGHT = 4, 38 WIDTH = 5, 39 BATCH = 6, 40 VALUE = 7, 41 DEPTH = 8, 42 }; 43 44 std::string ToString(Axis t); 45 46 // Layout represents axis order. 47 enum class Layout { 48 UNKNOWN = 0, 49 SCALAR = 1, 50 LINEAR = 2, 51 HW = 3, 52 CHW = 4, 53 HWC = 5, 54 OIHW = 6, 55 OHWI = 7, 56 IHWO = 8, 57 IOHW = 9, 58 BHWC = 10, 59 HWDC = 11, 60 BHWDC = 12, 61 HWD = 13, 62 OHWDI = 14, 63 }; 64 65 std::string ToString(Layout l); 66 67 // Returns number of axis for the fixed layout. 68 template <Layout T> 69 constexpr int Size(); 70 71 // Returns number of axis for the given layout. 72 int Size(Layout layout); 73 74 // Returns Axis for the given index and fixed layout. 75 template <Layout T> 76 constexpr Axis GetAxis(int index); 77 78 // Returns axis for the given layout and index. 79 Axis GetAxis(Layout layout, int32_t index); 80 81 // Returns axis index for the given axis and fixed layout. 82 template <Layout T> 83 constexpr int GetAxisIndex(Axis axis); 84 85 // Returns axis index for the given layout and axis. 86 int GetAxisIndex(Layout layout, Axis axis); 87 88 // Checks if fixed layout has given axis 89 template <Layout T> 90 constexpr bool HasAxis(Axis axis); 91 92 // Checks if given layout has given axis 93 bool HasAxis(Layout layout, Axis axis); 94 95 // Stores Layout(axis set and order) and value for dimensions. 96 struct Shape { ShapeShape97 Shape() : layout(Layout::UNKNOWN), dimensions() {} 98 ShapeShape99 explicit Shape(Layout t) : layout(t), dimensions(Size(t)) {} 100 ShapeShape101 Shape(Layout t, std::vector<int32_t> d) 102 : layout(t), dimensions(std::move(d)) {} 103 104 bool operator==(const Shape& other) const { 105 return (layout == other.layout) && (dimensions == other.dimensions); 106 } 107 108 bool operator!=(const Shape& other) const { return !operator==(other); } 109 110 // All methods below are matching same methods defined in StrongShape to 111 // make sure generic algorithms work both ways. 112 113 // Returns back a dimension or -1 if it is not found. 114 template <Axis D> 115 int32_t get() const; 116 int32_t get(Axis axis) const; 117 118 template <Axis D> 119 bool set(int32_t t); 120 bool set(Axis axis, int32_t t); 121 axisShape122 Axis axis(int index) const { return GetAxis(layout, index); } 123 indexShape124 int index(Axis axis) const { return GetAxisIndex(layout, axis); } 125 hasShape126 bool has(Axis axis) const { return HasAxis(layout, axis); } 127 DimensionsProductShape128 int64_t DimensionsProduct() const { 129 return std::accumulate(dimensions.begin(), dimensions.end(), 1ll, 130 std::multiplies<int64_t>()); 131 } 132 133 Layout layout = Layout::UNKNOWN; 134 135 std::vector<int32_t> dimensions; 136 }; 137 138 std::string ToString(const Shape& s); 139 140 // StrongShape provides convenient explicit access to dimensions stored in 141 // shape, e.g. StrongShape<Layout::HW> s; provides s.h and s.w accessors. 142 // 143 // There is a conversion possible both ways between Shape and StrongShape. 144 // 145 // OIHW oihw; // specific shape 146 // Shape l = oihw.ToShape(); 147 // 148 // OHWI other; // notice not the same but compatible shape. 149 // if (!other.Adopt(l)) { 150 // // error handling 151 // } 152 // 153 // StrongShape supports the following set of operations: 154 // 155 // // Returns number of axis in the shape class. 156 // static constexpr int size(); 157 // 158 // // Returns Axis for the given index or Axis::UNKNOWN if index 159 // // falls outside of the defined range in this shape. 160 // static constexpr Axis axis(int index); 161 // 162 // // Returns index for the given axis or -1 if axis is not defined in this 163 // // shape. 164 // static constexpr int index(Axis axis); 165 // 166 // // Getters 167 // int32_t get(int index) const; 168 // int32_t get(Axis axis) const; 169 // int32_t get<Axis>() const; 170 // 171 // // Setters that return false if set was not successful. 172 // bool set(int index, int32_t v); 173 // bool set(Axis axis, int32_t v); 174 // bool set<Axis>(int32_t v); 175 // 176 // // Returns shape's layout. 177 // static const Layout layout; 178 // 179 // // Turns specific shape into generic shape. 180 // Shape ToShape() const; 181 // 182 // // Copies all dimensions from the given shape. 183 // bool Adopt(const Shape&); 184 // 185 template <Layout L> 186 struct StrongShape; 187 188 using Scalar = StrongShape<Layout::SCALAR>; 189 using Linear = StrongShape<Layout::LINEAR>; 190 using HW = StrongShape<Layout::HW>; 191 using HWD = StrongShape<Layout::HWD>; 192 193 // Common tensor shape for CNN models working with images. 194 using CHW = StrongShape<Layout::CHW>; 195 using HWC = StrongShape<Layout::HWC>; 196 using HWDC = StrongShape<Layout::HWDC>; 197 using BHWC = StrongShape<Layout::BHWC>; 198 using BHWDC = StrongShape<Layout::BHWDC>; 199 200 // Tensor shape used in convolution_2d weights. 201 using OIHW = StrongShape<Layout::OIHW>; 202 using OHWI = StrongShape<Layout::OHWI>; 203 using IHWO = StrongShape<Layout::IHWO>; 204 using IOHW = StrongShape<Layout::IOHW>; 205 206 // Tensor shape used in convolution_3d weights. 207 using OHWDI = StrongShape<Layout::OHWDI>; 208 209 // ----------------------------------------------------------------------------- 210 // Everything below are internal implementation details. 211 // ----------------------------------------------------------------------------- 212 213 namespace internal_shape { 214 215 template <Axis T> 216 struct AxisTraits; 217 218 #define TFLITE_GPU_AXIS_TRAITS(AxisName, HolderName) \ 219 template <> \ 220 struct AxisTraits<Axis::AxisName> { \ 221 struct Holder { \ 222 int32_t HolderName; \ 223 \ 224 protected: \ 225 int32_t operator()() const { return HolderName; } \ 226 void operator()(int32_t v) { HolderName = v; } \ 227 }; \ 228 \ 229 using dimension_holder_type = Holder; \ 230 } 231 232 TFLITE_GPU_AXIS_TRAITS(CHANNELS, c); 233 TFLITE_GPU_AXIS_TRAITS(HEIGHT, h); 234 TFLITE_GPU_AXIS_TRAITS(WIDTH, w); 235 TFLITE_GPU_AXIS_TRAITS(INPUT_CHANNELS, i); 236 TFLITE_GPU_AXIS_TRAITS(OUTPUT_CHANNELS, o); 237 TFLITE_GPU_AXIS_TRAITS(BATCH, b); 238 TFLITE_GPU_AXIS_TRAITS(VALUE, v); 239 TFLITE_GPU_AXIS_TRAITS(DEPTH, d); 240 241 #undef TFLITE_GPU_AXIS_TRAITS 242 243 template <int N, Axis... As> 244 struct StrongShapeImpl; 245 246 template <int N> 247 struct StrongShapeImpl<N> { 248 static constexpr int size() { return N; } 249 250 static constexpr Axis axis(int) { return Axis::UNKNOWN; } 251 252 static constexpr int index(Axis) { return -1; } 253 254 static constexpr bool has(Axis) { return false; } 255 256 int32_t get(Axis) const { return -1; } 257 258 int32_t get(int) const { return -1; } 259 260 template <Axis B> 261 int32_t get() const { 262 return -1; 263 } 264 265 bool set(Axis, int32_t) { return false; } 266 267 bool set(int, int32_t) { return false; } 268 269 template <Axis B> 270 bool set(int32_t) { 271 return false; 272 } 273 }; 274 275 // Used to deduce number of axis, and to be a child of a proper holder to 276 // provide access to the dimension by name 277 template <int N, Axis A, Axis... As> 278 struct StrongShapeImpl<N, A, As...> 279 : public AxisTraits<A>::dimension_holder_type, 280 public StrongShapeImpl<N + 1, As...> { 281 using dimension_holder_type = typename AxisTraits<A>::dimension_holder_type; 282 283 using rest_type = StrongShapeImpl<N + 1, As...>; 284 285 StrongShapeImpl() : dimension_holder_type{0}, rest_type() {} 286 287 template <typename... Ts> 288 explicit StrongShapeImpl(int32_t t, Ts... ts) 289 : dimension_holder_type{t}, rest_type(ts...) {} 290 291 static constexpr Axis axis(int index) { 292 return index == N ? A : rest_type::axis(index); 293 } 294 295 static constexpr int index(Axis axis) { 296 return axis == A ? N : rest_type::index(axis); 297 } 298 299 static constexpr bool has(Axis axis) { 300 return axis == A ? true : rest_type::has(axis); 301 } 302 303 int32_t get(Axis axis) const { 304 return axis == A ? dimension_holder_type::operator()() 305 : rest_type::get(axis); 306 } 307 308 template <Axis B> 309 int32_t get() const { 310 return B == A ? dimension_holder_type::operator()() 311 : rest_type::template get<B>(); 312 } 313 314 int32_t get(int index) const { 315 return index == N ? dimension_holder_type::operator()() 316 : rest_type::get(index); 317 } 318 319 bool set(Axis axis, int32_t t) { 320 if (axis == A) { 321 dimension_holder_type::operator()(t); 322 return true; 323 } 324 return rest_type::set(axis, t); 325 } 326 327 bool set(int index, int32_t t) { 328 if (index == N) { 329 dimension_holder_type::operator()(t); 330 return true; 331 } 332 return rest_type::set(index, t); 333 } 334 335 template <Axis B> 336 bool set(int32_t t) { 337 if (A == B) { 338 dimension_holder_type::operator()(t); 339 return true; 340 } 341 return rest_type::template set<B>(t); 342 } 343 }; 344 345 template <Layout T> 346 struct LayoutTraits; 347 348 #define TFLITE_GPU_LAYOUT_TRAITS(LayoutName, ...) \ 349 template <> \ 350 struct LayoutTraits<Layout::LayoutName> { \ 351 using strong_shape_type = StrongShapeImpl<0, __VA_ARGS__>; \ 352 } 353 354 TFLITE_GPU_LAYOUT_TRAITS(HW, Axis::HEIGHT, Axis::WIDTH); 355 TFLITE_GPU_LAYOUT_TRAITS(HWD, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH); 356 TFLITE_GPU_LAYOUT_TRAITS(OHWI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH, 357 Axis::INPUT_CHANNELS); 358 TFLITE_GPU_LAYOUT_TRAITS(OIHW, Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS, 359 Axis::HEIGHT, Axis::WIDTH); 360 TFLITE_GPU_LAYOUT_TRAITS(IOHW, Axis::INPUT_CHANNELS, Axis::OUTPUT_CHANNELS, 361 Axis::HEIGHT, Axis::WIDTH); 362 TFLITE_GPU_LAYOUT_TRAITS(IHWO, Axis::INPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH, 363 Axis::OUTPUT_CHANNELS); 364 TFLITE_GPU_LAYOUT_TRAITS(CHW, Axis::CHANNELS, Axis::HEIGHT, Axis::WIDTH); 365 TFLITE_GPU_LAYOUT_TRAITS(HWC, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS); 366 TFLITE_GPU_LAYOUT_TRAITS(HWDC, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH, 367 Axis::CHANNELS); 368 TFLITE_GPU_LAYOUT_TRAITS(LINEAR, Axis::VALUE); 369 TFLITE_GPU_LAYOUT_TRAITS(SCALAR, Axis::VALUE); 370 TFLITE_GPU_LAYOUT_TRAITS(BHWC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, 371 Axis::CHANNELS); 372 TFLITE_GPU_LAYOUT_TRAITS(BHWDC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, 373 Axis::DEPTH, Axis::CHANNELS); 374 TFLITE_GPU_LAYOUT_TRAITS(OHWDI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, 375 Axis::WIDTH, Axis::DEPTH, Axis::INPUT_CHANNELS); 376 377 #undef TFLITE_GPU_LAYOUT_TRAITS 378 379 template <> 380 struct LayoutTraits<Layout::UNKNOWN> { 381 using strong_shape_type = StrongShapeImpl<0>; 382 }; 383 384 template <Axis A> 385 struct DimensionGetterFixedAxisFunc { 386 template <Layout T> 387 int32_t operator()() const { 388 constexpr int i = GetAxisIndex<T>(A); 389 return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1; 390 } 391 const Shape* l; 392 }; 393 394 struct DimensionGetterFunc { 395 template <Layout T> 396 int32_t operator()() const { 397 int i = GetAxisIndex<T>(axis); 398 return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1; 399 } 400 Axis axis; 401 const Shape* l; 402 }; 403 404 template <Axis A> 405 struct DimensionSetterFixedAxisFunc { 406 template <Layout T> 407 bool operator()() const { 408 constexpr int i = GetAxisIndex<T>(A); 409 if (i >= 0 && i < l->dimensions.size()) { 410 l->dimensions[i] = v; 411 return true; 412 } 413 return false; 414 } 415 Shape* l; 416 int32_t v; 417 }; 418 419 struct DimensionSetterFunc { 420 template <Layout T> 421 bool operator()() const { 422 int i = GetAxisIndex<T>(axis); 423 if (i >= 0 && i < l->dimensions.size()) { 424 l->dimensions[i] = v; 425 return true; 426 } 427 return false; 428 } 429 Axis axis; 430 Shape* l; 431 int32_t v; 432 }; 433 434 template <Layout L> 435 struct ToShapeFunc { 436 template <Layout T> 437 bool operator()() const { 438 for (int i = 0; i < StrongShape<L>::size(); ++i) { 439 int index = GetAxisIndex<T>(StrongShape<L>::axis(i)); 440 if (index < 0) return false; 441 shape->set(i, l.dimensions[index]); 442 } 443 return true; 444 } 445 446 StrongShape<L>* shape; 447 const Shape& l; 448 }; 449 450 } // namespace internal_shape 451 452 // template <Axis... As> 453 template <Layout L> 454 struct StrongShape : public internal_shape::LayoutTraits<L>::strong_shape_type { 455 using strong_shape_type = 456 typename internal_shape::LayoutTraits<L>::strong_shape_type; 457 StrongShape() = default; 458 459 template <typename... Ts> 460 explicit StrongShape(Ts... t) : strong_shape_type(t...) {} 461 462 constexpr static Layout layout = L; 463 464 bool operator==(const StrongShape<L>& shape) const { 465 // TODO(akulik): implement better alternative. 466 return this->ToShape() == shape.ToShape(); 467 } 468 469 bool operator!=(const StrongShape<L>& shape) const { 470 // TODO(akulik): implement better alternative. 471 return this->ToShape() != shape.ToShape(); 472 } 473 bool empty() const { return DimensionsProduct() == 0; } 474 475 // Turns StrongShape into generic shape. 476 Shape ToShape() const { 477 std::vector<int32_t> dimensions(StrongShape::size()); 478 for (int i = 0; i < StrongShape::size(); ++i) { 479 dimensions[i] = StrongShape::get(i); 480 } 481 return Shape(L, std::move(dimensions)); 482 } 483 484 // @return all dimensions multiplied 485 int64_t DimensionsProduct() const { 486 int64_t product = 1; 487 for (int i = 0; i < StrongShape::size(); ++i) { 488 product *= StrongShape::get(i); 489 } 490 return product; 491 } 492 493 // Translates given coordinates of the layout into a linear index assuming 494 // dimensions are sorted in tensor access order e.g. if you access 495 // foobar[i][j][k] order of coordinates should be i,j,k. 496 int64_t LinearIndex( 497 const std::array<int32_t, StrongShape::size()>& coordinates) const { 498 int64_t index = coordinates[0]; 499 for (int i = 1; i < StrongShape::size(); ++i) { 500 index = index * StrongShape::get(i) + coordinates[i]; 501 } 502 return index; 503 } 504 505 // Copies all dimensions from the given generic shape into specific shape. 506 // It requires shape to have all axis defined in the given 507 // StrongShape. For example: 508 // - If this shape is OHWI but given shape is OIHW, Adopt will copy all 509 // dimensions and return true. 510 // - If this shape is OIHW but input shape is HW, Adopt will copy H and W 511 // dimensions and return true, but if this shape is HW and given shape 512 // OIHW, then Adopt will return false because not all axis are present in 513 // the input shape. 514 // 515 // @return false if generic shape is not compatible. 516 bool Adopt(const Shape& shape) { 517 return DispatchByLayout(shape.layout, 518 internal_shape::ToShapeFunc<L>{this, shape}); 519 } 520 521 // For all axis defined in a given shape copies values to this shape. 522 // Therefore, it is possible to copy dimensions from CHW to BCHW, but not 523 // the other way around. 524 // 525 // BCHW bchw; 526 // CHW chw; 527 // bchw.CopyAllGivenAxis(chw); --> true 528 // chw.CopyAllGivenAxis(bchw); --> false 529 // 530 // @return false if axis in source shape is not defined here, thus value 531 // was not copied. 532 template <Layout B> 533 bool CopyAllGivenAxis(const StrongShape<B>& source) { 534 for (int i = 0; i < source.size(); ++i) { 535 if (!StrongShape::set(source.axis(i), source.get(i))) { 536 return false; 537 } 538 } 539 return true; 540 } 541 542 // For all axis defined in this shape copies values from the given shape. 543 // 544 // BCHW bchw; 545 // CHW chw; 546 // bchw.CopyAllDefinedAxis(chw); --> false 547 // chw.CopyAllDefinedAxis(bchw); --> true 548 // 549 // @return false if given shape does not have axis defined here, 550 // therefore a value was not copied. 551 template <Layout B> 552 bool CopyAllDefinedAxis(const StrongShape<B>& source) { 553 for (int i = 0; i < StrongShape::size(); ++i) { 554 int source_index = source.index(StrongShape::axis(i)); 555 if (source_index < 0) { 556 return false; 557 } 558 StrongShape::set(i, source.get(source_index)); // always true 559 } 560 return true; 561 } 562 563 // Copies values only for matching axis. 564 template <Layout B> 565 void CopyMatchingAxis(const StrongShape<B>& source) { 566 for (int i = 0; i < StrongShape::size(); ++i) { 567 StrongShape::set(source.axis(i), source.get(i)); 568 } 569 } 570 571 // AbslHash function for using in flat hash containers. 572 template <typename H> 573 friend H AbslHashValue(H hash_state, const StrongShape& strong_shape) { 574 for (size_t i = 0; i < strong_shape.size(); ++i) { 575 hash_state = H::combine(std::move(hash_state), strong_shape.get(i)); 576 } 577 return hash_state; 578 } 579 }; 580 581 template <Layout T> 582 inline std::string ToString(const StrongShape<T>& s) { 583 return ToString(s.ToShape()); 584 } 585 586 template <Layout L> 587 constexpr Layout StrongShape<L>::layout; 588 589 template <class F> 590 auto DispatchByLayout(Layout type, F f) 591 -> decltype(f.template operator()<Layout::UNKNOWN>()) { 592 switch (type) { 593 case Layout::HW: 594 return f.template operator()<Layout::HW>(); 595 case Layout::HWD: 596 return f.template operator()<Layout::HWD>(); 597 case Layout::HWC: 598 return f.template operator()<Layout::HWC>(); 599 case Layout::HWDC: 600 return f.template operator()<Layout::HWDC>(); 601 case Layout::CHW: 602 return f.template operator()<Layout::CHW>(); 603 case Layout::OIHW: 604 return f.template operator()<Layout::OIHW>(); 605 case Layout::IOHW: 606 return f.template operator()<Layout::IOHW>(); 607 case Layout::OHWI: 608 return f.template operator()<Layout::OHWI>(); 609 case Layout::IHWO: 610 return f.template operator()<Layout::IHWO>(); 611 case Layout::LINEAR: 612 return f.template operator()<Layout::LINEAR>(); 613 case Layout::SCALAR: 614 return f.template operator()<Layout::SCALAR>(); 615 case Layout::BHWC: 616 return f.template operator()<Layout::BHWC>(); 617 case Layout::BHWDC: 618 return f.template operator()<Layout::BHWDC>(); 619 case Layout::OHWDI: 620 return f.template operator()<Layout::OHWDI>(); 621 case Layout::UNKNOWN: 622 return f.template operator()<Layout::UNKNOWN>(); 623 } 624 } 625 626 template <Layout T> 627 constexpr int Size() { 628 return StrongShape<T>::size(); 629 } 630 631 template <Layout T> 632 constexpr Axis GetAxis(int index) { 633 return StrongShape<T>::axis(index); 634 } 635 636 template <Layout T> 637 constexpr int GetAxisIndex(Axis axis) { 638 return StrongShape<T>::index(axis); 639 } 640 641 template <Layout T> 642 constexpr bool HasAxis(Axis axis) { 643 return StrongShape<T>::has(axis); 644 } 645 646 template <Axis D> 647 inline int32_t Shape::get() const { 648 return DispatchByLayout( 649 layout, internal_shape::DimensionGetterFixedAxisFunc<D>{this}); 650 } 651 652 inline int32_t Shape::get(Axis axis) const { 653 return DispatchByLayout(layout, 654 internal_shape::DimensionGetterFunc{axis, this}); 655 } 656 657 template <Axis D> 658 inline bool Shape::set(int32_t t) { 659 return DispatchByLayout( 660 layout, internal_shape::DimensionSetterFixedAxisFunc<D>{this, t}); 661 } 662 663 inline bool Shape::set(Axis axis, int32_t t) { 664 return DispatchByLayout(layout, 665 internal_shape::DimensionSetterFunc{axis, this, t}); 666 } 667 668 } // namespace gpu 669 } // namespace tflite 670 671 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ 672