1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 17 #define TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "absl/container/inlined_vector.h" 23 #include "absl/types/span.h" 24 #include "tensorflow/compiler/xla/types.h" 25 #include "tensorflow/compiler/xla/util.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 28 namespace xla { 29 30 // Describes a tile used in tiling-based layout. Refer to 31 // g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for 32 // details. 33 class Tile { 34 public: 35 Tile() = default; Tile(absl::Span<const int64_t> dimensions)36 explicit Tile(absl::Span<const int64_t> dimensions) 37 : dimensions_(dimensions.begin(), dimensions.end()) {} 38 39 // De/Serialize a Tile to and from a TileProto. CreateFromProto(const TileProto & tile_proto)40 static Tile CreateFromProto(const TileProto& tile_proto) { 41 return Tile(tile_proto.dimensions()); 42 } 43 TileProto ToProto() const; 44 45 bool operator==(const Tile& other) const { 46 return dimensions() == other.dimensions(); 47 } 48 bool operator!=(const Tile& other) const { return !(*this == other); } 49 50 std::string ToString() const; 51 52 // Returns the bound of the tile in the given dimension index. dimension(int i)53 int64_t dimension(int i) const { return dimensions_.at(i); } 54 55 // Returns the dimensions of the tile. dimensions()56 absl::Span<const int64_t> dimensions() const { return dimensions_; } 57 add_dimensions(int64_t value)58 Tile& add_dimensions(int64_t value) { 59 dimensions_.push_back(value); 60 return *this; 61 } 62 clear_dimensions()63 Tile& clear_dimensions() { 64 dimensions_.clear(); 65 return *this; 66 } 67 68 // This dimension size means the corresponding dimension in the shape is 69 // combined with the next minor dimension before tiling is applied. 70 static constexpr int64_t kCombineDimension = 71 std::numeric_limits<int64_t>::min(); 72 73 template <typename H> AbslHashValue(H h,const Tile & t)74 friend H AbslHashValue(H h, const Tile& t) { 75 return H::combine(std::move(h), t.dimensions_); 76 } 77 78 private: 79 // The bounds of the tile. 80 absl::InlinedVector<int64_t, 2> dimensions_; 81 }; 82 83 class Layout { 84 public: 85 Layout() = default; 86 87 // Constructs a dense layout with the given minor-to-major order. Layout(absl::Span<const int64_t> minor_to_major)88 explicit Layout(absl::Span<const int64_t> minor_to_major) 89 : minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} 90 91 // Constructs a dense tiled layout with the given minor-to-major order, dim 92 // level types, and tiles. 93 Layout(absl::Span<const int64_t> minor_to_major, 94 absl::Span<const DimLevelType> dim_level_types, 95 absl::Span<const Tile> tiles, int64_t element_size_in_bits = 0, 96 int64_t memory_space = 0) 97 : dim_level_types_(dim_level_types.begin(), dim_level_types.end()), 98 minor_to_major_(minor_to_major.begin(), minor_to_major.end()), 99 tiles_(tiles.begin(), tiles.end()), 100 element_size_in_bits_(element_size_in_bits), 101 memory_space_(memory_space) {} 102 103 // Construct a shape from a LayoutProto. 104 static Layout CreateFromProto(const LayoutProto& proto); 105 106 // Returns a LayoutProto representation of the Layout. 107 LayoutProto ToProto() const; 108 109 // Returns a human-readable string that represents this layout. 110 std::string ToString() const; 111 112 // Equal is a configurable functor to check the equality of two layouts. 113 // 114 // Examples: 115 // 116 // - Comparing two layouts ignoring their difference in tiles: 117 // Equal().IgnoreTiles()(layout1, layout2); 118 // 119 // - Comparing two layouts ignoring their difference in tiles and element 120 // size: 121 // Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2); 122 class Equal { 123 public: 124 Equal() = default; 125 126 bool operator()(const Layout& lhs, const Layout& rhs); 127 IgnoreTiles()128 Equal& IgnoreTiles() { 129 ignore_tiles_ = true; 130 return *this; 131 } 132 IgnoreElementSize()133 Equal& IgnoreElementSize() { 134 ignore_element_size_ = true; 135 return *this; 136 } 137 MinorToMajorOnly()138 Equal& MinorToMajorOnly() { 139 ignore_tiles_ = true; 140 ignore_element_size_ = true; 141 ignore_memory_space_ = true; 142 return *this; 143 } 144 IgnoreMemorySpace()145 Equal& IgnoreMemorySpace() { 146 ignore_memory_space_ = true; 147 return *this; 148 } 149 150 private: 151 bool ignore_tiles_ = false; 152 bool ignore_element_size_ = false; 153 bool ignore_memory_space_ = false; 154 }; 155 156 bool operator==(const Layout& other) const; 157 bool operator!=(const Layout& other) const { return !(*this == other); } 158 159 // The following methods mirror the protobuf generated code interface for the 160 // message LayoutProto. This enabled easy migration of this data structure 161 // from a proto to a proper C++ class. 162 // 163 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 164 // interface. 165 166 // Methods for accessing the DimLevelType array. dim_level_types_size()167 int dim_level_types_size() const { return dim_level_types_.size(); } dim_level_type(int index)168 DimLevelType dim_level_type(int index) const { 169 return dim_level_types_.at(index); 170 } set_dim_level_type(int index,DimLevelType dim_level_type)171 Layout& set_dim_level_type(int index, DimLevelType dim_level_type) { 172 dim_level_types_.at(index) = dim_level_type; 173 return *this; 174 } add_dim_level_type(DimLevelType dim_level_type)175 Layout& add_dim_level_type(DimLevelType dim_level_type) { 176 dim_level_types_.push_back(dim_level_type); 177 return *this; 178 } clear_dim_level_types()179 Layout& clear_dim_level_types() { 180 dim_level_types_.clear(); 181 return *this; 182 } dim_level_types()183 absl::Span<const DimLevelType> dim_level_types() const { 184 return dim_level_types_; 185 } mutable_dim_level_types()186 DimLevelTypeVector* mutable_dim_level_types() { return &dim_level_types_; } 187 188 // Methods for accessing the minor-to-major array. minor_to_major_size()189 int minor_to_major_size() const { return minor_to_major_.size(); } minor_to_major(int index)190 int64_t minor_to_major(int index) const { return minor_to_major_.at(index); } set_minor_to_major(int index,int64_t value)191 Layout& set_minor_to_major(int index, int64_t value) { 192 minor_to_major_.at(index) = value; 193 return *this; 194 } add_minor_to_major(int64_t value)195 Layout& add_minor_to_major(int64_t value) { 196 minor_to_major_.push_back(value); 197 return *this; 198 } clear_minor_to_major()199 Layout& clear_minor_to_major() { 200 minor_to_major_.clear(); 201 return *this; 202 } minor_to_major()203 absl::Span<const int64_t> minor_to_major() const { return minor_to_major_; } mutable_minor_to_major()204 DimensionVector* mutable_minor_to_major() { return &minor_to_major_; } 205 206 // Methods for accessing the tile field. tiles_size()207 int tiles_size() const { return tiles_.size(); } tiles(int index)208 const Tile& tiles(int index) const { return tiles_.at(index); } mutable_tiles(int index)209 Tile* mutable_tiles(int index) { return &tiles_.at(index); } add_tiles()210 Tile* add_tiles() { 211 tiles_.push_back(Tile()); 212 return &tiles_.back(); 213 } clear_tiles()214 Layout& clear_tiles() { 215 tiles_.clear(); 216 return *this; 217 } tiles()218 absl::Span<const Tile> tiles() const { return tiles_; } mutable_tiles()219 absl::InlinedVector<Tile, 2>* mutable_tiles() { return &tiles_; } 220 element_size_in_bits()221 int64_t element_size_in_bits() const { return element_size_in_bits_; } set_element_size_in_bits(int64_t value)222 Layout& set_element_size_in_bits(int64_t value) { 223 element_size_in_bits_ = value; 224 return *this; 225 } 226 static constexpr int64_t kDefaultMemorySpace = 0; 227 static constexpr int64_t kGenericFastMemorySpace = 1; memory_space()228 int64_t memory_space() const { return memory_space_; } set_memory_space(int64_t value)229 Layout& set_memory_space(int64_t value) { 230 memory_space_ = value; 231 return *this; 232 } 233 Swap(Layout * other)234 void Swap(Layout* other) { 235 using std::swap; 236 swap(*this, *other); 237 } 238 Clear()239 void Clear() { *this = Layout(); } 240 241 template <typename H> AbslHashValue(H h,const Layout & l)242 friend H AbslHashValue(H h, const Layout& l) { 243 return H::combine(std::move(h), l.minor_to_major_, l.tiles_, 244 l.element_size_in_bits_, l.memory_space_); 245 } 246 247 private: 248 // The list of dimension level types, indicating the method that will be used 249 // to represent each dimension of the array. 250 DimLevelTypeVector dim_level_types_; 251 252 // A map from physical dimension numbers to logical dimension numbers. 253 // The first element is the most minor physical dimension (fastest varying 254 // index) and the last the most major (slowest varying index). The contents of 255 // the vector are the indices of the *logical* dimensions in the shape. 256 // 257 // For example, in shape f32[8,100,100,3]{3,0,2,1}, the logical dimensions 258 // are [8,100,100,3] and minor_to_major_ is {3,0,2,1}. 259 // So, the most minor physical dimension is [8,100,100,3][3], which is size 3. 260 // The second most minor is [8,100,100,3][0], which is size 8. 261 // The third most minor is [8,100,100,3][2], which is size 100. 262 // And the major dim is [8,100,100,3][1], which is size 100. 263 DimensionVector minor_to_major_; 264 265 // The tiles used in tiling-based layout. 266 absl::InlinedVector<Tile, 2> tiles_; 267 268 // The number of bits used to store an individual array element. 269 int64_t element_size_in_bits_ = 0; 270 271 // The assigned memory space. 272 int64_t memory_space_ = 0; 273 }; 274 275 std::ostream& operator<<(std::ostream& out, const Tile& Tile); 276 std::ostream& operator<<(std::ostream& out, const Layout& layout); 277 278 } // namespace xla 279 280 #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 281