• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_util.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/inlined_vector.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/tuple_util.h"
25 
26 namespace xla {
27 
28 using absl::StrCat;
29 
WidenWhileCondition(HloComputation * narrow_condition,const Shape & wide_shape)30 static StatusOr<HloComputation*> WidenWhileCondition(
31     HloComputation* narrow_condition, const Shape& wide_shape) {
32   const Shape& narrow_shape =
33       narrow_condition->parameter_instruction(0)->shape();
34 
35   HloComputation* wide_while_cond = [&]() {
36     HloComputation::Builder builder(StrCat("wide.", narrow_condition->name()));
37     builder.AddInstruction(
38         HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
39 
40     // This is needed so that the root instruction is shaped as a PRED[] -- we
41     // need to get this right to begin with since we can't mutate the type of
42     // the root instruction later.  We later change the root instruction to
43     // something more appropriate.
44     builder.AddInstruction(
45         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
46     return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
47   }();
48 
49   HloInstruction* truncated_parameter =
50       TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0),
51                                narrow_shape.tuple_shapes_size());
52   HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction(
53       HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}),
54                                  {truncated_parameter}, narrow_condition));
55 
56   wide_while_cond->set_root_instruction(call_narrow_cond);
57 
58   TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status());
59   return wide_while_cond;
60 }
61 
62 static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
WidenWhileBody(HloComputation * narrow_body,const Shape & wide_shape)63 WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
64   const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape();
65 
66   HloComputation* wide_while_body = [&]() {
67     HloComputation::Builder builder(StrCat("wide.", narrow_body->name()));
68     builder.AddInstruction(
69         HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
70     return narrow_body->parent()->AddEmbeddedComputation(builder.Build());
71   }();
72 
73   HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0);
74   HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
75       wide_parameter, narrow_shape.tuple_shapes_size());
76   HloInstruction* call_narrow_body =
77       wide_while_body->AddInstruction(HloInstruction::CreateCall(
78           narrow_shape, {truncated_parameter}, narrow_body));
79 
80   std::vector<HloInstruction*> live_through_values;
81   for (int i = narrow_shape.tuple_shapes_size();
82        i < wide_shape.tuple_shapes_size(); i++) {
83     live_through_values.push_back(
84         wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
85             wide_shape.tuple_shapes(i), wide_parameter, i)));
86   }
87 
88   wide_while_body->set_root_instruction(
89       TupleUtil::AppendSuffix(call_narrow_body, live_through_values));
90 
91   TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
92                       CallInliner::Inline(call_narrow_body));
93   return {{wide_while_body, std::move(inlined_instructions_map)}};
94 }
95 
96 /*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
MakeInstructionsLiveIn(HloInstruction * while_instr,absl::Span<HloInstruction * const> instructions)97 WhileUtil::MakeInstructionsLiveIn(
98     HloInstruction* while_instr,
99     absl::Span<HloInstruction* const> instructions) {
100   CHECK(while_instr->shape().IsTuple());
101 
102   int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
103   Shape new_while_shape = while_instr->shape();
104   for (auto* instruction : instructions) {
105     *new_while_shape.add_tuple_shapes() = instruction->shape();
106   }
107 
108   TF_ASSIGN_OR_RETURN(
109       HloComputation * new_while_condition,
110       WidenWhileCondition(while_instr->while_condition(), new_while_shape));
111 
112   HloComputation* new_while_body;
113   CallInliner::InlinedInstructionMap inlined_instructions_map;
114   TF_ASSIGN_OR_RETURN(
115       std::tie(new_while_body, inlined_instructions_map),
116       WidenWhileBody(while_instr->while_body(), new_while_shape));
117 
118   HloInstruction* new_while_init =
119       TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions);
120   HloComputation* containing_computation = while_instr->parent();
121   HloInstruction* new_while = containing_computation->AddInstruction(
122       HloInstruction::CreateWhile(new_while_shape, new_while_condition,
123                                   new_while_body, new_while_init));
124 
125   // We want to get rid of the old while instruction even if it has side
126   // effecting operations so we do a manual HloComputation::RemoveInstruction
127   // instead of relying on HloComputation::ReplaceInstruction.
128   HloInstruction* replacement_instr = TupleUtil::ExtractPrefix(
129       new_while, while_instr->shape().tuple_shapes_size());
130   TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(replacement_instr));
131   TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
132 
133   HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
134   std::vector<HloInstruction*> live_in_instructions;
135   for (int64 i = elements_in_old_while_shape;
136        i < new_while_shape.tuple_shapes_size(); i++) {
137     live_in_instructions.push_back(
138         new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
139             instructions[i - elements_in_old_while_shape]->shape(),
140             while_body_param, i)));
141   }
142 
143   WhileUtil::MakeInstructionsLiveInResult result;
144 
145   result.new_while_instr = new_while;
146   result.replacement_instr = replacement_instr;
147   result.while_body_live_in_values = std::move(live_in_instructions);
148   result.while_body_instruction_map = std::move(inlined_instructions_map);
149 
150   return std::move(result);
151 }
152 
153 static StatusOr<std::unique_ptr<HloComputation>>
MakeCountedLoopConditionComputation(const Shape & loop_state_shape,int32 trip_count)154 MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
155                                     int32 trip_count) {
156   Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
157 
158   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> cond_computation,
159                       CreateComputationWithSignature(
160                           {&loop_state_shape}, scalar_pred, "while_cond"));
161 
162   HloInstruction* trip_count_constant = cond_computation->AddInstruction(
163       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(trip_count)));
164 
165   HloInstruction* param = cond_computation->parameter_instruction(0);
166   TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
167                       MakeGetTupleElementHlo(param, 0));
168 
169   TF_ASSIGN_OR_RETURN(
170       HloInstruction * compare,
171       MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant));
172   cond_computation->set_root_instruction(compare);
173   return std::move(cond_computation);
174 }
175 
MakeCountedLoopBodyComputation(const Shape & loop_state_shape,const std::function<StatusOr<WhileUtil::LoopStateTy> (HloInstruction *,const WhileUtil::LoopStateTy &)> & loop_body_generator)176 static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
177     const Shape& loop_state_shape,
178     const std::function<StatusOr<WhileUtil::LoopStateTy>(
179         HloInstruction*, const WhileUtil::LoopStateTy&)>& loop_body_generator) {
180   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> body_computation,
181                       CreateComputationWithSignature(
182                           {&loop_state_shape}, loop_state_shape, "while_body"));
183   HloInstruction* one = body_computation->AddInstruction(
184       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
185   HloInstruction* param = body_computation->parameter_instruction(0);
186   TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
187                       MakeGetTupleElementHlo(param, 0));
188   TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar,
189                       MakeBinaryHlo(HloOpcode::kAdd, indvar, one));
190 
191   std::vector<HloInstruction*> loop_body_generator_args;
192   for (int64 i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) {
193     TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element,
194                         MakeGetTupleElementHlo(param, i));
195     loop_body_generator_args.push_back(tuple_element);
196   }
197   TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> next_state,
198                       loop_body_generator(indvar, loop_body_generator_args));
199   next_state.insert(next_state.begin(), next_indvar);
200   HloInstruction* next_state_tuple =
201       body_computation->AddInstruction(HloInstruction::CreateTuple(next_state));
202   body_computation->set_root_instruction(next_state_tuple);
203 
204   return std::move(body_computation);
205 }
206 
MakeInitTupleFromInitValues(HloComputation * computation,const WhileUtil::LoopStateTy & init_values)207 static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
208     HloComputation* computation, const WhileUtil::LoopStateTy& init_values) {
209   std::vector<HloInstruction*> init_values_with_indvar;
210   init_values_with_indvar.reserve(init_values.size() + 1);
211   HloInstruction* zero = computation->AddInstruction(
212       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
213   init_values_with_indvar.push_back(zero);
214   absl::c_copy(init_values, std::back_inserter(init_values_with_indvar));
215   return computation->AddInstruction(
216       HloInstruction::CreateTuple(init_values_with_indvar));
217 }
218 
219 // Returns a tuple shape containing a S32, and a shape from each value in
220 // `init_values`. If a shape from a value in `init_values` doesn't have a
221 // layout, use a default layout for the shape.
MakeLoopStateShapeWithLayout(const WhileUtil::LoopStateTy & init_values)222 static Shape MakeLoopStateShapeWithLayout(
223     const WhileUtil::LoopStateTy& init_values) {
224   std::vector<Shape> loop_state_shape_components;
225   loop_state_shape_components.reserve(init_values.size() + 1);
226   loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
227   absl::c_transform(init_values,
228                     std::back_inserter(loop_state_shape_components),
229                     [](HloInstruction* instr) {
230                       Shape shape = instr->shape();
231                       if (!shape.has_layout()) {
232                         LayoutUtil::SetToDefaultLayout(&shape);
233                       }
234                       return shape;
235                     });
236   return ShapeUtil::MakeTupleShape(loop_state_shape_components);
237 }
238 
MakeCountedLoop(HloComputation * computation,int32 trip_count,const WhileUtil::LoopStateTy & init_values,const WhileUtil::LoopBodyGeneratorTy & loop_body_generator,const OpMetadata & metadata)239 /*static*/ StatusOr<WhileUtil::LoopStateTy> WhileUtil::MakeCountedLoop(
240     HloComputation* computation, int32 trip_count,
241     const WhileUtil::LoopStateTy& init_values,
242     const WhileUtil::LoopBodyGeneratorTy& loop_body_generator,
243     const OpMetadata& metadata) {
244   CHECK_GE(trip_count, 0);
245 
246   // Both MakeCountedLoopConditionComputation and MakeCountedLoopBodyComputation
247   // use loop_state_shape to create a literal, which requires loop_state_shape
248   // to have a layout.
249   Shape loop_state_shape = MakeLoopStateShapeWithLayout(init_values);
250   TF_ASSIGN_OR_RETURN(
251       std::unique_ptr<HloComputation> cond,
252       MakeCountedLoopConditionComputation(loop_state_shape, trip_count));
253   TF_ASSIGN_OR_RETURN(
254       std::unique_ptr<HloComputation> body,
255       MakeCountedLoopBodyComputation(loop_state_shape, loop_body_generator));
256   TF_ASSIGN_OR_RETURN(HloInstruction * init_tuple,
257                       MakeInitTupleFromInitValues(computation, init_values));
258   HloModule* module = computation->parent();
259   HloInstruction* while_instr =
260       computation->AddInstruction(HloInstruction::CreateWhile(
261           loop_state_shape, module->AddEmbeddedComputation(std::move(cond)),
262           module->AddEmbeddedComputation(std::move(body)), init_tuple));
263   while_instr->set_metadata(metadata);
264 
265   std::vector<HloInstruction*> result;
266   for (int64 i = 0, e = init_values.size(); i < e; i++) {
267     TF_ASSIGN_OR_RETURN(HloInstruction * user_state,
268                         MakeGetTupleElementHlo(while_instr, i + 1));
269     result.push_back(user_state);
270   }
271   return result;
272 }
273 
GetInvariantGTEsForWhileBody(const HloComputation & while_body)274 /*static*/ std::vector<HloInstruction*> WhileUtil::GetInvariantGTEsForWhileBody(
275     const HloComputation& while_body) {
276   std::vector<HloInstruction*> result;
277   const HloInstruction::InstructionVector root_operands =
278       while_body.root_instruction()->operands();
279   for (int i = 0; i < root_operands.size(); i++) {
280     HloInstruction* instr = root_operands[i];
281     if (instr->opcode() == HloOpcode::kGetTupleElement &&
282         instr->tuple_index() == i &&
283         instr->operand(0) == while_body.parameter_instruction(0)) {
284       result.push_back(instr);
285     }
286   }
287   return result;
288 }
289 
290 /*static*/ absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>
GetGTEsMapForWhileConditional(const HloComputation & while_conditional)291 WhileUtil::GetGTEsMapForWhileConditional(
292     const HloComputation& while_conditional) {
293   absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>> result;
294   for (HloInstruction* user :
295        while_conditional.parameter_instruction(0)->users()) {
296     if (user->opcode() == HloOpcode::kGetTupleElement) {
297       result[user->tuple_index()].push_back(user);
298     }
299   }
300   return result;
301 }
302 
303 }  // namespace xla
304