• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_reduce_scatter_creator.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_query.h"
24 #include "tensorflow/compiler/xla/service/reduce_scatter_utils.h"
25 
26 namespace xla {
27 namespace gpu {
28 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)29 StatusOr<bool> ReduceScatterCreator::Run(
30     HloModule *module,
31     const absl::flat_hash_set<absl::string_view> &execution_threads) {
32   const HloModuleConfig &config = module->config();
33   int64_t next_channel_id = hlo_query::NextChannelId(*module);
34 
35   bool changed = false;
36   for (HloComputation *computation :
37        module->MakeNonfusionComputations(execution_threads)) {
38     for (HloInstruction *instruction :
39          computation->MakeInstructionPostOrder()) {
40       if (instruction->opcode() != HloOpcode::kAllReduce) {
41         continue;
42       }
43       auto *ar = Cast<HloAllReduceInstruction>(instruction);
44       auto ar_spec = MatchReduceScatter(ar, config.num_partitions(),
45                                         config.replica_count(),
46                                         /*allow_multiple_split_dims=*/false,
47                                         /*allow_intervening_reshape=*/true);
48       if (!ar_spec) {
49         VLOG(2) << "Cannot match reduce-scatter " << ar->ToString();
50         continue;
51       }
52 
53       HloInstruction *ds = ar_spec->dynamic_slice;
54 
55       // Convert to all-reduce scatter. The output shape of the all-reduce
56       // scatter will the same as the input shape, except the split dim size is
57       // that of the result of the dynamic slice.
58       const int64_t split_dim = ar_spec->split_dim;
59       Shape scatter_shape = ar->shape();
60       const int64_t split_dim_size = scatter_shape.dimensions(split_dim);
61       HloInstruction *rs_input = ar->mutable_operand(0);
62       const int64_t scatter_dim_size = split_dim_size / ar_spec->group_size;
63       TF_RET_CHECK(scatter_dim_size * ar_spec->group_size <= split_dim_size);
64       if (split_dim_size % ar_spec->group_size != 0) {
65         // The dynamic-slice does not evenly split the scatter dim. In that
66         // case, create a reduce-scatter with the relevant slice of the
67         // all-reduce input.
68         scatter_shape.set_dimensions(split_dim,
69                                      scatter_dim_size * ar_spec->group_size);
70         rs_input = computation->AddInstruction(HloInstruction::CreateSlice(
71             scatter_shape, rs_input,
72             std::vector<int64_t>(scatter_shape.rank(), 0),
73             scatter_shape.dimensions(),
74             std::vector<int64_t>(scatter_shape.rank(), 1)));
75       }
76       scatter_shape.set_dimensions(split_dim, scatter_dim_size);
77 
78       std::optional<int64_t> channel_id;
79       if (ar->channel_id()) {
80         // We cannot reuse the channel_id on all-reduce for reduce-scatter.
81         channel_id = next_channel_id++;
82       }
83 
84       HloInstruction *ars =
85           computation->AddInstruction(HloInstruction::CreateReduceScatter(
86               scatter_shape, {rs_input}, ar->to_apply(), ar->replica_groups(),
87               ar->constrain_layout(), channel_id, ar->use_global_device_ids(),
88               ar_spec->split_dim));
89 
90       // If there was an intervening reshape, reshape the non-split dimensions
91       // to match that existing reshape. Basically we can just reshape the ars
92       // result to the dynamic slice shape.
93       HloInstruction *result = ars;
94       HloInstruction *reshape = nullptr;
95       if (ds->operand(0) != ar) {
96         reshape = ds->mutable_operand(0);
97         result = computation->AddInstruction(
98             HloInstruction::CreateReshape(ds->shape(), result));
99       }
100 
101       // Note that RemoveInstructionAndUnusedOperands may not always remove the
102       // all-reduce operand of the dynamic-slice, so remove all the dead
103       // instructions manually.
104       TF_RETURN_IF_ERROR(ds->ReplaceAllUsesWith(result));
105       TF_RETURN_IF_ERROR(computation->RemoveInstruction(ds));
106       if (reshape) {
107         TF_RETURN_IF_ERROR(computation->RemoveInstruction(reshape));
108       }
109       TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(ar));
110       changed = true;
111     }
112   }
113 
114   return changed;
115 }
116 
117 }  // namespace gpu
118 }  // namespace xla
119