• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <float.h>
17 
18 #include <cmath>
19 #include <functional>
20 #include <memory>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/xla/client/lib/comparators.h"
28 #include "tensorflow/compiler/xla/comparison_util.h"
29 #include "tensorflow/compiler/xla/literal_util.h"
30 #include "tensorflow/compiler/xla/protobuf_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
35 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
36 #include "tensorflow/compiler/xla/service/shape_inference.h"
37 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
38 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 
43 namespace xla {
44 namespace spmd {
45 
46 namespace {
47 
48 // Pad each partition to have size that is multiplication of num_partitions.
49 // For example, if input is {0, 1, 2, 3, 4, 5} and num_partitions = 2,
50 // after padding, it becomes {0, 1, 2, 3} in partition 0 and {4, 5, 0, 0} in
51 // partition 1.
PadEachPartitionWithHaloExchange(HloInstruction * hlo,int64_t num_partitions,const HloSharding & sharding,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)52 absl::optional<HloInstruction*> PadEachPartitionWithHaloExchange(
53     HloInstruction* hlo, int64_t num_partitions, const HloSharding& sharding,
54     const SPMDCollectiveOpsCreator& collective_ops_creator,
55     int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
56   int64_t size_per_partition = hlo->shape().dimensions().back();
57   int64_t size_padded_per_partition =
58       CeilOfRatio(size_per_partition, num_partitions) * num_partitions;
59   if (size_per_partition == size_padded_per_partition) {
60     return hlo;
61   }
62   // 1. Calculate left_halo size.
63   // left-halo size is 0
64   OffsetCalculation left_halo_size_function =
65       OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
66 
67   // 2. Calculate right_halo size.
68   // D = size_padded_per_partition
69   // S = size_per_partition
70   // i = shard_ordinal
71   // right-halo size is D * (i + 2) - S * (i + 2) = (D - S) * i + 2 * (D - S)
72   OffsetCalculation right_halo_size_function =
73       OffsetCalculation(MultiplyAddDivideOffsetCalculation(
74           size_padded_per_partition - size_per_partition,
75           2 * (size_padded_per_partition - size_per_partition), 1));
76 
77   auto concat = hlo;
78   // 3. Halo exchange.
79   auto halo_exchange_result =
80       ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function,
81                    hlo->shape().rank() - 1, sharding, collective_ops_creator,
82                    next_channel_id, b);
83 
84   if (halo_exchange_result.has_value()) {
85     concat = halo_exchange_result.value();
86   } else {
87     return absl::nullopt;
88   }
89 
90   // 4. Slice the valid result.
91   // Slice offset is (D - S) * i
92   OffsetCalculation start_offset_on_padded_concat_calculation =
93       OffsetCalculation(MultiplyAddDivideOffsetCalculation(
94           size_padded_per_partition - size_per_partition, 0, 1));
95   auto slice_shape = concat->shape();
96   slice_shape.set_dimensions(concat->shape().rank() - 1,
97                              size_padded_per_partition);
98   auto zero_s32 =
99       b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
100   std::vector<HloInstruction*> slice_offsets(concat->shape().rank(), zero_s32);
101   auto partition_ordinals =
102       MakeTiledPartitionOrdinals(sharding, partition_id, b);
103   slice_offsets[concat->shape().rank() - 1] =
104       start_offset_on_padded_concat_calculation.Calculate(
105           partition_ordinals[concat->shape().rank() - 1], b);
106   return b->AddInstruction(HloInstruction::CreateDynamicSlice(
107       slice_shape, concat, slice_offsets, slice_shape.dimensions()));
108 }
109 
110 // If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
111 // the data becomes {0, 2, 1, 3}.
ShuffleWithinEachPartitionUsingOneHot(HloInstruction * hlo,int64_t num_partitions,SpmdBuilder * b)112 HloInstruction* ShuffleWithinEachPartitionUsingOneHot(HloInstruction* hlo,
113                                                       int64_t num_partitions,
114                                                       SpmdBuilder* b) {
115   int64_t size_per_partition = hlo->shape().dimensions().back();
116   CHECK_EQ(size_per_partition % num_partitions, 0);
117   auto indices_iota = b->AddInstruction(HloInstruction::CreateIota(
118       ShapeUtil::MakeShape(S32, {size_per_partition}), 0));
119   auto reshape_indices_iota = b->AddInstruction(HloInstruction::CreateReshape(
120       ShapeUtil::MakeShape(
121           S32, {size_per_partition / num_partitions, num_partitions}),
122       indices_iota));
123   auto transpoe_indices_iota =
124       b->AddInstruction(HloInstruction::CreateTranspose(
125           ShapeUtil::MakeShape(
126               S32, {num_partitions, size_per_partition / num_partitions}),
127           reshape_indices_iota, {1, 0}));
128   auto one_hot_indices = b->AddInstruction(HloInstruction::CreateBroadcast(
129       ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}),
130       b->AddInstruction(HloInstruction::CreateReshape(
131           ShapeUtil::MakeShape(S32, {size_per_partition}),
132           transpoe_indices_iota)),
133       /*broadcast_dimensions=*/{1}));
134 
135   auto partition_indices = b->AddInstruction(HloInstruction::CreateIota(
136       ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}), 0));
137 
138   auto shuffle_one_hot = b->AddInstruction(HloInstruction::CreateConvert(
139       ShapeUtil::ChangeElementType(partition_indices->shape(),
140                                    hlo->shape().element_type()),
141       b->AddInstruction(HloInstruction::CreateCompare(
142           ShapeUtil::ChangeElementType(partition_indices->shape(), PRED),
143           one_hot_indices, partition_indices, ComparisonDirection::kEq))));
144 
145   DotDimensionNumbers dot_dnums;
146   dot_dnums.add_lhs_contracting_dimensions(hlo->shape().rank() - 1);
147   dot_dnums.add_rhs_contracting_dimensions(0);
148   PrecisionConfig precision_config;
149   precision_config.mutable_operand_precision()->Resize(
150       2, PrecisionConfig::DEFAULT);
151   HloInstruction* dot = b->AddInstruction(HloInstruction::CreateDot(
152       hlo->shape(), hlo, shuffle_one_hot, dot_dnums, precision_config));
153   return dot;
154 }
155 
156 // If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
157 // num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
158 // and partition 1 will have {1, 3, 5, 0}.
ShuffleDataWithAllToAll(HloInstruction * hlo,int64_t num_partitions,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b)159 HloInstruction* ShuffleDataWithAllToAll(
160     HloInstruction* hlo, int64_t num_partitions,
161     const SPMDCollectiveOpsCreator& collective_ops_creator,
162     int64* next_channel_id, SpmdBuilder* b) {
163   std::vector<std::vector<int64>> groups(1);
164   std::vector<int64> partition_subgroups(num_partitions);
165   std::iota(partition_subgroups.begin(), partition_subgroups.end(), 0);
166   groups[0] = partition_subgroups;
167   auto all_to_all = collective_ops_creator.create_cross_partition_all_to_all(
168       b, {hlo}, groups, (*next_channel_id)++, hlo->shape().rank() - 1);
169   return all_to_all;
170 }
171 
GetCorrectionFactor(HloInstruction * hlo,int64_t num_partitions,HloInstruction * partition_id,SpmdBuilder * b)172 HloInstruction* GetCorrectionFactor(HloInstruction* hlo, int64_t num_partitions,
173                                     HloInstruction* partition_id,
174                                     SpmdBuilder* b) {
175   /* n = size_per_replica
176      m = num_partitions
177   factor = tf.exp(-2.0j * np.pi * tf.cast(position_index, tf.complex64) *
178                     * tf.cast(tf.range(n), dtype=tf.complex64) /
179                     (n * m))
180 
181   */
182   auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
183     return b->AddInstruction(std::move(to_add));
184   };
185   int64_t per_replica_size = hlo->shape().dimensions().back();
186   auto constant_factor =
187       add_hlo(HloInstruction::CreateConstant(LiteralUtil::CreateR0(
188           complex64(0, -2.0 * M_PI / (num_partitions * per_replica_size)))));
189   constant_factor = add_hlo(HloInstruction::CreateBroadcast(
190       hlo->shape(), constant_factor, /*broadcast_dimensions=*/{}));
191   auto converted_partition_id = add_hlo(HloInstruction::CreateConvert(
192       ShapeUtil::ChangeElementType(partition_id->shape(),
193                                    hlo->shape().element_type()),
194       partition_id));
195   // TODO(wangtao): multipy before broadcast.
196   auto broadcast_partition_id = add_hlo(HloInstruction::CreateBroadcast(
197       hlo->shape(), converted_partition_id, /*broadcast_dimensions=*/{}));
198   auto exp_operand = add_hlo(
199       HloInstruction::CreateBinary(hlo->shape(), HloOpcode::kMultiply,
200                                    constant_factor, broadcast_partition_id));
201   auto iota = add_hlo(
202       HloInstruction::CreateIota(hlo->shape(), hlo->shape().rank() - 1));
203   exp_operand = add_hlo(HloInstruction::CreateBinary(
204       hlo->shape(), HloOpcode::kMultiply, exp_operand, iota));
205   return add_hlo(
206       HloInstruction::CreateUnary(hlo->shape(), HloOpcode::kExp, exp_operand));
207 }
208 
209 // Sudo code for the while loop:
210 // def body(dest_transform, dest_core_position, source_transform,
211 //             source_core_position, i):
212 //      factor = tf.exp(-2.0j * np.pi  *
213 //                      tf.cast(dest_core_position, tf.complex64) *
214 //                tf.cast(source_core_position, tf.complex64) / num_partitions)
215 //      dest_transform += factor * source_transform
216 //      source_core_position = tf.raw_ops.CollectivePermute(
217 //          input=source_core_position,
218 //          source_target_pairs=source_target_pairs,
219 //          name='source_core_position_permute')
220 //      source_transform = tf.raw_ops.CollectivePermute(
221 //          input=source_transform,
222 //          source_target_pairs=source_target_pairs,
223 //          name='source_transform_permute')
224 //      i += 1
225 //      return (dest_transform, dest_core_position, source_transform,
226 //              source_core_position, i)
GetFinalFftUsingCollectivePermute(HloInstruction * hlo,const HloSharding & sharding,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t num_partitions,HloInstruction * partition_id,int64 * next_channel_id,HloModule * module,SpmdBuilder * b)227 HloInstruction* GetFinalFftUsingCollectivePermute(
228     HloInstruction* hlo, const HloSharding& sharding,
229     const SPMDCollectiveOpsCreator& collective_ops_creator,
230     int64_t num_partitions, HloInstruction* partition_id,
231     int64* next_channel_id, HloModule* module, SpmdBuilder* b) {
232   auto iteration = b->AddInstruction(
233       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
234   auto converted_partition_id = b->AddInstruction(HloInstruction::CreateConvert(
235       ShapeUtil::ChangeElementType(partition_id->shape(),
236                                    hlo->shape().element_type()),
237       partition_id));
238   // Buid while loop body.
239   SpmdBuilder body_b("fft_collective_permute_body", hlo);
240   auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
241       /*parameter_number=*/0,
242       ShapeUtil::MakeTupleShape(
243           {hlo->shape(), hlo->shape(), converted_partition_id->shape(),
244            converted_partition_id->shape(), iteration->shape()}),
245       "param"));
246   auto dest_transform = body_b.AddInstruction(
247       HloInstruction::CreateGetTupleElement(hlo->shape(), param, 0));
248   auto source_transform = body_b.AddInstruction(
249       HloInstruction::CreateGetTupleElement(hlo->shape(), param, 1));
250   auto dest_partition_id =
251       body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
252           converted_partition_id->shape(), param, 2));
253   auto source_partition_id =
254       body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
255           converted_partition_id->shape(), param, 3));
256   auto i = body_b.AddInstruction(
257       HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
258   /*
259     factor = tf.exp(-2.0j * np.pi  *
260                       tf.cast(dest_partiton_id, tf.complex64) *
261                       tf.cast(source_partition_id, tf.complex64) /
262     num_partitions) dest_transform += factor * source_transform
263   */
264   auto constant_factor = body_b.AddInstruction(HloInstruction::CreateConstant(
265       LiteralUtil::CreateR0(complex64(0, -2.0 * M_PI / num_partitions))));
266 
267   constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
268       constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
269       dest_partition_id));
270   constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
271       constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
272       source_partition_id));
273   auto phase_factor = body_b.AddInstruction(HloInstruction::CreateUnary(
274       constant_factor->shape(), HloOpcode::kExp, constant_factor));
275   phase_factor = body_b.AddInstruction(
276       HloInstruction::CreateBroadcast(hlo->shape(), phase_factor, {}));
277   auto phase_adjust_source_transform =
278       body_b.AddInstruction(HloInstruction::CreateBinary(
279           hlo->shape(), HloOpcode::kMultiply, phase_factor, source_transform));
280   dest_transform = body_b.AddInstruction(HloInstruction::CreateBinary(
281       hlo->shape(), HloOpcode::kAdd, phase_adjust_source_transform,
282       dest_transform));
283   // collective permute for source partition_id and source_transfrom.
284   std::vector<std::pair<int64, int64>> src_dst_pairs;
285   sharding.tile_assignment().Each(
286       [&](absl::Span<const int64> indices, int64_t src_device) {
287         std::vector<int64> target_indices(indices.begin(), indices.end());
288         target_indices.back() = (indices.back() + 1) % num_partitions;
289         int64_t dst_device = sharding.tile_assignment()(target_indices);
290         src_dst_pairs.emplace_back(src_device, dst_device);
291       });
292 
293   source_partition_id =
294       collective_ops_creator.create_cross_partition_collective_permute(
295           &body_b, source_partition_id, src_dst_pairs, (*next_channel_id)++);
296 
297   source_transform =
298       collective_ops_creator.create_cross_partition_collective_permute(
299           &body_b, source_transform, src_dst_pairs, (*next_channel_id)++);
300 
301   // ++i
302   i = body_b.AddInstruction(HloInstruction::CreateBinary(
303       i->shape(), HloOpcode::kAdd, i,
304       body_b.AddInstruction(
305           HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)))));
306   body_b.AddInstruction(
307       HloInstruction::CreateTuple({dest_transform, source_transform,
308                                    dest_partition_id, source_partition_id, i}));
309 
310   // Build while loop conditions.
311   auto zero = CreateZero(hlo->shape(), b);
312   SpmdBuilder cond_b("fft_collective_permute_condition", hlo);
313   auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
314       /*parameter_number=*/0,
315       ShapeUtil::MakeTupleShape(
316           {hlo->shape(), hlo->shape(), converted_partition_id->shape(),
317            converted_partition_id->shape(), iteration->shape()}),
318       "param"));
319   auto cond_i = cond_b.AddInstruction(
320       HloInstruction::CreateGetTupleElement(iteration->shape(), cond_param, 4));
321   cond_b.AddInstruction(HloInstruction::CreateCompare(
322       ShapeUtil::MakeShape(PRED, {}), cond_i,
323       cond_b.AddInstruction(HloInstruction::CreateConstant(
324           LiteralUtil::CreateR0<uint32>(num_partitions))),
325       ComparisonDirection::kLt));
326 
327   // Build while loop.
328   auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
329       cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
330       module->AddEmbeddedComputation(body_b.Build()),
331       b->AddInstruction(
332           HloInstruction::CreateTuple({zero, hlo, converted_partition_id,
333                                        converted_partition_id, iteration}))));
334 
335   return b->AddInstruction(
336       HloInstruction::CreateGetTupleElement(hlo->shape(), while_loop, 0));
337 }
338 
339 // Slice valid data in each partition.
SliceValidData(HloInstruction * hlo,const Shape & target_shape,SpmdBuilder * b)340 HloInstruction* SliceValidData(HloInstruction* hlo, const Shape& target_shape,
341                                SpmdBuilder* b) {
342   std::vector<int64> start_indices(target_shape.rank(), 0);
343   std::vector<int64> strides(target_shape.rank(), 1);
344   return b->AddInstruction(HloInstruction::CreateSlice(
345       target_shape, hlo, start_indices, target_shape.dimensions(), strides));
346 }
347 
348 }  // namespace
349 
350 // Distributed FFT using the algorithm described in go/tpu-spmd-fft.
HandleFft(HloInstruction * hlo)351 Status SpmdPartitioningVisitor::HandleFft(HloInstruction* hlo) {
352   if (hlo->operand(0)->shape().rank() < 3 || hlo->fft_type() != FftType::FFT) {
353     return DefaultAction(hlo);
354   }
355 
356   // Only support input_length equals fft_length's case.
357   int64_t input_length = hlo->operand(0)->shape().dimensions().back();
358   int64_t fft_length = hlo->fft_length().back();
359   if (input_length != fft_length || input_length % num_partitions_ != 0) {
360     return DefaultAction(hlo);
361   }
362 
363   // Support partition at the last dimension only.
364   if (!hlo->has_sharding() ||
365       hlo->sharding().tile_assignment().dimensions().back() !=
366           num_partitions_) {
367     return DefaultAction(hlo);
368   }
369 
370   auto partitioned_input =
371       GetPartitionedHlo(hlo->operand(0))
372           .PadWithValue(CreateR0WithType(hlo->shape().element_type(), 0, &b_));
373 
374   // 1.a. Use right halo exchange to shuffle data first and slice with
375   // valid data. Data shuffling ensures an in-order transform that the sequences
376   // of data before and after the transform are the same. The data shuffling
377   // requires the size of data per partition is divisible by the number of
378   // partitions. For example, If input is {0, 1, 2, 3, 4, 5} and
379   // num partitions is 2, after halo exchange partition 0 has {0, 1, 2, 3} and
380   // partition 1 has {4, 5, 0, 0}, where 0s in the partition 1 are padding data.
381   // Zeros paddings append zeros to the end of the full data.
382   auto result = partitioned_input.hlo();
383   auto padded_hlo = PadEachPartitionWithHaloExchange(
384       partitioned_input.hlo(), num_partitions_, hlo->sharding(),
385       partitioned_input.state().collective_ops_creator,
386       partitioned_input.state().next_channel_id,
387       partitioned_input.state().partition_id, partitioned_input.state().b);
388 
389   if (padded_hlo.has_value()) {
390     result = padded_hlo.value();
391   }
392 
393   // 1.b Shuffle data within each partition using one hot and matmul.
394   // If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
395   // the data becomes {0, 2, 1, 3}.
396   result = ShuffleWithinEachPartitionUsingOneHot(result, num_partitions_,
397                                                  partitioned_input.state().b);
398   // 1.c all-to-all
399   // If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
400   // num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
401   // and partition 1 will have {1, 3, 5, 0}.
402   result = ShuffleDataWithAllToAll(
403       result, num_partitions_, partitioned_input.state().collective_ops_creator,
404       partitioned_input.state().next_channel_id, partitioned_input.state().b);
405   // 1.d Slice valid data in each partition.
406   result = SliceValidData(result, partitioned_input.hlo()->shape(), &b_);
407 
408   // 2. Do local fft transform.
409   auto partitioned_fft_length = hlo->fft_length();
410   partitioned_fft_length.back() /= num_partitions_;
411   result = b_.AddInstruction(HloInstruction::CreateFft(
412       result->shape(), result, hlo->fft_type(), partitioned_fft_length));
413 
414   // Multiply by correct factor for local phase ajustment.
415   auto correction_factor = GetCorrectionFactor(
416       result, num_partitions_, partitioned_input.state().partition_id,
417       partitioned_input.state().b);
418   result = b_.AddInstruction(HloInstruction::CreateBinary(
419       result->shape(), HloOpcode::kMultiply, result, correction_factor));
420 
421   // 3. Second phase FFT with collective permute. fft_length = num_partitions.
422   result = GetFinalFftUsingCollectivePermute(
423       result, hlo->sharding(), partitioned_input.state().collective_ops_creator,
424       num_partitions_, partitioned_input.state().partition_id,
425       partitioned_input.state().next_channel_id, module_,
426       partitioned_input.state().b);
427 
428   result->set_sharding(hlo->sharding());
429   auto partitioned_fft =
430       PartitionedHlo(result, hlo->shape(), partitioned_input.state());
431   SetPartitionedHlo(hlo, partitioned_fft);
432   return Status::OK();
433 }
434 
435 }  // namespace spmd
436 }  // namespace xla
437