• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Shapes are protobuf messages, so this utility header offers a bunch of
17 // functionality for querying / poking at them.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
21 
22 #include <initializer_list>
23 #include <string>
24 
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/primitive_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/gtl/array_slice.h"
31 #include "tensorflow/core/lib/gtl/optional.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace xla {
36 
37 // An index for specifying a particular nested subshape within a shape. Used in
38 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
39 // structures (trees) and ShapeIndex defines a path through the tree where each
40 // element of ShapeIndex indexes into a tuple (or nested tuple) within the
41 // shape. For a non-nested tuple, an index has a single element. For example,
42 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
43 // {1} corresponds to array b. For a nested tuple, the index can have more than
44 // one element. For the nested tuple (a, (b, c, d), e) below are the values
45 // corresponding to the given indices:
46 //
47 //   index {0}    : array a
48 //   index {1, 2} : array d
49 //   index {2}    : array e
50 //   index {0, 0} : invalid index (element at {0} is an array not a tuple)
51 //
52 // For indexing into array shapes, the index is always trivially empty, ie {}.
53 //
54 // ShapeIndex is a trivial wrapper around std::vector with a minimum number of
55 // methods implemented.
56 class ShapeIndex {
57  public:
58   ShapeIndex() = default;
ShapeIndex(std::initializer_list<int64> init)59   ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
60 
empty()61   bool empty() const { return indices_.empty(); }
size()62   size_t size() const { return indices_.size(); }
push_back(int64 value)63   void push_back(int64 value) { indices_.push_back(value); }
pop_back()64   void pop_back() { indices_.pop_back(); }
65 
66   // push_front is O(n^2), but shapes don't usually have a ton of dimensions.
push_front(int64 value)67   void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
68 
begin()69   std::vector<int64>::const_iterator begin() const { return indices_.begin(); }
end()70   std::vector<int64>::const_iterator end() const { return indices_.end(); }
begin()71   std::vector<int64>::iterator begin() { return indices_.begin(); }
end()72   std::vector<int64>::iterator end() { return indices_.end(); }
73 
data()74   const int64* data() const { return indices_.data(); }
75 
back()76   int64 back() const { return indices_.back(); }
back()77   int64& back() { return indices_.back(); }
78 
79   const int64& operator[](size_t i) const { return indices_[i]; }
80   int64& operator[](size_t i) { return indices_[i]; }
81 
82   bool operator==(const ShapeIndex& other) const {
83     return indices_ == other.indices_;
84   }
85   bool operator!=(const ShapeIndex& other) const { return !(*this == other); }
86   bool operator<(const ShapeIndex& other) const {
87     return indices_ < other.indices_;
88   }
89 
90   string ToString() const;
91 
92  private:
93   std::vector<int64> indices_;
94 };
95 
96 // A view into a ShapeIndex as above, with the cheap/easy ability to consume the
97 // value at the front of the view.
98 //
99 // NB! ShapeIndexView does not own the memory backing the index array.
100 // The memory backing the index array should be owned by an object
101 // that lives longer than the ShapeIndexView instances pointing into
102 // it.
103 class ShapeIndexView {
104  public:
105   ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0)
106       : ShapeIndexView(shape_index.data() + offset,
107                        shape_index.data() + shape_index.size()) {
108     CHECK_LE(offset, shape_index.size());
109   }
ShapeIndexView(std::initializer_list<int64> indices)110   ShapeIndexView(std::initializer_list<int64> indices)
111       : ShapeIndexView(indices.begin(), indices.end()) {}
112   ShapeIndexView(const ShapeIndexView& other) = default;
113 
114   using iterator = const int64*;
115 
begin()116   iterator begin() const { return begin_; }
end()117   iterator end() const { return end_; }
size()118   int64 size() const { return std::distance(begin_, end_); }
empty()119   bool empty() const { return begin_ == end_; }
front()120   int64 front() const {
121     CHECK(!empty());
122     return *begin_;
123   }
ConsumeFront()124   ShapeIndexView ConsumeFront() const {
125     CHECK(!empty());
126     auto new_begin = begin_;
127     ++new_begin;
128     return ShapeIndexView(new_begin, end_);
129   }
130 
131   string ToString() const;
132 
133  private:
ShapeIndexView(iterator begin,iterator end)134   ShapeIndexView(iterator begin, iterator end) : begin_(begin), end_(end) {}
135 
136   iterator begin_;
137   iterator end_;
138 };
139 
140 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
141 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
142 
143 // Namespaced collection of (static) shape utilities.
144 //
145 // These are all effectively convenience functions for testing/tweaking proto
146 // properties, which do invariant checks before / after the operation.
147 class ShapeUtil {
148  public:
149   // Returns the number of elements are contained within the provided shape;
150   // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes
151   // may not actually be able to store this number of elements. See
152   // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
153   // elements that can be stored in a sparse shape.
154   // Precondition: !IsTuple(shape)
155   static int64 ElementsIn(const Shape& shape);
156 
157   // Returns true if 'shape' has zero elements.
158   static bool HasZeroElements(const Shape& shape);
159 
160   // Returns the number of bytes required for an allocation of shape.  The
161   // |pointer_size| parameter is used for calculating the size of tuple
162   // shapes. This includes only the size of the top-level buffer. For example, a
163   // tuple is stored as an array of pointers to other buffers. In this case,
164   // this method only returns the size of the pointer array.
165   // Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) &&
166   //               !ShapeUtil::IsOpaque(shape)
167   static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
168 
169   // Returns the number of bytes used to store the primitive_type.
170   //
171   // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
172   static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
173 
174   // Returns the number of bytes required to store the tuple member pointers for
175   // a allocation of shape. The `shape` must be a TUPLE shape, and
176   // `pointer_size` must be larger than zero.
177   static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
178                                          int64 pointer_size);
179 
180   // Returns the number of bytes required for the elements in an allocation of
181   // `shape`, which must be an array shape. The return value does not include
182   // the bytes needed to store sparse indices. Dense shapes use a separate
183   // memory location for each element, and so for these shapes,
184   // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this
185   // size also includes padding if present in the layout. For sparse shapes,
186   // `ByteSizeOf(shape) == ByteSizeOfElements(shape) +
187   // ByteSizeOfSparseindices(shape)`.
188   static int64 ByteSizeOfElements(const Shape& shape);
189 
190   // Returns the number of bytes required for the sparse indices in an
191   // allocation of shape. The shape must be an array shape. The return value
192   // does not include the bytes needed to store sparse indices.
193   static int64 ByteSizeOfSparseIndices(const Shape& shape);
194 
195   // Returns a human-readable string that represents the given shape, with or
196   // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
197   static string HumanString(const Shape& shape);
198   static string HumanStringWithLayout(const Shape& shape);
199 
200   // As above, but for program shapes, returns a string for the form:
201   //
202   // (param_name: f32[42x12], ...) -> f32[24x42]
203   static string HumanString(const ProgramShape& program_shape);
204 
205   // Parses a ShapeUtil::HumanString-format shape string back into a shape
206   // object.
207   static StatusOr<Shape> ParseShapeString(tensorflow::StringPiece s);
208 
209   // Returns whether the LHS and RHS shapes have the same dimensions; note: does
210   // not check element type.
211   static bool SameDimensions(const Shape& lhs, const Shape& rhs);
212 
213   // Returns whether the lhs and rhs shapes have the same element type.
SameElementType(const Shape & lhs,const Shape & rhs)214   static bool SameElementType(const Shape& lhs, const Shape& rhs) {
215     return lhs.element_type() == rhs.element_type();
216   }
217 
218   // As SameElementType, but allows floating point types to have different
219   // precisions.
SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)220   static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
221                                                  const Shape& b) {
222     if (ElementIsFloating(a) && ElementIsFloating(b)) {
223       return true;
224     }
225     return ShapeUtil::SameElementType(a, b);
226   }
227 
228   // Returns the higher-precision element type if a and b are both floating
229   // point types; otherwise, checks that that they have the same element type
230   // and returns it.
HigherPrecisionElementType(const Shape & a,const Shape & b)231   static PrimitiveType HigherPrecisionElementType(const Shape& a,
232                                                   const Shape& b) {
233     if (SameElementType(a, b)) {
234       return a.element_type();
235     }
236     CHECK(SameElementTypeIgnoringFpPrecision(a, b));
237     return primitive_util::BitWidth(a.element_type()) <
238                    primitive_util::BitWidth(b.element_type())
239                ? b.element_type()
240                : a.element_type();
241   }
242 
243   // Returns true if the rank, dimension sizes, and element type are
244   // identical. Layout is ignored. Tuple elements are compared recursively for
245   // compatibility.
246   static bool Compatible(const Shape& lhs, const Shape& rhs);
247 
248   // Returns true if the rank and dimension sizes are identical. Element type
249   // and layout are ignored. Tuple elements are compared recursively for
250   // compatibility.
251   static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
252 
253   // As Compatible, but allow one of lhs and rhs to be BF16 while the other
254   // being F32. Tuple elements are compared recursively for compatibility.
255   static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
256 
257   // Returns whether the lhs and rhs shapes are identical protobufs.
258   static bool Equal(const Shape& lhs, const Shape& rhs);
259 
260   // Returns the rank (number of dimensions) of the given shape.
261   // Precondition: !IsTuple(shape)
262   static int64 Rank(const Shape& shape);
263 
264   // Returns the number of dimensions for which the dimension is not (trivially)
265   // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
266   // fluff. Note that zero dimensions are included in the true rank, e.g.,
267   // f32[3,0,1] has a true rank of 2D.
268   static int64 TrueRank(const Shape& shape);
269 
270   static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
271                                        Shape result);
272 
273   ////////////////////
274   // Scalar-specific
275 
IsScalar(const Shape & shape)276   static bool IsScalar(const Shape& shape) {
277     return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0;
278   }
IsEffectiveScalar(const Shape & shape)279   static bool IsEffectiveScalar(const Shape& shape) {
280     return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0;
281   }
282   static bool IsScalarF32(const Shape& shape);
283 
284   // Extracts the size of the shape's dimension at dimension number
285   // GetDimensionNumber(dimension_number).
286   static int64 GetDimension(const Shape& shape, int64 dimension_number);
287 
288   // Resolves a dimension number, supporting negative indexing.
289   //
290   // Negative indexing has similar semantics to Python. For an N-dimensional
291   // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
292   // N-2, and so on.
293   //
294   // This function always returns a positive dimension number for any given
295   // dimension_number (which itself can be negative).
296   static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number);
297 
298   // Returns a shape with the same dimensions as the original, but with the
299   // element type changed to type.
300   static Shape ChangeElementType(const Shape& original, PrimitiveType type);
301 
302   // Creates a tuple shape from a slice of element shapes within the tuple.
303   static Shape MakeTupleShape(tensorflow::gtl::ArraySlice<Shape> shapes);
304 
305   // Creates an opaque shape. These are generally used for threading a context
306   // into a custom operation.
307   static Shape MakeOpaqueShape();
308 
309   // Appends a shape to the given tuple.
310   static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
311 
312   // Appends a major dimension to the shape with the given bound.
313   static void AppendMajorDimension(int bound, Shape* shape);
314 
315   // Returns an empty tuple shape. Can be used to indicate side-effects.
MakeNil()316   static Shape MakeNil() { return MakeTupleShape({}); }
317 
318   // Constructs a new shape with the given element type and sequence of
319   // dimensions.
320   static Shape MakeShape(PrimitiveType element_type,
321                          tensorflow::gtl::ArraySlice<int64> dimensions);
322 
323   // Constructs a new shape with the given minor_to_major order in its Layout.
324   // Returns a value shape such that shape.has_layout().
325   static Shape MakeShapeWithLayout(
326       PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
327       tensorflow::gtl::ArraySlice<int64> minor_to_major);
328 
329   static Shape MakeShapeWithSparseLayout(
330       PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
331       int64 max_sparse_elements);
332 
333   // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
334   static Shape MakeShapeWithDescendingLayout(
335       PrimitiveType element_type,
336       tensorflow::gtl::ArraySlice<int64> dimensions);
337 
338   // Returns a new Shape based on the given Shape with low-dimension-major
339   // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
340   // rearranged so that it has the same in-memory layout as the given shape.
341   //
342   // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
343   static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
344       const Shape& shape);
345 
346   // As MakeShape, but the object to write to is passed in.
347   static void PopulateShape(PrimitiveType element_type,
348                             tensorflow::gtl::ArraySlice<int64> dimensions,
349                             Shape* shape);
350 
351   // Validates that the provided shape satisfies invariants.
352   static Status ValidateShape(const Shape& shape);
353 
354   // Validates the provided shape satisfies invariants, except those that
355   // pertain to layout.
356   //
357   // Layout is optional for client-provided shapes, so that the compiler may
358   // determine and assign an optimized layout.
359   static Status ValidateShapeWithOptionalLayout(const Shape& shape);
360 
361   // Returns whether the element type of the shape is integral (signed or
362   // unsigned). Note that predicates are not considered integral here, since
363   // they are logical values.
364   static bool ElementIsIntegral(const Shape& shape);
365 
366   // Returns whether the element type of the shape is floating point.
367   static bool ElementIsFloating(const Shape& shape);
368 
369   // Returns whether the element type of the shape is complex.
370   static bool ElementIsComplex(const Shape& shape);
371 
372   // Returns whether the element type has the given bit width.
373   static bool ElementHasBitWidth(const Shape& shape, int bits);
374 
375   // Returns whether the element type of the shape is integral and has
376   // the specified number of bits.
377   static bool ElementIsIntegralWithBits(const Shape& shape, int bits);
378 
379   // Returns whether the element type of the shape is signed. Note
380   // that floating point numbers are signed.
381   static bool ElementIsSigned(const Shape& shape);
382 
383   // Returns whether the shape is a tuple.
IsTuple(const Shape & shape)384   static bool IsTuple(const Shape& shape) {
385     return shape.element_type() == TUPLE;
386   }
387 
388   // Returns whether the shape is an opaque value (i.e. an 'existential' typed
389   // value that is passed to CustomCall operations).
IsOpaque(const Shape & shape)390   static bool IsOpaque(const Shape& shape) {
391     return shape.element_type() == OPAQUE;
392   }
393 
394   // Returns whether the shape is an array.  Note that scalars are considered
395   // arrays.
IsArray(const Shape & shape)396   static bool IsArray(const Shape& shape) {
397     return !IsTuple(shape) && !IsOpaque(shape);
398   }
399 
400   // Returns whether the shape is a tuple with at least one element which is
401   // also a tuple.
402   static bool IsNestedTuple(const Shape& shape);
403 
404   // Returns true if shape is an empty tuple.
405   static bool IsEmptyTuple(const Shape& shape);
406 
407   // Returns true if shape is an empty tuple, or is an array with no elements.
408   static bool IsNil(const Shape& shape);
409 
410   // Returns the number of elements in the given tuple shape.
411   // Precondition: IsTuple(shape)
412   static int64 TupleElementCount(const Shape& shape);
413 
414   // Returns the tuple element shape at given index.
415   // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
416   static const Shape& GetTupleElementShape(const Shape& shape, int64 index);
417 
418   // Slices tuple elements in the range [start, limit) and returns a new tuple
419   // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
420   static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit);
421 
422   // Returns the shape of the real/imaginary components of the given complex
423   // shape.
424   static Shape ComplexComponentShape(const Shape& complex_shape);
425 
426   // Shorthand for testing whether a shape is of a given element type and
427   // sequence of dimensions.
428   //
429   // DEPRECATED: Use Equal() instead.
430   static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
431                       std::initializer_list<int64> dimensions);
432 
433   // GetSubshape and GetMutableSubshape return a particular nested Shape within
434   // the given Shape argument.
435   static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
436   static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
437 
438   // Returns whether the given index in the given shape is a leaf element of the
439   // shape.
440   static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);
441 
442   // Calls the given visitor function for each subshape of the given shape.
443   // Subshapes are visited in DFS pre-order starting with the entire shape
444   // (index {}).
445   using VisitorFunction = std::function<void(const Shape& /*subshape*/,
446                                              const ShapeIndex& /*index*/)>;
447   static void ForEachSubshape(const Shape& shape, const VisitorFunction& func);
448   using MutatingVisitorFunction =
449       std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
450   static void ForEachMutableSubshape(Shape* shape,
451                                      const MutatingVisitorFunction& func);
452 
453   // Variants of ForEach(Mutable)Subshape which propagate Status from the
454   // visitor function.
455   using StatusVisitorFunction = std::function<Status(
456       const Shape& /*subshape*/, const ShapeIndex& /*index*/)>;
457   static Status ForEachSubshapeWithStatus(const Shape& shape,
458                                           const StatusVisitorFunction& func);
459   using MutatingStatusVisitorFunction =
460       std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
461   static Status ForEachMutableSubshapeWithStatus(
462       Shape* shape, const MutatingStatusVisitorFunction& func);
463 
464   // Removes all degenerate dimensions (size one) from the given shape. The
465   // stripped minor_to_major preserves the relative ordering of non-degenerate
466   // dimensions. The stripped shape has the property that the underlying
467   // representation (bits in memory) for the stripped shape is the same as the
468   // original shape modulo padding. Examples:
469   //
470   // input shape:    F32 [1, 2, 1], minor_to_major = {0, 1, 2}
471   // stripped shape: F32 [2], minor_to_major = {0}
472   //
473   // input shape:    F32 [6, 1, 5], minor_to_major = {2, 0, 1}
474   // stripped shape: F32 [6, 5], minor_to_major = {1, 0}
475   //
476   // input shape:    F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1}
477   // stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1}
478   //
479   // input shape:    F32 [1, 1], minor_to_major = {0, 1}
480   // stripped shape: F32 [], minor_to_major = {}
481   // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
482   static Shape StripDegenerateDimensions(const Shape& shape);
483 
484   // Permutes the dimensions by the given permutation, so
485   // return_value.dimensions[permutation[i]] = argument.dimensions[i]
486   static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,
487                                  const Shape& shape);
488 
489   // If we can go from `shape_pre` to `shape_post` by merely inserting or
490   // deleting 1-sized dimensions, return the indices in `shape_pre` of the
491   // deleted dimensions and the indices in `dims_post` of the inserted
492   // dimensions.
493   // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
494   // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
495   // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
496   // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
497   // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
498   // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
499   static std::tuple<bool, std::vector<int64>, std::vector<int64>>
500   InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
501                                     const Shape& shape_post);
502 
503   // Suppose a reshape transforms input_shape to output shape. Returns a vector
504   // of pairs that indicate the input and output dimensions that this reshape
505   // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
506   // the returned vector, the reshape transforms any input index whose I-th
507   // dimension is x to an output index whose O-th dimension is x too.
508   //
509   // Post-condition: the returned vector is sorted (by both input and output
510   // dimensions because input and output dimensions have the same order).
511   //
512   // Example:
513   //   input  shape = T[a, b, x, y, cd]
514   //   output shape = T[ab, x, 1, y, c, d]
515   //   return value = {{2, 1}, {3, 3}}
516   //
517   //   The two pairs represent the input and output dimension of size x and
518   //   those of size y.
519   static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
520       const Shape& input_shape, const Shape& output_shape);
521 
522   // Returns whether a transpose from input_shape to output_shape with dimension
523   // mapping "dimension_mapping" produces a result which is bit-wise identical
524   // to its input and thus may be replaced with a bitcast.
525   static bool TransposeIsBitcast(
526       const Shape& input_shape, const Shape& output_shape,
527       tensorflow::gtl::ArraySlice<int64> dimension_mapping);
528 
529   // Returns whether a reshape from "input_shape" to "output_shape" is a
530   // bitcast.
531   static bool ReshapeIsBitcast(const Shape& input_shape,
532                                const Shape& output_shape);
533 
534   // Find a physical layout for 'output_shape' such that
535   // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
536   // true (where 'output_shape_with_layout' is 'output_shape' with the found
537   // layout). The layout of 'input_shape' is kept fixed. Returns
538   // 'output_shape_with_layout' if such a layout can be found, and an error
539   // otherwise.
540   static tensorflow::gtl::optional<Shape> AlignLayouts(
541       const Shape& input_shape, const Shape& output_shape);
542 
543   // Returns a shape with the given dimension deleted.
544   // For example:
545   // • `DeleteDimension(1, T[m, n, k]) = T[m, k]`
546   static Shape DeleteDimension(int64 dim_to_delete, Shape shape);
547 
548   // Returns a shape with all the dimensions of the input shape for which `p`
549   // returns true.
550   // For examples:
551   // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
552   // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
553   static Shape FilterDimensions(const std::function<bool(int64)>& p,
554                                 Shape shape);
555 
556   // Iterates through all the shape indexes, in minor to major order, starting
557   // from the base indexes, incrementing by the incr steps, up to count
558   // (index[i] < base[i] + count[i]), and calls the visitor_function with the
559   // current index.
560   // The visitor_function visitor function should return true if it wants to
561   // continue, or false otherwise.
562   //
563   // visitor_function must be a callable of type bool(const std::vector<int64>&)
564   // or compatible.
565   template <typename FnType>
ForEachIndex(const Shape & shape,tensorflow::gtl::ArraySlice<int64> base,tensorflow::gtl::ArraySlice<int64> count,tensorflow::gtl::ArraySlice<int64> incr,const FnType & visitor_function)566   static void ForEachIndex(const Shape& shape,
567                            tensorflow::gtl::ArraySlice<int64> base,
568                            tensorflow::gtl::ArraySlice<int64> count,
569                            tensorflow::gtl::ArraySlice<int64> incr,
570                            const FnType& visitor_function) {
571     if (ShapeUtil::HasZeroElements(shape)) {
572       return;
573     }
574     CHECK_EQ(Rank(shape), base.size());
575     CHECK_EQ(incr.size(), base.size());
576     CHECK_EQ(count.size(), base.size());
577     const int64 rank = LayoutUtil::MinorToMajor(shape).size();
578     // Allows handling R0 arrays, such that the visitor function will be called
579     // once with the proper empty indexes.
580     int64 n = -1;
581     std::vector<int64> indexes(base.begin(), base.end());
582     while (n < rank && visitor_function(indexes)) {
583       // Increments dimensions in minor to major order.
584       for (n = 0; n < rank; ++n) {
585         int64 dim = LayoutUtil::Minor(shape.layout(), n);
586         indexes[dim] += incr[dim];
587         if (indexes[dim] < base[dim] + count[dim]) {
588           break;
589         }
590         indexes[dim] = base[dim];
591       }
592     }
593   }
594 
595  private:
596   // Validates all of the non-layout properties of the shape -- this is a helper
597   // used by both the layout-optional and layout-required public method.
598   static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
599 
600   TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
601 };
602 
603 std::ostream& operator<<(std::ostream& out, const Shape& shape);
604 
605 }  // namespace xla
606 
607 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
608