• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
27 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_while.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_optimizer.h"
35 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
36 #include "tensorflow/core/framework/graph_to_functiondef.h"
37 #include "tensorflow/core/framework/node_def_builder.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/control_flow.h"
40 #include "tensorflow/core/graph/node_builder.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/cleanup.h"
43 #include "tensorflow/core/public/session_options.h"
44 #include "tensorflow/core/public/version.h"
45 #include "tensorflow/core/util/dump_graph.h"
46 
47 namespace tensorflow {
48 
49 // Helper functions for functionalizing control flow in functions.
50 
51 // Maps function name to
52 // - new function name, if the function body was functionalized
53 // - absl::nullopt, if not
54 using FuncMap = std::map<string, absl::optional<string>>;
55 using FuncMapIter = std::map<string, absl::optional<string>>::const_iterator;
56 
57 // Returns whether function has been processed before.
FunctionHasBeenProcessed(FuncMapIter func_iter,const FuncMap * func_map)58 bool FunctionHasBeenProcessed(FuncMapIter func_iter, const FuncMap* func_map) {
59   return func_iter != func_map->end();
60 }
61 
62 // Returns whether function has been modified (i.e., functionalized) before.
FunctionHasBeenModified(FuncMapIter func_iter)63 bool FunctionHasBeenModified(FuncMapIter func_iter) {
64   return func_iter->second.has_value();
65 }
66 
67 // Returns a name for the new functionalized version of a function.
GetNewFunctionName(const string & func_name,Node * n,AssociatedFunctionInfo::AssociatedFunctionType func_type,FunctionLibraryDefinition * fld)68 string GetNewFunctionName(
69     const string& func_name, Node* n,
70     AssociatedFunctionInfo::AssociatedFunctionType func_type,
71     FunctionLibraryDefinition* fld) {
72   // For SymbolicGradient, `func_name` is always "SymbolicGradient" which
73   // is not very informative. Use node name instead.
74   return (
75       func_type ==
76               AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient
77           ? fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_"))
78           : fld->UniqueFunctionName(absl::StrCat(func_name, "_f15n_")));
79 }
80 
81 // Returns name to which a modified function has been mapped.
GetMappedFunctionName(FuncMapIter func_iter)82 const string& GetMappedFunctionName(FuncMapIter func_iter) {
83   DCHECK(func_iter->second.has_value());
84   return func_iter->second.value();
85 }
86 
87 // Updates `func_map` with function given by `canonicalized_name`.
UpdateFunctionMap(FuncMap * func_map,const string & canonicalized_name,const string & new_func_name,bool function_modified)88 void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name,
89                        const string& new_func_name, bool function_modified) {
90   // If function was modified store its new name, otherwise add empty entry to
91   // record that function has been processed and does not need to be rewritten.
92   (*func_map)[canonicalized_name] =
93       function_modified ? absl::make_optional(new_func_name) : absl::nullopt;
94 }
95 
96 // Adds new function def to graph's function library if necessary.
AddFunctionDefToGraphLibrary(const string & func_name,const AssociatedFunctionInfo & associated_function,Graph * graph,FunctionLibraryDefinition * fld)97 Status AddFunctionDefToGraphLibrary(
98     const string& func_name, const AssociatedFunctionInfo& associated_function,
99     Graph* graph, FunctionLibraryDefinition* fld) {
100   const OpRegistrationData* op_reg_data;
101   // We have to be careful with adding the function def since there are three
102   // different `OpRegistryInterface`s involved here:
103   // `fld`, `graph->flib_def()` and `graph->flib_def().default_registry()`.
104   // We have already added the function def to `fld` before calling this
105   // function but for the subsequent `RewriteAssociatedFunction` call we need
106   // the function def to be in one of the other two registries, otherwise
107   // `RewriteAssociatedFunction` will fail for the `kFunctionCallNode` case
108   // because it cannot find the associated function def.
109   // On the other hand, we should not add the function def if it is already
110   // contained in one of the last two registries, this would lead to errors when
111   // the function def is already in one registry and we try to add it to the
112   // other one (if we try to add it to the same it's fine). This can happen in
113   // cases where one of the last two registries is identical to `fld` (which we
114   // already updated).
115   // Therefore, before adding the function def we have to check if it's already
116   // contained in either `graph->flib_def()` or
117   // `graph->flib_def().default_registry()` which is done in the following line
118   // (we have to use `LookUp` instead of `Contains` or `Find` because the latter
119   // both don't check the default registry).
120   if (graph->flib_def().LookUp(func_name, &op_reg_data).ok())
121     return Status::OK();
122 
123   const FunctionDef* new_fdef = fld->Find(func_name);
124   DCHECK(new_fdef != nullptr);
125   FunctionDefLibrary fdef_lib;
126   *(fdef_lib.add_function()) = *new_fdef;
127   return graph->AddFunctionLibrary(fdef_lib);
128 }
129 
130 // Functionalizes function given by `func_name`. Update `func_map` accordingly.
131 Status FunctionalizeControlFlowForFunction(
132     const string& func_name, const string& new_func_name,
133     const protobuf::Map<string, tensorflow::AttrValue>& attrs,
134     FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
135     FuncMap* func_map, bool* function_modified,
136     const NodeFilter& node_filter = {});
137 
138 // Functionalizes all functions that are (directly or indirectly) associated to
139 // any node in `graph`. Adds processed functions to `func_map`.
FunctionalizeControlFlowForNodeAssociatedFunctions(FuncMap * func_map,Graph * graph,FunctionLibraryDefinition * fld,FunctionLibraryRuntime * flr,bool * any_function_modified,const NodeFilter & node_filter)140 Status FunctionalizeControlFlowForNodeAssociatedFunctions(
141     FuncMap* func_map, Graph* graph, FunctionLibraryDefinition* fld,
142     FunctionLibraryRuntime* flr, bool* any_function_modified,
143     const NodeFilter& node_filter) {
144   std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
145       nodes_to_associated_functions;
146   for (auto* n : graph->nodes()) {
147     auto associated_functions = GetAssociatedFunctions(*n, fld);
148     if (!associated_functions.empty()) {
149       nodes_to_associated_functions.push_back({n, associated_functions});
150     }
151   }
152   for (const auto& pair : nodes_to_associated_functions) {
153     Node* n = pair.first;
154     auto associated_functions = pair.second;
155     for (auto& associated_function : associated_functions) {
156       // Note that if `n` is a function call node, then potential calls of
157       // `RewriteAssociatedFunction` below might delete `n` and create a new
158       // node instead, making `n` an invalid pointer. That's fine because in
159       // that case `n` only has one associated function, so this loop has only
160       // one iteration and we don't use `n` again after the rewrite.
161       // The invariant is guaranteed by `GetAssociatedFunctions` and confirmed
162       // below.
163       DCHECK(associated_function.type() !=
164                  AssociatedFunctionInfo::kFunctionCallNode ||
165              associated_functions.size() == 1);
166 
167       // Process one node-function-pair.
168       string func_name = associated_function.func_name();
169       string canonicalized_name =
170           Canonicalize(func_name, AttrSlice(&associated_function.attrs()));
171       auto func_iter = func_map->find(canonicalized_name);
172       string new_func_name;
173       if (FunctionHasBeenProcessed(func_iter, func_map)) {
174         if (FunctionHasBeenModified(func_iter)) {
175           *any_function_modified = true;
176           new_func_name = GetMappedFunctionName(func_iter);
177           TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
178               graph, n, fld, associated_function, new_func_name));
179         }
180         continue;
181       }
182       // Function is processed for the first time.
183       bool function_modified = false;
184       new_func_name =
185           GetNewFunctionName(func_name, n, associated_function.type(), fld);
186       // Perform functionalization for current function.
187       TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
188           func_name, new_func_name, associated_function.attrs(), fld, flr,
189           func_map, &function_modified, node_filter));
190       UpdateFunctionMap(func_map, canonicalized_name, new_func_name,
191                         function_modified);
192       if (function_modified) {
193         *any_function_modified = true;
194         TF_RETURN_IF_ERROR(AddFunctionDefToGraphLibrary(
195             new_func_name, associated_function, graph, fld));
196         TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
197             graph, n, fld, associated_function, new_func_name));
198       }
199     }
200   }
201   return Status::OK();
202 }
203 
FunctionalizeControlFlowForFunction(const string & func_name,const string & new_func_name,const protobuf::Map<string,tensorflow::AttrValue> & attrs,FunctionLibraryDefinition * fld,FunctionLibraryRuntime * flr,FuncMap * func_map,bool * function_modified,const NodeFilter & node_filter)204 Status FunctionalizeControlFlowForFunction(
205     const string& func_name, const string& new_func_name,
206     const protobuf::Map<string, tensorflow::AttrValue>& attrs,
207     FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
208     FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter) {
209   *function_modified = false;
210 
211   // Convert the function to a graph.
212   FunctionLibraryRuntime::Handle handle;
213   TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
214   Status ret_status = Status::OK();
215   auto cleanup_handle = gtl::MakeCleanup([&]() {
216     auto s = flr->ReleaseHandle(handle);
217     if (!s.ok()) {
218       ret_status.Update(s);
219     }
220   });
221   const FunctionBody* body = flr->GetFunctionBody(handle);
222   Graph* g = body->graph;
223 
224   // Check if the graph has Switch or Merge node.
225   bool has_switch_or_merge = false;
226   for (Node* n : body->graph->nodes()) {
227     // Skip nodes that are filtered out.
228     if (node_filter && !node_filter(n)) continue;
229     if (n->type_string() == "Switch" || n->type_string() == "Merge") {
230       has_switch_or_merge = true;
231       break;
232     }
233   }
234   // Before functionalizing control flow in `g` we functionalize control flow
235   // in functions (directly or indirectly) associated with nodes in `g`.
236   TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions(
237       func_map, g, fld, flr, function_modified, node_filter));
238 
239   if (has_switch_or_merge) {
240     *function_modified = true;
241 
242     // Functionalize the function body.
243     if (VLOG_IS_ON(4)) {
244       DumpGraphToFile(
245           absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
246           *g, fld);
247     }
248     TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld, node_filter));
249     if (VLOG_IS_ON(4)) {
250       DumpGraphToFile(
251           absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
252           fld);
253     }
254   }
255   if (*function_modified) {
256     // Add rewritten FunctionDef into library.
257     FunctionDef functionalized_fdef;
258     TF_RETURN_IF_ERROR(
259         GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
260     if (func_name == new_func_name) {
261       VLOG(2) << "Replacing function " << func_name;
262       TF_RETURN_IF_ERROR(
263           fld->ReplaceFunction(new_func_name, functionalized_fdef));
264     } else {
265       VLOG(2) << "Adding function " << new_func_name;
266       TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
267     }
268   }
269 
270   return ret_status;
271 }
272 
FunctionalizeControlFlow(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter,bool include_functions)273 Status FunctionalizeControlFlow(Graph* graph,
274                                 FunctionLibraryDefinition* library,
275                                 const NodeFilter& node_filter,
276                                 bool include_functions) {
277   VLOG(2) << "FunctionalizeControlFlow (initial): "
278           << DumpGraphToFile("functionalize_initial", *graph, library);
279 
280   if (include_functions) {
281     // Functionalize control flow in functions that are (directly or indirectly)
282     // associated with a node in `graph`.
283     auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
284         /*device_mgr=*/nullptr, tensorflow::Env::Default(),
285         /*config=*/nullptr, TF_GRAPH_DEF_VERSION, library,
286         tensorflow::OptimizerOptions());
287     // `pflr` has only one `FunctionLibraryRuntime`, for `kDefaultFLRDevice`
288     // (because we constructed it with `device_mgr = nullptr`).
289     FunctionLibraryRuntime* flr =
290         pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
291 
292     FuncMap func_map;
293     bool modified = false;
294     TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions(
295         &func_map, graph, library, flr, &modified, node_filter));
296   }
297   // Functionalize and remove while loops from graph.
298   TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library, node_filter));
299 
300   // FunctionalizeControlFlow is invoked for every function, so the loops's
301   // bodies and conditionals that were extracted into functions will be handled
302   // in successive invocations.
303   TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library, node_filter));
304 
305   VLOG(2) << "FunctionalizeControlFlow (final): "
306           << DumpGraphToFile("functionalize_final", *graph, library);
307 
308   return Status::OK();
309 }
310 
FunctionalizeControlFlowForGraphDef(GraphDef * graph_def,FunctionLibraryDefinition * library,const NodeFilter & node_filter,bool include_functions)311 Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
312                                            FunctionLibraryDefinition* library,
313                                            const NodeFilter& node_filter,
314                                            bool include_functions) {
315   FunctionDefLibrary function_lib = graph_def->library();
316   Graph graph(OpRegistry::Global());
317 
318   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph));
319   TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter,
320                                               include_functions));
321   graph.ToGraphDef(graph_def);
322   std::swap(*graph_def->mutable_library(), function_lib);
323   return Status::OK();
324 }
325 
Run(const GraphOptimizationPassOptions & options)326 Status FunctionalizeControlFlowForXlaPass::Run(
327     const GraphOptimizationPassOptions& options) {
328   Graph* graph = options.graph->get();
329   if (VLOG_IS_ON(4)) {
330     DumpGraphToFile("functionalize_control_flow_before", *graph,
331                     options.flib_def);
332   }
333   const auto* config = &options.session_options->config;
334   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
335       new ProcessFunctionLibraryRuntime(
336           /*device_mgr=*/nullptr, options.session_options->env, config,
337           TF_GRAPH_DEF_VERSION, options.flib_def,
338           config->graph_options().optimizer_options()));
339   FunctionLibraryRuntime* flr =
340       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
341 
342   // Find XLA compile ops and its corresponding FunctionDef.
343   // TPUCompile op is not in the map because graph rewriting might happen
344   // multiple times, and we want to avoid functionalize it again.
345   static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
346       new std::map<string, string>{
347           // _TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
348           {"_TPUReplicate", "computation"},
349           // XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
350           {"XlaLaunch", "function"},
351       };
352   FuncMap func_map;
353   bool fld_modified = false;
354   for (Node* n : graph->nodes()) {
355     auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
356     if (it == kNodeTypeToFunctionAttrMapping->end()) {
357       continue;
358     }
359     const string func_attr = it->second;
360     NameAttrList func;
361     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
362     VLOG(2) << "Graph has node " << n->type_string()
363             << ". Corresponding function: " << func.name();
364     string new_func_name = options.flib_def->UniqueFunctionName(
365         absl::StrCat(func.name(), "_f15n_"));
366     bool modified;
367     TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
368         func.name(), new_func_name, func.attr(), options.flib_def, flr,
369         &func_map, &modified));
370     if (modified) {
371       n->ClearAttr(func_attr);
372       func.set_name(new_func_name);
373       n->AddAttr(func_attr, func);
374       fld_modified = true;
375     }
376   }
377 
378   // TODO(ylc, endlessroad): Change this to "if (fld_modified")"
379   if (false) {
380     if (VLOG_IS_ON(4)) {
381       DumpGraphToFile("functionalize_control_flow_before_prune", *graph,
382                       options.flib_def);
383     }
384     TF_RETURN_IF_ERROR(
385         PruneUnreachableFunctionsFromGraph(*graph, options.flib_def));
386   }
387 
388   if (VLOG_IS_ON(4)) {
389     DumpGraphToFile("functionalize_control_flow_after", *graph,
390                     options.flib_def);
391   }
392   return Status::OK();
393 }
394 
Run(const GraphOptimizationPassOptions & options)395 Status FunctionalizeControlFlowPass::Run(
396     const GraphOptimizationPassOptions& options) {
397   return FunctionalizeControlFlow(options.graph->get(), options.flib_def);
398 }
399 
400 }  // namespace tensorflow
401