1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/layout_util.h"
17
18 #include <stddef.h>
19
20 #include <algorithm>
21 #include <functional>
22 #include <random>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/xla/protobuf_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/hash/hash.h"
36 #include "tensorflow/core/lib/strings/numbers.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/protobuf.h"
39
40 namespace xla {
41 namespace {
42
43 // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
44 // minor_to_major to the value that represents the default layout.
45 template <typename T>
SetDefaultLayoutToContainer(T * minor_to_major)46 void SetDefaultLayoutToContainer(T* minor_to_major) {
47 // The default XLA layout is major-to-minor (dim 0 is major).
48 // For more information on XLA layouts, see:
49 // https://www.tensorflow.org/performance/xla/shapes
50 const int64 size = minor_to_major->size();
51 for (int64 i = 0; i < size; ++i) {
52 (*minor_to_major)[i] = size - 1 - i;
53 }
54 }
55
56 } // namespace
57
MakeLayout(absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)58 /* static */ Layout LayoutUtil::MakeLayout(
59 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
60 int64 element_size_in_bits, int64 memory_space) {
61 Layout layout;
62 layout.set_format(DENSE);
63 for (int64 dimension_number : minor_to_major) {
64 layout.add_minor_to_major(dimension_number);
65 }
66 for (Tile tile : tiles) {
67 for (int64 dim : tile.dimensions()) {
68 if (dim < 0 && dim != Tile::kCombineDimension) {
69 LOG(FATAL) << "Tile dimension size needs to be minimum int64 value if "
70 "it's negative. Value is "
71 << dim;
72 }
73 }
74 *layout.add_tiles() = tile;
75 }
76 layout.set_element_size_in_bits(element_size_in_bits);
77 layout.set_memory_space(memory_space);
78 return layout;
79 }
80
MakeDescendingLayout(int64 rank)81 /* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) {
82 std::vector<int64> layout(rank);
83 std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
84 return MakeLayout(layout);
85 }
86
MakeLayoutFromMajorToMinor(absl::Span<const int64> major_to_minor)87 /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
88 absl::Span<const int64> major_to_minor) {
89 Layout layout;
90 layout.set_format(DENSE);
91 for (int i = major_to_minor.size() - 1; i >= 0; i--) {
92 layout.add_minor_to_major(major_to_minor[i]);
93 }
94 return layout;
95 }
96
97 namespace {
98
99 // Internal helper that creates a default layout for an array of the given rank.
CreateDefaultLayoutForRank(int64 rank)100 Layout CreateDefaultLayoutForRank(int64 rank) {
101 Layout layout;
102 layout.set_format(DENSE);
103 auto* minor_to_major = layout.mutable_minor_to_major();
104 minor_to_major->resize(rank, 0);
105 SetDefaultLayoutToContainer(minor_to_major);
106 return layout;
107 }
108
109 } // namespace
110
GetDefaultLayoutForShape(const Shape & shape)111 /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
112 if (shape.IsOpaque() || shape.IsToken()) {
113 // Opaque and token types have empty layouts.
114 return Layout();
115 }
116
117 // A Layout proto corresponds to a single array, not a tuple.
118 CHECK(shape.IsArray());
119 return CreateDefaultLayoutForRank(shape.dimensions_size());
120 }
121
GetDefaultLayoutForRank(int64 rank)122 /* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
123 return CreateDefaultLayoutForRank(rank);
124 }
125
GetDefaultLayoutForR2()126 /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
127 return CreateDefaultLayoutForRank(2);
128 }
129
GetDefaultLayoutForR3()130 /* static */ Layout LayoutUtil::GetDefaultLayoutForR3() {
131 return CreateDefaultLayoutForRank(3);
132 }
133
GetDefaultLayoutForR4()134 /* static */ Layout LayoutUtil::GetDefaultLayoutForR4() {
135 return CreateDefaultLayoutForRank(4);
136 }
137
SetToDefaultLayout(Shape * shape)138 /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) {
139 if (shape->IsTuple()) {
140 // Tuple shape.
141 for (auto& element_shape : *shape->mutable_tuple_shapes()) {
142 SetToDefaultLayout(&element_shape);
143 }
144 shape->clear_layout();
145 } else if (shape->IsArray()) {
146 shape->mutable_layout()->set_format(DENSE);
147 auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
148 minor_to_major->resize(shape->dimensions_size(), 0);
149 SetDefaultLayoutToContainer(minor_to_major);
150 } else {
151 // Opaque, token types etc. have no layout.
152 shape->clear_layout();
153 }
154 }
155
GetWithDefaultLayout(const Shape & shape)156 /* static */ Shape LayoutUtil::GetWithDefaultLayout(const Shape& shape) {
157 Shape copy(shape);
158 LayoutUtil::SetToDefaultLayout(©);
159 return copy;
160 }
161
SetToDefaultLayout(ProgramShape * program_shape)162 /* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) {
163 for (auto& parameter_shape : *program_shape->mutable_parameters()) {
164 LayoutUtil::SetToDefaultLayout(¶meter_shape);
165 }
166 LayoutUtil::SetToDefaultLayout(program_shape->mutable_result());
167 }
168
ValidateLayoutInShape(const Shape & shape,bool allow_missing_layouts)169 /* static */ Status LayoutUtil::ValidateLayoutInShape(
170 const Shape& shape, bool allow_missing_layouts) {
171 if (shape.IsTuple()) {
172 // Tuple shape.
173 if (shape.has_layout()) {
174 return InvalidArgument("tuple should not have a layout field");
175 }
176 for (auto& element_shape : shape.tuple_shapes()) {
177 TF_RETURN_IF_ERROR(
178 ValidateLayoutInShape(element_shape, allow_missing_layouts));
179 }
180 return Status::OK();
181 } else if (shape.IsArray()) {
182 if (!shape.has_layout()) {
183 if (allow_missing_layouts) {
184 return Status::OK();
185 }
186 return InvalidArgument("shape %s does not have a layout",
187 ShapeUtil::HumanString(shape));
188 }
189 return ValidateLayoutForShape(shape.layout(), shape);
190 } else {
191 // Token, opaque, etc. shape.
192 if (shape.has_layout()) {
193 return InvalidArgument(
194 "shape of primitive type %s should not have a layout",
195 PrimitiveType_Name(shape.element_type()));
196 }
197 return Status::OK();
198 }
199 }
200
ValidateLayoutForShape(const Layout & layout,const Shape & shape)201 /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout,
202 const Shape& shape) {
203 if (shape.IsTuple()) {
204 return InvalidArgument("a single Layout is not valid for tuple shapes");
205 }
206
207 if (!shape.IsArray()) {
208 if (layout.minor_to_major_size() != 0) {
209 return InvalidArgument(
210 "shape of primitive type %s should not have a non-trivial layout",
211 PrimitiveType_Name(shape.element_type()));
212 }
213 return Status::OK();
214 }
215
216 if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) {
217 return InvalidArgument("Layout has an invalid format (%d)",
218 layout.format());
219 }
220
221 if (layout.format() == DENSE) {
222 if (layout.minor_to_major_size() != shape.rank()) {
223 return InvalidArgument(
224 "layout minor_to_major field contains %d elements, "
225 "but shape is rank %d: {%s}; shape: %s",
226 layout.minor_to_major_size(), shape.rank(),
227 absl::StrJoin(layout.minor_to_major(), ", "),
228 shape.ShortDebugString());
229 }
230
231 std::vector<bool> dimensions_in_layout(shape.rank(), false);
232 for (int64 i = 0; i < shape.rank(); ++i) {
233 int64 dim = layout.minor_to_major(i);
234 if (dim < 0 || dim >= shape.rank()) {
235 return InvalidArgument(
236 "layout minor_to_major field has out-of-bounds value: %s",
237 HumanString(layout));
238 }
239 if (dimensions_in_layout[dim]) {
240 return InvalidArgument(
241 "layout minor_to_major field has duplicate values: {%s}",
242 HumanString(layout));
243 }
244 dimensions_in_layout[dim] = true;
245 }
246 } else {
247 if (layout.tiles_size() != 0) {
248 return InvalidArgument("Only dense layouts can be tiled.");
249 }
250 }
251
252 return Status::OK();
253 }
254
ClearLayout(Shape * shape)255 /* static */ void LayoutUtil::ClearLayout(Shape* shape) {
256 shape->clear_layout();
257 for (auto& element_shape : *shape->mutable_tuple_shapes()) {
258 ClearLayout(&element_shape);
259 }
260 }
261
ClearLayout(ProgramShape * program_shape)262 /* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) {
263 for (auto& parameter_shape : *program_shape->mutable_parameters()) {
264 LayoutUtil::ClearLayout(¶meter_shape);
265 }
266 LayoutUtil::ClearLayout(program_shape->mutable_result());
267 }
268
IsDenseArray(const Shape & shape)269 /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) {
270 return shape.IsArray() && shape.has_layout() && IsDense(shape.layout());
271 }
272
IsDense(const Layout & layout)273 /* static */ bool LayoutUtil::IsDense(const Layout& layout) {
274 return layout.format() == DENSE;
275 }
276
IsMonotonicWithDim0Minor(const Layout & layout)277 /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) {
278 CHECK(layout.format() == DENSE);
279 return std::is_sorted(layout.minor_to_major().begin(),
280 layout.minor_to_major().end());
281 }
282
IsMonotonicWithDim0Major(const Layout & layout)283 /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) {
284 CHECK(layout.format() == DENSE);
285 return std::is_sorted(layout.minor_to_major().begin(),
286 layout.minor_to_major().end(), std::greater<int64>());
287 }
288
HasLayout(const Shape & shape)289 /* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
290 if (shape.IsTuple()) {
291 // Tuple shape: all subshapes must have a layout.
292 return absl::c_all_of(shape.tuple_shapes(),
293 [](const Shape& s) { return HasLayout(s); });
294 } else if (!shape.IsArray()) {
295 // Opaque, token types etc. ignore layout.
296 return true;
297 }
298 return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
299 }
300
HasLayout(const ProgramShape & program_shape)301 /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) {
302 for (auto& parameter_shape : program_shape.parameters()) {
303 if (!LayoutUtil::HasLayout(parameter_shape)) {
304 return false;
305 }
306 }
307 return LayoutUtil::HasLayout(program_shape.result());
308 }
309
Equal(const Layout & lhs,const Layout & rhs)310 /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
311 return lhs == rhs;
312 }
313
MinorToMajor(const Shape & shape)314 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
315 const Shape& shape) {
316 CHECK(IsDenseArray(shape));
317 return AsInt64Slice(shape.layout().minor_to_major());
318 }
319
MinorToMajor(const Layout & layout)320 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
321 const Layout& layout) {
322 CHECK(layout.format() == DENSE);
323 return AsInt64Slice(layout.minor_to_major());
324 }
325
Major(const Layout & layout,int64 physical_dimension_number)326 /* static */ int64 LayoutUtil::Major(const Layout& layout,
327 int64 physical_dimension_number) {
328 CHECK_LE(0, physical_dimension_number);
329 CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
330 return Minor(layout,
331 layout.minor_to_major_size() - 1 - physical_dimension_number);
332 }
333
Minor(const Layout & layout,int64 physical_dimension_number)334 /* static */ int64 LayoutUtil::Minor(const Layout& layout,
335 int64 physical_dimension_number) {
336 CHECK_EQ(layout.format(), DENSE);
337 CHECK_LE(0, physical_dimension_number);
338 CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
339 return layout.minor_to_major(physical_dimension_number);
340 }
341
MakeLogicalToPhysical(const Layout & layout)342 /* static */ std::vector<int64> LayoutUtil::MakeLogicalToPhysical(
343 const Layout& layout) {
344 std::vector<int64> logical_to_physical(layout.minor_to_major_size());
345 for (int64 physical = 0; physical < logical_to_physical.size(); ++physical) {
346 const int64 logical = Major(layout, physical);
347 logical_to_physical[logical] = physical;
348 }
349 return logical_to_physical;
350 }
351
HumanString(const Layout & layout)352 /* static */ string LayoutUtil::HumanString(const Layout& layout) {
353 return layout.ToString();
354 }
355
356 namespace {
357
358 // Internal helper for recursively copying layouts.
CopyLayoutInternal(const Shape & src,Shape * dst)359 Status CopyLayoutInternal(const Shape& src, Shape* dst) {
360 if (src.IsTuple() != dst->IsTuple()) {
361 return InvalidArgument(
362 "cannot copy layout from shape: shape structure differs");
363 }
364 if (src.IsTuple()) {
365 if (ShapeUtil::TupleElementCount(src) !=
366 ShapeUtil::TupleElementCount(*dst)) {
367 return InvalidArgument(
368 "cannot copy layout from shape: tuple element count differs");
369 }
370 for (int64 i = 0; i < ShapeUtil::TupleElementCount(src); ++i) {
371 TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i),
372 dst->mutable_tuple_shapes(i)));
373 }
374 } else {
375 if (src.has_layout()) {
376 if (src.rank() != dst->rank()) {
377 return InvalidArgument("cannot copy layout from shape: ranks differs");
378 }
379 TF_RETURN_IF_ERROR(
380 LayoutUtil::ValidateLayoutForShape(src.layout(), *dst));
381 *dst->mutable_layout() = src.layout();
382 } else {
383 dst->clear_layout();
384 }
385 }
386 return Status::OK();
387 }
388
389 } // namespace
390
391 /* static */
CopyLayoutBetweenShapes(const Shape & src,Shape * dst)392 Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
393 return CopyLayoutInternal(src, dst);
394 }
395
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)396 /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
397 const Shape& rhs) {
398 if (lhs.IsTuple()) {
399 if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) !=
400 ShapeUtil::TupleElementCount(rhs)) {
401 return false;
402 }
403 for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
404 if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) {
405 return false;
406 }
407 }
408 return true;
409 } else if (lhs.IsArray()) {
410 return lhs.rank() == rhs.rank() &&
411 LayoutUtil::Equal(lhs.layout(), rhs.layout());
412 } else {
413 // Layouts of non-array and non-tuple shapes is ignored.
414 return true;
415 }
416 }
417
AreDimensionsConsecutive(const Layout & layout,absl::Span<const int64> dims)418 /* static */ bool LayoutUtil::AreDimensionsConsecutive(
419 const Layout& layout, absl::Span<const int64> dims) {
420 CHECK(IsDense(layout));
421 std::vector<int64> positions_in_layout;
422 for (int64 dim : dims) {
423 positions_in_layout.push_back(
424 PositionInContainer(layout.minor_to_major(), dim));
425 }
426 absl::c_sort(positions_in_layout);
427 for (size_t i = 1; i < positions_in_layout.size(); ++i) {
428 if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
429 return false;
430 }
431 }
432 return true;
433 }
434
Hash(const Layout & layout)435 /*static*/ size_t LayoutUtil::Hash(const Layout& layout) {
436 using tensorflow::hash;
437 using tensorflow::Hash64Combine;
438
439 size_t hash_value = hash<Format>()(layout.format());
440
441 for (int64 minor_to_major : layout.minor_to_major()) {
442 hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
443 }
444 for (Tile tile : layout.tiles()) {
445 for (int64 tile_dim : tile.dimensions()) {
446 hash_value = Hash64Combine(hash_value, hash<int64>()(tile_dim));
447 }
448 }
449 hash_value = Hash64Combine(hash_value, layout.element_size_in_bits());
450 hash_value = Hash64Combine(hash_value, layout.memory_space());
451
452 return hash_value;
453 }
454
455 } // namespace xla
456