1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/control_flow_deps_to_chains.h"
17 #include <cstdint>
18
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/op_def_builder.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/strcat.h"
26 #include "tensorflow/core/util/dump_graph.h"
27
28 namespace tensorflow {
29
30 // TODO(mdan): Move this into Grappler - cleaner interface.
Run(const GraphOptimizationPassOptions & options)31 Status ControlFlowDepsToChainsPass::Run(
32 const GraphOptimizationPassOptions& options) {
33 VLOG(1) << "ControlFlowDepsToChainsPass::Run";
34
35 if (options.graph == nullptr) {
36 VLOG(1) << "ControlFlowDepsToChainsPass::Run Aborted";
37 return Status::OK();
38 }
39
40 Graph* g = options.graph->get();
41 DCHECK(g != nullptr);
42 FunctionLibraryDefinition* flib_def = options.flib_def;
43 DCHECK(flib_def != nullptr);
44
45 if (VLOG_IS_ON(1)) {
46 DumpGraphToFile("control_flow_deps_to_chains_before", *g, flib_def);
47 }
48
49 for (Node* n : g->nodes()) {
50 if (n == nullptr) continue;
51 if (!n->IsWhileNode()) continue;
52
53 // TODO(mdan): This breaks encapsulation of Node/Graph. Is there any needed?
54 // TODO(mdan): Consolidate this with AddWhileInputHack.
55 NodeDef* while_node = n->mutable_def();
56 const auto& attrs = while_node->attr();
57 auto* mattrs = while_node->mutable_attr();
58
59 string body_name = attrs.at("body").func().name();
60 auto* body_graph = flib_def->Find(body_name);
61 DCHECK(body_graph != nullptr);
62
63 // Look for required annotations.
64
65 if (attrs.find("_stateful_parallelism") == attrs.end()) continue;
66 if (!attrs.at("_stateful_parallelism").b()) continue;
67 // TODO(mdan): We don't really need this attribute.
68 if (attrs.find("_num_original_outputs") == attrs.end()) continue;
69 int body_barrier_loc = -1;
70 std::map<string, int> node_index;
71 for (int i = 0, s = body_graph->node_def_size(); i < s; i++) {
72 node_index.emplace(body_graph->node_def(i).name(), i);
73 if (body_barrier_loc < 0) {
74 const auto& node_attr = body_graph->node_def(i).attr();
75 if (node_attr.find("_acd_function_control_output") != node_attr.end()) {
76 body_barrier_loc = i;
77 }
78 }
79 }
80 if (body_barrier_loc < 0) continue;
81 bool ok_for_lowering = true;
82 for (int i = 0; i < body_graph->control_ret_size(); i++) {
83 const auto& control_node = body_graph->node_def(
84 node_index[body_graph->signature().control_output(i)]);
85 const auto& control_attr = control_node.attr();
86 if (control_attr.find("_res_first_used_by") == control_attr.end()) {
87 ok_for_lowering = false;
88 break;
89 }
90 }
91 if (!ok_for_lowering) continue;
92
93 int num_loop_vars = body_graph->signature().input_arg_size();
94 int num_new_chains = body_graph->control_ret_size();
95 int num_node_inputs = while_node->input_size();
96
97 if (!num_new_chains) continue; // Nothing to do for stateless loops.
98
99 // Add extra loop vars to the while node.
100
101 // TODO(mdan): If the loop vars contains the resource, we should reuse it.
102 // Note that stateful ops of resource inputs cause their resources to be
103 // captured into the loop vars (through the body/cond captures). We could
104 // effectively use those as chains.
105
106 // TODO(mdan): Is there a more efficient way to do this?
107 // Insert the new While node inputs: at the end of the loop vars, but before
108 // any non-loop var inputs (like control dependencies). Once the initial
109 // chain values are created below, they will be added to these inputs.
110 for (int i = 0; i < num_new_chains; i++) {
111 while_node->add_input();
112 }
113 for (int i = num_node_inputs - 1; i >= num_loop_vars; i--) {
114 while_node->set_input(i + num_new_chains, while_node->input(i));
115 }
116
117 std::vector<Node*> new_inputs;
118 std::vector<int> new_input_locations;
119 // Set their name to a gensym, type to float and shape to scalar.
120 for (int i = 0; i < num_new_chains; i++) {
121 string c_name = g->NewName("acd__chain");
122
123 // The initial value for the i'th chain loop var.
124 NodeDef new_in;
125 new_in.set_name(c_name);
126 new_in.set_op("Const");
127 AttrValue att_dtype;
128 att_dtype.set_type(DT_FLOAT);
129 new_in.mutable_attr()->insert({"dtype", att_dtype});
130 AttrValue att_value;
131 att_value.mutable_tensor()->set_dtype(DT_FLOAT);
132 att_value.mutable_tensor()->mutable_tensor_shape();
133 att_value.mutable_tensor()->add_int_val(0);
134 new_in.mutable_attr()->insert({"value", att_value});
135 Status status;
136 new_inputs.push_back(g->AddNode(new_in, &status));
137 TF_RETURN_WITH_CONTEXT_IF_ERROR(status, "while creating chain", c_name);
138
139 int loc = num_loop_vars + i;
140 new_input_locations.push_back(loc);
141 while_node->set_input(loc, c_name);
142 mattrs->at("T").mutable_list()->add_type(DT_FLOAT);
143 mattrs->at("output_shapes").mutable_list()->add_shape();
144 }
145
146 // TODO(mdan): This should not be necessary to update. Delete?
147 mattrs->at("_num_original_outputs").set_i(num_loop_vars + num_new_chains);
148 n->UpdateProperties();
149 for (int i = 0; i < num_new_chains; i++) {
150 g->AddEdge(new_inputs[i], 0, n, new_input_locations[i]);
151 }
152
153 // TODO(mdan): This is wasteful. Can we just mutate the original proto?
154 FunctionDef modified_body = *body_graph;
155
156 // Disable the global end-of-body barrier from the body function.
157 // Because removing a node is too inefficient (would have to walk all the
158 // inputs of all graph nodes), we instead clear its control dependencies.
159 modified_body.mutable_node_def(body_barrier_loc)->clear_input();
160
161 // Add extra loop vars to the body function.
162
163 for (int i = 0; i < num_new_chains; i++) {
164 // Input loop vars.
165 // TODO(mdan): Double check that this doesn't clash with names in body.
166 string c_name = g->NewName("acd__chainv");
167 std::replace(c_name.begin(), c_name.end(), '/', '_');
168 auto* new_arg = modified_body.mutable_signature()->add_input_arg();
169 new_arg->set_name(c_name);
170 new_arg->set_type(DT_FLOAT);
171
172 // Output ops. These are copies of the inputs conditioned on the actual
173 // control outputs.
174 string c_out_name = g->NewName("acd__outchain");
175 auto* new_out = modified_body.add_node_def();
176 new_out->set_name(c_out_name);
177 new_out->set_op("Identity");
178 new_out->add_input(c_name);
179 new_out->add_input(
180 strings::StrCat("^", body_graph->signature().control_output(i)));
181 AttrValue attr;
182 attr.set_type(DT_FLOAT);
183 new_out->mutable_attr()->insert({"T", attr});
184
185 // Output loop var declarations.
186 string c_ret_name = c_out_name;
187 std::replace(c_ret_name.begin(), c_ret_name.end(), '/', '_');
188 auto* new_out_arg = modified_body.mutable_signature()->add_output_arg();
189 new_out_arg->set_name(c_ret_name);
190 new_out_arg->set_type(DT_FLOAT);
191
192 // Actual output loop vars.
193 modified_body.mutable_ret()->insert(
194 {c_ret_name, strings::StrCat(c_out_name, ":output:0")});
195 AttrValue attr_val;
196 attr_val.mutable_list()->mutable_shape();
197 FunctionDef_ArgAttrs arg_attrs;
198 arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val});
199 modified_body.mutable_arg_attr()->insert(
200 {static_cast<uint32_t>(i + num_loop_vars), arg_attrs});
201 }
202
203 // Wire chain loop vars to the ops they need to condition.
204
205 node_index.clear();
206 for (int i = 0; i < modified_body.node_def_size(); i++) {
207 node_index.emplace(modified_body.node_def(i).name(), i);
208 }
209 auto& modified_sig = modified_body.signature();
210 for (int i = 0; i < num_new_chains; i++) {
211 const auto& control_node =
212 modified_body.node_def(node_index[modified_sig.control_output(i)]);
213 for (const auto& r :
214 control_node.attr().at("_res_first_used_by").list().s()) {
215 NodeDef* first_node = modified_body.mutable_node_def(node_index[r]);
216 // This control dependency ensures proper sequencing of stateful ops
217 // upon entry into the loop body, so that they run after the ops
218 // which affected the same resource in the previous iteration.
219 first_node->add_input(strings::StrCat(
220 "^", modified_sig.input_arg(i + num_loop_vars).name()));
221 }
222 }
223
224 // Clear body function's control returns.
225 modified_body.mutable_control_ret()->clear();
226
227 // Add extra loop vars to the cond function.
228
229 // TODO(mdan): This is wasteful. Can't we just mutate the original proto?
230 string cond_name = attrs.at("cond").func().name();
231 auto* cond_graph = flib_def->Find(cond_name);
232 DCHECK(cond_graph != nullptr);
233 FunctionDef modified_cond = *cond_graph;
234
235 int cond_barrier_loc = -1;
236 for (int i = 0, s = cond_graph->node_def_size(); i < s; i++) {
237 if (cond_barrier_loc < 0) {
238 const auto& node_attr = cond_graph->node_def(i).attr();
239 if (node_attr.find("_acd_function_control_output") != node_attr.end()) {
240 cond_barrier_loc = i;
241 }
242 }
243 }
244 if (cond_barrier_loc > 0) {
245 // Disable the global end-of-body barrier from the cond function.
246 // Because removing a node is too inefficient (would have to walk all the
247 // inputs of all graph nodes), we instead clear its control dependencies.
248 modified_cond.mutable_node_def(cond_barrier_loc)->clear_input();
249 }
250
251 for (int i = 0; i < num_new_chains; i++) {
252 // Input loop vars.
253 // TODO(mdan): These should gate the stateful ops in the cond.
254 // Until ACD supplies the necessary information, these are dummies in this
255 // function.
256 string c_name = g->NewName("acd__chain");
257 auto* new_arg = modified_cond.mutable_signature()->add_input_arg();
258 new_arg->set_name(c_name);
259 new_arg->set_type(DT_FLOAT);
260
261 // TODO(mdan): Return values on the cond function? Most likely a bug.
262 AttrValue attr_val;
263 attr_val.mutable_list()->mutable_shape();
264 FunctionDef_ArgAttrs arg_attrs;
265 arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val});
266 modified_cond.mutable_arg_attr()->insert(
267 {static_cast<uint32_t>(i + num_loop_vars), arg_attrs});
268 }
269
270 // Wire the new cond/body functions to the While node.
271
272 string new_cond_name = g->NewName("acd__while_cond");
273 modified_cond.mutable_signature()->set_name(new_cond_name);
274 mattrs->at("cond").mutable_func()->set_name(new_cond_name);
275
276 string new_body_name = g->NewName("acd__while_body");
277 modified_body.mutable_signature()->set_name(new_body_name);
278 mattrs->at("body").mutable_func()->set_name(new_body_name);
279
280 // Commit the new functions.
281
282 TF_RETURN_WITH_CONTEXT_IF_ERROR(
283 flib_def->AddFunctionDef(modified_body,
284 flib_def->GetStackTraces(body_name)),
285 "while attaching ", new_body_name, " to flib_def");
286 TF_RETURN_WITH_CONTEXT_IF_ERROR(
287 flib_def->AddFunctionDef(modified_cond,
288 flib_def->GetStackTraces(cond_name)),
289 "while attaching ", new_cond_name, " to flib_def");
290
291 // TODO(b/183666205): This should not be necessary.
292 // It's unclear why adding the functions here is also required.
293 // Moreover, it's unclear when graph_lib's parent is flib_def itself.
294 auto* graph_lib = g->mutable_flib_def();
295 if (graph_lib->default_registry() != flib_def) {
296 TF_RETURN_WITH_CONTEXT_IF_ERROR(
297 graph_lib->AddFunctionDef(modified_body,
298 graph_lib->GetStackTraces(body_name)),
299 "while attaching ", new_body_name, " to graph");
300 TF_RETURN_WITH_CONTEXT_IF_ERROR(
301 graph_lib->AddFunctionDef(modified_cond,
302 graph_lib->GetStackTraces(cond_name)),
303 "while attaching ", new_cond_name, " to graph");
304 }
305 }
306
307 if (VLOG_IS_ON(1)) {
308 DumpGraphToFile("control_flow_deps_to_chains_after", *g, flib_def);
309 }
310
311 return Status::OK();
312 }
313
314 // Note: This needs to run before functional control flow lowering, which is 10.
315 REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 9,
316 ControlFlowDepsToChainsPass);
317
318 } // namespace tensorflow
319