1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/shape_util.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/inlined_vector.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/numbers.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_join.h"
31 #include "absl/strings/str_split.h"
32 #include "absl/strings/string_view.h"
33 #include "absl/strings/strip.h"
34 #include "absl/types/optional.h"
35 #include "tensorflow/compiler/xla/index_util.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/overflow_util.h"
38 #include "tensorflow/compiler/xla/permutation_util.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/gtl/iterator_range.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/lib/strings/numbers.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/protobuf.h"
49 #include "tensorflow/core/platform/regexp.h"
50
51 namespace xla {
52
53 using absl::StrAppend;
54 using absl::StrCat;
55
56 namespace {
57 // An array that is indexed by PrimitiveType, and returns
58 // the size of each element of that primitive type, or 0
59 // if the PrimitiveType is not a primitive type
60 constexpr uint8 primitive_byte_size[PrimitiveType_ARRAYSIZE] = {
61 0, // PRIMITIVE_TYPE_INVALID = 0,
62 sizeof(int8), // PRED = 1
63 sizeof(int8), // S8 = 2
64 sizeof(int16), // S16 = 3
65 sizeof(int32), // S32 = 4
66 sizeof(int64), // S64 = 5
67 sizeof(uint8), // U8 = 6
68 sizeof(uint16), // U16 = 7
69 sizeof(uint32), // U32 = 8
70 sizeof(uint64), // U64 = 9
71 sizeof(float) / 2, // F16 = 10
72 sizeof(float), // F32 = 11
73 sizeof(double), // F64 = 12
74 0, // TUPLE = 13
75 0, // OPAQUE_TYPE = 14
76 sizeof(complex64), // C64 = 15
77 sizeof(float) / 2, // BF16 = 16
78 0, // TOKEN = 17
79 sizeof(complex128) // C128 = 18
80 };
81 } // namespace
82
ToString() const83 string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
84
ToString() const85 string ShapeIndexView::ToString() const {
86 return StrCat("{", absl::StrJoin(indices_, ","), "}");
87 }
88
operator ==(const ShapeIndexView & other) const89 bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
90 return indices_ == other.indices_;
91 }
92
operator !=(const ShapeIndexView & other) const93 bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
94 return !(*this == other);
95 }
96
operator <<(std::ostream & out,const ShapeIndex & shape_index)97 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
98 out << shape_index.ToString();
99 return out;
100 }
101
operator <<(std::ostream & out,const ShapeIndexView & shape_index)102 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
103 out << shape_index.ToString();
104 return out;
105 }
106
StartsWith(ShapeIndexView prefix) const107 bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const {
108 return size() >= prefix.size() &&
109 indices_.subspan(0, prefix.size()) == prefix.indices_;
110 }
111
IsArrayPrimitiveType(PrimitiveType primitive_type)112 /* static */ bool ShapeUtil::IsArrayPrimitiveType(
113 PrimitiveType primitive_type) {
114 return primitive_util::IsArrayType(primitive_type);
115 }
116
117 namespace {
118 // Constructs and returns the new shape with the given minor_to_major order in
119 // its Layout.
MakeShapeWithLayoutInternal(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)120 StatusOr<Shape> MakeShapeWithLayoutInternal(
121 PrimitiveType element_type, absl::Span<const int64> dimensions,
122 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
123 int64 element_size_in_bits, int64 memory_space) {
124 if (dimensions.size() != minor_to_major.size()) {
125 return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
126 dimensions.size(), minor_to_major.size());
127 }
128 if (element_type == OPAQUE_TYPE || element_type == TUPLE) {
129 return InvalidArgument("Unsupported element type: %s",
130 PrimitiveType_Name(element_type));
131 }
132 TF_ASSIGN_OR_RETURN(Shape shape,
133 ShapeUtil::MakeValidatedShape(element_type, dimensions));
134 if (element_size_in_bits ==
135 ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
136 // Only set element_size_in_bits if it's different from the default value.
137 element_size_in_bits = 0;
138 }
139 *shape.mutable_layout() = LayoutUtil::MakeLayout(
140 minor_to_major, tiles, element_size_in_bits, memory_space);
141 if (!shape.has_layout()) {
142 return InvalidArgument("Shape has no layout.");
143 }
144 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
145 return shape;
146 }
147 } // namespace
148
Equal(const Shape & lhs,const Shape & rhs)149 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
150 bool equal = Shape::Equal()(lhs, rhs);
151
152 if (!equal && VLOG_IS_ON(3)) {
153 VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
154 << ", rhs = " << rhs.ShortDebugString();
155 }
156
157 return equal;
158 }
159
EqualIgnoringElementType(const Shape & lhs,const Shape & rhs)160 /* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
161 const Shape& rhs) {
162 bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
163 if (!equal && VLOG_IS_ON(3)) {
164 VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
165 << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
166 }
167
168 return equal;
169 }
170
EqualIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)171 /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
172 const Shape& rhs) {
173 bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
174 if (!equal && VLOG_IS_ON(3)) {
175 VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
176 << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
177 }
178
179 return equal;
180 }
181
EqualStructure(const Shape & lhs,const Shape & rhs)182 /* static */ bool ShapeUtil::EqualStructure(const Shape& lhs,
183 const Shape& rhs) {
184 bool equal = true;
185 ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
186 equal &= IndexIsValid(rhs, index);
187 });
188 ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
189 equal &= IndexIsValid(lhs, index);
190 });
191
192 return equal;
193 }
194
TrueRank(const Shape & shape)195 /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
196 int64 accum = 0;
197 for (int64 dimension : shape.dimensions()) {
198 // We do not count zero dimensions.
199 if (dimension != 1) {
200 accum += 1;
201 }
202 }
203 return accum;
204 }
205
FillNewShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)206 /* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type,
207 absl::Span<const int64> dimensions,
208 Shape* shape) {
209 const int eint = static_cast<int>(element_type);
210 int64 dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE)
211 ? primitive_byte_size[eint]
212 : 0); // Out of range: force a failure
213 if (dense_shape_size <= 0) {
214 return false;
215 }
216
217 // Verify that array-based lookup is consistent with public API.
218 DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type))
219 << element_type;
220
221 shape->set_element_type(element_type);
222 const int ndims = dimensions.size();
223 auto layout = shape->mutable_layout();
224 layout->set_format(DENSE);
225 auto* minor_to_major = layout->mutable_minor_to_major();
226 for (int i = 0; i < ndims; i++) {
227 const int64 d = dimensions[i];
228 if (d < 0) {
229 return false;
230 }
231 dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
232 if (dense_shape_size < 0) {
233 return false;
234 }
235
236 shape->add_dimensions(d);
237 minor_to_major->push_back(ndims - 1 - i);
238 }
239 return true;
240 }
241
MakeProgramShape(std::initializer_list<Shape> parameters,Shape result)242 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
243 std::initializer_list<Shape> parameters, Shape result) {
244 ProgramShape program_shape;
245 for (const Shape& shape : parameters) {
246 *program_shape.add_parameters() = shape;
247 }
248 *program_shape.mutable_result() = std::move(result);
249 return program_shape;
250 }
251
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions)252 /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
253 absl::Span<const int64> dimensions) {
254 Shape shape;
255 CHECK(FillNewShape(element_type, dimensions, &shape));
256 return shape;
257 }
258
MakeScalarShape(PrimitiveType element_type)259 /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
260 return MakeShape(element_type, {});
261 }
262
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)263 /* static */ Shape ShapeUtil::MakeShape(
264 PrimitiveType element_type, absl::Span<const int64> dimensions,
265 const std::vector<bool>& dynamic_dimensions) {
266 return MakeValidatedShape(element_type, dimensions, dynamic_dimensions)
267 .ValueOrDie();
268 }
269
MakeShapeWithStaticDimensions(const Shape & shape)270 /* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
271 const Shape& shape) {
272 Shape output = shape;
273 output.clear_dynamic_dimensions();
274 return output;
275 }
276
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions)277 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
278 PrimitiveType element_type, absl::Span<const int64> dimensions) {
279 Shape shape;
280 if (!FillNewShape(element_type, dimensions, &shape)) {
281 return InvalidArgument("invalid shape type=%d, dims=[%s]",
282 static_cast<int>(element_type),
283 absl::StrJoin(dimensions, ","));
284 }
285 return shape;
286 }
287
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)288 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
289 PrimitiveType element_type, absl::Span<const int64> dimensions,
290 const std::vector<bool>& dynamic_dimensions) {
291 if (dynamic_dimensions.size() != dimensions.size()) {
292 return InvalidArgument(
293 "dynamic dimensions size %d did not match number of dimensions %d",
294 dynamic_dimensions.size(), dimensions.size());
295 }
296
297 Shape shape;
298 if (!FillNewShape(element_type, dimensions, &shape)) {
299 return InvalidArgument("invalid shape type=%d, dims=[%s]",
300 static_cast<int>(element_type),
301 absl::StrJoin(dimensions, ","));
302 }
303 for (int i = 0, n = dimensions.size(); i < n; i++) {
304 shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
305 }
306 return shape;
307 }
308
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)309 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
310 PrimitiveType element_type, absl::Span<const int64> dimensions,
311 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
312 int64 element_size_in_bits, int64 memory_space) {
313 auto ret =
314 MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major,
315 tiles, element_size_in_bits, memory_space);
316 if (!ret.ok()) LOG(ERROR) << ret.status();
317 return ret.ValueOrDie();
318 }
319
MakeShapeWithDescendingLayout(PrimitiveType element_type,absl::Span<const int64> dimensions)320 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
321 PrimitiveType element_type, absl::Span<const int64> dimensions) {
322 std::vector<int64> layout(dimensions.size());
323 std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
324 return MakeShapeWithLayout(element_type, dimensions, layout);
325 }
326
327 /* static */ Shape
MakeShapeWithDescendingLayoutAndSamePhysicalLayout(const Shape & shape)328 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
329 const Shape& shape) {
330 std::vector<int64> dims(shape.dimensions_size());
331 for (int i = 0; i < shape.dimensions_size(); ++i) {
332 dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
333 }
334 Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
335 // Since the physical layout is kept the same, the tiles and element size are
336 // the same also.
337 new_shape.mutable_layout()->mutable_tiles()->assign(
338 shape.layout().tiles().begin(), shape.layout().tiles().end());
339 new_shape.mutable_layout()->set_element_size_in_bits(
340 shape.layout().element_size_in_bits());
341 for (int i = 0; i < shape.dimensions_size(); ++i) {
342 new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
343 }
344 return new_shape;
345 }
346
PopulateShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)347 /* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,
348 absl::Span<const int64> dimensions,
349 Shape* shape) {
350 shape->Clear();
351 shape->set_element_type(element_type);
352 for (int64 dimension : dimensions) {
353 shape->add_dimensions(dimension);
354 }
355 LayoutUtil::SetToDefaultLayout(shape);
356 return ValidateShape(*shape);
357 }
358
MakeStaticShape(const Shape & original)359 /* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) {
360 Shape result = original;
361 result.clear_dynamic_dimensions();
362 return result;
363 }
364
MakeTupleShape(absl::Span<const Shape> shapes)365 /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
366 Shape result;
367 result.set_element_type(TUPLE);
368 result.mutable_tuple_shapes()->reserve(shapes.size());
369 for (const auto& shape : shapes) {
370 AppendShapeToTuple(shape, &result);
371 }
372 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
373 return result;
374 }
375
MakeOpaqueShape()376 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
377 Shape result;
378 result.set_element_type(OPAQUE_TYPE);
379 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
380 return result;
381 }
382
MakeTokenShape()383 /* static */ Shape ShapeUtil::MakeTokenShape() {
384 Shape result;
385 result.set_element_type(TOKEN);
386 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
387 return result;
388 }
389
AppendShapeToTuple(const Shape & shape,Shape * tuple_shape)390 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
391 Shape* tuple_shape) {
392 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
393 *tuple_shape->add_tuple_shapes() = shape;
394 }
395
UpdateTupleShape(const Shape & shape,int64 index,Shape * tuple_shape)396 /* static */ void ShapeUtil::UpdateTupleShape(const Shape& shape, int64 index,
397 Shape* tuple_shape) {
398 CHECK(index < tuple_shape->tuple_shapes_size());
399 *tuple_shape->mutable_tuple_shapes(index) = shape;
400 }
401
UpdateDynamicDimension(Shape * shape,ShapeIndexView index,int64 dim,bool is_dynamic)402 /* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
403 ShapeIndexView index,
404 int64 dim,
405 bool is_dynamic) {
406 if (index.empty()) {
407 CHECK(!shape->IsTuple());
408 shape->set_dynamic_dimension(dim, is_dynamic);
409 return;
410 }
411
412 UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
413 index.ConsumeFront(), dim, is_dynamic);
414 }
415
AppendMajorDimension(int bound,Shape * shape)416 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
417 CHECK(LayoutUtil::IsDenseArray(*shape));
418 shape->mutable_layout()->add_minor_to_major(shape->rank());
419 shape->add_dimensions(bound);
420 TF_DCHECK_OK(ValidateShape(*shape));
421 }
422
CopyDynamicDimensions(Shape * to,const Shape & from)423 /* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to,
424 const Shape& from) {
425 CHECK_EQ(to->rank(), from.rank());
426 for (int64 i = 0; i < from.rank(); ++i) {
427 to->set_dynamic_dimension(i, from.is_dynamic_dimension(i));
428 }
429 TF_DCHECK_OK(ValidateShape(*to));
430 }
431
ElementIsIntegral(const Shape & shape)432 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
433 return primitive_util::IsIntegralType(shape.element_type());
434 }
435
ElementIsIntegralWithBits(const Shape & shape,int32 bits)436 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
437 int32 bits) {
438 return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
439 }
440
ElementHasBitWidth(const Shape & shape,int bits)441 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
442 if (!shape.IsArray()) {
443 return false;
444 }
445 return primitive_util::BitWidth(shape.element_type()) == bits;
446 }
447
ElementIsSigned(const Shape & shape)448 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
449 switch (shape.element_type()) {
450 case S8:
451 case S16:
452 case S32:
453 case S64:
454 case F16:
455 case BF16:
456 case F32:
457 case F64:
458 return true;
459
460 case PRED:
461 case U8:
462 case U16:
463 case U32:
464 case U64:
465 case C64:
466 case C128:
467 case TUPLE:
468 case OPAQUE_TYPE:
469 case TOKEN:
470 return false;
471
472 default:
473 LOG(FATAL) << "Unhandled element type " << shape.element_type();
474 }
475 }
476
ElementIsComplex(const Shape & shape)477 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
478 return primitive_util::IsComplexType(shape.element_type());
479 }
480
ElementIsFloating(const Shape & shape)481 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
482 return primitive_util::IsFloatingPointType(shape.element_type());
483 }
484
IsNestedTuple(const Shape & shape)485 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
486 return shape.IsTuple() &&
487 absl::c_any_of(shape.tuple_shapes(),
488 [](const Shape& s) { return s.IsTuple(); });
489 }
490
IsEmptyTuple(const Shape & shape)491 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
492 return shape.IsTuple() && TupleElementCount(shape) == 0;
493 }
494
TupleElementCount(const Shape & shape)495 /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
496 CHECK(shape.IsTuple()) << HumanString(shape);
497 return shape.tuple_shapes_size();
498 }
499
GetTupleElementShape(const Shape & shape,int64 index)500 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
501 int64 index) {
502 CHECK(shape.IsTuple());
503 CHECK_GT(TupleElementCount(shape), index);
504 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
505 return shape.tuple_shapes(index);
506 }
507
SubshapeCount(const Shape & shape)508 /* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) {
509 int64 n = 0;
510 ForEachSubshape(shape, [&](const Shape& literal_subshape,
511 const ShapeIndex& index) { ++n; });
512 return n;
513 }
514
SliceTuple(const Shape & tuple,int64 start,int64 limit)515 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
516 int64 limit) {
517 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
518 CHECK(tuple.IsTuple());
519 CHECK_LE(start, TupleElementCount(tuple));
520 CHECK_LE(limit, TupleElementCount(tuple));
521
522 std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
523 tuple.tuple_shapes().begin() + limit);
524 return MakeTupleShape(new_elements);
525 }
526
527 // Returns the shape of a real or imaginary component.
ComplexComponentShape(const Shape & complex_shape)528 /* static */ Shape ShapeUtil::ComplexComponentShape(
529 const Shape& complex_shape) {
530 CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
531 return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
532 complex_shape.element_type()));
533 }
534
ElementsIn(const Shape & shape)535 /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
536 DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape);
537 DCHECK_EQ(shape.dimensions_size(), shape.rank());
538 if (shape.dimensions().size() == 1) {
539 return shape.dimensions()[0];
540 }
541 return std::accumulate<decltype(shape.dimensions().begin()), int64>(
542 shape.dimensions().begin(), shape.dimensions().end(), 1LL,
543 std::multiplies<int64>());
544 }
545
ElementsInRecursive(const Shape & shape)546 /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) {
547 CHECK(shape.IsArray() || shape.IsTuple());
548 if (shape.IsArray()) {
549 return ElementsIn(shape);
550 }
551 int64 count = 0;
552 for (const Shape& element_shape : shape.tuple_shapes()) {
553 count += ElementsInRecursive(element_shape);
554 }
555 return count;
556 }
557
HasPrimitiveType(const Shape & shape,PrimitiveType primitive_type)558 /* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
559 PrimitiveType primitive_type) {
560 if (shape.element_type() == primitive_type) {
561 return true;
562 }
563 for (const Shape& element_shape : shape.tuple_shapes()) {
564 if (HasPrimitiveType(element_shape, primitive_type)) {
565 return true;
566 }
567 }
568 return false;
569 }
570
IsZeroElementArray(const Shape & shape)571 /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
572 return shape.IsArray() && ElementsIn(shape) == 0;
573 }
574
IsScalarWithElementType(const Shape & shape,PrimitiveType element_type)575 /* static */ bool ShapeUtil::IsScalarWithElementType(
576 const Shape& shape, PrimitiveType element_type) {
577 return IsScalar(shape) && shape.element_type() == element_type;
578 }
579
HumanString(const Shape & shape)580 /* static */ string ShapeUtil::HumanString(const Shape& shape) {
581 if (shape.IsTuple()) {
582 string text = "(";
583 const char* prefix = "";
584 for (const Shape& elem_shape : shape.tuple_shapes()) {
585 StrAppend(&text, prefix, HumanString(elem_shape));
586 prefix = ", ";
587 }
588 text += ")";
589 return text;
590 }
591 std::vector<string> dim_elements;
592 for (int i = 0; i < shape.dimensions_size(); ++i) {
593 if (shape.is_dynamic_dimension(i)) {
594 dim_elements.push_back(StrCat("<=", shape.dimensions(i)));
595 } else {
596 dim_elements.push_back(StrCat(shape.dimensions(i)));
597 }
598 }
599 return StrCat(
600 primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
601 absl::StrJoin(dim_elements, ","), "]");
602 }
603
HumanStringWithLayout(const Shape & shape)604 /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
605 if (shape.IsTuple()) {
606 string text = "(";
607 const char* prefix = "";
608 for (const Shape& elem_shape : shape.tuple_shapes()) {
609 StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
610 prefix = ", ";
611 }
612 text += ")";
613 return text;
614 }
615 string result = HumanString(shape);
616 if (IsScalar(shape)) {
617 string layout_str = LayoutUtil::HumanString(shape.layout());
618 // Don't print "{}" as layout for scalars.
619 if (layout_str != "{}") {
620 StrAppend(&result, layout_str);
621 }
622 } else if (shape.IsArray() && LayoutUtil::HasLayout(shape)) {
623 StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
624 }
625 return result;
626 }
627
HumanString(const ProgramShape & program_shape)628 /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
629 std::vector<string> parameters;
630 for (auto& shape : program_shape.parameters()) {
631 const int i = parameters.size();
632 parameters.push_back(StrCat(i < program_shape.parameter_names_size()
633 ? program_shape.parameter_names(i)
634 : "(unknown)",
635 ": ", HumanString(shape)));
636 }
637 return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
638 HumanString(program_shape.result()));
639 }
640
SameDimensions(const Shape & lhs,const Shape & rhs)641 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
642 const Shape& rhs) {
643 CHECK(lhs.IsArray());
644 CHECK(rhs.IsArray());
645 return absl::c_equal(lhs.dimensions(), rhs.dimensions());
646 }
647
SameRank(const Shape & lhs,const Shape & rhs)648 /* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) {
649 CHECK(lhs.IsArray());
650 CHECK(rhs.IsArray());
651 return lhs.rank() == rhs.rank();
652 }
653
Compatible(const Shape & lhs,const Shape & rhs)654 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
655 return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
656 }
657
CompatibleIgnoringElementType(const Shape & lhs,const Shape & rhs)658 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
659 const Shape& rhs) {
660 return Shape::Equal()
661 .IgnoreDynamicDimension()
662 .IgnoreElementType()
663 .IgnoreLayout()(lhs, rhs);
664 }
665
CompatibleKind(const Shape & lhs,const Shape & rhs)666 /* static */ bool ShapeUtil::CompatibleKind(const Shape& lhs,
667 const Shape& rhs) {
668 return Shape::Equal()
669 .IgnoreElementType()
670 .IgnoreLayout()
671 .IgnoreDimensions()
672 .IgnoreDynamicDimension()(lhs, rhs);
673 }
674
CompatibleIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)675 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
676 const Shape& rhs) {
677 return Shape::Equal()
678 .IgnoreDynamicDimension()
679 .IgnoreFpPrecision()
680 .IgnoreLayout()(lhs, rhs);
681 }
682
GetDimension(const Shape & shape,int64 dimension_number)683 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
684 int64 dimension_number) {
685 return shape.dimensions(GetDimensionNumber(shape, dimension_number));
686 }
687
GetDimensionNumber(const Shape & shape,int64 dimension_number)688 /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
689 int64 dimension_number) {
690 if (dimension_number < 0) {
691 dimension_number += shape.rank();
692 }
693 CHECK_GE(dimension_number, 0);
694 return dimension_number;
695 }
696
ByteSizeOfPrimitiveType(PrimitiveType primitive_type)697 /* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
698 PrimitiveType primitive_type) {
699 switch (primitive_type) {
700 case PRED:
701 return sizeof(int8);
702 case S8:
703 return sizeof(int8);
704 case S16:
705 return sizeof(int16);
706 case S32:
707 return sizeof(int32);
708 case S64:
709 return sizeof(int64);
710 case U8:
711 return sizeof(uint8);
712 case U16:
713 return sizeof(uint16);
714 case U32:
715 return sizeof(uint32);
716 case U64:
717 return sizeof(uint64);
718 case BF16:
719 return sizeof(float) / 2;
720 case F16:
721 return sizeof(float) / 2;
722 case F32:
723 return sizeof(float);
724 case F64:
725 return sizeof(double);
726 case C64:
727 return sizeof(complex64);
728 case C128:
729 return sizeof(complex128);
730 case TOKEN:
731 // Tokens require no space.
732 return 0;
733 case TUPLE:
734 case OPAQUE_TYPE:
735 LOG(FATAL) << PrimitiveType_Name(primitive_type)
736 << " primitive type has no definitive size";
737 default:
738 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
739 }
740 }
741
ByteSizeOf(const Shape & shape,int64 pointer_size)742 /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
743 int64 pointer_size) {
744 TF_DCHECK_OK(ValidateShape(shape));
745 if (shape.element_type() == TUPLE) {
746 return ByteSizeOfTupleIndexTable(shape, pointer_size);
747 } else if (shape.IsArray()) {
748 return ByteSizeOfElements(shape);
749 } else if (shape.element_type() == TOKEN) {
750 return 0;
751 } else if (shape.element_type() == OPAQUE_TYPE) {
752 CHECK_GT(pointer_size, 0);
753 return pointer_size;
754 }
755 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
756 << " primitive type has no definitive size";
757 }
758
ByteSizeOfTupleIndexTable(const Shape & shape,int64 pointer_size)759 /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
760 int64 pointer_size) {
761 TF_DCHECK_OK(ValidateShape(shape));
762 CHECK_EQ(TUPLE, shape.element_type());
763 CHECK_GT(pointer_size, 0);
764 return pointer_size * shape.tuple_shapes_size();
765 }
766
ByteSizeOfElements(const Shape & shape)767 /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
768 TF_DCHECK_OK(ValidateShape(shape));
769 CHECK(shape.IsArray());
770 int64 allocated_element_count;
771
772 CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
773 allocated_element_count = ElementsIn(shape);
774 return allocated_element_count *
775 ByteSizeOfPrimitiveType(shape.element_type());
776 }
777
ValidateShapeWithOptionalLayoutInternal(const Shape & shape)778 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
779 const Shape& shape) {
780 if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
781 !PrimitiveType_IsValid(shape.element_type())) {
782 return InvalidArgument("shape has invalid element type: %s",
783 shape.ShortDebugString());
784 }
785 if (shape.element_type() == TUPLE) {
786 if (shape.dimensions_size() != 0) {
787 return InvalidArgument("tuples must not have dimensions specified");
788 }
789 for (auto& element_shape : shape.tuple_shapes()) {
790 TF_RETURN_IF_ERROR(
791 ValidateShapeWithOptionalLayoutInternal(element_shape));
792 }
793 return Status::OK();
794 }
795
796 // Non-tuple shape.
797 if (shape.tuple_shapes_size() > 0) {
798 return InvalidArgument("non-tuple shape has tuple_shapes field");
799 }
800
801 // Tokens and opaques can should not have layout or dimensions.
802 if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) {
803 if (shape.dimensions_size() != 0) {
804 return InvalidArgument(
805 "shape has %s element type, but has dimensions field: %s",
806 primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
807 shape.ShortDebugString());
808 }
809 if (shape.has_layout()) {
810 return InvalidArgument(
811 "shape has %s element type, but has layout field: %s",
812 primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
813 shape.ShortDebugString());
814 }
815 return Status::OK();
816 }
817
818 for (int64 i = 0; i < shape.rank(); ++i) {
819 int64 dimension = shape.dimensions(i);
820 if (dimension < 0) {
821 return InvalidArgument(
822 "shape's dimensions must not be < 0; dimension at index %d was %d", i,
823 dimension);
824 }
825 }
826
827 TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
828 return Status::OK();
829 }
830
ValidateShapeSize(const Shape & shape)831 /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
832 VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
833
834 if (!shape.IsArray()) {
835 return Status::OK();
836 }
837
838 int64 shape_size = [&]() {
839 int64 dense_shape_size = 1;
840 if (shape.dimensions().empty()) {
841 return dense_shape_size;
842 }
843
844 absl::Span<const int64> shape_max_dimensions =
845 AsInt64Slice(shape.dimensions());
846 for (int64 dim : shape_max_dimensions) {
847 dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
848 if (dense_shape_size < 0) {
849 return dense_shape_size;
850 }
851 }
852 dense_shape_size = MultiplyWithoutOverflow(
853 dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
854 return dense_shape_size;
855 }();
856
857 if (shape_size < 0) {
858 return InvalidArgument("Shape %s size may overflow int64.",
859 ShapeUtil::HumanString(shape));
860 }
861
862 VLOG(3) << "Shape size is valid: " << shape_size;
863 return Status::OK();
864 }
865
ValidateShapeWithOptionalLayout(const Shape & shape)866 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
867 const Shape& shape) {
868 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
869
870 return LayoutUtil::ValidateLayoutInShape(shape,
871 /*allow_missing_layouts=*/true);
872 }
873
ValidateShape(const Shape & shape)874 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
875 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
876
877 return LayoutUtil::ValidateLayoutInShape(shape);
878 }
879
ChangeElementType(const Shape & original,PrimitiveType type)880 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
881 PrimitiveType type) {
882 if (original.IsTuple()) {
883 std::vector<Shape> new_operands;
884 new_operands.reserve(original.tuple_shapes_size());
885 for (const Shape& operand : original.tuple_shapes()) {
886 new_operands.push_back(ChangeElementType(operand, type));
887 }
888 return MakeTupleShape(new_operands);
889 } else {
890 Shape new_shape = original;
891 new_shape.set_element_type(type);
892 return new_shape;
893 }
894 }
895
IndexIsValid(const Shape & shape,ShapeIndexView index)896 /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
897 ShapeIndexView index) {
898 const Shape* subshape = &shape;
899 for (auto i : index) {
900 if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) {
901 return false;
902 }
903 subshape = &subshape->tuple_shapes(i);
904 }
905 return true;
906 }
907
GetSubshape(const Shape & shape,ShapeIndexView index)908 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
909 ShapeIndexView index) {
910 const Shape* return_shape = &shape;
911 for (auto i : index) {
912 CHECK(return_shape->IsTuple())
913 << "Invalid index " << index << " for shape " << shape;
914 return_shape = &return_shape->tuple_shapes(i);
915 }
916 return *return_shape;
917 }
918
TryGetSubshape(const Shape & shape,ShapeIndexView index)919 /* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
920 const Shape& shape, ShapeIndexView index) {
921 const Shape* return_shape = &shape;
922 for (auto i : index) {
923 if (!return_shape->IsTuple() || i < 0 ||
924 i >= return_shape->tuple_shapes_size()) {
925 return InvalidArgument(
926 "Shape index %s not a valid subshape index for tuple with shape %s",
927 index.ToString(), shape.DebugString());
928 }
929 return_shape = &return_shape->tuple_shapes(i);
930 }
931 return return_shape;
932 }
933
GetMutableSubshape(Shape * shape,ShapeIndexView index)934 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
935 ShapeIndexView index) {
936 Shape* return_shape = shape;
937 for (auto i : index) {
938 CHECK(return_shape->IsTuple());
939 return_shape = return_shape->mutable_tuple_shapes(i);
940 }
941 return return_shape;
942 }
943
944 /* static */
IsLeafIndex(const Shape & shape,const ShapeIndex & index)945 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
946 return !GetSubshape(shape, index).IsTuple();
947 }
948
GetLeafCount(const Shape & shape)949 /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
950 if (!shape.IsTuple()) {
951 return 1;
952 }
953 int64 count = 0;
954 for (const Shape& subshape : shape.tuple_shapes()) {
955 count += GetLeafCount(subshape);
956 }
957 return count;
958 }
959
GetLeafShapes(const Shape & shape)960 /* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
961 const Shape& shape) {
962 std::vector<IndexedShape> leaves;
963 ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
964 if (IsLeafIndex(shape, index)) {
965 leaves.emplace_back(index, sub_shape);
966 }
967 });
968 return leaves;
969 }
970
HasDegenerateDimensions(const Shape & shape)971 /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
972 CHECK(shape.IsArray());
973 return absl::c_linear_search(shape.dimensions(), 1);
974 }
975
DropDegenerateDimensions(const Shape & shape)976 /* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
977 return FilterDimensions(
978 [&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
979 }
980
981 namespace {
982
983 // Helper for ForEachSubshape which visits the subshapes of the given shape in
984 // DFS pre-order starting with the index.
ForEachSubshapeHelper(const Shape & shape,const ShapeUtil::StatusVisitorFunction & func,ShapeIndex * index)985 Status ForEachSubshapeHelper(const Shape& shape,
986 const ShapeUtil::StatusVisitorFunction& func,
987 ShapeIndex* index) {
988 TF_RETURN_IF_ERROR(func(shape, *index));
989 if (shape.IsTuple()) {
990 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
991 index->push_back(i);
992 TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
993 ShapeUtil::GetTupleElementShape(shape, i), func, index));
994 index->pop_back();
995 }
996 }
997 return Status::OK();
998 }
999
1000 // Helper for ForEachMutableSubshape which visits the subshapes of the given
1001 // shape in DFS pre-order starting with the index.
ForEachMutableSubshapeHelper(Shape * shape,const ShapeUtil::MutatingStatusVisitorFunction & func,ShapeIndex * index)1002 Status ForEachMutableSubshapeHelper(
1003 Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
1004 ShapeIndex* index) {
1005 TF_RETURN_IF_ERROR(func(shape, *index));
1006 if (shape->IsTuple()) {
1007 for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
1008 index->push_back(i);
1009 TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
1010 shape->mutable_tuple_shapes(i), func, index));
1011 index->pop_back();
1012 }
1013 }
1014 return Status::OK();
1015 }
1016
1017 } // namespace
1018
ForEachSubshape(const Shape & shape,const VisitorFunction & func)1019 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
1020 const VisitorFunction& func) {
1021 ShapeIndex index;
1022 ForEachSubshapeHelper(
1023 shape,
1024 [&func](const Shape& subshape, const ShapeIndex& index) {
1025 func(subshape, index);
1026 return Status::OK();
1027 },
1028 &index)
1029 .IgnoreError();
1030 }
1031
ForEachMutableSubshape(Shape * shape,const MutatingVisitorFunction & func)1032 /* static */ void ShapeUtil::ForEachMutableSubshape(
1033 Shape* shape, const MutatingVisitorFunction& func) {
1034 ShapeIndex index;
1035 ForEachMutableSubshapeHelper(
1036 shape,
1037 [&func](Shape* subshape, const ShapeIndex& index) {
1038 func(subshape, index);
1039 return Status::OK();
1040 },
1041 &index)
1042 .IgnoreError();
1043 }
1044
ForEachSubshapeWithStatus(const Shape & shape,const StatusVisitorFunction & func)1045 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
1046 const Shape& shape, const StatusVisitorFunction& func) {
1047 ShapeIndex index;
1048 return ForEachSubshapeHelper(shape, func, &index);
1049 }
1050
ForEachMutableSubshapeWithStatus(Shape * shape,const MutatingStatusVisitorFunction & func)1051 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
1052 Shape* shape, const MutatingStatusVisitorFunction& func) {
1053 ShapeIndex index;
1054 return ForEachMutableSubshapeHelper(shape, func, &index);
1055 }
1056
PermuteDimensions(absl::Span<const int64> permutation,const Shape & shape)1057 /* static */ Shape ShapeUtil::PermuteDimensions(
1058 absl::Span<const int64> permutation, const Shape& shape) {
1059 Shape new_shape = shape;
1060 new_shape.clear_dimensions();
1061 for (auto dim : Permute(shape.dimensions(), permutation)) {
1062 new_shape.add_dimensions(dim);
1063 }
1064 auto inv_permutation = InversePermutation(permutation);
1065 for (int64 i = 0; i < shape.rank(); i++) {
1066 new_shape.set_dynamic_dimension(inv_permutation[i],
1067 shape.is_dynamic_dimension(i));
1068 }
1069
1070 // If `shape` has a layout, by contract we choose a new layout such that the
1071 // transpose defined by this permutation is a bitcast.
1072 //
1073 // Some formalism helps to understand the correct way to do this. We're going
1074 // to do algebra in the group of permutations of the dimensions of `shape`.
1075 //
1076 // Since the order of `shape`'s dimensions is not permuted relative to itself,
1077 // `shape`'s list of dimensions is isomorphic to the identity I.
1078 //
1079 // Let `shape`'s layout be L. A layout is a permutation which maps a
1080 // minor-to-major physical dimension ordering to a shape's logical dimension
1081 // ordering. Therefore the inverse of a layout maps from logical to physical
1082 // dims, and so the physical ordering of I is simply L'.I = L', where L' is
1083 // the inverse of L.
1084 //
1085 // Let the argument `permutation` be P. This is a permutation over `shape`'s
1086 // dimensions, so our return value will be a shape with dims P.I = P. Our
1087 // goal is to construct a layout permutation L* for this shape. The physical
1088 // dimension ordering of this returned shape must be the same as that of the
1089 // original shape, namely L'.
1090 //
1091 // Our returned shape has dims P and layout L*, so its in-memory ordering is
1092 // L*'.P. Setting this equal to L' and solving for L*, we get:
1093 //
1094 // L*'.P = L' =>
1095 // L*' = L'P' =>
1096 // L* = P.L
1097 //
1098 if (shape.has_layout()) {
1099 CHECK(LayoutUtil::IsDenseArray(shape));
1100 Layout* new_layout = new_shape.mutable_layout();
1101 new_layout->set_format(DENSE);
1102 new_layout->clear_minor_to_major();
1103 for (auto index : ComposePermutations(
1104 inv_permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
1105 new_layout->add_minor_to_major(index);
1106 }
1107 // The permutation accepted by TransposeIsBitcast is the inverse of the
1108 // permutation here.
1109 CHECK(TransposeIsBitcast(shape, new_shape, permutation))
1110 << "shape=" << HumanStringWithLayout(shape)
1111 << ", new_shape=" << HumanStringWithLayout(new_shape)
1112 << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
1113 }
1114 return new_shape;
1115 }
1116
1117 /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
InsertedOrDeleted1SizedDimensions(const Shape & shape_pre,const Shape & shape_post)1118 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
1119 const Shape& shape_post) {
1120 CHECK(shape_pre.IsArray());
1121 CHECK(shape_post.IsArray());
1122
1123 auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
1124
1125 std::vector<int64> deleted_indices;
1126 std::vector<int64> inserted_indices;
1127 // Returns false if any input/output index between prior_unmodified_dim_pair
1128 // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
1129 // the degerenate input/output dimensions in the gap to
1130 // deleted_indices/inserted_indices respectively.
1131 auto check_modified_dims =
1132 [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
1133 std::pair<int64, int64> prior_unmodified_dim_pair,
1134 std::pair<int64, int64> unmodified_dim_pair) {
1135 for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
1136 modified_input_dim < unmodified_dim_pair.first;
1137 ++modified_input_dim) {
1138 if (shape_pre.dimensions(modified_input_dim) > 1) {
1139 return false;
1140 }
1141 deleted_indices.push_back(modified_input_dim);
1142 }
1143 for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
1144 modified_output_dim < unmodified_dim_pair.second;
1145 ++modified_output_dim) {
1146 if (shape_post.dimensions(modified_output_dim) > 1) {
1147 return false;
1148 }
1149 inserted_indices.push_back(modified_output_dim);
1150 }
1151 return true;
1152 };
1153
1154 std::vector<std::pair<int64, int64>> unmodified_dims =
1155 DimensionsUnmodifiedByReshape(shape_pre, shape_post);
1156 // Returns nil if the reshape modifies any non-degenerate input/output
1157 // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
1158 // dimensions, so we only need to check whether dimensions in the gaps (thus
1159 // modified) have size >1.
1160 for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
1161 // Check (modified) dimensions between unmodified_dims[i-1] and
1162 // unmodified_dims[i].
1163 auto prior_unmodified_dim_pair =
1164 i > 0 ? unmodified_dims[i - 1] : std::pair<int64, int64>(-1, -1);
1165 auto unmodified_dim_pair =
1166 i < unmodified_dims.size()
1167 ? unmodified_dims[i]
1168 : std::make_pair(shape_pre.rank(), shape_post.rank());
1169 if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
1170 return nil;
1171 }
1172 }
1173
1174 return std::make_tuple(true, deleted_indices, inserted_indices);
1175 }
1176
1177 /* static */ std::vector<std::pair<int64, int64>>
DimensionsUnmodifiedByReshape(const Shape & input_shape,const Shape & output_shape)1178 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
1179 const Shape& output_shape) {
1180 CHECK(input_shape.IsArray());
1181 CHECK(output_shape.IsArray());
1182
1183 // Unmodified dimensions are merely common factors of rank 1.
1184 auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
1185 AsInt64Slice(output_shape.dimensions()));
1186 for (size_t i = 0; i < common_factors.size() - 1;) {
1187 if (1 != common_factors[i + 1].first - common_factors[i].first ||
1188 1 != common_factors[i + 1].second - common_factors[i].second) {
1189 common_factors.erase(common_factors.begin() + i);
1190 } else {
1191 ++i;
1192 }
1193 }
1194 // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
1195 common_factors.pop_back();
1196 return std::vector<std::pair<int64, int64>>(common_factors.begin(),
1197 common_factors.end());
1198 }
1199
1200 /* static */ absl::optional<std::vector<int64>>
ReshapeLeavesDimensionsUnmodified(const Shape & from_shape,const Shape & to_shape,absl::Span<const int64> input_dim_indices)1201 ShapeUtil::ReshapeLeavesDimensionsUnmodified(
1202 const Shape& from_shape, const Shape& to_shape,
1203 absl::Span<const int64> input_dim_indices) {
1204 CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
1205
1206 std::vector<int64> output_dim_indices;
1207 std::vector<std::pair<int64, int64>> unmodified_dims =
1208 ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
1209 size_t i = 0; // index to unmodified_dims
1210 for (int64 input_dim_index : input_dim_indices) {
1211 // Search unmodified_dims for input_dim_index. We can search from the last
1212 // matching position because input_dim_indices is guaranteed to be sorted.
1213 while (i < unmodified_dims.size() &&
1214 unmodified_dims[i].first < input_dim_index) {
1215 ++i;
1216 }
1217 if (i >= unmodified_dims.size() ||
1218 unmodified_dims[i].first != input_dim_index) {
1219 return absl::nullopt;
1220 }
1221 output_dim_indices.push_back(unmodified_dims[i].second);
1222 }
1223 return output_dim_indices;
1224 }
1225
TransposeIsBitcast(const Shape & input_shape,const Shape & output_shape,absl::Span<const int64> dimension_mapping)1226 /* static */ bool ShapeUtil::TransposeIsBitcast(
1227 const Shape& input_shape, const Shape& output_shape,
1228 absl::Span<const int64> dimension_mapping) {
1229 CHECK(LayoutUtil::HasLayout(input_shape) &&
1230 LayoutUtil::HasLayout(output_shape));
1231
1232 if (!SameElementType(input_shape, output_shape)) {
1233 return false;
1234 }
1235
1236 // Check the reshape permutes the positions of each dimension in the
1237 // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
1238 // input_positions = apply(dimension_mapping, output_positions)
1239 //
1240 // Because the positions of each dimension are the inverse permutation of the
1241 // minor-to-major order, the above check is equivalent to
1242 // inverse(input_dimensions) =
1243 // apply(dimension_mapping, inverse(output_dimensions))
1244 // # `I` indicates identity permutation.
1245 // apply(input_dimensions, I) =
1246 // apply(dimension_mapping, apply(output_dimensions, I))
1247 // apply(input_dimensions, I) =
1248 // apply((dimension_mapping * output_dimensions), I)
1249 // input_dimensions = dimension_mapping * output_dimensions
1250 return absl::c_equal(
1251 ComposePermutations(dimension_mapping,
1252 AsInt64Slice(output_shape.layout().minor_to_major())),
1253 input_shape.layout().minor_to_major());
1254 }
1255
ReshapeIsBitcast(const Shape & input_shape,const Shape & output_shape)1256 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
1257 const Shape& output_shape) {
1258 CHECK(input_shape.IsArray());
1259 CHECK(output_shape.IsArray());
1260 CHECK(LayoutUtil::HasLayout(input_shape));
1261 CHECK(LayoutUtil::HasLayout(output_shape));
1262
1263 if (!SameElementType(input_shape, output_shape)) {
1264 return false;
1265 }
1266
1267 CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape))
1268 << "input_shape=" << input_shape.ShortDebugString()
1269 << ", output_shape=" << output_shape.ShortDebugString();
1270 if (ElementsIn(input_shape) == 0) {
1271 return true;
1272 }
1273
1274 // TL;DR: The rest of the method checks that the reshape does not change the
1275 // physical location of any unit input or output index. Unit indices have
1276 // exactly one dimension that equals 1 and other dimensions 0. This condition
1277 // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
1278 // reshape shouldn't change the physical location of any element. It is also a
1279 // sufficient condition as is proved below (note: many details are omitted for
1280 // space).
1281 //
1282 // Definitions:
1283 //
1284 // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
1285 // the size of i-th least significant dimension of IS or OS (this is opposite
1286 // to how we define the index of Shape::dimensions()).
1287 //
1288 // * Given an input or output index I, denote by p(I) I's physical linear
1289 // index (or physical index for short) and l(I) I's logical linear index (or
1290 // logical index for short).
1291 //
1292 // * Given a logical index k, denote by II(k) the input index whose linear
1293 // index is k, and OI(k) the corresponding output index.
1294 //
1295 // * Denote by IT[i] the increment of physical index if i-th dimension of the
1296 // input index is increased by 1. Similarly, OT[i] means the increment if i-th
1297 // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
1298 // is a function of IS or OS and the layout, and not dependent on the specific
1299 // input or output index.
1300 //
1301 // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
1302 // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
1303 // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
1304 // to prove, with every increment on k, the above formula still holds.
1305 //
1306 // First, suppose reshaping from IS to OS is non-factorizable (we discuss
1307 // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
1308 // there exists (i,j) such that
1309 //
1310 // 0<=i<=|IS|
1311 // 0<=j<=|OS|
1312 // |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
1313 // product(IS[i], IS[i+1], ..., IS[|IS|-1])
1314 // = product(OS[j], OS[j+1], ..., OS[|OS|-1])
1315 //
1316 // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
1317 // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
1318 // unit indices which are already tested. This also means IT[0]=OT[0]
1319 // because p(II(1))=IT[0] and p(OI(1))=OT[0].
1320 //
1321 // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
1322 // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
1323 // to the output physical.
1324 //
1325 // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
1326 // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
1327 // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
1328 // logical index k-1 to logical index k, dimension 1 of the input index
1329 // is increased by 1 and dimension 0 is reset to 0 thus decreased by
1330 // IS[0]-1. Therefore, the physical input index is increased by
1331 //
1332 // p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
1333 //
1334 // Because IS[0]<OS[0], the only change to the output index is that its
1335 // dimension 0 is increased by one. Therefore,
1336 //
1337 // p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
1338 //
1339 // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
1340 // p(II(k))=p(OI(k)). Therefore,
1341 // IT[1] - (IS[0]-1) * IT[0] = IT[0]
1342 // IT[1] = IS[0] * IT[0]
1343 // In other words, input dimension 1 is immediately more major than input
1344 // dimension 0. We can now conceptually collapse these two dimensions because
1345 // an increment in the logical index affecting only these two dimensions maps
1346 // to IT[0] in the physical index.
1347 //
1348 // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
1349 // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
1350 // identical.
1351 //
1352 // A factorizable reshape can be factorized into a list of non-factorizable
1353 // sub-reshapes, each of which can be handled similarly to the proof above.
1354 // For example,
1355 //
1356 // [7x9x2x15] -> [63x6x5]
1357 //
1358 // can be factorized into
1359 //
1360 // [7x9] -> [63] and [2x15] -> [6x5].
1361 //
1362 // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
1363 // same logical linear index. According to the factorization, we know
1364 // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
1365 // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
1366 // similar proof, with the increment of the logical index set to
1367 // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
1368 // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
1369 //
1370 // p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
1371 // = p(y2,0,0) + p(0,0,y1,y0)
1372 // = p(y2,y1,y0)
1373 //
1374 // check_input_unit_indices checks one way of the condition: each input unit
1375 // index is mapped to an output index with the same physical location. This
1376 // lambda will be called again with input_shape and output_shape reversed to
1377 // check the other way.
1378 auto check_input_unit_indices = [](const Shape& input_shape,
1379 const Shape& output_shape) {
1380 // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
1381 // as input_shape/output_shape and the dimension-0-major layout. These two
1382 // shapes are used for conversion between logical linear indices and
1383 // multi-dimensional indices.
1384 Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
1385 input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
1386 Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
1387 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
1388
1389 for (int64 input_dim = 0; input_dim < input_shape.rank(); ++input_dim) {
1390 if (input_shape.dimensions(input_dim) <= 1) {
1391 continue;
1392 }
1393
1394 std::vector<int64> input_unit_index(input_shape.rank(), 0);
1395 input_unit_index[input_dim] = 1;
1396 int64 logical_linear_index =
1397 IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
1398 input_unit_index);
1399 // output_index has the same logical linear index as input_unit_index.
1400 std::vector<int64> output_index =
1401 IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
1402 logical_linear_index);
1403 // Check input_unit_index and output_index have the same physical linear
1404 // index.
1405 if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
1406 input_unit_index) !=
1407 IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
1408 output_index)) {
1409 return false;
1410 }
1411 }
1412 return true;
1413 };
1414 return check_input_unit_indices(input_shape, output_shape) &&
1415 check_input_unit_indices(output_shape, input_shape);
1416 }
1417
AlignLayouts(const Shape & input_shape,const Shape & output_shape)1418 /* static */ absl::optional<Shape> ShapeUtil::AlignLayouts(
1419 const Shape& input_shape, const Shape& output_shape) {
1420 CHECK(input_shape.IsArray());
1421 CHECK(output_shape.IsArray());
1422 // Removing trivial dimensions from the shape simplifies the alignment
1423 // algorithm since ones can go in any position.
1424 if (HasDegenerateDimensions(input_shape) ||
1425 HasDegenerateDimensions(output_shape)) {
1426 auto simple_output_shape =
1427 AlignLayouts(DropDegenerateDimensions(input_shape),
1428 DropDegenerateDimensions(output_shape));
1429 if (!simple_output_shape) {
1430 return absl::nullopt;
1431 }
1432
1433 std::vector<int64> layout =
1434 SpanToVector(simple_output_shape->layout().minor_to_major());
1435 // For each one sized dimension in the output, increment the dimension
1436 // numbers in layout that are more minor than the one.
1437 absl::InlinedVector<int64, 8> dim_map;
1438 dim_map.reserve(simple_output_shape->rank());
1439 for (int64 i = 0; i < output_shape.rank(); ++i) {
1440 if (output_shape.dimensions(i) != 1) {
1441 dim_map.push_back(i);
1442 }
1443 }
1444 for (int64& d : layout) {
1445 d = dim_map[d];
1446 }
1447
1448 // Add the ones in descending order to the layout. Descending layouts tend
1449 // to reduce the number of copies inserted in layout assignment.
1450 for (int64 i = output_shape.rank() - 1; i >= 0; --i) {
1451 if (output_shape.dimensions(i) == 1) {
1452 layout.push_back(i);
1453 }
1454 }
1455 Shape output_shape_with_layout = output_shape;
1456 *output_shape_with_layout.mutable_layout() = Layout{layout};
1457 return output_shape_with_layout;
1458 }
1459
1460 int64 input_rank = input_shape.rank();
1461 int64 output_rank = output_shape.rank();
1462
1463 // First, calculate an alignment of the dimensions. A consecutive sequence of
1464 // input dimensions and output dimensions belong to the same alignment part if
1465 // the products of their dimension bounds are the same. In the easiest case,
1466 // an alignment part consists of one input dimension and one output dimension
1467 // which both have the same dimension bound. An alignment part specifies which
1468 // dimensions need to be kept together in a physical layout if we want a
1469 // reshape to be a bitcast. The order of the alignment parts is defined by the
1470 // physical layout of the input shape, so when we construct the layout for the
1471 // output shape we just process the alignment parts in this order, and then
1472 // layout the dimensions belonging to each part in descending (major to minor)
1473 // order.
1474
1475 // Stores the input and output dimension numbers where each alignment part
1476 // starts.
1477 std::vector<std::pair<int64, int64>> alignment;
1478 alignment.push_back({0, 0});
1479
1480 // Stores a mapping from the input dimension to the alignment part it belongs
1481 // to.
1482 std::vector<int64> dimension_to_alignment_index(input_rank);
1483 int64 input_dimension_product = 1, output_dimension_product = 1;
1484 for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) {
1485 // Check if we have reached the end of an alignment part.
1486 if (input_dimension_product == output_dimension_product &&
1487 input_dimension_product > 1) {
1488 alignment.push_back({i, j});
1489 input_dimension_product = output_dimension_product = 1;
1490 }
1491 if (input_dimension_product < output_dimension_product ||
1492 j == output_rank) {
1493 if (i == input_rank) {
1494 return absl::nullopt;
1495 }
1496 dimension_to_alignment_index[i] = alignment.size() - 1;
1497 input_dimension_product *= input_shape.dimensions(i);
1498 ++i;
1499 } else {
1500 output_dimension_product *= output_shape.dimensions(j);
1501 ++j;
1502 }
1503 }
1504 if (input_dimension_product != output_dimension_product) {
1505 return absl::nullopt;
1506 }
1507
1508 // We also need to store an end element so that we know where the last
1509 // alignment part ends.
1510 alignment.push_back({input_rank, output_rank});
1511 // Now check if the physical layout can potentially be aligned to the output
1512 // shape by changing the physical layout of the output shape. We need to check
1513 // that all dimension numbers that belong to the same alignment part appear
1514 // consecutively, and are in descending order. However we can ignore any
1515 // trivial dimension bounds of 1, because they can be placed anywhere.
1516 auto input_dimension_numbers = input_shape.layout().minor_to_major();
1517 std::vector<int64> output_layout;
1518 output_layout.reserve(output_rank);
1519 for (int64 i = 0; i < input_rank;) {
1520 int64 current_dimension_number = input_dimension_numbers[i];
1521
1522 // Trivial dimensions are stripped.
1523 CHECK_NE(input_shape.dimensions(current_dimension_number), 1);
1524 const int64 current_alignment_index =
1525 dimension_to_alignment_index[current_dimension_number];
1526 // Because of the special end element that we added, we can be sure that
1527 // 'current_alignment_index' is < alignment.size() - 1.
1528 CHECK_LT(current_alignment_index, alignment.size() - 1);
1529
1530 // Check that the following 'num_non_trivial_dimensions_in_alignment_part'
1531 // dimension numbers (ignoring dimension numbers with dimension bound 1) are
1532 // in descending order and belong to the current alignment part.
1533 for (int64 j = 0; j < alignment[current_alignment_index + 1].first -
1534 alignment[current_alignment_index].first;
1535 ++i, ++j) {
1536 if (i == input_rank) {
1537 return absl::nullopt;
1538 }
1539 // If the current dimension number belongs to a different alignment part,
1540 // or the dimension numbers are not in descending order, we can return
1541 // early.
1542 if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
1543 current_alignment_index ||
1544 input_dimension_numbers[i] > current_dimension_number) {
1545 return absl::nullopt;
1546 }
1547 current_dimension_number = input_dimension_numbers[i];
1548 }
1549 // The output dimension numbers that belong to the current alignment part
1550 // need to appear in the same descending order as in the input.
1551 for (int64 j = alignment[current_alignment_index + 1].second - 1;
1552 j >= alignment[current_alignment_index].second; --j) {
1553 output_layout.push_back(j);
1554 }
1555 }
1556 CHECK_EQ(output_layout.size(), output_rank);
1557 Shape output_shape_with_layout = MakeShapeWithLayout(
1558 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1559 output_layout);
1560 CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout))
1561 << "reshape is not a bitcast for input_shape: "
1562 << ShapeUtil::HumanStringWithLayout(input_shape)
1563 << " and output_shape_with_layout: "
1564 << ShapeUtil::HumanStringWithLayout(output_shape_with_layout);
1565 return output_shape_with_layout;
1566 }
1567
DeleteDimension(int64 dim_to_delete,Shape shape)1568 /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
1569 Shape shape) {
1570 CHECK(shape.IsArray());
1571 shape.DeleteDimension(dim_to_delete);
1572 return shape;
1573 }
1574
DynamicArrayShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1575 /* static */ bool ShapeUtil::DynamicArrayShapeIsCompatible(
1576 const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1577 if (dynamic_shape.rank() != bounded_shape.rank()) {
1578 return false;
1579 }
1580 for (int64 i = 0; i < dynamic_shape.rank(); ++i) {
1581 if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
1582 return false;
1583 }
1584 }
1585 return true;
1586 }
1587
DynamicShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1588 /* static */ bool ShapeUtil::DynamicShapeIsCompatible(
1589 const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1590 bool compatible = true;
1591 xla::ShapeUtil::ForEachSubshape(dynamic_shape, [&](const Shape& sub_shape,
1592 const ShapeIndex& index) {
1593 if (compatible) {
1594 auto subshape_result = TryGetSubshape(bounded_shape, index);
1595 if (subshape_result.ok()) {
1596 const Shape* bounded_sub_shape = subshape_result.ConsumeValueOrDie();
1597 if (sub_shape.IsTuple()) {
1598 if (!bounded_sub_shape->IsTuple()) {
1599 compatible = false;
1600 }
1601 } else {
1602 if (bounded_sub_shape->IsTuple()) {
1603 compatible = false;
1604 } else if (!sub_shape.is_static() &&
1605 !DynamicArrayShapeIsCompatible(sub_shape,
1606 *bounded_sub_shape)) {
1607 compatible = false;
1608 }
1609 }
1610 } else {
1611 compatible = false;
1612 }
1613 }
1614 });
1615 return compatible;
1616 }
1617
FilterDimensions(const std::function<bool (int64)> & p,Shape shape)1618 /* static */ Shape ShapeUtil::FilterDimensions(
1619 const std::function<bool(int64)>& p, Shape shape) {
1620 CHECK(shape.IsArray());
1621 std::vector<int64> dims_to_delete;
1622 for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
1623 if (!p(i)) {
1624 dims_to_delete.push_back(i);
1625 }
1626 }
1627 for (int64 dim : dims_to_delete) {
1628 shape = DeleteDimension(dim, shape);
1629 }
1630 return shape;
1631 }
1632
Hash(const Shape & shape)1633 /*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
1634 using tensorflow::hash;
1635 using tensorflow::Hash64Combine;
1636
1637 size_t hash_value = hash<PrimitiveType>()(shape.element_type());
1638
1639 if (shape.tuple_shapes().empty()) {
1640 for (int i = 0; i < shape.dimensions_size(); ++i) {
1641 hash_value =
1642 Hash64Combine(hash_value, hash<int64>()(shape.dimensions(i)));
1643 hash_value = Hash64Combine(hash_value,
1644 hash<bool>()(shape.is_dynamic_dimension(i)));
1645 }
1646
1647 hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout()));
1648 } else {
1649 hash_value = 0;
1650 for (const Shape& subshape : shape.tuple_shapes()) {
1651 hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape));
1652 }
1653 }
1654
1655 return hash_value;
1656 }
1657
1658 // Returns the indices of the first elements of all consecutive subarrays of the
1659 // given array. For example:
1660 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
ConsecutiveSegments(absl::Span<const int64> xs)1661 static std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
1662 std::vector<size_t> is = {0};
1663 for (size_t i = 1; i < xs.size(); ++i) {
1664 if (1 != xs[i] - xs[i - 1]) {
1665 is.push_back(i);
1666 }
1667 }
1668 return is;
1669 }
1670
1671 // Merges the sequences of dimensions of the given shape which start at the
1672 // given indices `segs`.
MergeDimensions(absl::Span<const size_t> segs,const Shape & shape)1673 static Shape MergeDimensions(absl::Span<const size_t> segs,
1674 const Shape& shape) {
1675 std::vector<int64> dimensions;
1676 for (size_t i = 1; i <= segs.size(); ++i) {
1677 dimensions.push_back(std::accumulate(
1678 shape.dimensions().begin() + segs[i - 1],
1679 shape.dimensions().begin() +
1680 (segs.size() == i ? shape.dimensions().size() : segs[i]),
1681 1, std::multiplies<int64>()));
1682 }
1683 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1684 dimensions);
1685 }
1686
FindTranspose021(const Shape & a,const Shape & b)1687 /*static*/ absl::optional<std::vector<int64>> ShapeUtil::FindTranspose021(
1688 const Shape& a, const Shape& b) {
1689 if (!CompatibleIgnoringElementType(a, b)) {
1690 return absl::nullopt;
1691 }
1692
1693 std::vector<int64> permutation(a.dimensions().size());
1694 absl::Span<const int64> minor_to_major_a = LayoutUtil::MinorToMajor(a);
1695 std::vector<int64> major_to_minor_a(minor_to_major_a.rbegin(),
1696 minor_to_major_a.rend());
1697 absl::Span<const int64> minor_to_major_b = LayoutUtil::MinorToMajor(b);
1698 std::vector<int64> major_to_minor_b(minor_to_major_b.rbegin(),
1699 minor_to_major_b.rend());
1700 for (size_t i = 0; i < permutation.size(); ++i) {
1701 permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
1702 }
1703
1704 std::vector<size_t> segments = ConsecutiveSegments(permutation);
1705 if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
1706 Shape descending_layout_shape =
1707 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
1708 Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
1709 absl::Span<const int64> normalized_dims =
1710 AsInt64Slice(normalized_shape.dimensions());
1711 std::vector<int64> dims_021;
1712 if (2 == segments.size()) {
1713 // The logical component-0 is of size one.
1714 dims_021 = {1, normalized_dims[1], normalized_dims[0]};
1715 } else {
1716 dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
1717 }
1718
1719 return dims_021;
1720 }
1721
1722 return absl::nullopt;
1723 }
1724
DeviceShapeToHostShape(Shape s)1725 Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
1726 ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) {
1727 if (subshape->IsArray()) {
1728 subshape->mutable_layout()->clear_tiles();
1729 subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);
1730 }
1731 });
1732 return s;
1733 }
1734
ElementCanUpcast(const Shape & from,const Shape & to)1735 /*static*/ bool ShapeUtil::ElementCanUpcast(const Shape& from,
1736 const Shape& to) {
1737 return ElementIsFloating(from) == ElementIsFloating(to) &&
1738 ElementIsSigned(from) == ElementIsSigned(to) &&
1739 HigherPrecisionElementType(from, to) == to.element_type();
1740 }
1741
1742 } // namespace xla
1743