• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_
17 #define TENSORFLOW_COMPILER_XLA_LAYOUT_H_
18 
19 #include <vector>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace xla {
29 
30 // Describes a tile used in tiling-based layout. Refer to
31 // g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
32 // details.
33 class Tile {
34  public:
35   Tile() = default;
Tile(absl::Span<const int64> dimensions)36   explicit Tile(absl::Span<const int64> dimensions)
37       : dimensions_(dimensions.begin(), dimensions.end()) {}
38 
39   // De/Serialize a Tile to and from a TileProto.
CreateFromProto(const TileProto & tile_proto)40   static Tile CreateFromProto(const TileProto& tile_proto) {
41     return Tile(AsInt64Slice(tile_proto.dimensions()));
42   }
43   TileProto ToProto() const;
44 
45   bool operator==(const Tile& other) const {
46     return dimensions() == other.dimensions();
47   }
48   bool operator!=(const Tile& other) const { return !(*this == other); }
49 
50   string ToString() const;
51 
52   // Returns the bound of the tile in the given dimension index.
dimension(int i)53   int64 dimension(int i) const { return dimensions_.at(i); }
54 
55   // Returns the dimensions of the tile.
dimensions()56   absl::Span<const int64> dimensions() const { return dimensions_; }
57 
add_dimensions(int64_t value)58   Tile& add_dimensions(int64_t value) {
59     dimensions_.push_back(value);
60     return *this;
61   }
62 
clear_dimensions()63   Tile& clear_dimensions() {
64     dimensions_.clear();
65     return *this;
66   }
67 
68   // This dimension size means the corresponding dimension in the shape is
69   // combined with the next minor dimension before tiling is applied.
70   static constexpr int64_t kCombineDimension =
71       std::numeric_limits<int64>::min();
72 
73   template <typename H>
AbslHashValue(H h,const Tile & t)74   friend H AbslHashValue(H h, const Tile& t) {
75     return H::combine(std::move(h), t.dimensions_);
76   }
77 
78  private:
79   // The bounds of the tile.
80   absl::InlinedVector<int64, 2> dimensions_;
81 };
82 
83 class Layout {
84  public:
85   Layout() = default;
86 
87   // Constructs a dense layout with the given minor-to-major order.
Layout(absl::Span<const int64> minor_to_major)88   explicit Layout(absl::Span<const int64> minor_to_major)
89       : format_(DENSE),
90         minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {}
91 
92   // Constructs a dense tiled layout with the given minor-to-major order and
93   // tiles.
94   Layout(absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
95          int64_t element_size_in_bits = 0, int64_t memory_space = 0)
format_(DENSE)96       : format_(DENSE),
97         minor_to_major_(minor_to_major.begin(), minor_to_major.end()),
98         tiles_(tiles.begin(), tiles.end()),
99         element_size_in_bits_(element_size_in_bits),
100         memory_space_(memory_space) {}
101 
102   // Construct a shape from a LayoutProto.
103   static Layout CreateFromProto(const LayoutProto& proto);
104 
105   // Returns a LayoutProto representation of the Layout.
106   LayoutProto ToProto() const;
107 
108   // Returns a human-readable string that represents this layout.
109   string ToString() const;
110 
111   // Equal is a configurable functor to check the equality of two layouts.
112   //
113   // Examples:
114   //
115   // - Comparing two layouts ignoring their difference in tiles:
116   //   Equal().IgnoreTiles()(layout1, layout2);
117   //
118   // - Comparing two layouts ignoring their difference in tiles and element
119   //   size:
120   //   Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2);
121   class Equal {
122    public:
123     Equal() = default;
124 
125     bool operator()(const Layout& lhs, const Layout& rhs);
126 
IgnoreTiles()127     Equal& IgnoreTiles() {
128       ignore_tiles_ = true;
129       return *this;
130     }
131 
IgnoreElementSize()132     Equal& IgnoreElementSize() {
133       ignore_element_size_ = true;
134       return *this;
135     }
136 
MinorToMajorOnly()137     Equal& MinorToMajorOnly() {
138       ignore_tiles_ = true;
139       ignore_element_size_ = true;
140       ignore_memory_space_ = true;
141       return *this;
142     }
143 
IgnoreMemorySpace()144     Equal& IgnoreMemorySpace() {
145       ignore_memory_space_ = true;
146       return *this;
147     }
148 
149    private:
150     bool ignore_tiles_ = false;
151     bool ignore_element_size_ = false;
152     bool ignore_memory_space_ = false;
153   };
154 
155   bool operator==(const Layout& other) const;
156   bool operator!=(const Layout& other) const { return !(*this == other); }
157 
158   // The following methods mirror the protobuf generated code interface for the
159   // message LayoutProto. This enabled easy migration of this data structure
160   // from a proto to a proper C++ class.
161   //
162   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
163   // interface.
164 
165   // Methods for accessing the format.
format()166   Format format() const { return format_; }
set_format(Format value)167   Layout& set_format(Format value) {
168     format_ = value;
169     return *this;
170   }
171 
172   // Methods for accessing the minor-to-major array.
minor_to_major_size()173   int minor_to_major_size() const { return minor_to_major_.size(); }
minor_to_major(int index)174   int64 minor_to_major(int index) const { return minor_to_major_.at(index); }
set_minor_to_major(int index,int64_t value)175   Layout& set_minor_to_major(int index, int64_t value) {
176     minor_to_major_.at(index) = value;
177     return *this;
178   }
add_minor_to_major(int64_t value)179   Layout& add_minor_to_major(int64_t value) {
180     minor_to_major_.push_back(value);
181     return *this;
182   }
clear_minor_to_major()183   Layout& clear_minor_to_major() {
184     minor_to_major_.clear();
185     return *this;
186   }
minor_to_major()187   absl::Span<const int64> minor_to_major() const { return minor_to_major_; }
mutable_minor_to_major()188   absl::InlinedVector<int64, 6>* mutable_minor_to_major() {
189     return &minor_to_major_;
190   }
191 
192   // Methods for accessing the tile field.
tiles_size()193   int tiles_size() const { return tiles_.size(); }
tiles(int index)194   const Tile& tiles(int index) const { return tiles_.at(index); }
mutable_tiles(int index)195   Tile* mutable_tiles(int index) { return &tiles_.at(index); }
add_tiles()196   Tile* add_tiles() {
197     tiles_.push_back(Tile());
198     return &tiles_.back();
199   }
clear_tiles()200   Layout& clear_tiles() {
201     tiles_.clear();
202     return *this;
203   }
tiles()204   absl::Span<const Tile> tiles() const { return tiles_; }
mutable_tiles()205   absl::InlinedVector<Tile, 2>* mutable_tiles() { return &tiles_; }
206 
element_size_in_bits()207   int64 element_size_in_bits() const { return element_size_in_bits_; }
set_element_size_in_bits(int64_t value)208   Layout& set_element_size_in_bits(int64_t value) {
209     element_size_in_bits_ = value;
210     return *this;
211   }
212   static constexpr int64_t kDefaultMemorySpace = 0;
213   static constexpr int64_t kGenericFastMemorySpace = 1;
memory_space()214   int64 memory_space() const { return memory_space_; }
set_memory_space(int64_t value)215   Layout& set_memory_space(int64_t value) {
216     memory_space_ = value;
217     return *this;
218   }
219 
Swap(Layout * other)220   void Swap(Layout* other) {
221     using std::swap;
222     swap(*this, *other);
223   }
224 
Clear()225   void Clear() {
226     *this = Layout();
227     format_ = INVALID_FORMAT;
228   }
229 
230   template <typename H>
AbslHashValue(H h,const Layout & l)231   friend H AbslHashValue(H h, const Layout& l) {
232     return H::combine(std::move(h), l.format_, l.minor_to_major_, l.tiles_,
233                       l.element_size_in_bits_);
234   }
235 
236  private:
237   // The format of this layout.
238   Format format_ = INVALID_FORMAT;
239 
240   // A map from physical dimension numbers to logical dimension numbers.
241   // The first element is the most minor physical dimension (fastest varying
242   // index) and the last the most major (slowest varying index). The contents of
243   // the vector are the indices of the *logical* dimensions in the shape.
244   //
245   // For example, in shape f32[8,100,100,3]{3,0,2,1}, the logical dimensions
246   // are [8,100,100,3] and minor_to_major_ is {3,0,2,1}.
247   // So, the most minor physical dimension is [8,100,100,3][3], which is size 3.
248   // The second most minor is [8,100,100,3][0], which is size 8.
249   // The third most minor is [8,100,100,3][2], which is size 100.
250   // And the major dim is [8,100,100,3][1], which is size 100.
251   absl::InlinedVector<int64, 6> minor_to_major_;
252 
253   // The tiles used in tiling-based layout.
254   absl::InlinedVector<Tile, 2> tiles_;
255 
256   // The number of bits used to store an individual array element.
257   int64 element_size_in_bits_ = 0;
258 
259   // The assigned memory space.
260   int64 memory_space_ = 0;
261 };
262 
263 std::ostream& operator<<(std::ostream& out, const Tile& Tile);
264 std::ostream& operator<<(std::ostream& out, const Layout& layout);
265 
266 }  // namespace xla
267 
268 #endif  // TENSORFLOW_COMPILER_XLA_LAYOUT_H_
269