1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Shapes are protobuf messages, so this utility header offers a bunch of
17 // functionality for querying / poking at them.
18
19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
21
22 #include <initializer_list>
23 #include <string>
24
25 #include "absl/base/macros.h"
26 #include "absl/container/inlined_vector.h"
27 #include "absl/types/optional.h"
28 #include "absl/types/span.h"
29 #include "tensorflow/compiler/xla/layout_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/shape.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/lib/core/threadpool.h"
38 #include "tensorflow/core/platform/cpu_info.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/macros.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/types.h"
43
44 namespace xla {
45
46 class ShapeIndexView;
47
48 // An index for specifying a particular nested subshape within a shape. Used in
49 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
50 // structures (trees) and ShapeIndex defines a path through the tree where each
51 // element of ShapeIndex indexes into a tuple (or nested tuple) within the
52 // shape. For a non-nested tuple, an index has a single element. For example,
53 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
54 // {1} corresponds to array b. For a nested tuple, the index can have more than
55 // one element. For the nested tuple (a, (b, c, d), e) below are the values
56 // corresponding to the given indices:
57 //
58 // index {0} : array a
59 // index {1, 2} : array d
60 // index {2} : array e
61 // index {0, 0} : invalid index (element at {0} is an array not a tuple)
62 //
63 // For indexing into array shapes, the index is always trivially empty, ie {}.
64 //
65 // ShapeIndex is a trivial wrapper around std::vector with a minimum number of
66 // methods implemented.
67 class ShapeIndex {
68 public:
69 ShapeIndex() = default;
ShapeIndex(std::initializer_list<int64> init)70 ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
71 template <typename InputIt>
ShapeIndex(InputIt start,InputIt end)72 ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {}
73
74 explicit ShapeIndex(ShapeIndexView v);
75
empty()76 bool empty() const { return indices_.empty(); }
size()77 size_t size() const { return indices_.size(); }
push_back(int64 value)78 void push_back(int64 value) { indices_.push_back(value); }
pop_back()79 void pop_back() { indices_.pop_back(); }
80
81 // push_front is O(n), but shapes don't usually have a ton of dimensions.
push_front(int64 value)82 void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
83
84 using container_type = absl::InlinedVector<int64, 2>;
85
begin()86 container_type::const_iterator begin() const { return indices_.begin(); }
end()87 container_type::const_iterator end() const { return indices_.end(); }
begin()88 container_type::iterator begin() { return indices_.begin(); }
end()89 container_type::iterator end() { return indices_.end(); }
90
data()91 const int64* data() const { return indices_.data(); }
92
back()93 int64 back() const { return indices_.back(); }
back()94 int64& back() { return indices_.back(); }
95
96 const int64& operator[](size_t i) const { return indices_[i]; }
97 int64& operator[](size_t i) { return indices_[i]; }
98
99 bool operator==(const ShapeIndex& other) const {
100 return indices_ == other.indices_;
101 }
102 bool operator!=(const ShapeIndex& other) const { return !(*this == other); }
103 bool operator<(const ShapeIndex& other) const {
104 return indices_ < other.indices_;
105 }
106
107 string ToString() const;
108
109 template <typename H>
AbslHashValue(H h,const ShapeIndex & index)110 friend H AbslHashValue(H h, const ShapeIndex& index) {
111 return H::combine(std::move(h), index.indices_);
112 }
113
114 private:
115 container_type indices_;
116 };
117
118 // A view into a ShapeIndex as above, with the cheap/easy ability to consume the
119 // value at the front of the view.
120 //
121 // NB! ShapeIndexView does not own the memory backing the index array.
122 // The memory backing the index array should be owned by an object
123 // that lives longer than the ShapeIndexView instances pointing into
124 // it.
125 class ShapeIndexView {
126 public:
127 ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0)
128 : indices_(shape_index.data() + offset, shape_index.size() - offset) {
129 CHECK_LE(offset, shape_index.size());
130 }
ShapeIndexView(std::initializer_list<int64> indices)131 ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {}
132 ShapeIndexView(const ShapeIndexView& other) = default;
133
134 using iterator = const int64*;
135
begin()136 iterator begin() const { return indices_.begin(); }
end()137 iterator end() const { return indices_.end(); }
size()138 int64 size() const { return indices_.size(); }
empty()139 bool empty() const { return indices_.empty(); }
front()140 int64 front() const {
141 CHECK(!empty());
142 return indices_.front();
143 }
back()144 int64 back() const {
145 CHECK(!empty());
146 return indices_.back();
147 }
ConsumeFront()148 ShapeIndexView ConsumeFront() const {
149 ShapeIndexView result = *this;
150 result.indices_.remove_prefix(1);
151 return result;
152 }
ConsumeBack()153 ShapeIndexView ConsumeBack() const {
154 ShapeIndexView result = *this;
155 result.indices_.remove_suffix(1);
156 return result;
157 }
ToShapeIndex()158 ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }
159
160 bool operator==(const ShapeIndexView& other) const;
161 bool operator!=(const ShapeIndexView& other) const;
162
163 string ToString() const;
164
165 // Returns true if this shape index starts with 'prefix'.
166 bool StartsWith(ShapeIndexView prefix) const;
167
168 private:
169 absl::Span<const int64> indices_;
170 };
171
ShapeIndex(ShapeIndexView v)172 inline ShapeIndex::ShapeIndex(ShapeIndexView v)
173 : ShapeIndex(v.begin(), v.end()) {}
174
175 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
176 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
177
178 // Namespaced collection of (static) shape utilities.
179 //
180 // These are all effectively convenience functions for testing/tweaking proto
181 // properties, which do invariant checks before / after the operation.
182 class ShapeUtil {
183 public:
184 // Data structure which describes the coordinates and the shape, of a tuple
185 // shaped sub-shape.
186 struct IndexedShape {
187 IndexedShape() = default;
IndexedShapeIndexedShape188 IndexedShape(ShapeIndex index, Shape shape)
189 : index(std::move(index)), shape(std::move(shape)) {}
190 ShapeIndex index;
191 Shape shape;
192 };
193
194 // Returns the number of elements are contained within the provided shape;
195 // e.g. for rank 0 (scalars) the result is always 1.
196 // Precondition: shape.IsArray()
197 static int64 ElementsIn(const Shape& shape);
198
199 // As ElementsIn(), but recurses through tuples.
200 static int64 ElementsInRecursive(const Shape& shape);
201
202 // Returns true if shape has the primitive type, recurses through tuples.
203 static bool HasPrimitiveType(const Shape& shape,
204 PrimitiveType primitive_type);
205
206 // Returns true if 'shape' is an array with zero elements.
207 static bool IsZeroElementArray(const Shape& shape);
208
209 // Returns the number of bytes required for an allocation of shape. The
210 // |pointer_size| parameter is used for calculating the size of tuple
211 // shapes. This includes only the size of the top-level buffer. For example, a
212 // tuple is stored as an array of pointers to other buffers. In this case,
213 // this method only returns the size of the pointer array.
214 static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
215
216 // Returns the number of bytes used to store the primitive_type.
217 //
218 // Precondition: shape.IsArray()
219 static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
220
221 // Returns the number of bytes required to store the tuple member pointers for
222 // a allocation of shape. The `shape` must be a TUPLE shape, and
223 // `pointer_size` must be larger than zero.
224 static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
225 int64 pointer_size);
226
227 // Returns the number of bytes required for the elements in an allocation of
228 // `shape`, which must be an array shape. Shapes use a separate
229 // memory location for each element, and so for these shapes,
230 // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. This
231 // size also includes padding if present in the layout.
232 static int64 ByteSizeOfElements(const Shape& shape);
233
234 // Returns a human-readable string that represents the given shape, with or
235 // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
236 static string HumanString(const Shape& shape);
237 static string HumanStringWithLayout(const Shape& shape);
238
239 // As above, but for program shapes, returns a string for the form:
240 //
241 // (param_name: f32[42x12], ...) -> f32[24x42]
242 static string HumanString(const ProgramShape& program_shape);
243
244 // Returns whether the LHS and RHS shapes have the same dimensions; note: does
245 // not check element type.
246 // Precondition: IsArray(lhs) && IsArray(rhs)
247 static bool SameDimensions(const Shape& lhs, const Shape& rhs);
248
249 // Returns whether the lhs and rhs shapes have the same element type.
SameElementType(const Shape & lhs,const Shape & rhs)250 static bool SameElementType(const Shape& lhs, const Shape& rhs) {
251 return lhs.element_type() == rhs.element_type();
252 }
253
254 // As SameElementType, but allows floating point types to have different
255 // precisions.
SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)256 static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
257 const Shape& b) {
258 if (ElementIsFloating(a) && ElementIsFloating(b)) {
259 return true;
260 }
261 return ShapeUtil::SameElementType(a, b);
262 }
263
264 // Returns the higher-precision element type if a and b are both floating
265 // point types; otherwise, checks that they have the same element type
266 // and returns it.
HigherPrecisionElementType(const Shape & a,const Shape & b)267 static PrimitiveType HigherPrecisionElementType(const Shape& a,
268 const Shape& b) {
269 if (SameElementType(a, b)) {
270 return a.element_type();
271 }
272 return primitive_util::BitWidth(a.element_type()) <
273 primitive_util::BitWidth(b.element_type())
274 ? b.element_type()
275 : a.element_type();
276 }
277
278 // Returns true if the rank, dimension sizes, and element type are
279 // identical. Layout is ignored. Tuple elements are compared recursively for
280 // compatibility.
281 static bool Compatible(const Shape& lhs, const Shape& rhs);
282
283 // Returns true if the rank and dimension sizes are identical. Element type
284 // and layout are ignored. Tuple elements are compared recursively for
285 // compatibility.
286 static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
287
288 // As Compatible, but allow one of lhs and rhs to be BF16 while the other
289 // being F32. Tuple elements are compared recursively for compatibility.
290 static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
291
292 // Returns whether the lhs and rhs shapes are identical.
293 static bool Equal(const Shape& lhs, const Shape& rhs);
294
295 // As Equal, but does not compare the element type.
296 static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs);
297
298 // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
299 static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
300
301 // Returns the number of dimensions for which the dimension is not (trivially)
302 // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
303 // fluff. Note that zero dimensions are included in the true rank, e.g.,
304 // f32[3,0,1] has a true rank of 2D.
305 static int64 TrueRank(const Shape& shape);
306
307 static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
308 Shape result);
309
310 ////////////////////
311 // Scalar-specific
312
IsScalar(const Shape & shape)313 static bool IsScalar(const Shape& shape) {
314 return shape.IsArray() && shape.rank() == 0;
315 }
IsEffectiveScalar(const Shape & shape)316 static bool IsEffectiveScalar(const Shape& shape) {
317 return shape.IsArray() && TrueRank(shape) == 0;
318 }
319
320 // Returns whether "shape" is a scalar (array) with the given element_type.
321 static bool IsScalarWithElementType(const Shape& shape,
322 PrimitiveType element_type);
323
324 // Extracts the size of the shape's dimension at dimension number
325 // GetDimensionNumber(dimension_number).
326 static int64 GetDimension(const Shape& shape, int64 dimension_number);
327
328 // Resolves a dimension number, supporting negative indexing.
329 //
330 // Negative indexing has similar semantics to Python. For an N-dimensional
331 // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
332 // N-2, and so on.
333 //
334 // This function always returns a positive dimension number for any given
335 // dimension_number (which itself can be negative).
336 static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number);
337
338 // Returns a shape with the same dimensions as the original, but with the
339 // element type changed to type.
340 static Shape ChangeElementType(const Shape& original, PrimitiveType type);
341
342 // Creates a tuple shape from a slice of element shapes within the tuple.
343 static Shape MakeTupleShape(absl::Span<const Shape> shapes);
344
345 // Creates an opaque shape. These are generally used for threading a context
346 // into a custom operation.
347 static Shape MakeOpaqueShape();
348
349 // Creates a token shape. Values of this shape are used for ordering
350 // side-effecting operations.
351 static Shape MakeTokenShape();
352
353 // Appends a shape to the given tuple.
354 static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
355
356 // Update a subshape of a tuple.
357 static void UpdateTupleShape(const Shape& shape, int64 index,
358 Shape* tuple_shape);
359
360 // Update the dynamic dimension for a shape. This shape can be a nested tuple.
361 static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index,
362 int64 dim, bool is_dynamic);
363
364 // Appends a major dimension to the shape with the given bound.
365 static void AppendMajorDimension(int bound, Shape* shape);
366
367 // Returns an empty tuple shape. Can be used as a sentinel Shape value.
MakeNil()368 static Shape MakeNil() { return MakeTupleShape({}); }
369
370 // Checks whether the shape is initialized.
IsInitialized(const Shape & shape)371 static bool IsInitialized(const Shape& shape) {
372 return shape.element_type() != PRIMITIVE_TYPE_INVALID;
373 }
374
375 // Constructs a new shape with the given element type and sequence of
376 // dimensions.
377 static Shape MakeShape(PrimitiveType element_type,
378 absl::Span<const int64> dimensions);
379
380 // Make a scalar shape with given primitive type.
381 static Shape MakeScalarShape(PrimitiveType element_type);
382
383 // Constructs a new shape with the given element type and sequence of
384 // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates
385 // with a true value that the respective dimension is dynamic. If the
386 // dimension is dynamic then the respective value in 'dimension' is an upper
387 // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be
388 // the same size.
389 static Shape MakeShape(PrimitiveType element_type,
390 absl::Span<const int64> dimensions,
391 const std::vector<bool>& dynamic_dimensions);
392
393 // Constructs a new shape with the given element type and sequence of
394 // dimensions. Method checks if the element type is valid and the shape's
395 // size fits in std::numeric_limits<int64>::max().
396 static StatusOr<Shape> MakeValidatedShape(PrimitiveType element_type,
397 absl::Span<const int64> dimensions);
398 static StatusOr<Shape> MakeValidatedShape(
399 PrimitiveType element_type, absl::Span<const int64> dimensions,
400 const std::vector<bool>& dynamic_dimensions);
401
402 // Creates a Shape with element type corresponding to T and the given
403 // dimensions
404 template <typename T>
MakeShapeWithType(absl::Span<const int64> dimensions)405 static Shape MakeShapeWithType(absl::Span<const int64> dimensions) {
406 return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(),
407 dimensions);
408 }
409
410 // Constructs a new shape with the given minor_to_major order in its Layout.
411 // Returns a value shape such that shape.has_layout().
412 static Shape MakeShapeWithLayout(PrimitiveType element_type,
413 absl::Span<const int64> dimensions,
414 absl::Span<const int64> minor_to_major,
415 absl::Span<const Tile> tiles = {},
416 int64 element_size_in_bits = 0,
417 int64 memory_space = 0);
418
419 // Returns the same shape except with all dimensions set to be static.
420 static Shape MakeShapeWithStaticDimensions(const Shape& shape);
421
422 // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
423 static Shape MakeShapeWithDescendingLayout(
424 PrimitiveType element_type, absl::Span<const int64> dimensions);
425
426 // Returns a new Shape based on the given Shape with low-dimension-major
427 // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
428 // rearranged so that it has the same in-memory layout as the given shape.
429 //
430 // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
431 static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
432 const Shape& shape);
433
434 // As MakeShape, but the object to write to is passed in.
435 static Status PopulateShape(PrimitiveType element_type,
436 absl::Span<const int64> dimensions, Shape* shape);
437
438 // Validates that the provided shape satisfies invariants.
439 static Status ValidateShape(const Shape& shape);
440
441 // Validates the provided shape satisfies invariants, except those that
442 // pertain to layout.
443 //
444 // Layout is optional for client-provided shapes, so that the compiler may
445 // determine and assign an optimized layout.
446 static Status ValidateShapeWithOptionalLayout(const Shape& shape);
447
448 // Returns whether the element type of the shape is integral (signed or
449 // unsigned). Note that predicates are not considered integral here, since
450 // they are logical values.
451 static bool ElementIsIntegral(const Shape& shape);
452
453 // Returns whether the element type of the shape is floating point.
454 static bool ElementIsFloating(const Shape& shape);
455
456 // Returns whether the element type of the shape is complex.
457 static bool ElementIsComplex(const Shape& shape);
458
459 // Returns whether the element type has the given bit width.
460 static bool ElementHasBitWidth(const Shape& shape, int bits);
461
462 // Returns whether the element type of the shape is integral and has
463 // the specified number of bits.
464 static bool ElementIsIntegralWithBits(const Shape& shape, int bits);
465
466 // Returns whether the element type of the shape is signed. Note
467 // that floating point numbers are signed.
468 static bool ElementIsSigned(const Shape& shape);
469
470 // Returns whether the given primitive type corresponds to an array shape.
471 static bool IsArrayPrimitiveType(PrimitiveType primitive_type);
472
473 // Returns whether the shape is a tuple with at least one element which is
474 // also a tuple.
475 static bool IsNestedTuple(const Shape& shape);
476
477 // Returns true if shape is an empty tuple.
478 static bool IsEmptyTuple(const Shape& shape);
479
480 // Returns the number of elements in the given tuple shape.
481 // Precondition: IsTuple(shape)
482 static int64 TupleElementCount(const Shape& shape);
483
484 // Returns the tuple element shape at given index.
485 // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
486 static const Shape& GetTupleElementShape(const Shape& shape, int64 index);
487
488 // Returns the number of elements, recursively, in the given shape.
489 static int64 SubshapeCount(const Shape& shape);
490
491 // Slices tuple elements in the range [start, limit) and returns a new tuple
492 // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
493 static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit);
494
495 // Returns the shape of the real/imaginary components of the given complex
496 // shape.
497 static Shape ComplexComponentShape(const Shape& complex_shape);
498
499 // Returns true if the given shape has a subshape at the given index.
500 static bool IndexIsValid(const Shape& shape, ShapeIndexView index);
501
502 // GetSubshape and GetMutableSubshape return a particular nested Shape within
503 // the given Shape argument. The non-Try variants check fail if index is
504 // invalid.
505 static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
506 static StatusOr<const Shape*> TryGetSubshape(const Shape& shape,
507 ShapeIndexView index);
508 static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
509
510 // Returns whether the given index in the given shape is a leaf element of the
511 // shape.
512 static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);
513
514 // Returns the number of leaves in the shape.
515 static int64 GetLeafCount(const Shape& shape);
516
517 // Retrieves all the leaf shapes and their indexes, in the order walked by
518 // the ForEachSubshape() API.
519 static std::vector<IndexedShape> GetLeafShapes(const Shape& shape);
520
521 // Calls the given visitor function for each subshape of the given shape.
522 // Subshapes are visited in DFS pre-order starting with the entire shape
523 // (index {}).
524 using VisitorFunction = std::function<void(const Shape& /*subshape*/,
525 const ShapeIndex& /*index*/)>;
526 static void ForEachSubshape(const Shape& shape, const VisitorFunction& func);
527 using MutatingVisitorFunction =
528 std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
529 static void ForEachMutableSubshape(Shape* shape,
530 const MutatingVisitorFunction& func);
531
532 // Variants of ForEach(Mutable)Subshape which propagate Status from the
533 // visitor function.
534 using StatusVisitorFunction = std::function<Status(
535 const Shape& /*subshape*/, const ShapeIndex& /*index*/)>;
536 static Status ForEachSubshapeWithStatus(const Shape& shape,
537 const StatusVisitorFunction& func);
538 using MutatingStatusVisitorFunction =
539 std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
540 static Status ForEachMutableSubshapeWithStatus(
541 Shape* shape, const MutatingStatusVisitorFunction& func);
542
543 // Returns true if `shape` (which must be an array) with degenerate dimensions
544 // (dimensions with bound 1).
545 static bool HasDegenerateDimensions(const Shape& shape);
546
547 // Drops any degenerate dimensions (i.e. dimensions of size 1)
548 static Shape DropDegenerateDimensions(const Shape& shape);
549
550 // Permutes the dimensions by the given permutation, so
551 // return_value.dimensions[permutation[i]] = argument.dimensions[i].
552 //
553 // Postcondition: For any valid permutation,
554 //
555 // !HasLayout(shape) ||
556 // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
557 // InversePermutation(permutation)).
558 static Shape PermuteDimensions(absl::Span<const int64> permutation,
559 const Shape& shape);
560
561 // If we can go from `shape_pre` to `shape_post` by merely inserting or
562 // deleting 1-sized dimensions, return the indices in `shape_pre` of the
563 // deleted dimensions and the indices in `dims_post` of the inserted
564 // dimensions.
565 // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
566 // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
567 // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
568 // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
569 // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
570 // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
571 static std::tuple<bool, std::vector<int64>, std::vector<int64>>
572 InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
573 const Shape& shape_post);
574
575 // Suppose a reshape transforms input_shape to output shape. Returns a vector
576 // of pairs that indicate the input and output dimensions that this reshape
577 // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
578 // the returned vector, the reshape transforms any input index whose I-th
579 // dimension is x to an output index whose O-th dimension is x too.
580 //
581 // Post-condition: the returned vector is sorted (by both input and output
582 // dimensions because input and output dimensions have the same order).
583 //
584 // Example:
585 // input shape = T[a, b, x, y, cd]
586 // output shape = T[ab, x, 1, y, c, d]
587 // return value = {{2, 1}, {3, 3}}
588 //
589 // The two pairs represent the input and output dimension of size x and
590 // those of size y.
591 static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
592 const Shape& input_shape, const Shape& output_shape);
593
594 // Return whether the given reshape instruction leaves the dimensions at the
595 // given input indices unmodified, and returns their output indices.
596 //
597 // Example:
598 // input_dim_indices = {2, 3}
599 // input shape = T[a, b, x, y, cd]
600 // output shape = T[ab, x, 1, y, c, d]
601 // return value = {1, 3}
602 //
603 // Precondition: input_dim_indices is sorted.
604 static absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
605 const Shape& from_shape, const Shape& to_shape,
606 absl::Span<const int64> input_dim_indices);
607
608 // Returns whether a transpose from input_shape to output_shape with dimension
609 // mapping "dimension_mapping" produces a result which is bit-wise identical
610 // to its input and thus may be replaced with a bitcast.
611 //
612 // Precondition: Both input_shape and output_shape have explicit layouts.
613 static bool TransposeIsBitcast(const Shape& input_shape,
614 const Shape& output_shape,
615 absl::Span<const int64> dimension_mapping);
616
617 // Returns whether a reshape from "input_shape" to "output_shape" is a
618 // bitcast.
619 //
620 // Precondition: Both input_shape and output_shape have explicit layouts.
621 static bool ReshapeIsBitcast(const Shape& input_shape,
622 const Shape& output_shape);
623
624 // Find a physical layout for 'output_shape' such that
625 // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
626 // true (where 'output_shape_with_layout' is 'output_shape' with the found
627 // layout). The layout of 'input_shape' is kept fixed. Returns
628 // 'output_shape_with_layout' if such a layout can be found, and an error
629 // otherwise.
630 static absl::optional<Shape> AlignLayouts(const Shape& input_shape,
631 const Shape& output_shape);
632
633 // Returns a shape with the given dimension deleted.
634 // For example:
635 // • `DeleteDimension(1, T[m, n, k]) = T[m, k]`
636 static Shape DeleteDimension(int64 dim_to_delete, Shape shape);
637
638 // Returns a shape with all the dimensions of the input shape for which `p`
639 // returns true.
640 // For examples:
641 // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
642 // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
643 static Shape FilterDimensions(const std::function<bool(int64)>& p,
644 Shape shape);
645
646 // Iterates through all the shape indexes, in minor to major order, starting
647 // from the base indexes, incrementing by the incr steps, up to count
648 // (index[i] < base[i] + count[i]), and calls the visitor_function with the
649 // current index.
650 // The visitor_function visitor function should return true if it wants to
651 // continue, or false otherwise.
652 //
653 // visitor_function must be a callable of type
654 // StatusOr<bool>(absl::Span<int64>) or compatible.
655 template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)656 static Status ForEachIndexWithStatus(const Shape& shape,
657 absl::Span<const int64> base,
658 absl::Span<const int64> count,
659 absl::Span<const int64> incr,
660 const FnType& visitor_function) {
661 return ForEachIndexInternal(shape, base, count, incr, visitor_function);
662 }
663
664 // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
665 struct IndexIterationSpace {
666 std::vector<int64> index_base;
667 std::vector<int64> index_count;
668 std::vector<int64> index_incr;
669 };
670
671 template <typename FnTy>
ForEachIndexWithStatus(const Shape & shape,const IndexIterationSpace & iteration_space,FnTy && function)672 static Status ForEachIndexWithStatus(
673 const Shape& shape, const IndexIterationSpace& iteration_space,
674 FnTy&& function) {
675 return ShapeUtil::ForEachIndexWithStatus(
676 shape, iteration_space.index_base, iteration_space.index_count,
677 iteration_space.index_incr, std::forward<FnTy>(function));
678 }
679
680 template <typename FnType>
ForEachIndex(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)681 static void ForEachIndex(const Shape& shape, absl::Span<const int64> base,
682 absl::Span<const int64> count,
683 absl::Span<const int64> incr,
684 const FnType& visitor_function) {
685 ForEachIndexWithStatus(shape, base, count, incr,
686 [&](absl::Span<const int64> indices) {
687 return StatusOr<bool>(visitor_function(indices));
688 })
689 .IgnoreError();
690 }
691
692 // These convenience wrappers don't take `base`, `count` and `incr`
693 // explicitly, but iterate over every element in `shape` instead.
694
695 template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,const FnType & visitor_function)696 static Status ForEachIndexWithStatus(const Shape& shape,
697 const FnType& visitor_function) {
698 std::vector<int64> base(shape.dimensions_size());
699 std::vector<int64> incr(shape.dimensions_size(), 1);
700 return ForEachIndexWithStatus(shape, base,
701 /*count=*/AsInt64Slice(shape.dimensions()),
702 incr, visitor_function);
703 }
704
705 template <typename FnType>
ForEachIndex(const Shape & shape,const FnType & visitor_function)706 static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
707 ForEachIndexWithStatus(shape, [&](absl::Span<const int64> indices) {
708 return StatusOr<bool>(visitor_function(indices));
709 }).IgnoreError();
710 }
711
712 // A parallel version of ForEachIndex(WithStatus). This can only be used if
713 // the visitor_function is thread-safe and the order of iteration does not
714 // matter.
715 //
716 // visitor_function must be a callable of type
717 // void(Span<int64>) or compatible.
718 template <typename FnType>
ForEachIndexParallel(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)719 static void ForEachIndexParallel(const Shape& shape,
720 absl::Span<const int64> base,
721 absl::Span<const int64> count,
722 absl::Span<const int64> incr,
723 const FnType& visitor_function) {
724 // The parallel version of ForEachIndexInternal can never fail.
725 CHECK(ForEachIndexInternal(
726 shape, base, count, incr,
727 [&visitor_function](
728 absl::Span<const int64> indexes) -> StatusOr<bool> {
729 visitor_function(indexes);
730 return true;
731 },
732 /*parallel=*/true)
733 .ok());
734 }
735
736 // Compute a hash for `shape`.
737 static size_t Hash(const Shape& shape);
738
739 // About 0-2-1 transpose:
740 //
741 // If a shape can be viewed as three logical components 0-1-2 in the order of
742 // major to minor, a 0-2-1-transpose changes the order of such logical
743 // components to 0-2-1. We call the shape being transposed the input shape and
744 // the transposed shape the output shape. The logical view of the input/output
745 // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the
746 // normalized shapes. The original input/output shapes are called unnormalized
747 // shapes.
748 //
749 // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
750 // normalized shape of `b` or the 0-2-1 shape.
751 static absl::optional<std::vector<int64>> FindTranspose021(const Shape& a,
752 const Shape& b);
753
754 private:
755 // Validates the shape size is sane. This makes sure it's safe to do
756 // calculations in int64 without overflowing.
757 static Status ValidateShapeSize(const Shape& shape);
758
759 // Validates all of the non-layout properties of the shape -- this is a helper
760 // used by both the layout-optional and layout-required public method.
761 static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
762
763 template <typename FnType>
764 static Status ForEachIndexInternal(const Shape& shape,
765 absl::Span<const int64> base,
766 absl::Span<const int64> count,
767 absl::Span<const int64> incr,
768 const FnType& visitor_function,
769 bool parallel = false) {
770 if (ShapeUtil::IsZeroElementArray(shape)) {
771 return Status::OK();
772 }
773 CHECK_EQ(shape.rank(), base.size());
774 CHECK_EQ(incr.size(), base.size());
775 CHECK_EQ(count.size(), base.size());
776 const int64 rank = LayoutUtil::MinorToMajor(shape).size();
777 // Allows handling R0 arrays, such that the visitor function will be called
778 // once with the proper empty indexes.
779 int64 n = -1;
780 std::vector<int64> indexes(base.begin(), base.end());
781 const int kNumThreads = tensorflow::port::MaxParallelism();
782 absl::optional<tensorflow::thread::ThreadPool> pool;
783 if (parallel) {
784 pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
785 }
786
787 tensorflow::mutex mu;
788 Status status; // Guarded by mu
789
790 while (n < rank) {
791 if (pool != absl::nullopt) {
792 pool->Schedule([indexes, &visitor_function, &mu, &status] {
793 StatusOr<bool> result = visitor_function(indexes);
794 if (!result.ok()) {
795 tensorflow::mutex_lock lock(mu);
796 status = status.ok() ? result.status() : status;
797 }
798 });
799 } else {
800 TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
801 if (!should_continue) {
802 break;
803 }
804 }
805 // Increments dimensions in minor to major order.
806 for (n = 0; n < rank; ++n) {
807 int64 dim = LayoutUtil::Minor(shape.layout(), n);
808 indexes[dim] += incr[dim];
809 if (indexes[dim] < base[dim] + count[dim]) {
810 break;
811 }
812 indexes[dim] = base[dim];
813 }
814 }
815
816 // Waits for the scheduled work to complete.
817 pool.reset();
818 return status;
819 }
820
821 TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
822 };
823
824 } // namespace xla
825
826 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
827