1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/shape_util.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/ascii.h"
27 #include "absl/strings/numbers.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/strings/str_split.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/strings/strip.h"
33 #include "absl/types/optional.h"
34 #include "tensorflow/compiler/xla/index_util.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/overflow_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/iterator_range.h"
43 #include "tensorflow/core/lib/hash/hash.h"
44 #include "tensorflow/core/lib/strings/numbers.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/protobuf.h"
47 #include "tensorflow/core/platform/regexp.h"
48
49 namespace xla {
50
51 using absl::StrAppend;
52 using absl::StrCat;
53
ToString() const54 string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
55
ToString() const56 string ShapeIndexView::ToString() const {
57 return StrCat("{", absl::StrJoin(indices_, ","), "}");
58 }
59
operator ==(const ShapeIndexView & other) const60 bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
61 return indices_ == other.indices_;
62 }
63
operator !=(const ShapeIndexView & other) const64 bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
65 return !(*this == other);
66 }
67
operator <<(std::ostream & out,const ShapeIndex & shape_index)68 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
69 out << shape_index.ToString();
70 return out;
71 }
72
operator <<(std::ostream & out,const ShapeIndexView & shape_index)73 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
74 out << shape_index.ToString();
75 return out;
76 }
77
StartsWith(ShapeIndexView prefix) const78 bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const {
79 return size() >= prefix.size() &&
80 indices_.subspan(0, prefix.size()) == prefix.indices_;
81 }
82
IsArrayPrimitiveType(PrimitiveType primitive_type)83 /* static */ bool ShapeUtil::IsArrayPrimitiveType(
84 PrimitiveType primitive_type) {
85 return primitive_util::IsArrayType(primitive_type);
86 }
87
88 namespace {
89 // Constructs and returns the new shape with the given minor_to_major order in
90 // its Layout.
MakeShapeWithLayoutInternal(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits)91 StatusOr<Shape> MakeShapeWithLayoutInternal(
92 PrimitiveType element_type, absl::Span<const int64> dimensions,
93 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
94 int64 element_size_in_bits) {
95 if (dimensions.size() != minor_to_major.size()) {
96 return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
97 dimensions.size(), minor_to_major.size());
98 }
99 if (element_type == OPAQUE || element_type == TUPLE) {
100 return InvalidArgument("Unsupported element type: %s",
101 PrimitiveType_Name(element_type));
102 }
103 TF_ASSIGN_OR_RETURN(Shape shape,
104 ShapeUtil::MakeValidatedShape(element_type, dimensions));
105 *shape.mutable_layout() =
106 LayoutUtil::MakeLayout(minor_to_major, tiles, element_size_in_bits);
107 if (!shape.has_layout()) {
108 return InvalidArgument("Shape has no layout.");
109 }
110 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
111 return shape;
112 }
113 } // namespace
114
Equal(const Shape & lhs,const Shape & rhs)115 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
116 bool equal = Shape::Equal()(lhs, rhs);
117
118 if (!equal && VLOG_IS_ON(3)) {
119 VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
120 << ", rhs = " << rhs.ShortDebugString();
121 }
122
123 return equal;
124 }
125
EqualIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)126 /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
127 const Shape& rhs) {
128 bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
129 if (!equal && VLOG_IS_ON(3)) {
130 VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
131 << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
132 }
133
134 return equal;
135 }
136
TrueRank(const Shape & shape)137 /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
138 int64 accum = 0;
139 for (int64 dimension : shape.dimensions()) {
140 // We do not count zero dimensions.
141 if (dimension != 1) {
142 accum += 1;
143 }
144 }
145 return accum;
146 }
147
MakeProgramShape(std::initializer_list<Shape> parameters,Shape result)148 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
149 std::initializer_list<Shape> parameters, Shape result) {
150 ProgramShape program_shape;
151 for (const Shape& shape : parameters) {
152 *program_shape.add_parameters() = shape;
153 }
154 *program_shape.mutable_result() = std::move(result);
155 return program_shape;
156 }
157
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions)158 /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
159 absl::Span<const int64> dimensions) {
160 return MakeValidatedShape(element_type, dimensions).ValueOrDie();
161 }
162
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)163 /* static */ Shape ShapeUtil::MakeShape(
164 PrimitiveType element_type, absl::Span<const int64> dimensions,
165 const std::vector<bool>& dynamic_dimensions) {
166 return MakeValidatedShape(element_type, dimensions, dynamic_dimensions)
167 .ValueOrDie();
168 }
169
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions)170 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
171 PrimitiveType element_type, absl::Span<const int64> dimensions) {
172 CHECK(IsArrayPrimitiveType(element_type)) << element_type;
173 Shape result;
174 TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result));
175 return result;
176 }
177
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)178 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
179 PrimitiveType element_type, absl::Span<const int64> dimensions,
180 const std::vector<bool>& dynamic_dimensions) {
181 TF_ASSIGN_OR_RETURN(Shape shape,
182 MakeValidatedShape(element_type, dimensions));
183 for (int i = 0; i < dynamic_dimensions.size(); ++i) {
184 shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
185 }
186 return shape;
187 }
188
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits)189 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
190 PrimitiveType element_type, absl::Span<const int64> dimensions,
191 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
192 int64 element_size_in_bits) {
193 return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major,
194 tiles, element_size_in_bits)
195 .ValueOrDie();
196 }
197
MakeShapeWithDescendingLayout(PrimitiveType element_type,absl::Span<const int64> dimensions)198 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
199 PrimitiveType element_type, absl::Span<const int64> dimensions) {
200 std::vector<int64> layout(dimensions.size());
201 std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
202 return MakeShapeWithLayout(element_type, dimensions, layout);
203 }
204
MakeShapeWithSparseLayout(PrimitiveType element_type,absl::Span<const int64> dimensions,int64 max_sparse_elements)205 /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
206 PrimitiveType element_type, absl::Span<const int64> dimensions,
207 int64 max_sparse_elements) {
208 CHECK(IsArrayPrimitiveType(element_type));
209 Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
210 *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
211 TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
212 return shape;
213 }
214
215 /* static */ Shape
MakeShapeWithDescendingLayoutAndSamePhysicalLayout(const Shape & shape)216 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
217 const Shape& shape) {
218 std::vector<int64> dims(shape.dimensions_size());
219 for (int i = 0; i < shape.dimensions_size(); ++i) {
220 dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
221 }
222 return MakeShapeWithDescendingLayout(shape.element_type(), dims);
223 }
224
PopulateShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)225 /* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,
226 absl::Span<const int64> dimensions,
227 Shape* shape) {
228 shape->Clear();
229 shape->set_element_type(element_type);
230 for (int64 dimension : dimensions) {
231 shape->add_dimensions(dimension);
232 }
233 LayoutUtil::SetToDefaultLayout(shape);
234 return ValidateShape(*shape);
235 }
236
MakeTupleShape(absl::Span<const Shape> shapes)237 /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
238 Shape result;
239 result.set_element_type(TUPLE);
240 result.mutable_tuple_shapes()->reserve(shapes.size());
241 for (const auto& shape : shapes) {
242 AppendShapeToTuple(shape, &result);
243 }
244 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
245 return result;
246 }
247
MakeOpaqueShape()248 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
249 Shape result;
250 result.set_element_type(OPAQUE);
251 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
252 return result;
253 }
254
MakeTokenShape()255 /* static */ Shape ShapeUtil::MakeTokenShape() {
256 Shape result;
257 result.set_element_type(TOKEN);
258 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
259 return result;
260 }
261
AppendShapeToTuple(const Shape & shape,Shape * tuple_shape)262 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
263 Shape* tuple_shape) {
264 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
265 *tuple_shape->add_tuple_shapes() = shape;
266 }
267
AppendMajorDimension(int bound,Shape * shape)268 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
269 CHECK(LayoutUtil::IsDenseArray(*shape));
270 shape->mutable_layout()->add_minor_to_major(shape->rank());
271 shape->add_dimensions(bound);
272 TF_DCHECK_OK(ValidateShape(*shape));
273 }
274
ElementIsIntegral(const Shape & shape)275 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
276 return primitive_util::IsIntegralType(shape.element_type());
277 }
278
ElementIsIntegralWithBits(const Shape & shape,int32 bits)279 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
280 int32 bits) {
281 return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
282 }
283
ElementHasBitWidth(const Shape & shape,int bits)284 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
285 if (!shape.IsArray()) {
286 return false;
287 }
288 return primitive_util::BitWidth(shape.element_type()) == bits;
289 }
290
ElementIsSigned(const Shape & shape)291 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
292 switch (shape.element_type()) {
293 case S8:
294 case S16:
295 case S32:
296 case S64:
297 case F16:
298 case BF16:
299 case F32:
300 case F64:
301 return true;
302
303 case PRED:
304 case U8:
305 case U16:
306 case U32:
307 case U64:
308 case C64:
309 case C128:
310 case TUPLE:
311 case OPAQUE:
312 case TOKEN:
313 return false;
314
315 default:
316 LOG(FATAL) << "Unhandled element type " << shape.element_type();
317 }
318 }
319
ElementIsComplex(const Shape & shape)320 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
321 return primitive_util::IsComplexType(shape.element_type());
322 }
323
ElementIsFloating(const Shape & shape)324 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
325 return primitive_util::IsFloatingPointType(shape.element_type());
326 }
327
IsNestedTuple(const Shape & shape)328 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
329 return shape.IsTuple() &&
330 absl::c_any_of(shape.tuple_shapes(),
331 [](const Shape& s) { return s.IsTuple(); });
332 }
333
IsEmptyTuple(const Shape & shape)334 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
335 return shape.IsTuple() && TupleElementCount(shape) == 0;
336 }
337
TupleElementCount(const Shape & shape)338 /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
339 CHECK(shape.IsTuple()) << HumanString(shape);
340 return shape.tuple_shapes_size();
341 }
342
GetTupleElementShape(const Shape & shape,int64 index)343 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
344 int64 index) {
345 CHECK(shape.IsTuple());
346 CHECK_GT(TupleElementCount(shape), index);
347 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
348 return shape.tuple_shapes(index);
349 }
350
SubshapeCount(const Shape & shape)351 /* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) {
352 int64 n = 0;
353 ForEachSubshape(shape, [&](const Shape& literal_subshape,
354 const ShapeIndex& index) { ++n; });
355 return n;
356 }
357
SliceTuple(const Shape & tuple,int64 start,int64 limit)358 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
359 int64 limit) {
360 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
361 CHECK(tuple.IsTuple());
362 CHECK_LE(start, TupleElementCount(tuple));
363 CHECK_LE(limit, TupleElementCount(tuple));
364
365 std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
366 tuple.tuple_shapes().begin() + limit);
367 return MakeTupleShape(new_elements);
368 }
369
370 // Returns the shape of a real or imaginary component.
ComplexComponentShape(const Shape & complex_shape)371 /* static */ Shape ShapeUtil::ComplexComponentShape(
372 const Shape& complex_shape) {
373 CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
374 return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
375 complex_shape.element_type()));
376 }
377
ElementsIn(const Shape & shape)378 /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
379 DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape);
380 DCHECK_EQ(shape.dimensions_size(), shape.rank());
381 if (shape.dimensions().size() == 1) {
382 return shape.dimensions()[0];
383 }
384 return std::accumulate<decltype(shape.dimensions().begin()), int64>(
385 shape.dimensions().begin(), shape.dimensions().end(), 1LL,
386 std::multiplies<int64>());
387 }
388
ElementsInRecursive(const Shape & shape)389 /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) {
390 CHECK(shape.IsArray() || shape.IsTuple());
391 if (shape.IsArray()) {
392 return ElementsIn(shape);
393 }
394 int64 count = 0;
395 for (const Shape& element_shape : shape.tuple_shapes()) {
396 count += ElementsInRecursive(element_shape);
397 }
398 return count;
399 }
400
HasPrimitiveType(const Shape & shape,PrimitiveType primitive_type)401 /* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
402 PrimitiveType primitive_type) {
403 if (shape.element_type() == primitive_type) {
404 return true;
405 }
406 for (const Shape& element_shape : shape.tuple_shapes()) {
407 if (HasPrimitiveType(element_shape, primitive_type)) {
408 return true;
409 }
410 }
411 return false;
412 }
413
IsZeroElementArray(const Shape & shape)414 /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
415 return shape.IsArray() && ElementsIn(shape) == 0;
416 }
417
IsScalarWithElementType(const Shape & shape,PrimitiveType element_type)418 /* static */ bool ShapeUtil::IsScalarWithElementType(
419 const Shape& shape, PrimitiveType element_type) {
420 return IsScalar(shape) && shape.element_type() == element_type;
421 }
422
HumanString(const Shape & shape)423 /* static */ string ShapeUtil::HumanString(const Shape& shape) {
424 if (shape.IsTuple()) {
425 string text = "(";
426 const char* prefix = "";
427 for (const Shape& elem_shape : shape.tuple_shapes()) {
428 StrAppend(&text, prefix, HumanString(elem_shape));
429 prefix = ", ";
430 }
431 text += ")";
432 return text;
433 }
434 std::vector<string> dim_elements;
435 for (int i = 0; i < shape.dimensions_size(); ++i) {
436 if (shape.is_dynamic_dimension(i)) {
437 dim_elements.push_back(StrCat("<=", shape.dimensions(i)));
438 } else {
439 dim_elements.push_back(StrCat(shape.dimensions(i)));
440 }
441 }
442 return StrCat(
443 primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
444 absl::StrJoin(dim_elements, ","), "]");
445 }
446
HumanStringWithLayout(const Shape & shape)447 /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
448 if (shape.IsTuple()) {
449 string text = "(";
450 const char* prefix = "";
451 for (const Shape& elem_shape : shape.tuple_shapes()) {
452 StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
453 prefix = ", ";
454 }
455 text += ")";
456 return text;
457 }
458 string result = StrCat(
459 primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[");
460 for (int i = 0; i < shape.dimensions().size(); i++) {
461 StrAppend(&result, (i > 0) ? "," : "",
462 shape.is_dynamic_dimension(i) ? "<=" : "", shape.dimensions(i));
463 }
464 result += "]";
465 if (IsScalar(shape)) {
466 string layout_str = LayoutUtil::HumanString(shape.layout());
467 // Don't print "{}" as layout for scalars.
468 if (layout_str != "{}") {
469 StrAppend(&result, layout_str);
470 }
471 } else if (shape.IsArray() && LayoutUtil::HasLayout(shape)) {
472 StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
473 }
474 return result;
475 }
476
HumanString(const ProgramShape & program_shape)477 /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
478 std::vector<string> parameters;
479 for (auto& shape : program_shape.parameters()) {
480 const int i = parameters.size();
481 parameters.push_back(StrCat(i < program_shape.parameter_names_size()
482 ? program_shape.parameter_names(i)
483 : "(unknown)",
484 ": ", HumanString(shape)));
485 }
486 return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
487 HumanString(program_shape.result()));
488 }
489
SameDimensions(const Shape & lhs,const Shape & rhs)490 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
491 const Shape& rhs) {
492 CHECK(lhs.IsArray());
493 CHECK(rhs.IsArray());
494 return absl::c_equal(lhs.dimensions(), rhs.dimensions());
495 }
496
Compatible(const Shape & lhs,const Shape & rhs)497 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
498 return Shape::Equal().IgnoreLayout()(lhs, rhs);
499 }
500
CompatibleIgnoringElementType(const Shape & lhs,const Shape & rhs)501 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
502 const Shape& rhs) {
503 return Shape::Equal().IgnoreElementType().IgnoreLayout()(lhs, rhs);
504 }
505
CompatibleIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)506 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
507 const Shape& rhs) {
508 return Shape::Equal().IgnoreFpPrecision().IgnoreLayout()(lhs, rhs);
509 }
510
GetDimension(const Shape & shape,int64 dimension_number)511 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
512 int64 dimension_number) {
513 return shape.dimensions(GetDimensionNumber(shape, dimension_number));
514 }
515
GetDimensionNumber(const Shape & shape,int64 dimension_number)516 /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
517 int64 dimension_number) {
518 if (dimension_number < 0) {
519 dimension_number += shape.rank();
520 }
521 CHECK_GE(dimension_number, 0);
522 return dimension_number;
523 }
524
ByteSizeOfPrimitiveType(PrimitiveType primitive_type)525 /* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
526 PrimitiveType primitive_type) {
527 switch (primitive_type) {
528 case PRED:
529 return sizeof(int8);
530 case S8:
531 return sizeof(int8);
532 case S16:
533 return sizeof(int16);
534 case S32:
535 return sizeof(int32);
536 case S64:
537 return sizeof(int64);
538 case U8:
539 return sizeof(uint8);
540 case U16:
541 return sizeof(uint16);
542 case U32:
543 return sizeof(uint32);
544 case U64:
545 return sizeof(uint64);
546 case BF16:
547 return sizeof(float) / 2;
548 case F16:
549 return sizeof(float) / 2;
550 case F32:
551 return sizeof(float);
552 case F64:
553 return sizeof(double);
554 case C64:
555 return sizeof(complex64);
556 case C128:
557 return sizeof(complex128);
558 case TOKEN:
559 // Tokens require no space.
560 return 0;
561 case TUPLE:
562 case OPAQUE:
563 LOG(FATAL) << PrimitiveType_Name(primitive_type)
564 << " primitive type has no definitive size";
565 default:
566 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
567 }
568 }
569
ByteSizeOf(const Shape & shape,int64 pointer_size)570 /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
571 int64 pointer_size) {
572 TF_DCHECK_OK(ValidateShape(shape));
573 if (shape.element_type() == TUPLE) {
574 return ByteSizeOfTupleIndexTable(shape, pointer_size);
575 } else if (shape.IsArray()) {
576 int64 byte_size = ByteSizeOfElements(shape);
577 if (LayoutUtil::IsSparseArray(shape)) {
578 byte_size += ByteSizeOfSparseIndices(shape);
579 }
580 return byte_size;
581 } else if (shape.element_type() == TOKEN) {
582 return 0;
583 } else if (shape.element_type() == OPAQUE) {
584 CHECK_GT(pointer_size, 0);
585 return pointer_size;
586 }
587 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
588 << " primitive type has no definitive size";
589 }
590
ByteSizeOfTupleIndexTable(const Shape & shape,int64 pointer_size)591 /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
592 int64 pointer_size) {
593 TF_DCHECK_OK(ValidateShape(shape));
594 CHECK_EQ(TUPLE, shape.element_type());
595 CHECK_GT(pointer_size, 0);
596 return pointer_size * shape.tuple_shapes_size();
597 }
598
ByteSizeOfElements(const Shape & shape)599 /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
600 TF_DCHECK_OK(ValidateShape(shape));
601 CHECK(shape.IsArray());
602 int64 allocated_element_count;
603
604 if (LayoutUtil::IsSparseArray(shape)) {
605 allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
606 } else {
607 CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
608 allocated_element_count = ElementsIn(shape);
609 }
610 return allocated_element_count *
611 ByteSizeOfPrimitiveType(shape.element_type());
612 }
613
ByteSizeOfSparseIndices(const Shape & shape)614 /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
615 TF_DCHECK_OK(ValidateShape(shape));
616 CHECK(LayoutUtil::IsSparseArray(shape));
617 return LayoutUtil::MaxSparseElements(shape.layout()) * shape.rank() *
618 sizeof(int64);
619 }
620
ValidateShapeWithOptionalLayoutInternal(const Shape & shape)621 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
622 const Shape& shape) {
623 if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
624 !PrimitiveType_IsValid(shape.element_type())) {
625 return InvalidArgument("shape has invalid element type: %s",
626 shape.ShortDebugString());
627 }
628 if (shape.element_type() == TUPLE) {
629 if (shape.dimensions_size() != 0) {
630 return InvalidArgument("tuples must not have dimensions specified");
631 }
632 for (auto& element_shape : shape.tuple_shapes()) {
633 TF_RETURN_IF_ERROR(
634 ValidateShapeWithOptionalLayoutInternal(element_shape));
635 }
636 return Status::OK();
637 }
638
639 // Non-tuple shape.
640 if (shape.tuple_shapes_size() > 0) {
641 return InvalidArgument("non-tuple shape has tuple_shapes field");
642 }
643
644 // Tokens and opaques can should not have layout or dimensions.
645 if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) {
646 if (shape.dimensions_size() != 0) {
647 return InvalidArgument(
648 "shape has %s element type, but has dimensions field: %s",
649 primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
650 shape.ShortDebugString());
651 }
652 if (shape.has_layout()) {
653 return InvalidArgument(
654 "shape has %s element type, but has layout field: %s",
655 primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
656 shape.ShortDebugString());
657 }
658 return Status::OK();
659 }
660
661 if (LayoutUtil::IsSparseArray(shape) && shape.rank() == 0) {
662 return InvalidArgument("sparse arrays must have rank > 0");
663 }
664 for (int64 i = 0; i < shape.rank(); ++i) {
665 int64 dimension = shape.dimensions(i);
666 if (dimension < 0) {
667 return InvalidArgument(
668 "shape's dimensions must not be < 0; dimension at index %d was %d", i,
669 dimension);
670 }
671 }
672
673 TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
674 return Status::OK();
675 }
676
ValidateShapeSize(const Shape & shape)677 /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
678 VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
679
680 if (!shape.IsArray()) {
681 return Status::OK();
682 }
683
684 // We can only reason about some aspects of array's shape if it has a valid
685 // layout, these aspects will be ignored otherwise.
686 bool shape_has_valid_layout = LayoutUtil::HasLayout(shape) &&
687 LayoutUtil::ValidateLayoutInShape(shape).ok();
688
689 int64 shape_size = [&]() {
690 if (shape_has_valid_layout && LayoutUtil::IsSparseArray(shape)) {
691 int64 max_sparse_elements = LayoutUtil::MaxSparseElements(shape.layout());
692 if (max_sparse_elements < 0) {
693 return max_sparse_elements;
694 }
695 int64 sparse_elements_size = MultiplyWithoutOverflow(
696 max_sparse_elements, ByteSizeOfPrimitiveType(shape.element_type()));
697 if (sparse_elements_size < 0) {
698 return sparse_elements_size;
699 }
700 int64 sparse_indices_size =
701 MultiplyWithoutOverflow(max_sparse_elements, shape.rank());
702 if (sparse_indices_size < 0) {
703 return sparse_indices_size;
704 }
705 sparse_indices_size =
706 MultiplyWithoutOverflow(sparse_indices_size, sizeof(int64));
707 if (sparse_indices_size < 0) {
708 return sparse_indices_size;
709 }
710 // At this point, both sparse_indices_size and sparse_elements_size are
711 // non-negative, so we can easily check if adding them wraps.
712 if (static_cast<uint64>(sparse_elements_size) +
713 static_cast<uint64>(sparse_indices_size) >
714 INT64_MAX) {
715 return static_cast<int64>(-1);
716 }
717 }
718
719 // This is intentionally unconditional: even if the shape is sparse, we want
720 // to verify the densified version has a reasonable size.
721 int64 dense_shape_size = 1;
722 if (shape.dimensions().empty()) {
723 return dense_shape_size;
724 }
725
726 absl::Span<const int64> shape_max_dimensions =
727 AsInt64Slice(shape.dimensions());
728 for (int64 dim : shape_max_dimensions) {
729 dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
730 if (dense_shape_size < 0) {
731 return dense_shape_size;
732 }
733 }
734 dense_shape_size = MultiplyWithoutOverflow(
735 dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
736 return dense_shape_size;
737 }();
738
739 if (shape_size < 0) {
740 return InvalidArgument("Shape %s size may overflow int64.",
741 ShapeUtil::HumanString(shape));
742 }
743
744 VLOG(3) << "Shape size is valid: " << shape_size;
745 return Status::OK();
746 }
747
ValidateShapeWithOptionalLayout(const Shape & shape)748 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
749 const Shape& shape) {
750 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
751
752 return LayoutUtil::ValidateLayoutInShape(shape,
753 /*allow_missing_layouts=*/true);
754 }
755
ValidateShape(const Shape & shape)756 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
757 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
758
759 return LayoutUtil::ValidateLayoutInShape(shape);
760 }
761
ChangeElementType(const Shape & original,PrimitiveType type)762 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
763 PrimitiveType type) {
764 Shape new_shape = original;
765 new_shape.set_element_type(type);
766 return new_shape;
767 }
768
IndexIsValid(const Shape & shape,ShapeIndexView index)769 /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
770 ShapeIndexView index) {
771 const Shape* subshape = &shape;
772 for (auto i : index) {
773 if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) {
774 return false;
775 }
776 subshape = &subshape->tuple_shapes(i);
777 }
778 return true;
779 }
780
GetSubshape(const Shape & shape,ShapeIndexView index)781 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
782 ShapeIndexView index) {
783 const Shape* return_shape = &shape;
784 for (auto i : index) {
785 CHECK(return_shape->IsTuple())
786 << "Invalid index " << index << " for shape " << shape;
787 return_shape = &return_shape->tuple_shapes(i);
788 }
789 return *return_shape;
790 }
791
TryGetSubshape(const Shape & shape,ShapeIndexView index)792 /* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
793 const Shape& shape, ShapeIndexView index) {
794 const Shape* return_shape = &shape;
795 for (auto i : index) {
796 if (!return_shape->IsTuple() || i < 0 ||
797 i >= return_shape->tuple_shapes_size()) {
798 return InvalidArgument(
799 "Shape index %s not a valid subshape index for tuple with shape %s",
800 index.ToString(), shape.DebugString());
801 }
802 return_shape = &return_shape->tuple_shapes(i);
803 }
804 return return_shape;
805 }
806
GetMutableSubshape(Shape * shape,ShapeIndexView index)807 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
808 ShapeIndexView index) {
809 Shape* return_shape = shape;
810 for (auto i : index) {
811 CHECK(return_shape->IsTuple());
812 return_shape = return_shape->mutable_tuple_shapes(i);
813 }
814 return return_shape;
815 }
816
817 /* static */
IsLeafIndex(const Shape & shape,const ShapeIndex & index)818 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
819 return !GetSubshape(shape, index).IsTuple();
820 }
821
GetLeafCount(const Shape & shape)822 /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
823 if (!shape.IsTuple()) {
824 return 1;
825 }
826 int64 count = 0;
827 for (const Shape& subshape : shape.tuple_shapes()) {
828 count += GetLeafCount(subshape);
829 }
830 return count;
831 }
832
GetLeafShapes(const Shape & shape)833 /* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
834 const Shape& shape) {
835 std::vector<IndexedShape> leaves;
836 ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
837 if (IsLeafIndex(shape, index)) {
838 leaves.emplace_back(index, sub_shape);
839 }
840 });
841 return leaves;
842 }
843
HasDegenerateDimensions(const Shape & shape)844 /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
845 CHECK(shape.IsArray());
846 return absl::c_linear_search(shape.dimensions(), 1);
847 }
848
DropDegenerateDimensions(const Shape & shape)849 /* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
850 return FilterDimensions(
851 [&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
852 }
853
854 namespace {
855
856 // Helper for ForEachSubshape which visits the subshapes of the given shape in
857 // DFS pre-order starting with the index.
ForEachSubshapeHelper(const Shape & shape,const ShapeUtil::StatusVisitorFunction & func,ShapeIndex * index)858 Status ForEachSubshapeHelper(const Shape& shape,
859 const ShapeUtil::StatusVisitorFunction& func,
860 ShapeIndex* index) {
861 TF_RETURN_IF_ERROR(func(shape, *index));
862 if (shape.IsTuple()) {
863 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
864 index->push_back(i);
865 TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
866 ShapeUtil::GetTupleElementShape(shape, i), func, index));
867 index->pop_back();
868 }
869 }
870 return Status::OK();
871 }
872
873 // Helper for ForEachMutableSubshape which visits the subshapes of the given
874 // shape in DFS pre-order starting with the index.
ForEachMutableSubshapeHelper(Shape * shape,const ShapeUtil::MutatingStatusVisitorFunction & func,ShapeIndex * index)875 Status ForEachMutableSubshapeHelper(
876 Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
877 ShapeIndex* index) {
878 TF_RETURN_IF_ERROR(func(shape, *index));
879 if (shape->IsTuple()) {
880 for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
881 index->push_back(i);
882 TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
883 shape->mutable_tuple_shapes(i), func, index));
884 index->pop_back();
885 }
886 }
887 return Status::OK();
888 }
889
890 } // namespace
891
ForEachSubshape(const Shape & shape,const VisitorFunction & func)892 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
893 const VisitorFunction& func) {
894 ShapeIndex index;
895 ForEachSubshapeHelper(
896 shape,
897 [&func](const Shape& subshape, const ShapeIndex& index) {
898 func(subshape, index);
899 return Status::OK();
900 },
901 &index)
902 .IgnoreError();
903 }
904
ForEachMutableSubshape(Shape * shape,const MutatingVisitorFunction & func)905 /* static */ void ShapeUtil::ForEachMutableSubshape(
906 Shape* shape, const MutatingVisitorFunction& func) {
907 ShapeIndex index;
908 ForEachMutableSubshapeHelper(
909 shape,
910 [&func](Shape* subshape, const ShapeIndex& index) {
911 func(subshape, index);
912 return Status::OK();
913 },
914 &index)
915 .IgnoreError();
916 }
917
ForEachSubshapeWithStatus(const Shape & shape,const StatusVisitorFunction & func)918 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
919 const Shape& shape, const StatusVisitorFunction& func) {
920 ShapeIndex index;
921 return ForEachSubshapeHelper(shape, func, &index);
922 }
923
ForEachMutableSubshapeWithStatus(Shape * shape,const MutatingStatusVisitorFunction & func)924 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
925 Shape* shape, const MutatingStatusVisitorFunction& func) {
926 ShapeIndex index;
927 return ForEachMutableSubshapeHelper(shape, func, &index);
928 }
929
PermuteDimensions(absl::Span<const int64> permutation,const Shape & shape)930 /* static */ Shape ShapeUtil::PermuteDimensions(
931 absl::Span<const int64> permutation, const Shape& shape) {
932 Shape new_shape = shape;
933 new_shape.clear_dimensions();
934 for (auto dim : Permute(permutation, shape.dimensions())) {
935 new_shape.add_dimensions(dim);
936 }
937 for (int64 i = 0; i < shape.rank(); i++) {
938 new_shape.set_dynamic_dimension(permutation[i],
939 shape.is_dynamic_dimension(i));
940 }
941
942 // If `shape` has a layout, by contract we choose a new layout such that the
943 // transpose defined by this permutation is a bitcast.
944 //
945 // Some formalism helps to understand the correct way to do this. We're going
946 // to do algebra in the group of permutations of the dimensions of `shape`.
947 //
948 // Since the order of `shape`'s dimensions is not permuted relative to itself,
949 // `shape`'s list of dimensions is isomorphic to the identity I.
950 //
951 // Let `shape`'s layout be L. A layout is a permutation which maps a
952 // minor-to-major physical layout to the order of a shape's logical dims.
953 // Therefore inverse of a layout maps from logical to physical dims, and so
954 // the physical layout of I is simply L'.I = L', where L' is the inverse of L.
955 //
956 // Let the argument `permutation` be P. This is a permutation over `shape`'s
957 // dimensions, so our return value will be a shape with dims P.I = P. Our
958 // goal is to construct a layout permutation L* that we can apply to P such
959 // that the physical dimension ordering of the returned shape is the same
960 // as that of the original shape, namely L'.
961 //
962 // Our returned shape has dims P and layout L*, so its in-memory layout is
963 // L*'.P. Setting this equal to L' and solving for L*, we get:
964 //
965 // L*'.P = L' =>
966 // L*' = L'P' =>
967 // L* = P.L
968 //
969 if (shape.has_layout()) {
970 CHECK(LayoutUtil::IsDenseArray(shape));
971 Layout* new_layout = new_shape.mutable_layout();
972 new_layout->set_format(DENSE);
973 new_layout->clear_minor_to_major();
974 for (auto index : ComposePermutations(
975 permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
976 new_layout->add_minor_to_major(index);
977 }
978 // The permutation accepted by TransposeIsBitcast is the inverse of the
979 // permutation here.
980 CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation)))
981 << "shape=" << HumanStringWithLayout(shape)
982 << ", new_shape=" << HumanStringWithLayout(new_shape)
983 << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
984 }
985 return new_shape;
986 }
987
988 /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
InsertedOrDeleted1SizedDimensions(const Shape & shape_pre,const Shape & shape_post)989 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
990 const Shape& shape_post) {
991 CHECK(shape_pre.IsArray());
992 CHECK(shape_post.IsArray());
993
994 auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
995
996 std::vector<int64> deleted_indices;
997 std::vector<int64> inserted_indices;
998 // Returns false if any input/output index between prior_unmodified_dim_pair
999 // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
1000 // the degerenate input/output dimensions in the gap to
1001 // deleted_indices/inserted_indices respectively.
1002 auto check_modified_dims =
1003 [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
1004 std::pair<int64, int64> prior_unmodified_dim_pair,
1005 std::pair<int64, int64> unmodified_dim_pair) {
1006 for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
1007 modified_input_dim < unmodified_dim_pair.first;
1008 ++modified_input_dim) {
1009 if (shape_pre.dimensions(modified_input_dim) > 1) {
1010 return false;
1011 }
1012 deleted_indices.push_back(modified_input_dim);
1013 }
1014 for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
1015 modified_output_dim < unmodified_dim_pair.second;
1016 ++modified_output_dim) {
1017 if (shape_post.dimensions(modified_output_dim) > 1) {
1018 return false;
1019 }
1020 inserted_indices.push_back(modified_output_dim);
1021 }
1022 return true;
1023 };
1024
1025 std::vector<std::pair<int64, int64>> unmodified_dims =
1026 DimensionsUnmodifiedByReshape(shape_pre, shape_post);
1027 // Returns nil if the reshape modifies any non-degenerate input/output
1028 // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
1029 // dimensions, so we only need to check whether dimensions in the gaps (thus
1030 // modified) have size >1.
1031 for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
1032 // Check (modified) dimensions between unmodified_dims[i-1] and
1033 // unmodified_dims[i].
1034 auto prior_unmodified_dim_pair =
1035 i > 0 ? unmodified_dims[i - 1] : std::make_pair(-1LL, -1LL);
1036 auto unmodified_dim_pair =
1037 i < unmodified_dims.size()
1038 ? unmodified_dims[i]
1039 : std::make_pair(shape_pre.rank(), shape_post.rank());
1040 if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
1041 return nil;
1042 }
1043 }
1044
1045 return std::make_tuple(true, deleted_indices, inserted_indices);
1046 }
1047
1048 /* static */ std::vector<std::pair<int64, int64>>
DimensionsUnmodifiedByReshape(const Shape & input_shape,const Shape & output_shape)1049 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
1050 const Shape& output_shape) {
1051 CHECK(input_shape.IsArray());
1052 CHECK(output_shape.IsArray());
1053
1054 // Unmodified dimensions are merely common factors of rank 1.
1055 auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
1056 AsInt64Slice(output_shape.dimensions()));
1057 for (size_t i = 0; i < common_factors.size() - 1;) {
1058 if (1 != common_factors[i + 1].first - common_factors[i].first ||
1059 1 != common_factors[i + 1].second - common_factors[i].second) {
1060 common_factors.erase(common_factors.begin() + i);
1061 } else {
1062 ++i;
1063 }
1064 }
1065 // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
1066 common_factors.pop_back();
1067 return common_factors;
1068 }
1069
TransposeIsBitcast(const Shape & input_shape,const Shape & output_shape,absl::Span<const int64> dimension_mapping)1070 /* static */ bool ShapeUtil::TransposeIsBitcast(
1071 const Shape& input_shape, const Shape& output_shape,
1072 absl::Span<const int64> dimension_mapping) {
1073 CHECK(LayoutUtil::HasLayout(input_shape) &&
1074 LayoutUtil::HasLayout(output_shape));
1075
1076 if (!SameElementType(input_shape, output_shape)) {
1077 return false;
1078 }
1079
1080 // Check the reshape permutes the positions of each dimension in the
1081 // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
1082 // input_positions = apply(dimension_mapping, output_positions)
1083 //
1084 // Because the positions of each dimension are the inverse permutation of the
1085 // minor-to-major order, the above check is equivalent to
1086 // inverse(input_dimensions) =
1087 // apply(dimension_mapping, inverse(output_dimensions))
1088 // # `I` indicates identity permutation.
1089 // apply(input_dimensions, I) =
1090 // apply(dimension_mapping, apply(output_dimensions, I))
1091 // apply(input_dimensions, I) =
1092 // apply((dimension_mapping * output_dimensions), I)
1093 // input_dimensions = dimension_mapping * output_dimensions
1094 return absl::c_equal(
1095 ComposePermutations(dimension_mapping,
1096 AsInt64Slice(output_shape.layout().minor_to_major())),
1097 input_shape.layout().minor_to_major());
1098 }
1099
ReshapeIsBitcast(const Shape & input_shape,const Shape & output_shape)1100 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
1101 const Shape& output_shape) {
1102 CHECK(input_shape.IsArray());
1103 CHECK(output_shape.IsArray());
1104 CHECK(LayoutUtil::HasLayout(input_shape));
1105 CHECK(LayoutUtil::HasLayout(output_shape));
1106
1107 if (!SameElementType(input_shape, output_shape)) {
1108 return false;
1109 }
1110
1111 CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape));
1112 if (ElementsIn(input_shape) == 0) {
1113 return true;
1114 }
1115
1116 // TL;DR: The rest of the method checks that the reshape does not change the
1117 // physical location of any unit input or output index. Unit indices have
1118 // exactly one dimension that equals 1 and other dimensions 0. This condition
1119 // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
1120 // reshape shouldn't change the physical location of any element. It is also a
1121 // sufficient condition as is proved below (note: many details are omitted for
1122 // space).
1123 //
1124 // Definitions:
1125 //
1126 // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
1127 // the size of i-th least significant dimension of IS or OS (this is opposite
1128 // to how we define the index of Shape::dimensions()).
1129 //
1130 // * Given an input or output index I, denote by p(I) I's physical linear
1131 // index (or physical index for short) and l(I) I's logical linear index (or
1132 // logical index for short).
1133 //
1134 // * Given a logical index k, denote by II(k) the input index whose linear
1135 // index is k, and OI(k) the corresponding output index.
1136 //
1137 // * Denote by IT[i] the increment of physical index if i-th dimension of the
1138 // input index is increased by 1. Similarly, OT[i] means the increment if i-th
1139 // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
1140 // is a function of IS or OS and the layout, and not dependent on the specific
1141 // input or output index.
1142 //
1143 // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
1144 // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
1145 // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
1146 // to prove, with every increment on k, the above formula still holds.
1147 //
1148 // First, suppose reshaping from IS to OS is non-factorizable (we discuss
1149 // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
1150 // there exists (i,j) such that
1151 //
1152 // 0<=i<=|IS|
1153 // 0<=j<=|OS|
1154 // |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
1155 // product(IS[i], IS[i+1], ..., IS[|IS|-1])
1156 // = product(OS[j], OS[j+1], ..., OS[|OS|-1])
1157 //
1158 // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
1159 // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
1160 // unit indices which are already tested. This also means IT[0]=OT[0]
1161 // because p(II(1))=IT[0] and p(OI(1))=OT[0].
1162 //
1163 // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
1164 // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
1165 // to the output physical.
1166 //
1167 // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
1168 // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
1169 // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
1170 // logical index k-1 to logical index k, dimension 1 of the input index
1171 // is increased by 1 and dimension 0 is reset to 0 thus decreased by
1172 // IS[0]-1. Therefore, the physical input index is increased by
1173 //
1174 // p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
1175 //
1176 // Because IS[0]<OS[0], the only change to the output index is that its
1177 // dimension 0 is increased by one. Therefore,
1178 //
1179 // p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
1180 //
1181 // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
1182 // p(II(k))=p(OI(k)). Therefore,
1183 // IT[1] - (IS[0]-1) * IT[0] = IT[0]
1184 // IT[1] = IS[0] * IT[0]
1185 // In other words, input dimension 1 is immediately more major than input
1186 // dimension 0. We can now conceptually collapse these two dimensions because
1187 // an increment in the logical index affecting only these two dimensions maps
1188 // to IT[0] in the physical index.
1189 //
1190 // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
1191 // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
1192 // identical.
1193 //
1194 // A factorizable reshape can be factorized into a list of non-factorizable
1195 // sub-reshapes, each of which can be handled similarly to the proof above.
1196 // For example,
1197 //
1198 // [7x9x2x15] -> [63x6x5]
1199 //
1200 // can be factorized into
1201 //
1202 // [7x9] -> [63] and [2x15] -> [6x5].
1203 //
1204 // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
1205 // same logical linear index. According to the factorization, we know
1206 // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
1207 // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
1208 // similar proof, with the increment of the logical index set to
1209 // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
1210 // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
1211 //
1212 // p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
1213 // = p(y2,0,0) + p(0,0,y1,y0)
1214 // = p(y2,y1,y0)
1215 //
1216 // check_input_unit_indices checks one way of the condition: each input unit
1217 // index is mapped to an output index with the same physical location. This
1218 // lambda will be called again with input_shape and output_shape reversed to
1219 // check the other way.
1220 auto check_input_unit_indices = [](const Shape& input_shape,
1221 const Shape& output_shape) {
1222 // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
1223 // as input_shape/output_shape and the dimension-0-major layout. These two
1224 // shapes are used for conversion between logical linear indices and
1225 // multi-dimensional indices.
1226 Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
1227 input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
1228 Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
1229 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
1230
1231 for (int64 input_dim = 0; input_dim < input_shape.rank(); ++input_dim) {
1232 if (input_shape.dimensions(input_dim) <= 1) {
1233 continue;
1234 }
1235
1236 std::vector<int64> input_unit_index(input_shape.rank(), 0);
1237 input_unit_index[input_dim] = 1;
1238 int64 logical_linear_index =
1239 IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
1240 input_unit_index);
1241 // output_index has the same logical linear index as input_unit_index.
1242 std::vector<int64> output_index =
1243 IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
1244 logical_linear_index);
1245 // Check input_unit_index and output_index have the same physical linear
1246 // index.
1247 if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
1248 input_unit_index) !=
1249 IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
1250 output_index)) {
1251 return false;
1252 }
1253 }
1254 return true;
1255 };
1256 return check_input_unit_indices(input_shape, output_shape) &&
1257 check_input_unit_indices(output_shape, input_shape);
1258 }
1259
AlignLayouts(const Shape & input_shape,const Shape & output_shape)1260 /* static */ absl::optional<Shape> ShapeUtil::AlignLayouts(
1261 const Shape& input_shape, const Shape& output_shape) {
1262 CHECK(input_shape.IsArray());
1263 CHECK(output_shape.IsArray());
1264 // Removing trivial dimensions from the shape simplifies the alignment
1265 // algorithm since ones can go in any position.
1266 if (HasDegenerateDimensions(input_shape) ||
1267 HasDegenerateDimensions(output_shape)) {
1268 auto simple_output_shape =
1269 AlignLayouts(DropDegenerateDimensions(input_shape),
1270 DropDegenerateDimensions(output_shape));
1271 if (!simple_output_shape) {
1272 return absl::nullopt;
1273 }
1274
1275 auto layout = simple_output_shape->layout().minor_to_major();
1276 // For each one sized dimension in the output, increment the dimension
1277 // numbers in layout that are more minor than the one.
1278 absl::InlinedVector<int64, 8> dim_map;
1279 dim_map.reserve(simple_output_shape->rank());
1280 for (int64 i = 0; i < output_shape.rank(); ++i) {
1281 if (output_shape.dimensions(i) != 1) {
1282 dim_map.push_back(i);
1283 }
1284 }
1285 for (int64& d : layout) {
1286 d = dim_map[d];
1287 }
1288
1289 // Add the ones in descending order to the layout. Descending layouts tend
1290 // to reduce the number of copies inserted in layout assignment.
1291 for (int64 i = output_shape.rank() - 1; i >= 0; --i) {
1292 if (output_shape.dimensions(i) == 1) {
1293 layout.push_back(i);
1294 }
1295 }
1296 Shape output_shape_with_layout = output_shape;
1297 *output_shape_with_layout.mutable_layout()->mutable_minor_to_major() =
1298 layout;
1299 return output_shape_with_layout;
1300 }
1301
1302 int64 input_rank = input_shape.rank();
1303 int64 output_rank = output_shape.rank();
1304
1305 // First, calculate an alignment of the dimensions. A consecutive sequence of
1306 // input dimensions and output dimensions belong to the same alignment part if
1307 // the products of their dimension bounds are the same. In the easiest case,
1308 // an alignment part consists of one input dimension and one output dimension
1309 // which both have the same dimension bound. An alignment part specifies which
1310 // dimensions need to be kept together in a physical layout if we want a
1311 // reshape to be a bitcast. The order of the alignment parts is defined by the
1312 // physical layout of the input shape, so when we construct the layout for the
1313 // output shape we just process the alignment parts in this order, and then
1314 // layout the dimensions belonging to each part in descending (major to minor)
1315 // order.
1316
1317 // Stores the input and output dimension numbers where each alignment part
1318 // starts.
1319 std::vector<std::pair<int64, int64>> alignment;
1320 alignment.push_back({0, 0});
1321
1322 // Stores a mapping from the input dimension to the alignment part it belongs
1323 // to.
1324 std::vector<int64> dimension_to_alignment_index(input_rank);
1325 int64 input_dimension_product = 1, output_dimension_product = 1;
1326 for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) {
1327 // Check if we have reached the end of an alignment part.
1328 if (input_dimension_product == output_dimension_product &&
1329 input_dimension_product > 1) {
1330 alignment.push_back({i, j});
1331 input_dimension_product = output_dimension_product = 1;
1332 }
1333 if (input_dimension_product < output_dimension_product ||
1334 j == output_rank) {
1335 if (i == input_rank) {
1336 return absl::nullopt;
1337 }
1338 dimension_to_alignment_index[i] = alignment.size() - 1;
1339 input_dimension_product *= input_shape.dimensions(i);
1340 ++i;
1341 } else {
1342 output_dimension_product *= output_shape.dimensions(j);
1343 ++j;
1344 }
1345 }
1346 if (input_dimension_product != output_dimension_product) {
1347 return absl::nullopt;
1348 }
1349
1350 // We also need to store an end element so that we know where the last
1351 // alignment part ends.
1352 alignment.push_back({input_rank, output_rank});
1353 // Now check if the physical layout can potentially be aligned to the output
1354 // shape by changing the physical layout of the output shape. We need to check
1355 // that all dimension numbers that belong to the same alignment part appear
1356 // consecutively, and are in descending order. However we can ignore any
1357 // trivial dimension bounds of 1, because they can be placed anywhere.
1358 auto input_dimension_numbers = input_shape.layout().minor_to_major();
1359 std::vector<int64> output_layout;
1360 output_layout.reserve(output_rank);
1361 for (int64 i = 0; i < input_rank;) {
1362 int64 current_dimension_number = input_dimension_numbers[i];
1363
1364 // Trivial dimensions are stripped.
1365 CHECK_NE(input_shape.dimensions(current_dimension_number), 1);
1366 const int64 current_alignment_index =
1367 dimension_to_alignment_index[current_dimension_number];
1368 // Because of the special end element that we added, we can be sure that
1369 // 'current_alignment_index' is < alignment.size() - 1.
1370 CHECK_LT(current_alignment_index, alignment.size() - 1);
1371
1372 // Check that the following 'num_non_trivial_dimensions_in_alignment_part'
1373 // dimension numbers (ignoring dimension numbers with dimension bound 1) are
1374 // in descending order and belong to the current alignment part.
1375 for (int64 j = 0; j < alignment[current_alignment_index + 1].first -
1376 alignment[current_alignment_index].first;
1377 ++i, ++j) {
1378 if (i == input_rank) {
1379 return absl::nullopt;
1380 }
1381 // If the current dimension number belongs to a different alignment part,
1382 // or the dimension numbers are not in descending order, we can return
1383 // early.
1384 if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
1385 current_alignment_index ||
1386 input_dimension_numbers[i] > current_dimension_number) {
1387 return absl::nullopt;
1388 }
1389 current_dimension_number = input_dimension_numbers[i];
1390 }
1391 // The output dimension numbers that belong to the current alignment part
1392 // need to appear in the same descending order as in the input.
1393 for (int64 j = alignment[current_alignment_index + 1].second - 1;
1394 j >= alignment[current_alignment_index].second; --j) {
1395 output_layout.push_back(j);
1396 }
1397 }
1398 CHECK_EQ(output_layout.size(), output_rank);
1399 Shape output_shape_with_layout = MakeShapeWithLayout(
1400 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1401 output_layout);
1402 CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout))
1403 << "reshape is not a bitcast for input_shape: "
1404 << ShapeUtil::HumanStringWithLayout(input_shape)
1405 << " and output_shape_with_layout: "
1406 << ShapeUtil::HumanStringWithLayout(output_shape_with_layout);
1407 return output_shape_with_layout;
1408 }
1409
DeleteDimension(int64 dim_to_delete,Shape shape)1410 /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
1411 Shape shape) {
1412 CHECK(shape.IsArray());
1413 shape.DeleteDimension(dim_to_delete);
1414 return shape;
1415 }
1416
FilterDimensions(const std::function<bool (int64)> & p,Shape shape)1417 /* static */ Shape ShapeUtil::FilterDimensions(
1418 const std::function<bool(int64)>& p, Shape shape) {
1419 CHECK(shape.IsArray());
1420 std::vector<int64> dims_to_delete;
1421 for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
1422 if (!p(i)) {
1423 dims_to_delete.push_back(i);
1424 }
1425 }
1426 for (int64 dim : dims_to_delete) {
1427 shape = DeleteDimension(dim, shape);
1428 }
1429 return shape;
1430 }
1431
Hash(const Shape & shape)1432 /*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
1433 using tensorflow::hash;
1434 using tensorflow::Hash64Combine;
1435
1436 size_t hash_value = hash<PrimitiveType>()(shape.element_type());
1437
1438 if (shape.tuple_shapes().empty()) {
1439 for (int i = 0; i < shape.dimensions_size(); ++i) {
1440 hash_value =
1441 Hash64Combine(hash_value, hash<int64>()(shape.dimensions(i)));
1442 hash_value = Hash64Combine(hash_value,
1443 hash<bool>()(shape.is_dynamic_dimension(i)));
1444 }
1445
1446 hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout()));
1447 } else {
1448 hash_value = 0;
1449 for (const Shape& subshape : shape.tuple_shapes()) {
1450 hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape));
1451 }
1452 }
1453
1454 return hash_value;
1455 }
1456
1457 } // namespace xla
1458