1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h"
17
18 #include <vector>
19
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/core/lib/core/errors.h"
27
28 namespace xla {
29 namespace gpu {
30
31 namespace {
32 // The parameter space on the GPU device is limited. We pick an arbitrary low
33 // constant here to try to prevent exceeding this parameter space. For a proper
34 // fix, we would have to take into account which parameters share a buffer, and
35 // how big these buffers are.
36 constexpr int32_t kMaxParameters = 128;
37
SplitConcatenate(HloInstruction * concat,HloComputation * comp)38 StatusOr<bool> SplitConcatenate(HloInstruction* concat, HloComputation* comp) {
39 auto operands = concat->operands();
40 std::vector<HloInstruction*> operands_to_split(operands.begin(),
41 operands.end());
42 while (operands_to_split.size() > 1) {
43 std::vector<HloInstruction*> new_operands;
44 absl::Span<HloInstruction*> operands_span(operands_to_split);
45 for (int64_t offset = 0; offset < operands_to_split.size();
46 offset += kMaxParameters) {
47 // Check if there is a remainder of operands that does not completely fill
48 // one "batch" of exactly 'kMaxParameters' operands. If there are only
49 // less than 'kMaxParameters' operands left, then we still put them into a
50 // concat together. Otherwise, we spare them for another round so that
51 // they can be put together into a concat with some of the newly created
52 // concats.
53 if (offset > 0 && offset + kMaxParameters > operands_to_split.size()) {
54 new_operands.insert(new_operands.end(),
55 operands_to_split.begin() + offset,
56 operands_to_split.end());
57 } else {
58 Shape new_shape = concat->shape();
59 int64_t concat_dimension_size = 0;
60 for (int64_t i = 0;
61 i < kMaxParameters && offset + i < operands_to_split.size(); ++i) {
62 concat_dimension_size +=
63 operands_to_split[i + offset]->shape().dimensions(
64 concat->concatenate_dimension());
65 }
66 new_shape.set_dimensions(concat->concatenate_dimension(),
67 concat_dimension_size);
68 auto new_concat = comp->AddInstruction(concat->CloneWithNewOperands(
69 new_shape, operands_span.subspan(offset, kMaxParameters)));
70 new_operands.push_back(new_concat);
71 }
72 }
73 operands_to_split = new_operands;
74 }
75 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(concat, operands_to_split[0]));
76 return true;
77 }
78
GetRelevantVariadicOps(HloComputation * comp)79 std::vector<HloInstruction*> GetRelevantVariadicOps(HloComputation* comp) {
80 std::vector<HloInstruction*> ops;
81 for (HloInstruction* instr : comp->instructions()) {
82 if (instr->opcode() == HloOpcode::kConcatenate &&
83 instr->operand_count() > kMaxParameters) {
84 ops.push_back(instr);
85 }
86 }
87 return ops;
88 }
89
90 } // namespace
91
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)92 StatusOr<bool> VariadicOpSplitter::Run(
93 HloModule* module,
94 const absl::flat_hash_set<absl::string_view>& execution_threads) {
95 bool changed = false;
96 for (HloComputation* comp :
97 module->MakeNonfusionComputations(execution_threads)) {
98 for (HloInstruction* op : GetRelevantVariadicOps(comp)) {
99 // TODO(b/112613927): Handle also other ops than concatenate.
100 TF_ASSIGN_OR_RETURN(bool result, SplitConcatenate(op, comp));
101 changed |= result;
102 }
103 }
104 return changed;
105 }
106
107 } // namespace gpu
108 } // namespace xla
109