• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_
17 #define TENSORFLOW_COMPILER_XLA_LAYOUT_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/types.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace xla {
29 
30 // Describes a tile used in tiling-based layout. Refer to
31 // g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
32 // details.
33 class Tile {
34  public:
35   Tile() = default;
Tile(absl::Span<const int64_t> dimensions)36   explicit Tile(absl::Span<const int64_t> dimensions)
37       : dimensions_(dimensions.begin(), dimensions.end()) {}
38 
39   // De/Serialize a Tile to and from a TileProto.
CreateFromProto(const TileProto & tile_proto)40   static Tile CreateFromProto(const TileProto& tile_proto) {
41     return Tile(tile_proto.dimensions());
42   }
43   TileProto ToProto() const;
44 
45   bool operator==(const Tile& other) const {
46     return dimensions() == other.dimensions();
47   }
48   bool operator!=(const Tile& other) const { return !(*this == other); }
49 
50   std::string ToString() const;
51 
52   // Returns the bound of the tile in the given dimension index.
dimension(int i)53   int64_t dimension(int i) const { return dimensions_.at(i); }
54 
55   // Returns the dimensions of the tile.
dimensions()56   absl::Span<const int64_t> dimensions() const { return dimensions_; }
57 
add_dimensions(int64_t value)58   Tile& add_dimensions(int64_t value) {
59     dimensions_.push_back(value);
60     return *this;
61   }
62 
clear_dimensions()63   Tile& clear_dimensions() {
64     dimensions_.clear();
65     return *this;
66   }
67 
68   // This dimension size means the corresponding dimension in the shape is
69   // combined with the next minor dimension before tiling is applied.
70   static constexpr int64_t kCombineDimension =
71       std::numeric_limits<int64_t>::min();
72 
73   template <typename H>
AbslHashValue(H h,const Tile & t)74   friend H AbslHashValue(H h, const Tile& t) {
75     return H::combine(std::move(h), t.dimensions_);
76   }
77 
78  private:
79   // The bounds of the tile.
80   absl::InlinedVector<int64_t, 2> dimensions_;
81 };
82 
83 class Layout {
84  public:
85   Layout() = default;
86 
87   // Constructs a dense layout with the given minor-to-major order.
Layout(absl::Span<const int64_t> minor_to_major)88   explicit Layout(absl::Span<const int64_t> minor_to_major)
89       : minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {}
90 
91   // Constructs a dense tiled layout with the given minor-to-major order, dim
92   // level types, and tiles.
93   Layout(absl::Span<const int64_t> minor_to_major,
94          absl::Span<const DimLevelType> dim_level_types,
95          absl::Span<const Tile> tiles, int64_t element_size_in_bits = 0,
96          int64_t memory_space = 0)
97       : dim_level_types_(dim_level_types.begin(), dim_level_types.end()),
98         minor_to_major_(minor_to_major.begin(), minor_to_major.end()),
99         tiles_(tiles.begin(), tiles.end()),
100         element_size_in_bits_(element_size_in_bits),
101         memory_space_(memory_space) {}
102 
103   // Construct a shape from a LayoutProto.
104   static Layout CreateFromProto(const LayoutProto& proto);
105 
106   // Returns a LayoutProto representation of the Layout.
107   LayoutProto ToProto() const;
108 
109   // Returns a human-readable string that represents this layout.
110   std::string ToString() const;
111 
112   // Equal is a configurable functor to check the equality of two layouts.
113   //
114   // Examples:
115   //
116   // - Comparing two layouts ignoring their difference in tiles:
117   //   Equal().IgnoreTiles()(layout1, layout2);
118   //
119   // - Comparing two layouts ignoring their difference in tiles and element
120   //   size:
121   //   Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2);
122   class Equal {
123    public:
124     Equal() = default;
125 
126     bool operator()(const Layout& lhs, const Layout& rhs);
127 
IgnoreTiles()128     Equal& IgnoreTiles() {
129       ignore_tiles_ = true;
130       return *this;
131     }
132 
IgnoreElementSize()133     Equal& IgnoreElementSize() {
134       ignore_element_size_ = true;
135       return *this;
136     }
137 
MinorToMajorOnly()138     Equal& MinorToMajorOnly() {
139       ignore_tiles_ = true;
140       ignore_element_size_ = true;
141       ignore_memory_space_ = true;
142       return *this;
143     }
144 
IgnoreMemorySpace()145     Equal& IgnoreMemorySpace() {
146       ignore_memory_space_ = true;
147       return *this;
148     }
149 
150    private:
151     bool ignore_tiles_ = false;
152     bool ignore_element_size_ = false;
153     bool ignore_memory_space_ = false;
154   };
155 
156   bool operator==(const Layout& other) const;
157   bool operator!=(const Layout& other) const { return !(*this == other); }
158 
159   // The following methods mirror the protobuf generated code interface for the
160   // message LayoutProto. This enabled easy migration of this data structure
161   // from a proto to a proper C++ class.
162   //
163   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
164   // interface.
165 
166   // Methods for accessing the DimLevelType array.
dim_level_types_size()167   int dim_level_types_size() const { return dim_level_types_.size(); }
dim_level_type(int index)168   DimLevelType dim_level_type(int index) const {
169     return dim_level_types_.at(index);
170   }
set_dim_level_type(int index,DimLevelType dim_level_type)171   Layout& set_dim_level_type(int index, DimLevelType dim_level_type) {
172     dim_level_types_.at(index) = dim_level_type;
173     return *this;
174   }
add_dim_level_type(DimLevelType dim_level_type)175   Layout& add_dim_level_type(DimLevelType dim_level_type) {
176     dim_level_types_.push_back(dim_level_type);
177     return *this;
178   }
clear_dim_level_types()179   Layout& clear_dim_level_types() {
180     dim_level_types_.clear();
181     return *this;
182   }
dim_level_types()183   absl::Span<const DimLevelType> dim_level_types() const {
184     return dim_level_types_;
185   }
mutable_dim_level_types()186   DimLevelTypeVector* mutable_dim_level_types() { return &dim_level_types_; }
187 
188   // Methods for accessing the minor-to-major array.
minor_to_major_size()189   int minor_to_major_size() const { return minor_to_major_.size(); }
minor_to_major(int index)190   int64_t minor_to_major(int index) const { return minor_to_major_.at(index); }
set_minor_to_major(int index,int64_t value)191   Layout& set_minor_to_major(int index, int64_t value) {
192     minor_to_major_.at(index) = value;
193     return *this;
194   }
add_minor_to_major(int64_t value)195   Layout& add_minor_to_major(int64_t value) {
196     minor_to_major_.push_back(value);
197     return *this;
198   }
clear_minor_to_major()199   Layout& clear_minor_to_major() {
200     minor_to_major_.clear();
201     return *this;
202   }
minor_to_major()203   absl::Span<const int64_t> minor_to_major() const { return minor_to_major_; }
mutable_minor_to_major()204   DimensionVector* mutable_minor_to_major() { return &minor_to_major_; }
205 
206   // Methods for accessing the tile field.
tiles_size()207   int tiles_size() const { return tiles_.size(); }
tiles(int index)208   const Tile& tiles(int index) const { return tiles_.at(index); }
mutable_tiles(int index)209   Tile* mutable_tiles(int index) { return &tiles_.at(index); }
add_tiles()210   Tile* add_tiles() {
211     tiles_.push_back(Tile());
212     return &tiles_.back();
213   }
clear_tiles()214   Layout& clear_tiles() {
215     tiles_.clear();
216     return *this;
217   }
tiles()218   absl::Span<const Tile> tiles() const { return tiles_; }
mutable_tiles()219   absl::InlinedVector<Tile, 2>* mutable_tiles() { return &tiles_; }
220 
element_size_in_bits()221   int64_t element_size_in_bits() const { return element_size_in_bits_; }
set_element_size_in_bits(int64_t value)222   Layout& set_element_size_in_bits(int64_t value) {
223     element_size_in_bits_ = value;
224     return *this;
225   }
226   static constexpr int64_t kDefaultMemorySpace = 0;
227   static constexpr int64_t kGenericFastMemorySpace = 1;
memory_space()228   int64_t memory_space() const { return memory_space_; }
set_memory_space(int64_t value)229   Layout& set_memory_space(int64_t value) {
230     memory_space_ = value;
231     return *this;
232   }
233 
Swap(Layout * other)234   void Swap(Layout* other) {
235     using std::swap;
236     swap(*this, *other);
237   }
238 
Clear()239   void Clear() { *this = Layout(); }
240 
241   template <typename H>
AbslHashValue(H h,const Layout & l)242   friend H AbslHashValue(H h, const Layout& l) {
243     return H::combine(std::move(h), l.minor_to_major_, l.tiles_,
244                       l.element_size_in_bits_, l.memory_space_);
245   }
246 
247  private:
248   // The list of dimension level types, indicating the method that will be used
249   // to represent each dimension of the array.
250   DimLevelTypeVector dim_level_types_;
251 
252   // A map from physical dimension numbers to logical dimension numbers.
253   // The first element is the most minor physical dimension (fastest varying
254   // index) and the last the most major (slowest varying index). The contents of
255   // the vector are the indices of the *logical* dimensions in the shape.
256   //
257   // For example, in shape f32[8,100,100,3]{3,0,2,1}, the logical dimensions
258   // are [8,100,100,3] and minor_to_major_ is {3,0,2,1}.
259   // So, the most minor physical dimension is [8,100,100,3][3], which is size 3.
260   // The second most minor is [8,100,100,3][0], which is size 8.
261   // The third most minor is [8,100,100,3][2], which is size 100.
262   // And the major dim is [8,100,100,3][1], which is size 100.
263   DimensionVector minor_to_major_;
264 
265   // The tiles used in tiling-based layout.
266   absl::InlinedVector<Tile, 2> tiles_;
267 
268   // The number of bits used to store an individual array element.
269   int64_t element_size_in_bits_ = 0;
270 
271   // The assigned memory space.
272   int64_t memory_space_ = 0;
273 };
274 
275 std::ostream& operator<<(std::ostream& out, const Tile& Tile);
276 std::ostream& operator<<(std::ostream& out, const Layout& layout);
277 
278 }  // namespace xla
279 
280 #endif  // TENSORFLOW_COMPILER_XLA_LAYOUT_H_
281