• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/shape_util.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/inlined_vector.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/numbers.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/strings/strip.h"
35 #include "absl/types/optional.h"
36 #include "tensorflow/compiler/xla/index_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/overflow_util.h"
39 #include "tensorflow/compiler/xla/permutation_util.h"
40 #include "tensorflow/compiler/xla/primitive_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/gtl/iterator_range.h"
46 #include "tensorflow/core/lib/hash/hash.h"
47 #include "tensorflow/core/lib/strings/numbers.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/protobuf.h"
50 #include "tensorflow/core/platform/regexp.h"
51 
52 namespace xla {
53 
54 using absl::StrAppend;
55 using absl::StrCat;
56 
57 namespace {
58 // An array that is indexed by PrimitiveType, and returns
59 // the size of each element of that primitive type, or 0
60 // if the PrimitiveType is not a primitive type
61 constexpr uint8 primitive_byte_size[PrimitiveType_ARRAYSIZE] = {
62     0,                  // PRIMITIVE_TYPE_INVALID = 0,
63     sizeof(int8),       // PRED = 1
64     sizeof(int8),       // S8 = 2
65     sizeof(int16),      // S16 = 3
66     sizeof(int32),      // S32 = 4
67     sizeof(int64),      // S64 = 5
68     sizeof(uint8),      // U8 = 6
69     sizeof(uint16),     // U16 = 7
70     sizeof(uint32),     // U32 = 8
71     sizeof(uint64),     // U64 = 9
72     sizeof(float) / 2,  // F16 = 10
73     sizeof(float),      // F32 = 11
74     sizeof(double),     // F64 = 12
75     0,                  // TUPLE = 13
76     0,                  // OPAQUE_TYPE = 14
77     sizeof(complex64),  // C64 = 15
78     sizeof(float) / 2,  // BF16 = 16
79     0,                  // TOKEN = 17
80     sizeof(complex128)  // C128 = 18
81 };
82 constexpr int64_t kAnnotationPrintInterval = 5;
83 }  // namespace
84 
ToString() const85 string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
86 
ToString() const87 string ShapeIndexView::ToString() const {
88   return StrCat("{", absl::StrJoin(indices_, ","), "}");
89 }
90 
operator ==(const ShapeIndexView & other) const91 bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
92   return indices_ == other.indices_;
93 }
94 
operator !=(const ShapeIndexView & other) const95 bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
96   return !(*this == other);
97 }
98 
operator <<(std::ostream & out,const ShapeIndex & shape_index)99 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
100   out << shape_index.ToString();
101   return out;
102 }
103 
operator <<(std::ostream & out,const ShapeIndexView & shape_index)104 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
105   out << shape_index.ToString();
106   return out;
107 }
108 
StartsWith(ShapeIndexView prefix) const109 bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const {
110   return size() >= prefix.size() &&
111          indices_.subspan(0, prefix.size()) == prefix.indices_;
112 }
113 
IsArrayPrimitiveType(PrimitiveType primitive_type)114 /* static */ bool ShapeUtil::IsArrayPrimitiveType(
115     PrimitiveType primitive_type) {
116   return primitive_util::IsArrayType(primitive_type);
117 }
118 
119 namespace {
120 // Constructs and returns the new shape with the given minor_to_major order in
121 // its Layout.
MakeShapeWithLayoutInternal(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64_t element_size_in_bits,int64_t memory_space)122 StatusOr<Shape> MakeShapeWithLayoutInternal(
123     PrimitiveType element_type, absl::Span<const int64> dimensions,
124     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
125     int64_t element_size_in_bits, int64_t memory_space) {
126   if (dimensions.size() != minor_to_major.size()) {
127     return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
128                            dimensions.size(), minor_to_major.size());
129   }
130   if (element_type == OPAQUE_TYPE || element_type == TUPLE) {
131     return InvalidArgument("Unsupported element type: %s",
132                            PrimitiveType_Name(element_type));
133   }
134   TF_ASSIGN_OR_RETURN(Shape shape,
135                       ShapeUtil::MakeValidatedShape(element_type, dimensions));
136   if (element_size_in_bits ==
137       ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
138     // Only set element_size_in_bits if it's different from the default value.
139     element_size_in_bits = 0;
140   }
141   *shape.mutable_layout() = LayoutUtil::MakeLayout(
142       minor_to_major, tiles, element_size_in_bits, memory_space);
143   if (!shape.has_layout()) {
144     return InvalidArgument("Shape has no layout.");
145   }
146   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
147   return shape;
148 }
149 }  // namespace
150 
Equal(const Shape & lhs,const Shape & rhs)151 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
152   bool equal = Shape::Equal()(lhs, rhs);
153 
154   if (!equal && VLOG_IS_ON(3)) {
155     VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
156             << ", rhs = " << rhs.ShortDebugString();
157   }
158 
159   return equal;
160 }
161 
EqualIgnoringElementType(const Shape & lhs,const Shape & rhs)162 /* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
163                                                       const Shape& rhs) {
164   bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
165   if (!equal && VLOG_IS_ON(3)) {
166     VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
167             << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
168   }
169 
170   return equal;
171 }
172 
EqualIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)173 /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
174                                                       const Shape& rhs) {
175   bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
176   if (!equal && VLOG_IS_ON(3)) {
177     VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
178             << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
179   }
180 
181   return equal;
182 }
183 
EqualStructure(const Shape & lhs,const Shape & rhs)184 /* static */ bool ShapeUtil::EqualStructure(const Shape& lhs,
185                                             const Shape& rhs) {
186   bool equal = true;
187   ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
188     equal &= IndexIsValid(rhs, index);
189   });
190   ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
191     equal &= IndexIsValid(lhs, index);
192   });
193 
194   return equal;
195 }
196 
TrueRank(const Shape & shape)197 /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
198   int64_t accum = 0;
199   for (int64_t dimension : shape.dimensions()) {
200     // We do not count zero dimensions.
201     if (dimension != 1) {
202       accum += 1;
203     }
204   }
205   return accum;
206 }
207 
FillNewShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)208 /* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type,
209                                           absl::Span<const int64> dimensions,
210                                           Shape* shape) {
211   const int eint = static_cast<int>(element_type);
212   int64_t dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE)
213                                   ? primitive_byte_size[eint]
214                                   : 0);  // Out of range: force a failure
215   if (dense_shape_size <= 0) {
216     return false;
217   }
218 
219   // Verify that array-based lookup is consistent with public API.
220   DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type))
221       << element_type;
222 
223   shape->set_element_type(element_type);
224   const int ndims = dimensions.size();
225   auto layout = shape->mutable_layout();
226   layout->set_format(DENSE);
227   auto* minor_to_major = layout->mutable_minor_to_major();
228   for (int i = 0; i < ndims; i++) {
229     const int64_t d = dimensions[i];
230     if (d < 0) {
231       return false;
232     }
233     dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
234     if (dense_shape_size < 0) {
235       return false;
236     }
237 
238     shape->add_dimensions(d);
239     minor_to_major->push_back(ndims - 1 - i);
240   }
241   return true;
242 }
243 
MakeProgramShape(std::initializer_list<Shape> parameters,Shape result)244 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
245     std::initializer_list<Shape> parameters, Shape result) {
246   ProgramShape program_shape;
247   for (const Shape& shape : parameters) {
248     *program_shape.add_parameters() = shape;
249   }
250   *program_shape.mutable_result() = std::move(result);
251   return program_shape;
252 }
253 
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions)254 /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
255                                         absl::Span<const int64> dimensions) {
256   Shape shape;
257   CHECK(FillNewShape(element_type, dimensions, &shape));
258   return shape;
259 }
260 
MakeScalarShape(PrimitiveType element_type)261 /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
262   return MakeShape(element_type, {});
263 }
264 
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)265 /* static */ Shape ShapeUtil::MakeShape(
266     PrimitiveType element_type, absl::Span<const int64> dimensions,
267     const std::vector<bool>& dynamic_dimensions) {
268   return MakeValidatedShape(element_type, dimensions, dynamic_dimensions)
269       .ValueOrDie();
270 }
271 
MakeShapeWithStaticDimensions(const Shape & shape)272 /* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
273     const Shape& shape) {
274   Shape output = shape;
275   output.clear_dynamic_dimensions();
276   return output;
277 }
278 
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions)279 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
280     PrimitiveType element_type, absl::Span<const int64> dimensions) {
281   Shape shape;
282   if (!FillNewShape(element_type, dimensions, &shape)) {
283     return InvalidArgument("invalid shape type=%d, dims=[%s]",
284                            static_cast<int>(element_type),
285                            absl::StrJoin(dimensions, ","));
286   }
287   return shape;
288 }
289 
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)290 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
291     PrimitiveType element_type, absl::Span<const int64> dimensions,
292     const std::vector<bool>& dynamic_dimensions) {
293   if (dynamic_dimensions.size() != dimensions.size()) {
294     return InvalidArgument(
295         "dynamic dimensions size %d did not match number of dimensions %d",
296         dynamic_dimensions.size(), dimensions.size());
297   }
298 
299   Shape shape;
300   if (!FillNewShape(element_type, dimensions, &shape)) {
301     return InvalidArgument("invalid shape type=%d, dims=[%s]",
302                            static_cast<int>(element_type),
303                            absl::StrJoin(dimensions, ","));
304   }
305   for (int i = 0, n = dimensions.size(); i < n; i++) {
306     shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
307   }
308   return shape;
309 }
310 
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64_t element_size_in_bits,int64_t memory_space)311 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
312     PrimitiveType element_type, absl::Span<const int64> dimensions,
313     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
314     int64_t element_size_in_bits, int64_t memory_space) {
315   auto ret =
316       MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major,
317                                   tiles, element_size_in_bits, memory_space);
318   if (!ret.ok()) LOG(ERROR) << ret.status();
319   return ret.ValueOrDie();
320 }
321 
MoveDimToMajor(const Shape & shape,int64_t dim)322 /* static */ Shape ShapeUtil::MoveDimToMajor(const Shape& shape, int64_t dim) {
323   if (shape.IsTuple()) {
324     std::vector<Shape> result_shapes;
325     result_shapes.reserve(shape.tuple_shapes_size());
326     for (const Shape& s : shape.tuple_shapes()) {
327       result_shapes.push_back(MoveDimToMajor(s, dim));
328     }
329     return ShapeUtil::MakeTupleShape(result_shapes);
330   }
331 
332   Shape ret = shape;
333   if (!ret.has_layout()) {
334     LayoutUtil::SetToDefaultLayout(&ret);
335   }
336   *ret.mutable_layout() = LayoutUtil::MoveDimToMajor(ret.layout(), dim);
337   DimensionVector minor_to_major;
338   for (int64_t d : LayoutUtil::MinorToMajor(ret)) {
339     if (d != dim) {
340       minor_to_major.push_back(d);
341     }
342   }
343   minor_to_major.push_back(dim);
344   *ret.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
345   return ret;
346 }
347 
MakeShapeWithDescendingLayout(PrimitiveType element_type,absl::Span<const int64> dimensions)348 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
349     PrimitiveType element_type, absl::Span<const int64> dimensions) {
350   std::vector<int64> layout(dimensions.size());
351   std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
352   return MakeShapeWithLayout(element_type, dimensions, layout);
353 }
354 
355 /* static */ Shape
MakeShapeWithDescendingLayoutAndSamePhysicalLayout(const Shape & shape)356 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
357     const Shape& shape) {
358   std::vector<int64> dims(shape.dimensions_size());
359   for (int i = 0; i < shape.dimensions_size(); ++i) {
360     dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
361   }
362   Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
363   // Since the physical layout is kept the same, the tiles and element size are
364   // the same also.
365   new_shape.mutable_layout()->mutable_tiles()->assign(
366       shape.layout().tiles().begin(), shape.layout().tiles().end());
367   new_shape.mutable_layout()->set_element_size_in_bits(
368       shape.layout().element_size_in_bits());
369   for (int i = 0; i < shape.dimensions_size(); ++i) {
370     new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
371   }
372   return new_shape;
373 }
374 
PopulateShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)375 /* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,
376                                              absl::Span<const int64> dimensions,
377                                              Shape* shape) {
378   shape->Clear();
379   shape->set_element_type(element_type);
380   for (int64_t dimension : dimensions) {
381     shape->add_dimensions(dimension);
382   }
383   LayoutUtil::SetToDefaultLayout(shape);
384   return ValidateShape(*shape);
385 }
386 
MakeStaticShape(const Shape & original)387 /* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) {
388   Shape result = original;
389   result.clear_dynamic_dimensions();
390   return result;
391 }
392 
MakeTupleShape(absl::Span<const Shape> shapes)393 /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
394   Shape result;
395   result.set_element_type(TUPLE);
396   result.mutable_tuple_shapes()->reserve(shapes.size());
397   for (const auto& shape : shapes) {
398     AppendShapeToTuple(shape, &result);
399   }
400   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
401   return result;
402 }
403 
MakeOpaqueShape()404 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
405   Shape result;
406   result.set_element_type(OPAQUE_TYPE);
407   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
408   return result;
409 }
410 
MakeTokenShape()411 /* static */ Shape ShapeUtil::MakeTokenShape() {
412   Shape result;
413   result.set_element_type(TOKEN);
414   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
415   return result;
416 }
417 
AppendShapeToTuple(const Shape & shape,Shape * tuple_shape)418 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
419                                                 Shape* tuple_shape) {
420   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
421   *tuple_shape->add_tuple_shapes() = shape;
422 }
423 
UpdateTupleShape(const Shape & shape,int64_t index,Shape * tuple_shape)424 /* static */ void ShapeUtil::UpdateTupleShape(const Shape& shape, int64_t index,
425                                               Shape* tuple_shape) {
426   CHECK(index < tuple_shape->tuple_shapes_size());
427   *tuple_shape->mutable_tuple_shapes(index) = shape;
428 }
429 
UpdateDynamicDimension(Shape * shape,ShapeIndexView index,int64_t dim,bool is_dynamic)430 /* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
431                                                     ShapeIndexView index,
432                                                     int64_t dim,
433                                                     bool is_dynamic) {
434   if (index.empty()) {
435     CHECK(!shape->IsTuple());
436     shape->set_dynamic_dimension(dim, is_dynamic);
437     return;
438   }
439 
440   UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
441                          index.ConsumeFront(), dim, is_dynamic);
442 }
443 
AppendMajorDimension(int bound,Shape * shape)444 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
445   CHECK(LayoutUtil::IsDenseArray(*shape));
446   shape->mutable_layout()->add_minor_to_major(shape->rank());
447   shape->add_dimensions(bound);
448   TF_DCHECK_OK(ValidateShape(*shape));
449 }
450 
CopyDynamicDimensions(Shape * to,const Shape & from)451 /* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to,
452                                                    const Shape& from) {
453   CHECK_EQ(to->rank(), from.rank());
454   for (int64_t i = 0; i < from.rank(); ++i) {
455     to->set_dynamic_dimension(i, from.is_dynamic_dimension(i));
456   }
457   TF_DCHECK_OK(ValidateShape(*to));
458 }
459 
ElementIsIntegral(const Shape & shape)460 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
461   return primitive_util::IsIntegralType(shape.element_type());
462 }
463 
ElementIsIntegralWithBits(const Shape & shape,int32_t bits)464 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
465                                                        int32_t bits) {
466   return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
467 }
468 
ElementHasBitWidth(const Shape & shape,int bits)469 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
470   if (!shape.IsArray()) {
471     return false;
472   }
473   return primitive_util::BitWidth(shape.element_type()) == bits;
474 }
475 
ElementIsSigned(const Shape & shape)476 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
477   switch (shape.element_type()) {
478     case S8:
479     case S16:
480     case S32:
481     case S64:
482     case F16:
483     case BF16:
484     case F32:
485     case F64:
486       return true;
487 
488     case PRED:
489     case U8:
490     case U16:
491     case U32:
492     case U64:
493     case C64:
494     case C128:
495     case TUPLE:
496     case OPAQUE_TYPE:
497     case TOKEN:
498       return false;
499 
500     default:
501       LOG(FATAL) << "Unhandled element type " << shape.element_type();
502   }
503 }
504 
ElementIsComplex(const Shape & shape)505 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
506   return primitive_util::IsComplexType(shape.element_type());
507 }
508 
ElementIsFloating(const Shape & shape)509 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
510   return primitive_util::IsFloatingPointType(shape.element_type());
511 }
512 
IsNestedTuple(const Shape & shape)513 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
514   return shape.IsTuple() &&
515          absl::c_any_of(shape.tuple_shapes(),
516                         [](const Shape& s) { return s.IsTuple(); });
517 }
518 
IsEmptyTuple(const Shape & shape)519 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
520   return shape.IsTuple() && TupleElementCount(shape) == 0;
521 }
522 
TupleElementCount(const Shape & shape)523 /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
524   CHECK(shape.IsTuple()) << HumanString(shape);
525   return shape.tuple_shapes_size();
526 }
527 
GetTupleElementShape(const Shape & shape,int64_t index)528 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
529                                                           int64_t index) {
530   CHECK(shape.IsTuple());
531   CHECK_GT(TupleElementCount(shape), index);
532   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
533   return shape.tuple_shapes(index);
534 }
535 
SubshapeCount(const Shape & shape)536 /* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) {
537   int64_t n = 0;
538   ForEachSubshape(shape, [&](const Shape& literal_subshape,
539                              const ShapeIndex& index) { ++n; });
540   return n;
541 }
542 
SliceTuple(const Shape & tuple,int64_t start,int64_t limit)543 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64_t start,
544                                          int64_t limit) {
545   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
546   CHECK(tuple.IsTuple());
547   CHECK_LE(start, TupleElementCount(tuple));
548   CHECK_LE(limit, TupleElementCount(tuple));
549 
550   std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
551                                   tuple.tuple_shapes().begin() + limit);
552   return MakeTupleShape(new_elements);
553 }
554 
555 // Returns the shape of a real or imaginary component.
ComplexComponentShape(const Shape & complex_shape)556 /* static */ Shape ShapeUtil::ComplexComponentShape(
557     const Shape& complex_shape) {
558   CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
559   return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
560                                               complex_shape.element_type()));
561 }
562 
ElementsIn(const Shape & shape)563 /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
564   DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape);
565   DCHECK_EQ(shape.dimensions_size(), shape.rank());
566   if (shape.dimensions().size() == 1) {
567     return shape.dimensions()[0];
568   }
569   return std::accumulate<decltype(shape.dimensions().begin()), int64>(
570       shape.dimensions().begin(), shape.dimensions().end(), 1LL,
571       std::multiplies<int64>());
572 }
573 
ElementsInRecursive(const Shape & shape)574 /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) {
575   CHECK(shape.IsArray() || shape.IsTuple());
576   if (shape.IsArray()) {
577     return ElementsIn(shape);
578   }
579   int64_t count = 0;
580   for (const Shape& element_shape : shape.tuple_shapes()) {
581     count += ElementsInRecursive(element_shape);
582   }
583   return count;
584 }
585 
HasPrimitiveType(const Shape & shape,PrimitiveType primitive_type)586 /* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
587                                               PrimitiveType primitive_type) {
588   if (shape.element_type() == primitive_type) {
589     return true;
590   }
591   for (const Shape& element_shape : shape.tuple_shapes()) {
592     if (HasPrimitiveType(element_shape, primitive_type)) {
593       return true;
594     }
595   }
596   return false;
597 }
598 
IsZeroElementArray(const Shape & shape)599 /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
600   return shape.IsArray() && ElementsIn(shape) == 0;
601 }
602 
IsScalarWithElementType(const Shape & shape,PrimitiveType element_type)603 /* static */ bool ShapeUtil::IsScalarWithElementType(
604     const Shape& shape, PrimitiveType element_type) {
605   return IsScalar(shape) && shape.element_type() == element_type;
606 }
607 
HumanString(const Shape & shape)608 /* static */ string ShapeUtil::HumanString(const Shape& shape) {
609   if (shape.IsTuple()) {
610     string text = "(";
611     const auto& tuple_shapes = shape.tuple_shapes();
612     for (int64_t i = 0; i < tuple_shapes.size(); ++i) {
613       const Shape& elem_shape = tuple_shapes[i];
614       if (i != 0) {
615         StrAppend(&text, ", ");
616         if (i % kAnnotationPrintInterval == 0) {
617           StrAppend(&text, absl::StrFormat("/*index=%lld*/", i));
618         }
619       }
620       StrAppend(&text, HumanString(elem_shape));
621     }
622     text += ")";
623     return text;
624   }
625   std::vector<string> dim_elements;
626   for (int i = 0; i < shape.dimensions_size(); ++i) {
627     if (shape.is_dynamic_dimension(i)) {
628       dim_elements.push_back(StrCat("<=", shape.dimensions(i)));
629     } else {
630       dim_elements.push_back(StrCat(shape.dimensions(i)));
631     }
632   }
633   return StrCat(
634       primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
635       absl::StrJoin(dim_elements, ","), "]");
636 }
637 
HumanStringWithLayout(const Shape & shape)638 /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
639   if (shape.IsTuple()) {
640     string text = "(";
641     const auto& tuple_shapes = shape.tuple_shapes();
642     for (int64_t i = 0; i < tuple_shapes.size(); ++i) {
643       const Shape& elem_shape = tuple_shapes[i];
644       if (i != 0) {
645         StrAppend(&text, ", ");
646         if (i % kAnnotationPrintInterval == 0) {
647           StrAppend(&text, absl::StrFormat("/*index=%lld*/", i));
648         }
649       }
650       StrAppend(&text, HumanStringWithLayout(elem_shape));
651     }
652     text += ")";
653     return text;
654   }
655   string result = HumanString(shape);
656   if (IsScalar(shape)) {
657     string layout_str = LayoutUtil::HumanString(shape.layout());
658     // Don't print "{}" as layout for scalars.
659     if (layout_str != "{}") {
660       StrAppend(&result, layout_str);
661     }
662   } else if (shape.IsArray() && LayoutUtil::HasLayout(shape)) {
663     StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
664   }
665   return result;
666 }
667 
HumanString(const ProgramShape & program_shape)668 /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
669   std::vector<string> parameters;
670   for (auto& shape : program_shape.parameters()) {
671     const int i = parameters.size();
672     parameters.push_back(StrCat(i < program_shape.parameter_names_size()
673                                     ? program_shape.parameter_names(i)
674                                     : "(unknown)",
675                                 ": ", HumanString(shape)));
676   }
677   return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
678                 HumanString(program_shape.result()));
679 }
680 
SameDimensions(const Shape & lhs,const Shape & rhs)681 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
682                                             const Shape& rhs) {
683   CHECK(lhs.IsArray());
684   CHECK(rhs.IsArray());
685   return absl::c_equal(lhs.dimensions(), rhs.dimensions());
686 }
687 
SameRank(const Shape & lhs,const Shape & rhs)688 /* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) {
689   CHECK(lhs.IsArray());
690   CHECK(rhs.IsArray());
691   return lhs.rank() == rhs.rank();
692 }
693 
Compatible(const Shape & lhs,const Shape & rhs)694 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
695   return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
696 }
697 
CompatibleIgnoringElementType(const Shape & lhs,const Shape & rhs)698 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
699                                                            const Shape& rhs) {
700   return Shape::Equal()
701       .IgnoreDynamicDimension()
702       .IgnoreElementType()
703       .IgnoreLayout()(lhs, rhs);
704 }
705 
CompatibleKind(const Shape & lhs,const Shape & rhs)706 /* static */ bool ShapeUtil::CompatibleKind(const Shape& lhs,
707                                             const Shape& rhs) {
708   return Shape::Equal()
709       .IgnoreElementType()
710       .IgnoreLayout()
711       .IgnoreDimensions()
712       .IgnoreDynamicDimension()(lhs, rhs);
713 }
714 
CompatibleIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)715 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
716                                                            const Shape& rhs) {
717   return Shape::Equal()
718       .IgnoreDynamicDimension()
719       .IgnoreFpPrecision()
720       .IgnoreLayout()(lhs, rhs);
721 }
722 
GetDimension(const Shape & shape,int64_t dimension_number)723 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
724                                            int64_t dimension_number) {
725   return shape.dimensions(GetDimensionNumber(shape, dimension_number));
726 }
727 
GetDimensionNumber(const Shape & shape,int64_t dimension_number)728 /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
729                                                  int64_t dimension_number) {
730   if (dimension_number < 0) {
731     dimension_number += shape.rank();
732   }
733   CHECK_GE(dimension_number, 0);
734   return dimension_number;
735 }
736 
ByteSizeOfPrimitiveType(PrimitiveType primitive_type)737 /* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
738     PrimitiveType primitive_type) {
739   switch (primitive_type) {
740     case PRED:
741       return sizeof(int8);
742     case S8:
743       return sizeof(int8);
744     case S16:
745       return sizeof(int16);
746     case S32:
747       return sizeof(int32);
748     case S64:
749       return sizeof(int64);
750     case U8:
751       return sizeof(uint8);
752     case U16:
753       return sizeof(uint16);
754     case U32:
755       return sizeof(uint32);
756     case U64:
757       return sizeof(uint64);
758     case BF16:
759       return sizeof(float) / 2;
760     case F16:
761       return sizeof(float) / 2;
762     case F32:
763       return sizeof(float);
764     case F64:
765       return sizeof(double);
766     case C64:
767       return sizeof(complex64);
768     case C128:
769       return sizeof(complex128);
770     case TOKEN:
771       // Tokens require no space.
772       return 0;
773     case TUPLE:
774     case OPAQUE_TYPE:
775       LOG(FATAL) << PrimitiveType_Name(primitive_type)
776                  << " primitive type has no definitive size";
777     default:
778       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
779   }
780 }
781 
ByteSizeOf(const Shape & shape,int64_t pointer_size)782 /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
783                                          int64_t pointer_size) {
784   TF_DCHECK_OK(ValidateShape(shape));
785   if (shape.element_type() == TUPLE) {
786     return ByteSizeOfTupleIndexTable(shape, pointer_size);
787   } else if (shape.IsArray()) {
788     return ByteSizeOfElements(shape);
789   } else if (shape.element_type() == TOKEN) {
790     return 0;
791   } else if (shape.element_type() == OPAQUE_TYPE) {
792     CHECK_GT(pointer_size, 0);
793     return pointer_size;
794   }
795   LOG(FATAL) << PrimitiveType_Name(shape.element_type())
796              << " primitive type has no definitive size";
797 }
798 
ByteSizeOfTupleIndexTable(const Shape & shape,int64_t pointer_size)799 /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
800                                                         int64_t pointer_size) {
801   TF_DCHECK_OK(ValidateShape(shape));
802   CHECK_EQ(TUPLE, shape.element_type());
803   CHECK_GT(pointer_size, 0);
804   return pointer_size * shape.tuple_shapes_size();
805 }
806 
ByteSizeOfElements(const Shape & shape)807 /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
808   TF_DCHECK_OK(ValidateShape(shape));
809   CHECK(shape.IsArray());
810   int64_t allocated_element_count;
811 
812   CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
813   allocated_element_count = ElementsIn(shape);
814   return allocated_element_count *
815          ByteSizeOfPrimitiveType(shape.element_type());
816 }
817 
ValidateShapeWithOptionalLayoutInternal(const Shape & shape)818 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
819     const Shape& shape) {
820   if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
821       !PrimitiveType_IsValid(shape.element_type())) {
822     return InvalidArgument("shape has invalid element type: %s",
823                            shape.ShortDebugString());
824   }
825   if (shape.element_type() == TUPLE) {
826     if (shape.dimensions_size() != 0) {
827       return InvalidArgument("tuples must not have dimensions specified");
828     }
829     for (auto& element_shape : shape.tuple_shapes()) {
830       TF_RETURN_IF_ERROR(
831           ValidateShapeWithOptionalLayoutInternal(element_shape));
832     }
833     return Status::OK();
834   }
835 
836   // Non-tuple shape.
837   if (shape.tuple_shapes_size() > 0) {
838     return InvalidArgument("non-tuple shape has tuple_shapes field");
839   }
840 
841   // Tokens and opaques can should not have layout or dimensions.
842   if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) {
843     if (shape.dimensions_size() != 0) {
844       return InvalidArgument(
845           "shape has %s element type, but has dimensions field: %s",
846           primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
847           shape.ShortDebugString());
848     }
849     if (shape.has_layout()) {
850       return InvalidArgument(
851           "shape has %s element type, but has layout field: %s",
852           primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
853           shape.ShortDebugString());
854     }
855     return Status::OK();
856   }
857 
858   for (int64_t i = 0; i < shape.rank(); ++i) {
859     int64_t dimension = shape.dimensions(i);
860     if (dimension < 0) {
861       return InvalidArgument(
862           "shape's dimensions must not be < 0; dimension at index %d was %d", i,
863           dimension);
864     }
865   }
866 
867   TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
868   return Status::OK();
869 }
870 
ValidateShapeSize(const Shape & shape)871 /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
872   VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
873 
874   if (!shape.IsArray()) {
875     return Status::OK();
876   }
877 
878   int64_t shape_size = [&]() {
879     int64_t dense_shape_size = 1;
880     if (shape.dimensions().empty()) {
881       return dense_shape_size;
882     }
883 
884     absl::Span<const int64> shape_max_dimensions =
885         AsInt64Slice(shape.dimensions());
886     for (int64_t dim : shape_max_dimensions) {
887       dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
888       if (dense_shape_size < 0) {
889         return dense_shape_size;
890       }
891     }
892     dense_shape_size = MultiplyWithoutOverflow(
893         dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
894     return dense_shape_size;
895   }();
896 
897   if (shape_size < 0) {
898     return InvalidArgument("Shape %s size may overflow int64.",
899                            ShapeUtil::HumanString(shape));
900   }
901 
902   VLOG(3) << "Shape size is valid: " << shape_size;
903   return Status::OK();
904 }
905 
ValidateShapeWithOptionalLayout(const Shape & shape)906 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
907     const Shape& shape) {
908   TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
909 
910   return LayoutUtil::ValidateLayoutInShape(shape,
911                                            /*allow_missing_layouts=*/true);
912 }
913 
ValidateShape(const Shape & shape)914 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
915   TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
916 
917   return LayoutUtil::ValidateLayoutInShape(shape);
918 }
919 
ChangeElementType(const Shape & original,PrimitiveType type)920 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
921                                                 PrimitiveType type) {
922   if (original.IsTuple()) {
923     std::vector<Shape> new_operands;
924     new_operands.reserve(original.tuple_shapes_size());
925     for (const Shape& operand : original.tuple_shapes()) {
926       new_operands.push_back(ChangeElementType(operand, type));
927     }
928     return MakeTupleShape(new_operands);
929   } else {
930     Shape new_shape = original;
931     new_shape.set_element_type(type);
932     return new_shape;
933   }
934 }
935 
IndexIsValid(const Shape & shape,ShapeIndexView index)936 /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
937                                           ShapeIndexView index) {
938   const Shape* subshape = &shape;
939   for (auto i : index) {
940     if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) {
941       return false;
942     }
943     subshape = &subshape->tuple_shapes(i);
944   }
945   return true;
946 }
947 
GetSubshape(const Shape & shape,ShapeIndexView index)948 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
949                                                  ShapeIndexView index) {
950   const Shape* return_shape = &shape;
951   for (auto i : index) {
952     CHECK(return_shape->IsTuple())
953         << "Invalid index " << index << " for shape " << shape;
954     return_shape = &return_shape->tuple_shapes(i);
955   }
956   return *return_shape;
957 }
958 
TryGetSubshape(const Shape & shape,ShapeIndexView index)959 /* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
960     const Shape& shape, ShapeIndexView index) {
961   const Shape* return_shape = &shape;
962   for (auto i : index) {
963     if (!return_shape->IsTuple() || i < 0 ||
964         i >= return_shape->tuple_shapes_size()) {
965       return InvalidArgument(
966           "Shape index %s not a valid subshape index for tuple with shape %s",
967           index.ToString(), shape.DebugString());
968     }
969     return_shape = &return_shape->tuple_shapes(i);
970   }
971   return return_shape;
972 }
973 
GetMutableSubshape(Shape * shape,ShapeIndexView index)974 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
975                                                   ShapeIndexView index) {
976   Shape* return_shape = shape;
977   for (auto i : index) {
978     CHECK(return_shape->IsTuple());
979     return_shape = return_shape->mutable_tuple_shapes(i);
980   }
981   return return_shape;
982 }
983 
984 /* static */
IsLeafIndex(const Shape & shape,const ShapeIndex & index)985 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
986   return !GetSubshape(shape, index).IsTuple();
987 }
988 
GetLeafCount(const Shape & shape)989 /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
990   if (!shape.IsTuple()) {
991     return 1;
992   }
993   int64_t count = 0;
994   for (const Shape& subshape : shape.tuple_shapes()) {
995     count += GetLeafCount(subshape);
996   }
997   return count;
998 }
999 
GetLeafShapes(const Shape & shape)1000 /* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
1001     const Shape& shape) {
1002   std::vector<IndexedShape> leaves;
1003   ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
1004     if (IsLeafIndex(shape, index)) {
1005       leaves.emplace_back(index, sub_shape);
1006     }
1007   });
1008   return leaves;
1009 }
1010 
HasDegenerateDimensions(const Shape & shape)1011 /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
1012   CHECK(shape.IsArray());
1013   return absl::c_linear_search(shape.dimensions(), 1);
1014 }
1015 
DropDegenerateDimensions(const Shape & shape)1016 /* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
1017   return FilterDimensions(
1018       [&](int64_t dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
1019 }
1020 
1021 namespace {
1022 
1023 // Helper for ForEachSubshape which visits the subshapes of the given shape in
1024 // DFS pre-order starting with the index.
ForEachSubshapeHelper(const Shape & shape,const ShapeUtil::StatusVisitorFunction & func,ShapeIndex * index)1025 Status ForEachSubshapeHelper(const Shape& shape,
1026                              const ShapeUtil::StatusVisitorFunction& func,
1027                              ShapeIndex* index) {
1028   TF_RETURN_IF_ERROR(func(shape, *index));
1029   if (shape.IsTuple()) {
1030     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
1031       index->push_back(i);
1032       TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
1033           ShapeUtil::GetTupleElementShape(shape, i), func, index));
1034       index->pop_back();
1035     }
1036   }
1037   return Status::OK();
1038 }
1039 
1040 // Helper for ForEachMutableSubshape which visits the subshapes of the given
1041 // shape in DFS pre-order starting with the index.
ForEachMutableSubshapeHelper(Shape * shape,const ShapeUtil::MutatingStatusVisitorFunction & func,ShapeIndex * index)1042 Status ForEachMutableSubshapeHelper(
1043     Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
1044     ShapeIndex* index) {
1045   TF_RETURN_IF_ERROR(func(shape, *index));
1046   if (shape->IsTuple()) {
1047     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
1048       index->push_back(i);
1049       TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
1050           shape->mutable_tuple_shapes(i), func, index));
1051       index->pop_back();
1052     }
1053   }
1054   return Status::OK();
1055 }
1056 
1057 }  // namespace
1058 
ForEachSubshape(const Shape & shape,const VisitorFunction & func)1059 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
1060                                              const VisitorFunction& func) {
1061   ShapeIndex index;
1062   ForEachSubshapeHelper(
1063       shape,
1064       [&func](const Shape& subshape, const ShapeIndex& index) {
1065         func(subshape, index);
1066         return Status::OK();
1067       },
1068       &index)
1069       .IgnoreError();
1070 }
1071 
ForEachMutableSubshape(Shape * shape,const MutatingVisitorFunction & func)1072 /* static */ void ShapeUtil::ForEachMutableSubshape(
1073     Shape* shape, const MutatingVisitorFunction& func) {
1074   ShapeIndex index;
1075   ForEachMutableSubshapeHelper(
1076       shape,
1077       [&func](Shape* subshape, const ShapeIndex& index) {
1078         func(subshape, index);
1079         return Status::OK();
1080       },
1081       &index)
1082       .IgnoreError();
1083 }
1084 
ForEachSubshapeWithStatus(const Shape & shape,const StatusVisitorFunction & func)1085 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
1086     const Shape& shape, const StatusVisitorFunction& func) {
1087   ShapeIndex index;
1088   return ForEachSubshapeHelper(shape, func, &index);
1089 }
1090 
ForEachMutableSubshapeWithStatus(Shape * shape,const MutatingStatusVisitorFunction & func)1091 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
1092     Shape* shape, const MutatingStatusVisitorFunction& func) {
1093   ShapeIndex index;
1094   return ForEachMutableSubshapeHelper(shape, func, &index);
1095 }
1096 
PermuteDimensions(absl::Span<const int64> permutation,const Shape & shape)1097 /* static */ Shape ShapeUtil::PermuteDimensions(
1098     absl::Span<const int64> permutation, const Shape& shape) {
1099   Shape new_shape = shape;
1100   new_shape.clear_dimensions();
1101   for (auto dim : Permute(shape.dimensions(), permutation)) {
1102     new_shape.add_dimensions(dim);
1103   }
1104   auto inv_permutation = InversePermutation(permutation);
1105   for (int64_t i = 0; i < shape.rank(); i++) {
1106     new_shape.set_dynamic_dimension(inv_permutation[i],
1107                                     shape.is_dynamic_dimension(i));
1108   }
1109 
1110   // If `shape` has a layout, by contract we choose a new layout such that the
1111   // transpose defined by this permutation is a bitcast.
1112   //
1113   // Some formalism helps to understand the correct way to do this.  We're going
1114   // to do algebra in the group of permutations of the dimensions of `shape`.
1115   //
1116   // Since the order of `shape`'s dimensions is not permuted relative to itself,
1117   // `shape`'s list of dimensions is isomorphic to the identity I.
1118   //
1119   // Let `shape`'s layout be L.  A layout is a permutation which maps a
1120   // minor-to-major physical dimension ordering to a shape's logical dimension
1121   // ordering.  Therefore the inverse of a layout maps from logical to physical
1122   // dims, and so the physical ordering of I is simply L'.I = L', where L' is
1123   // the inverse of L.
1124   //
1125   // Let the argument `permutation` be P.  This is a permutation over `shape`'s
1126   // dimensions, so our return value will be a shape with dims P.I = P.  Our
1127   // goal is to construct a layout permutation L* for this shape. The physical
1128   // dimension ordering of this returned shape must be the same as that of the
1129   // original shape, namely L'.
1130   //
1131   // Our returned shape has dims P and layout L*, so its in-memory ordering is
1132   // L*'.P.  Setting this equal to L' and solving for L*, we get:
1133   //
1134   //   L*'.P = L'    =>
1135   //   L*'   = L'P'  =>
1136   //   L*    = P.L
1137   //
1138   if (shape.has_layout()) {
1139     CHECK(LayoutUtil::IsDenseArray(shape));
1140     Layout* new_layout = new_shape.mutable_layout();
1141     new_layout->set_format(DENSE);
1142     new_layout->clear_minor_to_major();
1143     for (auto index : ComposePermutations(
1144              inv_permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
1145       new_layout->add_minor_to_major(index);
1146     }
1147     // The permutation accepted by TransposeIsBitcast is the inverse of the
1148     // permutation here.
1149     CHECK(TransposeIsBitcast(shape, new_shape, permutation))
1150         << "shape=" << HumanStringWithLayout(shape)
1151         << ", new_shape=" << HumanStringWithLayout(new_shape)
1152         << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
1153   }
1154   return new_shape;
1155 }
1156 
1157 /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
InsertedOrDeleted1SizedDimensions(const Shape & shape_pre,const Shape & shape_post)1158 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
1159                                              const Shape& shape_post) {
1160   CHECK(shape_pre.IsArray());
1161   CHECK(shape_post.IsArray());
1162 
1163   auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
1164 
1165   std::vector<int64> deleted_indices;
1166   std::vector<int64> inserted_indices;
1167   // Returns false if any input/output index between prior_unmodified_dim_pair
1168   // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
1169   // the degerenate input/output dimensions in the gap to
1170   // deleted_indices/inserted_indices respectively.
1171   auto check_modified_dims =
1172       [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
1173           std::pair<int64, int64> prior_unmodified_dim_pair,
1174           std::pair<int64, int64> unmodified_dim_pair) {
1175         for (int64_t modified_input_dim = prior_unmodified_dim_pair.first + 1;
1176              modified_input_dim < unmodified_dim_pair.first;
1177              ++modified_input_dim) {
1178           if (shape_pre.dimensions(modified_input_dim) > 1) {
1179             return false;
1180           }
1181           deleted_indices.push_back(modified_input_dim);
1182         }
1183         for (int64_t modified_output_dim = prior_unmodified_dim_pair.second + 1;
1184              modified_output_dim < unmodified_dim_pair.second;
1185              ++modified_output_dim) {
1186           if (shape_post.dimensions(modified_output_dim) > 1) {
1187             return false;
1188           }
1189           inserted_indices.push_back(modified_output_dim);
1190         }
1191         return true;
1192       };
1193 
1194   std::vector<std::pair<int64, int64>> unmodified_dims =
1195       DimensionsUnmodifiedByReshape(shape_pre, shape_post);
1196   // Returns nil if the reshape modifies any non-degenerate input/output
1197   // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
1198   // dimensions, so we only need to check whether dimensions in the gaps (thus
1199   // modified) have size >1.
1200   for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
1201     // Check (modified) dimensions between unmodified_dims[i-1] and
1202     // unmodified_dims[i].
1203     auto prior_unmodified_dim_pair =
1204         i > 0 ? unmodified_dims[i - 1] : std::pair<int64, int64>(-1, -1);
1205     auto unmodified_dim_pair =
1206         i < unmodified_dims.size()
1207             ? unmodified_dims[i]
1208             : std::make_pair(shape_pre.rank(), shape_post.rank());
1209     if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
1210       return nil;
1211     }
1212   }
1213 
1214   return std::make_tuple(true, deleted_indices, inserted_indices);
1215 }
1216 
1217 /* static */ std::vector<std::pair<int64, int64>>
DimensionsUnmodifiedByReshape(const Shape & input_shape,const Shape & output_shape)1218 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
1219                                          const Shape& output_shape) {
1220   CHECK(input_shape.IsArray());
1221   CHECK(output_shape.IsArray());
1222 
1223   // Unmodified dimensions are merely common factors of rank 1.
1224   auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
1225                                       AsInt64Slice(output_shape.dimensions()));
1226   for (size_t i = 0; i < common_factors.size() - 1;) {
1227     if (1 != common_factors[i + 1].first - common_factors[i].first ||
1228         1 != common_factors[i + 1].second - common_factors[i].second) {
1229       common_factors.erase(common_factors.begin() + i);
1230     } else {
1231       ++i;
1232     }
1233   }
1234   // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
1235   common_factors.pop_back();
1236   return std::vector<std::pair<int64, int64>>(common_factors.begin(),
1237                                               common_factors.end());
1238 }
1239 
1240 /* static */ absl::optional<std::vector<int64>>
ReshapeLeavesDimensionsUnmodified(const Shape & from_shape,const Shape & to_shape,absl::Span<const int64> input_dim_indices)1241 ShapeUtil::ReshapeLeavesDimensionsUnmodified(
1242     const Shape& from_shape, const Shape& to_shape,
1243     absl::Span<const int64> input_dim_indices) {
1244   CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
1245 
1246   std::vector<int64> output_dim_indices;
1247   std::vector<std::pair<int64, int64>> unmodified_dims =
1248       ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
1249   size_t i = 0;  // index to unmodified_dims
1250   for (int64_t input_dim_index : input_dim_indices) {
1251     // Search unmodified_dims for input_dim_index. We can search from the last
1252     // matching position because input_dim_indices is guaranteed to be sorted.
1253     while (i < unmodified_dims.size() &&
1254            unmodified_dims[i].first < input_dim_index) {
1255       ++i;
1256     }
1257     if (i >= unmodified_dims.size() ||
1258         unmodified_dims[i].first != input_dim_index) {
1259       return absl::nullopt;
1260     }
1261     output_dim_indices.push_back(unmodified_dims[i].second);
1262   }
1263   return output_dim_indices;
1264 }
1265 
TransposeIsBitcast(const Shape & input_shape,const Shape & output_shape,absl::Span<const int64> dimension_mapping)1266 /* static */ bool ShapeUtil::TransposeIsBitcast(
1267     const Shape& input_shape, const Shape& output_shape,
1268     absl::Span<const int64> dimension_mapping) {
1269   CHECK(LayoutUtil::HasLayout(input_shape) &&
1270         LayoutUtil::HasLayout(output_shape));
1271 
1272   if (!SameElementType(input_shape, output_shape)) {
1273     return false;
1274   }
1275 
1276   // Check the reshape permutes the positions of each dimension in the
1277   // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
1278   //   input_positions = apply(dimension_mapping, output_positions)
1279   //
1280   // Because the positions of each dimension are the inverse permutation of the
1281   // minor-to-major order, the above check is equivalent to
1282   //   inverse(input_dimensions) =
1283   //       apply(dimension_mapping, inverse(output_dimensions))
1284   //   # `I` indicates identity permutation.
1285   //   apply(input_dimensions, I) =
1286   //       apply(dimension_mapping, apply(output_dimensions, I))
1287   //   apply(input_dimensions, I) =
1288   //       apply((dimension_mapping * output_dimensions), I)
1289   //   input_dimensions = dimension_mapping * output_dimensions
1290   return absl::c_equal(
1291       ComposePermutations(dimension_mapping,
1292                           AsInt64Slice(output_shape.layout().minor_to_major())),
1293       input_shape.layout().minor_to_major());
1294 }
1295 
ReshapeIsBitcast(const Shape & input_shape,const Shape & output_shape)1296 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
1297                                               const Shape& output_shape) {
1298   CHECK(input_shape.IsArray());
1299   CHECK(output_shape.IsArray());
1300   CHECK(LayoutUtil::HasLayout(input_shape));
1301   CHECK(LayoutUtil::HasLayout(output_shape));
1302 
1303   if (!SameElementType(input_shape, output_shape)) {
1304     return false;
1305   }
1306 
1307   CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape))
1308       << "input_shape=" << input_shape.ShortDebugString()
1309       << ", output_shape=" << output_shape.ShortDebugString();
1310   if (ElementsIn(input_shape) == 0) {
1311     return true;
1312   }
1313 
1314   // TL;DR: The rest of the method checks that the reshape does not change the
1315   // physical location of any unit input or output index. Unit indices have
1316   // exactly one dimension that equals 1 and other dimensions 0. This condition
1317   // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
1318   // reshape shouldn't change the physical location of any element. It is also a
1319   // sufficient condition as is proved below (note: many details are omitted for
1320   // space).
1321   //
1322   // Definitions:
1323   //
1324   // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
1325   // the size of i-th least significant dimension of IS or OS (this is opposite
1326   // to how we define the index of Shape::dimensions()).
1327   //
1328   // * Given an input or output index I, denote by p(I) I's physical linear
1329   // index (or physical index for short) and l(I) I's logical linear index (or
1330   // logical index for short).
1331   //
1332   // * Given a logical index k, denote by II(k) the input index whose linear
1333   // index is k, and OI(k) the corresponding output index.
1334   //
1335   // * Denote by IT[i] the increment of physical index if i-th dimension of the
1336   // input index is increased by 1. Similarly, OT[i] means the increment if i-th
1337   // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
1338   // is a function of IS or OS and the layout, and not dependent on the specific
1339   // input or output index.
1340   //
1341   // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
1342   // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
1343   // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
1344   // to prove, with every increment on k, the above formula still holds.
1345   //
1346   // First, suppose reshaping from IS to OS is non-factorizable (we discuss
1347   // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
1348   // there exists (i,j) such that
1349   //
1350   //   0<=i<=|IS|
1351   //   0<=j<=|OS|
1352   //   |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
1353   //   product(IS[i], IS[i+1], ..., IS[|IS|-1])
1354   //     = product(OS[j], OS[j+1], ..., OS[|OS|-1])
1355   //
1356   // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
1357   // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
1358   // unit indices which are already tested. This also means IT[0]=OT[0]
1359   // because p(II(1))=IT[0] and p(OI(1))=OT[0].
1360   //
1361   // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
1362   // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
1363   // to the output physical.
1364   //
1365   // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
1366   // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
1367   // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
1368   // logical index k-1 to logical index k, dimension 1 of the input index
1369   // is increased by 1 and dimension 0 is reset to 0 thus decreased by
1370   // IS[0]-1. Therefore, the physical input index is increased by
1371   //
1372   //   p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
1373   //
1374   // Because IS[0]<OS[0], the only change to the output index is that its
1375   // dimension 0 is increased by one. Therefore,
1376   //
1377   //   p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
1378   //
1379   // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
1380   // p(II(k))=p(OI(k)). Therefore,
1381   //   IT[1] - (IS[0]-1) * IT[0] = IT[0]
1382   //   IT[1] = IS[0] * IT[0]
1383   // In other words, input dimension 1 is immediately more major than input
1384   // dimension 0. We can now conceptually collapse these two dimensions because
1385   // an increment in the logical index affecting only these two dimensions maps
1386   // to IT[0] in the physical index.
1387   //
1388   // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
1389   // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
1390   // identical.
1391   //
1392   // A factorizable reshape can be factorized into a list of non-factorizable
1393   // sub-reshapes, each of which can be handled similarly to the proof above.
1394   // For example,
1395   //
1396   //   [7x9x2x15] -> [63x6x5]
1397   //
1398   // can be factorized into
1399   //
1400   //   [7x9] -> [63] and [2x15] -> [6x5].
1401   //
1402   // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
1403   // same logical linear index. According to the factorization, we know
1404   // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
1405   // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
1406   // similar proof, with the increment of the logical index set to
1407   // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
1408   // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
1409   //
1410   //   p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
1411   //                  = p(y2,0,0) + p(0,0,y1,y0)
1412   //                  = p(y2,y1,y0)
1413   //
1414   // check_input_unit_indices checks one way of the condition: each input unit
1415   // index is mapped to an output index with the same physical location. This
1416   // lambda will be called again with input_shape and output_shape reversed to
1417   // check the other way.
1418   auto check_input_unit_indices = [](const Shape& input_shape,
1419                                      const Shape& output_shape) {
1420     // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
1421     // as input_shape/output_shape and the dimension-0-major layout. These two
1422     // shapes are used for conversion between logical linear indices and
1423     // multi-dimensional indices.
1424     Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
1425         input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
1426     Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
1427         output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
1428 
1429     for (int64_t input_dim = 0; input_dim < input_shape.rank(); ++input_dim) {
1430       if (input_shape.dimensions(input_dim) <= 1) {
1431         continue;
1432       }
1433 
1434       std::vector<int64> input_unit_index(input_shape.rank(), 0);
1435       input_unit_index[input_dim] = 1;
1436       int64_t logical_linear_index =
1437           IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
1438                                                         input_unit_index);
1439       // output_index has the same logical linear index as input_unit_index.
1440       std::vector<int64> output_index =
1441           IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
1442                                                         logical_linear_index);
1443       // Check input_unit_index and output_index have the same physical linear
1444       // index.
1445       if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
1446                                                         input_unit_index) !=
1447           IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
1448                                                         output_index)) {
1449         return false;
1450       }
1451     }
1452     return true;
1453   };
1454   return check_input_unit_indices(input_shape, output_shape) &&
1455          check_input_unit_indices(output_shape, input_shape);
1456 }
1457 
AlignLayouts(const Shape & input_shape,const Shape & output_shape)1458 /* static */ absl::optional<Shape> ShapeUtil::AlignLayouts(
1459     const Shape& input_shape, const Shape& output_shape) {
1460   CHECK(input_shape.IsArray());
1461   CHECK(output_shape.IsArray());
1462   // Removing trivial dimensions from the shape simplifies the alignment
1463   // algorithm since ones can go in any position.
1464   if (HasDegenerateDimensions(input_shape) ||
1465       HasDegenerateDimensions(output_shape)) {
1466     auto simple_output_shape =
1467         AlignLayouts(DropDegenerateDimensions(input_shape),
1468                      DropDegenerateDimensions(output_shape));
1469     if (!simple_output_shape) {
1470       return absl::nullopt;
1471     }
1472 
1473     std::vector<int64> layout =
1474         SpanToVector(simple_output_shape->layout().minor_to_major());
1475     // For each one sized dimension in the output, increment the dimension
1476     // numbers in layout that are more minor than the one.
1477     absl::InlinedVector<int64, 8> dim_map;
1478     dim_map.reserve(simple_output_shape->rank());
1479     for (int64_t i = 0; i < output_shape.rank(); ++i) {
1480       if (output_shape.dimensions(i) != 1) {
1481         dim_map.push_back(i);
1482       }
1483     }
1484     for (int64& d : layout) {
1485       d = dim_map[d];
1486     }
1487 
1488     // Add the ones in descending order to the layout. Descending layouts tend
1489     // to reduce the number of copies inserted in layout assignment.
1490     for (int64_t i = output_shape.rank() - 1; i >= 0; --i) {
1491       if (output_shape.dimensions(i) == 1) {
1492         layout.push_back(i);
1493       }
1494     }
1495     Shape output_shape_with_layout = output_shape;
1496     *output_shape_with_layout.mutable_layout() = Layout{layout};
1497     return output_shape_with_layout;
1498   }
1499 
1500   int64_t input_rank = input_shape.rank();
1501   int64_t output_rank = output_shape.rank();
1502 
1503   // First, calculate an alignment of the dimensions. A consecutive sequence of
1504   // input dimensions and output dimensions belong to the same alignment part if
1505   // the products of their dimension bounds are the same. In the easiest case,
1506   // an alignment part consists of one input dimension and one output dimension
1507   // which both have the same dimension bound. An alignment part specifies which
1508   // dimensions need to be kept together in a physical layout if we want a
1509   // reshape to be a bitcast. The order of the alignment parts is defined by the
1510   // physical layout of the input shape, so when we construct the layout for the
1511   // output shape we just process the alignment parts in this order, and then
1512   // layout the dimensions belonging to each part in descending (major to minor)
1513   // order.
1514 
1515   // Stores the input and output dimension numbers where each alignment part
1516   // starts.
1517   std::vector<std::pair<int64, int64>> alignment;
1518   alignment.push_back({0, 0});
1519 
1520   // Stores a mapping from the input dimension to the alignment part it belongs
1521   // to.
1522   std::vector<int64> dimension_to_alignment_index(input_rank);
1523   int64_t input_dimension_product = 1, output_dimension_product = 1;
1524   for (int64_t i = 0, j = 0; i < input_rank || j < output_rank;) {
1525     // Check if we have reached the end of an alignment part.
1526     if (input_dimension_product == output_dimension_product &&
1527         input_dimension_product > 1) {
1528       alignment.push_back({i, j});
1529       input_dimension_product = output_dimension_product = 1;
1530     }
1531     if (input_dimension_product < output_dimension_product ||
1532         j == output_rank) {
1533       if (i == input_rank) {
1534         return absl::nullopt;
1535       }
1536       dimension_to_alignment_index[i] = alignment.size() - 1;
1537       input_dimension_product *= input_shape.dimensions(i);
1538       ++i;
1539     } else {
1540       output_dimension_product *= output_shape.dimensions(j);
1541       ++j;
1542     }
1543   }
1544   if (input_dimension_product != output_dimension_product) {
1545     return absl::nullopt;
1546   }
1547 
1548   // We also need to store an end element so that we know where the last
1549   // alignment part ends.
1550   alignment.push_back({input_rank, output_rank});
1551   // Now check if the physical layout can potentially be aligned to the output
1552   // shape by changing the physical layout of the output shape. We need to check
1553   // that all dimension numbers that belong to the same alignment part appear
1554   // consecutively, and are in descending order. However we can ignore any
1555   // trivial dimension bounds of 1, because they can be placed anywhere.
1556   auto input_dimension_numbers = input_shape.layout().minor_to_major();
1557   std::vector<int64> output_layout;
1558   output_layout.reserve(output_rank);
1559   for (int64_t i = 0; i < input_rank;) {
1560     int64_t current_dimension_number = input_dimension_numbers[i];
1561 
1562     // Trivial dimensions are stripped.
1563     CHECK_NE(input_shape.dimensions(current_dimension_number), 1);
1564     const int64_t current_alignment_index =
1565         dimension_to_alignment_index[current_dimension_number];
1566     // Because of the special end element that we added, we can be sure that
1567     // 'current_alignment_index' is < alignment.size() - 1.
1568     CHECK_LT(current_alignment_index, alignment.size() - 1);
1569 
1570     // Check that the following 'num_non_trivial_dimensions_in_alignment_part'
1571     // dimension numbers (ignoring dimension numbers with dimension bound 1) are
1572     // in descending order and belong to the current alignment part.
1573     for (int64_t j = 0; j < alignment[current_alignment_index + 1].first -
1574                                 alignment[current_alignment_index].first;
1575          ++i, ++j) {
1576       if (i == input_rank) {
1577         return absl::nullopt;
1578       }
1579       // If the current dimension number belongs to a different alignment part,
1580       // or the dimension numbers are not in descending order, we can return
1581       // early.
1582       if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
1583               current_alignment_index ||
1584           input_dimension_numbers[i] > current_dimension_number) {
1585         return absl::nullopt;
1586       }
1587       current_dimension_number = input_dimension_numbers[i];
1588     }
1589     // The output dimension numbers that belong to the current alignment part
1590     // need to appear in the same descending order as in the input.
1591     for (int64_t j = alignment[current_alignment_index + 1].second - 1;
1592          j >= alignment[current_alignment_index].second; --j) {
1593       output_layout.push_back(j);
1594     }
1595   }
1596   CHECK_EQ(output_layout.size(), output_rank);
1597   Shape output_shape_with_layout = MakeShapeWithLayout(
1598       output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1599       output_layout);
1600   CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout))
1601       << "reshape is not a bitcast for input_shape: "
1602       << ShapeUtil::HumanStringWithLayout(input_shape)
1603       << " and output_shape_with_layout: "
1604       << ShapeUtil::HumanStringWithLayout(output_shape_with_layout);
1605   return output_shape_with_layout;
1606 }
1607 
DeleteDimension(int64_t dim_to_delete,Shape shape)1608 /* static */ Shape ShapeUtil::DeleteDimension(int64_t dim_to_delete,
1609                                               Shape shape) {
1610   CHECK(shape.IsArray());
1611   shape.DeleteDimension(dim_to_delete);
1612   return shape;
1613 }
1614 
DynamicArrayShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1615 /* static */ bool ShapeUtil::DynamicArrayShapeIsCompatible(
1616     const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1617   if (dynamic_shape.rank() != bounded_shape.rank()) {
1618     return false;
1619   }
1620   for (int64_t i = 0; i < dynamic_shape.rank(); ++i) {
1621     if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
1622       return false;
1623     }
1624   }
1625   return true;
1626 }
1627 
DynamicShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1628 /* static */ bool ShapeUtil::DynamicShapeIsCompatible(
1629     const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1630   bool compatible = true;
1631   xla::ShapeUtil::ForEachSubshape(dynamic_shape, [&](const Shape& sub_shape,
1632                                                      const ShapeIndex& index) {
1633     if (compatible) {
1634       auto subshape_result = TryGetSubshape(bounded_shape, index);
1635       if (subshape_result.ok()) {
1636         const Shape* bounded_sub_shape = subshape_result.ConsumeValueOrDie();
1637         if (sub_shape.IsTuple()) {
1638           if (!bounded_sub_shape->IsTuple()) {
1639             compatible = false;
1640           }
1641         } else {
1642           if (bounded_sub_shape->IsTuple()) {
1643             compatible = false;
1644           } else if (!sub_shape.is_static() &&
1645                      !DynamicArrayShapeIsCompatible(sub_shape,
1646                                                     *bounded_sub_shape)) {
1647             compatible = false;
1648           }
1649         }
1650       } else {
1651         compatible = false;
1652       }
1653     }
1654   });
1655   return compatible;
1656 }
1657 
FilterDimensions(const std::function<bool (int64_t)> & p,Shape shape)1658 /* static */ Shape ShapeUtil::FilterDimensions(
1659     const std::function<bool(int64_t)>& p, Shape shape) {
1660   CHECK(shape.IsArray());
1661   std::vector<int64> dims_to_delete;
1662   for (int64_t i = shape.dimensions().size() - 1; i >= 0; --i) {
1663     if (!p(i)) {
1664       dims_to_delete.push_back(i);
1665     }
1666   }
1667   for (int64_t dim : dims_to_delete) {
1668     shape = DeleteDimension(dim, shape);
1669   }
1670   return shape;
1671 }
1672 
Hash(const Shape & shape)1673 /*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
1674   using tensorflow::hash;
1675   using tensorflow::Hash64Combine;
1676 
1677   size_t hash_value = hash<PrimitiveType>()(shape.element_type());
1678 
1679   if (shape.tuple_shapes().empty()) {
1680     for (int i = 0; i < shape.dimensions_size(); ++i) {
1681       hash_value =
1682           Hash64Combine(hash_value, hash<int64>()(shape.dimensions(i)));
1683       hash_value = Hash64Combine(hash_value,
1684                                  hash<bool>()(shape.is_dynamic_dimension(i)));
1685     }
1686 
1687     hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout()));
1688   } else {
1689     hash_value = 0;
1690     for (const Shape& subshape : shape.tuple_shapes()) {
1691       hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape));
1692     }
1693   }
1694 
1695   return hash_value;
1696 }
1697 
1698 // Returns the indices of the first elements of all consecutive subarrays of the
1699 // given array. For example:
1700 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
ConsecutiveSegments(absl::Span<const int64> xs)1701 static std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
1702   std::vector<size_t> is = {0};
1703   for (size_t i = 1; i < xs.size(); ++i) {
1704     if (1 != xs[i] - xs[i - 1]) {
1705       is.push_back(i);
1706     }
1707   }
1708   return is;
1709 }
1710 
1711 // Merges the sequences of dimensions of the given shape which start at the
1712 // given indices `segs`.
MergeDimensions(absl::Span<const size_t> segs,const Shape & shape)1713 static Shape MergeDimensions(absl::Span<const size_t> segs,
1714                              const Shape& shape) {
1715   std::vector<int64> dimensions;
1716   for (size_t i = 1; i <= segs.size(); ++i) {
1717     dimensions.push_back(std::accumulate(
1718         shape.dimensions().begin() + segs[i - 1],
1719         shape.dimensions().begin() +
1720             (segs.size() == i ? shape.dimensions().size() : segs[i]),
1721         1, std::multiplies<int64>()));
1722   }
1723   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1724                                                   dimensions);
1725 }
1726 
FindTranspose021(const Shape & a,const Shape & b)1727 /*static*/ absl::optional<std::vector<int64>> ShapeUtil::FindTranspose021(
1728     const Shape& a, const Shape& b) {
1729   if (!CompatibleIgnoringElementType(a, b)) {
1730     return absl::nullopt;
1731   }
1732 
1733   std::vector<int64> permutation(a.dimensions().size());
1734   absl::Span<const int64> minor_to_major_a = LayoutUtil::MinorToMajor(a);
1735   std::vector<int64> major_to_minor_a(minor_to_major_a.rbegin(),
1736                                       minor_to_major_a.rend());
1737   absl::Span<const int64> minor_to_major_b = LayoutUtil::MinorToMajor(b);
1738   std::vector<int64> major_to_minor_b(minor_to_major_b.rbegin(),
1739                                       minor_to_major_b.rend());
1740   for (size_t i = 0; i < permutation.size(); ++i) {
1741     permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
1742   }
1743 
1744   std::vector<size_t> segments = ConsecutiveSegments(permutation);
1745   if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
1746     Shape descending_layout_shape =
1747         ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
1748     Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
1749     absl::Span<const int64> normalized_dims =
1750         AsInt64Slice(normalized_shape.dimensions());
1751     std::vector<int64> dims_021;
1752     if (2 == segments.size()) {
1753       // The logical component-0 is of size one.
1754       dims_021 = {1, normalized_dims[1], normalized_dims[0]};
1755     } else {
1756       dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
1757     }
1758 
1759     return dims_021;
1760   }
1761 
1762   return absl::nullopt;
1763 }
1764 
DeviceShapeToHostShape(Shape s)1765 Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
1766   ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) {
1767     if (subshape->IsArray()) {
1768       subshape->mutable_layout()->clear_tiles();
1769       subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);
1770     }
1771   });
1772   return s;
1773 }
1774 
ElementCanUpcast(const Shape & from,const Shape & to)1775 /*static*/ bool ShapeUtil::ElementCanUpcast(const Shape& from,
1776                                             const Shape& to) {
1777   return ElementIsFloating(from) == ElementIsFloating(to) &&
1778          ElementIsSigned(from) == ElementIsSigned(to) &&
1779          HigherPrecisionElementType(from, to) == to.element_type();
1780 }
1781 
1782 /*static*/
ByteStrides(const Shape & shape,absl::Span<int64_t> strides)1783 Status ShapeUtil::ByteStrides(const Shape& shape, absl::Span<int64_t> strides) {
1784   TF_RET_CHECK(shape.IsArray());
1785   TF_RET_CHECK(shape.has_layout());
1786   TF_RET_CHECK(shape.dimensions_size() == strides.size());
1787 
1788   int64_t stride = ByteSizeOfPrimitiveType(shape.element_type());
1789   for (int i : shape.layout().minor_to_major()) {
1790     strides.at(i) = stride;
1791     stride *= shape.dimensions(i);
1792   }
1793   return Status::OK();
1794 }
1795 
1796 }  // namespace xla
1797