• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/layout_util.h"
17 
18 #include <stddef.h>
19 
20 #include <algorithm>
21 #include <functional>
22 #include <random>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/xla/protobuf_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/hash/hash.h"
36 #include "tensorflow/core/lib/strings/numbers.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/protobuf.h"
39 
40 namespace xla {
41 namespace {
42 
43 // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
44 // minor_to_major to the value that represents the default layout.
45 template <typename T>
SetDefaultLayoutToContainer(T * minor_to_major)46 void SetDefaultLayoutToContainer(T* minor_to_major) {
47   // The default XLA layout is major-to-minor (dim 0 is major).
48   // For more information on XLA layouts, see:
49   // https://www.tensorflow.org/performance/xla/shapes
50   const int64 size = minor_to_major->size();
51   for (int64 i = 0; i < size; ++i) {
52     (*minor_to_major)[i] = size - 1 - i;
53   }
54 }
55 
56 }  // namespace
57 
MakeLayout(absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits,int64 memory_space)58 /* static */ Layout LayoutUtil::MakeLayout(
59     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
60     int64 element_size_in_bits, int64 memory_space) {
61   Layout layout;
62   layout.set_format(DENSE);
63   for (int64 dimension_number : minor_to_major) {
64     layout.add_minor_to_major(dimension_number);
65   }
66   for (Tile tile : tiles) {
67     for (int64 dim : tile.dimensions()) {
68       if (dim < 0 && dim != Tile::kCombineDimension) {
69         LOG(FATAL) << "Tile dimension size needs to be minimum int64 value if "
70                       "it's negative. Value is "
71                    << dim;
72       }
73     }
74     *layout.add_tiles() = tile;
75   }
76   layout.set_element_size_in_bits(element_size_in_bits);
77   layout.set_memory_space(memory_space);
78   return layout;
79 }
80 
MakeDescendingLayout(int64 rank)81 /* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) {
82   std::vector<int64> layout(rank);
83   std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
84   return MakeLayout(layout);
85 }
86 
MakeLayoutFromMajorToMinor(absl::Span<const int64> major_to_minor)87 /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
88     absl::Span<const int64> major_to_minor) {
89   Layout layout;
90   layout.set_format(DENSE);
91   for (int i = major_to_minor.size() - 1; i >= 0; i--) {
92     layout.add_minor_to_major(major_to_minor[i]);
93   }
94   return layout;
95 }
96 
97 namespace {
98 
99 // Internal helper that creates a default layout for an array of the given rank.
CreateDefaultLayoutForRank(int64 rank)100 Layout CreateDefaultLayoutForRank(int64 rank) {
101   Layout layout;
102   layout.set_format(DENSE);
103   auto* minor_to_major = layout.mutable_minor_to_major();
104   minor_to_major->resize(rank, 0);
105   SetDefaultLayoutToContainer(minor_to_major);
106   return layout;
107 }
108 
109 }  // namespace
110 
GetDefaultLayoutForShape(const Shape & shape)111 /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
112   if (shape.IsOpaque() || shape.IsToken()) {
113     // Opaque and token types have empty layouts.
114     return Layout();
115   }
116 
117   // A Layout proto corresponds to a single array, not a tuple.
118   CHECK(shape.IsArray());
119   return CreateDefaultLayoutForRank(shape.dimensions_size());
120 }
121 
GetDefaultLayoutForRank(int64 rank)122 /* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
123   return CreateDefaultLayoutForRank(rank);
124 }
125 
GetDefaultLayoutForR2()126 /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
127   return CreateDefaultLayoutForRank(2);
128 }
129 
GetDefaultLayoutForR3()130 /* static */ Layout LayoutUtil::GetDefaultLayoutForR3() {
131   return CreateDefaultLayoutForRank(3);
132 }
133 
GetDefaultLayoutForR4()134 /* static */ Layout LayoutUtil::GetDefaultLayoutForR4() {
135   return CreateDefaultLayoutForRank(4);
136 }
137 
SetToDefaultLayout(Shape * shape)138 /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) {
139   if (shape->IsTuple()) {
140     // Tuple shape.
141     for (auto& element_shape : *shape->mutable_tuple_shapes()) {
142       SetToDefaultLayout(&element_shape);
143     }
144     shape->clear_layout();
145   } else if (shape->IsArray()) {
146     shape->mutable_layout()->set_format(DENSE);
147     auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
148     minor_to_major->resize(shape->dimensions_size(), 0);
149     SetDefaultLayoutToContainer(minor_to_major);
150   } else {
151     // Opaque, token types etc. have no layout.
152     shape->clear_layout();
153   }
154 }
155 
GetWithDefaultLayout(const Shape & shape)156 /* static */ Shape LayoutUtil::GetWithDefaultLayout(const Shape& shape) {
157   Shape copy(shape);
158   LayoutUtil::SetToDefaultLayout(&copy);
159   return copy;
160 }
161 
SetToDefaultLayout(ProgramShape * program_shape)162 /* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) {
163   for (auto& parameter_shape : *program_shape->mutable_parameters()) {
164     LayoutUtil::SetToDefaultLayout(&parameter_shape);
165   }
166   LayoutUtil::SetToDefaultLayout(program_shape->mutable_result());
167 }
168 
ValidateLayoutInShape(const Shape & shape,bool allow_missing_layouts)169 /* static */ Status LayoutUtil::ValidateLayoutInShape(
170     const Shape& shape, bool allow_missing_layouts) {
171   if (shape.IsTuple()) {
172     // Tuple shape.
173     if (shape.has_layout()) {
174       return InvalidArgument("tuple should not have a layout field");
175     }
176     for (auto& element_shape : shape.tuple_shapes()) {
177       TF_RETURN_IF_ERROR(
178           ValidateLayoutInShape(element_shape, allow_missing_layouts));
179     }
180     return Status::OK();
181   } else if (shape.IsArray()) {
182     if (!shape.has_layout()) {
183       if (allow_missing_layouts) {
184         return Status::OK();
185       }
186       return InvalidArgument("shape %s does not have a layout",
187                              ShapeUtil::HumanString(shape));
188     }
189     return ValidateLayoutForShape(shape.layout(), shape);
190   } else {
191     // Token, opaque, etc. shape.
192     if (shape.has_layout()) {
193       return InvalidArgument(
194           "shape of primitive type %s should not have a layout",
195           PrimitiveType_Name(shape.element_type()));
196     }
197     return Status::OK();
198   }
199 }
200 
ValidateLayoutForShape(const Layout & layout,const Shape & shape)201 /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout,
202                                                        const Shape& shape) {
203   if (shape.IsTuple()) {
204     return InvalidArgument("a single Layout is not valid for tuple shapes");
205   }
206 
207   if (!shape.IsArray()) {
208     if (layout.minor_to_major_size() != 0) {
209       return InvalidArgument(
210           "shape of primitive type %s should not have a non-trivial layout",
211           PrimitiveType_Name(shape.element_type()));
212     }
213     return Status::OK();
214   }
215 
216   if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) {
217     return InvalidArgument("Layout has an invalid format (%d)",
218                            layout.format());
219   }
220 
221   if (layout.format() == DENSE) {
222     if (layout.minor_to_major_size() != shape.rank()) {
223       return InvalidArgument(
224           "layout minor_to_major field contains %d elements, "
225           "but shape is rank %d: {%s}; shape: %s",
226           layout.minor_to_major_size(), shape.rank(),
227           absl::StrJoin(layout.minor_to_major(), ", "),
228           shape.ShortDebugString());
229     }
230 
231     std::vector<bool> dimensions_in_layout(shape.rank(), false);
232     for (int64 i = 0; i < shape.rank(); ++i) {
233       int64 dim = layout.minor_to_major(i);
234       if (dim < 0 || dim >= shape.rank()) {
235         return InvalidArgument(
236             "layout minor_to_major field has out-of-bounds value: %s",
237             HumanString(layout));
238       }
239       if (dimensions_in_layout[dim]) {
240         return InvalidArgument(
241             "layout minor_to_major field has duplicate values: {%s}",
242             HumanString(layout));
243       }
244       dimensions_in_layout[dim] = true;
245     }
246   } else {
247     if (layout.tiles_size() != 0) {
248       return InvalidArgument("Only dense layouts can be tiled.");
249     }
250   }
251 
252   return Status::OK();
253 }
254 
ClearLayout(Shape * shape)255 /* static */ void LayoutUtil::ClearLayout(Shape* shape) {
256   shape->clear_layout();
257   for (auto& element_shape : *shape->mutable_tuple_shapes()) {
258     ClearLayout(&element_shape);
259   }
260 }
261 
ClearLayout(ProgramShape * program_shape)262 /* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) {
263   for (auto& parameter_shape : *program_shape->mutable_parameters()) {
264     LayoutUtil::ClearLayout(&parameter_shape);
265   }
266   LayoutUtil::ClearLayout(program_shape->mutable_result());
267 }
268 
IsDenseArray(const Shape & shape)269 /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) {
270   return shape.IsArray() && shape.has_layout() && IsDense(shape.layout());
271 }
272 
IsDense(const Layout & layout)273 /* static */ bool LayoutUtil::IsDense(const Layout& layout) {
274   return layout.format() == DENSE;
275 }
276 
IsMonotonicWithDim0Minor(const Layout & layout)277 /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) {
278   CHECK(layout.format() == DENSE);
279   return std::is_sorted(layout.minor_to_major().begin(),
280                         layout.minor_to_major().end());
281 }
282 
IsMonotonicWithDim0Major(const Layout & layout)283 /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) {
284   CHECK(layout.format() == DENSE);
285   return std::is_sorted(layout.minor_to_major().begin(),
286                         layout.minor_to_major().end(), std::greater<int64>());
287 }
288 
HasLayout(const Shape & shape)289 /* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
290   if (shape.IsTuple()) {
291     // Tuple shape: all subshapes must have a layout.
292     return absl::c_all_of(shape.tuple_shapes(),
293                           [](const Shape& s) { return HasLayout(s); });
294   } else if (!shape.IsArray()) {
295     // Opaque, token types etc. ignore layout.
296     return true;
297   }
298   return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
299 }
300 
HasLayout(const ProgramShape & program_shape)301 /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) {
302   for (auto& parameter_shape : program_shape.parameters()) {
303     if (!LayoutUtil::HasLayout(parameter_shape)) {
304       return false;
305     }
306   }
307   return LayoutUtil::HasLayout(program_shape.result());
308 }
309 
Equal(const Layout & lhs,const Layout & rhs)310 /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
311   return lhs == rhs;
312 }
313 
MinorToMajor(const Shape & shape)314 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
315     const Shape& shape) {
316   CHECK(IsDenseArray(shape));
317   return AsInt64Slice(shape.layout().minor_to_major());
318 }
319 
MinorToMajor(const Layout & layout)320 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
321     const Layout& layout) {
322   CHECK(layout.format() == DENSE);
323   return AsInt64Slice(layout.minor_to_major());
324 }
325 
Major(const Layout & layout,int64 physical_dimension_number)326 /* static */ int64 LayoutUtil::Major(const Layout& layout,
327                                      int64 physical_dimension_number) {
328   CHECK_LE(0, physical_dimension_number);
329   CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
330   return Minor(layout,
331                layout.minor_to_major_size() - 1 - physical_dimension_number);
332 }
333 
Minor(const Layout & layout,int64 physical_dimension_number)334 /* static */ int64 LayoutUtil::Minor(const Layout& layout,
335                                      int64 physical_dimension_number) {
336   CHECK_EQ(layout.format(), DENSE);
337   CHECK_LE(0, physical_dimension_number);
338   CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
339   return layout.minor_to_major(physical_dimension_number);
340 }
341 
MakeLogicalToPhysical(const Layout & layout)342 /* static */ std::vector<int64> LayoutUtil::MakeLogicalToPhysical(
343     const Layout& layout) {
344   std::vector<int64> logical_to_physical(layout.minor_to_major_size());
345   for (int64 physical = 0; physical < logical_to_physical.size(); ++physical) {
346     const int64 logical = Major(layout, physical);
347     logical_to_physical[logical] = physical;
348   }
349   return logical_to_physical;
350 }
351 
HumanString(const Layout & layout)352 /* static */ string LayoutUtil::HumanString(const Layout& layout) {
353   return layout.ToString();
354 }
355 
356 namespace {
357 
358 // Internal helper for recursively copying layouts.
CopyLayoutInternal(const Shape & src,Shape * dst)359 Status CopyLayoutInternal(const Shape& src, Shape* dst) {
360   if (src.IsTuple() != dst->IsTuple()) {
361     return InvalidArgument(
362         "cannot copy layout from shape: shape structure differs");
363   }
364   if (src.IsTuple()) {
365     if (ShapeUtil::TupleElementCount(src) !=
366         ShapeUtil::TupleElementCount(*dst)) {
367       return InvalidArgument(
368           "cannot copy layout from shape: tuple element count differs");
369     }
370     for (int64 i = 0; i < ShapeUtil::TupleElementCount(src); ++i) {
371       TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i),
372                                             dst->mutable_tuple_shapes(i)));
373     }
374   } else {
375     if (src.has_layout()) {
376       if (src.rank() != dst->rank()) {
377         return InvalidArgument("cannot copy layout from shape: ranks differs");
378       }
379       TF_RETURN_IF_ERROR(
380           LayoutUtil::ValidateLayoutForShape(src.layout(), *dst));
381       *dst->mutable_layout() = src.layout();
382     } else {
383       dst->clear_layout();
384     }
385   }
386   return Status::OK();
387 }
388 
389 }  // namespace
390 
391 /* static */
CopyLayoutBetweenShapes(const Shape & src,Shape * dst)392 Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
393   return CopyLayoutInternal(src, dst);
394 }
395 
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)396 /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
397                                                    const Shape& rhs) {
398   if (lhs.IsTuple()) {
399     if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) !=
400                               ShapeUtil::TupleElementCount(rhs)) {
401       return false;
402     }
403     for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
404       if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) {
405         return false;
406       }
407     }
408     return true;
409   } else if (lhs.IsArray()) {
410     return lhs.rank() == rhs.rank() &&
411            LayoutUtil::Equal(lhs.layout(), rhs.layout());
412   } else {
413     // Layouts of non-array and non-tuple shapes is ignored.
414     return true;
415   }
416 }
417 
AreDimensionsConsecutive(const Layout & layout,absl::Span<const int64> dims)418 /* static */ bool LayoutUtil::AreDimensionsConsecutive(
419     const Layout& layout, absl::Span<const int64> dims) {
420   CHECK(IsDense(layout));
421   std::vector<int64> positions_in_layout;
422   for (int64 dim : dims) {
423     positions_in_layout.push_back(
424         PositionInContainer(layout.minor_to_major(), dim));
425   }
426   absl::c_sort(positions_in_layout);
427   for (size_t i = 1; i < positions_in_layout.size(); ++i) {
428     if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
429       return false;
430     }
431   }
432   return true;
433 }
434 
Hash(const Layout & layout)435 /*static*/ size_t LayoutUtil::Hash(const Layout& layout) {
436   using tensorflow::hash;
437   using tensorflow::Hash64Combine;
438 
439   size_t hash_value = hash<Format>()(layout.format());
440 
441   for (int64 minor_to_major : layout.minor_to_major()) {
442     hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
443   }
444   for (Tile tile : layout.tiles()) {
445     for (int64 tile_dim : tile.dimensions()) {
446       hash_value = Hash64Combine(hash_value, hash<int64>()(tile_dim));
447     }
448   }
449   hash_value = Hash64Combine(hash_value, layout.element_size_in_bits());
450   hash_value = Hash64Combine(hash_value, layout.memory_space());
451 
452   return hash_value;
453 }
454 
455 }  // namespace xla
456