• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_
17 #define TENSORFLOW_COMPILER_XLA_LAYOUT_H_
18 
19 #include <vector>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace xla {
29 
30 // Describes a tile used in tiling-based layout. Refer to
31 // g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
32 // details.
33 class Tile {
34  public:
35   Tile() = default;
Tile(absl::Span<const int64> dimensions)36   explicit Tile(absl::Span<const int64> dimensions)
37       : dimensions_(dimensions.begin(), dimensions.end()) {}
38 
39   // De/Serialize a Tile to and from a TileProto.
CreateFromProto(const TileProto & tile_proto)40   static Tile CreateFromProto(const TileProto& tile_proto) {
41     return Tile(AsInt64Slice(tile_proto.dimensions()));
42   }
43   TileProto ToProto() const;
44 
45   bool operator==(const Tile& other) const {
46     return dimensions() == other.dimensions();
47   }
48   bool operator!=(const Tile& other) const { return !(*this == other); }
49 
50   string ToString() const;
51 
52   // Returns the bound of the tile in the given dimension index.
dimension(int i)53   int64 dimension(int i) const { return dimensions_.at(i); }
54 
55   // Returns the dimensions of the tile.
dimensions()56   absl::Span<const int64> dimensions() const { return dimensions_; }
57 
add_dimensions(int64 value)58   Tile& add_dimensions(int64 value) {
59     dimensions_.push_back(value);
60     return *this;
61   }
62 
clear_dimensions()63   Tile& clear_dimensions() {
64     dimensions_.clear();
65     return *this;
66   }
67 
68   // This dimension size means the corresponding dimension in the shape is
69   // combined with the next minor dimension before tiling is applied.
70   static constexpr int64 kCombineDimension = std::numeric_limits<int64>::min();
71 
72   template <typename H>
AbslHashValue(H h,const Tile & t)73   friend H AbslHashValue(H h, const Tile& t) {
74     return H::combine(std::move(h), t.dimensions_);
75   }
76 
77  private:
78   // The bounds of the tile.
79   absl::InlinedVector<int64, 2> dimensions_;
80 };
81 
82 class Layout {
83  public:
84   Layout() = default;
85 
86   // Constructs a dense layout with the given minor-to-major order.
Layout(absl::Span<const int64> minor_to_major)87   explicit Layout(absl::Span<const int64> minor_to_major)
88       : format_(DENSE),
89         minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {}
90 
91   // Constructs a dense tiled layout with the given minor-to-major order and
92   // tiles.
93   Layout(absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
94          int64 element_size_in_bits = 0, int64 memory_space = 0)
format_(DENSE)95       : format_(DENSE),
96         minor_to_major_(minor_to_major.begin(), minor_to_major.end()),
97         tiles_(tiles.begin(), tiles.end()),
98         element_size_in_bits_(element_size_in_bits),
99         memory_space_(memory_space) {}
100 
101   // Construct a shape from a LayoutProto.
102   static Layout CreateFromProto(const LayoutProto& proto);
103 
104   // Returns a LayoutProto representation of the Layout.
105   LayoutProto ToProto() const;
106 
107   // Returns a human-readable string that represents this layout.
108   string ToString() const;
109 
110   // Equal is a configurable functor to check the equality of two layouts.
111   //
112   // Examples:
113   //
114   // - Comparing two layouts ignoring their difference in tiles:
115   //   Equal().IgnoreTiles()(layout1, layout2);
116   //
117   // - Comparing two layouts ignoring their difference in tiles and element
118   //   size:
119   //   Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2);
120   class Equal {
121    public:
122     Equal() = default;
123 
124     bool operator()(const Layout& lhs, const Layout& rhs);
125 
IgnoreTiles()126     Equal& IgnoreTiles() {
127       ignore_tiles_ = true;
128       return *this;
129     }
130 
IgnoreElementSize()131     Equal& IgnoreElementSize() {
132       ignore_element_size_ = true;
133       return *this;
134     }
135 
MinorToMajorOnly()136     Equal& MinorToMajorOnly() {
137       ignore_tiles_ = true;
138       ignore_element_size_ = true;
139       ignore_memory_space_ = true;
140       return *this;
141     }
142 
IgnoreMemorySpace()143     Equal& IgnoreMemorySpace() {
144       ignore_memory_space_ = true;
145       return *this;
146     }
147 
148    private:
149     bool ignore_tiles_ = false;
150     bool ignore_element_size_ = false;
151     bool ignore_memory_space_ = false;
152   };
153 
154   bool operator==(const Layout& other) const;
155   bool operator!=(const Layout& other) const { return !(*this == other); }
156 
157   // The following methods mirror the protobuf generated code interface for the
158   // message LayoutProto. This enabled easy migration of this data structure
159   // from a proto to a proper C++ class.
160   //
161   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
162   // interface.
163 
164   // Methods for accessing the format.
format()165   Format format() const { return format_; }
set_format(Format value)166   Layout& set_format(Format value) {
167     format_ = value;
168     return *this;
169   }
170 
171   // Methods for accessing the minor-to-major array.
minor_to_major_size()172   int minor_to_major_size() const { return minor_to_major_.size(); }
minor_to_major(int index)173   int64 minor_to_major(int index) const { return minor_to_major_.at(index); }
set_minor_to_major(int index,int64 value)174   Layout& set_minor_to_major(int index, int64 value) {
175     minor_to_major_.at(index) = value;
176     return *this;
177   }
add_minor_to_major(int64 value)178   Layout& add_minor_to_major(int64 value) {
179     minor_to_major_.push_back(value);
180     return *this;
181   }
clear_minor_to_major()182   Layout& clear_minor_to_major() {
183     minor_to_major_.clear();
184     return *this;
185   }
minor_to_major()186   absl::Span<const int64> minor_to_major() const { return minor_to_major_; }
mutable_minor_to_major()187   absl::InlinedVector<int64, 6>* mutable_minor_to_major() {
188     return &minor_to_major_;
189   }
190 
191   // Methods for accessing the tile field.
tiles_size()192   int tiles_size() const { return tiles_.size(); }
tiles(int index)193   const Tile& tiles(int index) const { return tiles_.at(index); }
mutable_tiles(int index)194   Tile* mutable_tiles(int index) { return &tiles_.at(index); }
add_tiles()195   Tile* add_tiles() {
196     tiles_.push_back(Tile());
197     return &tiles_.back();
198   }
clear_tiles()199   Layout& clear_tiles() {
200     tiles_.clear();
201     return *this;
202   }
tiles()203   absl::Span<const Tile> tiles() const { return tiles_; }
mutable_tiles()204   absl::InlinedVector<Tile, 2>* mutable_tiles() { return &tiles_; }
205 
element_size_in_bits()206   int64 element_size_in_bits() const { return element_size_in_bits_; }
set_element_size_in_bits(int64 value)207   Layout& set_element_size_in_bits(int64 value) {
208     element_size_in_bits_ = value;
209     return *this;
210   }
211   static constexpr int64 kDefaultMemorySpace = 0;
212   static constexpr int64 kGenericFastMemorySpace = 1;
memory_space()213   int64 memory_space() const { return memory_space_; }
set_memory_space(int64 value)214   Layout& set_memory_space(int64 value) {
215     memory_space_ = value;
216     return *this;
217   }
218 
Swap(Layout * other)219   void Swap(Layout* other) {
220     using std::swap;
221     swap(*this, *other);
222   }
223 
Clear()224   void Clear() {
225     *this = Layout();
226     format_ = INVALID_FORMAT;
227   }
228 
229   template <typename H>
AbslHashValue(H h,const Layout & l)230   friend H AbslHashValue(H h, const Layout& l) {
231     return H::combine(std::move(h), l.format_, l.minor_to_major_, l.tiles_,
232                       l.element_size_in_bits_);
233   }
234 
235  private:
236   // The format of this layout.
237   Format format_ = INVALID_FORMAT;
238 
239   // A map from physical dimension numbers to logical dimension numbers.
240   // The first element is the most minor physical dimension (fastest varying
241   // index) and the last the most major (slowest varying index). The contents of
242   // the vector are the indices of the *logical* dimensions in the shape.
243   //
244   // For example, in shape f32[8,100,100,3]{3,0,2,1}, the logical dimensions
245   // are [8,100,100,3] and minor_to_major_ is {3,0,2,1}.
246   // So, the most minor physical dimension is [8,100,100,3][3], which is size 3.
247   // The second most minor is [8,100,100,3][0], which is size 8.
248   // The third most minor is [8,100,100,3][2], which is size 100.
249   // And the major dim is [8,100,100,3][1], which is size 100.
250   absl::InlinedVector<int64, 6> minor_to_major_;
251 
252   // The tiles used in tiling-based layout.
253   absl::InlinedVector<Tile, 2> tiles_;
254 
255   // The number of bits used to store an individual array element.
256   int64 element_size_in_bits_ = 0;
257 
258   // The assigned memory space.
259   int64 memory_space_ = 0;
260 };
261 
262 std::ostream& operator<<(std::ostream& out, const Tile& Tile);
263 std::ostream& operator<<(std::ostream& out, const Layout& layout);
264 
265 }  // namespace xla
266 
267 #endif  // TENSORFLOW_COMPILER_XLA_LAYOUT_H_
268