1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Shapes are protobuf messages, so this utility header offers a bunch of 17 // functionality for querying / poking at them. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ 20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ 21 22 #include <initializer_list> 23 #include <string> 24 25 #include "tensorflow/compiler/xla/layout_util.h" 26 #include "tensorflow/compiler/xla/primitive_util.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/lib/gtl/array_slice.h" 31 #include "tensorflow/core/lib/gtl/optional.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace xla { 36 37 // An index for specifying a particular nested subshape within a shape. Used in 38 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data 39 // structures (trees) and ShapeIndex defines a path through the tree where each 40 // element of ShapeIndex indexes into a tuple (or nested tuple) within the 41 // shape. For a non-nested tuple, an index has a single element. For example, 42 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index 43 // {1} corresponds to array b. For a nested tuple, the index can have more than 44 // one element. For the nested tuple (a, (b, c, d), e) below are the values 45 // corresponding to the given indices: 46 // 47 // index {0} : array a 48 // index {1, 2} : array d 49 // index {2} : array e 50 // index {0, 0} : invalid index (element at {0} is an array not a tuple) 51 // 52 // For indexing into array shapes, the index is always trivially empty, ie {}. 53 // 54 // ShapeIndex is a trivial wrapper around std::vector with a minimum number of 55 // methods implemented. 56 class ShapeIndex { 57 public: 58 ShapeIndex() = default; ShapeIndex(std::initializer_list<int64> init)59 ShapeIndex(std::initializer_list<int64> init) : indices_(init) {} 60 empty()61 bool empty() const { return indices_.empty(); } size()62 size_t size() const { return indices_.size(); } push_back(int64 value)63 void push_back(int64 value) { indices_.push_back(value); } pop_back()64 void pop_back() { indices_.pop_back(); } 65 66 // push_front is O(n^2), but shapes don't usually have a ton of dimensions. push_front(int64 value)67 void push_front(int64 value) { indices_.insert(indices_.begin(), value); } 68 begin()69 std::vector<int64>::const_iterator begin() const { return indices_.begin(); } end()70 std::vector<int64>::const_iterator end() const { return indices_.end(); } begin()71 std::vector<int64>::iterator begin() { return indices_.begin(); } end()72 std::vector<int64>::iterator end() { return indices_.end(); } 73 data()74 const int64* data() const { return indices_.data(); } 75 back()76 int64 back() const { return indices_.back(); } back()77 int64& back() { return indices_.back(); } 78 79 const int64& operator[](size_t i) const { return indices_[i]; } 80 int64& operator[](size_t i) { return indices_[i]; } 81 82 bool operator==(const ShapeIndex& other) const { 83 return indices_ == other.indices_; 84 } 85 bool operator!=(const ShapeIndex& other) const { return !(*this == other); } 86 bool operator<(const ShapeIndex& other) const { 87 return indices_ < other.indices_; 88 } 89 90 string ToString() const; 91 92 private: 93 std::vector<int64> indices_; 94 }; 95 96 // A view into a ShapeIndex as above, with the cheap/easy ability to consume the 97 // value at the front of the view. 98 // 99 // NB! ShapeIndexView does not own the memory backing the index array. 100 // The memory backing the index array should be owned by an object 101 // that lives longer than the ShapeIndexView instances pointing into 102 // it. 103 class ShapeIndexView { 104 public: 105 ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0) 106 : ShapeIndexView(shape_index.data() + offset, 107 shape_index.data() + shape_index.size()) { 108 CHECK_LE(offset, shape_index.size()); 109 } ShapeIndexView(std::initializer_list<int64> indices)110 ShapeIndexView(std::initializer_list<int64> indices) 111 : ShapeIndexView(indices.begin(), indices.end()) {} 112 ShapeIndexView(const ShapeIndexView& other) = default; 113 114 using iterator = const int64*; 115 begin()116 iterator begin() const { return begin_; } end()117 iterator end() const { return end_; } size()118 int64 size() const { return std::distance(begin_, end_); } empty()119 bool empty() const { return begin_ == end_; } front()120 int64 front() const { 121 CHECK(!empty()); 122 return *begin_; 123 } ConsumeFront()124 ShapeIndexView ConsumeFront() const { 125 CHECK(!empty()); 126 auto new_begin = begin_; 127 ++new_begin; 128 return ShapeIndexView(new_begin, end_); 129 } 130 131 string ToString() const; 132 133 private: ShapeIndexView(iterator begin,iterator end)134 ShapeIndexView(iterator begin, iterator end) : begin_(begin), end_(end) {} 135 136 iterator begin_; 137 iterator end_; 138 }; 139 140 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); 141 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); 142 143 // Namespaced collection of (static) shape utilities. 144 // 145 // These are all effectively convenience functions for testing/tweaking proto 146 // properties, which do invariant checks before / after the operation. 147 class ShapeUtil { 148 public: 149 // Returns the number of elements are contained within the provided shape; 150 // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes 151 // may not actually be able to store this number of elements. See 152 // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of 153 // elements that can be stored in a sparse shape. 154 // Precondition: !IsTuple(shape) 155 static int64 ElementsIn(const Shape& shape); 156 157 // Returns true if 'shape' has zero elements. 158 static bool HasZeroElements(const Shape& shape); 159 160 // Returns the number of bytes required for an allocation of shape. The 161 // |pointer_size| parameter is used for calculating the size of tuple 162 // shapes. This includes only the size of the top-level buffer. For example, a 163 // tuple is stored as an array of pointers to other buffers. In this case, 164 // this method only returns the size of the pointer array. 165 // Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) && 166 // !ShapeUtil::IsOpaque(shape) 167 static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1); 168 169 // Returns the number of bytes used to store the primitive_type. 170 // 171 // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) 172 static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); 173 174 // Returns the number of bytes required to store the tuple member pointers for 175 // a allocation of shape. The `shape` must be a TUPLE shape, and 176 // `pointer_size` must be larger than zero. 177 static int64 ByteSizeOfTupleIndexTable(const Shape& shape, 178 int64 pointer_size); 179 180 // Returns the number of bytes required for the elements in an allocation of 181 // `shape`, which must be an array shape. The return value does not include 182 // the bytes needed to store sparse indices. Dense shapes use a separate 183 // memory location for each element, and so for these shapes, 184 // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this 185 // size also includes padding if present in the layout. For sparse shapes, 186 // `ByteSizeOf(shape) == ByteSizeOfElements(shape) + 187 // ByteSizeOfSparseindices(shape)`. 188 static int64 ByteSizeOfElements(const Shape& shape); 189 190 // Returns the number of bytes required for the sparse indices in an 191 // allocation of shape. The shape must be an array shape. The return value 192 // does not include the bytes needed to store sparse indices. 193 static int64 ByteSizeOfSparseIndices(const Shape& shape); 194 195 // Returns a human-readable string that represents the given shape, with or 196 // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]". 197 static string HumanString(const Shape& shape); 198 static string HumanStringWithLayout(const Shape& shape); 199 200 // As above, but for program shapes, returns a string for the form: 201 // 202 // (param_name: f32[42x12], ...) -> f32[24x42] 203 static string HumanString(const ProgramShape& program_shape); 204 205 // Parses a ShapeUtil::HumanString-format shape string back into a shape 206 // object. 207 static StatusOr<Shape> ParseShapeString(tensorflow::StringPiece s); 208 209 // Returns whether the LHS and RHS shapes have the same dimensions; note: does 210 // not check element type. 211 static bool SameDimensions(const Shape& lhs, const Shape& rhs); 212 213 // Returns whether the lhs and rhs shapes have the same element type. SameElementType(const Shape & lhs,const Shape & rhs)214 static bool SameElementType(const Shape& lhs, const Shape& rhs) { 215 return lhs.element_type() == rhs.element_type(); 216 } 217 218 // As SameElementType, but allows floating point types to have different 219 // precisions. SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)220 static bool SameElementTypeIgnoringFpPrecision(const Shape& a, 221 const Shape& b) { 222 if (ElementIsFloating(a) && ElementIsFloating(b)) { 223 return true; 224 } 225 return ShapeUtil::SameElementType(a, b); 226 } 227 228 // Returns the higher-precision element type if a and b are both floating 229 // point types; otherwise, checks that that they have the same element type 230 // and returns it. HigherPrecisionElementType(const Shape & a,const Shape & b)231 static PrimitiveType HigherPrecisionElementType(const Shape& a, 232 const Shape& b) { 233 if (SameElementType(a, b)) { 234 return a.element_type(); 235 } 236 CHECK(SameElementTypeIgnoringFpPrecision(a, b)); 237 return primitive_util::BitWidth(a.element_type()) < 238 primitive_util::BitWidth(b.element_type()) 239 ? b.element_type() 240 : a.element_type(); 241 } 242 243 // Returns true if the rank, dimension sizes, and element type are 244 // identical. Layout is ignored. Tuple elements are compared recursively for 245 // compatibility. 246 static bool Compatible(const Shape& lhs, const Shape& rhs); 247 248 // Returns true if the rank and dimension sizes are identical. Element type 249 // and layout are ignored. Tuple elements are compared recursively for 250 // compatibility. 251 static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); 252 253 // As Compatible, but allow one of lhs and rhs to be BF16 while the other 254 // being F32. Tuple elements are compared recursively for compatibility. 255 static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); 256 257 // Returns whether the lhs and rhs shapes are identical protobufs. 258 static bool Equal(const Shape& lhs, const Shape& rhs); 259 260 // Returns the rank (number of dimensions) of the given shape. 261 // Precondition: !IsTuple(shape) 262 static int64 Rank(const Shape& shape); 263 264 // Returns the number of dimensions for which the dimension is not (trivially) 265 // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just 266 // fluff. Note that zero dimensions are included in the true rank, e.g., 267 // f32[3,0,1] has a true rank of 2D. 268 static int64 TrueRank(const Shape& shape); 269 270 static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters, 271 Shape result); 272 273 //////////////////// 274 // Scalar-specific 275 IsScalar(const Shape & shape)276 static bool IsScalar(const Shape& shape) { 277 return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0; 278 } IsEffectiveScalar(const Shape & shape)279 static bool IsEffectiveScalar(const Shape& shape) { 280 return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0; 281 } 282 static bool IsScalarF32(const Shape& shape); 283 284 // Extracts the size of the shape's dimension at dimension number 285 // GetDimensionNumber(dimension_number). 286 static int64 GetDimension(const Shape& shape, int64 dimension_number); 287 288 // Resolves a dimension number, supporting negative indexing. 289 // 290 // Negative indexing has similar semantics to Python. For an N-dimensional 291 // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to 292 // N-2, and so on. 293 // 294 // This function always returns a positive dimension number for any given 295 // dimension_number (which itself can be negative). 296 static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number); 297 298 // Returns a shape with the same dimensions as the original, but with the 299 // element type changed to type. 300 static Shape ChangeElementType(const Shape& original, PrimitiveType type); 301 302 // Creates a tuple shape from a slice of element shapes within the tuple. 303 static Shape MakeTupleShape(tensorflow::gtl::ArraySlice<Shape> shapes); 304 305 // Creates an opaque shape. These are generally used for threading a context 306 // into a custom operation. 307 static Shape MakeOpaqueShape(); 308 309 // Appends a shape to the given tuple. 310 static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape); 311 312 // Appends a major dimension to the shape with the given bound. 313 static void AppendMajorDimension(int bound, Shape* shape); 314 315 // Returns an empty tuple shape. Can be used to indicate side-effects. MakeNil()316 static Shape MakeNil() { return MakeTupleShape({}); } 317 318 // Constructs a new shape with the given element type and sequence of 319 // dimensions. 320 static Shape MakeShape(PrimitiveType element_type, 321 tensorflow::gtl::ArraySlice<int64> dimensions); 322 323 // Constructs a new shape with the given minor_to_major order in its Layout. 324 // Returns a value shape such that shape.has_layout(). 325 static Shape MakeShapeWithLayout( 326 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions, 327 tensorflow::gtl::ArraySlice<int64> minor_to_major); 328 329 static Shape MakeShapeWithSparseLayout( 330 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions, 331 int64 max_sparse_elements); 332 333 // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). 334 static Shape MakeShapeWithDescendingLayout( 335 PrimitiveType element_type, 336 tensorflow::gtl::ArraySlice<int64> dimensions); 337 338 // Returns a new Shape based on the given Shape with low-dimension-major 339 // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions 340 // rearranged so that it has the same in-memory layout as the given shape. 341 // 342 // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}. 343 static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout( 344 const Shape& shape); 345 346 // As MakeShape, but the object to write to is passed in. 347 static void PopulateShape(PrimitiveType element_type, 348 tensorflow::gtl::ArraySlice<int64> dimensions, 349 Shape* shape); 350 351 // Validates that the provided shape satisfies invariants. 352 static Status ValidateShape(const Shape& shape); 353 354 // Validates the provided shape satisfies invariants, except those that 355 // pertain to layout. 356 // 357 // Layout is optional for client-provided shapes, so that the compiler may 358 // determine and assign an optimized layout. 359 static Status ValidateShapeWithOptionalLayout(const Shape& shape); 360 361 // Returns whether the element type of the shape is integral (signed or 362 // unsigned). Note that predicates are not considered integral here, since 363 // they are logical values. 364 static bool ElementIsIntegral(const Shape& shape); 365 366 // Returns whether the element type of the shape is floating point. 367 static bool ElementIsFloating(const Shape& shape); 368 369 // Returns whether the element type of the shape is complex. 370 static bool ElementIsComplex(const Shape& shape); 371 372 // Returns whether the element type has the given bit width. 373 static bool ElementHasBitWidth(const Shape& shape, int bits); 374 375 // Returns whether the element type of the shape is integral and has 376 // the specified number of bits. 377 static bool ElementIsIntegralWithBits(const Shape& shape, int bits); 378 379 // Returns whether the element type of the shape is signed. Note 380 // that floating point numbers are signed. 381 static bool ElementIsSigned(const Shape& shape); 382 383 // Returns whether the shape is a tuple. IsTuple(const Shape & shape)384 static bool IsTuple(const Shape& shape) { 385 return shape.element_type() == TUPLE; 386 } 387 388 // Returns whether the shape is an opaque value (i.e. an 'existential' typed 389 // value that is passed to CustomCall operations). IsOpaque(const Shape & shape)390 static bool IsOpaque(const Shape& shape) { 391 return shape.element_type() == OPAQUE; 392 } 393 394 // Returns whether the shape is an array. Note that scalars are considered 395 // arrays. IsArray(const Shape & shape)396 static bool IsArray(const Shape& shape) { 397 return !IsTuple(shape) && !IsOpaque(shape); 398 } 399 400 // Returns whether the shape is a tuple with at least one element which is 401 // also a tuple. 402 static bool IsNestedTuple(const Shape& shape); 403 404 // Returns true if shape is an empty tuple. 405 static bool IsEmptyTuple(const Shape& shape); 406 407 // Returns true if shape is an empty tuple, or is an array with no elements. 408 static bool IsNil(const Shape& shape); 409 410 // Returns the number of elements in the given tuple shape. 411 // Precondition: IsTuple(shape) 412 static int64 TupleElementCount(const Shape& shape); 413 414 // Returns the tuple element shape at given index. 415 // Precondition: IsTuple(shape) && TupleElementCount(shape) > index 416 static const Shape& GetTupleElementShape(const Shape& shape, int64 index); 417 418 // Slices tuple elements in the range [start, limit) and returns a new tuple 419 // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). 420 static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); 421 422 // Returns the shape of the real/imaginary components of the given complex 423 // shape. 424 static Shape ComplexComponentShape(const Shape& complex_shape); 425 426 // Shorthand for testing whether a shape is of a given element type and 427 // sequence of dimensions. 428 // 429 // DEPRECATED: Use Equal() instead. 430 static bool ShapeIs(const Shape& shape, PrimitiveType element_type, 431 std::initializer_list<int64> dimensions); 432 433 // GetSubshape and GetMutableSubshape return a particular nested Shape within 434 // the given Shape argument. 435 static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index); 436 static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index); 437 438 // Returns whether the given index in the given shape is a leaf element of the 439 // shape. 440 static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index); 441 442 // Calls the given visitor function for each subshape of the given shape. 443 // Subshapes are visited in DFS pre-order starting with the entire shape 444 // (index {}). 445 using VisitorFunction = std::function<void(const Shape& /*subshape*/, 446 const ShapeIndex& /*index*/)>; 447 static void ForEachSubshape(const Shape& shape, const VisitorFunction& func); 448 using MutatingVisitorFunction = 449 std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>; 450 static void ForEachMutableSubshape(Shape* shape, 451 const MutatingVisitorFunction& func); 452 453 // Variants of ForEach(Mutable)Subshape which propagate Status from the 454 // visitor function. 455 using StatusVisitorFunction = std::function<Status( 456 const Shape& /*subshape*/, const ShapeIndex& /*index*/)>; 457 static Status ForEachSubshapeWithStatus(const Shape& shape, 458 const StatusVisitorFunction& func); 459 using MutatingStatusVisitorFunction = 460 std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>; 461 static Status ForEachMutableSubshapeWithStatus( 462 Shape* shape, const MutatingStatusVisitorFunction& func); 463 464 // Removes all degenerate dimensions (size one) from the given shape. The 465 // stripped minor_to_major preserves the relative ordering of non-degenerate 466 // dimensions. The stripped shape has the property that the underlying 467 // representation (bits in memory) for the stripped shape is the same as the 468 // original shape modulo padding. Examples: 469 // 470 // input shape: F32 [1, 2, 1], minor_to_major = {0, 1, 2} 471 // stripped shape: F32 [2], minor_to_major = {0} 472 // 473 // input shape: F32 [6, 1, 5], minor_to_major = {2, 0, 1} 474 // stripped shape: F32 [6, 5], minor_to_major = {1, 0} 475 // 476 // input shape: F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1} 477 // stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1} 478 // 479 // input shape: F32 [1, 1], minor_to_major = {0, 1} 480 // stripped shape: F32 [], minor_to_major = {} 481 // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) 482 static Shape StripDegenerateDimensions(const Shape& shape); 483 484 // Permutes the dimensions by the given permutation, so 485 // return_value.dimensions[permutation[i]] = argument.dimensions[i] 486 static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation, 487 const Shape& shape); 488 489 // If we can go from `shape_pre` to `shape_post` by merely inserting or 490 // deleting 1-sized dimensions, return the indices in `shape_pre` of the 491 // deleted dimensions and the indices in `dims_post` of the inserted 492 // dimensions. 493 // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and 494 // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s 495 // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each 496 // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and 497 // `j`s less than `k` for all other `k`, we return the `i`s and `j`s. 498 // For another example, if `shape_pre = shape_post = {}`, we return `{}`. 499 static std::tuple<bool, std::vector<int64>, std::vector<int64>> 500 InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, 501 const Shape& shape_post); 502 503 // Suppose a reshape transforms input_shape to output shape. Returns a vector 504 // of pairs that indicate the input and output dimensions that this reshape 505 // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in 506 // the returned vector, the reshape transforms any input index whose I-th 507 // dimension is x to an output index whose O-th dimension is x too. 508 // 509 // Post-condition: the returned vector is sorted (by both input and output 510 // dimensions because input and output dimensions have the same order). 511 // 512 // Example: 513 // input shape = T[a, b, x, y, cd] 514 // output shape = T[ab, x, 1, y, c, d] 515 // return value = {{2, 1}, {3, 3}} 516 // 517 // The two pairs represent the input and output dimension of size x and 518 // those of size y. 519 static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape( 520 const Shape& input_shape, const Shape& output_shape); 521 522 // Returns whether a transpose from input_shape to output_shape with dimension 523 // mapping "dimension_mapping" produces a result which is bit-wise identical 524 // to its input and thus may be replaced with a bitcast. 525 static bool TransposeIsBitcast( 526 const Shape& input_shape, const Shape& output_shape, 527 tensorflow::gtl::ArraySlice<int64> dimension_mapping); 528 529 // Returns whether a reshape from "input_shape" to "output_shape" is a 530 // bitcast. 531 static bool ReshapeIsBitcast(const Shape& input_shape, 532 const Shape& output_shape); 533 534 // Find a physical layout for 'output_shape' such that 535 // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns 536 // true (where 'output_shape_with_layout' is 'output_shape' with the found 537 // layout). The layout of 'input_shape' is kept fixed. Returns 538 // 'output_shape_with_layout' if such a layout can be found, and an error 539 // otherwise. 540 static tensorflow::gtl::optional<Shape> AlignLayouts( 541 const Shape& input_shape, const Shape& output_shape); 542 543 // Returns a shape with the given dimension deleted. 544 // For example: 545 // • `DeleteDimension(1, T[m, n, k]) = T[m, k]` 546 static Shape DeleteDimension(int64 dim_to_delete, Shape shape); 547 548 // Returns a shape with all the dimensions of the input shape for which `p` 549 // returns true. 550 // For examples: 551 // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]` 552 // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]` 553 static Shape FilterDimensions(const std::function<bool(int64)>& p, 554 Shape shape); 555 556 // Iterates through all the shape indexes, in minor to major order, starting 557 // from the base indexes, incrementing by the incr steps, up to count 558 // (index[i] < base[i] + count[i]), and calls the visitor_function with the 559 // current index. 560 // The visitor_function visitor function should return true if it wants to 561 // continue, or false otherwise. 562 // 563 // visitor_function must be a callable of type bool(const std::vector<int64>&) 564 // or compatible. 565 template <typename FnType> ForEachIndex(const Shape & shape,tensorflow::gtl::ArraySlice<int64> base,tensorflow::gtl::ArraySlice<int64> count,tensorflow::gtl::ArraySlice<int64> incr,const FnType & visitor_function)566 static void ForEachIndex(const Shape& shape, 567 tensorflow::gtl::ArraySlice<int64> base, 568 tensorflow::gtl::ArraySlice<int64> count, 569 tensorflow::gtl::ArraySlice<int64> incr, 570 const FnType& visitor_function) { 571 if (ShapeUtil::HasZeroElements(shape)) { 572 return; 573 } 574 CHECK_EQ(Rank(shape), base.size()); 575 CHECK_EQ(incr.size(), base.size()); 576 CHECK_EQ(count.size(), base.size()); 577 const int64 rank = LayoutUtil::MinorToMajor(shape).size(); 578 // Allows handling R0 arrays, such that the visitor function will be called 579 // once with the proper empty indexes. 580 int64 n = -1; 581 std::vector<int64> indexes(base.begin(), base.end()); 582 while (n < rank && visitor_function(indexes)) { 583 // Increments dimensions in minor to major order. 584 for (n = 0; n < rank; ++n) { 585 int64 dim = LayoutUtil::Minor(shape.layout(), n); 586 indexes[dim] += incr[dim]; 587 if (indexes[dim] < base[dim] + count[dim]) { 588 break; 589 } 590 indexes[dim] = base[dim]; 591 } 592 } 593 } 594 595 private: 596 // Validates all of the non-layout properties of the shape -- this is a helper 597 // used by both the layout-optional and layout-required public method. 598 static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); 599 600 TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil); 601 }; 602 603 std::ostream& operator<<(std::ostream& out, const Shape& shape); 604 605 } // namespace xla 606 607 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ 608