1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/custom_call_handler.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/comparators.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
29 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
30 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
31 #include "tensorflow/compiler/xla/service/shape_inference.h"
32 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
33 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 #include "tensorflow/core/platform/numbers.h"
38 
39 namespace xla {
40 namespace spmd {
41 
42 namespace {
43 
ParseOpaqueAsAttributes(const HloInstruction * hlo)44 StatusOr<absl::flat_hash_map<string, int64>> ParseOpaqueAsAttributes(
45     const HloInstruction* hlo) {
46   absl::string_view opaque = Cast<HloCustomCallInstruction>(hlo)->opaque();
47   HloLexer lexer(opaque);
48   absl::flat_hash_map<string, int64> result;
49   while (lexer.Lex() != TokKind::kEof) {
50     if (lexer.GetKind() != TokKind::kAttributeName) {
51       return InvalidArgument("Expects attribute name, %s", opaque);
52     }
53     string attr_name = lexer.GetStrVal();
54     if (lexer.Lex() != TokKind::kInt) {
55       return InvalidArgument("expects integer attribute value");
56     }
57     result[attr_name] = lexer.GetInt64Val();
58     if (lexer.Lex() != TokKind::kComma) {
59       break;
60     }
61   }
62   return result;
63 }
64 
65 constexpr char kSPMDOpRotateRight[] = "_SPMDInternalOp_RotateRight";
66 
67 }  // namespace
68 
HandleCustomCallTopK(HloInstruction * hlo)69 Status SpmdPartitioningVisitor::HandleCustomCallTopK(HloInstruction* hlo) {
70   if (!hlo->operand(0)->has_sharding()) {
71     return DefaultAction(hlo);
72   }
73 
74   const HloSharding& sharding = hlo->operand(0)->sharding();
75   // No support for partial replicate yet.
76   if (sharding.IsTileMaximal() || sharding.IsReplicated() ||
77       sharding.ReplicateOnLastTileDim()) {
78     return DefaultAction(hlo);
79   }
80 
81   const int64_t batch_dim = 0;
82   const int64_t sort_dim = 1;
83   const int64_t shard_count = sharding.tile_assignment().dim(sort_dim);
84 
85   if (shard_count <= 1) {
86     return DefaultAction(hlo);
87   }
88 
89   const int64_t batch_dim_partition = sharding.tile_assignment().dim(batch_dim);
90   const int64_t input_size = hlo->operand(0)->shape().dimensions(sort_dim);
91   const int64_t batch_size = hlo->shape().tuple_shapes(0).dimensions(batch_dim);
92   const int64_t k = hlo->shape().tuple_shapes(0).dimensions(sort_dim);
93   const int64_t per_partition_size = CeilOfRatio(input_size, shard_count);
94 
95   if (k >= per_partition_size) {
96     return DefaultAction(hlo);
97   }
98 
99   auto input = hlo->operand(0);
100   const auto element_type = input->shape().element_type();
101 
102   auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
103       CreateFirstWithType(element_type, &b_));
104 
105   auto partition_state = partitioned_input.state();
106   auto replicated_sharding = HloSharding::Replicate();
107   // If batch dimension is partitioned, partial replicated on sort dimension.
108   if (batch_dim_partition > 1) {
109     auto sharding_grouped = GroupShardingOnDims(sharding, {batch_dim});
110     partition_state = CreatePerGroupPartitioningState(
111         partitioned_input.state(), sharding_grouped.device_groups,
112         partitioned_input.state().b);
113     auto reshape_tile_assignment = sharding.tile_assignment();
114     auto reshape_dimensions = reshape_tile_assignment.dimensions();
115     reshape_dimensions.push_back(reshape_dimensions.back());
116     reshape_dimensions[sort_dim] = 1;
117     reshape_tile_assignment.Reshape(reshape_dimensions);
118     replicated_sharding = HloSharding::PartialTile(reshape_tile_assignment);
119   }
120 
121   // Each partition needs to do TopK separately, thus the base shape
122   // becomes [batch_size, k * shard_count].
123   const Shape replicated_shape = ShapeUtil::MakeTupleShape(
124       {ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(),
125                             {batch_size, k * shard_count}),
126        ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})});
127   auto custom_call_sharding =
128       sharding.GetTupleSharding(replicated_shape).ValueOrDie();
129   auto shard_shape =
130       MakePartitionedShape(replicated_shape, custom_call_sharding);
131   auto topk = b_.AddInstruction(
132       hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()}));
133   topk->set_sharding(custom_call_sharding);
134   // Partition customcall.
135   PartitionedHlo partitioned_topk(topk, replicated_shape,
136                                   MakePartitioningState());
137   topk = partitioned_topk.hlo();
138 
139   // Get value from TopK.
140   HloInstruction* value_gte =
141       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
142           topk->shape().tuple_shapes(0), topk, 0));
143   value_gte->set_sharding(sharding);
144   // Partition GetTupleElement of value.
145   PartitionedHlo value_partitioned_gte(
146       value_gte, partitioned_topk.base_shape().tuple_shapes(0),
147       MakePartitioningState());
148   // Reshard value to be replicated.
149   auto replicated_value_gte =
150       value_partitioned_gte.Reshard(replicated_sharding).hlo();
151 
152   // Get index from TopK.
153   HloInstruction* index_gte =
154       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
155           topk->shape().tuple_shapes(1), topk, 1));
156   auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert(
157       ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()),
158       partition_state.partition_id));
159   // Add per partition offset to index, index returned from CustomCall always
160   // starts from 0.
161   auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast(
162       index_gte->shape(),
163       b_.AddInstruction(HloInstruction::CreateBinary(
164           partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32,
165           b_.AddInstruction(HloInstruction::CreateConstant(
166               LiteralUtil::CreateR0<int32>(per_partition_size))))),
167       {}));
168   index_gte = b_.AddInstruction(HloInstruction::CreateBinary(
169       index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset));
170   index_gte->set_sharding(sharding);
171   // Parttion GetTupleElement of index.
172   PartitionedHlo index_partitioned_gte(
173       index_gte, partitioned_topk.base_shape().tuple_shapes(1),
174       MakePartitioningState());
175   // Reshard index to be replicated.
176   auto replicated_index_gte =
177       index_partitioned_gte.Reshard(replicated_sharding).hlo();
178 
179   // Creates replicated sort to do TopK, the input is value and index pairs
180   // from all the partitions. The reason to use Sort instead of CustomCall TopK
181   // is CustomCall only takes value as input. There will be an extra Gather
182   // to get the correct index if CustomCall is used here.
183 
184   // Create comparator for the sort.
185   XlaBuilder b("Sort.Compare");
186   XlaComputation comparator = CreateScalarComparisonComputation(
187       "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt},
188       &b);
189   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
190   HloModuleConfig config(program_shape);
191   TF_ASSIGN_OR_RETURN(auto new_module,
192                       HloModule::CreateFromProto(comparator.proto(), config));
193   HloCloneContext context(module_);
194   auto compare_computation =
195       module_->DeepCloneComputation(new_module->entry_computation(), &context);
196   // Each partition needs to do TopK separately, thus the base shape for sort
197   // becomes [ceil(batch_size / batch_dim_partition), k * shard_count].
198   const Shape sort_shape = ShapeUtil::MakeTupleShape(
199       {ShapeUtil::MakeShape(
200            hlo->operand(0)->shape().element_type(),
201            {CeilOfRatio(batch_size, batch_dim_partition), k * shard_count}),
202        ShapeUtil::MakeShape(S32, {CeilOfRatio(batch_size, batch_dim_partition),
203                                   k * shard_count})});
204   auto sort = b_.AddInstruction(HloInstruction::CreateSort(
205       sort_shape, sort_dim, {replicated_value_gte, replicated_index_gte},
206       compare_computation, true));
207   sort->set_sharding(
208       replicated_sharding.GetTupleSharding(sort->shape()).ValueOrDie());
209   PartitionedHlo replicated_sort(sort, replicated_shape,
210                                  MakePartitioningState());
211 
212   // Slice value and index from top-k for output.
213   HloInstruction* sort_value_gte =
214       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
215           replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(),
216           0));
217   HloInstruction* sort_index_gte =
218       b_.AddInstruction(HloInstruction::CreateGetTupleElement(
219           replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(),
220           1));
221   // Slice value from final sort.
222   HloInstruction* slice_sort_value =
223       SliceFirstK(sort_value_gte, &b_, sort_dim, k);
224   // Slice index from final sort.
225   HloInstruction* slice_index_value =
226       SliceFirstK(sort_index_gte, &b_, sort_dim, k);
227   auto create_tuple = b_.AddInstruction(
228       HloInstruction::CreateTuple({slice_sort_value, slice_index_value}));
229   create_tuple->set_sharding(
230       replicated_sharding.GetTupleSharding(create_tuple->shape()).ValueOrDie());
231   SetPartitionedHlo(
232       hlo, PartitionedHlo(create_tuple, hlo->shape(), MakePartitioningState())
233                .Reshard(hlo->sharding()));
234 
235   return Status::OK();
236 }
237 
HandleCustomCallSPMDInternal_RotateRight(HloInstruction * hlo)238 Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_RotateRight(
239     HloInstruction* hlo) {
240   TF_ASSIGN_OR_RETURN(auto attrs, ParseOpaqueAsAttributes(hlo));
241   auto dim_it = attrs.find("dimension");
242   TF_RET_CHECK(dim_it != attrs.end())
243       << "No dimension attribute in SPMD rotate op";
244   int64_t dim = dim_it->second;
245   auto amount_it = attrs.find("amount");
246   TF_RET_CHECK(amount_it != attrs.end())
247       << "No amount attribute in SPMD rotate op";
248 
249   PartitionedHlo input =
250       GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding());
251   const int64_t full_size = hlo->shape().dimensions(dim);
252   const int64_t shard_size = input.hlo()->shape().dimensions(dim);
253 
254   // We exclude shards that are entirely padding.
255   const int64_t participating_shards = CeilOfRatio(full_size, shard_size);
256   // The last included shard might still have padding on the right.
257   const int64_t right_padding = participating_shards * shard_size - full_size;
258   int64_t amount = amount_it->second;
259   TF_RET_CHECK(amount >= 0)
260       << "Rotate amount cannot be negative in SPMD rotate op";
261 
262   amount %= full_size;
263   if (amount == 0) {
264     SetPartitionedHlo(hlo, input);
265     return Status::OK();
266   }
267 
268   // First step: rotate `amount` on padded data. E.g., before
269   //      012|345|678|9__     (_: padding)
270   // after:
271   //      678|9__|012|345     (amount: 6)
272   auto rotate_with_padding = [&](int64_t rotate_amount) {
273     int64_t current_size = 0;
274     std::vector<HloInstruction*> concat_pieces;
275     while (current_size < shard_size) {
276       int64_t shard_distance =
277           CeilOfRatio(rotate_amount - current_size, shard_size);
278       int64_t offset_in_shard =
279           shard_distance * shard_size - rotate_amount + current_size;
280 
281       int64_t halo_size =
282           std::min(shard_size - offset_in_shard, shard_size - current_size);
283 
284       current_size += halo_size;
285       Shape halo_shape = input.hlo()->shape();
286       halo_shape.set_dimensions(dim, halo_size);
287       HloInstruction* halo = input.hlo();
288       if (halo_size != shard_size) {
289         halo_shape.set_dimensions(dim, halo_size);
290         std::vector<int64> slice_starts(hlo->shape().rank(), 0);
291         slice_starts[dim] = offset_in_shard;
292         std::vector<int64> slice_limits(
293             input.hlo()->shape().dimensions().begin(),
294             input.hlo()->shape().dimensions().end());
295         slice_limits[dim] = offset_in_shard + halo_size;
296         halo = b_.AddInstruction(HloInstruction::CreateSlice(
297             halo_shape, halo, slice_starts, slice_limits,
298             std::vector<int64>(halo_shape.rank(), 1)));
299       }
300       if (shard_distance != 0) {
301         std::vector<std::pair<int64, int64>> pairs;
302         hlo->sharding().tile_assignment().Each(
303             [&](absl::Span<const int64> indices, int64_t device) {
304               if (indices[dim] >= participating_shards) {
305                 return;
306               }
307               std::vector<int64> dst_idx(indices.begin(), indices.end());
308               dst_idx[dim] += shard_distance;
309               dst_idx[dim] %= participating_shards;
310               pairs.emplace_back(device,
311                                  hlo->sharding().tile_assignment()(dst_idx));
312             });
313         halo =
314             collective_ops_creator_.create_cross_partition_collective_permute(
315                 &b_, halo, pairs, NewChannel());
316       }
317       concat_pieces.push_back(halo);
318     }
319     if (concat_pieces.size() > 1) {
320       return b_.AddInstruction(HloInstruction::CreateConcatenate(
321           input.hlo()->shape(), concat_pieces, dim));
322     }
323     return concat_pieces[0];
324   };
325   HloInstruction* rotated0 = rotate_with_padding(amount);
326   if (right_padding == 0) {
327     SetPartitionedHlo(hlo, [&] { return rotated0; });
328     return Status::OK();
329   }
330 
331   // Second step: perform another rotate from input, with `right_padding` added
332   // to `amount`. E.g., before
333   //      012|345|678|9__     (_: padding)
334   // after:
335   //      456|789|__0|123     (amount: 6 + 2)
336   // combine (select) with first step:
337   //      678|9__|012|345
338   // now we get:
339   //      456|789|012|3__
340 
341   HloInstruction* rotated1 = rotate_with_padding(
342       (amount + right_padding) % (shard_size * participating_shards));
343   HloInstruction* shard_offset = MakePartitionOffsets(
344       hlo->shape(), hlo->sharding(), MakePartitioningState().partition_id, &b_,
345       {dim})[dim];
346   HloInstruction* iota = b_.AddInstruction(HloInstruction::CreateIota(
347       ShapeUtil::ChangeElementType(rotated0->shape(), S32), dim));
348   HloInstruction* selection_boundary =
349       b_.AddInstruction(HloInstruction::CreateBroadcast(
350           iota->shape(),
351           b_.AddInstruction(HloInstruction::CreateBinary(
352               shard_offset->shape(), HloOpcode::kSubtract,
353               b_.AddInstruction(HloInstruction::CreateConstant(
354                   LiteralUtil::CreateR0<int32>(amount))),
355               shard_offset)),
356           {}));
357   HloInstruction* pred = b_.AddInstruction(HloInstruction::CreateCompare(
358       ShapeUtil::ChangeElementType(iota->shape(), PRED), iota,
359       selection_boundary, Comparison::Direction::kLt));
360   SetPartitionedHlo(hlo, [&] {
361     return b_.AddInstruction(HloInstruction::CreateTernary(
362         rotated0->shape(), HloOpcode::kSelect, pred, rotated1, rotated0));
363   });
364   return Status::OK();
365 }
366 
CreateCustomCallSPMDInternal_RotateRight(HloInstruction * input,int64_t dim,int64_t amount)367 std::unique_ptr<HloInstruction> CreateCustomCallSPMDInternal_RotateRight(
368     HloInstruction* input, int64_t dim, int64_t amount) {
369   string opaque = absl::StrCat("dimension=", dim, ",amount=", amount);
370   return HloInstruction::CreateCustomCall(input->shape(), {input},
371                                           kSPMDOpRotateRight, opaque);
372 }
373 
HandleCustomCall(HloInstruction * hlo)374 Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
375   if (hlo->custom_call_target() == "SPMDFullToShardShape") {
376     // This op switches from auto partitioning to manual partitioning.
377     auto input_partitioned = GetPartitionedHlo(hlo->operand(0));
378     if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) {
379       input_partitioned = input_partitioned.PadWithValue(
380           CreateR0WithType(hlo->shape().element_type(), 0, &b_));
381     }
382     auto input = input_partitioned.hlo();
383     CHECK(hlo->sharding().IsManual());
384     CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape()));
385     auto copy = b_.AddInstruction(
386         HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
387     SetPartitionedHlo(hlo, [&] { return copy; });
388     return Status::OK();
389   }
390   if (hlo->custom_call_target() == "SPMDShardToFullShape") {
391     // This op switches from manual partitioning to auto partitioning.
392     auto input = GetPartitionedHlo(hlo->operand(0)).hlo();
393     CHECK(input->sharding().IsManual());
394     auto copy = b_.AddInstruction(
395         HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
396     CHECK(ShapeUtil::Compatible(
397         copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
398     SetPartitionedHlo(hlo, [&] { return copy; });
399     return Status::OK();
400   }
401 
402   if (hlo->custom_call_target() == "TopK") {
403     return HandleCustomCallTopK(hlo);
404   }
405 
406   if (hlo->custom_call_target() == kSPMDOpRotateRight) {
407     return HandleCustomCallSPMDInternal_RotateRight(hlo);
408   }
409 
410   return DefaultAction(hlo);
411 }
412 
413 }  // namespace spmd
414 }  // namespace xla
415