• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/shape_util.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/ascii.h"
27 #include "absl/strings/numbers.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/strings/str_split.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/strings/strip.h"
33 #include "absl/types/optional.h"
34 #include "tensorflow/compiler/xla/index_util.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/overflow_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/iterator_range.h"
43 #include "tensorflow/core/lib/hash/hash.h"
44 #include "tensorflow/core/lib/strings/numbers.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/protobuf.h"
47 #include "tensorflow/core/platform/regexp.h"
48 
49 namespace xla {
50 
51 using absl::StrAppend;
52 using absl::StrCat;
53 
ToString() const54 string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
55 
ToString() const56 string ShapeIndexView::ToString() const {
57   return StrCat("{", absl::StrJoin(indices_, ","), "}");
58 }
59 
operator ==(const ShapeIndexView & other) const60 bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
61   return indices_ == other.indices_;
62 }
63 
operator !=(const ShapeIndexView & other) const64 bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
65   return !(*this == other);
66 }
67 
operator <<(std::ostream & out,const ShapeIndex & shape_index)68 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
69   out << shape_index.ToString();
70   return out;
71 }
72 
operator <<(std::ostream & out,const ShapeIndexView & shape_index)73 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
74   out << shape_index.ToString();
75   return out;
76 }
77 
StartsWith(ShapeIndexView prefix) const78 bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const {
79   return size() >= prefix.size() &&
80          indices_.subspan(0, prefix.size()) == prefix.indices_;
81 }
82 
IsArrayPrimitiveType(PrimitiveType primitive_type)83 /* static */ bool ShapeUtil::IsArrayPrimitiveType(
84     PrimitiveType primitive_type) {
85   return primitive_util::IsArrayType(primitive_type);
86 }
87 
88 namespace {
89 // Constructs and returns the new shape with the given minor_to_major order in
90 // its Layout.
MakeShapeWithLayoutInternal(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)91 StatusOr<Shape> MakeShapeWithLayoutInternal(
92     PrimitiveType element_type, absl::Span<const int64> dimensions,
93     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
94     int64 element_size_in_bits, int64 memory_space) {
95   if (dimensions.size() != minor_to_major.size()) {
96     return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
97                            dimensions.size(), minor_to_major.size());
98   }
99   if (element_type == OPAQUE_TYPE || element_type == TUPLE) {
100     return InvalidArgument("Unsupported element type: %s",
101                            PrimitiveType_Name(element_type));
102   }
103   TF_ASSIGN_OR_RETURN(Shape shape,
104                       ShapeUtil::MakeValidatedShape(element_type, dimensions));
105   if (element_size_in_bits ==
106       ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
107     // Only set element_size_in_bits if it's different from the default value.
108     element_size_in_bits = 0;
109   }
110   *shape.mutable_layout() = LayoutUtil::MakeLayout(
111       minor_to_major, tiles, element_size_in_bits, memory_space);
112   if (!shape.has_layout()) {
113     return InvalidArgument("Shape has no layout.");
114   }
115   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
116   return shape;
117 }
118 }  // namespace
119 
Equal(const Shape & lhs,const Shape & rhs)120 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
121   bool equal = Shape::Equal()(lhs, rhs);
122 
123   if (!equal && VLOG_IS_ON(3)) {
124     VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
125             << ", rhs = " << rhs.ShortDebugString();
126   }
127 
128   return equal;
129 }
130 
EqualIgnoringElementType(const Shape & lhs,const Shape & rhs)131 /* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
132                                                       const Shape& rhs) {
133   bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
134   if (!equal && VLOG_IS_ON(3)) {
135     VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
136             << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
137   }
138 
139   return equal;
140 }
141 
EqualIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)142 /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
143                                                       const Shape& rhs) {
144   bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
145   if (!equal && VLOG_IS_ON(3)) {
146     VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
147             << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
148   }
149 
150   return equal;
151 }
152 
TrueRank(const Shape & shape)153 /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
154   int64 accum = 0;
155   for (int64 dimension : shape.dimensions()) {
156     // We do not count zero dimensions.
157     if (dimension != 1) {
158       accum += 1;
159     }
160   }
161   return accum;
162 }
163 
MakeProgramShape(std::initializer_list<Shape> parameters,Shape result)164 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
165     std::initializer_list<Shape> parameters, Shape result) {
166   ProgramShape program_shape;
167   for (const Shape& shape : parameters) {
168     *program_shape.add_parameters() = shape;
169   }
170   *program_shape.mutable_result() = std::move(result);
171   return program_shape;
172 }
173 
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions)174 /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
175                                         absl::Span<const int64> dimensions) {
176   return MakeValidatedShape(element_type, dimensions).ValueOrDie();
177 }
178 
MakeScalarShape(PrimitiveType element_type)179 /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
180   return MakeShape(element_type, {});
181 }
182 
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)183 /* static */ Shape ShapeUtil::MakeShape(
184     PrimitiveType element_type, absl::Span<const int64> dimensions,
185     const std::vector<bool>& dynamic_dimensions) {
186   return MakeValidatedShape(element_type, dimensions, dynamic_dimensions)
187       .ValueOrDie();
188 }
189 
MakeShapeWithStaticDimensions(const Shape & shape)190 /* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
191     const Shape& shape) {
192   Shape output = shape;
193   output.clear_dynamic_dimensions();
194   return output;
195 }
196 
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions)197 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
198     PrimitiveType element_type, absl::Span<const int64> dimensions) {
199   CHECK(IsArrayPrimitiveType(element_type)) << element_type;
200   Shape result;
201   TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result));
202   return result;
203 }
204 
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)205 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
206     PrimitiveType element_type, absl::Span<const int64> dimensions,
207     const std::vector<bool>& dynamic_dimensions) {
208   TF_ASSIGN_OR_RETURN(Shape shape,
209                       MakeValidatedShape(element_type, dimensions));
210   for (int i = 0; i < dynamic_dimensions.size(); ++i) {
211     shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
212   }
213   return shape;
214 }
215 
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)216 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
217     PrimitiveType element_type, absl::Span<const int64> dimensions,
218     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
219     int64 element_size_in_bits, int64 memory_space) {
220   return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major,
221                                      tiles, element_size_in_bits, memory_space)
222       .ValueOrDie();
223 }
224 
MakeShapeWithDescendingLayout(PrimitiveType element_type,absl::Span<const int64> dimensions)225 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
226     PrimitiveType element_type, absl::Span<const int64> dimensions) {
227   std::vector<int64> layout(dimensions.size());
228   std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
229   return MakeShapeWithLayout(element_type, dimensions, layout);
230 }
231 
232 /* static */ Shape
MakeShapeWithDescendingLayoutAndSamePhysicalLayout(const Shape & shape)233 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
234     const Shape& shape) {
235   std::vector<int64> dims(shape.dimensions_size());
236   for (int i = 0; i < shape.dimensions_size(); ++i) {
237     dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
238   }
239   Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
240   // Since the physical layout is kept the same, the tiles and element size are
241   // the same also.
242   new_shape.mutable_layout()->mutable_tiles()->assign(
243       shape.layout().tiles().begin(), shape.layout().tiles().end());
244   new_shape.mutable_layout()->set_element_size_in_bits(
245       shape.layout().element_size_in_bits());
246   for (int i = 0; i < shape.dimensions_size(); ++i) {
247     new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
248   }
249   return new_shape;
250 }
251 
PopulateShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)252 /* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,
253                                              absl::Span<const int64> dimensions,
254                                              Shape* shape) {
255   shape->Clear();
256   shape->set_element_type(element_type);
257   for (int64 dimension : dimensions) {
258     shape->add_dimensions(dimension);
259   }
260   LayoutUtil::SetToDefaultLayout(shape);
261   return ValidateShape(*shape);
262 }
263 
MakeTupleShape(absl::Span<const Shape> shapes)264 /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
265   Shape result;
266   result.set_element_type(TUPLE);
267   result.mutable_tuple_shapes()->reserve(shapes.size());
268   for (const auto& shape : shapes) {
269     AppendShapeToTuple(shape, &result);
270   }
271   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
272   return result;
273 }
274 
MakeOpaqueShape()275 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
276   Shape result;
277   result.set_element_type(OPAQUE_TYPE);
278   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
279   return result;
280 }
281 
MakeTokenShape()282 /* static */ Shape ShapeUtil::MakeTokenShape() {
283   Shape result;
284   result.set_element_type(TOKEN);
285   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
286   return result;
287 }
288 
AppendShapeToTuple(const Shape & shape,Shape * tuple_shape)289 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
290                                                 Shape* tuple_shape) {
291   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
292   *tuple_shape->add_tuple_shapes() = shape;
293 }
294 
UpdateTupleShape(const Shape & shape,int64 index,Shape * tuple_shape)295 /* static */ void ShapeUtil::UpdateTupleShape(const Shape& shape, int64 index,
296                                               Shape* tuple_shape) {
297   CHECK(index < tuple_shape->tuple_shapes_size());
298   *tuple_shape->mutable_tuple_shapes(index) = shape;
299 }
300 
UpdateDynamicDimension(Shape * shape,ShapeIndexView index,int64 dim,bool is_dynamic)301 /* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
302                                                     ShapeIndexView index,
303                                                     int64 dim,
304                                                     bool is_dynamic) {
305   if (index.empty()) {
306     CHECK(!shape->IsTuple());
307     shape->set_dynamic_dimension(dim, is_dynamic);
308     return;
309   }
310 
311   UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
312                          index.ConsumeFront(), dim, is_dynamic);
313 }
314 
AppendMajorDimension(int bound,Shape * shape)315 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
316   CHECK(LayoutUtil::IsDenseArray(*shape));
317   shape->mutable_layout()->add_minor_to_major(shape->rank());
318   shape->add_dimensions(bound);
319   TF_DCHECK_OK(ValidateShape(*shape));
320 }
321 
ElementIsIntegral(const Shape & shape)322 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
323   return primitive_util::IsIntegralType(shape.element_type());
324 }
325 
ElementIsIntegralWithBits(const Shape & shape,int32 bits)326 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
327                                                        int32 bits) {
328   return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
329 }
330 
ElementHasBitWidth(const Shape & shape,int bits)331 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
332   if (!shape.IsArray()) {
333     return false;
334   }
335   return primitive_util::BitWidth(shape.element_type()) == bits;
336 }
337 
ElementIsSigned(const Shape & shape)338 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
339   switch (shape.element_type()) {
340     case S8:
341     case S16:
342     case S32:
343     case S64:
344     case F16:
345     case BF16:
346     case F32:
347     case F64:
348       return true;
349 
350     case PRED:
351     case U8:
352     case U16:
353     case U32:
354     case U64:
355     case C64:
356     case C128:
357     case TUPLE:
358     case OPAQUE_TYPE:
359     case TOKEN:
360       return false;
361 
362     default:
363       LOG(FATAL) << "Unhandled element type " << shape.element_type();
364   }
365 }
366 
ElementIsComplex(const Shape & shape)367 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
368   return primitive_util::IsComplexType(shape.element_type());
369 }
370 
ElementIsFloating(const Shape & shape)371 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
372   return primitive_util::IsFloatingPointType(shape.element_type());
373 }
374 
IsNestedTuple(const Shape & shape)375 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
376   return shape.IsTuple() &&
377          absl::c_any_of(shape.tuple_shapes(),
378                         [](const Shape& s) { return s.IsTuple(); });
379 }
380 
IsEmptyTuple(const Shape & shape)381 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
382   return shape.IsTuple() && TupleElementCount(shape) == 0;
383 }
384 
TupleElementCount(const Shape & shape)385 /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
386   CHECK(shape.IsTuple()) << HumanString(shape);
387   return shape.tuple_shapes_size();
388 }
389 
GetTupleElementShape(const Shape & shape,int64 index)390 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
391                                                           int64 index) {
392   CHECK(shape.IsTuple());
393   CHECK_GT(TupleElementCount(shape), index);
394   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
395   return shape.tuple_shapes(index);
396 }
397 
SubshapeCount(const Shape & shape)398 /* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) {
399   int64 n = 0;
400   ForEachSubshape(shape, [&](const Shape& literal_subshape,
401                              const ShapeIndex& index) { ++n; });
402   return n;
403 }
404 
SliceTuple(const Shape & tuple,int64 start,int64 limit)405 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
406                                          int64 limit) {
407   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
408   CHECK(tuple.IsTuple());
409   CHECK_LE(start, TupleElementCount(tuple));
410   CHECK_LE(limit, TupleElementCount(tuple));
411 
412   std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
413                                   tuple.tuple_shapes().begin() + limit);
414   return MakeTupleShape(new_elements);
415 }
416 
417 // Returns the shape of a real or imaginary component.
ComplexComponentShape(const Shape & complex_shape)418 /* static */ Shape ShapeUtil::ComplexComponentShape(
419     const Shape& complex_shape) {
420   CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
421   return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
422                                               complex_shape.element_type()));
423 }
424 
ElementsIn(const Shape & shape)425 /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
426   DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape);
427   DCHECK_EQ(shape.dimensions_size(), shape.rank());
428   if (shape.dimensions().size() == 1) {
429     return shape.dimensions()[0];
430   }
431   return std::accumulate<decltype(shape.dimensions().begin()), int64>(
432       shape.dimensions().begin(), shape.dimensions().end(), 1LL,
433       std::multiplies<int64>());
434 }
435 
ElementsInRecursive(const Shape & shape)436 /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) {
437   CHECK(shape.IsArray() || shape.IsTuple());
438   if (shape.IsArray()) {
439     return ElementsIn(shape);
440   }
441   int64 count = 0;
442   for (const Shape& element_shape : shape.tuple_shapes()) {
443     count += ElementsInRecursive(element_shape);
444   }
445   return count;
446 }
447 
HasPrimitiveType(const Shape & shape,PrimitiveType primitive_type)448 /* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
449                                               PrimitiveType primitive_type) {
450   if (shape.element_type() == primitive_type) {
451     return true;
452   }
453   for (const Shape& element_shape : shape.tuple_shapes()) {
454     if (HasPrimitiveType(element_shape, primitive_type)) {
455       return true;
456     }
457   }
458   return false;
459 }
460 
IsZeroElementArray(const Shape & shape)461 /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
462   return shape.IsArray() && ElementsIn(shape) == 0;
463 }
464 
IsScalarWithElementType(const Shape & shape,PrimitiveType element_type)465 /* static */ bool ShapeUtil::IsScalarWithElementType(
466     const Shape& shape, PrimitiveType element_type) {
467   return IsScalar(shape) && shape.element_type() == element_type;
468 }
469 
HumanString(const Shape & shape)470 /* static */ string ShapeUtil::HumanString(const Shape& shape) {
471   if (shape.IsTuple()) {
472     string text = "(";
473     const char* prefix = "";
474     for (const Shape& elem_shape : shape.tuple_shapes()) {
475       StrAppend(&text, prefix, HumanString(elem_shape));
476       prefix = ", ";
477     }
478     text += ")";
479     return text;
480   }
481   std::vector<string> dim_elements;
482   for (int i = 0; i < shape.dimensions_size(); ++i) {
483     if (shape.is_dynamic_dimension(i)) {
484       dim_elements.push_back(StrCat("<=", shape.dimensions(i)));
485     } else {
486       dim_elements.push_back(StrCat(shape.dimensions(i)));
487     }
488   }
489   return StrCat(
490       primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
491       absl::StrJoin(dim_elements, ","), "]");
492 }
493 
HumanStringWithLayout(const Shape & shape)494 /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
495   if (shape.IsTuple()) {
496     string text = "(";
497     const char* prefix = "";
498     for (const Shape& elem_shape : shape.tuple_shapes()) {
499       StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
500       prefix = ", ";
501     }
502     text += ")";
503     return text;
504   }
505   string result = StrCat(
506       primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[");
507   for (int i = 0; i < shape.dimensions().size(); i++) {
508     StrAppend(&result, (i > 0) ? "," : "",
509               shape.is_dynamic_dimension(i) ? "<=" : "", shape.dimensions(i));
510   }
511   result += "]";
512   if (IsScalar(shape)) {
513     string layout_str = LayoutUtil::HumanString(shape.layout());
514     // Don't print "{}" as layout for scalars.
515     if (layout_str != "{}") {
516       StrAppend(&result, layout_str);
517     }
518   } else if (shape.IsArray() && LayoutUtil::HasLayout(shape)) {
519     StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
520   }
521   return result;
522 }
523 
HumanString(const ProgramShape & program_shape)524 /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
525   std::vector<string> parameters;
526   for (auto& shape : program_shape.parameters()) {
527     const int i = parameters.size();
528     parameters.push_back(StrCat(i < program_shape.parameter_names_size()
529                                     ? program_shape.parameter_names(i)
530                                     : "(unknown)",
531                                 ": ", HumanString(shape)));
532   }
533   return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
534                 HumanString(program_shape.result()));
535 }
536 
SameDimensions(const Shape & lhs,const Shape & rhs)537 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
538                                             const Shape& rhs) {
539   CHECK(lhs.IsArray());
540   CHECK(rhs.IsArray());
541   return absl::c_equal(lhs.dimensions(), rhs.dimensions());
542 }
543 
Compatible(const Shape & lhs,const Shape & rhs)544 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
545   return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
546 }
547 
CompatibleIgnoringElementType(const Shape & lhs,const Shape & rhs)548 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
549                                                            const Shape& rhs) {
550   return Shape::Equal()
551       .IgnoreDynamicDimension()
552       .IgnoreElementType()
553       .IgnoreLayout()(lhs, rhs);
554 }
555 
CompatibleIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)556 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
557                                                            const Shape& rhs) {
558   return Shape::Equal()
559       .IgnoreDynamicDimension()
560       .IgnoreFpPrecision()
561       .IgnoreLayout()(lhs, rhs);
562 }
563 
GetDimension(const Shape & shape,int64 dimension_number)564 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
565                                            int64 dimension_number) {
566   return shape.dimensions(GetDimensionNumber(shape, dimension_number));
567 }
568 
GetDimensionNumber(const Shape & shape,int64 dimension_number)569 /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
570                                                  int64 dimension_number) {
571   if (dimension_number < 0) {
572     dimension_number += shape.rank();
573   }
574   CHECK_GE(dimension_number, 0);
575   return dimension_number;
576 }
577 
ByteSizeOfPrimitiveType(PrimitiveType primitive_type)578 /* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
579     PrimitiveType primitive_type) {
580   switch (primitive_type) {
581     case PRED:
582       return sizeof(int8);
583     case S8:
584       return sizeof(int8);
585     case S16:
586       return sizeof(int16);
587     case S32:
588       return sizeof(int32);
589     case S64:
590       return sizeof(int64);
591     case U8:
592       return sizeof(uint8);
593     case U16:
594       return sizeof(uint16);
595     case U32:
596       return sizeof(uint32);
597     case U64:
598       return sizeof(uint64);
599     case BF16:
600       return sizeof(float) / 2;
601     case F16:
602       return sizeof(float) / 2;
603     case F32:
604       return sizeof(float);
605     case F64:
606       return sizeof(double);
607     case C64:
608       return sizeof(complex64);
609     case C128:
610       return sizeof(complex128);
611     case TOKEN:
612       // Tokens require no space.
613       return 0;
614     case TUPLE:
615     case OPAQUE_TYPE:
616       LOG(FATAL) << PrimitiveType_Name(primitive_type)
617                  << " primitive type has no definitive size";
618     default:
619       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
620   }
621 }
622 
ByteSizeOf(const Shape & shape,int64 pointer_size)623 /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
624                                          int64 pointer_size) {
625   TF_DCHECK_OK(ValidateShape(shape));
626   if (shape.element_type() == TUPLE) {
627     return ByteSizeOfTupleIndexTable(shape, pointer_size);
628   } else if (shape.IsArray()) {
629     int64 byte_size = ByteSizeOfElements(shape);
630     return byte_size;
631   } else if (shape.element_type() == TOKEN) {
632     return 0;
633   } else if (shape.element_type() == OPAQUE_TYPE) {
634     CHECK_GT(pointer_size, 0);
635     return pointer_size;
636   }
637   LOG(FATAL) << PrimitiveType_Name(shape.element_type())
638              << " primitive type has no definitive size";
639 }
640 
ByteSizeOfTupleIndexTable(const Shape & shape,int64 pointer_size)641 /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
642                                                         int64 pointer_size) {
643   TF_DCHECK_OK(ValidateShape(shape));
644   CHECK_EQ(TUPLE, shape.element_type());
645   CHECK_GT(pointer_size, 0);
646   return pointer_size * shape.tuple_shapes_size();
647 }
648 
ByteSizeOfElements(const Shape & shape)649 /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
650   TF_DCHECK_OK(ValidateShape(shape));
651   CHECK(shape.IsArray());
652   int64 allocated_element_count;
653 
654   CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
655   allocated_element_count = ElementsIn(shape);
656   return allocated_element_count *
657          ByteSizeOfPrimitiveType(shape.element_type());
658 }
659 
ValidateShapeWithOptionalLayoutInternal(const Shape & shape)660 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
661     const Shape& shape) {
662   if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
663       !PrimitiveType_IsValid(shape.element_type())) {
664     return InvalidArgument("shape has invalid element type: %s",
665                            shape.ShortDebugString());
666   }
667   if (shape.element_type() == TUPLE) {
668     if (shape.dimensions_size() != 0) {
669       return InvalidArgument("tuples must not have dimensions specified");
670     }
671     for (auto& element_shape : shape.tuple_shapes()) {
672       TF_RETURN_IF_ERROR(
673           ValidateShapeWithOptionalLayoutInternal(element_shape));
674     }
675     return Status::OK();
676   }
677 
678   // Non-tuple shape.
679   if (shape.tuple_shapes_size() > 0) {
680     return InvalidArgument("non-tuple shape has tuple_shapes field");
681   }
682 
683   // Tokens and opaques can should not have layout or dimensions.
684   if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) {
685     if (shape.dimensions_size() != 0) {
686       return InvalidArgument(
687           "shape has %s element type, but has dimensions field: %s",
688           primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
689           shape.ShortDebugString());
690     }
691     if (shape.has_layout()) {
692       return InvalidArgument(
693           "shape has %s element type, but has layout field: %s",
694           primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
695           shape.ShortDebugString());
696     }
697     return Status::OK();
698   }
699 
700   for (int64 i = 0; i < shape.rank(); ++i) {
701     int64 dimension = shape.dimensions(i);
702     if (dimension < 0) {
703       return InvalidArgument(
704           "shape's dimensions must not be < 0; dimension at index %d was %d", i,
705           dimension);
706     }
707   }
708 
709   TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
710   return Status::OK();
711 }
712 
ValidateShapeSize(const Shape & shape)713 /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
714   VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
715 
716   if (!shape.IsArray()) {
717     return Status::OK();
718   }
719 
720   int64 shape_size = [&]() {
721     int64 dense_shape_size = 1;
722     if (shape.dimensions().empty()) {
723       return dense_shape_size;
724     }
725 
726     absl::Span<const int64> shape_max_dimensions =
727         AsInt64Slice(shape.dimensions());
728     for (int64 dim : shape_max_dimensions) {
729       dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
730       if (dense_shape_size < 0) {
731         return dense_shape_size;
732       }
733     }
734     dense_shape_size = MultiplyWithoutOverflow(
735         dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
736     return dense_shape_size;
737   }();
738 
739   if (shape_size < 0) {
740     return InvalidArgument("Shape %s size may overflow int64.",
741                            ShapeUtil::HumanString(shape));
742   }
743 
744   VLOG(3) << "Shape size is valid: " << shape_size;
745   return Status::OK();
746 }
747 
ValidateShapeWithOptionalLayout(const Shape & shape)748 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
749     const Shape& shape) {
750   TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
751 
752   return LayoutUtil::ValidateLayoutInShape(shape,
753                                            /*allow_missing_layouts=*/true);
754 }
755 
ValidateShape(const Shape & shape)756 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
757   TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
758 
759   return LayoutUtil::ValidateLayoutInShape(shape);
760 }
761 
ChangeElementType(const Shape & original,PrimitiveType type)762 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
763                                                 PrimitiveType type) {
764   Shape new_shape = original;
765   new_shape.set_element_type(type);
766   return new_shape;
767 }
768 
IndexIsValid(const Shape & shape,ShapeIndexView index)769 /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
770                                           ShapeIndexView index) {
771   const Shape* subshape = &shape;
772   for (auto i : index) {
773     if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) {
774       return false;
775     }
776     subshape = &subshape->tuple_shapes(i);
777   }
778   return true;
779 }
780 
GetSubshape(const Shape & shape,ShapeIndexView index)781 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
782                                                  ShapeIndexView index) {
783   const Shape* return_shape = &shape;
784   for (auto i : index) {
785     CHECK(return_shape->IsTuple())
786         << "Invalid index " << index << " for shape " << shape;
787     return_shape = &return_shape->tuple_shapes(i);
788   }
789   return *return_shape;
790 }
791 
TryGetSubshape(const Shape & shape,ShapeIndexView index)792 /* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
793     const Shape& shape, ShapeIndexView index) {
794   const Shape* return_shape = &shape;
795   for (auto i : index) {
796     if (!return_shape->IsTuple() || i < 0 ||
797         i >= return_shape->tuple_shapes_size()) {
798       return InvalidArgument(
799           "Shape index %s not a valid subshape index for tuple with shape %s",
800           index.ToString(), shape.DebugString());
801     }
802     return_shape = &return_shape->tuple_shapes(i);
803   }
804   return return_shape;
805 }
806 
GetMutableSubshape(Shape * shape,ShapeIndexView index)807 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
808                                                   ShapeIndexView index) {
809   Shape* return_shape = shape;
810   for (auto i : index) {
811     CHECK(return_shape->IsTuple());
812     return_shape = return_shape->mutable_tuple_shapes(i);
813   }
814   return return_shape;
815 }
816 
817 /* static */
IsLeafIndex(const Shape & shape,const ShapeIndex & index)818 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
819   return !GetSubshape(shape, index).IsTuple();
820 }
821 
GetLeafCount(const Shape & shape)822 /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
823   if (!shape.IsTuple()) {
824     return 1;
825   }
826   int64 count = 0;
827   for (const Shape& subshape : shape.tuple_shapes()) {
828     count += GetLeafCount(subshape);
829   }
830   return count;
831 }
832 
GetLeafShapes(const Shape & shape)833 /* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
834     const Shape& shape) {
835   std::vector<IndexedShape> leaves;
836   ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
837     if (IsLeafIndex(shape, index)) {
838       leaves.emplace_back(index, sub_shape);
839     }
840   });
841   return leaves;
842 }
843 
HasDegenerateDimensions(const Shape & shape)844 /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
845   CHECK(shape.IsArray());
846   return absl::c_linear_search(shape.dimensions(), 1);
847 }
848 
DropDegenerateDimensions(const Shape & shape)849 /* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
850   return FilterDimensions(
851       [&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
852 }
853 
854 namespace {
855 
856 // Helper for ForEachSubshape which visits the subshapes of the given shape in
857 // DFS pre-order starting with the index.
ForEachSubshapeHelper(const Shape & shape,const ShapeUtil::StatusVisitorFunction & func,ShapeIndex * index)858 Status ForEachSubshapeHelper(const Shape& shape,
859                              const ShapeUtil::StatusVisitorFunction& func,
860                              ShapeIndex* index) {
861   TF_RETURN_IF_ERROR(func(shape, *index));
862   if (shape.IsTuple()) {
863     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
864       index->push_back(i);
865       TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
866           ShapeUtil::GetTupleElementShape(shape, i), func, index));
867       index->pop_back();
868     }
869   }
870   return Status::OK();
871 }
872 
873 // Helper for ForEachMutableSubshape which visits the subshapes of the given
874 // shape in DFS pre-order starting with the index.
ForEachMutableSubshapeHelper(Shape * shape,const ShapeUtil::MutatingStatusVisitorFunction & func,ShapeIndex * index)875 Status ForEachMutableSubshapeHelper(
876     Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
877     ShapeIndex* index) {
878   TF_RETURN_IF_ERROR(func(shape, *index));
879   if (shape->IsTuple()) {
880     for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
881       index->push_back(i);
882       TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
883           shape->mutable_tuple_shapes(i), func, index));
884       index->pop_back();
885     }
886   }
887   return Status::OK();
888 }
889 
890 }  // namespace
891 
ForEachSubshape(const Shape & shape,const VisitorFunction & func)892 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
893                                              const VisitorFunction& func) {
894   ShapeIndex index;
895   ForEachSubshapeHelper(
896       shape,
897       [&func](const Shape& subshape, const ShapeIndex& index) {
898         func(subshape, index);
899         return Status::OK();
900       },
901       &index)
902       .IgnoreError();
903 }
904 
ForEachMutableSubshape(Shape * shape,const MutatingVisitorFunction & func)905 /* static */ void ShapeUtil::ForEachMutableSubshape(
906     Shape* shape, const MutatingVisitorFunction& func) {
907   ShapeIndex index;
908   ForEachMutableSubshapeHelper(
909       shape,
910       [&func](Shape* subshape, const ShapeIndex& index) {
911         func(subshape, index);
912         return Status::OK();
913       },
914       &index)
915       .IgnoreError();
916 }
917 
ForEachSubshapeWithStatus(const Shape & shape,const StatusVisitorFunction & func)918 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
919     const Shape& shape, const StatusVisitorFunction& func) {
920   ShapeIndex index;
921   return ForEachSubshapeHelper(shape, func, &index);
922 }
923 
ForEachMutableSubshapeWithStatus(Shape * shape,const MutatingStatusVisitorFunction & func)924 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
925     Shape* shape, const MutatingStatusVisitorFunction& func) {
926   ShapeIndex index;
927   return ForEachMutableSubshapeHelper(shape, func, &index);
928 }
929 
PermuteDimensions(absl::Span<const int64> permutation,const Shape & shape)930 /* static */ Shape ShapeUtil::PermuteDimensions(
931     absl::Span<const int64> permutation, const Shape& shape) {
932   Shape new_shape = shape;
933   new_shape.clear_dimensions();
934   for (auto dim : Permute(permutation, shape.dimensions())) {
935     new_shape.add_dimensions(dim);
936   }
937   for (int64 i = 0; i < shape.rank(); i++) {
938     new_shape.set_dynamic_dimension(permutation[i],
939                                     shape.is_dynamic_dimension(i));
940   }
941 
942   // If `shape` has a layout, by contract we choose a new layout such that the
943   // transpose defined by this permutation is a bitcast.
944   //
945   // Some formalism helps to understand the correct way to do this.  We're going
946   // to do algebra in the group of permutations of the dimensions of `shape`.
947   //
948   // Since the order of `shape`'s dimensions is not permuted relative to itself,
949   // `shape`'s list of dimensions is isomorphic to the identity I.
950   //
951   // Let `shape`'s layout be L.  A layout is a permutation which maps a
952   // minor-to-major physical layout to the order of a shape's logical dims.
953   // Therefore inverse of a layout maps from logical to physical dims, and so
954   // the physical layout of I is simply L'.I = L', where L' is the inverse of L.
955   //
956   // Let the argument `permutation` be P.  This is a permutation over `shape`'s
957   // dimensions, so our return value will be a shape with dims P.I = P.  Our
958   // goal is to construct a layout permutation L* that we can apply to P such
959   // that the physical dimension ordering of the returned shape is the same
960   // as that of the original shape, namely L'.
961   //
962   // Our returned shape has dims P and layout L*, so its in-memory layout is
963   // L*'.P.  Setting this equal to L' and solving for L*, we get:
964   //
965   //   L*'.P = L'    =>
966   //   L*'   = L'P'  =>
967   //   L*    = P.L
968   //
969   if (shape.has_layout()) {
970     CHECK(LayoutUtil::IsDenseArray(shape));
971     Layout* new_layout = new_shape.mutable_layout();
972     new_layout->set_format(DENSE);
973     new_layout->clear_minor_to_major();
974     for (auto index : ComposePermutations(
975              permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
976       new_layout->add_minor_to_major(index);
977     }
978     // The permutation accepted by TransposeIsBitcast is the inverse of the
979     // permutation here.
980     CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation)))
981         << "shape=" << HumanStringWithLayout(shape)
982         << ", new_shape=" << HumanStringWithLayout(new_shape)
983         << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
984   }
985   return new_shape;
986 }
987 
988 /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
InsertedOrDeleted1SizedDimensions(const Shape & shape_pre,const Shape & shape_post)989 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
990                                              const Shape& shape_post) {
991   CHECK(shape_pre.IsArray());
992   CHECK(shape_post.IsArray());
993 
994   auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
995 
996   std::vector<int64> deleted_indices;
997   std::vector<int64> inserted_indices;
998   // Returns false if any input/output index between prior_unmodified_dim_pair
999   // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
1000   // the degerenate input/output dimensions in the gap to
1001   // deleted_indices/inserted_indices respectively.
1002   auto check_modified_dims =
1003       [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
1004           std::pair<int64, int64> prior_unmodified_dim_pair,
1005           std::pair<int64, int64> unmodified_dim_pair) {
1006         for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
1007              modified_input_dim < unmodified_dim_pair.first;
1008              ++modified_input_dim) {
1009           if (shape_pre.dimensions(modified_input_dim) > 1) {
1010             return false;
1011           }
1012           deleted_indices.push_back(modified_input_dim);
1013         }
1014         for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
1015              modified_output_dim < unmodified_dim_pair.second;
1016              ++modified_output_dim) {
1017           if (shape_post.dimensions(modified_output_dim) > 1) {
1018             return false;
1019           }
1020           inserted_indices.push_back(modified_output_dim);
1021         }
1022         return true;
1023       };
1024 
1025   std::vector<std::pair<int64, int64>> unmodified_dims =
1026       DimensionsUnmodifiedByReshape(shape_pre, shape_post);
1027   // Returns nil if the reshape modifies any non-degenerate input/output
1028   // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
1029   // dimensions, so we only need to check whether dimensions in the gaps (thus
1030   // modified) have size >1.
1031   for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
1032     // Check (modified) dimensions between unmodified_dims[i-1] and
1033     // unmodified_dims[i].
1034     auto prior_unmodified_dim_pair =
1035         i > 0 ? unmodified_dims[i - 1] : std::pair<int64, int64>(-1, -1);
1036     auto unmodified_dim_pair =
1037         i < unmodified_dims.size()
1038             ? unmodified_dims[i]
1039             : std::make_pair(shape_pre.rank(), shape_post.rank());
1040     if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
1041       return nil;
1042     }
1043   }
1044 
1045   return std::make_tuple(true, deleted_indices, inserted_indices);
1046 }
1047 
1048 /* static */ std::vector<std::pair<int64, int64>>
DimensionsUnmodifiedByReshape(const Shape & input_shape,const Shape & output_shape)1049 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
1050                                          const Shape& output_shape) {
1051   CHECK(input_shape.IsArray());
1052   CHECK(output_shape.IsArray());
1053 
1054   // Unmodified dimensions are merely common factors of rank 1.
1055   auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
1056                                       AsInt64Slice(output_shape.dimensions()));
1057   for (size_t i = 0; i < common_factors.size() - 1;) {
1058     if (1 != common_factors[i + 1].first - common_factors[i].first ||
1059         1 != common_factors[i + 1].second - common_factors[i].second) {
1060       common_factors.erase(common_factors.begin() + i);
1061     } else {
1062       ++i;
1063     }
1064   }
1065   // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
1066   common_factors.pop_back();
1067   return std::vector<std::pair<int64, int64>>(common_factors.begin(),
1068                                               common_factors.end());
1069 }
1070 
1071 /* static */ absl::optional<std::vector<int64>>
ReshapeLeavesDimensionsUnmodified(const Shape & from_shape,const Shape & to_shape,absl::Span<const int64> input_dim_indices)1072 ShapeUtil::ReshapeLeavesDimensionsUnmodified(
1073     const Shape& from_shape, const Shape& to_shape,
1074     absl::Span<const int64> input_dim_indices) {
1075   CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
1076 
1077   std::vector<int64> output_dim_indices;
1078   std::vector<std::pair<int64, int64>> unmodified_dims =
1079       ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
1080   size_t i = 0;  // index to unmodified_dims
1081   for (int64 input_dim_index : input_dim_indices) {
1082     // Search unmodified_dims for input_dim_index. We can search from the last
1083     // matching position because input_dim_indices is guaranteed to be sorted.
1084     while (i < unmodified_dims.size() &&
1085            unmodified_dims[i].first < input_dim_index) {
1086       ++i;
1087     }
1088     if (i >= unmodified_dims.size() ||
1089         unmodified_dims[i].first != input_dim_index) {
1090       return absl::nullopt;
1091     }
1092     output_dim_indices.push_back(unmodified_dims[i].second);
1093   }
1094   return output_dim_indices;
1095 }
1096 
TransposeIsBitcast(const Shape & input_shape,const Shape & output_shape,absl::Span<const int64> dimension_mapping)1097 /* static */ bool ShapeUtil::TransposeIsBitcast(
1098     const Shape& input_shape, const Shape& output_shape,
1099     absl::Span<const int64> dimension_mapping) {
1100   CHECK(LayoutUtil::HasLayout(input_shape) &&
1101         LayoutUtil::HasLayout(output_shape));
1102 
1103   if (!SameElementType(input_shape, output_shape)) {
1104     return false;
1105   }
1106 
1107   // Check the reshape permutes the positions of each dimension in the
1108   // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
1109   //   input_positions = apply(dimension_mapping, output_positions)
1110   //
1111   // Because the positions of each dimension are the inverse permutation of the
1112   // minor-to-major order, the above check is equivalent to
1113   //   inverse(input_dimensions) =
1114   //       apply(dimension_mapping, inverse(output_dimensions))
1115   //   # `I` indicates identity permutation.
1116   //   apply(input_dimensions, I) =
1117   //       apply(dimension_mapping, apply(output_dimensions, I))
1118   //   apply(input_dimensions, I) =
1119   //       apply((dimension_mapping * output_dimensions), I)
1120   //   input_dimensions = dimension_mapping * output_dimensions
1121   return absl::c_equal(
1122       ComposePermutations(dimension_mapping,
1123                           AsInt64Slice(output_shape.layout().minor_to_major())),
1124       input_shape.layout().minor_to_major());
1125 }
1126 
ReshapeIsBitcast(const Shape & input_shape,const Shape & output_shape)1127 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
1128                                               const Shape& output_shape) {
1129   CHECK(input_shape.IsArray());
1130   CHECK(output_shape.IsArray());
1131   CHECK(LayoutUtil::HasLayout(input_shape));
1132   CHECK(LayoutUtil::HasLayout(output_shape));
1133 
1134   if (!SameElementType(input_shape, output_shape)) {
1135     return false;
1136   }
1137 
1138   CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape));
1139   if (ElementsIn(input_shape) == 0) {
1140     return true;
1141   }
1142 
1143   // TL;DR: The rest of the method checks that the reshape does not change the
1144   // physical location of any unit input or output index. Unit indices have
1145   // exactly one dimension that equals 1 and other dimensions 0. This condition
1146   // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
1147   // reshape shouldn't change the physical location of any element. It is also a
1148   // sufficient condition as is proved below (note: many details are omitted for
1149   // space).
1150   //
1151   // Definitions:
1152   //
1153   // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
1154   // the size of i-th least significant dimension of IS or OS (this is opposite
1155   // to how we define the index of Shape::dimensions()).
1156   //
1157   // * Given an input or output index I, denote by p(I) I's physical linear
1158   // index (or physical index for short) and l(I) I's logical linear index (or
1159   // logical index for short).
1160   //
1161   // * Given a logical index k, denote by II(k) the input index whose linear
1162   // index is k, and OI(k) the corresponding output index.
1163   //
1164   // * Denote by IT[i] the increment of physical index if i-th dimension of the
1165   // input index is increased by 1. Similarly, OT[i] means the increment if i-th
1166   // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
1167   // is a function of IS or OS and the layout, and not dependent on the specific
1168   // input or output index.
1169   //
1170   // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
1171   // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
1172   // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
1173   // to prove, with every increment on k, the above formula still holds.
1174   //
1175   // First, suppose reshaping from IS to OS is non-factorizable (we discuss
1176   // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
1177   // there exists (i,j) such that
1178   //
1179   //   0<=i<=|IS|
1180   //   0<=j<=|OS|
1181   //   |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
1182   //   product(IS[i], IS[i+1], ..., IS[|IS|-1])
1183   //     = product(OS[j], OS[j+1], ..., OS[|OS|-1])
1184   //
1185   // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
1186   // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
1187   // unit indices which are already tested. This also means IT[0]=OT[0]
1188   // because p(II(1))=IT[0] and p(OI(1))=OT[0].
1189   //
1190   // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
1191   // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
1192   // to the output physical.
1193   //
1194   // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
1195   // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
1196   // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
1197   // logical index k-1 to logical index k, dimension 1 of the input index
1198   // is increased by 1 and dimension 0 is reset to 0 thus decreased by
1199   // IS[0]-1. Therefore, the physical input index is increased by
1200   //
1201   //   p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
1202   //
1203   // Because IS[0]<OS[0], the only change to the output index is that its
1204   // dimension 0 is increased by one. Therefore,
1205   //
1206   //   p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
1207   //
1208   // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
1209   // p(II(k))=p(OI(k)). Therefore,
1210   //   IT[1] - (IS[0]-1) * IT[0] = IT[0]
1211   //   IT[1] = IS[0] * IT[0]
1212   // In other words, input dimension 1 is immediately more major than input
1213   // dimension 0. We can now conceptually collapse these two dimensions because
1214   // an increment in the logical index affecting only these two dimensions maps
1215   // to IT[0] in the physical index.
1216   //
1217   // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
1218   // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
1219   // identical.
1220   //
1221   // A factorizable reshape can be factorized into a list of non-factorizable
1222   // sub-reshapes, each of which can be handled similarly to the proof above.
1223   // For example,
1224   //
1225   //   [7x9x2x15] -> [63x6x5]
1226   //
1227   // can be factorized into
1228   //
1229   //   [7x9] -> [63] and [2x15] -> [6x5].
1230   //
1231   // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
1232   // same logical linear index. According to the factorization, we know
1233   // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
1234   // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
1235   // similar proof, with the increment of the logical index set to
1236   // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
1237   // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
1238   //
1239   //   p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
1240   //                  = p(y2,0,0) + p(0,0,y1,y0)
1241   //                  = p(y2,y1,y0)
1242   //
1243   // check_input_unit_indices checks one way of the condition: each input unit
1244   // index is mapped to an output index with the same physical location. This
1245   // lambda will be called again with input_shape and output_shape reversed to
1246   // check the other way.
1247   auto check_input_unit_indices = [](const Shape& input_shape,
1248                                      const Shape& output_shape) {
1249     // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
1250     // as input_shape/output_shape and the dimension-0-major layout. These two
1251     // shapes are used for conversion between logical linear indices and
1252     // multi-dimensional indices.
1253     Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
1254         input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
1255     Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
1256         output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
1257 
1258     for (int64 input_dim = 0; input_dim < input_shape.rank(); ++input_dim) {
1259       if (input_shape.dimensions(input_dim) <= 1) {
1260         continue;
1261       }
1262 
1263       std::vector<int64> input_unit_index(input_shape.rank(), 0);
1264       input_unit_index[input_dim] = 1;
1265       int64 logical_linear_index =
1266           IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
1267                                                         input_unit_index);
1268       // output_index has the same logical linear index as input_unit_index.
1269       std::vector<int64> output_index =
1270           IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
1271                                                         logical_linear_index);
1272       // Check input_unit_index and output_index have the same physical linear
1273       // index.
1274       if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
1275                                                         input_unit_index) !=
1276           IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
1277                                                         output_index)) {
1278         return false;
1279       }
1280     }
1281     return true;
1282   };
1283   return check_input_unit_indices(input_shape, output_shape) &&
1284          check_input_unit_indices(output_shape, input_shape);
1285 }
1286 
AlignLayouts(const Shape & input_shape,const Shape & output_shape)1287 /* static */ absl::optional<Shape> ShapeUtil::AlignLayouts(
1288     const Shape& input_shape, const Shape& output_shape) {
1289   CHECK(input_shape.IsArray());
1290   CHECK(output_shape.IsArray());
1291   // Removing trivial dimensions from the shape simplifies the alignment
1292   // algorithm since ones can go in any position.
1293   if (HasDegenerateDimensions(input_shape) ||
1294       HasDegenerateDimensions(output_shape)) {
1295     auto simple_output_shape =
1296         AlignLayouts(DropDegenerateDimensions(input_shape),
1297                      DropDegenerateDimensions(output_shape));
1298     if (!simple_output_shape) {
1299       return absl::nullopt;
1300     }
1301 
1302     std::vector<int64> layout =
1303         SpanToVector(simple_output_shape->layout().minor_to_major());
1304     // For each one sized dimension in the output, increment the dimension
1305     // numbers in layout that are more minor than the one.
1306     absl::InlinedVector<int64, 8> dim_map;
1307     dim_map.reserve(simple_output_shape->rank());
1308     for (int64 i = 0; i < output_shape.rank(); ++i) {
1309       if (output_shape.dimensions(i) != 1) {
1310         dim_map.push_back(i);
1311       }
1312     }
1313     for (int64& d : layout) {
1314       d = dim_map[d];
1315     }
1316 
1317     // Add the ones in descending order to the layout. Descending layouts tend
1318     // to reduce the number of copies inserted in layout assignment.
1319     for (int64 i = output_shape.rank() - 1; i >= 0; --i) {
1320       if (output_shape.dimensions(i) == 1) {
1321         layout.push_back(i);
1322       }
1323     }
1324     Shape output_shape_with_layout = output_shape;
1325     *output_shape_with_layout.mutable_layout() = Layout{layout};
1326     return output_shape_with_layout;
1327   }
1328 
1329   int64 input_rank = input_shape.rank();
1330   int64 output_rank = output_shape.rank();
1331 
1332   // First, calculate an alignment of the dimensions. A consecutive sequence of
1333   // input dimensions and output dimensions belong to the same alignment part if
1334   // the products of their dimension bounds are the same. In the easiest case,
1335   // an alignment part consists of one input dimension and one output dimension
1336   // which both have the same dimension bound. An alignment part specifies which
1337   // dimensions need to be kept together in a physical layout if we want a
1338   // reshape to be a bitcast. The order of the alignment parts is defined by the
1339   // physical layout of the input shape, so when we construct the layout for the
1340   // output shape we just process the alignment parts in this order, and then
1341   // layout the dimensions belonging to each part in descending (major to minor)
1342   // order.
1343 
1344   // Stores the input and output dimension numbers where each alignment part
1345   // starts.
1346   std::vector<std::pair<int64, int64>> alignment;
1347   alignment.push_back({0, 0});
1348 
1349   // Stores a mapping from the input dimension to the alignment part it belongs
1350   // to.
1351   std::vector<int64> dimension_to_alignment_index(input_rank);
1352   int64 input_dimension_product = 1, output_dimension_product = 1;
1353   for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) {
1354     // Check if we have reached the end of an alignment part.
1355     if (input_dimension_product == output_dimension_product &&
1356         input_dimension_product > 1) {
1357       alignment.push_back({i, j});
1358       input_dimension_product = output_dimension_product = 1;
1359     }
1360     if (input_dimension_product < output_dimension_product ||
1361         j == output_rank) {
1362       if (i == input_rank) {
1363         return absl::nullopt;
1364       }
1365       dimension_to_alignment_index[i] = alignment.size() - 1;
1366       input_dimension_product *= input_shape.dimensions(i);
1367       ++i;
1368     } else {
1369       output_dimension_product *= output_shape.dimensions(j);
1370       ++j;
1371     }
1372   }
1373   if (input_dimension_product != output_dimension_product) {
1374     return absl::nullopt;
1375   }
1376 
1377   // We also need to store an end element so that we know where the last
1378   // alignment part ends.
1379   alignment.push_back({input_rank, output_rank});
1380   // Now check if the physical layout can potentially be aligned to the output
1381   // shape by changing the physical layout of the output shape. We need to check
1382   // that all dimension numbers that belong to the same alignment part appear
1383   // consecutively, and are in descending order. However we can ignore any
1384   // trivial dimension bounds of 1, because they can be placed anywhere.
1385   auto input_dimension_numbers = input_shape.layout().minor_to_major();
1386   std::vector<int64> output_layout;
1387   output_layout.reserve(output_rank);
1388   for (int64 i = 0; i < input_rank;) {
1389     int64 current_dimension_number = input_dimension_numbers[i];
1390 
1391     // Trivial dimensions are stripped.
1392     CHECK_NE(input_shape.dimensions(current_dimension_number), 1);
1393     const int64 current_alignment_index =
1394         dimension_to_alignment_index[current_dimension_number];
1395     // Because of the special end element that we added, we can be sure that
1396     // 'current_alignment_index' is < alignment.size() - 1.
1397     CHECK_LT(current_alignment_index, alignment.size() - 1);
1398 
1399     // Check that the following 'num_non_trivial_dimensions_in_alignment_part'
1400     // dimension numbers (ignoring dimension numbers with dimension bound 1) are
1401     // in descending order and belong to the current alignment part.
1402     for (int64 j = 0; j < alignment[current_alignment_index + 1].first -
1403                               alignment[current_alignment_index].first;
1404          ++i, ++j) {
1405       if (i == input_rank) {
1406         return absl::nullopt;
1407       }
1408       // If the current dimension number belongs to a different alignment part,
1409       // or the dimension numbers are not in descending order, we can return
1410       // early.
1411       if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
1412               current_alignment_index ||
1413           input_dimension_numbers[i] > current_dimension_number) {
1414         return absl::nullopt;
1415       }
1416       current_dimension_number = input_dimension_numbers[i];
1417     }
1418     // The output dimension numbers that belong to the current alignment part
1419     // need to appear in the same descending order as in the input.
1420     for (int64 j = alignment[current_alignment_index + 1].second - 1;
1421          j >= alignment[current_alignment_index].second; --j) {
1422       output_layout.push_back(j);
1423     }
1424   }
1425   CHECK_EQ(output_layout.size(), output_rank);
1426   Shape output_shape_with_layout = MakeShapeWithLayout(
1427       output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1428       output_layout);
1429   CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout))
1430       << "reshape is not a bitcast for input_shape: "
1431       << ShapeUtil::HumanStringWithLayout(input_shape)
1432       << " and output_shape_with_layout: "
1433       << ShapeUtil::HumanStringWithLayout(output_shape_with_layout);
1434   return output_shape_with_layout;
1435 }
1436 
DeleteDimension(int64 dim_to_delete,Shape shape)1437 /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
1438                                               Shape shape) {
1439   CHECK(shape.IsArray());
1440   shape.DeleteDimension(dim_to_delete);
1441   return shape;
1442 }
1443 
FilterDimensions(const std::function<bool (int64)> & p,Shape shape)1444 /* static */ Shape ShapeUtil::FilterDimensions(
1445     const std::function<bool(int64)>& p, Shape shape) {
1446   CHECK(shape.IsArray());
1447   std::vector<int64> dims_to_delete;
1448   for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
1449     if (!p(i)) {
1450       dims_to_delete.push_back(i);
1451     }
1452   }
1453   for (int64 dim : dims_to_delete) {
1454     shape = DeleteDimension(dim, shape);
1455   }
1456   return shape;
1457 }
1458 
Hash(const Shape & shape)1459 /*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
1460   using tensorflow::hash;
1461   using tensorflow::Hash64Combine;
1462 
1463   size_t hash_value = hash<PrimitiveType>()(shape.element_type());
1464 
1465   if (shape.tuple_shapes().empty()) {
1466     for (int i = 0; i < shape.dimensions_size(); ++i) {
1467       hash_value =
1468           Hash64Combine(hash_value, hash<int64>()(shape.dimensions(i)));
1469       hash_value = Hash64Combine(hash_value,
1470                                  hash<bool>()(shape.is_dynamic_dimension(i)));
1471     }
1472 
1473     hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout()));
1474   } else {
1475     hash_value = 0;
1476     for (const Shape& subshape : shape.tuple_shapes()) {
1477       hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape));
1478     }
1479   }
1480 
1481   return hash_value;
1482 }
1483 
1484 // Returns the indices of the first elements of all consecutive subarrays of the
1485 // given array. For example:
1486 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
ConsecutiveSegments(absl::Span<const int64> xs)1487 static std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
1488   std::vector<size_t> is = {0};
1489   for (size_t i = 1; i < xs.size(); ++i) {
1490     if (1 != xs[i] - xs[i - 1]) {
1491       is.push_back(i);
1492     }
1493   }
1494   return is;
1495 }
1496 
1497 // Merges the sequences of dimensions of the given shape which start at the
1498 // given indices `segs`.
MergeDimensions(absl::Span<const size_t> segs,const Shape & shape)1499 static Shape MergeDimensions(absl::Span<const size_t> segs,
1500                              const Shape& shape) {
1501   std::vector<int64> dimensions;
1502   for (size_t i = 1; i <= segs.size(); ++i) {
1503     dimensions.push_back(std::accumulate(
1504         shape.dimensions().begin() + segs[i - 1],
1505         shape.dimensions().begin() +
1506             (segs.size() == i ? shape.dimensions().size() : segs[i]),
1507         1, std::multiplies<int64>()));
1508   }
1509   return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1510                                                   dimensions);
1511 }
1512 
FindTranspose021(const Shape & a,const Shape & b)1513 /*static*/ absl::optional<std::vector<int64>> ShapeUtil::FindTranspose021(
1514     const Shape& a, const Shape& b) {
1515   if (!CompatibleIgnoringElementType(a, b)) {
1516     return absl::nullopt;
1517   }
1518 
1519   std::vector<int64> permutation(a.dimensions().size());
1520   absl::Span<const int64> minor_to_major_a = LayoutUtil::MinorToMajor(a);
1521   std::vector<int64> major_to_minor_a(minor_to_major_a.rbegin(),
1522                                       minor_to_major_a.rend());
1523   absl::Span<const int64> minor_to_major_b = LayoutUtil::MinorToMajor(b);
1524   std::vector<int64> major_to_minor_b(minor_to_major_b.rbegin(),
1525                                       minor_to_major_b.rend());
1526   for (size_t i = 0; i < permutation.size(); ++i) {
1527     permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
1528   }
1529 
1530   std::vector<size_t> segments = ConsecutiveSegments(permutation);
1531   if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
1532     Shape descending_layout_shape =
1533         ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
1534     Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
1535     absl::Span<const int64> normalized_dims =
1536         AsInt64Slice(normalized_shape.dimensions());
1537     std::vector<int64> dims_021;
1538     if (2 == segments.size()) {
1539       // The logical component-0 is of size one.
1540       dims_021 = {1, normalized_dims[1], normalized_dims[0]};
1541     } else {
1542       dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
1543     }
1544 
1545     return dims_021;
1546   }
1547 
1548   return absl::nullopt;
1549 }
1550 
1551 }  // namespace xla
1552