1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_phi_graph.h"
17
18 #include <queue>
19
20 namespace xla {
GetOptimizedId(const HloValue & value)21 HloValue::Id PhiGraph::GetOptimizedId(const HloValue& value) {
22 Node* node = value_id_to_node_[value.id()];
23 CHECK(!node->mark_as_dead);
24 return node->value_id;
25 }
26
27 // Returns true if the inputs to a hlo value are the same as `inputs`.
InputsEqualTo(const HloValue & value,absl::Span<const HloValue * const> inputs)28 bool PhiGraph::InputsEqualTo(const HloValue& value,
29 absl::Span<const HloValue* const> inputs) {
30 auto iter = value_id_to_node_.find(value.id());
31 CHECK(iter != value_id_to_node_.end());
32 absl::flat_hash_set<HloValue::Id> existing_set;
33 for (Node* operand : iter->second->operands) {
34 existing_set.insert(operand->value_id);
35 }
36 absl::flat_hash_set<HloValue::Id> new_set;
37 for (const HloValue* input : inputs) {
38 new_set.insert(input->id());
39 }
40 return existing_set == new_set;
41 }
42
FindOptimizedValue(const HloValue::Id id)43 HloValue::Id PhiGraph::FindOptimizedValue(const HloValue::Id id) {
44 auto iter = value_id_to_node_.find(id);
45 CHECK(iter != value_id_to_node_.end());
46 CHECK(!iter->second->mark_as_dead);
47 return iter->second->value_id;
48 }
49
CreateOrReuseNode(const HloValue & value)50 PhiGraph::Node* PhiGraph::CreateOrReuseNode(const HloValue& value) {
51 auto iter = value_id_to_node_.find(value.id());
52 if (iter == value_id_to_node_.end()) {
53 node_storage_.emplace_back(absl::make_unique<Node>());
54 Node* node = node_storage_.back().get();
55 node->value_id = value.id();
56 value_id_to_node_[value.id()] = node;
57 node_to_value_id_[node].push_back(value.id());
58 return node;
59 } else {
60 // A node is already registered with this value, check the value_id
61 // is the same as previously registrated.
62 CHECK_NE(iter->second, nullptr);
63 CHECK_EQ(iter->second->value_id, value.id());
64 return iter->second;
65 }
66 }
67
ReplaceNodeWith(PhiGraph::Node * node,PhiGraph::Node * replace)68 void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) {
69 // Update users.
70 CHECK(node->is_phi);
71 if (node->mark_as_dead) {
72 // The node has already been replaced with another.
73 return;
74 }
75 if (replace->mark_as_dead) {
76 // The node we are placing with has already been replaced with another node.
77 auto iter = value_id_to_node_.find(replace->value_id);
78 CHECK(iter != value_id_to_node_.end());
79 return ReplaceNodeWith(node, iter->second);
80 }
81 CHECK(!replace->mark_as_dead);
82 for (Node* user : node->users) {
83 absl::c_replace(user->operands, node, replace);
84 }
85
86 // Update operand's users
87 for (Node* operand : node->operands) {
88 absl::c_replace(operand->users, node, replace);
89 }
90
91 for (HloValue::Id value_id : node_to_value_id_[node]) {
92 CHECK(value_id_to_node_.contains(value_id));
93 value_id_to_node_[value_id] = replace;
94 }
95 // Update mappings to HloValue::Id.
96 absl::c_copy(node_to_value_id_[node],
97 std::back_inserter(node_to_value_id_[replace]));
98 node_to_value_id_[node].clear();
99 node->mark_as_dead = true;
100 }
101
RegisterPhi(const HloValue & value,absl::Span<const HloValue * const> inputs)102 void PhiGraph::RegisterPhi(const HloValue& value,
103 absl::Span<const HloValue* const> inputs) {
104 Node* node = CreateOrReuseNode(value);
105 CHECK(value.is_phi());
106 node->is_phi = true;
107 node->operands.clear();
108 for (auto input : inputs) {
109 CHECK(input != nullptr);
110 Node* input_node = CreateOrReuseNode(*input);
111 node->operands.push_back(input_node);
112 }
113 }
114
ToString()115 std::string PhiGraph::ToString() {
116 std::string out = "PhiGraph: \n";
117 for (auto& node : node_storage_) {
118 std::string is_phi = node->is_phi ? ", phi" : "";
119 std::string is_optimized = node->mark_as_dead ? ", dead" : "";
120 absl::StrAppend(&out, node->value_id);
121 absl::StrAppend(&out, is_phi);
122 absl::StrAppend(&out, is_optimized, ":\n");
123 for (Node* input : node->operands) {
124 absl::StrAppend(&out, " ", input->value_id);
125 absl::StrAppend(&out, "\n");
126 }
127 }
128 return out;
129 }
130
Optimize()131 void PhiGraph::Optimize() {
132 VLOG(2) << "Optimizing phi graph:";
133 XLA_VLOG_LINES(2, ToString());
134 // Set up users for each node.
135 for (auto& node : node_storage_) {
136 for (Node* input : node->operands) {
137 input->users.push_back(node.get());
138 }
139 }
140
141 // input_node->users.push_back(node);
142 bool changed = true;
143
144 // Run the optimization to a fixed point.
145 while (changed) {
146 changed = false;
147 absl::flat_hash_set<Node*> checked_for_closure;
148 for (auto& node : node_storage_) {
149 // Only optimize phi node.
150 if (!node->is_phi) {
151 continue;
152 }
153 // Skip dead nodes
154 if (node->mark_as_dead) {
155 continue;
156 }
157
158 Node* node_ptr = node.get();
159
160 VLOG(2) << "Optimizing: " << node_ptr->value_id;
161
162 CHECK_GE(node_ptr->operands.size(), 1);
163
164 // Remove self-referencing ids from users and operands.
165 auto it = absl::c_find(node_ptr->operands, node_ptr);
166 while (it != node_ptr->operands.end()) {
167 node_ptr->operands.erase(it);
168 it = absl::c_find(node_ptr->operands, node_ptr);
169 }
170
171 it = absl::c_find(node_ptr->users, node_ptr);
172 while (it != node_ptr->users.end()) {
173 node_ptr->users.erase(it);
174 it = absl::c_find(node_ptr->users, node_ptr);
175 }
176
177 // If all inputs to phi (after self referencing ids are removed) are the
178 // same value, replace the phi with that value.
179 //
180 // phi(A, A, ... A) => A
181 // phi(A, self) = phi(A) => A
182 CHECK_GE(node_ptr->operands.size(), 1);
183 bool all_inputs_are_same = absl::c_all_of(
184 node_ptr->operands,
185 [&](Node* elem) { return elem == node_ptr->operands[0]; });
186
187 if (all_inputs_are_same) {
188 VLOG(1) << "All inputs to node " << node_ptr->value_id
189 << " are the same, replacing it with "
190 << node_ptr->operands[0]->value_id;
191 ReplaceNodeWith(node_ptr, node_ptr->operands[0]);
192 changed = true;
193 continue;
194 }
195
196 // Find a closure of inter-connected phis and one non-phi node. Replace
197 // all phis with that non-phi node.
198 //
199 // def A = phi(B, C)
200 // def B = phi(C, D)
201 // def C = phi(A, B)
202 // def D = non-phi
203 // Replace A, B, and C with D:
204 // A = phi(B, C) => D
205 // B = phi(C, D) => D
206 // C = phi(A, B) => D
207 if (checked_for_closure.contains(node_ptr)) {
208 continue;
209 }
210 // Keeps track of nodes in the current closure being tested.
211 absl::flat_hash_set<Node*> workset;
212 std::queue<Node*> worklist;
213 Node* non_phi = nullptr;
214 worklist.push(node_ptr);
215 while (!worklist.empty()) {
216 Node* todo = worklist.front();
217 worklist.pop();
218 if (workset.contains(todo)) {
219 continue;
220 }
221 checked_for_closure.insert(todo);
222 workset.insert(todo);
223 for (Node* operand : todo->operands) {
224 worklist.push(operand);
225 }
226 if (!todo->is_phi) {
227 if (non_phi != nullptr && non_phi != todo) {
228 // We see distinct non-phi nodes in the closure, can't apply the
229 // optimization.
230 non_phi = nullptr;
231 // Break the while loop non_phi setting to nullptr, signaling that
232 // the optimization can't be applied.
233 break;
234 } else {
235 // This is the non_phi node we are seeing so far.
236 non_phi = todo;
237 }
238 }
239 }
240 if (non_phi != nullptr) {
241 // Replace all phi nodes in the closure/workset with the non_phi node.
242 for (Node* node : workset) {
243 if (!node->is_phi) {
244 CHECK_EQ(node, non_phi);
245 continue;
246 }
247 VLOG(1) << "Replace node " << node->value_id
248 << " in the closure with node " << non_phi->value_id;
249 ReplaceNodeWith(node, non_phi);
250 changed = true;
251 }
252 }
253 }
254 }
255 }
256 } // namespace xla
257