• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Shapes are protobuf messages, so this utility header offers a bunch of
17 // functionality for querying / poking at them.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
21 
22 #include <initializer_list>
23 #include <string>
24 
25 #include "absl/base/macros.h"
26 #include "absl/container/inlined_vector.h"
27 #include "absl/types/optional.h"
28 #include "absl/types/span.h"
29 #include "tensorflow/compiler/xla/layout_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/shape.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/lib/core/threadpool.h"
38 #include "tensorflow/core/platform/cpu_info.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/macros.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace xla {
45 
46 class ShapeIndexView;
47 
48 // An index for specifying a particular nested subshape within a shape. Used in
49 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
50 // structures (trees) and ShapeIndex defines a path through the tree where each
51 // element of ShapeIndex indexes into a tuple (or nested tuple) within the
52 // shape. For a non-nested tuple, an index has a single element. For example,
53 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
54 // {1} corresponds to array b. For a nested tuple, the index can have more than
55 // one element. For the nested tuple (a, (b, c, d), e) below are the values
56 // corresponding to the given indices:
57 //
58 //   index {0}    : array a
59 //   index {1, 2} : array d
60 //   index {2}    : array e
61 //   index {0, 0} : invalid index (element at {0} is an array not a tuple)
62 //
63 // For indexing into array shapes, the index is always trivially empty, ie {}.
64 //
65 // ShapeIndex is a trivial wrapper around std::vector with a minimum number of
66 // methods implemented.
67 class ShapeIndex {
68  public:
69   ShapeIndex() = default;
ShapeIndex(std::initializer_list<int64> init)70   ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
71   template <typename InputIt>
ShapeIndex(InputIt start,InputIt end)72   ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {}
73 
74   explicit ShapeIndex(ShapeIndexView v);
75 
empty()76   bool empty() const { return indices_.empty(); }
size()77   size_t size() const { return indices_.size(); }
push_back(int64 value)78   void push_back(int64 value) { indices_.push_back(value); }
pop_back()79   void pop_back() { indices_.pop_back(); }
80 
81   // push_front is O(n), but shapes don't usually have a ton of dimensions.
push_front(int64 value)82   void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
83 
84   using container_type = absl::InlinedVector<int64, 2>;
85 
begin()86   container_type::const_iterator begin() const { return indices_.begin(); }
end()87   container_type::const_iterator end() const { return indices_.end(); }
begin()88   container_type::iterator begin() { return indices_.begin(); }
end()89   container_type::iterator end() { return indices_.end(); }
90 
data()91   const int64* data() const { return indices_.data(); }
92 
back()93   int64 back() const { return indices_.back(); }
back()94   int64& back() { return indices_.back(); }
95 
96   const int64& operator[](size_t i) const { return indices_[i]; }
97   int64& operator[](size_t i) { return indices_[i]; }
98 
99   bool operator==(const ShapeIndex& other) const {
100     return indices_ == other.indices_;
101   }
102   bool operator!=(const ShapeIndex& other) const { return !(*this == other); }
103   bool operator<(const ShapeIndex& other) const {
104     return indices_ < other.indices_;
105   }
106 
107   string ToString() const;
108 
109   template <typename H>
AbslHashValue(H h,const ShapeIndex & index)110   friend H AbslHashValue(H h, const ShapeIndex& index) {
111     return H::combine(std::move(h), index.indices_);
112   }
113 
114  private:
115   container_type indices_;
116 };
117 
118 // A view into a ShapeIndex as above, with the cheap/easy ability to consume the
119 // value at the front of the view.
120 //
121 // NB! ShapeIndexView does not own the memory backing the index array.
122 // The memory backing the index array should be owned by an object
123 // that lives longer than the ShapeIndexView instances pointing into
124 // it.
125 class ShapeIndexView {
126  public:
127   ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0)
128       : indices_(shape_index.data() + offset, shape_index.size() - offset) {
129     CHECK_LE(offset, shape_index.size());
130   }
ShapeIndexView(std::initializer_list<int64> indices)131   ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {}
132   ShapeIndexView(const ShapeIndexView& other) = default;
133 
134   using iterator = const int64*;
135 
begin()136   iterator begin() const { return indices_.begin(); }
end()137   iterator end() const { return indices_.end(); }
size()138   int64 size() const { return indices_.size(); }
empty()139   bool empty() const { return indices_.empty(); }
front()140   int64 front() const {
141     CHECK(!empty());
142     return indices_.front();
143   }
back()144   int64 back() const {
145     CHECK(!empty());
146     return indices_.back();
147   }
ConsumeFront()148   ShapeIndexView ConsumeFront() const {
149     ShapeIndexView result = *this;
150     result.indices_.remove_prefix(1);
151     return result;
152   }
ConsumeBack()153   ShapeIndexView ConsumeBack() const {
154     ShapeIndexView result = *this;
155     result.indices_.remove_suffix(1);
156     return result;
157   }
ToShapeIndex()158   ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }
159 
160   bool operator==(const ShapeIndexView& other) const;
161   bool operator!=(const ShapeIndexView& other) const;
162 
163   string ToString() const;
164 
165   // Returns true if this shape index starts with 'prefix'.
166   bool StartsWith(ShapeIndexView prefix) const;
167 
168  private:
169   absl::Span<const int64> indices_;
170 };
171 
ShapeIndex(ShapeIndexView v)172 inline ShapeIndex::ShapeIndex(ShapeIndexView v)
173     : ShapeIndex(v.begin(), v.end()) {}
174 
175 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
176 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
177 
178 // Namespaced collection of (static) shape utilities.
179 //
180 // These are all effectively convenience functions for testing/tweaking proto
181 // properties, which do invariant checks before / after the operation.
182 class ShapeUtil {
183  public:
184   // Data structure which describes the coordinates and the shape, of a tuple
185   // shaped sub-shape.
186   struct IndexedShape {
187     IndexedShape() = default;
IndexedShapeIndexedShape188     IndexedShape(ShapeIndex index, Shape shape)
189         : index(std::move(index)), shape(std::move(shape)) {}
190     ShapeIndex index;
191     Shape shape;
192   };
193 
194   // Returns the number of elements are contained within the provided shape;
195   // e.g. for rank 0 (scalars) the result is always 1.
196   // Precondition: shape.IsArray()
197   static int64 ElementsIn(const Shape& shape);
198 
199   // As ElementsIn(), but recurses through tuples.
200   static int64 ElementsInRecursive(const Shape& shape);
201 
202   // Returns true if shape has the primitive type, recurses through tuples.
203   static bool HasPrimitiveType(const Shape& shape,
204                                PrimitiveType primitive_type);
205 
206   // Returns true if 'shape' is an array with zero elements.
207   static bool IsZeroElementArray(const Shape& shape);
208 
209   // Returns the number of bytes required for an allocation of shape.  The
210   // |pointer_size| parameter is used for calculating the size of tuple
211   // shapes. This includes only the size of the top-level buffer. For example, a
212   // tuple is stored as an array of pointers to other buffers. In this case,
213   // this method only returns the size of the pointer array.
214   static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
215 
216   // Returns the number of bytes used to store the primitive_type.
217   //
218   // Precondition: shape.IsArray()
219   static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
220 
221   // Returns the number of bytes required to store the tuple member pointers for
222   // a allocation of shape. The `shape` must be a TUPLE shape, and
223   // `pointer_size` must be larger than zero.
224   static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
225                                          int64 pointer_size);
226 
227   // Returns the number of bytes required for the elements in an allocation of
228   // `shape`, which must be an array shape. Shapes use a separate
229   // memory location for each element, and so for these shapes,
230   // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. This
231   // size also includes padding if present in the layout.
232   static int64 ByteSizeOfElements(const Shape& shape);
233 
234   // Returns a human-readable string that represents the given shape, with or
235   // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
236   static string HumanString(const Shape& shape);
237   static string HumanStringWithLayout(const Shape& shape);
238 
239   // As above, but for program shapes, returns a string for the form:
240   //
241   // (param_name: f32[42x12], ...) -> f32[24x42]
242   static string HumanString(const ProgramShape& program_shape);
243 
244   // Returns whether the LHS and RHS shapes have the same dimensions; note: does
245   // not check element type.
246   // Precondition: IsArray(lhs) && IsArray(rhs)
247   static bool SameDimensions(const Shape& lhs, const Shape& rhs);
248 
249   // Returns whether the lhs and rhs shapes have the same element type.
SameElementType(const Shape & lhs,const Shape & rhs)250   static bool SameElementType(const Shape& lhs, const Shape& rhs) {
251     return lhs.element_type() == rhs.element_type();
252   }
253 
254   // As SameElementType, but allows floating point types to have different
255   // precisions.
SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)256   static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
257                                                  const Shape& b) {
258     if (ElementIsFloating(a) && ElementIsFloating(b)) {
259       return true;
260     }
261     return ShapeUtil::SameElementType(a, b);
262   }
263 
264   // Returns the higher-precision element type if a and b are both floating
265   // point types; otherwise, checks that they have the same element type
266   // and returns it.
HigherPrecisionElementType(const Shape & a,const Shape & b)267   static PrimitiveType HigherPrecisionElementType(const Shape& a,
268                                                   const Shape& b) {
269     if (SameElementType(a, b)) {
270       return a.element_type();
271     }
272     return primitive_util::BitWidth(a.element_type()) <
273                    primitive_util::BitWidth(b.element_type())
274                ? b.element_type()
275                : a.element_type();
276   }
277 
278   // Returns true if the rank, dimension sizes, and element type are
279   // identical. Layout is ignored. Tuple elements are compared recursively for
280   // compatibility.
281   static bool Compatible(const Shape& lhs, const Shape& rhs);
282 
283   // Returns true if the rank and dimension sizes are identical. Element type
284   // and layout are ignored. Tuple elements are compared recursively for
285   // compatibility.
286   static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
287 
288   // As Compatible, but allow one of lhs and rhs to be BF16 while the other
289   // being F32. Tuple elements are compared recursively for compatibility.
290   static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
291 
292   // Returns whether the lhs and rhs shapes are identical.
293   static bool Equal(const Shape& lhs, const Shape& rhs);
294 
295   // As Equal, but does not compare the element type.
296   static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs);
297 
298   // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
299   static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
300 
301   // Returns the number of dimensions for which the dimension is not (trivially)
302   // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
303   // fluff. Note that zero dimensions are included in the true rank, e.g.,
304   // f32[3,0,1] has a true rank of 2D.
305   static int64 TrueRank(const Shape& shape);
306 
307   static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
308                                        Shape result);
309 
310   ////////////////////
311   // Scalar-specific
312 
IsScalar(const Shape & shape)313   static bool IsScalar(const Shape& shape) {
314     return shape.IsArray() && shape.rank() == 0;
315   }
IsEffectiveScalar(const Shape & shape)316   static bool IsEffectiveScalar(const Shape& shape) {
317     return shape.IsArray() && TrueRank(shape) == 0;
318   }
319 
320   // Returns whether "shape" is a scalar (array) with the given element_type.
321   static bool IsScalarWithElementType(const Shape& shape,
322                                       PrimitiveType element_type);
323 
324   // Extracts the size of the shape's dimension at dimension number
325   // GetDimensionNumber(dimension_number).
326   static int64 GetDimension(const Shape& shape, int64 dimension_number);
327 
328   // Resolves a dimension number, supporting negative indexing.
329   //
330   // Negative indexing has similar semantics to Python. For an N-dimensional
331   // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
332   // N-2, and so on.
333   //
334   // This function always returns a positive dimension number for any given
335   // dimension_number (which itself can be negative).
336   static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number);
337 
338   // Returns a shape with the same dimensions as the original, but with the
339   // element type changed to type.
340   static Shape ChangeElementType(const Shape& original, PrimitiveType type);
341 
342   // Creates a tuple shape from a slice of element shapes within the tuple.
343   static Shape MakeTupleShape(absl::Span<const Shape> shapes);
344 
345   // Creates an opaque shape. These are generally used for threading a context
346   // into a custom operation.
347   static Shape MakeOpaqueShape();
348 
349   // Creates a token shape. Values of this shape are used for ordering
350   // side-effecting operations.
351   static Shape MakeTokenShape();
352 
353   // Appends a shape to the given tuple.
354   static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
355 
356   // Update a subshape of a tuple.
357   static void UpdateTupleShape(const Shape& shape, int64 index,
358                                Shape* tuple_shape);
359 
360   // Update the dynamic dimension for a shape. This shape can be a nested tuple.
361   static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index,
362                                      int64 dim, bool is_dynamic);
363 
364   // Appends a major dimension to the shape with the given bound.
365   static void AppendMajorDimension(int bound, Shape* shape);
366 
367   // Returns an empty tuple shape. Can be used as a sentinel Shape value.
MakeNil()368   static Shape MakeNil() { return MakeTupleShape({}); }
369 
370   // Checks whether the shape is initialized.
IsInitialized(const Shape & shape)371   static bool IsInitialized(const Shape& shape) {
372     return shape.element_type() != PRIMITIVE_TYPE_INVALID;
373   }
374 
375   // Constructs a new shape with the given element type and sequence of
376   // dimensions.
377   static Shape MakeShape(PrimitiveType element_type,
378                          absl::Span<const int64> dimensions);
379 
380   // Make a scalar shape with given primitive type.
381   static Shape MakeScalarShape(PrimitiveType element_type);
382 
383   // Constructs a new shape with the given element type and sequence of
384   // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates
385   // with a true value that the respective dimension is dynamic. If the
386   // dimension is dynamic then the respective value in 'dimension' is an upper
387   // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be
388   // the same size.
389   static Shape MakeShape(PrimitiveType element_type,
390                          absl::Span<const int64> dimensions,
391                          const std::vector<bool>& dynamic_dimensions);
392 
393   // Constructs a new shape with the given element type and sequence of
394   // dimensions. Method checks if the element type is valid and the shape's
395   // size fits in std::numeric_limits<int64>::max().
396   static StatusOr<Shape> MakeValidatedShape(PrimitiveType element_type,
397                                             absl::Span<const int64> dimensions);
398   static StatusOr<Shape> MakeValidatedShape(
399       PrimitiveType element_type, absl::Span<const int64> dimensions,
400       const std::vector<bool>& dynamic_dimensions);
401 
402   // Creates a Shape with element type corresponding to T and the given
403   // dimensions
404   template <typename T>
MakeShapeWithType(absl::Span<const int64> dimensions)405   static Shape MakeShapeWithType(absl::Span<const int64> dimensions) {
406     return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(),
407                                 dimensions);
408   }
409 
410   // Constructs a new shape with the given minor_to_major order in its Layout.
411   // Returns a value shape such that shape.has_layout().
412   static Shape MakeShapeWithLayout(PrimitiveType element_type,
413                                    absl::Span<const int64> dimensions,
414                                    absl::Span<const int64> minor_to_major,
415                                    absl::Span<const Tile> tiles = {},
416                                    int64 element_size_in_bits = 0,
417                                    int64 memory_space = 0);
418 
419   // Returns the same shape except with all dimensions set to be static.
420   static Shape MakeShapeWithStaticDimensions(const Shape& shape);
421 
422   // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
423   static Shape MakeShapeWithDescendingLayout(
424       PrimitiveType element_type, absl::Span<const int64> dimensions);
425 
426   // Returns a new Shape based on the given Shape with low-dimension-major
427   // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
428   // rearranged so that it has the same in-memory layout as the given shape.
429   //
430   // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
431   static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
432       const Shape& shape);
433 
434   // As MakeShape, but the object to write to is passed in.
435   static Status PopulateShape(PrimitiveType element_type,
436                               absl::Span<const int64> dimensions, Shape* shape);
437 
438   // Validates that the provided shape satisfies invariants.
439   static Status ValidateShape(const Shape& shape);
440 
441   // Validates the provided shape satisfies invariants, except those that
442   // pertain to layout.
443   //
444   // Layout is optional for client-provided shapes, so that the compiler may
445   // determine and assign an optimized layout.
446   static Status ValidateShapeWithOptionalLayout(const Shape& shape);
447 
448   // Returns whether the element type of the shape is integral (signed or
449   // unsigned). Note that predicates are not considered integral here, since
450   // they are logical values.
451   static bool ElementIsIntegral(const Shape& shape);
452 
453   // Returns whether the element type of the shape is floating point.
454   static bool ElementIsFloating(const Shape& shape);
455 
456   // Returns whether the element type of the shape is complex.
457   static bool ElementIsComplex(const Shape& shape);
458 
459   // Returns whether the element type has the given bit width.
460   static bool ElementHasBitWidth(const Shape& shape, int bits);
461 
462   // Returns whether the element type of the shape is integral and has
463   // the specified number of bits.
464   static bool ElementIsIntegralWithBits(const Shape& shape, int bits);
465 
466   // Returns whether the element type of the shape is signed. Note
467   // that floating point numbers are signed.
468   static bool ElementIsSigned(const Shape& shape);
469 
470   // Returns whether the given primitive type corresponds to an array shape.
471   static bool IsArrayPrimitiveType(PrimitiveType primitive_type);
472 
473   // Returns whether the shape is a tuple with at least one element which is
474   // also a tuple.
475   static bool IsNestedTuple(const Shape& shape);
476 
477   // Returns true if shape is an empty tuple.
478   static bool IsEmptyTuple(const Shape& shape);
479 
480   // Returns the number of elements in the given tuple shape.
481   // Precondition: IsTuple(shape)
482   static int64 TupleElementCount(const Shape& shape);
483 
484   // Returns the tuple element shape at given index.
485   // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
486   static const Shape& GetTupleElementShape(const Shape& shape, int64 index);
487 
488   // Returns the number of elements, recursively, in the given shape.
489   static int64 SubshapeCount(const Shape& shape);
490 
491   // Slices tuple elements in the range [start, limit) and returns a new tuple
492   // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
493   static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit);
494 
495   // Returns the shape of the real/imaginary components of the given complex
496   // shape.
497   static Shape ComplexComponentShape(const Shape& complex_shape);
498 
499   // Returns true if the given shape has a subshape at the given index.
500   static bool IndexIsValid(const Shape& shape, ShapeIndexView index);
501 
502   // GetSubshape and GetMutableSubshape return a particular nested Shape within
503   // the given Shape argument. The non-Try variants check fail if index is
504   // invalid.
505   static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
506   static StatusOr<const Shape*> TryGetSubshape(const Shape& shape,
507                                                ShapeIndexView index);
508   static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
509 
510   // Returns whether the given index in the given shape is a leaf element of the
511   // shape.
512   static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);
513 
514   // Returns the number of leaves in the shape.
515   static int64 GetLeafCount(const Shape& shape);
516 
517   // Retrieves all the leaf shapes and their indexes, in the order walked by
518   // the ForEachSubshape() API.
519   static std::vector<IndexedShape> GetLeafShapes(const Shape& shape);
520 
521   // Calls the given visitor function for each subshape of the given shape.
522   // Subshapes are visited in DFS pre-order starting with the entire shape
523   // (index {}).
524   using VisitorFunction = std::function<void(const Shape& /*subshape*/,
525                                              const ShapeIndex& /*index*/)>;
526   static void ForEachSubshape(const Shape& shape, const VisitorFunction& func);
527   using MutatingVisitorFunction =
528       std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
529   static void ForEachMutableSubshape(Shape* shape,
530                                      const MutatingVisitorFunction& func);
531 
532   // Variants of ForEach(Mutable)Subshape which propagate Status from the
533   // visitor function.
534   using StatusVisitorFunction = std::function<Status(
535       const Shape& /*subshape*/, const ShapeIndex& /*index*/)>;
536   static Status ForEachSubshapeWithStatus(const Shape& shape,
537                                           const StatusVisitorFunction& func);
538   using MutatingStatusVisitorFunction =
539       std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
540   static Status ForEachMutableSubshapeWithStatus(
541       Shape* shape, const MutatingStatusVisitorFunction& func);
542 
543   // Returns true if `shape` (which must be an array) with degenerate dimensions
544   // (dimensions with bound 1).
545   static bool HasDegenerateDimensions(const Shape& shape);
546 
547   // Drops any degenerate dimensions (i.e. dimensions of size 1)
548   static Shape DropDegenerateDimensions(const Shape& shape);
549 
550   // Permutes the dimensions by the given permutation, so
551   // return_value.dimensions[permutation[i]] = argument.dimensions[i].
552   //
553   // Postcondition: For any valid permutation,
554   //
555   //   !HasLayout(shape) ||
556   //   TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
557   //                      InversePermutation(permutation)).
558   static Shape PermuteDimensions(absl::Span<const int64> permutation,
559                                  const Shape& shape);
560 
561   // If we can go from `shape_pre` to `shape_post` by merely inserting or
562   // deleting 1-sized dimensions, return the indices in `shape_pre` of the
563   // deleted dimensions and the indices in `dims_post` of the inserted
564   // dimensions.
565   // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
566   // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
567   // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
568   // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
569   // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
570   // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
571   static std::tuple<bool, std::vector<int64>, std::vector<int64>>
572   InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
573                                     const Shape& shape_post);
574 
575   // Suppose a reshape transforms input_shape to output shape. Returns a vector
576   // of pairs that indicate the input and output dimensions that this reshape
577   // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
578   // the returned vector, the reshape transforms any input index whose I-th
579   // dimension is x to an output index whose O-th dimension is x too.
580   //
581   // Post-condition: the returned vector is sorted (by both input and output
582   // dimensions because input and output dimensions have the same order).
583   //
584   // Example:
585   //   input  shape = T[a, b, x, y, cd]
586   //   output shape = T[ab, x, 1, y, c, d]
587   //   return value = {{2, 1}, {3, 3}}
588   //
589   //   The two pairs represent the input and output dimension of size x and
590   //   those of size y.
591   static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
592       const Shape& input_shape, const Shape& output_shape);
593 
594   // Return whether the given reshape instruction leaves the dimensions at the
595   // given input indices unmodified, and returns their output indices.
596   //
597   // Example:
598   //   input_dim_indices = {2, 3}
599   //   input  shape = T[a, b, x, y, cd]
600   //   output shape = T[ab, x, 1, y, c, d]
601   //   return value = {1, 3}
602   //
603   // Precondition: input_dim_indices is sorted.
604   static absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
605       const Shape& from_shape, const Shape& to_shape,
606       absl::Span<const int64> input_dim_indices);
607 
608   // Returns whether a transpose from input_shape to output_shape with dimension
609   // mapping "dimension_mapping" produces a result which is bit-wise identical
610   // to its input and thus may be replaced with a bitcast.
611   //
612   // Precondition: Both input_shape and output_shape have explicit layouts.
613   static bool TransposeIsBitcast(const Shape& input_shape,
614                                  const Shape& output_shape,
615                                  absl::Span<const int64> dimension_mapping);
616 
617   // Returns whether a reshape from "input_shape" to "output_shape" is a
618   // bitcast.
619   //
620   // Precondition: Both input_shape and output_shape have explicit layouts.
621   static bool ReshapeIsBitcast(const Shape& input_shape,
622                                const Shape& output_shape);
623 
624   // Find a physical layout for 'output_shape' such that
625   // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
626   // true (where 'output_shape_with_layout' is 'output_shape' with the found
627   // layout). The layout of 'input_shape' is kept fixed. Returns
628   // 'output_shape_with_layout' if such a layout can be found, and an error
629   // otherwise.
630   static absl::optional<Shape> AlignLayouts(const Shape& input_shape,
631                                             const Shape& output_shape);
632 
633   // Returns a shape with the given dimension deleted.
634   // For example:
635   // • `DeleteDimension(1, T[m, n, k]) = T[m, k]`
636   static Shape DeleteDimension(int64 dim_to_delete, Shape shape);
637 
638   // Returns a shape with all the dimensions of the input shape for which `p`
639   // returns true.
640   // For examples:
641   // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
642   // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
643   static Shape FilterDimensions(const std::function<bool(int64)>& p,
644                                 Shape shape);
645 
646   // Iterates through all the shape indexes, in minor to major order, starting
647   // from the base indexes, incrementing by the incr steps, up to count
648   // (index[i] < base[i] + count[i]), and calls the visitor_function with the
649   // current index.
650   // The visitor_function visitor function should return true if it wants to
651   // continue, or false otherwise.
652   //
653   // visitor_function must be a callable of type
654   // StatusOr<bool>(absl::Span<int64>) or compatible.
655   template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)656   static Status ForEachIndexWithStatus(const Shape& shape,
657                                        absl::Span<const int64> base,
658                                        absl::Span<const int64> count,
659                                        absl::Span<const int64> incr,
660                                        const FnType& visitor_function) {
661     return ForEachIndexInternal(shape, base, count, incr, visitor_function);
662   }
663 
664   // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
665   struct IndexIterationSpace {
666     std::vector<int64> index_base;
667     std::vector<int64> index_count;
668     std::vector<int64> index_incr;
669   };
670 
671   template <typename FnTy>
ForEachIndexWithStatus(const Shape & shape,const IndexIterationSpace & iteration_space,FnTy && function)672   static Status ForEachIndexWithStatus(
673       const Shape& shape, const IndexIterationSpace& iteration_space,
674       FnTy&& function) {
675     return ShapeUtil::ForEachIndexWithStatus(
676         shape, iteration_space.index_base, iteration_space.index_count,
677         iteration_space.index_incr, std::forward<FnTy>(function));
678   }
679 
680   template <typename FnType>
ForEachIndex(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)681   static void ForEachIndex(const Shape& shape, absl::Span<const int64> base,
682                            absl::Span<const int64> count,
683                            absl::Span<const int64> incr,
684                            const FnType& visitor_function) {
685     ForEachIndexWithStatus(shape, base, count, incr,
686                            [&](absl::Span<const int64> indices) {
687                              return StatusOr<bool>(visitor_function(indices));
688                            })
689         .IgnoreError();
690   }
691 
692   // These convenience wrappers don't take `base`, `count` and `incr`
693   // explicitly, but iterate over every element in `shape` instead.
694 
695   template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,const FnType & visitor_function)696   static Status ForEachIndexWithStatus(const Shape& shape,
697                                        const FnType& visitor_function) {
698     std::vector<int64> base(shape.dimensions_size());
699     std::vector<int64> incr(shape.dimensions_size(), 1);
700     return ForEachIndexWithStatus(shape, base,
701                                   /*count=*/AsInt64Slice(shape.dimensions()),
702                                   incr, visitor_function);
703   }
704 
705   template <typename FnType>
ForEachIndex(const Shape & shape,const FnType & visitor_function)706   static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
707     ForEachIndexWithStatus(shape, [&](absl::Span<const int64> indices) {
708       return StatusOr<bool>(visitor_function(indices));
709     }).IgnoreError();
710   }
711 
712   // A parallel version of ForEachIndex(WithStatus). This can only be used if
713   // the visitor_function is thread-safe and the order of iteration does not
714   // matter.
715   //
716   // visitor_function must be a callable of type
717   // void(Span<int64>) or compatible.
718   template <typename FnType>
ForEachIndexParallel(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)719   static void ForEachIndexParallel(const Shape& shape,
720                                    absl::Span<const int64> base,
721                                    absl::Span<const int64> count,
722                                    absl::Span<const int64> incr,
723                                    const FnType& visitor_function) {
724     // The parallel version of ForEachIndexInternal can never fail.
725     CHECK(ForEachIndexInternal(
726               shape, base, count, incr,
727               [&visitor_function](
728                   absl::Span<const int64> indexes) -> StatusOr<bool> {
729                 visitor_function(indexes);
730                 return true;
731               },
732               /*parallel=*/true)
733               .ok());
734   }
735 
736   // Compute a hash for `shape`.
737   static size_t Hash(const Shape& shape);
738 
739   // About 0-2-1 transpose:
740   //
741   // If a shape can be viewed as three logical components 0-1-2 in the order of
742   // major to minor, a 0-2-1-transpose changes the order of such logical
743   // components to 0-2-1. We call the shape being transposed the input shape and
744   // the transposed shape the output shape. The logical view of the input/output
745   // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the
746   // normalized shapes. The original input/output shapes are called unnormalized
747   // shapes.
748   //
749   // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
750   // normalized shape of `b` or the 0-2-1 shape.
751   static absl::optional<std::vector<int64>> FindTranspose021(const Shape& a,
752                                                              const Shape& b);
753 
754  private:
755   // Validates the shape size is sane. This makes sure it's safe to do
756   // calculations in int64 without overflowing.
757   static Status ValidateShapeSize(const Shape& shape);
758 
759   // Validates all of the non-layout properties of the shape -- this is a helper
760   // used by both the layout-optional and layout-required public method.
761   static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
762 
763   template <typename FnType>
764   static Status ForEachIndexInternal(const Shape& shape,
765                                      absl::Span<const int64> base,
766                                      absl::Span<const int64> count,
767                                      absl::Span<const int64> incr,
768                                      const FnType& visitor_function,
769                                      bool parallel = false) {
770     if (ShapeUtil::IsZeroElementArray(shape)) {
771       return Status::OK();
772     }
773     CHECK_EQ(shape.rank(), base.size());
774     CHECK_EQ(incr.size(), base.size());
775     CHECK_EQ(count.size(), base.size());
776     const int64 rank = LayoutUtil::MinorToMajor(shape).size();
777     // Allows handling R0 arrays, such that the visitor function will be called
778     // once with the proper empty indexes.
779     int64 n = -1;
780     std::vector<int64> indexes(base.begin(), base.end());
781     const int kNumThreads = tensorflow::port::MaxParallelism();
782     absl::optional<tensorflow::thread::ThreadPool> pool;
783     if (parallel) {
784       pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
785     }
786 
787     tensorflow::mutex mu;
788     Status status;  // Guarded by mu
789 
790     while (n < rank) {
791       if (pool != absl::nullopt) {
792         pool->Schedule([indexes, &visitor_function, &mu, &status] {
793           StatusOr<bool> result = visitor_function(indexes);
794           if (!result.ok()) {
795             tensorflow::mutex_lock lock(mu);
796             status = status.ok() ? result.status() : status;
797           }
798         });
799       } else {
800         TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
801         if (!should_continue) {
802           break;
803         }
804       }
805       // Increments dimensions in minor to major order.
806       for (n = 0; n < rank; ++n) {
807         int64 dim = LayoutUtil::Minor(shape.layout(), n);
808         indexes[dim] += incr[dim];
809         if (indexes[dim] < base[dim] + count[dim]) {
810           break;
811         }
812         indexes[dim] = base[dim];
813       }
814     }
815 
816     // Waits for the scheduled work to complete.
817     pool.reset();
818     return status;
819   }
820 
821   TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
822 };
823 
824 }  // namespace xla
825 
826 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
827