1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/shape_util.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <numeric>
21 #include <optional>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/hash/hash.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_split.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/strings/strip.h"
38 #include "tensorflow/compiler/xla/index_util.h"
39 #include "tensorflow/compiler/xla/layout_util.h"
40 #include "tensorflow/compiler/xla/overflow_util.h"
41 #include "tensorflow/compiler/xla/permutation_util.h"
42 #include "tensorflow/compiler/xla/primitive_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/platform/logging.h"
48
49 namespace xla {
50
51 using absl::StrAppend;
52 using absl::StrCat;
53
54 namespace {
55 // An array that is indexed by PrimitiveType, and returns
56 // the size of each element of that primitive type, or 0
57 // if the PrimitiveType is not a primitive type
58 constexpr uint8_t primitive_byte_size[PrimitiveType_ARRAYSIZE] = {
59 0, // PRIMITIVE_TYPE_INVALID = 0,
60 sizeof(int8_t), // PRED = 1
61 sizeof(int8_t), // S8 = 2
62 sizeof(int16_t), // S16 = 3
63 sizeof(int32_t), // S32 = 4
64 sizeof(int64_t), // S64 = 5
65 sizeof(uint8_t), // U8 = 6
66 sizeof(uint16_t), // U16 = 7
67 sizeof(uint32_t), // U32 = 8
68 sizeof(uint64_t), // U64 = 9
69 sizeof(float) / 2, // F16 = 10
70 sizeof(float), // F32 = 11
71 sizeof(double), // F64 = 12
72 0, // TUPLE = 13
73 0, // OPAQUE_TYPE = 14
74 sizeof(complex64), // C64 = 15
75 sizeof(float) / 2, // BF16 = 16
76 0, // TOKEN = 17
77 sizeof(complex128) // C128 = 18
78 };
79 constexpr int64_t kAnnotationPrintInterval = 5;
80 } // namespace
81
ToString() const82 std::string ShapeIndex::ToString() const {
83 return StrCat("{", absl::StrJoin(*this, ","), "}");
84 }
85
operator <<(std::ostream & out,const ShapeIndex & shape_index)86 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
87 out << shape_index.ToString();
88 return out;
89 }
90
IsArrayPrimitiveType(PrimitiveType primitive_type)91 /* static */ bool ShapeUtil::IsArrayPrimitiveType(
92 PrimitiveType primitive_type) {
93 return primitive_util::IsArrayType(primitive_type);
94 }
95
96 namespace {
97 // Constructs and returns the new shape with the given minor_to_major order in
98 // its Layout.
MakeShapeWithLayoutInternal(PrimitiveType element_type,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> minor_to_major,absl::Span<const DimLevelType> dim_level_types,absl::Span<const Tile> tiles,int64_t element_size_in_bits,int64_t memory_space)99 StatusOr<Shape> MakeShapeWithLayoutInternal(
100 PrimitiveType element_type, absl::Span<const int64_t> dimensions,
101 absl::Span<const int64_t> minor_to_major,
102 absl::Span<const DimLevelType> dim_level_types,
103 absl::Span<const Tile> tiles, int64_t element_size_in_bits,
104 int64_t memory_space) {
105 if (dimensions.size() != minor_to_major.size()) {
106 return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
107 dimensions.size(), minor_to_major.size());
108 }
109 if (element_type == OPAQUE_TYPE || element_type == TUPLE ||
110 element_type == TOKEN) {
111 return InvalidArgument("Unsupported element type: %s",
112 PrimitiveType_Name(element_type));
113 }
114 TF_ASSIGN_OR_RETURN(Shape shape,
115 ShapeUtil::MakeValidatedShape(element_type, dimensions));
116 if (element_size_in_bits ==
117 ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
118 // Only set element_size_in_bits if it's different from the default value.
119 element_size_in_bits = 0;
120 }
121 *shape.mutable_layout() =
122 LayoutUtil::MakeLayout(minor_to_major, dim_level_types, tiles,
123 element_size_in_bits, memory_space);
124 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
125 return shape;
126 }
127
128 template <typename T>
Deref(const T * ptr)129 const T& Deref(const T* ptr) {
130 DCHECK(ptr != nullptr);
131 return *ptr;
132 }
133
134 template <typename T>
Deref(const T & ref)135 const T& Deref(const T& ref) {
136 return ref;
137 }
138
139 template <typename ShapePtrOrRef>
MakeTupleShapeImpl(absl::Span<ShapePtrOrRef> shapes)140 Shape MakeTupleShapeImpl(absl::Span<ShapePtrOrRef> shapes) {
141 Shape result;
142 result.set_element_type(TUPLE);
143 result.mutable_tuple_shapes()->reserve(shapes.size());
144 for (const auto& shape : shapes) {
145 ShapeUtil::AppendShapeToTuple(Deref(shape), &result);
146 }
147 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
148 return result;
149 }
150
151 } // namespace
152
Equal(const Shape & lhs,const Shape & rhs)153 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
154 bool equal = Shape::Equal()(lhs, rhs);
155
156 if (!equal && VLOG_IS_ON(3)) {
157 VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
158 << ", rhs = " << rhs.ShortDebugString();
159 }
160
161 return equal;
162 }
163
EqualIgnoringElementType(const Shape & lhs,const Shape & rhs)164 /* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
165 const Shape& rhs) {
166 bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
167 if (!equal && VLOG_IS_ON(3)) {
168 VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
169 << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
170 }
171
172 return equal;
173 }
174
EqualIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)175 /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
176 const Shape& rhs) {
177 bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
178 if (!equal && VLOG_IS_ON(3)) {
179 VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
180 << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
181 }
182
183 return equal;
184 }
185
EqualStructure(const Shape & lhs,const Shape & rhs)186 /* static */ bool ShapeUtil::EqualStructure(const Shape& lhs,
187 const Shape& rhs) {
188 bool equal = true;
189 ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
190 equal &= IndexIsValid(rhs, index);
191 });
192 ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
193 equal &= IndexIsValid(lhs, index);
194 });
195
196 return equal;
197 }
198
TrueRank(const Shape & shape)199 /* static */ int64_t ShapeUtil::TrueRank(const Shape& shape) {
200 int64_t accum = 0;
201 for (int64_t dimension : shape.dimensions()) {
202 // We do not count zero dimensions.
203 if (dimension != 1) {
204 accum += 1;
205 }
206 }
207 return accum;
208 }
209
FillNewShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions,Shape * shape)210 /* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type,
211 absl::Span<const int64_t> dimensions,
212 Shape* shape) {
213 const int eint = static_cast<int>(element_type);
214 int64_t dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE)
215 ? primitive_byte_size[eint]
216 : 0); // Out of range: force a failure
217 if (dense_shape_size <= 0) {
218 return false;
219 }
220
221 // Verify that array-based lookup is consistent with public API.
222 DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type))
223 << element_type;
224
225 shape->set_element_type(element_type);
226 const int ndims = dimensions.size();
227 auto layout = shape->mutable_layout();
228 auto* minor_to_major = layout->mutable_minor_to_major();
229 for (int i = 0; i < ndims; i++) {
230 const int64_t d = dimensions[i];
231 if (d < 0) {
232 return false;
233 }
234 dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
235 if (dense_shape_size < 0) {
236 return false;
237 }
238
239 shape->add_dimensions(d);
240 minor_to_major->push_back(ndims - 1 - i);
241 }
242 return true;
243 }
244
MakeProgramShape(std::initializer_list<Shape> parameters,Shape result)245 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
246 std::initializer_list<Shape> parameters, Shape result) {
247 ProgramShape program_shape;
248 for (const Shape& shape : parameters) {
249 *program_shape.add_parameters() = shape;
250 }
251 *program_shape.mutable_result() = std::move(result);
252 return program_shape;
253 }
254
MakeShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions)255 /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
256 absl::Span<const int64_t> dimensions) {
257 Shape shape;
258 CHECK(FillNewShape(element_type, dimensions, &shape));
259 return shape;
260 }
261
MakeScalarShape(PrimitiveType element_type)262 /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
263 return MakeShape(element_type, {});
264 }
265
MakeShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions,const std::vector<bool> & dynamic_dimensions)266 /* static */ Shape ShapeUtil::MakeShape(
267 PrimitiveType element_type, absl::Span<const int64_t> dimensions,
268 const std::vector<bool>& dynamic_dimensions) {
269 return MakeValidatedShape(element_type, dimensions, dynamic_dimensions)
270 .ValueOrDie();
271 }
272
MakeShapeWithStaticDimensions(const Shape & shape)273 /* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
274 const Shape& shape) {
275 Shape output = shape;
276 output.clear_dynamic_dimensions();
277 return output;
278 }
279
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions)280 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
281 PrimitiveType element_type, absl::Span<const int64_t> dimensions) {
282 Shape shape;
283 if (!FillNewShape(element_type, dimensions, &shape)) {
284 return InvalidArgument("invalid shape type=%d, dims=[%s]",
285 static_cast<int>(element_type),
286 absl::StrJoin(dimensions, ","));
287 }
288 return shape;
289 }
290
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions,const std::vector<bool> & dynamic_dimensions)291 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
292 PrimitiveType element_type, absl::Span<const int64_t> dimensions,
293 const std::vector<bool>& dynamic_dimensions) {
294 if (dynamic_dimensions.size() != dimensions.size()) {
295 return InvalidArgument(
296 "dynamic dimensions size %d did not match number of dimensions %d",
297 dynamic_dimensions.size(), dimensions.size());
298 }
299
300 Shape shape;
301 if (!FillNewShape(element_type, dimensions, &shape)) {
302 return InvalidArgument("invalid shape type=%d, dims=[%s]",
303 static_cast<int>(element_type),
304 absl::StrJoin(dimensions, ","));
305 }
306 for (int i = 0, n = dimensions.size(); i < n; i++) {
307 shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
308 }
309 return shape;
310 }
311
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> minor_to_major,absl::Span<const DimLevelType> dim_level_types,absl::Span<const Tile> tiles,int64_t element_size_in_bits,int64_t memory_space)312 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
313 PrimitiveType element_type, absl::Span<const int64_t> dimensions,
314 absl::Span<const int64_t> minor_to_major,
315 absl::Span<const DimLevelType> dim_level_types,
316 absl::Span<const Tile> tiles, int64_t element_size_in_bits,
317 int64_t memory_space) {
318 auto ret = MakeShapeWithLayoutInternal(element_type, dimensions,
319 minor_to_major, dim_level_types, tiles,
320 element_size_in_bits, memory_space);
321 if (!ret.ok()) LOG(ERROR) << ret.status();
322 return ret.ValueOrDie();
323 }
324
MoveDimToMajor(const Shape & shape,int64_t dim)325 /* static */ Shape ShapeUtil::MoveDimToMajor(const Shape& shape, int64_t dim) {
326 if (shape.IsTuple()) {
327 std::vector<Shape> result_shapes;
328 result_shapes.reserve(shape.tuple_shapes_size());
329 for (const Shape& s : shape.tuple_shapes()) {
330 result_shapes.push_back(MoveDimToMajor(s, dim));
331 }
332 return ShapeUtil::MakeTupleShape(result_shapes);
333 }
334
335 Shape ret = shape;
336 if (!ret.has_layout()) {
337 LayoutUtil::SetToDefaultLayout(&ret);
338 }
339 *ret.mutable_layout() = LayoutUtil::MoveDimToMajor(ret.layout(), dim);
340 DimensionVector minor_to_major;
341 for (int64_t d : LayoutUtil::MinorToMajor(ret)) {
342 if (d != dim) {
343 minor_to_major.push_back(d);
344 }
345 }
346 minor_to_major.push_back(dim);
347 *ret.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
348 return ret;
349 }
350
MakeShapeWithDescendingLayout(PrimitiveType element_type,absl::Span<const int64_t> dimensions)351 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
352 PrimitiveType element_type, absl::Span<const int64_t> dimensions) {
353 std::vector<int64_t> layout(dimensions.size());
354 std::iota(layout.rbegin(), layout.rend(), static_cast<int64_t>(0));
355 return MakeShapeWithLayout(element_type, dimensions, layout);
356 }
357
358 /* static */ Shape
MakeShapeWithDescendingLayoutAndSamePhysicalLayout(const Shape & shape)359 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
360 const Shape& shape) {
361 std::vector<int64_t> dims(shape.dimensions_size());
362 for (int i = 0; i < shape.dimensions_size(); ++i) {
363 int dim = i;
364 if (shape.has_layout()) {
365 dim = LayoutUtil::Major(shape.layout(), dim);
366 }
367 dims[i] = shape.dimensions(dim);
368 }
369 Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
370 // Since the physical layout is kept the same, the tiles and element size are
371 // the same also.
372 if (shape.has_layout()) {
373 new_shape.mutable_layout()->mutable_tiles()->assign(
374 shape.layout().tiles().begin(), shape.layout().tiles().end());
375 new_shape.mutable_layout()->set_element_size_in_bits(
376 shape.layout().element_size_in_bits());
377 }
378 for (int i = 0; i < shape.dimensions_size(); ++i) {
379 new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
380 }
381 return new_shape;
382 }
383
PopulateShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions,Shape * shape)384 /* static */ Status ShapeUtil::PopulateShape(
385 PrimitiveType element_type, absl::Span<const int64_t> dimensions,
386 Shape* shape) {
387 shape->Clear();
388 shape->set_element_type(element_type);
389 for (int64_t dimension : dimensions) {
390 shape->add_dimensions(dimension);
391 }
392 LayoutUtil::SetToDefaultLayout(shape);
393 return ValidateShape(*shape);
394 }
395
MakeStaticShape(const Shape & original)396 /* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) {
397 Shape result = original;
398 result.clear_dynamic_dimensions();
399 return result;
400 }
401
MakeTupleShape(absl::Span<const Shape> shapes)402 /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
403 return MakeTupleShapeImpl(shapes);
404 }
405
MakeTupleShapeWithPtrs(absl::Span<const Shape * const> shapes)406 /* static */ Shape ShapeUtil::MakeTupleShapeWithPtrs(
407 absl::Span<const Shape* const> shapes) {
408 return MakeTupleShapeImpl(shapes);
409 }
410
MakeMaybeTupleShape(absl::Span<const Shape> shapes)411 /* static */ Shape ShapeUtil::MakeMaybeTupleShape(
412 absl::Span<const Shape> shapes) {
413 if (shapes.size() == 1) {
414 return shapes[0];
415 }
416 return MakeTupleShape(shapes);
417 }
418
MakeOpaqueShape()419 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
420 Shape result;
421 result.set_element_type(OPAQUE_TYPE);
422 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
423 return result;
424 }
425
MakeTokenShape()426 /* static */ Shape ShapeUtil::MakeTokenShape() {
427 Shape result;
428 result.set_element_type(TOKEN);
429 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
430 return result;
431 }
432
AppendShapeToTuple(const Shape & shape,Shape * tuple_shape)433 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
434 Shape* tuple_shape) {
435 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
436 *tuple_shape->add_tuple_shapes() = shape;
437 }
438
UpdateTupleShape(const Shape & shape,int64_t index,Shape * tuple_shape)439 /* static */ void ShapeUtil::UpdateTupleShape(const Shape& shape, int64_t index,
440 Shape* tuple_shape) {
441 CHECK(index < tuple_shape->tuple_shapes_size());
442 *tuple_shape->mutable_tuple_shapes(index) = shape;
443 }
444
UpdateDynamicDimension(Shape * shape,ShapeIndexView index,int64_t dim,bool is_dynamic)445 /* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
446 ShapeIndexView index,
447 int64_t dim,
448 bool is_dynamic) {
449 if (index.empty()) {
450 CHECK(!shape->IsTuple());
451 shape->set_dynamic_dimension(dim, is_dynamic);
452 return;
453 }
454
455 UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
456 index.subspan(1), dim, is_dynamic);
457 }
458
AppendMajorDimension(int bound,Shape * shape)459 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
460 CHECK(LayoutUtil::IsDenseArray(*shape));
461 if (shape->has_layout()) {
462 shape->mutable_layout()->add_minor_to_major(shape->rank());
463 }
464 shape->add_dimensions(bound);
465 TF_DCHECK_OK(ValidateShape(*shape));
466 }
467
AppendMinorDimension(int bound,Shape * shape)468 /* static */ void ShapeUtil::AppendMinorDimension(int bound, Shape* shape) {
469 CHECK(LayoutUtil::IsDenseArray(*shape));
470
471 if (shape->has_layout()) {
472 // Bump up all values in the layout by one.
473 for (int dim_idx = 0; dim_idx < shape->layout().minor_to_major_size();
474 dim_idx++) {
475 int layout_idx = shape->layout().minor_to_major(dim_idx);
476 shape->mutable_layout()->set_minor_to_major(dim_idx, layout_idx + 1);
477 }
478 // Then we can safely add zero.
479 shape->mutable_layout()->add_minor_to_major(0);
480 }
481 shape->add_dimensions(bound);
482 TF_DCHECK_OK(ValidateShape(*shape));
483 }
484
CopyDynamicDimensions(Shape * to,const Shape & from)485 /* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to,
486 const Shape& from) {
487 CHECK_EQ(to->rank(), from.rank());
488 for (int64_t i = 0; i < from.rank(); ++i) {
489 to->set_dynamic_dimension(i, from.is_dynamic_dimension(i));
490 }
491 TF_DCHECK_OK(ValidateShape(*to));
492 }
493
ElementIsIntegral(const Shape & shape)494 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
495 return primitive_util::IsIntegralType(shape.element_type());
496 }
497
ElementIsIntegralWithBits(const Shape & shape,int32_t bits)498 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
499 int32_t bits) {
500 return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
501 }
502
ElementHasBitWidth(const Shape & shape,int bits)503 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
504 if (!shape.IsArray()) {
505 return false;
506 }
507 return primitive_util::BitWidth(shape.element_type()) == bits;
508 }
509
ElementIsSigned(const Shape & shape)510 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
511 switch (shape.element_type()) {
512 case S8:
513 case S16:
514 case S32:
515 case S64:
516 case F16:
517 case BF16:
518 case F32:
519 case F64:
520 return true;
521
522 case PRED:
523 case U8:
524 case U16:
525 case U32:
526 case U64:
527 case C64:
528 case C128:
529 case TUPLE:
530 case OPAQUE_TYPE:
531 case TOKEN:
532 return false;
533
534 default:
535 LOG(FATAL) << "Unhandled element type " << shape.element_type();
536 }
537 }
538
ElementIsComplex(const Shape & shape)539 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
540 return primitive_util::IsComplexType(shape.element_type());
541 }
542
ElementIsFloating(const Shape & shape)543 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
544 return primitive_util::IsFloatingPointType(shape.element_type());
545 }
546
IsNestedTuple(const Shape & shape)547 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
548 return shape.IsTuple() &&
549 absl::c_any_of(shape.tuple_shapes(),
550 [](const Shape& s) { return s.IsTuple(); });
551 }
552
IsEmptyTuple(const Shape & shape)553 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
554 return shape.IsTuple() && TupleElementCount(shape) == 0;
555 }
556
TupleElementCount(const Shape & shape)557 /* static */ int64_t ShapeUtil::TupleElementCount(const Shape& shape) {
558 CHECK(shape.IsTuple()) << HumanString(shape);
559 return shape.tuple_shapes_size();
560 }
561
GetTupleElementShape(const Shape & shape,int64_t index)562 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
563 int64_t index) {
564 CHECK(shape.IsTuple());
565 CHECK_GT(TupleElementCount(shape), index);
566 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
567 return shape.tuple_shapes(index);
568 }
569
SubshapeCount(const Shape & shape)570 /* static */ int64_t ShapeUtil::SubshapeCount(const Shape& shape) {
571 int64_t n = 0;
572 ForEachSubshape(shape, [&](const Shape& literal_subshape,
573 const ShapeIndex& index) { ++n; });
574 return n;
575 }
576
SliceTuple(const Shape & tuple,int64_t start,int64_t limit)577 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64_t start,
578 int64_t limit) {
579 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
580 CHECK(tuple.IsTuple());
581 CHECK_LE(start, TupleElementCount(tuple));
582 CHECK_LE(limit, TupleElementCount(tuple));
583
584 std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
585 tuple.tuple_shapes().begin() + limit);
586 return MakeTupleShape(new_elements);
587 }
588
589 // Returns the shape of a real or imaginary component.
ComplexComponentShape(const Shape & complex_shape)590 /* static */ Shape ShapeUtil::ComplexComponentShape(
591 const Shape& complex_shape) {
592 CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
593 return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
594 complex_shape.element_type()));
595 }
596
ElementsIn(const Shape & shape)597 /* static */ int64_t ShapeUtil::ElementsIn(const Shape& shape) {
598 DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape);
599 DCHECK_EQ(shape.dimensions_size(), shape.rank());
600 if (shape.dimensions().size() == 1) {
601 return shape.dimensions()[0];
602 }
603 return std::accumulate<decltype(shape.dimensions().begin()), int64_t>(
604 shape.dimensions().begin(), shape.dimensions().end(), 1LL,
605 std::multiplies<int64_t>());
606 }
607
ElementsInRecursive(const Shape & shape)608 /* static */ int64_t ShapeUtil::ElementsInRecursive(const Shape& shape) {
609 CHECK(shape.IsArray() || shape.IsTuple());
610 if (shape.IsArray()) {
611 return ElementsIn(shape);
612 }
613 int64_t count = 0;
614 for (const Shape& element_shape : shape.tuple_shapes()) {
615 count += ElementsInRecursive(element_shape);
616 }
617 return count;
618 }
619
HasPrimitiveType(const Shape & shape,PrimitiveType primitive_type)620 /* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
621 PrimitiveType primitive_type) {
622 if (shape.element_type() == primitive_type) {
623 return true;
624 }
625 for (const Shape& element_shape : shape.tuple_shapes()) {
626 if (HasPrimitiveType(element_shape, primitive_type)) {
627 return true;
628 }
629 }
630 return false;
631 }
632
IsZeroElementArray(const Shape & shape)633 /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
634 return shape.IsArray() && ElementsIn(shape) == 0;
635 }
636
IsScalarWithElementType(const Shape & shape,PrimitiveType element_type)637 /* static */ bool ShapeUtil::IsScalarWithElementType(
638 const Shape& shape, PrimitiveType element_type) {
639 return IsScalar(shape) && shape.element_type() == element_type;
640 }
641
HumanString(const Shape & shape)642 /* static */ std::string ShapeUtil::HumanString(const Shape& shape) {
643 if (shape.IsTuple()) {
644 std::string text = "(";
645 const auto& tuple_shapes = shape.tuple_shapes();
646 for (int64_t i = 0; i < tuple_shapes.size(); ++i) {
647 const Shape& elem_shape = tuple_shapes[i];
648 if (i != 0) {
649 StrAppend(&text, ", ");
650 if (i % kAnnotationPrintInterval == 0) {
651 StrAppend(&text, absl::StrFormat("/*index=%lld*/", i));
652 }
653 }
654 StrAppend(&text, HumanString(elem_shape));
655 }
656 text += ")";
657 return text;
658 }
659 std::vector<std::string> dim_elements;
660 const auto dimensions_size = shape.dimensions_size();
661 dim_elements.reserve(dimensions_size);
662 for (int i = 0; i < dimensions_size; ++i) {
663 if (shape.is_dynamic_dimension(i)) {
664 dim_elements.push_back(StrCat("<=", shape.dimensions(i)));
665 } else {
666 dim_elements.push_back(StrCat(shape.dimensions(i)));
667 }
668 }
669 return StrCat(
670 primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
671 absl::StrJoin(dim_elements, ","), "]");
672 }
673
HumanStringWithLayout(const Shape & shape)674 /* static */ std::string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
675 if (shape.IsTuple()) {
676 std::string text = "(";
677 const auto& tuple_shapes = shape.tuple_shapes();
678 for (int64_t i = 0; i < tuple_shapes.size(); ++i) {
679 const Shape& elem_shape = tuple_shapes[i];
680 if (i != 0) {
681 StrAppend(&text, ", ");
682 if (i % kAnnotationPrintInterval == 0) {
683 StrAppend(&text, absl::StrFormat("/*index=%lld*/", i));
684 }
685 }
686 StrAppend(&text, HumanStringWithLayout(elem_shape));
687 }
688 text += ")";
689 return text;
690 }
691 std::string result = HumanString(shape);
692 if (shape.has_layout()) {
693 if (IsScalar(shape)) {
694 std::string layout_str = LayoutUtil::HumanString(shape.layout());
695 // Don't print "{}" as layout for scalars.
696 if (layout_str != "{}") {
697 StrAppend(&result, layout_str);
698 }
699 } else if (shape.IsArray()) {
700 StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
701 }
702 }
703 return result;
704 }
705
HumanString(const ProgramShape & program_shape)706 /* static */ std::string ShapeUtil::HumanString(
707 const ProgramShape& program_shape) {
708 std::vector<std::string> parameters;
709 const auto& shape_parameters = program_shape.parameters();
710 parameters.reserve(shape_parameters.size());
711 for (const auto& shape : shape_parameters) {
712 const int i = parameters.size();
713 parameters.push_back(StrCat(i < program_shape.parameter_names_size()
714 ? program_shape.parameter_names(i)
715 : "(unknown)",
716 ": ", HumanString(shape)));
717 }
718 return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
719 HumanString(program_shape.result()));
720 }
721
SameDimensions(const Shape & lhs,const Shape & rhs)722 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
723 const Shape& rhs) {
724 CHECK(lhs.IsArray());
725 CHECK(rhs.IsArray());
726 return absl::c_equal(lhs.dimensions(), rhs.dimensions());
727 }
728
SameRank(const Shape & lhs,const Shape & rhs)729 /* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) {
730 CHECK(lhs.IsArray());
731 CHECK(rhs.IsArray());
732 return lhs.rank() == rhs.rank();
733 }
734
Compatible(const Shape & lhs,const Shape & rhs)735 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
736 return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
737 }
738
CompatibleIgnoringElementType(const Shape & lhs,const Shape & rhs)739 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
740 const Shape& rhs) {
741 return Shape::Equal()
742 .IgnoreDynamicDimension()
743 .IgnoreElementType()
744 .IgnoreLayout()(lhs, rhs);
745 }
746
CompatibleKind(const Shape & lhs,const Shape & rhs)747 /* static */ bool ShapeUtil::CompatibleKind(const Shape& lhs,
748 const Shape& rhs) {
749 return Shape::Equal()
750 .IgnoreElementType()
751 .IgnoreLayout()
752 .IgnoreDimensions()
753 .IgnoreDynamicDimension()(lhs, rhs);
754 }
755
CompatibleIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)756 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
757 const Shape& rhs) {
758 return Shape::Equal()
759 .IgnoreDynamicDimension()
760 .IgnoreFpPrecision()
761 .IgnoreLayout()(lhs, rhs);
762 }
763
GetDimension(const Shape & shape,int64_t dimension_number)764 /* static */ int64_t ShapeUtil::GetDimension(const Shape& shape,
765 int64_t dimension_number) {
766 return shape.dimensions(GetDimensionNumber(shape, dimension_number));
767 }
768
GetDimensionNumber(const Shape & shape,int64_t dimension_number)769 /* static */ int64_t ShapeUtil::GetDimensionNumber(const Shape& shape,
770 int64_t dimension_number) {
771 if (dimension_number < 0) {
772 dimension_number += shape.rank();
773 }
774 CHECK_GE(dimension_number, 0);
775 return dimension_number;
776 }
777
ByteSizeOfPrimitiveType(PrimitiveType primitive_type)778 /* static */ int64_t ShapeUtil::ByteSizeOfPrimitiveType(
779 PrimitiveType primitive_type) {
780 switch (primitive_type) {
781 case PRED:
782 return sizeof(int8_t);
783 case S8:
784 return sizeof(int8_t);
785 case S16:
786 return sizeof(int16_t);
787 case S32:
788 return sizeof(int32_t);
789 case S64:
790 return sizeof(int64_t);
791 case U8:
792 return sizeof(uint8_t);
793 case U16:
794 return sizeof(uint16_t);
795 case U32:
796 return sizeof(uint32_t);
797 case U64:
798 return sizeof(uint64_t);
799 case BF16:
800 return sizeof(float) / 2;
801 case F16:
802 return sizeof(float) / 2;
803 case F32:
804 return sizeof(float);
805 case F64:
806 return sizeof(double);
807 case C64:
808 return sizeof(complex64);
809 case C128:
810 return sizeof(complex128);
811 case TOKEN:
812 // Tokens require no space.
813 return 0;
814 case TUPLE:
815 case OPAQUE_TYPE:
816 LOG(FATAL) << PrimitiveType_Name(primitive_type)
817 << " primitive type has no definitive size";
818 default:
819 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
820 }
821 }
822
ByteSizeOf(const Shape & shape,int64_t pointer_size)823 /* static */ int64_t ShapeUtil::ByteSizeOf(const Shape& shape,
824 int64_t pointer_size) {
825 TF_DCHECK_OK(ValidateShape(shape));
826 if (shape.element_type() == TUPLE) {
827 return ByteSizeOfTupleIndexTable(shape, pointer_size);
828 } else if (shape.IsArray()) {
829 return ByteSizeOfElements(shape);
830 } else if (shape.element_type() == TOKEN) {
831 return 0;
832 } else if (shape.element_type() == OPAQUE_TYPE) {
833 CHECK_GT(pointer_size, 0);
834 return pointer_size;
835 }
836 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
837 << " primitive type has no definitive size";
838 }
839
ByteSizeOfTupleIndexTable(const Shape & shape,int64_t pointer_size)840 /* static */ int64_t ShapeUtil::ByteSizeOfTupleIndexTable(
841 const Shape& shape, int64_t pointer_size) {
842 TF_DCHECK_OK(ValidateShape(shape));
843 CHECK_EQ(TUPLE, shape.element_type());
844 CHECK_GT(pointer_size, 0);
845 return pointer_size * shape.tuple_shapes_size();
846 }
847
ByteSizeOfElements(const Shape & shape)848 /* static */ int64_t ShapeUtil::ByteSizeOfElements(const Shape& shape) {
849 TF_DCHECK_OK(ValidateShape(shape));
850 int64_t allocated_element_count;
851
852 CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
853 allocated_element_count = ElementsIn(shape);
854 return allocated_element_count *
855 ByteSizeOfPrimitiveType(shape.element_type());
856 }
857
ValidateShapeWithOptionalLayoutInternal(const Shape & shape)858 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
859 const Shape& shape) {
860 if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
861 !PrimitiveType_IsValid(shape.element_type())) {
862 return InvalidArgument("shape has invalid element type: %s",
863 shape.ShortDebugString());
864 }
865 if (shape.element_type() == TUPLE) {
866 if (shape.dimensions_size() != 0) {
867 return InvalidArgument("tuples must not have dimensions specified");
868 }
869 for (auto& element_shape : shape.tuple_shapes()) {
870 TF_RETURN_IF_ERROR(
871 ValidateShapeWithOptionalLayoutInternal(element_shape));
872 }
873 return OkStatus();
874 }
875
876 // Non-tuple shape.
877 if (shape.tuple_shapes_size() > 0) {
878 return InvalidArgument("non-tuple shape has tuple_shapes field");
879 }
880
881 // Tokens and opaques can should not have layout or dimensions.
882 if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) {
883 if (shape.dimensions_size() != 0) {
884 return InvalidArgument(
885 "shape has %s element type, but has dimensions field: %s",
886 primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
887 shape.ShortDebugString());
888 }
889 if (shape.has_layout()) {
890 return InvalidArgument(
891 "shape has %s element type, but has layout field: %s",
892 primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
893 shape.ShortDebugString());
894 }
895 return OkStatus();
896 }
897
898 for (int64_t i = 0; i < shape.rank(); ++i) {
899 int64_t dimension = shape.dimensions(i);
900 if (dimension < 0) {
901 return InvalidArgument(
902 "shape's dimensions must not be < 0; dimension at index %d was %d", i,
903 dimension);
904 }
905 }
906
907 TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
908 return OkStatus();
909 }
910
ValidateShapeSize(const Shape & shape)911 /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
912 VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
913
914 if (!shape.IsArray()) {
915 return OkStatus();
916 }
917
918 int64_t shape_size = [&]() {
919 int64_t dense_shape_size = 1;
920 if (shape.dimensions().empty()) {
921 return dense_shape_size;
922 }
923
924 absl::Span<const int64_t> shape_max_dimensions = shape.dimensions();
925 for (int64_t dim : shape_max_dimensions) {
926 dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
927 if (dense_shape_size < 0) {
928 return dense_shape_size;
929 }
930 }
931 dense_shape_size = MultiplyWithoutOverflow(
932 dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
933 return dense_shape_size;
934 }();
935
936 if (shape_size < 0) {
937 return InvalidArgument("Shape %s size may overflow int64_t.",
938 ShapeUtil::HumanString(shape));
939 }
940
941 VLOG(3) << "Shape size is valid: " << shape_size;
942 return OkStatus();
943 }
944
ValidateShapeWithOptionalLayout(const Shape & shape)945 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
946 const Shape& shape) {
947 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
948
949 return LayoutUtil::ValidateLayoutInShape(shape,
950 /*allow_missing_layouts=*/true);
951 }
952
ValidateShape(const Shape & shape)953 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
954 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
955
956 return LayoutUtil::ValidateLayoutInShape(shape);
957 }
958
ChangeElementType(const Shape & original,PrimitiveType type)959 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
960 PrimitiveType type) {
961 if (original.IsTuple()) {
962 std::vector<Shape> new_operands;
963 new_operands.reserve(original.tuple_shapes_size());
964 for (const Shape& operand : original.tuple_shapes()) {
965 new_operands.push_back(ChangeElementType(operand, type));
966 }
967 return MakeTupleShape(new_operands);
968 } else {
969 Shape new_shape = original;
970 new_shape.set_element_type(type);
971 return new_shape;
972 }
973 }
974
IndexIsValid(const Shape & shape,ShapeIndexView index)975 /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
976 ShapeIndexView index) {
977 const Shape* subshape = &shape;
978 for (auto i : index) {
979 if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) {
980 return false;
981 }
982 subshape = &subshape->tuple_shapes(i);
983 }
984 return true;
985 }
986
GetSubshape(const Shape & shape,ShapeIndexView index)987 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
988 ShapeIndexView index) {
989 const Shape* return_shape = &shape;
990 for (auto i : index) {
991 CHECK(return_shape->IsTuple())
992 << "Invalid index " << ShapeIndex(index) << " for shape " << shape;
993 return_shape = &return_shape->tuple_shapes(i);
994 }
995 return *return_shape;
996 }
997
TryGetSubshape(const Shape & shape,ShapeIndexView index)998 /* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
999 const Shape& shape, ShapeIndexView index) {
1000 const Shape* return_shape = &shape;
1001 for (auto i : index) {
1002 if (!return_shape->IsTuple() || i < 0 ||
1003 i >= return_shape->tuple_shapes_size()) {
1004 return InvalidArgument(
1005 "Shape index %s not a valid subshape index for tuple with shape %s",
1006 ShapeIndex(index).ToString(), shape.DebugString());
1007 }
1008 return_shape = &return_shape->tuple_shapes(i);
1009 }
1010 return return_shape;
1011 }
1012
GetMutableSubshape(Shape * shape,ShapeIndexView index)1013 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
1014 ShapeIndexView index) {
1015 Shape* return_shape = shape;
1016 for (auto i : index) {
1017 CHECK(return_shape->IsTuple());
1018 return_shape = return_shape->mutable_tuple_shapes(i);
1019 }
1020 return return_shape;
1021 }
1022
1023 /* static */
IsLeafIndex(const Shape & shape,const ShapeIndex & index)1024 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
1025 return !GetSubshape(shape, index).IsTuple();
1026 }
1027
GetLeafCount(const Shape & shape)1028 /* static */ int64_t ShapeUtil::GetLeafCount(const Shape& shape) {
1029 if (!shape.IsTuple()) {
1030 return 1;
1031 }
1032 int64_t count = 0;
1033 for (const Shape& subshape : shape.tuple_shapes()) {
1034 count += GetLeafCount(subshape);
1035 }
1036 return count;
1037 }
1038
GetLeafShapes(const Shape & shape)1039 /* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
1040 const Shape& shape) {
1041 std::vector<IndexedShape> leaves;
1042 ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
1043 if (IsLeafIndex(shape, index)) {
1044 leaves.emplace_back(index, sub_shape);
1045 }
1046 });
1047 return leaves;
1048 }
1049
HasDegenerateDimensions(const Shape & shape)1050 /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
1051 CHECK(shape.IsArray());
1052 return absl::c_linear_search(shape.dimensions(), 1);
1053 }
1054
DropDegenerateDimensions(const Shape & shape)1055 /* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
1056 return FilterDimensions(
1057 [&](int64_t dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
1058 }
1059
1060 namespace {
1061
1062 // Helper for ForEachSubshape which visits the subshapes of the given shape in
1063 // DFS pre-order starting with the index.
ForEachSubshapeHelper(const Shape & shape,const ShapeUtil::StatusVisitorFunction & func,ShapeIndex * index)1064 Status ForEachSubshapeHelper(const Shape& shape,
1065 const ShapeUtil::StatusVisitorFunction& func,
1066 ShapeIndex* index) {
1067 TF_RETURN_IF_ERROR(func(shape, *index));
1068 if (shape.IsTuple()) {
1069 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
1070 index->push_back(i);
1071 TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
1072 ShapeUtil::GetTupleElementShape(shape, i), func, index));
1073 index->pop_back();
1074 }
1075 }
1076 return OkStatus();
1077 }
1078
1079 // Helper for ForEachMutableSubshape which visits the subshapes of the given
1080 // shape in DFS pre-order starting with the index.
ForEachMutableSubshapeHelper(Shape * shape,const ShapeUtil::MutatingStatusVisitorFunction & func,ShapeIndex * index)1081 Status ForEachMutableSubshapeHelper(
1082 Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
1083 ShapeIndex* index) {
1084 TF_RETURN_IF_ERROR(func(shape, *index));
1085 if (shape->IsTuple()) {
1086 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
1087 index->push_back(i);
1088 TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
1089 shape->mutable_tuple_shapes(i), func, index));
1090 index->pop_back();
1091 }
1092 }
1093 return OkStatus();
1094 }
1095
1096 } // namespace
1097
ForEachSubshape(const Shape & shape,const VisitorFunction & func)1098 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
1099 const VisitorFunction& func) {
1100 ShapeIndex index;
1101 ForEachSubshapeHelper(
1102 shape,
1103 [&func](const Shape& subshape, const ShapeIndex& index) {
1104 func(subshape, index);
1105 return OkStatus();
1106 },
1107 &index)
1108 .IgnoreError();
1109 }
1110
ForEachMutableSubshape(Shape * shape,const MutatingVisitorFunction & func)1111 /* static */ void ShapeUtil::ForEachMutableSubshape(
1112 Shape* shape, const MutatingVisitorFunction& func) {
1113 ShapeIndex index;
1114 ForEachMutableSubshapeHelper(
1115 shape,
1116 [&func](Shape* subshape, const ShapeIndex& index) {
1117 func(subshape, index);
1118 return OkStatus();
1119 },
1120 &index)
1121 .IgnoreError();
1122 }
1123
ForEachSubshapeWithStatus(const Shape & shape,const StatusVisitorFunction & func)1124 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
1125 const Shape& shape, const StatusVisitorFunction& func) {
1126 ShapeIndex index;
1127 return ForEachSubshapeHelper(shape, func, &index);
1128 }
1129
ForEachMutableSubshapeWithStatus(Shape * shape,const MutatingStatusVisitorFunction & func)1130 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
1131 Shape* shape, const MutatingStatusVisitorFunction& func) {
1132 ShapeIndex index;
1133 return ForEachMutableSubshapeHelper(shape, func, &index);
1134 }
1135
PermuteDimensions(absl::Span<const int64_t> permutation,const Shape & shape)1136 /* static */ Shape ShapeUtil::PermuteDimensions(
1137 absl::Span<const int64_t> permutation, const Shape& shape) {
1138 Shape new_shape = shape;
1139 new_shape.clear_dimensions();
1140 for (auto dim : Permute(shape.dimensions(), permutation)) {
1141 new_shape.add_dimensions(dim);
1142 }
1143 auto inv_permutation = InversePermutation(permutation);
1144 for (int64_t i = 0; i < shape.rank(); i++) {
1145 new_shape.set_dynamic_dimension(inv_permutation[i],
1146 shape.is_dynamic_dimension(i));
1147 }
1148
1149 // If `shape` has a layout, by contract we choose a new layout such that the
1150 // transpose defined by this permutation is a bitcast.
1151 //
1152 // Some formalism helps to understand the correct way to do this. We're going
1153 // to do algebra in the group of permutations of the dimensions of `shape`.
1154 //
1155 // Since the order of `shape`'s dimensions is not permuted relative to itself,
1156 // `shape`'s list of dimensions is isomorphic to the identity I.
1157 //
1158 // Let `shape`'s layout be L. A layout is a permutation which maps a
1159 // minor-to-major physical dimension ordering to a shape's logical dimension
1160 // ordering. Therefore the inverse of a layout maps from logical to physical
1161 // dims, and so the physical ordering of I is simply L'.I = L', where L' is
1162 // the inverse of L.
1163 //
1164 // Let the argument `permutation` be P. This is a permutation over `shape`'s
1165 // dimensions, so our return value will be a shape with dims P.I = P. Our
1166 // goal is to construct a layout permutation L* for this shape. The physical
1167 // dimension ordering of this returned shape must be the same as that of the
1168 // original shape, namely L'.
1169 //
1170 // Our returned shape has dims P and layout L*, so its in-memory ordering is
1171 // L*'.P. Setting this equal to L' and solving for L*, we get:
1172 //
1173 // L*'.P = L' =>
1174 // L*' = L'P' =>
1175 // L* = P.L
1176 //
1177 if (shape.has_layout()) {
1178 CHECK(LayoutUtil::IsDenseArray(shape));
1179 Layout* new_layout = new_shape.mutable_layout();
1180 new_layout->clear_minor_to_major();
1181 for (auto index : ComposePermutations(inv_permutation,
1182 shape.layout().minor_to_major())) {
1183 new_layout->add_minor_to_major(index);
1184 }
1185 // The permutation accepted by TransposeIsBitcast is the inverse of the
1186 // permutation here.
1187 CHECK(TransposeIsBitcast(shape, new_shape, permutation))
1188 << "shape=" << HumanStringWithLayout(shape)
1189 << ", new_shape=" << HumanStringWithLayout(new_shape)
1190 << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
1191 }
1192 return new_shape;
1193 }
1194
1195 /* static */ std::optional<ShapeUtil::ShapeEqualityDescriptor>
InsertedOrDeleted1SizedDimensions(const Shape & shape_pre,const Shape & shape_post)1196 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
1197 const Shape& shape_post) {
1198 CHECK(shape_pre.IsArray());
1199 CHECK(shape_post.IsArray());
1200
1201 std::vector<int64_t> deleted_indices;
1202 std::vector<int64_t> inserted_indices;
1203 // Returns false if any input/output index between prior_unmodified_dim_pair
1204 // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
1205 // the degerenate input/output dimensions in the gap to
1206 // deleted_indices/inserted_indices respectively.
1207 auto check_modified_dims =
1208 [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
1209 std::pair<int64_t, int64_t> prior_unmodified_dim_pair,
1210 std::pair<int64_t, int64_t> unmodified_dim_pair) {
1211 for (int64_t modified_input_dim = prior_unmodified_dim_pair.first + 1;
1212 modified_input_dim < unmodified_dim_pair.first;
1213 ++modified_input_dim) {
1214 if (shape_pre.dimensions(modified_input_dim) > 1) {
1215 return false;
1216 }
1217 deleted_indices.push_back(modified_input_dim);
1218 }
1219 for (int64_t modified_output_dim = prior_unmodified_dim_pair.second + 1;
1220 modified_output_dim < unmodified_dim_pair.second;
1221 ++modified_output_dim) {
1222 if (shape_post.dimensions(modified_output_dim) > 1) {
1223 return false;
1224 }
1225 inserted_indices.push_back(modified_output_dim);
1226 }
1227 return true;
1228 };
1229
1230 std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
1231 DimensionsUnmodifiedByReshape(shape_pre, shape_post);
1232 // Returns nil if the reshape modifies any non-degenerate input/output
1233 // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
1234 // dimensions, so we only need to check whether dimensions in the gaps (thus
1235 // modified) have size >1.
1236 for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
1237 // Check (modified) dimensions between unmodified_dims[i-1] and
1238 // unmodified_dims[i].
1239 auto prior_unmodified_dim_pair =
1240 i > 0 ? unmodified_dims[i - 1] : std::pair<int64_t, int64_t>(-1, -1);
1241 auto unmodified_dim_pair =
1242 i < unmodified_dims.size()
1243 ? unmodified_dims[i]
1244 : std::make_pair(shape_pre.rank(), shape_post.rank());
1245 if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
1246 return std::nullopt;
1247 }
1248 }
1249
1250 return ShapeEqualityDescriptor{deleted_indices, inserted_indices};
1251 }
1252
1253 /* static */ std::vector<std::pair<int64_t, int64_t>>
DimensionsUnmodifiedByReshape(const Shape & input_shape,const Shape & output_shape)1254 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
1255 const Shape& output_shape) {
1256 CHECK(input_shape.IsArray());
1257 CHECK(output_shape.IsArray());
1258
1259 // Unmodified dimensions are merely common factors of rank 1.
1260 auto common_factors =
1261 CommonFactors(input_shape.dimensions(), output_shape.dimensions());
1262 for (size_t i = 0; i < common_factors.size() - 1;) {
1263 if (1 != common_factors[i + 1].first - common_factors[i].first ||
1264 1 != common_factors[i + 1].second - common_factors[i].second) {
1265 common_factors.erase(common_factors.begin() + i);
1266 } else {
1267 ++i;
1268 }
1269 }
1270 // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
1271 common_factors.pop_back();
1272 return std::vector<std::pair<int64_t, int64_t>>(common_factors.begin(),
1273 common_factors.end());
1274 }
1275
1276 /* static */ std::optional<std::vector<int64_t>>
ReshapeLeavesDimensionsUnmodified(const Shape & from_shape,const Shape & to_shape,absl::Span<const int64_t> input_dim_indices)1277 ShapeUtil::ReshapeLeavesDimensionsUnmodified(
1278 const Shape& from_shape, const Shape& to_shape,
1279 absl::Span<const int64_t> input_dim_indices) {
1280 if (!std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())) {
1281 return std::nullopt;
1282 }
1283
1284 std::vector<int64_t> output_dim_indices;
1285 std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
1286 ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
1287 size_t i = 0; // index to unmodified_dims
1288 for (int64_t input_dim_index : input_dim_indices) {
1289 // Search unmodified_dims for input_dim_index. We can search from the last
1290 // matching position because input_dim_indices is guaranteed to be sorted.
1291 while (i < unmodified_dims.size() &&
1292 unmodified_dims[i].first < input_dim_index) {
1293 ++i;
1294 }
1295 if (i >= unmodified_dims.size() ||
1296 unmodified_dims[i].first != input_dim_index) {
1297 return std::nullopt;
1298 }
1299 output_dim_indices.push_back(unmodified_dims[i].second);
1300 }
1301 return output_dim_indices;
1302 }
1303
TransposeIsBitcast(const Shape & input_shape,const Shape & output_shape,absl::Span<const int64_t> dimension_mapping)1304 /* static */ bool ShapeUtil::TransposeIsBitcast(
1305 const Shape& input_shape, const Shape& output_shape,
1306 absl::Span<const int64_t> dimension_mapping) {
1307 CHECK(LayoutUtil::IsDenseArray(input_shape)) << input_shape.ToString(true);
1308 CHECK(LayoutUtil::IsDenseArray(output_shape)) << output_shape.ToString(true);
1309 CHECK(input_shape.has_layout()) << input_shape.ToString(true);
1310 CHECK(output_shape.has_layout()) << output_shape.ToString(true);
1311
1312 if (!SameElementType(input_shape, output_shape)) {
1313 return false;
1314 }
1315
1316 // Check the reshape permutes the positions of each dimension in the
1317 // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
1318 // input_positions = apply(dimension_mapping, output_positions)
1319 //
1320 // Because the positions of each dimension are the inverse permutation of the
1321 // minor-to-major order, the above check is equivalent to
1322 // inverse(input_dimensions) =
1323 // apply(dimension_mapping, inverse(output_dimensions))
1324 // # `I` indicates identity permutation.
1325 // apply(input_dimensions, I) =
1326 // apply(dimension_mapping, apply(output_dimensions, I))
1327 // apply(input_dimensions, I) =
1328 // apply((dimension_mapping * output_dimensions), I)
1329 // input_dimensions = dimension_mapping * output_dimensions
1330 return absl::c_equal(
1331 ComposePermutations(dimension_mapping,
1332 output_shape.layout().minor_to_major()),
1333 input_shape.layout().minor_to_major());
1334 }
1335
ReshapeIsBitcast(const Shape & input_shape,const Shape & output_shape)1336 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
1337 const Shape& output_shape) {
1338 CHECK(LayoutUtil::IsDenseArray(input_shape)) << input_shape.ToString(true);
1339 CHECK(LayoutUtil::IsDenseArray(output_shape)) << output_shape.ToString(true);
1340 CHECK(input_shape.has_layout()) << input_shape.ToString(true);
1341 CHECK(output_shape.has_layout()) << output_shape.ToString(true);
1342
1343 if (!SameElementType(input_shape, output_shape)) {
1344 return false;
1345 }
1346
1347 if (ElementsIn(input_shape) != ElementsIn(output_shape)) {
1348 VLOG(3) << "input_shape=" << input_shape.ShortDebugString()
1349 << ", output_shape=" << output_shape.ShortDebugString();
1350 return false;
1351 }
1352 if (ElementsIn(input_shape) == 0) {
1353 return true;
1354 }
1355
1356 // TL;DR: The rest of the method checks that the reshape does not change the
1357 // physical location of any unit input or output index. Unit indices have
1358 // exactly one dimension that equals 1 and other dimensions 0. This condition
1359 // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
1360 // reshape shouldn't change the physical location of any element. It is also a
1361 // sufficient condition as is proved below (note: many details are omitted for
1362 // space).
1363 //
1364 // Definitions:
1365 //
1366 // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
1367 // the size of i-th least significant dimension of IS or OS (this is opposite
1368 // to how we define the index of Shape::dimensions()).
1369 //
1370 // * Given an input or output index I, denote by p(I) I's physical linear
1371 // index (or physical index for short) and l(I) I's logical linear index (or
1372 // logical index for short).
1373 //
1374 // * Given a logical index k, denote by II(k) the input index whose linear
1375 // index is k, and OI(k) the corresponding output index.
1376 //
1377 // * Denote by IT[i] the increment of physical index if i-th dimension of the
1378 // input index is increased by 1. Similarly, OT[i] means the increment if i-th
1379 // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
1380 // is a function of IS or OS and the layout, and not dependent on the specific
1381 // input or output index.
1382 //
1383 // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
1384 // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
1385 // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
1386 // to prove, with every increment on k, the above formula still holds.
1387 //
1388 // First, suppose reshaping from IS to OS is non-factorizable (we discuss
1389 // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
1390 // there exists (i,j) such that
1391 //
1392 // 0<=i<=|IS|
1393 // 0<=j<=|OS|
1394 // |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
1395 // product(IS[i], IS[i+1], ..., IS[|IS|-1])
1396 // = product(OS[j], OS[j+1], ..., OS[|OS|-1])
1397 //
1398 // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
1399 // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
1400 // unit indices which are already tested. This also means IT[0]=OT[0]
1401 // because p(II(1))=IT[0] and p(OI(1))=OT[0].
1402 //
1403 // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
1404 // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
1405 // to the output physical.
1406 //
1407 // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
1408 // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
1409 // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
1410 // logical index k-1 to logical index k, dimension 1 of the input index
1411 // is increased by 1 and dimension 0 is reset to 0 thus decreased by
1412 // IS[0]-1. Therefore, the physical input index is increased by
1413 //
1414 // p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
1415 //
1416 // Because IS[0]<OS[0], the only change to the output index is that its
1417 // dimension 0 is increased by one. Therefore,
1418 //
1419 // p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
1420 //
1421 // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
1422 // p(II(k))=p(OI(k)). Therefore,
1423 // IT[1] - (IS[0]-1) * IT[0] = IT[0]
1424 // IT[1] = IS[0] * IT[0]
1425 // In other words, input dimension 1 is immediately more major than input
1426 // dimension 0. We can now conceptually collapse these two dimensions because
1427 // an increment in the logical index affecting only these two dimensions maps
1428 // to IT[0] in the physical index.
1429 //
1430 // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
1431 // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
1432 // identical.
1433 //
1434 // A factorizable reshape can be factorized into a list of non-factorizable
1435 // sub-reshapes, each of which can be handled similarly to the proof above.
1436 // For example,
1437 //
1438 // [7x9x2x15] -> [63x6x5]
1439 //
1440 // can be factorized into
1441 //
1442 // [7x9] -> [63] and [2x15] -> [6x5].
1443 //
1444 // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
1445 // same logical linear index. According to the factorization, we know
1446 // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
1447 // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
1448 // similar proof, with the increment of the logical index set to
1449 // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
1450 // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
1451 //
1452 // p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
1453 // = p(y2,0,0) + p(0,0,y1,y0)
1454 // = p(y2,y1,y0)
1455 //
1456 // check_input_unit_indices checks one way of the condition: each input unit
1457 // index is mapped to an output index with the same physical location. This
1458 // lambda will be called again with input_shape and output_shape reversed to
1459 // check the other way.
1460 auto check_input_unit_indices = [](const Shape& input_shape,
1461 const Shape& output_shape) {
1462 // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
1463 // as input_shape/output_shape and the dimension-0-major layout. These two
1464 // shapes are used for conversion between logical linear indices and
1465 // multi-dimensional indices.
1466 Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
1467 input_shape.element_type(), input_shape.dimensions());
1468 Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
1469 output_shape.element_type(), output_shape.dimensions());
1470
1471 for (int64_t input_dim = 0; input_dim < input_shape.rank(); ++input_dim) {
1472 if (input_shape.dimensions(input_dim) <= 1) {
1473 continue;
1474 }
1475
1476 std::vector<int64_t> input_unit_index(input_shape.rank(), 0);
1477 input_unit_index[input_dim] = 1;
1478 int64_t logical_linear_index =
1479 IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
1480 input_unit_index);
1481 // output_index has the same logical linear index as input_unit_index.
1482 std::vector<int64_t> output_index =
1483 IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
1484 logical_linear_index);
1485 // Check input_unit_index and output_index have the same physical linear
1486 // index.
1487 if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
1488 input_unit_index) !=
1489 IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
1490 output_index)) {
1491 return false;
1492 }
1493 }
1494 return true;
1495 };
1496 return check_input_unit_indices(input_shape, output_shape) &&
1497 check_input_unit_indices(output_shape, input_shape);
1498 }
1499
AlignLayouts(const Shape & input_shape,const Shape & output_shape)1500 /* static */ std::optional<Shape> ShapeUtil::AlignLayouts(
1501 const Shape& input_shape, const Shape& output_shape) {
1502 CHECK(input_shape.IsArray());
1503 CHECK(output_shape.IsArray());
1504 // Removing trivial dimensions from the shape simplifies the alignment
1505 // algorithm since ones can go in any position.
1506 if (HasDegenerateDimensions(input_shape) ||
1507 HasDegenerateDimensions(output_shape)) {
1508 auto simple_output_shape =
1509 AlignLayouts(DropDegenerateDimensions(input_shape),
1510 DropDegenerateDimensions(output_shape));
1511 if (!simple_output_shape) {
1512 return std::nullopt;
1513 }
1514
1515 std::vector<int64_t> layout =
1516 SpanToVector(simple_output_shape->layout().minor_to_major());
1517 // For each one sized dimension in the output, increment the dimension
1518 // numbers in layout that are more minor than the one.
1519 absl::InlinedVector<int64_t, 8> dim_map;
1520 dim_map.reserve(simple_output_shape->rank());
1521 for (int64_t i = 0; i < output_shape.rank(); ++i) {
1522 if (output_shape.dimensions(i) != 1) {
1523 dim_map.push_back(i);
1524 }
1525 }
1526 for (int64_t& d : layout) {
1527 d = dim_map[d];
1528 }
1529
1530 // Add the ones in descending order to the layout. Descending layouts tend
1531 // to reduce the number of copies inserted in layout assignment.
1532 for (int64_t i = output_shape.rank() - 1; i >= 0; --i) {
1533 if (output_shape.dimensions(i) == 1) {
1534 layout.push_back(i);
1535 }
1536 }
1537 Shape output_shape_with_layout = output_shape;
1538 *output_shape_with_layout.mutable_layout() = Layout{layout};
1539 return output_shape_with_layout;
1540 }
1541
1542 auto common_factors =
1543 CommonFactors(input_shape.dimensions(), output_shape.dimensions());
1544 const int64_t input_rank = input_shape.rank();
1545 DimensionVector input_to_factor(input_rank);
1546 for (int64_t pos = 0; pos < common_factors.size() - 1; ++pos) {
1547 const int64_t input_start = common_factors[pos].first;
1548 const int64_t input_end = common_factors[pos + 1].first;
1549 int64_t input_physical =
1550 PositionInContainer(input_shape.layout().minor_to_major(), input_start);
1551 input_to_factor[input_start] = pos;
1552 for (int64_t i = input_start + 1; i < input_end; ++i) {
1553 --input_physical;
1554 if (input_physical < 0 ||
1555 input_shape.layout().minor_to_major(input_physical) != i) {
1556 return std::nullopt;
1557 }
1558 input_to_factor[i] = pos;
1559 }
1560 }
1561
1562 int64_t output_rank = output_shape.rank();
1563 DimensionVector output_layout;
1564 output_layout.reserve(output_rank);
1565 int64_t input_minor = 0;
1566 while (output_layout.size() < output_rank) {
1567 const int64_t input_dim = input_shape.layout().minor_to_major(input_minor);
1568 const int64_t common_factor = input_to_factor[input_dim];
1569 const auto start_factor = common_factors[common_factor];
1570 const auto end_factor = common_factors[common_factor + 1];
1571 for (int64_t dim = end_factor.second - 1; dim >= start_factor.second;
1572 --dim) {
1573 output_layout.push_back(dim);
1574 }
1575 input_minor += end_factor.first - start_factor.first;
1576 }
1577
1578 Shape output_shape_with_layout = MakeShapeWithLayout(
1579 output_shape.element_type(), output_shape.dimensions(), output_layout);
1580 CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout))
1581 << "reshape is not a bitcast for input_shape: "
1582 << ShapeUtil::HumanStringWithLayout(input_shape)
1583 << " and output_shape_with_layout: "
1584 << ShapeUtil::HumanStringWithLayout(output_shape_with_layout);
1585 return output_shape_with_layout;
1586 }
1587
DeleteDimension(int64_t dim_to_delete,Shape shape)1588 /* static */ Shape ShapeUtil::DeleteDimension(int64_t dim_to_delete,
1589 Shape shape) {
1590 CHECK(shape.IsArray());
1591 shape.DeleteDimension(dim_to_delete);
1592 return shape;
1593 }
1594
DynamicArrayShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1595 /* static */ bool ShapeUtil::DynamicArrayShapeIsCompatible(
1596 const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1597 if (dynamic_shape.rank() != bounded_shape.rank()) {
1598 return false;
1599 }
1600 for (int64_t i = 0; i < dynamic_shape.rank(); ++i) {
1601 if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
1602 return false;
1603 }
1604 }
1605 return true;
1606 }
1607
DynamicShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1608 /* static */ bool ShapeUtil::DynamicShapeIsCompatible(
1609 const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1610 bool compatible = true;
1611 xla::ShapeUtil::ForEachSubshape(
1612 dynamic_shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
1613 if (compatible) {
1614 auto subshape_result = TryGetSubshape(bounded_shape, index);
1615 if (subshape_result.ok()) {
1616 const Shape* bounded_sub_shape = std::move(subshape_result).value();
1617 if (sub_shape.IsTuple()) {
1618 if (!bounded_sub_shape->IsTuple()) {
1619 compatible = false;
1620 }
1621 } else {
1622 if (bounded_sub_shape->IsTuple()) {
1623 compatible = false;
1624 } else if (!sub_shape.is_static() &&
1625 !DynamicArrayShapeIsCompatible(sub_shape,
1626 *bounded_sub_shape)) {
1627 compatible = false;
1628 }
1629 }
1630 } else {
1631 compatible = false;
1632 }
1633 }
1634 });
1635 return compatible;
1636 }
1637
DeleteDimensions(absl::Span<int64_t const> dims_to_delete,Shape shape)1638 /* static */ Shape ShapeUtil::DeleteDimensions(
1639 absl::Span<int64_t const> dims_to_delete, Shape shape) {
1640 std::vector<int64_t> dims_to_delete_v(dims_to_delete.begin(),
1641 dims_to_delete.end());
1642 absl::c_sort(dims_to_delete_v, std::greater<int64_t>());
1643 for (int64_t dim : dims_to_delete_v) {
1644 shape = DeleteDimension(dim, shape);
1645 }
1646 return shape;
1647 }
1648
FilterDimensions(const std::function<bool (int64_t)> & p,Shape shape)1649 /* static */ Shape ShapeUtil::FilterDimensions(
1650 const std::function<bool(int64_t)>& p, Shape shape) {
1651 CHECK(shape.IsArray());
1652 std::vector<int64_t> dims_to_delete;
1653 for (int64_t i = shape.dimensions().size() - 1; i >= 0; --i) {
1654 if (!p(i)) {
1655 dims_to_delete.push_back(i);
1656 }
1657 }
1658 return DeleteDimensions(dims_to_delete, shape);
1659 }
1660
1661 // Returns the indices of the first elements of all consecutive subarrays of the
1662 // given array. For example:
1663 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
ConsecutiveSegments(absl::Span<const int64_t> xs)1664 static std::vector<size_t> ConsecutiveSegments(absl::Span<const int64_t> xs) {
1665 std::vector<size_t> is = {0};
1666 for (size_t i = 1; i < xs.size(); ++i) {
1667 if (1 != xs[i] - xs[i - 1]) {
1668 is.push_back(i);
1669 }
1670 }
1671 return is;
1672 }
1673
1674 // Merges the sequences of dimensions of the given shape which start at the
1675 // given indices `segs`.
MergeDimensions(absl::Span<const size_t> segs,const Shape & shape)1676 static Shape MergeDimensions(absl::Span<const size_t> segs,
1677 const Shape& shape) {
1678 std::vector<int64_t> dimensions;
1679 const auto size = segs.size();
1680 dimensions.reserve(size);
1681 for (size_t i = 1; i <= size; ++i) {
1682 dimensions.push_back(std::accumulate(
1683 shape.dimensions().begin() + segs[i - 1],
1684 shape.dimensions().begin() +
1685 (segs.size() == i ? shape.dimensions().size() : segs[i]),
1686 1, std::multiplies<int64_t>()));
1687 }
1688 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1689 dimensions);
1690 }
1691
FindTranspose021(const Shape & a,const Shape & b)1692 /* static */ std::optional<Vector3> ShapeUtil::FindTranspose021(
1693 const Shape& a, const Shape& b) {
1694 if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) {
1695 return std::nullopt;
1696 }
1697
1698 std::vector<int64_t> permutation(a.dimensions().size());
1699 absl::Span<const int64_t> minor_to_major_a = LayoutUtil::MinorToMajor(a);
1700 std::vector<int64_t> major_to_minor_a(minor_to_major_a.rbegin(),
1701 minor_to_major_a.rend());
1702 absl::Span<const int64_t> minor_to_major_b = LayoutUtil::MinorToMajor(b);
1703 std::vector<int64_t> major_to_minor_b(minor_to_major_b.rbegin(),
1704 minor_to_major_b.rend());
1705 for (size_t i = 0; i < permutation.size(); ++i) {
1706 permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
1707 }
1708
1709 std::vector<size_t> segments = ConsecutiveSegments(permutation);
1710 if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
1711 Shape descending_layout_shape =
1712 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
1713 Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
1714 absl::Span<const int64_t> normalized_dims = normalized_shape.dimensions();
1715 Vector3 dims_021;
1716 if (2 == segments.size()) {
1717 // The logical component-0 is of size one.
1718 dims_021 = {1, normalized_dims[1], normalized_dims[0]};
1719 } else {
1720 dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
1721 }
1722 return dims_021;
1723 }
1724
1725 return std::nullopt;
1726 }
1727
DeviceShapeToHostShape(Shape s)1728 Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
1729 ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) {
1730 if (subshape->IsArray()) {
1731 subshape->mutable_layout()->clear_tiles();
1732 subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);
1733 }
1734 });
1735 return s;
1736 }
1737
ElementCanUpcast(const Shape & from,const Shape & to)1738 /*static*/ bool ShapeUtil::ElementCanUpcast(const Shape& from,
1739 const Shape& to) {
1740 return HigherPrecisionElementType(from, to) == to.element_type();
1741 }
1742
1743 /*static*/
ByteStrides(const Shape & shape,absl::Span<int64_t> strides)1744 Status ShapeUtil::ByteStrides(const Shape& shape, absl::Span<int64_t> strides) {
1745 TF_RET_CHECK(shape.IsArray());
1746 TF_RET_CHECK(shape.has_layout());
1747 TF_RET_CHECK(shape.dimensions_size() == strides.size());
1748
1749 int64_t stride = ByteSizeOfPrimitiveType(shape.element_type());
1750 for (int i : shape.layout().minor_to_major()) {
1751 strides.at(i) = stride;
1752 stride *= shape.dimensions(i);
1753 }
1754 return OkStatus();
1755 }
1756
ArraySize(const Shape & shape)1757 /*static*/ int64_t ShapeUtil::ArraySize(const Shape& shape) {
1758 CHECK(shape.IsArray());
1759 CHECK(!shape.layout().tiles().empty());
1760
1761 auto tile_dimensions = shape.layout().tiles(0).dimensions();
1762 auto shape_dimensions = shape.dimensions();
1763 auto minor_to_major = shape.layout().minor_to_major();
1764 int64_t shape_dim_size = shape_dimensions.size();
1765 int64_t tile_dim_size = tile_dimensions.size();
1766
1767 // Use the top-level tile for shape size calculation. We assume the
1768 // sub-tiles won't cause additional padding.
1769 int64_t num_of_elements = 1;
1770 int64_t dim = 0;
1771 for (dim = 0; dim < tile_dim_size; dim++) {
1772 int64_t dim_size =
1773 dim < shape_dim_size ? shape_dimensions[minor_to_major[dim]] : 1;
1774 num_of_elements *=
1775 RoundUpTo(dim_size, tile_dimensions[tile_dim_size - dim - 1]);
1776 }
1777 for (; dim < shape_dim_size; dim++) {
1778 int64_t dim_size = shape_dimensions[minor_to_major[dim]];
1779 num_of_elements *= dim_size;
1780 }
1781 if (shape.layout().element_size_in_bits() != 0) {
1782 const int64_t kBitsPerByte = 8;
1783 return CeilOfRatio(num_of_elements * shape.layout().element_size_in_bits(),
1784 static_cast<int64_t>(kBitsPerByte));
1785 }
1786 return num_of_elements * ByteSizeOfPrimitiveType(shape.element_type());
1787 }
1788
ArrayDataSize(const Shape & shape)1789 /*static*/ int64_t ShapeUtil::ArrayDataSize(const Shape& shape) {
1790 CHECK(shape.IsArray());
1791 absl::InlinedVector<int64_t, 4> indices;
1792 for (int64_t dim : shape.dimensions()) {
1793 indices.push_back(dim - 1);
1794 }
1795 int64_t size = LayoutUtil::LinearIndex(shape, indices) + 1;
1796 if (shape.layout().element_size_in_bits() != 0) {
1797 const int64_t kBitsPerByte = 8;
1798 return CeilOfRatio(size * shape.layout().element_size_in_bits(),
1799 static_cast<int64_t>(kBitsPerByte));
1800 }
1801 return (size * ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()));
1802 }
1803
1804 } // namespace xla
1805