1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/lower_function_call_op.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/core/common_runtime/function_def_utils.h"
20 #include "tensorflow/core/common_runtime/inline_function_utils.h"
21 #include "tensorflow/core/common_runtime/lower_function_call_inline_policy.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/graph_node_util.h"
25 #include "tensorflow/core/platform/errors.h"
26 
27 namespace tensorflow {
28 
29 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
30 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
31 
RewriteFunctionCallNode(Node * n,Graph * g,const FunctionLibraryDefinition & flib_def,bool keep_caller_fetchable)32 Status RewriteFunctionCallNode(Node* n, Graph* g,
33                                const FunctionLibraryDefinition& flib_def,
34                                bool keep_caller_fetchable) {
35   VLOG(2) << "Lower function call node: " << SummarizeNode(*n);
36 
37   // We support lowering of two types of functions that could be invoked by the
38   // node `n`: 1) native functions and 2) multi-device functions.
39   // NOTE(ezhulenev): We explicitly choose not to deal with SymbolicGradient,
40   // because it has been deprecated for a long time.
41   InlineFunctionBodyOptions inline_options;
42   inline_options.keep_caller_node = keep_caller_fetchable
43                                         ? KeepCallerNode::kFetchable
44                                         : KeepCallerNode::kTargetable;
45 
46   FunctionCallInlinePolicy policy = GetFunctionCallInlinePolicy(n);
47   if (policy == FunctionCallInlinePolicy::kMultiDevicePlacer) {
48     // Multi-device function calls (PartitionedCall or StatefulPartitionedCall
49     // ops) can execute on multiple devices and accept DT_RESOURCE inputs that
50     // belong to different devices. This type of functions was added in
51     // Tensorflow 2.0 Eager mode, and it has control outputs to represent
52     // side-effects that must always execute (see `control_ret` in FunctionDef).
53     inline_options.output_control_src = OutputControlSrc::kControlOutputs;
54     inline_options.inlined_function_body_placer =
55         InlinedFunctionBodyPlacer::MultiDevice();
56   } else if (policy == FunctionCallInlinePolicy::kSingleDevicePlacer) {
57     // Native function call (node.type_string() is the function name). These
58     // functions are always executed on a single-device, which is the device of
59     // the function call node.
60     inline_options.output_control_src = OutputControlSrc::kDataOutputs;
61     inline_options.inlined_function_body_placer =
62         InlinedFunctionBodyPlacer::SingleDevice();
63   } else {
64     return errors::InvalidArgument("Unsupported function inlining policy");
65   }
66 
67   const FunctionDef* fdef;
68   if (n->IsPartitionedCall()) {
69     NameAttrList func;
70     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "f", &func));
71     fdef = flib_def.Find(func.name());
72   } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) {
73     VLOG(2) << "Skip SymbolicGradient lowering";
74     return Status::OK();
75   } else {
76     fdef = flib_def.Find(n->type_string());
77   }
78 
79   if (fdef == nullptr) {
80     return errors::Internal("Can't find a function: node=", SummarizeNode(*n));
81   }
82 
83   std::unique_ptr<FunctionBody> fbody;
84   TF_RETURN_IF_ERROR(
85       FunctionDefToBodyHelper(*fdef, n->attrs(), &flib_def, &fbody));
86 
87   Status can_inline_function_call =
88       ValidateInlining(n, fbody.get(), inline_options);
89   if (can_inline_function_call.ok()) {
90     TF_RETURN_IF_ERROR(
91         InlineFunctionBody(flib_def, g, n, fbody.get(), inline_options));
92   } else {
93     VLOG(2) << "Failed to inline function call node: "
94             << can_inline_function_call.error_message();
95   }
96 
97   return Status::OK();
98 }
99 
100 }  // namespace tensorflow
101