1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/inlined_vector.h"
19 #include "tensorflow/compiler/xla/service/while_util.h"
20 #include "tensorflow/compiler/xla/util.h"
21
22 namespace xla {
23
24 // Replaces all uses of old_instr with new_instr except the use at
25 // `while_body_root` (which must be a tuple instruction) at index `tuple_index`.
26 // This utility helps us replace an instruction in the while body with a
27 // constant while still keeping it trivially loop invariant.
ReplaceUsesWhileKeepingLoopInvariance(HloInstruction * old_instr,HloInstruction * new_instr,HloInstruction * while_body_root,int64 tuple_index)28 static Status ReplaceUsesWhileKeepingLoopInvariance(
29 HloInstruction* old_instr, HloInstruction* new_instr,
30 HloInstruction* while_body_root, int64 tuple_index) {
31 CHECK_EQ(while_body_root->opcode(), HloOpcode::kTuple);
32
33 std::vector<HloInstruction*> users;
34 users.reserve(old_instr->user_count());
35 absl::c_copy(old_instr->users(), std::back_inserter(users));
36
37 for (auto* user : users) {
38 for (int64 i = 0, e = user->operand_count(); i < e; i++) {
39 if (user->operand(i) == old_instr &&
40 !(user == while_body_root && i == tuple_index)) {
41 TF_RETURN_IF_ERROR(user->ReplaceOperandWith(i, new_instr));
42 }
43 }
44 }
45
46 return Status::OK();
47 }
48
TrySinkingConstantsIntoWhileLoop(HloInstruction * while_instr)49 StatusOr<bool> WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop(
50 HloInstruction* while_instr) {
51 HloComputation* while_cond = while_instr->while_condition();
52 HloComputation* while_body = while_instr->while_body();
53
54 const HloInstruction& init_value = *while_instr->operand(0);
55 if (init_value.opcode() != HloOpcode::kTuple) {
56 return false;
57 }
58
59 bool changed = false;
60
61 absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>
62 conditional_gte_index_to_insts =
63 WhileUtil::GetGTEsMapForWhileConditional(*while_cond);
64 std::vector<HloInstruction*> invariant_body_gtes =
65 WhileUtil::GetInvariantGTEsForWhileBody(*while_body);
66
67 for (HloInstruction* invariant_body_gte : invariant_body_gtes) {
68 int64 index = invariant_body_gte->tuple_index();
69 const HloInstruction& invariant_value = *init_value.operand(index);
70
71 // Original value should be a constant.
72 if (invariant_value.opcode() != HloOpcode::kConstant) {
73 continue;
74 }
75
76 // Sink into the while_body.
77 // Should have at least one user that's not while_body_root.
78 if (invariant_body_gte->user_count() > 1) {
79 HloInstruction* constant_instr =
80 while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk"));
81 TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance(
82 invariant_body_gte, constant_instr, while_body->root_instruction(),
83 index));
84 changed = true;
85 }
86
87 // Check if there is a corresponding GTE in while_conditional.
88 auto it = conditional_gte_index_to_insts.find(index);
89 if (it == conditional_gte_index_to_insts.end()) {
90 continue;
91 }
92
93 for (HloInstruction* invariant_cond_gte : it->second) {
94 // Should have at least one user.
95 if (invariant_cond_gte->user_count() > 0) {
96 HloInstruction* constant_instr = while_cond->AddInstruction(
97 invariant_value.Clone(/*suffix=*/".sunk"));
98 TF_RETURN_IF_ERROR(
99 invariant_cond_gte->ReplaceAllUsesWith(constant_instr));
100 changed = true;
101 }
102 }
103 }
104
105 return changed;
106 }
107
Run(HloModule * module)108 StatusOr<bool> WhileLoopConstantSinking::Run(HloModule* module) {
109 VLOG(2) << "HLO module before WhileLoopConstantSinking:";
110 XLA_VLOG_LINES(2, module->ToString());
111
112 bool changed = false;
113 std::vector<HloInstruction*> while_instrs;
114 for (auto* comp : module->MakeNonfusionComputations()) {
115 // Right now we don't particularly care about optimizing while-of-while
116 // patterns. If/When we do, we'll want to visit the outer while (while_0)
117 // before we visit the inner while (while_1):
118 //
119 // while_1_body(state) {
120 // val = gte(state, 0) // Loop invariant
121 // use(val)
122 // }
123 //
124 // while_0_body(state) {
125 // val = gte(state, 0) // Loop invariant
126 // while_1 = while(init=tuple(val, ...), body=while_1_body, ...)
127 // ...
128 // }
129 //
130 // main {
131 // while_0 = while(init=(constant, ...), body=while_0_body, ...)
132 // }
133 //
134 // This will let us sink the constant into the outer while first and then
135 // into the inner while in a single run of this pass.
136 absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
137 [](const HloInstruction* instr) {
138 return instr->opcode() == HloOpcode::kWhile;
139 });
140 }
141
142 for (HloInstruction* while_instr : while_instrs) {
143 TF_ASSIGN_OR_RETURN(bool result,
144 TrySinkingConstantsIntoWhileLoop(while_instr));
145 changed |= result;
146 }
147
148 if (changed) {
149 VLOG(2) << "HLO module after WhileLoopConstantSinking:";
150 XLA_VLOG_LINES(2, module->ToString());
151 } else {
152 VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking";
153 }
154
155 return changed;
156 }
157 } // namespace xla
158