• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // HLO shardings describe how an HLO instruction is split across multiple
17 // computations.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
21 
22 #include <map>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/protobuf_util.h"
30 #include "tensorflow/compiler/xla/shape_tree.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/lib/hash/hash.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace xla {
38 
39 // HLO shardings describe how an HLO instruction is split across multiple
40 // computations.
41 class HloSharding {
42  public:
43   // Creates a trivial sharding that replicates a maximal tile across all
44   // devices.
Replicate()45   static HloSharding Replicate() { return HloSharding(); }
46 
47   // Creates a sharding that emulates device placement; a tile shape equal to
48   // the input shape (one tile) assigned to a single device.
49   static HloSharding AssignDevice(int64 device_id);
50 
51   // Creates a new sharding which splits a shape into tiles amongst the devices
52   // specified by `tile_assignment`.
Tile(const Array<int64> & tile_assignment)53   static HloSharding Tile(const Array<int64>& tile_assignment) {
54     return HloSharding(tile_assignment);
55   }
56 
57   // Creates a new sharding which splits a one-dimensional input shape into
58   // `num_tiles` tiles.
59   static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles);
60 
61   // Creates a new sharding for a tuple type. The given ShapeTree must have
62   // elements for every leaf shape contained in the tuple.
63   static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings);
64 
65   // Creates a new sharding for a tuple type. The number of elements in
66   // shardings must match the number of leaf nodes in tuple_shape. For
67   // empty tuples, the shardings array must have one element.
68   static HloSharding Tuple(const Shape& tuple_shape,
69                            absl::Span<const HloSharding> shardings);
70 
71   // Creates a new sharding for a tuple type, with a single input sharding
72   // repeated on each leaf.
73   static HloSharding SingleTuple(const Shape& tuple_shape,
74                                  const HloSharding& sharding);
75 
76   // If shape is an array, returns sharding, otherwise returns the tuple shaped
77   // sharding with all the leaf nodes having the same input sharding.
78   static HloSharding Single(const Shape& shape, const HloSharding& sharding);
79 
80   // Create a new sharding from a protobuf OpSharding.
81   static StatusOr<HloSharding> FromProto(const OpSharding& proto);
82 
83   // Checks whether device is a reserved device number. A reserved device number
84   // has usually a special meaning, with dedicated handling logic.
IsReservedDevice(int64 device)85   static bool IsReservedDevice(int64 device) { return device < 0; }
86 
87   OpSharding ToProto() const;
88 
89   // Note that this string canonically has outer curly braces, e.g.
90   // "{replicated}".
91   string ToString() const;
92 
93   // Validate that this sharding can be applied to a tensor with shape `shape`.
94   Status Validate(const Shape& shape, int64 num_devices) const;
95 
96   // Returns true if the sharding has tuple type.
IsTuple()97   bool IsTuple() const { return tuple_; }
98 
99   // Returns true if the sharding is trivial: replicate on all devices.
IsReplicated()100   bool IsReplicated() const {
101     if (!IsTuple()) {
102       return replicated_;
103     }
104     return absl::c_all_of(
105         tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); });
106   }
107 
108   // Returns true if the tile size is the same as the input size.
IsTileMaximal()109   bool IsTileMaximal() const {
110     if (!IsTuple()) {
111       return maximal_;
112     }
113     return absl::c_all_of(tuple_elements_, [](const HloSharding& s) {
114       return s.IsTileMaximal();
115     });
116   }
117 
118   // Returns true if the sharding defines an operation on the given device.
119   bool UsesDevice(int64 device) const;
120 
121   // Retrieves a histogram of the devices used by the sharding. The returned
122   // map has the device number as key, and the occurrence count as value.
123   // If a sharding does not have a device, it will not be incuded in the
124   // histogram. The count argument, if not nullptr, will receive the total
125   // number of elements this sharding is made of (one for array, N leaves for
126   // tuples).
127   std::map<int64, int64> UsedDevices(int64* count) const;
128 
129   // Returns the tile that should be executed on the given device.
130   // REQUIRES: !IsTuple()
131   std::vector<int64> TileIndexForDevice(int64 device) const;
132 
133   // Returns the device that should execute the given tile.
134   // It is an error to call this if is_replicated() is true.
135   // REQUIRES: !IsTuple()
136   int64 DeviceForTileIndex(absl::Span<const int64> index) const;
137 
138   // Given a device ID, returns the offset within the specified shape of the
139   // tile that should be executed on the given core. This returns the lower
140   // extent of the tile in the input space.
141   // REQUIRES: !IsTuple()
142   std::vector<int64> TileOffsetForDevice(const Shape& shape,
143                                          int64 device) const;
144 
145   // Given a device ID, returns the limit within the specified shape of the
146   // tile that should be executed on the given core. This returns the upper
147   // extent of the tile in the input space.
148   // REQUIRES: !IsTuple()
149   std::vector<int64> TileLimitForDevice(const Shape& shape, int64 device) const;
150 
151   // Returns the single device this op operates on. If the sharding does not
152   // span a single device, the return value will be empty.
153   // In order for a sharding to span a single device, every leaf sharding must
154   // be maximal and not replicated, and the used device must match.
155   absl::optional<int64> UniqueDevice() const;
156 
157   // Retrieves the unique device or fails with a CHECK.
158   int64 GetUniqueDevice() const;
159 
160   // Returns true if this op only uses a single device.
HasUniqueDevice()161   bool HasUniqueDevice() const { return UniqueDevice().has_value(); }
162 
163   // Returns the ShapeTree containing the shardings for each element of this
164   // tuple, if IsTuple, or a ShapeTree with a single element containing this
165   // sharding. Only the leaf elements are populated. This creates a new
166   // ShapeTree object so is not cheap.
167   StatusOr<ShapeTree<HloSharding>> AsShapeTree(const Shape& shape) const;
GetAsShapeTree(const Shape & shape)168   ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const {
169     return AsShapeTree(shape).ValueOrDie();
170   }
171 
172   // Retrieves the sub sharding at a given index, out of a tuple sharding.
173   // REQUIRES: IsTuple()
174   HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const;
175 
176   // If the current sharding is a tuple sharding, return itself as result.
177   // Otherwise returns a tuple sharding for the input shape, with all the leaves
178   // having this object sharding.
179   StatusOr<HloSharding> GetTupleSharding(const Shape& shape) const;
180 
181   // Extracts the sharding that is common within the current sharding.
182   // If the current sharding is not a tuple sharding, the current sharding will
183   // be returned. If it is a tuple, and all the tuple elements are common, the
184   // common element will be returned. Otherwise the optional will contain no
185   // value.
186   absl::optional<HloSharding> ExtractSingleSharding() const;
187 
188   bool operator==(const HloSharding& other) const {
189     return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
190            tile_assignment_ == other.tile_assignment_ &&
191            tuple_elements_ == other.tuple_elements_;
192   }
193   bool operator!=(const HloSharding& other) const { return !(*this == other); }
194 
195   size_t Hash() const;
196 
197   struct Hasher {
operatorHasher198     size_t operator()(const HloSharding& sharding) const {
199       return sharding.Hash();
200     }
201   };
202 
203   // Gets the tile assignment tensor.
204   // REQUIRES: !IsReplicated() && !IsTuple()
tile_assignment()205   const Array<int64>& tile_assignment() const { return tile_assignment_; }
206 
207   // Returns the flattened list of all the leaf shardings in a tuple shape, by
208   // pre-order walk (ShapeTree iterator order).
209   // REQUIRES: IsTuple().
tuple_elements()210   const std::vector<HloSharding>& tuple_elements() const {
211     return tuple_elements_;
212   }
213 
214   // Gets the tile shape.
215   // REQUIRES: !IsTuple()
216   Shape TileShape(const Shape& shape) const;
217 
218  private:
HloSharding()219   HloSharding()
220       : replicated_(true),
221         maximal_(true),
222         tuple_(false),
223         tile_assignment_({0}) {}
224   // device_id values:
225   // -2: magic number to mean unassigned device, used by spatial partitioning
226   // -1: the id of the host
227   //  0 or positive: the id of a device
228   // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once
229   // we have fully switched to the side-effect tokens.
HloSharding(int64 device_id)230   explicit HloSharding(int64 device_id)
231       : replicated_(false),
232         maximal_(true),
233         tuple_(false),
234         tile_assignment_({1}, device_id) {}
HloSharding(const Array<int64> & tile_assignment)235   explicit HloSharding(const Array<int64>& tile_assignment)
236       : replicated_(false),
237         maximal_(false),
238         tuple_(false),
239         tile_assignment_(tile_assignment) {}
HloSharding(const std::vector<HloSharding> & tuple_shardings)240   explicit HloSharding(const std::vector<HloSharding>& tuple_shardings)
241       : replicated_(false),
242         maximal_(false),
243         tuple_(true),
244         tile_assignment_({0}),
245         tuple_elements_(tuple_shardings) {}
246 
247   // Checks that the number of elements in tuple_elements_ is consistent with
248   // the tuple shape passes as argument.
249   Status CheckLeafCount(const Shape& shape) const;
250 
251   // Internal helper to validate a tuple sharding.
252   Status ValidateTuple(const Shape& shape, int64 num_devices) const;
253 
254   // Internal helper to validate a non-tuple (leaf) sharding.
255   Status ValidateNonTuple(const Shape& shape, int64 num_devices) const;
256 
257   // Returns the number of tuple_elements_ entries to fit the shape.
258   static int64 RequiredLeaves(const Shape& shape);
259 
260   bool replicated_;
261   bool maximal_;
262   bool tuple_;
263   // This field is only used if replicated_ is false. If maximal_ is true, then
264   // the field contains a rank 1 array with a single element, which is the
265   // device the HLO is assigned to. If maximal_ is false, the field contains an
266   // array with the same rank as the corresponding HLO. The dimension sizes of
267   // the array describe the number of ways the HLO is partitioned along each
268   // dimension. The values of the array specify which device each tile of
269   // the HLO is assigned to. The index of each value determines which tile it
270   // takes.
271   // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is
272   // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and
273   // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the
274   // tile that contains the 2nd half of dimension 1 and the 1st half of
275   // dimension 3.
276   Array<int64> tile_assignment_;
277   // Only non-empty when tuple_ is true. If a tuple is empty then one entry is
278   // present for the root. This is a flattened list of all the leaf shardings in
279   // a tuple shape, by pre-order walk (ShapeTree iterator order).
280   std::vector<HloSharding> tuple_elements_;
281 };
282 
283 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding);
284 
285 }  // namespace xla
286 
287 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
288