• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // HLO shardings describe how an HLO instruction is split across multiple
17 // computations.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
21 
22 #include <map>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/protobuf_util.h"
30 #include "tensorflow/compiler/xla/shape_tree.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/lib/hash/hash.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace xla {
38 
39 // HLO shardings describe how an HLO instruction is split across multiple
40 // computations.
41 class HloSharding {
42  public:
43   // Creates a trivial sharding that replicates a maximal tile across all
44   // devices.
45   static HloSharding Replicate(absl::Span<const OpMetadata> metadata = {}) {
46     return HloSharding(/*manual=*/false, /*replicated=*/true, metadata);
47   }
48 
49   // Creates a sharding that represents the op is manually partitioned.
50   static HloSharding Manual(absl::Span<const OpMetadata> metadata = {}) {
51     return HloSharding(/*manual=*/true, /*replicated=*/false, metadata);
52   }
53 
54   // Creates a sharding that emulates device placement; a tile shape equal to
55   // the input shape (one tile) assigned to a single device.
56   static HloSharding AssignDevice(int64_t device_id,
57                                   absl::Span<const OpMetadata> metadata = {});
58 
59   // Creates a new sharding which splits a shape into tiles amongst the devices
60   // specified by `tile_assignment`.
61   static HloSharding Tile(const Array<int64>& tile_assignment,
62                           absl::Span<const OpMetadata> metadata = {}) {
63     return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false,
64                        metadata);
65   }
66 
67   // Creates a new sharding where data is replicated within each replication
68   // group, and sharded across replication groups according to
69   // group_tile_assignment. Replication group members will be sorted.
70   static HloSharding PartialTile(
71       const Array<int64>& group_tile_assignment,
72       absl::Span<const absl::Span<const int64>> replication_groups,
73       absl::Span<const OpMetadata> metadata = {});
74 
75   // Creates a partially replicated tiled sharding with device-level tile
76   // assignment, where the last dimension is the additional replication
77   // dimension. Replication group members will be sorted.
78   static HloSharding PartialTile(
79       const Array<int64>& tile_assignment_last_dim_replicate,
80       absl::Span<const OpMetadata> metadata = {});
81 
82   // Creates a subgroup sharding with device-level tile assignment, the
83   // sharding type of each subgroup is defined by sharding_types.
84   static HloSharding Subgroup(const Array<int64>& tile_assignment,
85                               absl::Span<const OpSharding::Type> sharding_types,
86                               absl::Span<const OpMetadata> metadata = {});
87 
88   // Creates a new sharding which splits a one-dimensional input shape into
89   // `num_tiles` tiles.
90   static HloSharding Tile1D(const Shape& input_shape, int64_t num_tiles,
91                             absl::Span<const OpMetadata> metadata = {});
92 
93   // Creates a new sharding for a tuple type. The given ShapeTree must have
94   // elements for every leaf shape contained in the tuple.
95   static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings);
96 
97   // Creates a new sharding for a tuple type. The number of elements in
98   // shardings must match the number of leaf nodes in tuple_shape. For
99   // empty tuples, the shardings array must have one element.
100   static HloSharding Tuple(const Shape& tuple_shape,
101                            absl::Span<const HloSharding> shardings);
102 
103   // Creates a new sharding for a tuple type, with a single input sharding
104   // repeated on each leaf.
105   static HloSharding SingleTuple(const Shape& tuple_shape,
106                                  const HloSharding& sharding);
107 
108   // If shape is an array, returns sharding, otherwise returns the tuple shaped
109   // sharding with all the leaf nodes having the same input sharding.
110   static HloSharding Single(const Shape& shape, const HloSharding& sharding);
111 
112   // Create a new sharding from a protobuf OpSharding.
113   static StatusOr<HloSharding> FromProto(const OpSharding& proto);
114 
115   // Checks whether device is a reserved device number. A reserved device number
116   // has usually a special meaning, with dedicated handling logic.
IsReservedDevice(int64_t device)117   static bool IsReservedDevice(int64_t device) { return device < 0; }
118 
119   OpSharding ToProto() const;
120 
121   // Note that this string canonically has outer curly braces, e.g.
122   // "{replicated}".
123   string ToString(bool include_metadata = false) const;
124 
125   // Validate that this sharding can be applied to a tensor with shape `shape`.
126   Status Validate(const Shape& shape, int64_t num_devices) const;
127 
128   // Returns true if the sharding has tuple type.
IsTuple()129   bool IsTuple() const { return tuple_; }
130 
131   // Returns true if the sharding is trivial: replicate on all devices.
IsReplicated()132   bool IsReplicated() const {
133     if (!IsTuple()) {
134       return replicated_;
135     }
136     return absl::c_all_of(
137         tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); });
138   }
139 
140   // Returns true if the tile size is the same as the input size.
IsTileMaximal()141   bool IsTileMaximal() const {
142     if (!IsTuple()) {
143       return maximal_;
144     }
145     return absl::c_all_of(tuple_elements_, [](const HloSharding& s) {
146       return s.IsTileMaximal();
147     });
148   }
149 
150   // Returns whether the sharding represents manual partitioning.
IsManual()151   bool IsManual() const {
152     if (!IsTuple()) {
153       return manual_;
154     }
155     return absl::c_all_of(tuple_elements_,
156                           [](const HloSharding& s) { return s.IsManual(); });
157   }
158 
159   // Returns whether the sharding represents manual subgroup sharding.
IsManualSubgroup()160   bool IsManualSubgroup() const {
161     if (!IsTuple()) {
162       return absl::c_linear_search(sharding_types_, OpSharding::MANUAL);
163     }
164     return absl::c_all_of(tuple_elements_, [](const HloSharding& s) {
165       return s.IsManualSubgroup();
166     });
167   }
168 
169   // Returns weather the sharding represents a tiled sharding where the mapping
170   // between devices and tiles is represented through 'tile_assignment()'.
IsTiled()171   bool IsTiled() const { return !IsTileMaximal() && !IsManual(); }
172 
173   // Returns if the sharding has partial replication and partial sharding. If
174   // true, data is sharded according to other dimensions of tile_assignment(),
175   // but replicated across devices along the last dimension.
ReplicateOnLastTileDim()176   bool ReplicateOnLastTileDim() const { return replicate_on_last_tile_dim_; }
177 
178   // Returns true if the sharding defines an operation on the given device.
179   bool UsesDevice(int64_t device) const;
180 
181   // Retrieves a histogram of the devices used by the sharding. The returned
182   // map has the device number as key, and the occurrence count as value.
183   // If a sharding does not have a device, it will not be included in the
184   // histogram. The count argument, if not nullptr, will receive the total
185   // number of elements this sharding is made of (one for array, N leaves for
186   // tuples).
187   std::map<int64, int64> UsedDevices(int64* count) const;
188 
189   // Returns the tile that should be executed on the given device.
190   // REQUIRES: !IsTuple()
191   std::vector<int64> TileIndexForDevice(int64_t device) const;
192 
193   // Returns the device that should execute the given tile.
194   // It is an error to call this if is_replicated() is true.
195   // When ReplicateOnLastTileDim() == true, if index.size() == data rank, it
196   // returns the first device in that replicated subgroup; otherwise,
197   // index.size() should be the same as tile_assignment()'s rank and specifies
198   // the member of the replication subgroup.
199   // REQUIRES: !IsTuple()
200   int64 DeviceForTileIndex(absl::Span<const int64> index) const;
201 
202   // Given a device ID, returns the offset within the specified shape of the
203   // tile that should be executed on the given core. This returns the lower
204   // extent of the tile in the input space.
205   // REQUIRES: !IsTuple()
206   std::vector<int64> TileOffsetForDevice(const Shape& shape,
207                                          int64_t device) const;
208 
209   // Given a device ID, returns the limit within the specified shape of the
210   // tile that should be executed on the given core. This returns the upper
211   // extent of the tile in the input space.
212   // REQUIRES: !IsTuple()
213   std::vector<int64> TileLimitForDevice(const Shape& shape,
214                                         int64_t device) const;
215 
216   // Returns the single device this op operates on. If the sharding does not
217   // span a single device, the return value will be empty.
218   // In order for a sharding to span a single device, every leaf sharding must
219   // be maximal and not replicated, and the used device must match.
220   absl::optional<int64> UniqueDevice() const;
221 
222   // Retrieves the unique device or fails with a CHECK.
223   int64 GetUniqueDevice() const;
224 
225   // Returns true if this op only uses a single device.
HasUniqueDevice()226   bool HasUniqueDevice() const { return UniqueDevice().has_value(); }
227 
228   // Returns the ShapeTree containing the shardings for each element of this
229   // tuple, if IsTuple, or a ShapeTree with a single element containing this
230   // sharding. Only the leaf elements are populated. This creates a new
231   // ShapeTree object so is not cheap.
232   StatusOr<ShapeTree<HloSharding>> AsShapeTree(const Shape& shape) const;
GetAsShapeTree(const Shape & shape)233   ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const {
234     return AsShapeTree(shape).ValueOrDie();
235   }
236 
237   // Retrieves the sub sharding at a given index, out of a tuple sharding.
238   // REQUIRES: IsTuple()
239   HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const;
240 
241   // If the current sharding is a tuple sharding, return itself as result.
242   // Otherwise returns a tuple sharding for the input shape, with all the leaves
243   // having this object sharding.
244   StatusOr<HloSharding> GetTupleSharding(const Shape& shape) const;
245 
246   // Extracts the sharding that is common within the current sharding.
247   // If the current sharding is not a tuple sharding, the current sharding will
248   // be returned. If it is a tuple, and all the tuple elements are common, the
249   // common element will be returned. Otherwise the optional will contain no
250   // value.
251   absl::optional<HloSharding> ExtractSingleSharding() const;
252 
253   // Returns a copy of the sharding with no metadata. If sharding is of tuple
254   // type, sub shardings will have no metadata.
255   HloSharding WithoutMetadata() const;
256 
257   // Returns a copy of the sharding with specified metadata. If metadata is
258   // already present, that metadata will not be replaced unless `overwrite` is
259   // set to true. If sharding is of tuple type, sub shardings metadata will be
260   // assigned instead.
261   HloSharding WithMetadata(absl::Span<const OpMetadata> metadata,
262                            bool overwrite) const;
263 
264   bool operator==(const HloSharding& other) const {
265     return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
266            manual_ == other.manual_ &&
267            tile_assignment_ == other.tile_assignment_ &&
268            tuple_elements_ == other.tuple_elements_ &&
269            replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_ &&
270            sharding_types_ == other.sharding_types_;
271   }
272   bool operator!=(const HloSharding& other) const { return !(*this == other); }
273 
274   size_t Hash() const;
275 
276   struct Hasher {
operatorHasher277     size_t operator()(const HloSharding& sharding) const {
278       return sharding.Hash();
279     }
280   };
281 
282   // Gets the tile assignment tensor.
283   // REQUIRES: !IsReplicated() && !IsTuple()
tile_assignment()284   const Array<int64>& tile_assignment() const { return tile_assignment_; }
285 
286   // Gets the sharding types array.
287   // REQUIRES: !sharding_tyes.empty() && !IsTuple()
sharding_types()288   const std::vector<OpSharding::Type>& sharding_types() const {
289     return sharding_types_;
290   }
291 
292   // Returns the flattened list of all the leaf shardings in a tuple shape, by
293   // pre-order walk (ShapeTree iterator order).
294   // REQUIRES: IsTuple().
tuple_elements()295   std::vector<HloSharding>& tuple_elements() { return tuple_elements_; }
tuple_elements()296   const std::vector<HloSharding>& tuple_elements() const {
297     return tuple_elements_;
298   }
299 
300   // Gets the tile shape.
301   // REQUIRES: !IsTuple()
302   Shape TileShape(const Shape& shape) const;
303 
304   // Gets the tile shape on the device.
305   // REQUIRES: !IsTuple()
306   Shape TileShape(const Shape& shape, int64_t device) const;
307 
308   // Gets the number of tiles. If it has partial replication, this will not
309   // equal the device count.
310   int64 NumTiles() const;
311   // Like NumTiles() but considers only some specific dimensions passed as
312   // argument
313   int64 NumTiles(absl::Span<const int64> dims) const;
314 
315   // Gets metadata from sharding.
metadata()316   std::vector<OpMetadata>& metadata() { return metadata_; }
metadata()317   const std::vector<OpMetadata>& metadata() const { return metadata_; }
318 
319  private:
HloSharding(bool manual,bool replicated,absl::Span<const OpMetadata> metadata)320   explicit HloSharding(bool manual, bool replicated,
321                        absl::Span<const OpMetadata> metadata)
322       : replicated_(replicated),
323         maximal_(replicated),
324         tuple_(false),
325         manual_(manual),
326         tile_assignment_({0}),
327         replicate_on_last_tile_dim_(false),
328         metadata_(metadata.begin(), metadata.end()) {}
329   // device_id values:
330   // -2: magic number to mean unassigned device, used by spatial partitioning
331   // -1: the id of the host
332   //  0 or positive: the id of a device
333   // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once
334   // we have fully switched to the side-effect tokens.
HloSharding(int64_t device_id,absl::Span<const OpMetadata> metadata)335   explicit HloSharding(int64_t device_id, absl::Span<const OpMetadata> metadata)
336       : replicated_(false),
337         maximal_(true),
338         tuple_(false),
339         manual_(false),
340         tile_assignment_({1}, device_id),
341         replicate_on_last_tile_dim_(false),
342         metadata_(metadata.begin(), metadata.end()) {}
343   explicit HloSharding(const Array<int64>& tile_assignment,
344                        bool replicate_on_last_tile_dim,
345                        absl::Span<const OpMetadata> metadata = {})
replicated_(false)346       : replicated_(false),
347         maximal_(false),
348         tuple_(false),
349         manual_(false),
350         tile_assignment_(tile_assignment),
351         replicate_on_last_tile_dim_(replicate_on_last_tile_dim),
352         metadata_(metadata.begin(), metadata.end()) {}
353   explicit HloSharding(const Array<int64>& tile_assignment,
354                        absl::Span<const OpSharding::Type> sharding_types,
355                        absl::Span<const OpMetadata> metadata = {})
replicated_(false)356       : replicated_(false),
357         maximal_(false),
358         tuple_(false),
359         manual_(false),
360         tile_assignment_(tile_assignment),
361         replicate_on_last_tile_dim_(false),
362         metadata_(metadata.begin(), metadata.end()),
363         sharding_types_(sharding_types.begin(), sharding_types.end()) {}
HloSharding(const std::vector<HloSharding> & tuple_shardings)364   explicit HloSharding(const std::vector<HloSharding>& tuple_shardings)
365       : replicated_(false),
366         maximal_(false),
367         tuple_(true),
368         manual_(false),
369         tile_assignment_({0}),
370         tuple_elements_(tuple_shardings),
371         replicate_on_last_tile_dim_(false) {}
372 
373   // Checks that the number of elements in tuple_elements_ is consistent with
374   // the tuple shape passes as argument.
375   Status CheckLeafCount(const Shape& shape) const;
376 
377   // Internal helper to validate a tuple sharding.
378   Status ValidateTuple(const Shape& shape, int64_t num_devices) const;
379 
380   // Internal helper to validate a non-tuple (leaf) sharding.
381   Status ValidateNonTuple(const Shape& shape, int64_t num_devices) const;
382 
383   // Returns the number of tuple_elements_ entries to fit the shape.
384   static int64 RequiredLeaves(const Shape& shape);
385 
386   bool replicated_;
387   bool maximal_;
388   bool tuple_;
389   bool manual_;
390   // This field is only used if replicated_ is false. If maximal_ is true, then
391   // the field contains a rank 1 array with a single element, which is the
392   // device the HLO is assigned to. If maximal_ is false, the field contains an
393   // array with the same rank as the corresponding HLO. The dimension sizes of
394   // the array describe the number of ways the HLO is partitioned along each
395   // dimension. The values of the array specify which device each tile of
396   // the HLO is assigned to. The index of each value determines which tile it
397   // takes.
398   // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is
399   // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and
400   // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the
401   // tile that contains the 2nd half of dimension 1 and the 1st half of
402   // dimension 3.
403   Array<int64> tile_assignment_;
404   // Only non-empty when tuple_ is true. If a tuple is empty then one entry is
405   // present for the root. This is a flattened list of all the leaf shardings in
406   // a tuple shape, by pre-order walk (ShapeTree iterator order).
407   std::vector<HloSharding> tuple_elements_;
408   // This flag is to support partial replication and partial sharding. If it is
409   // true, tile_assignment_ will have an extra dimension in addition to the data
410   // shape rank, and the added last dimension represents the subgroups of
411   // replications, i.e., elements in slice [..., :] will be replicated.
412   bool replicate_on_last_tile_dim_;
413   // This field is used to track the source of this sharding, usually derived
414   // from instructions. Multiple metadata may be populated if sharding is
415   // combined with other shardings. Metadata are to not be populated when
416   // tuple_ == true and instead metadata should be set on individual tuple
417   // elements.
418   std::vector<OpMetadata> metadata_;
419   // This field is used to represented the sharding type of each subgroup.
420   // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={
421   // replicate, manual, unreduced}} means that each of the last 3 dimensions
422   // in [2,2,2,2] represents a subgrouping in replicate, manual,
423   std::vector<OpSharding::Type> sharding_types_;
424 };
425 
426 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding);
427 
428 }  // namespace xla
429 
430 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
431