1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/shape_util.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/inlined_vector.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/numbers.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/strings/strip.h"
35 #include "absl/types/optional.h"
36 #include "tensorflow/compiler/xla/index_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/overflow_util.h"
39 #include "tensorflow/compiler/xla/permutation_util.h"
40 #include "tensorflow/compiler/xla/primitive_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/gtl/iterator_range.h"
46 #include "tensorflow/core/lib/hash/hash.h"
47 #include "tensorflow/core/lib/strings/numbers.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/protobuf.h"
50 #include "tensorflow/core/platform/regexp.h"
51
52 namespace xla {
53
54 using absl::StrAppend;
55 using absl::StrCat;
56
57 namespace {
58 // An array that is indexed by PrimitiveType, and returns
59 // the size of each element of that primitive type, or 0
60 // if the PrimitiveType is not a primitive type
61 constexpr uint8 primitive_byte_size[PrimitiveType_ARRAYSIZE] = {
62 0, // PRIMITIVE_TYPE_INVALID = 0,
63 sizeof(int8), // PRED = 1
64 sizeof(int8), // S8 = 2
65 sizeof(int16), // S16 = 3
66 sizeof(int32), // S32 = 4
67 sizeof(int64), // S64 = 5
68 sizeof(uint8), // U8 = 6
69 sizeof(uint16), // U16 = 7
70 sizeof(uint32), // U32 = 8
71 sizeof(uint64), // U64 = 9
72 sizeof(float) / 2, // F16 = 10
73 sizeof(float), // F32 = 11
74 sizeof(double), // F64 = 12
75 0, // TUPLE = 13
76 0, // OPAQUE_TYPE = 14
77 sizeof(complex64), // C64 = 15
78 sizeof(float) / 2, // BF16 = 16
79 0, // TOKEN = 17
80 sizeof(complex128) // C128 = 18
81 };
82 constexpr int64_t kAnnotationPrintInterval = 5;
83 } // namespace
84
ToString() const85 string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
86
ToString() const87 string ShapeIndexView::ToString() const {
88 return StrCat("{", absl::StrJoin(indices_, ","), "}");
89 }
90
operator ==(const ShapeIndexView & other) const91 bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
92 return indices_ == other.indices_;
93 }
94
operator !=(const ShapeIndexView & other) const95 bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
96 return !(*this == other);
97 }
98
operator <<(std::ostream & out,const ShapeIndex & shape_index)99 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
100 out << shape_index.ToString();
101 return out;
102 }
103
operator <<(std::ostream & out,const ShapeIndexView & shape_index)104 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
105 out << shape_index.ToString();
106 return out;
107 }
108
StartsWith(ShapeIndexView prefix) const109 bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const {
110 return size() >= prefix.size() &&
111 indices_.subspan(0, prefix.size()) == prefix.indices_;
112 }
113
IsArrayPrimitiveType(PrimitiveType primitive_type)114 /* static */ bool ShapeUtil::IsArrayPrimitiveType(
115 PrimitiveType primitive_type) {
116 return primitive_util::IsArrayType(primitive_type);
117 }
118
119 namespace {
120 // Constructs and returns the new shape with the given minor_to_major order in
121 // its Layout.
MakeShapeWithLayoutInternal(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64_t element_size_in_bits,int64_t memory_space)122 StatusOr<Shape> MakeShapeWithLayoutInternal(
123 PrimitiveType element_type, absl::Span<const int64> dimensions,
124 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
125 int64_t element_size_in_bits, int64_t memory_space) {
126 if (dimensions.size() != minor_to_major.size()) {
127 return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
128 dimensions.size(), minor_to_major.size());
129 }
130 if (element_type == OPAQUE_TYPE || element_type == TUPLE) {
131 return InvalidArgument("Unsupported element type: %s",
132 PrimitiveType_Name(element_type));
133 }
134 TF_ASSIGN_OR_RETURN(Shape shape,
135 ShapeUtil::MakeValidatedShape(element_type, dimensions));
136 if (element_size_in_bits ==
137 ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
138 // Only set element_size_in_bits if it's different from the default value.
139 element_size_in_bits = 0;
140 }
141 *shape.mutable_layout() = LayoutUtil::MakeLayout(
142 minor_to_major, tiles, element_size_in_bits, memory_space);
143 if (!shape.has_layout()) {
144 return InvalidArgument("Shape has no layout.");
145 }
146 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
147 return shape;
148 }
149 } // namespace
150
Equal(const Shape & lhs,const Shape & rhs)151 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
152 bool equal = Shape::Equal()(lhs, rhs);
153
154 if (!equal && VLOG_IS_ON(3)) {
155 VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
156 << ", rhs = " << rhs.ShortDebugString();
157 }
158
159 return equal;
160 }
161
EqualIgnoringElementType(const Shape & lhs,const Shape & rhs)162 /* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
163 const Shape& rhs) {
164 bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
165 if (!equal && VLOG_IS_ON(3)) {
166 VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
167 << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
168 }
169
170 return equal;
171 }
172
EqualIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)173 /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
174 const Shape& rhs) {
175 bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
176 if (!equal && VLOG_IS_ON(3)) {
177 VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
178 << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
179 }
180
181 return equal;
182 }
183
EqualStructure(const Shape & lhs,const Shape & rhs)184 /* static */ bool ShapeUtil::EqualStructure(const Shape& lhs,
185 const Shape& rhs) {
186 bool equal = true;
187 ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
188 equal &= IndexIsValid(rhs, index);
189 });
190 ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
191 equal &= IndexIsValid(lhs, index);
192 });
193
194 return equal;
195 }
196
TrueRank(const Shape & shape)197 /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
198 int64_t accum = 0;
199 for (int64_t dimension : shape.dimensions()) {
200 // We do not count zero dimensions.
201 if (dimension != 1) {
202 accum += 1;
203 }
204 }
205 return accum;
206 }
207
FillNewShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)208 /* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type,
209 absl::Span<const int64> dimensions,
210 Shape* shape) {
211 const int eint = static_cast<int>(element_type);
212 int64_t dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE)
213 ? primitive_byte_size[eint]
214 : 0); // Out of range: force a failure
215 if (dense_shape_size <= 0) {
216 return false;
217 }
218
219 // Verify that array-based lookup is consistent with public API.
220 DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type))
221 << element_type;
222
223 shape->set_element_type(element_type);
224 const int ndims = dimensions.size();
225 auto layout = shape->mutable_layout();
226 layout->set_format(DENSE);
227 auto* minor_to_major = layout->mutable_minor_to_major();
228 for (int i = 0; i < ndims; i++) {
229 const int64_t d = dimensions[i];
230 if (d < 0) {
231 return false;
232 }
233 dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
234 if (dense_shape_size < 0) {
235 return false;
236 }
237
238 shape->add_dimensions(d);
239 minor_to_major->push_back(ndims - 1 - i);
240 }
241 return true;
242 }
243
MakeProgramShape(std::initializer_list<Shape> parameters,Shape result)244 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
245 std::initializer_list<Shape> parameters, Shape result) {
246 ProgramShape program_shape;
247 for (const Shape& shape : parameters) {
248 *program_shape.add_parameters() = shape;
249 }
250 *program_shape.mutable_result() = std::move(result);
251 return program_shape;
252 }
253
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions)254 /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
255 absl::Span<const int64> dimensions) {
256 Shape shape;
257 CHECK(FillNewShape(element_type, dimensions, &shape));
258 return shape;
259 }
260
MakeScalarShape(PrimitiveType element_type)261 /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
262 return MakeShape(element_type, {});
263 }
264
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)265 /* static */ Shape ShapeUtil::MakeShape(
266 PrimitiveType element_type, absl::Span<const int64> dimensions,
267 const std::vector<bool>& dynamic_dimensions) {
268 return MakeValidatedShape(element_type, dimensions, dynamic_dimensions)
269 .ValueOrDie();
270 }
271
MakeShapeWithStaticDimensions(const Shape & shape)272 /* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
273 const Shape& shape) {
274 Shape output = shape;
275 output.clear_dynamic_dimensions();
276 return output;
277 }
278
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions)279 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
280 PrimitiveType element_type, absl::Span<const int64> dimensions) {
281 Shape shape;
282 if (!FillNewShape(element_type, dimensions, &shape)) {
283 return InvalidArgument("invalid shape type=%d, dims=[%s]",
284 static_cast<int>(element_type),
285 absl::StrJoin(dimensions, ","));
286 }
287 return shape;
288 }
289
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)290 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
291 PrimitiveType element_type, absl::Span<const int64> dimensions,
292 const std::vector<bool>& dynamic_dimensions) {
293 if (dynamic_dimensions.size() != dimensions.size()) {
294 return InvalidArgument(
295 "dynamic dimensions size %d did not match number of dimensions %d",
296 dynamic_dimensions.size(), dimensions.size());
297 }
298
299 Shape shape;
300 if (!FillNewShape(element_type, dimensions, &shape)) {
301 return InvalidArgument("invalid shape type=%d, dims=[%s]",
302 static_cast<int>(element_type),
303 absl::StrJoin(dimensions, ","));
304 }
305 for (int i = 0, n = dimensions.size(); i < n; i++) {
306 shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
307 }
308 return shape;
309 }
310
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64_t element_size_in_bits,int64_t memory_space)311 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
312 PrimitiveType element_type, absl::Span<const int64> dimensions,
313 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
314 int64_t element_size_in_bits, int64_t memory_space) {
315 auto ret =
316 MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major,
317 tiles, element_size_in_bits, memory_space);
318 if (!ret.ok()) LOG(ERROR) << ret.status();
319 return ret.ValueOrDie();
320 }
321
MoveDimToMajor(const Shape & shape,int64_t dim)322 /* static */ Shape ShapeUtil::MoveDimToMajor(const Shape& shape, int64_t dim) {
323 if (shape.IsTuple()) {
324 std::vector<Shape> result_shapes;
325 result_shapes.reserve(shape.tuple_shapes_size());
326 for (const Shape& s : shape.tuple_shapes()) {
327 result_shapes.push_back(MoveDimToMajor(s, dim));
328 }
329 return ShapeUtil::MakeTupleShape(result_shapes);
330 }
331
332 Shape ret = shape;
333 if (!ret.has_layout()) {
334 LayoutUtil::SetToDefaultLayout(&ret);
335 }
336 *ret.mutable_layout() = LayoutUtil::MoveDimToMajor(ret.layout(), dim);
337 DimensionVector minor_to_major;
338 for (int64_t d : LayoutUtil::MinorToMajor(ret)) {
339 if (d != dim) {
340 minor_to_major.push_back(d);
341 }
342 }
343 minor_to_major.push_back(dim);
344 *ret.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
345 return ret;
346 }
347
MakeShapeWithDescendingLayout(PrimitiveType element_type,absl::Span<const int64> dimensions)348 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
349 PrimitiveType element_type, absl::Span<const int64> dimensions) {
350 std::vector<int64> layout(dimensions.size());
351 std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
352 return MakeShapeWithLayout(element_type, dimensions, layout);
353 }
354
355 /* static */ Shape
MakeShapeWithDescendingLayoutAndSamePhysicalLayout(const Shape & shape)356 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
357 const Shape& shape) {
358 std::vector<int64> dims(shape.dimensions_size());
359 for (int i = 0; i < shape.dimensions_size(); ++i) {
360 dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
361 }
362 Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
363 // Since the physical layout is kept the same, the tiles and element size are
364 // the same also.
365 new_shape.mutable_layout()->mutable_tiles()->assign(
366 shape.layout().tiles().begin(), shape.layout().tiles().end());
367 new_shape.mutable_layout()->set_element_size_in_bits(
368 shape.layout().element_size_in_bits());
369 for (int i = 0; i < shape.dimensions_size(); ++i) {
370 new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
371 }
372 return new_shape;
373 }
374
PopulateShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)375 /* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,
376 absl::Span<const int64> dimensions,
377 Shape* shape) {
378 shape->Clear();
379 shape->set_element_type(element_type);
380 for (int64_t dimension : dimensions) {
381 shape->add_dimensions(dimension);
382 }
383 LayoutUtil::SetToDefaultLayout(shape);
384 return ValidateShape(*shape);
385 }
386
MakeStaticShape(const Shape & original)387 /* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) {
388 Shape result = original;
389 result.clear_dynamic_dimensions();
390 return result;
391 }
392
MakeTupleShape(absl::Span<const Shape> shapes)393 /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
394 Shape result;
395 result.set_element_type(TUPLE);
396 result.mutable_tuple_shapes()->reserve(shapes.size());
397 for (const auto& shape : shapes) {
398 AppendShapeToTuple(shape, &result);
399 }
400 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
401 return result;
402 }
403
MakeOpaqueShape()404 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
405 Shape result;
406 result.set_element_type(OPAQUE_TYPE);
407 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
408 return result;
409 }
410
MakeTokenShape()411 /* static */ Shape ShapeUtil::MakeTokenShape() {
412 Shape result;
413 result.set_element_type(TOKEN);
414 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
415 return result;
416 }
417
AppendShapeToTuple(const Shape & shape,Shape * tuple_shape)418 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
419 Shape* tuple_shape) {
420 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
421 *tuple_shape->add_tuple_shapes() = shape;
422 }
423
UpdateTupleShape(const Shape & shape,int64_t index,Shape * tuple_shape)424 /* static */ void ShapeUtil::UpdateTupleShape(const Shape& shape, int64_t index,
425 Shape* tuple_shape) {
426 CHECK(index < tuple_shape->tuple_shapes_size());
427 *tuple_shape->mutable_tuple_shapes(index) = shape;
428 }
429
UpdateDynamicDimension(Shape * shape,ShapeIndexView index,int64_t dim,bool is_dynamic)430 /* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
431 ShapeIndexView index,
432 int64_t dim,
433 bool is_dynamic) {
434 if (index.empty()) {
435 CHECK(!shape->IsTuple());
436 shape->set_dynamic_dimension(dim, is_dynamic);
437 return;
438 }
439
440 UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
441 index.ConsumeFront(), dim, is_dynamic);
442 }
443
AppendMajorDimension(int bound,Shape * shape)444 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
445 CHECK(LayoutUtil::IsDenseArray(*shape));
446 shape->mutable_layout()->add_minor_to_major(shape->rank());
447 shape->add_dimensions(bound);
448 TF_DCHECK_OK(ValidateShape(*shape));
449 }
450
CopyDynamicDimensions(Shape * to,const Shape & from)451 /* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to,
452 const Shape& from) {
453 CHECK_EQ(to->rank(), from.rank());
454 for (int64_t i = 0; i < from.rank(); ++i) {
455 to->set_dynamic_dimension(i, from.is_dynamic_dimension(i));
456 }
457 TF_DCHECK_OK(ValidateShape(*to));
458 }
459
ElementIsIntegral(const Shape & shape)460 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
461 return primitive_util::IsIntegralType(shape.element_type());
462 }
463
ElementIsIntegralWithBits(const Shape & shape,int32_t bits)464 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
465 int32_t bits) {
466 return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
467 }
468
ElementHasBitWidth(const Shape & shape,int bits)469 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
470 if (!shape.IsArray()) {
471 return false;
472 }
473 return primitive_util::BitWidth(shape.element_type()) == bits;
474 }
475
ElementIsSigned(const Shape & shape)476 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
477 switch (shape.element_type()) {
478 case S8:
479 case S16:
480 case S32:
481 case S64:
482 case F16:
483 case BF16:
484 case F32:
485 case F64:
486 return true;
487
488 case PRED:
489 case U8:
490 case U16:
491 case U32:
492 case U64:
493 case C64:
494 case C128:
495 case TUPLE:
496 case OPAQUE_TYPE:
497 case TOKEN:
498 return false;
499
500 default:
501 LOG(FATAL) << "Unhandled element type " << shape.element_type();
502 }
503 }
504
ElementIsComplex(const Shape & shape)505 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
506 return primitive_util::IsComplexType(shape.element_type());
507 }
508
ElementIsFloating(const Shape & shape)509 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
510 return primitive_util::IsFloatingPointType(shape.element_type());
511 }
512
IsNestedTuple(const Shape & shape)513 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
514 return shape.IsTuple() &&
515 absl::c_any_of(shape.tuple_shapes(),
516 [](const Shape& s) { return s.IsTuple(); });
517 }
518
IsEmptyTuple(const Shape & shape)519 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
520 return shape.IsTuple() && TupleElementCount(shape) == 0;
521 }
522
TupleElementCount(const Shape & shape)523 /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
524 CHECK(shape.IsTuple()) << HumanString(shape);
525 return shape.tuple_shapes_size();
526 }
527
GetTupleElementShape(const Shape & shape,int64_t index)528 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
529 int64_t index) {
530 CHECK(shape.IsTuple());
531 CHECK_GT(TupleElementCount(shape), index);
532 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
533 return shape.tuple_shapes(index);
534 }
535
SubshapeCount(const Shape & shape)536 /* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) {
537 int64_t n = 0;
538 ForEachSubshape(shape, [&](const Shape& literal_subshape,
539 const ShapeIndex& index) { ++n; });
540 return n;
541 }
542
SliceTuple(const Shape & tuple,int64_t start,int64_t limit)543 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64_t start,
544 int64_t limit) {
545 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
546 CHECK(tuple.IsTuple());
547 CHECK_LE(start, TupleElementCount(tuple));
548 CHECK_LE(limit, TupleElementCount(tuple));
549
550 std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
551 tuple.tuple_shapes().begin() + limit);
552 return MakeTupleShape(new_elements);
553 }
554
555 // Returns the shape of a real or imaginary component.
ComplexComponentShape(const Shape & complex_shape)556 /* static */ Shape ShapeUtil::ComplexComponentShape(
557 const Shape& complex_shape) {
558 CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
559 return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
560 complex_shape.element_type()));
561 }
562
ElementsIn(const Shape & shape)563 /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
564 DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape);
565 DCHECK_EQ(shape.dimensions_size(), shape.rank());
566 if (shape.dimensions().size() == 1) {
567 return shape.dimensions()[0];
568 }
569 return std::accumulate<decltype(shape.dimensions().begin()), int64>(
570 shape.dimensions().begin(), shape.dimensions().end(), 1LL,
571 std::multiplies<int64>());
572 }
573
ElementsInRecursive(const Shape & shape)574 /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) {
575 CHECK(shape.IsArray() || shape.IsTuple());
576 if (shape.IsArray()) {
577 return ElementsIn(shape);
578 }
579 int64_t count = 0;
580 for (const Shape& element_shape : shape.tuple_shapes()) {
581 count += ElementsInRecursive(element_shape);
582 }
583 return count;
584 }
585
HasPrimitiveType(const Shape & shape,PrimitiveType primitive_type)586 /* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
587 PrimitiveType primitive_type) {
588 if (shape.element_type() == primitive_type) {
589 return true;
590 }
591 for (const Shape& element_shape : shape.tuple_shapes()) {
592 if (HasPrimitiveType(element_shape, primitive_type)) {
593 return true;
594 }
595 }
596 return false;
597 }
598
IsZeroElementArray(const Shape & shape)599 /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
600 return shape.IsArray() && ElementsIn(shape) == 0;
601 }
602
IsScalarWithElementType(const Shape & shape,PrimitiveType element_type)603 /* static */ bool ShapeUtil::IsScalarWithElementType(
604 const Shape& shape, PrimitiveType element_type) {
605 return IsScalar(shape) && shape.element_type() == element_type;
606 }
607
HumanString(const Shape & shape)608 /* static */ string ShapeUtil::HumanString(const Shape& shape) {
609 if (shape.IsTuple()) {
610 string text = "(";
611 const auto& tuple_shapes = shape.tuple_shapes();
612 for (int64_t i = 0; i < tuple_shapes.size(); ++i) {
613 const Shape& elem_shape = tuple_shapes[i];
614 if (i != 0) {
615 StrAppend(&text, ", ");
616 if (i % kAnnotationPrintInterval == 0) {
617 StrAppend(&text, absl::StrFormat("/*index=%lld*/", i));
618 }
619 }
620 StrAppend(&text, HumanString(elem_shape));
621 }
622 text += ")";
623 return text;
624 }
625 std::vector<string> dim_elements;
626 for (int i = 0; i < shape.dimensions_size(); ++i) {
627 if (shape.is_dynamic_dimension(i)) {
628 dim_elements.push_back(StrCat("<=", shape.dimensions(i)));
629 } else {
630 dim_elements.push_back(StrCat(shape.dimensions(i)));
631 }
632 }
633 return StrCat(
634 primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
635 absl::StrJoin(dim_elements, ","), "]");
636 }
637
HumanStringWithLayout(const Shape & shape)638 /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
639 if (shape.IsTuple()) {
640 string text = "(";
641 const auto& tuple_shapes = shape.tuple_shapes();
642 for (int64_t i = 0; i < tuple_shapes.size(); ++i) {
643 const Shape& elem_shape = tuple_shapes[i];
644 if (i != 0) {
645 StrAppend(&text, ", ");
646 if (i % kAnnotationPrintInterval == 0) {
647 StrAppend(&text, absl::StrFormat("/*index=%lld*/", i));
648 }
649 }
650 StrAppend(&text, HumanStringWithLayout(elem_shape));
651 }
652 text += ")";
653 return text;
654 }
655 string result = HumanString(shape);
656 if (IsScalar(shape)) {
657 string layout_str = LayoutUtil::HumanString(shape.layout());
658 // Don't print "{}" as layout for scalars.
659 if (layout_str != "{}") {
660 StrAppend(&result, layout_str);
661 }
662 } else if (shape.IsArray() && LayoutUtil::HasLayout(shape)) {
663 StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
664 }
665 return result;
666 }
667
HumanString(const ProgramShape & program_shape)668 /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
669 std::vector<string> parameters;
670 for (auto& shape : program_shape.parameters()) {
671 const int i = parameters.size();
672 parameters.push_back(StrCat(i < program_shape.parameter_names_size()
673 ? program_shape.parameter_names(i)
674 : "(unknown)",
675 ": ", HumanString(shape)));
676 }
677 return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
678 HumanString(program_shape.result()));
679 }
680
SameDimensions(const Shape & lhs,const Shape & rhs)681 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
682 const Shape& rhs) {
683 CHECK(lhs.IsArray());
684 CHECK(rhs.IsArray());
685 return absl::c_equal(lhs.dimensions(), rhs.dimensions());
686 }
687
SameRank(const Shape & lhs,const Shape & rhs)688 /* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) {
689 CHECK(lhs.IsArray());
690 CHECK(rhs.IsArray());
691 return lhs.rank() == rhs.rank();
692 }
693
Compatible(const Shape & lhs,const Shape & rhs)694 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
695 return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
696 }
697
CompatibleIgnoringElementType(const Shape & lhs,const Shape & rhs)698 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
699 const Shape& rhs) {
700 return Shape::Equal()
701 .IgnoreDynamicDimension()
702 .IgnoreElementType()
703 .IgnoreLayout()(lhs, rhs);
704 }
705
CompatibleKind(const Shape & lhs,const Shape & rhs)706 /* static */ bool ShapeUtil::CompatibleKind(const Shape& lhs,
707 const Shape& rhs) {
708 return Shape::Equal()
709 .IgnoreElementType()
710 .IgnoreLayout()
711 .IgnoreDimensions()
712 .IgnoreDynamicDimension()(lhs, rhs);
713 }
714
CompatibleIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)715 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
716 const Shape& rhs) {
717 return Shape::Equal()
718 .IgnoreDynamicDimension()
719 .IgnoreFpPrecision()
720 .IgnoreLayout()(lhs, rhs);
721 }
722
GetDimension(const Shape & shape,int64_t dimension_number)723 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
724 int64_t dimension_number) {
725 return shape.dimensions(GetDimensionNumber(shape, dimension_number));
726 }
727
GetDimensionNumber(const Shape & shape,int64_t dimension_number)728 /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
729 int64_t dimension_number) {
730 if (dimension_number < 0) {
731 dimension_number += shape.rank();
732 }
733 CHECK_GE(dimension_number, 0);
734 return dimension_number;
735 }
736
ByteSizeOfPrimitiveType(PrimitiveType primitive_type)737 /* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
738 PrimitiveType primitive_type) {
739 switch (primitive_type) {
740 case PRED:
741 return sizeof(int8);
742 case S8:
743 return sizeof(int8);
744 case S16:
745 return sizeof(int16);
746 case S32:
747 return sizeof(int32);
748 case S64:
749 return sizeof(int64);
750 case U8:
751 return sizeof(uint8);
752 case U16:
753 return sizeof(uint16);
754 case U32:
755 return sizeof(uint32);
756 case U64:
757 return sizeof(uint64);
758 case BF16:
759 return sizeof(float) / 2;
760 case F16:
761 return sizeof(float) / 2;
762 case F32:
763 return sizeof(float);
764 case F64:
765 return sizeof(double);
766 case C64:
767 return sizeof(complex64);
768 case C128:
769 return sizeof(complex128);
770 case TOKEN:
771 // Tokens require no space.
772 return 0;
773 case TUPLE:
774 case OPAQUE_TYPE:
775 LOG(FATAL) << PrimitiveType_Name(primitive_type)
776 << " primitive type has no definitive size";
777 default:
778 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
779 }
780 }
781
ByteSizeOf(const Shape & shape,int64_t pointer_size)782 /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
783 int64_t pointer_size) {
784 TF_DCHECK_OK(ValidateShape(shape));
785 if (shape.element_type() == TUPLE) {
786 return ByteSizeOfTupleIndexTable(shape, pointer_size);
787 } else if (shape.IsArray()) {
788 return ByteSizeOfElements(shape);
789 } else if (shape.element_type() == TOKEN) {
790 return 0;
791 } else if (shape.element_type() == OPAQUE_TYPE) {
792 CHECK_GT(pointer_size, 0);
793 return pointer_size;
794 }
795 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
796 << " primitive type has no definitive size";
797 }
798
ByteSizeOfTupleIndexTable(const Shape & shape,int64_t pointer_size)799 /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
800 int64_t pointer_size) {
801 TF_DCHECK_OK(ValidateShape(shape));
802 CHECK_EQ(TUPLE, shape.element_type());
803 CHECK_GT(pointer_size, 0);
804 return pointer_size * shape.tuple_shapes_size();
805 }
806
ByteSizeOfElements(const Shape & shape)807 /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
808 TF_DCHECK_OK(ValidateShape(shape));
809 CHECK(shape.IsArray());
810 int64_t allocated_element_count;
811
812 CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
813 allocated_element_count = ElementsIn(shape);
814 return allocated_element_count *
815 ByteSizeOfPrimitiveType(shape.element_type());
816 }
817
ValidateShapeWithOptionalLayoutInternal(const Shape & shape)818 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
819 const Shape& shape) {
820 if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
821 !PrimitiveType_IsValid(shape.element_type())) {
822 return InvalidArgument("shape has invalid element type: %s",
823 shape.ShortDebugString());
824 }
825 if (shape.element_type() == TUPLE) {
826 if (shape.dimensions_size() != 0) {
827 return InvalidArgument("tuples must not have dimensions specified");
828 }
829 for (auto& element_shape : shape.tuple_shapes()) {
830 TF_RETURN_IF_ERROR(
831 ValidateShapeWithOptionalLayoutInternal(element_shape));
832 }
833 return Status::OK();
834 }
835
836 // Non-tuple shape.
837 if (shape.tuple_shapes_size() > 0) {
838 return InvalidArgument("non-tuple shape has tuple_shapes field");
839 }
840
841 // Tokens and opaques can should not have layout or dimensions.
842 if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) {
843 if (shape.dimensions_size() != 0) {
844 return InvalidArgument(
845 "shape has %s element type, but has dimensions field: %s",
846 primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
847 shape.ShortDebugString());
848 }
849 if (shape.has_layout()) {
850 return InvalidArgument(
851 "shape has %s element type, but has layout field: %s",
852 primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
853 shape.ShortDebugString());
854 }
855 return Status::OK();
856 }
857
858 for (int64_t i = 0; i < shape.rank(); ++i) {
859 int64_t dimension = shape.dimensions(i);
860 if (dimension < 0) {
861 return InvalidArgument(
862 "shape's dimensions must not be < 0; dimension at index %d was %d", i,
863 dimension);
864 }
865 }
866
867 TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
868 return Status::OK();
869 }
870
ValidateShapeSize(const Shape & shape)871 /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
872 VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
873
874 if (!shape.IsArray()) {
875 return Status::OK();
876 }
877
878 int64_t shape_size = [&]() {
879 int64_t dense_shape_size = 1;
880 if (shape.dimensions().empty()) {
881 return dense_shape_size;
882 }
883
884 absl::Span<const int64> shape_max_dimensions =
885 AsInt64Slice(shape.dimensions());
886 for (int64_t dim : shape_max_dimensions) {
887 dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
888 if (dense_shape_size < 0) {
889 return dense_shape_size;
890 }
891 }
892 dense_shape_size = MultiplyWithoutOverflow(
893 dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
894 return dense_shape_size;
895 }();
896
897 if (shape_size < 0) {
898 return InvalidArgument("Shape %s size may overflow int64.",
899 ShapeUtil::HumanString(shape));
900 }
901
902 VLOG(3) << "Shape size is valid: " << shape_size;
903 return Status::OK();
904 }
905
ValidateShapeWithOptionalLayout(const Shape & shape)906 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
907 const Shape& shape) {
908 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
909
910 return LayoutUtil::ValidateLayoutInShape(shape,
911 /*allow_missing_layouts=*/true);
912 }
913
ValidateShape(const Shape & shape)914 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
915 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
916
917 return LayoutUtil::ValidateLayoutInShape(shape);
918 }
919
ChangeElementType(const Shape & original,PrimitiveType type)920 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
921 PrimitiveType type) {
922 if (original.IsTuple()) {
923 std::vector<Shape> new_operands;
924 new_operands.reserve(original.tuple_shapes_size());
925 for (const Shape& operand : original.tuple_shapes()) {
926 new_operands.push_back(ChangeElementType(operand, type));
927 }
928 return MakeTupleShape(new_operands);
929 } else {
930 Shape new_shape = original;
931 new_shape.set_element_type(type);
932 return new_shape;
933 }
934 }
935
IndexIsValid(const Shape & shape,ShapeIndexView index)936 /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
937 ShapeIndexView index) {
938 const Shape* subshape = &shape;
939 for (auto i : index) {
940 if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) {
941 return false;
942 }
943 subshape = &subshape->tuple_shapes(i);
944 }
945 return true;
946 }
947
GetSubshape(const Shape & shape,ShapeIndexView index)948 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
949 ShapeIndexView index) {
950 const Shape* return_shape = &shape;
951 for (auto i : index) {
952 CHECK(return_shape->IsTuple())
953 << "Invalid index " << index << " for shape " << shape;
954 return_shape = &return_shape->tuple_shapes(i);
955 }
956 return *return_shape;
957 }
958
TryGetSubshape(const Shape & shape,ShapeIndexView index)959 /* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
960 const Shape& shape, ShapeIndexView index) {
961 const Shape* return_shape = &shape;
962 for (auto i : index) {
963 if (!return_shape->IsTuple() || i < 0 ||
964 i >= return_shape->tuple_shapes_size()) {
965 return InvalidArgument(
966 "Shape index %s not a valid subshape index for tuple with shape %s",
967 index.ToString(), shape.DebugString());
968 }
969 return_shape = &return_shape->tuple_shapes(i);
970 }
971 return return_shape;
972 }
973
GetMutableSubshape(Shape * shape,ShapeIndexView index)974 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
975 ShapeIndexView index) {
976 Shape* return_shape = shape;
977 for (auto i : index) {
978 CHECK(return_shape->IsTuple());
979 return_shape = return_shape->mutable_tuple_shapes(i);
980 }
981 return return_shape;
982 }
983
984 /* static */
IsLeafIndex(const Shape & shape,const ShapeIndex & index)985 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
986 return !GetSubshape(shape, index).IsTuple();
987 }
988
GetLeafCount(const Shape & shape)989 /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
990 if (!shape.IsTuple()) {
991 return 1;
992 }
993 int64_t count = 0;
994 for (const Shape& subshape : shape.tuple_shapes()) {
995 count += GetLeafCount(subshape);
996 }
997 return count;
998 }
999
GetLeafShapes(const Shape & shape)1000 /* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
1001 const Shape& shape) {
1002 std::vector<IndexedShape> leaves;
1003 ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
1004 if (IsLeafIndex(shape, index)) {
1005 leaves.emplace_back(index, sub_shape);
1006 }
1007 });
1008 return leaves;
1009 }
1010
HasDegenerateDimensions(const Shape & shape)1011 /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
1012 CHECK(shape.IsArray());
1013 return absl::c_linear_search(shape.dimensions(), 1);
1014 }
1015
DropDegenerateDimensions(const Shape & shape)1016 /* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
1017 return FilterDimensions(
1018 [&](int64_t dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
1019 }
1020
1021 namespace {
1022
1023 // Helper for ForEachSubshape which visits the subshapes of the given shape in
1024 // DFS pre-order starting with the index.
ForEachSubshapeHelper(const Shape & shape,const ShapeUtil::StatusVisitorFunction & func,ShapeIndex * index)1025 Status ForEachSubshapeHelper(const Shape& shape,
1026 const ShapeUtil::StatusVisitorFunction& func,
1027 ShapeIndex* index) {
1028 TF_RETURN_IF_ERROR(func(shape, *index));
1029 if (shape.IsTuple()) {
1030 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
1031 index->push_back(i);
1032 TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
1033 ShapeUtil::GetTupleElementShape(shape, i), func, index));
1034 index->pop_back();
1035 }
1036 }
1037 return Status::OK();
1038 }
1039
1040 // Helper for ForEachMutableSubshape which visits the subshapes of the given
1041 // shape in DFS pre-order starting with the index.
ForEachMutableSubshapeHelper(Shape * shape,const ShapeUtil::MutatingStatusVisitorFunction & func,ShapeIndex * index)1042 Status ForEachMutableSubshapeHelper(
1043 Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
1044 ShapeIndex* index) {
1045 TF_RETURN_IF_ERROR(func(shape, *index));
1046 if (shape->IsTuple()) {
1047 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
1048 index->push_back(i);
1049 TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
1050 shape->mutable_tuple_shapes(i), func, index));
1051 index->pop_back();
1052 }
1053 }
1054 return Status::OK();
1055 }
1056
1057 } // namespace
1058
ForEachSubshape(const Shape & shape,const VisitorFunction & func)1059 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
1060 const VisitorFunction& func) {
1061 ShapeIndex index;
1062 ForEachSubshapeHelper(
1063 shape,
1064 [&func](const Shape& subshape, const ShapeIndex& index) {
1065 func(subshape, index);
1066 return Status::OK();
1067 },
1068 &index)
1069 .IgnoreError();
1070 }
1071
ForEachMutableSubshape(Shape * shape,const MutatingVisitorFunction & func)1072 /* static */ void ShapeUtil::ForEachMutableSubshape(
1073 Shape* shape, const MutatingVisitorFunction& func) {
1074 ShapeIndex index;
1075 ForEachMutableSubshapeHelper(
1076 shape,
1077 [&func](Shape* subshape, const ShapeIndex& index) {
1078 func(subshape, index);
1079 return Status::OK();
1080 },
1081 &index)
1082 .IgnoreError();
1083 }
1084
ForEachSubshapeWithStatus(const Shape & shape,const StatusVisitorFunction & func)1085 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
1086 const Shape& shape, const StatusVisitorFunction& func) {
1087 ShapeIndex index;
1088 return ForEachSubshapeHelper(shape, func, &index);
1089 }
1090
ForEachMutableSubshapeWithStatus(Shape * shape,const MutatingStatusVisitorFunction & func)1091 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
1092 Shape* shape, const MutatingStatusVisitorFunction& func) {
1093 ShapeIndex index;
1094 return ForEachMutableSubshapeHelper(shape, func, &index);
1095 }
1096
PermuteDimensions(absl::Span<const int64> permutation,const Shape & shape)1097 /* static */ Shape ShapeUtil::PermuteDimensions(
1098 absl::Span<const int64> permutation, const Shape& shape) {
1099 Shape new_shape = shape;
1100 new_shape.clear_dimensions();
1101 for (auto dim : Permute(shape.dimensions(), permutation)) {
1102 new_shape.add_dimensions(dim);
1103 }
1104 auto inv_permutation = InversePermutation(permutation);
1105 for (int64_t i = 0; i < shape.rank(); i++) {
1106 new_shape.set_dynamic_dimension(inv_permutation[i],
1107 shape.is_dynamic_dimension(i));
1108 }
1109
1110 // If `shape` has a layout, by contract we choose a new layout such that the
1111 // transpose defined by this permutation is a bitcast.
1112 //
1113 // Some formalism helps to understand the correct way to do this. We're going
1114 // to do algebra in the group of permutations of the dimensions of `shape`.
1115 //
1116 // Since the order of `shape`'s dimensions is not permuted relative to itself,
1117 // `shape`'s list of dimensions is isomorphic to the identity I.
1118 //
1119 // Let `shape`'s layout be L. A layout is a permutation which maps a
1120 // minor-to-major physical dimension ordering to a shape's logical dimension
1121 // ordering. Therefore the inverse of a layout maps from logical to physical
1122 // dims, and so the physical ordering of I is simply L'.I = L', where L' is
1123 // the inverse of L.
1124 //
1125 // Let the argument `permutation` be P. This is a permutation over `shape`'s
1126 // dimensions, so our return value will be a shape with dims P.I = P. Our
1127 // goal is to construct a layout permutation L* for this shape. The physical
1128 // dimension ordering of this returned shape must be the same as that of the
1129 // original shape, namely L'.
1130 //
1131 // Our returned shape has dims P and layout L*, so its in-memory ordering is
1132 // L*'.P. Setting this equal to L' and solving for L*, we get:
1133 //
1134 // L*'.P = L' =>
1135 // L*' = L'P' =>
1136 // L* = P.L
1137 //
1138 if (shape.has_layout()) {
1139 CHECK(LayoutUtil::IsDenseArray(shape));
1140 Layout* new_layout = new_shape.mutable_layout();
1141 new_layout->set_format(DENSE);
1142 new_layout->clear_minor_to_major();
1143 for (auto index : ComposePermutations(
1144 inv_permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
1145 new_layout->add_minor_to_major(index);
1146 }
1147 // The permutation accepted by TransposeIsBitcast is the inverse of the
1148 // permutation here.
1149 CHECK(TransposeIsBitcast(shape, new_shape, permutation))
1150 << "shape=" << HumanStringWithLayout(shape)
1151 << ", new_shape=" << HumanStringWithLayout(new_shape)
1152 << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
1153 }
1154 return new_shape;
1155 }
1156
1157 /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
InsertedOrDeleted1SizedDimensions(const Shape & shape_pre,const Shape & shape_post)1158 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
1159 const Shape& shape_post) {
1160 CHECK(shape_pre.IsArray());
1161 CHECK(shape_post.IsArray());
1162
1163 auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
1164
1165 std::vector<int64> deleted_indices;
1166 std::vector<int64> inserted_indices;
1167 // Returns false if any input/output index between prior_unmodified_dim_pair
1168 // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
1169 // the degerenate input/output dimensions in the gap to
1170 // deleted_indices/inserted_indices respectively.
1171 auto check_modified_dims =
1172 [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
1173 std::pair<int64, int64> prior_unmodified_dim_pair,
1174 std::pair<int64, int64> unmodified_dim_pair) {
1175 for (int64_t modified_input_dim = prior_unmodified_dim_pair.first + 1;
1176 modified_input_dim < unmodified_dim_pair.first;
1177 ++modified_input_dim) {
1178 if (shape_pre.dimensions(modified_input_dim) > 1) {
1179 return false;
1180 }
1181 deleted_indices.push_back(modified_input_dim);
1182 }
1183 for (int64_t modified_output_dim = prior_unmodified_dim_pair.second + 1;
1184 modified_output_dim < unmodified_dim_pair.second;
1185 ++modified_output_dim) {
1186 if (shape_post.dimensions(modified_output_dim) > 1) {
1187 return false;
1188 }
1189 inserted_indices.push_back(modified_output_dim);
1190 }
1191 return true;
1192 };
1193
1194 std::vector<std::pair<int64, int64>> unmodified_dims =
1195 DimensionsUnmodifiedByReshape(shape_pre, shape_post);
1196 // Returns nil if the reshape modifies any non-degenerate input/output
1197 // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
1198 // dimensions, so we only need to check whether dimensions in the gaps (thus
1199 // modified) have size >1.
1200 for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
1201 // Check (modified) dimensions between unmodified_dims[i-1] and
1202 // unmodified_dims[i].
1203 auto prior_unmodified_dim_pair =
1204 i > 0 ? unmodified_dims[i - 1] : std::pair<int64, int64>(-1, -1);
1205 auto unmodified_dim_pair =
1206 i < unmodified_dims.size()
1207 ? unmodified_dims[i]
1208 : std::make_pair(shape_pre.rank(), shape_post.rank());
1209 if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
1210 return nil;
1211 }
1212 }
1213
1214 return std::make_tuple(true, deleted_indices, inserted_indices);
1215 }
1216
1217 /* static */ std::vector<std::pair<int64, int64>>
DimensionsUnmodifiedByReshape(const Shape & input_shape,const Shape & output_shape)1218 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
1219 const Shape& output_shape) {
1220 CHECK(input_shape.IsArray());
1221 CHECK(output_shape.IsArray());
1222
1223 // Unmodified dimensions are merely common factors of rank 1.
1224 auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
1225 AsInt64Slice(output_shape.dimensions()));
1226 for (size_t i = 0; i < common_factors.size() - 1;) {
1227 if (1 != common_factors[i + 1].first - common_factors[i].first ||
1228 1 != common_factors[i + 1].second - common_factors[i].second) {
1229 common_factors.erase(common_factors.begin() + i);
1230 } else {
1231 ++i;
1232 }
1233 }
1234 // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
1235 common_factors.pop_back();
1236 return std::vector<std::pair<int64, int64>>(common_factors.begin(),
1237 common_factors.end());
1238 }
1239
1240 /* static */ absl::optional<std::vector<int64>>
ReshapeLeavesDimensionsUnmodified(const Shape & from_shape,const Shape & to_shape,absl::Span<const int64> input_dim_indices)1241 ShapeUtil::ReshapeLeavesDimensionsUnmodified(
1242 const Shape& from_shape, const Shape& to_shape,
1243 absl::Span<const int64> input_dim_indices) {
1244 CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
1245
1246 std::vector<int64> output_dim_indices;
1247 std::vector<std::pair<int64, int64>> unmodified_dims =
1248 ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
1249 size_t i = 0; // index to unmodified_dims
1250 for (int64_t input_dim_index : input_dim_indices) {
1251 // Search unmodified_dims for input_dim_index. We can search from the last
1252 // matching position because input_dim_indices is guaranteed to be sorted.
1253 while (i < unmodified_dims.size() &&
1254 unmodified_dims[i].first < input_dim_index) {
1255 ++i;
1256 }
1257 if (i >= unmodified_dims.size() ||
1258 unmodified_dims[i].first != input_dim_index) {
1259 return absl::nullopt;
1260 }
1261 output_dim_indices.push_back(unmodified_dims[i].second);
1262 }
1263 return output_dim_indices;
1264 }
1265
TransposeIsBitcast(const Shape & input_shape,const Shape & output_shape,absl::Span<const int64> dimension_mapping)1266 /* static */ bool ShapeUtil::TransposeIsBitcast(
1267 const Shape& input_shape, const Shape& output_shape,
1268 absl::Span<const int64> dimension_mapping) {
1269 CHECK(LayoutUtil::HasLayout(input_shape) &&
1270 LayoutUtil::HasLayout(output_shape));
1271
1272 if (!SameElementType(input_shape, output_shape)) {
1273 return false;
1274 }
1275
1276 // Check the reshape permutes the positions of each dimension in the
1277 // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
1278 // input_positions = apply(dimension_mapping, output_positions)
1279 //
1280 // Because the positions of each dimension are the inverse permutation of the
1281 // minor-to-major order, the above check is equivalent to
1282 // inverse(input_dimensions) =
1283 // apply(dimension_mapping, inverse(output_dimensions))
1284 // # `I` indicates identity permutation.
1285 // apply(input_dimensions, I) =
1286 // apply(dimension_mapping, apply(output_dimensions, I))
1287 // apply(input_dimensions, I) =
1288 // apply((dimension_mapping * output_dimensions), I)
1289 // input_dimensions = dimension_mapping * output_dimensions
1290 return absl::c_equal(
1291 ComposePermutations(dimension_mapping,
1292 AsInt64Slice(output_shape.layout().minor_to_major())),
1293 input_shape.layout().minor_to_major());
1294 }
1295
ReshapeIsBitcast(const Shape & input_shape,const Shape & output_shape)1296 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
1297 const Shape& output_shape) {
1298 CHECK(input_shape.IsArray());
1299 CHECK(output_shape.IsArray());
1300 CHECK(LayoutUtil::HasLayout(input_shape));
1301 CHECK(LayoutUtil::HasLayout(output_shape));
1302
1303 if (!SameElementType(input_shape, output_shape)) {
1304 return false;
1305 }
1306
1307 CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape))
1308 << "input_shape=" << input_shape.ShortDebugString()
1309 << ", output_shape=" << output_shape.ShortDebugString();
1310 if (ElementsIn(input_shape) == 0) {
1311 return true;
1312 }
1313
1314 // TL;DR: The rest of the method checks that the reshape does not change the
1315 // physical location of any unit input or output index. Unit indices have
1316 // exactly one dimension that equals 1 and other dimensions 0. This condition
1317 // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
1318 // reshape shouldn't change the physical location of any element. It is also a
1319 // sufficient condition as is proved below (note: many details are omitted for
1320 // space).
1321 //
1322 // Definitions:
1323 //
1324 // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
1325 // the size of i-th least significant dimension of IS or OS (this is opposite
1326 // to how we define the index of Shape::dimensions()).
1327 //
1328 // * Given an input or output index I, denote by p(I) I's physical linear
1329 // index (or physical index for short) and l(I) I's logical linear index (or
1330 // logical index for short).
1331 //
1332 // * Given a logical index k, denote by II(k) the input index whose linear
1333 // index is k, and OI(k) the corresponding output index.
1334 //
1335 // * Denote by IT[i] the increment of physical index if i-th dimension of the
1336 // input index is increased by 1. Similarly, OT[i] means the increment if i-th
1337 // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
1338 // is a function of IS or OS and the layout, and not dependent on the specific
1339 // input or output index.
1340 //
1341 // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
1342 // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
1343 // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
1344 // to prove, with every increment on k, the above formula still holds.
1345 //
1346 // First, suppose reshaping from IS to OS is non-factorizable (we discuss
1347 // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
1348 // there exists (i,j) such that
1349 //
1350 // 0<=i<=|IS|
1351 // 0<=j<=|OS|
1352 // |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
1353 // product(IS[i], IS[i+1], ..., IS[|IS|-1])
1354 // = product(OS[j], OS[j+1], ..., OS[|OS|-1])
1355 //
1356 // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
1357 // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
1358 // unit indices which are already tested. This also means IT[0]=OT[0]
1359 // because p(II(1))=IT[0] and p(OI(1))=OT[0].
1360 //
1361 // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
1362 // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
1363 // to the output physical.
1364 //
1365 // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
1366 // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
1367 // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
1368 // logical index k-1 to logical index k, dimension 1 of the input index
1369 // is increased by 1 and dimension 0 is reset to 0 thus decreased by
1370 // IS[0]-1. Therefore, the physical input index is increased by
1371 //
1372 // p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
1373 //
1374 // Because IS[0]<OS[0], the only change to the output index is that its
1375 // dimension 0 is increased by one. Therefore,
1376 //
1377 // p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
1378 //
1379 // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
1380 // p(II(k))=p(OI(k)). Therefore,
1381 // IT[1] - (IS[0]-1) * IT[0] = IT[0]
1382 // IT[1] = IS[0] * IT[0]
1383 // In other words, input dimension 1 is immediately more major than input
1384 // dimension 0. We can now conceptually collapse these two dimensions because
1385 // an increment in the logical index affecting only these two dimensions maps
1386 // to IT[0] in the physical index.
1387 //
1388 // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
1389 // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
1390 // identical.
1391 //
1392 // A factorizable reshape can be factorized into a list of non-factorizable
1393 // sub-reshapes, each of which can be handled similarly to the proof above.
1394 // For example,
1395 //
1396 // [7x9x2x15] -> [63x6x5]
1397 //
1398 // can be factorized into
1399 //
1400 // [7x9] -> [63] and [2x15] -> [6x5].
1401 //
1402 // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
1403 // same logical linear index. According to the factorization, we know
1404 // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
1405 // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
1406 // similar proof, with the increment of the logical index set to
1407 // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
1408 // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
1409 //
1410 // p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
1411 // = p(y2,0,0) + p(0,0,y1,y0)
1412 // = p(y2,y1,y0)
1413 //
1414 // check_input_unit_indices checks one way of the condition: each input unit
1415 // index is mapped to an output index with the same physical location. This
1416 // lambda will be called again with input_shape and output_shape reversed to
1417 // check the other way.
1418 auto check_input_unit_indices = [](const Shape& input_shape,
1419 const Shape& output_shape) {
1420 // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
1421 // as input_shape/output_shape and the dimension-0-major layout. These two
1422 // shapes are used for conversion between logical linear indices and
1423 // multi-dimensional indices.
1424 Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
1425 input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
1426 Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
1427 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
1428
1429 for (int64_t input_dim = 0; input_dim < input_shape.rank(); ++input_dim) {
1430 if (input_shape.dimensions(input_dim) <= 1) {
1431 continue;
1432 }
1433
1434 std::vector<int64> input_unit_index(input_shape.rank(), 0);
1435 input_unit_index[input_dim] = 1;
1436 int64_t logical_linear_index =
1437 IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
1438 input_unit_index);
1439 // output_index has the same logical linear index as input_unit_index.
1440 std::vector<int64> output_index =
1441 IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
1442 logical_linear_index);
1443 // Check input_unit_index and output_index have the same physical linear
1444 // index.
1445 if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
1446 input_unit_index) !=
1447 IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
1448 output_index)) {
1449 return false;
1450 }
1451 }
1452 return true;
1453 };
1454 return check_input_unit_indices(input_shape, output_shape) &&
1455 check_input_unit_indices(output_shape, input_shape);
1456 }
1457
AlignLayouts(const Shape & input_shape,const Shape & output_shape)1458 /* static */ absl::optional<Shape> ShapeUtil::AlignLayouts(
1459 const Shape& input_shape, const Shape& output_shape) {
1460 CHECK(input_shape.IsArray());
1461 CHECK(output_shape.IsArray());
1462 // Removing trivial dimensions from the shape simplifies the alignment
1463 // algorithm since ones can go in any position.
1464 if (HasDegenerateDimensions(input_shape) ||
1465 HasDegenerateDimensions(output_shape)) {
1466 auto simple_output_shape =
1467 AlignLayouts(DropDegenerateDimensions(input_shape),
1468 DropDegenerateDimensions(output_shape));
1469 if (!simple_output_shape) {
1470 return absl::nullopt;
1471 }
1472
1473 std::vector<int64> layout =
1474 SpanToVector(simple_output_shape->layout().minor_to_major());
1475 // For each one sized dimension in the output, increment the dimension
1476 // numbers in layout that are more minor than the one.
1477 absl::InlinedVector<int64, 8> dim_map;
1478 dim_map.reserve(simple_output_shape->rank());
1479 for (int64_t i = 0; i < output_shape.rank(); ++i) {
1480 if (output_shape.dimensions(i) != 1) {
1481 dim_map.push_back(i);
1482 }
1483 }
1484 for (int64& d : layout) {
1485 d = dim_map[d];
1486 }
1487
1488 // Add the ones in descending order to the layout. Descending layouts tend
1489 // to reduce the number of copies inserted in layout assignment.
1490 for (int64_t i = output_shape.rank() - 1; i >= 0; --i) {
1491 if (output_shape.dimensions(i) == 1) {
1492 layout.push_back(i);
1493 }
1494 }
1495 Shape output_shape_with_layout = output_shape;
1496 *output_shape_with_layout.mutable_layout() = Layout{layout};
1497 return output_shape_with_layout;
1498 }
1499
1500 int64_t input_rank = input_shape.rank();
1501 int64_t output_rank = output_shape.rank();
1502
1503 // First, calculate an alignment of the dimensions. A consecutive sequence of
1504 // input dimensions and output dimensions belong to the same alignment part if
1505 // the products of their dimension bounds are the same. In the easiest case,
1506 // an alignment part consists of one input dimension and one output dimension
1507 // which both have the same dimension bound. An alignment part specifies which
1508 // dimensions need to be kept together in a physical layout if we want a
1509 // reshape to be a bitcast. The order of the alignment parts is defined by the
1510 // physical layout of the input shape, so when we construct the layout for the
1511 // output shape we just process the alignment parts in this order, and then
1512 // layout the dimensions belonging to each part in descending (major to minor)
1513 // order.
1514
1515 // Stores the input and output dimension numbers where each alignment part
1516 // starts.
1517 std::vector<std::pair<int64, int64>> alignment;
1518 alignment.push_back({0, 0});
1519
1520 // Stores a mapping from the input dimension to the alignment part it belongs
1521 // to.
1522 std::vector<int64> dimension_to_alignment_index(input_rank);
1523 int64_t input_dimension_product = 1, output_dimension_product = 1;
1524 for (int64_t i = 0, j = 0; i < input_rank || j < output_rank;) {
1525 // Check if we have reached the end of an alignment part.
1526 if (input_dimension_product == output_dimension_product &&
1527 input_dimension_product > 1) {
1528 alignment.push_back({i, j});
1529 input_dimension_product = output_dimension_product = 1;
1530 }
1531 if (input_dimension_product < output_dimension_product ||
1532 j == output_rank) {
1533 if (i == input_rank) {
1534 return absl::nullopt;
1535 }
1536 dimension_to_alignment_index[i] = alignment.size() - 1;
1537 input_dimension_product *= input_shape.dimensions(i);
1538 ++i;
1539 } else {
1540 output_dimension_product *= output_shape.dimensions(j);
1541 ++j;
1542 }
1543 }
1544 if (input_dimension_product != output_dimension_product) {
1545 return absl::nullopt;
1546 }
1547
1548 // We also need to store an end element so that we know where the last
1549 // alignment part ends.
1550 alignment.push_back({input_rank, output_rank});
1551 // Now check if the physical layout can potentially be aligned to the output
1552 // shape by changing the physical layout of the output shape. We need to check
1553 // that all dimension numbers that belong to the same alignment part appear
1554 // consecutively, and are in descending order. However we can ignore any
1555 // trivial dimension bounds of 1, because they can be placed anywhere.
1556 auto input_dimension_numbers = input_shape.layout().minor_to_major();
1557 std::vector<int64> output_layout;
1558 output_layout.reserve(output_rank);
1559 for (int64_t i = 0; i < input_rank;) {
1560 int64_t current_dimension_number = input_dimension_numbers[i];
1561
1562 // Trivial dimensions are stripped.
1563 CHECK_NE(input_shape.dimensions(current_dimension_number), 1);
1564 const int64_t current_alignment_index =
1565 dimension_to_alignment_index[current_dimension_number];
1566 // Because of the special end element that we added, we can be sure that
1567 // 'current_alignment_index' is < alignment.size() - 1.
1568 CHECK_LT(current_alignment_index, alignment.size() - 1);
1569
1570 // Check that the following 'num_non_trivial_dimensions_in_alignment_part'
1571 // dimension numbers (ignoring dimension numbers with dimension bound 1) are
1572 // in descending order and belong to the current alignment part.
1573 for (int64_t j = 0; j < alignment[current_alignment_index + 1].first -
1574 alignment[current_alignment_index].first;
1575 ++i, ++j) {
1576 if (i == input_rank) {
1577 return absl::nullopt;
1578 }
1579 // If the current dimension number belongs to a different alignment part,
1580 // or the dimension numbers are not in descending order, we can return
1581 // early.
1582 if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
1583 current_alignment_index ||
1584 input_dimension_numbers[i] > current_dimension_number) {
1585 return absl::nullopt;
1586 }
1587 current_dimension_number = input_dimension_numbers[i];
1588 }
1589 // The output dimension numbers that belong to the current alignment part
1590 // need to appear in the same descending order as in the input.
1591 for (int64_t j = alignment[current_alignment_index + 1].second - 1;
1592 j >= alignment[current_alignment_index].second; --j) {
1593 output_layout.push_back(j);
1594 }
1595 }
1596 CHECK_EQ(output_layout.size(), output_rank);
1597 Shape output_shape_with_layout = MakeShapeWithLayout(
1598 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1599 output_layout);
1600 CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout))
1601 << "reshape is not a bitcast for input_shape: "
1602 << ShapeUtil::HumanStringWithLayout(input_shape)
1603 << " and output_shape_with_layout: "
1604 << ShapeUtil::HumanStringWithLayout(output_shape_with_layout);
1605 return output_shape_with_layout;
1606 }
1607
DeleteDimension(int64_t dim_to_delete,Shape shape)1608 /* static */ Shape ShapeUtil::DeleteDimension(int64_t dim_to_delete,
1609 Shape shape) {
1610 CHECK(shape.IsArray());
1611 shape.DeleteDimension(dim_to_delete);
1612 return shape;
1613 }
1614
DynamicArrayShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1615 /* static */ bool ShapeUtil::DynamicArrayShapeIsCompatible(
1616 const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1617 if (dynamic_shape.rank() != bounded_shape.rank()) {
1618 return false;
1619 }
1620 for (int64_t i = 0; i < dynamic_shape.rank(); ++i) {
1621 if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
1622 return false;
1623 }
1624 }
1625 return true;
1626 }
1627
DynamicShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)1628 /* static */ bool ShapeUtil::DynamicShapeIsCompatible(
1629 const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
1630 bool compatible = true;
1631 xla::ShapeUtil::ForEachSubshape(dynamic_shape, [&](const Shape& sub_shape,
1632 const ShapeIndex& index) {
1633 if (compatible) {
1634 auto subshape_result = TryGetSubshape(bounded_shape, index);
1635 if (subshape_result.ok()) {
1636 const Shape* bounded_sub_shape = subshape_result.ConsumeValueOrDie();
1637 if (sub_shape.IsTuple()) {
1638 if (!bounded_sub_shape->IsTuple()) {
1639 compatible = false;
1640 }
1641 } else {
1642 if (bounded_sub_shape->IsTuple()) {
1643 compatible = false;
1644 } else if (!sub_shape.is_static() &&
1645 !DynamicArrayShapeIsCompatible(sub_shape,
1646 *bounded_sub_shape)) {
1647 compatible = false;
1648 }
1649 }
1650 } else {
1651 compatible = false;
1652 }
1653 }
1654 });
1655 return compatible;
1656 }
1657
FilterDimensions(const std::function<bool (int64_t)> & p,Shape shape)1658 /* static */ Shape ShapeUtil::FilterDimensions(
1659 const std::function<bool(int64_t)>& p, Shape shape) {
1660 CHECK(shape.IsArray());
1661 std::vector<int64> dims_to_delete;
1662 for (int64_t i = shape.dimensions().size() - 1; i >= 0; --i) {
1663 if (!p(i)) {
1664 dims_to_delete.push_back(i);
1665 }
1666 }
1667 for (int64_t dim : dims_to_delete) {
1668 shape = DeleteDimension(dim, shape);
1669 }
1670 return shape;
1671 }
1672
Hash(const Shape & shape)1673 /*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
1674 using tensorflow::hash;
1675 using tensorflow::Hash64Combine;
1676
1677 size_t hash_value = hash<PrimitiveType>()(shape.element_type());
1678
1679 if (shape.tuple_shapes().empty()) {
1680 for (int i = 0; i < shape.dimensions_size(); ++i) {
1681 hash_value =
1682 Hash64Combine(hash_value, hash<int64>()(shape.dimensions(i)));
1683 hash_value = Hash64Combine(hash_value,
1684 hash<bool>()(shape.is_dynamic_dimension(i)));
1685 }
1686
1687 hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout()));
1688 } else {
1689 hash_value = 0;
1690 for (const Shape& subshape : shape.tuple_shapes()) {
1691 hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape));
1692 }
1693 }
1694
1695 return hash_value;
1696 }
1697
1698 // Returns the indices of the first elements of all consecutive subarrays of the
1699 // given array. For example:
1700 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
ConsecutiveSegments(absl::Span<const int64> xs)1701 static std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
1702 std::vector<size_t> is = {0};
1703 for (size_t i = 1; i < xs.size(); ++i) {
1704 if (1 != xs[i] - xs[i - 1]) {
1705 is.push_back(i);
1706 }
1707 }
1708 return is;
1709 }
1710
1711 // Merges the sequences of dimensions of the given shape which start at the
1712 // given indices `segs`.
MergeDimensions(absl::Span<const size_t> segs,const Shape & shape)1713 static Shape MergeDimensions(absl::Span<const size_t> segs,
1714 const Shape& shape) {
1715 std::vector<int64> dimensions;
1716 for (size_t i = 1; i <= segs.size(); ++i) {
1717 dimensions.push_back(std::accumulate(
1718 shape.dimensions().begin() + segs[i - 1],
1719 shape.dimensions().begin() +
1720 (segs.size() == i ? shape.dimensions().size() : segs[i]),
1721 1, std::multiplies<int64>()));
1722 }
1723 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1724 dimensions);
1725 }
1726
FindTranspose021(const Shape & a,const Shape & b)1727 /*static*/ absl::optional<std::vector<int64>> ShapeUtil::FindTranspose021(
1728 const Shape& a, const Shape& b) {
1729 if (!CompatibleIgnoringElementType(a, b)) {
1730 return absl::nullopt;
1731 }
1732
1733 std::vector<int64> permutation(a.dimensions().size());
1734 absl::Span<const int64> minor_to_major_a = LayoutUtil::MinorToMajor(a);
1735 std::vector<int64> major_to_minor_a(minor_to_major_a.rbegin(),
1736 minor_to_major_a.rend());
1737 absl::Span<const int64> minor_to_major_b = LayoutUtil::MinorToMajor(b);
1738 std::vector<int64> major_to_minor_b(minor_to_major_b.rbegin(),
1739 minor_to_major_b.rend());
1740 for (size_t i = 0; i < permutation.size(); ++i) {
1741 permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
1742 }
1743
1744 std::vector<size_t> segments = ConsecutiveSegments(permutation);
1745 if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
1746 Shape descending_layout_shape =
1747 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
1748 Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
1749 absl::Span<const int64> normalized_dims =
1750 AsInt64Slice(normalized_shape.dimensions());
1751 std::vector<int64> dims_021;
1752 if (2 == segments.size()) {
1753 // The logical component-0 is of size one.
1754 dims_021 = {1, normalized_dims[1], normalized_dims[0]};
1755 } else {
1756 dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
1757 }
1758
1759 return dims_021;
1760 }
1761
1762 return absl::nullopt;
1763 }
1764
DeviceShapeToHostShape(Shape s)1765 Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
1766 ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) {
1767 if (subshape->IsArray()) {
1768 subshape->mutable_layout()->clear_tiles();
1769 subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);
1770 }
1771 });
1772 return s;
1773 }
1774
ElementCanUpcast(const Shape & from,const Shape & to)1775 /*static*/ bool ShapeUtil::ElementCanUpcast(const Shape& from,
1776 const Shape& to) {
1777 return ElementIsFloating(from) == ElementIsFloating(to) &&
1778 ElementIsSigned(from) == ElementIsSigned(to) &&
1779 HigherPrecisionElementType(from, to) == to.element_type();
1780 }
1781
1782 /*static*/
ByteStrides(const Shape & shape,absl::Span<int64_t> strides)1783 Status ShapeUtil::ByteStrides(const Shape& shape, absl::Span<int64_t> strides) {
1784 TF_RET_CHECK(shape.IsArray());
1785 TF_RET_CHECK(shape.has_layout());
1786 TF_RET_CHECK(shape.dimensions_size() == strides.size());
1787
1788 int64_t stride = ByteSizeOfPrimitiveType(shape.element_type());
1789 for (int i : shape.layout().minor_to_major()) {
1790 strides.at(i) = stride;
1791 stride *= shape.dimensions(i);
1792 }
1793 return Status::OK();
1794 }
1795
1796 } // namespace xla
1797