• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/control_flow_deps_to_chains.h"
17 #include <cstdint>
18 
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/op_def_builder.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/strcat.h"
26 #include "tensorflow/core/util/dump_graph.h"
27 
28 namespace tensorflow {
29 
30 // TODO(mdan): Move this into Grappler - cleaner interface.
Run(const GraphOptimizationPassOptions & options)31 Status ControlFlowDepsToChainsPass::Run(
32     const GraphOptimizationPassOptions& options) {
33   VLOG(1) << "ControlFlowDepsToChainsPass::Run";
34 
35   if (options.graph == nullptr) {
36     VLOG(1) << "ControlFlowDepsToChainsPass::Run Aborted";
37     return Status::OK();
38   }
39 
40   Graph* g = options.graph->get();
41   DCHECK(g != nullptr);
42   FunctionLibraryDefinition* flib_def = options.flib_def;
43   DCHECK(flib_def != nullptr);
44 
45   if (VLOG_IS_ON(1)) {
46     DumpGraphToFile("control_flow_deps_to_chains_before", *g, flib_def);
47   }
48 
49   for (Node* n : g->nodes()) {
50     if (n == nullptr) continue;
51     if (!n->IsWhileNode()) continue;
52 
53     // TODO(mdan): This breaks encapsulation of Node/Graph. Is there any needed?
54     // TODO(mdan): Consolidate this with AddWhileInputHack.
55     NodeDef* while_node = n->mutable_def();
56     const auto& attrs = while_node->attr();
57     auto* mattrs = while_node->mutable_attr();
58 
59     string body_name = attrs.at("body").func().name();
60     auto* body_graph = flib_def->Find(body_name);
61     DCHECK(body_graph != nullptr);
62 
63     // Look for required annotations.
64 
65     if (attrs.find("_stateful_parallelism") == attrs.end()) continue;
66     if (!attrs.at("_stateful_parallelism").b()) continue;
67     // TODO(mdan): We don't really need this attribute.
68     if (attrs.find("_num_original_outputs") == attrs.end()) continue;
69     int body_barrier_loc = -1;
70     std::map<string, int> node_index;
71     for (int i = 0, s = body_graph->node_def_size(); i < s; i++) {
72       node_index.emplace(body_graph->node_def(i).name(), i);
73       if (body_barrier_loc < 0) {
74         const auto& node_attr = body_graph->node_def(i).attr();
75         if (node_attr.find("_acd_function_control_output") != node_attr.end()) {
76           body_barrier_loc = i;
77         }
78       }
79     }
80     if (body_barrier_loc < 0) continue;
81     bool ok_for_lowering = true;
82     for (int i = 0; i < body_graph->control_ret_size(); i++) {
83       const auto& control_node = body_graph->node_def(
84           node_index[body_graph->signature().control_output(i)]);
85       const auto& control_attr = control_node.attr();
86       if (control_attr.find("_res_first_used_by") == control_attr.end()) {
87         ok_for_lowering = false;
88         break;
89       }
90     }
91     if (!ok_for_lowering) continue;
92 
93     int num_loop_vars = body_graph->signature().input_arg_size();
94     int num_new_chains = body_graph->control_ret_size();
95     int num_node_inputs = while_node->input_size();
96 
97     if (!num_new_chains) continue;  // Nothing to do for stateless loops.
98 
99     // Add extra loop vars to the while node.
100 
101     // TODO(mdan): If the loop vars contains the resource, we should reuse it.
102     // Note that stateful ops of resource inputs cause their resources to be
103     // captured into the loop vars (through the body/cond captures). We could
104     // effectively use those as chains.
105 
106     // TODO(mdan): Is there a more efficient way to do this?
107     // Insert the new While node inputs: at the end of the loop vars, but before
108     // any non-loop var inputs (like control dependencies). Once the initial
109     // chain values are created below, they will be added to these inputs.
110     for (int i = 0; i < num_new_chains; i++) {
111       while_node->add_input();
112     }
113     for (int i = num_node_inputs - 1; i >= num_loop_vars; i--) {
114       while_node->set_input(i + num_new_chains, while_node->input(i));
115     }
116 
117     std::vector<Node*> new_inputs;
118     std::vector<int> new_input_locations;
119     // Set their name to a gensym, type to float and shape to scalar.
120     for (int i = 0; i < num_new_chains; i++) {
121       string c_name = g->NewName("acd__chain");
122 
123       // The initial value for the i'th chain loop var.
124       NodeDef new_in;
125       new_in.set_name(c_name);
126       new_in.set_op("Const");
127       AttrValue att_dtype;
128       att_dtype.set_type(DT_FLOAT);
129       new_in.mutable_attr()->insert({"dtype", att_dtype});
130       AttrValue att_value;
131       att_value.mutable_tensor()->set_dtype(DT_FLOAT);
132       att_value.mutable_tensor()->mutable_tensor_shape();
133       att_value.mutable_tensor()->add_int_val(0);
134       new_in.mutable_attr()->insert({"value", att_value});
135       Status status;
136       new_inputs.push_back(g->AddNode(new_in, &status));
137       TF_RETURN_WITH_CONTEXT_IF_ERROR(status, "while creating chain", c_name);
138 
139       int loc = num_loop_vars + i;
140       new_input_locations.push_back(loc);
141       while_node->set_input(loc, c_name);
142       mattrs->at("T").mutable_list()->add_type(DT_FLOAT);
143       mattrs->at("output_shapes").mutable_list()->add_shape();
144     }
145 
146     // TODO(mdan): This should not be necessary to update. Delete?
147     mattrs->at("_num_original_outputs").set_i(num_loop_vars + num_new_chains);
148     n->UpdateProperties();
149     for (int i = 0; i < num_new_chains; i++) {
150       g->AddEdge(new_inputs[i], 0, n, new_input_locations[i]);
151     }
152 
153     // TODO(mdan): This is wasteful. Can we just mutate the original proto?
154     FunctionDef modified_body = *body_graph;
155 
156     // Disable the global end-of-body barrier from the body function.
157     // Because removing a node is too inefficient (would have to walk all the
158     // inputs of all graph nodes), we instead clear its control dependencies.
159     modified_body.mutable_node_def(body_barrier_loc)->clear_input();
160 
161     // Add extra loop vars to the body function.
162 
163     for (int i = 0; i < num_new_chains; i++) {
164       // Input loop vars.
165       // TODO(mdan): Double check that this doesn't clash with names in body.
166       string c_name = g->NewName("acd__chainv");
167       std::replace(c_name.begin(), c_name.end(), '/', '_');
168       auto* new_arg = modified_body.mutable_signature()->add_input_arg();
169       new_arg->set_name(c_name);
170       new_arg->set_type(DT_FLOAT);
171 
172       // Output ops. These are copies of the inputs conditioned on the actual
173       // control outputs.
174       string c_out_name = g->NewName("acd__outchain");
175       auto* new_out = modified_body.add_node_def();
176       new_out->set_name(c_out_name);
177       new_out->set_op("Identity");
178       new_out->add_input(c_name);
179       new_out->add_input(
180           strings::StrCat("^", body_graph->signature().control_output(i)));
181       AttrValue attr;
182       attr.set_type(DT_FLOAT);
183       new_out->mutable_attr()->insert({"T", attr});
184 
185       // Output loop var declarations.
186       string c_ret_name = c_out_name;
187       std::replace(c_ret_name.begin(), c_ret_name.end(), '/', '_');
188       auto* new_out_arg = modified_body.mutable_signature()->add_output_arg();
189       new_out_arg->set_name(c_ret_name);
190       new_out_arg->set_type(DT_FLOAT);
191 
192       // Actual output loop vars.
193       modified_body.mutable_ret()->insert(
194           {c_ret_name, strings::StrCat(c_out_name, ":output:0")});
195       AttrValue attr_val;
196       attr_val.mutable_list()->mutable_shape();
197       FunctionDef_ArgAttrs arg_attrs;
198       arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val});
199       modified_body.mutable_arg_attr()->insert(
200           {static_cast<uint32_t>(i + num_loop_vars), arg_attrs});
201     }
202 
203     // Wire chain loop vars to the ops they need to condition.
204 
205     node_index.clear();
206     for (int i = 0; i < modified_body.node_def_size(); i++) {
207       node_index.emplace(modified_body.node_def(i).name(), i);
208     }
209     auto& modified_sig = modified_body.signature();
210     for (int i = 0; i < num_new_chains; i++) {
211       const auto& control_node =
212           modified_body.node_def(node_index[modified_sig.control_output(i)]);
213       for (const auto& r :
214            control_node.attr().at("_res_first_used_by").list().s()) {
215         NodeDef* first_node = modified_body.mutable_node_def(node_index[r]);
216         // This control dependency ensures proper sequencing of stateful ops
217         // upon entry into the loop body, so that they run after the ops
218         // which affected the same resource in the previous iteration.
219         first_node->add_input(strings::StrCat(
220             "^", modified_sig.input_arg(i + num_loop_vars).name()));
221       }
222     }
223 
224     // Clear body function's control returns.
225     modified_body.mutable_control_ret()->clear();
226 
227     // Add extra loop vars to the cond function.
228 
229     // TODO(mdan): This is wasteful. Can't we just mutate the original proto?
230     string cond_name = attrs.at("cond").func().name();
231     auto* cond_graph = flib_def->Find(cond_name);
232     DCHECK(cond_graph != nullptr);
233     FunctionDef modified_cond = *cond_graph;
234 
235     int cond_barrier_loc = -1;
236     for (int i = 0, s = cond_graph->node_def_size(); i < s; i++) {
237       if (cond_barrier_loc < 0) {
238         const auto& node_attr = cond_graph->node_def(i).attr();
239         if (node_attr.find("_acd_function_control_output") != node_attr.end()) {
240           cond_barrier_loc = i;
241         }
242       }
243     }
244     if (cond_barrier_loc > 0) {
245       // Disable the global end-of-body barrier from the cond function.
246       // Because removing a node is too inefficient (would have to walk all the
247       // inputs of all graph nodes), we instead clear its control dependencies.
248       modified_cond.mutable_node_def(cond_barrier_loc)->clear_input();
249     }
250 
251     for (int i = 0; i < num_new_chains; i++) {
252       // Input loop vars.
253       // TODO(mdan): These should gate the stateful ops in the cond.
254       // Until ACD supplies the necessary information, these are dummies in this
255       // function.
256       string c_name = g->NewName("acd__chain");
257       auto* new_arg = modified_cond.mutable_signature()->add_input_arg();
258       new_arg->set_name(c_name);
259       new_arg->set_type(DT_FLOAT);
260 
261       // TODO(mdan): Return values on the cond function? Most likely a bug.
262       AttrValue attr_val;
263       attr_val.mutable_list()->mutable_shape();
264       FunctionDef_ArgAttrs arg_attrs;
265       arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val});
266       modified_cond.mutable_arg_attr()->insert(
267           {static_cast<uint32_t>(i + num_loop_vars), arg_attrs});
268     }
269 
270     // Wire the new cond/body functions to the While node.
271 
272     string new_cond_name = g->NewName("acd__while_cond");
273     modified_cond.mutable_signature()->set_name(new_cond_name);
274     mattrs->at("cond").mutable_func()->set_name(new_cond_name);
275 
276     string new_body_name = g->NewName("acd__while_body");
277     modified_body.mutable_signature()->set_name(new_body_name);
278     mattrs->at("body").mutable_func()->set_name(new_body_name);
279 
280     // Commit the new functions.
281 
282     TF_RETURN_WITH_CONTEXT_IF_ERROR(
283         flib_def->AddFunctionDef(modified_body,
284                                  flib_def->GetStackTraces(body_name)),
285         "while attaching ", new_body_name, " to flib_def");
286     TF_RETURN_WITH_CONTEXT_IF_ERROR(
287         flib_def->AddFunctionDef(modified_cond,
288                                  flib_def->GetStackTraces(cond_name)),
289         "while attaching ", new_cond_name, " to flib_def");
290 
291     // TODO(b/183666205): This should not be necessary.
292     // It's unclear why adding the functions here is also required.
293     // Moreover, it's unclear when graph_lib's parent is flib_def itself.
294     auto* graph_lib = g->mutable_flib_def();
295     if (graph_lib->default_registry() != flib_def) {
296       TF_RETURN_WITH_CONTEXT_IF_ERROR(
297           graph_lib->AddFunctionDef(modified_body,
298                                     graph_lib->GetStackTraces(body_name)),
299           "while attaching ", new_body_name, " to graph");
300       TF_RETURN_WITH_CONTEXT_IF_ERROR(
301           graph_lib->AddFunctionDef(modified_cond,
302                                     graph_lib->GetStackTraces(cond_name)),
303           "while attaching ", new_cond_name, " to graph");
304     }
305   }
306 
307   if (VLOG_IS_ON(1)) {
308     DumpGraphToFile("control_flow_deps_to_chains_after", *g, flib_def);
309   }
310 
311   return Status::OK();
312 }
313 
314 // Note: This needs to run before functional control flow lowering, which is 10.
315 REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 9,
316                       ControlFlowDepsToChainsPass);
317 
318 }  // namespace tensorflow
319