1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/layout_util.h"
17
18 #include <stddef.h>
19
20 #include <algorithm>
21 #include <functional>
22 #include <random>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/xla/protobuf_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/hash/hash.h"
36 #include "tensorflow/core/lib/strings/numbers.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/protobuf.h"
39
40 namespace xla {
41 namespace {
42
43 // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
44 // minor_to_major to the value that represents the default layout.
45 template <typename T>
SetDefaultLayoutToContainer(T * minor_to_major)46 void SetDefaultLayoutToContainer(T* minor_to_major) {
47 // The default XLA layout is major-to-minor (dim 0 is major).
48 // For more information on XLA layouts, see:
49 // https://www.tensorflow.org/performance/xla/shapes
50 const int64 size = minor_to_major->size();
51 for (int64 i = 0; i < size; ++i) {
52 (*minor_to_major)[i] = size - 1 - i;
53 }
54 }
55
56 } // namespace
57
MakeLayout(absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)58 /* static */ Layout LayoutUtil::MakeLayout(
59 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
60 int64 element_size_in_bits, int64 memory_space) {
61 Layout layout;
62 layout.set_format(DENSE);
63 for (int64 dimension_number : minor_to_major) {
64 layout.add_minor_to_major(dimension_number);
65 }
66 for (const Tile& tile : tiles) {
67 for (int64 dim : tile.dimensions()) {
68 if (dim < 0 && dim != Tile::kCombineDimension) {
69 LOG(FATAL) << "Tile dimension size needs to be minimum int64 value if "
70 "it's negative. Value is "
71 << dim;
72 }
73 }
74 *layout.add_tiles() = tile;
75 }
76 layout.set_element_size_in_bits(element_size_in_bits);
77 layout.set_memory_space(memory_space);
78 return layout;
79 }
80
MakeDescendingLayout(int64 rank)81 /* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) {
82 std::vector<int64> layout(rank);
83 std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
84 return MakeLayout(layout);
85 }
86
MakeAscendingLayout(int64 rank)87 /* static */ Layout LayoutUtil::MakeAscendingLayout(int64 rank) {
88 std::vector<int64> layout(rank);
89 std::iota(layout.begin(), layout.end(), static_cast<int64>(0));
90 return MakeLayout(layout);
91 }
92
MakeLayoutFromMajorToMinor(absl::Span<const int64> major_to_minor)93 /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
94 absl::Span<const int64> major_to_minor) {
95 Layout layout;
96 layout.set_format(DENSE);
97 for (int i = major_to_minor.size() - 1; i >= 0; i--) {
98 layout.add_minor_to_major(major_to_minor[i]);
99 }
100 return layout;
101 }
102
103 namespace {
104
105 // Internal helper that creates a default layout for an array of the given rank.
CreateDefaultLayoutForRank(int64 rank)106 Layout CreateDefaultLayoutForRank(int64 rank) {
107 Layout layout;
108 layout.set_format(DENSE);
109 auto* minor_to_major = layout.mutable_minor_to_major();
110 minor_to_major->resize(rank, 0);
111 SetDefaultLayoutToContainer(minor_to_major);
112 return layout;
113 }
114
115 } // namespace
116
GetDefaultLayoutForShape(const Shape & shape)117 /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
118 if (shape.IsOpaque() || shape.IsToken()) {
119 // Opaque and token types have empty layouts.
120 return Layout();
121 }
122
123 // A Layout proto corresponds to a single array, not a tuple.
124 CHECK(shape.IsArray());
125 return CreateDefaultLayoutForRank(shape.dimensions_size());
126 }
127
GetDefaultLayoutForRank(int64 rank)128 /* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
129 return CreateDefaultLayoutForRank(rank);
130 }
131
GetDefaultLayoutForR2()132 /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
133 return CreateDefaultLayoutForRank(2);
134 }
135
GetDefaultLayoutForR3()136 /* static */ Layout LayoutUtil::GetDefaultLayoutForR3() {
137 return CreateDefaultLayoutForRank(3);
138 }
139
GetDefaultLayoutForR4()140 /* static */ Layout LayoutUtil::GetDefaultLayoutForR4() {
141 return CreateDefaultLayoutForRank(4);
142 }
143
SetToDefaultLayout(Shape * shape)144 /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) {
145 if (shape->IsTuple()) {
146 // Tuple shape.
147 for (auto& element_shape : *shape->mutable_tuple_shapes()) {
148 SetToDefaultLayout(&element_shape);
149 }
150 shape->clear_layout();
151 } else if (shape->IsArray()) {
152 shape->mutable_layout()->set_format(DENSE);
153 auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
154 minor_to_major->resize(shape->dimensions_size(), 0);
155 SetDefaultLayoutToContainer(minor_to_major);
156 } else {
157 // Opaque, token types etc. have no layout.
158 shape->clear_layout();
159 }
160 }
161
GetWithDefaultLayout(const Shape & shape)162 /* static */ Shape LayoutUtil::GetWithDefaultLayout(const Shape& shape) {
163 Shape copy(shape);
164 LayoutUtil::SetToDefaultLayout(©);
165 return copy;
166 }
167
SetToDefaultLayout(ProgramShape * program_shape)168 /* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) {
169 for (auto& parameter_shape : *program_shape->mutable_parameters()) {
170 LayoutUtil::SetToDefaultLayout(¶meter_shape);
171 }
172 LayoutUtil::SetToDefaultLayout(program_shape->mutable_result());
173 }
174
ValidateLayoutInShape(const Shape & shape,bool allow_missing_layouts)175 /* static */ Status LayoutUtil::ValidateLayoutInShape(
176 const Shape& shape, bool allow_missing_layouts) {
177 if (shape.IsTuple()) {
178 // Tuple shape.
179 if (shape.has_layout()) {
180 return InvalidArgument("tuple should not have a layout field");
181 }
182 for (auto& element_shape : shape.tuple_shapes()) {
183 TF_RETURN_IF_ERROR(
184 ValidateLayoutInShape(element_shape, allow_missing_layouts));
185 }
186 return Status::OK();
187 } else if (shape.IsArray()) {
188 if (!shape.has_layout()) {
189 if (allow_missing_layouts) {
190 return Status::OK();
191 }
192 return InvalidArgument("shape %s does not have a layout",
193 ShapeUtil::HumanString(shape));
194 }
195 return ValidateLayoutForShape(shape.layout(), shape);
196 } else {
197 // Token, opaque, etc. shape.
198 if (shape.has_layout()) {
199 return InvalidArgument(
200 "shape of primitive type %s should not have a layout",
201 PrimitiveType_Name(shape.element_type()));
202 }
203 return Status::OK();
204 }
205 }
206
ValidateLayoutForShape(const Layout & layout,const Shape & shape)207 /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout,
208 const Shape& shape) {
209 if (shape.IsTuple()) {
210 return InvalidArgument("a single Layout is not valid for tuple shapes");
211 }
212
213 if (!shape.IsArray()) {
214 if (layout.minor_to_major_size() != 0) {
215 return InvalidArgument(
216 "shape of primitive type %s should not have a non-trivial layout",
217 PrimitiveType_Name(shape.element_type()));
218 }
219 return Status::OK();
220 }
221
222 if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) {
223 return InvalidArgument("Layout has an invalid format (%d)",
224 layout.format());
225 }
226
227 if (layout.format() == DENSE) {
228 if (layout.minor_to_major_size() != shape.rank()) {
229 return InvalidArgument(
230 "layout minor_to_major field contains %d elements, "
231 "but shape is rank %d: {%s}; shape: %s",
232 layout.minor_to_major_size(), shape.rank(),
233 absl::StrJoin(layout.minor_to_major(), ", "),
234 shape.ShortDebugString());
235 }
236
237 std::vector<bool> dimensions_in_layout(shape.rank(), false);
238 for (int64 i = 0; i < shape.rank(); ++i) {
239 int64 dim = layout.minor_to_major(i);
240 if (dim < 0 || dim >= shape.rank()) {
241 return InvalidArgument(
242 "layout minor_to_major field has out-of-bounds value: %s",
243 HumanString(layout));
244 }
245 if (dimensions_in_layout[dim]) {
246 return InvalidArgument(
247 "layout minor_to_major field has duplicate values: {%s}",
248 HumanString(layout));
249 }
250 dimensions_in_layout[dim] = true;
251 }
252 } else {
253 if (layout.tiles_size() != 0) {
254 return InvalidArgument("Only dense layouts can be tiled.");
255 }
256 }
257
258 return Status::OK();
259 }
260
ClearLayout(Shape * shape)261 /* static */ void LayoutUtil::ClearLayout(Shape* shape) {
262 shape->clear_layout();
263 for (auto& element_shape : *shape->mutable_tuple_shapes()) {
264 ClearLayout(&element_shape);
265 }
266 }
267
ClearLayout(ProgramShape * program_shape)268 /* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) {
269 for (auto& parameter_shape : *program_shape->mutable_parameters()) {
270 LayoutUtil::ClearLayout(¶meter_shape);
271 }
272 LayoutUtil::ClearLayout(program_shape->mutable_result());
273 }
274
IsDenseArray(const Shape & shape)275 /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) {
276 return shape.IsArray() && shape.has_layout() && IsDense(shape.layout());
277 }
278
IsDense(const Layout & layout)279 /* static */ bool LayoutUtil::IsDense(const Layout& layout) {
280 return layout.format() == DENSE;
281 }
282
IsMonotonicWithDim0Minor(const Layout & layout)283 /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) {
284 CHECK(layout.format() == DENSE);
285 return std::is_sorted(layout.minor_to_major().begin(),
286 layout.minor_to_major().end());
287 }
288
IsMonotonicWithDim0Major(const Layout & layout)289 /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) {
290 CHECK(layout.format() == DENSE);
291 return std::is_sorted(layout.minor_to_major().begin(),
292 layout.minor_to_major().end(), std::greater<int64>());
293 }
294
HasLayout(const Shape & shape)295 /* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
296 if (shape.IsTuple()) {
297 // Tuple shape: all subshapes must have a layout.
298 return absl::c_all_of(shape.tuple_shapes(),
299 [](const Shape& s) { return HasLayout(s); });
300 } else if (!shape.IsArray()) {
301 // Opaque, token types etc. ignore layout.
302 return true;
303 }
304 return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
305 }
306
HasLayout(const ProgramShape & program_shape)307 /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) {
308 for (auto& parameter_shape : program_shape.parameters()) {
309 if (!LayoutUtil::HasLayout(parameter_shape)) {
310 return false;
311 }
312 }
313 return LayoutUtil::HasLayout(program_shape.result());
314 }
315
Equal(const Layout & lhs,const Layout & rhs)316 /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
317 return lhs == rhs;
318 }
319
MinorToMajor(const Shape & shape)320 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
321 const Shape& shape) {
322 CHECK(IsDenseArray(shape));
323 return AsInt64Slice(shape.layout().minor_to_major());
324 }
325
MinorToMajor(const Layout & layout)326 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
327 const Layout& layout) {
328 CHECK(layout.format() == DENSE);
329 return AsInt64Slice(layout.minor_to_major());
330 }
331
Major(const Layout & layout,int64 physical_dimension_number)332 /* static */ int64 LayoutUtil::Major(const Layout& layout,
333 int64 physical_dimension_number) {
334 CHECK_LE(0, physical_dimension_number);
335 CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
336 return Minor(layout,
337 layout.minor_to_major_size() - 1 - physical_dimension_number);
338 }
339
Minor(const Layout & layout,int64 physical_dimension_number)340 /* static */ int64 LayoutUtil::Minor(const Layout& layout,
341 int64 physical_dimension_number) {
342 CHECK_EQ(layout.format(), DENSE);
343 CHECK_LE(0, physical_dimension_number);
344 CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
345 return layout.minor_to_major(physical_dimension_number);
346 }
347
MakeLogicalToPhysical(const Layout & layout)348 /* static */ std::vector<int64> LayoutUtil::MakeLogicalToPhysical(
349 const Layout& layout) {
350 std::vector<int64> logical_to_physical(layout.minor_to_major_size());
351 for (int64 physical = 0, end = logical_to_physical.size(); physical < end;
352 ++physical) {
353 const int64 logical = Major(layout, physical);
354 logical_to_physical[logical] = physical;
355 }
356 return logical_to_physical;
357 }
358
HumanString(const Layout & layout)359 /* static */ string LayoutUtil::HumanString(const Layout& layout) {
360 return layout.ToString();
361 }
362
363 namespace {
364
365 // Internal helper for recursively copying layouts.
CopyLayoutInternal(const Shape & src,Shape * dst)366 Status CopyLayoutInternal(const Shape& src, Shape* dst) {
367 if (src.IsTuple() != dst->IsTuple()) {
368 return InvalidArgument(
369 "cannot copy layout from shape: shape structure differs");
370 }
371 if (src.IsTuple()) {
372 if (ShapeUtil::TupleElementCount(src) !=
373 ShapeUtil::TupleElementCount(*dst)) {
374 return InvalidArgument(
375 "cannot copy layout from shape: tuple element count differs");
376 }
377 for (int64 i = 0; i < ShapeUtil::TupleElementCount(src); ++i) {
378 TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i),
379 dst->mutable_tuple_shapes(i)));
380 }
381 } else {
382 if (src.has_layout()) {
383 if (src.rank() != dst->rank()) {
384 return InvalidArgument("cannot copy layout from shape: ranks differs");
385 }
386 TF_RETURN_IF_ERROR(
387 LayoutUtil::ValidateLayoutForShape(src.layout(), *dst));
388 *dst->mutable_layout() = src.layout();
389 } else {
390 dst->clear_layout();
391 }
392 }
393 return Status::OK();
394 }
395
396 } // namespace
397
398 /* static */
CopyLayoutBetweenShapes(const Shape & src,Shape * dst)399 Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
400 return CopyLayoutInternal(src, dst);
401 }
402
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)403 /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
404 const Shape& rhs) {
405 if (lhs.IsTuple()) {
406 if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) !=
407 ShapeUtil::TupleElementCount(rhs)) {
408 return false;
409 }
410 for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
411 if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) {
412 return false;
413 }
414 }
415 return true;
416 } else if (lhs.IsArray()) {
417 return lhs.rank() == rhs.rank() &&
418 LayoutUtil::Equal(lhs.layout(), rhs.layout());
419 } else {
420 // Layouts of non-array and non-tuple shapes is ignored.
421 return true;
422 }
423 }
424
AreDimensionsConsecutive(const Layout & layout,absl::Span<const int64> dims)425 /* static */ bool LayoutUtil::AreDimensionsConsecutive(
426 const Layout& layout, absl::Span<const int64> dims) {
427 CHECK(IsDense(layout));
428 std::vector<int64> positions_in_layout;
429 for (int64 dim : dims) {
430 positions_in_layout.push_back(
431 PositionInContainer(layout.minor_to_major(), dim));
432 }
433 absl::c_sort(positions_in_layout);
434 for (size_t i = 1; i < positions_in_layout.size(); ++i) {
435 if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
436 return false;
437 }
438 }
439 return true;
440 }
441
Hash(const Layout & layout)442 /*static*/ size_t LayoutUtil::Hash(const Layout& layout) {
443 using tensorflow::hash;
444 using tensorflow::Hash64Combine;
445
446 size_t hash_value = hash<Format>()(layout.format());
447
448 for (int64 minor_to_major : layout.minor_to_major()) {
449 hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
450 }
451 for (const Tile& tile : layout.tiles()) {
452 for (int64 tile_dim : tile.dimensions()) {
453 hash_value = Hash64Combine(hash_value, hash<int64>()(tile_dim));
454 }
455 }
456 hash_value = Hash64Combine(hash_value, layout.element_size_in_bits());
457 hash_value = Hash64Combine(hash_value, layout.memory_space());
458
459 return hash_value;
460 }
461
462 } // namespace xla
463