1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 17 #define TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 18 19 #include <vector> 20 21 #include "absl/container/inlined_vector.h" 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/xla/types.h" 24 #include "tensorflow/compiler/xla/util.h" 25 #include "tensorflow/compiler/xla/xla_data.pb.h" 26 #include "tensorflow/core/platform/types.h" 27 28 namespace xla { 29 30 // Describes a tile used in tiling-based layout. Refer to 31 // g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for 32 // details. 33 class Tile { 34 public: 35 Tile() = default; Tile(absl::Span<const int64> dimensions)36 explicit Tile(absl::Span<const int64> dimensions) 37 : dimensions_(dimensions.begin(), dimensions.end()) {} 38 39 // De/Serialize a Tile to and from a TileProto. CreateFromProto(const TileProto & tile_proto)40 static Tile CreateFromProto(const TileProto& tile_proto) { 41 return Tile(AsInt64Slice(tile_proto.dimensions())); 42 } 43 TileProto ToProto() const; 44 45 bool operator==(const Tile& other) const { 46 return dimensions() == other.dimensions(); 47 } 48 bool operator!=(const Tile& other) const { return !(*this == other); } 49 50 string ToString() const; 51 52 // Returns the bound of the tile in the given dimension index. dimension(int i)53 int64 dimension(int i) const { return dimensions_.at(i); } 54 55 // Returns the dimensions of the tile. dimensions()56 absl::Span<const int64> dimensions() const { return dimensions_; } 57 add_dimensions(int64_t value)58 Tile& add_dimensions(int64_t value) { 59 dimensions_.push_back(value); 60 return *this; 61 } 62 clear_dimensions()63 Tile& clear_dimensions() { 64 dimensions_.clear(); 65 return *this; 66 } 67 68 // This dimension size means the corresponding dimension in the shape is 69 // combined with the next minor dimension before tiling is applied. 70 static constexpr int64_t kCombineDimension = 71 std::numeric_limits<int64>::min(); 72 73 template <typename H> AbslHashValue(H h,const Tile & t)74 friend H AbslHashValue(H h, const Tile& t) { 75 return H::combine(std::move(h), t.dimensions_); 76 } 77 78 private: 79 // The bounds of the tile. 80 absl::InlinedVector<int64, 2> dimensions_; 81 }; 82 83 class Layout { 84 public: 85 Layout() = default; 86 87 // Constructs a dense layout with the given minor-to-major order. Layout(absl::Span<const int64> minor_to_major)88 explicit Layout(absl::Span<const int64> minor_to_major) 89 : format_(DENSE), 90 minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} 91 92 // Constructs a dense tiled layout with the given minor-to-major order and 93 // tiles. 94 Layout(absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles, 95 int64_t element_size_in_bits = 0, int64_t memory_space = 0) format_(DENSE)96 : format_(DENSE), 97 minor_to_major_(minor_to_major.begin(), minor_to_major.end()), 98 tiles_(tiles.begin(), tiles.end()), 99 element_size_in_bits_(element_size_in_bits), 100 memory_space_(memory_space) {} 101 102 // Construct a shape from a LayoutProto. 103 static Layout CreateFromProto(const LayoutProto& proto); 104 105 // Returns a LayoutProto representation of the Layout. 106 LayoutProto ToProto() const; 107 108 // Returns a human-readable string that represents this layout. 109 string ToString() const; 110 111 // Equal is a configurable functor to check the equality of two layouts. 112 // 113 // Examples: 114 // 115 // - Comparing two layouts ignoring their difference in tiles: 116 // Equal().IgnoreTiles()(layout1, layout2); 117 // 118 // - Comparing two layouts ignoring their difference in tiles and element 119 // size: 120 // Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2); 121 class Equal { 122 public: 123 Equal() = default; 124 125 bool operator()(const Layout& lhs, const Layout& rhs); 126 IgnoreTiles()127 Equal& IgnoreTiles() { 128 ignore_tiles_ = true; 129 return *this; 130 } 131 IgnoreElementSize()132 Equal& IgnoreElementSize() { 133 ignore_element_size_ = true; 134 return *this; 135 } 136 MinorToMajorOnly()137 Equal& MinorToMajorOnly() { 138 ignore_tiles_ = true; 139 ignore_element_size_ = true; 140 ignore_memory_space_ = true; 141 return *this; 142 } 143 IgnoreMemorySpace()144 Equal& IgnoreMemorySpace() { 145 ignore_memory_space_ = true; 146 return *this; 147 } 148 149 private: 150 bool ignore_tiles_ = false; 151 bool ignore_element_size_ = false; 152 bool ignore_memory_space_ = false; 153 }; 154 155 bool operator==(const Layout& other) const; 156 bool operator!=(const Layout& other) const { return !(*this == other); } 157 158 // The following methods mirror the protobuf generated code interface for the 159 // message LayoutProto. This enabled easy migration of this data structure 160 // from a proto to a proper C++ class. 161 // 162 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 163 // interface. 164 165 // Methods for accessing the format. format()166 Format format() const { return format_; } set_format(Format value)167 Layout& set_format(Format value) { 168 format_ = value; 169 return *this; 170 } 171 172 // Methods for accessing the minor-to-major array. minor_to_major_size()173 int minor_to_major_size() const { return minor_to_major_.size(); } minor_to_major(int index)174 int64 minor_to_major(int index) const { return minor_to_major_.at(index); } set_minor_to_major(int index,int64_t value)175 Layout& set_minor_to_major(int index, int64_t value) { 176 minor_to_major_.at(index) = value; 177 return *this; 178 } add_minor_to_major(int64_t value)179 Layout& add_minor_to_major(int64_t value) { 180 minor_to_major_.push_back(value); 181 return *this; 182 } clear_minor_to_major()183 Layout& clear_minor_to_major() { 184 minor_to_major_.clear(); 185 return *this; 186 } minor_to_major()187 absl::Span<const int64> minor_to_major() const { return minor_to_major_; } mutable_minor_to_major()188 absl::InlinedVector<int64, 6>* mutable_minor_to_major() { 189 return &minor_to_major_; 190 } 191 192 // Methods for accessing the tile field. tiles_size()193 int tiles_size() const { return tiles_.size(); } tiles(int index)194 const Tile& tiles(int index) const { return tiles_.at(index); } mutable_tiles(int index)195 Tile* mutable_tiles(int index) { return &tiles_.at(index); } add_tiles()196 Tile* add_tiles() { 197 tiles_.push_back(Tile()); 198 return &tiles_.back(); 199 } clear_tiles()200 Layout& clear_tiles() { 201 tiles_.clear(); 202 return *this; 203 } tiles()204 absl::Span<const Tile> tiles() const { return tiles_; } mutable_tiles()205 absl::InlinedVector<Tile, 2>* mutable_tiles() { return &tiles_; } 206 element_size_in_bits()207 int64 element_size_in_bits() const { return element_size_in_bits_; } set_element_size_in_bits(int64_t value)208 Layout& set_element_size_in_bits(int64_t value) { 209 element_size_in_bits_ = value; 210 return *this; 211 } 212 static constexpr int64_t kDefaultMemorySpace = 0; 213 static constexpr int64_t kGenericFastMemorySpace = 1; memory_space()214 int64 memory_space() const { return memory_space_; } set_memory_space(int64_t value)215 Layout& set_memory_space(int64_t value) { 216 memory_space_ = value; 217 return *this; 218 } 219 Swap(Layout * other)220 void Swap(Layout* other) { 221 using std::swap; 222 swap(*this, *other); 223 } 224 Clear()225 void Clear() { 226 *this = Layout(); 227 format_ = INVALID_FORMAT; 228 } 229 230 template <typename H> AbslHashValue(H h,const Layout & l)231 friend H AbslHashValue(H h, const Layout& l) { 232 return H::combine(std::move(h), l.format_, l.minor_to_major_, l.tiles_, 233 l.element_size_in_bits_); 234 } 235 236 private: 237 // The format of this layout. 238 Format format_ = INVALID_FORMAT; 239 240 // A map from physical dimension numbers to logical dimension numbers. 241 // The first element is the most minor physical dimension (fastest varying 242 // index) and the last the most major (slowest varying index). The contents of 243 // the vector are the indices of the *logical* dimensions in the shape. 244 // 245 // For example, in shape f32[8,100,100,3]{3,0,2,1}, the logical dimensions 246 // are [8,100,100,3] and minor_to_major_ is {3,0,2,1}. 247 // So, the most minor physical dimension is [8,100,100,3][3], which is size 3. 248 // The second most minor is [8,100,100,3][0], which is size 8. 249 // The third most minor is [8,100,100,3][2], which is size 100. 250 // And the major dim is [8,100,100,3][1], which is size 100. 251 absl::InlinedVector<int64, 6> minor_to_major_; 252 253 // The tiles used in tiling-based layout. 254 absl::InlinedVector<Tile, 2> tiles_; 255 256 // The number of bits used to store an individual array element. 257 int64 element_size_in_bits_ = 0; 258 259 // The assigned memory space. 260 int64 memory_space_ = 0; 261 }; 262 263 std::ostream& operator<<(std::ostream& out, const Tile& Tile); 264 std::ostream& operator<<(std::ostream& out, const Layout& layout); 265 266 } // namespace xla 267 268 #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 269