1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 17 #define TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 18 19 #include <vector> 20 21 #include "absl/container/inlined_vector.h" 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/xla/types.h" 24 #include "tensorflow/compiler/xla/util.h" 25 #include "tensorflow/compiler/xla/xla_data.pb.h" 26 #include "tensorflow/core/platform/types.h" 27 28 namespace xla { 29 30 // Describes a tile used in tiling-based layout. Refer to 31 // g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for 32 // details. 33 class Tile { 34 public: 35 Tile() = default; Tile(absl::Span<const int64> dimensions)36 explicit Tile(absl::Span<const int64> dimensions) 37 : dimensions_(dimensions.begin(), dimensions.end()) {} 38 39 // De/Serialize a Tile to and from a TileProto. CreateFromProto(const TileProto & tile_proto)40 static Tile CreateFromProto(const TileProto& tile_proto) { 41 return Tile(AsInt64Slice(tile_proto.dimensions())); 42 } 43 TileProto ToProto() const; 44 45 bool operator==(const Tile& other) const { 46 return dimensions() == other.dimensions(); 47 } 48 bool operator!=(const Tile& other) const { return !(*this == other); } 49 50 string ToString() const; 51 52 // Returns the bound of the tile in the given dimension index. dimension(int i)53 int64 dimension(int i) const { return dimensions_.at(i); } 54 55 // Returns the dimensions of the tile. dimensions()56 absl::Span<const int64> dimensions() const { return dimensions_; } 57 add_dimensions(int64 value)58 Tile& add_dimensions(int64 value) { 59 dimensions_.push_back(value); 60 return *this; 61 } 62 clear_dimensions()63 Tile& clear_dimensions() { 64 dimensions_.clear(); 65 return *this; 66 } 67 68 // This dimension size means the corresponding dimension in the shape is 69 // combined with the next minor dimension before tiling is applied. 70 static constexpr int64 kCombineDimension = std::numeric_limits<int64>::min(); 71 72 template <typename H> AbslHashValue(H h,const Tile & t)73 friend H AbslHashValue(H h, const Tile& t) { 74 return H::combine(std::move(h), t.dimensions_); 75 } 76 77 private: 78 // The bounds of the tile. 79 absl::InlinedVector<int64, 2> dimensions_; 80 }; 81 82 class Layout { 83 public: 84 Layout() = default; 85 86 // Constructs a dense layout with the given minor-to-major order. Layout(absl::Span<const int64> minor_to_major)87 explicit Layout(absl::Span<const int64> minor_to_major) 88 : format_(DENSE), 89 minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} 90 91 // Constructs a dense tiled layout with the given minor-to-major order and 92 // tiles. 93 Layout(absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles, 94 int64 element_size_in_bits = 0, int64 memory_space = 0) format_(DENSE)95 : format_(DENSE), 96 minor_to_major_(minor_to_major.begin(), minor_to_major.end()), 97 tiles_(tiles.begin(), tiles.end()), 98 element_size_in_bits_(element_size_in_bits), 99 memory_space_(memory_space) {} 100 101 // Construct a shape from a LayoutProto. 102 static Layout CreateFromProto(const LayoutProto& proto); 103 104 // Returns a LayoutProto representation of the Layout. 105 LayoutProto ToProto() const; 106 107 // Returns a human-readable string that represents this layout. 108 string ToString() const; 109 110 // Equal is a configurable functor to check the equality of two layouts. 111 // 112 // Examples: 113 // 114 // - Comparing two layouts ignoring their difference in tiles: 115 // Equal().IgnoreTiles()(layout1, layout2); 116 // 117 // - Comparing two layouts ignoring their difference in tiles and element 118 // size: 119 // Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2); 120 class Equal { 121 public: 122 Equal() = default; 123 124 bool operator()(const Layout& lhs, const Layout& rhs); 125 IgnoreTiles()126 Equal& IgnoreTiles() { 127 ignore_tiles_ = true; 128 return *this; 129 } 130 IgnoreElementSize()131 Equal& IgnoreElementSize() { 132 ignore_element_size_ = true; 133 return *this; 134 } 135 MinorToMajorOnly()136 Equal& MinorToMajorOnly() { 137 ignore_tiles_ = true; 138 ignore_element_size_ = true; 139 ignore_memory_space_ = true; 140 return *this; 141 } 142 IgnoreMemorySpace()143 Equal& IgnoreMemorySpace() { 144 ignore_memory_space_ = true; 145 return *this; 146 } 147 148 private: 149 bool ignore_tiles_ = false; 150 bool ignore_element_size_ = false; 151 bool ignore_memory_space_ = false; 152 }; 153 154 bool operator==(const Layout& other) const; 155 bool operator!=(const Layout& other) const { return !(*this == other); } 156 157 // The following methods mirror the protobuf generated code interface for the 158 // message LayoutProto. This enabled easy migration of this data structure 159 // from a proto to a proper C++ class. 160 // 161 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 162 // interface. 163 164 // Methods for accessing the format. format()165 Format format() const { return format_; } set_format(Format value)166 Layout& set_format(Format value) { 167 format_ = value; 168 return *this; 169 } 170 171 // Methods for accessing the minor-to-major array. minor_to_major_size()172 int minor_to_major_size() const { return minor_to_major_.size(); } minor_to_major(int index)173 int64 minor_to_major(int index) const { return minor_to_major_.at(index); } set_minor_to_major(int index,int64 value)174 Layout& set_minor_to_major(int index, int64 value) { 175 minor_to_major_.at(index) = value; 176 return *this; 177 } add_minor_to_major(int64 value)178 Layout& add_minor_to_major(int64 value) { 179 minor_to_major_.push_back(value); 180 return *this; 181 } clear_minor_to_major()182 Layout& clear_minor_to_major() { 183 minor_to_major_.clear(); 184 return *this; 185 } minor_to_major()186 absl::Span<const int64> minor_to_major() const { return minor_to_major_; } mutable_minor_to_major()187 absl::InlinedVector<int64, 6>* mutable_minor_to_major() { 188 return &minor_to_major_; 189 } 190 191 // Methods for accessing the tile field. tiles_size()192 int tiles_size() const { return tiles_.size(); } tiles(int index)193 const Tile& tiles(int index) const { return tiles_.at(index); } mutable_tiles(int index)194 Tile* mutable_tiles(int index) { return &tiles_.at(index); } add_tiles()195 Tile* add_tiles() { 196 tiles_.push_back(Tile()); 197 return &tiles_.back(); 198 } clear_tiles()199 Layout& clear_tiles() { 200 tiles_.clear(); 201 return *this; 202 } tiles()203 absl::Span<const Tile> tiles() const { return tiles_; } mutable_tiles()204 absl::InlinedVector<Tile, 2>* mutable_tiles() { return &tiles_; } 205 element_size_in_bits()206 int64 element_size_in_bits() const { return element_size_in_bits_; } set_element_size_in_bits(int64 value)207 Layout& set_element_size_in_bits(int64 value) { 208 element_size_in_bits_ = value; 209 return *this; 210 } 211 static constexpr int64 kDefaultMemorySpace = 0; memory_space()212 int64 memory_space() const { return memory_space_; } set_memory_space(int64 value)213 Layout& set_memory_space(int64 value) { 214 memory_space_ = value; 215 return *this; 216 } 217 Swap(Layout * other)218 void Swap(Layout* other) { 219 using std::swap; 220 swap(*this, *other); 221 } 222 Clear()223 void Clear() { 224 *this = Layout(); 225 format_ = INVALID_FORMAT; 226 } 227 228 template <typename H> AbslHashValue(H h,const Layout & l)229 friend H AbslHashValue(H h, const Layout& l) { 230 return H::combine(std::move(h), l.format_, l.minor_to_major_, l.tiles_, 231 l.element_size_in_bits_); 232 } 233 234 private: 235 // The format of this layout. 236 Format format_ = INVALID_FORMAT; 237 238 // A map from physical dimension numbers to logical dimension numbers. 239 // The first element is the most minor physical dimension (fastest varying 240 // index) and the last the most major (slowest varying index). The contents of 241 // the vector are the indices of the *logical* dimensions in the shape. 242 // 243 // For example, in shape f32[8,100,100,3]{3,0,2,1}, the logical dimensions 244 // are [8,100,100,3] and minor_to_major_ is {3,0,2,1}. 245 // So, the most minor physical dimension is [8,100,100,3][3], which is size 3. 246 // The second most minor is [8,100,100,3][0], which is size 8. 247 // The third most minor is [8,100,100,3][2], which is size 100. 248 // And the major dim is [8,100,100,3][1], which is size 100. 249 absl::InlinedVector<int64, 6> minor_to_major_; 250 251 // The tiles used in tiling-based layout. 252 absl::InlinedVector<Tile, 2> tiles_; 253 254 // The number of bits used to store an individual array element. 255 int64 element_size_in_bits_ = 0; 256 257 // The assigned memory space. 258 int64 memory_space_ = 0; 259 }; 260 261 std::ostream& operator<<(std::ostream& out, const Tile& Tile); 262 std::ostream& operator<<(std::ostream& out, const Layout& layout); 263 264 } // namespace xla 265 266 #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 267