• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
17 
18 #include <vector>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_replace.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/device_set.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/common_runtime/lower_case_op.h"
33 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
34 #include "tensorflow/core/common_runtime/lower_if_op.h"
35 #include "tensorflow/core/common_runtime/lower_while_op.h"
36 #include "tensorflow/core/common_runtime/placer.h"
37 #include "tensorflow/core/framework/attr_value_util.h"
38 #include "tensorflow/core/framework/function.h"
39 #include "tensorflow/core/framework/function.pb.h"
40 #include "tensorflow/core/framework/graph_def_util.h"
41 #include "tensorflow/core/framework/node_def.pb.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/framework/op_def.pb.h"
44 #include "tensorflow/core/framework/versions.pb.h"
45 #include "tensorflow/core/graph/algorithm.h"
46 #include "tensorflow/core/graph/control_flow.h"
47 #include "tensorflow/core/graph/graph_node_util.h"
48 #include "tensorflow/core/graph/tensor_id.h"
49 #include "tensorflow/core/grappler/graph_view.h"
50 #include "tensorflow/core/grappler/grappler_item.h"
51 #include "tensorflow/core/grappler/op_types.h"
52 #include "tensorflow/core/grappler/utils.h"
53 #include "tensorflow/core/grappler/utils/functions.h"
54 #include "tensorflow/core/lib/gtl/map_util.h"
55 
56 namespace tensorflow {
57 namespace grappler {
58 namespace {
59 
60 constexpr const char* const kFuncAttr = FunctionLibraryDefinition::kFuncAttr;
61 
62 // Do not specialize functions marked with '_nospecialize' attribute.
63 constexpr const char* const kNoSpecializeAttr = "_nospecialize";
64 
65 // Mark functions that were created as a result of function specialization.
66 constexpr const char* const kGrapplerSpecializedFuncAttr =
67     "_GrapplerSpecializedFunc";
68 
69 // There are two ways of calling a Tensorflow function:
70 //
71 // 1. Direct function call: node.op() is the name of the function.
72 //
73 // 2. Indirect function call: the function name is passed through a node
74 //    attribute, and special Tensorflow kernels are responsible for calling the
75 //    function through the FunctionLibraryRuntime. Example: PartitionedCallOp.
76 
77 // Check if func_node.op() matches the name in FunctionDef signature.
IsDirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)78 bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
79   return func_node.op() == func.signature().name();
80 }
81 
82 // Check if func_node has function attribute with a function name matching
83 // FunctionDef signature.
IsIndirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)84 bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
85   if (!IsPartitionedCall(func_node) && !IsStatefulPartitionedCall(func_node)) {
86     return false;
87   }
88 
89   auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
90   return func_attr != nullptr && func_attr->has_func() &&
91          func_attr->func().name() == func.signature().name();
92 }
93 
FunctionInstantiationAttributes(const FunctionDef & func,const NodeDef & func_node)94 AttrSlice FunctionInstantiationAttributes(const FunctionDef& func,
95                                           const NodeDef& func_node) {
96   if (IsDirectFunctionCall(func, func_node)) {
97     return AttrSlice(func_node);
98 
99   } else if (IsIndirectFunctionCall(func, func_node)) {
100     auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
101     return AttrSlice(&func_attr->func().attr());
102 
103   } else {
104     LOG(WARNING) << "Can't resolve function instantiation attributes: "
105                  << SummarizeNodeDef(func_node);
106     return AttrSlice();
107   }
108 }
109 
110 // This is a fake device that should not be used for any op kernel execution,
111 // the only purpose of this device is to be passed as a part of DeviceSet to the
112 // Placer.
113 class FakeDevice : public Device {
114  public:
FakeDevice(Env * env,const string & device)115   FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {}
FakeDevice(const string & device)116   explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {}
Sync()117   Status Sync() override { return Status::OK(); }
118 
119  private:
attr(const string & device)120   static DeviceAttributes attr(const string& device) {
121     DeviceNameUtils::ParsedName parsed_name;
122     bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name);
123     DCHECK(parsed) << "Failed to parse full device name: " << device;
124 
125     DeviceAttributes attr;
126     attr.set_name(device);
127     attr.set_device_type(parsed_name.type);
128     return attr;
129   }
130 };
131 
132 // -------------------------------------------------------------------------- //
133 // Function specialization.
134 //
135 // FunctionDef is somewhat similar to function template in C++, given all the
136 // type parameters (and attribute values) it generates a statically defined
137 // graph from the type parametrized "graph template" (function body).
138 //
139 // Function specialization instantiates a parametrized FunctionDef into a
140 // statically defined graph, and then converts it back to the fully defined
141 // FunctionDef (it doesn't have any unknown type parameters or attribute
142 // values, known as placeholders).
143 //
144 // Given the fully specified graph we can apply all the Grappler optimizers to
145 // it (see details in MetaOptimizer). Also we can push known constant inputs
146 // into the function body, and remove unused outputs/inputs.
147 
MarkedNoSpecialize(const FunctionDef & fdef)148 bool MarkedNoSpecialize(const FunctionDef& fdef) {
149   const auto attr = AttrSlice(&fdef.attr());
150   bool nospecialize = false;
151   return TryGetNodeAttr(attr, kNoSpecializeAttr, &nospecialize) && nospecialize;
152 }
153 
154 // Specialized function instantiation type parameters, body parameters, and
155 // const inputs.
156 struct FunctionSpecializationSignature {
157   // Currently we do not support functions with tensor lists as inputs or
158   // outputs, so caller node input/output ports always match function
159   // input/output arguments.
160   using InputPort = int;
161   using OutputPort = int;
162 
163   string func_name;
164   bool is_in_fetch_set;
165   absl::flat_hash_set<OutputPort> active_outputs;
166   absl::flat_hash_map<string, DataType> type_parameters;
167   absl::flat_hash_map<string, AttrValue> body_parameters;
168   absl::flat_hash_map<InputPort, string> const_inputs;
169 
operator ==tensorflow::grappler::__anon42ddc3660111::FunctionSpecializationSignature170   bool operator==(const FunctionSpecializationSignature& other) const {
171     bool equals = func_name == other.func_name &&
172                   is_in_fetch_set == other.is_in_fetch_set &&
173                   active_outputs == other.active_outputs &&
174                   type_parameters == other.type_parameters &&
175                   const_inputs == other.const_inputs;
176 
177     if (!equals) return false;
178 
179     // Equality is not defined for AttrValue.
180     if (body_parameters.size() != other.body_parameters.size()) return false;
181 
182     for (const auto& lhs : body_parameters) {
183       auto it = other.body_parameters.find(lhs.first);
184       if (it == other.body_parameters.end()) return false;
185       if (!AreAttrValuesEqual(lhs.second, (*it).second,
186                               /*allow_false_negatives=*/true)) {
187         return false;
188       }
189     }
190 
191     return true;
192   }
193 
194   template <typename H>
AbslHashValue(H h,const FunctionSpecializationSignature & s)195   friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) {
196     H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set);
197 
198     // First pre-compute hashes for all values in collections with
199     // non-deterministic iteration order.
200     std::vector<uint64> hashes;
201     hashes.reserve(s.active_outputs.size()         //
202                    + s.type_parameters.size() * 2  //
203                    + s.body_parameters.size() * 2  //
204                    + s.const_inputs.size() * 2);
205 
206     absl::c_transform(s.active_outputs, std::back_inserter(hashes),
207                       hash<OutputPort>());
208 
209     using TypeParam = std::pair<const string, DataType>;
210     absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) {
211       AttrValue attr_value;
212       attr_value.set_type(type_param.second);
213       hashes.push_back(Hash64(type_param.first));
214       hashes.push_back(AttrValueHash(attr_value));
215     });
216 
217     using BodyParam = std::pair<const string, AttrValue>;
218     absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) {
219       hashes.push_back(Hash64(body_param.first));
220       hashes.push_back(FastAttrValueHash(body_param.second));
221     });
222 
223     using ConstInput = std::pair<const InputPort, string>;
224     absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) {
225       hashes.push_back(hash<InputPort>()(const_input.first));
226       hashes.push_back(Hash64(const_input.second));
227     });
228 
229     // Combine all pre-computed hashes in a deterministic order.
230     absl::c_sort(hashes);
231     return H::combine_contiguous(std::move(base), hashes.data(), hashes.size());
232   }
233 };
234 
235 struct FunctionSpecialization {
236   string specialized_func_name;
237   // True if the function caller node is in GrapplerItem fetch set.
238   bool is_in_fetch_set;
239   // Names of the tensors that were pushed down into the function body.
240   absl::flat_hash_set<string> const_inputs;
241   // Control dependencies of pushed down const inputs have to be attached to
242   // function caller node.
243   absl::flat_hash_set<string> control_deps;
244   // Output tensors (ports) that consumed by other nodes in the graph or in a
245   // GrapplerItem fetch set.
246   absl::flat_hash_set<int> active_outputs;
247   // Mapping from original function output port to the output port of
248   // specialized function. If function specialization changes the number of
249   // function outputs it's required to update all node consumers.
250   std::vector<std::pair<int, int>> output_mapping;
251 };
252 
253 // Function optimizer context initialized once for each optimization pass, and
254 // it uses the latest available graph (for the first iteration it will be the
255 // GrapplerItem.graph, for next iterations it will be the output of previous
256 // function optimizer pass).
257 class FunctionOptimizerContext {
258  public:
FunctionOptimizerContext(const GrapplerItem & item,RewriterConfig::Toggle opt_level,const GraphDef & graph)259   explicit FunctionOptimizerContext(const GrapplerItem& item,
260                                     RewriterConfig::Toggle opt_level,
261                                     const GraphDef& graph)
262       : item_(&item),
263         opt_level_(opt_level),
264         function_library_(OpRegistry::Global(), graph.library()),
265         truly_const_nodes_(InferTrulyConstNodes(item, graph)),
266         graph_view_(&graph) {}
267 
item() const268   const GrapplerItem& item() const { return *item_; }
269 
graph_version() const270   const int graph_version() const { return item_->graph.versions().producer(); }
271 
opt_level() const272   RewriterConfig::Toggle opt_level() const { return opt_level_; }
273 
function_library() const274   const FunctionLibraryDefinition& function_library() const {
275     return function_library_;
276   }
function_library()277   FunctionLibraryDefinition& function_library() { return function_library_; }
278 
279   const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
tensor_mapping() const280   tensor_mapping() const {
281     return tensor_mapping_;
282   }
283 
graph_view() const284   const GraphView& graph_view() const { return graph_view_; }
285 
IsFeedNode(const string & node_name) const286   bool IsFeedNode(const string& node_name) const {
287     return absl::c_any_of(
288         item_->feed, [&](const std::pair<std::string, Tensor>& feed) {
289           return ParseTensorName(feed.first).node() == node_name;
290         });
291   }
292 
IsFetchNode(const string & node_name) const293   bool IsFetchNode(const string& node_name) const {
294     return absl::c_any_of(item_->fetch, [&](const string& fetch) {
295       return ParseTensorName(fetch).node() == node_name;
296     });
297   }
298 
IsTrulyConst(const string & name) const299   bool IsTrulyConst(const string& name) const {
300     return TrulyConstNode(name) != nullptr;
301   }
302 
TrulyConstNode(const string & name) const303   const NodeDef* TrulyConstNode(const string& name) const {
304     return gtl::FindWithDefault(truly_const_nodes_, name, nullptr);
305   }
306 
FindFunctionSpecialization(const FunctionSpecializationSignature & sig) const307   const FunctionSpecialization* FindFunctionSpecialization(
308       const FunctionSpecializationSignature& sig) const {
309     return gtl::FindOrNull(specialized_functions_, sig);
310   }
311 
AddSpecializedFunction(const FunctionSpecializationSignature & sig,const FunctionSpecialization & specialized_func)312   void AddSpecializedFunction(const FunctionSpecializationSignature& sig,
313                               const FunctionSpecialization& specialized_func) {
314     specialized_functions_.emplace(sig, specialized_func);
315   }
316 
AddTensorMapping(const SafeTensorId & from,const SafeTensorId & to)317   void AddTensorMapping(const SafeTensorId& from, const SafeTensorId& to) {
318     DCHECK(from.index() != Graph::kControlSlot)
319         << "Tensor mapping must be from regular tensor";
320     DCHECK(to.index() != Graph::kControlSlot)
321         << "Tensor mapping must be to regular tensor";
322 
323     auto inserted = tensor_mapping_.insert({from, to});
324     DCHECK(inserted.second)
325         << "Failed to insert duplicated tensor mapping: "
326         << "from=" << from.ToString() << " to=" << to.ToString();
327   }
328 
AddTensorMapping(const string & func_node,const FunctionSpecialization & specialized_func)329   void AddTensorMapping(const string& func_node,
330                         const FunctionSpecialization& specialized_func) {
331     for (const auto& pair : specialized_func.output_mapping) {
332       int from_idx = pair.first;
333       int to_idx = pair.second;
334       if (from_idx != to_idx) {
335         SafeTensorId from_tensor(func_node, from_idx);
336         SafeTensorId to_tensor(func_node, to_idx);
337         AddTensorMapping(from_tensor, to_tensor);
338       }
339     }
340   }
341 
342  private:
InferTrulyConstNodes(const GrapplerItem & item,const GraphDef & graph)343   static absl::flat_hash_map<string, const NodeDef*> InferTrulyConstNodes(
344       const GrapplerItem& item, const GraphDef& graph) {
345     absl::flat_hash_set<absl::string_view> feed_nodes;
346     for (const auto& feed : item.feed) {
347       feed_nodes.insert(feed.first);
348     }
349 
350     absl::flat_hash_map<string, const NodeDef*> const_nodes;
351     for (const NodeDef& node : graph.node()) {
352       if (IsConstant(node) && !feed_nodes.contains(node.name())) {
353         const_nodes[node.name()] = &node;
354       }
355     }
356 
357     return const_nodes;
358   }
359 
360   const GrapplerItem* item_;  // must outlive this object
361   RewriterConfig::Toggle opt_level_;
362 
363   // Function library constructed from current graph.
364   FunctionLibraryDefinition function_library_;
365 
366   // Nodes that are Const and not in feed.
367   absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
368   // Specialized functions.
369   absl::flat_hash_map<FunctionSpecializationSignature,
370                       const FunctionSpecialization>
371       specialized_functions_;
372 
373   // After function specialization, the optimized graph might be in invalid
374   // state, nodes can read from output index that is no longer valid after
375   // unused outputs pruning.
376   //
377   // Tensor mapping that has to be applied to the graph after all functions
378   // optimizations (invalidated tensor id -> optimized graph tensor id).
379   absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
380       tensor_mapping_;
381 
382   // Use graph view to find active outputs of the function caller nodes.
383   GraphView graph_view_;
384 
385   TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
386 };
387 
388 // Returns a pointer to the called function definition iff the given node is
389 // indeed a function call. Otherwise returns nullptr.
FindFunctionCall(const FunctionOptimizerContext & ctx,const NodeDef & node)390 const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
391                                     const NodeDef& node) {
392   // Check if a node does indirect function call via PartitionedCallOp.
393   if (IsPartitionedCall(node) || IsStatefulPartitionedCall(node)) {
394     const AttrValue* func_attr = AttrSlice(node).Find("f");
395     return (func_attr != nullptr && func_attr->has_func())
396                ? ctx.function_library().Find(func_attr->func().name())
397                : nullptr;
398   }
399 
400   // Check if the function op itself is a function name.
401   return ctx.function_library().Find(node.op());
402 }
403 
GetActiveOutputs(const NodeDef & node,const FunctionOptimizerContext & ctx,int size_hint=0)404 absl::flat_hash_set<int> GetActiveOutputs(const NodeDef& node,
405                                           const FunctionOptimizerContext& ctx,
406                                           int size_hint = 0) {
407   absl::flat_hash_set<int> active_outputs;
408   active_outputs.reserve(static_cast<size_t>(size_hint));
409 
410   // 1. Output can be consumed by the other graph node.
411   const auto node_fanout_edges =
412       ctx.graph_view().GetFanoutEdges(node, /*include_controlled_edges=*/false);
413   for (const GraphView::Edge& edge : node_fanout_edges) {
414     active_outputs.insert(edge.src.port_id);
415   }
416 
417   // 2. Or it can be in a fetch set.
418   for (const string& fetch : ctx.item().fetch) {
419     TensorId fetch_tensor = ParseTensorName(fetch);
420     if (fetch_tensor.node() == node.name()) {
421       active_outputs.insert(fetch_tensor.index());
422     }
423   }
424 
425   return active_outputs;
426 }
427 
HasTrulyConstInputs(const NodeDef & node,const FunctionOptimizerContext & ctx)428 bool HasTrulyConstInputs(const NodeDef& node,
429                          const FunctionOptimizerContext& ctx) {
430   const auto is_truly_const = [&ctx](const string& input) {
431     return ctx.IsTrulyConst(NodeName(input));
432   };
433   return absl::c_any_of(node.input(), is_truly_const);
434 }
435 
HasUnusedOutputs(const NodeDef & func_node,const FunctionDef & func,const FunctionOptimizerContext & ctx)436 bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
437                       const FunctionOptimizerContext& ctx) {
438   // Functions with tensor list outputs are not supported right now, so the
439   // number of output args is the same as number of possible function caller
440   // node outputs.
441   int num_outputs = func.signature().output_arg_size();
442   const absl::flat_hash_set<int> active_outputs =
443       GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
444   int active_outputs_size = active_outputs.size();
445   return active_outputs_size != num_outputs;
446 }
447 
448 // Return pruned FunctionDefLibrary with functions that are reachable from
449 // the optimized graph.
PruneFunctionLibrary(const FunctionLibraryDefinition & flib,const GraphDef & optimized_graph)450 FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
451                                         const GraphDef& optimized_graph) {
452   FunctionLibraryDefinition pruned_flib =
453       flib.ReachableDefinitions(optimized_graph);
454 
455   int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
456                          static_cast<int>(flib.num_functions());
457 
458   VLOG(3) << "Pruned function library: " << pruned_flib.num_functions()
459           << " functions (" << pruned_functions << ")";
460 
461   return pruned_flib.ToProto();
462 }
463 
464 // Push all constant inputs of an instantiating node into the function body.
PushDownConstInputs(const NodeDef & func_node,const FunctionOptimizerContext & ctx,GrapplerFunctionItem * item,absl::flat_hash_set<string> * const_inputs,absl::flat_hash_set<string> * control_deps)465 Status PushDownConstInputs(const NodeDef& func_node,
466                            const FunctionOptimizerContext& ctx,
467                            GrapplerFunctionItem* item,
468                            absl::flat_hash_set<string>* const_inputs,
469                            absl::flat_hash_set<string>* control_deps) {
470   // Record node control dependencies in the control_deps set.
471   const auto record_control_deps = [&](const NodeDef* const_input) {
472     for (int i = const_input->input_size() - 1; i >= 0; --i) {
473       const string& input = const_input->input(i);
474       if (IsControlInput(input))
475         control_deps->insert(input);
476       else
477         break;
478     }
479   };
480 
481   for (int i = func_node.input_size() - 1; i >= 0; --i) {
482     const string& input = func_node.input(i);
483     if (IsControlInput(input)) continue;
484 
485     const string node_name = NodeName(input);
486     if (ctx.IsTrulyConst(node_name)) {
487       VLOG(3) << "Push const into function body: input=" << input;
488       const auto* const_input = CHECK_NOTNULL(ctx.TrulyConstNode(node_name));
489       const_inputs->insert(input);
490       record_control_deps(const_input);
491       TF_RETURN_IF_ERROR(ReplaceInputWithConst(*const_input, i, item));
492     }
493   }
494 
495   return Status::OK();
496 }
497 
498 // Remove inputs that were pushed into the function body, and attach their
499 // control dependencies to the function caller node.
RemovePushedDownConstInputs(const FunctionSpecialization & specialization,NodeDef * specialized_func_node)500 void RemovePushedDownConstInputs(const FunctionSpecialization& specialization,
501                                  NodeDef* specialized_func_node) {
502   // Nothing to do if it was no const inputs to the function node.
503   if (specialization.const_inputs.empty()) return;
504 
505   // Keep only non-const inputs.
506   std::vector<string> keep_inputs;
507   const auto& inputs = specialized_func_node->input();
508   absl::c_copy_if(inputs, std::back_inserter(keep_inputs),
509                   [&](const string& input) {
510                     return !specialization.const_inputs.contains(input);
511                   });
512 
513   specialized_func_node->clear_input();
514   for (const auto& keep : keep_inputs) specialized_func_node->add_input(keep);
515 
516   // Attach control dependencies of pushed down const input to the caller node.
517   if (!specialization.control_deps.empty()) {
518     absl::flat_hash_set<string> existing_control_deps;
519 
520     for (const string& input : keep_inputs) {
521       existing_control_deps.insert(AsControlDependency(NodeName(input)));
522     }
523 
524     for (const string& ctrl : specialization.control_deps) {
525       if (!existing_control_deps.contains(ctrl)) {
526         VLOG(3) << "Forward control dependency: input=" << ctrl;
527         specialized_func_node->add_input(ctrl);
528       }
529     }
530   }
531 }
532 
533 // Remove Tin type parameters for pushed down const inputs.
RemovePushedDownConstInputTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)534 void RemovePushedDownConstInputTypes(
535     const FunctionSpecialization& specialization, const NodeDef& func_node,
536     NodeDef* specialized_func_node) {
537   // Nothing to do if it was no const inputs to the function node.
538   if (specialization.const_inputs.empty()) return;
539 
540   // Make sure that original function caller has Tin attribute.
541   const AttrValue* tin = AttrSlice(func_node).Find("Tin");
542   if (tin == nullptr || !tin->has_list()) return;
543 
544   // Clear input types for the specialized node.
545   auto* attr = specialized_func_node->mutable_attr();
546   (*attr)["Tin"].mutable_list()->clear_type();
547 
548   // Keep types of non-const inputs.
549   for (int i = 0; i < func_node.input_size(); ++i) {
550     const string& input = func_node.input(i);
551     if (IsControlInput(input)) break;
552 
553     if (!specialization.const_inputs.contains(input)) {
554       DataType dt = tin->list().type(i);
555       (*attr)["Tin"].mutable_list()->add_type(dt);
556     }
557   }
558 }
559 
560 // Remove Tout type parameters for pruned function outputs.
RemoveUnusedOutputsTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)561 void RemoveUnusedOutputsTypes(const FunctionSpecialization& specialization,
562                               const NodeDef& func_node,
563                               NodeDef* specialized_func_node) {
564   // Make sure that original function caller has Tout attribute.
565   const AttrValue* tout = AttrSlice(func_node).Find("Tout");
566   if (tout == nullptr || !tout->has_list()) return;
567 
568   // Nothing to do if all outputs are active.
569   int specialization_active_outputs_size = specialization.active_outputs.size();
570   if (specialization_active_outputs_size == tout->list().type_size()) return;
571 
572   // Clear input types for the specialized node.
573   auto* attr = specialized_func_node->mutable_attr();
574   (*attr)["Tout"].mutable_list()->clear_type();
575 
576   // Keep output types of active outputs only.
577   for (int i = 0; i < tout->list().type_size(); ++i) {
578     if (specialization.active_outputs.contains(i)) {
579       DataType dt = tout->list().type(i);
580       (*attr)["Tout"].mutable_list()->add_type(dt);
581     }
582   }
583 }
584 
UpdateSpecializedFunctionCallSite(const FunctionDef & func,const NodeDef & func_node,const string & specialized_func_name,NodeDef * specialized_func_node)585 Status UpdateSpecializedFunctionCallSite(const FunctionDef& func,
586                                          const NodeDef& func_node,
587                                          const string& specialized_func_name,
588                                          NodeDef* specialized_func_node) {
589   if (IsDirectFunctionCall(func, func_node)) {
590     specialized_func_node->set_op(specialized_func_name);
591 
592   } else if (IsIndirectFunctionCall(func, func_node)) {
593     auto* attr = specialized_func_node->mutable_attr();
594     (*attr)[kFuncAttr].mutable_func()->set_name(specialized_func_name);
595 
596   } else {
597     return errors::InvalidArgument("Unknown function call site");
598   }
599 
600   return Status::OK();
601 }
602 
603 // Update a graph node created from the original function caller node, to the
604 // function specialization. Function specialization might change the number of
605 // inputs and outputs, so we have to make sure that graph node is updated
606 // accordingly.
UpdateSpecializedFunctionNode(const FunctionDef & func,const NodeDef & func_node,const FunctionSpecialization & specialization,NodeDef * specialized_func_node)607 Status UpdateSpecializedFunctionNode(
608     const FunctionDef& func, const NodeDef& func_node,
609     const FunctionSpecialization& specialization,
610     NodeDef* specialized_func_node) {
611   // Function called indirectly via custom kernel (e.g. PartitionedCallOp).
612   bool is_indirect_call = IsIndirectFunctionCall(func, func_node);
613 
614   // 1. Call the specialized function instead of original one.
615   TF_RETURN_IF_ERROR(UpdateSpecializedFunctionCallSite(
616       func, func_node, specialization.specialized_func_name,
617       specialized_func_node));
618 
619   // 2. Remove inputs corresponding to the pushed down consts.
620   RemovePushedDownConstInputs(specialization, specialized_func_node);
621 
622   // NOTE: PartitionedCallOp has `Tin` and `Tout` attributes for input/output
623   // types, that must be in sync with updated function signature.
624 
625   // 3. Update input types for the indirect function calls.
626   if (is_indirect_call) {
627     RemovePushedDownConstInputTypes(specialization, func_node,
628                                     specialized_func_node);
629   }
630 
631   // 4. Update output types for the indirect function call. It's unsafe to
632   // change the number of outputs for the fetch nodes, so we just skip them.
633   if (is_indirect_call && !specialization.is_in_fetch_set) {
634     RemoveUnusedOutputsTypes(specialization, func_node, specialized_func_node);
635   }
636 
637   // 5. Remove custom gradient annotation.
638   specialized_func_node->mutable_attr()->erase("_gradient_op_type");
639 
640   return Status::OK();
641 }
642 
InitializeFunctionSpecializationSignature(const NodeDef & func_node,const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionOptimizerContext & ctx,FunctionSpecializationSignature * sig)643 Status InitializeFunctionSpecializationSignature(
644     const NodeDef& func_node, const FunctionDef& func,
645     const AttrSlice& func_instantiation_attr,
646     const FunctionOptimizerContext& ctx, FunctionSpecializationSignature* sig) {
647   DCHECK(sig->const_inputs.empty());
648   DCHECK(sig->active_outputs.empty());
649 
650   sig->func_name = func.signature().name();
651   sig->is_in_fetch_set = ctx.IsFetchNode(func_node.name());
652   sig->active_outputs = GetActiveOutputs(func_node, ctx);
653 
654   TF_RETURN_IF_ERROR(InstantiationTypeParameters(func, func_instantiation_attr,
655                                                  &sig->type_parameters));
656   TF_RETURN_IF_ERROR(InstantiationBodyParameters(func, func_instantiation_attr,
657                                                  &sig->body_parameters));
658 
659   for (int i = 0; i < func_node.input_size(); ++i) {
660     const string& input = func_node.input(i);
661     if (IsControlInput(input)) break;
662     if (ctx.IsTrulyConst(input)) {
663       sig->const_inputs.emplace(i, input);
664     }
665   }
666 
667   return Status::OK();
668 }
669 
670 // Create a name for the function specialization. The name of the function, name
671 // of the node instantiating it, and a Grappler item id should generate unique
672 // function name. Meta optimizer might create multiple Grappler items for the
673 // same graph when optimizing functions, but it's guaranteed that they all will
674 // have unique ids.
SpecializedFunctionName(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)675 string SpecializedFunctionName(const FunctionOptimizerContext& ctx,
676                                const FunctionDef& func,
677                                const NodeDef& func_node) {
678   return absl::Substitute(
679       "$0_specialized_for_$1_at_$2", func.signature().name(),
680       absl::StrReplaceAll(func_node.name(), {{"/", "_"}}), ctx.item().id);
681 }
682 
SpecializeFunction(const NodeDef & func_node,const FunctionDef & func,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)683 Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
684                           FunctionOptimizerContext* ctx,
685                           GraphDef* optimized_graph) {
686   VLOG(2) << "Specialize function call: " << SummarizeNodeDef(func_node);
687 
688   const AttrSlice func_instantiation_attr =
689       FunctionInstantiationAttributes(func, func_node);
690 
691   FunctionSpecializationSignature signature;
692   TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature(
693       func_node, func, func_instantiation_attr, *ctx, &signature));
694 
695   // Check if function was already specialized for identical context.
696   const FunctionSpecialization* already_specialized =
697       ctx->FindFunctionSpecialization(signature);
698 
699   if (already_specialized) {
700     VLOG(2) << "Function was already specialized in identical context: "
701                "specialized_name="
702             << already_specialized->specialized_func_name;
703 
704     // Add a function call node for the specialized function.
705     NodeDef* specialized_func_node = optimized_graph->add_node();
706     *specialized_func_node = func_node;
707 
708     TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
709         func, func_node, *already_specialized, specialized_func_node));
710 
711     ctx->AddTensorMapping(specialized_func_node->name(), *already_specialized);
712 
713     return Status::OK();
714   }
715 
716   // Add a new specialized function definition to the library.
717   const auto& flib = ctx->function_library();
718 
719   // Make a GrapplerFunctionItem and convert it back to FunctionDef after
720   // pushing all constant inputs into the function body.
721   GrapplerFunctionItem item;
722   TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
723       func, func_instantiation_attr, flib, ctx->graph_version(), &item));
724 
725   // Push const inputs into the function body, and keep track of their control
726   // dependencies.
727   absl::flat_hash_set<string> const_inputs;
728   absl::flat_hash_set<string> control_deps;
729   TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
730                                          &control_deps));
731 
732   // Remove function outputs that do not have any consumers. We can't safely
733   // update outputs for the fetch nodes, so we just skip them.
734   std::vector<std::pair<int, int>> output_mapping;
735   if (!signature.is_in_fetch_set) {
736     int num_func_outputs = item.output_size();
737 
738     absl::flat_hash_set<int> remove;
739     for (int i = 0; i < num_func_outputs; ++i) {
740       if (!signature.active_outputs.count(i)) remove.insert(i);
741     }
742 
743     TF_RETURN_IF_ERROR(RemoveFunctionOutputs(remove, &item, &output_mapping));
744   }
745 
746   // TODO(ezhulenev): Push down known input shapes.
747   FunctionDef specialized_func;
748   TF_RETURN_IF_ERROR(MakeFunctionDef(item, flib, &specialized_func));
749 
750   // Find a name for specialized function.
751   const string specialized_func_name =
752       SpecializedFunctionName(*ctx, func, func_node);
753   if (flib.Contains(specialized_func_name)) {
754     // NOTE(ezhulenev): This should never happen. If it happens, it's a sign of
755     // a serious internal error, that must be investigated.
756     return errors::Internal("Created duplicate function specialization");
757   }
758 
759   specialized_func.mutable_signature()->set_name(specialized_func_name);
760   auto* specialized_attr = specialized_func.mutable_attr();
761   (*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true);
762 
763   // Add specialized function to the library.
764   TF_RETURN_IF_ERROR(ctx->function_library().AddFunctionDef(specialized_func));
765 
766   // Add a function call node for the specialized function.
767   NodeDef* specialized_func_node = optimized_graph->add_node();
768   *specialized_func_node = func_node;
769 
770   FunctionSpecialization func_specialization = {
771       specialized_func_name, signature.is_in_fetch_set, const_inputs,
772       control_deps,          signature.active_outputs,  output_mapping};
773 
774   TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
775       func, func_node, func_specialization, specialized_func_node));
776 
777   ctx->AddSpecializedFunction(signature, func_specialization);
778   ctx->AddTensorMapping(specialized_func_node->name(), func_specialization);
779 
780   return Status::OK();
781 }
782 
783 // -------------------------------------------------------------------------- //
784 // Inline function calls into a graph using function inlining implementation
785 // from common_runtime:
786 //
787 // 1) Convert GraphDef to Graph.
788 // 2) Inline function calls.
789 // 3) Convert Graph back to the GraphDef.
790 
791 constexpr const char* const kLowerUsingSwitchMergeAttr =
792     LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
793 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
794     LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
795 
796 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
797 using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource;
798 
799 // Checks if boolean attribute is defined and its value is 'true'.
CheckBoolAttr(const Node * n,absl::string_view attr_name)800 bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
801   bool match;
802   bool found = TryGetNodeAttr(n->attrs(), attr_name, &match);
803   return found && match;
804 }
805 
806 // Checks if string attribute is defined and it's not empty.
CheckStringAttr(const Node * n,absl::string_view attr_name)807 bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
808   const string& value = GetNodeAttrString(n->attrs(), attr_name);
809   return !value.empty();
810 }
811 
LowerUsingSwitchMergeIsOn(const Node * n)812 bool LowerUsingSwitchMergeIsOn(const Node* n) {
813   return CheckBoolAttr(n, kLowerUsingSwitchMergeAttr);
814 }
815 
LowerAsMultiDeviceFunctionIsOn(const Node * n)816 bool LowerAsMultiDeviceFunctionIsOn(const Node* n) {
817   return CheckBoolAttr(n, kLowerAsMultiDeviceFunctionAttr);
818 }
819 
MarkedForXlaCompilation(const NodeDef & n)820 bool MarkedForXlaCompilation(const NodeDef& n) {
821   auto is_enabled = [&](std::string attr_name) -> bool {
822     auto it = n.attr().find(attr_name);
823     return it != n.attr().end() && (!it->second.s().empty() || it->second.b());
824   };
825   return is_enabled("_xla_compile_id") || is_enabled("_tpu_replicate") ||
826          is_enabled(kXlaMustCompileAttr);
827 }
828 
IsExemptFromSideEffectsExecutionValidation(const string & op)829 const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
830   static const auto* exemption = new absl::flat_hash_set<string>(
831       {// LINT.IfChange
832        // Op types that should not run in program order, e.g. because they need
833        // to run asynchronously to avoid deadlock.
834        "CollectiveGather", "CollectiveGatherV2", "CollectiveReduce",
835        "CollectiveReduceV2", "CollectiveBcastSend", "CollectiveBcastRecv",
836        "CollectiveBcastSendV2", "CollectiveBcastRecvV2", "NcclAllReduce",
837        "Send", "Recv",
838 
839        // Legacy random ops.
840        // See details in tensorflow/python/framework/auto_control_deps.py.
841        "RandomUniform", "RandomUniformInt", "RandomStandardNormal",
842        "ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
843        "Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
844        "RandomPoissonV2",
845 
846        // ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
847        // but it can't generate any observable side-effect.
848        "ReadVariableOp",
849 
850        // CudnnRNN ops are stateful but they can't generate any observable
851        // side-effect.
852        "CudnnRNN", "CudnnRNNBackprop", "CudnnRNNV2", "CudnnRNNV3",
853        "CudnnRNNBackpropV2", "CudnnRNNBackpropV3",
854 
855        // TPUEmbedding EnqueueOps are stateful but this is only between ops with
856        // the same device_ordinal on the same host.
857        "EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
858        "EnqueueTPUEmbeddingSparseTensorBatch",
859        "EnqueueTPUEmbeddingRaggedTensorBatch",
860 
861        // SaveV2 and RestoreV2 should be allowed to operate in parallel on
862        // multiple hosts.
863        "SaveV2", "RestoreV2"});
864   // LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
865   return exemption->contains(op);
866 }
867 
868 // Validates that all side effects inside function body will be executed after
869 // function inlining. We do it by looking for a path from stateful ops, to one
870 // of the output control sources.
871 //
872 // When function executed via FunctionLibraryRuntime we do not have to check
873 // this, because `PruneFunctionBody` has special pruning rules for stateful ops.
ValidateSideEffectsExecution(const FunctionBody & fbody,OutputControlSource output_control_source,bool has_outgoing_control_edges,bool validate_outgoing_control_edge=true)874 Status ValidateSideEffectsExecution(
875     const FunctionBody& fbody, OutputControlSource output_control_source,
876     bool has_outgoing_control_edges,
877     bool validate_outgoing_control_edge = true) {
878   // Find all nodes that can produce side effects in the function body graph. We
879   // use 'is_stateful()' bit as an approximation of "has side effects" property.
880   std::vector<const Node*> fbody_side_effects;
881   absl::c_copy_if(
882       fbody.graph->nodes(), std::back_inserter(fbody_side_effects),
883       [](const Node* n) {
884         return n->op_def().is_stateful() && !n->IsArg() && !n->IsRetval() &&
885                !IsExemptFromSideEffectsExecutionValidation(n->type_string());
886       });
887 
888   // When graph executed in TF-2.0 context with automatic control dependencies
889   // tracking, absence of outgoing control edge indicates that no one is
890   // interested in observing side effects, so it is safe to inline the function
891   // body, even if some side-effects will not be executed.
892   if (!fbody_side_effects.empty() && !has_outgoing_control_edges) {
893     const string error_message =
894         "Can't guarantee execution of function side-effects after inlining. "
895         "Function call node has no outgoing control edges.";
896     if (validate_outgoing_control_edge) {
897       return errors::Internal(error_message);
898     } else {
899       VLOG(3) << error_message;
900     }
901   }
902 
903   // Find all nodes in the function body that will be used as control sources.
904   absl::flat_hash_set<const Node*> control_sources;
905   if (output_control_source == OutputControlSource::kDataOutputs) {
906     control_sources = {fbody.ret_nodes.begin(), fbody.ret_nodes.end()};
907   } else if (output_control_source == OutputControlSource::kControlOutputs) {
908     control_sources = {fbody.control_ret_nodes.begin(),
909                        fbody.control_ret_nodes.end()};
910   }
911 
912   for (const Node* side_effect : fbody_side_effects) {
913     VLOG(4) << "Check that node " << side_effect->name()
914             << " will execute after inlining.";
915     bool will_execute = false;
916 
917     const auto is_control_source = [&](const Node* n) -> void {
918       const auto it = control_sources.find(n);
919       if (it != control_sources.end()) {
920         VLOG(4) << "Found a path to control source: " << side_effect->name()
921                 << " ---> " << (*it)->name();
922         will_execute = true;
923       }
924     };
925 
926     DFSFrom(*fbody.graph, {side_effect}, /*enter=*/is_control_source,
927             /*leave=*/{}, NodeComparatorName{});
928 
929     if (!will_execute) {
930       return errors::Internal(
931           "Can't guarantee execution of a side-effectful node, that is not "
932           "reachable from function control source. Function body node: ",
933           SummarizeNode(*side_effect));
934     }
935   }
936 
937   return Status::OK();
938 }
939 
940 // Validates that no dead tensor can reach function output.
ValidateNoDeadOutputs(const FunctionLibraryDefinition & flib_def,const FunctionBody & fbody)941 Status ValidateNoDeadOutputs(const FunctionLibraryDefinition& flib_def,
942                              const FunctionBody& fbody) {
943   absl::flat_hash_set<const Node*> output_nodes = {fbody.ret_nodes.begin(),
944                                                    fbody.ret_nodes.end()};
945 
946   // Find all nodes that can produce dead tensors.
947   std::vector<const Node*> dead_tensor_sources;
948   for (const Node* n : fbody.graph->nodes()) {
949     if (n->IsSwitch()) {
950       VLOG(4) << "Add dead tensors source. Switch node: " << n->name();
951       dead_tensor_sources.push_back(n);
952       continue;
953     }
954 
955     // Native function call can also produce dead tensors if the function body
956     // has mergeless switches.
957     const FunctionDef* fdef = flib_def.Find(n->type_string());
958     if (fdef != nullptr) {
959       std::unique_ptr<FunctionBody> nested_fbody;
960 
961       NameAttrList func;
962       TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(n->def(), &func));
963       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
964                                                  &flib_def, &nested_fbody));
965 
966       if (!ValidateNoDeadOutputs(flib_def, *nested_fbody).ok()) {
967         VLOG(4) << "Add dead tensors source. Function call: " << func.name()
968                 << " node=" << n->name();
969         dead_tensor_sources.push_back(n);
970       }
971     }
972   }
973 
974   for (const Node* dead_tensor_source : dead_tensor_sources) {
975     bool has_dead_output = false;
976 
977     const auto is_output_node = [&](const Node* n) -> void {
978       const auto it = output_nodes.find(n);
979       if (it != output_nodes.end()) {
980         VLOG(4) << "Found a path to output node from dead tensor source: "
981                 << dead_tensor_source->name() << " ---> " << (*it)->name();
982         has_dead_output = true;
983       }
984     };
985 
986     // Stop DFS traversal at a Merge node or if already found a dead output.
987     const auto stop_traversal = [&has_dead_output](const Edge& edge) -> bool {
988       return !edge.src()->IsMerge() || has_dead_output;
989     };
990 
991     DFSFrom(*fbody.graph, {dead_tensor_source}, /*enter=*/is_output_node,
992             /*leave=*/{}, NodeComparatorName{},
993             /*edge_filter=*/stop_traversal);
994 
995     if (has_dead_output) {
996       return errors::Internal(
997           "Can't inline a function with dead outputs. Dead tensor source: ",
998           SummarizeNode(*dead_tensor_source));
999     }
1000   }
1001 
1002   return Status::OK();
1003 }
1004 
1005 // Makes an instance of FunctionBody for inlining from a Node.
MakeFunctionBodyForInlining(const Node & node,const FunctionLibraryDefinition & flib_def,std::unique_ptr<FunctionBody> * fbody)1006 Status MakeFunctionBodyForInlining(const Node& node,
1007                                    const FunctionLibraryDefinition& flib_def,
1008                                    std::unique_ptr<FunctionBody>* fbody) {
1009   VLOG(3) << "Make function body for inlining: " << SummarizeNode(node);
1010 
1011   // Finds a FunctionDef in a library and verifies that it exists.
1012   const auto find_fdef = [&flib_def, &node](
1013                              const string& name,
1014                              const FunctionDef** fdef) -> Status {
1015     if ((*fdef = flib_def.Find(name)) == nullptr) {
1016       return errors::Internal(
1017           "Was not able to find a function definition (name=", name,
1018           ") for a function call: ", SummarizeNode(node));
1019     }
1020     return Status::OK();
1021   };
1022 
1023   // SymbolicGradient is a special "function call" op, which has been
1024   // deprecated for a while, but we still support for compatibility reasons.
1025   if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
1026     NameAttrList func;
1027     TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), kFuncAttr, &func));
1028 
1029     const string grad = flib_def.FindGradient(func.name());
1030 
1031     if (!grad.empty()) {
1032       // Function has a custom gradient registered in a library.
1033       const FunctionDef* grad_fdef;
1034       TF_RETURN_IF_ERROR(find_fdef(grad, &grad_fdef));
1035 
1036       VLOG(4) << "Instantiate a custom SymbolicGradient: gradient=" << grad
1037               << " (function=" << func.name() << ")";
1038       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1039           *grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1040 
1041     } else if (flib_def.Find(func.name()) == nullptr) {
1042       // Function is not really a function, but a primitive op.
1043       gradient::Creator creator;
1044       TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
1045       if (creator == nullptr) {
1046         return errors::InvalidArgument("No gradient is defined for ",
1047                                        func.name());
1048       }
1049       FunctionDef grad_fdef;
1050       TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
1051 
1052       VLOG(4) << "Instantiate a SymbolicGradient for a primitive op: "
1053               << func.name();
1054       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1055           grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1056 
1057     } else {
1058       // Build a gradient graph from the function body.
1059       const FunctionDef* fdef;
1060       TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1061 
1062       VLOG(4) << "Instantiate a SymbolicGradient for a function: "
1063               << func.name();
1064       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1065                                                  &flib_def, fbody));
1066       *fbody = SymbolicGradient(**fbody);
1067     }
1068 
1069   } else {
1070     NameAttrList func;
1071     TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node.def(), &func));
1072     const FunctionDef* fdef;
1073     TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1074 
1075     VLOG(4) << "Instantiate a function call: function=" << func.name();
1076     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1077                                                &flib_def, fbody));
1078   }
1079 
1080   return Status::OK();
1081 }
1082 
1083 // Adds a control edges from each data input to the 'caller' to enforce strict
1084 // inputs semantics (all inputs are ready and alive). This is required when:
1085 //
1086 //  1) The function takes resources as inputs, and it doesn't have incoming
1087 //     control edges. In Tensorflow v2 context (eager mode) this should never
1088 //     happen, because automatic control dependencies tracking will add a
1089 //     control edge from the last op touching the resource. However such graphs
1090 //     might be produced by legacy v1 code without automatic dependency
1091 //     tracking. In this case strict function call semantics is required for
1092 //     enforcing side effects execution order.
1093 //
1094 //  2) One of the inputs is consuming Enter[is_constant=true] node, in which
1095 //     case it will be always alive, and potentially can lead to partial
1096 //     function execution after the last loop execution.
1097 //
1098 // Both of these cases would be considered illegal by construction in Tensorflow
1099 // V2, however we have to guarantee that graphs constructed with Tensorflow V1
1100 // will produce correct results.
AddStrictInputSemantics(Node * caller,Graph * g)1101 void AddStrictInputSemantics(Node* caller, Graph* g) {
1102   absl::flat_hash_set<const Node*> existing_control_sources;
1103   for (const Edge* edge : caller->in_edges()) {
1104     if (edge->IsControlEdge()) {
1105       existing_control_sources.insert(edge->src());
1106     }
1107   }
1108 
1109   const bool has_incoming_control_edges = !existing_control_sources.empty();
1110 
1111   const bool has_resource_input =
1112       absl::c_any_of(caller->input_types(),
1113                      [](const DataType dtype) { return dtype == DT_RESOURCE; });
1114 
1115   const bool has_constant_enter_input =
1116       absl::c_any_of(caller->in_edges(), [](const Edge* edge) {
1117         Node* src = edge->src();
1118         return src->IsEnter() && CheckBoolAttr(src, "is_constant");
1119       });
1120 
1121   const bool requires_strict_semantics =
1122       (!has_incoming_control_edges && has_resource_input) ||  // Case #1
1123       (has_constant_enter_input);                             // Case #2
1124   if (!requires_strict_semantics) return;
1125 
1126   std::set<const Node*> data_inputs;
1127   for (const Edge* edge : caller->in_edges()) {
1128     if (!edge->IsControlEdge() &&
1129         !existing_control_sources.contains(edge->src())) {
1130       data_inputs.insert(edge->src());
1131     }
1132   }
1133 
1134   VLOG(3) << "Add control edges from all data inputs to enforce strict "
1135              "semantics with regard to function inputs";
1136 
1137   // Do not add control edges from placeholders, because it will prevent
1138   // pruning, and they can't produce any side effects anyway.
1139   const auto is_placeholder = [](const Node* node) -> bool {
1140     return node->type_string() == "Placeholder";
1141   };
1142 
1143   for (const Node* node : data_inputs) {
1144     if (is_placeholder(node)) continue;
1145     g->AddControlEdge(g->FindNodeId(node->id()), caller,
1146                       /*allow_duplicates=*/true);
1147   }
1148 }
1149 
1150 // Adds a control edge from a frame node if the 'caller' is executing inside a
1151 // While loop (see control_flow.h for the 'frame' node explanation).
AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo> & info,Node * caller,Graph * g)1152 void AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo>& info,
1153                                    Node* caller, Graph* g) {
1154   // All nodes added to the graph by v2 control flow lowering and function
1155   // inlining are guaranteed to have control edges to nested function calls.
1156   int info_size = info.size();
1157   if (caller->id() >= info_size) return;
1158 
1159   // Check if a lowered node is executing inside a while loop.
1160   const Node* frame = info[caller->id()].frame;
1161   const bool is_in_while_loop = frame->id() != Graph::kSourceId;
1162   if (!is_in_while_loop) return;
1163 
1164   // Check if a node already has an incoming control edge. All incoming edges
1165   // must be from the same execution frame (executor.cc invariant), so if we
1166   // already have an incoming control edge, it's guaranteed that it will "carry"
1167   // the same frame as all regular inputs.
1168   const bool has_incoming_control_edges =
1169       absl::c_any_of(caller->in_edges(),
1170                      [](const Edge* edge) { return edge->IsControlEdge(); });
1171   if (has_incoming_control_edges) return;
1172 
1173   VLOG(3) << "Add a frame forwarding control edge: from=" << frame->name()
1174           << " to=" << caller->name();
1175   Node* enter = g->FindNodeId(frame->id());
1176   bool is_constant_enter = enter->attrs().Find("is_constant")->b();
1177   if (is_constant_enter) {
1178     // Enter[is_constant=true] is always alive. So we directly add a control
1179     // edge from that.
1180     g->AddControlEdge(enter, caller);
1181   } else {
1182     // Enter[is_constant=false] activates nodes only in 0th iteration so we
1183     // add an edge from the Merge node which is activated in every iteration.
1184     // A non-constant Enter node must have an edge to a Merge node.
1185     auto it = absl::c_find_if(enter->out_edges(), [](const Edge* e) {
1186       return !e->IsControlEdge() && e->dst()->IsMerge();
1187     });
1188     if (it != enter->out_edges().end()) {
1189       g->AddControlEdge((*it)->dst(), caller);
1190     } else {
1191       LOG(WARNING) << "Enter[is_constant=false] node: " << enter->name()
1192                    << " does not have an outgoing edge to a Merge.";
1193     }
1194   }
1195 }
1196 
1197 // Inlines all function calls that are safe for inlining into the main graph.
1198 // Also lowers control flow V2 ops (functional If/While) into the V1 low level
1199 // ops (Switch/Merge/...).
1200 //
1201 // Runs a placer after inlining, to keep all nodes in a graph placed.
InlineFunctionCalls(const GrapplerItem & item,const RewriterConfig::Toggle opt_level,const bool lower_control_flow,GraphDef * output_graph)1202 Status InlineFunctionCalls(const GrapplerItem& item,
1203                            const RewriterConfig::Toggle opt_level,
1204                            const bool lower_control_flow,
1205                            GraphDef* output_graph) {
1206   bool is_aggressive = opt_level == RewriterConfig::AGGRESSIVE;
1207   VLOG(2) << "Inline function calls: grappler_item_id=" << item.id
1208           << " (aggressive_mode=" << is_aggressive << ")";
1209 
1210   FunctionLibraryDefinition flib_def =
1211       FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library());
1212   std::unique_ptr<Graph> graph = absl::make_unique<Graph>(flib_def);
1213 
1214   GraphConstructorOptions graph_constructor_options;
1215   graph_constructor_options.allow_internal_ops = true;
1216   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_constructor_options,
1217                                             item.graph, graph.get()));
1218 
1219   using NodeNames = absl::flat_hash_set<absl::string_view>;
1220   NodeNames fetch_nodes;
1221   fetch_nodes.reserve(item.fetch.size());
1222   for (const string& fetch : item.fetch) {
1223     fetch_nodes.insert(ParseTensorName(fetch).node());
1224   }
1225   NodeNames keep_nodes(item.keep_ops.begin(), item.keep_ops.end());
1226 
1227   std::vector<string> inlined_function_names;
1228 
1229   // Do not inline function call nodes that are part of a feed set.
1230   NodeNames feed_nodes;
1231   feed_nodes.reserve(item.feed.size());
1232   for (const std::pair<std::string, Tensor>& feed : item.feed) {
1233     feed_nodes.insert(ParseTensorName(feed.first).node());
1234   }
1235 
1236   // If a function call is inside a While loop, it must have an incoming control
1237   // edge, because it will be used to pass execution frame into the function
1238   // body. All nodes without inputs in the function body (e.g. Const and NoOp)
1239   // will be added an extra control edge from the 'input_control_node'.
1240   std::vector<ControlFlowInfo> control_flow_info;
1241   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &control_flow_info));
1242 
1243   // Function inlining always adds new nodes to the end of the list, so we keep
1244   // iterating until we are out of nodes.
1245   for (int i = 2; i < graph->num_node_ids(); ++i) {
1246     Node* n = graph->FindNodeId(i);
1247     if (n == nullptr) continue;  // deleted node
1248 
1249     // Special case for lowering functional control flow ops. We do not rely on
1250     // LowerFunctionOpsPass because in Grappler we have to be more restrictive
1251     // about what type of function calls we are allowed to inline.
1252     if (lower_control_flow && LowerUsingSwitchMergeIsOn(n)) {
1253       VLOG(2) << "Lower functional control flow op: " << SummarizeNode(*n);
1254       AddStrictInputSemantics(n, graph.get());
1255       AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1256 
1257       if (n->IsIfNode()) {
1258         TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), false));
1259       } else if (n->IsCaseNode()) {
1260         TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), false));
1261       } else if (n->IsWhileNode()) {
1262         TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), &flib_def, false));
1263       }
1264       continue;
1265     }
1266 
1267     // Skip nodes that are not function calls.
1268     if (!IsFunctionCall(flib_def, *n)) continue;
1269     // Skip function calls that we plan to compile later.
1270     if (MarkedForXlaCompilation(n->def())) continue;
1271     // Skip nodes in a feed set.
1272     if (feed_nodes.contains(n->name())) continue;
1273 
1274     // Function body that we will inline into the main graph. It can be a
1275     // function instantiation, or a gradient function instantiated from
1276     // SymbolicGradient op.
1277     std::unique_ptr<FunctionBody> fbody;
1278     TF_RETURN_IF_ERROR(MakeFunctionBodyForInlining(*n, flib_def, &fbody));
1279 
1280     InlineFunctionBodyOptions inline_options;
1281     // Ignore '_noinline' flag in aggressive mode.
1282     inline_options.ignore_noinline = is_aggressive;
1283 
1284     // Function calls created after inlining If/While ops are always inlined as
1285     // multi-device functions and are not required to pass additional Grappler
1286     // validations (side effects execution validation below).
1287     bool force_inline_as_multi_device = LowerAsMultiDeviceFunctionIsOn(n);
1288 
1289     // `PartitionedCall` is a TF-2.0 function call mechanism for multi-device
1290     // functions:
1291     // a) Function can be multi-device.
1292     // b) Automatic control dependencies tracking guarantees that all function
1293     //    side-effectful nodes will have a path to one of the control outputs.
1294     //    Control outputs and control edges between side-effectful (stateful)
1295     //    nodes are used to explicitly mark the nodes that must execute, and to
1296     //    define their execution order.
1297     if (n->IsPartitionedCall() || force_inline_as_multi_device) {
1298       inline_options.output_control_src = OutputControlSource::kControlOutputs;
1299       inline_options.inlined_function_body_placer =
1300           InlinedFunctionBodyPlacer::MultiDevice();
1301     } else {
1302       inline_options.output_control_src = OutputControlSource::kDataOutputs;
1303       inline_options.inlined_function_body_placer =
1304           InlinedFunctionBodyPlacer::SingleDevice();
1305     }
1306 
1307     if (fetch_nodes.contains(n->name())) {
1308       inline_options.keep_caller_node = KeepCallerNode::kFetchable;
1309     } else if (keep_nodes.contains(n->name())) {
1310       inline_options.keep_caller_node = KeepCallerNode::kTargetable;
1311     } else {
1312       inline_options.keep_caller_node = KeepCallerNode::kDoNotKeep;
1313     }
1314 
1315     // Basic validation rules defined in common_runtime shared by all functions.
1316     Status can_inline_function_call =
1317         ValidateInlining(n, fbody.get(), inline_options);
1318 
1319     // Additional validation rules defined only in Grappler.
1320     // TODO(ezhulenev): Move it to common_runtime InlineFunctionBodyOptions?
1321     if (can_inline_function_call.ok()) {
1322       bool has_outgoing_control_edges = absl::c_any_of(
1323           n->out_edges(),
1324           [](const Edge* edge) { return edge->IsControlEdge(); });
1325 
1326       can_inline_function_call = ValidateSideEffectsExecution(
1327           *fbody, inline_options.output_control_src,
1328           has_outgoing_control_edges);
1329 
1330       if (!can_inline_function_call.ok() &&
1331           (is_aggressive || force_inline_as_multi_device)) {
1332         VLOG(2) << "Ignore error: " << can_inline_function_call.error_message();
1333         can_inline_function_call = Status::OK();
1334       }
1335     }
1336     if (can_inline_function_call.ok()) {
1337       can_inline_function_call = ValidateNoDeadOutputs(flib_def, *fbody);
1338     }
1339 
1340     if (can_inline_function_call.ok()) {
1341       VLOG(2) << "Inline function call node: " << n->name();
1342       AddStrictInputSemantics(n, graph.get());
1343       AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1344 
1345       TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, graph.get(), n,
1346                                             fbody.get(), inline_options));
1347       inlined_function_names.push_back(fbody->fdef.signature().name());
1348 
1349     } else {
1350       VLOG(2) << "Failed to inline function call node: "
1351               << can_inline_function_call.error_message();
1352     }
1353   }
1354 
1355   VLOG(4) << "Inlined " << inlined_function_names.size()
1356           << " function calls: " << absl::StrJoin(inlined_function_names, ", ");
1357 
1358   // ------------------------------------------------------------------------ //
1359   // Grappler receives the graph after PRE_PLACEMENT, Placer, and POST_PLACEMENT
1360   // passes, so each node has a valid device assignment. After function inlining
1361   // and control flow V2 lowering we have to keep graph placed.
1362 
1363   if (inlined_function_names.empty()) {
1364     VLOG(3) << "Not placing graph after function inlining"
1365             << " (did not inline any of the function calls).";
1366 
1367   } else if (item.devices().empty()) {
1368     // If there are no devices available for placer, we do not place graph after
1369     // function inlining. This happens when Grappler is optimizing the function
1370     // library, or when a graph optimized "offline", without an active runtime
1371     // session, for example as a part of batch job for graph
1372     // analysis/optimization. GrapplerItem instantiated from a function library
1373     // doesn't have to be fully placed after all optimizations; it will be
1374     // placed by the function library runtime before execution.
1375     VLOG(3) << "Not placing graph after function inlining"
1376             << " (device set is empty)";
1377 
1378   } else {
1379     // If we are running in an active runtime session, Grappler will get the
1380     // graph after initial placing is done, and we should have devices for the
1381     // placer.
1382     VLOG(3) << "Run placer for the graph after function inlining. "
1383             << "Devices: [" << absl::StrJoin(item.devices(), ", ") << "]";
1384 
1385     DeviceSet device_set;                               // does not own devices
1386     std::vector<std::unique_ptr<Device>> fake_devices;  // owns fake devices
1387 
1388     for (const string& name : item.devices()) {
1389       auto device = absl::make_unique<FakeDevice>(name);
1390       device_set.AddDevice(device.get());
1391       fake_devices.push_back(std::move(device));
1392     }
1393 
1394     Placer placer(graph.get(), item.id, &flib_def, &device_set);
1395     TF_RETURN_IF_ERROR(placer.Run());
1396   }
1397 
1398   graph->ToGraphDef(output_graph);
1399   return Status::OK();
1400 }
1401 
1402 // Restores tensor mapping after function specialization: all inputs must be
1403 // connected to valid nodes.
RestoreTensorMapping(const FunctionOptimizerContext & ctx,GraphDef * optimized_graph)1404 void RestoreTensorMapping(const FunctionOptimizerContext& ctx,
1405                           GraphDef* optimized_graph) {
1406   if (ctx.tensor_mapping().empty()) return;
1407 
1408   // During function specialization, we might prune unused function outputs. We
1409   // need to "close the holes" that might appear in the function outputs.
1410   //
1411   // Example: prune unused output "f:1"
1412   //
1413   //   f = my_func[T=float](...)          f = my_func_specialized[T=float](...)
1414   //   a = Identity(f:0)             ->   a = Identity(f:0)
1415   //   b = Identity(f:2)                  b = Identity(f:1)
1416   //
1417   // Tensor mapping (size=1): [f:2 -> f:1]
1418   for (NodeDef& node : *optimized_graph->mutable_node()) {
1419     for (int idx = 0; idx < node.input_size(); ++idx) {
1420       TensorId input_tensor = ParseTensorName(node.input(idx));
1421       if (input_tensor.index() == Graph::kControlSlot) break;
1422 
1423       auto mapping = ctx.tensor_mapping().find(input_tensor);
1424       if (mapping != ctx.tensor_mapping().end()) {
1425         node.set_input(idx, TensorIdToString(mapping->second));
1426       }
1427     }
1428   }
1429 }
1430 
1431 }  // namespace
1432 
RunFunctionOptimizerPass(const GrapplerItem & item,GraphDef * optimized_graph) const1433 Status FunctionOptimizer::RunFunctionOptimizerPass(
1434     const GrapplerItem& item, GraphDef* optimized_graph) const {
1435   VLOG(3) << "Run function optimizer pass: grappler_item_id=" << item.id;
1436 
1437   // Inline all function calls into a graph using common_runtime/function
1438   // implementation (see `InlineFunctionBody` function documentation).
1439   GraphDef graph_after_inlining;
1440   TF_RETURN_IF_ERROR(InlineFunctionCalls(item, opt_level_, lower_control_flow_,
1441                                          &graph_after_inlining));
1442 
1443   // Specialize function calls that we could not inline.
1444   FunctionOptimizerContext ctx(item, opt_level_, graph_after_inlining);
1445 
1446   for (const NodeDef& node : graph_after_inlining.node()) {
1447     // Function specialization can modify optimized graph only by adding new
1448     // nodes, we can check node size to make sure that graph was not modified.
1449     const int num_nodes_before = optimized_graph->node_size();
1450     const auto is_graph_modified = [&]() {
1451       int num_nodes = optimized_graph->node_size();
1452       DCHECK_GE(num_nodes, num_nodes_before) << "Nodes should not be removed";
1453       return num_nodes > num_nodes_before;
1454     };
1455 
1456     // Copy node from the `graph_after_inlining` to the `optimized_graph`.
1457     const auto copy_node = [&]() { *optimized_graph->add_node() = node; };
1458 
1459     // Find if a node is a function call (direct or indirect).
1460     const FunctionDef* func = FindFunctionCall(ctx, node);
1461     if (func == nullptr) {
1462       copy_node();
1463       continue;
1464     }
1465 
1466     const string& func_name = func->signature().name();
1467 
1468     // Specialize it to its instantiation context if it has something worth
1469     // specializing.
1470     const bool specialization_worthy = IsParametrized(*func) ||
1471                                        HasTrulyConstInputs(node, ctx) ||
1472                                        HasUnusedOutputs(node, *func, ctx);
1473 
1474     // Do not specialize if function has custom gradient or marked nospecialize.
1475     const string grad_func = ctx.function_library().FindGradient(func_name);
1476     const bool no_specialize =
1477         !grad_func.empty() || ctx.IsFeedNode(node.name()) ||
1478         MarkedNoSpecialize(*func) || MarkedForXlaCompilation(node);
1479 
1480     if (specialization_worthy && !no_specialize) {
1481       // TODO(ezhulenev): Specialize function call if input has a known shape.
1482       // Specialize function body for its instantiation attributes and inputs.
1483       Status status = SpecializeFunction(node, *func, &ctx, optimized_graph);
1484       if (!status.ok() && is_graph_modified()) {
1485         return status;
1486       } else if (!status.ok() && !is_graph_modified()) {
1487         VLOG(3) << "Skip specialization error: " << status.error_message();
1488         copy_node();
1489       }
1490       continue;
1491     } else {
1492       VLOG(2) << "Skip function specialization: " << func->signature().name();
1493       copy_node();
1494     }
1495   }
1496 
1497   RestoreTensorMapping(ctx, optimized_graph);
1498 
1499   // Preserve the graph version.
1500   *optimized_graph->mutable_versions() = item.graph.versions();
1501   // Prune unreachable function from the library.
1502   *optimized_graph->mutable_library() =
1503       PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
1504 
1505   return Status::OK();
1506 }
1507 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)1508 Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
1509                                    GraphDef* optimized_graph) {
1510   // Nothing to do here.
1511   if (item.graph.library().function_size() == 0) {
1512     return errors::Aborted("Nothing to do.");
1513   }
1514 
1515   TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(item, optimized_graph));
1516 
1517   return Status::OK();
1518 }
1519 
1520 }  // end namespace grappler
1521 }  // end namespace tensorflow
1522