1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <float.h>
17
18 #include <cmath>
19 #include <functional>
20 #include <memory>
21 #include <unordered_map>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/xla/client/lib/comparators.h"
28 #include "tensorflow/compiler/xla/comparison_util.h"
29 #include "tensorflow/compiler/xla/literal_util.h"
30 #include "tensorflow/compiler/xla/protobuf_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
35 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
36 #include "tensorflow/compiler/xla/service/shape_inference.h"
37 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
38 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42
43 namespace xla {
44 namespace spmd {
45
46 namespace {
47
48 // Pad each partition to have size that is multiplication of num_partitions.
49 // For example, if input is {0, 1, 2, 3, 4, 5} and num_partitions = 2,
50 // after padding, it becomes {0, 1, 2, 3} in partition 0 and {4, 5, 0, 0} in
51 // partition 1.
PadEachPartitionWithHaloExchange(HloInstruction * hlo,int64_t num_partitions,const HloSharding & sharding,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)52 absl::optional<HloInstruction*> PadEachPartitionWithHaloExchange(
53 HloInstruction* hlo, int64_t num_partitions, const HloSharding& sharding,
54 const SPMDCollectiveOpsCreator& collective_ops_creator,
55 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
56 int64_t size_per_partition = hlo->shape().dimensions().back();
57 int64_t size_padded_per_partition =
58 CeilOfRatio(size_per_partition, num_partitions) * num_partitions;
59 if (size_per_partition == size_padded_per_partition) {
60 return hlo;
61 }
62 // 1. Calculate left_halo size.
63 // left-halo size is 0
64 OffsetCalculation left_halo_size_function =
65 OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
66
67 // 2. Calculate right_halo size.
68 // D = size_padded_per_partition
69 // S = size_per_partition
70 // i = shard_ordinal
71 // right-halo size is D * (i + 2) - S * (i + 2) = (D - S) * i + 2 * (D - S)
72 OffsetCalculation right_halo_size_function =
73 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
74 size_padded_per_partition - size_per_partition,
75 2 * (size_padded_per_partition - size_per_partition), 1));
76
77 auto concat = hlo;
78 // 3. Halo exchange.
79 auto halo_exchange_result =
80 ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function,
81 hlo->shape().rank() - 1, sharding, collective_ops_creator,
82 next_channel_id, b);
83
84 if (halo_exchange_result.has_value()) {
85 concat = halo_exchange_result.value();
86 } else {
87 return absl::nullopt;
88 }
89
90 // 4. Slice the valid result.
91 // Slice offset is (D - S) * i
92 OffsetCalculation start_offset_on_padded_concat_calculation =
93 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
94 size_padded_per_partition - size_per_partition, 0, 1));
95 auto slice_shape = concat->shape();
96 slice_shape.set_dimensions(concat->shape().rank() - 1,
97 size_padded_per_partition);
98 auto zero_s32 =
99 b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
100 std::vector<HloInstruction*> slice_offsets(concat->shape().rank(), zero_s32);
101 auto partition_ordinals =
102 MakeTiledPartitionOrdinals(sharding, partition_id, b);
103 slice_offsets[concat->shape().rank() - 1] =
104 start_offset_on_padded_concat_calculation.Calculate(
105 partition_ordinals[concat->shape().rank() - 1], b);
106 return b->AddInstruction(HloInstruction::CreateDynamicSlice(
107 slice_shape, concat, slice_offsets, slice_shape.dimensions()));
108 }
109
110 // If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
111 // the data becomes {0, 2, 1, 3}.
ShuffleWithinEachPartitionUsingOneHot(HloInstruction * hlo,int64_t num_partitions,SpmdBuilder * b)112 HloInstruction* ShuffleWithinEachPartitionUsingOneHot(HloInstruction* hlo,
113 int64_t num_partitions,
114 SpmdBuilder* b) {
115 int64_t size_per_partition = hlo->shape().dimensions().back();
116 CHECK_EQ(size_per_partition % num_partitions, 0);
117 auto indices_iota = b->AddInstruction(HloInstruction::CreateIota(
118 ShapeUtil::MakeShape(S32, {size_per_partition}), 0));
119 auto reshape_indices_iota = b->AddInstruction(HloInstruction::CreateReshape(
120 ShapeUtil::MakeShape(
121 S32, {size_per_partition / num_partitions, num_partitions}),
122 indices_iota));
123 auto transpoe_indices_iota =
124 b->AddInstruction(HloInstruction::CreateTranspose(
125 ShapeUtil::MakeShape(
126 S32, {num_partitions, size_per_partition / num_partitions}),
127 reshape_indices_iota, {1, 0}));
128 auto one_hot_indices = b->AddInstruction(HloInstruction::CreateBroadcast(
129 ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}),
130 b->AddInstruction(HloInstruction::CreateReshape(
131 ShapeUtil::MakeShape(S32, {size_per_partition}),
132 transpoe_indices_iota)),
133 /*broadcast_dimensions=*/{1}));
134
135 auto partition_indices = b->AddInstruction(HloInstruction::CreateIota(
136 ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}), 0));
137
138 auto shuffle_one_hot = b->AddInstruction(HloInstruction::CreateConvert(
139 ShapeUtil::ChangeElementType(partition_indices->shape(),
140 hlo->shape().element_type()),
141 b->AddInstruction(HloInstruction::CreateCompare(
142 ShapeUtil::ChangeElementType(partition_indices->shape(), PRED),
143 one_hot_indices, partition_indices, ComparisonDirection::kEq))));
144
145 DotDimensionNumbers dot_dnums;
146 dot_dnums.add_lhs_contracting_dimensions(hlo->shape().rank() - 1);
147 dot_dnums.add_rhs_contracting_dimensions(0);
148 PrecisionConfig precision_config;
149 precision_config.mutable_operand_precision()->Resize(
150 2, PrecisionConfig::DEFAULT);
151 HloInstruction* dot = b->AddInstruction(HloInstruction::CreateDot(
152 hlo->shape(), hlo, shuffle_one_hot, dot_dnums, precision_config));
153 return dot;
154 }
155
156 // If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
157 // num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
158 // and partition 1 will have {1, 3, 5, 0}.
ShuffleDataWithAllToAll(HloInstruction * hlo,int64_t num_partitions,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b)159 HloInstruction* ShuffleDataWithAllToAll(
160 HloInstruction* hlo, int64_t num_partitions,
161 const SPMDCollectiveOpsCreator& collective_ops_creator,
162 int64* next_channel_id, SpmdBuilder* b) {
163 std::vector<std::vector<int64>> groups(1);
164 std::vector<int64> partition_subgroups(num_partitions);
165 std::iota(partition_subgroups.begin(), partition_subgroups.end(), 0);
166 groups[0] = partition_subgroups;
167 auto all_to_all = collective_ops_creator.create_cross_partition_all_to_all(
168 b, {hlo}, groups, (*next_channel_id)++, hlo->shape().rank() - 1);
169 return all_to_all;
170 }
171
GetCorrectionFactor(HloInstruction * hlo,int64_t num_partitions,HloInstruction * partition_id,SpmdBuilder * b)172 HloInstruction* GetCorrectionFactor(HloInstruction* hlo, int64_t num_partitions,
173 HloInstruction* partition_id,
174 SpmdBuilder* b) {
175 /* n = size_per_replica
176 m = num_partitions
177 factor = tf.exp(-2.0j * np.pi * tf.cast(position_index, tf.complex64) *
178 * tf.cast(tf.range(n), dtype=tf.complex64) /
179 (n * m))
180
181 */
182 auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
183 return b->AddInstruction(std::move(to_add));
184 };
185 int64_t per_replica_size = hlo->shape().dimensions().back();
186 auto constant_factor =
187 add_hlo(HloInstruction::CreateConstant(LiteralUtil::CreateR0(
188 complex64(0, -2.0 * M_PI / (num_partitions * per_replica_size)))));
189 constant_factor = add_hlo(HloInstruction::CreateBroadcast(
190 hlo->shape(), constant_factor, /*broadcast_dimensions=*/{}));
191 auto converted_partition_id = add_hlo(HloInstruction::CreateConvert(
192 ShapeUtil::ChangeElementType(partition_id->shape(),
193 hlo->shape().element_type()),
194 partition_id));
195 // TODO(wangtao): multipy before broadcast.
196 auto broadcast_partition_id = add_hlo(HloInstruction::CreateBroadcast(
197 hlo->shape(), converted_partition_id, /*broadcast_dimensions=*/{}));
198 auto exp_operand = add_hlo(
199 HloInstruction::CreateBinary(hlo->shape(), HloOpcode::kMultiply,
200 constant_factor, broadcast_partition_id));
201 auto iota = add_hlo(
202 HloInstruction::CreateIota(hlo->shape(), hlo->shape().rank() - 1));
203 exp_operand = add_hlo(HloInstruction::CreateBinary(
204 hlo->shape(), HloOpcode::kMultiply, exp_operand, iota));
205 return add_hlo(
206 HloInstruction::CreateUnary(hlo->shape(), HloOpcode::kExp, exp_operand));
207 }
208
209 // Sudo code for the while loop:
210 // def body(dest_transform, dest_core_position, source_transform,
211 // source_core_position, i):
212 // factor = tf.exp(-2.0j * np.pi *
213 // tf.cast(dest_core_position, tf.complex64) *
214 // tf.cast(source_core_position, tf.complex64) / num_partitions)
215 // dest_transform += factor * source_transform
216 // source_core_position = tf.raw_ops.CollectivePermute(
217 // input=source_core_position,
218 // source_target_pairs=source_target_pairs,
219 // name='source_core_position_permute')
220 // source_transform = tf.raw_ops.CollectivePermute(
221 // input=source_transform,
222 // source_target_pairs=source_target_pairs,
223 // name='source_transform_permute')
224 // i += 1
225 // return (dest_transform, dest_core_position, source_transform,
226 // source_core_position, i)
GetFinalFftUsingCollectivePermute(HloInstruction * hlo,const HloSharding & sharding,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t num_partitions,HloInstruction * partition_id,int64 * next_channel_id,HloModule * module,SpmdBuilder * b)227 HloInstruction* GetFinalFftUsingCollectivePermute(
228 HloInstruction* hlo, const HloSharding& sharding,
229 const SPMDCollectiveOpsCreator& collective_ops_creator,
230 int64_t num_partitions, HloInstruction* partition_id,
231 int64* next_channel_id, HloModule* module, SpmdBuilder* b) {
232 auto iteration = b->AddInstruction(
233 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
234 auto converted_partition_id = b->AddInstruction(HloInstruction::CreateConvert(
235 ShapeUtil::ChangeElementType(partition_id->shape(),
236 hlo->shape().element_type()),
237 partition_id));
238 // Buid while loop body.
239 SpmdBuilder body_b("fft_collective_permute_body", hlo);
240 auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
241 /*parameter_number=*/0,
242 ShapeUtil::MakeTupleShape(
243 {hlo->shape(), hlo->shape(), converted_partition_id->shape(),
244 converted_partition_id->shape(), iteration->shape()}),
245 "param"));
246 auto dest_transform = body_b.AddInstruction(
247 HloInstruction::CreateGetTupleElement(hlo->shape(), param, 0));
248 auto source_transform = body_b.AddInstruction(
249 HloInstruction::CreateGetTupleElement(hlo->shape(), param, 1));
250 auto dest_partition_id =
251 body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
252 converted_partition_id->shape(), param, 2));
253 auto source_partition_id =
254 body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
255 converted_partition_id->shape(), param, 3));
256 auto i = body_b.AddInstruction(
257 HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
258 /*
259 factor = tf.exp(-2.0j * np.pi *
260 tf.cast(dest_partiton_id, tf.complex64) *
261 tf.cast(source_partition_id, tf.complex64) /
262 num_partitions) dest_transform += factor * source_transform
263 */
264 auto constant_factor = body_b.AddInstruction(HloInstruction::CreateConstant(
265 LiteralUtil::CreateR0(complex64(0, -2.0 * M_PI / num_partitions))));
266
267 constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
268 constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
269 dest_partition_id));
270 constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
271 constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
272 source_partition_id));
273 auto phase_factor = body_b.AddInstruction(HloInstruction::CreateUnary(
274 constant_factor->shape(), HloOpcode::kExp, constant_factor));
275 phase_factor = body_b.AddInstruction(
276 HloInstruction::CreateBroadcast(hlo->shape(), phase_factor, {}));
277 auto phase_adjust_source_transform =
278 body_b.AddInstruction(HloInstruction::CreateBinary(
279 hlo->shape(), HloOpcode::kMultiply, phase_factor, source_transform));
280 dest_transform = body_b.AddInstruction(HloInstruction::CreateBinary(
281 hlo->shape(), HloOpcode::kAdd, phase_adjust_source_transform,
282 dest_transform));
283 // collective permute for source partition_id and source_transfrom.
284 std::vector<std::pair<int64, int64>> src_dst_pairs;
285 sharding.tile_assignment().Each(
286 [&](absl::Span<const int64> indices, int64_t src_device) {
287 std::vector<int64> target_indices(indices.begin(), indices.end());
288 target_indices.back() = (indices.back() + 1) % num_partitions;
289 int64_t dst_device = sharding.tile_assignment()(target_indices);
290 src_dst_pairs.emplace_back(src_device, dst_device);
291 });
292
293 source_partition_id =
294 collective_ops_creator.create_cross_partition_collective_permute(
295 &body_b, source_partition_id, src_dst_pairs, (*next_channel_id)++);
296
297 source_transform =
298 collective_ops_creator.create_cross_partition_collective_permute(
299 &body_b, source_transform, src_dst_pairs, (*next_channel_id)++);
300
301 // ++i
302 i = body_b.AddInstruction(HloInstruction::CreateBinary(
303 i->shape(), HloOpcode::kAdd, i,
304 body_b.AddInstruction(
305 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)))));
306 body_b.AddInstruction(
307 HloInstruction::CreateTuple({dest_transform, source_transform,
308 dest_partition_id, source_partition_id, i}));
309
310 // Build while loop conditions.
311 auto zero = CreateZero(hlo->shape(), b);
312 SpmdBuilder cond_b("fft_collective_permute_condition", hlo);
313 auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
314 /*parameter_number=*/0,
315 ShapeUtil::MakeTupleShape(
316 {hlo->shape(), hlo->shape(), converted_partition_id->shape(),
317 converted_partition_id->shape(), iteration->shape()}),
318 "param"));
319 auto cond_i = cond_b.AddInstruction(
320 HloInstruction::CreateGetTupleElement(iteration->shape(), cond_param, 4));
321 cond_b.AddInstruction(HloInstruction::CreateCompare(
322 ShapeUtil::MakeShape(PRED, {}), cond_i,
323 cond_b.AddInstruction(HloInstruction::CreateConstant(
324 LiteralUtil::CreateR0<uint32>(num_partitions))),
325 ComparisonDirection::kLt));
326
327 // Build while loop.
328 auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
329 cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
330 module->AddEmbeddedComputation(body_b.Build()),
331 b->AddInstruction(
332 HloInstruction::CreateTuple({zero, hlo, converted_partition_id,
333 converted_partition_id, iteration}))));
334
335 return b->AddInstruction(
336 HloInstruction::CreateGetTupleElement(hlo->shape(), while_loop, 0));
337 }
338
339 // Slice valid data in each partition.
SliceValidData(HloInstruction * hlo,const Shape & target_shape,SpmdBuilder * b)340 HloInstruction* SliceValidData(HloInstruction* hlo, const Shape& target_shape,
341 SpmdBuilder* b) {
342 std::vector<int64> start_indices(target_shape.rank(), 0);
343 std::vector<int64> strides(target_shape.rank(), 1);
344 return b->AddInstruction(HloInstruction::CreateSlice(
345 target_shape, hlo, start_indices, target_shape.dimensions(), strides));
346 }
347
348 } // namespace
349
350 // Distributed FFT using the algorithm described in go/tpu-spmd-fft.
HandleFft(HloInstruction * hlo)351 Status SpmdPartitioningVisitor::HandleFft(HloInstruction* hlo) {
352 if (hlo->operand(0)->shape().rank() < 3 || hlo->fft_type() != FftType::FFT) {
353 return DefaultAction(hlo);
354 }
355
356 // Only support input_length equals fft_length's case.
357 int64_t input_length = hlo->operand(0)->shape().dimensions().back();
358 int64_t fft_length = hlo->fft_length().back();
359 if (input_length != fft_length || input_length % num_partitions_ != 0) {
360 return DefaultAction(hlo);
361 }
362
363 // Support partition at the last dimension only.
364 if (!hlo->has_sharding() ||
365 hlo->sharding().tile_assignment().dimensions().back() !=
366 num_partitions_) {
367 return DefaultAction(hlo);
368 }
369
370 auto partitioned_input =
371 GetPartitionedHlo(hlo->operand(0))
372 .PadWithValue(CreateR0WithType(hlo->shape().element_type(), 0, &b_));
373
374 // 1.a. Use right halo exchange to shuffle data first and slice with
375 // valid data. Data shuffling ensures an in-order transform that the sequences
376 // of data before and after the transform are the same. The data shuffling
377 // requires the size of data per partition is divisible by the number of
378 // partitions. For example, If input is {0, 1, 2, 3, 4, 5} and
379 // num partitions is 2, after halo exchange partition 0 has {0, 1, 2, 3} and
380 // partition 1 has {4, 5, 0, 0}, where 0s in the partition 1 are padding data.
381 // Zeros paddings append zeros to the end of the full data.
382 auto result = partitioned_input.hlo();
383 auto padded_hlo = PadEachPartitionWithHaloExchange(
384 partitioned_input.hlo(), num_partitions_, hlo->sharding(),
385 partitioned_input.state().collective_ops_creator,
386 partitioned_input.state().next_channel_id,
387 partitioned_input.state().partition_id, partitioned_input.state().b);
388
389 if (padded_hlo.has_value()) {
390 result = padded_hlo.value();
391 }
392
393 // 1.b Shuffle data within each partition using one hot and matmul.
394 // If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
395 // the data becomes {0, 2, 1, 3}.
396 result = ShuffleWithinEachPartitionUsingOneHot(result, num_partitions_,
397 partitioned_input.state().b);
398 // 1.c all-to-all
399 // If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
400 // num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
401 // and partition 1 will have {1, 3, 5, 0}.
402 result = ShuffleDataWithAllToAll(
403 result, num_partitions_, partitioned_input.state().collective_ops_creator,
404 partitioned_input.state().next_channel_id, partitioned_input.state().b);
405 // 1.d Slice valid data in each partition.
406 result = SliceValidData(result, partitioned_input.hlo()->shape(), &b_);
407
408 // 2. Do local fft transform.
409 auto partitioned_fft_length = hlo->fft_length();
410 partitioned_fft_length.back() /= num_partitions_;
411 result = b_.AddInstruction(HloInstruction::CreateFft(
412 result->shape(), result, hlo->fft_type(), partitioned_fft_length));
413
414 // Multiply by correct factor for local phase ajustment.
415 auto correction_factor = GetCorrectionFactor(
416 result, num_partitions_, partitioned_input.state().partition_id,
417 partitioned_input.state().b);
418 result = b_.AddInstruction(HloInstruction::CreateBinary(
419 result->shape(), HloOpcode::kMultiply, result, correction_factor));
420
421 // 3. Second phase FFT with collective permute. fft_length = num_partitions.
422 result = GetFinalFftUsingCollectivePermute(
423 result, hlo->sharding(), partitioned_input.state().collective_ops_creator,
424 num_partitions_, partitioned_input.state().partition_id,
425 partitioned_input.state().next_channel_id, module_,
426 partitioned_input.state().b);
427
428 result->set_sharding(hlo->sharding());
429 auto partitioned_fft =
430 PartitionedHlo(result, hlo->shape(), partitioned_input.state());
431 SetPartitionedHlo(hlo, partitioned_fft);
432 return Status::OK();
433 }
434
435 } // namespace spmd
436 } // namespace xla
437