1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/layout_util.h"
17
18 #include <stddef.h>
19 #include <algorithm>
20 #include <functional>
21 #include <random>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/protobuf_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/hash/hash.h"
35 #include "tensorflow/core/lib/strings/numbers.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/protobuf.h"
38
39 namespace xla {
40 namespace {
41
42 // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
43 // minor_to_major to the value that represents the default layout.
SetDefaultLayoutToContainer(std::vector<int64> * minor_to_major)44 void SetDefaultLayoutToContainer(std::vector<int64>* minor_to_major) {
45 // The default XLA layout is major-to-minor (dim 0 is major).
46 // For more information on XLA layouts, see:
47 // https://www.tensorflow.org/performance/xla/shapes
48 const int64 size = minor_to_major->size();
49 for (int64 i = 0; i < size; ++i) {
50 (*minor_to_major)[i] = size - 1 - i;
51 }
52 }
53
54 } // namespace
55
MakeLayout(absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits)56 /* static */ Layout LayoutUtil::MakeLayout(
57 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
58 int64 element_size_in_bits) {
59 Layout layout;
60 layout.set_format(DENSE);
61 for (int64 dimension_number : minor_to_major) {
62 layout.add_minor_to_major(dimension_number);
63 }
64 for (Tile tile : tiles) {
65 for (int64 dim : tile.dimensions()) {
66 if (dim < 0 && dim != Tile::kCombineDimension) {
67 LOG(FATAL) << "Tile dimension size needs to be mininum int64 value if "
68 "it's negative. Value is "
69 << dim;
70 }
71 }
72 *layout.add_tiles() = tile;
73 }
74 layout.set_element_size_in_bits(element_size_in_bits);
75 return layout;
76 }
77
MakeDescendingLayout(int64 rank)78 /* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) {
79 std::vector<int64> layout(rank);
80 std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
81 return MakeLayout(layout);
82 }
83
MakeLayoutFromMajorToMinor(absl::Span<const int64> major_to_minor)84 /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
85 absl::Span<const int64> major_to_minor) {
86 Layout layout;
87 layout.set_format(DENSE);
88 for (int i = major_to_minor.size() - 1; i >= 0; i--) {
89 layout.add_minor_to_major(major_to_minor[i]);
90 }
91 return layout;
92 }
93
MakeSparseLayout(int64 max_sparse_elements)94 /* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) {
95 Layout layout;
96 layout.set_format(SPARSE);
97 layout.set_max_sparse_elements(max_sparse_elements);
98 return layout;
99 }
100
101 namespace {
102
103 // Internal helper that creates a default layout for an array of the given rank.
CreateDefaultLayoutForRank(int64 rank)104 Layout CreateDefaultLayoutForRank(int64 rank) {
105 Layout layout;
106 layout.set_format(DENSE);
107 std::vector<int64>* minor_to_major = layout.mutable_minor_to_major();
108 minor_to_major->resize(rank, 0);
109 SetDefaultLayoutToContainer(minor_to_major);
110 return layout;
111 }
112
113 } // namespace
114
GetDefaultLayoutForShape(const Shape & shape)115 /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
116 if (shape.IsOpaque() || shape.IsToken()) {
117 // Opaque and token types have empty layouts.
118 return Layout();
119 }
120
121 // A Layout proto corresponds to a single array, not a tuple.
122 CHECK(shape.IsArray());
123 return CreateDefaultLayoutForRank(shape.dimensions_size());
124 }
125
GetDefaultLayoutForRank(int64 rank)126 /* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
127 return CreateDefaultLayoutForRank(rank);
128 }
129
GetDefaultLayoutForR2()130 /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
131 return CreateDefaultLayoutForRank(2);
132 }
133
GetDefaultLayoutForR3()134 /* static */ Layout LayoutUtil::GetDefaultLayoutForR3() {
135 return CreateDefaultLayoutForRank(3);
136 }
137
GetDefaultLayoutForR4()138 /* static */ Layout LayoutUtil::GetDefaultLayoutForR4() {
139 return CreateDefaultLayoutForRank(4);
140 }
141
SetToDefaultLayout(Shape * shape)142 /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) {
143 if (shape->IsTuple()) {
144 // Tuple shape.
145 for (auto& element_shape : *shape->mutable_tuple_shapes()) {
146 SetToDefaultLayout(&element_shape);
147 }
148 shape->clear_layout();
149 } else if (shape->IsArray()) {
150 shape->mutable_layout()->set_format(DENSE);
151 auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
152 minor_to_major->resize(shape->dimensions_size(), 0);
153 SetDefaultLayoutToContainer(minor_to_major);
154 } else {
155 // Opaque, token types etc. have no layout.
156 shape->clear_layout();
157 }
158 }
159
GetWithDefaultLayout(const Shape & shape)160 /* static */ Shape LayoutUtil::GetWithDefaultLayout(const Shape& shape) {
161 Shape copy(shape);
162 LayoutUtil::SetToDefaultLayout(©);
163 return copy;
164 }
165
SetToDefaultLayout(ProgramShape * program_shape)166 /* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) {
167 for (auto& parameter_shape : *program_shape->mutable_parameters()) {
168 LayoutUtil::SetToDefaultLayout(¶meter_shape);
169 }
170 LayoutUtil::SetToDefaultLayout(program_shape->mutable_result());
171 }
172
ValidateLayoutInShape(const Shape & shape,bool allow_missing_layouts)173 /* static */ Status LayoutUtil::ValidateLayoutInShape(
174 const Shape& shape, bool allow_missing_layouts) {
175 if (shape.IsTuple()) {
176 // Tuple shape.
177 if (shape.has_layout()) {
178 return InvalidArgument("tuple should not have a layout field");
179 }
180 for (auto& element_shape : shape.tuple_shapes()) {
181 TF_RETURN_IF_ERROR(
182 ValidateLayoutInShape(element_shape, allow_missing_layouts));
183 }
184 return Status::OK();
185 } else if (shape.IsArray()) {
186 if (!shape.has_layout()) {
187 if (allow_missing_layouts) {
188 return Status::OK();
189 }
190 return InvalidArgument("shape %s does not have a layout",
191 ShapeUtil::HumanString(shape));
192 }
193 return ValidateLayoutForShape(shape.layout(), shape);
194 } else {
195 // Token, opaque, etc. shape.
196 if (shape.has_layout()) {
197 return InvalidArgument(
198 "shape of primitive type %s should not have a layout",
199 PrimitiveType_Name(shape.element_type()));
200 }
201 return Status::OK();
202 }
203 }
204
ValidateLayoutForShape(const Layout & layout,const Shape & shape)205 /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout,
206 const Shape& shape) {
207 if (shape.IsTuple()) {
208 return InvalidArgument("a single Layout is not valid for tuple shapes");
209 }
210
211 if (!shape.IsArray()) {
212 if (layout.minor_to_major_size() != 0) {
213 return InvalidArgument(
214 "shape of primitive type %s should not have a non-trivial layout",
215 PrimitiveType_Name(shape.element_type()));
216 }
217 return Status::OK();
218 }
219
220 if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) {
221 return InvalidArgument("Layout has an invalid format (%d)",
222 layout.format());
223 }
224
225 if (layout.format() == DENSE) {
226 if (layout.minor_to_major_size() != shape.rank()) {
227 return InvalidArgument(
228 "layout minor_to_major field contains %d elements, "
229 "but shape is rank %d: {%s}; shape: %s",
230 layout.minor_to_major_size(), shape.rank(),
231 absl::StrJoin(layout.minor_to_major(), ", "),
232 shape.ShortDebugString());
233 }
234
235 std::vector<bool> dimensions_in_layout(shape.rank(), false);
236 for (int64 i = 0; i < shape.rank(); ++i) {
237 int64 dim = layout.minor_to_major(i);
238 if (dim < 0 || dim >= shape.rank()) {
239 return InvalidArgument(
240 "layout minor_to_major field has out-of-bounds value: %s",
241 HumanString(layout));
242 }
243 if (dimensions_in_layout[dim]) {
244 return InvalidArgument(
245 "layout minor_to_major field has duplicate values: {%s}",
246 HumanString(layout));
247 }
248 dimensions_in_layout[dim] = true;
249 }
250 } else {
251 if (layout.tiles_size() != 0) {
252 return InvalidArgument("Only dense layouts can be tiled.");
253 }
254 }
255
256 return Status::OK();
257 }
258
ClearLayout(Shape * shape)259 /* static */ void LayoutUtil::ClearLayout(Shape* shape) {
260 shape->clear_layout();
261 for (auto& element_shape : *shape->mutable_tuple_shapes()) {
262 ClearLayout(&element_shape);
263 }
264 }
265
ClearLayout(ProgramShape * program_shape)266 /* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) {
267 for (auto& parameter_shape : *program_shape->mutable_parameters()) {
268 LayoutUtil::ClearLayout(¶meter_shape);
269 }
270 LayoutUtil::ClearLayout(program_shape->mutable_result());
271 }
272
IsDenseArray(const Shape & shape)273 /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) {
274 return shape.IsArray() && shape.has_layout() && IsDense(shape.layout());
275 }
276
IsDense(const Layout & layout)277 /* static */ bool LayoutUtil::IsDense(const Layout& layout) {
278 return layout.format() == DENSE;
279 }
280
IsMonotonicWithDim0Minor(const Layout & layout)281 /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) {
282 CHECK(layout.format() == DENSE);
283 return std::is_sorted(layout.minor_to_major().begin(),
284 layout.minor_to_major().end());
285 }
286
IsMonotonicWithDim0Major(const Layout & layout)287 /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) {
288 CHECK(layout.format() == DENSE);
289 return std::is_sorted(layout.minor_to_major().begin(),
290 layout.minor_to_major().end(), std::greater<int64>());
291 }
292
IsSparseArray(const Shape & shape)293 /* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) {
294 return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout());
295 }
296
IsSparse(const Layout & layout)297 /* static */ bool LayoutUtil::IsSparse(const Layout& layout) {
298 return layout.format() == SPARSE;
299 }
300
MaxSparseElements(const Layout & layout)301 /* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) {
302 CHECK(IsSparse(layout));
303 return layout.max_sparse_elements();
304 }
305
HasLayout(const Shape & shape)306 /* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
307 if (shape.IsTuple()) {
308 // Tuple shape: all subshapes must have a layout.
309 return absl::c_all_of(shape.tuple_shapes(),
310 [](const Shape& s) { return HasLayout(s); });
311 } else if (!shape.IsArray()) {
312 // Opaque, token types etc. ignore layout.
313 return true;
314 }
315 return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
316 }
317
HasLayout(const ProgramShape & program_shape)318 /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) {
319 for (auto& parameter_shape : program_shape.parameters()) {
320 if (!LayoutUtil::HasLayout(parameter_shape)) {
321 return false;
322 }
323 }
324 return LayoutUtil::HasLayout(program_shape.result());
325 }
326
Equal(const Layout & lhs,const Layout & rhs)327 /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
328 return lhs == rhs;
329 }
330
MinorToMajor(const Shape & shape)331 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
332 const Shape& shape) {
333 CHECK(IsDenseArray(shape));
334 return AsInt64Slice(shape.layout().minor_to_major());
335 }
336
MinorToMajor(const Layout & layout)337 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
338 const Layout& layout) {
339 CHECK(layout.format() == DENSE);
340 return AsInt64Slice(layout.minor_to_major());
341 }
342
Major(const Layout & layout,int64 physical_dimension_number)343 /* static */ int64 LayoutUtil::Major(const Layout& layout,
344 int64 physical_dimension_number) {
345 CHECK_LE(0, physical_dimension_number);
346 CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
347 return Minor(layout,
348 layout.minor_to_major_size() - 1 - physical_dimension_number);
349 }
350
Minor(const Layout & layout,int64 physical_dimension_number)351 /* static */ int64 LayoutUtil::Minor(const Layout& layout,
352 int64 physical_dimension_number) {
353 CHECK_EQ(layout.format(), DENSE);
354 CHECK_LE(0, physical_dimension_number);
355 CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
356 return layout.minor_to_major(physical_dimension_number);
357 }
358
MakeLogicalToPhysical(const Layout & layout)359 /* static */ std::vector<int64> LayoutUtil::MakeLogicalToPhysical(
360 const Layout& layout) {
361 std::vector<int64> logical_to_physical(layout.minor_to_major_size());
362 for (int64 physical = 0; physical < logical_to_physical.size(); ++physical) {
363 const int64 logical = Major(layout, physical);
364 logical_to_physical[logical] = physical;
365 }
366 return logical_to_physical;
367 }
368
HumanString(const Layout & layout)369 /* static */ string LayoutUtil::HumanString(const Layout& layout) {
370 return layout.ToString();
371 }
372
373 namespace {
374
375 // Internal helper for recursively copying layouts.
CopyLayoutInternal(const Shape & src,Shape * dst)376 Status CopyLayoutInternal(const Shape& src, Shape* dst) {
377 if (src.IsTuple() != dst->IsTuple()) {
378 return InvalidArgument(
379 "cannot copy layout from shape: shape structure differs");
380 }
381 if (src.IsTuple()) {
382 if (ShapeUtil::TupleElementCount(src) !=
383 ShapeUtil::TupleElementCount(*dst)) {
384 return InvalidArgument(
385 "cannot copy layout from shape: tuple element count differs");
386 }
387 for (int64 i = 0; i < ShapeUtil::TupleElementCount(src); ++i) {
388 TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i),
389 dst->mutable_tuple_shapes(i)));
390 }
391 } else {
392 if (src.has_layout()) {
393 if (src.rank() != dst->rank()) {
394 return InvalidArgument("cannot copy layout from shape: ranks differs");
395 }
396 TF_RETURN_IF_ERROR(
397 LayoutUtil::ValidateLayoutForShape(src.layout(), *dst));
398 *dst->mutable_layout() = src.layout();
399 } else {
400 dst->clear_layout();
401 }
402 }
403 return Status::OK();
404 }
405
406 } // namespace
407
408 /* static */
CopyLayoutBetweenShapes(const Shape & src,Shape * dst)409 Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
410 return CopyLayoutInternal(src, dst);
411 }
412
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)413 /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
414 const Shape& rhs) {
415 if (lhs.IsTuple()) {
416 if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) !=
417 ShapeUtil::TupleElementCount(rhs)) {
418 return false;
419 }
420 for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
421 if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) {
422 return false;
423 }
424 }
425 return true;
426 } else if (lhs.IsArray()) {
427 return lhs.rank() == rhs.rank() &&
428 LayoutUtil::Equal(lhs.layout(), rhs.layout());
429 } else {
430 // Layouts of non-array and non-tuple shapes is ignored.
431 return true;
432 }
433 }
434
AreDimensionsConsecutive(const Layout & layout,absl::Span<const int64> dims)435 /* static */ bool LayoutUtil::AreDimensionsConsecutive(
436 const Layout& layout, absl::Span<const int64> dims) {
437 CHECK(IsDense(layout));
438 std::vector<int64> positions_in_layout;
439 for (int64 dim : dims) {
440 positions_in_layout.push_back(
441 PositionInContainer(layout.minor_to_major(), dim));
442 }
443 absl::c_sort(positions_in_layout);
444 for (size_t i = 1; i < positions_in_layout.size(); ++i) {
445 if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
446 return false;
447 }
448 }
449 return true;
450 }
451
Hash(const Layout & layout)452 /*static*/ size_t LayoutUtil::Hash(const Layout& layout) {
453 using tensorflow::hash;
454 using tensorflow::Hash64Combine;
455
456 size_t hash_value = hash<Format>()(layout.format());
457
458 for (int64 minor_to_major : layout.minor_to_major()) {
459 hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
460 }
461 hash_value = Hash64Combine(hash_value, layout.max_sparse_elements());
462
463 for (Tile tile : layout.tiles()) {
464 for (int64 tile_dim : tile.dimensions()) {
465 hash_value = Hash64Combine(hash_value, hash<int64>()(tile_dim));
466 }
467 }
468 hash_value = Hash64Combine(hash_value, layout.element_size_in_bits());
469
470 return hash_value;
471 }
472
473 } // namespace xla
474