1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Shapes are protobuf messages, so this utility header offers a bunch of
17 // functionality for querying / poking at them.
18
19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
21
22 #include <initializer_list>
23 #include <string>
24 #include <tuple>
25
26 #include "absl/base/macros.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/types/optional.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/platform/cpu_info.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace xla {
46
47 class ShapeIndexView;
48
49 // An index for specifying a particular nested subshape within a shape. Used in
50 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
51 // structures (trees) and ShapeIndex defines a path through the tree where each
52 // element of ShapeIndex indexes into a tuple (or nested tuple) within the
53 // shape. For a non-nested tuple, an index has a single element. For example,
54 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
55 // {1} corresponds to array b. For a nested tuple, the index can have more than
56 // one element. For the nested tuple (a, (b, c, d), e) below are the values
57 // corresponding to the given indices:
58 //
59 // index {0} : array a
60 // index {1, 2} : array d
61 // index {2} : array e
62 // index {0, 0} : invalid index (element at {0} is an array not a tuple)
63 //
64 // For indexing into array shapes, the index is always trivially empty, ie {}.
65 //
66 // ShapeIndex is a trivial wrapper around std::vector with a minimum number of
67 // methods implemented.
68 class ShapeIndex {
69 public:
70 ShapeIndex() = default;
ShapeIndex(std::initializer_list<int64> init)71 ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
72 template <typename InputIt>
ShapeIndex(InputIt start,InputIt end)73 ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {}
74
75 explicit ShapeIndex(ShapeIndexView v);
76
empty()77 bool empty() const { return indices_.empty(); }
size()78 size_t size() const { return indices_.size(); }
push_back(int64 value)79 void push_back(int64 value) { indices_.push_back(value); }
pop_back()80 void pop_back() { indices_.pop_back(); }
81
82 // push_front is O(n), but shapes don't usually have a ton of dimensions.
push_front(int64 value)83 void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
84
85 using container_type = absl::InlinedVector<int64, 2>;
86
begin()87 container_type::const_iterator begin() const { return indices_.begin(); }
end()88 container_type::const_iterator end() const { return indices_.end(); }
begin()89 container_type::iterator begin() { return indices_.begin(); }
end()90 container_type::iterator end() { return indices_.end(); }
91
data()92 const int64* data() const { return indices_.data(); }
93
back()94 int64 back() const { return indices_.back(); }
back()95 int64& back() { return indices_.back(); }
96
97 const int64& operator[](size_t i) const { return indices_[i]; }
98 int64& operator[](size_t i) { return indices_[i]; }
99
100 bool operator==(const ShapeIndex& other) const {
101 return indices_ == other.indices_;
102 }
103 bool operator!=(const ShapeIndex& other) const { return !(*this == other); }
104 bool operator<(const ShapeIndex& other) const {
105 return indices_ < other.indices_;
106 }
107
108 string ToString() const;
109
110 template <typename H>
AbslHashValue(H h,const ShapeIndex & index)111 friend H AbslHashValue(H h, const ShapeIndex& index) {
112 return H::combine(std::move(h), index.indices_);
113 }
114
115 private:
116 container_type indices_;
117 };
118
119 // A view into a ShapeIndex as above, with the cheap/easy ability to consume the
120 // value at the front of the view.
121 //
122 // NB! ShapeIndexView does not own the memory backing the index array.
123 // The memory backing the index array should be owned by an object
124 // that lives longer than the ShapeIndexView instances pointing into
125 // it.
126 class ShapeIndexView {
127 public:
128 ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0)
129 : indices_(shape_index.data() + offset, shape_index.size() - offset) {
130 CHECK_LE(offset, shape_index.size());
131 }
ShapeIndexView(std::initializer_list<int64> indices)132 ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {}
133 ShapeIndexView(const ShapeIndexView& other) = default;
134
135 using iterator = const int64*;
136
begin()137 iterator begin() const { return indices_.begin(); }
end()138 iterator end() const { return indices_.end(); }
size()139 int64 size() const { return indices_.size(); }
empty()140 bool empty() const { return indices_.empty(); }
front()141 int64 front() const {
142 CHECK(!empty());
143 return indices_.front();
144 }
back()145 int64 back() const {
146 CHECK(!empty());
147 return indices_.back();
148 }
ConsumeFront()149 ShapeIndexView ConsumeFront() const {
150 ShapeIndexView result = *this;
151 result.indices_.remove_prefix(1);
152 return result;
153 }
ConsumeBack()154 ShapeIndexView ConsumeBack() const {
155 ShapeIndexView result = *this;
156 result.indices_.remove_suffix(1);
157 return result;
158 }
ToShapeIndex()159 ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }
160
161 bool operator==(const ShapeIndexView& other) const;
162 bool operator!=(const ShapeIndexView& other) const;
163
164 string ToString() const;
165
166 // Returns true if this shape index starts with 'prefix'.
167 bool StartsWith(ShapeIndexView prefix) const;
168
169 private:
170 absl::Span<const int64> indices_;
171 };
172
ShapeIndex(ShapeIndexView v)173 inline ShapeIndex::ShapeIndex(ShapeIndexView v)
174 : ShapeIndex(v.begin(), v.end()) {}
175
176 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
177 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
178
179 // Namespaced collection of (static) shape utilities.
180 //
181 // These are all effectively convenience functions for testing/tweaking proto
182 // properties, which do invariant checks before / after the operation.
183 class ShapeUtil {
184 public:
185 // Data structure which describes the coordinates and the shape, of a tuple
186 // shaped sub-shape.
187 struct IndexedShape {
188 IndexedShape() = default;
IndexedShapeIndexedShape189 IndexedShape(ShapeIndex index, Shape shape)
190 : index(std::move(index)), shape(std::move(shape)) {}
191 ShapeIndex index;
192 Shape shape;
193 };
194
195 // Returns the number of elements are contained within the provided shape;
196 // e.g. for rank 0 (scalars) the result is always 1.
197 // Precondition: shape.IsArray()
198 static int64 ElementsIn(const Shape& shape);
199
200 // As ElementsIn(), but recurses through tuples.
201 static int64 ElementsInRecursive(const Shape& shape);
202
203 // Returns true if shape has the primitive type, recurses through tuples.
204 static bool HasPrimitiveType(const Shape& shape,
205 PrimitiveType primitive_type);
206
207 // Returns true if 'shape' is an array with zero elements.
208 static bool IsZeroElementArray(const Shape& shape);
209
210 // Returns the number of bytes required for an allocation of shape. The
211 // |pointer_size| parameter is used for calculating the size of tuple
212 // shapes. This includes only the size of the top-level buffer. For example, a
213 // tuple is stored as an array of pointers to other buffers. In this case,
214 // this method only returns the size of the pointer array.
215 static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
216
217 // Returns the number of bytes used to store the primitive_type.
218 //
219 // Precondition: shape.IsArray()
220 static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
221
222 // Returns the number of bytes required to store the tuple member pointers for
223 // a allocation of shape. The `shape` must be a TUPLE shape, and
224 // `pointer_size` must be larger than zero.
225 static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
226 int64 pointer_size);
227
228 // Returns the number of bytes required for the elements in an allocation of
229 // `shape`, which must be an array shape. Shapes use a separate
230 // memory location for each element, and so for these shapes,
231 // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. This
232 // size also includes padding if present in the layout.
233 static int64 ByteSizeOfElements(const Shape& shape);
234
235 // Returns a human-readable string that represents the given shape, with or
236 // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
237 static string HumanString(const Shape& shape);
238 static string HumanStringWithLayout(const Shape& shape);
239
240 // As above, but for program shapes, returns a string for the form:
241 //
242 // (param_name: f32[42x12], ...) -> f32[24x42]
243 static string HumanString(const ProgramShape& program_shape);
244
245 // Returns whether the LHS and RHS shapes have the same dimensions; note: does
246 // not check element type.
247 // Precondition: IsArray(lhs) && IsArray(rhs)
248 static bool SameDimensions(const Shape& lhs, const Shape& rhs);
249
250 // Returns whether the LHS and RHS shapes have the same rank; note: does
251 // not check element type.
252 // Precondition: IsArray(lhs) && IsArray(rhs)
253 static bool SameRank(const Shape& lhs, const Shape& rhs);
254
255 // Returns whether the lhs and rhs shapes have the same element type.
SameElementType(const Shape & lhs,const Shape & rhs)256 static bool SameElementType(const Shape& lhs, const Shape& rhs) {
257 return lhs.element_type() == rhs.element_type();
258 }
259
260 // As SameElementType, but allows floating point types to have different
261 // precisions.
SameElementTypeIgnoringFpPrecision(const Shape & a,const Shape & b)262 static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
263 const Shape& b) {
264 if (ElementIsFloating(a) && ElementIsFloating(b)) {
265 return true;
266 }
267 return ShapeUtil::SameElementType(a, b);
268 }
269
270 // Returns the higher-precision element type if a and b are both floating
271 // point types; otherwise, checks that they have the same element type
272 // and returns it.
HigherPrecisionElementType(const Shape & a,const Shape & b)273 static PrimitiveType HigherPrecisionElementType(const Shape& a,
274 const Shape& b) {
275 // Returns a tuple where the elements are lexicographically ordered in terms
276 // of importance.
277 auto type_properties = [](const Shape& shape) {
278 return std::make_tuple(
279 // Prefer floating point types with more range over other
280 // floating-point types or non-floating point types.
281 ElementIsFloating(shape)
282 ? primitive_util::OverflowExponent(shape.element_type())
283 : -1,
284 // Prefer floating point types with more precision over less precise
285 // types.
286 ElementIsFloating(shape)
287 ? primitive_util::SignificandWidth(shape.element_type())
288 : -1,
289 // Prefer wider types over narrower types.
290 primitive_util::BitWidth(shape.element_type()),
291 // Prefer signed integer types over unsigned integer types.
292 primitive_util::IsSignedIntegralType(shape.element_type()));
293 };
294 auto a_properties = type_properties(a);
295 auto b_properties = type_properties(b);
296 if (a_properties > b_properties) {
297 return a.element_type();
298 }
299 if (b_properties > a_properties) {
300 return b.element_type();
301 }
302 CHECK(SameElementType(a, b));
303 return a.element_type();
304 }
305
306 // Returns true if the rank, dimension sizes, and element type are
307 // identical. Layout is ignored. Tuple elements are compared recursively for
308 // compatibility.
309 static bool Compatible(const Shape& lhs, const Shape& rhs);
310
311 // Returns true if the rank and dimension sizes are identical. Element type
312 // and layout are ignored. Tuple elements are compared recursively for
313 // compatibility.
314 static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
315
316 // Returns true if the tuple tree shapes and leaf ranks are identical.
317 // Leaf dimensions, element type, and layout are ignored. Tuple elements are
318 // compared recursively for compatibility.
319 static bool CompatibleKind(const Shape& lhs, const Shape& rhs);
320
321 // As Compatible, but allow one of lhs and rhs to be BF16 while the other
322 // being F32. Tuple elements are compared recursively for compatibility.
323 static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
324
325 // Returns whether the lhs and rhs shapes are identical.
326 static bool Equal(const Shape& lhs, const Shape& rhs);
327
328 // As Equal, but does not compare the element type.
329 static bool EqualIgnoringElementType(const Shape& lhs, const Shape& rhs);
330
331 // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
332 static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
333
334 // Two shapes have same structure if all subshape indices of lhs are presented
335 // on rhs and vice versa.
336 // A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to
337 // (S32, (F32[3], S32[2])) as their structures are both (,(,))
338 //
339 // In contrast, (F32, (F32, F32)) is structurally different from
340 // ((F32, F32), F32) as the former has structure (,(,)) while the latter has
341 // ((,),)
342 static bool EqualStructure(const Shape& lhs, const Shape& rhs);
343
344 // Returns the number of dimensions for which the dimension is not (trivially)
345 // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
346 // fluff. Note that zero dimensions are included in the true rank, e.g.,
347 // f32[3,0,1] has a true rank of 2D.
348 static int64 TrueRank(const Shape& shape);
349
350 static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
351 Shape result);
352
353 ////////////////////
354 // Scalar-specific
355
IsScalar(const Shape & shape)356 static bool IsScalar(const Shape& shape) {
357 return shape.IsArray() && shape.rank() == 0;
358 }
IsEffectiveScalar(const Shape & shape)359 static bool IsEffectiveScalar(const Shape& shape) {
360 return shape.IsArray() && TrueRank(shape) == 0;
361 }
362
363 // Returns whether "shape" is a scalar (array) with the given element_type.
364 static bool IsScalarWithElementType(const Shape& shape,
365 PrimitiveType element_type);
366
367 // Extracts the size of the shape's dimension at dimension number
368 // GetDimensionNumber(dimension_number).
369 static int64 GetDimension(const Shape& shape, int64 dimension_number);
370
371 // Resolves a dimension number, supporting negative indexing.
372 //
373 // Negative indexing has similar semantics to Python. For an N-dimensional
374 // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
375 // N-2, and so on.
376 //
377 // This function always returns a positive dimension number for any given
378 // dimension_number (which itself can be negative).
379 static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number);
380
381 // Returns a shape with the same dimensions as the original, but with the
382 // element type changed to type.
383 static Shape ChangeElementType(const Shape& original, PrimitiveType type);
384
385 // Retursn a shape with same dimensions but with all dimensions set to static.
386 static Shape MakeStaticShape(const Shape& original);
387
388 // Creates a tuple shape from a slice of element shapes within the tuple.
389 static Shape MakeTupleShape(absl::Span<const Shape> shapes);
390
391 // Creates an opaque shape. These are generally used for threading a context
392 // into a custom operation.
393 static Shape MakeOpaqueShape();
394
395 // Creates a token shape. Values of this shape are used for ordering
396 // side-effecting operations.
397 static Shape MakeTokenShape();
398
399 // Appends a shape to the given tuple.
400 static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
401
402 // Update a subshape of a tuple.
403 static void UpdateTupleShape(const Shape& shape, int64 index,
404 Shape* tuple_shape);
405
406 // Update the dynamic dimension for a shape. This shape can be a nested tuple.
407 static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index,
408 int64 dim, bool is_dynamic);
409
410 // Appends a major dimension to the shape with the given bound.
411 static void AppendMajorDimension(int bound, Shape* shape);
412
413 // Copy the dynamic dimensions property from one shape to another.
414 static void CopyDynamicDimensions(Shape* to, const Shape& from);
415
416 // Returns an empty tuple shape. Can be used as a sentinel Shape value.
MakeNil()417 static Shape MakeNil() { return MakeTupleShape({}); }
418
419 // Checks whether the shape is initialized.
IsInitialized(const Shape & shape)420 static bool IsInitialized(const Shape& shape) {
421 return shape.element_type() != PRIMITIVE_TYPE_INVALID;
422 }
423
424 // Constructs a new shape with the given element type and sequence of
425 // dimensions.
426 static Shape MakeShape(PrimitiveType element_type,
427 absl::Span<const int64> dimensions);
428
429 // Make a scalar shape with given primitive type.
430 static Shape MakeScalarShape(PrimitiveType element_type);
431
432 // Constructs a new shape with the given element type and sequence of
433 // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates
434 // with a true value that the respective dimension is dynamic. If the
435 // dimension is dynamic then the respective value in 'dimension' is an upper
436 // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be
437 // the same size.
438 static Shape MakeShape(PrimitiveType element_type,
439 absl::Span<const int64> dimensions,
440 const std::vector<bool>& dynamic_dimensions);
441
442 // Constructs a new shape with the given element type and sequence of
443 // dimensions. Method checks if the element type is valid and the shape's
444 // size fits in std::numeric_limits<int64>::max().
445 static StatusOr<Shape> MakeValidatedShape(PrimitiveType element_type,
446 absl::Span<const int64> dimensions);
447 static StatusOr<Shape> MakeValidatedShape(
448 PrimitiveType element_type, absl::Span<const int64> dimensions,
449 const std::vector<bool>& dynamic_dimensions);
450
451 // Creates a Shape with element type corresponding to T and the given
452 // dimensions
453 template <typename T>
MakeShapeWithType(absl::Span<const int64> dimensions)454 static Shape MakeShapeWithType(absl::Span<const int64> dimensions) {
455 return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(),
456 dimensions);
457 }
458
459 // Constructs a new shape with the given minor_to_major order in its Layout.
460 // Returns a value shape such that shape.has_layout().
461 static Shape MakeShapeWithLayout(PrimitiveType element_type,
462 absl::Span<const int64> dimensions,
463 absl::Span<const int64> minor_to_major,
464 absl::Span<const Tile> tiles = {},
465 int64 element_size_in_bits = 0,
466 int64 memory_space = 0);
467
468 // Returns the same shape except with all dimensions set to be static.
469 static Shape MakeShapeWithStaticDimensions(const Shape& shape);
470
471 // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
472 static Shape MakeShapeWithDescendingLayout(
473 PrimitiveType element_type, absl::Span<const int64> dimensions);
474
475 // Returns a new Shape based on the given Shape with low-dimension-major
476 // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
477 // rearranged so that it has the same in-memory layout as the given shape.
478 //
479 // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
480 static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
481 const Shape& shape);
482
483 // As MakeShape, but the object to write to is passed in.
484 static Status PopulateShape(PrimitiveType element_type,
485 absl::Span<const int64> dimensions, Shape* shape);
486
487 // Validates that the provided shape satisfies invariants.
488 static Status ValidateShape(const Shape& shape);
489
490 // Validates the provided shape satisfies invariants, except those that
491 // pertain to layout.
492 //
493 // Layout is optional for client-provided shapes, so that the compiler may
494 // determine and assign an optimized layout.
495 static Status ValidateShapeWithOptionalLayout(const Shape& shape);
496
497 // Returns whether the element type of the shape is integral (signed or
498 // unsigned). Note that predicates are not considered integral here, since
499 // they are logical values.
500 static bool ElementIsIntegral(const Shape& shape);
501
502 // Returns whether the element type of the shape is floating point.
503 static bool ElementIsFloating(const Shape& shape);
504
505 // Returns whether the element type of the shape is complex.
506 static bool ElementIsComplex(const Shape& shape);
507
508 // Returns whether the element type has the given bit width.
509 static bool ElementHasBitWidth(const Shape& shape, int bits);
510
511 // Returns whether the element type of the shape is integral and has
512 // the specified number of bits.
513 static bool ElementIsIntegralWithBits(const Shape& shape, int bits);
514
515 // Returns whether the element type of the shape is signed. Note
516 // that floating point numbers are signed.
517 static bool ElementIsSigned(const Shape& shape);
518
519 // Returns whether the given primitive type corresponds to an array shape.
520 static bool IsArrayPrimitiveType(PrimitiveType primitive_type);
521
522 // Returns whether the shape is a tuple with at least one element which is
523 // also a tuple.
524 static bool IsNestedTuple(const Shape& shape);
525
526 // Returns true if shape is an empty tuple.
527 static bool IsEmptyTuple(const Shape& shape);
528
529 // Returns the number of elements in the given tuple shape.
530 // Precondition: IsTuple(shape)
531 static int64 TupleElementCount(const Shape& shape);
532
533 // Returns the tuple element shape at given index.
534 // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
535 static const Shape& GetTupleElementShape(const Shape& shape, int64 index);
536
537 // Returns the number of elements, recursively, in the given shape.
538 static int64 SubshapeCount(const Shape& shape);
539
540 // Slices tuple elements in the range [start, limit) and returns a new tuple
541 // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
542 static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit);
543
544 // Returns the shape of the real/imaginary components of the given complex
545 // shape.
546 static Shape ComplexComponentShape(const Shape& complex_shape);
547
548 // Returns true if the given shape has a subshape at the given index.
549 static bool IndexIsValid(const Shape& shape, ShapeIndexView index);
550
551 // GetSubshape and GetMutableSubshape return a particular nested Shape within
552 // the given Shape argument. The non-Try variants check fail if index is
553 // invalid.
554 static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
555 static StatusOr<const Shape*> TryGetSubshape(const Shape& shape,
556 ShapeIndexView index);
557 static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
558
559 // Returns whether the given index in the given shape is a leaf element of the
560 // shape.
561 static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);
562
563 // Returns the number of leaves in the shape.
564 static int64 GetLeafCount(const Shape& shape);
565
566 // Retrieves all the leaf shapes and their indexes, in the order walked by
567 // the ForEachSubshape() API.
568 static std::vector<IndexedShape> GetLeafShapes(const Shape& shape);
569
570 // Calls the given visitor function for each subshape of the given shape.
571 // Subshapes are visited in DFS pre-order starting with the entire shape
572 // (index {}).
573 using VisitorFunction = std::function<void(const Shape& /*subshape*/,
574 const ShapeIndex& /*index*/)>;
575 static void ForEachSubshape(const Shape& shape, const VisitorFunction& func);
576 using MutatingVisitorFunction =
577 std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
578 static void ForEachMutableSubshape(Shape* shape,
579 const MutatingVisitorFunction& func);
580
581 // Variants of ForEach(Mutable)Subshape which propagate Status from the
582 // visitor function.
583 using StatusVisitorFunction = std::function<Status(
584 const Shape& /*subshape*/, const ShapeIndex& /*index*/)>;
585 static Status ForEachSubshapeWithStatus(const Shape& shape,
586 const StatusVisitorFunction& func);
587 using MutatingStatusVisitorFunction =
588 std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
589 static Status ForEachMutableSubshapeWithStatus(
590 Shape* shape, const MutatingStatusVisitorFunction& func);
591
592 // Returns true if `shape` (which must be an array) with degenerate dimensions
593 // (dimensions with bound 1).
594 static bool HasDegenerateDimensions(const Shape& shape);
595
596 // Drops any degenerate dimensions (i.e. dimensions of size 1)
597 static Shape DropDegenerateDimensions(const Shape& shape);
598
599 // Permutes the dimensions by the given permutation, so
600 // return_value.dimensions[i] = argument.dimensions[permutation[i]].
601 //
602 // Postcondition: For any valid permutation,
603 //
604 // !HasLayout(shape) ||
605 // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
606 // permutation).
607 static Shape PermuteDimensions(absl::Span<const int64> permutation,
608 const Shape& shape);
609
610 // If we can go from `shape_pre` to `shape_post` by merely inserting or
611 // deleting 1-sized dimensions, return the indices in `shape_pre` of the
612 // deleted dimensions and the indices in `dims_post` of the inserted
613 // dimensions.
614 // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
615 // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
616 // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
617 // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
618 // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
619 // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
620 static std::tuple<bool, std::vector<int64>, std::vector<int64>>
621 InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
622 const Shape& shape_post);
623
624 // Suppose a reshape transforms input_shape to output shape. Returns a vector
625 // of pairs that indicate the input and output dimensions that this reshape
626 // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
627 // the returned vector, the reshape transforms any input index whose I-th
628 // dimension is x to an output index whose O-th dimension is x too.
629 //
630 // Post-condition: the returned vector is sorted (by both input and output
631 // dimensions because input and output dimensions have the same order).
632 //
633 // Example:
634 // input shape = T[a, b, x, y, cd]
635 // output shape = T[ab, x, 1, y, c, d]
636 // return value = {{2, 1}, {3, 3}}
637 //
638 // The two pairs represent the input and output dimension of size x and
639 // those of size y.
640 static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
641 const Shape& input_shape, const Shape& output_shape);
642
643 // Return whether the given reshape instruction leaves the dimensions at the
644 // given input indices unmodified, and returns their output indices.
645 //
646 // Example:
647 // input_dim_indices = {2, 3}
648 // input shape = T[a, b, x, y, cd]
649 // output shape = T[ab, x, 1, y, c, d]
650 // return value = {1, 3}
651 //
652 // Precondition: input_dim_indices is sorted.
653 static absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
654 const Shape& from_shape, const Shape& to_shape,
655 absl::Span<const int64> input_dim_indices);
656
657 // Returns whether a transpose from input_shape to output_shape with dimension
658 // mapping "dimension_mapping" produces a result which is bit-wise identical
659 // to its input and thus may be replaced with a bitcast.
660 //
661 // Precondition: Both input_shape and output_shape have explicit layouts.
662 static bool TransposeIsBitcast(const Shape& input_shape,
663 const Shape& output_shape,
664 absl::Span<const int64> dimension_mapping);
665
666 // Returns whether a reshape from "input_shape" to "output_shape" is a
667 // bitcast.
668 //
669 // Precondition: Both input_shape and output_shape have explicit layouts.
670 static bool ReshapeIsBitcast(const Shape& input_shape,
671 const Shape& output_shape);
672
673 // Find a physical layout for 'output_shape' such that
674 // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
675 // true (where 'output_shape_with_layout' is 'output_shape' with the found
676 // layout). The layout of 'input_shape' is kept fixed. Returns
677 // 'output_shape_with_layout' if such a layout can be found, and an error
678 // otherwise.
679 static absl::optional<Shape> AlignLayouts(const Shape& input_shape,
680 const Shape& output_shape);
681
682 // Returns a shape with the given dimension deleted.
683 // For example:
684 // • `DeleteDimension(1, T[m, n, k]) = T[m, k]`
685 static Shape DeleteDimension(int64 dim_to_delete, Shape shape);
686
687 // Returns a shape with all the dimensions of the input shape for which `p`
688 // returns true.
689 // For examples:
690 // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
691 // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
692 static Shape FilterDimensions(const std::function<bool(int64)>& p,
693 Shape shape);
694
695 // Returns true if `dynamic_shape` has dimensions that are less-equal to the
696 // "bounded_shape". Shapes must be arrays.
697 static bool DynamicArrayShapeIsCompatible(const xla::Shape& dynamic_shape,
698 const xla::Shape& bounded_shape);
699
700 // Same as DynamicArrayShapeIsCompatible() but supports tuples.
701 static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
702 const xla::Shape& bounded_shape);
703
704 // Iterates through all the shape indexes, in minor to major order,
705 // starting from the base indexes, incrementing by the incr steps, up to
706 // count (index[i] < base[i] + count[i]), and calls the visitor_function
707 // with the current index. The visitor_function visitor function should
708 // return true if it wants to continue, or false otherwise.
709 //
710 // visitor_function must be a callable of type
711 // StatusOr<bool>(absl::Span<int64>) or compatible.
712 template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)713 static Status ForEachIndexWithStatus(const Shape& shape,
714 absl::Span<const int64> base,
715 absl::Span<const int64> count,
716 absl::Span<const int64> incr,
717 const FnType& visitor_function) {
718 return ForEachIndexInternal(shape, base, count, incr, visitor_function);
719 }
720
721 // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
722 struct IndexIterationSpace {
723 std::vector<int64> index_base;
724 std::vector<int64> index_count;
725 std::vector<int64> index_incr;
726 };
727
728 template <typename FnTy>
ForEachIndexWithStatus(const Shape & shape,const IndexIterationSpace & iteration_space,FnTy && function)729 static Status ForEachIndexWithStatus(
730 const Shape& shape, const IndexIterationSpace& iteration_space,
731 FnTy&& function) {
732 return ShapeUtil::ForEachIndexWithStatus(
733 shape, iteration_space.index_base, iteration_space.index_count,
734 iteration_space.index_incr, std::forward<FnTy>(function));
735 }
736
737 template <typename FnType>
ForEachIndex(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)738 static void ForEachIndex(const Shape& shape, absl::Span<const int64> base,
739 absl::Span<const int64> count,
740 absl::Span<const int64> incr,
741 const FnType& visitor_function) {
742 ForEachIndexWithStatus(shape, base, count, incr,
743 [&](absl::Span<const int64> indices) {
744 return StatusOr<bool>(visitor_function(indices));
745 })
746 .IgnoreError();
747 }
748
749 // These convenience wrappers don't take `base`, `count` and `incr`
750 // explicitly, but iterate over every element in `shape` instead.
751
752 template <typename FnType>
ForEachIndexWithStatus(const Shape & shape,const FnType & visitor_function)753 static Status ForEachIndexWithStatus(const Shape& shape,
754 const FnType& visitor_function) {
755 std::vector<int64> base(shape.dimensions_size());
756 std::vector<int64> incr(shape.dimensions_size(), 1);
757 return ForEachIndexWithStatus(shape, base,
758 /*count=*/AsInt64Slice(shape.dimensions()),
759 incr, visitor_function);
760 }
761
762 template <typename FnType>
ForEachIndex(const Shape & shape,const FnType & visitor_function)763 static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
764 ForEachIndexWithStatus(shape, [&](absl::Span<const int64> indices) {
765 return StatusOr<bool>(visitor_function(indices));
766 }).IgnoreError();
767 }
768
769 // A parallel version of ForEachIndex(WithStatus). This can only be used if
770 // the visitor_function is thread-safe and the order of iteration does not
771 // matter.
772 //
773 // visitor_function must be a callable of type
774 // void(Span<int64>) or compatible.
775 template <typename FnType>
ForEachIndexParallel(const Shape & shape,absl::Span<const int64> base,absl::Span<const int64> count,absl::Span<const int64> incr,const FnType & visitor_function)776 static void ForEachIndexParallel(const Shape& shape,
777 absl::Span<const int64> base,
778 absl::Span<const int64> count,
779 absl::Span<const int64> incr,
780 const FnType& visitor_function) {
781 // The parallel version of ForEachIndexInternal can never fail.
782 CHECK(ForEachIndexInternal(
783 shape, base, count, incr,
784 [&visitor_function](
785 absl::Span<const int64> indexes) -> StatusOr<bool> {
786 visitor_function(indexes);
787 return true;
788 },
789 /*parallel=*/true)
790 .ok());
791 }
792
793 // Compute a hash for `shape`.
794 static size_t Hash(const Shape& shape);
795
796 // About 0-2-1 transpose:
797 //
798 // If a shape can be viewed as three logical components 0-1-2 in the order of
799 // major to minor, a 0-2-1-transpose changes the order of such logical
800 // components to 0-2-1. We call the shape being transposed the input shape and
801 // the transposed shape the output shape. The logical view of the input/output
802 // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the
803 // normalized shapes. The original input/output shapes are called unnormalized
804 // shapes.
805 //
806 // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
807 // normalized shape of `b` or the 0-2-1 shape.
808 static absl::optional<std::vector<int64>> FindTranspose021(const Shape& a,
809 const Shape& b);
810
811 // Strips device-specific information, namely tiling and memory-space
812 // information, from a shape.
813 static Shape DeviceShapeToHostShape(Shape s);
814
815 // Returns true iff element type of shape `from` can be safely upcasted to
816 // element type of shape `to`.
817 static bool ElementCanUpcast(const Shape& from, const Shape& to);
818
819 private:
820 // Fills *shape. Returns true on success.
821 // REQUIRES: *shape is empty.
822 static bool FillNewShape(PrimitiveType element_type,
823 absl::Span<const int64> dimensions, Shape* shape);
824
825 // Validates the shape size is sane. This makes sure it's safe to do
826 // calculations in int64 without overflowing.
827 static Status ValidateShapeSize(const Shape& shape);
828
829 // Validates all of the non-layout properties of the shape -- this is a helper
830 // used by both the layout-optional and layout-required public method.
831 static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
832
833 template <typename FnType>
834 static Status ForEachIndexInternal(const Shape& shape,
835 absl::Span<const int64> base,
836 absl::Span<const int64> count,
837 absl::Span<const int64> incr,
838 const FnType& visitor_function,
839 bool parallel = false) {
840 if (ShapeUtil::IsZeroElementArray(shape)) {
841 return Status::OK();
842 }
843 CHECK_EQ(shape.rank(), base.size());
844 CHECK_EQ(incr.size(), base.size());
845 CHECK_EQ(count.size(), base.size());
846 const int64 rank = LayoutUtil::MinorToMajor(shape).size();
847 // Allows handling R0 arrays, such that the visitor function will be called
848 // once with the proper empty indexes.
849 int64 n = -1;
850 std::vector<int64> indexes(base.begin(), base.end());
851 const int kNumThreads = tensorflow::port::MaxParallelism();
852 absl::optional<tensorflow::thread::ThreadPool> pool;
853 if (parallel) {
854 pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
855 }
856
857 tensorflow::mutex mu;
858 Status status; // Guarded by mu
859
860 while (n < rank) {
861 if (pool != absl::nullopt) {
862 pool->Schedule([indexes, &visitor_function, &mu, &status] {
863 StatusOr<bool> result = visitor_function(indexes);
864 if (!result.ok()) {
865 tensorflow::mutex_lock lock(mu);
866 status = status.ok() ? result.status() : status;
867 }
868 });
869 } else {
870 TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
871 if (!should_continue) {
872 break;
873 }
874 }
875 // Increments dimensions in minor to major order.
876 for (n = 0; n < rank; ++n) {
877 int64 dim = LayoutUtil::Minor(shape.layout(), n);
878 indexes[dim] += incr[dim];
879 if (indexes[dim] < base[dim] + count[dim]) {
880 break;
881 }
882 indexes[dim] = base[dim];
883 }
884 }
885
886 // Waits for the scheduled work to complete.
887 pool.reset();
888 return status;
889 }
890
891 TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
892 };
893
894 } // namespace xla
895
896 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
897