1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
27 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_while.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_optimizer.h"
35 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
36 #include "tensorflow/core/framework/graph_to_functiondef.h"
37 #include "tensorflow/core/framework/node_def_builder.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/control_flow.h"
40 #include "tensorflow/core/graph/node_builder.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/gtl/cleanup.h"
43 #include "tensorflow/core/public/session_options.h"
44 #include "tensorflow/core/public/version.h"
45 #include "tensorflow/core/util/dump_graph.h"
46
47 namespace tensorflow {
48
49 // Helper functions for functionalizing control flow in functions.
50
51 // Maps function name to
52 // - new function name, if the function body was functionalized
53 // - absl::nullopt, if not
54 using FuncMap = std::map<string, absl::optional<string>>;
55 using FuncMapIter = std::map<string, absl::optional<string>>::const_iterator;
56
57 // Returns whether function has been processed before.
FunctionHasBeenProcessed(FuncMapIter func_iter,const FuncMap * func_map)58 bool FunctionHasBeenProcessed(FuncMapIter func_iter, const FuncMap* func_map) {
59 return func_iter != func_map->end();
60 }
61
62 // Returns whether function has been modified (i.e., functionalized) before.
FunctionHasBeenModified(FuncMapIter func_iter)63 bool FunctionHasBeenModified(FuncMapIter func_iter) {
64 return func_iter->second.has_value();
65 }
66
67 // Returns a name for the new functionalized version of a function.
GetNewFunctionName(const string & func_name,Node * n,AssociatedFunctionInfo::AssociatedFunctionType func_type,FunctionLibraryDefinition * fld)68 string GetNewFunctionName(
69 const string& func_name, Node* n,
70 AssociatedFunctionInfo::AssociatedFunctionType func_type,
71 FunctionLibraryDefinition* fld) {
72 // For SymbolicGradient, `func_name` is always "SymbolicGradient" which
73 // is not very informative. Use node name instead.
74 return (
75 func_type ==
76 AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient
77 ? fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_"))
78 : fld->UniqueFunctionName(absl::StrCat(func_name, "_f15n_")));
79 }
80
81 // Returns name to which a modified function has been mapped.
GetMappedFunctionName(FuncMapIter func_iter)82 const string& GetMappedFunctionName(FuncMapIter func_iter) {
83 DCHECK(func_iter->second.has_value());
84 return func_iter->second.value();
85 }
86
87 // Updates `func_map` with function given by `canonicalized_name`.
UpdateFunctionMap(FuncMap * func_map,const string & canonicalized_name,const string & new_func_name,bool function_modified)88 void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name,
89 const string& new_func_name, bool function_modified) {
90 // If function was modified store its new name, otherwise add empty entry to
91 // record that function has been processed and does not need to be rewritten.
92 (*func_map)[canonicalized_name] =
93 function_modified ? absl::make_optional(new_func_name) : absl::nullopt;
94 }
95
96 // Adds new function def to graph's function library if necessary.
AddFunctionDefToGraphLibrary(const string & func_name,const AssociatedFunctionInfo & associated_function,Graph * graph,FunctionLibraryDefinition * fld)97 Status AddFunctionDefToGraphLibrary(
98 const string& func_name, const AssociatedFunctionInfo& associated_function,
99 Graph* graph, FunctionLibraryDefinition* fld) {
100 const OpRegistrationData* op_reg_data;
101 // We have to be careful with adding the function def since there are three
102 // different `OpRegistryInterface`s involved here:
103 // `fld`, `graph->flib_def()` and `graph->flib_def().default_registry()`.
104 // We have already added the function def to `fld` before calling this
105 // function but for the subsequent `RewriteAssociatedFunction` call we need
106 // the function def to be in one of the other two registries, otherwise
107 // `RewriteAssociatedFunction` will fail for the `kFunctionCallNode` case
108 // because it cannot find the associated function def.
109 // On the other hand, we should not add the function def if it is already
110 // contained in one of the last two registries, this would lead to errors when
111 // the function def is already in one registry and we try to add it to the
112 // other one (if we try to add it to the same it's fine). This can happen in
113 // cases where one of the last two registries is identical to `fld` (which we
114 // already updated).
115 // Therefore, before adding the function def we have to check if it's already
116 // contained in either `graph->flib_def()` or
117 // `graph->flib_def().default_registry()` which is done in the following line
118 // (we have to use `LookUp` instead of `Contains` or `Find` because the latter
119 // both don't check the default registry).
120 if (graph->flib_def().LookUp(func_name, &op_reg_data).ok())
121 return Status::OK();
122
123 const FunctionDef* new_fdef = fld->Find(func_name);
124 DCHECK(new_fdef != nullptr);
125 FunctionDefLibrary fdef_lib;
126 *(fdef_lib.add_function()) = *new_fdef;
127 return graph->AddFunctionLibrary(fdef_lib);
128 }
129
130 // Functionalizes function given by `func_name`. Update `func_map` accordingly.
131 Status FunctionalizeControlFlowForFunction(
132 const string& func_name, const string& new_func_name,
133 const protobuf::Map<string, tensorflow::AttrValue>& attrs,
134 FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
135 FuncMap* func_map, bool* function_modified,
136 const NodeFilter& node_filter = {});
137
138 // Functionalizes all functions that are (directly or indirectly) associated to
139 // any node in `graph`. Adds processed functions to `func_map`.
FunctionalizeControlFlowForNodeAssociatedFunctions(FuncMap * func_map,Graph * graph,FunctionLibraryDefinition * fld,FunctionLibraryRuntime * flr,bool * any_function_modified,const NodeFilter & node_filter)140 Status FunctionalizeControlFlowForNodeAssociatedFunctions(
141 FuncMap* func_map, Graph* graph, FunctionLibraryDefinition* fld,
142 FunctionLibraryRuntime* flr, bool* any_function_modified,
143 const NodeFilter& node_filter) {
144 std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
145 nodes_to_associated_functions;
146 for (auto* n : graph->nodes()) {
147 auto associated_functions = GetAssociatedFunctions(*n, fld);
148 if (!associated_functions.empty()) {
149 nodes_to_associated_functions.push_back({n, associated_functions});
150 }
151 }
152 for (const auto& pair : nodes_to_associated_functions) {
153 Node* n = pair.first;
154 auto associated_functions = pair.second;
155 for (auto& associated_function : associated_functions) {
156 // Note that if `n` is a function call node, then potential calls of
157 // `RewriteAssociatedFunction` below might delete `n` and create a new
158 // node instead, making `n` an invalid pointer. That's fine because in
159 // that case `n` only has one associated function, so this loop has only
160 // one iteration and we don't use `n` again after the rewrite.
161 // The invariant is guaranteed by `GetAssociatedFunctions` and confirmed
162 // below.
163 DCHECK(associated_function.type() !=
164 AssociatedFunctionInfo::kFunctionCallNode ||
165 associated_functions.size() == 1);
166
167 // Process one node-function-pair.
168 string func_name = associated_function.func_name();
169 string canonicalized_name =
170 Canonicalize(func_name, AttrSlice(&associated_function.attrs()));
171 auto func_iter = func_map->find(canonicalized_name);
172 string new_func_name;
173 if (FunctionHasBeenProcessed(func_iter, func_map)) {
174 if (FunctionHasBeenModified(func_iter)) {
175 *any_function_modified = true;
176 new_func_name = GetMappedFunctionName(func_iter);
177 TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
178 graph, n, fld, associated_function, new_func_name));
179 }
180 continue;
181 }
182 // Function is processed for the first time.
183 bool function_modified = false;
184 new_func_name =
185 GetNewFunctionName(func_name, n, associated_function.type(), fld);
186 // Perform functionalization for current function.
187 TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
188 func_name, new_func_name, associated_function.attrs(), fld, flr,
189 func_map, &function_modified, node_filter));
190 UpdateFunctionMap(func_map, canonicalized_name, new_func_name,
191 function_modified);
192 if (function_modified) {
193 *any_function_modified = true;
194 TF_RETURN_IF_ERROR(AddFunctionDefToGraphLibrary(
195 new_func_name, associated_function, graph, fld));
196 TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
197 graph, n, fld, associated_function, new_func_name));
198 }
199 }
200 }
201 return Status::OK();
202 }
203
FunctionalizeControlFlowForFunction(const string & func_name,const string & new_func_name,const protobuf::Map<string,tensorflow::AttrValue> & attrs,FunctionLibraryDefinition * fld,FunctionLibraryRuntime * flr,FuncMap * func_map,bool * function_modified,const NodeFilter & node_filter)204 Status FunctionalizeControlFlowForFunction(
205 const string& func_name, const string& new_func_name,
206 const protobuf::Map<string, tensorflow::AttrValue>& attrs,
207 FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
208 FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter) {
209 *function_modified = false;
210
211 // Convert the function to a graph.
212 FunctionLibraryRuntime::Handle handle;
213 TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
214 Status ret_status = Status::OK();
215 auto cleanup_handle = gtl::MakeCleanup([&]() {
216 auto s = flr->ReleaseHandle(handle);
217 if (!s.ok()) {
218 ret_status.Update(s);
219 }
220 });
221 const FunctionBody* body = flr->GetFunctionBody(handle);
222 Graph* g = body->graph;
223
224 // Check if the graph has Switch or Merge node.
225 bool has_switch_or_merge = false;
226 for (Node* n : body->graph->nodes()) {
227 // Skip nodes that are filtered out.
228 if (node_filter && !node_filter(n)) continue;
229 if (n->type_string() == "Switch" || n->type_string() == "Merge") {
230 has_switch_or_merge = true;
231 break;
232 }
233 }
234 // Before functionalizing control flow in `g` we functionalize control flow
235 // in functions (directly or indirectly) associated with nodes in `g`.
236 TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions(
237 func_map, g, fld, flr, function_modified, node_filter));
238
239 if (has_switch_or_merge) {
240 *function_modified = true;
241
242 // Functionalize the function body.
243 if (VLOG_IS_ON(4)) {
244 DumpGraphToFile(
245 absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
246 *g, fld);
247 }
248 TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld, node_filter));
249 if (VLOG_IS_ON(4)) {
250 DumpGraphToFile(
251 absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
252 fld);
253 }
254 }
255 if (*function_modified) {
256 // Add rewritten FunctionDef into library.
257 FunctionDef functionalized_fdef;
258 TF_RETURN_IF_ERROR(
259 GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
260 if (func_name == new_func_name) {
261 VLOG(2) << "Replacing function " << func_name;
262 TF_RETURN_IF_ERROR(
263 fld->ReplaceFunction(new_func_name, functionalized_fdef));
264 } else {
265 VLOG(2) << "Adding function " << new_func_name;
266 TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
267 }
268 }
269
270 return ret_status;
271 }
272
FunctionalizeControlFlow(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter,bool include_functions)273 Status FunctionalizeControlFlow(Graph* graph,
274 FunctionLibraryDefinition* library,
275 const NodeFilter& node_filter,
276 bool include_functions) {
277 VLOG(2) << "FunctionalizeControlFlow (initial): "
278 << DumpGraphToFile("functionalize_initial", *graph, library);
279
280 if (include_functions) {
281 // Functionalize control flow in functions that are (directly or indirectly)
282 // associated with a node in `graph`.
283 auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
284 /*device_mgr=*/nullptr, tensorflow::Env::Default(),
285 /*config=*/nullptr, TF_GRAPH_DEF_VERSION, library,
286 tensorflow::OptimizerOptions());
287 // `pflr` has only one `FunctionLibraryRuntime`, for `kDefaultFLRDevice`
288 // (because we constructed it with `device_mgr = nullptr`).
289 FunctionLibraryRuntime* flr =
290 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
291
292 FuncMap func_map;
293 bool modified = false;
294 TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions(
295 &func_map, graph, library, flr, &modified, node_filter));
296 }
297 // Functionalize and remove while loops from graph.
298 TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library, node_filter));
299
300 // FunctionalizeControlFlow is invoked for every function, so the loops's
301 // bodies and conditionals that were extracted into functions will be handled
302 // in successive invocations.
303 TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library, node_filter));
304
305 VLOG(2) << "FunctionalizeControlFlow (final): "
306 << DumpGraphToFile("functionalize_final", *graph, library);
307
308 return Status::OK();
309 }
310
FunctionalizeControlFlowForGraphDef(GraphDef * graph_def,FunctionLibraryDefinition * library,const NodeFilter & node_filter,bool include_functions)311 Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def,
312 FunctionLibraryDefinition* library,
313 const NodeFilter& node_filter,
314 bool include_functions) {
315 FunctionDefLibrary function_lib = graph_def->library();
316 Graph graph(OpRegistry::Global());
317
318 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph));
319 TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter,
320 include_functions));
321 graph.ToGraphDef(graph_def);
322 std::swap(*graph_def->mutable_library(), function_lib);
323 return Status::OK();
324 }
325
Run(const GraphOptimizationPassOptions & options)326 Status FunctionalizeControlFlowForXlaPass::Run(
327 const GraphOptimizationPassOptions& options) {
328 Graph* graph = options.graph->get();
329 if (VLOG_IS_ON(4)) {
330 DumpGraphToFile("functionalize_control_flow_before", *graph,
331 options.flib_def);
332 }
333 const auto* config = &options.session_options->config;
334 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
335 new ProcessFunctionLibraryRuntime(
336 /*device_mgr=*/nullptr, options.session_options->env, config,
337 TF_GRAPH_DEF_VERSION, options.flib_def,
338 config->graph_options().optimizer_options()));
339 FunctionLibraryRuntime* flr =
340 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
341
342 // Find XLA compile ops and its corresponding FunctionDef.
343 // TPUCompile op is not in the map because graph rewriting might happen
344 // multiple times, and we want to avoid functionalize it again.
345 static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
346 new std::map<string, string>{
347 // _TPUReplicate ops are generated by EncapsulateTPUComputationsPass.
348 {"_TPUReplicate", "computation"},
349 // XlaLaunch ops are generated by EncapsulateXlaComputationsPass.
350 {"XlaLaunch", "function"},
351 };
352 FuncMap func_map;
353 bool fld_modified = false;
354 for (Node* n : graph->nodes()) {
355 auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
356 if (it == kNodeTypeToFunctionAttrMapping->end()) {
357 continue;
358 }
359 const string func_attr = it->second;
360 NameAttrList func;
361 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
362 VLOG(2) << "Graph has node " << n->type_string()
363 << ". Corresponding function: " << func.name();
364 string new_func_name = options.flib_def->UniqueFunctionName(
365 absl::StrCat(func.name(), "_f15n_"));
366 bool modified;
367 TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
368 func.name(), new_func_name, func.attr(), options.flib_def, flr,
369 &func_map, &modified));
370 if (modified) {
371 n->ClearAttr(func_attr);
372 func.set_name(new_func_name);
373 n->AddAttr(func_attr, func);
374 fld_modified = true;
375 }
376 }
377
378 // TODO(ylc, endlessroad): Change this to "if (fld_modified")"
379 if (false) {
380 if (VLOG_IS_ON(4)) {
381 DumpGraphToFile("functionalize_control_flow_before_prune", *graph,
382 options.flib_def);
383 }
384 TF_RETURN_IF_ERROR(
385 PruneUnreachableFunctionsFromGraph(*graph, options.flib_def));
386 }
387
388 if (VLOG_IS_ON(4)) {
389 DumpGraphToFile("functionalize_control_flow_after", *graph,
390 options.flib_def);
391 }
392 return Status::OK();
393 }
394
Run(const GraphOptimizationPassOptions & options)395 Status FunctionalizeControlFlowPass::Run(
396 const GraphOptimizationPassOptions& options) {
397 return FunctionalizeControlFlow(options.graph->get(), options.flib_def);
398 }
399
400 } // namespace tensorflow
401