• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/shape_util.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <numeric>
21 #include <optional>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/hash/hash.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_split.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/strings/strip.h"
38 #include "tensorflow/compiler/xla/index_util.h"
39 #include "tensorflow/compiler/xla/layout_util.h"
40 #include "tensorflow/compiler/xla/overflow_util.h"
41 #include "tensorflow/compiler/xla/permutation_util.h"
42 #include "tensorflow/compiler/xla/primitive_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/platform/logging.h"
48 
49 namespace xla {
50 
51 using absl::StrAppend;
52 using absl::StrCat;
53 
54 namespace {
55 // An array that is indexed by PrimitiveType, and returns
56 // the size of each element of that primitive type, or 0
57 // if the PrimitiveType is not a primitive type
58 constexpr uint8_t primitive_byte_size[PrimitiveType_ARRAYSIZE] = {
59     0,                  // PRIMITIVE_TYPE_INVALID = 0,
60     sizeof(int8_t),     // PRED = 1
61     sizeof(int8_t),     // S8 = 2
62     sizeof(int16_t),    // S16 = 3
63     sizeof(int32_t),    // S32 = 4
64     sizeof(int64_t),    // S64 = 5
65     sizeof(uint8_t),    // U8 = 6
66     sizeof(uint16_t),   // U16 = 7
67     sizeof(uint32_t),   // U32 = 8
68     sizeof(uint64_t),   // U64 = 9
69     sizeof(float) / 2,  // F16 = 10
70     sizeof(float),      // F32 = 11
71     sizeof(double),     // F64 = 12
72     0,                  // TUPLE = 13
73     0,                  // OPAQUE_TYPE = 14
74     sizeof(complex64),  // C64 = 15
75     sizeof(float) / 2,  // BF16 = 16
76     0,                  // TOKEN = 17
77     sizeof(complex128)  // C128 = 18
78 };
79 constexpr int64_t kAnnotationPrintInterval = 5;
80 }  // namespace
81 
ToString() const82 std::string ShapeIndex::ToString() const {
83   return StrCat("{", absl::StrJoin(*this, ","), "}");
84 }
85 
operator <<(std::ostream & out,const ShapeIndex & shape_index)86 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
87   out << shape_index.ToString();
88   return out;
89 }
90 
IsArrayPrimitiveType(PrimitiveType primitive_type)91 /* static */ bool ShapeUtil::IsArrayPrimitiveType(
92     PrimitiveType primitive_type) {
93   return primitive_util::IsArrayType(primitive_type);
94 }
95 
96 namespace {
97 // Constructs and returns the new shape with the given minor_to_major order in
98 // its Layout.
MakeShapeWithLayoutInternal(PrimitiveType element_type,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> minor_to_major,absl::Span<const DimLevelType> dim_level_types,absl::Span<const Tile> tiles,int64_t element_size_in_bits,int64_t memory_space)99 StatusOr<Shape> MakeShapeWithLayoutInternal(
100     PrimitiveType element_type, absl::Span<const int64_t> dimensions,
101     absl::Span<const int64_t> minor_to_major,
102     absl::Span<const DimLevelType> dim_level_types,
103     absl::Span<const Tile> tiles, int64_t element_size_in_bits,
104     int64_t memory_space) {
105   if (dimensions.size() != minor_to_major.size()) {
106     return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
107                            dimensions.size(), minor_to_major.size());
108   }
109   if (element_type == OPAQUE_TYPE || element_type == TUPLE ||
110       element_type == TOKEN) {
111     return InvalidArgument("Unsupported element type: %s",
112                            PrimitiveType_Name(element_type));
113   }
114   TF_ASSIGN_OR_RETURN(Shape shape,
115                       ShapeUtil::MakeValidatedShape(element_type, dimensions));
116   if (element_size_in_bits ==
117       ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
118     // Only set element_size_in_bits if it's different from the default value.
119     element_size_in_bits = 0;
120   }
121   *shape.mutable_layout() =
122       LayoutUtil::MakeLayout(minor_to_major, dim_level_types, tiles,
123                              element_size_in_bits, memory_space);
124   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
125   return shape;
126 }
127 
128 template <typename T>
Deref(const T * ptr)129 const T& Deref(const T* ptr) {
130   DCHECK(ptr != nullptr);
131   return *ptr;
132 }
133 
134 template <typename T>
Deref(const T & ref)135 const T& Deref(const T& ref) {
136   return ref;
137 }
138 
139 template <typename ShapePtrOrRef>
MakeTupleShapeImpl(absl::Span<ShapePtrOrRef> shapes)140 Shape MakeTupleShapeImpl(absl::Span<ShapePtrOrRef> shapes) {
141   Shape result;
142   result.set_element_type(TUPLE);
143   result.mutable_tuple_shapes()->reserve(shapes.size());
144   for (const auto& shape : shapes) {
145     ShapeUtil::AppendShapeToTuple(Deref(shape), &result);
146   }
147   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
148   return result;
149 }
150 
151 }  // namespace
152 
Equal(const Shape & lhs,const Shape & rhs)153 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
154   bool equal = Shape::Equal()(lhs, rhs);
155 
156   if (!equal && VLOG_IS_ON(3)) {
157     VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
158             << ", rhs = " << rhs.ShortDebugString();
159   }
160 
161   return equal;
162 }
163 
EqualIgnoringElementType(const Shape & lhs,const Shape & rhs)164 /* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
165                                                       const Shape& rhs) {
166   bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
167   if (!equal && VLOG_IS_ON(3)) {
168     VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
169             << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
170   }
171 
172   return equal;
173 }
174 
EqualIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)175 /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
176                                                       const Shape& rhs) {
177   bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
178   if (!equal && VLOG_IS_ON(3)) {
179     VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
180             << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
181   }
182 
183   return equal;
184 }
185 
EqualStructure(const Shape & lhs,const Shape & rhs)186 /* static */ bool ShapeUtil::EqualStructure(const Shape& lhs,
187                                             const Shape& rhs) {
188   bool equal = true;
189   ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
190     equal &= IndexIsValid(rhs, index);
191   });
192   ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
193     equal &= IndexIsValid(lhs, index);
194   });
195 
196   return equal;
197 }
198 
TrueRank(const Shape & shape)199 /* static */ int64_t ShapeUtil::TrueRank(const Shape& shape) {
200   int64_t accum = 0;
201   for (int64_t dimension : shape.dimensions()) {
202     // We do not count zero dimensions.
203     if (dimension != 1) {
204       accum += 1;
205     }
206   }
207   return accum;
208 }
209 
FillNewShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions,Shape * shape)210 /* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type,
211                                           absl::Span<const int64_t> dimensions,
212                                           Shape* shape) {
213   const int eint = static_cast<int>(element_type);
214   int64_t dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE)
215                                   ? primitive_byte_size[eint]
216                                   : 0);  // Out of range: force a failure
217   if (dense_shape_size <= 0) {
218     return false;
219   }
220 
221   // Verify that array-based lookup is consistent with public API.
222   DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type))
223       << element_type;
224 
225   shape->set_element_type(element_type);
226   const int ndims = dimensions.size();
227   auto layout = shape->mutable_layout();
228   auto* minor_to_major = layout->mutable_minor_to_major();
229   for (int i = 0; i < ndims; i++) {
230     const int64_t d = dimensions[i];
231     if (d < 0) {
232       return false;
233     }
234     dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
235     if (dense_shape_size < 0) {
236       return false;
237     }
238 
239     shape->add_dimensions(d);
240     minor_to_major->push_back(ndims - 1 - i);
241   }
242   return true;
243 }
244 
MakeProgramShape(std::initializer_list<Shape> parameters,Shape result)245 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
246     std::initializer_list<Shape> parameters, Shape result) {
247   ProgramShape program_shape;
248   for (const Shape& shape : parameters) {
249     *program_shape.add_parameters() = shape;
250   }
251   *program_shape.mutable_result() = std::move(result);
252   return program_shape;
253 }
254 
MakeShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions)255 /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
256                                         absl::Span<const int64_t> dimensions) {
257   Shape shape;
258   CHECK(FillNewShape(element_type, dimensions, &shape));
259   return shape;
260 }
261 
MakeScalarShape(PrimitiveType element_type)262 /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
263   return MakeShape(element_type, {});
264 }
265 
MakeShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions,const std::vector<bool> & dynamic_dimensions)266 /* static */ Shape ShapeUtil::MakeShape(
267     PrimitiveType element_type, absl::Span<const int64_t> dimensions,
268     const std::vector<bool>& dynamic_dimensions) {
269   return MakeValidatedShape(element_type, dimensions, dynamic_dimensions)
270       .ValueOrDie();
271 }
272 
MakeShapeWithStaticDimensions(const Shape & shape)273 /* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
274     const Shape& shape) {
275   Shape output = shape;
276   output.clear_dynamic_dimensions();
277   return output;
278 }
279 
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions)280 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
281     PrimitiveType element_type, absl::Span<const int64_t> dimensions) {
282   Shape shape;
283   if (!FillNewShape(element_type, dimensions, &shape)) {
284     return InvalidArgument("invalid shape type=%d, dims=[%s]",
285                            static_cast<int>(element_type),
286                            absl::StrJoin(dimensions, ","));
287   }
288   return shape;
289 }
290 
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions,const std::vector<bool> & dynamic_dimensions)291 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
292     PrimitiveType element_type, absl::Span<const int64_t> dimensions,
293     const std::vector<bool>& dynamic_dimensions) {
294   if (dynamic_dimensions.size() != dimensions.size()) {
295     return InvalidArgument(
296         "dynamic dimensions size %d did not match number of dimensions %d",
297         dynamic_dimensions.size(), dimensions.size());
298   }
299 
300   Shape shape;
301   if (!FillNewShape(element_type, dimensions, &shape)) {
302     return InvalidArgument("invalid shape type=%d, dims=[%s]",
303                            static_cast<int>(element_type),
304                            absl::StrJoin(dimensions, ","));
305   }
306   for (int i = 0, n = dimensions.size(); i < n; i++) {
307     shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
308   }
309   return shape;
310 }
311 
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> minor_to_major,absl::Span<const DimLevelType> dim_level_types,absl::Span<const Tile> tiles,int64_t element_size_in_bits,int64_t memory_space)312 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
313     PrimitiveType element_type, absl::Span<const int64_t> dimensions,
314     absl::Span<const int64_t> minor_to_major,
315     absl::Span<const DimLevelType> dim_level_types,
316     absl::Span<const Tile> tiles, int64_t element_size_in_bits,
317     int64_t memory_space) {
318   auto ret = MakeShapeWithLayoutInternal(element_type, dimensions,
319                                          minor_to_major, dim_level_types, tiles,
320                                          element_size_in_bits, memory_space);
321   if (!ret.ok()) LOG(ERROR) << ret.status();
322   return ret.ValueOrDie();
323 }
324 
MoveDimToMajor(const Shape & shape,int64_t dim)325 /* static */ Shape ShapeUtil::MoveDimToMajor(const Shape& shape, int64_t dim) {
326   if (shape.IsTuple()) {
327     std::vector<Shape> result_shapes;
328     result_shapes.reserve(shape.tuple_shapes_size());
329     for (const Shape& s : shape.tuple_shapes()) {
330       result_shapes.push_back(MoveDimToMajor(s, dim));
331     }
332     return ShapeUtil::MakeTupleShape(result_shapes);
333   }
334 
335   Shape ret = shape;
336   if (!ret.has_layout()) {
337     LayoutUtil::SetToDefaultLayout(&ret);
338   }
339   *ret.mutable_layout() = LayoutUtil::MoveDimToMajor(ret.layout(), dim);
340   DimensionVector minor_to_major;
341   for (int64_t d : LayoutUtil::MinorToMajor(ret)) {
342     if (d != dim) {
343       minor_to_major.push_back(d);
344     }
345   }
346   minor_to_major.push_back(dim);
347   *ret.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
348   return ret;
349 }
350 
MakeShapeWithDescendingLayout(PrimitiveType element_type,absl::Span<const int64_t> dimensions)351 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
352     PrimitiveType element_type, absl::Span<const int64_t> dimensions) {
353   std::vector<int64_t> layout(dimensions.size());
354   std::iota(layout.rbegin(), layout.rend(), static_cast<int64_t>(0));
355   return MakeShapeWithLayout(element_type, dimensions, layout);
356 }
357 
358 /* static */ Shape
MakeShapeWithDescendingLayoutAndSamePhysicalLayout(const Shape & shape)359 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
360     const Shape& shape) {
361   std::vector<int64_t> dims(shape.dimensions_size());
362   for (int i = 0; i < shape.dimensions_size(); ++i) {
363     int dim = i;
364     if (shape.has_layout()) {
365       dim = LayoutUtil::Major(shape.layout(), dim);
366     }
367     dims[i] = shape.dimensions(dim);
368   }
369   Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
370   // Since the physical layout is kept the same, the tiles and element size are
371   // the same also.
372   if (shape.has_layout()) {
373     new_shape.mutable_layout()->mutable_tiles()->assign(
374         shape.layout().tiles().begin(), shape.layout().tiles().end());
375     new_shape.mutable_layout()->set_element_size_in_bits(
376         shape.layout().element_size_in_bits());
377   }
378   for (int i = 0; i < shape.dimensions_size(); ++i) {
379     new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
380   }
381   return new_shape;
382 }
383 
PopulateShape(PrimitiveType element_type,absl::Span<const int64_t> dimensions,Shape * shape)384 /* static */ Status ShapeUtil::PopulateShape(
385     PrimitiveType element_type, absl::Span<const int64_t> dimensions,
386     Shape* shape) {
387   shape->Clear();
388   shape->set_element_type(element_type);
389   for (int64_t dimension : dimensions) {
390     shape->add_dimensions(dimension);
391   }
392   LayoutUtil::SetToDefaultLayout(shape);
393   return ValidateShape(*shape);
394 }
395 
MakeStaticShape(const Shape & original)396 /* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) {
397   Shape result = original;
398   result.clear_dynamic_dimensions();
399   return result;
400 }
401 
MakeTupleShape(absl::Span<const Shape> shapes)402 /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
403   return MakeTupleShapeImpl(shapes);
404 }
405 
MakeTupleShapeWithPtrs(absl::Span<const Shape * const> shapes)406 /* static */ Shape ShapeUtil::MakeTupleShapeWithPtrs(
407     absl::Span<const Shape* const> shapes) {
408   return MakeTupleShapeImpl(shapes);
409 }
410 
MakeMaybeTupleShape(absl::Span<const Shape> shapes)411 /* static */ Shape ShapeUtil::MakeMaybeTupleShape(
412     absl::Span<const Shape> shapes) {
413   if (shapes.size() == 1) {
414     return shapes[0];
415   }
416   return MakeTupleShape(shapes);
417 }
418 
MakeOpaqueShape()419 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
420   Shape result;
421   result.set_element_type(OPAQUE_TYPE);
422   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
423   return result;
424 }
425 
MakeTokenShape()426 /* static */ Shape ShapeUtil::MakeTokenShape() {
427   Shape result;
428   result.set_element_type(TOKEN);
429   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
430   return result;
431 }
432 
AppendShapeToTuple(const Shape & shape,Shape * tuple_shape)433 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
434                                                 Shape* tuple_shape) {
435   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
436   *tuple_shape->add_tuple_shapes() = shape;
437 }
438 
UpdateTupleShape(const Shape & shape,int64_t index,Shape * tuple_shape)439 /* static */ void ShapeUtil::UpdateTupleShape(const Shape& shape, int64_t index,
440                                               Shape* tuple_shape) {
441   CHECK(index < tuple_shape->tuple_shapes_size());
442   *tuple_shape->mutable_tuple_shapes(index) = shape;
443 }
444 
UpdateDynamicDimension(Shape * shape,ShapeIndexView index,int64_t dim,bool is_dynamic)445 /* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
446                                                     ShapeIndexView index,
447                                                     int64_t dim,
448                                                     bool is_dynamic) {
449   if (index.empty()) {
450     CHECK(!shape->IsTuple());
451     shape->set_dynamic_dimension(dim, is_dynamic);
452     return;
453   }
454 
455   UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
456                          index.subspan(1), dim, is_dynamic);
457 }
458 
AppendMajorDimension(int bound,Shape * shape)459 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
460   CHECK(LayoutUtil::IsDenseArray(*shape));
461   if (shape->has_layout()) {
462     shape->mutable_layout()->add_minor_to_major(shape->rank());
463   }
464   shape->add_dimensions(bound);
465   TF_DCHECK_OK(ValidateShape(*shape));
466 }
467 
AppendMinorDimension(int bound,Shape * shape)468 /* static */ void ShapeUtil::AppendMinorDimension(int bound, Shape* shape) {
469   CHECK(LayoutUtil::IsDenseArray(*shape));
470 
471   if (shape->has_layout()) {
472     // Bump up all values in the layout by one.
473     for (int dim_idx = 0; dim_idx < shape->layout().minor_to_major_size();
474          dim_idx++) {
475       int layout_idx = shape->layout().minor_to_major(dim_idx);
476       shape->mutable_layout()->set_minor_to_major(dim_idx, layout_idx + 1);
477     }
478     // Then we can safely add zero.
479     shape->mutable_layout()->add_minor_to_major(0);
480   }
481   shape->add_dimensions(bound);
482   TF_DCHECK_OK(ValidateShape(*shape));
483 }
484 
CopyDynamicDimensions(Shape * to,const Shape & from)485 /* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to,
486                                                    const Shape& from) {
487   CHECK_EQ(to->rank(), from.rank());
488   for (int64_t i = 0; i < from.rank(); ++i) {
489     to->set_dynamic_dimension(i, from.is_dynamic_dimension(i));
490   }
491   TF_DCHECK_OK(ValidateShape(*to));
492 }
493 
ElementIsIntegral(const Shape & shape)494 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
495   return primitive_util::IsIntegralType(shape.element_type());
496 }
497 
ElementIsIntegralWithBits(const Shape & shape,int32_t bits)498 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
499                                                        int32_t bits) {
500   return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
501 }
502 
ElementHasBitWidth(const Shape & shape,int bits)503 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
504   if (!shape.IsArray()) {
505     return false;
506   }
507   return primitive_util::BitWidth(shape.element_type()) == bits;
508 }
509 
ElementIsSigned(const Shape & shape)510 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
511   switch (shape.element_type()) {
512     case S8:
513     case S16:
514     case S32:
515     case S64:
516     case F16:
517     case BF16:
518     case F32:
519     case F64:
520       return true;
521 
522     case PRED:
523     case U8:
524     case U16:
525     case U32:
526     case U64:
527     case C64:
528     case C128:
529     case TUPLE:
530     case OPAQUE_TYPE:
531     case TOKEN:
532       return false;
533 
534     default:
535       LOG(FATAL) << "Unhandled element type " << shape.element_type();
536   }
537 }
538 
ElementIsComplex(const Shape & shape)539 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
540   return primitive_util::IsComplexType(shape.element_type());
541 }
542 
ElementIsFloating(const Shape & shape)543 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
544   return primitive_util::IsFloatingPointType(shape.element_type());
545 }
546 
IsNestedTuple(const Shape & shape)547 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
548   return shape.IsTuple() &&
549          absl::c_any_of(shape.tuple_shapes(),
550                         [](const Shape& s) { return s.IsTuple(); });
551 }
552 
IsEmptyTuple(const Shape & shape)553 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
554   return shape.IsTuple() && TupleElementCount(shape) == 0;
555 }
556 
TupleElementCount(const Shape & shape)557 /* static */ int64_t ShapeUtil::TupleElementCount(const Shape& shape) {
558   CHECK(shape.IsTuple()) << HumanString(shape);
559   return shape.tuple_shapes_size();
560 }
561 
GetTupleElementShape(const Shape & shape,int64_t index)562 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
563                                                           int64_t index) {
564   CHECK(shape.IsTuple());
565   CHECK_GT(TupleElementCount(shape), index);
566   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
567   return shape.tuple_shapes(index);
568 }
569 
SubshapeCount(const Shape & shape)570 /* static */ int64_t ShapeUtil::SubshapeCount(const Shape& shape) {
571   int64_t n = 0;
572   ForEachSubshape(shape, [&](const Shape& literal_subshape,
573                              const ShapeIndex& index) { ++n; });
574   return n;
575 }
576 
SliceTuple(const Shape & tuple,int64_t start,int64_t limit)577 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64_t start,
578                                          int64_t limit) {
579   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
580   CHECK(tuple.IsTuple());
581   CHECK_LE(start, TupleElementCount(tuple));
582   CHECK_LE(limit, TupleElementCount(tuple));
583 
584   std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
585                                   tuple.tuple_shapes().begin() + limit);
586   return MakeTupleShape(new_elements);
587 }
588 
589 // Returns the shape of a real or imaginary component.
ComplexComponentShape(const Shape & complex_shape)590 /* static */ Shape ShapeUtil::ComplexComponentShape(
591     const Shape& complex_shape) {
592   CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
593   return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
594                                               complex_shape.element_type()));
595 }
596 
ElementsIn(const Shape & shape)597 /* static */ int64_t ShapeUtil::ElementsIn(const Shape& shape) {
598   DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape);
599   DCHECK_EQ(shape.dimensions_size(), shape.rank());
600   if (shape.dimensions().size() == 1) {
601     return shape.dimensions()[0];
602   }
603   return std::accumulate<decltype(shape.dimensions().begin()), int64_t>(
604       shape.dimensions().begin(), shape.dimensions().end(), 1LL,
605       std::multiplies<int64_t>());
606 }
607 
ElementsInRecursive(const Shape & shape)608 /* static */ int64_t ShapeUtil::ElementsInRecursive(const Shape& shape) {
609   CHECK(shape.IsArray() || shape.IsTuple());
610   if (shape.IsArray()) {
611     return ElementsIn(shape);
612   }
613   int64_t count = 0;
614   for (const Shape& element_shape : shape.tuple_shapes()) {
615     count += ElementsInRecursive(element_shape);
616   }
617   return count;
618 }
619 
HasPrimitiveType(const Shape & shape,PrimitiveType primitive_type)620 /* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
621                                               PrimitiveType primitive_type) {
622   if (shape.element_type() == primitive_type) {
623     return true;
624   }
625   for (const Shape& element_shape : shape.tuple_shapes()) {
626     if (HasPrimitiveType(element_shape, primitive_type)) {
627       return true;
628     }
629   }
630   return false;
631 }
632 
IsZeroElementArray(const Shape & shape)633 /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
634   return shape.IsArray() && ElementsIn(shape) == 0;
635 }
636 
IsScalarWithElementType(const Shape & shape,PrimitiveType element_type)637 /* static */ bool ShapeUtil::IsScalarWithElementType(
638     const Shape& shape, PrimitiveType element_type) {
639   return IsScalar(shape) && shape.element_type() == element_type;
640 }
641 
HumanString(const Shape & shape)642 /* static */ std::string ShapeUtil::HumanString(const Shape& shape) {
643   if (shape.IsTuple()) {
644     std::string text = "(";
645     const auto& tuple_shapes = shape.tuple_shapes();
646     for (int64_t i = 0; i < tuple_shapes.size(); ++i) {
647       const Shape& elem_shape = tuple_shapes[i];
648       if (i != 0) {
649         StrAppend(&text, ", ");
650         if (i % kAnnotationPrintInterval == 0) {
651           StrAppend(&text, absl::StrFormat("/*index=%lld*/", i));
652         }
653       }
654       StrAppend(&text, HumanString(elem_shape));
655     }
656     text += ")";
657     return text;
658   }
659   std::vector<std::string> dim_elements;
660   const auto dimensions_size = shape.dimensions_size();
661   dim_elements.reserve(dimensions_size);
662   for (int i = 0; i < dimensions_size; ++i) {
663     if (shape.is_dynamic_dimension(i)) {
664       dim_elements.push_back(StrCat("<=", shape.dimensions(i)));
665     } else {
666       dim_elements.push_back(StrCat(shape.dimensions(i)));
667     }
668   }
669   return StrCat(
670       primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
671       absl::StrJoin(dim_elements, ","), "]");
672 }
673 
HumanStringWithLayout(const Shape & shape)674 /* static */ std::string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
675   if (shape.IsTuple()) {
676     std::string text = "(";
677     const auto& tuple_shapes = shape.tuple_shapes();
678     for (int64_t i = 0; i < tuple_shapes.size(); ++i) {
679       const Shape& elem_shape = tuple_shapes[i];
680       if (i != 0) {
681         StrAppend(&text, ", ");
682         if (i % kAnnotationPrintInterval == 0) {
683           StrAppend(&text, absl::StrFormat("/*index=%lld*/", i));
684         }
685       }
686       StrAppend(&text, HumanStringWithLayout(elem_shape));
687     }
688     text += ")";
689     return text;
690   }
691   std::string result = HumanString(shape);
692   if (shape.has_layout()) {
693     if (IsScalar(shape)) {
694       std::string layout_str = LayoutUtil::HumanString(shape.layout());
695       // Don't print "{}" as layout for scalars.
696       if (layout_str != "{}") {
697         StrAppend(&result, layout_str);
698       }
699     } else if (shape.IsArray()) {
700       StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
701     }
702   }
703   return result;
704 }
705 
HumanString(const ProgramShape & program_shape)706 /* static */ std::string ShapeUtil::HumanString(
707     const ProgramShape& program_shape) {
708   std::vector<std::string> parameters;
709   const auto& shape_parameters = program_shape.parameters();
710   parameters.reserve(shape_parameters.size());
711   for (const auto& shape : shape_parameters) {
712     const int i = parameters.size();
713     parameters.push_back(StrCat(i < program_shape.parameter_names_size()
714                                     ? program_shape.parameter_names(i)
715                                     : "(unknown)",
716                                 ": ", HumanString(shape)));
717   }
718   return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
719                 HumanString(program_shape.result()));
720 }
721 
SameDimensions(const Shape & lhs,const Shape & rhs)722 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
723                                             const Shape& rhs) {
724   CHECK(lhs.IsArray());
725   CHECK(rhs.IsArray());
726   return absl::c_equal(lhs.dimensions(), rhs.dimensions());
727 }
728 
SameRank(const Shape & lhs,const Shape & rhs)729 /* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) {
730   CHECK(lhs.IsArray());
731   CHECK(rhs.IsArray());
732   return lhs.rank() == rhs.rank();
733 }
734 
Compatible(const Shape & lhs,const Shape & rhs)735 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
736   return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
737 }
738 
CompatibleIgnoringElementType(const Shape & lhs,const Shape & rhs)739 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
740                                                            const Shape& rhs) {
741   return Shape::Equal()
742       .IgnoreDynamicDimension()
743       .IgnoreElementType()
744       .IgnoreLayout()(lhs, rhs);
745 }
746 
CompatibleKind(const Shape & lhs,const Shape & rhs)747 /* static */ bool ShapeUtil::CompatibleKind(const Shape& lhs,
748                                             const Shape& rhs) {
749   return Shape::Equal()
750       .IgnoreElementType()
751       .IgnoreLayout()
752       .IgnoreDimensions()
753       .IgnoreDynamicDimension()(lhs, rhs);
754 }
755 
CompatibleIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)756 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
757                                                            const Shape& rhs) {
758   return Shape::Equal()
759       .IgnoreDynamicDimension()
760       .IgnoreFpPrecision()
761       .IgnoreLayout()(lhs, rhs);
762 }
763 
GetDimension(const Shape & shape,int64_t dimension_number)764 /* static */ int64_t ShapeUtil::GetDimension(const Shape& shape,
765                                              int64_t dimension_number) {
766   return shape.dimensions(GetDimensionNumber(shape, dimension_number));
767 }
768 
GetDimensionNumber(const Shape & shape,int64_t dimension_number)769 /* static */ int64_t ShapeUtil::GetDimensionNumber(const Shape& shape,
770                                                    int64_t dimension_number) {
771   if (dimension_number < 0) {
772     dimension_number += shape.rank();
773   }
774   CHECK_GE(dimension_number, 0);
775   return dimension_number;
776 }
777 
ByteSizeOfPrimitiveType(PrimitiveType primitive_type)778 /* static */ int64_t ShapeUtil::ByteSizeOfPrimitiveType(
779     PrimitiveType primitive_type) {
780   switch (primitive_type) {
781     case PRED:
782       return sizeof(int8_t);
783     case S8:
784       return sizeof(int8_t);
785     case S16:
786       return sizeof(int16_t);
787     case S32:
788       return sizeof(int32_t);
789     case S64:
790       return sizeof(int64_t);
791     case U8:
792       return sizeof(uint8_t);
793     case U16:
794       return sizeof(uint16_t);
795     case U32:
796       return sizeof(uint32_t);
797     case U64:
798       return sizeof(uint64_t);
799     case BF16:
800       return sizeof(float) / 2;
801     case F16:
802       return sizeof(float) / 2;
803     case F32:
804       return sizeof(float);
805     case F64:
806       return sizeof(double);
807     case C64:
808       return sizeof(complex64);
809     case C128:
810       return sizeof(complex128);
811     case TOKEN:
812       // Tokens require no space.
813       return 0;
814     case TUPLE:
815     case OPAQUE_TYPE:
816       LOG(FATAL) << PrimitiveType_Name(primitive_type)
817                  << " primitive type has no definitive size";
818     default:
819       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
820   }
821 }
822 
ByteSizeOf(const Shape & shape,int64_t pointer_size)823 /* static */ int64_t ShapeUtil::ByteSizeOf(const Shape& shape,
824                                            int64_t pointer_size) {
825   TF_DCHECK_OK(ValidateShape(shape));
826   if (shape.element_type() == TUPLE) {
827     return ByteSizeOfTupleIndexTable(shape, pointer_size);
828   } else if (shape.IsArray()) {
829     return ByteSizeOfElements(shape);
830   } else if (shape.element_type() == TOKEN) {
831     return 0;
832   } else if (shape.element_type() == OPAQUE_TYPE) {
833     CHECK_GT(pointer_size, 0);
834     return pointer_size;
835   }
836   LOG(FATAL) << PrimitiveType_Name(shape.element_type())
837              << " primitive type has no definitive size";
838 }
839 
ByteSizeOfTupleIndexTable(const Shape & shape,int64_t pointer_size)840 /* static */ int64_t ShapeUtil::ByteSizeOfTupleIndexTable(
841     const Shape& shape, int64_t pointer_size) {
842   TF_DCHECK_OK(ValidateShape(shape));
843   CHECK_EQ(TUPLE, shape.element_type());
844   CHECK_GT(pointer_size, 0);
845   return pointer_size * shape.tuple_shapes_size();
846 }
847 
ByteSizeOfElements(const Shape & shape)848 /* static */ int64_t ShapeUtil::ByteSizeOfElements(const Shape& shape) {
849   TF_DCHECK_OK(ValidateShape(shape));
850   int64_t allocated_element_count;
851 
852   CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
853   allocated_element_count = ElementsIn(shape);
854   return allocated_element_count *
855          ByteSizeOfPrimitiveType(shape.element_type());
856 }
857 
ValidateShapeWithOptionalLayoutInternal(const Shape & shape)858 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
859     const Shape& shape) {
860   if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
861       !PrimitiveType_IsValid(shape.element_type())) {
862     return InvalidArgument("shape has invalid element type: %s",
863                            shape.ShortDebugString());
864   }
865   if (shape.element_type() == TUPLE) {
866     if (shape.dimensions_size() != 0) {
867       return InvalidArgument("tuples must not have dimensions specified");
868     }
869     for (auto& element_shape : shape.tuple_shapes()) {
870       TF_RETURN_IF_ERROR(
871           ValidateShapeWithOptionalLayoutInternal(element_shape));
872     }
873     return OkStatus();
874   }
875 
876   // Non-tuple shape.
877   if (shape.tuple_shapes_size() > 0) {
878     return InvalidArgument("non-tuple shape has tuple_shapes field");
879   }
880 
881   // Tokens and opaques can should not have layout or dimensions.
882   if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) {
883     if (shape.dimensions_size() != 0) {
884       return InvalidArgument(
885           "shape has %s element type, but has dimensions field: %s",
886           primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
887           shape.ShortDebugString());
888     }
889     if (shape.has_layout()) {
890       return InvalidArgument(
891           "shape has %s element type, but has layout field: %s",
892           primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
893           shape.ShortDebugString());
894     }
895     return OkStatus();
896   }
897 
898   for (int64_t i = 0; i < shape.rank(); ++i) {
899     int64_t dimension = shape.dimensions(i);
900     if (dimension < 0) {
901       return InvalidArgument(
902           "shape's dimensions must not be < 0; dimension at index %d was %d", i,
903           dimension);
904     }
905   }
906 
907   TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
908   return OkStatus();
909 }
910 
ValidateShapeSize(const Shape & shape)911 /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
912   VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
913 
914   if (!shape.IsArray()) {
915     return OkStatus();
916   }
917 
918   int64_t shape_size = [&]() {
919     int64_t dense_shape_size = 1;
920     if (shape.dimensions().empty()) {
921       return dense_shape_size;
922     }
923 
924     absl::Span<const int64_t> shape_max_dimensions = shape.dimensions();
925     for (int64_t dim : shape_max_dimensions) {
926       dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
927       if (dense_shape_size < 0) {
928         return dense_shape_size;
929       }
930     }
931     dense_shape_size = MultiplyWithoutOverflow(
932         dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
933     return dense_shape_size;
934   }();
935 
936   if (shape_size < 0) {
937     return InvalidArgument("Shape %s size may overflow int64_t.",
938                            ShapeUtil::HumanString(shape));
939   }
940 
941   VLOG(3) << "Shape size is valid: " << shape_size;
942   return OkStatus();
943 }
944 
ValidateShapeWithOptionalLayout(const Shape & shape)945 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
946     const Shape& shape) {
947   TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
948 
949   return LayoutUtil::ValidateLayoutInShape(shape,
950                                            /*allow_missing_layouts=*/true);
951 }
952 
ValidateShape(const Shape & shape)953 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
954   TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
955 
956   return LayoutUtil::ValidateLayoutInShape(shape);
957 }
958 
ChangeElementType(const Shape & original,PrimitiveType type)959 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
960                                                 PrimitiveType type) {
961   if (original.IsTuple()) {
962     std::vector<Shape> new_operands;
963     new_operands.reserve(original.tuple_shapes_size());
964     for (const Shape& operand : original.tuple_shapes()) {
965       new_operands.push_back(ChangeElementType(operand, type));
966     }
967     return MakeTupleShape(new_operands);
968   } else {
969     Shape new_shape = original;
970     new_shape.set_element_type(type);
971     return new_shape;
972   }
973 }
974 
IndexIsValid(const Shape & shape,ShapeIndexView index)975 /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
976                                           ShapeIndexView index) {
977   const Shape* subshape = &shape;
978   for (auto i : index) {
979     if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) {
980       return false;
981     }
982     subshape = &subshape->tuple_shapes(i);
983   }
984   return true;
985 }
986 
GetSubshape(const Shape & shape,ShapeIndexView index)987 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
988                                                  ShapeIndexView index) {
989   const Shape* return_shape = &shape;
990   for (auto i : index) {
991     CHECK(return_shape->IsTuple())
992         << "Invalid index " << ShapeIndex(index) << " for shape " << shape;
993     return_shape = &return_shape->tuple_shapes(i);
994   }
995   return *return_shape;
996 }
997 
TryGetSubshape(const Shape & shape,ShapeIndexView index)998 /* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
999     const Shape& shape, ShapeIndexView index) {
1000   const Shape* return_shape = &shape;
1001   for (auto i : index) {
1002     if (!return_shape->IsTuple() || i < 0 ||
1003         i >= return_shape->tuple_shapes_size()) {
1004       return InvalidArgument(
1005           "Shape index %s not a valid subshape index for tuple with shape %s",
1006           ShapeIndex(index).ToString(), shape.DebugString());
1007     }
1008     return_shape = &return_shape->tuple_shapes(i);
1009   }
1010   return return_shape;
1011 }
1012 
GetMutableSubshape(Shape * shape,ShapeIndexView index)1013 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
1014                                                   ShapeIndexView index) {
1015   Shape* return_shape = shape;
1016   for (auto i : index) {
1017     CHECK(return_shape->IsTuple());
1018     return_shape = return_shape->mutable_tuple_shapes(i);
1019   }
1020   return return_shape;
1021 }
1022 
1023 /* static */
IsLeafIndex(const Shape & shape,const ShapeIndex & index)1024 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
1025   return !GetSubshape(shape, index).IsTuple();
1026 }
1027 
GetLeafCount(const Shape & shape)1028 /* static */ int64_t ShapeUtil::GetLeafCount(const Shape& shape) {
1029   if (!shape.IsTuple()) {
1030     return 1;
1031   }
1032   int64_t count = 0;
1033   for (const Shape& subshape : shape.tuple_shapes()) {
1034     count += GetLeafCount(subshape);
1035   }
1036   return count;
1037 }
1038 
GetLeafShapes(const Shape & shape)1039 /* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
1040     const Shape& shape) {
1041   std::vector<IndexedShape> leaves;
1042   ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
1043     if (IsLeafIndex(shape, index)) {
1044       leaves.emplace_back(index, sub_shape);
1045     }
1046   });
1047   return leaves;
1048 }
1049 
HasDegenerateDimensions(const Shape & shape)1050 /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
1051   CHECK(shape.IsArray());
1052   return absl::c_linear_search(shape.dimensions(), 1);
1053 }
1054 
DropDegenerateDimensions(const Shape & shape)1055 /* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
1056   return FilterDimensions(
1057       [&](int64_t dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
1058 }
1059 
1060 namespace {
1061 
1062 // Helper for ForEachSubshape which visits the subshapes of the given shape in
1063 // DFS pre-order starting with the index.
ForEachSubshapeHelper(const Shape & shape,const ShapeUtil::StatusVisitorFunction & func,ShapeIndex * index)1064 Status ForEachSubshapeHelper(const Shape& shape,
1065                              const ShapeUtil::StatusVisitorFunction& func,
1066                              ShapeIndex* index) {
1067   TF_RETURN_IF_ERROR(func(shape, *index));
1068   if (shape.IsTuple()) {
1069     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
1070       index->push_back(i);
1071       TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
1072           ShapeUtil::GetTupleElementShape(shape, i), func, index));
1073       index->pop_back();
1074     }
1075   }
1076   return OkStatus();
1077 }
1078 
1079 // Helper for ForEachMutableSubshape which visits the subshapes of the given
1080 // shape in DFS pre-order starting with the index.
ForEachMutableSubshapeHelper(Shape * shape,const ShapeUtil::MutatingStatusVisitorFunction & func,ShapeIndex * index)1081 Status ForEachMutableSubshapeHelper(
1082     Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
1083     ShapeIndex* index) {
1084   TF_RETURN_IF_ERROR(func(shape, *index));
1085   if (shape->IsTuple()) {
1086     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
1087       index->push_back(i);
1088       TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
1089           shape->mutable_tuple_shapes(i), func, index));
1090       index->pop_back();
1091     }
1092   }
1093   return OkStatus();
1094 }
1095 
1096 }  // namespace
1097 
ForEachSubshape(const Shape & shape,const VisitorFunction & func)1098 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
1099                                              const VisitorFunction& func) {
1100   ShapeIndex index;
1101   ForEachSubshapeHelper(
1102       shape,
1103       [&func](const Shape& subshape, const ShapeIndex& index) {
1104         func(subshape, index);
1105         return OkStatus();
1106       },
1107       &index)
1108       .IgnoreError();
1109 }
1110 
ForEachMutableSubshape(Shape * shape,const MutatingVisitorFunction & func)1111 /* static */ void ShapeUtil::ForEachMutableSubshape(
1112     Shape* shape, const MutatingVisitorFunction& func) {
1113   ShapeIndex index;
1114   ForEachMutableSubshapeHelper(
1115       shape,
1116       [&func](Shape* subshape, const ShapeIndex& index) {
1117         func(subshape, index);
1118         return OkStatus();
1119       },
1120       &index)
1121       .IgnoreError();
1122 }
1123 
ForEachSubshapeWithStatus(const Shape & shape,const StatusVisitorFunction & func)1124 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
1125     const Shape& shape, const StatusVisitorFunction& func) {
1126   ShapeIndex index;
1127   return ForEachSubshapeHelper(shape, func, &index);
1128 }
1129 
ForEachMutableSubshapeWithStatus(Shape * shape,const MutatingStatusVisitorFunction & func)1130 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
1131     Shape* shape, const MutatingStatusVisitorFunction& func) {
1132   ShapeIndex index;
1133   return ForEachMutableSubshapeHelper(shape, func, &index);
1134 }
1135 
PermuteDimensions(absl::Span<const int64_t> permutation,const Shape & shape)1136 /* static */ Shape ShapeUtil::PermuteDimensions(
1137     absl::Span<const int64_t> permutation, const Shape& shape) {
1138   Shape new_shape = shape;
1139   new_shape.clear_dimensions();
1140   for (auto dim : Permute(shape.dimensions(), permutation)) {
1141     new_shape.add_dimensions(dim);
1142   }
1143   auto inv_permutation = InversePermutation(permutation);
1144   for (int64_t i = 0; i < shape.rank(); i++) {
1145     new_shape.set_dynamic_dimension(inv_permutation[i],
1146                                     shape.is_dynamic_dimension(i));
1147   }
1148 
1149   // If `shape` has a layout, by contract we choose a new layout such that the
1150   // transpose defined by this permutation is a bitcast.
1151   //
1152   // Some formalism helps to understand the correct way to do this.  We're going
1153   // to do algebra in the group of permutations of the dimensions of `shape`.
1154   //
1155   // Since the order of `shape`'s dimensions is not permuted relative to itself,
1156   // `shape`'s list of dimensions is isomorphic to the identity I.
1157   //
1158   // Let `shape`'s layout be L.  A layout is a permutation which maps a
1159   // minor-to-major physical dimension ordering to a shape's logical dimension
1160   // ordering.  Therefore the inverse of a layout maps from logical to physical
1161   // dims, and so the physical ordering of I is simply L'.I = L', where L' is
1162   // the inverse of L.
1163   //
1164   // Let the argument `permutation` be P.  This is a permutation over `shape`'s
1165   // dimensions, so our return value will be a shape with dims P.I = P.  Our
1166   // goal is to construct a layout permutation L* for this shape. The physical
1167   // dimension ordering of this returned shape must be the same as that of the
1168   // original shape, namely L'.
1169   //
1170   // Our returned shape has dims P and layout L*, so its in-memory ordering is
1171   // L*'.P.  Setting this equal to L' and solving for L*, we get:
1172   //
1173   //   L*'.P = L'    =>
1174   //   L*'   = L'P'  =>
1175   //   L*    = P.L
1176   //
1177   if (shape.has_layout()) {
1178     CHECK(LayoutUtil::IsDenseArray(shape));
1179     Layout* new_layout = new_shape.mutable_layout();
1180     new_layout->clear_minor_to_major();
1181     for (auto index : ComposePermutations(inv_permutation,
1182                                           shape.layout().minor_to_major())) {
1183       new_layout->add_minor_to_major(index);
1184     }
1185     // The permutation accepted by TransposeIsBitcast is the inverse of the
1186     // permutation here.
1187     CHECK(TransposeIsBitcast(shape, new_shape, permutation))
1188         << "shape=" << HumanStringWithLayout(shape)
1189         << ", new_shape=" << HumanStringWithLayout(new_shape)
1190         << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
1191   }
1192   return new_shape;
1193 }
1194 
1195 /* static */ std::optional<ShapeUtil::ShapeEqualityDescriptor>
InsertedOrDeleted1SizedDimensions(const Shape & shape_pre,const Shape & shape_post)1196 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
1197                                              const Shape& shape_post) {
1198   CHECK(shape_pre.IsArray());
1199   CHECK(shape_post.IsArray());
1200 
1201   std::vector<int64_t> deleted_indices;
1202   std::vector<int64_t> inserted_indices;
1203   // Returns false if any input/output index between prior_unmodified_dim_pair
1204   // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
1205   // the degerenate input/output dimensions in the gap to
1206   // deleted_indices/inserted_indices respectively.
1207   auto check_modified_dims =
1208       [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
1209           std::pair<int64_t, int64_t> prior_unmodified_dim_pair,
1210           std::pair<int64_t, int64_t> unmodified_dim_pair) {
1211         for (int64_t modified_input_dim = prior_unmodified_dim_pair.first + 1;
1212              modified_input_dim < unmodified_dim_pair.first;
1213              ++modified_input_dim) {
1214           if (shape_pre.dimensions(modified_input_dim) > 1) {
1215             return false;
1216           }
1217           deleted_indices.push_back(modified_input_dim);
1218         }
1219         for (int64_t modified_output_dim = prior_unmodified_dim_pair.second + 1;
1220              modified_output_dim < unmodified_dim_pair.second;
1221              ++modified_output_dim) {
1222           if (shape_post.dimensions(modified_output_dim) > 1) {
1223             return false;
1224           }
1225           inserted_indices.push_back(modified_output_dim);
1226         }
1227         return true;
1228       };
1229 
1230   std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
1231       DimensionsUnmodifiedByReshape(shape_pre, shape_post);
1232   // Returns nil if the reshape modifies any non-degenerate input/output
1233   // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
1234   // dimensions, so we only need to check whether dimensions in the gaps (thus
1235   // modified) have size >1.
1236   for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
1237     // Check (modified) dimensions between unmodified_dims[i-1] and
1238     // unmodified_dims[i].
1239     auto prior_unmodified_dim_pair =
1240         i > 0 ? unmodified_dims[i - 1] : std::pair<int64_t, int64_t>(-1, -1);
1241     auto unmodified_dim_pair =
1242         i < unmodified_dims.size()
1243             ? unmodified_dims[i]
1244             : std::make_pair(shape_pre.rank(), shape_post.rank());
1245     if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
1246       return std::nullopt;
1247     }
1248   }
1249 
1250   return ShapeEqualityDescriptor{deleted_indices, inserted_indices};
1251 }
1252 
1253 /* static */ std::vector<std::pair<int64_t, int64_t>>
DimensionsUnmodifiedByReshape(const Shape & input_shape,const Shape & output_shape)1254 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
1255                                          const Shape& output_shape) {
1256   CHECK(input_shape.IsArray());
1257   CHECK(output_shape.IsArray());
1258 
1259   // Unmodified dimensions are merely common factors of rank 1.
1260   auto common_factors =
1261       CommonFactors(input_shape.dimensions(), output_shape.dimensions());
1262   for (size_t i = 0; i < common_factors.size() - 1;) {
1263     if (1 != common_factors[i + 1].first - common_factors[i].first ||
1264         1 != common_factors[i + 1].second - common_factors[i].second) {
1265       common_factors.erase(common_factors.begin() + i);
1266     } else {
1267       ++i;
1268     }
1269   }
1270   // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
1271   common_factors.pop_back();
1272   return std::vector<std::pair<int64_t, int64_t>>(common_factors.begin(),
1273                                                   common_factors.end());
1274 }
1275 
1276 /* static */ std::optional<std::vector<int64_t>>
ReshapeLeavesDimensionsUnmodified(const Shape & from_shape,const Shape & to_shape,absl::Span<const int64_t> input_dim_indices)1277 ShapeUtil::ReshapeLeavesDimensionsUnmodified(
1278     const Shape& from_shape, const Shape& to_shape,
1279     absl::Span<const int64_t> input_dim_indices) {
1280   if (!std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())) {
1281     return std::nullopt;
1282   }
1283 
1284   std::vector<int64_t> output_dim_indices;
1285   std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
1286       ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
1287   size_t i = 0;  // index to unmodified_dims
1288   for (int64_t input_dim_index : input_dim_indices) {
1289     // Search unmodified_dims for input_dim_index. We can search from the last
1290     // matching position because input_dim_indices is guaranteed to be sorted.
1291     while (i < unmodified_dims.size() &&
1292            unmodified_dims[i].first < input_dim_index) {
1293       ++i;
1294     }
1295     if (i >= unmodified_dims.size() ||
1296         unmodified_dims[i].first != input_dim_index) {
1297       return std::nullopt;
1298     }
1299     output_dim_indices.push_back(unmodified_dims[i].second);
1300   }
1301   return output_dim_indices;
1302 }
1303 
TransposeIsBitcast(const Shape & input_shape,const Shape & output_shape,absl::Span<const int64_t> dimension_mapping)1304 /* static */ bool ShapeUtil::TransposeIsBitcast(
1305     const Shape& input_shape, const Shape& output_shape,
1306     absl::Span<const int64_t> dimension_mapping) {
1307   CHECK(LayoutUtil::IsDenseArray(input_shape)) << input_shape.ToString(true);
1308   CHECK(LayoutUtil::IsDenseArray(output_shape)) << output_shape.ToString(true);
1309   CHECK(input_shape.has_layout()) << input_shape.ToString(true);
1310   CHECK(output_shape.has_layout()) << output_shape.ToString(true);
1311 
1312   if (!SameElementType(input_shape, output_shape)) {
1313     return false;
1314   }
1315 
1316   // Check the reshape permutes the positions of each dimension in the
1317   // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
1318   //   input_positions = apply(dimension_mapping, output_positions)
1319   //
1320   // Because the positions of each dimension are the inverse permutation of the
1321   // minor-to-major order, the above check is equivalent to
1322   //   inverse(input_dimensions) =
1323   //       apply(dimension_mapping, inverse(output_dimensions))
1324   //   # `I` indicates identity permutation.
1325   //   apply(input_dimensions, I) =
1326   //       apply(dimension_mapping, apply(output_dimensions, I))
1327   //   apply(input_dimensions, I) =
1328   //       apply((dimension_mapping * output_dimensions), I)
1329   //   input_dimensions = dimension_mapping * output_dimensions
1330   return absl::c_equal(
1331       ComposePermutations(dimension_mapping,
1332                           output_shape.layout().minor_to_major()),
1333       input_shape.layout().minor_to_major());
1334 }
1335 
ReshapeIsBitcast(const Shape & input_shape,const Shape & output_shape)1336 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
1337                                               const Shape& output_shape) {
1338   CHECK(LayoutUtil::IsDenseArray(input_shape)) << input_shape.ToString(true);
1339   CHECK(LayoutUtil::IsDenseArray(output_shape)) << output_shape.ToString(true);
1340   CHECK(input_shape.has_layout()) << input_shape.ToString(true);
1341   CHECK(output_shape.has_layout()) << output_shape.ToString(true);
1342 
1343   if (!SameElementType(input_shape, output_shape)) {
1344     return false;
1345   }
1346 
1347   if (ElementsIn(input_shape) != ElementsIn(output_shape)) {
1348     VLOG(3) << "input_shape=" << input_shape.ShortDebugString()
1349             << ", output_shape=" << output_shape.ShortDebugString();
1350     return false;
1351   }
1352   if (ElementsIn(input_shape) == 0) {
1353     return true;
1354   }
1355 
1356   // TL;DR: The rest of the method checks that the reshape does not change the
1357   // physical location of any unit input or output index. Unit indices have
1358   // exactly one dimension that equals 1 and other dimensions 0. This condition
1359   // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
1360   // reshape shouldn't change the physical location of any element. It is also a
1361   // sufficient condition as is proved below (note: many details are omitted for
1362   // space).
1363   //
1364   // Definitions:
1365   //
1366   // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
1367   // the size of i-th least significant dimension of IS or OS (this is opposite
1368   // to how we define the index of Shape::dimensions()).
1369   //
1370   // * Given an input or output index I, denote by p(I) I's physical linear
1371   // index (or physical index for short) and l(I) I's logical linear index (or
1372   // logical index for short).
1373   //
1374   // * Given a logical index k, denote by II(k) the input index whose linear
1375   // index is k, and OI(k) the corresponding output index.
1376   //
1377   // * Denote by IT[i] the increment of physical index if i-th dimension of the
1378   // input index is increased by 1. Similarly, OT[i] means the increment if i-th
1379   // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
1380   // is a function of IS or OS and the layout, and not dependent on the specific
1381   // input or output index.
1382   //
1383   // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
1384   // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
1385   // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
1386   // to prove, with every increment on k, the above formula still holds.
1387   //
1388   // First, suppose reshaping from IS to OS is non-factorizable (we discuss
1389   // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
1390   // there exists (i,j) such that
1391   //
1392   //   0<=i<=|IS|
1393   //   0<=j<=|OS|
1394   //   |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
1395   //   product(IS[i], IS[i+1], ..., IS[|IS|-1])
1396   //     = product(OS[j], OS[j+1], ..., OS[|OS|-1])
1397   //
1398   // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
1399   // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
1400   // unit indices which are already tested. This also means IT[0]=OT[0]
1401   // because p(II(1))=IT[0] and p(OI(1))=OT[0].
1402   //
1403   // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
1404   // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
1405   // to the output physical.
1406   //
1407   // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
1408   // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
1409   // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
1410   // logical index k-1 to logical index k, dimension 1 of the input index
1411   // is increased by 1 and dimension 0 is reset to 0 thus decreased by
1412   // IS[0]-1. Therefore, the physical input index is increased by
1413   //
1414   //   p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
1415   //
1416   // Because IS[0]<OS[0], the only change to the output index is that its
1417   // dimension 0 is increased by one. Therefore,
1418   //
1419   //   p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
1420   //
1421   // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
1422   // p(II(k))=p(OI(k)). Therefore,
1423   //   IT[1] - (IS[0]-1) * IT[0] = IT[0]
1424   //   IT[1] = IS[0] * IT[0]
1425   // In other words, input dimension 1 is immediately more major than input
1426   // dimension 0. We can now conceptually collapse these two dimensions because
1427   // an increment in the logical index affecting only these two dimensions maps
1428   // to IT[0] in the physical index.
1429   //
1430   // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
1431   // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
1432   // identical.
1433   //
1434   // A factorizable reshape can be factorized into a list of non-factorizable
1435   // sub-reshapes, each of which can be handled similarly to the proof above.
1436   // For example,
1437   //
1438   //   [7x9x2x15] -> [63x6x5]
1439   //
1440   // can be factorized into
1441   //
1442   //   [7x9] -> [63] and [2x15] -> [6x5].
1443   //
1444   // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
1445   // same logical linear index. According to the factorization, we know
1446   // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
1447   // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
1448   // similar proof, with the increment of the logical index set to
1449   // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
1450   // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
1451   //
1452   //   p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
1453   //                  = p(y2,0,0) + p(0,0,y1,y0)
1454   //                  = p(y2,y1,y0)
1455   //
1456   // check_input_unit_indices checks one way of the condition: each input unit
1457   // index is mapped to an output index with the same physical location. This
1458   // lambda will be called again with input_shape and output_shape reversed to
1459   // check the other way.
1460   auto check_input_unit_indices = [](const Shape& input_shape,
1461                                      const Shape& output_shape) {
1462     // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
1463     // as input_shape/output_shape and the dimension-0-major layout. These two
1464     // shapes are used for conversion between logical linear indices and
1465     // multi-dimensional indices.
1466     Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
1467         input_shape.element_type(), input_shape.dimensions());
1468     Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
1469         output_shape.element_type(), output_shape.dimensions());
1470 
1471     for (int64_t input_dim = 0; input_dim < input_shape.rank(); ++input_dim) {
1472       if (input_shape.dimensions(input_dim) <= 1) {
1473         continue;
1474       }
1475 
1476       std::vector<int64_t> input_unit_index(input_shape.rank(), 0);
1477       input_unit_index[input_dim] = 1;
1478       int64_t logical_linear_index =
1479           IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
1480                                                         input_unit_index);
1481       // output_index has the same logical linear index as input_unit_index.
1482       std::vector<int64_t> output_index =
1483           IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
1484                                                         logical_linear_index);
1485       // Check input_unit_index and output_index have the same physical linear
1486       // index.
1487       if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
1488                                                         input_unit_index) !=
1489           IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
1490                                                         output_index)) {
1491         return false;
1492       }
1493     }
1494     return true;
1495   };
1496   return check_input_unit_indices(input_shape, output_shape) &&
1497          check_input_unit_indices(output_shape, input_shape);
1498 }
1499 
AlignLayouts(const Shape & input_shape,const Shape & output_shape)1500 /* static */ std::optional<Shape> ShapeUtil::AlignLayouts(
1501     const Shape& input_shape, const Shape& output_shape) {
1502   CHECK(input_shape.IsArray());
1503   CHECK(output_shape.IsArray());
1504   // Removing trivial dimensions from the shape simplifies the alignment
1505   // algorithm since ones can go in any position.
1506   if (HasDegenerateDimensions(input_shape) ||
1507       HasDegenerateDimensions(output_shape)) {
1508     auto simple_output_shape =
1509         AlignLayouts(DropDegenerateDimensions(input_shape),
1510                      DropDegenerateDimensions(output_shape));
1511     if (!simple_output_shape) {
1512       return std::nullopt;
1513     }
1514 
1515     std::vector<int64_t> layout =
1516         SpanToVector(simple_output_shape->layout().minor_to_major());
1517     // For each one sized dimension in the output, increment the dimension
1518     // numbers in layout that are more minor than the one.
1519     absl::InlinedVector<int64_t, 8> dim_map;
1520     dim_map.reserve(simple_output_shape->rank());
1521     for (int64_t i = 0; i < output_shape.rank(); ++i) {
1522       if (output_shape.dimensions(i) != 1) {
1523         dim_map.push_back(i);
1524       }
1525     }
1526     for (int64_t& d : layout) {
1527       d = dim_map[d];
1528     }
1529 
1530     // Add the ones in descending order to the layout. Descending layouts tend
1531     // to reduce the number of copies inserted in layout assignment.
1532     for (int64_t i = output_shape.rank() - 1; i >= 0; --i) {
1533       if (output_shape.dimensions(i) == 1) {
1534         layout.push_back(i);
1535       }
1536     }
1537     Shape output_shape_with_layout = output_shape;
1538     *output_shape_with_layout.mutable_layout() = Layout{layout};
1539     return output_shape_with_layout;
1540   }
1541 
1542   auto common_factors =
1543       CommonFactors(input_shape.dimensions(), output_shape.dimensions());
1544   const int64_t input_rank = input_shape.rank();
1545   DimensionVector input_to_factor(input_rank);
1546   for (int64_t pos = 0; pos < common_factors.size() - 1; ++pos) {
1547     const int64_t input_start = common_factors[pos].first;
1548     const int64_t input_end = common_factors[pos + 1].first;
1549     int64_t input_physical =
1550         PositionInContainer(input_shape.layout().minor_to_major(), input_start);
1551     input_to_factor[input_start] = pos;
1552     for (int64_t i = input_start + 1; i < input_end; ++i) {
1553       --input_physical;
1554       if (input_physical < 0 ||
1555           input_shape.layout().minor_to_major(input_physical) != i) {
1556         return std::nullopt;
1557       }
1558       input_to_factor[i] = pos;
1559     }
1560   }
1561 
1562   int64_t output_rank = output_shape.rank();
1563   DimensionVector output_layout;
1564   output_layout.reserve(output_rank);
1565   int64_t input_minor = 0;
1566   while (output_layout.size() < output_rank) {
1567     const int64_t input_dim = input_shape.layout().minor_to_major(input_minor);
1568     const int64_t common_factor = input_to_factor[input_dim];
1569     const auto start_factor = common_factors[common_factor];
1570     const auto end_factor = common_factors[common_factor + 1];
1571     for (int64_t dim = end_factor.second - 1; dim >= start_factor.second;
1572          --dim) {
1573       output_layout.push_back(dim);
1574     }
1575     input_minor += end_factor.first - start_factor.first;
1576   }
1577 
1578   Shape output_shape_with_layout = MakeShapeWithLayout(
1579       output_shape.element_type(), output_shape.dimensions(), output_layout);
1580   CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout))
1581       << "reshape is not a bitcast for input_shape: "
1582       << ShapeUtil::HumanStringWithLayout(input_shape)
1583       << " and output_shape_with_layout: "
1584       << ShapeUtil::HumanStringWithLayout(output_shape_with_layout);
1585   return output_shape_with_layout;
1586 }
1587 
DeleteDimension(int64_t dim_to_delete,Shape shape)1588 /* static */ Shape ShapeUtil::DeleteDimension(int64_t dim_to_delete,
1589                                               Shape shape) {
1590   CHECK(shape.IsArray());
1591   shape.DeleteDimension(dim_to_delete);
1592   return shape;
1593 }
1594 
DynamicArrayShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1595 /* static */ bool ShapeUtil::DynamicArrayShapeIsCompatible(
1596     const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1597   if (dynamic_shape.rank() != bounded_shape.rank()) {
1598     return false;
1599   }
1600   for (int64_t i = 0; i < dynamic_shape.rank(); ++i) {
1601     if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
1602       return false;
1603     }
1604   }
1605   return true;
1606 }
1607 
DynamicShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1608 /* static */ bool ShapeUtil::DynamicShapeIsCompatible(
1609     const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1610   bool compatible = true;
1611   xla::ShapeUtil::ForEachSubshape(
1612       dynamic_shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
1613         if (compatible) {
1614           auto subshape_result = TryGetSubshape(bounded_shape, index);
1615           if (subshape_result.ok()) {
1616             const Shape* bounded_sub_shape = std::move(subshape_result).value();
1617             if (sub_shape.IsTuple()) {
1618               if (!bounded_sub_shape->IsTuple()) {
1619                 compatible = false;
1620               }
1621             } else {
1622               if (bounded_sub_shape->IsTuple()) {
1623                 compatible = false;
1624               } else if (!sub_shape.is_static() &&
1625                          !DynamicArrayShapeIsCompatible(sub_shape,
1626                                                         *bounded_sub_shape)) {
1627                 compatible = false;
1628               }
1629             }
1630           } else {
1631             compatible = false;
1632           }
1633         }
1634       });
1635   return compatible;
1636 }
1637 
DeleteDimensions(absl::Span<int64_t const> dims_to_delete,Shape shape)1638 /* static */ Shape ShapeUtil::DeleteDimensions(
1639     absl::Span<int64_t const> dims_to_delete, Shape shape) {
1640   std::vector<int64_t> dims_to_delete_v(dims_to_delete.begin(),
1641                                         dims_to_delete.end());
1642   absl::c_sort(dims_to_delete_v, std::greater<int64_t>());
1643   for (int64_t dim : dims_to_delete_v) {
1644     shape = DeleteDimension(dim, shape);
1645   }
1646   return shape;
1647 }
1648 
FilterDimensions(const std::function<bool (int64_t)> & p,Shape shape)1649 /* static */ Shape ShapeUtil::FilterDimensions(
1650     const std::function<bool(int64_t)>& p, Shape shape) {
1651   CHECK(shape.IsArray());
1652   std::vector<int64_t> dims_to_delete;
1653   for (int64_t i = shape.dimensions().size() - 1; i >= 0; --i) {
1654     if (!p(i)) {
1655       dims_to_delete.push_back(i);
1656     }
1657   }
1658   return DeleteDimensions(dims_to_delete, shape);
1659 }
1660 
1661 // Returns the indices of the first elements of all consecutive subarrays of the
1662 // given array. For example:
1663 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
ConsecutiveSegments(absl::Span<const int64_t> xs)1664 static std::vector<size_t> ConsecutiveSegments(absl::Span<const int64_t> xs) {
1665   std::vector<size_t> is = {0};
1666   for (size_t i = 1; i < xs.size(); ++i) {
1667     if (1 != xs[i] - xs[i - 1]) {
1668       is.push_back(i);
1669     }
1670   }
1671   return is;
1672 }
1673 
1674 // Merges the sequences of dimensions of the given shape which start at the
1675 // given indices `segs`.
MergeDimensions(absl::Span<const size_t> segs,const Shape & shape)1676 static Shape MergeDimensions(absl::Span<const size_t> segs,
1677                              const Shape& shape) {
1678   std::vector<int64_t> dimensions;
1679   const auto size = segs.size();
1680   dimensions.reserve(size);
1681   for (size_t i = 1; i <= size; ++i) {
1682     dimensions.push_back(std::accumulate(
1683         shape.dimensions().begin() + segs[i - 1],
1684         shape.dimensions().begin() +
1685             (segs.size() == i ? shape.dimensions().size() : segs[i]),
1686         1, std::multiplies<int64_t>()));
1687   }
1688   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1689                                                   dimensions);
1690 }
1691 
FindTranspose021(const Shape & a,const Shape & b)1692 /* static */ std::optional<Vector3> ShapeUtil::FindTranspose021(
1693     const Shape& a, const Shape& b) {
1694   if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) {
1695     return std::nullopt;
1696   }
1697 
1698   std::vector<int64_t> permutation(a.dimensions().size());
1699   absl::Span<const int64_t> minor_to_major_a = LayoutUtil::MinorToMajor(a);
1700   std::vector<int64_t> major_to_minor_a(minor_to_major_a.rbegin(),
1701                                         minor_to_major_a.rend());
1702   absl::Span<const int64_t> minor_to_major_b = LayoutUtil::MinorToMajor(b);
1703   std::vector<int64_t> major_to_minor_b(minor_to_major_b.rbegin(),
1704                                         minor_to_major_b.rend());
1705   for (size_t i = 0; i < permutation.size(); ++i) {
1706     permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
1707   }
1708 
1709   std::vector<size_t> segments = ConsecutiveSegments(permutation);
1710   if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
1711     Shape descending_layout_shape =
1712         ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
1713     Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
1714     absl::Span<const int64_t> normalized_dims = normalized_shape.dimensions();
1715     Vector3 dims_021;
1716     if (2 == segments.size()) {
1717       // The logical component-0 is of size one.
1718       dims_021 = {1, normalized_dims[1], normalized_dims[0]};
1719     } else {
1720       dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
1721     }
1722     return dims_021;
1723   }
1724 
1725   return std::nullopt;
1726 }
1727 
DeviceShapeToHostShape(Shape s)1728 Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
1729   ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) {
1730     if (subshape->IsArray()) {
1731       subshape->mutable_layout()->clear_tiles();
1732       subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);
1733     }
1734   });
1735   return s;
1736 }
1737 
ElementCanUpcast(const Shape & from,const Shape & to)1738 /*static*/ bool ShapeUtil::ElementCanUpcast(const Shape& from,
1739                                             const Shape& to) {
1740   return HigherPrecisionElementType(from, to) == to.element_type();
1741 }
1742 
1743 /*static*/
ByteStrides(const Shape & shape,absl::Span<int64_t> strides)1744 Status ShapeUtil::ByteStrides(const Shape& shape, absl::Span<int64_t> strides) {
1745   TF_RET_CHECK(shape.IsArray());
1746   TF_RET_CHECK(shape.has_layout());
1747   TF_RET_CHECK(shape.dimensions_size() == strides.size());
1748 
1749   int64_t stride = ByteSizeOfPrimitiveType(shape.element_type());
1750   for (int i : shape.layout().minor_to_major()) {
1751     strides.at(i) = stride;
1752     stride *= shape.dimensions(i);
1753   }
1754   return OkStatus();
1755 }
1756 
ArraySize(const Shape & shape)1757 /*static*/ int64_t ShapeUtil::ArraySize(const Shape& shape) {
1758   CHECK(shape.IsArray());
1759   CHECK(!shape.layout().tiles().empty());
1760 
1761   auto tile_dimensions = shape.layout().tiles(0).dimensions();
1762   auto shape_dimensions = shape.dimensions();
1763   auto minor_to_major = shape.layout().minor_to_major();
1764   int64_t shape_dim_size = shape_dimensions.size();
1765   int64_t tile_dim_size = tile_dimensions.size();
1766 
1767   // Use the top-level tile for shape size calculation. We assume the
1768   // sub-tiles won't cause additional padding.
1769   int64_t num_of_elements = 1;
1770   int64_t dim = 0;
1771   for (dim = 0; dim < tile_dim_size; dim++) {
1772     int64_t dim_size =
1773         dim < shape_dim_size ? shape_dimensions[minor_to_major[dim]] : 1;
1774     num_of_elements *=
1775         RoundUpTo(dim_size, tile_dimensions[tile_dim_size - dim - 1]);
1776   }
1777   for (; dim < shape_dim_size; dim++) {
1778     int64_t dim_size = shape_dimensions[minor_to_major[dim]];
1779     num_of_elements *= dim_size;
1780   }
1781   if (shape.layout().element_size_in_bits() != 0) {
1782     const int64_t kBitsPerByte = 8;
1783     return CeilOfRatio(num_of_elements * shape.layout().element_size_in_bits(),
1784                        static_cast<int64_t>(kBitsPerByte));
1785   }
1786   return num_of_elements * ByteSizeOfPrimitiveType(shape.element_type());
1787 }
1788 
ArrayDataSize(const Shape & shape)1789 /*static*/ int64_t ShapeUtil::ArrayDataSize(const Shape& shape) {
1790   CHECK(shape.IsArray());
1791   absl::InlinedVector<int64_t, 4> indices;
1792   for (int64_t dim : shape.dimensions()) {
1793     indices.push_back(dim - 1);
1794   }
1795   int64_t size = LayoutUtil::LinearIndex(shape, indices) + 1;
1796   if (shape.layout().element_size_in_bits() != 0) {
1797     const int64_t kBitsPerByte = 8;
1798     return CeilOfRatio(size * shape.layout().element_size_in_bits(),
1799                        static_cast<int64_t>(kBitsPerByte));
1800   }
1801   return (size * ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()));
1802 }
1803 
1804 }  // namespace xla
1805