• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Shapes are protobuf messages, so this utility header offers a bunch of
17 // functionality for querying / poking at them.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
21 
22 #include <initializer_list>
23 #include <string>
24 #include <tuple>
25 
26 #include "absl/base/macros.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/types/optional.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/platform/cpu_info.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 class ShapeIndexView;
48 
49 // An index for specifying a particular nested subshape within a shape. Used in
50 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
51 // structures (trees) and ShapeIndex defines a path through the tree where each
52 // element of ShapeIndex indexes into a tuple (or nested tuple) within the
53 // shape. For a non-nested tuple, an index has a single element. For example,
54 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
55 // {1} corresponds to array b. For a nested tuple, the index can have more than
56 // one element. For the nested tuple (a, (b, c, d), e) below are the values
57 // corresponding to the given indices:
58 //
59 //   index {0}    : array a
60 //   index {1, 2} : array d
61 //   index {2}    : array e
62 //   index {0, 0} : invalid index (element at {0} is an array not a tuple)
63 //
64 // For indexing into array shapes, the index is always trivially empty, ie {}.
65 //
66 // ShapeIndex is a trivial wrapper around std::vector with a minimum number of
67 // methods implemented.
68 class ShapeIndex {
69  public:
70   ShapeIndex() = default;
ShapeIndex(std::initializer_list<int64> init)71   ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
72   template <typename InputIt>
ShapeIndex(InputIt start,InputIt end)73   ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {}
74 
75   explicit ShapeIndex(ShapeIndexView v);
76 
empty()77   bool empty() const { return indices_.empty(); }
size()78   size_t size() const { return indices_.size(); }
push_back(int64_t value)79   void push_back(int64_t value) { indices_.push_back(value); }
pop_back()80   void pop_back() { indices_.pop_back(); }
81 
82   // push_front is O(n), but shapes don't usually have a ton of dimensions.
push_front(int64_t value)83   void push_front(int64_t value) { indices_.insert(indices_.begin(), value); }
pop_front()84   void pop_front() { indices_.erase(indices_.begin()); }
85 
86   using container_type = absl::InlinedVector<int64, 2>;
87 
begin()88   container_type::const_iterator begin() const { return indices_.begin(); }
end()89   container_type::const_iterator end() const { return indices_.end(); }
begin()90   container_type::iterator begin() { return indices_.begin(); }
end()91   container_type::iterator end() { return indices_.end(); }
92 
data()93   const int64* data() const { return indices_.data(); }
94 
back()95   int64 back() const { return indices_.back(); }
back()96   int64& back() { return indices_.back(); }
97 
front()98   int64 front() const { return indices_.front(); }
front()99   int64& front() { return indices_.front(); }
100 
101   const int64& operator[](size_t i) const { return indices_[i]; }
102   int64& operator[](size_t i) { return indices_[i]; }
103 
104   bool operator==(const ShapeIndex& other) const {
105     return indices_ == other.indices_;
106   }
107   bool operator!=(const ShapeIndex& other) const { return !(*this == other); }
108   bool operator<(const ShapeIndex& other) const {
109     return indices_ < other.indices_;
110   }
111 
112   string ToString() const;
113 
114   template <typename H>
AbslHashValue(H h,const ShapeIndex & index)115   friend H AbslHashValue(H h, const ShapeIndex& index) {
116     return H::combine(std::move(h), index.indices_);
117   }
118 
119  private:
120   container_type indices_;
121 };
122 
123 // A view into a ShapeIndex as above, with the cheap/easy ability to consume the
124 // value at the front of the view.
125 //
126 // NB! ShapeIndexView does not own the memory backing the index array.
127 // The memory backing the index array should be owned by an object
128 // that lives longer than the ShapeIndexView instances pointing into
129 // it.
130 class ShapeIndexView {
131  public:
132   ShapeIndexView(const ShapeIndex& shape_index, int64_t offset = 0)
133       : indices_(shape_index.data() + offset, shape_index.size() - offset) {
134     CHECK_LE(offset, shape_index.size());
135   }
ShapeIndexView(std::initializer_list<int64> indices)136   ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {}
137   ShapeIndexView(const ShapeIndexView& other) = default;
138 
139   using iterator = const int64*;
140 
begin()141   iterator begin() const { return indices_.begin(); }
end()142   iterator end() const { return indices_.end(); }
size()143   int64 size() const { return indices_.size(); }
empty()144   bool empty() const { return indices_.empty(); }
front()145   int64 front() const {
146     CHECK(!empty());
147     return indices_.front();
148   }
back()149   int64 back() const {
150     CHECK(!empty());
151     return indices_.back();
152   }
ConsumeFront()153   ShapeIndexView ConsumeFront() const {
154     ShapeIndexView result = *this;
155     result.indices_.remove_prefix(1);
156     return result;
157   }
ConsumeBack()158   ShapeIndexView ConsumeBack() const {
159     ShapeIndexView result = *this;
160     result.indices_.remove_suffix(1);
161     return result;
162   }
ToShapeIndex()163   ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }
164 
165   bool operator==(const ShapeIndexView& other) const;
166   bool operator!=(const ShapeIndexView& other) const;
167 
168   string ToString() const;
169 
170   // Returns true if this shape index starts with 'prefix'.
171   bool StartsWith(ShapeIndexView prefix) const;
172 
173  private:
174   absl::Span<const int64> indices_;
175 };
176 
ShapeIndex(ShapeIndexView v)177 inline ShapeIndex::ShapeIndex(ShapeIndexView v)
178     : ShapeIndex(v.begin(), v.end()) {}
179 
180 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
181 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
182 
183 // Namespaced collection of (static) shape utilities.
184 //
185 // These are all effectively convenience functions for testing/tweaking proto
186 // properties, which do invariant checks before / after the operation.
187 class ShapeUtil {
188  public:
189   // Data structure which describes the coordinates and the shape, of a tuple
190   // shaped sub-shape.
191   struct IndexedShape {
192     IndexedShape() = default;
IndexedShapeIndexedShape193     IndexedShape(ShapeIndex index, Shape shape)
194         : index(std::move(index)), shape(std::move(shape)) {}
195     ShapeIndex index;
196     Shape shape;
197   };
198 
199   // Returns the number of elements are contained within the provided shape;
200   // e.g. for rank 0 (scalars) the result is always 1.
201   // Precondition: shape.IsArray()
202   static int64 ElementsIn(const Shape& shape);
203 
204   // As ElementsIn(), but recurses through tuples.
205   static int64 ElementsInRecursive(const Shape& shape);
206 
207   // Returns true if shape has the primitive type, recurses through tuples.
208   static bool HasPrimitiveType(const Shape& shape,
209                                PrimitiveType primitive_type);
210 
211   // Returns true if 'shape' is an array with zero elements.
212   static bool IsZeroElementArray(const Shape& shape);
213 
214   // Returns the number of bytes required for an allocation of shape.  The
215   // |pointer_size| parameter is used for calculating the size of tuple
216   // shapes. This includes only the size of the top-level buffer. For example, a
217   // tuple is stored as an array of pointers to other buffers. In this case,
218   // this method only returns the size of the pointer array.
219   static int64 ByteSizeOf(const Shape& shape, int64_t pointer_size = -1);
220 
221   // Returns the number of bytes used to store the primitive_type.
222   //
223   // Precondition: shape.IsArray()
224   static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
225 
226   // Returns the number of bytes required to store the tuple member pointers for
227   // a allocation of shape. The `shape` must be a TUPLE shape, and
228   // `pointer_size` must be larger than zero.
229   static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
230                                          int64_t pointer_size);
231 
232   // Returns the number of bytes required for the elements in an allocation of
233   // `shape`, which must be an array shape. Shapes use a separate
234   // memory location for each element, and so for these shapes,
235   // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. This
236   // size also includes padding if present in the layout.
237   static int64 ByteSizeOfElements(const Shape& shape);
238 
239   // Returns a human-readable string that represents the given shape, with or
240   // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
241   static string HumanString(const Shape& shape);
242   static string HumanStringWithLayout(const Shape& shape);
243 
244   // As above, but for program shapes, returns a string for the form:
245   //
246   // (param_name: f32[42x12], ...) -> f32[24x42]
247   static string HumanString(const ProgramShape& program_shape);
248 
249   // Returns whether the LHS and RHS shapes have the same dimensions; note: does
250   // not check element type.
251   // Precondition: IsArray(lhs) && IsArray(rhs)
252   static bool SameDimensions(const Shape& lhs, const Shape& rhs);
253 
254   // Returns whether the LHS and RHS shapes have the same rank; note: does
255   // not check element type.
256   // Precondition: IsArray(lhs) && IsArray(rhs)
257   static bool SameRank(const Shape& lhs, const Shape& rhs);
258 
259   // Returns whether the lhs and rhs shapes have the same element type.
SameElementType(const Shape & lhs,const Shape & rhs)260   static bool SameElementType(const Shape& lhs, const Shape& rhs) {
261     return lhs.element_type() == rhs.element_type();
262   }
263 
264   // As SameElementType, but allows floating point types to have different
265   // precisions.
SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)266   static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
267                                                  const Shape& b) {
268     if (ElementIsFloating(a) && ElementIsFloating(b)) {
269       return true;
270     }
271     return ShapeUtil::SameElementType(a, b);
272   }
273 
274   // Returns the higher-precision element type if a and b are both floating
275   // point types; otherwise, checks that they have the same element type
276   // and returns it.
HigherPrecisionElementType(const Shape & a,const Shape & b)277   static PrimitiveType HigherPrecisionElementType(const Shape& a,
278                                                   const Shape& b) {
279     return primitive_util::HigherPrecisionType(a.element_type(),
280                                                b.element_type());
281   }
282 
283   // Returns true if the rank, dimension sizes, and element type are
284   // identical. Layout is ignored. Tuple elements are compared recursively for
285   // compatibility.
286   static bool Compatible(const Shape& lhs, const Shape& rhs);
287 
288   // Returns true if the rank and dimension sizes are identical. Element type
289   // and layout are ignored. Tuple elements are compared recursively for
290   // compatibility.
291   static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
292 
293   // Returns true if the tuple tree shapes and leaf ranks are identical.
294   // Leaf dimensions, element type, and layout are ignored. Tuple elements are
295   // compared recursively for compatibility.
296   static bool CompatibleKind(const Shape& lhs, const Shape& rhs);
297 
298   // As Compatible, but allow one of lhs and rhs to be BF16 while the other
299   // being F32. Tuple elements are compared recursively for compatibility.
300   static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
301 
302   // Returns whether the lhs and rhs shapes are identical.
303   static bool Equal(const Shape& lhs, const Shape& rhs);
304 
305   // As Equal, but does not compare the element type.
306   static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs);
307 
308   // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
309   static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
310 
311   // Two shapes have same structure if all subshape indices of lhs are presented
312   // on rhs and vice versa.
313   // A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to
314   // (S32, (F32[3], S32[2])) as their structures are both (,(,))
315   //
316   // In contrast, (F32, (F32, F32)) is structurally different from
317   // ((F32, F32), F32) as the former has structure (,(,)) while the latter has
318   // ((,),)
319   static bool EqualStructure(const Shape& lhs, const Shape& rhs);
320 
321   // Returns the number of dimensions for which the dimension is not (trivially)
322   // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
323   // fluff. Note that zero dimensions are included in the true rank, e.g.,
324   // f32[3,0,1] has a true rank of 2D.
325   static int64 TrueRank(const Shape& shape);
326 
327   static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
328                                        Shape result);
329 
330   ////////////////////
331   // Scalar-specific
332 
IsScalar(const Shape & shape)333   static bool IsScalar(const Shape& shape) {
334     return shape.IsArray() && shape.rank() == 0;
335   }
IsEffectiveScalar(const Shape & shape)336   static bool IsEffectiveScalar(const Shape& shape) {
337     return shape.IsArray() && TrueRank(shape) == 0;
338   }
339 
340   // Returns whether "shape" is a scalar (array) with the given element_type.
341   static bool IsScalarWithElementType(const Shape& shape,
342                                       PrimitiveType element_type);
343 
344   // Extracts the size of the shape's dimension at dimension number
345   // GetDimensionNumber(dimension_number).
346   static int64 GetDimension(const Shape& shape, int64_t dimension_number);
347 
348   // Resolves a dimension number, supporting negative indexing.
349   //
350   // Negative indexing has similar semantics to Python. For an N-dimensional
351   // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
352   // N-2, and so on.
353   //
354   // This function always returns a positive dimension number for any given
355   // dimension_number (which itself can be negative).
356   static int64 GetDimensionNumber(const Shape& shape, int64_t dimension_number);
357 
358   // Returns a shape with the same dimensions as the original, but with the
359   // element type changed to type.
360   static Shape ChangeElementType(const Shape& original, PrimitiveType type);
361 
362   // Retursn a shape with same dimensions but with all dimensions set to static.
363   static Shape MakeStaticShape(const Shape& original);
364 
365   // Creates a tuple shape from a slice of element shapes within the tuple.
366   static Shape MakeTupleShape(absl::Span<const Shape> shapes);
367 
368   // Creates an opaque shape. These are generally used for threading a context
369   // into a custom operation.
370   static Shape MakeOpaqueShape();
371 
372   // Creates a token shape. Values of this shape are used for ordering
373   // side-effecting operations.
374   static Shape MakeTokenShape();
375 
376   // Appends a shape to the given tuple.
377   static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
378 
379   // Update a subshape of a tuple.
380   static void UpdateTupleShape(const Shape& shape, int64_t index,
381                                Shape* tuple_shape);
382 
383   // Update the dynamic dimension for a shape. This shape can be a nested tuple.
384   static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index,
385                                      int64_t dim, bool is_dynamic);
386 
387   // Appends a major dimension to the shape with the given bound.
388   static void AppendMajorDimension(int bound, Shape* shape);
389 
390   // Copy the dynamic dimensions property from one shape to another.
391   static void CopyDynamicDimensions(Shape* to, const Shape& from);
392 
393   // Returns an empty tuple shape. Can be used as a sentinel Shape value.
MakeNil()394   static Shape MakeNil() { return MakeTupleShape({}); }
395 
396   // Checks whether the shape is initialized.
IsInitialized(const Shape & shape)397   static bool IsInitialized(const Shape& shape) {
398     return shape.element_type() != PRIMITIVE_TYPE_INVALID;
399   }
400 
401   // Constructs a new shape with the given element type and sequence of
402   // dimensions.
403   static Shape MakeShape(PrimitiveType element_type,
404                          absl::Span<const int64> dimensions);
405 
406   // Make a scalar shape with given primitive type.
407   static Shape MakeScalarShape(PrimitiveType element_type);
408 
409   // Constructs a new shape with the given element type and sequence of
410   // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates
411   // with a true value that the respective dimension is dynamic. If the
412   // dimension is dynamic then the respective value in 'dimension' is an upper
413   // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be
414   // the same size.
415   static Shape MakeShape(PrimitiveType element_type,
416                          absl::Span<const int64> dimensions,
417                          const std::vector<bool>& dynamic_dimensions);
418 
419   // Constructs a new shape with the given element type and sequence of
420   // dimensions. Method checks if the element type is valid and the shape's
421   // size fits in std::numeric_limits<int64>::max().
422   static StatusOr<Shape> MakeValidatedShape(PrimitiveType element_type,
423                                             absl::Span<const int64> dimensions);
424   static StatusOr<Shape> MakeValidatedShape(
425       PrimitiveType element_type, absl::Span<const int64> dimensions,
426       const std::vector<bool>& dynamic_dimensions);
427 
428   // Creates a Shape with element type corresponding to T and the given
429   // dimensions
430   template <typename T>
MakeShapeWithType(absl::Span<const int64> dimensions)431   static Shape MakeShapeWithType(absl::Span<const int64> dimensions) {
432     return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(),
433                                 dimensions);
434   }
435 
436   // Constructs a new shape with the given minor_to_major order in its Layout.
437   // Returns a value shape such that shape.has_layout().
438   static Shape MakeShapeWithLayout(PrimitiveType element_type,
439                                    absl::Span<const int64> dimensions,
440                                    absl::Span<const int64> minor_to_major,
441                                    absl::Span<const Tile> tiles = {},
442                                    int64_t element_size_in_bits = 0,
443                                    int64_t memory_space = 0);
444 
445   // Constructs a new shape with the given dimension `dim` as the most major
446   // dimension in the layout. If the shape does not have a layout, assumes a
447   // default layout. If the shape is a tuple, apply this to all the leaf shapes
448   // of the tuple.
449   static Shape MoveDimToMajor(const Shape& shape, int64_t dim);
450 
451   // Returns the same shape except with all dimensions set to be static.
452   static Shape MakeShapeWithStaticDimensions(const Shape& shape);
453 
454   // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
455   static Shape MakeShapeWithDescendingLayout(
456       PrimitiveType element_type, absl::Span<const int64> dimensions);
457 
458   // Returns a new Shape based on the given Shape with low-dimension-major
459   // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
460   // rearranged so that it has the same in-memory layout as the given shape.
461   //
462   // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
463   static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
464       const Shape& shape);
465 
466   // As MakeShape, but the object to write to is passed in.
467   static Status PopulateShape(PrimitiveType element_type,
468                               absl::Span<const int64> dimensions, Shape* shape);
469 
470   // Validates that the provided shape satisfies invariants.
471   static Status ValidateShape(const Shape& shape);
472 
473   // Validates the provided shape satisfies invariants, except those that
474   // pertain to layout.
475   //
476   // Layout is optional for client-provided shapes, so that the compiler may
477   // determine and assign an optimized layout.
478   static Status ValidateShapeWithOptionalLayout(const Shape& shape);
479 
480   // Returns whether the element type of the shape is integral (signed or
481   // unsigned). Note that predicates are not considered integral here, since
482   // they are logical values.
483   static bool ElementIsIntegral(const Shape& shape);
484 
485   // Returns whether the element type of the shape is floating point.
486   static bool ElementIsFloating(const Shape& shape);
487 
488   // Returns whether the element type of the shape is complex.
489   static bool ElementIsComplex(const Shape& shape);
490 
491   // Returns whether the element type has the given bit width.
492   static bool ElementHasBitWidth(const Shape& shape, int bits);
493 
494   // Returns whether the element type of the shape is integral and has
495   // the specified number of bits.
496   static bool ElementIsIntegralWithBits(const Shape& shape, int bits);
497 
498   // Returns whether the element type of the shape is signed. Note
499   // that floating point numbers are signed.
500   static bool ElementIsSigned(const Shape& shape);
501 
502   // Returns whether the given primitive type corresponds to an array shape.
503   static bool IsArrayPrimitiveType(PrimitiveType primitive_type);
504 
505   // Returns whether the shape is a tuple with at least one element which is
506   // also a tuple.
507   static bool IsNestedTuple(const Shape& shape);
508 
509   // Returns true if shape is an empty tuple.
510   static bool IsEmptyTuple(const Shape& shape);
511 
512   // Returns the number of elements in the given tuple shape.
513   // Precondition: IsTuple(shape)
514   static int64 TupleElementCount(const Shape& shape);
515 
516   // Returns the tuple element shape at given index.
517   // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
518   static const Shape& GetTupleElementShape(const Shape& shape, int64_t index);
519 
520   // Returns the number of elements, recursively, in the given shape.
521   static int64 SubshapeCount(const Shape& shape);
522 
523   // Slices tuple elements in the range [start, limit) and returns a new tuple
524   // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
525   static Shape SliceTuple(const Shape& tuple, int64_t start, int64_t limit);
526 
527   // Returns the shape of the real/imaginary components of the given complex
528   // shape.
529   static Shape ComplexComponentShape(const Shape& complex_shape);
530 
531   // Returns true if the given shape has a subshape at the given index.
532   static bool IndexIsValid(const Shape& shape, ShapeIndexView index);
533 
534   // GetSubshape and GetMutableSubshape return a particular nested Shape within
535   // the given Shape argument. The non-Try variants check fail if index is
536   // invalid.
537   static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
538   static StatusOr<const Shape*> TryGetSubshape(const Shape& shape,
539                                                ShapeIndexView index);
540   static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
541 
542   // Returns whether the given index in the given shape is a leaf element of the
543   // shape.
544   static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);
545 
546   // Returns the number of leaves in the shape.
547   static int64 GetLeafCount(const Shape& shape);
548 
549   // Retrieves all the leaf shapes and their indexes, in the order walked by
550   // the ForEachSubshape() API.
551   static std::vector<IndexedShape> GetLeafShapes(const Shape& shape);
552 
553   // Calls the given visitor function for each subshape of the given shape.
554   // Subshapes are visited in DFS pre-order starting with the entire shape
555   // (index {}).
556   using VisitorFunction = std::function<void(const Shape& /*subshape*/,
557                                              const ShapeIndex& /*index*/)>;
558   static void ForEachSubshape(const Shape& shape, const VisitorFunction& func);
559   using MutatingVisitorFunction =
560       std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
561   static void ForEachMutableSubshape(Shape* shape,
562                                      const MutatingVisitorFunction& func);
563 
564   // Variants of ForEach(Mutable)Subshape which propagate Status from the
565   // visitor function.
566   using StatusVisitorFunction = std::function<Status(
567       const Shape& /*subshape*/, const ShapeIndex& /*index*/)>;
568   static Status ForEachSubshapeWithStatus(const Shape& shape,
569                                           const StatusVisitorFunction& func);
570   using MutatingStatusVisitorFunction =
571       std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
572   static Status ForEachMutableSubshapeWithStatus(
573       Shape* shape, const MutatingStatusVisitorFunction& func);
574 
575   // Returns true if `shape` (which must be an array) with degenerate dimensions
576   // (dimensions with bound 1).
577   static bool HasDegenerateDimensions(const Shape& shape);
578 
579   // Drops any degenerate dimensions (i.e. dimensions of size 1)
580   static Shape DropDegenerateDimensions(const Shape& shape);
581 
582   // Permutes the dimensions by the given permutation, so
583   // return_value.dimensions[i] = argument.dimensions[permutation[i]].
584   //
585   // Postcondition: For any valid permutation,
586   //
587   //   !HasLayout(shape) ||
588   //   TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
589   //                      permutation).
590   static Shape PermuteDimensions(absl::Span<const int64> permutation,
591                                  const Shape& shape);
592 
593   // If we can go from `shape_pre` to `shape_post` by merely inserting or
594   // deleting 1-sized dimensions, return the indices in `shape_pre` of the
595   // deleted dimensions and the indices in `dims_post` of the inserted
596   // dimensions.
597   // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
598   // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
599   // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
600   // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
601   // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
602   // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
603   static std::tuple<bool, std::vector<int64>, std::vector<int64>>
604   InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
605                                     const Shape& shape_post);
606 
607   // Suppose a reshape transforms input_shape to output shape. Returns a vector
608   // of pairs that indicate the input and output dimensions that this reshape
609   // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
610   // the returned vector, the reshape transforms any input index whose I-th
611   // dimension is x to an output index whose O-th dimension is x too.
612   //
613   // Post-condition: the returned vector is sorted (by both input and output
614   // dimensions because input and output dimensions have the same order).
615   //
616   // Example:
617   //   input  shape = T[a, b, x, y, cd]
618   //   output shape = T[ab, x, 1, y, c, d]
619   //   return value = {{2, 1}, {3, 3}}
620   //
621   //   The two pairs represent the input and output dimension of size x and
622   //   those of size y.
623   static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
624       const Shape& input_shape, const Shape& output_shape);
625 
626   // Return whether the given reshape instruction leaves the dimensions at the
627   // given input indices unmodified, and returns their output indices.
628   //
629   // Example:
630   //   input_dim_indices = {2, 3}
631   //   input  shape = T[a, b, x, y, cd]
632   //   output shape = T[ab, x, 1, y, c, d]
633   //   return value = {1, 3}
634   //
635   // Precondition: input_dim_indices is sorted.
636   static absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
637       const Shape& from_shape, const Shape& to_shape,
638       absl::Span<const int64> input_dim_indices);
639 
640   // Returns whether a transpose from input_shape to output_shape with dimension
641   // mapping "dimension_mapping" produces a result which is bit-wise identical
642   // to its input and thus may be replaced with a bitcast.
643   //
644   // Precondition: Both input_shape and output_shape have explicit layouts.
645   static bool TransposeIsBitcast(const Shape& input_shape,
646                                  const Shape& output_shape,
647                                  absl::Span<const int64> dimension_mapping);
648 
649   // Returns whether a reshape from "input_shape" to "output_shape" is a
650   // bitcast.
651   //
652   // Precondition: Both input_shape and output_shape have explicit layouts.
653   static bool ReshapeIsBitcast(const Shape& input_shape,
654                                const Shape& output_shape);
655 
656   // Find a physical layout for 'output_shape' such that
657   // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
658   // true (where 'output_shape_with_layout' is 'output_shape' with the found
659   // layout). The layout of 'input_shape' is kept fixed. Returns
660   // 'output_shape_with_layout' if such a layout can be found, and an error
661   // otherwise.
662   static absl::optional<Shape> AlignLayouts(const Shape& input_shape,
663                                             const Shape& output_shape);
664 
665   // Returns a shape with the given dimension deleted.
666   // For example:
667   // • `DeleteDimension(1, T[m, n, k]) = T[m, k]`
668   static Shape DeleteDimension(int64_t dim_to_delete, Shape shape);
669 
670   // Returns a shape with all the dimensions of the input shape for which `p`
671   // returns true.
672   // For examples:
673   // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
674   // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
675   static Shape FilterDimensions(const std::function<bool(int64_t)>& p,
676                                 Shape shape);
677 
678   // Returns true if `dynamic_shape` has dimensions that are less-equal to the
679   // "bounded_shape". Shapes must be arrays.
680   static bool DynamicArrayShapeIsCompatible(const xla::Shape& dynamic_shape,
681                                             const xla::Shape& bounded_shape);
682 
683   // Same as DynamicArrayShapeIsCompatible() but supports tuples.
684   static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
685                                        const xla::Shape& bounded_shape);
686 
687   // Iterates through all the shape indexes, in minor to major order,
688   // starting from the base indexes, incrementing by the incr steps, up to
689   // count (index[i] < base[i] + count[i]), and calls the visitor_function
690   // with the current index. The visitor_function visitor function should
691   // return true if it wants to continue, or false otherwise.
692   //
693   // visitor_function must be a callable of type
694   // StatusOr<bool>(absl::Span<int64>) or compatible.
695   template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)696   static Status ForEachIndexWithStatus(const Shape& shape,
697                                        absl::Span<const int64> base,
698                                        absl::Span<const int64> count,
699                                        absl::Span<const int64> incr,
700                                        const FnType& visitor_function) {
701     return ForEachIndexInternal(shape, base, count, incr, visitor_function);
702   }
703 
704   // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
705   struct IndexIterationSpace {
706     std::vector<int64> index_base;
707     std::vector<int64> index_count;
708     std::vector<int64> index_incr;
709   };
710 
711   template <typename FnTy>
ForEachIndexWithStatus(const Shape & shape,const IndexIterationSpace & iteration_space,FnTy && function)712   static Status ForEachIndexWithStatus(
713       const Shape& shape, const IndexIterationSpace& iteration_space,
714       FnTy&& function) {
715     return ShapeUtil::ForEachIndexWithStatus(
716         shape, iteration_space.index_base, iteration_space.index_count,
717         iteration_space.index_incr, std::forward<FnTy>(function));
718   }
719 
720   template <typename FnType>
ForEachIndex(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)721   static void ForEachIndex(const Shape& shape, absl::Span<const int64> base,
722                            absl::Span<const int64> count,
723                            absl::Span<const int64> incr,
724                            const FnType& visitor_function) {
725     ForEachIndexWithStatus(shape, base, count, incr,
726                            [&](absl::Span<const int64> indices) {
727                              return StatusOr<bool>(visitor_function(indices));
728                            })
729         .IgnoreError();
730   }
731 
732   // These convenience wrappers don't take `base`, `count` and `incr`
733   // explicitly, but iterate over every element in `shape` instead.
734 
735   template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,const FnType & visitor_function)736   static Status ForEachIndexWithStatus(const Shape& shape,
737                                        const FnType& visitor_function) {
738     std::vector<int64> base(shape.dimensions_size());
739     std::vector<int64> incr(shape.dimensions_size(), 1);
740     return ForEachIndexWithStatus(shape, base,
741                                   /*count=*/AsInt64Slice(shape.dimensions()),
742                                   incr, visitor_function);
743   }
744 
745   template <typename FnType>
ForEachIndex(const Shape & shape,const FnType & visitor_function)746   static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
747     ForEachIndexWithStatus(shape, [&](absl::Span<const int64> indices) {
748       return StatusOr<bool>(visitor_function(indices));
749     }).IgnoreError();
750   }
751 
752   // A parallel version of ForEachIndex(WithStatus). This can only be used if
753   // the visitor_function is thread-safe and the order of iteration does not
754   // matter.
755   //
756   // visitor_function must be a callable of type
757   // void(Span<int64>) or compatible.
758   template <typename FnType>
ForEachIndexParallel(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)759   static void ForEachIndexParallel(const Shape& shape,
760                                    absl::Span<const int64> base,
761                                    absl::Span<const int64> count,
762                                    absl::Span<const int64> incr,
763                                    const FnType& visitor_function) {
764     // The parallel version of ForEachIndexInternal can never fail.
765     CHECK(ForEachIndexInternal(
766               shape, base, count, incr,
767               [&visitor_function](
768                   absl::Span<const int64> indexes) -> StatusOr<bool> {
769                 visitor_function(indexes);
770                 return true;
771               },
772               /*parallel=*/true)
773               .ok());
774   }
775 
776   // Compute a hash for `shape`.
777   static size_t Hash(const Shape& shape);
778 
779   // About 0-2-1 transpose:
780   //
781   // If a shape can be viewed as three logical components 0-1-2 in the order of
782   // major to minor, a 0-2-1-transpose changes the order of such logical
783   // components to 0-2-1. We call the shape being transposed the input shape and
784   // the transposed shape the output shape. The logical view of the input/output
785   // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the
786   // normalized shapes. The original input/output shapes are called unnormalized
787   // shapes.
788   //
789   // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
790   // normalized shape of `b` or the 0-2-1 shape.
791   static absl::optional<std::vector<int64>> FindTranspose021(const Shape& a,
792                                                              const Shape& b);
793 
794   // Strips device-specific information, namely tiling and memory-space
795   // information, from a shape.
796   static Shape DeviceShapeToHostShape(Shape s);
797 
798   // Returns true iff element type of shape `from` can be safely upcasted to
799   // element type of shape `to`.
800   static bool ElementCanUpcast(const Shape& from, const Shape& to);
801 
802   // Computes byte strides of an array shape `shape`. `shape` must have a
803   // layout. Ignores tiling. `strides` must have size equal to the number of
804   // dimensions of `shape`.
805   static Status ByteStrides(const Shape& shape, absl::Span<int64_t> strides);
806 
807  private:
808   // Fills *shape. Returns true on success.
809   // REQUIRES: *shape is empty.
810   static bool FillNewShape(PrimitiveType element_type,
811                            absl::Span<const int64> dimensions, Shape* shape);
812 
813   // Validates the shape size is sane. This makes sure it's safe to do
814   // calculations in int64 without overflowing.
815   static Status ValidateShapeSize(const Shape& shape);
816 
817   // Validates all of the non-layout properties of the shape -- this is a helper
818   // used by both the layout-optional and layout-required public method.
819   static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
820 
821   template <typename FnType>
822   static Status ForEachIndexInternal(const Shape& shape,
823                                      absl::Span<const int64> base,
824                                      absl::Span<const int64> count,
825                                      absl::Span<const int64> incr,
826                                      const FnType& visitor_function,
827                                      bool parallel = false) {
828     if (ShapeUtil::IsZeroElementArray(shape)) {
829       return Status::OK();
830     }
831     CHECK_EQ(shape.rank(), base.size());
832     CHECK_EQ(incr.size(), base.size());
833     CHECK_EQ(count.size(), base.size());
834     const int64_t rank = LayoutUtil::MinorToMajor(shape).size();
835     // Allows handling R0 arrays, such that the visitor function will be called
836     // once with the proper empty indexes.
837     int64_t n = -1;
838     std::vector<int64> indexes(base.begin(), base.end());
839     const int kNumThreads = tensorflow::port::MaxParallelism();
840     absl::optional<tensorflow::thread::ThreadPool> pool;
841     if (parallel) {
842       pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
843     }
844 
845     tensorflow::mutex mu;
846     Status status;  // Guarded by mu
847 
848     while (n < rank) {
849       if (pool != absl::nullopt) {
850         pool->Schedule([indexes, &visitor_function, &mu, &status] {
851           StatusOr<bool> result = visitor_function(indexes);
852           if (!result.ok()) {
853             tensorflow::mutex_lock lock(mu);
854             status = status.ok() ? result.status() : status;
855           }
856         });
857       } else {
858         TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
859         if (!should_continue) {
860           break;
861         }
862       }
863       // Increments dimensions in minor to major order.
864       for (n = 0; n < rank; ++n) {
865         int64_t dim = LayoutUtil::Minor(shape.layout(), n);
866         indexes[dim] += incr[dim];
867         if (indexes[dim] < base[dim] + count[dim]) {
868           break;
869         }
870         indexes[dim] = base[dim];
871       }
872     }
873 
874     // Waits for the scheduled work to complete.
875     pool.reset();
876     return status;
877   }
878 
879   TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
880 };
881 
882 }  // namespace xla
883 
884 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
885