1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/inline_function_utils.h"
17
18 #include <deque>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/function_utils.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/framework/collective.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/versions.pb.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/control_flow.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/graph/optimizer_cse.h"
39 #include "tensorflow/core/lib/core/threadpool.h"
40 #include "tensorflow/core/lib/gtl/map_util.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/profiler/lib/traceme.h"
43 #include "tensorflow/core/protobuf/config.pb.h"
44
45 namespace tensorflow {
46
47 /*static*/ constexpr const char* const
48 LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
49 /*static*/ constexpr const char* const
50 LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
51
52 namespace {
53 // A few string constant used throughout this module.
54 static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
55 static constexpr const char* const kDeviceArgOp =
56 FunctionLibraryDefinition::kDeviceArgOp;
57 static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
58 static constexpr const char* const kDeviceRetOp =
59 FunctionLibraryDefinition::kDeviceRetOp;
60 static constexpr const char* const kGradientOp =
61 FunctionLibraryDefinition::kGradientOp;
62 static constexpr const char* const kNodeLabel = "Func";
63 static constexpr const char* const kFuncAttr =
64 FunctionLibraryDefinition::kFuncAttr;
65
66 // Represents the index-th output of a node.
67 struct Endpoint {
68 Node* node;
69 int index;
70
71 // Returns the string name represents this endpoint.
nametensorflow::__anon12cc2c770111::Endpoint72 string name() const {
73 if (index == 0) {
74 return node->name();
75 } else {
76 return strings::StrCat(node->name(), ":", index);
77 }
78 }
79
dtypetensorflow::__anon12cc2c770111::Endpoint80 DataType dtype() const { return node->output_type(index); }
81 };
82
83 struct EndpointHash {
operator ()tensorflow::__anon12cc2c770111::EndpointHash84 uint64 operator()(const Endpoint& x) const {
85 return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
86 x.index);
87 }
88 };
89
90 struct EndpointEq {
operator ()tensorflow::__anon12cc2c770111::EndpointEq91 bool operator()(const Endpoint& x, const Endpoint& y) const {
92 return (x.node == y.node) && (x.index == y.index);
93 }
94 };
95
96 // The following Add* routines are used to add a few graph nodes while
97 // functions are transformed.
AddNoOp(StringPiece name,Graph * g)98 static Node* AddNoOp(StringPiece name, Graph* g) {
99 NodeDef ndef;
100 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
101 ndef.set_op("NoOp");
102 Status s;
103 Node* ret = g->AddNode(ndef, &s);
104 TF_CHECK_OK(s);
105 return ret;
106 }
107
AddIdentity(StringPiece name,Graph * g,Endpoint input)108 static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
109 DCHECK_LT(0, input.dtype());
110 NodeDef ndef;
111 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
112 ndef.set_op("Identity");
113 ndef.add_input(input.name());
114 AddNodeAttr("T", BaseType(input.dtype()), &ndef);
115 Status s;
116 Node* ret = g->AddNode(ndef, &s);
117 TF_CHECK_OK(s);
118 g->AddEdge(input.node, input.index, ret, 0);
119 return ret;
120 }
121
InputDevices(const Node & caller)122 std::vector<string> InputDevices(const Node& caller) {
123 std::vector<string> input_devices(caller.in_edges().size());
124 std::vector<string> input_tensors(caller.in_edges().size());
125
126 for (const Edge* edge : caller.in_edges()) {
127 if (edge->IsControlEdge()) continue;
128 const string& input_device = edge->src()->has_assigned_device_name()
129 ? edge->src()->assigned_device_name()
130 : edge->src()->requested_device();
131 input_devices[edge->dst_input()] = input_device;
132 input_tensors[edge->dst_input()] =
133 absl::StrCat(edge->src()->name(), ":", edge->src_output());
134 }
135
136 if (VLOG_IS_ON(4)) {
137 VLOG(4) << "Function instantiation input devices:";
138 for (int i = 0; i < input_devices.size(); ++i) {
139 if (input_tensors[i].empty()) continue; // skip control edges
140 VLOG(4) << " [index " << i << "]"
141 << " device: " << input_devices[i]
142 << " (input: " << input_tensors[i] << ")";
143 }
144 }
145
146 return input_devices;
147 }
148
149 // Place input nodes on the same device as the corresponding caller input
150 // node. Do not specify any placement for all other nodes.
151 class DefaultFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
152 public:
DefaultFunctionBodyPlacer(const Node & caller)153 explicit DefaultFunctionBodyPlacer(const Node& caller)
154 : input_devices_(InputDevices(caller)) {}
155
InputNodeDevice(int input_index) const156 absl::optional<string> InputNodeDevice(int input_index) const override {
157 return input_devices_[input_index];
158 }
OutputNodeDevice(int output_index) const159 absl::optional<string> OutputNodeDevice(int output_index) const override {
160 return absl::nullopt;
161 }
ColocateInputOutputIdentities() const162 bool ColocateInputOutputIdentities() const override { return false; }
ControlNodeDevice() const163 absl::optional<string> ControlNodeDevice() const override {
164 return absl::nullopt;
165 }
BodyNodeDevice(const NodeDef & ndef) const166 absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
167 return absl::nullopt;
168 }
169
170 private:
171 const std::vector<string> input_devices_;
172 };
173
174 // Place all nodes on the same device as caller node.
175 class SingleDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
176 public:
SingleDeviceFunctionBodyPlacer(const Node & caller)177 explicit SingleDeviceFunctionBodyPlacer(const Node& caller)
178 : caller_device_(caller.def().device()) {}
179
InputNodeDevice(int input_index) const180 absl::optional<string> InputNodeDevice(int input_index) const override {
181 return caller_device_;
182 }
OutputNodeDevice(int output_index) const183 absl::optional<string> OutputNodeDevice(int output_index) const override {
184 return caller_device_;
185 }
ColocateInputOutputIdentities() const186 bool ColocateInputOutputIdentities() const override { return false; }
ControlNodeDevice() const187 absl::optional<string> ControlNodeDevice() const override {
188 return caller_device_;
189 }
BodyNodeDevice(const NodeDef & ndef) const190 absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
191 return caller_device_;
192 }
193
194 private:
195 const string caller_device_;
196 };
197
198 // Place input nodes on the same device as the corresponding caller input
199 // node. Do not place output node. Place control nodes on the same device as
200 // caller node. For all function body nodes overrides job, replica and task
201 // parts of the device assignment to match function caller node.
202 class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
203 public:
MultiDeviceFunctionBodyPlacer(const Node & caller)204 explicit MultiDeviceFunctionBodyPlacer(const Node& caller)
205 : caller_device_(caller.def().device()),
206 input_devices_(InputDevices(caller)) {
207 has_parsed_caller_device_ =
208 DeviceNameUtils::ParseFullName(caller_device_, &caller_parsed_device_);
209 }
210
InputNodeDevice(int input_index) const211 absl::optional<string> InputNodeDevice(int input_index) const override {
212 return input_devices_[input_index];
213 }
OutputNodeDevice(int output_index) const214 absl::optional<string> OutputNodeDevice(int output_index) const override {
215 return absl::nullopt;
216 }
ColocateInputOutputIdentities() const217 bool ColocateInputOutputIdentities() const override { return true; }
ControlNodeDevice() const218 absl::optional<string> ControlNodeDevice() const override {
219 return caller_device_;
220 }
BodyNodeDevice(const NodeDef & ndef) const221 absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
222 // TODO(ezhulenev): If function would have been instantiated as a
223 // multi-device function and executed via FunctionLibraryRuntime, it could
224 // be potentially placed on any available device. However there are multiple
225 // tests relying on this assumption. Fix them, and remove this line.
226 if (ndef.device().empty()) return caller_device_;
227
228 if (!has_parsed_caller_device_) return ndef.device();
229
230 DeviceNameUtils::ParsedName ndef_parsed_device;
231 if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device))
232 return ndef.device();
233
234 // Nodes with explicit device placements in the function body have those
235 // respected, but otherwise the function's placement provides a default.
236 if (caller_parsed_device_.has_job && !ndef_parsed_device.has_job) {
237 ndef_parsed_device.has_job = caller_parsed_device_.has_job;
238 ndef_parsed_device.job = caller_parsed_device_.job;
239 }
240
241 if (caller_parsed_device_.has_replica && !ndef_parsed_device.has_replica) {
242 ndef_parsed_device.has_replica = caller_parsed_device_.has_replica;
243 ndef_parsed_device.replica = caller_parsed_device_.replica;
244 }
245
246 if (caller_parsed_device_.has_task && !ndef_parsed_device.has_task) {
247 ndef_parsed_device.has_task = caller_parsed_device_.has_task;
248 ndef_parsed_device.task = caller_parsed_device_.task;
249 }
250 return DeviceNameUtils::ParsedNameToString(ndef_parsed_device);
251 }
252
253 private:
254 string caller_device_;
255 bool has_parsed_caller_device_;
256 DeviceNameUtils::ParsedName caller_parsed_device_;
257 std::vector<string> input_devices_;
258 };
259
260 } // namespace
261
262 std::unique_ptr<InlinedFunctionBodyPlacer>
DefaultPlacer(const Graph & graph,const Node & caller)263 InlinedFunctionBodyPlacer::DefaultPlacer(const Graph& graph,
264 const Node& caller) {
265 VLOG(3) << "Create default placer for inlined function body.";
266 return absl::make_unique<DefaultFunctionBodyPlacer>(caller);
267 }
268
269 std::unique_ptr<InlinedFunctionBodyPlacer>
SingleDevicePlacer(const Graph & graph,const Node & caller)270 InlinedFunctionBodyPlacer::SingleDevicePlacer(const Graph& graph,
271 const Node& caller) {
272 VLOG(3) << "Create single device placer for inlined function body.";
273 return absl::make_unique<SingleDeviceFunctionBodyPlacer>(caller);
274 }
275
276 std::unique_ptr<InlinedFunctionBodyPlacer>
MultiDevicePlacer(const Graph & graph,const Node & caller)277 InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph,
278 const Node& caller) {
279 VLOG(3) << "Create multi device placer for inlined function body.";
280 return absl::make_unique<MultiDeviceFunctionBodyPlacer>(caller);
281 }
282
283 namespace {
284
ValidateNoInline(const FunctionBody * fbody)285 Status ValidateNoInline(const FunctionBody* fbody) {
286 const auto attr = AttrSlice(&fbody->fdef.attr());
287 bool noinline = false;
288 if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) {
289 return errors::InvalidArgument(
290 "Can't inline function marked with '_noinline'");
291 }
292 return Status::OK();
293 }
294
295 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
296
297 // Propagate the debug info of `nodes` in function `func` to the `target` node.
298 // If the debug info of any node is missing, its node name and function name
299 // is used.
PropagateDebugInfoToNode(const string & func,const std::vector<const Node * > & nodes,NodeDef * target)300 void PropagateDebugInfoToNode(const string& func,
301 const std::vector<const Node*>& nodes,
302 NodeDef* target) {
303 if (nodes.empty() || target->has_experimental_debug_info()) {
304 return;
305 }
306 for (const Node* node : nodes) {
307 const auto& node_def = node->def();
308 if (node_def.has_experimental_debug_info()) {
309 target->mutable_experimental_debug_info()->MergeFrom(
310 node_def.experimental_debug_info());
311 } else {
312 target->mutable_experimental_debug_info()->add_original_node_names(
313 node_def.name());
314 target->mutable_experimental_debug_info()->add_original_func_names(func);
315 }
316 }
317 }
318 } // namespace
319
DebugString() const320 string InlineFunctionBodyOptions::DebugString() const {
321 const auto true_false = [](bool b) { return b ? "true" : "false"; };
322
323 const auto keep_caller_node_str = [this]() -> string {
324 switch (keep_caller_node) {
325 case KeepCallerNode::kDoNotKeep:
326 return "DoNotKeep";
327 case KeepCallerNode::kFetchable:
328 return "Fetchable";
329 case KeepCallerNode::kTargetable:
330 return "Targetable";
331 }
332 };
333
334 return absl::StrCat(
335 "disable_inlining=", true_false(disable_inlining),
336 ", ignore_noinline=", true_false(ignore_noinline),
337 ", inline_impl_selection_group_functions=",
338 true_false(inline_impl_selection_group_functions),
339 ", keep_caller_node=", keep_caller_node_str(), ", output_control_src=",
340 output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs"
341 : "ControlOutputs",
342 ", inlined_function_body_placer=", inlined_function_body_placer.name,
343 ", uniquify_frame_names=", true_false(uniquify_frame_names));
344 }
345
ValidateInlining(const Node * node,const FunctionBody * fbody,const InlineFunctionBodyOptions & options)346 Status ValidateInlining(const Node* node, const FunctionBody* fbody,
347 const InlineFunctionBodyOptions& options) {
348 // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee
349 // that all side-effectful ops will be executed after inlining. See Grappler
350 // function_optimizer for details. Unify all function inlining mechanism.
351 // Do not inline if `!fbody->control_ret_nodes.empty()`.
352
353 const auto num_node_inputs = static_cast<size_t>(node->num_inputs());
354 const auto num_node_outputs = static_cast<size_t>(node->num_outputs());
355
356 if (num_node_inputs != fbody->arg_types.size() ||
357 num_node_inputs != fbody->arg_nodes.size()) {
358 return errors::InvalidArgument(
359 "Node inputs do not match function arguments: inputs=", num_node_inputs,
360 " arg_types=", fbody->arg_types.size(),
361 " arg_nodes=", fbody->arg_nodes.size());
362 }
363
364 if (num_node_outputs != fbody->ret_types.size() ||
365 num_node_outputs != fbody->ret_nodes.size()) {
366 return errors::InvalidArgument(
367 "Node outputs do not match function returns: outputs=",
368 num_node_outputs, " ret_types=", fbody->ret_types.size(),
369 " ret_nodes=", fbody->ret_nodes.size());
370 }
371
372 for (int i = 0; i < node->num_inputs(); ++i) {
373 if (node->input_type(i) != fbody->arg_types[i]) {
374 return errors::InvalidArgument(
375 "Node input type doesn't match function argument type: ",
376 node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i);
377 }
378 }
379 for (int i = 0; i < node->num_outputs(); ++i) {
380 if (node->output_type(i) != fbody->ret_types[i]) {
381 return errors::InvalidArgument(
382 "Node output type doesn't match function return type: ",
383 node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i);
384 }
385 }
386
387 if (options.disable_inlining) {
388 return errors::InvalidArgument(
389 "Function inlining explicitly disabled by 'options.disable_inlining'");
390 }
391
392 if (!options.inline_impl_selection_group_functions) {
393 bool is_impl_selection_group_function =
394 fbody->fdef.attr().find("api_implements") != fbody->fdef.attr().end();
395 if (is_impl_selection_group_function) {
396 return errors::InvalidArgument(
397 "Inlining of implementation selection group function ",
398 fbody->fdef.signature().name(),
399 " is disabled by options.inline_impl_selection_group_functions");
400 }
401 }
402
403 if (!options.ignore_noinline) {
404 TF_RETURN_IF_ERROR(ValidateNoInline(fbody));
405 }
406
407 return Status::OK();
408 }
409
410 // Function inlining must preserve function execution semantics with regards to
411 // side-effects visibility. Tensorflow in Eager mode has an automatic control
412 // dependencies tracking mechanism, which enforces well-defined execution order
413 // of all side-effects. Any other frontend (e.g. Swift) must produce graphs
414 // following the same rules, to ensure that function inlining works correctly.
415 //
416 // IMPORTANT: Currently we do not have a true notion of "side-effectful" node,
417 // we assume that all stateful nodes might have side-effects, though it's not
418 // true in practice, e.g. `ReadVariableOp` doesn't have an observable
419 // side-effect.
420 //
421 // Automatic control dependency rules in Tensorflow 2.0 (python in eager mode):
422 //
423 // 1) When a function has a resource (DT_RESOURCE data type) input argument it
424 // "captures" the mutable resource. This is implemented by automatically
425 // adding a incoming control edge from the previous side-effectful op
426 // touching that resource, and an outgoing control edge to the next
427 // side-effectful op using the same resource. This serializes the mutations
428 // of the resource to make graph execution deterministic.
429 //
430 // 2) All stateful ops inside a function body are guaranteed to execute in
431 // program order, this is achieved by adding control edges between stateful
432 // ops at graph construction time. Stateful ops (or ops that must execute)
433 // should be in the function control return set. Having a data edge to the
434 // regular function output might be not enough, because after function
435 // inlining it might happen that data output is unused.
436 //
437 // 3) Furthermore, all ops accepting the same resource as an input are
438 // guaranteed to run in program order. This is also done by adding control
439 // edges at graph construction time. The last op touching the resource
440 // must be in a control return set, which will guarantee that all side
441 // effects to the resource will happen before function completion.
442 //
443 // Function inlining must preserve side-effect visibility:
444 //
445 // 1) All side-effects to the captured resources, that happened before function
446 // call must be visible to the function body nodes using that resources.
447 //
448 // 2) All side-effects to the captured resources, that happened inside function
449 // body, must be visible to every op/function using that resource after the
450 // function call completed.
451 //
452 // To guarantee that these properties are preserved after inlining we:
453 //
454 // 1) Create "input_control_node" NoOp. Function call node incoming control
455 // edges will be forwarded *to* this node. Function inputs (Identity nodes)
456 // will have a control edge *from* this node. If function body has nodes
457 // without inputs, they will have a control edge *from* this node.
458 //
459 // 2) Create "output_control_node" NoOp. All nodes that have incoming control
460 // edge *from* the function call node, will be forwarded to this node.
461 //
462 // We have two options for choosing which nodes will have a control edge *to*
463 // the "output control node":
464 // a) control returns (`control_ret` field in FunctionDef)
465 // b) data returns (`ret` field in FunctionDef)
466 //
467 // We do a) for multi-device function calls in Tensorflow v2 and b)
468 // for the rest for compatibility with Tensorflow v1.
469 //
470 // Following the automatic control dependencies tracking rules, a node that
471 // has an incoming control edge from the function call node is dependent on
472 // the side-effects happening inside the function body. The output control
473 // node will guarantee side-effects execution order.
474 //
475 // If function call node doesn't have an outgoing control edge, it means that
476 // no one is interested in observing side-effects that might have happened.
477 //
478 // Function inlining might leave the graph in partially-placed state. Function
479 // inlining caller must call Placer to guarantee that all nodes are placed.
480 //
481 // Function inlining with `options.override_device=true` will leave graph in
482 // fully placed state, by overriding all inlined nodes devices with the caller
483 // node device, but it will make functions always single-device. These functions
484 // after inlining will not be able to handle resources on multiple devices. This
485 // is currently acceptable for XLA use cases (XLA cluster is always executed on
486 // a single device).
487 //
488 // TODO(ezhulenev): Documentation above is ahead of implementation below.
InlineFunctionBody(const FunctionLibraryDefinition & flib_def,Graph * g,Node * caller,const FunctionBody * fbody,const InlineFunctionBodyOptions & options)489 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
490 Node* caller, const FunctionBody* fbody,
491 const InlineFunctionBodyOptions& options) {
492 VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " ["
493 << options.DebugString() << "]";
494
495 Status validation = ValidateInlining(caller, fbody, options);
496 if (!validation.ok()) {
497 return errors::Internal("Inlining mismatch: ", validation.error_message());
498 }
499
500 // Placer is responsible for assigning devices for all nodes that we will add
501 // to the graph.
502 const std::unique_ptr<InlinedFunctionBodyPlacer> placer =
503 options.inlined_function_body_placer.get(*g, *caller);
504
505 // We can't possibly introduce a duplicate control edge during function
506 // inlining, so we skip this check in calls to the 'g->AddControlEdge(...)'.
507 static constexpr bool kDoNotCheckDuplicates = true;
508
509 // ------------------------------------------------------------------------ //
510 // Helper functions to create `NoOp` and `Identity` nodes for auxiliary
511 // control nodes and inlined function inputs and outputs.
512
513 // Add a NoOp node for function control inputs/outputs.
514 const auto no_op = [&](StringPiece name) -> Node* {
515 Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g);
516 const absl::optional<string> device = placer->ControlNodeDevice();
517 if (device.has_value()) node->set_requested_device(*device);
518 return node;
519 };
520
521 // Add an Identity node for function input.
522 const auto input_identity = [&](StringPiece name, Endpoint input,
523 int index) -> Node* {
524 Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
525 const absl::optional<string> device = placer->InputNodeDevice(index);
526 if (device.has_value()) node->set_requested_device(*device);
527 bool colocate_identity = placer->ColocateInputOutputIdentities();
528 if (colocate_identity) {
529 node->AddAttr(kColocationAttrName,
530 std::vector<string>{absl::StrCat(kColocationGroupPrefix,
531 input.node->name())});
532 }
533 return node;
534 };
535
536 // Add an Identity node for function output.
537 const auto output_identity = [&](StringPiece name, Endpoint input,
538 int index) -> Node* {
539 Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
540 const absl::optional<string> device = placer->OutputNodeDevice(index);
541 if (device.has_value()) node->set_requested_device(*device);
542 bool colocate_identity = placer->ColocateInputOutputIdentities();
543 if (colocate_identity) {
544 node->AddAttr(kColocationAttrName,
545 std::vector<string>{absl::StrCat(kColocationGroupPrefix,
546 input.node->name())});
547 }
548 return node;
549 };
550
551 // ------------------------------------------------------------------------ //
552 // Helper function to get an input/output argument name by index. For
553 // functions instantiated from SymbolicGradien corresponding FunctionDef is
554 // empty, and argument name is unknown.
555
556 auto arg_name = [&](auto& args, size_t i) -> absl::string_view {
557 if (i < args.size()) {
558 return args[i].name();
559 } else {
560 return "<unknown>";
561 }
562 };
563
564 // ------------------------------------------------------------------------ //
565 // Input edges. For data edges coming into "caller", we first compute the
566 // <src>:<src_output> for the i-th input in "inputs".
567 // If "caller" has any input control dependencies, we add a NoOp
568 // node "input_control_node", which depends on "caller"'s control inputs.
569 std::vector<Endpoint> inputs(caller->num_inputs());
570 Node* input_control_node = nullptr;
571 for (const Edge* e : caller->in_edges()) {
572 if (e->IsControlEdge()) {
573 if (input_control_node == nullptr) {
574 input_control_node = no_op("input_control_node");
575 }
576 g->AddControlEdge(e->src(), input_control_node, kDoNotCheckDuplicates);
577 } else {
578 inputs[e->dst_input()] = {e->src(), e->src_output()};
579 }
580 }
581 if (input_control_node != nullptr) {
582 VLOG(3) << "Created input control node: " << input_control_node->name();
583 }
584
585 // ------------------------------------------------------------------------ //
586 // Duplicate fbody->graph into 'g'. First, we copy the nodes of
587 // fbody->graph into 'g' except the source and sink nodes. We copy
588 // edges among nodes in 'fbody->graph'.
589 //
590 // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
591 // remember 'y' in node_map[x->id()].
592 std::unordered_set<string> fn_nodes;
593 for (Node* n : fbody->graph->op_nodes()) {
594 fn_nodes.insert(n->name());
595 }
596 std::vector<Node*> node_map(fbody->graph->num_node_ids());
597 for (Node* n : fbody->graph->op_nodes()) {
598 NodeDef ndef = n->def();
599
600 // Maybe override requested node device assignment.
601 const absl::optional<string> device = placer->BodyNodeDevice(ndef);
602 if (device.has_value()) ndef.set_device(*device);
603
604 // Add inlined function name to inlined node debug information.
605 PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef);
606
607 // Add the function node name as a prefix:
608 // 1) to node name to avoid collisions
609 // 2) to frame name to avoid multiple LoopCond nodes in one frame
610 // 3) to colocation attribute
611 const string prefix = strings::StrCat(caller->name(), "/");
612 TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef,
613 options.uniquify_frame_names));
614 TF_RETURN_IF_ERROR(
615 MaybeAddPrefixToColocationConstraints(fn_nodes, prefix, &ndef));
616
617 Status added_node;
618 Node* clone = g->AddNode(ndef, &added_node);
619 TF_CHECK_OK(added_node);
620 node_map[n->id()] = clone;
621 clone->SetStackTrace(n->GetStackTrace());
622
623 // If there is an input control node, and one of:
624 // a) the node has no data or control inputs, or
625 // b) the node is a function call (including SymbolicGradient),
626 // then add a control edge from the input control node to the clone (only
627 // if it does not already have a control input).
628 //
629 // We must not execute any nodes if the original function call would not
630 // have executed. This is especially critical when the function call is
631 // inside a control-flow construct like tf.cond(). Case (a) ensures that
632 // such nodes do not run.
633 //
634 // The purpose of case (b) is to ensure that instances of case (a) created
635 // by further inlining steps also receive the control dependency.
636 //
637 // This edge is required to transfer execution frame down to all function
638 // body nodes of inlined nested function calls.
639 if (input_control_node) {
640 const auto is_input_edge = [](const Edge* e) -> bool {
641 return !e->src()->IsSource();
642 };
643 const auto is_control_edge = [](const Edge* e) -> bool {
644 return !e->src()->IsSource() && e->IsControlEdge();
645 };
646
647 // Forward execution frame if:
648 //
649 // a) The node has no data or control inputs.
650 // b) OR the node is a function call without control inputs (control edge
651 // will be used in nested function inlining to forward execution frame
652 // to constants inside the function body).
653 //
654 // c) Do not forward control frame to function argument nodes, they will
655 // be connected to the corresponding function input later.
656 const bool forward_execution_frame =
657 (absl::c_none_of(n->in_edges(), is_input_edge) || // (a)
658 (n->IsFunctionCall() && // (b)
659 absl::c_none_of(n->in_edges(), is_control_edge))) && //
660 !n->IsArg(); // (c)
661
662 if (forward_execution_frame) {
663 VLOG(4) << "Add control edge from input control node to: "
664 << clone->name();
665 g->AddControlEdge(input_control_node, clone, kDoNotCheckDuplicates);
666 }
667 }
668 }
669 for (const Edge* e : fbody->graph->edges()) {
670 if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
671 e->dst()->IsSink()) {
672 continue;
673 }
674 Node* src_copy = node_map[e->src()->id()];
675 Node* dst_copy = node_map[e->dst()->id()];
676 g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
677 }
678
679 // ------------------------------------------------------------------------ //
680 // Connect input edges.
681 //
682 // We create one Identity node for each input. Then, we connect inputs[i] to
683 // the i-th identity node added. The nodes that previously connected
684 // to the j-th output of i-th arg node are reconnected to the i-th
685 // identity node.
686 //
687 // The added identity nodes depend on "input_control_node".
688 VLOG(4) << "Add input Identity nodes for each function argument:";
689
690 for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
691 Node* arg = node_map[fbody->arg_nodes[i]->id()];
692 Node* n = input_identity("input", inputs[i], i);
693 VLOG(4) << " [index " << i << "] "
694 << arg_name(fbody->fdef.signature().input_arg(), i) << " as "
695 << n->name() << " (input: " << inputs[i].name()
696 << ", requested_device: " << n->requested_device() << ")";
697
698 if (input_control_node) {
699 g->AddControlEdge(input_control_node, n, kDoNotCheckDuplicates);
700 }
701 for (const Edge* e : arg->out_edges()) {
702 if (e->IsControlEdge()) {
703 g->AddControlEdge(n, e->dst(), kDoNotCheckDuplicates);
704 } else {
705 g->AddEdge(n, 0, e->dst(), e->dst_input());
706 }
707 }
708 node_map[fbody->arg_nodes[i]->id()] = n;
709 g->RemoveNode(arg); // 'arg' is disconnected.
710 }
711
712 // ------------------------------------------------------------------------ //
713 // Connect output edges.
714 //
715 // For i-th return node in fbody->graph, we add in "g" an identity node
716 // (outputs[i-th]). We then reconnect every incoming edge into the i-th return
717 // node to the added identity node.
718 //
719 // For every data edge coming out of "callee"s i-th output, we reconnect it to
720 // the i-th identity added above.
721 //
722 // If "callee" is control-depended upon by any other nodes, we add a NoOp node
723 // "output_control_node". "output_control_node" depends on all identity nodes
724 // added above or on all control return nodes (controlled by
725 // `options.output_control_src` value). And nodes previously depend on
726 // "callee" is changed to depend on "output_control_node".
727 //
728 // If `keep_node_fetchable` is `true` we always add an output control node, to
729 // guarantee that executing a fetchable node will execute all side-effects.
730 VLOG(4) << "Add output Identity nodes for each function output argument:";
731
732 std::vector<Node*> outputs(caller->num_outputs());
733 for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
734 Node* ret = node_map[fbody->ret_nodes[i]->id()];
735 Endpoint data; // Data input for the ret node.
736 for (const Edge* e : ret->in_edges()) {
737 if (!e->IsControlEdge()) {
738 data = {e->src(), e->src_output()};
739 break;
740 }
741 }
742 CHECK(data.node != nullptr);
743 Node* n = output_identity("output", data, i);
744 outputs[i] = n;
745 VLOG(4) << " [index " << i << "] "
746 << arg_name(fbody->fdef.signature().output_arg(), i) << " as "
747 << n->name() << " (ret: " << data.node->name() << ":" << data.index
748 << ", requested_device: " << n->requested_device() << ")";
749 for (const Edge* e : ret->in_edges()) {
750 if (e->IsControlEdge()) {
751 g->AddControlEdge(e->src(), n, kDoNotCheckDuplicates);
752 }
753 }
754 g->RemoveNode(ret); // 'ret' is disconnected.
755 }
756
757 Node* output_control_node = nullptr;
758 const bool has_control_outputs = absl::c_any_of(
759 caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); });
760
761 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
762 const bool keep_caller_node =
763 options.keep_caller_node == KeepCallerNode::kFetchable ||
764 options.keep_caller_node == KeepCallerNode::kTargetable;
765
766 if (has_control_outputs || keep_caller_node) {
767 output_control_node = no_op("output_control_node");
768 VLOG(4) << "Add output control node: " << output_control_node->name();
769 if (options.output_control_src == OutputControlSrc::kDataOutputs) {
770 for (Node* n : outputs) {
771 VLOG(4) << " [data output] add control edge from: " << n->name();
772 g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
773 }
774 } else {
775 for (Node* fbody_node : fbody->control_ret_nodes) {
776 Node* n = node_map[fbody_node->id()];
777 VLOG(4) << " [control output] add control edge from: " << n->name();
778 g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
779 }
780 }
781 }
782
783 // We can't leave output control node without incoming control edges, because
784 // in this case outgoing control edge will loose execution frame information.
785 // We connect input_control_node and output_control_node with a control edge
786 // to forward execution frame to the controlled nodes. Above we add a control
787 // edge to all function calls inside function body, to guarantee that we will
788 // always have input_control_node when we need it.
789 if (output_control_node && output_control_node->in_edges().empty()) {
790 if (input_control_node) {
791 VLOG(4) << "Add a control edge between input and output control nodes: "
792 << input_control_node->name() << " to "
793 << output_control_node->name();
794 g->AddControlEdge(input_control_node, output_control_node,
795 kDoNotCheckDuplicates);
796 } else {
797 VLOG(4) << "Function inlining potentially dropped execution frame "
798 "information from outgoing control edges.";
799 }
800 }
801
802 for (const Edge* e : caller->out_edges()) {
803 if (e->IsControlEdge()) {
804 g->AddControlEdge(output_control_node, e->dst(), kDoNotCheckDuplicates);
805 } else {
806 g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
807 }
808 }
809
810 // ------------------------------------------------------------------------ //
811 // Add an IdentityN or NoOp node in-place of caller node to keep `caller`
812 // fetchable or targetable.
813
814 if (keep_caller_node) {
815 std::vector<NodeBuilder::NodeOut> output_tensors;
816 absl::c_transform(outputs, std::back_inserter(output_tensors),
817 [](Node* n) { return NodeBuilder::NodeOut(n, 0); });
818
819 Node* caller_substitute_node;
820 if (options.keep_caller_node == KeepCallerNode::kTargetable ||
821 output_tensors.empty()) {
822 // IdentityN node must have at least one data input. If function has no
823 // data outputs, we can't keep it fetchable.
824 TF_CHECK_OK(NodeBuilder(caller->name(), "NoOp")
825 .Device(caller->requested_device())
826 .ControlInput(output_control_node)
827 .Finalize(g, &caller_substitute_node));
828
829 } else if (options.keep_caller_node == KeepCallerNode::kFetchable) {
830 TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN")
831 .Device(caller->requested_device())
832 .Input(output_tensors)
833 .ControlInput(output_control_node)
834 .Finalize(g, &caller_substitute_node));
835 }
836 }
837
838 // ------------------------------------------------------------------------ //
839 // 'caller' is replaced with inlined function body nodes and maybe IdentityN
840 // to keep it fetchable.
841 VLOG(3) << "Successfully inlined function call node: " << caller->name();
842 g->RemoveNode(caller);
843
844 return Status::OK();
845 }
846
ExpandInlineFunctions(FunctionLibraryRuntime * lib,Graph * graph,const ExpandInlineFunctionsOptions & options)847 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
848 const ExpandInlineFunctionsOptions& options) {
849 std::vector<std::pair<Node*, const FunctionBody*>> candidates;
850
851 const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
852
853 for (Node* node : graph->nodes()) {
854 // Skip nodes that are not function calls or SymbolicGradient calls.
855 if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) {
856 continue;
857 }
858 // Skip function calls that marked noinline.
859 bool noinline;
860 if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
861 VLOG(3) << "noinline: " << SummarizeNode(*node);
862 continue;
863 }
864 FunctionLibraryRuntime::Handle handle;
865 Status s = InstantiateFunctionCall(node->def(), lib, &handle);
866 if (!s.ok()) {
867 LOG(ERROR) << "Failed to instantiate a function: " << s.error_message();
868 continue;
869 }
870 const FunctionBody* fbody = lib->GetFunctionBody(handle);
871 CHECK_NOTNULL(fbody);
872 candidates.emplace_back(node, fbody);
873 }
874
875 bool inlined_any = false;
876 for (const auto& p : candidates) {
877 Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second,
878 p.first->IsPartitionedCall()
879 ? options.multi_device_options
880 : options.native_options);
881 if (inlined.ok()) {
882 inlined_any = true;
883 } else {
884 VLOG(1) << "Failed to inline function call: node=" << p.first->name()
885 << " error=" << inlined.error_message();
886 }
887 }
888
889 // TODO(ezhulenev): Release handles for inlined function calls.
890
891 return inlined_any;
892 }
893
894 } // end namespace tensorflow
895