• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/layout_util.h"
17 
18 #include <stddef.h>
19 #include <algorithm>
20 #include <functional>
21 #include <random>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/protobuf_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/hash/hash.h"
35 #include "tensorflow/core/lib/strings/numbers.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/protobuf.h"
38 
39 namespace xla {
40 namespace {
41 
42 // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
43 // minor_to_major to the value that represents the default layout.
SetDefaultLayoutToContainer(std::vector<int64> * minor_to_major)44 void SetDefaultLayoutToContainer(std::vector<int64>* minor_to_major) {
45   // The default XLA layout is major-to-minor (dim 0 is major).
46   // For more information on XLA layouts, see:
47   // https://www.tensorflow.org/performance/xla/shapes
48   const int64 size = minor_to_major->size();
49   for (int64 i = 0; i < size; ++i) {
50     (*minor_to_major)[i] = size - 1 - i;
51   }
52 }
53 
54 }  // namespace
55 
MakeLayout(absl::Span<const int64> minor_to_major,absl::Span<const Tile> tiles,int64 element_size_in_bits)56 /* static */ Layout LayoutUtil::MakeLayout(
57     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
58     int64 element_size_in_bits) {
59   Layout layout;
60   layout.set_format(DENSE);
61   for (int64 dimension_number : minor_to_major) {
62     layout.add_minor_to_major(dimension_number);
63   }
64   for (Tile tile : tiles) {
65     for (int64 dim : tile.dimensions()) {
66       if (dim < 0 && dim != Tile::kCombineDimension) {
67         LOG(FATAL) << "Tile dimension size needs to be mininum int64 value if "
68                       "it's negative. Value is "
69                    << dim;
70       }
71     }
72     *layout.add_tiles() = tile;
73   }
74   layout.set_element_size_in_bits(element_size_in_bits);
75   return layout;
76 }
77 
MakeDescendingLayout(int64 rank)78 /* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) {
79   std::vector<int64> layout(rank);
80   std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
81   return MakeLayout(layout);
82 }
83 
MakeLayoutFromMajorToMinor(absl::Span<const int64> major_to_minor)84 /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
85     absl::Span<const int64> major_to_minor) {
86   Layout layout;
87   layout.set_format(DENSE);
88   for (int i = major_to_minor.size() - 1; i >= 0; i--) {
89     layout.add_minor_to_major(major_to_minor[i]);
90   }
91   return layout;
92 }
93 
MakeSparseLayout(int64 max_sparse_elements)94 /* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) {
95   Layout layout;
96   layout.set_format(SPARSE);
97   layout.set_max_sparse_elements(max_sparse_elements);
98   return layout;
99 }
100 
101 namespace {
102 
103 // Internal helper that creates a default layout for an array of the given rank.
CreateDefaultLayoutForRank(int64 rank)104 Layout CreateDefaultLayoutForRank(int64 rank) {
105   Layout layout;
106   layout.set_format(DENSE);
107   std::vector<int64>* minor_to_major = layout.mutable_minor_to_major();
108   minor_to_major->resize(rank, 0);
109   SetDefaultLayoutToContainer(minor_to_major);
110   return layout;
111 }
112 
113 }  // namespace
114 
GetDefaultLayoutForShape(const Shape & shape)115 /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
116   if (shape.IsOpaque() || shape.IsToken()) {
117     // Opaque and token types have empty layouts.
118     return Layout();
119   }
120 
121   // A Layout proto corresponds to a single array, not a tuple.
122   CHECK(shape.IsArray());
123   return CreateDefaultLayoutForRank(shape.dimensions_size());
124 }
125 
GetDefaultLayoutForRank(int64 rank)126 /* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
127   return CreateDefaultLayoutForRank(rank);
128 }
129 
GetDefaultLayoutForR2()130 /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
131   return CreateDefaultLayoutForRank(2);
132 }
133 
GetDefaultLayoutForR3()134 /* static */ Layout LayoutUtil::GetDefaultLayoutForR3() {
135   return CreateDefaultLayoutForRank(3);
136 }
137 
GetDefaultLayoutForR4()138 /* static */ Layout LayoutUtil::GetDefaultLayoutForR4() {
139   return CreateDefaultLayoutForRank(4);
140 }
141 
SetToDefaultLayout(Shape * shape)142 /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) {
143   if (shape->IsTuple()) {
144     // Tuple shape.
145     for (auto& element_shape : *shape->mutable_tuple_shapes()) {
146       SetToDefaultLayout(&element_shape);
147     }
148     shape->clear_layout();
149   } else if (shape->IsArray()) {
150     shape->mutable_layout()->set_format(DENSE);
151     auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
152     minor_to_major->resize(shape->dimensions_size(), 0);
153     SetDefaultLayoutToContainer(minor_to_major);
154   } else {
155     // Opaque, token types etc. have no layout.
156     shape->clear_layout();
157   }
158 }
159 
GetWithDefaultLayout(const Shape & shape)160 /* static */ Shape LayoutUtil::GetWithDefaultLayout(const Shape& shape) {
161   Shape copy(shape);
162   LayoutUtil::SetToDefaultLayout(&copy);
163   return copy;
164 }
165 
SetToDefaultLayout(ProgramShape * program_shape)166 /* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) {
167   for (auto& parameter_shape : *program_shape->mutable_parameters()) {
168     LayoutUtil::SetToDefaultLayout(&parameter_shape);
169   }
170   LayoutUtil::SetToDefaultLayout(program_shape->mutable_result());
171 }
172 
ValidateLayoutInShape(const Shape & shape,bool allow_missing_layouts)173 /* static */ Status LayoutUtil::ValidateLayoutInShape(
174     const Shape& shape, bool allow_missing_layouts) {
175   if (shape.IsTuple()) {
176     // Tuple shape.
177     if (shape.has_layout()) {
178       return InvalidArgument("tuple should not have a layout field");
179     }
180     for (auto& element_shape : shape.tuple_shapes()) {
181       TF_RETURN_IF_ERROR(
182           ValidateLayoutInShape(element_shape, allow_missing_layouts));
183     }
184     return Status::OK();
185   } else if (shape.IsArray()) {
186     if (!shape.has_layout()) {
187       if (allow_missing_layouts) {
188         return Status::OK();
189       }
190       return InvalidArgument("shape %s does not have a layout",
191                              ShapeUtil::HumanString(shape));
192     }
193     return ValidateLayoutForShape(shape.layout(), shape);
194   } else {
195     // Token, opaque, etc. shape.
196     if (shape.has_layout()) {
197       return InvalidArgument(
198           "shape of primitive type %s should not have a layout",
199           PrimitiveType_Name(shape.element_type()));
200     }
201     return Status::OK();
202   }
203 }
204 
ValidateLayoutForShape(const Layout & layout,const Shape & shape)205 /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout,
206                                                        const Shape& shape) {
207   if (shape.IsTuple()) {
208     return InvalidArgument("a single Layout is not valid for tuple shapes");
209   }
210 
211   if (!shape.IsArray()) {
212     if (layout.minor_to_major_size() != 0) {
213       return InvalidArgument(
214           "shape of primitive type %s should not have a non-trivial layout",
215           PrimitiveType_Name(shape.element_type()));
216     }
217     return Status::OK();
218   }
219 
220   if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) {
221     return InvalidArgument("Layout has an invalid format (%d)",
222                            layout.format());
223   }
224 
225   if (layout.format() == DENSE) {
226     if (layout.minor_to_major_size() != shape.rank()) {
227       return InvalidArgument(
228           "layout minor_to_major field contains %d elements, "
229           "but shape is rank %d: {%s}; shape: %s",
230           layout.minor_to_major_size(), shape.rank(),
231           absl::StrJoin(layout.minor_to_major(), ", "),
232           shape.ShortDebugString());
233     }
234 
235     std::vector<bool> dimensions_in_layout(shape.rank(), false);
236     for (int64 i = 0; i < shape.rank(); ++i) {
237       int64 dim = layout.minor_to_major(i);
238       if (dim < 0 || dim >= shape.rank()) {
239         return InvalidArgument(
240             "layout minor_to_major field has out-of-bounds value: %s",
241             HumanString(layout));
242       }
243       if (dimensions_in_layout[dim]) {
244         return InvalidArgument(
245             "layout minor_to_major field has duplicate values: {%s}",
246             HumanString(layout));
247       }
248       dimensions_in_layout[dim] = true;
249     }
250   } else {
251     if (layout.tiles_size() != 0) {
252       return InvalidArgument("Only dense layouts can be tiled.");
253     }
254   }
255 
256   return Status::OK();
257 }
258 
ClearLayout(Shape * shape)259 /* static */ void LayoutUtil::ClearLayout(Shape* shape) {
260   shape->clear_layout();
261   for (auto& element_shape : *shape->mutable_tuple_shapes()) {
262     ClearLayout(&element_shape);
263   }
264 }
265 
ClearLayout(ProgramShape * program_shape)266 /* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) {
267   for (auto& parameter_shape : *program_shape->mutable_parameters()) {
268     LayoutUtil::ClearLayout(&parameter_shape);
269   }
270   LayoutUtil::ClearLayout(program_shape->mutable_result());
271 }
272 
IsDenseArray(const Shape & shape)273 /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) {
274   return shape.IsArray() && shape.has_layout() && IsDense(shape.layout());
275 }
276 
IsDense(const Layout & layout)277 /* static */ bool LayoutUtil::IsDense(const Layout& layout) {
278   return layout.format() == DENSE;
279 }
280 
IsMonotonicWithDim0Minor(const Layout & layout)281 /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) {
282   CHECK(layout.format() == DENSE);
283   return std::is_sorted(layout.minor_to_major().begin(),
284                         layout.minor_to_major().end());
285 }
286 
IsMonotonicWithDim0Major(const Layout & layout)287 /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) {
288   CHECK(layout.format() == DENSE);
289   return std::is_sorted(layout.minor_to_major().begin(),
290                         layout.minor_to_major().end(), std::greater<int64>());
291 }
292 
IsSparseArray(const Shape & shape)293 /* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) {
294   return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout());
295 }
296 
IsSparse(const Layout & layout)297 /* static */ bool LayoutUtil::IsSparse(const Layout& layout) {
298   return layout.format() == SPARSE;
299 }
300 
MaxSparseElements(const Layout & layout)301 /* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) {
302   CHECK(IsSparse(layout));
303   return layout.max_sparse_elements();
304 }
305 
HasLayout(const Shape & shape)306 /* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
307   if (shape.IsTuple()) {
308     // Tuple shape: all subshapes must have a layout.
309     return absl::c_all_of(shape.tuple_shapes(),
310                           [](const Shape& s) { return HasLayout(s); });
311   } else if (!shape.IsArray()) {
312     // Opaque, token types etc. ignore layout.
313     return true;
314   }
315   return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
316 }
317 
HasLayout(const ProgramShape & program_shape)318 /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) {
319   for (auto& parameter_shape : program_shape.parameters()) {
320     if (!LayoutUtil::HasLayout(parameter_shape)) {
321       return false;
322     }
323   }
324   return LayoutUtil::HasLayout(program_shape.result());
325 }
326 
Equal(const Layout & lhs,const Layout & rhs)327 /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
328   return lhs == rhs;
329 }
330 
MinorToMajor(const Shape & shape)331 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
332     const Shape& shape) {
333   CHECK(IsDenseArray(shape));
334   return AsInt64Slice(shape.layout().minor_to_major());
335 }
336 
MinorToMajor(const Layout & layout)337 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
338     const Layout& layout) {
339   CHECK(layout.format() == DENSE);
340   return AsInt64Slice(layout.minor_to_major());
341 }
342 
Major(const Layout & layout,int64 physical_dimension_number)343 /* static */ int64 LayoutUtil::Major(const Layout& layout,
344                                      int64 physical_dimension_number) {
345   CHECK_LE(0, physical_dimension_number);
346   CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
347   return Minor(layout,
348                layout.minor_to_major_size() - 1 - physical_dimension_number);
349 }
350 
Minor(const Layout & layout,int64 physical_dimension_number)351 /* static */ int64 LayoutUtil::Minor(const Layout& layout,
352                                      int64 physical_dimension_number) {
353   CHECK_EQ(layout.format(), DENSE);
354   CHECK_LE(0, physical_dimension_number);
355   CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
356   return layout.minor_to_major(physical_dimension_number);
357 }
358 
MakeLogicalToPhysical(const Layout & layout)359 /* static */ std::vector<int64> LayoutUtil::MakeLogicalToPhysical(
360     const Layout& layout) {
361   std::vector<int64> logical_to_physical(layout.minor_to_major_size());
362   for (int64 physical = 0; physical < logical_to_physical.size(); ++physical) {
363     const int64 logical = Major(layout, physical);
364     logical_to_physical[logical] = physical;
365   }
366   return logical_to_physical;
367 }
368 
HumanString(const Layout & layout)369 /* static */ string LayoutUtil::HumanString(const Layout& layout) {
370   return layout.ToString();
371 }
372 
373 namespace {
374 
375 // Internal helper for recursively copying layouts.
CopyLayoutInternal(const Shape & src,Shape * dst)376 Status CopyLayoutInternal(const Shape& src, Shape* dst) {
377   if (src.IsTuple() != dst->IsTuple()) {
378     return InvalidArgument(
379         "cannot copy layout from shape: shape structure differs");
380   }
381   if (src.IsTuple()) {
382     if (ShapeUtil::TupleElementCount(src) !=
383         ShapeUtil::TupleElementCount(*dst)) {
384       return InvalidArgument(
385           "cannot copy layout from shape: tuple element count differs");
386     }
387     for (int64 i = 0; i < ShapeUtil::TupleElementCount(src); ++i) {
388       TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i),
389                                             dst->mutable_tuple_shapes(i)));
390     }
391   } else {
392     if (src.has_layout()) {
393       if (src.rank() != dst->rank()) {
394         return InvalidArgument("cannot copy layout from shape: ranks differs");
395       }
396       TF_RETURN_IF_ERROR(
397           LayoutUtil::ValidateLayoutForShape(src.layout(), *dst));
398       *dst->mutable_layout() = src.layout();
399     } else {
400       dst->clear_layout();
401     }
402   }
403   return Status::OK();
404 }
405 
406 }  // namespace
407 
408 /* static */
CopyLayoutBetweenShapes(const Shape & src,Shape * dst)409 Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
410   return CopyLayoutInternal(src, dst);
411 }
412 
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)413 /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
414                                                    const Shape& rhs) {
415   if (lhs.IsTuple()) {
416     if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) !=
417                               ShapeUtil::TupleElementCount(rhs)) {
418       return false;
419     }
420     for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
421       if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) {
422         return false;
423       }
424     }
425     return true;
426   } else if (lhs.IsArray()) {
427     return lhs.rank() == rhs.rank() &&
428            LayoutUtil::Equal(lhs.layout(), rhs.layout());
429   } else {
430     // Layouts of non-array and non-tuple shapes is ignored.
431     return true;
432   }
433 }
434 
AreDimensionsConsecutive(const Layout & layout,absl::Span<const int64> dims)435 /* static */ bool LayoutUtil::AreDimensionsConsecutive(
436     const Layout& layout, absl::Span<const int64> dims) {
437   CHECK(IsDense(layout));
438   std::vector<int64> positions_in_layout;
439   for (int64 dim : dims) {
440     positions_in_layout.push_back(
441         PositionInContainer(layout.minor_to_major(), dim));
442   }
443   absl::c_sort(positions_in_layout);
444   for (size_t i = 1; i < positions_in_layout.size(); ++i) {
445     if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
446       return false;
447     }
448   }
449   return true;
450 }
451 
Hash(const Layout & layout)452 /*static*/ size_t LayoutUtil::Hash(const Layout& layout) {
453   using tensorflow::hash;
454   using tensorflow::Hash64Combine;
455 
456   size_t hash_value = hash<Format>()(layout.format());
457 
458   for (int64 minor_to_major : layout.minor_to_major()) {
459     hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
460   }
461   hash_value = Hash64Combine(hash_value, layout.max_sparse_elements());
462 
463   for (Tile tile : layout.tiles()) {
464     for (int64 tile_dim : tile.dimensions()) {
465       hash_value = Hash64Combine(hash_value, hash<int64>()(tile_dim));
466     }
467   }
468   hash_value = Hash64Combine(hash_value, layout.element_size_in_bits());
469 
470   return hash_value;
471 }
472 
473 }  // namespace xla
474