1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO shardings describe how an HLO instruction is split across multiple 17 // computations. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 21 22 #include <map> 23 #include <string> 24 #include <vector> 25 26 #include "absl/types/span.h" 27 #include "tensorflow/compiler/xla/array.h" 28 #include "tensorflow/compiler/xla/literal.h" 29 #include "tensorflow/compiler/xla/protobuf_util.h" 30 #include "tensorflow/compiler/xla/shape_tree.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/lib/hash/hash.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/macros.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace xla { 38 39 // HLO shardings describe how an HLO instruction is split across multiple 40 // computations. 41 class HloSharding { 42 public: 43 // Creates a trivial sharding that replicates a maximal tile across all 44 // devices. Replicate()45 static HloSharding Replicate() { return HloSharding(); } 46 47 // Creates a sharding that emulates device placement; a tile shape equal to 48 // the input shape (one tile) assigned to a single device. 49 static HloSharding AssignDevice(int64 device_id); 50 51 // Creates a new sharding which splits a shape into tiles amongst the devices 52 // specified by `tile_assignment`. Tile(const Array<int64> & tile_assignment)53 static HloSharding Tile(const Array<int64>& tile_assignment) { 54 return HloSharding(tile_assignment); 55 } 56 57 // Creates a new sharding which splits a one-dimensional input shape into 58 // `num_tiles` tiles. 59 static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles); 60 61 // Creates a new sharding for a tuple type. The given ShapeTree must have 62 // elements for every leaf shape contained in the tuple. 63 static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings); 64 65 // Creates a new sharding for a tuple type. The number of elements in 66 // shardings must match the number of leaf nodes in tuple_shape. For 67 // empty tuples, the shardings array must have one element. 68 static HloSharding Tuple(const Shape& tuple_shape, 69 absl::Span<const HloSharding> shardings); 70 71 // Creates a new sharding for a tuple type, with a single input sharding 72 // repeated on each leaf. 73 static HloSharding SingleTuple(const Shape& tuple_shape, 74 const HloSharding& sharding); 75 76 // If shape is an array, returns sharding, otherwise returns the tuple shaped 77 // sharding with all the leaf nodes having the same input sharding. 78 static HloSharding Single(const Shape& shape, const HloSharding& sharding); 79 80 // Create a new sharding from a protobuf OpSharding. 81 static StatusOr<HloSharding> FromProto(const OpSharding& proto); 82 83 // Checks whether device is a reserved device number. A reserved device number 84 // has usually a special meaning, with dedicated handling logic. IsReservedDevice(int64 device)85 static bool IsReservedDevice(int64 device) { return device < 0; } 86 87 OpSharding ToProto() const; 88 89 // Note that this string canonically has outer curly braces, e.g. 90 // "{replicated}". 91 string ToString() const; 92 93 // Validate that this sharding can be applied to a tensor with shape `shape`. 94 Status Validate(const Shape& shape, int64 num_devices) const; 95 96 // Returns true if the sharding has tuple type. IsTuple()97 bool IsTuple() const { return tuple_; } 98 99 // Returns true if the sharding is trivial: replicate on all devices. IsReplicated()100 bool IsReplicated() const { 101 if (!IsTuple()) { 102 return replicated_; 103 } 104 return absl::c_all_of( 105 tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); }); 106 } 107 108 // Returns true if the tile size is the same as the input size. IsTileMaximal()109 bool IsTileMaximal() const { 110 if (!IsTuple()) { 111 return maximal_; 112 } 113 return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { 114 return s.IsTileMaximal(); 115 }); 116 } 117 118 // Returns true if the sharding defines an operation on the given device. 119 bool UsesDevice(int64 device) const; 120 121 // Retrieves a histogram of the devices used by the sharding. The returned 122 // map has the device number as key, and the occurrence count as value. 123 // If a sharding does not have a device, it will not be included in the 124 // histogram. The count argument, if not nullptr, will receive the total 125 // number of elements this sharding is made of (one for array, N leaves for 126 // tuples). 127 std::map<int64, int64> UsedDevices(int64* count) const; 128 129 // Returns the tile that should be executed on the given device. 130 // REQUIRES: !IsTuple() 131 std::vector<int64> TileIndexForDevice(int64 device) const; 132 133 // Returns the device that should execute the given tile. 134 // It is an error to call this if is_replicated() is true. 135 // REQUIRES: !IsTuple() 136 int64 DeviceForTileIndex(absl::Span<const int64> index) const; 137 138 // Given a device ID, returns the offset within the specified shape of the 139 // tile that should be executed on the given core. This returns the lower 140 // extent of the tile in the input space. 141 // REQUIRES: !IsTuple() 142 std::vector<int64> TileOffsetForDevice(const Shape& shape, 143 int64 device) const; 144 145 // Given a device ID, returns the limit within the specified shape of the 146 // tile that should be executed on the given core. This returns the upper 147 // extent of the tile in the input space. 148 // REQUIRES: !IsTuple() 149 std::vector<int64> TileLimitForDevice(const Shape& shape, int64 device) const; 150 151 // Returns the single device this op operates on. If the sharding does not 152 // span a single device, the return value will be empty. 153 // In order for a sharding to span a single device, every leaf sharding must 154 // be maximal and not replicated, and the used device must match. 155 absl::optional<int64> UniqueDevice() const; 156 157 // Retrieves the unique device or fails with a CHECK. 158 int64 GetUniqueDevice() const; 159 160 // Returns true if this op only uses a single device. HasUniqueDevice()161 bool HasUniqueDevice() const { return UniqueDevice().has_value(); } 162 163 // Returns the ShapeTree containing the shardings for each element of this 164 // tuple, if IsTuple, or a ShapeTree with a single element containing this 165 // sharding. Only the leaf elements are populated. This creates a new 166 // ShapeTree object so is not cheap. 167 StatusOr<ShapeTree<HloSharding>> AsShapeTree(const Shape& shape) const; GetAsShapeTree(const Shape & shape)168 ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const { 169 return AsShapeTree(shape).ValueOrDie(); 170 } 171 172 // Retrieves the sub sharding at a given index, out of a tuple sharding. 173 // REQUIRES: IsTuple() 174 HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; 175 176 // If the current sharding is a tuple sharding, return itself as result. 177 // Otherwise returns a tuple sharding for the input shape, with all the leaves 178 // having this object sharding. 179 StatusOr<HloSharding> GetTupleSharding(const Shape& shape) const; 180 181 // Extracts the sharding that is common within the current sharding. 182 // If the current sharding is not a tuple sharding, the current sharding will 183 // be returned. If it is a tuple, and all the tuple elements are common, the 184 // common element will be returned. Otherwise the optional will contain no 185 // value. 186 absl::optional<HloSharding> ExtractSingleSharding() const; 187 188 bool operator==(const HloSharding& other) const { 189 return replicated_ == other.replicated_ && maximal_ == other.maximal_ && 190 tile_assignment_ == other.tile_assignment_ && 191 tuple_elements_ == other.tuple_elements_; 192 } 193 bool operator!=(const HloSharding& other) const { return !(*this == other); } 194 195 size_t Hash() const; 196 197 struct Hasher { operatorHasher198 size_t operator()(const HloSharding& sharding) const { 199 return sharding.Hash(); 200 } 201 }; 202 203 // Gets the tile assignment tensor. 204 // REQUIRES: !IsReplicated() && !IsTuple() tile_assignment()205 const Array<int64>& tile_assignment() const { return tile_assignment_; } 206 207 // Returns the flattened list of all the leaf shardings in a tuple shape, by 208 // pre-order walk (ShapeTree iterator order). 209 // REQUIRES: IsTuple(). tuple_elements()210 std::vector<HloSharding>& tuple_elements() { return tuple_elements_; } tuple_elements()211 const std::vector<HloSharding>& tuple_elements() const { 212 return tuple_elements_; 213 } 214 215 // Gets the tile shape. 216 // REQUIRES: !IsTuple() 217 Shape TileShape(const Shape& shape) const; 218 219 private: HloSharding()220 HloSharding() 221 : replicated_(true), 222 maximal_(true), 223 tuple_(false), 224 tile_assignment_({0}) {} 225 // device_id values: 226 // -2: magic number to mean unassigned device, used by spatial partitioning 227 // -1: the id of the host 228 // 0 or positive: the id of a device 229 // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once 230 // we have fully switched to the side-effect tokens. HloSharding(int64 device_id)231 explicit HloSharding(int64 device_id) 232 : replicated_(false), 233 maximal_(true), 234 tuple_(false), 235 tile_assignment_({1}, device_id) {} HloSharding(const Array<int64> & tile_assignment)236 explicit HloSharding(const Array<int64>& tile_assignment) 237 : replicated_(false), 238 maximal_(false), 239 tuple_(false), 240 tile_assignment_(tile_assignment) {} HloSharding(const std::vector<HloSharding> & tuple_shardings)241 explicit HloSharding(const std::vector<HloSharding>& tuple_shardings) 242 : replicated_(false), 243 maximal_(false), 244 tuple_(true), 245 tile_assignment_({0}), 246 tuple_elements_(tuple_shardings) {} 247 248 // Checks that the number of elements in tuple_elements_ is consistent with 249 // the tuple shape passes as argument. 250 Status CheckLeafCount(const Shape& shape) const; 251 252 // Internal helper to validate a tuple sharding. 253 Status ValidateTuple(const Shape& shape, int64 num_devices) const; 254 255 // Internal helper to validate a non-tuple (leaf) sharding. 256 Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; 257 258 // Returns the number of tuple_elements_ entries to fit the shape. 259 static int64 RequiredLeaves(const Shape& shape); 260 261 bool replicated_; 262 bool maximal_; 263 bool tuple_; 264 // This field is only used if replicated_ is false. If maximal_ is true, then 265 // the field contains a rank 1 array with a single element, which is the 266 // device the HLO is assigned to. If maximal_ is false, the field contains an 267 // array with the same rank as the corresponding HLO. The dimension sizes of 268 // the array describe the number of ways the HLO is partitioned along each 269 // dimension. The values of the array specify which device each tile of 270 // the HLO is assigned to. The index of each value determines which tile it 271 // takes. 272 // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is 273 // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and 274 // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the 275 // tile that contains the 2nd half of dimension 1 and the 1st half of 276 // dimension 3. 277 Array<int64> tile_assignment_; 278 // Only non-empty when tuple_ is true. If a tuple is empty then one entry is 279 // present for the root. This is a flattened list of all the leaf shardings in 280 // a tuple shape, by pre-order walk (ShapeTree iterator order). 281 std::vector<HloSharding> tuple_elements_; 282 }; 283 284 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); 285 286 } // namespace xla 287 288 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 289