1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_util.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/inlined_vector.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/tuple_util.h"
25
26 namespace xla {
27
28 using absl::StrCat;
29
WidenWhileCondition(HloComputation * narrow_condition,const Shape & wide_shape)30 static StatusOr<HloComputation*> WidenWhileCondition(
31 HloComputation* narrow_condition, const Shape& wide_shape) {
32 const Shape& narrow_shape =
33 narrow_condition->parameter_instruction(0)->shape();
34
35 HloComputation* wide_while_cond = [&]() {
36 HloComputation::Builder builder(StrCat("wide.", narrow_condition->name()));
37 builder.AddInstruction(
38 HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
39
40 // This is needed so that the root instruction is shaped as a PRED[] -- we
41 // need to get this right to begin with since we can't mutate the type of
42 // the root instruction later. We later change the root instruction to
43 // something more appropriate.
44 builder.AddInstruction(
45 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
46 return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
47 }();
48
49 HloInstruction* truncated_parameter =
50 TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0),
51 narrow_shape.tuple_shapes_size());
52 HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction(
53 HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}),
54 {truncated_parameter}, narrow_condition));
55
56 wide_while_cond->set_root_instruction(call_narrow_cond);
57
58 TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status());
59 return wide_while_cond;
60 }
61
62 static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
WidenWhileBody(HloComputation * narrow_body,const Shape & wide_shape)63 WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
64 const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape();
65
66 HloComputation* wide_while_body = [&]() {
67 HloComputation::Builder builder(StrCat("wide.", narrow_body->name()));
68 builder.AddInstruction(
69 HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
70 return narrow_body->parent()->AddEmbeddedComputation(builder.Build());
71 }();
72
73 HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0);
74 HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
75 wide_parameter, narrow_shape.tuple_shapes_size());
76 HloInstruction* call_narrow_body =
77 wide_while_body->AddInstruction(HloInstruction::CreateCall(
78 narrow_shape, {truncated_parameter}, narrow_body));
79
80 std::vector<HloInstruction*> live_through_values;
81 for (int i = narrow_shape.tuple_shapes_size();
82 i < wide_shape.tuple_shapes_size(); i++) {
83 live_through_values.push_back(
84 wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
85 wide_shape.tuple_shapes(i), wide_parameter, i)));
86 }
87
88 wide_while_body->set_root_instruction(
89 TupleUtil::AppendSuffix(call_narrow_body, live_through_values));
90
91 TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
92 CallInliner::Inline(call_narrow_body));
93 return {{wide_while_body, std::move(inlined_instructions_map)}};
94 }
95
96 /*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
MakeInstructionsLiveIn(HloInstruction * while_instr,absl::Span<HloInstruction * const> instructions)97 WhileUtil::MakeInstructionsLiveIn(
98 HloInstruction* while_instr,
99 absl::Span<HloInstruction* const> instructions) {
100 CHECK(while_instr->shape().IsTuple());
101
102 int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
103 Shape new_while_shape = while_instr->shape();
104 for (auto* instruction : instructions) {
105 *new_while_shape.add_tuple_shapes() = instruction->shape();
106 }
107
108 TF_ASSIGN_OR_RETURN(
109 HloComputation * new_while_condition,
110 WidenWhileCondition(while_instr->while_condition(), new_while_shape));
111
112 HloComputation* new_while_body;
113 CallInliner::InlinedInstructionMap inlined_instructions_map;
114 TF_ASSIGN_OR_RETURN(
115 std::tie(new_while_body, inlined_instructions_map),
116 WidenWhileBody(while_instr->while_body(), new_while_shape));
117
118 HloInstruction* new_while_init =
119 TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions);
120 HloComputation* containing_computation = while_instr->parent();
121 HloInstruction* new_while = containing_computation->AddInstruction(
122 HloInstruction::CreateWhile(new_while_shape, new_while_condition,
123 new_while_body, new_while_init));
124
125 // We want to get rid of the old while instruction even if it has side
126 // effecting operations so we do a manual HloComputation::RemoveInstruction
127 // instead of relying on HloComputation::ReplaceInstruction.
128 HloInstruction* replacement_instr = TupleUtil::ExtractPrefix(
129 new_while, while_instr->shape().tuple_shapes_size());
130 TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(replacement_instr));
131 TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
132
133 HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
134 std::vector<HloInstruction*> live_in_instructions;
135 for (int64 i = elements_in_old_while_shape;
136 i < new_while_shape.tuple_shapes_size(); i++) {
137 live_in_instructions.push_back(
138 new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
139 instructions[i - elements_in_old_while_shape]->shape(),
140 while_body_param, i)));
141 }
142
143 WhileUtil::MakeInstructionsLiveInResult result;
144
145 result.new_while_instr = new_while;
146 result.replacement_instr = replacement_instr;
147 result.while_body_live_in_values = std::move(live_in_instructions);
148 result.while_body_instruction_map = std::move(inlined_instructions_map);
149
150 return std::move(result);
151 }
152
153 static StatusOr<std::unique_ptr<HloComputation>>
MakeCountedLoopConditionComputation(const Shape & loop_state_shape,int32 trip_count)154 MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
155 int32 trip_count) {
156 Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
157
158 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> cond_computation,
159 CreateComputationWithSignature(
160 {&loop_state_shape}, scalar_pred, "while_cond"));
161
162 HloInstruction* trip_count_constant = cond_computation->AddInstruction(
163 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(trip_count)));
164
165 HloInstruction* param = cond_computation->parameter_instruction(0);
166 TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
167 MakeGetTupleElementHlo(param, 0));
168
169 TF_ASSIGN_OR_RETURN(
170 HloInstruction * compare,
171 MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant));
172 cond_computation->set_root_instruction(compare);
173 return std::move(cond_computation);
174 }
175
MakeCountedLoopBodyComputation(const Shape & loop_state_shape,const std::function<StatusOr<WhileUtil::LoopStateTy> (HloInstruction *,const WhileUtil::LoopStateTy &)> & loop_body_generator)176 static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
177 const Shape& loop_state_shape,
178 const std::function<StatusOr<WhileUtil::LoopStateTy>(
179 HloInstruction*, const WhileUtil::LoopStateTy&)>& loop_body_generator) {
180 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> body_computation,
181 CreateComputationWithSignature(
182 {&loop_state_shape}, loop_state_shape, "while_body"));
183 HloInstruction* one = body_computation->AddInstruction(
184 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
185 HloInstruction* param = body_computation->parameter_instruction(0);
186 TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
187 MakeGetTupleElementHlo(param, 0));
188 TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar,
189 MakeBinaryHlo(HloOpcode::kAdd, indvar, one));
190
191 std::vector<HloInstruction*> loop_body_generator_args;
192 for (int64 i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) {
193 TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element,
194 MakeGetTupleElementHlo(param, i));
195 loop_body_generator_args.push_back(tuple_element);
196 }
197 TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> next_state,
198 loop_body_generator(indvar, loop_body_generator_args));
199 next_state.insert(next_state.begin(), next_indvar);
200 HloInstruction* next_state_tuple =
201 body_computation->AddInstruction(HloInstruction::CreateTuple(next_state));
202 body_computation->set_root_instruction(next_state_tuple);
203
204 return std::move(body_computation);
205 }
206
MakeInitTupleFromInitValues(HloComputation * computation,const WhileUtil::LoopStateTy & init_values)207 static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
208 HloComputation* computation, const WhileUtil::LoopStateTy& init_values) {
209 std::vector<HloInstruction*> init_values_with_indvar;
210 init_values_with_indvar.reserve(init_values.size() + 1);
211 HloInstruction* zero = computation->AddInstruction(
212 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
213 init_values_with_indvar.push_back(zero);
214 absl::c_copy(init_values, std::back_inserter(init_values_with_indvar));
215 return computation->AddInstruction(
216 HloInstruction::CreateTuple(init_values_with_indvar));
217 }
218
219 // Returns a tuple shape containing a S32, and a shape from each value in
220 // `init_values`. If a shape from a value in `init_values` doesn't have a
221 // layout, use a default layout for the shape.
MakeLoopStateShapeWithLayout(const WhileUtil::LoopStateTy & init_values)222 static Shape MakeLoopStateShapeWithLayout(
223 const WhileUtil::LoopStateTy& init_values) {
224 std::vector<Shape> loop_state_shape_components;
225 loop_state_shape_components.reserve(init_values.size() + 1);
226 loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
227 absl::c_transform(init_values,
228 std::back_inserter(loop_state_shape_components),
229 [](HloInstruction* instr) {
230 Shape shape = instr->shape();
231 if (!shape.has_layout()) {
232 LayoutUtil::SetToDefaultLayout(&shape);
233 }
234 return shape;
235 });
236 return ShapeUtil::MakeTupleShape(loop_state_shape_components);
237 }
238
MakeCountedLoop(HloComputation * computation,int32 trip_count,const WhileUtil::LoopStateTy & init_values,const WhileUtil::LoopBodyGeneratorTy & loop_body_generator,const OpMetadata & metadata)239 /*static*/ StatusOr<WhileUtil::LoopStateTy> WhileUtil::MakeCountedLoop(
240 HloComputation* computation, int32 trip_count,
241 const WhileUtil::LoopStateTy& init_values,
242 const WhileUtil::LoopBodyGeneratorTy& loop_body_generator,
243 const OpMetadata& metadata) {
244 CHECK_GE(trip_count, 0);
245
246 // Both MakeCountedLoopConditionComputation and MakeCountedLoopBodyComputation
247 // use loop_state_shape to create a literal, which requires loop_state_shape
248 // to have a layout.
249 Shape loop_state_shape = MakeLoopStateShapeWithLayout(init_values);
250 TF_ASSIGN_OR_RETURN(
251 std::unique_ptr<HloComputation> cond,
252 MakeCountedLoopConditionComputation(loop_state_shape, trip_count));
253 TF_ASSIGN_OR_RETURN(
254 std::unique_ptr<HloComputation> body,
255 MakeCountedLoopBodyComputation(loop_state_shape, loop_body_generator));
256 TF_ASSIGN_OR_RETURN(HloInstruction * init_tuple,
257 MakeInitTupleFromInitValues(computation, init_values));
258 HloModule* module = computation->parent();
259 HloInstruction* while_instr =
260 computation->AddInstruction(HloInstruction::CreateWhile(
261 loop_state_shape, module->AddEmbeddedComputation(std::move(cond)),
262 module->AddEmbeddedComputation(std::move(body)), init_tuple));
263 while_instr->set_metadata(metadata);
264
265 std::vector<HloInstruction*> result;
266 for (int64 i = 0, e = init_values.size(); i < e; i++) {
267 TF_ASSIGN_OR_RETURN(HloInstruction * user_state,
268 MakeGetTupleElementHlo(while_instr, i + 1));
269 result.push_back(user_state);
270 }
271 return result;
272 }
273
GetInvariantGTEsForWhileBody(const HloComputation & while_body)274 /*static*/ std::vector<HloInstruction*> WhileUtil::GetInvariantGTEsForWhileBody(
275 const HloComputation& while_body) {
276 std::vector<HloInstruction*> result;
277 const HloInstruction::InstructionVector root_operands =
278 while_body.root_instruction()->operands();
279 for (int i = 0; i < root_operands.size(); i++) {
280 HloInstruction* instr = root_operands[i];
281 if (instr->opcode() == HloOpcode::kGetTupleElement &&
282 instr->tuple_index() == i &&
283 instr->operand(0) == while_body.parameter_instruction(0)) {
284 result.push_back(instr);
285 }
286 }
287 return result;
288 }
289
290 /*static*/ absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>
GetGTEsMapForWhileConditional(const HloComputation & while_conditional)291 WhileUtil::GetGTEsMapForWhileConditional(
292 const HloComputation& while_conditional) {
293 absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>> result;
294 for (HloInstruction* user :
295 while_conditional.parameter_instruction(0)->users()) {
296 if (user->opcode() == HloOpcode::kGetTupleElement) {
297 result[user->tuple_index()].push_back(user);
298 }
299 }
300 return result;
301 }
302
303 } // namespace xla
304