1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Shapes are protobuf messages, so this utility header offers a bunch of 17 // functionality for querying / poking at them. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ 20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ 21 22 #include <algorithm> 23 #include <functional> 24 #include <initializer_list> 25 #include <optional> 26 #include <string> 27 #include <tuple> 28 #include <utility> 29 #include <vector> 30 31 #include "absl/base/macros.h" 32 #include "absl/container/inlined_vector.h" 33 #include "absl/types/span.h" 34 #include "tensorflow/compiler/xla/layout_util.h" 35 #include "tensorflow/compiler/xla/primitive_util.h" 36 #include "tensorflow/compiler/xla/shape.h" 37 #include "tensorflow/compiler/xla/status_macros.h" 38 #include "tensorflow/compiler/xla/statusor.h" 39 #include "tensorflow/compiler/xla/types.h" 40 #include "tensorflow/compiler/xla/util.h" 41 #include "tensorflow/compiler/xla/xla_data.pb.h" 42 #include "tensorflow/core/lib/core/threadpool.h" 43 #include "tensorflow/core/platform/cpu_info.h" 44 #include "tensorflow/core/platform/env.h" 45 46 namespace xla { 47 48 // A view into a ShapeIndex below, with the cheap/easy ability to consume the 49 // value at the front of the view. 50 // 51 // NB! ShapeIndexView does not own the memory backing the index array. 52 // The memory backing the index array should be owned by an object 53 // that lives longer than the ShapeIndexView instances pointing into 54 // it. 55 using ShapeIndexView = absl::Span<const int64_t>; 56 57 // An index for specifying a particular nested subshape within a shape. Used in 58 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data 59 // structures (trees) and ShapeIndex defines a path through the tree where each 60 // element of ShapeIndex indexes into a tuple (or nested tuple) within the 61 // shape. For a non-nested tuple, an index has a single element. For example, 62 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index 63 // {1} corresponds to array b. For a nested tuple, the index can have more than 64 // one element. For the nested tuple (a, (b, c, d), e) below are the values 65 // corresponding to the given indices: 66 // 67 // index {0} : array a 68 // index {1, 2} : array d 69 // index {2} : array e 70 // index {0, 0} : invalid index (element at {0} is an array not a tuple) 71 // 72 // For indexing into array shapes, the index is always trivially empty, ie {}. 73 struct ShapeIndex : public absl::InlinedVector<int64_t, 2> { 74 using InlinedVector::InlinedVector; 75 76 ShapeIndex() = default; // Needed to make MSVC work for some reason. ShapeIndexShapeIndex77 explicit ShapeIndex(ShapeIndexView view) 78 : ShapeIndex(view.begin(), view.end()) {} 79 80 // push_front is O(n), but shapes don't usually have a ton of dimensions. push_frontShapeIndex81 void push_front(int64_t value) { insert(begin(), value); } pop_frontShapeIndex82 void pop_front() { erase(begin()); } 83 84 std::string ToString() const; 85 }; 86 87 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); 88 89 // Namespaced collection of (static) shape utilities. 90 // 91 // These are all effectively convenience functions for testing/tweaking proto 92 // properties, which do invariant checks before / after the operation. 93 class ShapeUtil { 94 public: 95 // Data structure which describes the coordinates and the shape, of a tuple 96 // shaped sub-shape. 97 struct IndexedShape { 98 IndexedShape() = default; IndexedShapeIndexedShape99 IndexedShape(ShapeIndex index, Shape shape) 100 : index(std::move(index)), shape(std::move(shape)) {} 101 ShapeIndex index; 102 Shape shape; 103 }; 104 105 // Returns the number of elements are contained within the provided shape; 106 // e.g. for rank 0 (scalars) the result is always 1. 107 // Precondition: shape.IsArray() 108 static int64_t ElementsIn(const Shape& shape); 109 110 // As ElementsIn(), but recurses through tuples. 111 static int64_t ElementsInRecursive(const Shape& shape); 112 113 // Returns true if shape has the primitive type, recurses through tuples. 114 static bool HasPrimitiveType(const Shape& shape, 115 PrimitiveType primitive_type); 116 117 // Returns true if 'shape' is an array with zero elements. 118 static bool IsZeroElementArray(const Shape& shape); 119 120 // Returns the number of bytes required for an allocation of shape. The 121 // |pointer_size| parameter is used for calculating the size of tuple 122 // shapes. This includes only the size of the top-level buffer. For example, a 123 // tuple is stored as an array of pointers to other buffers. In this case, 124 // this method only returns the size of the pointer array. 125 static int64_t ByteSizeOf(const Shape& shape, int64_t pointer_size = -1); 126 127 // Returns the number of bytes used to store the primitive_type. 128 // 129 // Precondition: shape.IsArray() 130 static int64_t ByteSizeOfPrimitiveType(PrimitiveType primitive_type); 131 132 // Returns the number of bytes required to store the tuple member pointers for 133 // a allocation of shape. The `shape` must be a TUPLE shape, and 134 // `pointer_size` must be larger than zero. 135 static int64_t ByteSizeOfTupleIndexTable(const Shape& shape, 136 int64_t pointer_size); 137 138 // Returns the number of bytes required for the elements in an allocation of 139 // `shape`, which must be an array shape. Shapes use a separate 140 // memory location for each element, and so for these shapes, 141 // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. This 142 // size also includes padding if present in the layout. 143 static int64_t ByteSizeOfElements(const Shape& shape); 144 145 // Returns a human-readable string that represents the given shape, with or 146 // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]". 147 static std::string HumanString(const Shape& shape); 148 static std::string HumanStringWithLayout(const Shape& shape); 149 150 // As above, but for program shapes, returns a string for the form: 151 // 152 // (param_name: f32[42x12], ...) -> f32[24x42] 153 static std::string HumanString(const ProgramShape& program_shape); 154 155 // Returns whether the LHS and RHS shapes have the same dimensions; note: does 156 // not check element type. 157 // Precondition: IsArray(lhs) && IsArray(rhs) 158 static bool SameDimensions(const Shape& lhs, const Shape& rhs); 159 160 // Returns whether the LHS and RHS shapes have the same rank; note: does 161 // not check element type. 162 // Precondition: IsArray(lhs) && IsArray(rhs) 163 static bool SameRank(const Shape& lhs, const Shape& rhs); 164 165 // Returns whether the lhs and rhs shapes have the same element type. SameElementType(const Shape & lhs,const Shape & rhs)166 static bool SameElementType(const Shape& lhs, const Shape& rhs) { 167 return lhs.element_type() == rhs.element_type(); 168 } 169 170 // As SameElementType, but allows floating point types to have different 171 // precisions. SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)172 static bool SameElementTypeIgnoringFpPrecision(const Shape& a, 173 const Shape& b) { 174 if (ElementIsFloating(a) && ElementIsFloating(b)) { 175 return true; 176 } 177 return ShapeUtil::SameElementType(a, b); 178 } 179 180 // Returns the higher-precision element type if a and b are both floating 181 // point types; otherwise, checks that they have the same element type 182 // and returns it. HigherPrecisionElementType(const Shape & a,const Shape & b)183 static PrimitiveType HigherPrecisionElementType(const Shape& a, 184 const Shape& b) { 185 return primitive_util::HigherPrecisionType(a.element_type(), 186 b.element_type()); 187 } 188 189 // Returns true if the rank, dimension sizes, and element type are 190 // identical. Layout is ignored. Tuple elements are compared recursively for 191 // compatibility. 192 static bool Compatible(const Shape& lhs, const Shape& rhs); 193 194 // Returns true if the rank and dimension sizes are identical. Element type 195 // and layout are ignored. Tuple elements are compared recursively for 196 // compatibility. 197 static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); 198 199 // Returns true if the tuple tree shapes and leaf ranks are identical. 200 // Leaf dimensions, element type, and layout are ignored. Tuple elements are 201 // compared recursively for compatibility. 202 static bool CompatibleKind(const Shape& lhs, const Shape& rhs); 203 204 // As Compatible, but allow one of lhs and rhs to be BF16 while the other 205 // being F32. Tuple elements are compared recursively for compatibility. 206 static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); 207 208 // Returns whether the lhs and rhs shapes are identical. 209 static bool Equal(const Shape& lhs, const Shape& rhs); 210 211 // As Equal, but does not compare the element type. 212 static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs); 213 214 // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. 215 static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); 216 217 // Two shapes have same structure if all subshape indices of lhs are presented 218 // on rhs and vice versa. 219 // A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to 220 // (S32, (F32[3], S32[2])) as their structures are both (,(,)) 221 // 222 // In contrast, (F32, (F32, F32)) is structurally different from 223 // ((F32, F32), F32) as the former has structure (,(,)) while the latter has 224 // ((,),) 225 static bool EqualStructure(const Shape& lhs, const Shape& rhs); 226 227 // Returns the number of dimensions for which the dimension is not (trivially) 228 // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just 229 // fluff. Note that zero dimensions are included in the true rank, e.g., 230 // f32[3,0,1] has a true rank of 2D. 231 static int64_t TrueRank(const Shape& shape); 232 233 static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters, 234 Shape result); 235 236 //////////////////// 237 // Scalar-specific 238 IsScalar(const Shape & shape)239 static bool IsScalar(const Shape& shape) { 240 return shape.IsArray() && shape.rank() == 0; 241 } IsEffectiveScalar(const Shape & shape)242 static bool IsEffectiveScalar(const Shape& shape) { 243 return shape.IsArray() && TrueRank(shape) == 0; 244 } 245 246 // Returns whether "shape" is a scalar (array) with the given element_type. 247 static bool IsScalarWithElementType(const Shape& shape, 248 PrimitiveType element_type); 249 250 // Extracts the size of the shape's dimension at dimension number 251 // GetDimensionNumber(dimension_number). 252 static int64_t GetDimension(const Shape& shape, int64_t dimension_number); 253 254 // Resolves a dimension number, supporting negative indexing. 255 // 256 // Negative indexing has similar semantics to Python. For an N-dimensional 257 // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to 258 // N-2, and so on. 259 // 260 // This function always returns a positive dimension number for any given 261 // dimension_number (which itself can be negative). 262 static int64_t GetDimensionNumber(const Shape& shape, 263 int64_t dimension_number); 264 265 // Returns a shape with the same dimensions as the original, but with the 266 // element type changed to type. 267 static Shape ChangeElementType(const Shape& original, PrimitiveType type); 268 269 // Retursn a shape with same dimensions but with all dimensions set to static. 270 static Shape MakeStaticShape(const Shape& original); 271 272 // Creates a tuple shape from a slice of element shapes within the tuple. 273 static Shape MakeTupleShape(absl::Span<const Shape> shapes); 274 static Shape MakeTupleShapeWithPtrs(absl::Span<const Shape* const> shapes); 275 276 // Creates a tuple shape from a slice of element shapes within the tuple. If 277 // only one shape is passed, returns that. 278 static Shape MakeMaybeTupleShape(absl::Span<const Shape> shapes); 279 280 // Creates an opaque shape. These are generally used for threading a context 281 // into a custom operation. 282 static Shape MakeOpaqueShape(); 283 284 // Creates a token shape. Values of this shape are used for ordering 285 // side-effecting operations. 286 static Shape MakeTokenShape(); 287 288 // Appends a shape to the given tuple. 289 static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape); 290 291 // Update a subshape of a tuple. 292 static void UpdateTupleShape(const Shape& shape, int64_t index, 293 Shape* tuple_shape); 294 295 // Update the dynamic dimension for a shape. This shape can be a nested tuple. 296 static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index, 297 int64_t dim, bool is_dynamic); 298 299 // Appends a major dimension to the shape with the given bound. 300 static void AppendMajorDimension(int bound, Shape* shape); 301 302 // Appends a minor dimension to the shape with the given bound. 303 static void AppendMinorDimension(int bound, Shape* shape); 304 305 // Copy the dynamic dimensions property from one shape to another. 306 static void CopyDynamicDimensions(Shape* to, const Shape& from); 307 308 // Returns an empty tuple shape. Can be used as a sentinel Shape value. MakeNil()309 static Shape MakeNil() { return MakeTupleShape({}); } 310 311 // Checks whether the shape is initialized. IsInitialized(const Shape & shape)312 static bool IsInitialized(const Shape& shape) { 313 return shape.element_type() != PRIMITIVE_TYPE_INVALID; 314 } 315 316 // Constructs a new shape with the given element type and sequence of 317 // dimensions. 318 static Shape MakeShape(PrimitiveType element_type, 319 absl::Span<const int64_t> dimensions); 320 321 // Make a scalar shape with given primitive type. 322 static Shape MakeScalarShape(PrimitiveType element_type); 323 324 // Constructs a new shape with the given element type and sequence of 325 // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates 326 // with a true value that the respective dimension is dynamic. If the 327 // dimension is dynamic then the respective value in 'dimension' is an upper 328 // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be 329 // the same size. 330 static Shape MakeShape(PrimitiveType element_type, 331 absl::Span<const int64_t> dimensions, 332 const std::vector<bool>& dynamic_dimensions); 333 334 // Constructs a new shape with the given element type and sequence of 335 // dimensions. Method checks if the element type is valid and the shape's 336 // size fits in std::numeric_limits<int64_t>::max(). 337 static StatusOr<Shape> MakeValidatedShape( 338 PrimitiveType element_type, absl::Span<const int64_t> dimensions); 339 static StatusOr<Shape> MakeValidatedShape( 340 PrimitiveType element_type, absl::Span<const int64_t> dimensions, 341 const std::vector<bool>& dynamic_dimensions); 342 343 // Creates a Shape with element type corresponding to T and the given 344 // dimensions 345 template <typename T> MakeShapeWithType(absl::Span<const int64_t> dimensions)346 static Shape MakeShapeWithType(absl::Span<const int64_t> dimensions) { 347 return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(), 348 dimensions); 349 } 350 351 // Constructs a new shape with the given minor_to_major order in its Layout. 352 // Returns a value shape such that shape.has_layout(). 353 static Shape MakeShapeWithLayout( 354 PrimitiveType element_type, absl::Span<const int64_t> dimensions, 355 absl::Span<const int64_t> minor_to_major, 356 absl::Span<const DimLevelType> dim_level_types = {}, 357 absl::Span<const Tile> tiles = {}, int64_t element_size_in_bits = 0, 358 int64_t memory_space = 0); 359 360 // Constructs a new shape with the given dimension `dim` as the most major 361 // dimension in the layout. If the shape does not have a layout, assumes a 362 // default layout. If the shape is a tuple, apply this to all the leaf shapes 363 // of the tuple. 364 static Shape MoveDimToMajor(const Shape& shape, int64_t dim); 365 366 // Returns the same shape except with all dimensions set to be static. 367 static Shape MakeShapeWithStaticDimensions(const Shape& shape); 368 369 // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). 370 static Shape MakeShapeWithDescendingLayout( 371 PrimitiveType element_type, absl::Span<const int64_t> dimensions); 372 373 // Returns a new Shape based on the given Shape with low-dimension-major 374 // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions 375 // rearranged so that it has the same in-memory layout as the given shape. 376 // 377 // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}. 378 static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout( 379 const Shape& shape); 380 381 // As MakeShape, but the object to write to is passed in. 382 static Status PopulateShape(PrimitiveType element_type, 383 absl::Span<const int64_t> dimensions, 384 Shape* shape); 385 386 // Validates that the provided shape satisfies invariants. 387 static Status ValidateShape(const Shape& shape); 388 389 // Validates the provided shape satisfies invariants, except those that 390 // pertain to layout. 391 // 392 // Layout is optional for client-provided shapes, so that the compiler may 393 // determine and assign an optimized layout. 394 static Status ValidateShapeWithOptionalLayout(const Shape& shape); 395 396 // Returns whether the element type of the shape is integral (signed or 397 // unsigned). Note that predicates are not considered integral here, since 398 // they are logical values. 399 static bool ElementIsIntegral(const Shape& shape); 400 401 // Returns whether the element type of the shape is floating point. 402 static bool ElementIsFloating(const Shape& shape); 403 404 // Returns whether the element type of the shape is complex. 405 static bool ElementIsComplex(const Shape& shape); 406 407 // Returns whether the element type has the given bit width. 408 static bool ElementHasBitWidth(const Shape& shape, int bits); 409 410 // Returns whether the element type of the shape is integral and has 411 // the specified number of bits. 412 static bool ElementIsIntegralWithBits(const Shape& shape, int bits); 413 414 // Returns whether the element type of the shape is signed. Note 415 // that floating point numbers are signed. 416 static bool ElementIsSigned(const Shape& shape); 417 418 // Returns whether the given primitive type corresponds to an array shape. 419 static bool IsArrayPrimitiveType(PrimitiveType primitive_type); 420 421 // Returns whether the shape is a tuple with at least one element which is 422 // also a tuple. 423 static bool IsNestedTuple(const Shape& shape); 424 425 // Returns true if shape is an empty tuple. 426 static bool IsEmptyTuple(const Shape& shape); 427 428 // Returns the number of elements in the given tuple shape. 429 // Precondition: IsTuple(shape) 430 static int64_t TupleElementCount(const Shape& shape); 431 432 // Returns the tuple element shape at given index. 433 // Precondition: IsTuple(shape) && TupleElementCount(shape) > index 434 static const Shape& GetTupleElementShape(const Shape& shape, int64_t index); 435 436 // Returns the number of elements, recursively, in the given shape. 437 static int64_t SubshapeCount(const Shape& shape); 438 439 // Slices tuple elements in the range [start, limit) and returns a new tuple 440 // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). 441 static Shape SliceTuple(const Shape& tuple, int64_t start, int64_t limit); 442 443 // Returns the shape of the real/imaginary components of the given complex 444 // shape. 445 static Shape ComplexComponentShape(const Shape& complex_shape); 446 447 // Returns true if the given shape has a subshape at the given index. 448 static bool IndexIsValid(const Shape& shape, ShapeIndexView index); 449 450 // GetSubshape and GetMutableSubshape return a particular nested Shape within 451 // the given Shape argument. The non-Try variants check fail if index is 452 // invalid. 453 static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index); 454 static StatusOr<const Shape*> TryGetSubshape(const Shape& shape, 455 ShapeIndexView index); 456 static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index); 457 458 // Returns whether the given index in the given shape is a leaf element of the 459 // shape. 460 static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index); 461 462 // Returns the number of leaves in the shape. 463 static int64_t GetLeafCount(const Shape& shape); 464 465 // Retrieves all the leaf shapes and their indexes, in the order walked by 466 // the ForEachSubshape() API. 467 static std::vector<IndexedShape> GetLeafShapes(const Shape& shape); 468 469 // Calls the given visitor function for each subshape of the given shape. 470 // Subshapes are visited in DFS pre-order starting with the entire shape 471 // (index {}). 472 using VisitorFunction = std::function<void(const Shape& /*subshape*/, 473 const ShapeIndex& /*index*/)>; 474 static void ForEachSubshape(const Shape& shape, const VisitorFunction& func); 475 using MutatingVisitorFunction = 476 std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>; 477 static void ForEachMutableSubshape(Shape* shape, 478 const MutatingVisitorFunction& func); 479 480 // Variants of ForEach(Mutable)Subshape which propagate Status from the 481 // visitor function. 482 using StatusVisitorFunction = std::function<Status( 483 const Shape& /*subshape*/, const ShapeIndex& /*index*/)>; 484 static Status ForEachSubshapeWithStatus(const Shape& shape, 485 const StatusVisitorFunction& func); 486 using MutatingStatusVisitorFunction = 487 std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>; 488 static Status ForEachMutableSubshapeWithStatus( 489 Shape* shape, const MutatingStatusVisitorFunction& func); 490 491 // Returns true if `shape` (which must be an array) with degenerate dimensions 492 // (dimensions with bound 1). 493 static bool HasDegenerateDimensions(const Shape& shape); 494 495 // Drops any degenerate dimensions (i.e. dimensions of size 1) 496 static Shape DropDegenerateDimensions(const Shape& shape); 497 498 // Permutes the dimensions by the given permutation, so 499 // return_value.dimensions[i] = argument.dimensions[permutation[i]]. 500 // 501 // Postcondition: For any valid permutation, 502 // 503 // !HasLayout(shape) || 504 // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape), 505 // permutation). 506 static Shape PermuteDimensions(absl::Span<const int64_t> permutation, 507 const Shape& shape); 508 509 // Describes how we can go from shape A to shape B by inserting degenerate 510 // 1-sized dimensions in `added_dimensions` and removing degenerate 1-sized 511 // dimensions from B in `removed_dimensions`. 512 // 513 // Only exists if shapes A and B only differ by degenerate dimensions. 514 struct ShapeEqualityDescriptor { 515 std::vector<int64_t> deleted_dimensions; 516 std::vector<int64_t> inserted_dimensions; 517 }; 518 519 // If we can go from `shape_pre` to `shape_post` by merely inserting or 520 // deleting 1-sized dimensions, return the indices in `shape_pre` of the 521 // deleted dimensions and the indices in `dims_post` of the inserted 522 // dimensions. 523 // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and 524 // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s 525 // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each 526 // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and 527 // `j`s less than `k` for all other `k`, we return the `i`s and `j`s. 528 // For another example, if `shape_pre = shape_post = {}`, we return `{}`. 529 static std::optional<ShapeEqualityDescriptor> 530 InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, 531 const Shape& shape_post); 532 533 // Suppose a reshape transforms input_shape to output shape. Returns a vector 534 // of pairs that indicate the input and output dimensions that this reshape 535 // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in 536 // the returned vector, the reshape transforms any input index whose I-th 537 // dimension is x to an output index whose O-th dimension is x too. 538 // 539 // Post-condition: the returned vector is sorted (by both input and output 540 // dimensions because input and output dimensions have the same order). 541 // 542 // Example: 543 // input shape = T[a, b, x, y, cd] 544 // output shape = T[ab, x, 1, y, c, d] 545 // return value = {{2, 1}, {3, 3}} 546 // 547 // The two pairs represent the input and output dimension of size x and 548 // those of size y. 549 static std::vector<std::pair<int64_t, int64_t>> DimensionsUnmodifiedByReshape( 550 const Shape& input_shape, const Shape& output_shape); 551 552 // Return whether the given reshape instruction leaves the dimensions at the 553 // given input indices unmodified, and returns their output indices. 554 // 555 // Example: 556 // input_dim_indices = {2, 3} 557 // input shape = T[a, b, x, y, cd] 558 // output shape = T[ab, x, 1, y, c, d] 559 // return value = {1, 3} 560 // 561 // Precondition: input_dim_indices is sorted. 562 static std::optional<std::vector<int64_t>> ReshapeLeavesDimensionsUnmodified( 563 const Shape& from_shape, const Shape& to_shape, 564 absl::Span<const int64_t> input_dim_indices); 565 566 // Returns whether a transpose from input_shape to output_shape with dimension 567 // mapping "dimension_mapping" produces a result which is bit-wise identical 568 // to its input and thus may be replaced with a bitcast. 569 // 570 // Precondition: Both input_shape and output_shape have explicit layouts. 571 static bool TransposeIsBitcast(const Shape& input_shape, 572 const Shape& output_shape, 573 absl::Span<const int64_t> dimension_mapping); 574 575 // Returns whether a reshape from `input_shape` to `output_shape` is a 576 // bitcast, when minor_to_major in layout is considered. 577 // 578 // Precondition: Both input_shape and output_shape have explicit layouts. 579 static bool ReshapeIsBitcast(const Shape& input_shape, 580 const Shape& output_shape); 581 582 // Find a physical layout for 'output_shape' such that 583 // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns 584 // true (where 'output_shape_with_layout' is 'output_shape' with the found 585 // layout). The layout of 'input_shape' is kept fixed. Returns 586 // 'output_shape_with_layout' if such a layout can be found, and an error 587 // otherwise. 588 static std::optional<Shape> AlignLayouts(const Shape& input_shape, 589 const Shape& output_shape); 590 591 // Returns a shape with the given dimension deleted. 592 // For example: 593 // • `DeleteDimension(1, T[m, n, k]) = T[m, k]` 594 static Shape DeleteDimension(int64_t dim_to_delete, Shape shape); 595 596 // Returns a shape with dimensions in `to_drop` dropped. 597 static Shape DeleteDimensions(absl::Span<int64_t const> dims_to_delete, 598 Shape shape); 599 600 // Returns a shape with all the dimensions of the input shape for which `p` 601 // returns true. 602 // For examples: 603 // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]` 604 // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]` 605 static Shape FilterDimensions(const std::function<bool(int64_t)>& p, 606 Shape shape); 607 608 // Returns true if `dynamic_shape` has dimensions that are less-equal to the 609 // "bounded_shape". Shapes must be arrays. 610 static bool DynamicArrayShapeIsCompatible(const xla::Shape& dynamic_shape, 611 const xla::Shape& bounded_shape); 612 613 // Same as DynamicArrayShapeIsCompatible() but supports tuples. 614 static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape, 615 const xla::Shape& bounded_shape); 616 617 // Iterates through all the shape indexes, in minor to major order, 618 // starting from the base indexes, incrementing by the incr steps, up to 619 // count (index[i] < base[i] + count[i]), and calls the visitor_function 620 // with the current index. The visitor_function visitor function should 621 // return true if it wants to continue, or false otherwise. 622 // 623 // visitor_function must be a callable of type 624 // StatusOr<bool>(absl::Span<int64_t>) or compatible. 625 template <typename FnType> ForEachIndexWithStatus(const Shape & shape,absl::Span<const int64_t> base,absl::Span<const int64_t> count,absl::Span<const int64_t> incr,const FnType & visitor_function)626 static Status ForEachIndexWithStatus(const Shape& shape, 627 absl::Span<const int64_t> base, 628 absl::Span<const int64_t> count, 629 absl::Span<const int64_t> incr, 630 const FnType& visitor_function) { 631 return ForEachIndexInternal( 632 shape, base, count, incr, 633 [&visitor_function](absl::Span<const int64_t> indexes, 634 int /*thread_id*/) { 635 return visitor_function(indexes); 636 }); 637 } 638 639 // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus. 640 struct IndexIterationSpace { 641 std::vector<int64_t> index_base; 642 std::vector<int64_t> index_count; 643 std::vector<int64_t> index_incr; 644 }; 645 646 template <typename FnTy> ForEachIndexWithStatus(const Shape & shape,const IndexIterationSpace & iteration_space,FnTy && function)647 static Status ForEachIndexWithStatus( 648 const Shape& shape, const IndexIterationSpace& iteration_space, 649 FnTy&& function) { 650 return ShapeUtil::ForEachIndexWithStatus( 651 shape, iteration_space.index_base, iteration_space.index_count, 652 iteration_space.index_incr, std::forward<FnTy>(function)); 653 } 654 655 template <typename FnType> ForEachIndex(const Shape & shape,absl::Span<const int64_t> base,absl::Span<const int64_t> count,absl::Span<const int64_t> incr,const FnType & visitor_function)656 static void ForEachIndex(const Shape& shape, absl::Span<const int64_t> base, 657 absl::Span<const int64_t> count, 658 absl::Span<const int64_t> incr, 659 const FnType& visitor_function) { 660 ForEachIndexWithStatus(shape, base, count, incr, 661 [&](absl::Span<const int64_t> indices) { 662 return StatusOr<bool>(visitor_function(indices)); 663 }) 664 .IgnoreError(); 665 } 666 667 // These convenience wrappers don't take `base`, `count` and `incr` 668 // explicitly, but iterate over every element in `shape` instead. 669 670 template <typename FnType> ForEachIndexWithStatus(const Shape & shape,const FnType & visitor_function)671 static Status ForEachIndexWithStatus(const Shape& shape, 672 const FnType& visitor_function) { 673 std::vector<int64_t> base(shape.dimensions_size()); 674 std::vector<int64_t> incr(shape.dimensions_size(), 1); 675 return ForEachIndexWithStatus(shape, base, 676 /*count=*/shape.dimensions(), incr, 677 visitor_function); 678 } 679 680 template <typename FnType> ForEachIndex(const Shape & shape,const FnType & visitor_function)681 static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { 682 ForEachIndexWithStatus(shape, [&](absl::Span<const int64_t> indices) { 683 return StatusOr<bool>(visitor_function(indices)); 684 }).IgnoreError(); 685 } 686 687 // A parallel version of ForEachIndex(WithStatus). This can only be used if 688 // the visitor_function is thread-safe and the order of iteration does not 689 // matter. 690 // 691 // visitor_function must be a callable of type 692 // void(Span<int64_t>, int thread_id) or compatible. 693 template <typename FnType> ForEachIndexParallel(const Shape & shape,absl::Span<const int64_t> base,absl::Span<const int64_t> count,absl::Span<const int64_t> incr,const FnType & visitor_function)694 static void ForEachIndexParallel(const Shape& shape, 695 absl::Span<const int64_t> base, 696 absl::Span<const int64_t> count, 697 absl::Span<const int64_t> incr, 698 const FnType& visitor_function) { 699 // The parallel version of ForEachIndexInternal can never fail. 700 CHECK(ForEachIndexInternal( 701 shape, base, count, incr, 702 [&visitor_function](absl::Span<const int64_t> indexes, 703 int thread_id) -> StatusOr<bool> { 704 visitor_function(indexes, thread_id); 705 return true; 706 }, 707 /*parallel=*/true) 708 .ok()); 709 } 710 // Convenience wrapper which doesn't take `base`, `count` and `incr` 711 // explicitly, but iterates over every element in `shape` instead. 712 template <typename FnType> ForEachIndexParallel(const Shape & shape,const FnType & visitor_function)713 static void ForEachIndexParallel(const Shape& shape, 714 const FnType& visitor_function) { 715 std::vector<int64_t> base(shape.dimensions_size()); 716 std::vector<int64_t> incr(shape.dimensions_size(), 1); 717 return ForEachIndexParallel(shape, base, 718 /*count=*/shape.dimensions(), incr, 719 visitor_function); 720 } 721 722 // About 0-2-1 transpose: 723 // 724 // If a shape can be viewed as three logical components 0-1-2 in the order of 725 // major to minor, a 0-2-1-transpose changes the order of such logical 726 // components to 0-2-1. We call the shape being transposed the input shape and 727 // the transposed shape the output shape. The logical view of the input/output 728 // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the 729 // normalized shapes. The original input/output shapes are called unnormalized 730 // shapes. 731 // 732 // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the 733 // normalized shape of `b` or the 0-2-1 shape. 734 static std::optional<Vector3> FindTranspose021(const Shape& a, 735 const Shape& b); 736 737 // Strips device-specific information, namely tiling and memory-space 738 // information, from a shape. 739 static Shape DeviceShapeToHostShape(Shape s); 740 741 // Returns true iff element type of shape `from` can be safely upcasted to 742 // element type of shape `to`. 743 static bool ElementCanUpcast(const Shape& from, const Shape& to); 744 745 // Computes byte strides of an array shape `shape`. `shape` must have a 746 // layout. Ignores tiling. `strides` must have size equal to the number of 747 // dimensions of `shape`. 748 static Status ByteStrides(const Shape& shape, absl::Span<int64_t> strides); 749 750 // Returns the array size in bytes (layout/tiling required), all paddings are 751 // included. 752 static int64_t ArraySize(const Shape& shape); 753 754 // Returns the size of array data in bytes, ignoring the trailing padding 755 // due to the tiling requirement. 756 static int64_t ArrayDataSize(const Shape& shape); 757 758 private: 759 // Fills *shape. Returns true on success. 760 // REQUIRES: *shape is empty. 761 static bool FillNewShape(PrimitiveType element_type, 762 absl::Span<const int64_t> dimensions, Shape* shape); 763 764 // Validates the shape size is sane. This makes sure it's safe to do 765 // calculations in int64_t without overflowing. 766 static Status ValidateShapeSize(const Shape& shape); 767 768 // Validates all of the non-layout properties of the shape -- this is a helper 769 // used by both the layout-optional and layout-required public method. 770 static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); 771 772 template <typename FnType> 773 static Status ForEachIndexInternal(const Shape& shape, 774 absl::Span<const int64_t> base, 775 absl::Span<const int64_t> count, 776 absl::Span<const int64_t> incr, 777 const FnType& visitor_function, 778 bool parallel = false) { 779 if (ShapeUtil::IsZeroElementArray(shape)) { 780 return OkStatus(); 781 } 782 CHECK_EQ(shape.rank(), base.size()); 783 CHECK_EQ(incr.size(), base.size()); 784 CHECK_EQ(count.size(), base.size()); 785 const int64_t rank = LayoutUtil::MinorToMajor(shape).size(); 786 // Allows handling R0 arrays, such that the visitor function will be called 787 // once with the proper empty indexes. 788 int64_t n = -1; 789 std::vector<int64_t> indexes(base.begin(), base.end()); 790 const int kNumThreads = tensorflow::port::MaxParallelism(); 791 std::optional<tensorflow::thread::ThreadPool> pool; 792 if (parallel) { 793 pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads); 794 } 795 796 absl::Mutex mu; 797 Status status; // Guarded by mu 798 799 while (n < rank) { 800 if (pool != std::nullopt) { 801 pool->Schedule([indexes, &visitor_function, &mu, &status, &pool] { 802 const int thread_id = pool->CurrentThreadId(); 803 StatusOr<bool> result = visitor_function(indexes, thread_id); 804 if (!result.ok()) { 805 absl::MutexLock lock(&mu); 806 status = status.ok() ? result.status() : status; 807 } 808 }); 809 } else { 810 TF_ASSIGN_OR_RETURN(bool should_continue, 811 visitor_function(indexes, /*thread_id=*/-1)); 812 if (!should_continue) { 813 break; 814 } 815 } 816 // Increments dimensions in minor to major order. 817 for (n = 0; n < rank; ++n) { 818 int64_t dim = LayoutUtil::Minor(shape.layout(), n); 819 indexes[dim] += incr[dim]; 820 if (indexes[dim] < base[dim] + count[dim]) { 821 break; 822 } 823 indexes[dim] = base[dim]; 824 } 825 } 826 827 // Waits for the scheduled work to complete. 828 pool.reset(); 829 return status; 830 } 831 832 ShapeUtil(const ShapeUtil&) = delete; 833 ShapeUtil& operator=(const ShapeUtil&) = delete; 834 }; 835 836 } // namespace xla 837 838 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ 839