1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO shardings describe how an HLO instruction is split across multiple 17 // computations. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 21 22 #include <map> 23 #include <string> 24 #include <vector> 25 26 #include "absl/types/span.h" 27 #include "tensorflow/compiler/xla/array.h" 28 #include "tensorflow/compiler/xla/literal.h" 29 #include "tensorflow/compiler/xla/protobuf_util.h" 30 #include "tensorflow/compiler/xla/shape_tree.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/lib/hash/hash.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/macros.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace xla { 38 39 // HLO shardings describe how an HLO instruction is split across multiple 40 // computations. 41 class HloSharding { 42 public: 43 // Creates a trivial sharding that replicates a maximal tile across all 44 // devices. 45 static HloSharding Replicate(absl::Span<const OpMetadata> metadata = {}) { 46 return HloSharding(/*manual=*/false, /*replicated=*/true, metadata); 47 } 48 49 // Creates a sharding that represents the op is manually partitioned. 50 static HloSharding Manual(absl::Span<const OpMetadata> metadata = {}) { 51 return HloSharding(/*manual=*/true, /*replicated=*/false, metadata); 52 } 53 54 // Creates a sharding that emulates device placement; a tile shape equal to 55 // the input shape (one tile) assigned to a single device. 56 static HloSharding AssignDevice(int64_t device_id, 57 absl::Span<const OpMetadata> metadata = {}); 58 59 // Creates a new sharding which splits a shape into tiles amongst the devices 60 // specified by `tile_assignment`. 61 static HloSharding Tile(const Array<int64>& tile_assignment, 62 absl::Span<const OpMetadata> metadata = {}) { 63 return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false, 64 metadata); 65 } 66 67 // Creates a new sharding where data is replicated within each replication 68 // group, and sharded across replication groups according to 69 // group_tile_assignment. Replication group members will be sorted. 70 static HloSharding PartialTile( 71 const Array<int64>& group_tile_assignment, 72 absl::Span<const absl::Span<const int64>> replication_groups, 73 absl::Span<const OpMetadata> metadata = {}); 74 75 // Creates a partially replicated tiled sharding with device-level tile 76 // assignment, where the last dimension is the additional replication 77 // dimension. Replication group members will be sorted. 78 static HloSharding PartialTile( 79 const Array<int64>& tile_assignment_last_dim_replicate, 80 absl::Span<const OpMetadata> metadata = {}); 81 82 // Creates a subgroup sharding with device-level tile assignment, the 83 // sharding type of each subgroup is defined by sharding_types. 84 static HloSharding Subgroup(const Array<int64>& tile_assignment, 85 absl::Span<const OpSharding::Type> sharding_types, 86 absl::Span<const OpMetadata> metadata = {}); 87 88 // Creates a new sharding which splits a one-dimensional input shape into 89 // `num_tiles` tiles. 90 static HloSharding Tile1D(const Shape& input_shape, int64_t num_tiles, 91 absl::Span<const OpMetadata> metadata = {}); 92 93 // Creates a new sharding for a tuple type. The given ShapeTree must have 94 // elements for every leaf shape contained in the tuple. 95 static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings); 96 97 // Creates a new sharding for a tuple type. The number of elements in 98 // shardings must match the number of leaf nodes in tuple_shape. For 99 // empty tuples, the shardings array must have one element. 100 static HloSharding Tuple(const Shape& tuple_shape, 101 absl::Span<const HloSharding> shardings); 102 103 // Creates a new sharding for a tuple type, with a single input sharding 104 // repeated on each leaf. 105 static HloSharding SingleTuple(const Shape& tuple_shape, 106 const HloSharding& sharding); 107 108 // If shape is an array, returns sharding, otherwise returns the tuple shaped 109 // sharding with all the leaf nodes having the same input sharding. 110 static HloSharding Single(const Shape& shape, const HloSharding& sharding); 111 112 // Create a new sharding from a protobuf OpSharding. 113 static StatusOr<HloSharding> FromProto(const OpSharding& proto); 114 115 // Checks whether device is a reserved device number. A reserved device number 116 // has usually a special meaning, with dedicated handling logic. IsReservedDevice(int64_t device)117 static bool IsReservedDevice(int64_t device) { return device < 0; } 118 119 OpSharding ToProto() const; 120 121 // Note that this string canonically has outer curly braces, e.g. 122 // "{replicated}". 123 string ToString(bool include_metadata = false) const; 124 125 // Validate that this sharding can be applied to a tensor with shape `shape`. 126 Status Validate(const Shape& shape, int64_t num_devices) const; 127 128 // Returns true if the sharding has tuple type. IsTuple()129 bool IsTuple() const { return tuple_; } 130 131 // Returns true if the sharding is trivial: replicate on all devices. IsReplicated()132 bool IsReplicated() const { 133 if (!IsTuple()) { 134 return replicated_; 135 } 136 return absl::c_all_of( 137 tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); }); 138 } 139 140 // Returns true if the tile size is the same as the input size. IsTileMaximal()141 bool IsTileMaximal() const { 142 if (!IsTuple()) { 143 return maximal_; 144 } 145 return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { 146 return s.IsTileMaximal(); 147 }); 148 } 149 150 // Returns whether the sharding represents manual partitioning. IsManual()151 bool IsManual() const { 152 if (!IsTuple()) { 153 return manual_; 154 } 155 return absl::c_all_of(tuple_elements_, 156 [](const HloSharding& s) { return s.IsManual(); }); 157 } 158 159 // Returns whether the sharding represents manual subgroup sharding. IsManualSubgroup()160 bool IsManualSubgroup() const { 161 if (!IsTuple()) { 162 return absl::c_linear_search(sharding_types_, OpSharding::MANUAL); 163 } 164 return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { 165 return s.IsManualSubgroup(); 166 }); 167 } 168 169 // Returns weather the sharding represents a tiled sharding where the mapping 170 // between devices and tiles is represented through 'tile_assignment()'. IsTiled()171 bool IsTiled() const { return !IsTileMaximal() && !IsManual(); } 172 173 // Returns if the sharding has partial replication and partial sharding. If 174 // true, data is sharded according to other dimensions of tile_assignment(), 175 // but replicated across devices along the last dimension. ReplicateOnLastTileDim()176 bool ReplicateOnLastTileDim() const { return replicate_on_last_tile_dim_; } 177 178 // Returns true if the sharding defines an operation on the given device. 179 bool UsesDevice(int64_t device) const; 180 181 // Retrieves a histogram of the devices used by the sharding. The returned 182 // map has the device number as key, and the occurrence count as value. 183 // If a sharding does not have a device, it will not be included in the 184 // histogram. The count argument, if not nullptr, will receive the total 185 // number of elements this sharding is made of (one for array, N leaves for 186 // tuples). 187 std::map<int64, int64> UsedDevices(int64* count) const; 188 189 // Returns the tile that should be executed on the given device. 190 // REQUIRES: !IsTuple() 191 std::vector<int64> TileIndexForDevice(int64_t device) const; 192 193 // Returns the device that should execute the given tile. 194 // It is an error to call this if is_replicated() is true. 195 // When ReplicateOnLastTileDim() == true, if index.size() == data rank, it 196 // returns the first device in that replicated subgroup; otherwise, 197 // index.size() should be the same as tile_assignment()'s rank and specifies 198 // the member of the replication subgroup. 199 // REQUIRES: !IsTuple() 200 int64 DeviceForTileIndex(absl::Span<const int64> index) const; 201 202 // Given a device ID, returns the offset within the specified shape of the 203 // tile that should be executed on the given core. This returns the lower 204 // extent of the tile in the input space. 205 // REQUIRES: !IsTuple() 206 std::vector<int64> TileOffsetForDevice(const Shape& shape, 207 int64_t device) const; 208 209 // Given a device ID, returns the limit within the specified shape of the 210 // tile that should be executed on the given core. This returns the upper 211 // extent of the tile in the input space. 212 // REQUIRES: !IsTuple() 213 std::vector<int64> TileLimitForDevice(const Shape& shape, 214 int64_t device) const; 215 216 // Returns the single device this op operates on. If the sharding does not 217 // span a single device, the return value will be empty. 218 // In order for a sharding to span a single device, every leaf sharding must 219 // be maximal and not replicated, and the used device must match. 220 absl::optional<int64> UniqueDevice() const; 221 222 // Retrieves the unique device or fails with a CHECK. 223 int64 GetUniqueDevice() const; 224 225 // Returns true if this op only uses a single device. HasUniqueDevice()226 bool HasUniqueDevice() const { return UniqueDevice().has_value(); } 227 228 // Returns the ShapeTree containing the shardings for each element of this 229 // tuple, if IsTuple, or a ShapeTree with a single element containing this 230 // sharding. Only the leaf elements are populated. This creates a new 231 // ShapeTree object so is not cheap. 232 StatusOr<ShapeTree<HloSharding>> AsShapeTree(const Shape& shape) const; GetAsShapeTree(const Shape & shape)233 ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const { 234 return AsShapeTree(shape).ValueOrDie(); 235 } 236 237 // Retrieves the sub sharding at a given index, out of a tuple sharding. 238 // REQUIRES: IsTuple() 239 HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; 240 241 // If the current sharding is a tuple sharding, return itself as result. 242 // Otherwise returns a tuple sharding for the input shape, with all the leaves 243 // having this object sharding. 244 StatusOr<HloSharding> GetTupleSharding(const Shape& shape) const; 245 246 // Extracts the sharding that is common within the current sharding. 247 // If the current sharding is not a tuple sharding, the current sharding will 248 // be returned. If it is a tuple, and all the tuple elements are common, the 249 // common element will be returned. Otherwise the optional will contain no 250 // value. 251 absl::optional<HloSharding> ExtractSingleSharding() const; 252 253 // Returns a copy of the sharding with no metadata. If sharding is of tuple 254 // type, sub shardings will have no metadata. 255 HloSharding WithoutMetadata() const; 256 257 // Returns a copy of the sharding with specified metadata. If metadata is 258 // already present, that metadata will not be replaced unless `overwrite` is 259 // set to true. If sharding is of tuple type, sub shardings metadata will be 260 // assigned instead. 261 HloSharding WithMetadata(absl::Span<const OpMetadata> metadata, 262 bool overwrite) const; 263 264 bool operator==(const HloSharding& other) const { 265 return replicated_ == other.replicated_ && maximal_ == other.maximal_ && 266 manual_ == other.manual_ && 267 tile_assignment_ == other.tile_assignment_ && 268 tuple_elements_ == other.tuple_elements_ && 269 replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_ && 270 sharding_types_ == other.sharding_types_; 271 } 272 bool operator!=(const HloSharding& other) const { return !(*this == other); } 273 274 size_t Hash() const; 275 276 struct Hasher { operatorHasher277 size_t operator()(const HloSharding& sharding) const { 278 return sharding.Hash(); 279 } 280 }; 281 282 // Gets the tile assignment tensor. 283 // REQUIRES: !IsReplicated() && !IsTuple() tile_assignment()284 const Array<int64>& tile_assignment() const { return tile_assignment_; } 285 286 // Gets the sharding types array. 287 // REQUIRES: !sharding_tyes.empty() && !IsTuple() sharding_types()288 const std::vector<OpSharding::Type>& sharding_types() const { 289 return sharding_types_; 290 } 291 292 // Returns the flattened list of all the leaf shardings in a tuple shape, by 293 // pre-order walk (ShapeTree iterator order). 294 // REQUIRES: IsTuple(). tuple_elements()295 std::vector<HloSharding>& tuple_elements() { return tuple_elements_; } tuple_elements()296 const std::vector<HloSharding>& tuple_elements() const { 297 return tuple_elements_; 298 } 299 300 // Gets the tile shape. 301 // REQUIRES: !IsTuple() 302 Shape TileShape(const Shape& shape) const; 303 304 // Gets the tile shape on the device. 305 // REQUIRES: !IsTuple() 306 Shape TileShape(const Shape& shape, int64_t device) const; 307 308 // Gets the number of tiles. If it has partial replication, this will not 309 // equal the device count. 310 int64 NumTiles() const; 311 // Like NumTiles() but considers only some specific dimensions passed as 312 // argument 313 int64 NumTiles(absl::Span<const int64> dims) const; 314 315 // Gets metadata from sharding. metadata()316 std::vector<OpMetadata>& metadata() { return metadata_; } metadata()317 const std::vector<OpMetadata>& metadata() const { return metadata_; } 318 319 private: HloSharding(bool manual,bool replicated,absl::Span<const OpMetadata> metadata)320 explicit HloSharding(bool manual, bool replicated, 321 absl::Span<const OpMetadata> metadata) 322 : replicated_(replicated), 323 maximal_(replicated), 324 tuple_(false), 325 manual_(manual), 326 tile_assignment_({0}), 327 replicate_on_last_tile_dim_(false), 328 metadata_(metadata.begin(), metadata.end()) {} 329 // device_id values: 330 // -2: magic number to mean unassigned device, used by spatial partitioning 331 // -1: the id of the host 332 // 0 or positive: the id of a device 333 // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once 334 // we have fully switched to the side-effect tokens. HloSharding(int64_t device_id,absl::Span<const OpMetadata> metadata)335 explicit HloSharding(int64_t device_id, absl::Span<const OpMetadata> metadata) 336 : replicated_(false), 337 maximal_(true), 338 tuple_(false), 339 manual_(false), 340 tile_assignment_({1}, device_id), 341 replicate_on_last_tile_dim_(false), 342 metadata_(metadata.begin(), metadata.end()) {} 343 explicit HloSharding(const Array<int64>& tile_assignment, 344 bool replicate_on_last_tile_dim, 345 absl::Span<const OpMetadata> metadata = {}) replicated_(false)346 : replicated_(false), 347 maximal_(false), 348 tuple_(false), 349 manual_(false), 350 tile_assignment_(tile_assignment), 351 replicate_on_last_tile_dim_(replicate_on_last_tile_dim), 352 metadata_(metadata.begin(), metadata.end()) {} 353 explicit HloSharding(const Array<int64>& tile_assignment, 354 absl::Span<const OpSharding::Type> sharding_types, 355 absl::Span<const OpMetadata> metadata = {}) replicated_(false)356 : replicated_(false), 357 maximal_(false), 358 tuple_(false), 359 manual_(false), 360 tile_assignment_(tile_assignment), 361 replicate_on_last_tile_dim_(false), 362 metadata_(metadata.begin(), metadata.end()), 363 sharding_types_(sharding_types.begin(), sharding_types.end()) {} HloSharding(const std::vector<HloSharding> & tuple_shardings)364 explicit HloSharding(const std::vector<HloSharding>& tuple_shardings) 365 : replicated_(false), 366 maximal_(false), 367 tuple_(true), 368 manual_(false), 369 tile_assignment_({0}), 370 tuple_elements_(tuple_shardings), 371 replicate_on_last_tile_dim_(false) {} 372 373 // Checks that the number of elements in tuple_elements_ is consistent with 374 // the tuple shape passes as argument. 375 Status CheckLeafCount(const Shape& shape) const; 376 377 // Internal helper to validate a tuple sharding. 378 Status ValidateTuple(const Shape& shape, int64_t num_devices) const; 379 380 // Internal helper to validate a non-tuple (leaf) sharding. 381 Status ValidateNonTuple(const Shape& shape, int64_t num_devices) const; 382 383 // Returns the number of tuple_elements_ entries to fit the shape. 384 static int64 RequiredLeaves(const Shape& shape); 385 386 bool replicated_; 387 bool maximal_; 388 bool tuple_; 389 bool manual_; 390 // This field is only used if replicated_ is false. If maximal_ is true, then 391 // the field contains a rank 1 array with a single element, which is the 392 // device the HLO is assigned to. If maximal_ is false, the field contains an 393 // array with the same rank as the corresponding HLO. The dimension sizes of 394 // the array describe the number of ways the HLO is partitioned along each 395 // dimension. The values of the array specify which device each tile of 396 // the HLO is assigned to. The index of each value determines which tile it 397 // takes. 398 // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is 399 // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and 400 // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the 401 // tile that contains the 2nd half of dimension 1 and the 1st half of 402 // dimension 3. 403 Array<int64> tile_assignment_; 404 // Only non-empty when tuple_ is true. If a tuple is empty then one entry is 405 // present for the root. This is a flattened list of all the leaf shardings in 406 // a tuple shape, by pre-order walk (ShapeTree iterator order). 407 std::vector<HloSharding> tuple_elements_; 408 // This flag is to support partial replication and partial sharding. If it is 409 // true, tile_assignment_ will have an extra dimension in addition to the data 410 // shape rank, and the added last dimension represents the subgroups of 411 // replications, i.e., elements in slice [..., :] will be replicated. 412 bool replicate_on_last_tile_dim_; 413 // This field is used to track the source of this sharding, usually derived 414 // from instructions. Multiple metadata may be populated if sharding is 415 // combined with other shardings. Metadata are to not be populated when 416 // tuple_ == true and instead metadata should be set on individual tuple 417 // elements. 418 std::vector<OpMetadata> metadata_; 419 // This field is used to represented the sharding type of each subgroup. 420 // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={ 421 // replicate, manual, unreduced}} means that each of the last 3 dimensions 422 // in [2,2,2,2] represents a subgrouping in replicate, manual, 423 std::vector<OpSharding::Type> sharding_types_; 424 }; 425 426 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); 427 428 } // namespace xla 429 430 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 431