1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/custom_call_handler.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/comparators.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
29 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
30 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
31 #include "tensorflow/compiler/xla/service/shape_inference.h"
32 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
33 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 #include "tensorflow/core/platform/numbers.h"
38
39 namespace xla {
40 namespace spmd {
41
42 namespace {
43
ParseOpaqueAsAttributes(const HloInstruction * hlo)44 StatusOr<absl::flat_hash_map<string, int64>> ParseOpaqueAsAttributes(
45 const HloInstruction* hlo) {
46 absl::string_view opaque = Cast<HloCustomCallInstruction>(hlo)->opaque();
47 HloLexer lexer(opaque);
48 absl::flat_hash_map<string, int64> result;
49 while (lexer.Lex() != TokKind::kEof) {
50 if (lexer.GetKind() != TokKind::kAttributeName) {
51 return InvalidArgument("Expects attribute name, %s", opaque);
52 }
53 string attr_name = lexer.GetStrVal();
54 if (lexer.Lex() != TokKind::kInt) {
55 return InvalidArgument("expects integer attribute value");
56 }
57 result[attr_name] = lexer.GetInt64Val();
58 if (lexer.Lex() != TokKind::kComma) {
59 break;
60 }
61 }
62 return result;
63 }
64
65 constexpr char kSPMDOpRotateRight[] = "_SPMDInternalOp_RotateRight";
66
67 } // namespace
68
HandleCustomCallTopK(HloInstruction * hlo)69 Status SpmdPartitioningVisitor::HandleCustomCallTopK(HloInstruction* hlo) {
70 if (!hlo->operand(0)->has_sharding()) {
71 return DefaultAction(hlo);
72 }
73
74 const HloSharding& sharding = hlo->operand(0)->sharding();
75 // No support for partial replicate yet.
76 if (sharding.IsTileMaximal() || sharding.IsReplicated() ||
77 sharding.ReplicateOnLastTileDim()) {
78 return DefaultAction(hlo);
79 }
80
81 const int64_t batch_dim = 0;
82 const int64_t sort_dim = 1;
83 const int64_t shard_count = sharding.tile_assignment().dim(sort_dim);
84
85 if (shard_count <= 1) {
86 return DefaultAction(hlo);
87 }
88
89 const int64_t batch_dim_partition = sharding.tile_assignment().dim(batch_dim);
90 const int64_t input_size = hlo->operand(0)->shape().dimensions(sort_dim);
91 const int64_t batch_size = hlo->shape().tuple_shapes(0).dimensions(batch_dim);
92 const int64_t k = hlo->shape().tuple_shapes(0).dimensions(sort_dim);
93 const int64_t per_partition_size = CeilOfRatio(input_size, shard_count);
94
95 if (k >= per_partition_size) {
96 return DefaultAction(hlo);
97 }
98
99 auto input = hlo->operand(0);
100 const auto element_type = input->shape().element_type();
101
102 auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
103 CreateFirstWithType(element_type, &b_));
104
105 auto partition_state = partitioned_input.state();
106 auto replicated_sharding = HloSharding::Replicate();
107 // If batch dimension is partitioned, partial replicated on sort dimension.
108 if (batch_dim_partition > 1) {
109 auto sharding_grouped = GroupShardingOnDims(sharding, {batch_dim});
110 partition_state = CreatePerGroupPartitioningState(
111 partitioned_input.state(), sharding_grouped.device_groups,
112 partitioned_input.state().b);
113 auto reshape_tile_assignment = sharding.tile_assignment();
114 auto reshape_dimensions = reshape_tile_assignment.dimensions();
115 reshape_dimensions.push_back(reshape_dimensions.back());
116 reshape_dimensions[sort_dim] = 1;
117 reshape_tile_assignment.Reshape(reshape_dimensions);
118 replicated_sharding = HloSharding::PartialTile(reshape_tile_assignment);
119 }
120
121 // Each partition needs to do TopK separately, thus the base shape
122 // becomes [batch_size, k * shard_count].
123 const Shape replicated_shape = ShapeUtil::MakeTupleShape(
124 {ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(),
125 {batch_size, k * shard_count}),
126 ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})});
127 auto custom_call_sharding =
128 sharding.GetTupleSharding(replicated_shape).ValueOrDie();
129 auto shard_shape =
130 MakePartitionedShape(replicated_shape, custom_call_sharding);
131 auto topk = b_.AddInstruction(
132 hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()}));
133 topk->set_sharding(custom_call_sharding);
134 // Partition customcall.
135 PartitionedHlo partitioned_topk(topk, replicated_shape,
136 MakePartitioningState());
137 topk = partitioned_topk.hlo();
138
139 // Get value from TopK.
140 HloInstruction* value_gte =
141 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
142 topk->shape().tuple_shapes(0), topk, 0));
143 value_gte->set_sharding(sharding);
144 // Partition GetTupleElement of value.
145 PartitionedHlo value_partitioned_gte(
146 value_gte, partitioned_topk.base_shape().tuple_shapes(0),
147 MakePartitioningState());
148 // Reshard value to be replicated.
149 auto replicated_value_gte =
150 value_partitioned_gte.Reshard(replicated_sharding).hlo();
151
152 // Get index from TopK.
153 HloInstruction* index_gte =
154 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
155 topk->shape().tuple_shapes(1), topk, 1));
156 auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert(
157 ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()),
158 partition_state.partition_id));
159 // Add per partition offset to index, index returned from CustomCall always
160 // starts from 0.
161 auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast(
162 index_gte->shape(),
163 b_.AddInstruction(HloInstruction::CreateBinary(
164 partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32,
165 b_.AddInstruction(HloInstruction::CreateConstant(
166 LiteralUtil::CreateR0<int32>(per_partition_size))))),
167 {}));
168 index_gte = b_.AddInstruction(HloInstruction::CreateBinary(
169 index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset));
170 index_gte->set_sharding(sharding);
171 // Parttion GetTupleElement of index.
172 PartitionedHlo index_partitioned_gte(
173 index_gte, partitioned_topk.base_shape().tuple_shapes(1),
174 MakePartitioningState());
175 // Reshard index to be replicated.
176 auto replicated_index_gte =
177 index_partitioned_gte.Reshard(replicated_sharding).hlo();
178
179 // Creates replicated sort to do TopK, the input is value and index pairs
180 // from all the partitions. The reason to use Sort instead of CustomCall TopK
181 // is CustomCall only takes value as input. There will be an extra Gather
182 // to get the correct index if CustomCall is used here.
183
184 // Create comparator for the sort.
185 XlaBuilder b("Sort.Compare");
186 XlaComputation comparator = CreateScalarComparisonComputation(
187 "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt},
188 &b);
189 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
190 HloModuleConfig config(program_shape);
191 TF_ASSIGN_OR_RETURN(auto new_module,
192 HloModule::CreateFromProto(comparator.proto(), config));
193 HloCloneContext context(module_);
194 auto compare_computation =
195 module_->DeepCloneComputation(new_module->entry_computation(), &context);
196 // Each partition needs to do TopK separately, thus the base shape for sort
197 // becomes [ceil(batch_size / batch_dim_partition), k * shard_count].
198 const Shape sort_shape = ShapeUtil::MakeTupleShape(
199 {ShapeUtil::MakeShape(
200 hlo->operand(0)->shape().element_type(),
201 {CeilOfRatio(batch_size, batch_dim_partition), k * shard_count}),
202 ShapeUtil::MakeShape(S32, {CeilOfRatio(batch_size, batch_dim_partition),
203 k * shard_count})});
204 auto sort = b_.AddInstruction(HloInstruction::CreateSort(
205 sort_shape, sort_dim, {replicated_value_gte, replicated_index_gte},
206 compare_computation, true));
207 sort->set_sharding(
208 replicated_sharding.GetTupleSharding(sort->shape()).ValueOrDie());
209 PartitionedHlo replicated_sort(sort, replicated_shape,
210 MakePartitioningState());
211
212 // Slice value and index from top-k for output.
213 HloInstruction* sort_value_gte =
214 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
215 replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(),
216 0));
217 HloInstruction* sort_index_gte =
218 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
219 replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(),
220 1));
221 // Slice value from final sort.
222 HloInstruction* slice_sort_value =
223 SliceFirstK(sort_value_gte, &b_, sort_dim, k);
224 // Slice index from final sort.
225 HloInstruction* slice_index_value =
226 SliceFirstK(sort_index_gte, &b_, sort_dim, k);
227 auto create_tuple = b_.AddInstruction(
228 HloInstruction::CreateTuple({slice_sort_value, slice_index_value}));
229 create_tuple->set_sharding(
230 replicated_sharding.GetTupleSharding(create_tuple->shape()).ValueOrDie());
231 SetPartitionedHlo(
232 hlo, PartitionedHlo(create_tuple, hlo->shape(), MakePartitioningState())
233 .Reshard(hlo->sharding()));
234
235 return Status::OK();
236 }
237
HandleCustomCallSPMDInternal_RotateRight(HloInstruction * hlo)238 Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_RotateRight(
239 HloInstruction* hlo) {
240 TF_ASSIGN_OR_RETURN(auto attrs, ParseOpaqueAsAttributes(hlo));
241 auto dim_it = attrs.find("dimension");
242 TF_RET_CHECK(dim_it != attrs.end())
243 << "No dimension attribute in SPMD rotate op";
244 int64_t dim = dim_it->second;
245 auto amount_it = attrs.find("amount");
246 TF_RET_CHECK(amount_it != attrs.end())
247 << "No amount attribute in SPMD rotate op";
248
249 PartitionedHlo input =
250 GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding());
251 const int64_t full_size = hlo->shape().dimensions(dim);
252 const int64_t shard_size = input.hlo()->shape().dimensions(dim);
253
254 // We exclude shards that are entirely padding.
255 const int64_t participating_shards = CeilOfRatio(full_size, shard_size);
256 // The last included shard might still have padding on the right.
257 const int64_t right_padding = participating_shards * shard_size - full_size;
258 int64_t amount = amount_it->second;
259 TF_RET_CHECK(amount >= 0)
260 << "Rotate amount cannot be negative in SPMD rotate op";
261
262 amount %= full_size;
263 if (amount == 0) {
264 SetPartitionedHlo(hlo, input);
265 return Status::OK();
266 }
267
268 // First step: rotate `amount` on padded data. E.g., before
269 // 012|345|678|9__ (_: padding)
270 // after:
271 // 678|9__|012|345 (amount: 6)
272 auto rotate_with_padding = [&](int64_t rotate_amount) {
273 int64_t current_size = 0;
274 std::vector<HloInstruction*> concat_pieces;
275 while (current_size < shard_size) {
276 int64_t shard_distance =
277 CeilOfRatio(rotate_amount - current_size, shard_size);
278 int64_t offset_in_shard =
279 shard_distance * shard_size - rotate_amount + current_size;
280
281 int64_t halo_size =
282 std::min(shard_size - offset_in_shard, shard_size - current_size);
283
284 current_size += halo_size;
285 Shape halo_shape = input.hlo()->shape();
286 halo_shape.set_dimensions(dim, halo_size);
287 HloInstruction* halo = input.hlo();
288 if (halo_size != shard_size) {
289 halo_shape.set_dimensions(dim, halo_size);
290 std::vector<int64> slice_starts(hlo->shape().rank(), 0);
291 slice_starts[dim] = offset_in_shard;
292 std::vector<int64> slice_limits(
293 input.hlo()->shape().dimensions().begin(),
294 input.hlo()->shape().dimensions().end());
295 slice_limits[dim] = offset_in_shard + halo_size;
296 halo = b_.AddInstruction(HloInstruction::CreateSlice(
297 halo_shape, halo, slice_starts, slice_limits,
298 std::vector<int64>(halo_shape.rank(), 1)));
299 }
300 if (shard_distance != 0) {
301 std::vector<std::pair<int64, int64>> pairs;
302 hlo->sharding().tile_assignment().Each(
303 [&](absl::Span<const int64> indices, int64_t device) {
304 if (indices[dim] >= participating_shards) {
305 return;
306 }
307 std::vector<int64> dst_idx(indices.begin(), indices.end());
308 dst_idx[dim] += shard_distance;
309 dst_idx[dim] %= participating_shards;
310 pairs.emplace_back(device,
311 hlo->sharding().tile_assignment()(dst_idx));
312 });
313 halo =
314 collective_ops_creator_.create_cross_partition_collective_permute(
315 &b_, halo, pairs, NewChannel());
316 }
317 concat_pieces.push_back(halo);
318 }
319 if (concat_pieces.size() > 1) {
320 return b_.AddInstruction(HloInstruction::CreateConcatenate(
321 input.hlo()->shape(), concat_pieces, dim));
322 }
323 return concat_pieces[0];
324 };
325 HloInstruction* rotated0 = rotate_with_padding(amount);
326 if (right_padding == 0) {
327 SetPartitionedHlo(hlo, [&] { return rotated0; });
328 return Status::OK();
329 }
330
331 // Second step: perform another rotate from input, with `right_padding` added
332 // to `amount`. E.g., before
333 // 012|345|678|9__ (_: padding)
334 // after:
335 // 456|789|__0|123 (amount: 6 + 2)
336 // combine (select) with first step:
337 // 678|9__|012|345
338 // now we get:
339 // 456|789|012|3__
340
341 HloInstruction* rotated1 = rotate_with_padding(
342 (amount + right_padding) % (shard_size * participating_shards));
343 HloInstruction* shard_offset = MakePartitionOffsets(
344 hlo->shape(), hlo->sharding(), MakePartitioningState().partition_id, &b_,
345 {dim})[dim];
346 HloInstruction* iota = b_.AddInstruction(HloInstruction::CreateIota(
347 ShapeUtil::ChangeElementType(rotated0->shape(), S32), dim));
348 HloInstruction* selection_boundary =
349 b_.AddInstruction(HloInstruction::CreateBroadcast(
350 iota->shape(),
351 b_.AddInstruction(HloInstruction::CreateBinary(
352 shard_offset->shape(), HloOpcode::kSubtract,
353 b_.AddInstruction(HloInstruction::CreateConstant(
354 LiteralUtil::CreateR0<int32>(amount))),
355 shard_offset)),
356 {}));
357 HloInstruction* pred = b_.AddInstruction(HloInstruction::CreateCompare(
358 ShapeUtil::ChangeElementType(iota->shape(), PRED), iota,
359 selection_boundary, Comparison::Direction::kLt));
360 SetPartitionedHlo(hlo, [&] {
361 return b_.AddInstruction(HloInstruction::CreateTernary(
362 rotated0->shape(), HloOpcode::kSelect, pred, rotated1, rotated0));
363 });
364 return Status::OK();
365 }
366
CreateCustomCallSPMDInternal_RotateRight(HloInstruction * input,int64_t dim,int64_t amount)367 std::unique_ptr<HloInstruction> CreateCustomCallSPMDInternal_RotateRight(
368 HloInstruction* input, int64_t dim, int64_t amount) {
369 string opaque = absl::StrCat("dimension=", dim, ",amount=", amount);
370 return HloInstruction::CreateCustomCall(input->shape(), {input},
371 kSPMDOpRotateRight, opaque);
372 }
373
HandleCustomCall(HloInstruction * hlo)374 Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
375 if (hlo->custom_call_target() == "SPMDFullToShardShape") {
376 // This op switches from auto partitioning to manual partitioning.
377 auto input_partitioned = GetPartitionedHlo(hlo->operand(0));
378 if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) {
379 input_partitioned = input_partitioned.PadWithValue(
380 CreateR0WithType(hlo->shape().element_type(), 0, &b_));
381 }
382 auto input = input_partitioned.hlo();
383 CHECK(hlo->sharding().IsManual());
384 CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape()));
385 auto copy = b_.AddInstruction(
386 HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
387 SetPartitionedHlo(hlo, [&] { return copy; });
388 return Status::OK();
389 }
390 if (hlo->custom_call_target() == "SPMDShardToFullShape") {
391 // This op switches from manual partitioning to auto partitioning.
392 auto input = GetPartitionedHlo(hlo->operand(0)).hlo();
393 CHECK(input->sharding().IsManual());
394 auto copy = b_.AddInstruction(
395 HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
396 CHECK(ShapeUtil::Compatible(
397 copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
398 SetPartitionedHlo(hlo, [&] { return copy; });
399 return Status::OK();
400 }
401
402 if (hlo->custom_call_target() == "TopK") {
403 return HandleCustomCallTopK(hlo);
404 }
405
406 if (hlo->custom_call_target() == kSPMDOpRotateRight) {
407 return HandleCustomCallSPMDInternal_RotateRight(hlo);
408 }
409
410 return DefaultAction(hlo);
411 }
412
413 } // namespace spmd
414 } // namespace xla
415