1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/shape_util.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/ascii.h"
27 #include "absl/strings/numbers.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/strings/str_split.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/strings/strip.h"
33 #include "absl/types/optional.h"
34 #include "tensorflow/compiler/xla/index_util.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/overflow_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/iterator_range.h"
43 #include "tensorflow/core/lib/hash/hash.h"
44 #include "tensorflow/core/lib/strings/numbers.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/protobuf.h"
47 #include "tensorflow/core/platform/regexp.h"
48
49 namespace xla {
50
51 using absl::StrAppend;
52 using absl::StrCat;
53
ToString() const54 string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
55
ToString() const56 string ShapeIndexView::ToString() const {
57 return StrCat("{", absl::StrJoin(indices_, ","), "}");
58 }
59
operator ==(const ShapeIndexView & other) const60 bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
61 return indices_ == other.indices_;
62 }
63
operator !=(const ShapeIndexView & other) const64 bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
65 return !(*this == other);
66 }
67
operator <<(std::ostream & out,const ShapeIndex & shape_index)68 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
69 out << shape_index.ToString();
70 return out;
71 }
72
operator <<(std::ostream & out,const ShapeIndexView & shape_index)73 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
74 out << shape_index.ToString();
75 return out;
76 }
77
StartsWith(ShapeIndexView prefix) const78 bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const {
79 return size() >= prefix.size() &&
80 indices_.subspan(0, prefix.size()) == prefix.indices_;
81 }
82
IsArrayPrimitiveType(PrimitiveType primitive_type)83 /* static */ bool ShapeUtil::IsArrayPrimitiveType(
84 PrimitiveType primitive_type) {
85 return primitive_util::IsArrayType(primitive_type);
86 }
87
88 namespace {
89 // Constructs and returns the new shape with the given minor_to_major order in
90 // its Layout.
MakeShapeWithLayoutInternal(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)91 StatusOr<Shape> MakeShapeWithLayoutInternal(
92 PrimitiveType element_type, absl::Span<const int64> dimensions,
93 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
94 int64 element_size_in_bits, int64 memory_space) {
95 if (dimensions.size() != minor_to_major.size()) {
96 return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
97 dimensions.size(), minor_to_major.size());
98 }
99 if (element_type == OPAQUE_TYPE || element_type == TUPLE) {
100 return InvalidArgument("Unsupported element type: %s",
101 PrimitiveType_Name(element_type));
102 }
103 TF_ASSIGN_OR_RETURN(Shape shape,
104 ShapeUtil::MakeValidatedShape(element_type, dimensions));
105 if (element_size_in_bits ==
106 ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
107 // Only set element_size_in_bits if it's different from the default value.
108 element_size_in_bits = 0;
109 }
110 *shape.mutable_layout() = LayoutUtil::MakeLayout(
111 minor_to_major, tiles, element_size_in_bits, memory_space);
112 if (!shape.has_layout()) {
113 return InvalidArgument("Shape has no layout.");
114 }
115 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
116 return shape;
117 }
118 } // namespace
119
Equal(const Shape & lhs,const Shape & rhs)120 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
121 bool equal = Shape::Equal()(lhs, rhs);
122
123 if (!equal && VLOG_IS_ON(3)) {
124 VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
125 << ", rhs = " << rhs.ShortDebugString();
126 }
127
128 return equal;
129 }
130
EqualIgnoringElementType(const Shape & lhs,const Shape & rhs)131 /* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
132 const Shape& rhs) {
133 bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
134 if (!equal && VLOG_IS_ON(3)) {
135 VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
136 << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
137 }
138
139 return equal;
140 }
141
EqualIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)142 /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
143 const Shape& rhs) {
144 bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
145 if (!equal && VLOG_IS_ON(3)) {
146 VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
147 << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
148 }
149
150 return equal;
151 }
152
TrueRank(const Shape & shape)153 /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
154 int64 accum = 0;
155 for (int64 dimension : shape.dimensions()) {
156 // We do not count zero dimensions.
157 if (dimension != 1) {
158 accum += 1;
159 }
160 }
161 return accum;
162 }
163
MakeProgramShape(std::initializer_list<Shape> parameters,Shape result)164 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
165 std::initializer_list<Shape> parameters, Shape result) {
166 ProgramShape program_shape;
167 for (const Shape& shape : parameters) {
168 *program_shape.add_parameters() = shape;
169 }
170 *program_shape.mutable_result() = std::move(result);
171 return program_shape;
172 }
173
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions)174 /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
175 absl::Span<const int64> dimensions) {
176 return MakeValidatedShape(element_type, dimensions).ValueOrDie();
177 }
178
MakeScalarShape(PrimitiveType element_type)179 /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
180 return MakeShape(element_type, {});
181 }
182
MakeShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)183 /* static */ Shape ShapeUtil::MakeShape(
184 PrimitiveType element_type, absl::Span<const int64> dimensions,
185 const std::vector<bool>& dynamic_dimensions) {
186 return MakeValidatedShape(element_type, dimensions, dynamic_dimensions)
187 .ValueOrDie();
188 }
189
MakeShapeWithStaticDimensions(const Shape & shape)190 /* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
191 const Shape& shape) {
192 Shape output = shape;
193 output.clear_dynamic_dimensions();
194 return output;
195 }
196
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions)197 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
198 PrimitiveType element_type, absl::Span<const int64> dimensions) {
199 CHECK(IsArrayPrimitiveType(element_type)) << element_type;
200 Shape result;
201 TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result));
202 return result;
203 }
204
MakeValidatedShape(PrimitiveType element_type,absl::Span<const int64> dimensions,const std::vector<bool> & dynamic_dimensions)205 /* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
206 PrimitiveType element_type, absl::Span<const int64> dimensions,
207 const std::vector<bool>& dynamic_dimensions) {
208 TF_ASSIGN_OR_RETURN(Shape shape,
209 MakeValidatedShape(element_type, dimensions));
210 for (int i = 0; i < dynamic_dimensions.size(); ++i) {
211 shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
212 }
213 return shape;
214 }
215
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)216 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
217 PrimitiveType element_type, absl::Span<const int64> dimensions,
218 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
219 int64 element_size_in_bits, int64 memory_space) {
220 return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major,
221 tiles, element_size_in_bits, memory_space)
222 .ValueOrDie();
223 }
224
MakeShapeWithDescendingLayout(PrimitiveType element_type,absl::Span<const int64> dimensions)225 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
226 PrimitiveType element_type, absl::Span<const int64> dimensions) {
227 std::vector<int64> layout(dimensions.size());
228 std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
229 return MakeShapeWithLayout(element_type, dimensions, layout);
230 }
231
232 /* static */ Shape
MakeShapeWithDescendingLayoutAndSamePhysicalLayout(const Shape & shape)233 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
234 const Shape& shape) {
235 std::vector<int64> dims(shape.dimensions_size());
236 for (int i = 0; i < shape.dimensions_size(); ++i) {
237 dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
238 }
239 Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
240 // Since the physical layout is kept the same, the tiles and element size are
241 // the same also.
242 new_shape.mutable_layout()->mutable_tiles()->assign(
243 shape.layout().tiles().begin(), shape.layout().tiles().end());
244 new_shape.mutable_layout()->set_element_size_in_bits(
245 shape.layout().element_size_in_bits());
246 for (int i = 0; i < shape.dimensions_size(); ++i) {
247 new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
248 }
249 return new_shape;
250 }
251
PopulateShape(PrimitiveType element_type,absl::Span<const int64> dimensions,Shape * shape)252 /* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,
253 absl::Span<const int64> dimensions,
254 Shape* shape) {
255 shape->Clear();
256 shape->set_element_type(element_type);
257 for (int64 dimension : dimensions) {
258 shape->add_dimensions(dimension);
259 }
260 LayoutUtil::SetToDefaultLayout(shape);
261 return ValidateShape(*shape);
262 }
263
MakeTupleShape(absl::Span<const Shape> shapes)264 /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
265 Shape result;
266 result.set_element_type(TUPLE);
267 result.mutable_tuple_shapes()->reserve(shapes.size());
268 for (const auto& shape : shapes) {
269 AppendShapeToTuple(shape, &result);
270 }
271 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
272 return result;
273 }
274
MakeOpaqueShape()275 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
276 Shape result;
277 result.set_element_type(OPAQUE_TYPE);
278 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
279 return result;
280 }
281
MakeTokenShape()282 /* static */ Shape ShapeUtil::MakeTokenShape() {
283 Shape result;
284 result.set_element_type(TOKEN);
285 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
286 return result;
287 }
288
AppendShapeToTuple(const Shape & shape,Shape * tuple_shape)289 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
290 Shape* tuple_shape) {
291 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
292 *tuple_shape->add_tuple_shapes() = shape;
293 }
294
UpdateTupleShape(const Shape & shape,int64 index,Shape * tuple_shape)295 /* static */ void ShapeUtil::UpdateTupleShape(const Shape& shape, int64 index,
296 Shape* tuple_shape) {
297 CHECK(index < tuple_shape->tuple_shapes_size());
298 *tuple_shape->mutable_tuple_shapes(index) = shape;
299 }
300
UpdateDynamicDimension(Shape * shape,ShapeIndexView index,int64 dim,bool is_dynamic)301 /* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
302 ShapeIndexView index,
303 int64 dim,
304 bool is_dynamic) {
305 if (index.empty()) {
306 CHECK(!shape->IsTuple());
307 shape->set_dynamic_dimension(dim, is_dynamic);
308 return;
309 }
310
311 UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
312 index.ConsumeFront(), dim, is_dynamic);
313 }
314
AppendMajorDimension(int bound,Shape * shape)315 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
316 CHECK(LayoutUtil::IsDenseArray(*shape));
317 shape->mutable_layout()->add_minor_to_major(shape->rank());
318 shape->add_dimensions(bound);
319 TF_DCHECK_OK(ValidateShape(*shape));
320 }
321
ElementIsIntegral(const Shape & shape)322 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
323 return primitive_util::IsIntegralType(shape.element_type());
324 }
325
ElementIsIntegralWithBits(const Shape & shape,int32 bits)326 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
327 int32 bits) {
328 return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
329 }
330
ElementHasBitWidth(const Shape & shape,int bits)331 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
332 if (!shape.IsArray()) {
333 return false;
334 }
335 return primitive_util::BitWidth(shape.element_type()) == bits;
336 }
337
ElementIsSigned(const Shape & shape)338 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
339 switch (shape.element_type()) {
340 case S8:
341 case S16:
342 case S32:
343 case S64:
344 case F16:
345 case BF16:
346 case F32:
347 case F64:
348 return true;
349
350 case PRED:
351 case U8:
352 case U16:
353 case U32:
354 case U64:
355 case C64:
356 case C128:
357 case TUPLE:
358 case OPAQUE_TYPE:
359 case TOKEN:
360 return false;
361
362 default:
363 LOG(FATAL) << "Unhandled element type " << shape.element_type();
364 }
365 }
366
ElementIsComplex(const Shape & shape)367 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
368 return primitive_util::IsComplexType(shape.element_type());
369 }
370
ElementIsFloating(const Shape & shape)371 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
372 return primitive_util::IsFloatingPointType(shape.element_type());
373 }
374
IsNestedTuple(const Shape & shape)375 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
376 return shape.IsTuple() &&
377 absl::c_any_of(shape.tuple_shapes(),
378 [](const Shape& s) { return s.IsTuple(); });
379 }
380
IsEmptyTuple(const Shape & shape)381 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
382 return shape.IsTuple() && TupleElementCount(shape) == 0;
383 }
384
TupleElementCount(const Shape & shape)385 /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
386 CHECK(shape.IsTuple()) << HumanString(shape);
387 return shape.tuple_shapes_size();
388 }
389
GetTupleElementShape(const Shape & shape,int64 index)390 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
391 int64 index) {
392 CHECK(shape.IsTuple());
393 CHECK_GT(TupleElementCount(shape), index);
394 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
395 return shape.tuple_shapes(index);
396 }
397
SubshapeCount(const Shape & shape)398 /* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) {
399 int64 n = 0;
400 ForEachSubshape(shape, [&](const Shape& literal_subshape,
401 const ShapeIndex& index) { ++n; });
402 return n;
403 }
404
SliceTuple(const Shape & tuple,int64 start,int64 limit)405 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
406 int64 limit) {
407 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
408 CHECK(tuple.IsTuple());
409 CHECK_LE(start, TupleElementCount(tuple));
410 CHECK_LE(limit, TupleElementCount(tuple));
411
412 std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
413 tuple.tuple_shapes().begin() + limit);
414 return MakeTupleShape(new_elements);
415 }
416
417 // Returns the shape of a real or imaginary component.
ComplexComponentShape(const Shape & complex_shape)418 /* static */ Shape ShapeUtil::ComplexComponentShape(
419 const Shape& complex_shape) {
420 CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
421 return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
422 complex_shape.element_type()));
423 }
424
ElementsIn(const Shape & shape)425 /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
426 DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape);
427 DCHECK_EQ(shape.dimensions_size(), shape.rank());
428 if (shape.dimensions().size() == 1) {
429 return shape.dimensions()[0];
430 }
431 return std::accumulate<decltype(shape.dimensions().begin()), int64>(
432 shape.dimensions().begin(), shape.dimensions().end(), 1LL,
433 std::multiplies<int64>());
434 }
435
ElementsInRecursive(const Shape & shape)436 /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) {
437 CHECK(shape.IsArray() || shape.IsTuple());
438 if (shape.IsArray()) {
439 return ElementsIn(shape);
440 }
441 int64 count = 0;
442 for (const Shape& element_shape : shape.tuple_shapes()) {
443 count += ElementsInRecursive(element_shape);
444 }
445 return count;
446 }
447
HasPrimitiveType(const Shape & shape,PrimitiveType primitive_type)448 /* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
449 PrimitiveType primitive_type) {
450 if (shape.element_type() == primitive_type) {
451 return true;
452 }
453 for (const Shape& element_shape : shape.tuple_shapes()) {
454 if (HasPrimitiveType(element_shape, primitive_type)) {
455 return true;
456 }
457 }
458 return false;
459 }
460
IsZeroElementArray(const Shape & shape)461 /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
462 return shape.IsArray() && ElementsIn(shape) == 0;
463 }
464
IsScalarWithElementType(const Shape & shape,PrimitiveType element_type)465 /* static */ bool ShapeUtil::IsScalarWithElementType(
466 const Shape& shape, PrimitiveType element_type) {
467 return IsScalar(shape) && shape.element_type() == element_type;
468 }
469
HumanString(const Shape & shape)470 /* static */ string ShapeUtil::HumanString(const Shape& shape) {
471 if (shape.IsTuple()) {
472 string text = "(";
473 const char* prefix = "";
474 for (const Shape& elem_shape : shape.tuple_shapes()) {
475 StrAppend(&text, prefix, HumanString(elem_shape));
476 prefix = ", ";
477 }
478 text += ")";
479 return text;
480 }
481 std::vector<string> dim_elements;
482 for (int i = 0; i < shape.dimensions_size(); ++i) {
483 if (shape.is_dynamic_dimension(i)) {
484 dim_elements.push_back(StrCat("<=", shape.dimensions(i)));
485 } else {
486 dim_elements.push_back(StrCat(shape.dimensions(i)));
487 }
488 }
489 return StrCat(
490 primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
491 absl::StrJoin(dim_elements, ","), "]");
492 }
493
HumanStringWithLayout(const Shape & shape)494 /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
495 if (shape.IsTuple()) {
496 string text = "(";
497 const char* prefix = "";
498 for (const Shape& elem_shape : shape.tuple_shapes()) {
499 StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
500 prefix = ", ";
501 }
502 text += ")";
503 return text;
504 }
505 string result = StrCat(
506 primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[");
507 for (int i = 0; i < shape.dimensions().size(); i++) {
508 StrAppend(&result, (i > 0) ? "," : "",
509 shape.is_dynamic_dimension(i) ? "<=" : "", shape.dimensions(i));
510 }
511 result += "]";
512 if (IsScalar(shape)) {
513 string layout_str = LayoutUtil::HumanString(shape.layout());
514 // Don't print "{}" as layout for scalars.
515 if (layout_str != "{}") {
516 StrAppend(&result, layout_str);
517 }
518 } else if (shape.IsArray() && LayoutUtil::HasLayout(shape)) {
519 StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
520 }
521 return result;
522 }
523
HumanString(const ProgramShape & program_shape)524 /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
525 std::vector<string> parameters;
526 for (auto& shape : program_shape.parameters()) {
527 const int i = parameters.size();
528 parameters.push_back(StrCat(i < program_shape.parameter_names_size()
529 ? program_shape.parameter_names(i)
530 : "(unknown)",
531 ": ", HumanString(shape)));
532 }
533 return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
534 HumanString(program_shape.result()));
535 }
536
SameDimensions(const Shape & lhs,const Shape & rhs)537 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
538 const Shape& rhs) {
539 CHECK(lhs.IsArray());
540 CHECK(rhs.IsArray());
541 return absl::c_equal(lhs.dimensions(), rhs.dimensions());
542 }
543
Compatible(const Shape & lhs,const Shape & rhs)544 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
545 return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
546 }
547
CompatibleIgnoringElementType(const Shape & lhs,const Shape & rhs)548 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
549 const Shape& rhs) {
550 return Shape::Equal()
551 .IgnoreDynamicDimension()
552 .IgnoreElementType()
553 .IgnoreLayout()(lhs, rhs);
554 }
555
CompatibleIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)556 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
557 const Shape& rhs) {
558 return Shape::Equal()
559 .IgnoreDynamicDimension()
560 .IgnoreFpPrecision()
561 .IgnoreLayout()(lhs, rhs);
562 }
563
GetDimension(const Shape & shape,int64 dimension_number)564 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
565 int64 dimension_number) {
566 return shape.dimensions(GetDimensionNumber(shape, dimension_number));
567 }
568
GetDimensionNumber(const Shape & shape,int64 dimension_number)569 /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
570 int64 dimension_number) {
571 if (dimension_number < 0) {
572 dimension_number += shape.rank();
573 }
574 CHECK_GE(dimension_number, 0);
575 return dimension_number;
576 }
577
ByteSizeOfPrimitiveType(PrimitiveType primitive_type)578 /* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
579 PrimitiveType primitive_type) {
580 switch (primitive_type) {
581 case PRED:
582 return sizeof(int8);
583 case S8:
584 return sizeof(int8);
585 case S16:
586 return sizeof(int16);
587 case S32:
588 return sizeof(int32);
589 case S64:
590 return sizeof(int64);
591 case U8:
592 return sizeof(uint8);
593 case U16:
594 return sizeof(uint16);
595 case U32:
596 return sizeof(uint32);
597 case U64:
598 return sizeof(uint64);
599 case BF16:
600 return sizeof(float) / 2;
601 case F16:
602 return sizeof(float) / 2;
603 case F32:
604 return sizeof(float);
605 case F64:
606 return sizeof(double);
607 case C64:
608 return sizeof(complex64);
609 case C128:
610 return sizeof(complex128);
611 case TOKEN:
612 // Tokens require no space.
613 return 0;
614 case TUPLE:
615 case OPAQUE_TYPE:
616 LOG(FATAL) << PrimitiveType_Name(primitive_type)
617 << " primitive type has no definitive size";
618 default:
619 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
620 }
621 }
622
ByteSizeOf(const Shape & shape,int64 pointer_size)623 /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
624 int64 pointer_size) {
625 TF_DCHECK_OK(ValidateShape(shape));
626 if (shape.element_type() == TUPLE) {
627 return ByteSizeOfTupleIndexTable(shape, pointer_size);
628 } else if (shape.IsArray()) {
629 int64 byte_size = ByteSizeOfElements(shape);
630 return byte_size;
631 } else if (shape.element_type() == TOKEN) {
632 return 0;
633 } else if (shape.element_type() == OPAQUE_TYPE) {
634 CHECK_GT(pointer_size, 0);
635 return pointer_size;
636 }
637 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
638 << " primitive type has no definitive size";
639 }
640
ByteSizeOfTupleIndexTable(const Shape & shape,int64 pointer_size)641 /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
642 int64 pointer_size) {
643 TF_DCHECK_OK(ValidateShape(shape));
644 CHECK_EQ(TUPLE, shape.element_type());
645 CHECK_GT(pointer_size, 0);
646 return pointer_size * shape.tuple_shapes_size();
647 }
648
ByteSizeOfElements(const Shape & shape)649 /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
650 TF_DCHECK_OK(ValidateShape(shape));
651 CHECK(shape.IsArray());
652 int64 allocated_element_count;
653
654 CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
655 allocated_element_count = ElementsIn(shape);
656 return allocated_element_count *
657 ByteSizeOfPrimitiveType(shape.element_type());
658 }
659
ValidateShapeWithOptionalLayoutInternal(const Shape & shape)660 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
661 const Shape& shape) {
662 if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
663 !PrimitiveType_IsValid(shape.element_type())) {
664 return InvalidArgument("shape has invalid element type: %s",
665 shape.ShortDebugString());
666 }
667 if (shape.element_type() == TUPLE) {
668 if (shape.dimensions_size() != 0) {
669 return InvalidArgument("tuples must not have dimensions specified");
670 }
671 for (auto& element_shape : shape.tuple_shapes()) {
672 TF_RETURN_IF_ERROR(
673 ValidateShapeWithOptionalLayoutInternal(element_shape));
674 }
675 return Status::OK();
676 }
677
678 // Non-tuple shape.
679 if (shape.tuple_shapes_size() > 0) {
680 return InvalidArgument("non-tuple shape has tuple_shapes field");
681 }
682
683 // Tokens and opaques can should not have layout or dimensions.
684 if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) {
685 if (shape.dimensions_size() != 0) {
686 return InvalidArgument(
687 "shape has %s element type, but has dimensions field: %s",
688 primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
689 shape.ShortDebugString());
690 }
691 if (shape.has_layout()) {
692 return InvalidArgument(
693 "shape has %s element type, but has layout field: %s",
694 primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
695 shape.ShortDebugString());
696 }
697 return Status::OK();
698 }
699
700 for (int64 i = 0; i < shape.rank(); ++i) {
701 int64 dimension = shape.dimensions(i);
702 if (dimension < 0) {
703 return InvalidArgument(
704 "shape's dimensions must not be < 0; dimension at index %d was %d", i,
705 dimension);
706 }
707 }
708
709 TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
710 return Status::OK();
711 }
712
ValidateShapeSize(const Shape & shape)713 /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
714 VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
715
716 if (!shape.IsArray()) {
717 return Status::OK();
718 }
719
720 int64 shape_size = [&]() {
721 int64 dense_shape_size = 1;
722 if (shape.dimensions().empty()) {
723 return dense_shape_size;
724 }
725
726 absl::Span<const int64> shape_max_dimensions =
727 AsInt64Slice(shape.dimensions());
728 for (int64 dim : shape_max_dimensions) {
729 dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
730 if (dense_shape_size < 0) {
731 return dense_shape_size;
732 }
733 }
734 dense_shape_size = MultiplyWithoutOverflow(
735 dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
736 return dense_shape_size;
737 }();
738
739 if (shape_size < 0) {
740 return InvalidArgument("Shape %s size may overflow int64.",
741 ShapeUtil::HumanString(shape));
742 }
743
744 VLOG(3) << "Shape size is valid: " << shape_size;
745 return Status::OK();
746 }
747
ValidateShapeWithOptionalLayout(const Shape & shape)748 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
749 const Shape& shape) {
750 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
751
752 return LayoutUtil::ValidateLayoutInShape(shape,
753 /*allow_missing_layouts=*/true);
754 }
755
ValidateShape(const Shape & shape)756 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
757 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
758
759 return LayoutUtil::ValidateLayoutInShape(shape);
760 }
761
ChangeElementType(const Shape & original,PrimitiveType type)762 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
763 PrimitiveType type) {
764 Shape new_shape = original;
765 new_shape.set_element_type(type);
766 return new_shape;
767 }
768
IndexIsValid(const Shape & shape,ShapeIndexView index)769 /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
770 ShapeIndexView index) {
771 const Shape* subshape = &shape;
772 for (auto i : index) {
773 if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) {
774 return false;
775 }
776 subshape = &subshape->tuple_shapes(i);
777 }
778 return true;
779 }
780
GetSubshape(const Shape & shape,ShapeIndexView index)781 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
782 ShapeIndexView index) {
783 const Shape* return_shape = &shape;
784 for (auto i : index) {
785 CHECK(return_shape->IsTuple())
786 << "Invalid index " << index << " for shape " << shape;
787 return_shape = &return_shape->tuple_shapes(i);
788 }
789 return *return_shape;
790 }
791
TryGetSubshape(const Shape & shape,ShapeIndexView index)792 /* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
793 const Shape& shape, ShapeIndexView index) {
794 const Shape* return_shape = &shape;
795 for (auto i : index) {
796 if (!return_shape->IsTuple() || i < 0 ||
797 i >= return_shape->tuple_shapes_size()) {
798 return InvalidArgument(
799 "Shape index %s not a valid subshape index for tuple with shape %s",
800 index.ToString(), shape.DebugString());
801 }
802 return_shape = &return_shape->tuple_shapes(i);
803 }
804 return return_shape;
805 }
806
GetMutableSubshape(Shape * shape,ShapeIndexView index)807 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
808 ShapeIndexView index) {
809 Shape* return_shape = shape;
810 for (auto i : index) {
811 CHECK(return_shape->IsTuple());
812 return_shape = return_shape->mutable_tuple_shapes(i);
813 }
814 return return_shape;
815 }
816
817 /* static */
IsLeafIndex(const Shape & shape,const ShapeIndex & index)818 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
819 return !GetSubshape(shape, index).IsTuple();
820 }
821
GetLeafCount(const Shape & shape)822 /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
823 if (!shape.IsTuple()) {
824 return 1;
825 }
826 int64 count = 0;
827 for (const Shape& subshape : shape.tuple_shapes()) {
828 count += GetLeafCount(subshape);
829 }
830 return count;
831 }
832
GetLeafShapes(const Shape & shape)833 /* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
834 const Shape& shape) {
835 std::vector<IndexedShape> leaves;
836 ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
837 if (IsLeafIndex(shape, index)) {
838 leaves.emplace_back(index, sub_shape);
839 }
840 });
841 return leaves;
842 }
843
HasDegenerateDimensions(const Shape & shape)844 /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
845 CHECK(shape.IsArray());
846 return absl::c_linear_search(shape.dimensions(), 1);
847 }
848
DropDegenerateDimensions(const Shape & shape)849 /* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
850 return FilterDimensions(
851 [&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
852 }
853
854 namespace {
855
856 // Helper for ForEachSubshape which visits the subshapes of the given shape in
857 // DFS pre-order starting with the index.
ForEachSubshapeHelper(const Shape & shape,const ShapeUtil::StatusVisitorFunction & func,ShapeIndex * index)858 Status ForEachSubshapeHelper(const Shape& shape,
859 const ShapeUtil::StatusVisitorFunction& func,
860 ShapeIndex* index) {
861 TF_RETURN_IF_ERROR(func(shape, *index));
862 if (shape.IsTuple()) {
863 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
864 index->push_back(i);
865 TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
866 ShapeUtil::GetTupleElementShape(shape, i), func, index));
867 index->pop_back();
868 }
869 }
870 return Status::OK();
871 }
872
873 // Helper for ForEachMutableSubshape which visits the subshapes of the given
874 // shape in DFS pre-order starting with the index.
ForEachMutableSubshapeHelper(Shape * shape,const ShapeUtil::MutatingStatusVisitorFunction & func,ShapeIndex * index)875 Status ForEachMutableSubshapeHelper(
876 Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
877 ShapeIndex* index) {
878 TF_RETURN_IF_ERROR(func(shape, *index));
879 if (shape->IsTuple()) {
880 for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
881 index->push_back(i);
882 TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
883 shape->mutable_tuple_shapes(i), func, index));
884 index->pop_back();
885 }
886 }
887 return Status::OK();
888 }
889
890 } // namespace
891
ForEachSubshape(const Shape & shape,const VisitorFunction & func)892 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
893 const VisitorFunction& func) {
894 ShapeIndex index;
895 ForEachSubshapeHelper(
896 shape,
897 [&func](const Shape& subshape, const ShapeIndex& index) {
898 func(subshape, index);
899 return Status::OK();
900 },
901 &index)
902 .IgnoreError();
903 }
904
ForEachMutableSubshape(Shape * shape,const MutatingVisitorFunction & func)905 /* static */ void ShapeUtil::ForEachMutableSubshape(
906 Shape* shape, const MutatingVisitorFunction& func) {
907 ShapeIndex index;
908 ForEachMutableSubshapeHelper(
909 shape,
910 [&func](Shape* subshape, const ShapeIndex& index) {
911 func(subshape, index);
912 return Status::OK();
913 },
914 &index)
915 .IgnoreError();
916 }
917
ForEachSubshapeWithStatus(const Shape & shape,const StatusVisitorFunction & func)918 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
919 const Shape& shape, const StatusVisitorFunction& func) {
920 ShapeIndex index;
921 return ForEachSubshapeHelper(shape, func, &index);
922 }
923
ForEachMutableSubshapeWithStatus(Shape * shape,const MutatingStatusVisitorFunction & func)924 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
925 Shape* shape, const MutatingStatusVisitorFunction& func) {
926 ShapeIndex index;
927 return ForEachMutableSubshapeHelper(shape, func, &index);
928 }
929
PermuteDimensions(absl::Span<const int64> permutation,const Shape & shape)930 /* static */ Shape ShapeUtil::PermuteDimensions(
931 absl::Span<const int64> permutation, const Shape& shape) {
932 Shape new_shape = shape;
933 new_shape.clear_dimensions();
934 for (auto dim : Permute(permutation, shape.dimensions())) {
935 new_shape.add_dimensions(dim);
936 }
937 for (int64 i = 0; i < shape.rank(); i++) {
938 new_shape.set_dynamic_dimension(permutation[i],
939 shape.is_dynamic_dimension(i));
940 }
941
942 // If `shape` has a layout, by contract we choose a new layout such that the
943 // transpose defined by this permutation is a bitcast.
944 //
945 // Some formalism helps to understand the correct way to do this. We're going
946 // to do algebra in the group of permutations of the dimensions of `shape`.
947 //
948 // Since the order of `shape`'s dimensions is not permuted relative to itself,
949 // `shape`'s list of dimensions is isomorphic to the identity I.
950 //
951 // Let `shape`'s layout be L. A layout is a permutation which maps a
952 // minor-to-major physical layout to the order of a shape's logical dims.
953 // Therefore inverse of a layout maps from logical to physical dims, and so
954 // the physical layout of I is simply L'.I = L', where L' is the inverse of L.
955 //
956 // Let the argument `permutation` be P. This is a permutation over `shape`'s
957 // dimensions, so our return value will be a shape with dims P.I = P. Our
958 // goal is to construct a layout permutation L* that we can apply to P such
959 // that the physical dimension ordering of the returned shape is the same
960 // as that of the original shape, namely L'.
961 //
962 // Our returned shape has dims P and layout L*, so its in-memory layout is
963 // L*'.P. Setting this equal to L' and solving for L*, we get:
964 //
965 // L*'.P = L' =>
966 // L*' = L'P' =>
967 // L* = P.L
968 //
969 if (shape.has_layout()) {
970 CHECK(LayoutUtil::IsDenseArray(shape));
971 Layout* new_layout = new_shape.mutable_layout();
972 new_layout->set_format(DENSE);
973 new_layout->clear_minor_to_major();
974 for (auto index : ComposePermutations(
975 permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
976 new_layout->add_minor_to_major(index);
977 }
978 // The permutation accepted by TransposeIsBitcast is the inverse of the
979 // permutation here.
980 CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation)))
981 << "shape=" << HumanStringWithLayout(shape)
982 << ", new_shape=" << HumanStringWithLayout(new_shape)
983 << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
984 }
985 return new_shape;
986 }
987
988 /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
InsertedOrDeleted1SizedDimensions(const Shape & shape_pre,const Shape & shape_post)989 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
990 const Shape& shape_post) {
991 CHECK(shape_pre.IsArray());
992 CHECK(shape_post.IsArray());
993
994 auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
995
996 std::vector<int64> deleted_indices;
997 std::vector<int64> inserted_indices;
998 // Returns false if any input/output index between prior_unmodified_dim_pair
999 // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
1000 // the degerenate input/output dimensions in the gap to
1001 // deleted_indices/inserted_indices respectively.
1002 auto check_modified_dims =
1003 [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
1004 std::pair<int64, int64> prior_unmodified_dim_pair,
1005 std::pair<int64, int64> unmodified_dim_pair) {
1006 for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
1007 modified_input_dim < unmodified_dim_pair.first;
1008 ++modified_input_dim) {
1009 if (shape_pre.dimensions(modified_input_dim) > 1) {
1010 return false;
1011 }
1012 deleted_indices.push_back(modified_input_dim);
1013 }
1014 for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
1015 modified_output_dim < unmodified_dim_pair.second;
1016 ++modified_output_dim) {
1017 if (shape_post.dimensions(modified_output_dim) > 1) {
1018 return false;
1019 }
1020 inserted_indices.push_back(modified_output_dim);
1021 }
1022 return true;
1023 };
1024
1025 std::vector<std::pair<int64, int64>> unmodified_dims =
1026 DimensionsUnmodifiedByReshape(shape_pre, shape_post);
1027 // Returns nil if the reshape modifies any non-degenerate input/output
1028 // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
1029 // dimensions, so we only need to check whether dimensions in the gaps (thus
1030 // modified) have size >1.
1031 for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
1032 // Check (modified) dimensions between unmodified_dims[i-1] and
1033 // unmodified_dims[i].
1034 auto prior_unmodified_dim_pair =
1035 i > 0 ? unmodified_dims[i - 1] : std::pair<int64, int64>(-1, -1);
1036 auto unmodified_dim_pair =
1037 i < unmodified_dims.size()
1038 ? unmodified_dims[i]
1039 : std::make_pair(shape_pre.rank(), shape_post.rank());
1040 if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
1041 return nil;
1042 }
1043 }
1044
1045 return std::make_tuple(true, deleted_indices, inserted_indices);
1046 }
1047
1048 /* static */ std::vector<std::pair<int64, int64>>
DimensionsUnmodifiedByReshape(const Shape & input_shape,const Shape & output_shape)1049 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
1050 const Shape& output_shape) {
1051 CHECK(input_shape.IsArray());
1052 CHECK(output_shape.IsArray());
1053
1054 // Unmodified dimensions are merely common factors of rank 1.
1055 auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
1056 AsInt64Slice(output_shape.dimensions()));
1057 for (size_t i = 0; i < common_factors.size() - 1;) {
1058 if (1 != common_factors[i + 1].first - common_factors[i].first ||
1059 1 != common_factors[i + 1].second - common_factors[i].second) {
1060 common_factors.erase(common_factors.begin() + i);
1061 } else {
1062 ++i;
1063 }
1064 }
1065 // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
1066 common_factors.pop_back();
1067 return std::vector<std::pair<int64, int64>>(common_factors.begin(),
1068 common_factors.end());
1069 }
1070
1071 /* static */ absl::optional<std::vector<int64>>
ReshapeLeavesDimensionsUnmodified(const Shape & from_shape,const Shape & to_shape,absl::Span<const int64> input_dim_indices)1072 ShapeUtil::ReshapeLeavesDimensionsUnmodified(
1073 const Shape& from_shape, const Shape& to_shape,
1074 absl::Span<const int64> input_dim_indices) {
1075 CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
1076
1077 std::vector<int64> output_dim_indices;
1078 std::vector<std::pair<int64, int64>> unmodified_dims =
1079 ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
1080 size_t i = 0; // index to unmodified_dims
1081 for (int64 input_dim_index : input_dim_indices) {
1082 // Search unmodified_dims for input_dim_index. We can search from the last
1083 // matching position because input_dim_indices is guaranteed to be sorted.
1084 while (i < unmodified_dims.size() &&
1085 unmodified_dims[i].first < input_dim_index) {
1086 ++i;
1087 }
1088 if (i >= unmodified_dims.size() ||
1089 unmodified_dims[i].first != input_dim_index) {
1090 return absl::nullopt;
1091 }
1092 output_dim_indices.push_back(unmodified_dims[i].second);
1093 }
1094 return output_dim_indices;
1095 }
1096
TransposeIsBitcast(const Shape & input_shape,const Shape & output_shape,absl::Span<const int64> dimension_mapping)1097 /* static */ bool ShapeUtil::TransposeIsBitcast(
1098 const Shape& input_shape, const Shape& output_shape,
1099 absl::Span<const int64> dimension_mapping) {
1100 CHECK(LayoutUtil::HasLayout(input_shape) &&
1101 LayoutUtil::HasLayout(output_shape));
1102
1103 if (!SameElementType(input_shape, output_shape)) {
1104 return false;
1105 }
1106
1107 // Check the reshape permutes the positions of each dimension in the
1108 // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
1109 // input_positions = apply(dimension_mapping, output_positions)
1110 //
1111 // Because the positions of each dimension are the inverse permutation of the
1112 // minor-to-major order, the above check is equivalent to
1113 // inverse(input_dimensions) =
1114 // apply(dimension_mapping, inverse(output_dimensions))
1115 // # `I` indicates identity permutation.
1116 // apply(input_dimensions, I) =
1117 // apply(dimension_mapping, apply(output_dimensions, I))
1118 // apply(input_dimensions, I) =
1119 // apply((dimension_mapping * output_dimensions), I)
1120 // input_dimensions = dimension_mapping * output_dimensions
1121 return absl::c_equal(
1122 ComposePermutations(dimension_mapping,
1123 AsInt64Slice(output_shape.layout().minor_to_major())),
1124 input_shape.layout().minor_to_major());
1125 }
1126
ReshapeIsBitcast(const Shape & input_shape,const Shape & output_shape)1127 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
1128 const Shape& output_shape) {
1129 CHECK(input_shape.IsArray());
1130 CHECK(output_shape.IsArray());
1131 CHECK(LayoutUtil::HasLayout(input_shape));
1132 CHECK(LayoutUtil::HasLayout(output_shape));
1133
1134 if (!SameElementType(input_shape, output_shape)) {
1135 return false;
1136 }
1137
1138 CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape));
1139 if (ElementsIn(input_shape) == 0) {
1140 return true;
1141 }
1142
1143 // TL;DR: The rest of the method checks that the reshape does not change the
1144 // physical location of any unit input or output index. Unit indices have
1145 // exactly one dimension that equals 1 and other dimensions 0. This condition
1146 // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
1147 // reshape shouldn't change the physical location of any element. It is also a
1148 // sufficient condition as is proved below (note: many details are omitted for
1149 // space).
1150 //
1151 // Definitions:
1152 //
1153 // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
1154 // the size of i-th least significant dimension of IS or OS (this is opposite
1155 // to how we define the index of Shape::dimensions()).
1156 //
1157 // * Given an input or output index I, denote by p(I) I's physical linear
1158 // index (or physical index for short) and l(I) I's logical linear index (or
1159 // logical index for short).
1160 //
1161 // * Given a logical index k, denote by II(k) the input index whose linear
1162 // index is k, and OI(k) the corresponding output index.
1163 //
1164 // * Denote by IT[i] the increment of physical index if i-th dimension of the
1165 // input index is increased by 1. Similarly, OT[i] means the increment if i-th
1166 // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
1167 // is a function of IS or OS and the layout, and not dependent on the specific
1168 // input or output index.
1169 //
1170 // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
1171 // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
1172 // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
1173 // to prove, with every increment on k, the above formula still holds.
1174 //
1175 // First, suppose reshaping from IS to OS is non-factorizable (we discuss
1176 // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
1177 // there exists (i,j) such that
1178 //
1179 // 0<=i<=|IS|
1180 // 0<=j<=|OS|
1181 // |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
1182 // product(IS[i], IS[i+1], ..., IS[|IS|-1])
1183 // = product(OS[j], OS[j+1], ..., OS[|OS|-1])
1184 //
1185 // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
1186 // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
1187 // unit indices which are already tested. This also means IT[0]=OT[0]
1188 // because p(II(1))=IT[0] and p(OI(1))=OT[0].
1189 //
1190 // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
1191 // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
1192 // to the output physical.
1193 //
1194 // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
1195 // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
1196 // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
1197 // logical index k-1 to logical index k, dimension 1 of the input index
1198 // is increased by 1 and dimension 0 is reset to 0 thus decreased by
1199 // IS[0]-1. Therefore, the physical input index is increased by
1200 //
1201 // p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
1202 //
1203 // Because IS[0]<OS[0], the only change to the output index is that its
1204 // dimension 0 is increased by one. Therefore,
1205 //
1206 // p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
1207 //
1208 // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
1209 // p(II(k))=p(OI(k)). Therefore,
1210 // IT[1] - (IS[0]-1) * IT[0] = IT[0]
1211 // IT[1] = IS[0] * IT[0]
1212 // In other words, input dimension 1 is immediately more major than input
1213 // dimension 0. We can now conceptually collapse these two dimensions because
1214 // an increment in the logical index affecting only these two dimensions maps
1215 // to IT[0] in the physical index.
1216 //
1217 // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
1218 // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
1219 // identical.
1220 //
1221 // A factorizable reshape can be factorized into a list of non-factorizable
1222 // sub-reshapes, each of which can be handled similarly to the proof above.
1223 // For example,
1224 //
1225 // [7x9x2x15] -> [63x6x5]
1226 //
1227 // can be factorized into
1228 //
1229 // [7x9] -> [63] and [2x15] -> [6x5].
1230 //
1231 // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
1232 // same logical linear index. According to the factorization, we know
1233 // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
1234 // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
1235 // similar proof, with the increment of the logical index set to
1236 // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
1237 // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
1238 //
1239 // p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
1240 // = p(y2,0,0) + p(0,0,y1,y0)
1241 // = p(y2,y1,y0)
1242 //
1243 // check_input_unit_indices checks one way of the condition: each input unit
1244 // index is mapped to an output index with the same physical location. This
1245 // lambda will be called again with input_shape and output_shape reversed to
1246 // check the other way.
1247 auto check_input_unit_indices = [](const Shape& input_shape,
1248 const Shape& output_shape) {
1249 // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
1250 // as input_shape/output_shape and the dimension-0-major layout. These two
1251 // shapes are used for conversion between logical linear indices and
1252 // multi-dimensional indices.
1253 Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
1254 input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
1255 Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
1256 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
1257
1258 for (int64 input_dim = 0; input_dim < input_shape.rank(); ++input_dim) {
1259 if (input_shape.dimensions(input_dim) <= 1) {
1260 continue;
1261 }
1262
1263 std::vector<int64> input_unit_index(input_shape.rank(), 0);
1264 input_unit_index[input_dim] = 1;
1265 int64 logical_linear_index =
1266 IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
1267 input_unit_index);
1268 // output_index has the same logical linear index as input_unit_index.
1269 std::vector<int64> output_index =
1270 IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
1271 logical_linear_index);
1272 // Check input_unit_index and output_index have the same physical linear
1273 // index.
1274 if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
1275 input_unit_index) !=
1276 IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
1277 output_index)) {
1278 return false;
1279 }
1280 }
1281 return true;
1282 };
1283 return check_input_unit_indices(input_shape, output_shape) &&
1284 check_input_unit_indices(output_shape, input_shape);
1285 }
1286
AlignLayouts(const Shape & input_shape,const Shape & output_shape)1287 /* static */ absl::optional<Shape> ShapeUtil::AlignLayouts(
1288 const Shape& input_shape, const Shape& output_shape) {
1289 CHECK(input_shape.IsArray());
1290 CHECK(output_shape.IsArray());
1291 // Removing trivial dimensions from the shape simplifies the alignment
1292 // algorithm since ones can go in any position.
1293 if (HasDegenerateDimensions(input_shape) ||
1294 HasDegenerateDimensions(output_shape)) {
1295 auto simple_output_shape =
1296 AlignLayouts(DropDegenerateDimensions(input_shape),
1297 DropDegenerateDimensions(output_shape));
1298 if (!simple_output_shape) {
1299 return absl::nullopt;
1300 }
1301
1302 std::vector<int64> layout =
1303 SpanToVector(simple_output_shape->layout().minor_to_major());
1304 // For each one sized dimension in the output, increment the dimension
1305 // numbers in layout that are more minor than the one.
1306 absl::InlinedVector<int64, 8> dim_map;
1307 dim_map.reserve(simple_output_shape->rank());
1308 for (int64 i = 0; i < output_shape.rank(); ++i) {
1309 if (output_shape.dimensions(i) != 1) {
1310 dim_map.push_back(i);
1311 }
1312 }
1313 for (int64& d : layout) {
1314 d = dim_map[d];
1315 }
1316
1317 // Add the ones in descending order to the layout. Descending layouts tend
1318 // to reduce the number of copies inserted in layout assignment.
1319 for (int64 i = output_shape.rank() - 1; i >= 0; --i) {
1320 if (output_shape.dimensions(i) == 1) {
1321 layout.push_back(i);
1322 }
1323 }
1324 Shape output_shape_with_layout = output_shape;
1325 *output_shape_with_layout.mutable_layout() = Layout{layout};
1326 return output_shape_with_layout;
1327 }
1328
1329 int64 input_rank = input_shape.rank();
1330 int64 output_rank = output_shape.rank();
1331
1332 // First, calculate an alignment of the dimensions. A consecutive sequence of
1333 // input dimensions and output dimensions belong to the same alignment part if
1334 // the products of their dimension bounds are the same. In the easiest case,
1335 // an alignment part consists of one input dimension and one output dimension
1336 // which both have the same dimension bound. An alignment part specifies which
1337 // dimensions need to be kept together in a physical layout if we want a
1338 // reshape to be a bitcast. The order of the alignment parts is defined by the
1339 // physical layout of the input shape, so when we construct the layout for the
1340 // output shape we just process the alignment parts in this order, and then
1341 // layout the dimensions belonging to each part in descending (major to minor)
1342 // order.
1343
1344 // Stores the input and output dimension numbers where each alignment part
1345 // starts.
1346 std::vector<std::pair<int64, int64>> alignment;
1347 alignment.push_back({0, 0});
1348
1349 // Stores a mapping from the input dimension to the alignment part it belongs
1350 // to.
1351 std::vector<int64> dimension_to_alignment_index(input_rank);
1352 int64 input_dimension_product = 1, output_dimension_product = 1;
1353 for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) {
1354 // Check if we have reached the end of an alignment part.
1355 if (input_dimension_product == output_dimension_product &&
1356 input_dimension_product > 1) {
1357 alignment.push_back({i, j});
1358 input_dimension_product = output_dimension_product = 1;
1359 }
1360 if (input_dimension_product < output_dimension_product ||
1361 j == output_rank) {
1362 if (i == input_rank) {
1363 return absl::nullopt;
1364 }
1365 dimension_to_alignment_index[i] = alignment.size() - 1;
1366 input_dimension_product *= input_shape.dimensions(i);
1367 ++i;
1368 } else {
1369 output_dimension_product *= output_shape.dimensions(j);
1370 ++j;
1371 }
1372 }
1373 if (input_dimension_product != output_dimension_product) {
1374 return absl::nullopt;
1375 }
1376
1377 // We also need to store an end element so that we know where the last
1378 // alignment part ends.
1379 alignment.push_back({input_rank, output_rank});
1380 // Now check if the physical layout can potentially be aligned to the output
1381 // shape by changing the physical layout of the output shape. We need to check
1382 // that all dimension numbers that belong to the same alignment part appear
1383 // consecutively, and are in descending order. However we can ignore any
1384 // trivial dimension bounds of 1, because they can be placed anywhere.
1385 auto input_dimension_numbers = input_shape.layout().minor_to_major();
1386 std::vector<int64> output_layout;
1387 output_layout.reserve(output_rank);
1388 for (int64 i = 0; i < input_rank;) {
1389 int64 current_dimension_number = input_dimension_numbers[i];
1390
1391 // Trivial dimensions are stripped.
1392 CHECK_NE(input_shape.dimensions(current_dimension_number), 1);
1393 const int64 current_alignment_index =
1394 dimension_to_alignment_index[current_dimension_number];
1395 // Because of the special end element that we added, we can be sure that
1396 // 'current_alignment_index' is < alignment.size() - 1.
1397 CHECK_LT(current_alignment_index, alignment.size() - 1);
1398
1399 // Check that the following 'num_non_trivial_dimensions_in_alignment_part'
1400 // dimension numbers (ignoring dimension numbers with dimension bound 1) are
1401 // in descending order and belong to the current alignment part.
1402 for (int64 j = 0; j < alignment[current_alignment_index + 1].first -
1403 alignment[current_alignment_index].first;
1404 ++i, ++j) {
1405 if (i == input_rank) {
1406 return absl::nullopt;
1407 }
1408 // If the current dimension number belongs to a different alignment part,
1409 // or the dimension numbers are not in descending order, we can return
1410 // early.
1411 if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
1412 current_alignment_index ||
1413 input_dimension_numbers[i] > current_dimension_number) {
1414 return absl::nullopt;
1415 }
1416 current_dimension_number = input_dimension_numbers[i];
1417 }
1418 // The output dimension numbers that belong to the current alignment part
1419 // need to appear in the same descending order as in the input.
1420 for (int64 j = alignment[current_alignment_index + 1].second - 1;
1421 j >= alignment[current_alignment_index].second; --j) {
1422 output_layout.push_back(j);
1423 }
1424 }
1425 CHECK_EQ(output_layout.size(), output_rank);
1426 Shape output_shape_with_layout = MakeShapeWithLayout(
1427 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1428 output_layout);
1429 CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout))
1430 << "reshape is not a bitcast for input_shape: "
1431 << ShapeUtil::HumanStringWithLayout(input_shape)
1432 << " and output_shape_with_layout: "
1433 << ShapeUtil::HumanStringWithLayout(output_shape_with_layout);
1434 return output_shape_with_layout;
1435 }
1436
DeleteDimension(int64 dim_to_delete,Shape shape)1437 /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
1438 Shape shape) {
1439 CHECK(shape.IsArray());
1440 shape.DeleteDimension(dim_to_delete);
1441 return shape;
1442 }
1443
FilterDimensions(const std::function<bool (int64)> & p,Shape shape)1444 /* static */ Shape ShapeUtil::FilterDimensions(
1445 const std::function<bool(int64)>& p, Shape shape) {
1446 CHECK(shape.IsArray());
1447 std::vector<int64> dims_to_delete;
1448 for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
1449 if (!p(i)) {
1450 dims_to_delete.push_back(i);
1451 }
1452 }
1453 for (int64 dim : dims_to_delete) {
1454 shape = DeleteDimension(dim, shape);
1455 }
1456 return shape;
1457 }
1458
Hash(const Shape & shape)1459 /*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
1460 using tensorflow::hash;
1461 using tensorflow::Hash64Combine;
1462
1463 size_t hash_value = hash<PrimitiveType>()(shape.element_type());
1464
1465 if (shape.tuple_shapes().empty()) {
1466 for (int i = 0; i < shape.dimensions_size(); ++i) {
1467 hash_value =
1468 Hash64Combine(hash_value, hash<int64>()(shape.dimensions(i)));
1469 hash_value = Hash64Combine(hash_value,
1470 hash<bool>()(shape.is_dynamic_dimension(i)));
1471 }
1472
1473 hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout()));
1474 } else {
1475 hash_value = 0;
1476 for (const Shape& subshape : shape.tuple_shapes()) {
1477 hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape));
1478 }
1479 }
1480
1481 return hash_value;
1482 }
1483
1484 // Returns the indices of the first elements of all consecutive subarrays of the
1485 // given array. For example:
1486 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
ConsecutiveSegments(absl::Span<const int64> xs)1487 static std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
1488 std::vector<size_t> is = {0};
1489 for (size_t i = 1; i < xs.size(); ++i) {
1490 if (1 != xs[i] - xs[i - 1]) {
1491 is.push_back(i);
1492 }
1493 }
1494 return is;
1495 }
1496
1497 // Merges the sequences of dimensions of the given shape which start at the
1498 // given indices `segs`.
MergeDimensions(absl::Span<const size_t> segs,const Shape & shape)1499 static Shape MergeDimensions(absl::Span<const size_t> segs,
1500 const Shape& shape) {
1501 std::vector<int64> dimensions;
1502 for (size_t i = 1; i <= segs.size(); ++i) {
1503 dimensions.push_back(std::accumulate(
1504 shape.dimensions().begin() + segs[i - 1],
1505 shape.dimensions().begin() +
1506 (segs.size() == i ? shape.dimensions().size() : segs[i]),
1507 1, std::multiplies<int64>()));
1508 }
1509 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
1510 dimensions);
1511 }
1512
FindTranspose021(const Shape & a,const Shape & b)1513 /*static*/ absl::optional<std::vector<int64>> ShapeUtil::FindTranspose021(
1514 const Shape& a, const Shape& b) {
1515 if (!CompatibleIgnoringElementType(a, b)) {
1516 return absl::nullopt;
1517 }
1518
1519 std::vector<int64> permutation(a.dimensions().size());
1520 absl::Span<const int64> minor_to_major_a = LayoutUtil::MinorToMajor(a);
1521 std::vector<int64> major_to_minor_a(minor_to_major_a.rbegin(),
1522 minor_to_major_a.rend());
1523 absl::Span<const int64> minor_to_major_b = LayoutUtil::MinorToMajor(b);
1524 std::vector<int64> major_to_minor_b(minor_to_major_b.rbegin(),
1525 minor_to_major_b.rend());
1526 for (size_t i = 0; i < permutation.size(); ++i) {
1527 permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
1528 }
1529
1530 std::vector<size_t> segments = ConsecutiveSegments(permutation);
1531 if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
1532 Shape descending_layout_shape =
1533 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
1534 Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
1535 absl::Span<const int64> normalized_dims =
1536 AsInt64Slice(normalized_shape.dimensions());
1537 std::vector<int64> dims_021;
1538 if (2 == segments.size()) {
1539 // The logical component-0 is of size one.
1540 dims_021 = {1, normalized_dims[1], normalized_dims[0]};
1541 } else {
1542 dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
1543 }
1544
1545 return dims_021;
1546 }
1547
1548 return absl::nullopt;
1549 }
1550
1551 } // namespace xla
1552