• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/layout_util.h"
17 
18 #include <stddef.h>
19 
20 #include <algorithm>
21 #include <functional>
22 #include <random>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/xla/protobuf_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/hash/hash.h"
36 #include "tensorflow/core/lib/strings/numbers.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/protobuf.h"
39 
40 namespace xla {
41 namespace {
42 
43 // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
44 // minor_to_major to the value that represents the default layout.
45 template <typename T>
SetDefaultLayoutToContainer(T * minor_to_major)46 void SetDefaultLayoutToContainer(T* minor_to_major) {
47   // The default XLA layout is major-to-minor (dim 0 is major).
48   // For more information on XLA layouts, see:
49   // https://www.tensorflow.org/performance/xla/shapes
50   const int64 size = minor_to_major->size();
51   for (int64 i = 0; i < size; ++i) {
52     (*minor_to_major)[i] = size - 1 - i;
53   }
54 }
55 
56 }  // namespace
57 
MakeLayout(absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)58 /* static */ Layout LayoutUtil::MakeLayout(
59     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
60     int64 element_size_in_bits, int64 memory_space) {
61   Layout layout;
62   layout.set_format(DENSE);
63   for (int64 dimension_number : minor_to_major) {
64     layout.add_minor_to_major(dimension_number);
65   }
66   for (const Tile& tile : tiles) {
67     for (int64 dim : tile.dimensions()) {
68       if (dim < 0 && dim != Tile::kCombineDimension) {
69         LOG(FATAL) << "Tile dimension size needs to be minimum int64 value if "
70                       "it's negative. Value is "
71                    << dim;
72       }
73     }
74     *layout.add_tiles() = tile;
75   }
76   layout.set_element_size_in_bits(element_size_in_bits);
77   layout.set_memory_space(memory_space);
78   return layout;
79 }
80 
MakeDescendingLayout(int64 rank)81 /* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) {
82   std::vector<int64> layout(rank);
83   std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
84   return MakeLayout(layout);
85 }
86 
MakeAscendingLayout(int64 rank)87 /* static */ Layout LayoutUtil::MakeAscendingLayout(int64 rank) {
88   std::vector<int64> layout(rank);
89   std::iota(layout.begin(), layout.end(), static_cast<int64>(0));
90   return MakeLayout(layout);
91 }
92 
MakeLayoutFromMajorToMinor(absl::Span<const int64> major_to_minor)93 /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
94     absl::Span<const int64> major_to_minor) {
95   Layout layout;
96   layout.set_format(DENSE);
97   for (int i = major_to_minor.size() - 1; i >= 0; i--) {
98     layout.add_minor_to_major(major_to_minor[i]);
99   }
100   return layout;
101 }
102 
103 namespace {
104 
105 // Internal helper that creates a default layout for an array of the given rank.
CreateDefaultLayoutForRank(int64 rank)106 Layout CreateDefaultLayoutForRank(int64 rank) {
107   Layout layout;
108   layout.set_format(DENSE);
109   auto* minor_to_major = layout.mutable_minor_to_major();
110   minor_to_major->resize(rank, 0);
111   SetDefaultLayoutToContainer(minor_to_major);
112   return layout;
113 }
114 
115 }  // namespace
116 
GetDefaultLayoutForShape(const Shape & shape)117 /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
118   if (shape.IsOpaque() || shape.IsToken()) {
119     // Opaque and token types have empty layouts.
120     return Layout();
121   }
122 
123   // A Layout proto corresponds to a single array, not a tuple.
124   CHECK(shape.IsArray());
125   return CreateDefaultLayoutForRank(shape.dimensions_size());
126 }
127 
GetDefaultLayoutForRank(int64 rank)128 /* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
129   return CreateDefaultLayoutForRank(rank);
130 }
131 
GetDefaultLayoutForR2()132 /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
133   return CreateDefaultLayoutForRank(2);
134 }
135 
GetDefaultLayoutForR3()136 /* static */ Layout LayoutUtil::GetDefaultLayoutForR3() {
137   return CreateDefaultLayoutForRank(3);
138 }
139 
GetDefaultLayoutForR4()140 /* static */ Layout LayoutUtil::GetDefaultLayoutForR4() {
141   return CreateDefaultLayoutForRank(4);
142 }
143 
SetToDefaultLayout(Shape * shape)144 /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) {
145   if (shape->IsTuple()) {
146     // Tuple shape.
147     for (auto& element_shape : *shape->mutable_tuple_shapes()) {
148       SetToDefaultLayout(&element_shape);
149     }
150     shape->clear_layout();
151   } else if (shape->IsArray()) {
152     shape->mutable_layout()->set_format(DENSE);
153     auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
154     minor_to_major->resize(shape->dimensions_size(), 0);
155     SetDefaultLayoutToContainer(minor_to_major);
156   } else {
157     // Opaque, token types etc. have no layout.
158     shape->clear_layout();
159   }
160 }
161 
GetWithDefaultLayout(const Shape & shape)162 /* static */ Shape LayoutUtil::GetWithDefaultLayout(const Shape& shape) {
163   Shape copy(shape);
164   LayoutUtil::SetToDefaultLayout(&copy);
165   return copy;
166 }
167 
SetToDefaultLayout(ProgramShape * program_shape)168 /* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) {
169   for (auto& parameter_shape : *program_shape->mutable_parameters()) {
170     LayoutUtil::SetToDefaultLayout(&parameter_shape);
171   }
172   LayoutUtil::SetToDefaultLayout(program_shape->mutable_result());
173 }
174 
ValidateLayoutInShape(const Shape & shape,bool allow_missing_layouts)175 /* static */ Status LayoutUtil::ValidateLayoutInShape(
176     const Shape& shape, bool allow_missing_layouts) {
177   if (shape.IsTuple()) {
178     // Tuple shape.
179     if (shape.has_layout()) {
180       return InvalidArgument("tuple should not have a layout field");
181     }
182     for (auto& element_shape : shape.tuple_shapes()) {
183       TF_RETURN_IF_ERROR(
184           ValidateLayoutInShape(element_shape, allow_missing_layouts));
185     }
186     return Status::OK();
187   } else if (shape.IsArray()) {
188     if (!shape.has_layout()) {
189       if (allow_missing_layouts) {
190         return Status::OK();
191       }
192       return InvalidArgument("shape %s does not have a layout",
193                              ShapeUtil::HumanString(shape));
194     }
195     return ValidateLayoutForShape(shape.layout(), shape);
196   } else {
197     // Token, opaque, etc. shape.
198     if (shape.has_layout()) {
199       return InvalidArgument(
200           "shape of primitive type %s should not have a layout",
201           PrimitiveType_Name(shape.element_type()));
202     }
203     return Status::OK();
204   }
205 }
206 
ValidateLayoutForShape(const Layout & layout,const Shape & shape)207 /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout,
208                                                        const Shape& shape) {
209   if (shape.IsTuple()) {
210     return InvalidArgument("a single Layout is not valid for tuple shapes");
211   }
212 
213   if (!shape.IsArray()) {
214     if (layout.minor_to_major_size() != 0) {
215       return InvalidArgument(
216           "shape of primitive type %s should not have a non-trivial layout",
217           PrimitiveType_Name(shape.element_type()));
218     }
219     return Status::OK();
220   }
221 
222   if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) {
223     return InvalidArgument("Layout has an invalid format (%d)",
224                            layout.format());
225   }
226 
227   if (layout.format() == DENSE) {
228     if (layout.minor_to_major_size() != shape.rank()) {
229       return InvalidArgument(
230           "layout minor_to_major field contains %d elements, "
231           "but shape is rank %d: {%s}; shape: %s",
232           layout.minor_to_major_size(), shape.rank(),
233           absl::StrJoin(layout.minor_to_major(), ", "),
234           shape.ShortDebugString());
235     }
236 
237     std::vector<bool> dimensions_in_layout(shape.rank(), false);
238     for (int64 i = 0; i < shape.rank(); ++i) {
239       int64 dim = layout.minor_to_major(i);
240       if (dim < 0 || dim >= shape.rank()) {
241         return InvalidArgument(
242             "layout minor_to_major field has out-of-bounds value: %s",
243             HumanString(layout));
244       }
245       if (dimensions_in_layout[dim]) {
246         return InvalidArgument(
247             "layout minor_to_major field has duplicate values: {%s}",
248             HumanString(layout));
249       }
250       dimensions_in_layout[dim] = true;
251     }
252   } else {
253     if (layout.tiles_size() != 0) {
254       return InvalidArgument("Only dense layouts can be tiled.");
255     }
256   }
257 
258   return Status::OK();
259 }
260 
ClearLayout(Shape * shape)261 /* static */ void LayoutUtil::ClearLayout(Shape* shape) {
262   shape->clear_layout();
263   for (auto& element_shape : *shape->mutable_tuple_shapes()) {
264     ClearLayout(&element_shape);
265   }
266 }
267 
ClearLayout(ProgramShape * program_shape)268 /* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) {
269   for (auto& parameter_shape : *program_shape->mutable_parameters()) {
270     LayoutUtil::ClearLayout(&parameter_shape);
271   }
272   LayoutUtil::ClearLayout(program_shape->mutable_result());
273 }
274 
IsDenseArray(const Shape & shape)275 /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) {
276   return shape.IsArray() && shape.has_layout() && IsDense(shape.layout());
277 }
278 
IsDense(const Layout & layout)279 /* static */ bool LayoutUtil::IsDense(const Layout& layout) {
280   return layout.format() == DENSE;
281 }
282 
IsMonotonicWithDim0Minor(const Layout & layout)283 /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) {
284   CHECK(layout.format() == DENSE);
285   return std::is_sorted(layout.minor_to_major().begin(),
286                         layout.minor_to_major().end());
287 }
288 
IsMonotonicWithDim0Major(const Layout & layout)289 /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) {
290   CHECK(layout.format() == DENSE);
291   return std::is_sorted(layout.minor_to_major().begin(),
292                         layout.minor_to_major().end(), std::greater<int64>());
293 }
294 
HasLayout(const Shape & shape)295 /* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
296   if (shape.IsTuple()) {
297     // Tuple shape: all subshapes must have a layout.
298     return absl::c_all_of(shape.tuple_shapes(),
299                           [](const Shape& s) { return HasLayout(s); });
300   } else if (!shape.IsArray()) {
301     // Opaque, token types etc. ignore layout.
302     return true;
303   }
304   return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
305 }
306 
HasLayout(const ProgramShape & program_shape)307 /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) {
308   for (auto& parameter_shape : program_shape.parameters()) {
309     if (!LayoutUtil::HasLayout(parameter_shape)) {
310       return false;
311     }
312   }
313   return LayoutUtil::HasLayout(program_shape.result());
314 }
315 
Equal(const Layout & lhs,const Layout & rhs)316 /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
317   return lhs == rhs;
318 }
319 
MinorToMajor(const Shape & shape)320 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
321     const Shape& shape) {
322   CHECK(IsDenseArray(shape));
323   return AsInt64Slice(shape.layout().minor_to_major());
324 }
325 
MinorToMajor(const Layout & layout)326 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
327     const Layout& layout) {
328   CHECK(layout.format() == DENSE);
329   return AsInt64Slice(layout.minor_to_major());
330 }
331 
Major(const Layout & layout,int64 physical_dimension_number)332 /* static */ int64 LayoutUtil::Major(const Layout& layout,
333                                      int64 physical_dimension_number) {
334   CHECK_LE(0, physical_dimension_number);
335   CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
336   return Minor(layout,
337                layout.minor_to_major_size() - 1 - physical_dimension_number);
338 }
339 
Minor(const Layout & layout,int64 physical_dimension_number)340 /* static */ int64 LayoutUtil::Minor(const Layout& layout,
341                                      int64 physical_dimension_number) {
342   CHECK_EQ(layout.format(), DENSE);
343   CHECK_LE(0, physical_dimension_number);
344   CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
345   return layout.minor_to_major(physical_dimension_number);
346 }
347 
MakeLogicalToPhysical(const Layout & layout)348 /* static */ std::vector<int64> LayoutUtil::MakeLogicalToPhysical(
349     const Layout& layout) {
350   std::vector<int64> logical_to_physical(layout.minor_to_major_size());
351   for (int64 physical = 0, end = logical_to_physical.size(); physical < end;
352        ++physical) {
353     const int64 logical = Major(layout, physical);
354     logical_to_physical[logical] = physical;
355   }
356   return logical_to_physical;
357 }
358 
HumanString(const Layout & layout)359 /* static */ string LayoutUtil::HumanString(const Layout& layout) {
360   return layout.ToString();
361 }
362 
363 namespace {
364 
365 // Internal helper for recursively copying layouts.
CopyLayoutInternal(const Shape & src,Shape * dst)366 Status CopyLayoutInternal(const Shape& src, Shape* dst) {
367   if (src.IsTuple() != dst->IsTuple()) {
368     return InvalidArgument(
369         "cannot copy layout from shape: shape structure differs");
370   }
371   if (src.IsTuple()) {
372     if (ShapeUtil::TupleElementCount(src) !=
373         ShapeUtil::TupleElementCount(*dst)) {
374       return InvalidArgument(
375           "cannot copy layout from shape: tuple element count differs");
376     }
377     for (int64 i = 0; i < ShapeUtil::TupleElementCount(src); ++i) {
378       TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i),
379                                             dst->mutable_tuple_shapes(i)));
380     }
381   } else {
382     if (src.has_layout()) {
383       if (src.rank() != dst->rank()) {
384         return InvalidArgument("cannot copy layout from shape: ranks differs");
385       }
386       TF_RETURN_IF_ERROR(
387           LayoutUtil::ValidateLayoutForShape(src.layout(), *dst));
388       *dst->mutable_layout() = src.layout();
389     } else {
390       dst->clear_layout();
391     }
392   }
393   return Status::OK();
394 }
395 
396 }  // namespace
397 
398 /* static */
CopyLayoutBetweenShapes(const Shape & src,Shape * dst)399 Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
400   return CopyLayoutInternal(src, dst);
401 }
402 
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)403 /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
404                                                    const Shape& rhs) {
405   if (lhs.IsTuple()) {
406     if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) !=
407                               ShapeUtil::TupleElementCount(rhs)) {
408       return false;
409     }
410     for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
411       if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) {
412         return false;
413       }
414     }
415     return true;
416   } else if (lhs.IsArray()) {
417     return lhs.rank() == rhs.rank() &&
418            LayoutUtil::Equal(lhs.layout(), rhs.layout());
419   } else {
420     // Layouts of non-array and non-tuple shapes is ignored.
421     return true;
422   }
423 }
424 
AreDimensionsConsecutive(const Layout & layout,absl::Span<const int64> dims)425 /* static */ bool LayoutUtil::AreDimensionsConsecutive(
426     const Layout& layout, absl::Span<const int64> dims) {
427   CHECK(IsDense(layout));
428   std::vector<int64> positions_in_layout;
429   for (int64 dim : dims) {
430     positions_in_layout.push_back(
431         PositionInContainer(layout.minor_to_major(), dim));
432   }
433   absl::c_sort(positions_in_layout);
434   for (size_t i = 1; i < positions_in_layout.size(); ++i) {
435     if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
436       return false;
437     }
438   }
439   return true;
440 }
441 
Hash(const Layout & layout)442 /*static*/ size_t LayoutUtil::Hash(const Layout& layout) {
443   using tensorflow::hash;
444   using tensorflow::Hash64Combine;
445 
446   size_t hash_value = hash<Format>()(layout.format());
447 
448   for (int64 minor_to_major : layout.minor_to_major()) {
449     hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
450   }
451   for (const Tile& tile : layout.tiles()) {
452     for (int64 tile_dim : tile.dimensions()) {
453       hash_value = Hash64Combine(hash_value, hash<int64>()(tile_dim));
454     }
455   }
456   hash_value = Hash64Combine(hash_value, layout.element_size_in_bits());
457   hash_value = Hash64Combine(hash_value, layout.memory_space());
458 
459   return hash_value;
460 }
461 
462 }  // namespace xla
463