• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
18 
19 #include <map>
20 #include <vector>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
29 
30 namespace xla {
31 namespace hlo_sharding_util {
32 
33 struct GatherParallelDims {
34   absl::InlinedVector<int64, 1> indices_parallel_dims;
35   absl::InlinedVector<int64, 1> operand_parallel_dims;
36   std::vector<int64> index_parallel_in_dim;
37 };
38 
39 // Returns true if the lhs sharding is preferable over the rhs sharding.
40 // The most specific sharding is tile maximal followed by single device tile
41 // maximal and finally replicated. This order aims to primarily reduce memory
42 // usage and secondly reduce total compute.
43 // Note: This does NOT provide a total ordering as we can have 2 different
44 // sharding with same preference level.
45 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs);
46 
47 // Tries to refine `to_merge` by combining with `old`. Returns if the final
48 // `to_merge` is more specific than `old`.
49 bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
50                    bool may_combine_partial_sharding);
51 
52 // Merges `to_merge` into `dst` only if they are compatible, and the merged
53 // sharding has >= minimum_tiles tiles. Returns if merging happened.
54 bool MergeShardingIfCompatible(const HloSharding& to_merge,
55                                int64_t minimum_tiles, HloSharding* dst);
56 
57 // Given a map<device, occurrence_count>, selects the device with higher
58 // occurrence count (if any). If top_count in not nullptr, it will receive the
59 // count of the dominant device returned.
60 absl::optional<int64> SelectDominantDevice(
61     const std::map<int64, int64>& device_map, int64* top_count);
62 
63 // Assigns all the instructions of a computation, to a given device.
64 // This API does not recurse into called computations, and does not assign
65 // instructions which already have sharding.
66 Status AssignComputationDevice(HloComputation* computation, int64_t device);
67 
68 // Given an instruction container, returns the device which is most commonly
69 // occurring among the instructions.
70 absl::optional<int64> GetMostOccurringDevice(
71     absl::Span<HloInstruction* const> instructions);
72 
73 // Given a set of computations, tries to extract the dominant device. A device
74 // is dominant if the combined occurrence among all the instructions of the
75 // input computations, is greater/equal than/to dominant_factor (real number
76 // from 0 to 1).
77 // This API does not recurse into called computations.
78 // If no device exists that satisfies the condition, the returned optional will
79 // hold no value.
80 StatusOr<absl::optional<int64>> GetDominantDevice(
81     absl::Span<HloComputation* const> computations, double dominant_factor);
82 
83 // Returns the HloSharding with the tile dimensions and tile assignment
84 // transposed based on the specified dimension numbers. In case of a tile
85 // maximal sharding returns the original sharding.
86 HloSharding TransposeSharding(const HloSharding& sharding,
87                               const std::vector<int64>& dimensions);
88 
89 // Returns the HloSharding with the tile shape reshaped based on the source and
90 // target shapes and the tile assignment adjusted to correspond to the new tile
91 // shape or absl::nullopt if the resulting reshape would create an invalid
92 // sharding (non continuous or non uniformly sized tiles). In case of a tile
93 // maximal sharding returns the original sharding.
94 absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
95                                             const Shape& target_shape,
96                                             const HloSharding& sharding);
97 
98 // Returns the HloSharding with the tile dimensions and tile assignment
99 // reversed based on the specified dimension numbers. In case of a tile
100 // maximal sharding returns the original sharding.
101 HloSharding ReverseSharding(const HloSharding& sharding,
102                             absl::Span<const int64> dimensions);
103 
104 // Returns a sharding tiled on unique dimension dim by reshaping the tile
105 // assignment of the sharding argument. Only dimensions in the dims span
106 // argument are considered for reshaping, the others are ignored.
107 // Assumptions: sharding is tile sharded, and dim must be included in dims.
108 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64_t dim,
109                                    absl::Span<const int64> dims);
110 
111 // Returns true if the provided module includes one or more instructions with
112 // a tile sharding.
113 bool ContainsTileSharding(const HloModule& module);
114 
115 // Returns the preferred output sharding for a gather op based on the sharding
116 // of the indces.
117 HloSharding GatherOutputSharding(const HloSharding& index_sharding,
118                                  const HloInstruction* hlo);
119 
120 // Returns the preferred index sharding for a gather op based on the sharding
121 // of the output.
122 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
123                                 const HloInstruction* hlo);
124 
125 // Returns a new HloSharding for a gather op so that only non offset dimensions
126 // are sharded. Assume "result" is returned by this function. It is ensured that
127 // "GetIndexSharding(result, hlo)" will have the same number of elements as
128 // "result".
129 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo);
130 
131 // Returns the preferred index sharding for a scatter op based on the sharding
132 // of the data.
133 HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
134                                  const HloInstruction* hlo);
135 
136 // Returns the preferred data sharding for a scatter op based on the sharding
137 // of the index.
138 HloSharding ScatterDataSharding(const HloSharding& index_sharding,
139                                 const HloInstruction* hlo);
140 
141 // Returns a new index sharding for a scatter op so that we only shard on first
142 // "number of scatter_window_dims" dimensions. Assume "result" is returned by
143 // this function. It is ensured that "ScatterDataSharding(result, hlo)" will
144 // have the same number of elements as "result".
145 HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
146                                           const HloInstruction& hlo);
147 
148 // Returns a new data sharding for a scatter op so that we only shard on
149 // scatter_window_dims. Assume "result" is returned by this function. It is
150 // ensured that "ScatterIndexSharding(result, hlo)" will have the same number of
151 // elements as "result".
152 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
153                                          const HloInstruction& hlo);
154 
155 // Returns an output sharding of gather by passing through the data operand's
156 // sharding.
157 absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
158     const HloSharding& data_operand_sharding, const HloInstruction& hlo,
159     const Shape& output_shape, const Shape& operand_shape);
160 
161 // Returns a data operand sharding of gather by passing through the output's
162 // sharding.
163 absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
164     const HloSharding& output_sharding, const HloInstruction& hlo);
165 
166 // Returns an output sharding of scatter by passing through the update operand's
167 // sharding.
168 absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
169     const HloSharding& update_sharding, const HloInstruction& hlo);
170 
171 // Returns an update operand sharding of scatter by passing through the output's
172 // sharding.
173 absl::optional<HloSharding> ScatterUpdateShardingFromOutput(
174     const HloSharding& output_sharding, const HloInstruction& hlo);
175 
176 // Returns an identity value and an HloOpcode for reduce computation of scatter
177 // instruction.
178 // - If computation is add/or, return 0/false with corresponding op code;
179 // - If computation is multiply/and, return 1/true with corresponding op code.
180 // - If computation is min/max, return max value/min value with corresponding op
181 //   code.
182 // - Otherwise, return error status.
183 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
184 IdentityValueAndHloOpcodeForScatterReduceComputation(
185     const HloScatterInstruction& scatter);
186 
187 // Given a sharding and a list of devices in the topology, return a
188 // list of the devices that `sharding` applies to.
189 std::vector<int64> DevicesForSharding(
190     const HloSharding& sharding, const std::vector<int64>& available_devices);
191 
192 // Returns a sharding that replicates data across devices along the given
193 // dimensions in the original sharding.
194 HloSharding PartiallyReplicateTiledShardingOnDims(
195     const HloSharding& sharding, absl::Span<const int64> dims_to_replicate);
196 
197 // Returns a sharding that replicates data across devices along all dimensions
198 // but the given ones to keep in the original sharding.
199 HloSharding PartiallyReplicateTiledShardingOnAllDimsExcept(
200     const HloSharding& sharding, absl::Span<const int64> dims_to_keep);
201 
202 // Returns a sharding the removes given tile dimensions.
203 //
204 // Precondition: if not tile maximal, the size of each tile dimension must be 1.
205 HloSharding RemoveShapeDimensions(const HloSharding& sharding,
206                                   const std::vector<int64>& dims_to_remove);
207 
208 // Similar to TransposeSharding(), but allows removing/adding non-partitioned
209 // dimensions. In src_to_tgt and tgt_to_src, -1 represents a non-existing
210 // dimension.
211 absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
212     const HloSharding& source, absl::Span<int64 const> src_to_tgt,
213     absl::Span<int64 const> tgt_to_src);
214 
215 // Returns the iota dimension if maybe_iota is an kIota instruction or
216 // equivalent to kIota.
217 absl::optional<int64> GetDimensionForIota(const HloInstruction* maybe_iota);
218 
219 // Returns identified parallel dimensions for Gather.
220 absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
221     const HloInstruction& hlo);
222 
223 // Returns the parallel dimensions of the output of a gather based on the
224 // parallel dimensions of the input.
225 absl::InlinedVector<int64, 1> GatherParallelOutputDims(
226     const HloInstruction& gather, const GatherParallelDims& parallel_dim);
227 
228 // Returns the parallel dimensions of the data operand of a gather with the
229 // order of the parallel dimensions matching that of the parallel dimensions
230 // of the output.
231 absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
232     const HloInstruction& gather, const GatherParallelDims& parallel_dims);
233 
234 }  // namespace hlo_sharding_util
235 }  // namespace xla
236 
237 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
238