• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Shapes are protobuf messages, so this utility header offers a bunch of
17 // functionality for querying / poking at them.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
21 
22 #include <initializer_list>
23 #include <string>
24 #include <tuple>
25 
26 #include "absl/base/macros.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/types/optional.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/platform/cpu_info.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 class ShapeIndexView;
48 
49 // An index for specifying a particular nested subshape within a shape. Used in
50 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
51 // structures (trees) and ShapeIndex defines a path through the tree where each
52 // element of ShapeIndex indexes into a tuple (or nested tuple) within the
53 // shape. For a non-nested tuple, an index has a single element. For example,
54 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
55 // {1} corresponds to array b. For a nested tuple, the index can have more than
56 // one element. For the nested tuple (a, (b, c, d), e) below are the values
57 // corresponding to the given indices:
58 //
59 //   index {0}    : array a
60 //   index {1, 2} : array d
61 //   index {2}    : array e
62 //   index {0, 0} : invalid index (element at {0} is an array not a tuple)
63 //
64 // For indexing into array shapes, the index is always trivially empty, ie {}.
65 //
66 // ShapeIndex is a trivial wrapper around std::vector with a minimum number of
67 // methods implemented.
68 class ShapeIndex {
69  public:
70   ShapeIndex() = default;
ShapeIndex(std::initializer_list<int64> init)71   ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
72   template <typename InputIt>
ShapeIndex(InputIt start,InputIt end)73   ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {}
74 
75   explicit ShapeIndex(ShapeIndexView v);
76 
empty()77   bool empty() const { return indices_.empty(); }
size()78   size_t size() const { return indices_.size(); }
push_back(int64 value)79   void push_back(int64 value) { indices_.push_back(value); }
pop_back()80   void pop_back() { indices_.pop_back(); }
81 
82   // push_front is O(n), but shapes don't usually have a ton of dimensions.
push_front(int64 value)83   void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
84 
85   using container_type = absl::InlinedVector<int64, 2>;
86 
begin()87   container_type::const_iterator begin() const { return indices_.begin(); }
end()88   container_type::const_iterator end() const { return indices_.end(); }
begin()89   container_type::iterator begin() { return indices_.begin(); }
end()90   container_type::iterator end() { return indices_.end(); }
91 
data()92   const int64* data() const { return indices_.data(); }
93 
back()94   int64 back() const { return indices_.back(); }
back()95   int64& back() { return indices_.back(); }
96 
97   const int64& operator[](size_t i) const { return indices_[i]; }
98   int64& operator[](size_t i) { return indices_[i]; }
99 
100   bool operator==(const ShapeIndex& other) const {
101     return indices_ == other.indices_;
102   }
103   bool operator!=(const ShapeIndex& other) const { return !(*this == other); }
104   bool operator<(const ShapeIndex& other) const {
105     return indices_ < other.indices_;
106   }
107 
108   string ToString() const;
109 
110   template <typename H>
AbslHashValue(H h,const ShapeIndex & index)111   friend H AbslHashValue(H h, const ShapeIndex& index) {
112     return H::combine(std::move(h), index.indices_);
113   }
114 
115  private:
116   container_type indices_;
117 };
118 
119 // A view into a ShapeIndex as above, with the cheap/easy ability to consume the
120 // value at the front of the view.
121 //
122 // NB! ShapeIndexView does not own the memory backing the index array.
123 // The memory backing the index array should be owned by an object
124 // that lives longer than the ShapeIndexView instances pointing into
125 // it.
126 class ShapeIndexView {
127  public:
128   ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0)
129       : indices_(shape_index.data() + offset, shape_index.size() - offset) {
130     CHECK_LE(offset, shape_index.size());
131   }
ShapeIndexView(std::initializer_list<int64> indices)132   ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {}
133   ShapeIndexView(const ShapeIndexView& other) = default;
134 
135   using iterator = const int64*;
136 
begin()137   iterator begin() const { return indices_.begin(); }
end()138   iterator end() const { return indices_.end(); }
size()139   int64 size() const { return indices_.size(); }
empty()140   bool empty() const { return indices_.empty(); }
front()141   int64 front() const {
142     CHECK(!empty());
143     return indices_.front();
144   }
back()145   int64 back() const {
146     CHECK(!empty());
147     return indices_.back();
148   }
ConsumeFront()149   ShapeIndexView ConsumeFront() const {
150     ShapeIndexView result = *this;
151     result.indices_.remove_prefix(1);
152     return result;
153   }
ConsumeBack()154   ShapeIndexView ConsumeBack() const {
155     ShapeIndexView result = *this;
156     result.indices_.remove_suffix(1);
157     return result;
158   }
ToShapeIndex()159   ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }
160 
161   bool operator==(const ShapeIndexView& other) const;
162   bool operator!=(const ShapeIndexView& other) const;
163 
164   string ToString() const;
165 
166   // Returns true if this shape index starts with 'prefix'.
167   bool StartsWith(ShapeIndexView prefix) const;
168 
169  private:
170   absl::Span<const int64> indices_;
171 };
172 
ShapeIndex(ShapeIndexView v)173 inline ShapeIndex::ShapeIndex(ShapeIndexView v)
174     : ShapeIndex(v.begin(), v.end()) {}
175 
176 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
177 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
178 
179 // Namespaced collection of (static) shape utilities.
180 //
181 // These are all effectively convenience functions for testing/tweaking proto
182 // properties, which do invariant checks before / after the operation.
183 class ShapeUtil {
184  public:
185   // Data structure which describes the coordinates and the shape, of a tuple
186   // shaped sub-shape.
187   struct IndexedShape {
188     IndexedShape() = default;
IndexedShapeIndexedShape189     IndexedShape(ShapeIndex index, Shape shape)
190         : index(std::move(index)), shape(std::move(shape)) {}
191     ShapeIndex index;
192     Shape shape;
193   };
194 
195   // Returns the number of elements are contained within the provided shape;
196   // e.g. for rank 0 (scalars) the result is always 1.
197   // Precondition: shape.IsArray()
198   static int64 ElementsIn(const Shape& shape);
199 
200   // As ElementsIn(), but recurses through tuples.
201   static int64 ElementsInRecursive(const Shape& shape);
202 
203   // Returns true if shape has the primitive type, recurses through tuples.
204   static bool HasPrimitiveType(const Shape& shape,
205                                PrimitiveType primitive_type);
206 
207   // Returns true if 'shape' is an array with zero elements.
208   static bool IsZeroElementArray(const Shape& shape);
209 
210   // Returns the number of bytes required for an allocation of shape.  The
211   // |pointer_size| parameter is used for calculating the size of tuple
212   // shapes. This includes only the size of the top-level buffer. For example, a
213   // tuple is stored as an array of pointers to other buffers. In this case,
214   // this method only returns the size of the pointer array.
215   static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
216 
217   // Returns the number of bytes used to store the primitive_type.
218   //
219   // Precondition: shape.IsArray()
220   static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
221 
222   // Returns the number of bytes required to store the tuple member pointers for
223   // a allocation of shape. The `shape` must be a TUPLE shape, and
224   // `pointer_size` must be larger than zero.
225   static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
226                                          int64 pointer_size);
227 
228   // Returns the number of bytes required for the elements in an allocation of
229   // `shape`, which must be an array shape. Shapes use a separate
230   // memory location for each element, and so for these shapes,
231   // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. This
232   // size also includes padding if present in the layout.
233   static int64 ByteSizeOfElements(const Shape& shape);
234 
235   // Returns a human-readable string that represents the given shape, with or
236   // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
237   static string HumanString(const Shape& shape);
238   static string HumanStringWithLayout(const Shape& shape);
239 
240   // As above, but for program shapes, returns a string for the form:
241   //
242   // (param_name: f32[42x12], ...) -> f32[24x42]
243   static string HumanString(const ProgramShape& program_shape);
244 
245   // Returns whether the LHS and RHS shapes have the same dimensions; note: does
246   // not check element type.
247   // Precondition: IsArray(lhs) && IsArray(rhs)
248   static bool SameDimensions(const Shape& lhs, const Shape& rhs);
249 
250   // Returns whether the LHS and RHS shapes have the same rank; note: does
251   // not check element type.
252   // Precondition: IsArray(lhs) && IsArray(rhs)
253   static bool SameRank(const Shape& lhs, const Shape& rhs);
254 
255   // Returns whether the lhs and rhs shapes have the same element type.
SameElementType(const Shape & lhs,const Shape & rhs)256   static bool SameElementType(const Shape& lhs, const Shape& rhs) {
257     return lhs.element_type() == rhs.element_type();
258   }
259 
260   // As SameElementType, but allows floating point types to have different
261   // precisions.
SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)262   static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
263                                                  const Shape& b) {
264     if (ElementIsFloating(a) && ElementIsFloating(b)) {
265       return true;
266     }
267     return ShapeUtil::SameElementType(a, b);
268   }
269 
270   // Returns the higher-precision element type if a and b are both floating
271   // point types; otherwise, checks that they have the same element type
272   // and returns it.
HigherPrecisionElementType(const Shape & a,const Shape & b)273   static PrimitiveType HigherPrecisionElementType(const Shape& a,
274                                                   const Shape& b) {
275     // Returns a tuple where the elements are lexicographically ordered in terms
276     // of importance.
277     auto type_properties = [](const Shape& shape) {
278       return std::make_tuple(
279           // Prefer floating point types with more range over other
280           // floating-point types or non-floating point types.
281           ElementIsFloating(shape)
282               ? primitive_util::OverflowExponent(shape.element_type())
283               : -1,
284           // Prefer floating point types with more precision over less precise
285           // types.
286           ElementIsFloating(shape)
287               ? primitive_util::SignificandWidth(shape.element_type())
288               : -1,
289           // Prefer wider types over narrower types.
290           primitive_util::BitWidth(shape.element_type()),
291           // Prefer signed integer types over unsigned integer types.
292           primitive_util::IsSignedIntegralType(shape.element_type()));
293     };
294     auto a_properties = type_properties(a);
295     auto b_properties = type_properties(b);
296     if (a_properties > b_properties) {
297       return a.element_type();
298     }
299     if (b_properties > a_properties) {
300       return b.element_type();
301     }
302     CHECK(SameElementType(a, b));
303     return a.element_type();
304   }
305 
306   // Returns true if the rank, dimension sizes, and element type are
307   // identical. Layout is ignored. Tuple elements are compared recursively for
308   // compatibility.
309   static bool Compatible(const Shape& lhs, const Shape& rhs);
310 
311   // Returns true if the rank and dimension sizes are identical. Element type
312   // and layout are ignored. Tuple elements are compared recursively for
313   // compatibility.
314   static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
315 
316   // Returns true if the tuple tree shapes and leaf ranks are identical.
317   // Leaf dimensions, element type, and layout are ignored. Tuple elements are
318   // compared recursively for compatibility.
319   static bool CompatibleKind(const Shape& lhs, const Shape& rhs);
320 
321   // As Compatible, but allow one of lhs and rhs to be BF16 while the other
322   // being F32. Tuple elements are compared recursively for compatibility.
323   static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
324 
325   // Returns whether the lhs and rhs shapes are identical.
326   static bool Equal(const Shape& lhs, const Shape& rhs);
327 
328   // As Equal, but does not compare the element type.
329   static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs);
330 
331   // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
332   static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
333 
334   // Two shapes have same structure if all subshape indices of lhs are presented
335   // on rhs and vice versa.
336   // A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to
337   // (S32, (F32[3], S32[2])) as their structures are both (,(,))
338   //
339   // In contrast, (F32, (F32, F32)) is structurally different from
340   // ((F32, F32), F32) as the former has structure (,(,)) while the latter has
341   // ((,),)
342   static bool EqualStructure(const Shape& lhs, const Shape& rhs);
343 
344   // Returns the number of dimensions for which the dimension is not (trivially)
345   // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
346   // fluff. Note that zero dimensions are included in the true rank, e.g.,
347   // f32[3,0,1] has a true rank of 2D.
348   static int64 TrueRank(const Shape& shape);
349 
350   static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
351                                        Shape result);
352 
353   ////////////////////
354   // Scalar-specific
355 
IsScalar(const Shape & shape)356   static bool IsScalar(const Shape& shape) {
357     return shape.IsArray() && shape.rank() == 0;
358   }
IsEffectiveScalar(const Shape & shape)359   static bool IsEffectiveScalar(const Shape& shape) {
360     return shape.IsArray() && TrueRank(shape) == 0;
361   }
362 
363   // Returns whether "shape" is a scalar (array) with the given element_type.
364   static bool IsScalarWithElementType(const Shape& shape,
365                                       PrimitiveType element_type);
366 
367   // Extracts the size of the shape's dimension at dimension number
368   // GetDimensionNumber(dimension_number).
369   static int64 GetDimension(const Shape& shape, int64 dimension_number);
370 
371   // Resolves a dimension number, supporting negative indexing.
372   //
373   // Negative indexing has similar semantics to Python. For an N-dimensional
374   // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
375   // N-2, and so on.
376   //
377   // This function always returns a positive dimension number for any given
378   // dimension_number (which itself can be negative).
379   static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number);
380 
381   // Returns a shape with the same dimensions as the original, but with the
382   // element type changed to type.
383   static Shape ChangeElementType(const Shape& original, PrimitiveType type);
384 
385   // Retursn a shape with same dimensions but with all dimensions set to static.
386   static Shape MakeStaticShape(const Shape& original);
387 
388   // Creates a tuple shape from a slice of element shapes within the tuple.
389   static Shape MakeTupleShape(absl::Span<const Shape> shapes);
390 
391   // Creates an opaque shape. These are generally used for threading a context
392   // into a custom operation.
393   static Shape MakeOpaqueShape();
394 
395   // Creates a token shape. Values of this shape are used for ordering
396   // side-effecting operations.
397   static Shape MakeTokenShape();
398 
399   // Appends a shape to the given tuple.
400   static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
401 
402   // Update a subshape of a tuple.
403   static void UpdateTupleShape(const Shape& shape, int64 index,
404                                Shape* tuple_shape);
405 
406   // Update the dynamic dimension for a shape. This shape can be a nested tuple.
407   static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index,
408                                      int64 dim, bool is_dynamic);
409 
410   // Appends a major dimension to the shape with the given bound.
411   static void AppendMajorDimension(int bound, Shape* shape);
412 
413   // Copy the dynamic dimensions property from one shape to another.
414   static void CopyDynamicDimensions(Shape* to, const Shape& from);
415 
416   // Returns an empty tuple shape. Can be used as a sentinel Shape value.
MakeNil()417   static Shape MakeNil() { return MakeTupleShape({}); }
418 
419   // Checks whether the shape is initialized.
IsInitialized(const Shape & shape)420   static bool IsInitialized(const Shape& shape) {
421     return shape.element_type() != PRIMITIVE_TYPE_INVALID;
422   }
423 
424   // Constructs a new shape with the given element type and sequence of
425   // dimensions.
426   static Shape MakeShape(PrimitiveType element_type,
427                          absl::Span<const int64> dimensions);
428 
429   // Make a scalar shape with given primitive type.
430   static Shape MakeScalarShape(PrimitiveType element_type);
431 
432   // Constructs a new shape with the given element type and sequence of
433   // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates
434   // with a true value that the respective dimension is dynamic. If the
435   // dimension is dynamic then the respective value in 'dimension' is an upper
436   // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be
437   // the same size.
438   static Shape MakeShape(PrimitiveType element_type,
439                          absl::Span<const int64> dimensions,
440                          const std::vector<bool>& dynamic_dimensions);
441 
442   // Constructs a new shape with the given element type and sequence of
443   // dimensions. Method checks if the element type is valid and the shape's
444   // size fits in std::numeric_limits<int64>::max().
445   static StatusOr<Shape> MakeValidatedShape(PrimitiveType element_type,
446                                             absl::Span<const int64> dimensions);
447   static StatusOr<Shape> MakeValidatedShape(
448       PrimitiveType element_type, absl::Span<const int64> dimensions,
449       const std::vector<bool>& dynamic_dimensions);
450 
451   // Creates a Shape with element type corresponding to T and the given
452   // dimensions
453   template <typename T>
MakeShapeWithType(absl::Span<const int64> dimensions)454   static Shape MakeShapeWithType(absl::Span<const int64> dimensions) {
455     return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(),
456                                 dimensions);
457   }
458 
459   // Constructs a new shape with the given minor_to_major order in its Layout.
460   // Returns a value shape such that shape.has_layout().
461   static Shape MakeShapeWithLayout(PrimitiveType element_type,
462                                    absl::Span<const int64> dimensions,
463                                    absl::Span<const int64> minor_to_major,
464                                    absl::Span<const Tile> tiles = {},
465                                    int64 element_size_in_bits = 0,
466                                    int64 memory_space = 0);
467 
468   // Returns the same shape except with all dimensions set to be static.
469   static Shape MakeShapeWithStaticDimensions(const Shape& shape);
470 
471   // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
472   static Shape MakeShapeWithDescendingLayout(
473       PrimitiveType element_type, absl::Span<const int64> dimensions);
474 
475   // Returns a new Shape based on the given Shape with low-dimension-major
476   // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
477   // rearranged so that it has the same in-memory layout as the given shape.
478   //
479   // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
480   static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
481       const Shape& shape);
482 
483   // As MakeShape, but the object to write to is passed in.
484   static Status PopulateShape(PrimitiveType element_type,
485                               absl::Span<const int64> dimensions, Shape* shape);
486 
487   // Validates that the provided shape satisfies invariants.
488   static Status ValidateShape(const Shape& shape);
489 
490   // Validates the provided shape satisfies invariants, except those that
491   // pertain to layout.
492   //
493   // Layout is optional for client-provided shapes, so that the compiler may
494   // determine and assign an optimized layout.
495   static Status ValidateShapeWithOptionalLayout(const Shape& shape);
496 
497   // Returns whether the element type of the shape is integral (signed or
498   // unsigned). Note that predicates are not considered integral here, since
499   // they are logical values.
500   static bool ElementIsIntegral(const Shape& shape);
501 
502   // Returns whether the element type of the shape is floating point.
503   static bool ElementIsFloating(const Shape& shape);
504 
505   // Returns whether the element type of the shape is complex.
506   static bool ElementIsComplex(const Shape& shape);
507 
508   // Returns whether the element type has the given bit width.
509   static bool ElementHasBitWidth(const Shape& shape, int bits);
510 
511   // Returns whether the element type of the shape is integral and has
512   // the specified number of bits.
513   static bool ElementIsIntegralWithBits(const Shape& shape, int bits);
514 
515   // Returns whether the element type of the shape is signed. Note
516   // that floating point numbers are signed.
517   static bool ElementIsSigned(const Shape& shape);
518 
519   // Returns whether the given primitive type corresponds to an array shape.
520   static bool IsArrayPrimitiveType(PrimitiveType primitive_type);
521 
522   // Returns whether the shape is a tuple with at least one element which is
523   // also a tuple.
524   static bool IsNestedTuple(const Shape& shape);
525 
526   // Returns true if shape is an empty tuple.
527   static bool IsEmptyTuple(const Shape& shape);
528 
529   // Returns the number of elements in the given tuple shape.
530   // Precondition: IsTuple(shape)
531   static int64 TupleElementCount(const Shape& shape);
532 
533   // Returns the tuple element shape at given index.
534   // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
535   static const Shape& GetTupleElementShape(const Shape& shape, int64 index);
536 
537   // Returns the number of elements, recursively, in the given shape.
538   static int64 SubshapeCount(const Shape& shape);
539 
540   // Slices tuple elements in the range [start, limit) and returns a new tuple
541   // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
542   static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit);
543 
544   // Returns the shape of the real/imaginary components of the given complex
545   // shape.
546   static Shape ComplexComponentShape(const Shape& complex_shape);
547 
548   // Returns true if the given shape has a subshape at the given index.
549   static bool IndexIsValid(const Shape& shape, ShapeIndexView index);
550 
551   // GetSubshape and GetMutableSubshape return a particular nested Shape within
552   // the given Shape argument. The non-Try variants check fail if index is
553   // invalid.
554   static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
555   static StatusOr<const Shape*> TryGetSubshape(const Shape& shape,
556                                                ShapeIndexView index);
557   static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
558 
559   // Returns whether the given index in the given shape is a leaf element of the
560   // shape.
561   static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);
562 
563   // Returns the number of leaves in the shape.
564   static int64 GetLeafCount(const Shape& shape);
565 
566   // Retrieves all the leaf shapes and their indexes, in the order walked by
567   // the ForEachSubshape() API.
568   static std::vector<IndexedShape> GetLeafShapes(const Shape& shape);
569 
570   // Calls the given visitor function for each subshape of the given shape.
571   // Subshapes are visited in DFS pre-order starting with the entire shape
572   // (index {}).
573   using VisitorFunction = std::function<void(const Shape& /*subshape*/,
574                                              const ShapeIndex& /*index*/)>;
575   static void ForEachSubshape(const Shape& shape, const VisitorFunction& func);
576   using MutatingVisitorFunction =
577       std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
578   static void ForEachMutableSubshape(Shape* shape,
579                                      const MutatingVisitorFunction& func);
580 
581   // Variants of ForEach(Mutable)Subshape which propagate Status from the
582   // visitor function.
583   using StatusVisitorFunction = std::function<Status(
584       const Shape& /*subshape*/, const ShapeIndex& /*index*/)>;
585   static Status ForEachSubshapeWithStatus(const Shape& shape,
586                                           const StatusVisitorFunction& func);
587   using MutatingStatusVisitorFunction =
588       std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
589   static Status ForEachMutableSubshapeWithStatus(
590       Shape* shape, const MutatingStatusVisitorFunction& func);
591 
592   // Returns true if `shape` (which must be an array) with degenerate dimensions
593   // (dimensions with bound 1).
594   static bool HasDegenerateDimensions(const Shape& shape);
595 
596   // Drops any degenerate dimensions (i.e. dimensions of size 1)
597   static Shape DropDegenerateDimensions(const Shape& shape);
598 
599   // Permutes the dimensions by the given permutation, so
600   // return_value.dimensions[i] = argument.dimensions[permutation[i]].
601   //
602   // Postcondition: For any valid permutation,
603   //
604   //   !HasLayout(shape) ||
605   //   TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
606   //                      permutation).
607   static Shape PermuteDimensions(absl::Span<const int64> permutation,
608                                  const Shape& shape);
609 
610   // If we can go from `shape_pre` to `shape_post` by merely inserting or
611   // deleting 1-sized dimensions, return the indices in `shape_pre` of the
612   // deleted dimensions and the indices in `dims_post` of the inserted
613   // dimensions.
614   // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
615   // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
616   // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
617   // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
618   // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
619   // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
620   static std::tuple<bool, std::vector<int64>, std::vector<int64>>
621   InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
622                                     const Shape& shape_post);
623 
624   // Suppose a reshape transforms input_shape to output shape. Returns a vector
625   // of pairs that indicate the input and output dimensions that this reshape
626   // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
627   // the returned vector, the reshape transforms any input index whose I-th
628   // dimension is x to an output index whose O-th dimension is x too.
629   //
630   // Post-condition: the returned vector is sorted (by both input and output
631   // dimensions because input and output dimensions have the same order).
632   //
633   // Example:
634   //   input  shape = T[a, b, x, y, cd]
635   //   output shape = T[ab, x, 1, y, c, d]
636   //   return value = {{2, 1}, {3, 3}}
637   //
638   //   The two pairs represent the input and output dimension of size x and
639   //   those of size y.
640   static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
641       const Shape& input_shape, const Shape& output_shape);
642 
643   // Return whether the given reshape instruction leaves the dimensions at the
644   // given input indices unmodified, and returns their output indices.
645   //
646   // Example:
647   //   input_dim_indices = {2, 3}
648   //   input  shape = T[a, b, x, y, cd]
649   //   output shape = T[ab, x, 1, y, c, d]
650   //   return value = {1, 3}
651   //
652   // Precondition: input_dim_indices is sorted.
653   static absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
654       const Shape& from_shape, const Shape& to_shape,
655       absl::Span<const int64> input_dim_indices);
656 
657   // Returns whether a transpose from input_shape to output_shape with dimension
658   // mapping "dimension_mapping" produces a result which is bit-wise identical
659   // to its input and thus may be replaced with a bitcast.
660   //
661   // Precondition: Both input_shape and output_shape have explicit layouts.
662   static bool TransposeIsBitcast(const Shape& input_shape,
663                                  const Shape& output_shape,
664                                  absl::Span<const int64> dimension_mapping);
665 
666   // Returns whether a reshape from "input_shape" to "output_shape" is a
667   // bitcast.
668   //
669   // Precondition: Both input_shape and output_shape have explicit layouts.
670   static bool ReshapeIsBitcast(const Shape& input_shape,
671                                const Shape& output_shape);
672 
673   // Find a physical layout for 'output_shape' such that
674   // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
675   // true (where 'output_shape_with_layout' is 'output_shape' with the found
676   // layout). The layout of 'input_shape' is kept fixed. Returns
677   // 'output_shape_with_layout' if such a layout can be found, and an error
678   // otherwise.
679   static absl::optional<Shape> AlignLayouts(const Shape& input_shape,
680                                             const Shape& output_shape);
681 
682   // Returns a shape with the given dimension deleted.
683   // For example:
684   // • `DeleteDimension(1, T[m, n, k]) = T[m, k]`
685   static Shape DeleteDimension(int64 dim_to_delete, Shape shape);
686 
687   // Returns a shape with all the dimensions of the input shape for which `p`
688   // returns true.
689   // For examples:
690   // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
691   // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
692   static Shape FilterDimensions(const std::function<bool(int64)>& p,
693                                 Shape shape);
694 
695   // Returns true if `dynamic_shape` has dimensions that are less-equal to the
696   // "bounded_shape". Shapes must be arrays.
697   static bool DynamicArrayShapeIsCompatible(const xla::Shape& dynamic_shape,
698                                             const xla::Shape& bounded_shape);
699 
700   // Same as DynamicArrayShapeIsCompatible() but supports tuples.
701   static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
702                                        const xla::Shape& bounded_shape);
703 
704   // Iterates through all the shape indexes, in minor to major order,
705   // starting from the base indexes, incrementing by the incr steps, up to
706   // count (index[i] < base[i] + count[i]), and calls the visitor_function
707   // with the current index. The visitor_function visitor function should
708   // return true if it wants to continue, or false otherwise.
709   //
710   // visitor_function must be a callable of type
711   // StatusOr<bool>(absl::Span<int64>) or compatible.
712   template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)713   static Status ForEachIndexWithStatus(const Shape& shape,
714                                        absl::Span<const int64> base,
715                                        absl::Span<const int64> count,
716                                        absl::Span<const int64> incr,
717                                        const FnType& visitor_function) {
718     return ForEachIndexInternal(shape, base, count, incr, visitor_function);
719   }
720 
721   // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
722   struct IndexIterationSpace {
723     std::vector<int64> index_base;
724     std::vector<int64> index_count;
725     std::vector<int64> index_incr;
726   };
727 
728   template <typename FnTy>
ForEachIndexWithStatus(const Shape & shape,const IndexIterationSpace & iteration_space,FnTy && function)729   static Status ForEachIndexWithStatus(
730       const Shape& shape, const IndexIterationSpace& iteration_space,
731       FnTy&& function) {
732     return ShapeUtil::ForEachIndexWithStatus(
733         shape, iteration_space.index_base, iteration_space.index_count,
734         iteration_space.index_incr, std::forward<FnTy>(function));
735   }
736 
737   template <typename FnType>
ForEachIndex(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)738   static void ForEachIndex(const Shape& shape, absl::Span<const int64> base,
739                            absl::Span<const int64> count,
740                            absl::Span<const int64> incr,
741                            const FnType& visitor_function) {
742     ForEachIndexWithStatus(shape, base, count, incr,
743                            [&](absl::Span<const int64> indices) {
744                              return StatusOr<bool>(visitor_function(indices));
745                            })
746         .IgnoreError();
747   }
748 
749   // These convenience wrappers don't take `base`, `count` and `incr`
750   // explicitly, but iterate over every element in `shape` instead.
751 
752   template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,const FnType & visitor_function)753   static Status ForEachIndexWithStatus(const Shape& shape,
754                                        const FnType& visitor_function) {
755     std::vector<int64> base(shape.dimensions_size());
756     std::vector<int64> incr(shape.dimensions_size(), 1);
757     return ForEachIndexWithStatus(shape, base,
758                                   /*count=*/AsInt64Slice(shape.dimensions()),
759                                   incr, visitor_function);
760   }
761 
762   template <typename FnType>
ForEachIndex(const Shape & shape,const FnType & visitor_function)763   static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
764     ForEachIndexWithStatus(shape, [&](absl::Span<const int64> indices) {
765       return StatusOr<bool>(visitor_function(indices));
766     }).IgnoreError();
767   }
768 
769   // A parallel version of ForEachIndex(WithStatus). This can only be used if
770   // the visitor_function is thread-safe and the order of iteration does not
771   // matter.
772   //
773   // visitor_function must be a callable of type
774   // void(Span<int64>) or compatible.
775   template <typename FnType>
ForEachIndexParallel(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)776   static void ForEachIndexParallel(const Shape& shape,
777                                    absl::Span<const int64> base,
778                                    absl::Span<const int64> count,
779                                    absl::Span<const int64> incr,
780                                    const FnType& visitor_function) {
781     // The parallel version of ForEachIndexInternal can never fail.
782     CHECK(ForEachIndexInternal(
783               shape, base, count, incr,
784               [&visitor_function](
785                   absl::Span<const int64> indexes) -> StatusOr<bool> {
786                 visitor_function(indexes);
787                 return true;
788               },
789               /*parallel=*/true)
790               .ok());
791   }
792 
793   // Compute a hash for `shape`.
794   static size_t Hash(const Shape& shape);
795 
796   // About 0-2-1 transpose:
797   //
798   // If a shape can be viewed as three logical components 0-1-2 in the order of
799   // major to minor, a 0-2-1-transpose changes the order of such logical
800   // components to 0-2-1. We call the shape being transposed the input shape and
801   // the transposed shape the output shape. The logical view of the input/output
802   // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the
803   // normalized shapes. The original input/output shapes are called unnormalized
804   // shapes.
805   //
806   // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
807   // normalized shape of `b` or the 0-2-1 shape.
808   static absl::optional<std::vector<int64>> FindTranspose021(const Shape& a,
809                                                              const Shape& b);
810 
811   // Strips device-specific information, namely tiling and memory-space
812   // information, from a shape.
813   static Shape DeviceShapeToHostShape(Shape s);
814 
815   // Returns true iff element type of shape `from` can be safely upcasted to
816   // element type of shape `to`.
817   static bool ElementCanUpcast(const Shape& from, const Shape& to);
818 
819  private:
820   // Fills *shape. Returns true on success.
821   // REQUIRES: *shape is empty.
822   static bool FillNewShape(PrimitiveType element_type,
823                            absl::Span<const int64> dimensions, Shape* shape);
824 
825   // Validates the shape size is sane. This makes sure it's safe to do
826   // calculations in int64 without overflowing.
827   static Status ValidateShapeSize(const Shape& shape);
828 
829   // Validates all of the non-layout properties of the shape -- this is a helper
830   // used by both the layout-optional and layout-required public method.
831   static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
832 
833   template <typename FnType>
834   static Status ForEachIndexInternal(const Shape& shape,
835                                      absl::Span<const int64> base,
836                                      absl::Span<const int64> count,
837                                      absl::Span<const int64> incr,
838                                      const FnType& visitor_function,
839                                      bool parallel = false) {
840     if (ShapeUtil::IsZeroElementArray(shape)) {
841       return Status::OK();
842     }
843     CHECK_EQ(shape.rank(), base.size());
844     CHECK_EQ(incr.size(), base.size());
845     CHECK_EQ(count.size(), base.size());
846     const int64 rank = LayoutUtil::MinorToMajor(shape).size();
847     // Allows handling R0 arrays, such that the visitor function will be called
848     // once with the proper empty indexes.
849     int64 n = -1;
850     std::vector<int64> indexes(base.begin(), base.end());
851     const int kNumThreads = tensorflow::port::MaxParallelism();
852     absl::optional<tensorflow::thread::ThreadPool> pool;
853     if (parallel) {
854       pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
855     }
856 
857     tensorflow::mutex mu;
858     Status status;  // Guarded by mu
859 
860     while (n < rank) {
861       if (pool != absl::nullopt) {
862         pool->Schedule([indexes, &visitor_function, &mu, &status] {
863           StatusOr<bool> result = visitor_function(indexes);
864           if (!result.ok()) {
865             tensorflow::mutex_lock lock(mu);
866             status = status.ok() ? result.status() : status;
867           }
868         });
869       } else {
870         TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
871         if (!should_continue) {
872           break;
873         }
874       }
875       // Increments dimensions in minor to major order.
876       for (n = 0; n < rank; ++n) {
877         int64 dim = LayoutUtil::Minor(shape.layout(), n);
878         indexes[dim] += incr[dim];
879         if (indexes[dim] < base[dim] + count[dim]) {
880           break;
881         }
882         indexes[dim] = base[dim];
883       }
884     }
885 
886     // Waits for the scheduled work to complete.
887     pool.reset();
888     return status;
889   }
890 
891   TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
892 };
893 
894 }  // namespace xla
895 
896 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
897