• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Shapes are protobuf messages, so this utility header offers a bunch of
17 // functionality for querying / poking at them.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
21 
22 #include <algorithm>
23 #include <functional>
24 #include <initializer_list>
25 #include <optional>
26 #include <string>
27 #include <tuple>
28 #include <utility>
29 #include <vector>
30 
31 #include "absl/base/macros.h"
32 #include "absl/container/inlined_vector.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/layout_util.h"
35 #include "tensorflow/compiler/xla/primitive_util.h"
36 #include "tensorflow/compiler/xla/shape.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/threadpool.h"
43 #include "tensorflow/core/platform/cpu_info.h"
44 #include "tensorflow/core/platform/env.h"
45 
46 namespace xla {
47 
48 // A view into a ShapeIndex below, with the cheap/easy ability to consume the
49 // value at the front of the view.
50 //
51 // NB! ShapeIndexView does not own the memory backing the index array.
52 // The memory backing the index array should be owned by an object
53 // that lives longer than the ShapeIndexView instances pointing into
54 // it.
55 using ShapeIndexView = absl::Span<const int64_t>;
56 
57 // An index for specifying a particular nested subshape within a shape. Used in
58 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
59 // structures (trees) and ShapeIndex defines a path through the tree where each
60 // element of ShapeIndex indexes into a tuple (or nested tuple) within the
61 // shape. For a non-nested tuple, an index has a single element. For example,
62 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
63 // {1} corresponds to array b. For a nested tuple, the index can have more than
64 // one element. For the nested tuple (a, (b, c, d), e) below are the values
65 // corresponding to the given indices:
66 //
67 //   index {0}    : array a
68 //   index {1, 2} : array d
69 //   index {2}    : array e
70 //   index {0, 0} : invalid index (element at {0} is an array not a tuple)
71 //
72 // For indexing into array shapes, the index is always trivially empty, ie {}.
73 struct ShapeIndex : public absl::InlinedVector<int64_t, 2> {
74   using InlinedVector::InlinedVector;
75 
76   ShapeIndex() = default;  // Needed to make MSVC work for some reason.
ShapeIndexShapeIndex77   explicit ShapeIndex(ShapeIndexView view)
78       : ShapeIndex(view.begin(), view.end()) {}
79 
80   // push_front is O(n), but shapes don't usually have a ton of dimensions.
push_frontShapeIndex81   void push_front(int64_t value) { insert(begin(), value); }
pop_frontShapeIndex82   void pop_front() { erase(begin()); }
83 
84   std::string ToString() const;
85 };
86 
87 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
88 
89 // Namespaced collection of (static) shape utilities.
90 //
91 // These are all effectively convenience functions for testing/tweaking proto
92 // properties, which do invariant checks before / after the operation.
93 class ShapeUtil {
94  public:
95   // Data structure which describes the coordinates and the shape, of a tuple
96   // shaped sub-shape.
97   struct IndexedShape {
98     IndexedShape() = default;
IndexedShapeIndexedShape99     IndexedShape(ShapeIndex index, Shape shape)
100         : index(std::move(index)), shape(std::move(shape)) {}
101     ShapeIndex index;
102     Shape shape;
103   };
104 
105   // Returns the number of elements are contained within the provided shape;
106   // e.g. for rank 0 (scalars) the result is always 1.
107   // Precondition: shape.IsArray()
108   static int64_t ElementsIn(const Shape& shape);
109 
110   // As ElementsIn(), but recurses through tuples.
111   static int64_t ElementsInRecursive(const Shape& shape);
112 
113   // Returns true if shape has the primitive type, recurses through tuples.
114   static bool HasPrimitiveType(const Shape& shape,
115                                PrimitiveType primitive_type);
116 
117   // Returns true if 'shape' is an array with zero elements.
118   static bool IsZeroElementArray(const Shape& shape);
119 
120   // Returns the number of bytes required for an allocation of shape.  The
121   // |pointer_size| parameter is used for calculating the size of tuple
122   // shapes. This includes only the size of the top-level buffer. For example, a
123   // tuple is stored as an array of pointers to other buffers. In this case,
124   // this method only returns the size of the pointer array.
125   static int64_t ByteSizeOf(const Shape& shape, int64_t pointer_size = -1);
126 
127   // Returns the number of bytes used to store the primitive_type.
128   //
129   // Precondition: shape.IsArray()
130   static int64_t ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
131 
132   // Returns the number of bytes required to store the tuple member pointers for
133   // a allocation of shape. The `shape` must be a TUPLE shape, and
134   // `pointer_size` must be larger than zero.
135   static int64_t ByteSizeOfTupleIndexTable(const Shape& shape,
136                                            int64_t pointer_size);
137 
138   // Returns the number of bytes required for the elements in an allocation of
139   // `shape`, which must be an array shape. Shapes use a separate
140   // memory location for each element, and so for these shapes,
141   // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. This
142   // size also includes padding if present in the layout.
143   static int64_t ByteSizeOfElements(const Shape& shape);
144 
145   // Returns a human-readable string that represents the given shape, with or
146   // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
147   static std::string HumanString(const Shape& shape);
148   static std::string HumanStringWithLayout(const Shape& shape);
149 
150   // As above, but for program shapes, returns a string for the form:
151   //
152   // (param_name: f32[42x12], ...) -> f32[24x42]
153   static std::string HumanString(const ProgramShape& program_shape);
154 
155   // Returns whether the LHS and RHS shapes have the same dimensions; note: does
156   // not check element type.
157   // Precondition: IsArray(lhs) && IsArray(rhs)
158   static bool SameDimensions(const Shape& lhs, const Shape& rhs);
159 
160   // Returns whether the LHS and RHS shapes have the same rank; note: does
161   // not check element type.
162   // Precondition: IsArray(lhs) && IsArray(rhs)
163   static bool SameRank(const Shape& lhs, const Shape& rhs);
164 
165   // Returns whether the lhs and rhs shapes have the same element type.
SameElementType(const Shape & lhs,const Shape & rhs)166   static bool SameElementType(const Shape& lhs, const Shape& rhs) {
167     return lhs.element_type() == rhs.element_type();
168   }
169 
170   // As SameElementType, but allows floating point types to have different
171   // precisions.
SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)172   static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
173                                                  const Shape& b) {
174     if (ElementIsFloating(a) && ElementIsFloating(b)) {
175       return true;
176     }
177     return ShapeUtil::SameElementType(a, b);
178   }
179 
180   // Returns the higher-precision element type if a and b are both floating
181   // point types; otherwise, checks that they have the same element type
182   // and returns it.
HigherPrecisionElementType(const Shape & a,const Shape & b)183   static PrimitiveType HigherPrecisionElementType(const Shape& a,
184                                                   const Shape& b) {
185     return primitive_util::HigherPrecisionType(a.element_type(),
186                                                b.element_type());
187   }
188 
189   // Returns true if the rank, dimension sizes, and element type are
190   // identical. Layout is ignored. Tuple elements are compared recursively for
191   // compatibility.
192   static bool Compatible(const Shape& lhs, const Shape& rhs);
193 
194   // Returns true if the rank and dimension sizes are identical. Element type
195   // and layout are ignored. Tuple elements are compared recursively for
196   // compatibility.
197   static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
198 
199   // Returns true if the tuple tree shapes and leaf ranks are identical.
200   // Leaf dimensions, element type, and layout are ignored. Tuple elements are
201   // compared recursively for compatibility.
202   static bool CompatibleKind(const Shape& lhs, const Shape& rhs);
203 
204   // As Compatible, but allow one of lhs and rhs to be BF16 while the other
205   // being F32. Tuple elements are compared recursively for compatibility.
206   static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
207 
208   // Returns whether the lhs and rhs shapes are identical.
209   static bool Equal(const Shape& lhs, const Shape& rhs);
210 
211   // As Equal, but does not compare the element type.
212   static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs);
213 
214   // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
215   static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
216 
217   // Two shapes have same structure if all subshape indices of lhs are presented
218   // on rhs and vice versa.
219   // A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to
220   // (S32, (F32[3], S32[2])) as their structures are both (,(,))
221   //
222   // In contrast, (F32, (F32, F32)) is structurally different from
223   // ((F32, F32), F32) as the former has structure (,(,)) while the latter has
224   // ((,),)
225   static bool EqualStructure(const Shape& lhs, const Shape& rhs);
226 
227   // Returns the number of dimensions for which the dimension is not (trivially)
228   // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
229   // fluff. Note that zero dimensions are included in the true rank, e.g.,
230   // f32[3,0,1] has a true rank of 2D.
231   static int64_t TrueRank(const Shape& shape);
232 
233   static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
234                                        Shape result);
235 
236   ////////////////////
237   // Scalar-specific
238 
IsScalar(const Shape & shape)239   static bool IsScalar(const Shape& shape) {
240     return shape.IsArray() && shape.rank() == 0;
241   }
IsEffectiveScalar(const Shape & shape)242   static bool IsEffectiveScalar(const Shape& shape) {
243     return shape.IsArray() && TrueRank(shape) == 0;
244   }
245 
246   // Returns whether "shape" is a scalar (array) with the given element_type.
247   static bool IsScalarWithElementType(const Shape& shape,
248                                       PrimitiveType element_type);
249 
250   // Extracts the size of the shape's dimension at dimension number
251   // GetDimensionNumber(dimension_number).
252   static int64_t GetDimension(const Shape& shape, int64_t dimension_number);
253 
254   // Resolves a dimension number, supporting negative indexing.
255   //
256   // Negative indexing has similar semantics to Python. For an N-dimensional
257   // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
258   // N-2, and so on.
259   //
260   // This function always returns a positive dimension number for any given
261   // dimension_number (which itself can be negative).
262   static int64_t GetDimensionNumber(const Shape& shape,
263                                     int64_t dimension_number);
264 
265   // Returns a shape with the same dimensions as the original, but with the
266   // element type changed to type.
267   static Shape ChangeElementType(const Shape& original, PrimitiveType type);
268 
269   // Retursn a shape with same dimensions but with all dimensions set to static.
270   static Shape MakeStaticShape(const Shape& original);
271 
272   // Creates a tuple shape from a slice of element shapes within the tuple.
273   static Shape MakeTupleShape(absl::Span<const Shape> shapes);
274   static Shape MakeTupleShapeWithPtrs(absl::Span<const Shape* const> shapes);
275 
276   // Creates a tuple shape from a slice of element shapes within the tuple. If
277   // only one shape is passed, returns that.
278   static Shape MakeMaybeTupleShape(absl::Span<const Shape> shapes);
279 
280   // Creates an opaque shape. These are generally used for threading a context
281   // into a custom operation.
282   static Shape MakeOpaqueShape();
283 
284   // Creates a token shape. Values of this shape are used for ordering
285   // side-effecting operations.
286   static Shape MakeTokenShape();
287 
288   // Appends a shape to the given tuple.
289   static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
290 
291   // Update a subshape of a tuple.
292   static void UpdateTupleShape(const Shape& shape, int64_t index,
293                                Shape* tuple_shape);
294 
295   // Update the dynamic dimension for a shape. This shape can be a nested tuple.
296   static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index,
297                                      int64_t dim, bool is_dynamic);
298 
299   // Appends a major dimension to the shape with the given bound.
300   static void AppendMajorDimension(int bound, Shape* shape);
301 
302   // Appends a minor dimension to the shape with the given bound.
303   static void AppendMinorDimension(int bound, Shape* shape);
304 
305   // Copy the dynamic dimensions property from one shape to another.
306   static void CopyDynamicDimensions(Shape* to, const Shape& from);
307 
308   // Returns an empty tuple shape. Can be used as a sentinel Shape value.
MakeNil()309   static Shape MakeNil() { return MakeTupleShape({}); }
310 
311   // Checks whether the shape is initialized.
IsInitialized(const Shape & shape)312   static bool IsInitialized(const Shape& shape) {
313     return shape.element_type() != PRIMITIVE_TYPE_INVALID;
314   }
315 
316   // Constructs a new shape with the given element type and sequence of
317   // dimensions.
318   static Shape MakeShape(PrimitiveType element_type,
319                          absl::Span<const int64_t> dimensions);
320 
321   // Make a scalar shape with given primitive type.
322   static Shape MakeScalarShape(PrimitiveType element_type);
323 
324   // Constructs a new shape with the given element type and sequence of
325   // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates
326   // with a true value that the respective dimension is dynamic. If the
327   // dimension is dynamic then the respective value in 'dimension' is an upper
328   // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be
329   // the same size.
330   static Shape MakeShape(PrimitiveType element_type,
331                          absl::Span<const int64_t> dimensions,
332                          const std::vector<bool>& dynamic_dimensions);
333 
334   // Constructs a new shape with the given element type and sequence of
335   // dimensions. Method checks if the element type is valid and the shape's
336   // size fits in std::numeric_limits<int64_t>::max().
337   static StatusOr<Shape> MakeValidatedShape(
338       PrimitiveType element_type, absl::Span<const int64_t> dimensions);
339   static StatusOr<Shape> MakeValidatedShape(
340       PrimitiveType element_type, absl::Span<const int64_t> dimensions,
341       const std::vector<bool>& dynamic_dimensions);
342 
343   // Creates a Shape with element type corresponding to T and the given
344   // dimensions
345   template <typename T>
MakeShapeWithType(absl::Span<const int64_t> dimensions)346   static Shape MakeShapeWithType(absl::Span<const int64_t> dimensions) {
347     return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(),
348                                 dimensions);
349   }
350 
351   // Constructs a new shape with the given minor_to_major order in its Layout.
352   // Returns a value shape such that shape.has_layout().
353   static Shape MakeShapeWithLayout(
354       PrimitiveType element_type, absl::Span<const int64_t> dimensions,
355       absl::Span<const int64_t> minor_to_major,
356       absl::Span<const DimLevelType> dim_level_types = {},
357       absl::Span<const Tile> tiles = {}, int64_t element_size_in_bits = 0,
358       int64_t memory_space = 0);
359 
360   // Constructs a new shape with the given dimension `dim` as the most major
361   // dimension in the layout. If the shape does not have a layout, assumes a
362   // default layout. If the shape is a tuple, apply this to all the leaf shapes
363   // of the tuple.
364   static Shape MoveDimToMajor(const Shape& shape, int64_t dim);
365 
366   // Returns the same shape except with all dimensions set to be static.
367   static Shape MakeShapeWithStaticDimensions(const Shape& shape);
368 
369   // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
370   static Shape MakeShapeWithDescendingLayout(
371       PrimitiveType element_type, absl::Span<const int64_t> dimensions);
372 
373   // Returns a new Shape based on the given Shape with low-dimension-major
374   // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
375   // rearranged so that it has the same in-memory layout as the given shape.
376   //
377   // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
378   static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
379       const Shape& shape);
380 
381   // As MakeShape, but the object to write to is passed in.
382   static Status PopulateShape(PrimitiveType element_type,
383                               absl::Span<const int64_t> dimensions,
384                               Shape* shape);
385 
386   // Validates that the provided shape satisfies invariants.
387   static Status ValidateShape(const Shape& shape);
388 
389   // Validates the provided shape satisfies invariants, except those that
390   // pertain to layout.
391   //
392   // Layout is optional for client-provided shapes, so that the compiler may
393   // determine and assign an optimized layout.
394   static Status ValidateShapeWithOptionalLayout(const Shape& shape);
395 
396   // Returns whether the element type of the shape is integral (signed or
397   // unsigned). Note that predicates are not considered integral here, since
398   // they are logical values.
399   static bool ElementIsIntegral(const Shape& shape);
400 
401   // Returns whether the element type of the shape is floating point.
402   static bool ElementIsFloating(const Shape& shape);
403 
404   // Returns whether the element type of the shape is complex.
405   static bool ElementIsComplex(const Shape& shape);
406 
407   // Returns whether the element type has the given bit width.
408   static bool ElementHasBitWidth(const Shape& shape, int bits);
409 
410   // Returns whether the element type of the shape is integral and has
411   // the specified number of bits.
412   static bool ElementIsIntegralWithBits(const Shape& shape, int bits);
413 
414   // Returns whether the element type of the shape is signed. Note
415   // that floating point numbers are signed.
416   static bool ElementIsSigned(const Shape& shape);
417 
418   // Returns whether the given primitive type corresponds to an array shape.
419   static bool IsArrayPrimitiveType(PrimitiveType primitive_type);
420 
421   // Returns whether the shape is a tuple with at least one element which is
422   // also a tuple.
423   static bool IsNestedTuple(const Shape& shape);
424 
425   // Returns true if shape is an empty tuple.
426   static bool IsEmptyTuple(const Shape& shape);
427 
428   // Returns the number of elements in the given tuple shape.
429   // Precondition: IsTuple(shape)
430   static int64_t TupleElementCount(const Shape& shape);
431 
432   // Returns the tuple element shape at given index.
433   // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
434   static const Shape& GetTupleElementShape(const Shape& shape, int64_t index);
435 
436   // Returns the number of elements, recursively, in the given shape.
437   static int64_t SubshapeCount(const Shape& shape);
438 
439   // Slices tuple elements in the range [start, limit) and returns a new tuple
440   // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
441   static Shape SliceTuple(const Shape& tuple, int64_t start, int64_t limit);
442 
443   // Returns the shape of the real/imaginary components of the given complex
444   // shape.
445   static Shape ComplexComponentShape(const Shape& complex_shape);
446 
447   // Returns true if the given shape has a subshape at the given index.
448   static bool IndexIsValid(const Shape& shape, ShapeIndexView index);
449 
450   // GetSubshape and GetMutableSubshape return a particular nested Shape within
451   // the given Shape argument. The non-Try variants check fail if index is
452   // invalid.
453   static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
454   static StatusOr<const Shape*> TryGetSubshape(const Shape& shape,
455                                                ShapeIndexView index);
456   static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
457 
458   // Returns whether the given index in the given shape is a leaf element of the
459   // shape.
460   static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);
461 
462   // Returns the number of leaves in the shape.
463   static int64_t GetLeafCount(const Shape& shape);
464 
465   // Retrieves all the leaf shapes and their indexes, in the order walked by
466   // the ForEachSubshape() API.
467   static std::vector<IndexedShape> GetLeafShapes(const Shape& shape);
468 
469   // Calls the given visitor function for each subshape of the given shape.
470   // Subshapes are visited in DFS pre-order starting with the entire shape
471   // (index {}).
472   using VisitorFunction = std::function<void(const Shape& /*subshape*/,
473                                              const ShapeIndex& /*index*/)>;
474   static void ForEachSubshape(const Shape& shape, const VisitorFunction& func);
475   using MutatingVisitorFunction =
476       std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
477   static void ForEachMutableSubshape(Shape* shape,
478                                      const MutatingVisitorFunction& func);
479 
480   // Variants of ForEach(Mutable)Subshape which propagate Status from the
481   // visitor function.
482   using StatusVisitorFunction = std::function<Status(
483       const Shape& /*subshape*/, const ShapeIndex& /*index*/)>;
484   static Status ForEachSubshapeWithStatus(const Shape& shape,
485                                           const StatusVisitorFunction& func);
486   using MutatingStatusVisitorFunction =
487       std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
488   static Status ForEachMutableSubshapeWithStatus(
489       Shape* shape, const MutatingStatusVisitorFunction& func);
490 
491   // Returns true if `shape` (which must be an array) with degenerate dimensions
492   // (dimensions with bound 1).
493   static bool HasDegenerateDimensions(const Shape& shape);
494 
495   // Drops any degenerate dimensions (i.e. dimensions of size 1)
496   static Shape DropDegenerateDimensions(const Shape& shape);
497 
498   // Permutes the dimensions by the given permutation, so
499   // return_value.dimensions[i] = argument.dimensions[permutation[i]].
500   //
501   // Postcondition: For any valid permutation,
502   //
503   //   !HasLayout(shape) ||
504   //   TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
505   //                      permutation).
506   static Shape PermuteDimensions(absl::Span<const int64_t> permutation,
507                                  const Shape& shape);
508 
509   // Describes how we can go from shape A to shape B by inserting degenerate
510   // 1-sized dimensions in `added_dimensions` and removing degenerate 1-sized
511   // dimensions from B in `removed_dimensions`.
512   //
513   // Only exists if shapes A and B only differ by degenerate dimensions.
514   struct ShapeEqualityDescriptor {
515     std::vector<int64_t> deleted_dimensions;
516     std::vector<int64_t> inserted_dimensions;
517   };
518 
519   // If we can go from `shape_pre` to `shape_post` by merely inserting or
520   // deleting 1-sized dimensions, return the indices in `shape_pre` of the
521   // deleted dimensions and the indices in `dims_post` of the inserted
522   // dimensions.
523   // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
524   // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
525   // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
526   // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
527   // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
528   // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
529   static std::optional<ShapeEqualityDescriptor>
530   InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
531                                     const Shape& shape_post);
532 
533   // Suppose a reshape transforms input_shape to output shape. Returns a vector
534   // of pairs that indicate the input and output dimensions that this reshape
535   // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
536   // the returned vector, the reshape transforms any input index whose I-th
537   // dimension is x to an output index whose O-th dimension is x too.
538   //
539   // Post-condition: the returned vector is sorted (by both input and output
540   // dimensions because input and output dimensions have the same order).
541   //
542   // Example:
543   //   input  shape = T[a, b, x, y, cd]
544   //   output shape = T[ab, x, 1, y, c, d]
545   //   return value = {{2, 1}, {3, 3}}
546   //
547   //   The two pairs represent the input and output dimension of size x and
548   //   those of size y.
549   static std::vector<std::pair<int64_t, int64_t>> DimensionsUnmodifiedByReshape(
550       const Shape& input_shape, const Shape& output_shape);
551 
552   // Return whether the given reshape instruction leaves the dimensions at the
553   // given input indices unmodified, and returns their output indices.
554   //
555   // Example:
556   //   input_dim_indices = {2, 3}
557   //   input  shape = T[a, b, x, y, cd]
558   //   output shape = T[ab, x, 1, y, c, d]
559   //   return value = {1, 3}
560   //
561   // Precondition: input_dim_indices is sorted.
562   static std::optional<std::vector<int64_t>> ReshapeLeavesDimensionsUnmodified(
563       const Shape& from_shape, const Shape& to_shape,
564       absl::Span<const int64_t> input_dim_indices);
565 
566   // Returns whether a transpose from input_shape to output_shape with dimension
567   // mapping "dimension_mapping" produces a result which is bit-wise identical
568   // to its input and thus may be replaced with a bitcast.
569   //
570   // Precondition: Both input_shape and output_shape have explicit layouts.
571   static bool TransposeIsBitcast(const Shape& input_shape,
572                                  const Shape& output_shape,
573                                  absl::Span<const int64_t> dimension_mapping);
574 
575   // Returns whether a reshape from `input_shape` to `output_shape` is a
576   // bitcast, when minor_to_major in layout is considered.
577   //
578   // Precondition: Both input_shape and output_shape have explicit layouts.
579   static bool ReshapeIsBitcast(const Shape& input_shape,
580                                const Shape& output_shape);
581 
582   // Find a physical layout for 'output_shape' such that
583   // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
584   // true (where 'output_shape_with_layout' is 'output_shape' with the found
585   // layout). The layout of 'input_shape' is kept fixed. Returns
586   // 'output_shape_with_layout' if such a layout can be found, and an error
587   // otherwise.
588   static std::optional<Shape> AlignLayouts(const Shape& input_shape,
589                                            const Shape& output_shape);
590 
591   // Returns a shape with the given dimension deleted.
592   // For example:
593   // • `DeleteDimension(1, T[m, n, k]) = T[m, k]`
594   static Shape DeleteDimension(int64_t dim_to_delete, Shape shape);
595 
596   // Returns a shape with dimensions in `to_drop` dropped.
597   static Shape DeleteDimensions(absl::Span<int64_t const> dims_to_delete,
598                                 Shape shape);
599 
600   // Returns a shape with all the dimensions of the input shape for which `p`
601   // returns true.
602   // For examples:
603   // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
604   // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
605   static Shape FilterDimensions(const std::function<bool(int64_t)>& p,
606                                 Shape shape);
607 
608   // Returns true if `dynamic_shape` has dimensions that are less-equal to the
609   // "bounded_shape". Shapes must be arrays.
610   static bool DynamicArrayShapeIsCompatible(const xla::Shape& dynamic_shape,
611                                             const xla::Shape& bounded_shape);
612 
613   // Same as DynamicArrayShapeIsCompatible() but supports tuples.
614   static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
615                                        const xla::Shape& bounded_shape);
616 
617   // Iterates through all the shape indexes, in minor to major order,
618   // starting from the base indexes, incrementing by the incr steps, up to
619   // count (index[i] < base[i] + count[i]), and calls the visitor_function
620   // with the current index. The visitor_function visitor function should
621   // return true if it wants to continue, or false otherwise.
622   //
623   // visitor_function must be a callable of type
624   // StatusOr<bool>(absl::Span<int64_t>) or compatible.
625   template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,absl::Span<const int64_t> base,absl::Span<const int64_t> count,absl::Span<const int64_t> incr,const FnType & visitor_function)626   static Status ForEachIndexWithStatus(const Shape& shape,
627                                        absl::Span<const int64_t> base,
628                                        absl::Span<const int64_t> count,
629                                        absl::Span<const int64_t> incr,
630                                        const FnType& visitor_function) {
631     return ForEachIndexInternal(
632         shape, base, count, incr,
633         [&visitor_function](absl::Span<const int64_t> indexes,
634                             int /*thread_id*/) {
635           return visitor_function(indexes);
636         });
637   }
638 
639   // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
640   struct IndexIterationSpace {
641     std::vector<int64_t> index_base;
642     std::vector<int64_t> index_count;
643     std::vector<int64_t> index_incr;
644   };
645 
646   template <typename FnTy>
ForEachIndexWithStatus(const Shape & shape,const IndexIterationSpace & iteration_space,FnTy && function)647   static Status ForEachIndexWithStatus(
648       const Shape& shape, const IndexIterationSpace& iteration_space,
649       FnTy&& function) {
650     return ShapeUtil::ForEachIndexWithStatus(
651         shape, iteration_space.index_base, iteration_space.index_count,
652         iteration_space.index_incr, std::forward<FnTy>(function));
653   }
654 
655   template <typename FnType>
ForEachIndex(const Shape & shape,absl::Span<const int64_t> base,absl::Span<const int64_t> count,absl::Span<const int64_t> incr,const FnType & visitor_function)656   static void ForEachIndex(const Shape& shape, absl::Span<const int64_t> base,
657                            absl::Span<const int64_t> count,
658                            absl::Span<const int64_t> incr,
659                            const FnType& visitor_function) {
660     ForEachIndexWithStatus(shape, base, count, incr,
661                            [&](absl::Span<const int64_t> indices) {
662                              return StatusOr<bool>(visitor_function(indices));
663                            })
664         .IgnoreError();
665   }
666 
667   // These convenience wrappers don't take `base`, `count` and `incr`
668   // explicitly, but iterate over every element in `shape` instead.
669 
670   template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,const FnType & visitor_function)671   static Status ForEachIndexWithStatus(const Shape& shape,
672                                        const FnType& visitor_function) {
673     std::vector<int64_t> base(shape.dimensions_size());
674     std::vector<int64_t> incr(shape.dimensions_size(), 1);
675     return ForEachIndexWithStatus(shape, base,
676                                   /*count=*/shape.dimensions(), incr,
677                                   visitor_function);
678   }
679 
680   template <typename FnType>
ForEachIndex(const Shape & shape,const FnType & visitor_function)681   static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
682     ForEachIndexWithStatus(shape, [&](absl::Span<const int64_t> indices) {
683       return StatusOr<bool>(visitor_function(indices));
684     }).IgnoreError();
685   }
686 
687   // A parallel version of ForEachIndex(WithStatus). This can only be used if
688   // the visitor_function is thread-safe and the order of iteration does not
689   // matter.
690   //
691   // visitor_function must be a callable of type
692   // void(Span<int64_t>, int thread_id) or compatible.
693   template <typename FnType>
ForEachIndexParallel(const Shape & shape,absl::Span<const int64_t> base,absl::Span<const int64_t> count,absl::Span<const int64_t> incr,const FnType & visitor_function)694   static void ForEachIndexParallel(const Shape& shape,
695                                    absl::Span<const int64_t> base,
696                                    absl::Span<const int64_t> count,
697                                    absl::Span<const int64_t> incr,
698                                    const FnType& visitor_function) {
699     // The parallel version of ForEachIndexInternal can never fail.
700     CHECK(ForEachIndexInternal(
701               shape, base, count, incr,
702               [&visitor_function](absl::Span<const int64_t> indexes,
703                                   int thread_id) -> StatusOr<bool> {
704                 visitor_function(indexes, thread_id);
705                 return true;
706               },
707               /*parallel=*/true)
708               .ok());
709   }
710   // Convenience wrapper which doesn't take `base`, `count` and `incr`
711   // explicitly, but iterates over every element in `shape` instead.
712   template <typename FnType>
ForEachIndexParallel(const Shape & shape,const FnType & visitor_function)713   static void ForEachIndexParallel(const Shape& shape,
714                                    const FnType& visitor_function) {
715     std::vector<int64_t> base(shape.dimensions_size());
716     std::vector<int64_t> incr(shape.dimensions_size(), 1);
717     return ForEachIndexParallel(shape, base,
718                                 /*count=*/shape.dimensions(), incr,
719                                 visitor_function);
720   }
721 
722   // About 0-2-1 transpose:
723   //
724   // If a shape can be viewed as three logical components 0-1-2 in the order of
725   // major to minor, a 0-2-1-transpose changes the order of such logical
726   // components to 0-2-1. We call the shape being transposed the input shape and
727   // the transposed shape the output shape. The logical view of the input/output
728   // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the
729   // normalized shapes. The original input/output shapes are called unnormalized
730   // shapes.
731   //
732   // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
733   // normalized shape of `b` or the 0-2-1 shape.
734   static std::optional<Vector3> FindTranspose021(const Shape& a,
735                                                  const Shape& b);
736 
737   // Strips device-specific information, namely tiling and memory-space
738   // information, from a shape.
739   static Shape DeviceShapeToHostShape(Shape s);
740 
741   // Returns true iff element type of shape `from` can be safely upcasted to
742   // element type of shape `to`.
743   static bool ElementCanUpcast(const Shape& from, const Shape& to);
744 
745   // Computes byte strides of an array shape `shape`. `shape` must have a
746   // layout. Ignores tiling. `strides` must have size equal to the number of
747   // dimensions of `shape`.
748   static Status ByteStrides(const Shape& shape, absl::Span<int64_t> strides);
749 
750   // Returns the array size in bytes (layout/tiling required), all paddings are
751   // included.
752   static int64_t ArraySize(const Shape& shape);
753 
754   // Returns the size of array data in bytes, ignoring the trailing padding
755   // due to the tiling requirement.
756   static int64_t ArrayDataSize(const Shape& shape);
757 
758  private:
759   // Fills *shape. Returns true on success.
760   // REQUIRES: *shape is empty.
761   static bool FillNewShape(PrimitiveType element_type,
762                            absl::Span<const int64_t> dimensions, Shape* shape);
763 
764   // Validates the shape size is sane. This makes sure it's safe to do
765   // calculations in int64_t without overflowing.
766   static Status ValidateShapeSize(const Shape& shape);
767 
768   // Validates all of the non-layout properties of the shape -- this is a helper
769   // used by both the layout-optional and layout-required public method.
770   static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
771 
772   template <typename FnType>
773   static Status ForEachIndexInternal(const Shape& shape,
774                                      absl::Span<const int64_t> base,
775                                      absl::Span<const int64_t> count,
776                                      absl::Span<const int64_t> incr,
777                                      const FnType& visitor_function,
778                                      bool parallel = false) {
779     if (ShapeUtil::IsZeroElementArray(shape)) {
780       return OkStatus();
781     }
782     CHECK_EQ(shape.rank(), base.size());
783     CHECK_EQ(incr.size(), base.size());
784     CHECK_EQ(count.size(), base.size());
785     const int64_t rank = LayoutUtil::MinorToMajor(shape).size();
786     // Allows handling R0 arrays, such that the visitor function will be called
787     // once with the proper empty indexes.
788     int64_t n = -1;
789     std::vector<int64_t> indexes(base.begin(), base.end());
790     const int kNumThreads = tensorflow::port::MaxParallelism();
791     std::optional<tensorflow::thread::ThreadPool> pool;
792     if (parallel) {
793       pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
794     }
795 
796     absl::Mutex mu;
797     Status status;  // Guarded by mu
798 
799     while (n < rank) {
800       if (pool != std::nullopt) {
801         pool->Schedule([indexes, &visitor_function, &mu, &status, &pool] {
802           const int thread_id = pool->CurrentThreadId();
803           StatusOr<bool> result = visitor_function(indexes, thread_id);
804           if (!result.ok()) {
805             absl::MutexLock lock(&mu);
806             status = status.ok() ? result.status() : status;
807           }
808         });
809       } else {
810         TF_ASSIGN_OR_RETURN(bool should_continue,
811                             visitor_function(indexes, /*thread_id=*/-1));
812         if (!should_continue) {
813           break;
814         }
815       }
816       // Increments dimensions in minor to major order.
817       for (n = 0; n < rank; ++n) {
818         int64_t dim = LayoutUtil::Minor(shape.layout(), n);
819         indexes[dim] += incr[dim];
820         if (indexes[dim] < base[dim] + count[dim]) {
821           break;
822         }
823         indexes[dim] = base[dim];
824       }
825     }
826 
827     // Waits for the scheduled work to complete.
828     pool.reset();
829     return status;
830   }
831 
832   ShapeUtil(const ShapeUtil&) = delete;
833   ShapeUtil& operator=(const ShapeUtil&) = delete;
834 };
835 
836 }  // namespace xla
837 
838 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
839