1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Shapes are protobuf messages, so this utility header offers a bunch of
17 // functionality for querying / poking at them.
18
19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
21
22 #include <initializer_list>
23 #include <string>
24 #include <tuple>
25
26 #include "absl/base/macros.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/types/optional.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/platform/cpu_info.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace xla {
46
47 class ShapeIndexView;
48
49 // An index for specifying a particular nested subshape within a shape. Used in
50 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
51 // structures (trees) and ShapeIndex defines a path through the tree where each
52 // element of ShapeIndex indexes into a tuple (or nested tuple) within the
53 // shape. For a non-nested tuple, an index has a single element. For example,
54 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
55 // {1} corresponds to array b. For a nested tuple, the index can have more than
56 // one element. For the nested tuple (a, (b, c, d), e) below are the values
57 // corresponding to the given indices:
58 //
59 // index {0} : array a
60 // index {1, 2} : array d
61 // index {2} : array e
62 // index {0, 0} : invalid index (element at {0} is an array not a tuple)
63 //
64 // For indexing into array shapes, the index is always trivially empty, ie {}.
65 //
66 // ShapeIndex is a trivial wrapper around std::vector with a minimum number of
67 // methods implemented.
68 class ShapeIndex {
69 public:
70 ShapeIndex() = default;
ShapeIndex(std::initializer_list<int64> init)71 ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
72 template <typename InputIt>
ShapeIndex(InputIt start,InputIt end)73 ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {}
74
75 explicit ShapeIndex(ShapeIndexView v);
76
empty()77 bool empty() const { return indices_.empty(); }
size()78 size_t size() const { return indices_.size(); }
push_back(int64_t value)79 void push_back(int64_t value) { indices_.push_back(value); }
pop_back()80 void pop_back() { indices_.pop_back(); }
81
82 // push_front is O(n), but shapes don't usually have a ton of dimensions.
push_front(int64_t value)83 void push_front(int64_t value) { indices_.insert(indices_.begin(), value); }
pop_front()84 void pop_front() { indices_.erase(indices_.begin()); }
85
86 using container_type = absl::InlinedVector<int64, 2>;
87
begin()88 container_type::const_iterator begin() const { return indices_.begin(); }
end()89 container_type::const_iterator end() const { return indices_.end(); }
begin()90 container_type::iterator begin() { return indices_.begin(); }
end()91 container_type::iterator end() { return indices_.end(); }
92
data()93 const int64* data() const { return indices_.data(); }
94
back()95 int64 back() const { return indices_.back(); }
back()96 int64& back() { return indices_.back(); }
97
front()98 int64 front() const { return indices_.front(); }
front()99 int64& front() { return indices_.front(); }
100
101 const int64& operator[](size_t i) const { return indices_[i]; }
102 int64& operator[](size_t i) { return indices_[i]; }
103
104 bool operator==(const ShapeIndex& other) const {
105 return indices_ == other.indices_;
106 }
107 bool operator!=(const ShapeIndex& other) const { return !(*this == other); }
108 bool operator<(const ShapeIndex& other) const {
109 return indices_ < other.indices_;
110 }
111
112 string ToString() const;
113
114 template <typename H>
AbslHashValue(H h,const ShapeIndex & index)115 friend H AbslHashValue(H h, const ShapeIndex& index) {
116 return H::combine(std::move(h), index.indices_);
117 }
118
119 private:
120 container_type indices_;
121 };
122
123 // A view into a ShapeIndex as above, with the cheap/easy ability to consume the
124 // value at the front of the view.
125 //
126 // NB! ShapeIndexView does not own the memory backing the index array.
127 // The memory backing the index array should be owned by an object
128 // that lives longer than the ShapeIndexView instances pointing into
129 // it.
130 class ShapeIndexView {
131 public:
132 ShapeIndexView(const ShapeIndex& shape_index, int64_t offset = 0)
133 : indices_(shape_index.data() + offset, shape_index.size() - offset) {
134 CHECK_LE(offset, shape_index.size());
135 }
ShapeIndexView(std::initializer_list<int64> indices)136 ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {}
137 ShapeIndexView(const ShapeIndexView& other) = default;
138
139 using iterator = const int64*;
140
begin()141 iterator begin() const { return indices_.begin(); }
end()142 iterator end() const { return indices_.end(); }
size()143 int64 size() const { return indices_.size(); }
empty()144 bool empty() const { return indices_.empty(); }
front()145 int64 front() const {
146 CHECK(!empty());
147 return indices_.front();
148 }
back()149 int64 back() const {
150 CHECK(!empty());
151 return indices_.back();
152 }
ConsumeFront()153 ShapeIndexView ConsumeFront() const {
154 ShapeIndexView result = *this;
155 result.indices_.remove_prefix(1);
156 return result;
157 }
ConsumeBack()158 ShapeIndexView ConsumeBack() const {
159 ShapeIndexView result = *this;
160 result.indices_.remove_suffix(1);
161 return result;
162 }
ToShapeIndex()163 ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }
164
165 bool operator==(const ShapeIndexView& other) const;
166 bool operator!=(const ShapeIndexView& other) const;
167
168 string ToString() const;
169
170 // Returns true if this shape index starts with 'prefix'.
171 bool StartsWith(ShapeIndexView prefix) const;
172
173 private:
174 absl::Span<const int64> indices_;
175 };
176
ShapeIndex(ShapeIndexView v)177 inline ShapeIndex::ShapeIndex(ShapeIndexView v)
178 : ShapeIndex(v.begin(), v.end()) {}
179
180 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
181 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
182
183 // Namespaced collection of (static) shape utilities.
184 //
185 // These are all effectively convenience functions for testing/tweaking proto
186 // properties, which do invariant checks before / after the operation.
187 class ShapeUtil {
188 public:
189 // Data structure which describes the coordinates and the shape, of a tuple
190 // shaped sub-shape.
191 struct IndexedShape {
192 IndexedShape() = default;
IndexedShapeIndexedShape193 IndexedShape(ShapeIndex index, Shape shape)
194 : index(std::move(index)), shape(std::move(shape)) {}
195 ShapeIndex index;
196 Shape shape;
197 };
198
199 // Returns the number of elements are contained within the provided shape;
200 // e.g. for rank 0 (scalars) the result is always 1.
201 // Precondition: shape.IsArray()
202 static int64 ElementsIn(const Shape& shape);
203
204 // As ElementsIn(), but recurses through tuples.
205 static int64 ElementsInRecursive(const Shape& shape);
206
207 // Returns true if shape has the primitive type, recurses through tuples.
208 static bool HasPrimitiveType(const Shape& shape,
209 PrimitiveType primitive_type);
210
211 // Returns true if 'shape' is an array with zero elements.
212 static bool IsZeroElementArray(const Shape& shape);
213
214 // Returns the number of bytes required for an allocation of shape. The
215 // |pointer_size| parameter is used for calculating the size of tuple
216 // shapes. This includes only the size of the top-level buffer. For example, a
217 // tuple is stored as an array of pointers to other buffers. In this case,
218 // this method only returns the size of the pointer array.
219 static int64 ByteSizeOf(const Shape& shape, int64_t pointer_size = -1);
220
221 // Returns the number of bytes used to store the primitive_type.
222 //
223 // Precondition: shape.IsArray()
224 static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
225
226 // Returns the number of bytes required to store the tuple member pointers for
227 // a allocation of shape. The `shape` must be a TUPLE shape, and
228 // `pointer_size` must be larger than zero.
229 static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
230 int64_t pointer_size);
231
232 // Returns the number of bytes required for the elements in an allocation of
233 // `shape`, which must be an array shape. Shapes use a separate
234 // memory location for each element, and so for these shapes,
235 // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. This
236 // size also includes padding if present in the layout.
237 static int64 ByteSizeOfElements(const Shape& shape);
238
239 // Returns a human-readable string that represents the given shape, with or
240 // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
241 static string HumanString(const Shape& shape);
242 static string HumanStringWithLayout(const Shape& shape);
243
244 // As above, but for program shapes, returns a string for the form:
245 //
246 // (param_name: f32[42x12], ...) -> f32[24x42]
247 static string HumanString(const ProgramShape& program_shape);
248
249 // Returns whether the LHS and RHS shapes have the same dimensions; note: does
250 // not check element type.
251 // Precondition: IsArray(lhs) && IsArray(rhs)
252 static bool SameDimensions(const Shape& lhs, const Shape& rhs);
253
254 // Returns whether the LHS and RHS shapes have the same rank; note: does
255 // not check element type.
256 // Precondition: IsArray(lhs) && IsArray(rhs)
257 static bool SameRank(const Shape& lhs, const Shape& rhs);
258
259 // Returns whether the lhs and rhs shapes have the same element type.
SameElementType(const Shape & lhs,const Shape & rhs)260 static bool SameElementType(const Shape& lhs, const Shape& rhs) {
261 return lhs.element_type() == rhs.element_type();
262 }
263
264 // As SameElementType, but allows floating point types to have different
265 // precisions.
SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)266 static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
267 const Shape& b) {
268 if (ElementIsFloating(a) && ElementIsFloating(b)) {
269 return true;
270 }
271 return ShapeUtil::SameElementType(a, b);
272 }
273
274 // Returns the higher-precision element type if a and b are both floating
275 // point types; otherwise, checks that they have the same element type
276 // and returns it.
HigherPrecisionElementType(const Shape & a,const Shape & b)277 static PrimitiveType HigherPrecisionElementType(const Shape& a,
278 const Shape& b) {
279 return primitive_util::HigherPrecisionType(a.element_type(),
280 b.element_type());
281 }
282
283 // Returns true if the rank, dimension sizes, and element type are
284 // identical. Layout is ignored. Tuple elements are compared recursively for
285 // compatibility.
286 static bool Compatible(const Shape& lhs, const Shape& rhs);
287
288 // Returns true if the rank and dimension sizes are identical. Element type
289 // and layout are ignored. Tuple elements are compared recursively for
290 // compatibility.
291 static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
292
293 // Returns true if the tuple tree shapes and leaf ranks are identical.
294 // Leaf dimensions, element type, and layout are ignored. Tuple elements are
295 // compared recursively for compatibility.
296 static bool CompatibleKind(const Shape& lhs, const Shape& rhs);
297
298 // As Compatible, but allow one of lhs and rhs to be BF16 while the other
299 // being F32. Tuple elements are compared recursively for compatibility.
300 static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
301
302 // Returns whether the lhs and rhs shapes are identical.
303 static bool Equal(const Shape& lhs, const Shape& rhs);
304
305 // As Equal, but does not compare the element type.
306 static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs);
307
308 // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
309 static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
310
311 // Two shapes have same structure if all subshape indices of lhs are presented
312 // on rhs and vice versa.
313 // A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to
314 // (S32, (F32[3], S32[2])) as their structures are both (,(,))
315 //
316 // In contrast, (F32, (F32, F32)) is structurally different from
317 // ((F32, F32), F32) as the former has structure (,(,)) while the latter has
318 // ((,),)
319 static bool EqualStructure(const Shape& lhs, const Shape& rhs);
320
321 // Returns the number of dimensions for which the dimension is not (trivially)
322 // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
323 // fluff. Note that zero dimensions are included in the true rank, e.g.,
324 // f32[3,0,1] has a true rank of 2D.
325 static int64 TrueRank(const Shape& shape);
326
327 static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
328 Shape result);
329
330 ////////////////////
331 // Scalar-specific
332
IsScalar(const Shape & shape)333 static bool IsScalar(const Shape& shape) {
334 return shape.IsArray() && shape.rank() == 0;
335 }
IsEffectiveScalar(const Shape & shape)336 static bool IsEffectiveScalar(const Shape& shape) {
337 return shape.IsArray() && TrueRank(shape) == 0;
338 }
339
340 // Returns whether "shape" is a scalar (array) with the given element_type.
341 static bool IsScalarWithElementType(const Shape& shape,
342 PrimitiveType element_type);
343
344 // Extracts the size of the shape's dimension at dimension number
345 // GetDimensionNumber(dimension_number).
346 static int64 GetDimension(const Shape& shape, int64_t dimension_number);
347
348 // Resolves a dimension number, supporting negative indexing.
349 //
350 // Negative indexing has similar semantics to Python. For an N-dimensional
351 // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
352 // N-2, and so on.
353 //
354 // This function always returns a positive dimension number for any given
355 // dimension_number (which itself can be negative).
356 static int64 GetDimensionNumber(const Shape& shape, int64_t dimension_number);
357
358 // Returns a shape with the same dimensions as the original, but with the
359 // element type changed to type.
360 static Shape ChangeElementType(const Shape& original, PrimitiveType type);
361
362 // Retursn a shape with same dimensions but with all dimensions set to static.
363 static Shape MakeStaticShape(const Shape& original);
364
365 // Creates a tuple shape from a slice of element shapes within the tuple.
366 static Shape MakeTupleShape(absl::Span<const Shape> shapes);
367
368 // Creates an opaque shape. These are generally used for threading a context
369 // into a custom operation.
370 static Shape MakeOpaqueShape();
371
372 // Creates a token shape. Values of this shape are used for ordering
373 // side-effecting operations.
374 static Shape MakeTokenShape();
375
376 // Appends a shape to the given tuple.
377 static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
378
379 // Update a subshape of a tuple.
380 static void UpdateTupleShape(const Shape& shape, int64_t index,
381 Shape* tuple_shape);
382
383 // Update the dynamic dimension for a shape. This shape can be a nested tuple.
384 static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index,
385 int64_t dim, bool is_dynamic);
386
387 // Appends a major dimension to the shape with the given bound.
388 static void AppendMajorDimension(int bound, Shape* shape);
389
390 // Copy the dynamic dimensions property from one shape to another.
391 static void CopyDynamicDimensions(Shape* to, const Shape& from);
392
393 // Returns an empty tuple shape. Can be used as a sentinel Shape value.
MakeNil()394 static Shape MakeNil() { return MakeTupleShape({}); }
395
396 // Checks whether the shape is initialized.
IsInitialized(const Shape & shape)397 static bool IsInitialized(const Shape& shape) {
398 return shape.element_type() != PRIMITIVE_TYPE_INVALID;
399 }
400
401 // Constructs a new shape with the given element type and sequence of
402 // dimensions.
403 static Shape MakeShape(PrimitiveType element_type,
404 absl::Span<const int64> dimensions);
405
406 // Make a scalar shape with given primitive type.
407 static Shape MakeScalarShape(PrimitiveType element_type);
408
409 // Constructs a new shape with the given element type and sequence of
410 // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates
411 // with a true value that the respective dimension is dynamic. If the
412 // dimension is dynamic then the respective value in 'dimension' is an upper
413 // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be
414 // the same size.
415 static Shape MakeShape(PrimitiveType element_type,
416 absl::Span<const int64> dimensions,
417 const std::vector<bool>& dynamic_dimensions);
418
419 // Constructs a new shape with the given element type and sequence of
420 // dimensions. Method checks if the element type is valid and the shape's
421 // size fits in std::numeric_limits<int64>::max().
422 static StatusOr<Shape> MakeValidatedShape(PrimitiveType element_type,
423 absl::Span<const int64> dimensions);
424 static StatusOr<Shape> MakeValidatedShape(
425 PrimitiveType element_type, absl::Span<const int64> dimensions,
426 const std::vector<bool>& dynamic_dimensions);
427
428 // Creates a Shape with element type corresponding to T and the given
429 // dimensions
430 template <typename T>
MakeShapeWithType(absl::Span<const int64> dimensions)431 static Shape MakeShapeWithType(absl::Span<const int64> dimensions) {
432 return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(),
433 dimensions);
434 }
435
436 // Constructs a new shape with the given minor_to_major order in its Layout.
437 // Returns a value shape such that shape.has_layout().
438 static Shape MakeShapeWithLayout(PrimitiveType element_type,
439 absl::Span<const int64> dimensions,
440 absl::Span<const int64> minor_to_major,
441 absl::Span<const Tile> tiles = {},
442 int64_t element_size_in_bits = 0,
443 int64_t memory_space = 0);
444
445 // Constructs a new shape with the given dimension `dim` as the most major
446 // dimension in the layout. If the shape does not have a layout, assumes a
447 // default layout. If the shape is a tuple, apply this to all the leaf shapes
448 // of the tuple.
449 static Shape MoveDimToMajor(const Shape& shape, int64_t dim);
450
451 // Returns the same shape except with all dimensions set to be static.
452 static Shape MakeShapeWithStaticDimensions(const Shape& shape);
453
454 // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
455 static Shape MakeShapeWithDescendingLayout(
456 PrimitiveType element_type, absl::Span<const int64> dimensions);
457
458 // Returns a new Shape based on the given Shape with low-dimension-major
459 // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
460 // rearranged so that it has the same in-memory layout as the given shape.
461 //
462 // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
463 static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
464 const Shape& shape);
465
466 // As MakeShape, but the object to write to is passed in.
467 static Status PopulateShape(PrimitiveType element_type,
468 absl::Span<const int64> dimensions, Shape* shape);
469
470 // Validates that the provided shape satisfies invariants.
471 static Status ValidateShape(const Shape& shape);
472
473 // Validates the provided shape satisfies invariants, except those that
474 // pertain to layout.
475 //
476 // Layout is optional for client-provided shapes, so that the compiler may
477 // determine and assign an optimized layout.
478 static Status ValidateShapeWithOptionalLayout(const Shape& shape);
479
480 // Returns whether the element type of the shape is integral (signed or
481 // unsigned). Note that predicates are not considered integral here, since
482 // they are logical values.
483 static bool ElementIsIntegral(const Shape& shape);
484
485 // Returns whether the element type of the shape is floating point.
486 static bool ElementIsFloating(const Shape& shape);
487
488 // Returns whether the element type of the shape is complex.
489 static bool ElementIsComplex(const Shape& shape);
490
491 // Returns whether the element type has the given bit width.
492 static bool ElementHasBitWidth(const Shape& shape, int bits);
493
494 // Returns whether the element type of the shape is integral and has
495 // the specified number of bits.
496 static bool ElementIsIntegralWithBits(const Shape& shape, int bits);
497
498 // Returns whether the element type of the shape is signed. Note
499 // that floating point numbers are signed.
500 static bool ElementIsSigned(const Shape& shape);
501
502 // Returns whether the given primitive type corresponds to an array shape.
503 static bool IsArrayPrimitiveType(PrimitiveType primitive_type);
504
505 // Returns whether the shape is a tuple with at least one element which is
506 // also a tuple.
507 static bool IsNestedTuple(const Shape& shape);
508
509 // Returns true if shape is an empty tuple.
510 static bool IsEmptyTuple(const Shape& shape);
511
512 // Returns the number of elements in the given tuple shape.
513 // Precondition: IsTuple(shape)
514 static int64 TupleElementCount(const Shape& shape);
515
516 // Returns the tuple element shape at given index.
517 // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
518 static const Shape& GetTupleElementShape(const Shape& shape, int64_t index);
519
520 // Returns the number of elements, recursively, in the given shape.
521 static int64 SubshapeCount(const Shape& shape);
522
523 // Slices tuple elements in the range [start, limit) and returns a new tuple
524 // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
525 static Shape SliceTuple(const Shape& tuple, int64_t start, int64_t limit);
526
527 // Returns the shape of the real/imaginary components of the given complex
528 // shape.
529 static Shape ComplexComponentShape(const Shape& complex_shape);
530
531 // Returns true if the given shape has a subshape at the given index.
532 static bool IndexIsValid(const Shape& shape, ShapeIndexView index);
533
534 // GetSubshape and GetMutableSubshape return a particular nested Shape within
535 // the given Shape argument. The non-Try variants check fail if index is
536 // invalid.
537 static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
538 static StatusOr<const Shape*> TryGetSubshape(const Shape& shape,
539 ShapeIndexView index);
540 static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
541
542 // Returns whether the given index in the given shape is a leaf element of the
543 // shape.
544 static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);
545
546 // Returns the number of leaves in the shape.
547 static int64 GetLeafCount(const Shape& shape);
548
549 // Retrieves all the leaf shapes and their indexes, in the order walked by
550 // the ForEachSubshape() API.
551 static std::vector<IndexedShape> GetLeafShapes(const Shape& shape);
552
553 // Calls the given visitor function for each subshape of the given shape.
554 // Subshapes are visited in DFS pre-order starting with the entire shape
555 // (index {}).
556 using VisitorFunction = std::function<void(const Shape& /*subshape*/,
557 const ShapeIndex& /*index*/)>;
558 static void ForEachSubshape(const Shape& shape, const VisitorFunction& func);
559 using MutatingVisitorFunction =
560 std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
561 static void ForEachMutableSubshape(Shape* shape,
562 const MutatingVisitorFunction& func);
563
564 // Variants of ForEach(Mutable)Subshape which propagate Status from the
565 // visitor function.
566 using StatusVisitorFunction = std::function<Status(
567 const Shape& /*subshape*/, const ShapeIndex& /*index*/)>;
568 static Status ForEachSubshapeWithStatus(const Shape& shape,
569 const StatusVisitorFunction& func);
570 using MutatingStatusVisitorFunction =
571 std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
572 static Status ForEachMutableSubshapeWithStatus(
573 Shape* shape, const MutatingStatusVisitorFunction& func);
574
575 // Returns true if `shape` (which must be an array) with degenerate dimensions
576 // (dimensions with bound 1).
577 static bool HasDegenerateDimensions(const Shape& shape);
578
579 // Drops any degenerate dimensions (i.e. dimensions of size 1)
580 static Shape DropDegenerateDimensions(const Shape& shape);
581
582 // Permutes the dimensions by the given permutation, so
583 // return_value.dimensions[i] = argument.dimensions[permutation[i]].
584 //
585 // Postcondition: For any valid permutation,
586 //
587 // !HasLayout(shape) ||
588 // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
589 // permutation).
590 static Shape PermuteDimensions(absl::Span<const int64> permutation,
591 const Shape& shape);
592
593 // If we can go from `shape_pre` to `shape_post` by merely inserting or
594 // deleting 1-sized dimensions, return the indices in `shape_pre` of the
595 // deleted dimensions and the indices in `dims_post` of the inserted
596 // dimensions.
597 // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
598 // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
599 // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
600 // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
601 // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
602 // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
603 static std::tuple<bool, std::vector<int64>, std::vector<int64>>
604 InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
605 const Shape& shape_post);
606
607 // Suppose a reshape transforms input_shape to output shape. Returns a vector
608 // of pairs that indicate the input and output dimensions that this reshape
609 // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
610 // the returned vector, the reshape transforms any input index whose I-th
611 // dimension is x to an output index whose O-th dimension is x too.
612 //
613 // Post-condition: the returned vector is sorted (by both input and output
614 // dimensions because input and output dimensions have the same order).
615 //
616 // Example:
617 // input shape = T[a, b, x, y, cd]
618 // output shape = T[ab, x, 1, y, c, d]
619 // return value = {{2, 1}, {3, 3}}
620 //
621 // The two pairs represent the input and output dimension of size x and
622 // those of size y.
623 static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
624 const Shape& input_shape, const Shape& output_shape);
625
626 // Return whether the given reshape instruction leaves the dimensions at the
627 // given input indices unmodified, and returns their output indices.
628 //
629 // Example:
630 // input_dim_indices = {2, 3}
631 // input shape = T[a, b, x, y, cd]
632 // output shape = T[ab, x, 1, y, c, d]
633 // return value = {1, 3}
634 //
635 // Precondition: input_dim_indices is sorted.
636 static absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
637 const Shape& from_shape, const Shape& to_shape,
638 absl::Span<const int64> input_dim_indices);
639
640 // Returns whether a transpose from input_shape to output_shape with dimension
641 // mapping "dimension_mapping" produces a result which is bit-wise identical
642 // to its input and thus may be replaced with a bitcast.
643 //
644 // Precondition: Both input_shape and output_shape have explicit layouts.
645 static bool TransposeIsBitcast(const Shape& input_shape,
646 const Shape& output_shape,
647 absl::Span<const int64> dimension_mapping);
648
649 // Returns whether a reshape from "input_shape" to "output_shape" is a
650 // bitcast.
651 //
652 // Precondition: Both input_shape and output_shape have explicit layouts.
653 static bool ReshapeIsBitcast(const Shape& input_shape,
654 const Shape& output_shape);
655
656 // Find a physical layout for 'output_shape' such that
657 // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
658 // true (where 'output_shape_with_layout' is 'output_shape' with the found
659 // layout). The layout of 'input_shape' is kept fixed. Returns
660 // 'output_shape_with_layout' if such a layout can be found, and an error
661 // otherwise.
662 static absl::optional<Shape> AlignLayouts(const Shape& input_shape,
663 const Shape& output_shape);
664
665 // Returns a shape with the given dimension deleted.
666 // For example:
667 // • `DeleteDimension(1, T[m, n, k]) = T[m, k]`
668 static Shape DeleteDimension(int64_t dim_to_delete, Shape shape);
669
670 // Returns a shape with all the dimensions of the input shape for which `p`
671 // returns true.
672 // For examples:
673 // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
674 // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
675 static Shape FilterDimensions(const std::function<bool(int64_t)>& p,
676 Shape shape);
677
678 // Returns true if `dynamic_shape` has dimensions that are less-equal to the
679 // "bounded_shape". Shapes must be arrays.
680 static bool DynamicArrayShapeIsCompatible(const xla::Shape& dynamic_shape,
681 const xla::Shape& bounded_shape);
682
683 // Same as DynamicArrayShapeIsCompatible() but supports tuples.
684 static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
685 const xla::Shape& bounded_shape);
686
687 // Iterates through all the shape indexes, in minor to major order,
688 // starting from the base indexes, incrementing by the incr steps, up to
689 // count (index[i] < base[i] + count[i]), and calls the visitor_function
690 // with the current index. The visitor_function visitor function should
691 // return true if it wants to continue, or false otherwise.
692 //
693 // visitor_function must be a callable of type
694 // StatusOr<bool>(absl::Span<int64>) or compatible.
695 template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)696 static Status ForEachIndexWithStatus(const Shape& shape,
697 absl::Span<const int64> base,
698 absl::Span<const int64> count,
699 absl::Span<const int64> incr,
700 const FnType& visitor_function) {
701 return ForEachIndexInternal(shape, base, count, incr, visitor_function);
702 }
703
704 // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
705 struct IndexIterationSpace {
706 std::vector<int64> index_base;
707 std::vector<int64> index_count;
708 std::vector<int64> index_incr;
709 };
710
711 template <typename FnTy>
ForEachIndexWithStatus(const Shape & shape,const IndexIterationSpace & iteration_space,FnTy && function)712 static Status ForEachIndexWithStatus(
713 const Shape& shape, const IndexIterationSpace& iteration_space,
714 FnTy&& function) {
715 return ShapeUtil::ForEachIndexWithStatus(
716 shape, iteration_space.index_base, iteration_space.index_count,
717 iteration_space.index_incr, std::forward<FnTy>(function));
718 }
719
720 template <typename FnType>
ForEachIndex(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)721 static void ForEachIndex(const Shape& shape, absl::Span<const int64> base,
722 absl::Span<const int64> count,
723 absl::Span<const int64> incr,
724 const FnType& visitor_function) {
725 ForEachIndexWithStatus(shape, base, count, incr,
726 [&](absl::Span<const int64> indices) {
727 return StatusOr<bool>(visitor_function(indices));
728 })
729 .IgnoreError();
730 }
731
732 // These convenience wrappers don't take `base`, `count` and `incr`
733 // explicitly, but iterate over every element in `shape` instead.
734
735 template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,const FnType & visitor_function)736 static Status ForEachIndexWithStatus(const Shape& shape,
737 const FnType& visitor_function) {
738 std::vector<int64> base(shape.dimensions_size());
739 std::vector<int64> incr(shape.dimensions_size(), 1);
740 return ForEachIndexWithStatus(shape, base,
741 /*count=*/AsInt64Slice(shape.dimensions()),
742 incr, visitor_function);
743 }
744
745 template <typename FnType>
ForEachIndex(const Shape & shape,const FnType & visitor_function)746 static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
747 ForEachIndexWithStatus(shape, [&](absl::Span<const int64> indices) {
748 return StatusOr<bool>(visitor_function(indices));
749 }).IgnoreError();
750 }
751
752 // A parallel version of ForEachIndex(WithStatus). This can only be used if
753 // the visitor_function is thread-safe and the order of iteration does not
754 // matter.
755 //
756 // visitor_function must be a callable of type
757 // void(Span<int64>) or compatible.
758 template <typename FnType>
ForEachIndexParallel(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)759 static void ForEachIndexParallel(const Shape& shape,
760 absl::Span<const int64> base,
761 absl::Span<const int64> count,
762 absl::Span<const int64> incr,
763 const FnType& visitor_function) {
764 // The parallel version of ForEachIndexInternal can never fail.
765 CHECK(ForEachIndexInternal(
766 shape, base, count, incr,
767 [&visitor_function](
768 absl::Span<const int64> indexes) -> StatusOr<bool> {
769 visitor_function(indexes);
770 return true;
771 },
772 /*parallel=*/true)
773 .ok());
774 }
775
776 // Compute a hash for `shape`.
777 static size_t Hash(const Shape& shape);
778
779 // About 0-2-1 transpose:
780 //
781 // If a shape can be viewed as three logical components 0-1-2 in the order of
782 // major to minor, a 0-2-1-transpose changes the order of such logical
783 // components to 0-2-1. We call the shape being transposed the input shape and
784 // the transposed shape the output shape. The logical view of the input/output
785 // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the
786 // normalized shapes. The original input/output shapes are called unnormalized
787 // shapes.
788 //
789 // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
790 // normalized shape of `b` or the 0-2-1 shape.
791 static absl::optional<std::vector<int64>> FindTranspose021(const Shape& a,
792 const Shape& b);
793
794 // Strips device-specific information, namely tiling and memory-space
795 // information, from a shape.
796 static Shape DeviceShapeToHostShape(Shape s);
797
798 // Returns true iff element type of shape `from` can be safely upcasted to
799 // element type of shape `to`.
800 static bool ElementCanUpcast(const Shape& from, const Shape& to);
801
802 // Computes byte strides of an array shape `shape`. `shape` must have a
803 // layout. Ignores tiling. `strides` must have size equal to the number of
804 // dimensions of `shape`.
805 static Status ByteStrides(const Shape& shape, absl::Span<int64_t> strides);
806
807 private:
808 // Fills *shape. Returns true on success.
809 // REQUIRES: *shape is empty.
810 static bool FillNewShape(PrimitiveType element_type,
811 absl::Span<const int64> dimensions, Shape* shape);
812
813 // Validates the shape size is sane. This makes sure it's safe to do
814 // calculations in int64 without overflowing.
815 static Status ValidateShapeSize(const Shape& shape);
816
817 // Validates all of the non-layout properties of the shape -- this is a helper
818 // used by both the layout-optional and layout-required public method.
819 static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
820
821 template <typename FnType>
822 static Status ForEachIndexInternal(const Shape& shape,
823 absl::Span<const int64> base,
824 absl::Span<const int64> count,
825 absl::Span<const int64> incr,
826 const FnType& visitor_function,
827 bool parallel = false) {
828 if (ShapeUtil::IsZeroElementArray(shape)) {
829 return Status::OK();
830 }
831 CHECK_EQ(shape.rank(), base.size());
832 CHECK_EQ(incr.size(), base.size());
833 CHECK_EQ(count.size(), base.size());
834 const int64_t rank = LayoutUtil::MinorToMajor(shape).size();
835 // Allows handling R0 arrays, such that the visitor function will be called
836 // once with the proper empty indexes.
837 int64_t n = -1;
838 std::vector<int64> indexes(base.begin(), base.end());
839 const int kNumThreads = tensorflow::port::MaxParallelism();
840 absl::optional<tensorflow::thread::ThreadPool> pool;
841 if (parallel) {
842 pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
843 }
844
845 tensorflow::mutex mu;
846 Status status; // Guarded by mu
847
848 while (n < rank) {
849 if (pool != absl::nullopt) {
850 pool->Schedule([indexes, &visitor_function, &mu, &status] {
851 StatusOr<bool> result = visitor_function(indexes);
852 if (!result.ok()) {
853 tensorflow::mutex_lock lock(mu);
854 status = status.ok() ? result.status() : status;
855 }
856 });
857 } else {
858 TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
859 if (!should_continue) {
860 break;
861 }
862 }
863 // Increments dimensions in minor to major order.
864 for (n = 0; n < rank; ++n) {
865 int64_t dim = LayoutUtil::Minor(shape.layout(), n);
866 indexes[dim] += incr[dim];
867 if (indexes[dim] < base[dim] + count[dim]) {
868 break;
869 }
870 indexes[dim] = base[dim];
871 }
872 }
873
874 // Waits for the scheduled work to complete.
875 pool.reset();
876 return status;
877 }
878
879 TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
880 };
881
882 } // namespace xla
883
884 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
885