1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/shape_util.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "tensorflow/compiler/xla/index_util.h"
26 #include "tensorflow/compiler/xla/layout_util.h"
27 #include "tensorflow/compiler/xla/primitive_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/stringpiece.h"
33 #include "tensorflow/core/lib/gtl/iterator_range.h"
34 #include "tensorflow/core/lib/gtl/optional.h"
35 #include "tensorflow/core/lib/strings/numbers.h"
36 #include "tensorflow/core/lib/strings/str_util.h"
37 #include "tensorflow/core/lib/strings/strcat.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 #include "tensorflow/core/platform/regexp.h"
41
42 namespace xla {
43
ToString() const44 string ShapeIndex::ToString() const {
45 return tensorflow::strings::StrCat(
46 "{", tensorflow::str_util::Join(indices_, ","), "}");
47 }
48
ToString() const49 string ShapeIndexView::ToString() const {
50 return tensorflow::strings::StrCat(
51 "{",
52 tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_),
53 ","),
54 "}");
55 }
56
operator <<(std::ostream & out,const ShapeIndex & shape_index)57 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
58 out << shape_index.ToString();
59 return out;
60 }
61
operator <<(std::ostream & out,const ShapeIndexView & shape_index)62 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
63 out << shape_index.ToString();
64 return out;
65 }
66
67 namespace {
68
69 // Recursive helper for comparing the equality of two shapes. Returns true if
70 // the shapes are the same. If compare_layouts is true, then layouts must also
71 // match.
CompareShapes(const Shape & lhs,const Shape & rhs,bool compare_layouts)72 bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
73 if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) {
74 return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) &&
75 ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
76 [=](const Shape& l, const Shape& r) {
77 return CompareShapes(l, r, compare_layouts);
78 });
79 } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) {
80 return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs);
81 }
82
83 if (compare_layouts) {
84 if (lhs.layout().format() != rhs.layout().format()) {
85 return false;
86 }
87 if (LayoutUtil::IsDenseArray(lhs)) {
88 if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs),
89 LayoutUtil::MinorToMajor(rhs))) {
90 VLOG(3) << "CompareShapes: lhs layout != rhs layout";
91 return false;
92 }
93 if (!ContainersEqual(lhs.layout().padded_dimensions(),
94 rhs.layout().padded_dimensions())) {
95 VLOG(3)
96 << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions";
97 return false;
98 }
99 if (lhs.layout().padding_value() != rhs.layout().padding_value()) {
100 VLOG(3) << "CompareShapes: lhs padding value != rhs padding_value";
101 return false;
102 }
103 }
104 }
105
106 if (!ShapeUtil::SameDimensions(lhs, rhs)) {
107 VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
108 return false;
109 }
110 if (!ShapeUtil::SameElementType(lhs, rhs)) {
111 VLOG(3) << "CompareShapes: lhs element type != rhs element type";
112 return false;
113 }
114 return true;
115 }
116
117 // Constructs and returns the new shape with the given minor_to_major order in
118 // its Layout.
MakeShapeWithLayoutInternal(PrimitiveType element_type,tensorflow::gtl::ArraySlice<int64> dimensions,tensorflow::gtl::ArraySlice<int64> minor_to_major)119 StatusOr<Shape> MakeShapeWithLayoutInternal(
120 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
121 tensorflow::gtl::ArraySlice<int64> minor_to_major) {
122 if (dimensions.size() != minor_to_major.size()) {
123 return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
124 dimensions.size(), minor_to_major.size());
125 }
126 if (element_type == OPAQUE || element_type == TUPLE) {
127 return InvalidArgument("Unsupported element type: %s",
128 PrimitiveType_Name(element_type).c_str());
129 }
130 Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
131 auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
132 min2maj->Clear();
133 for (int64 value : minor_to_major) {
134 min2maj->Add(value);
135 }
136 if (!shape.has_layout()) {
137 return InvalidArgument("Shape has no layout.");
138 }
139 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
140 return shape;
141 }
142
143 } // namespace
144
Equal(const Shape & lhs,const Shape & rhs)145 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
146 bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true);
147 if (!equal && VLOG_IS_ON(3)) {
148 VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
149 << ", rhs = " << rhs.ShortDebugString();
150 }
151
152 return equal;
153 }
154
Rank(const Shape & shape)155 /* static */ int64 ShapeUtil::Rank(const Shape& shape) {
156 CHECK(!ShapeUtil::IsTuple(shape))
157 << "Tuples do not have a rank, shape: " << shape;
158 return shape.dimensions_size();
159 }
160
TrueRank(const Shape & shape)161 /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
162 int64 accum = 0;
163 for (int64 dimension : shape.dimensions()) {
164 // We do not count zero dimensions.
165 if (dimension != 1) {
166 accum += 1;
167 }
168 }
169 return accum;
170 }
171
MakeProgramShape(std::initializer_list<Shape> parameters,Shape result)172 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
173 std::initializer_list<Shape> parameters, Shape result) {
174 ProgramShape program_shape;
175 for (const auto& shape : parameters) {
176 *program_shape.add_parameters() = shape;
177 }
178 *program_shape.mutable_result() = std::move(result);
179 return program_shape;
180 }
181
MakeShape(PrimitiveType element_type,tensorflow::gtl::ArraySlice<int64> dimensions)182 /* static */ Shape ShapeUtil::MakeShape(
183 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
184 DCHECK_NE(TUPLE, element_type);
185 DCHECK_NE(OPAQUE, element_type);
186 Shape result;
187 PopulateShape(element_type, dimensions, &result);
188 return result;
189 }
190
MakeShapeWithLayout(PrimitiveType element_type,tensorflow::gtl::ArraySlice<int64> dimensions,tensorflow::gtl::ArraySlice<int64> minor_to_major)191 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
192 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
193 tensorflow::gtl::ArraySlice<int64> minor_to_major) {
194 return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major)
195 .ValueOrDie();
196 }
197
MakeShapeWithDescendingLayout(PrimitiveType element_type,tensorflow::gtl::ArraySlice<int64> dimensions)198 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
199 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
200 std::vector<int64> layout(dimensions.size());
201 std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
202 return MakeShapeWithLayout(element_type, dimensions, layout);
203 }
204
MakeShapeWithSparseLayout(PrimitiveType element_type,tensorflow::gtl::ArraySlice<int64> dimensions,int64 max_sparse_elements)205 /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
206 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
207 int64 max_sparse_elements) {
208 DCHECK_NE(TUPLE, element_type);
209 DCHECK_NE(OPAQUE, element_type);
210 Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
211 *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
212 TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
213 return shape;
214 }
215
216 /* static */ Shape
MakeShapeWithDescendingLayoutAndSamePhysicalLayout(const Shape & shape)217 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
218 const Shape& shape) {
219 std::vector<int64> dims(shape.dimensions_size());
220 for (int i = 0; i < shape.dimensions_size(); ++i) {
221 dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
222 }
223 return MakeShapeWithDescendingLayout(shape.element_type(), dims);
224 }
225
PopulateShape(PrimitiveType element_type,tensorflow::gtl::ArraySlice<int64> dimensions,Shape * shape)226 /* static */ void ShapeUtil::PopulateShape(
227 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
228 Shape* shape) {
229 shape->Clear();
230 shape->set_element_type(element_type);
231 for (int64 dimension : dimensions) {
232 shape->add_dimensions(dimension);
233 }
234 LayoutUtil::SetToDefaultLayout(shape);
235 TF_DCHECK_OK(ValidateShape(*shape));
236 }
237
MakeTupleShape(tensorflow::gtl::ArraySlice<Shape> shapes)238 /* static */ Shape ShapeUtil::MakeTupleShape(
239 tensorflow::gtl::ArraySlice<Shape> shapes) {
240 Shape result;
241 result.set_element_type(TUPLE);
242 for (const auto& shape : shapes) {
243 AppendShapeToTuple(shape, &result);
244 }
245 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
246 return result;
247 }
248
MakeOpaqueShape()249 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
250 Shape result;
251 result.set_element_type(OPAQUE);
252 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
253 return result;
254 }
255
AppendShapeToTuple(const Shape & shape,Shape * tuple_shape)256 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
257 Shape* tuple_shape) {
258 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
259 *tuple_shape->add_tuple_shapes() = shape;
260 }
261
AppendMajorDimension(int bound,Shape * shape)262 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
263 CHECK(LayoutUtil::IsDenseArray(*shape));
264 shape->mutable_layout()->add_minor_to_major(Rank(*shape));
265 shape->add_dimensions(bound);
266 TF_DCHECK_OK(ValidateShape(*shape));
267 }
268
ElementIsIntegral(const Shape & shape)269 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
270 return primitive_util::IsIntegralType(shape.element_type());
271 }
272
ElementIsIntegralWithBits(const Shape & shape,int32 bits)273 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
274 int32 bits) {
275 return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
276 }
277
ElementHasBitWidth(const Shape & shape,int bits)278 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
279 if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) {
280 return false;
281 }
282 return primitive_util::BitWidth(shape.element_type()) == bits;
283 }
284
ElementIsSigned(const Shape & shape)285 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
286 switch (shape.element_type()) {
287 case S8:
288 case S16:
289 case S32:
290 case S64:
291 case F16:
292 case BF16:
293 case F32:
294 case F64:
295 return true;
296
297 case PRED:
298 case U8:
299 case U16:
300 case U32:
301 case U64:
302 case C64:
303 case TUPLE:
304 case OPAQUE:
305 return false;
306
307 default:
308 LOG(FATAL) << "Unhandled element type " << shape.element_type();
309 }
310 }
311
ElementIsComplex(const Shape & shape)312 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
313 return primitive_util::IsComplexType(shape.element_type());
314 }
315
ElementIsFloating(const Shape & shape)316 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
317 return primitive_util::IsFloatingPointType(shape.element_type());
318 }
319
IsNestedTuple(const Shape & shape)320 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
321 return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(),
322 shape.tuple_shapes().end(), IsTuple);
323 }
324
IsEmptyTuple(const Shape & shape)325 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
326 return IsTuple(shape) && TupleElementCount(shape) == 0;
327 }
328
IsNil(const Shape & shape)329 /* static */ bool ShapeUtil::IsNil(const Shape& shape) {
330 return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape);
331 }
332
TupleElementCount(const Shape & shape)333 /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
334 CHECK(IsTuple(shape)) << HumanString(shape);
335 return shape.tuple_shapes_size();
336 }
337
GetTupleElementShape(const Shape & shape,int64 index)338 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
339 int64 index) {
340 CHECK(IsTuple(shape));
341 CHECK_GT(TupleElementCount(shape), index);
342 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
343 return shape.tuple_shapes(index);
344 }
345
SliceTuple(const Shape & tuple,int64 start,int64 limit)346 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
347 int64 limit) {
348 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
349 CHECK(IsTuple(tuple));
350 CHECK_LE(start, TupleElementCount(tuple));
351 CHECK_LE(limit, TupleElementCount(tuple));
352
353 std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
354 tuple.tuple_shapes().begin() + limit);
355 return MakeTupleShape(new_elements);
356 }
357
358 // Returns the shape of a real or imaginary component.
ComplexComponentShape(const Shape & complex_shape)359 /* static */ Shape ShapeUtil::ComplexComponentShape(
360 const Shape& complex_shape) {
361 CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
362 return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
363 complex_shape.element_type()));
364 }
365
ShapeIs(const Shape & shape,PrimitiveType element_type,std::initializer_list<int64> dimensions)366 /* static */ bool ShapeUtil::ShapeIs(const Shape& shape,
367 PrimitiveType element_type,
368 std::initializer_list<int64> dimensions) {
369 return Equal(shape, MakeShape(element_type, dimensions));
370 }
371
ElementsIn(const Shape & shape)372 /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
373 CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape);
374 CHECK_EQ(shape.dimensions_size(), Rank(shape));
375 return std::accumulate<decltype(shape.dimensions().begin()), int64>(
376 shape.dimensions().begin(), shape.dimensions().end(), 1LL,
377 std::multiplies<int64>());
378 }
379
HasZeroElements(const Shape & shape)380 /* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) {
381 return ElementsIn(shape) == 0;
382 }
383
IsScalarF32(const Shape & shape)384 /* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) {
385 return shape.element_type() == F32 && Rank(shape) == 0;
386 }
387
HumanString(const Shape & shape)388 /* static */ string ShapeUtil::HumanString(const Shape& shape) {
389 if (IsTuple(shape)) {
390 string text = "(";
391 const char* prefix = "";
392 for (const Shape& elem_shape : shape.tuple_shapes()) {
393 tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape));
394 prefix = ", ";
395 }
396 text += ")";
397 return text;
398 } else {
399 return tensorflow::strings::StrCat(
400 tensorflow::str_util::Lowercase(
401 PrimitiveType_Name(shape.element_type())),
402 "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]");
403 }
404 }
405
406 namespace {
407
408 // Class to memoize the computation of
409 // tensorflow::str_util::Lowercase(PrimitiveType_Name(p))
410 // for all PrimitiveType values "p"
411 class PrimitiveTypeNameGenerator {
412 public:
PrimitiveTypeNameGenerator()413 PrimitiveTypeNameGenerator() {
414 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
415 if (PrimitiveType_IsValid(i)) {
416 lowercase_name_[i] = tensorflow::str_util::Lowercase(
417 PrimitiveType_Name(static_cast<PrimitiveType>(i)));
418 }
419 }
420 }
LowercaseName(PrimitiveType t)421 const string& LowercaseName(PrimitiveType t) {
422 return lowercase_name_[static_cast<int>(t)];
423 }
424
425 private:
426 string lowercase_name_[PrimitiveType_ARRAYSIZE];
427 };
428
LowercasePrimitiveTypeName(PrimitiveType s)429 const string& LowercasePrimitiveTypeName(PrimitiveType s) {
430 static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator();
431 return gen->LowercaseName(s);
432 }
433
StringToPrimitiveType(const string & name)434 StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
435 static std::unordered_map<string, PrimitiveType>* name_to_type = [] {
436 static auto* map = new std::unordered_map<string, PrimitiveType>;
437 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
438 if (PrimitiveType_IsValid(i)) {
439 auto value = static_cast<PrimitiveType>(i);
440 (*map)[LowercasePrimitiveTypeName(value)] = value;
441 }
442 }
443 return map;
444 }();
445 auto found = name_to_type->find(name);
446 if (found == name_to_type->end()) {
447 return InvalidArgument("Invalid element type string: \"%s\".",
448 name.c_str());
449 }
450 return found->second;
451 }
452
453 } // namespace
454
HumanStringWithLayout(const Shape & shape)455 /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
456 if (IsTuple(shape)) {
457 string text = "(";
458 const char* prefix = "";
459 for (const Shape& elem_shape : shape.tuple_shapes()) {
460 tensorflow::strings::StrAppend(&text, prefix,
461 HumanStringWithLayout(elem_shape));
462 prefix = ", ";
463 }
464 text += ")";
465 return text;
466 } else {
467 string result = tensorflow::strings::StrCat(
468 LowercasePrimitiveTypeName(shape.element_type()), "[");
469 for (int i = 0; i < shape.dimensions().size(); i++) {
470 tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "",
471 shape.dimensions(i));
472 }
473 result += "]";
474 if (!IsScalar(shape) && !IsOpaque(shape)) {
475 if (LayoutUtil::HasLayout(shape)) {
476 tensorflow::strings::StrAppend(&result,
477 LayoutUtil::HumanString(shape.layout()));
478 }
479 }
480 return result;
481 }
482 }
483
HumanString(const ProgramShape & program_shape)484 /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
485 std::vector<string> parameters;
486 for (auto& shape : program_shape.parameters()) {
487 const int i = parameters.size();
488 parameters.push_back(
489 tensorflow::strings::StrCat(i < program_shape.parameter_names_size()
490 ? program_shape.parameter_names(i)
491 : "(unknown)",
492 ": ", HumanString(shape)));
493 }
494 return tensorflow::strings::StrCat(
495 "(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
496 HumanString(program_shape.result()));
497 }
498
499 namespace {
500 // Parses shapes with simple recursive descent structure -- consumes from the
501 // front of s and passes that view recursively as required.
ParseShapeStringInternal(tensorflow::StringPiece * s)502 StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
503 tensorflow::str_util::RemoveLeadingWhitespace(s);
504
505 if (s->Consume("(")) { // Tuple.
506 std::vector<Shape> shapes;
507 bool must_end = false;
508 while (true) {
509 if (s->Consume(")")) {
510 break;
511 } else if (must_end) {
512 return InvalidArgument("Expected end of tuple; got: \"%s\"",
513 s->ToString().c_str());
514 }
515 shapes.emplace_back();
516 TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s));
517 tensorflow::str_util::RemoveLeadingWhitespace(s);
518 must_end = !s->Consume(",");
519 }
520 return ShapeUtil::MakeTupleShape(shapes);
521 }
522
523 string element_type_string;
524 string dimensions_string;
525 string format_string;
526 string layout_string;
527 // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so
528 // we convert in to the RE2-consumable type and then consume the corresponding
529 // amount from our StringPiece type.
530 tensorflow::RegexpStringPiece s_consumable(s->data(), s->size());
531 if (RE2::Consume(
532 &s_consumable,
533 "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?",
534 &element_type_string, &dimensions_string, &format_string,
535 &layout_string)) {
536 size_t consumed = s->size() - s_consumable.size();
537 s->remove_prefix(consumed);
538 auto string_to_int64 = [&s](const string& input) -> StatusOr<int64> {
539 int64 element;
540 if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) {
541 return InvalidArgument(
542 "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"",
543 input.c_str(), s->ToString().c_str());
544 }
545 return element;
546 };
547
548 auto comma_list_to_int64s =
549 [&s,
550 string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
551 std::vector<int64> results;
552 for (const string& piece : tensorflow::str_util::Split(input, ',')) {
553 TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece));
554 results.push_back(element);
555 }
556 return results;
557 };
558
559 // Extract the dimensions.
560 TF_ASSIGN_OR_RETURN(std::vector<int64> dimensions,
561 comma_list_to_int64s(dimensions_string));
562
563 // Extract the primitive element type.
564 TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type,
565 StringToPrimitiveType(element_type_string));
566 if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE ||
567 primitive_type == OPAQUE) {
568 return InvalidArgument("Invalid element type string: \"%s\".",
569 element_type_string.c_str());
570 }
571
572 Shape result;
573 if (format_string.empty() && layout_string.empty()) {
574 // Create a shape without a layout set.
575 result = ShapeUtil::MakeShape(primitive_type, dimensions);
576 } else if (format_string == "sparse") {
577 TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string));
578 result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions,
579 max_elements);
580 } else if (format_string.empty() || format_string == "dense") {
581 // Extract the layout minor-to-major and set it.
582 TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
583 comma_list_to_int64s(layout_string));
584 TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal(
585 primitive_type, dimensions, min2maj));
586 } else {
587 // This should not be reached.
588 LOG(FATAL) << "Unhandled condition when parsing shape; format: \""
589 << format_string << "\", layout: \"" << layout_string << "\"";
590 }
591 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result));
592 return std::move(result);
593 }
594
595 return InvalidArgument("Invalid shape string to parse: \"%s\"",
596 s->ToString().c_str());
597 }
598 } // namespace
599
ParseShapeString(tensorflow::StringPiece s)600 /* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(
601 tensorflow::StringPiece s) {
602 TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s));
603 if (!s.empty()) {
604 return InvalidArgument("Invalid shape string to parse: \"%s\"",
605 s.ToString().c_str());
606 }
607 return shape;
608 }
609
SameDimensions(const Shape & lhs,const Shape & rhs)610 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
611 const Shape& rhs) {
612 return ContainersEqual(lhs.dimensions(), rhs.dimensions());
613 }
614
Compatible(const Shape & lhs,const Shape & rhs)615 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
616 if (lhs.element_type() == TUPLE) {
617 return rhs.element_type() == TUPLE &&
618 ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible);
619 }
620 return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs);
621 }
622
CompatibleIgnoringElementType(const Shape & lhs,const Shape & rhs)623 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
624 const Shape& rhs) {
625 if (lhs.element_type() == TUPLE) {
626 return rhs.element_type() == TUPLE &&
627 ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
628 CompatibleIgnoringElementType);
629 }
630 return SameDimensions(lhs, rhs);
631 }
632
CompatibleIgnoringFpPrecision(const Shape & lhs,const Shape & rhs)633 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
634 const Shape& rhs) {
635 if (lhs.element_type() == TUPLE) {
636 return rhs.element_type() == TUPLE &&
637 ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
638 CompatibleIgnoringFpPrecision);
639 }
640 if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
641 return CompatibleIgnoringElementType(lhs, rhs);
642 }
643 return false;
644 }
645
GetDimension(const Shape & shape,int64 dimension_number)646 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
647 int64 dimension_number) {
648 return shape.dimensions(GetDimensionNumber(shape, dimension_number));
649 }
650
GetDimensionNumber(const Shape & shape,int64 dimension_number)651 /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
652 int64 dimension_number) {
653 if (dimension_number < 0) {
654 dimension_number += Rank(shape);
655 }
656 CHECK_GE(dimension_number, 0);
657 return dimension_number;
658 }
659
ByteSizeOfPrimitiveType(PrimitiveType primitive_type)660 /* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
661 PrimitiveType primitive_type) {
662 switch (primitive_type) {
663 case PRED:
664 return sizeof(int8);
665 case TUPLE:
666 LOG(FATAL) << "tuples have no definitive size";
667 case OPAQUE:
668 LOG(FATAL) << "opaque have no definitive size";
669 case S8:
670 return sizeof(int8);
671 case S16:
672 return sizeof(int16);
673 case S32:
674 return sizeof(int32);
675 case S64:
676 return sizeof(int64);
677 case U8:
678 return sizeof(uint8);
679 case U16:
680 return sizeof(uint16);
681 case U32:
682 return sizeof(uint32);
683 case U64:
684 return sizeof(uint64);
685 case BF16:
686 return sizeof(float) / 2;
687 case F16:
688 return sizeof(float) / 2;
689 case F32:
690 return sizeof(float);
691 case F64:
692 return sizeof(double);
693 case C64:
694 return sizeof(complex64);
695 default:
696 LOG(FATAL) << "Unhandled primitive type " << primitive_type;
697 }
698 }
699
ByteSizeOf(const Shape & shape,int64 pointer_size)700 /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
701 int64 pointer_size) {
702 TF_DCHECK_OK(ValidateShape(shape));
703 DCHECK_NE(OPAQUE, shape.element_type());
704 if (shape.element_type() == TUPLE) {
705 return ByteSizeOfTupleIndexTable(shape, pointer_size);
706 }
707 int64 byte_size = ByteSizeOfElements(shape);
708 if (LayoutUtil::IsSparseArray(shape)) {
709 byte_size += ByteSizeOfSparseIndices(shape);
710 }
711 return byte_size;
712 }
713
ByteSizeOfTupleIndexTable(const Shape & shape,int64 pointer_size)714 /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
715 int64 pointer_size) {
716 TF_DCHECK_OK(ValidateShape(shape));
717 DCHECK_EQ(TUPLE, shape.element_type());
718 CHECK_GT(pointer_size, 0);
719 return pointer_size * shape.tuple_shapes_size();
720 }
721
ByteSizeOfElements(const Shape & shape)722 /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
723 TF_DCHECK_OK(ValidateShape(shape));
724 DCHECK(ShapeUtil::IsArray(shape));
725 int64 allocated_element_count;
726
727 if (LayoutUtil::IsSparseArray(shape)) {
728 allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
729 } else {
730 CHECK(LayoutUtil::IsDenseArray(shape));
731 tensorflow::gtl::ArraySlice<int64> padded_dimensions =
732 LayoutUtil::PaddedDimensions(shape);
733 if (!padded_dimensions.empty()) {
734 CHECK_EQ(Rank(shape), padded_dimensions.size());
735 allocated_element_count = 1;
736 for (int64 dimension_size : padded_dimensions) {
737 allocated_element_count *= dimension_size;
738 }
739 } else {
740 allocated_element_count = ElementsIn(shape);
741 }
742 }
743 return allocated_element_count *
744 ByteSizeOfPrimitiveType(shape.element_type());
745 }
746
ByteSizeOfSparseIndices(const Shape & shape)747 /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
748 TF_DCHECK_OK(ValidateShape(shape));
749 DCHECK(LayoutUtil::IsSparseArray(shape));
750 return LayoutUtil::MaxSparseElements(shape.layout()) *
751 ShapeUtil::Rank(shape) * sizeof(int64);
752 }
753
ValidateShapeWithOptionalLayoutInternal(const Shape & shape)754 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
755 const Shape& shape) {
756 if (shape.element_type() == TUPLE) {
757 if (shape.dimensions_size() != 0) {
758 return InvalidArgument("tuples must not have dimensions specified");
759 }
760 for (auto& element_shape : shape.tuple_shapes()) {
761 TF_RETURN_IF_ERROR(
762 ValidateShapeWithOptionalLayoutInternal(element_shape));
763 }
764 return Status::OK();
765 }
766
767 // Non-tuple shape.
768 if (shape.tuple_shapes_size() > 0) {
769 return InvalidArgument("non-tuple shape has tuple_shapes field");
770 }
771 if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
772 return InvalidArgument("shape has invalid element type: %s",
773 shape.ShortDebugString().c_str());
774 }
775 if (Rank(shape) != shape.dimensions_size()) {
776 return InvalidArgument(
777 "shape's rank is mismatched with dimension count; rank=%lld "
778 "dimensions_size=%d",
779 Rank(shape), shape.dimensions_size());
780 }
781 for (int64 i = 0; i < Rank(shape); ++i) {
782 int64 dimension = shape.dimensions(i);
783 if (dimension < 0) {
784 return InvalidArgument(
785 "shape's dimensions must not be < 0; dimension at index %lld was "
786 "%lld",
787 i, dimension);
788 }
789 }
790
791 return Status::OK();
792 }
793
ValidateShapeWithOptionalLayout(const Shape & shape)794 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
795 const Shape& shape) {
796 if (LayoutUtil::HasLayout(shape)) {
797 // Since a layout is present, upgrade to the full set of invariant checks.
798 return ValidateShape(shape);
799 }
800 return ValidateShapeWithOptionalLayoutInternal(shape);
801 }
802
ValidateShape(const Shape & shape)803 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
804 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
805
806 return LayoutUtil::ValidateLayoutInShape(shape);
807 }
808
ChangeElementType(const Shape & original,PrimitiveType type)809 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
810 PrimitiveType type) {
811 Shape new_shape = original;
812 new_shape.set_element_type(type);
813 return new_shape;
814 }
815
GetSubshape(const Shape & shape,ShapeIndexView index)816 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
817 ShapeIndexView index) {
818 const Shape* return_shape = &shape;
819 for (auto i : index) {
820 CHECK(IsTuple(*return_shape))
821 << "Invalid index " << index << " for shape " << shape;
822 return_shape = &return_shape->tuple_shapes(i);
823 }
824 return *return_shape;
825 }
826
GetMutableSubshape(Shape * shape,ShapeIndexView index)827 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
828 ShapeIndexView index) {
829 Shape* return_shape = shape;
830 for (auto i : index) {
831 CHECK(IsTuple(*return_shape));
832 return_shape = return_shape->mutable_tuple_shapes(i);
833 }
834 return return_shape;
835 }
836
837 /* static */
IsLeafIndex(const Shape & shape,const ShapeIndex & index)838 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
839 return !IsTuple(GetSubshape(shape, index));
840 }
841
StripDegenerateDimensions(const Shape & shape)842 /* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) {
843 std::vector<int64> dimension_sizes;
844 std::vector<int64> degenerate_dimensions;
845 for (int64 i = 0; i < shape.dimensions_size(); ++i) {
846 if (shape.dimensions(i) == 1) {
847 degenerate_dimensions.push_back(i);
848 } else {
849 dimension_sizes.push_back(shape.dimensions(i));
850 }
851 }
852
853 // Construct minor_to_major of stripped shape. The order of the non-degenerate
854 // dimensions should be preserved from the original shape. First, create
855 // vector of the non-degenerate dimensions from the original minor_to_major
856 // array.
857 std::vector<int64> minor_to_major;
858 for (int64 i : shape.layout().minor_to_major()) {
859 if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(),
860 i) == degenerate_dimensions.end()) {
861 minor_to_major.push_back(i);
862 }
863 }
864
865 // The dimensions in minor_to_major need to be renumbered to account for the
866 // degenerate dimensions which have removed. Decrement each dimension number
867 // once for each degenerate dimension which has a smaller number.
868 for (int i = 0; i < minor_to_major.size(); ++i) {
869 int adjustment = 0;
870 for (int64 dim : degenerate_dimensions) {
871 if (minor_to_major[i] > dim) {
872 adjustment++;
873 }
874 }
875 minor_to_major[i] -= adjustment;
876 }
877
878 {
879 std::vector<int64> dims(minor_to_major.size());
880 std::iota(dims.begin(), dims.end(), 0);
881 DCHECK(minor_to_major.size() == dims.size() &&
882 std::is_permutation(minor_to_major.begin(), minor_to_major.end(),
883 dims.begin()));
884 }
885 Shape stripped_shape =
886 shape.has_layout() ? MakeShapeWithLayout(shape.element_type(),
887 dimension_sizes, minor_to_major)
888 : MakeShape(shape.element_type(), dimension_sizes);
889
890 VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape);
891 VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape);
892 return stripped_shape;
893 }
894
895 namespace {
896
897 // Helper for ForEachSubshape which visits the subshapes of the given shape in
898 // DFS pre-order starting with the index.
ForEachSubshapeHelper(const Shape & shape,const ShapeUtil::StatusVisitorFunction & func,ShapeIndex * index)899 Status ForEachSubshapeHelper(const Shape& shape,
900 const ShapeUtil::StatusVisitorFunction& func,
901 ShapeIndex* index) {
902 TF_RETURN_IF_ERROR(func(shape, *index));
903 if (ShapeUtil::IsTuple(shape)) {
904 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
905 index->push_back(i);
906 TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
907 ShapeUtil::GetTupleElementShape(shape, i), func, index));
908 index->pop_back();
909 }
910 }
911 return Status::OK();
912 }
913
914 // Helper for ForEachMutableSubshape which visits the subshapes of the given
915 // shape in DFS pre-order starting with the index.
ForEachMutableSubshapeHelper(Shape * shape,const ShapeUtil::MutatingStatusVisitorFunction & func,ShapeIndex * index)916 Status ForEachMutableSubshapeHelper(
917 Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
918 ShapeIndex* index) {
919 TF_RETURN_IF_ERROR(func(shape, *index));
920 if (ShapeUtil::IsTuple(*shape)) {
921 for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
922 index->push_back(i);
923 TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
924 shape->mutable_tuple_shapes(i), func, index));
925 index->pop_back();
926 }
927 }
928 return Status::OK();
929 }
930
931 } // namespace
932
ForEachSubshape(const Shape & shape,const VisitorFunction & func)933 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
934 const VisitorFunction& func) {
935 ShapeIndex index;
936 ForEachSubshapeHelper(
937 shape,
938 [&func](const Shape& subshape, const ShapeIndex& index) {
939 func(subshape, index);
940 return Status::OK();
941 },
942 &index)
943 .IgnoreError();
944 }
945
ForEachMutableSubshape(Shape * shape,const MutatingVisitorFunction & func)946 /* static */ void ShapeUtil::ForEachMutableSubshape(
947 Shape* shape, const MutatingVisitorFunction& func) {
948 ShapeIndex index;
949 ForEachMutableSubshapeHelper(
950 shape,
951 [&func](Shape* subshape, const ShapeIndex& index) {
952 func(subshape, index);
953 return Status::OK();
954 },
955 &index)
956 .IgnoreError();
957 }
958
ForEachSubshapeWithStatus(const Shape & shape,const StatusVisitorFunction & func)959 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
960 const Shape& shape, const StatusVisitorFunction& func) {
961 ShapeIndex index;
962 return ForEachSubshapeHelper(shape, func, &index);
963 }
964
ForEachMutableSubshapeWithStatus(Shape * shape,const MutatingStatusVisitorFunction & func)965 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
966 Shape* shape, const MutatingStatusVisitorFunction& func) {
967 ShapeIndex index;
968 return ForEachMutableSubshapeHelper(shape, func, &index);
969 }
970
PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,const Shape & shape)971 /* static */ Shape ShapeUtil::PermuteDimensions(
972 tensorflow::gtl::ArraySlice<int64> permutation, const Shape& shape) {
973 Shape new_shape = shape;
974 new_shape.clear_dimensions();
975 for (auto dim : Permute(permutation, shape.dimensions())) {
976 new_shape.add_dimensions(dim);
977 }
978 if (shape.has_layout()) {
979 CHECK(LayoutUtil::IsDenseArray(shape));
980 Layout* new_layout = new_shape.mutable_layout();
981 new_layout->set_format(DENSE);
982 new_layout->clear_minor_to_major();
983 for (auto index : Permute(permutation, shape.layout().minor_to_major())) {
984 new_layout->add_minor_to_major(index);
985 }
986 if (shape.layout().padded_dimensions_size() > 0) {
987 new_layout->clear_padded_dimensions();
988 for (auto dim :
989 Permute(permutation, shape.layout().padded_dimensions())) {
990 new_layout->add_padded_dimensions(dim);
991 }
992 }
993 }
994 return new_shape;
995 }
996
997 /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
InsertedOrDeleted1SizedDimensions(const Shape & shape_pre,const Shape & shape_post)998 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
999 const Shape& shape_post) {
1000 auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
1001
1002 std::vector<int64> deleted_indices;
1003 std::vector<int64> inserted_indices;
1004 // Returns false if any input/output index between prior_unmodified_dim_pair
1005 // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
1006 // the degerenate input/output dimensions in the gap to
1007 // deleted_indices/inserted_indices respectively.
1008 auto check_modified_dims =
1009 [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
1010 std::pair<int64, int64> prior_unmodified_dim_pair,
1011 std::pair<int64, int64> unmodified_dim_pair) {
1012 for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
1013 modified_input_dim < unmodified_dim_pair.first;
1014 ++modified_input_dim) {
1015 if (shape_pre.dimensions(modified_input_dim) > 1) {
1016 return false;
1017 }
1018 deleted_indices.push_back(modified_input_dim);
1019 }
1020 for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
1021 modified_output_dim < unmodified_dim_pair.second;
1022 ++modified_output_dim) {
1023 if (shape_post.dimensions(modified_output_dim) > 1) {
1024 return false;
1025 }
1026 inserted_indices.push_back(modified_output_dim);
1027 }
1028 return true;
1029 };
1030
1031 std::vector<std::pair<int64, int64>> unmodified_dims =
1032 DimensionsUnmodifiedByReshape(shape_pre, shape_post);
1033 // Returns nil if the reshape modifies any non-degenerate input/output
1034 // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
1035 // dimensions, so we only need to check whether dimensions in the gaps (thus
1036 // modified) have size >1.
1037 for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
1038 // Check (modified) dimensions between unmodified_dims[i-1] and
1039 // unmodified_dims[i].
1040 auto prior_unmodified_dim_pair =
1041 i > 0 ? unmodified_dims[i - 1] : std::make_pair(-1LL, -1LL);
1042 auto unmodified_dim_pair =
1043 i < unmodified_dims.size()
1044 ? unmodified_dims[i]
1045 : std::make_pair(Rank(shape_pre), Rank(shape_post));
1046 if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
1047 return nil;
1048 }
1049 }
1050
1051 return std::make_tuple(true, deleted_indices, inserted_indices);
1052 }
1053
1054 /* static */ std::vector<std::pair<int64, int64>>
DimensionsUnmodifiedByReshape(const Shape & input_shape,const Shape & output_shape)1055 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
1056 const Shape& output_shape) {
1057 // Unmodified dimensions are merely common factors of rank 1.
1058 auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
1059 AsInt64Slice(output_shape.dimensions()));
1060 for (size_t i = 0; i < common_factors.size() - 1;) {
1061 if (1 != common_factors[i + 1].first - common_factors[i].first ||
1062 1 != common_factors[i + 1].second - common_factors[i].second) {
1063 common_factors.erase(common_factors.begin() + i);
1064 } else {
1065 ++i;
1066 }
1067 }
1068 // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
1069 common_factors.pop_back();
1070 return common_factors;
1071 }
1072
TransposeIsBitcast(const Shape & input_shape,const Shape & output_shape,tensorflow::gtl::ArraySlice<int64> dimension_mapping)1073 /* static */ bool ShapeUtil::TransposeIsBitcast(
1074 const Shape& input_shape, const Shape& output_shape,
1075 tensorflow::gtl::ArraySlice<int64> dimension_mapping) {
1076 // Can't insert bitcasts without layout information.
1077 if (!LayoutUtil::HasLayout(input_shape) &&
1078 !LayoutUtil::HasLayout(output_shape)) {
1079 return false;
1080 }
1081
1082 // Padding is not handled.
1083 if (LayoutUtil::IsPadded(input_shape) && LayoutUtil::IsPadded(output_shape)) {
1084 return false;
1085 }
1086
1087 // Check the reshape permutes the positions of each dimension in the
1088 // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
1089 // input_positions = apply(dimension_mapping, output_positions)
1090 //
1091 // Because the positions of each dimension are the inverse permutation of the
1092 // minor-to-major order, the above check is equivalent to
1093 // inverse(input_dimensions) =
1094 // apply(dimension_mapping, inverse(output_dimensions))
1095 // # `I` indicates identity permutation.
1096 // apply(input_dimensions, I) =
1097 // apply(dimension_mapping, apply(output_dimensions, I))
1098 // apply(input_dimensions, I) =
1099 // apply((dimension_mapping * output_dimensions), I)
1100 // input_dimensions = dimension_mapping * output_dimensions
1101 return ContainersEqual(
1102 ComposePermutations(dimension_mapping,
1103 AsInt64Slice(output_shape.layout().minor_to_major())),
1104 input_shape.layout().minor_to_major());
1105 }
1106
ReshapeIsBitcast(const Shape & input_shape,const Shape & output_shape)1107 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
1108 const Shape& output_shape) {
1109 // Can't convert reshapes into bitcasts without layout information.
1110 if (!LayoutUtil::HasLayout(input_shape) ||
1111 !LayoutUtil::HasLayout(output_shape)) {
1112 return false;
1113 }
1114
1115 // Padding is not handled.
1116 if (LayoutUtil::IsPadded(input_shape) || LayoutUtil::IsPadded(output_shape)) {
1117 return false;
1118 }
1119
1120 CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape));
1121 if (ElementsIn(input_shape) == 0) {
1122 return true;
1123 }
1124
1125 // TL;DR: The rest of the method checks that the reshape does not change the
1126 // physical location of any unit input or output index. Unit indices have
1127 // exactly one dimension that equals 1 and other dimensions 0. This condition
1128 // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
1129 // reshape shouldn't change the physical location of any element. It is also a
1130 // sufficient condition as is proved below (note: many details are omitted for
1131 // space).
1132 //
1133 // Definitions:
1134 //
1135 // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
1136 // the size of i-th least significant dimension of IS or OS (this is opposite
1137 // to how we define the index of Shape::dimensions()).
1138 //
1139 // * Given an input or output index I, denote by p(I) I's physical linear
1140 // index (or physical index for short) and l(I) I's logical linear index (or
1141 // logical index for short).
1142 //
1143 // * Given a logical index k, denote by II(k) the input index whose linear
1144 // index is k, and OI(k) the corresponding output index.
1145 //
1146 // * Denote by IT[i] the increment of physical index if i-th dimension of the
1147 // input index is increased by 1. Similarly, OT[i] means the increment if i-th
1148 // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
1149 // is a function of IS or OS and the layout, and not dependent on the specific
1150 // input or output index.
1151 //
1152 // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
1153 // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
1154 // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
1155 // to prove, with every increment on k, the above formula still holds.
1156 //
1157 // First, suppose reshaping from IS to OS is non-factorizable (we discuss
1158 // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
1159 // there exists (i,j) such that
1160 //
1161 // 0<=i<=|IS|
1162 // 0<=j<=|OS|
1163 // |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
1164 // product(IS[i], IS[i+1], ..., IS[|IS|-1])
1165 // = product(OS[j], OS[j+1], ..., OS[|OS|-1])
1166 //
1167 // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
1168 // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
1169 // unit indices which are already tested. This also means IT[0]=OT[0]
1170 // because p(II(1))=IT[0] and p(OI(1))=OT[0].
1171 //
1172 // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
1173 // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
1174 // to the output physical.
1175 //
1176 // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
1177 // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
1178 // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
1179 // logical index k-1 to logical index k, dimension 1 of the input index
1180 // is increased by 1 and dimension 0 is reset to 0 thus decreased by
1181 // IS[0]-1. Therefore, the physical input index is increased by
1182 //
1183 // p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
1184 //
1185 // Because IS[0]<OS[0], the only change to the output index is that its
1186 // dimension 0 is increased by one. Therefore,
1187 //
1188 // p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
1189 //
1190 // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
1191 // p(II(k))=p(OI(k)). Therefore,
1192 // IT[1] - (IS[0]-1) * IT[0] = IT[0]
1193 // IT[1] = IS[0] * IT[0]
1194 // In other words, input dimension 1 is immediately more major than input
1195 // dimension 0. We can now conceptually collapse these two dimensions because
1196 // an increment in the logical index affecting only these two dimensions maps
1197 // to IT[0] in the physical index.
1198 //
1199 // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
1200 // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
1201 // identical.
1202 //
1203 // A factorizable reshape can be factorized into a list of non-factorizable
1204 // sub-reshapes, each of which can be handled similarly to the proof above.
1205 // For example,
1206 //
1207 // [7x9x2x15] -> [63x6x5]
1208 //
1209 // can be factorized into
1210 //
1211 // [7x9] -> [63] and [2x15] -> [6x5].
1212 //
1213 // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
1214 // same logical linear index. According to the factorization, we know
1215 // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
1216 // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
1217 // similar proof, with the increment of the logical index set to
1218 // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
1219 // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
1220 //
1221 // p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
1222 // = p(y2,0,0) + p(0,0,y1,y0)
1223 // = p(y2,y1,y0)
1224 //
1225 // check_input_unit_indices checks one way of the condition: each input unit
1226 // index is mapped to an output index with the same physical location. This
1227 // lambda will be called again with input_shape and output_shape reversed to
1228 // check the other way.
1229 auto check_input_unit_indices = [](const Shape& input_shape,
1230 const Shape& output_shape) {
1231 // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
1232 // as input_shape/output_shape and the dimension-0-major layout. These two
1233 // shapes are used for conversion between logical linear indices and
1234 // multi-dimensional indices.
1235 Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
1236 input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
1237 Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
1238 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
1239
1240 for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) {
1241 if (input_shape.dimensions(input_dim) <= 1) {
1242 continue;
1243 }
1244
1245 std::vector<int64> input_unit_index(Rank(input_shape), 0);
1246 input_unit_index[input_dim] = 1;
1247 int64 logical_linear_index =
1248 IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
1249 input_unit_index);
1250 // output_index has the same logical linear index as input_unit_index.
1251 std::vector<int64> output_index =
1252 IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
1253 logical_linear_index);
1254 // Check input_unit_index and output_index have the same physical linear
1255 // index.
1256 if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
1257 input_unit_index) !=
1258 IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
1259 output_index)) {
1260 return false;
1261 }
1262 }
1263 return true;
1264 };
1265 return check_input_unit_indices(input_shape, output_shape) &&
1266 check_input_unit_indices(output_shape, input_shape);
1267 }
1268
AlignLayouts(const Shape & input_shape,const Shape & output_shape)1269 /* static */ tensorflow::gtl::optional<Shape> ShapeUtil::AlignLayouts(
1270 const Shape& input_shape, const Shape& output_shape) {
1271 int64 input_rank = Rank(input_shape);
1272 int64 output_rank = Rank(output_shape);
1273
1274 // First, calculate an alignment of the dimensions. A consecutive sequence of
1275 // input dimensions and output dimensions belong to the same alignment part if
1276 // the products of their dimension bounds are the same. In the easiest case,
1277 // an alignment part consists of one input dimension and one output dimension
1278 // which both have the same dimension bound. An alignment part specifies which
1279 // dimensions need to be kept together in a physical layout if we want a
1280 // reshape to be a bitcast. The order of the alignment parts is defined by the
1281 // physical layout of the input shape, so when we construct the layout for the
1282 // output shape we just process the alignment parts in this order, and then
1283 // layout the dimensions belonging to each part in descending (major to minor)
1284 // order.
1285
1286 // Stores the input and output dimension numbers where each alignment part
1287 // starts.
1288 std::vector<std::pair<int64, int64>> alignment;
1289 alignment.push_back({0, 0});
1290
1291 // Stores a mapping from the input dimension to the alignment part it belongs
1292 // to.
1293 std::vector<int64> dimension_to_alignment_index(input_rank);
1294 int64 input_dimension_product = 1, output_dimension_product = 1;
1295 for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) {
1296 // Check if we have reached the end of an alignment part.
1297 if (input_dimension_product == output_dimension_product &&
1298 input_dimension_product > 1) {
1299 alignment.push_back({i, j});
1300 input_dimension_product = output_dimension_product = 1;
1301 }
1302 if (input_dimension_product < output_dimension_product ||
1303 j == output_rank) {
1304 if (i == input_rank) {
1305 return tensorflow::gtl::nullopt;
1306 }
1307 dimension_to_alignment_index[i] = alignment.size() - 1;
1308 input_dimension_product *= input_shape.dimensions(i);
1309 ++i;
1310 } else {
1311 output_dimension_product *= output_shape.dimensions(j);
1312 ++j;
1313 }
1314 }
1315 if (input_dimension_product != output_dimension_product) {
1316 return tensorflow::gtl::nullopt;
1317 }
1318 // We also need to store an end element so that we know where the last
1319 // alignment part ends.
1320 alignment.push_back({input_rank, output_rank});
1321
1322 // Now check if the physical layout can potentially be aligned to the output
1323 // shape by changing the physical layout of the output shape. We need to check
1324 // that all dimension numbers that belong to the same alignment part appear
1325 // consecutively, and are in descending order. However we can ignore any
1326 // trivial dimension bounds of 1, because they can be placed anywhere.
1327 auto input_dimension_numbers = input_shape.layout().minor_to_major();
1328 std::vector<int64> output_layout;
1329 output_layout.reserve(output_rank);
1330 for (int64 i = 0; i < input_rank;) {
1331 int64 current_dimension_number = input_dimension_numbers[i];
1332
1333 // Skip trivial dimensions with a bound of 1.
1334 if (input_shape.dimensions(current_dimension_number) == 1) {
1335 ++i;
1336 continue;
1337 }
1338
1339 // Calculate the number of non-trivial dimension bounds in the input shape
1340 // belonging to the current alignment part.
1341 const int64 current_alignment_index =
1342 dimension_to_alignment_index[current_dimension_number];
1343 // Because of the special end element that we added, we can be sure that
1344 // 'current_alignment_index' is < alignment.size() - 1.
1345 CHECK_LT(current_alignment_index, alignment.size() - 1);
1346 int64 num_non_trivial_dimensions_in_alignment_part = 0;
1347 for (int64 j = alignment[current_alignment_index].first;
1348 j < alignment[current_alignment_index + 1].first; ++j) {
1349 if (input_shape.dimensions(j) != 1) {
1350 ++num_non_trivial_dimensions_in_alignment_part;
1351 }
1352 }
1353
1354 // Check that the following 'num_non_trivial_dimensions_in_alignment_part'
1355 // dimension numbers (ignoring dimension numbers with dimension bound 1) are
1356 // in descending order and belong to the current alignment part.
1357 for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part;
1358 ++i, ++j) {
1359 if (i == input_rank) {
1360 return tensorflow::gtl::nullopt;
1361 }
1362 // Skip trivial dimensions with a bound of 1.
1363 if (input_shape.dimensions(input_dimension_numbers[i]) == 1) {
1364 --j;
1365 continue;
1366 }
1367 // If the current dimension number belongs to a different alignment part,
1368 // or the dimension numbers are not in descending order, we can return
1369 // early.
1370 if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
1371 current_alignment_index ||
1372 input_dimension_numbers[i] > current_dimension_number) {
1373 return tensorflow::gtl::nullopt;
1374 }
1375 current_dimension_number = input_dimension_numbers[i];
1376 }
1377
1378 // The output dimension numbers that belong to the current alignment part
1379 // need to appear in the same descending order as in the input. Again, we
1380 // can skip dimensions with a bound of 1.
1381 for (int64 j = alignment[current_alignment_index + 1].second - 1;
1382 j >= alignment[current_alignment_index].second; --j) {
1383 if (output_shape.dimensions(j) != 1) {
1384 output_layout.push_back(j);
1385 }
1386 }
1387 }
1388 // Now add all the dimensions with dimension bound 1 at the end of
1389 // 'output_layout'.
1390 for (int64 i = 0; i < output_rank; ++i) {
1391 if (output_shape.dimensions(i) == 1) {
1392 output_layout.push_back(i);
1393 }
1394 }
1395 CHECK_EQ(output_layout.size(), output_rank);
1396 Shape output_shape_with_layout = MakeShapeWithLayout(
1397 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1398 output_layout);
1399 CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout));
1400 return output_shape_with_layout;
1401 }
1402
DeleteDimension(int64 dim_to_delete,Shape shape)1403 /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
1404 Shape shape) {
1405 shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete);
1406 if (LayoutUtil::HasLayout(shape)) {
1407 Layout* layout = shape.mutable_layout();
1408 layout->set_format(DENSE);
1409 for (size_t i = 0; i < layout->minor_to_major().size();) {
1410 if (layout->minor_to_major(i) == dim_to_delete) {
1411 layout->mutable_minor_to_major()->erase(
1412 layout->minor_to_major().begin() + i);
1413 continue;
1414 }
1415 if (layout->minor_to_major(i) > dim_to_delete) {
1416 (*layout->mutable_minor_to_major())[i] -= 1;
1417 }
1418 ++i;
1419 }
1420 }
1421 return shape;
1422 }
1423
FilterDimensions(const std::function<bool (int64)> & p,Shape shape)1424 /* static */ Shape ShapeUtil::FilterDimensions(
1425 const std::function<bool(int64)>& p, Shape shape) {
1426 std::vector<int64> dims_to_delete;
1427 for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
1428 if (!p(i)) {
1429 dims_to_delete.push_back(i);
1430 }
1431 }
1432 for (int64 dim : dims_to_delete) {
1433 shape = DeleteDimension(dim, shape);
1434 }
1435 return shape;
1436 }
1437
operator <<(std::ostream & out,const Shape & shape)1438 std::ostream& operator<<(std::ostream& out, const Shape& shape) {
1439 out << ShapeUtil::HumanString(shape);
1440 return out;
1441 }
1442
1443 } // namespace xla
1444