1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Shapes are protobuf messages, so this utility header offers a bunch of 17 // functionality for querying / poking at them. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ 20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ 21 22 #include <initializer_list> 23 #include <string> 24 25 #include "absl/base/macros.h" 26 #include "absl/container/inlined_vector.h" 27 #include "absl/types/optional.h" 28 #include "absl/types/span.h" 29 #include "tensorflow/compiler/xla/layout_util.h" 30 #include "tensorflow/compiler/xla/primitive_util.h" 31 #include "tensorflow/compiler/xla/shape.h" 32 #include "tensorflow/compiler/xla/status_macros.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/types.h" 35 #include "tensorflow/compiler/xla/util.h" 36 #include "tensorflow/compiler/xla/xla_data.pb.h" 37 #include "tensorflow/core/lib/core/threadpool.h" 38 #include "tensorflow/core/platform/cpu_info.h" 39 #include "tensorflow/core/platform/env.h" 40 #include "tensorflow/core/platform/macros.h" 41 #include "tensorflow/core/platform/mutex.h" 42 #include "tensorflow/core/platform/types.h" 43 44 namespace xla { 45 46 // An index for specifying a particular nested subshape within a shape. Used in 47 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data 48 // structures (trees) and ShapeIndex defines a path through the tree where each 49 // element of ShapeIndex indexes into a tuple (or nested tuple) within the 50 // shape. For a non-nested tuple, an index has a single element. For example, 51 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index 52 // {1} corresponds to array b. For a nested tuple, the index can have more than 53 // one element. For the nested tuple (a, (b, c, d), e) below are the values 54 // corresponding to the given indices: 55 // 56 // index {0} : array a 57 // index {1, 2} : array d 58 // index {2} : array e 59 // index {0, 0} : invalid index (element at {0} is an array not a tuple) 60 // 61 // For indexing into array shapes, the index is always trivially empty, ie {}. 62 // 63 // ShapeIndex is a trivial wrapper around std::vector with a minimum number of 64 // methods implemented. 65 class ShapeIndex { 66 public: 67 ShapeIndex() = default; ShapeIndex(std::initializer_list<int64> init)68 ShapeIndex(std::initializer_list<int64> init) : indices_(init) {} 69 template <typename InputIt> ShapeIndex(InputIt start,InputIt end)70 ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {} 71 empty()72 bool empty() const { return indices_.empty(); } size()73 size_t size() const { return indices_.size(); } push_back(int64 value)74 void push_back(int64 value) { indices_.push_back(value); } pop_back()75 void pop_back() { indices_.pop_back(); } 76 77 // push_front is O(n), but shapes don't usually have a ton of dimensions. push_front(int64 value)78 void push_front(int64 value) { indices_.insert(indices_.begin(), value); } 79 80 using container_type = absl::InlinedVector<int64, 2>; 81 begin()82 container_type::const_iterator begin() const { return indices_.begin(); } end()83 container_type::const_iterator end() const { return indices_.end(); } begin()84 container_type::iterator begin() { return indices_.begin(); } end()85 container_type::iterator end() { return indices_.end(); } 86 data()87 const int64* data() const { return indices_.data(); } 88 back()89 int64 back() const { return indices_.back(); } back()90 int64& back() { return indices_.back(); } 91 92 const int64& operator[](size_t i) const { return indices_[i]; } 93 int64& operator[](size_t i) { return indices_[i]; } 94 95 bool operator==(const ShapeIndex& other) const { 96 return indices_ == other.indices_; 97 } 98 bool operator!=(const ShapeIndex& other) const { return !(*this == other); } 99 bool operator<(const ShapeIndex& other) const { 100 return indices_ < other.indices_; 101 } 102 103 string ToString() const; 104 105 template <typename H> AbslHashValue(H h,const ShapeIndex & index)106 friend H AbslHashValue(H h, const ShapeIndex& index) { 107 return H::combine(std::move(h), index.indices_); 108 } 109 110 private: 111 container_type indices_; 112 }; 113 114 // A view into a ShapeIndex as above, with the cheap/easy ability to consume the 115 // value at the front of the view. 116 // 117 // NB! ShapeIndexView does not own the memory backing the index array. 118 // The memory backing the index array should be owned by an object 119 // that lives longer than the ShapeIndexView instances pointing into 120 // it. 121 class ShapeIndexView { 122 public: 123 ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0) 124 : indices_(shape_index.data() + offset, shape_index.size() - offset) { 125 CHECK_LE(offset, shape_index.size()); 126 } ShapeIndexView(std::initializer_list<int64> indices)127 ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {} 128 ShapeIndexView(const ShapeIndexView& other) = default; 129 130 using iterator = const int64*; 131 begin()132 iterator begin() const { return indices_.begin(); } end()133 iterator end() const { return indices_.end(); } size()134 int64 size() const { return indices_.size(); } empty()135 bool empty() const { return indices_.empty(); } front()136 int64 front() const { 137 CHECK(!empty()); 138 return indices_.front(); 139 } ConsumeFront()140 ShapeIndexView ConsumeFront() const { 141 ShapeIndexView result = *this; 142 result.indices_.remove_prefix(1); 143 return result; 144 } ConsumeBack()145 ShapeIndexView ConsumeBack() const { 146 ShapeIndexView result = *this; 147 result.indices_.remove_suffix(1); 148 return result; 149 } ToShapeIndex()150 ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); } 151 152 bool operator==(const ShapeIndexView& other) const; 153 bool operator!=(const ShapeIndexView& other) const; 154 155 string ToString() const; 156 157 // Returns true if this shape index starts with 'prefix'. 158 bool StartsWith(ShapeIndexView prefix) const; 159 160 private: 161 absl::Span<const int64> indices_; 162 }; 163 164 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); 165 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); 166 167 // Namespaced collection of (static) shape utilities. 168 // 169 // These are all effectively convenience functions for testing/tweaking proto 170 // properties, which do invariant checks before / after the operation. 171 class ShapeUtil { 172 public: 173 // Data structure which describes the coordinates and the shape, of a tuple 174 // shaped sub-shape. 175 struct IndexedShape { 176 IndexedShape() = default; IndexedShapeIndexedShape177 IndexedShape(ShapeIndex index, Shape shape) 178 : index(std::move(index)), shape(std::move(shape)) {} 179 ShapeIndex index; 180 Shape shape; 181 }; 182 183 // Returns the number of elements are contained within the provided shape; 184 // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes 185 // may not actually be able to store this number of elements. See 186 // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of 187 // elements that can be stored in a sparse shape. 188 // Precondition: shape.IsArray() 189 static int64 ElementsIn(const Shape& shape); 190 191 // As ElementsIn(), but recurses through tuples. 192 static int64 ElementsInRecursive(const Shape& shape); 193 194 // Returns true if shape has the primitive type, recurses through tuples. 195 static bool HasPrimitiveType(const Shape& shape, 196 PrimitiveType primitive_type); 197 198 // Returns true if 'shape' is an array with zero elements. 199 static bool IsZeroElementArray(const Shape& shape); 200 201 // Returns the number of bytes required for an allocation of shape. The 202 // |pointer_size| parameter is used for calculating the size of tuple 203 // shapes. This includes only the size of the top-level buffer. For example, a 204 // tuple is stored as an array of pointers to other buffers. In this case, 205 // this method only returns the size of the pointer array. 206 static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1); 207 208 // Returns the number of bytes used to store the primitive_type. 209 // 210 // Precondition: shape.IsArray() 211 static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); 212 213 // Returns the number of bytes required to store the tuple member pointers for 214 // a allocation of shape. The `shape` must be a TUPLE shape, and 215 // `pointer_size` must be larger than zero. 216 static int64 ByteSizeOfTupleIndexTable(const Shape& shape, 217 int64 pointer_size); 218 219 // Returns the number of bytes required for the elements in an allocation of 220 // `shape`, which must be an array shape. The return value does not include 221 // the bytes needed to store sparse indices. Dense shapes use a separate 222 // memory location for each element, and so for these shapes, 223 // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this 224 // size also includes padding if present in the layout. For sparse shapes, 225 // `ByteSizeOf(shape) == ByteSizeOfElements(shape) + 226 // ByteSizeOfSparseindices(shape)`. 227 static int64 ByteSizeOfElements(const Shape& shape); 228 229 // Returns the number of bytes required for the sparse indices in an 230 // allocation of shape. The shape must be an array shape. The return value 231 // does not include the bytes needed to store sparse indices. 232 static int64 ByteSizeOfSparseIndices(const Shape& shape); 233 234 // Returns a human-readable string that represents the given shape, with or 235 // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]". 236 static string HumanString(const Shape& shape); 237 static string HumanStringWithLayout(const Shape& shape); 238 239 // As above, but for program shapes, returns a string for the form: 240 // 241 // (param_name: f32[42x12], ...) -> f32[24x42] 242 static string HumanString(const ProgramShape& program_shape); 243 244 // Returns whether the LHS and RHS shapes have the same dimensions; note: does 245 // not check element type. 246 // Precondition: IsArray(lhs) && IsArray(rhs) 247 static bool SameDimensions(const Shape& lhs, const Shape& rhs); 248 249 // Returns whether the lhs and rhs shapes have the same element type. SameElementType(const Shape & lhs,const Shape & rhs)250 static bool SameElementType(const Shape& lhs, const Shape& rhs) { 251 return lhs.element_type() == rhs.element_type(); 252 } 253 254 // As SameElementType, but allows floating point types to have different 255 // precisions. SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)256 static bool SameElementTypeIgnoringFpPrecision(const Shape& a, 257 const Shape& b) { 258 if (ElementIsFloating(a) && ElementIsFloating(b)) { 259 return true; 260 } 261 return ShapeUtil::SameElementType(a, b); 262 } 263 264 // Returns the higher-precision element type if a and b are both floating 265 // point types; otherwise, checks that they have the same element type 266 // and returns it. HigherPrecisionElementType(const Shape & a,const Shape & b)267 static PrimitiveType HigherPrecisionElementType(const Shape& a, 268 const Shape& b) { 269 if (SameElementType(a, b)) { 270 return a.element_type(); 271 } 272 CHECK(SameElementTypeIgnoringFpPrecision(a, b)); 273 return primitive_util::BitWidth(a.element_type()) < 274 primitive_util::BitWidth(b.element_type()) 275 ? b.element_type() 276 : a.element_type(); 277 } 278 279 // Returns true if the rank, dimension sizes, and element type are 280 // identical. Layout is ignored. Tuple elements are compared recursively for 281 // compatibility. 282 static bool Compatible(const Shape& lhs, const Shape& rhs); 283 284 // Returns true if the rank and dimension sizes are identical. Element type 285 // and layout are ignored. Tuple elements are compared recursively for 286 // compatibility. 287 static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); 288 289 // As Compatible, but allow one of lhs and rhs to be BF16 while the other 290 // being F32. Tuple elements are compared recursively for compatibility. 291 static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); 292 293 // Returns whether the lhs and rhs shapes are identical. 294 static bool Equal(const Shape& lhs, const Shape& rhs); 295 296 // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. 297 static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); 298 299 // Returns the number of dimensions for which the dimension is not (trivially) 300 // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just 301 // fluff. Note that zero dimensions are included in the true rank, e.g., 302 // f32[3,0,1] has a true rank of 2D. 303 static int64 TrueRank(const Shape& shape); 304 305 static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters, 306 Shape result); 307 308 //////////////////// 309 // Scalar-specific 310 IsScalar(const Shape & shape)311 static bool IsScalar(const Shape& shape) { 312 return shape.IsArray() && shape.rank() == 0; 313 } IsEffectiveScalar(const Shape & shape)314 static bool IsEffectiveScalar(const Shape& shape) { 315 return shape.IsArray() && TrueRank(shape) == 0; 316 } 317 318 // Returns whether "shape" is a scalar (array) with the given element_type. 319 static bool IsScalarWithElementType(const Shape& shape, 320 PrimitiveType element_type); 321 322 // Extracts the size of the shape's dimension at dimension number 323 // GetDimensionNumber(dimension_number). 324 static int64 GetDimension(const Shape& shape, int64 dimension_number); 325 326 // Resolves a dimension number, supporting negative indexing. 327 // 328 // Negative indexing has similar semantics to Python. For an N-dimensional 329 // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to 330 // N-2, and so on. 331 // 332 // This function always returns a positive dimension number for any given 333 // dimension_number (which itself can be negative). 334 static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number); 335 336 // Returns a shape with the same dimensions as the original, but with the 337 // element type changed to type. 338 static Shape ChangeElementType(const Shape& original, PrimitiveType type); 339 340 // Creates a tuple shape from a slice of element shapes within the tuple. 341 static Shape MakeTupleShape(absl::Span<const Shape> shapes); 342 343 // Creates an opaque shape. These are generally used for threading a context 344 // into a custom operation. 345 static Shape MakeOpaqueShape(); 346 347 // Creates a token shape. Values of this shape are used for ordering 348 // side-effecting operations. 349 static Shape MakeTokenShape(); 350 351 // Appends a shape to the given tuple. 352 static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape); 353 354 // Appends a major dimension to the shape with the given bound. 355 static void AppendMajorDimension(int bound, Shape* shape); 356 357 // Returns an empty tuple shape. Can be used as a sentinel Shape value. MakeNil()358 static Shape MakeNil() { return MakeTupleShape({}); } 359 360 // Checks whether the shape is initialized. IsInitialized(const Shape & shape)361 static bool IsInitialized(const Shape& shape) { 362 return shape.element_type() != PRIMITIVE_TYPE_INVALID; 363 } 364 365 // Constructs a new shape with the given element type and sequence of 366 // dimensions. 367 static Shape MakeShape(PrimitiveType element_type, 368 absl::Span<const int64> dimensions); 369 370 // Constructs a new shape with the given element type and sequence of 371 // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates 372 // with a true value that the respective dimension is dynamic. If the 373 // dimension is dynamic then the respective value in 'dimension' is an upper 374 // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be 375 // the same size. 376 static Shape MakeShape(PrimitiveType element_type, 377 absl::Span<const int64> dimensions, 378 const std::vector<bool>& dynamic_dimensions); 379 380 // Constructs a new shape with the given element type and sequence of 381 // dimensions. Method checks if the element type is valid and the shape's 382 // size fits in std::numeric_limits<int64>::max(). 383 static StatusOr<Shape> MakeValidatedShape(PrimitiveType element_type, 384 absl::Span<const int64> dimensions); 385 static StatusOr<Shape> MakeValidatedShape( 386 PrimitiveType element_type, absl::Span<const int64> dimensions, 387 const std::vector<bool>& dynamic_dimensions); 388 389 // Creates a Shape with element type corresponding to T and the given 390 // dimensions 391 template <typename T> MakeShapeWithType(absl::Span<const int64> dimensions)392 static Shape MakeShapeWithType(absl::Span<const int64> dimensions) { 393 return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(), 394 dimensions); 395 } 396 397 // Constructs a new shape with the given minor_to_major order in its Layout. 398 // Returns a value shape such that shape.has_layout(). 399 static Shape MakeShapeWithLayout(PrimitiveType element_type, 400 absl::Span<const int64> dimensions, 401 absl::Span<const int64> minor_to_major, 402 absl::Span<const Tile> tiles = {}, 403 int64 element_size_in_bits = 0); 404 405 static Shape MakeShapeWithSparseLayout(PrimitiveType element_type, 406 absl::Span<const int64> dimensions, 407 int64 max_sparse_elements); 408 409 // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). 410 static Shape MakeShapeWithDescendingLayout( 411 PrimitiveType element_type, absl::Span<const int64> dimensions); 412 413 // Returns a new Shape based on the given Shape with low-dimension-major 414 // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions 415 // rearranged so that it has the same in-memory layout as the given shape. 416 // 417 // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}. 418 static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout( 419 const Shape& shape); 420 421 // As MakeShape, but the object to write to is passed in. 422 static Status PopulateShape(PrimitiveType element_type, 423 absl::Span<const int64> dimensions, Shape* shape); 424 425 // Validates that the provided shape satisfies invariants. 426 static Status ValidateShape(const Shape& shape); 427 428 // Validates the provided shape satisfies invariants, except those that 429 // pertain to layout. 430 // 431 // Layout is optional for client-provided shapes, so that the compiler may 432 // determine and assign an optimized layout. 433 static Status ValidateShapeWithOptionalLayout(const Shape& shape); 434 435 // Returns whether the element type of the shape is integral (signed or 436 // unsigned). Note that predicates are not considered integral here, since 437 // they are logical values. 438 static bool ElementIsIntegral(const Shape& shape); 439 440 // Returns whether the element type of the shape is floating point. 441 static bool ElementIsFloating(const Shape& shape); 442 443 // Returns whether the element type of the shape is complex. 444 static bool ElementIsComplex(const Shape& shape); 445 446 // Returns whether the element type has the given bit width. 447 static bool ElementHasBitWidth(const Shape& shape, int bits); 448 449 // Returns whether the element type of the shape is integral and has 450 // the specified number of bits. 451 static bool ElementIsIntegralWithBits(const Shape& shape, int bits); 452 453 // Returns whether the element type of the shape is signed. Note 454 // that floating point numbers are signed. 455 static bool ElementIsSigned(const Shape& shape); 456 457 // Returns whether the given primitive type corresponds to an array shape. 458 static bool IsArrayPrimitiveType(PrimitiveType primitive_type); 459 460 // Returns whether the shape is a tuple with at least one element which is 461 // also a tuple. 462 static bool IsNestedTuple(const Shape& shape); 463 464 // Returns true if shape is an empty tuple. 465 static bool IsEmptyTuple(const Shape& shape); 466 467 // Returns the number of elements in the given tuple shape. 468 // Precondition: IsTuple(shape) 469 static int64 TupleElementCount(const Shape& shape); 470 471 // Returns the tuple element shape at given index. 472 // Precondition: IsTuple(shape) && TupleElementCount(shape) > index 473 static const Shape& GetTupleElementShape(const Shape& shape, int64 index); 474 475 // Returns the number of elements, recursively, in the given shape. 476 static int64 SubshapeCount(const Shape& shape); 477 478 // Slices tuple elements in the range [start, limit) and returns a new tuple 479 // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). 480 static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); 481 482 // Returns the shape of the real/imaginary components of the given complex 483 // shape. 484 static Shape ComplexComponentShape(const Shape& complex_shape); 485 486 // Returns true if the given shape has a subshape at the given index. 487 static bool IndexIsValid(const Shape& shape, ShapeIndexView index); 488 489 // GetSubshape and GetMutableSubshape return a particular nested Shape within 490 // the given Shape argument. The non-Try variants check fail if index is 491 // invalid. 492 static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index); 493 static StatusOr<const Shape*> TryGetSubshape(const Shape& shape, 494 ShapeIndexView index); 495 static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index); 496 497 // Returns whether the given index in the given shape is a leaf element of the 498 // shape. 499 static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index); 500 501 // Returns the number of leaves in the shape. 502 static int64 GetLeafCount(const Shape& shape); 503 504 // Retrieves all the leaf shapes and their indexes, in the order walked by 505 // the ForEachSubshape() API. 506 static std::vector<IndexedShape> GetLeafShapes(const Shape& shape); 507 508 // Calls the given visitor function for each subshape of the given shape. 509 // Subshapes are visited in DFS pre-order starting with the entire shape 510 // (index {}). 511 using VisitorFunction = std::function<void(const Shape& /*subshape*/, 512 const ShapeIndex& /*index*/)>; 513 static void ForEachSubshape(const Shape& shape, const VisitorFunction& func); 514 using MutatingVisitorFunction = 515 std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>; 516 static void ForEachMutableSubshape(Shape* shape, 517 const MutatingVisitorFunction& func); 518 519 // Variants of ForEach(Mutable)Subshape which propagate Status from the 520 // visitor function. 521 using StatusVisitorFunction = std::function<Status( 522 const Shape& /*subshape*/, const ShapeIndex& /*index*/)>; 523 static Status ForEachSubshapeWithStatus(const Shape& shape, 524 const StatusVisitorFunction& func); 525 using MutatingStatusVisitorFunction = 526 std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>; 527 static Status ForEachMutableSubshapeWithStatus( 528 Shape* shape, const MutatingStatusVisitorFunction& func); 529 530 // Returns true if `shape` (which must be an array) with degenerate dimensions 531 // (dimensions with bound 1). 532 static bool HasDegenerateDimensions(const Shape& shape); 533 534 // Drops any degenerate dimensions (i.e. dimensions of size 1) 535 static Shape DropDegenerateDimensions(const Shape& shape); 536 537 // Permutes the dimensions by the given permutation, so 538 // return_value.dimensions[permutation[i]] = argument.dimensions[i]. 539 // 540 // Postcondition: For any valid permutation, 541 // 542 // !HasLayout(shape) || 543 // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape), 544 // InversePermutation(permutation)). 545 static Shape PermuteDimensions(absl::Span<const int64> permutation, 546 const Shape& shape); 547 548 // If we can go from `shape_pre` to `shape_post` by merely inserting or 549 // deleting 1-sized dimensions, return the indices in `shape_pre` of the 550 // deleted dimensions and the indices in `dims_post` of the inserted 551 // dimensions. 552 // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and 553 // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s 554 // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each 555 // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and 556 // `j`s less than `k` for all other `k`, we return the `i`s and `j`s. 557 // For another example, if `shape_pre = shape_post = {}`, we return `{}`. 558 static std::tuple<bool, std::vector<int64>, std::vector<int64>> 559 InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, 560 const Shape& shape_post); 561 562 // Suppose a reshape transforms input_shape to output shape. Returns a vector 563 // of pairs that indicate the input and output dimensions that this reshape 564 // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in 565 // the returned vector, the reshape transforms any input index whose I-th 566 // dimension is x to an output index whose O-th dimension is x too. 567 // 568 // Post-condition: the returned vector is sorted (by both input and output 569 // dimensions because input and output dimensions have the same order). 570 // 571 // Example: 572 // input shape = T[a, b, x, y, cd] 573 // output shape = T[ab, x, 1, y, c, d] 574 // return value = {{2, 1}, {3, 3}} 575 // 576 // The two pairs represent the input and output dimension of size x and 577 // those of size y. 578 static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape( 579 const Shape& input_shape, const Shape& output_shape); 580 581 // Returns whether a transpose from input_shape to output_shape with dimension 582 // mapping "dimension_mapping" produces a result which is bit-wise identical 583 // to its input and thus may be replaced with a bitcast. 584 // 585 // Precondition: Both input_shape and output_shape have explicit layouts. 586 static bool TransposeIsBitcast(const Shape& input_shape, 587 const Shape& output_shape, 588 absl::Span<const int64> dimension_mapping); 589 590 // Returns whether a reshape from "input_shape" to "output_shape" is a 591 // bitcast. 592 // 593 // Precondition: Both input_shape and output_shape have explicit layouts. 594 static bool ReshapeIsBitcast(const Shape& input_shape, 595 const Shape& output_shape); 596 597 // Find a physical layout for 'output_shape' such that 598 // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns 599 // true (where 'output_shape_with_layout' is 'output_shape' with the found 600 // layout). The layout of 'input_shape' is kept fixed. Returns 601 // 'output_shape_with_layout' if such a layout can be found, and an error 602 // otherwise. 603 static absl::optional<Shape> AlignLayouts(const Shape& input_shape, 604 const Shape& output_shape); 605 606 // Returns a shape with the given dimension deleted. 607 // For example: 608 // • `DeleteDimension(1, T[m, n, k]) = T[m, k]` 609 static Shape DeleteDimension(int64 dim_to_delete, Shape shape); 610 611 // Returns a shape with all the dimensions of the input shape for which `p` 612 // returns true. 613 // For examples: 614 // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]` 615 // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]` 616 static Shape FilterDimensions(const std::function<bool(int64)>& p, 617 Shape shape); 618 619 // Iterates through all the shape indexes, in minor to major order, starting 620 // from the base indexes, incrementing by the incr steps, up to count 621 // (index[i] < base[i] + count[i]), and calls the visitor_function with the 622 // current index. 623 // The visitor_function visitor function should return true if it wants to 624 // continue, or false otherwise. 625 // 626 // visitor_function must be a callable of type 627 // StatusOr<bool>(Span<int64>) or compatible. 628 template <typename FnType> ForEachIndexWithStatus(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)629 static Status ForEachIndexWithStatus(const Shape& shape, 630 absl::Span<const int64> base, 631 absl::Span<const int64> count, 632 absl::Span<const int64> incr, 633 const FnType& visitor_function) { 634 return ForEachIndexInternal(shape, base, count, incr, visitor_function); 635 } 636 637 // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus. 638 struct IndexIterationSpace { 639 std::vector<int64> index_base; 640 std::vector<int64> index_count; 641 std::vector<int64> index_incr; 642 }; 643 644 template <typename FnTy> ForEachIndexWithStatus(const Shape & shape,const IndexIterationSpace & iteration_space,FnTy && function)645 static Status ForEachIndexWithStatus( 646 const Shape& shape, const IndexIterationSpace& iteration_space, 647 FnTy&& function) { 648 return ShapeUtil::ForEachIndexWithStatus( 649 shape, iteration_space.index_base, iteration_space.index_count, 650 iteration_space.index_incr, std::forward<FnTy>(function)); 651 } 652 653 template <typename FnType> ForEachIndex(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)654 static void ForEachIndex(const Shape& shape, absl::Span<const int64> base, 655 absl::Span<const int64> count, 656 absl::Span<const int64> incr, 657 const FnType& visitor_function) { 658 ForEachIndexWithStatus(shape, base, count, incr, 659 [&](absl::Span<const int64> indices) { 660 return StatusOr<bool>(visitor_function(indices)); 661 }) 662 .IgnoreError(); 663 } 664 665 // These convenience wrappers don't take `base`, `count` and `incr` 666 // explicitly, but iterate over every element in `shape` instead. 667 668 template <typename FnType> ForEachIndexWithStatus(const Shape & shape,const FnType & visitor_function)669 static Status ForEachIndexWithStatus(const Shape& shape, 670 const FnType& visitor_function) { 671 std::vector<int64> base(shape.dimensions_size()); 672 std::vector<int64> incr(shape.dimensions_size(), 1); 673 return ForEachIndexWithStatus(shape, base, 674 /*count=*/AsInt64Slice(shape.dimensions()), 675 incr, visitor_function); 676 } 677 678 template <typename FnType> ForEachIndex(const Shape & shape,const FnType & visitor_function)679 static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { 680 ForEachIndexWithStatus(shape, [&](absl::Span<const int64> indices) { 681 return StatusOr<bool>(visitor_function(indices)); 682 }).IgnoreError(); 683 } 684 685 // A parallel version of ForEachIndex(WithStatus). This can only be used if 686 // the visitor_function is thread-safe and the order of iteration does not 687 // matter. 688 // 689 // visitor_function must be a callable of type 690 // void(Span<int64>) or compatible. 691 template <typename FnType> ForEachIndexParallel(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)692 static void ForEachIndexParallel(const Shape& shape, 693 absl::Span<const int64> base, 694 absl::Span<const int64> count, 695 absl::Span<const int64> incr, 696 const FnType& visitor_function) { 697 // The parallel version of ForEachIndexInternal can never fail. 698 CHECK(ForEachIndexInternal( 699 shape, base, count, incr, 700 [&visitor_function]( 701 absl::Span<const int64> indexes) -> StatusOr<bool> { 702 visitor_function(indexes); 703 return true; 704 }, 705 /*parallel=*/true) 706 .ok()); 707 } 708 709 // Compute a hash for `shape`. 710 static size_t Hash(const Shape& shape); 711 712 private: 713 // Validates the shape size is sane. This makes sure it's safe to do 714 // calculations in int64 without overflowing. 715 static Status ValidateShapeSize(const Shape& shape); 716 717 // Validates all of the non-layout properties of the shape -- this is a helper 718 // used by both the layout-optional and layout-required public method. 719 static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); 720 721 template <typename FnType> 722 static Status ForEachIndexInternal(const Shape& shape, 723 absl::Span<const int64> base, 724 absl::Span<const int64> count, 725 absl::Span<const int64> incr, 726 const FnType& visitor_function, 727 bool parallel = false) { 728 if (ShapeUtil::IsZeroElementArray(shape)) { 729 return Status::OK(); 730 } 731 CHECK_EQ(shape.rank(), base.size()); 732 CHECK_EQ(incr.size(), base.size()); 733 CHECK_EQ(count.size(), base.size()); 734 const int64 rank = LayoutUtil::MinorToMajor(shape).size(); 735 // Allows handling R0 arrays, such that the visitor function will be called 736 // once with the proper empty indexes. 737 int64 n = -1; 738 std::vector<int64> indexes(base.begin(), base.end()); 739 const int kNumThreads = tensorflow::port::NumSchedulableCPUs(); 740 absl::optional<tensorflow::thread::ThreadPool> pool; 741 if (parallel) { 742 pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads); 743 } 744 745 tensorflow::mutex mu; 746 Status status; // Guarded by mu 747 748 while (n < rank) { 749 if (pool != absl::nullopt) { 750 pool->Schedule([indexes, &visitor_function, &mu, &status] { 751 StatusOr<bool> result = visitor_function(indexes); 752 if (!result.ok()) { 753 tensorflow::mutex_lock lock(mu); 754 status = status.ok() ? result.status() : status; 755 } 756 }); 757 } else { 758 TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes)); 759 if (!should_continue) { 760 break; 761 } 762 } 763 // Increments dimensions in minor to major order. 764 for (n = 0; n < rank; ++n) { 765 int64 dim = LayoutUtil::Minor(shape.layout(), n); 766 indexes[dim] += incr[dim]; 767 if (indexes[dim] < base[dim] + count[dim]) { 768 break; 769 } 770 indexes[dim] = base[dim]; 771 } 772 } 773 774 // Waits for the scheduled work to complete. 775 pool.reset(); 776 return status; 777 } 778 779 TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil); 780 }; 781 782 } // namespace xla 783 784 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ 785