1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 17 #define TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 18 19 #include <vector> 20 21 #include "absl/types/span.h" 22 23 #include "tensorflow/compiler/xla/types.h" 24 #include "tensorflow/compiler/xla/util.h" 25 #include "tensorflow/compiler/xla/xla_data.pb.h" 26 #include "tensorflow/core/platform/types.h" 27 28 namespace xla { 29 30 // Describes a tile used in tiling-based layout. Refer to 31 // g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for 32 // details. 33 class Tile { 34 public: 35 Tile() = default; Tile(absl::Span<const int64> dimensions)36 explicit Tile(absl::Span<const int64> dimensions) 37 : dimensions_(dimensions.begin(), dimensions.end()) {} 38 39 // De/Serialize a Tile to and from a TileProto. CreateFromProto(const TileProto & tile_proto)40 static Tile CreateFromProto(const TileProto& tile_proto) { 41 return Tile(AsInt64Slice(tile_proto.dimensions())); 42 } 43 TileProto ToProto() const; 44 45 bool operator==(const Tile& other) const { 46 return dimensions() == other.dimensions(); 47 } 48 bool operator!=(const Tile& other) const { return !(*this == other); } 49 50 string ToString() const; 51 52 // Returns the bound of the tile in the given dimension index. dimension(int i)53 int64 dimension(int i) const { return dimensions_.at(i); } 54 55 // Returns the dimensions of the tile. dimensions()56 const std::vector<int64>& dimensions() const { return dimensions_; } 57 add_dimensions(int64 value)58 Tile& add_dimensions(int64 value) { 59 dimensions_.push_back(value); 60 return *this; 61 } 62 clear_dimensions()63 Tile& clear_dimensions() { 64 dimensions_.clear(); 65 return *this; 66 } 67 68 // This dimension size means the corresponding dimension in the shape is 69 // combined with the next minor dimension before tiling is applied. 70 static constexpr int64 kCombineDimension = std::numeric_limits<int64>::min(); 71 72 private: 73 // The bounds of the tile. 74 std::vector<int64> dimensions_; 75 }; 76 77 class Layout { 78 public: 79 Layout() = default; 80 81 // Constructs a dense layout with the given minor-to-major order. Layout(absl::Span<const int64> minor_to_major)82 explicit Layout(absl::Span<const int64> minor_to_major) 83 : format_(DENSE), 84 minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} 85 86 // Constructs a dense tiled layout with the given minor-to-major order and 87 // tiles. 88 Layout(absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles, 89 int64 element_size_in_bits = 0) format_(DENSE)90 : format_(DENSE), 91 minor_to_major_(minor_to_major.begin(), minor_to_major.end()), 92 tiles_(tiles.begin(), tiles.end()), 93 element_size_in_bits_(element_size_in_bits) {} 94 95 // Construct a shape from a LayoutProto. 96 static Layout CreateFromProto(const LayoutProto& proto); 97 98 // Returns a LayoutProto representation of the Layout. 99 LayoutProto ToProto() const; 100 101 // Returns a human-readable string that represents this layout. 102 string ToString() const; 103 104 // Equal is a configurable functor to check the equality of two layouts. 105 // 106 // Examples: 107 // 108 // - Comparing two layouts ignoring their difference in tiles: 109 // Equal().IgnoreTiles()(layout1, layout2); 110 // 111 // - Comparing two layouts ignoring their difference in tiles and element 112 // size: 113 // Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2); 114 class Equal { 115 public: 116 Equal() = default; 117 118 bool operator()(const Layout& lhs, const Layout& rhs); 119 IgnoreTiles()120 Equal& IgnoreTiles() { 121 ignore_tiles_ = true; 122 return *this; 123 } 124 IgnoreElementSize()125 Equal& IgnoreElementSize() { 126 ignore_element_size_ = true; 127 return *this; 128 } 129 130 private: 131 bool ignore_tiles_ = false; 132 bool ignore_element_size_ = false; 133 }; 134 135 bool operator==(const Layout& other) const; 136 bool operator!=(const Layout& other) const { return !(*this == other); } 137 138 // The following methods mirror the protobuf generated code interface for the 139 // message LayoutProto. This enabled easy migration of this data structure 140 // from a proto to a proper C++ class. 141 // 142 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 143 // interface. 144 145 // Methods for accessing the format. format()146 Format format() const { return format_; } set_format(Format value)147 Layout& set_format(Format value) { 148 format_ = value; 149 return *this; 150 } 151 152 // Methods for accessing the minor-to-major array. minor_to_major_size()153 int minor_to_major_size() const { return minor_to_major_.size(); } minor_to_major(int index)154 int64 minor_to_major(int index) const { return minor_to_major_.at(index); } set_minor_to_major(int index,int64 value)155 Layout& set_minor_to_major(int index, int64 value) { 156 minor_to_major_.at(index) = value; 157 return *this; 158 } add_minor_to_major(int64 value)159 Layout& add_minor_to_major(int64 value) { 160 minor_to_major_.push_back(value); 161 return *this; 162 } clear_minor_to_major()163 Layout& clear_minor_to_major() { 164 minor_to_major_.clear(); 165 return *this; 166 } minor_to_major()167 const std::vector<int64>& minor_to_major() const { return minor_to_major_; } mutable_minor_to_major()168 std::vector<int64>* mutable_minor_to_major() { return &minor_to_major_; } 169 170 // Methods for accessing the tile field. tiles_size()171 int tiles_size() const { return tiles_.size(); } tiles(int index)172 const Tile& tiles(int index) const { return tiles_.at(index); } mutable_tiles(int index)173 Tile* mutable_tiles(int index) { return &tiles_.at(index); } add_tiles()174 Tile* add_tiles() { 175 tiles_.push_back(Tile()); 176 return &tiles_.back(); 177 } clear_tiles()178 Layout& clear_tiles() { 179 tiles_.clear(); 180 return *this; 181 } tiles()182 const std::vector<Tile>& tiles() const { return tiles_; } mutable_tiles()183 std::vector<Tile>* mutable_tiles() { return &tiles_; } 184 185 // Methods for accessing the int64 fields. max_sparse_elements()186 int64 max_sparse_elements() const { return max_sparse_elements_; } set_max_sparse_elements(int64 value)187 Layout& set_max_sparse_elements(int64 value) { 188 max_sparse_elements_ = value; 189 return *this; 190 } element_size_in_bits()191 int64 element_size_in_bits() const { return element_size_in_bits_; } set_element_size_in_bits(int64 value)192 Layout& set_element_size_in_bits(int64 value) { 193 element_size_in_bits_ = value; 194 return *this; 195 } 196 Swap(Layout * other)197 void Swap(Layout* other) { 198 using std::swap; 199 swap(*this, *other); 200 } 201 Clear()202 void Clear() { 203 format_ = INVALID_FORMAT; 204 minor_to_major_.clear(); 205 max_sparse_elements_ = 0; 206 element_size_in_bits_ = 0; 207 } 208 209 private: 210 // The format of this layout. 211 Format format_ = INVALID_FORMAT; 212 213 // Sequence of dimension numbers, from minor (fastest varying index) to major 214 // (slowest varying index). 215 std::vector<int64> minor_to_major_; 216 217 // The maximum number of elements that can be stored for SPARSE formats. This 218 // can be used to determine the maximum size in bytes of arrays stored in 219 // memory. This field must be zero unless the format is SPARSE. 220 int64 max_sparse_elements_ = 0; 221 222 // The tiles used in tiling-based layout. 223 std::vector<Tile> tiles_; 224 225 // The number of bits used to store an individual array element. 226 int64 element_size_in_bits_ = 0; 227 }; 228 229 std::ostream& operator<<(std::ostream& out, const Tile& Tile); 230 std::ostream& operator<<(std::ostream& out, const Layout& layout); 231 232 } // namespace xla 233 234 #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_ 235