1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
17
18 #include <vector>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_replace.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/device_set.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/common_runtime/lower_case_op.h"
33 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
34 #include "tensorflow/core/common_runtime/lower_if_op.h"
35 #include "tensorflow/core/common_runtime/lower_while_op.h"
36 #include "tensorflow/core/common_runtime/placer.h"
37 #include "tensorflow/core/framework/attr_value_util.h"
38 #include "tensorflow/core/framework/function.h"
39 #include "tensorflow/core/framework/function.pb.h"
40 #include "tensorflow/core/framework/graph_def_util.h"
41 #include "tensorflow/core/framework/node_def.pb.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/framework/op_def.pb.h"
44 #include "tensorflow/core/framework/versions.pb.h"
45 #include "tensorflow/core/graph/algorithm.h"
46 #include "tensorflow/core/graph/control_flow.h"
47 #include "tensorflow/core/graph/graph_node_util.h"
48 #include "tensorflow/core/graph/tensor_id.h"
49 #include "tensorflow/core/grappler/graph_view.h"
50 #include "tensorflow/core/grappler/grappler_item.h"
51 #include "tensorflow/core/grappler/op_types.h"
52 #include "tensorflow/core/grappler/utils.h"
53 #include "tensorflow/core/grappler/utils/functions.h"
54 #include "tensorflow/core/lib/gtl/map_util.h"
55
56 namespace tensorflow {
57 namespace grappler {
58 namespace {
59
60 constexpr const char* const kFuncAttr = FunctionLibraryDefinition::kFuncAttr;
61
62 // Do not specialize functions marked with '_nospecialize' attribute.
63 constexpr const char* const kNoSpecializeAttr = "_nospecialize";
64
65 // Mark functions that were created as a result of function specialization.
66 constexpr const char* const kGrapplerSpecializedFuncAttr =
67 "_GrapplerSpecializedFunc";
68
69 // There are two ways of calling a Tensorflow function:
70 //
71 // 1. Direct function call: node.op() is the name of the function.
72 //
73 // 2. Indirect function call: the function name is passed through a node
74 // attribute, and special Tensorflow kernels are responsible for calling the
75 // function through the FunctionLibraryRuntime. Example: PartitionedCallOp.
76
77 // Check if func_node.op() matches the name in FunctionDef signature.
IsDirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)78 bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
79 return func_node.op() == func.signature().name();
80 }
81
82 // Check if func_node has function attribute with a function name matching
83 // FunctionDef signature.
IsIndirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)84 bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
85 if (!IsPartitionedCall(func_node) && !IsStatefulPartitionedCall(func_node)) {
86 return false;
87 }
88
89 auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
90 return func_attr != nullptr && func_attr->has_func() &&
91 func_attr->func().name() == func.signature().name();
92 }
93
FunctionInstantiationAttributes(const FunctionDef & func,const NodeDef & func_node)94 AttrSlice FunctionInstantiationAttributes(const FunctionDef& func,
95 const NodeDef& func_node) {
96 if (IsDirectFunctionCall(func, func_node)) {
97 return AttrSlice(func_node);
98
99 } else if (IsIndirectFunctionCall(func, func_node)) {
100 auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
101 return AttrSlice(&func_attr->func().attr());
102
103 } else {
104 LOG(WARNING) << "Can't resolve function instantiation attributes: "
105 << SummarizeNodeDef(func_node);
106 return AttrSlice();
107 }
108 }
109
110 // This is a fake device that should not be used for any op kernel execution,
111 // the only purpose of this device is to be passed as a part of DeviceSet to the
112 // Placer.
113 class FakeDevice : public Device {
114 public:
FakeDevice(Env * env,const string & device)115 FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {}
FakeDevice(const string & device)116 explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {}
Sync()117 Status Sync() override { return Status::OK(); }
118
119 private:
attr(const string & device)120 static DeviceAttributes attr(const string& device) {
121 DeviceNameUtils::ParsedName parsed_name;
122 bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name);
123 DCHECK(parsed) << "Failed to parse full device name: " << device;
124
125 DeviceAttributes attr;
126 attr.set_name(device);
127 attr.set_device_type(parsed_name.type);
128 return attr;
129 }
130 };
131
132 // -------------------------------------------------------------------------- //
133 // Function specialization.
134 //
135 // FunctionDef is somewhat similar to function template in C++, given all the
136 // type parameters (and attribute values) it generates a statically defined
137 // graph from the type parametrized "graph template" (function body).
138 //
139 // Function specialization instantiates a parametrized FunctionDef into a
140 // statically defined graph, and then converts it back to the fully defined
141 // FunctionDef (it doesn't have any unknown type parameters or attribute
142 // values, known as placeholders).
143 //
144 // Given the fully specified graph we can apply all the Grappler optimizers to
145 // it (see details in MetaOptimizer). Also we can push known constant inputs
146 // into the function body, and remove unused outputs/inputs.
147
MarkedNoSpecialize(const FunctionDef & fdef)148 bool MarkedNoSpecialize(const FunctionDef& fdef) {
149 const auto attr = AttrSlice(&fdef.attr());
150 bool nospecialize = false;
151 return TryGetNodeAttr(attr, kNoSpecializeAttr, &nospecialize) && nospecialize;
152 }
153
154 // Specialized function instantiation type parameters, body parameters, and
155 // const inputs.
156 struct FunctionSpecializationSignature {
157 // Currently we do not support functions with tensor lists as inputs or
158 // outputs, so caller node input/output ports always match function
159 // input/output arguments.
160 using InputPort = int;
161 using OutputPort = int;
162
163 string func_name;
164 bool is_in_fetch_set;
165 absl::flat_hash_set<OutputPort> active_outputs;
166 absl::flat_hash_map<string, DataType> type_parameters;
167 absl::flat_hash_map<string, AttrValue> body_parameters;
168 absl::flat_hash_map<InputPort, string> const_inputs;
169
operator ==tensorflow::grappler::__anon42ddc3660111::FunctionSpecializationSignature170 bool operator==(const FunctionSpecializationSignature& other) const {
171 bool equals = func_name == other.func_name &&
172 is_in_fetch_set == other.is_in_fetch_set &&
173 active_outputs == other.active_outputs &&
174 type_parameters == other.type_parameters &&
175 const_inputs == other.const_inputs;
176
177 if (!equals) return false;
178
179 // Equality is not defined for AttrValue.
180 if (body_parameters.size() != other.body_parameters.size()) return false;
181
182 for (const auto& lhs : body_parameters) {
183 auto it = other.body_parameters.find(lhs.first);
184 if (it == other.body_parameters.end()) return false;
185 if (!AreAttrValuesEqual(lhs.second, (*it).second,
186 /*allow_false_negatives=*/true)) {
187 return false;
188 }
189 }
190
191 return true;
192 }
193
194 template <typename H>
AbslHashValue(H h,const FunctionSpecializationSignature & s)195 friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) {
196 H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set);
197
198 // First pre-compute hashes for all values in collections with
199 // non-deterministic iteration order.
200 std::vector<uint64> hashes;
201 hashes.reserve(s.active_outputs.size() //
202 + s.type_parameters.size() * 2 //
203 + s.body_parameters.size() * 2 //
204 + s.const_inputs.size() * 2);
205
206 absl::c_transform(s.active_outputs, std::back_inserter(hashes),
207 hash<OutputPort>());
208
209 using TypeParam = std::pair<const string, DataType>;
210 absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) {
211 AttrValue attr_value;
212 attr_value.set_type(type_param.second);
213 hashes.push_back(Hash64(type_param.first));
214 hashes.push_back(AttrValueHash(attr_value));
215 });
216
217 using BodyParam = std::pair<const string, AttrValue>;
218 absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) {
219 hashes.push_back(Hash64(body_param.first));
220 hashes.push_back(FastAttrValueHash(body_param.second));
221 });
222
223 using ConstInput = std::pair<const InputPort, string>;
224 absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) {
225 hashes.push_back(hash<InputPort>()(const_input.first));
226 hashes.push_back(Hash64(const_input.second));
227 });
228
229 // Combine all pre-computed hashes in a deterministic order.
230 absl::c_sort(hashes);
231 return H::combine_contiguous(std::move(base), hashes.data(), hashes.size());
232 }
233 };
234
235 struct FunctionSpecialization {
236 string specialized_func_name;
237 // True if the function caller node is in GrapplerItem fetch set.
238 bool is_in_fetch_set;
239 // Names of the tensors that were pushed down into the function body.
240 absl::flat_hash_set<string> const_inputs;
241 // Control dependencies of pushed down const inputs have to be attached to
242 // function caller node.
243 absl::flat_hash_set<string> control_deps;
244 // Output tensors (ports) that consumed by other nodes in the graph or in a
245 // GrapplerItem fetch set.
246 absl::flat_hash_set<int> active_outputs;
247 // Mapping from original function output port to the output port of
248 // specialized function. If function specialization changes the number of
249 // function outputs it's required to update all node consumers.
250 std::vector<std::pair<int, int>> output_mapping;
251 };
252
253 // Function optimizer context initialized once for each optimization pass, and
254 // it uses the latest available graph (for the first iteration it will be the
255 // GrapplerItem.graph, for next iterations it will be the output of previous
256 // function optimizer pass).
257 class FunctionOptimizerContext {
258 public:
FunctionOptimizerContext(const GrapplerItem & item,RewriterConfig::Toggle opt_level,const GraphDef & graph)259 explicit FunctionOptimizerContext(const GrapplerItem& item,
260 RewriterConfig::Toggle opt_level,
261 const GraphDef& graph)
262 : item_(&item),
263 opt_level_(opt_level),
264 function_library_(OpRegistry::Global(), graph.library()),
265 truly_const_nodes_(InferTrulyConstNodes(item, graph)),
266 graph_view_(&graph) {}
267
item() const268 const GrapplerItem& item() const { return *item_; }
269
graph_version() const270 const int graph_version() const { return item_->graph.versions().producer(); }
271
opt_level() const272 RewriterConfig::Toggle opt_level() const { return opt_level_; }
273
function_library() const274 const FunctionLibraryDefinition& function_library() const {
275 return function_library_;
276 }
function_library()277 FunctionLibraryDefinition& function_library() { return function_library_; }
278
279 const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
tensor_mapping() const280 tensor_mapping() const {
281 return tensor_mapping_;
282 }
283
graph_view() const284 const GraphView& graph_view() const { return graph_view_; }
285
IsFeedNode(const string & node_name) const286 bool IsFeedNode(const string& node_name) const {
287 return absl::c_any_of(
288 item_->feed, [&](const std::pair<std::string, Tensor>& feed) {
289 return ParseTensorName(feed.first).node() == node_name;
290 });
291 }
292
IsFetchNode(const string & node_name) const293 bool IsFetchNode(const string& node_name) const {
294 return absl::c_any_of(item_->fetch, [&](const string& fetch) {
295 return ParseTensorName(fetch).node() == node_name;
296 });
297 }
298
IsTrulyConst(const string & name) const299 bool IsTrulyConst(const string& name) const {
300 return TrulyConstNode(name) != nullptr;
301 }
302
TrulyConstNode(const string & name) const303 const NodeDef* TrulyConstNode(const string& name) const {
304 return gtl::FindWithDefault(truly_const_nodes_, name, nullptr);
305 }
306
FindFunctionSpecialization(const FunctionSpecializationSignature & sig) const307 const FunctionSpecialization* FindFunctionSpecialization(
308 const FunctionSpecializationSignature& sig) const {
309 return gtl::FindOrNull(specialized_functions_, sig);
310 }
311
AddSpecializedFunction(const FunctionSpecializationSignature & sig,const FunctionSpecialization & specialized_func)312 void AddSpecializedFunction(const FunctionSpecializationSignature& sig,
313 const FunctionSpecialization& specialized_func) {
314 specialized_functions_.emplace(sig, specialized_func);
315 }
316
AddTensorMapping(const SafeTensorId & from,const SafeTensorId & to)317 void AddTensorMapping(const SafeTensorId& from, const SafeTensorId& to) {
318 DCHECK(from.index() != Graph::kControlSlot)
319 << "Tensor mapping must be from regular tensor";
320 DCHECK(to.index() != Graph::kControlSlot)
321 << "Tensor mapping must be to regular tensor";
322
323 auto inserted = tensor_mapping_.insert({from, to});
324 DCHECK(inserted.second)
325 << "Failed to insert duplicated tensor mapping: "
326 << "from=" << from.ToString() << " to=" << to.ToString();
327 }
328
AddTensorMapping(const string & func_node,const FunctionSpecialization & specialized_func)329 void AddTensorMapping(const string& func_node,
330 const FunctionSpecialization& specialized_func) {
331 for (const auto& pair : specialized_func.output_mapping) {
332 int from_idx = pair.first;
333 int to_idx = pair.second;
334 if (from_idx != to_idx) {
335 SafeTensorId from_tensor(func_node, from_idx);
336 SafeTensorId to_tensor(func_node, to_idx);
337 AddTensorMapping(from_tensor, to_tensor);
338 }
339 }
340 }
341
342 private:
InferTrulyConstNodes(const GrapplerItem & item,const GraphDef & graph)343 static absl::flat_hash_map<string, const NodeDef*> InferTrulyConstNodes(
344 const GrapplerItem& item, const GraphDef& graph) {
345 absl::flat_hash_set<absl::string_view> feed_nodes;
346 for (const auto& feed : item.feed) {
347 feed_nodes.insert(feed.first);
348 }
349
350 absl::flat_hash_map<string, const NodeDef*> const_nodes;
351 for (const NodeDef& node : graph.node()) {
352 if (IsConstant(node) && !feed_nodes.contains(node.name())) {
353 const_nodes[node.name()] = &node;
354 }
355 }
356
357 return const_nodes;
358 }
359
360 const GrapplerItem* item_; // must outlive this object
361 RewriterConfig::Toggle opt_level_;
362
363 // Function library constructed from current graph.
364 FunctionLibraryDefinition function_library_;
365
366 // Nodes that are Const and not in feed.
367 absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
368 // Specialized functions.
369 absl::flat_hash_map<FunctionSpecializationSignature,
370 const FunctionSpecialization>
371 specialized_functions_;
372
373 // After function specialization, the optimized graph might be in invalid
374 // state, nodes can read from output index that is no longer valid after
375 // unused outputs pruning.
376 //
377 // Tensor mapping that has to be applied to the graph after all functions
378 // optimizations (invalidated tensor id -> optimized graph tensor id).
379 absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
380 tensor_mapping_;
381
382 // Use graph view to find active outputs of the function caller nodes.
383 GraphView graph_view_;
384
385 TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
386 };
387
388 // Returns a pointer to the called function definition iff the given node is
389 // indeed a function call. Otherwise returns nullptr.
FindFunctionCall(const FunctionOptimizerContext & ctx,const NodeDef & node)390 const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
391 const NodeDef& node) {
392 // Check if a node does indirect function call via PartitionedCallOp.
393 if (IsPartitionedCall(node) || IsStatefulPartitionedCall(node)) {
394 const AttrValue* func_attr = AttrSlice(node).Find("f");
395 return (func_attr != nullptr && func_attr->has_func())
396 ? ctx.function_library().Find(func_attr->func().name())
397 : nullptr;
398 }
399
400 // Check if the function op itself is a function name.
401 return ctx.function_library().Find(node.op());
402 }
403
GetActiveOutputs(const NodeDef & node,const FunctionOptimizerContext & ctx,int size_hint=0)404 absl::flat_hash_set<int> GetActiveOutputs(const NodeDef& node,
405 const FunctionOptimizerContext& ctx,
406 int size_hint = 0) {
407 absl::flat_hash_set<int> active_outputs;
408 active_outputs.reserve(static_cast<size_t>(size_hint));
409
410 // 1. Output can be consumed by the other graph node.
411 const auto node_fanout_edges =
412 ctx.graph_view().GetFanoutEdges(node, /*include_controlled_edges=*/false);
413 for (const GraphView::Edge& edge : node_fanout_edges) {
414 active_outputs.insert(edge.src.port_id);
415 }
416
417 // 2. Or it can be in a fetch set.
418 for (const string& fetch : ctx.item().fetch) {
419 TensorId fetch_tensor = ParseTensorName(fetch);
420 if (fetch_tensor.node() == node.name()) {
421 active_outputs.insert(fetch_tensor.index());
422 }
423 }
424
425 return active_outputs;
426 }
427
HasTrulyConstInputs(const NodeDef & node,const FunctionOptimizerContext & ctx)428 bool HasTrulyConstInputs(const NodeDef& node,
429 const FunctionOptimizerContext& ctx) {
430 const auto is_truly_const = [&ctx](const string& input) {
431 return ctx.IsTrulyConst(NodeName(input));
432 };
433 return absl::c_any_of(node.input(), is_truly_const);
434 }
435
HasUnusedOutputs(const NodeDef & func_node,const FunctionDef & func,const FunctionOptimizerContext & ctx)436 bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
437 const FunctionOptimizerContext& ctx) {
438 // Functions with tensor list outputs are not supported right now, so the
439 // number of output args is the same as number of possible function caller
440 // node outputs.
441 int num_outputs = func.signature().output_arg_size();
442 const absl::flat_hash_set<int> active_outputs =
443 GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
444 int active_outputs_size = active_outputs.size();
445 return active_outputs_size != num_outputs;
446 }
447
448 // Return pruned FunctionDefLibrary with functions that are reachable from
449 // the optimized graph.
PruneFunctionLibrary(const FunctionLibraryDefinition & flib,const GraphDef & optimized_graph)450 FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
451 const GraphDef& optimized_graph) {
452 FunctionLibraryDefinition pruned_flib =
453 flib.ReachableDefinitions(optimized_graph);
454
455 int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
456 static_cast<int>(flib.num_functions());
457
458 VLOG(3) << "Pruned function library: " << pruned_flib.num_functions()
459 << " functions (" << pruned_functions << ")";
460
461 return pruned_flib.ToProto();
462 }
463
464 // Push all constant inputs of an instantiating node into the function body.
PushDownConstInputs(const NodeDef & func_node,const FunctionOptimizerContext & ctx,GrapplerFunctionItem * item,absl::flat_hash_set<string> * const_inputs,absl::flat_hash_set<string> * control_deps)465 Status PushDownConstInputs(const NodeDef& func_node,
466 const FunctionOptimizerContext& ctx,
467 GrapplerFunctionItem* item,
468 absl::flat_hash_set<string>* const_inputs,
469 absl::flat_hash_set<string>* control_deps) {
470 // Record node control dependencies in the control_deps set.
471 const auto record_control_deps = [&](const NodeDef* const_input) {
472 for (int i = const_input->input_size() - 1; i >= 0; --i) {
473 const string& input = const_input->input(i);
474 if (IsControlInput(input))
475 control_deps->insert(input);
476 else
477 break;
478 }
479 };
480
481 for (int i = func_node.input_size() - 1; i >= 0; --i) {
482 const string& input = func_node.input(i);
483 if (IsControlInput(input)) continue;
484
485 const string node_name = NodeName(input);
486 if (ctx.IsTrulyConst(node_name)) {
487 VLOG(3) << "Push const into function body: input=" << input;
488 const auto* const_input = CHECK_NOTNULL(ctx.TrulyConstNode(node_name));
489 const_inputs->insert(input);
490 record_control_deps(const_input);
491 TF_RETURN_IF_ERROR(ReplaceInputWithConst(*const_input, i, item));
492 }
493 }
494
495 return Status::OK();
496 }
497
498 // Remove inputs that were pushed into the function body, and attach their
499 // control dependencies to the function caller node.
RemovePushedDownConstInputs(const FunctionSpecialization & specialization,NodeDef * specialized_func_node)500 void RemovePushedDownConstInputs(const FunctionSpecialization& specialization,
501 NodeDef* specialized_func_node) {
502 // Nothing to do if it was no const inputs to the function node.
503 if (specialization.const_inputs.empty()) return;
504
505 // Keep only non-const inputs.
506 std::vector<string> keep_inputs;
507 const auto& inputs = specialized_func_node->input();
508 absl::c_copy_if(inputs, std::back_inserter(keep_inputs),
509 [&](const string& input) {
510 return !specialization.const_inputs.contains(input);
511 });
512
513 specialized_func_node->clear_input();
514 for (const auto& keep : keep_inputs) specialized_func_node->add_input(keep);
515
516 // Attach control dependencies of pushed down const input to the caller node.
517 if (!specialization.control_deps.empty()) {
518 absl::flat_hash_set<string> existing_control_deps;
519
520 for (const string& input : keep_inputs) {
521 existing_control_deps.insert(AsControlDependency(NodeName(input)));
522 }
523
524 for (const string& ctrl : specialization.control_deps) {
525 if (!existing_control_deps.contains(ctrl)) {
526 VLOG(3) << "Forward control dependency: input=" << ctrl;
527 specialized_func_node->add_input(ctrl);
528 }
529 }
530 }
531 }
532
533 // Remove Tin type parameters for pushed down const inputs.
RemovePushedDownConstInputTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)534 void RemovePushedDownConstInputTypes(
535 const FunctionSpecialization& specialization, const NodeDef& func_node,
536 NodeDef* specialized_func_node) {
537 // Nothing to do if it was no const inputs to the function node.
538 if (specialization.const_inputs.empty()) return;
539
540 // Make sure that original function caller has Tin attribute.
541 const AttrValue* tin = AttrSlice(func_node).Find("Tin");
542 if (tin == nullptr || !tin->has_list()) return;
543
544 // Clear input types for the specialized node.
545 auto* attr = specialized_func_node->mutable_attr();
546 (*attr)["Tin"].mutable_list()->clear_type();
547
548 // Keep types of non-const inputs.
549 for (int i = 0; i < func_node.input_size(); ++i) {
550 const string& input = func_node.input(i);
551 if (IsControlInput(input)) break;
552
553 if (!specialization.const_inputs.contains(input)) {
554 DataType dt = tin->list().type(i);
555 (*attr)["Tin"].mutable_list()->add_type(dt);
556 }
557 }
558 }
559
560 // Remove Tout type parameters for pruned function outputs.
RemoveUnusedOutputsTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)561 void RemoveUnusedOutputsTypes(const FunctionSpecialization& specialization,
562 const NodeDef& func_node,
563 NodeDef* specialized_func_node) {
564 // Make sure that original function caller has Tout attribute.
565 const AttrValue* tout = AttrSlice(func_node).Find("Tout");
566 if (tout == nullptr || !tout->has_list()) return;
567
568 // Nothing to do if all outputs are active.
569 int specialization_active_outputs_size = specialization.active_outputs.size();
570 if (specialization_active_outputs_size == tout->list().type_size()) return;
571
572 // Clear input types for the specialized node.
573 auto* attr = specialized_func_node->mutable_attr();
574 (*attr)["Tout"].mutable_list()->clear_type();
575
576 // Keep output types of active outputs only.
577 for (int i = 0; i < tout->list().type_size(); ++i) {
578 if (specialization.active_outputs.contains(i)) {
579 DataType dt = tout->list().type(i);
580 (*attr)["Tout"].mutable_list()->add_type(dt);
581 }
582 }
583 }
584
UpdateSpecializedFunctionCallSite(const FunctionDef & func,const NodeDef & func_node,const string & specialized_func_name,NodeDef * specialized_func_node)585 Status UpdateSpecializedFunctionCallSite(const FunctionDef& func,
586 const NodeDef& func_node,
587 const string& specialized_func_name,
588 NodeDef* specialized_func_node) {
589 if (IsDirectFunctionCall(func, func_node)) {
590 specialized_func_node->set_op(specialized_func_name);
591
592 } else if (IsIndirectFunctionCall(func, func_node)) {
593 auto* attr = specialized_func_node->mutable_attr();
594 (*attr)[kFuncAttr].mutable_func()->set_name(specialized_func_name);
595
596 } else {
597 return errors::InvalidArgument("Unknown function call site");
598 }
599
600 return Status::OK();
601 }
602
603 // Update a graph node created from the original function caller node, to the
604 // function specialization. Function specialization might change the number of
605 // inputs and outputs, so we have to make sure that graph node is updated
606 // accordingly.
UpdateSpecializedFunctionNode(const FunctionDef & func,const NodeDef & func_node,const FunctionSpecialization & specialization,NodeDef * specialized_func_node)607 Status UpdateSpecializedFunctionNode(
608 const FunctionDef& func, const NodeDef& func_node,
609 const FunctionSpecialization& specialization,
610 NodeDef* specialized_func_node) {
611 // Function called indirectly via custom kernel (e.g. PartitionedCallOp).
612 bool is_indirect_call = IsIndirectFunctionCall(func, func_node);
613
614 // 1. Call the specialized function instead of original one.
615 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionCallSite(
616 func, func_node, specialization.specialized_func_name,
617 specialized_func_node));
618
619 // 2. Remove inputs corresponding to the pushed down consts.
620 RemovePushedDownConstInputs(specialization, specialized_func_node);
621
622 // NOTE: PartitionedCallOp has `Tin` and `Tout` attributes for input/output
623 // types, that must be in sync with updated function signature.
624
625 // 3. Update input types for the indirect function calls.
626 if (is_indirect_call) {
627 RemovePushedDownConstInputTypes(specialization, func_node,
628 specialized_func_node);
629 }
630
631 // 4. Update output types for the indirect function call. It's unsafe to
632 // change the number of outputs for the fetch nodes, so we just skip them.
633 if (is_indirect_call && !specialization.is_in_fetch_set) {
634 RemoveUnusedOutputsTypes(specialization, func_node, specialized_func_node);
635 }
636
637 // 5. Remove custom gradient annotation.
638 specialized_func_node->mutable_attr()->erase("_gradient_op_type");
639
640 return Status::OK();
641 }
642
InitializeFunctionSpecializationSignature(const NodeDef & func_node,const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionOptimizerContext & ctx,FunctionSpecializationSignature * sig)643 Status InitializeFunctionSpecializationSignature(
644 const NodeDef& func_node, const FunctionDef& func,
645 const AttrSlice& func_instantiation_attr,
646 const FunctionOptimizerContext& ctx, FunctionSpecializationSignature* sig) {
647 DCHECK(sig->const_inputs.empty());
648 DCHECK(sig->active_outputs.empty());
649
650 sig->func_name = func.signature().name();
651 sig->is_in_fetch_set = ctx.IsFetchNode(func_node.name());
652 sig->active_outputs = GetActiveOutputs(func_node, ctx);
653
654 TF_RETURN_IF_ERROR(InstantiationTypeParameters(func, func_instantiation_attr,
655 &sig->type_parameters));
656 TF_RETURN_IF_ERROR(InstantiationBodyParameters(func, func_instantiation_attr,
657 &sig->body_parameters));
658
659 for (int i = 0; i < func_node.input_size(); ++i) {
660 const string& input = func_node.input(i);
661 if (IsControlInput(input)) break;
662 if (ctx.IsTrulyConst(input)) {
663 sig->const_inputs.emplace(i, input);
664 }
665 }
666
667 return Status::OK();
668 }
669
670 // Create a name for the function specialization. The name of the function, name
671 // of the node instantiating it, and a Grappler item id should generate unique
672 // function name. Meta optimizer might create multiple Grappler items for the
673 // same graph when optimizing functions, but it's guaranteed that they all will
674 // have unique ids.
SpecializedFunctionName(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)675 string SpecializedFunctionName(const FunctionOptimizerContext& ctx,
676 const FunctionDef& func,
677 const NodeDef& func_node) {
678 return absl::Substitute(
679 "$0_specialized_for_$1_at_$2", func.signature().name(),
680 absl::StrReplaceAll(func_node.name(), {{"/", "_"}}), ctx.item().id);
681 }
682
SpecializeFunction(const NodeDef & func_node,const FunctionDef & func,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)683 Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
684 FunctionOptimizerContext* ctx,
685 GraphDef* optimized_graph) {
686 VLOG(2) << "Specialize function call: " << SummarizeNodeDef(func_node);
687
688 const AttrSlice func_instantiation_attr =
689 FunctionInstantiationAttributes(func, func_node);
690
691 FunctionSpecializationSignature signature;
692 TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature(
693 func_node, func, func_instantiation_attr, *ctx, &signature));
694
695 // Check if function was already specialized for identical context.
696 const FunctionSpecialization* already_specialized =
697 ctx->FindFunctionSpecialization(signature);
698
699 if (already_specialized) {
700 VLOG(2) << "Function was already specialized in identical context: "
701 "specialized_name="
702 << already_specialized->specialized_func_name;
703
704 // Add a function call node for the specialized function.
705 NodeDef* specialized_func_node = optimized_graph->add_node();
706 *specialized_func_node = func_node;
707
708 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
709 func, func_node, *already_specialized, specialized_func_node));
710
711 ctx->AddTensorMapping(specialized_func_node->name(), *already_specialized);
712
713 return Status::OK();
714 }
715
716 // Add a new specialized function definition to the library.
717 const auto& flib = ctx->function_library();
718
719 // Make a GrapplerFunctionItem and convert it back to FunctionDef after
720 // pushing all constant inputs into the function body.
721 GrapplerFunctionItem item;
722 TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
723 func, func_instantiation_attr, flib, ctx->graph_version(), &item));
724
725 // Push const inputs into the function body, and keep track of their control
726 // dependencies.
727 absl::flat_hash_set<string> const_inputs;
728 absl::flat_hash_set<string> control_deps;
729 TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
730 &control_deps));
731
732 // Remove function outputs that do not have any consumers. We can't safely
733 // update outputs for the fetch nodes, so we just skip them.
734 std::vector<std::pair<int, int>> output_mapping;
735 if (!signature.is_in_fetch_set) {
736 int num_func_outputs = item.output_size();
737
738 absl::flat_hash_set<int> remove;
739 for (int i = 0; i < num_func_outputs; ++i) {
740 if (!signature.active_outputs.count(i)) remove.insert(i);
741 }
742
743 TF_RETURN_IF_ERROR(RemoveFunctionOutputs(remove, &item, &output_mapping));
744 }
745
746 // TODO(ezhulenev): Push down known input shapes.
747 FunctionDef specialized_func;
748 TF_RETURN_IF_ERROR(MakeFunctionDef(item, flib, &specialized_func));
749
750 // Find a name for specialized function.
751 const string specialized_func_name =
752 SpecializedFunctionName(*ctx, func, func_node);
753 if (flib.Contains(specialized_func_name)) {
754 // NOTE(ezhulenev): This should never happen. If it happens, it's a sign of
755 // a serious internal error, that must be investigated.
756 return errors::Internal("Created duplicate function specialization");
757 }
758
759 specialized_func.mutable_signature()->set_name(specialized_func_name);
760 auto* specialized_attr = specialized_func.mutable_attr();
761 (*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true);
762
763 // Add specialized function to the library.
764 TF_RETURN_IF_ERROR(ctx->function_library().AddFunctionDef(specialized_func));
765
766 // Add a function call node for the specialized function.
767 NodeDef* specialized_func_node = optimized_graph->add_node();
768 *specialized_func_node = func_node;
769
770 FunctionSpecialization func_specialization = {
771 specialized_func_name, signature.is_in_fetch_set, const_inputs,
772 control_deps, signature.active_outputs, output_mapping};
773
774 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
775 func, func_node, func_specialization, specialized_func_node));
776
777 ctx->AddSpecializedFunction(signature, func_specialization);
778 ctx->AddTensorMapping(specialized_func_node->name(), func_specialization);
779
780 return Status::OK();
781 }
782
783 // -------------------------------------------------------------------------- //
784 // Inline function calls into a graph using function inlining implementation
785 // from common_runtime:
786 //
787 // 1) Convert GraphDef to Graph.
788 // 2) Inline function calls.
789 // 3) Convert Graph back to the GraphDef.
790
791 constexpr const char* const kLowerUsingSwitchMergeAttr =
792 LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
793 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
794 LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
795
796 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
797 using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource;
798
799 // Checks if boolean attribute is defined and its value is 'true'.
CheckBoolAttr(const Node * n,absl::string_view attr_name)800 bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
801 bool match;
802 bool found = TryGetNodeAttr(n->attrs(), attr_name, &match);
803 return found && match;
804 }
805
806 // Checks if string attribute is defined and it's not empty.
CheckStringAttr(const Node * n,absl::string_view attr_name)807 bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
808 const string& value = GetNodeAttrString(n->attrs(), attr_name);
809 return !value.empty();
810 }
811
LowerUsingSwitchMergeIsOn(const Node * n)812 bool LowerUsingSwitchMergeIsOn(const Node* n) {
813 return CheckBoolAttr(n, kLowerUsingSwitchMergeAttr);
814 }
815
LowerAsMultiDeviceFunctionIsOn(const Node * n)816 bool LowerAsMultiDeviceFunctionIsOn(const Node* n) {
817 return CheckBoolAttr(n, kLowerAsMultiDeviceFunctionAttr);
818 }
819
MarkedForXlaCompilation(const NodeDef & n)820 bool MarkedForXlaCompilation(const NodeDef& n) {
821 auto is_enabled = [&](std::string attr_name) -> bool {
822 auto it = n.attr().find(attr_name);
823 return it != n.attr().end() && (!it->second.s().empty() || it->second.b());
824 };
825 return is_enabled("_xla_compile_id") || is_enabled("_tpu_replicate") ||
826 is_enabled(kXlaMustCompileAttr);
827 }
828
IsExemptFromSideEffectsExecutionValidation(const string & op)829 const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
830 static const auto* exemption = new absl::flat_hash_set<string>(
831 {// LINT.IfChange
832 // Op types that should not run in program order, e.g. because they need
833 // to run asynchronously to avoid deadlock.
834 "CollectiveGather", "CollectiveGatherV2", "CollectiveReduce",
835 "CollectiveReduceV2", "CollectiveBcastSend", "CollectiveBcastRecv",
836 "CollectiveBcastSendV2", "CollectiveBcastRecvV2", "NcclAllReduce",
837 "Send", "Recv",
838
839 // Legacy random ops.
840 // See details in tensorflow/python/framework/auto_control_deps.py.
841 "RandomUniform", "RandomUniformInt", "RandomStandardNormal",
842 "ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
843 "Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
844 "RandomPoissonV2",
845
846 // ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
847 // but it can't generate any observable side-effect.
848 "ReadVariableOp",
849
850 // CudnnRNN ops are stateful but they can't generate any observable
851 // side-effect.
852 "CudnnRNN", "CudnnRNNBackprop", "CudnnRNNV2", "CudnnRNNV3",
853 "CudnnRNNBackpropV2", "CudnnRNNBackpropV3",
854
855 // TPUEmbedding EnqueueOps are stateful but this is only between ops with
856 // the same device_ordinal on the same host.
857 "EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
858 "EnqueueTPUEmbeddingSparseTensorBatch",
859 "EnqueueTPUEmbeddingRaggedTensorBatch",
860
861 // SaveV2 and RestoreV2 should be allowed to operate in parallel on
862 // multiple hosts.
863 "SaveV2", "RestoreV2"});
864 // LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
865 return exemption->contains(op);
866 }
867
868 // Validates that all side effects inside function body will be executed after
869 // function inlining. We do it by looking for a path from stateful ops, to one
870 // of the output control sources.
871 //
872 // When function executed via FunctionLibraryRuntime we do not have to check
873 // this, because `PruneFunctionBody` has special pruning rules for stateful ops.
ValidateSideEffectsExecution(const FunctionBody & fbody,OutputControlSource output_control_source,bool has_outgoing_control_edges,bool validate_outgoing_control_edge=true)874 Status ValidateSideEffectsExecution(
875 const FunctionBody& fbody, OutputControlSource output_control_source,
876 bool has_outgoing_control_edges,
877 bool validate_outgoing_control_edge = true) {
878 // Find all nodes that can produce side effects in the function body graph. We
879 // use 'is_stateful()' bit as an approximation of "has side effects" property.
880 std::vector<const Node*> fbody_side_effects;
881 absl::c_copy_if(
882 fbody.graph->nodes(), std::back_inserter(fbody_side_effects),
883 [](const Node* n) {
884 return n->op_def().is_stateful() && !n->IsArg() && !n->IsRetval() &&
885 !IsExemptFromSideEffectsExecutionValidation(n->type_string());
886 });
887
888 // When graph executed in TF-2.0 context with automatic control dependencies
889 // tracking, absence of outgoing control edge indicates that no one is
890 // interested in observing side effects, so it is safe to inline the function
891 // body, even if some side-effects will not be executed.
892 if (!fbody_side_effects.empty() && !has_outgoing_control_edges) {
893 const string error_message =
894 "Can't guarantee execution of function side-effects after inlining. "
895 "Function call node has no outgoing control edges.";
896 if (validate_outgoing_control_edge) {
897 return errors::Internal(error_message);
898 } else {
899 VLOG(3) << error_message;
900 }
901 }
902
903 // Find all nodes in the function body that will be used as control sources.
904 absl::flat_hash_set<const Node*> control_sources;
905 if (output_control_source == OutputControlSource::kDataOutputs) {
906 control_sources = {fbody.ret_nodes.begin(), fbody.ret_nodes.end()};
907 } else if (output_control_source == OutputControlSource::kControlOutputs) {
908 control_sources = {fbody.control_ret_nodes.begin(),
909 fbody.control_ret_nodes.end()};
910 }
911
912 for (const Node* side_effect : fbody_side_effects) {
913 VLOG(4) << "Check that node " << side_effect->name()
914 << " will execute after inlining.";
915 bool will_execute = false;
916
917 const auto is_control_source = [&](const Node* n) -> void {
918 const auto it = control_sources.find(n);
919 if (it != control_sources.end()) {
920 VLOG(4) << "Found a path to control source: " << side_effect->name()
921 << " ---> " << (*it)->name();
922 will_execute = true;
923 }
924 };
925
926 DFSFrom(*fbody.graph, {side_effect}, /*enter=*/is_control_source,
927 /*leave=*/{}, NodeComparatorName{});
928
929 if (!will_execute) {
930 return errors::Internal(
931 "Can't guarantee execution of a side-effectful node, that is not "
932 "reachable from function control source. Function body node: ",
933 SummarizeNode(*side_effect));
934 }
935 }
936
937 return Status::OK();
938 }
939
940 // Validates that no dead tensor can reach function output.
ValidateNoDeadOutputs(const FunctionLibraryDefinition & flib_def,const FunctionBody & fbody)941 Status ValidateNoDeadOutputs(const FunctionLibraryDefinition& flib_def,
942 const FunctionBody& fbody) {
943 absl::flat_hash_set<const Node*> output_nodes = {fbody.ret_nodes.begin(),
944 fbody.ret_nodes.end()};
945
946 // Find all nodes that can produce dead tensors.
947 std::vector<const Node*> dead_tensor_sources;
948 for (const Node* n : fbody.graph->nodes()) {
949 if (n->IsSwitch()) {
950 VLOG(4) << "Add dead tensors source. Switch node: " << n->name();
951 dead_tensor_sources.push_back(n);
952 continue;
953 }
954
955 // Native function call can also produce dead tensors if the function body
956 // has mergeless switches.
957 const FunctionDef* fdef = flib_def.Find(n->type_string());
958 if (fdef != nullptr) {
959 std::unique_ptr<FunctionBody> nested_fbody;
960
961 NameAttrList func;
962 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(n->def(), &func));
963 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
964 &flib_def, &nested_fbody));
965
966 if (!ValidateNoDeadOutputs(flib_def, *nested_fbody).ok()) {
967 VLOG(4) << "Add dead tensors source. Function call: " << func.name()
968 << " node=" << n->name();
969 dead_tensor_sources.push_back(n);
970 }
971 }
972 }
973
974 for (const Node* dead_tensor_source : dead_tensor_sources) {
975 bool has_dead_output = false;
976
977 const auto is_output_node = [&](const Node* n) -> void {
978 const auto it = output_nodes.find(n);
979 if (it != output_nodes.end()) {
980 VLOG(4) << "Found a path to output node from dead tensor source: "
981 << dead_tensor_source->name() << " ---> " << (*it)->name();
982 has_dead_output = true;
983 }
984 };
985
986 // Stop DFS traversal at a Merge node or if already found a dead output.
987 const auto stop_traversal = [&has_dead_output](const Edge& edge) -> bool {
988 return !edge.src()->IsMerge() || has_dead_output;
989 };
990
991 DFSFrom(*fbody.graph, {dead_tensor_source}, /*enter=*/is_output_node,
992 /*leave=*/{}, NodeComparatorName{},
993 /*edge_filter=*/stop_traversal);
994
995 if (has_dead_output) {
996 return errors::Internal(
997 "Can't inline a function with dead outputs. Dead tensor source: ",
998 SummarizeNode(*dead_tensor_source));
999 }
1000 }
1001
1002 return Status::OK();
1003 }
1004
1005 // Makes an instance of FunctionBody for inlining from a Node.
MakeFunctionBodyForInlining(const Node & node,const FunctionLibraryDefinition & flib_def,std::unique_ptr<FunctionBody> * fbody)1006 Status MakeFunctionBodyForInlining(const Node& node,
1007 const FunctionLibraryDefinition& flib_def,
1008 std::unique_ptr<FunctionBody>* fbody) {
1009 VLOG(3) << "Make function body for inlining: " << SummarizeNode(node);
1010
1011 // Finds a FunctionDef in a library and verifies that it exists.
1012 const auto find_fdef = [&flib_def, &node](
1013 const string& name,
1014 const FunctionDef** fdef) -> Status {
1015 if ((*fdef = flib_def.Find(name)) == nullptr) {
1016 return errors::Internal(
1017 "Was not able to find a function definition (name=", name,
1018 ") for a function call: ", SummarizeNode(node));
1019 }
1020 return Status::OK();
1021 };
1022
1023 // SymbolicGradient is a special "function call" op, which has been
1024 // deprecated for a while, but we still support for compatibility reasons.
1025 if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
1026 NameAttrList func;
1027 TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), kFuncAttr, &func));
1028
1029 const string grad = flib_def.FindGradient(func.name());
1030
1031 if (!grad.empty()) {
1032 // Function has a custom gradient registered in a library.
1033 const FunctionDef* grad_fdef;
1034 TF_RETURN_IF_ERROR(find_fdef(grad, &grad_fdef));
1035
1036 VLOG(4) << "Instantiate a custom SymbolicGradient: gradient=" << grad
1037 << " (function=" << func.name() << ")";
1038 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1039 *grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1040
1041 } else if (flib_def.Find(func.name()) == nullptr) {
1042 // Function is not really a function, but a primitive op.
1043 gradient::Creator creator;
1044 TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
1045 if (creator == nullptr) {
1046 return errors::InvalidArgument("No gradient is defined for ",
1047 func.name());
1048 }
1049 FunctionDef grad_fdef;
1050 TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
1051
1052 VLOG(4) << "Instantiate a SymbolicGradient for a primitive op: "
1053 << func.name();
1054 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1055 grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1056
1057 } else {
1058 // Build a gradient graph from the function body.
1059 const FunctionDef* fdef;
1060 TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1061
1062 VLOG(4) << "Instantiate a SymbolicGradient for a function: "
1063 << func.name();
1064 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1065 &flib_def, fbody));
1066 *fbody = SymbolicGradient(**fbody);
1067 }
1068
1069 } else {
1070 NameAttrList func;
1071 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node.def(), &func));
1072 const FunctionDef* fdef;
1073 TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1074
1075 VLOG(4) << "Instantiate a function call: function=" << func.name();
1076 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1077 &flib_def, fbody));
1078 }
1079
1080 return Status::OK();
1081 }
1082
1083 // Adds a control edges from each data input to the 'caller' to enforce strict
1084 // inputs semantics (all inputs are ready and alive). This is required when:
1085 //
1086 // 1) The function takes resources as inputs, and it doesn't have incoming
1087 // control edges. In Tensorflow v2 context (eager mode) this should never
1088 // happen, because automatic control dependencies tracking will add a
1089 // control edge from the last op touching the resource. However such graphs
1090 // might be produced by legacy v1 code without automatic dependency
1091 // tracking. In this case strict function call semantics is required for
1092 // enforcing side effects execution order.
1093 //
1094 // 2) One of the inputs is consuming Enter[is_constant=true] node, in which
1095 // case it will be always alive, and potentially can lead to partial
1096 // function execution after the last loop execution.
1097 //
1098 // Both of these cases would be considered illegal by construction in Tensorflow
1099 // V2, however we have to guarantee that graphs constructed with Tensorflow V1
1100 // will produce correct results.
AddStrictInputSemantics(Node * caller,Graph * g)1101 void AddStrictInputSemantics(Node* caller, Graph* g) {
1102 absl::flat_hash_set<const Node*> existing_control_sources;
1103 for (const Edge* edge : caller->in_edges()) {
1104 if (edge->IsControlEdge()) {
1105 existing_control_sources.insert(edge->src());
1106 }
1107 }
1108
1109 const bool has_incoming_control_edges = !existing_control_sources.empty();
1110
1111 const bool has_resource_input =
1112 absl::c_any_of(caller->input_types(),
1113 [](const DataType dtype) { return dtype == DT_RESOURCE; });
1114
1115 const bool has_constant_enter_input =
1116 absl::c_any_of(caller->in_edges(), [](const Edge* edge) {
1117 Node* src = edge->src();
1118 return src->IsEnter() && CheckBoolAttr(src, "is_constant");
1119 });
1120
1121 const bool requires_strict_semantics =
1122 (!has_incoming_control_edges && has_resource_input) || // Case #1
1123 (has_constant_enter_input); // Case #2
1124 if (!requires_strict_semantics) return;
1125
1126 std::set<const Node*> data_inputs;
1127 for (const Edge* edge : caller->in_edges()) {
1128 if (!edge->IsControlEdge() &&
1129 !existing_control_sources.contains(edge->src())) {
1130 data_inputs.insert(edge->src());
1131 }
1132 }
1133
1134 VLOG(3) << "Add control edges from all data inputs to enforce strict "
1135 "semantics with regard to function inputs";
1136
1137 // Do not add control edges from placeholders, because it will prevent
1138 // pruning, and they can't produce any side effects anyway.
1139 const auto is_placeholder = [](const Node* node) -> bool {
1140 return node->type_string() == "Placeholder";
1141 };
1142
1143 for (const Node* node : data_inputs) {
1144 if (is_placeholder(node)) continue;
1145 g->AddControlEdge(g->FindNodeId(node->id()), caller,
1146 /*allow_duplicates=*/true);
1147 }
1148 }
1149
1150 // Adds a control edge from a frame node if the 'caller' is executing inside a
1151 // While loop (see control_flow.h for the 'frame' node explanation).
AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo> & info,Node * caller,Graph * g)1152 void AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo>& info,
1153 Node* caller, Graph* g) {
1154 // All nodes added to the graph by v2 control flow lowering and function
1155 // inlining are guaranteed to have control edges to nested function calls.
1156 int info_size = info.size();
1157 if (caller->id() >= info_size) return;
1158
1159 // Check if a lowered node is executing inside a while loop.
1160 const Node* frame = info[caller->id()].frame;
1161 const bool is_in_while_loop = frame->id() != Graph::kSourceId;
1162 if (!is_in_while_loop) return;
1163
1164 // Check if a node already has an incoming control edge. All incoming edges
1165 // must be from the same execution frame (executor.cc invariant), so if we
1166 // already have an incoming control edge, it's guaranteed that it will "carry"
1167 // the same frame as all regular inputs.
1168 const bool has_incoming_control_edges =
1169 absl::c_any_of(caller->in_edges(),
1170 [](const Edge* edge) { return edge->IsControlEdge(); });
1171 if (has_incoming_control_edges) return;
1172
1173 VLOG(3) << "Add a frame forwarding control edge: from=" << frame->name()
1174 << " to=" << caller->name();
1175 Node* enter = g->FindNodeId(frame->id());
1176 bool is_constant_enter = enter->attrs().Find("is_constant")->b();
1177 if (is_constant_enter) {
1178 // Enter[is_constant=true] is always alive. So we directly add a control
1179 // edge from that.
1180 g->AddControlEdge(enter, caller);
1181 } else {
1182 // Enter[is_constant=false] activates nodes only in 0th iteration so we
1183 // add an edge from the Merge node which is activated in every iteration.
1184 // A non-constant Enter node must have an edge to a Merge node.
1185 auto it = absl::c_find_if(enter->out_edges(), [](const Edge* e) {
1186 return !e->IsControlEdge() && e->dst()->IsMerge();
1187 });
1188 if (it != enter->out_edges().end()) {
1189 g->AddControlEdge((*it)->dst(), caller);
1190 } else {
1191 LOG(WARNING) << "Enter[is_constant=false] node: " << enter->name()
1192 << " does not have an outgoing edge to a Merge.";
1193 }
1194 }
1195 }
1196
1197 // Inlines all function calls that are safe for inlining into the main graph.
1198 // Also lowers control flow V2 ops (functional If/While) into the V1 low level
1199 // ops (Switch/Merge/...).
1200 //
1201 // Runs a placer after inlining, to keep all nodes in a graph placed.
InlineFunctionCalls(const GrapplerItem & item,const RewriterConfig::Toggle opt_level,const bool lower_control_flow,GraphDef * output_graph)1202 Status InlineFunctionCalls(const GrapplerItem& item,
1203 const RewriterConfig::Toggle opt_level,
1204 const bool lower_control_flow,
1205 GraphDef* output_graph) {
1206 bool is_aggressive = opt_level == RewriterConfig::AGGRESSIVE;
1207 VLOG(2) << "Inline function calls: grappler_item_id=" << item.id
1208 << " (aggressive_mode=" << is_aggressive << ")";
1209
1210 FunctionLibraryDefinition flib_def =
1211 FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library());
1212 std::unique_ptr<Graph> graph = absl::make_unique<Graph>(flib_def);
1213
1214 GraphConstructorOptions graph_constructor_options;
1215 graph_constructor_options.allow_internal_ops = true;
1216 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_constructor_options,
1217 item.graph, graph.get()));
1218
1219 using NodeNames = absl::flat_hash_set<absl::string_view>;
1220 NodeNames fetch_nodes;
1221 fetch_nodes.reserve(item.fetch.size());
1222 for (const string& fetch : item.fetch) {
1223 fetch_nodes.insert(ParseTensorName(fetch).node());
1224 }
1225 NodeNames keep_nodes(item.keep_ops.begin(), item.keep_ops.end());
1226
1227 std::vector<string> inlined_function_names;
1228
1229 // Do not inline function call nodes that are part of a feed set.
1230 NodeNames feed_nodes;
1231 feed_nodes.reserve(item.feed.size());
1232 for (const std::pair<std::string, Tensor>& feed : item.feed) {
1233 feed_nodes.insert(ParseTensorName(feed.first).node());
1234 }
1235
1236 // If a function call is inside a While loop, it must have an incoming control
1237 // edge, because it will be used to pass execution frame into the function
1238 // body. All nodes without inputs in the function body (e.g. Const and NoOp)
1239 // will be added an extra control edge from the 'input_control_node'.
1240 std::vector<ControlFlowInfo> control_flow_info;
1241 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &control_flow_info));
1242
1243 // Function inlining always adds new nodes to the end of the list, so we keep
1244 // iterating until we are out of nodes.
1245 for (int i = 2; i < graph->num_node_ids(); ++i) {
1246 Node* n = graph->FindNodeId(i);
1247 if (n == nullptr) continue; // deleted node
1248
1249 // Special case for lowering functional control flow ops. We do not rely on
1250 // LowerFunctionOpsPass because in Grappler we have to be more restrictive
1251 // about what type of function calls we are allowed to inline.
1252 if (lower_control_flow && LowerUsingSwitchMergeIsOn(n)) {
1253 VLOG(2) << "Lower functional control flow op: " << SummarizeNode(*n);
1254 AddStrictInputSemantics(n, graph.get());
1255 AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1256
1257 if (n->IsIfNode()) {
1258 TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), false));
1259 } else if (n->IsCaseNode()) {
1260 TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), false));
1261 } else if (n->IsWhileNode()) {
1262 TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), &flib_def, false));
1263 }
1264 continue;
1265 }
1266
1267 // Skip nodes that are not function calls.
1268 if (!IsFunctionCall(flib_def, *n)) continue;
1269 // Skip function calls that we plan to compile later.
1270 if (MarkedForXlaCompilation(n->def())) continue;
1271 // Skip nodes in a feed set.
1272 if (feed_nodes.contains(n->name())) continue;
1273
1274 // Function body that we will inline into the main graph. It can be a
1275 // function instantiation, or a gradient function instantiated from
1276 // SymbolicGradient op.
1277 std::unique_ptr<FunctionBody> fbody;
1278 TF_RETURN_IF_ERROR(MakeFunctionBodyForInlining(*n, flib_def, &fbody));
1279
1280 InlineFunctionBodyOptions inline_options;
1281 // Ignore '_noinline' flag in aggressive mode.
1282 inline_options.ignore_noinline = is_aggressive;
1283
1284 // Function calls created after inlining If/While ops are always inlined as
1285 // multi-device functions and are not required to pass additional Grappler
1286 // validations (side effects execution validation below).
1287 bool force_inline_as_multi_device = LowerAsMultiDeviceFunctionIsOn(n);
1288
1289 // `PartitionedCall` is a TF-2.0 function call mechanism for multi-device
1290 // functions:
1291 // a) Function can be multi-device.
1292 // b) Automatic control dependencies tracking guarantees that all function
1293 // side-effectful nodes will have a path to one of the control outputs.
1294 // Control outputs and control edges between side-effectful (stateful)
1295 // nodes are used to explicitly mark the nodes that must execute, and to
1296 // define their execution order.
1297 if (n->IsPartitionedCall() || force_inline_as_multi_device) {
1298 inline_options.output_control_src = OutputControlSource::kControlOutputs;
1299 inline_options.inlined_function_body_placer =
1300 InlinedFunctionBodyPlacer::MultiDevice();
1301 } else {
1302 inline_options.output_control_src = OutputControlSource::kDataOutputs;
1303 inline_options.inlined_function_body_placer =
1304 InlinedFunctionBodyPlacer::SingleDevice();
1305 }
1306
1307 if (fetch_nodes.contains(n->name())) {
1308 inline_options.keep_caller_node = KeepCallerNode::kFetchable;
1309 } else if (keep_nodes.contains(n->name())) {
1310 inline_options.keep_caller_node = KeepCallerNode::kTargetable;
1311 } else {
1312 inline_options.keep_caller_node = KeepCallerNode::kDoNotKeep;
1313 }
1314
1315 // Basic validation rules defined in common_runtime shared by all functions.
1316 Status can_inline_function_call =
1317 ValidateInlining(n, fbody.get(), inline_options);
1318
1319 // Additional validation rules defined only in Grappler.
1320 // TODO(ezhulenev): Move it to common_runtime InlineFunctionBodyOptions?
1321 if (can_inline_function_call.ok()) {
1322 bool has_outgoing_control_edges = absl::c_any_of(
1323 n->out_edges(),
1324 [](const Edge* edge) { return edge->IsControlEdge(); });
1325
1326 can_inline_function_call = ValidateSideEffectsExecution(
1327 *fbody, inline_options.output_control_src,
1328 has_outgoing_control_edges);
1329
1330 if (!can_inline_function_call.ok() &&
1331 (is_aggressive || force_inline_as_multi_device)) {
1332 VLOG(2) << "Ignore error: " << can_inline_function_call.error_message();
1333 can_inline_function_call = Status::OK();
1334 }
1335 }
1336 if (can_inline_function_call.ok()) {
1337 can_inline_function_call = ValidateNoDeadOutputs(flib_def, *fbody);
1338 }
1339
1340 if (can_inline_function_call.ok()) {
1341 VLOG(2) << "Inline function call node: " << n->name();
1342 AddStrictInputSemantics(n, graph.get());
1343 AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1344
1345 TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, graph.get(), n,
1346 fbody.get(), inline_options));
1347 inlined_function_names.push_back(fbody->fdef.signature().name());
1348
1349 } else {
1350 VLOG(2) << "Failed to inline function call node: "
1351 << can_inline_function_call.error_message();
1352 }
1353 }
1354
1355 VLOG(4) << "Inlined " << inlined_function_names.size()
1356 << " function calls: " << absl::StrJoin(inlined_function_names, ", ");
1357
1358 // ------------------------------------------------------------------------ //
1359 // Grappler receives the graph after PRE_PLACEMENT, Placer, and POST_PLACEMENT
1360 // passes, so each node has a valid device assignment. After function inlining
1361 // and control flow V2 lowering we have to keep graph placed.
1362
1363 if (inlined_function_names.empty()) {
1364 VLOG(3) << "Not placing graph after function inlining"
1365 << " (did not inline any of the function calls).";
1366
1367 } else if (item.devices().empty()) {
1368 // If there are no devices available for placer, we do not place graph after
1369 // function inlining. This happens when Grappler is optimizing the function
1370 // library, or when a graph optimized "offline", without an active runtime
1371 // session, for example as a part of batch job for graph
1372 // analysis/optimization. GrapplerItem instantiated from a function library
1373 // doesn't have to be fully placed after all optimizations; it will be
1374 // placed by the function library runtime before execution.
1375 VLOG(3) << "Not placing graph after function inlining"
1376 << " (device set is empty)";
1377
1378 } else {
1379 // If we are running in an active runtime session, Grappler will get the
1380 // graph after initial placing is done, and we should have devices for the
1381 // placer.
1382 VLOG(3) << "Run placer for the graph after function inlining. "
1383 << "Devices: [" << absl::StrJoin(item.devices(), ", ") << "]";
1384
1385 DeviceSet device_set; // does not own devices
1386 std::vector<std::unique_ptr<Device>> fake_devices; // owns fake devices
1387
1388 for (const string& name : item.devices()) {
1389 auto device = absl::make_unique<FakeDevice>(name);
1390 device_set.AddDevice(device.get());
1391 fake_devices.push_back(std::move(device));
1392 }
1393
1394 Placer placer(graph.get(), item.id, &flib_def, &device_set);
1395 TF_RETURN_IF_ERROR(placer.Run());
1396 }
1397
1398 graph->ToGraphDef(output_graph);
1399 return Status::OK();
1400 }
1401
1402 // Restores tensor mapping after function specialization: all inputs must be
1403 // connected to valid nodes.
RestoreTensorMapping(const FunctionOptimizerContext & ctx,GraphDef * optimized_graph)1404 void RestoreTensorMapping(const FunctionOptimizerContext& ctx,
1405 GraphDef* optimized_graph) {
1406 if (ctx.tensor_mapping().empty()) return;
1407
1408 // During function specialization, we might prune unused function outputs. We
1409 // need to "close the holes" that might appear in the function outputs.
1410 //
1411 // Example: prune unused output "f:1"
1412 //
1413 // f = my_func[T=float](...) f = my_func_specialized[T=float](...)
1414 // a = Identity(f:0) -> a = Identity(f:0)
1415 // b = Identity(f:2) b = Identity(f:1)
1416 //
1417 // Tensor mapping (size=1): [f:2 -> f:1]
1418 for (NodeDef& node : *optimized_graph->mutable_node()) {
1419 for (int idx = 0; idx < node.input_size(); ++idx) {
1420 TensorId input_tensor = ParseTensorName(node.input(idx));
1421 if (input_tensor.index() == Graph::kControlSlot) break;
1422
1423 auto mapping = ctx.tensor_mapping().find(input_tensor);
1424 if (mapping != ctx.tensor_mapping().end()) {
1425 node.set_input(idx, TensorIdToString(mapping->second));
1426 }
1427 }
1428 }
1429 }
1430
1431 } // namespace
1432
RunFunctionOptimizerPass(const GrapplerItem & item,GraphDef * optimized_graph) const1433 Status FunctionOptimizer::RunFunctionOptimizerPass(
1434 const GrapplerItem& item, GraphDef* optimized_graph) const {
1435 VLOG(3) << "Run function optimizer pass: grappler_item_id=" << item.id;
1436
1437 // Inline all function calls into a graph using common_runtime/function
1438 // implementation (see `InlineFunctionBody` function documentation).
1439 GraphDef graph_after_inlining;
1440 TF_RETURN_IF_ERROR(InlineFunctionCalls(item, opt_level_, lower_control_flow_,
1441 &graph_after_inlining));
1442
1443 // Specialize function calls that we could not inline.
1444 FunctionOptimizerContext ctx(item, opt_level_, graph_after_inlining);
1445
1446 for (const NodeDef& node : graph_after_inlining.node()) {
1447 // Function specialization can modify optimized graph only by adding new
1448 // nodes, we can check node size to make sure that graph was not modified.
1449 const int num_nodes_before = optimized_graph->node_size();
1450 const auto is_graph_modified = [&]() {
1451 int num_nodes = optimized_graph->node_size();
1452 DCHECK_GE(num_nodes, num_nodes_before) << "Nodes should not be removed";
1453 return num_nodes > num_nodes_before;
1454 };
1455
1456 // Copy node from the `graph_after_inlining` to the `optimized_graph`.
1457 const auto copy_node = [&]() { *optimized_graph->add_node() = node; };
1458
1459 // Find if a node is a function call (direct or indirect).
1460 const FunctionDef* func = FindFunctionCall(ctx, node);
1461 if (func == nullptr) {
1462 copy_node();
1463 continue;
1464 }
1465
1466 const string& func_name = func->signature().name();
1467
1468 // Specialize it to its instantiation context if it has something worth
1469 // specializing.
1470 const bool specialization_worthy = IsParametrized(*func) ||
1471 HasTrulyConstInputs(node, ctx) ||
1472 HasUnusedOutputs(node, *func, ctx);
1473
1474 // Do not specialize if function has custom gradient or marked nospecialize.
1475 const string grad_func = ctx.function_library().FindGradient(func_name);
1476 const bool no_specialize =
1477 !grad_func.empty() || ctx.IsFeedNode(node.name()) ||
1478 MarkedNoSpecialize(*func) || MarkedForXlaCompilation(node);
1479
1480 if (specialization_worthy && !no_specialize) {
1481 // TODO(ezhulenev): Specialize function call if input has a known shape.
1482 // Specialize function body for its instantiation attributes and inputs.
1483 Status status = SpecializeFunction(node, *func, &ctx, optimized_graph);
1484 if (!status.ok() && is_graph_modified()) {
1485 return status;
1486 } else if (!status.ok() && !is_graph_modified()) {
1487 VLOG(3) << "Skip specialization error: " << status.error_message();
1488 copy_node();
1489 }
1490 continue;
1491 } else {
1492 VLOG(2) << "Skip function specialization: " << func->signature().name();
1493 copy_node();
1494 }
1495 }
1496
1497 RestoreTensorMapping(ctx, optimized_graph);
1498
1499 // Preserve the graph version.
1500 *optimized_graph->mutable_versions() = item.graph.versions();
1501 // Prune unreachable function from the library.
1502 *optimized_graph->mutable_library() =
1503 PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
1504
1505 return Status::OK();
1506 }
1507
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)1508 Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
1509 GraphDef* optimized_graph) {
1510 // Nothing to do here.
1511 if (item.graph.library().function_size() == 0) {
1512 return errors::Aborted("Nothing to do.");
1513 }
1514
1515 TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(item, optimized_graph));
1516
1517 return Status::OK();
1518 }
1519
1520 } // end namespace grappler
1521 } // end namespace tensorflow
1522