1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO shardings describe how an HLO instruction is split across multiple 17 // computations. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 21 22 #include <map> 23 #include <string> 24 #include <vector> 25 26 #include "absl/types/span.h" 27 #include "tensorflow/compiler/xla/array.h" 28 #include "tensorflow/compiler/xla/literal.h" 29 #include "tensorflow/compiler/xla/protobuf_util.h" 30 #include "tensorflow/compiler/xla/shape_tree.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/lib/hash/hash.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/macros.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace xla { 38 39 // HLO shardings describe how an HLO instruction is split across multiple 40 // computations. 41 class HloSharding { 42 public: 43 // Creates a trivial sharding that replicates a maximal tile across all 44 // devices. Replicate()45 static HloSharding Replicate() { return HloSharding(); } 46 47 // Creates a sharding that emulates device placement; a tile shape equal to 48 // the input shape (one tile) assigned to a single device. 49 static HloSharding AssignDevice(int64 device_id); 50 51 // Creates a new sharding which splits a shape into tiles amongst the devices 52 // specified by `tile_assignment`. Tile(const Array<int64> & tile_assignment)53 static HloSharding Tile(const Array<int64>& tile_assignment) { 54 return HloSharding(tile_assignment); 55 } 56 57 // Creates a new sharding which splits a one-dimensional input shape into 58 // `num_tiles` tiles. 59 static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles); 60 61 // Creates a new sharding for a tuple type. The given ShapeTree must have 62 // elements for every leaf shape contained in the tuple. 63 static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings); 64 65 // Creates a new sharding for a tuple type. The number of elements in 66 // shardings must match the number of leaf nodes in tuple_shape. For 67 // empty tuples, the shardings array must have one element. 68 static HloSharding Tuple(const Shape& tuple_shape, 69 absl::Span<const HloSharding> shardings); 70 71 // Creates a new sharding for a tuple type, with a single input sharding 72 // repeated on each leaf. 73 static HloSharding SingleTuple(const Shape& tuple_shape, 74 const HloSharding& sharding); 75 76 // If shape is an array, returns sharding, otherwise returns the tuple shaped 77 // sharding with all the leaf nodes having the same input sharding. 78 static HloSharding Single(const Shape& shape, const HloSharding& sharding); 79 80 // Create a new sharding from a protobuf OpSharding. 81 static StatusOr<HloSharding> FromProto(const OpSharding& proto); 82 83 // Checks whether device is a reserved device number. A reserved device number 84 // has usually a special meaning, with dedicated handling logic. IsReservedDevice(int64 device)85 static bool IsReservedDevice(int64 device) { return device < 0; } 86 87 OpSharding ToProto() const; 88 89 // Note that this string canonically has outer curly braces, e.g. 90 // "{replicated}". 91 string ToString() const; 92 93 // Validate that this sharding can be applied to a tensor with shape `shape`. 94 Status Validate(const Shape& shape, int64 num_devices) const; 95 96 // Returns true if the sharding has tuple type. IsTuple()97 bool IsTuple() const { return tuple_; } 98 99 // Returns true if the sharding is trivial: replicate on all devices. IsReplicated()100 bool IsReplicated() const { 101 if (!IsTuple()) { 102 return replicated_; 103 } 104 return absl::c_all_of( 105 tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); }); 106 } 107 108 // Returns true if the tile size is the same as the input size. IsTileMaximal()109 bool IsTileMaximal() const { 110 if (!IsTuple()) { 111 return maximal_; 112 } 113 return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { 114 return s.IsTileMaximal(); 115 }); 116 } 117 118 // Returns true if the sharding defines an operation on the given device. 119 bool UsesDevice(int64 device) const; 120 121 // Retrieves a histogram of the devices used by the sharding. The returned 122 // map has the device number as key, and the occurrence count as value. 123 // If a sharding does not have a device, it will not be incuded in the 124 // histogram. The count argument, if not nullptr, will receive the total 125 // number of elements this sharding is made of (one for array, N leaves for 126 // tuples). 127 std::map<int64, int64> UsedDevices(int64* count) const; 128 129 // Returns the tile that should be executed on the given device. 130 // REQUIRES: !IsTuple() 131 std::vector<int64> TileIndexForDevice(int64 device) const; 132 133 // Returns the device that should execute the given tile. 134 // It is an error to call this if is_replicated() is true. 135 // REQUIRES: !IsTuple() 136 int64 DeviceForTileIndex(absl::Span<const int64> index) const; 137 138 // Given a device ID, returns the offset within the specified shape of the 139 // tile that should be executed on the given core. This returns the lower 140 // extent of the tile in the input space. 141 // REQUIRES: !IsTuple() 142 std::vector<int64> TileOffsetForDevice(const Shape& shape, 143 int64 device) const; 144 145 // Given a device ID, returns the limit within the specified shape of the 146 // tile that should be executed on the given core. This returns the upper 147 // extent of the tile in the input space. 148 // REQUIRES: !IsTuple() 149 std::vector<int64> TileLimitForDevice(const Shape& shape, int64 device) const; 150 151 // Returns the single device this op operates on. If the sharding does not 152 // span a single device, the return value will be empty. 153 // In order for a sharding to span a single device, every leaf sharding must 154 // be maximal and not replicated, and the used device must match. 155 absl::optional<int64> UniqueDevice() const; 156 157 // Retrieves the unique device or fails with a CHECK. 158 int64 GetUniqueDevice() const; 159 160 // Returns true if this op only uses a single device. HasUniqueDevice()161 bool HasUniqueDevice() const { return UniqueDevice().has_value(); } 162 163 // Returns the ShapeTree containing the shardings for each element of this 164 // tuple, if IsTuple, or a ShapeTree with a single element containing this 165 // sharding. Only the leaf elements are populated. This creates a new 166 // ShapeTree object so is not cheap. 167 StatusOr<ShapeTree<HloSharding>> AsShapeTree(const Shape& shape) const; GetAsShapeTree(const Shape & shape)168 ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const { 169 return AsShapeTree(shape).ValueOrDie(); 170 } 171 172 // Retrieves the sub sharding at a given index, out of a tuple sharding. 173 // REQUIRES: IsTuple() 174 HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; 175 176 // If the current sharding is a tuple sharding, return itself as result. 177 // Otherwise returns a tuple sharding for the input shape, with all the leaves 178 // having this object sharding. 179 StatusOr<HloSharding> GetTupleSharding(const Shape& shape) const; 180 181 // Extracts the sharding that is common within the current sharding. 182 // If the current sharding is not a tuple sharding, the current sharding will 183 // be returned. If it is a tuple, and all the tuple elements are common, the 184 // common element will be returned. Otherwise the optional will contain no 185 // value. 186 absl::optional<HloSharding> ExtractSingleSharding() const; 187 188 bool operator==(const HloSharding& other) const { 189 return replicated_ == other.replicated_ && maximal_ == other.maximal_ && 190 tile_assignment_ == other.tile_assignment_ && 191 tuple_elements_ == other.tuple_elements_; 192 } 193 bool operator!=(const HloSharding& other) const { return !(*this == other); } 194 195 size_t Hash() const; 196 197 struct Hasher { operatorHasher198 size_t operator()(const HloSharding& sharding) const { 199 return sharding.Hash(); 200 } 201 }; 202 203 // Gets the tile assignment tensor. 204 // REQUIRES: !IsReplicated() && !IsTuple() tile_assignment()205 const Array<int64>& tile_assignment() const { return tile_assignment_; } 206 207 // Returns the flattened list of all the leaf shardings in a tuple shape, by 208 // pre-order walk (ShapeTree iterator order). 209 // REQUIRES: IsTuple(). tuple_elements()210 const std::vector<HloSharding>& tuple_elements() const { 211 return tuple_elements_; 212 } 213 214 // Gets the tile shape. 215 // REQUIRES: !IsTuple() 216 Shape TileShape(const Shape& shape) const; 217 218 private: HloSharding()219 HloSharding() 220 : replicated_(true), 221 maximal_(true), 222 tuple_(false), 223 tile_assignment_({0}) {} 224 // device_id values: 225 // -2: magic number to mean unassigned device, used by spatial partitioning 226 // -1: the id of the host 227 // 0 or positive: the id of a device 228 // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once 229 // we have fully switched to the side-effect tokens. HloSharding(int64 device_id)230 explicit HloSharding(int64 device_id) 231 : replicated_(false), 232 maximal_(true), 233 tuple_(false), 234 tile_assignment_({1}, device_id) {} HloSharding(const Array<int64> & tile_assignment)235 explicit HloSharding(const Array<int64>& tile_assignment) 236 : replicated_(false), 237 maximal_(false), 238 tuple_(false), 239 tile_assignment_(tile_assignment) {} HloSharding(const std::vector<HloSharding> & tuple_shardings)240 explicit HloSharding(const std::vector<HloSharding>& tuple_shardings) 241 : replicated_(false), 242 maximal_(false), 243 tuple_(true), 244 tile_assignment_({0}), 245 tuple_elements_(tuple_shardings) {} 246 247 // Checks that the number of elements in tuple_elements_ is consistent with 248 // the tuple shape passes as argument. 249 Status CheckLeafCount(const Shape& shape) const; 250 251 // Internal helper to validate a tuple sharding. 252 Status ValidateTuple(const Shape& shape, int64 num_devices) const; 253 254 // Internal helper to validate a non-tuple (leaf) sharding. 255 Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; 256 257 // Returns the number of tuple_elements_ entries to fit the shape. 258 static int64 RequiredLeaves(const Shape& shape); 259 260 bool replicated_; 261 bool maximal_; 262 bool tuple_; 263 // This field is only used if replicated_ is false. If maximal_ is true, then 264 // the field contains a rank 1 array with a single element, which is the 265 // device the HLO is assigned to. If maximal_ is false, the field contains an 266 // array with the same rank as the corresponding HLO. The dimension sizes of 267 // the array describe the number of ways the HLO is partitioned along each 268 // dimension. The values of the array specify which device each tile of 269 // the HLO is assigned to. The index of each value determines which tile it 270 // takes. 271 // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is 272 // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and 273 // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the 274 // tile that contains the 2nd half of dimension 1 and the 1st half of 275 // dimension 3. 276 Array<int64> tile_assignment_; 277 // Only non-empty when tuple_ is true. If a tuple is empty then one entry is 278 // present for the root. This is a flattened list of all the leaf shardings in 279 // a tuple shape, by pre-order walk (ShapeTree iterator order). 280 std::vector<HloSharding> tuple_elements_; 281 }; 282 283 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); 284 285 } // namespace xla 286 287 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 288