1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/lower_function_call_op.h"
17
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/core/common_runtime/function_def_utils.h"
20 #include "tensorflow/core/common_runtime/inline_function_utils.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/graph/graph_node_util.h"
24
25 namespace tensorflow {
26 namespace {
27
28 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
29 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
30
31 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
32 LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
33
LowerAsMultiDeviceFunction(const Node * n)34 bool LowerAsMultiDeviceFunction(const Node* n) {
35 if (n->IsPartitionedCall()) return true;
36
37 bool match;
38 bool found =
39 TryGetNodeAttr(n->attrs(), kLowerAsMultiDeviceFunctionAttr, &match);
40 return found && match;
41 }
42
43 } // namespace
44
RewriteFunctionCallNode(Node * n,Graph * g,const FunctionLibraryDefinition & flib_def,bool keep_caller_fetchable)45 Status RewriteFunctionCallNode(Node* n, Graph* g,
46 const FunctionLibraryDefinition& flib_def,
47 bool keep_caller_fetchable) {
48 VLOG(2) << "Lower function call node: " << SummarizeNode(*n);
49
50 // We support lowering of two types of functions that could be invoked by the
51 // node `n`: 1) native functions and 2) multi-device functions.
52 // NOTE(ezhulenev): We explicitly choose not to deal with SymbolicGradient,
53 // because it has been deprecated for a long time.
54 InlineFunctionBodyOptions inline_options;
55 inline_options.keep_caller_node = keep_caller_fetchable
56 ? KeepCallerNode::kFetchable
57 : KeepCallerNode::kTargetable;
58
59 if (LowerAsMultiDeviceFunction(n)) {
60 // Multi-device function calls (PartitionedCall or StatefulPartitionedCall
61 // ops) can execute on multiple devices and accept DT_RESOURCE inputs that
62 // belong to different devices. This type of functions was added in
63 // Tensorflow 2.0 Eager mode, and it has control outputs to represent
64 // side-effects that must always execute (see `control_ret` in FunctionDef).
65 inline_options.output_control_src = OutputControlSrc::kControlOutputs;
66 inline_options.inlined_function_body_placer =
67 InlinedFunctionBodyPlacer::MultiDevice();
68 } else {
69 // Native function call (node.type_string() is the function name). These
70 // functions are always executed on a single-device, which is the device of
71 // the function call node.
72 inline_options.output_control_src = OutputControlSrc::kDataOutputs;
73 inline_options.inlined_function_body_placer =
74 InlinedFunctionBodyPlacer::SingleDevice();
75 }
76
77 const FunctionDef* fdef;
78 if (n->IsPartitionedCall()) {
79 NameAttrList func;
80 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "f", &func));
81 fdef = flib_def.Find(func.name());
82 } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) {
83 VLOG(2) << "Skip SymbolicGradient lowering";
84 return Status::OK();
85 } else {
86 fdef = flib_def.Find(n->type_string());
87 }
88
89 if (fdef == nullptr) {
90 return errors::Internal("Can't find a function: node=", SummarizeNode(*n));
91 }
92
93 std::unique_ptr<FunctionBody> fbody;
94 TF_RETURN_IF_ERROR(
95 FunctionDefToBodyHelper(*fdef, n->attrs(), &flib_def, &fbody));
96
97 Status can_inline_function_call =
98 ValidateInlining(n, fbody.get(), inline_options);
99 if (can_inline_function_call.ok()) {
100 TF_RETURN_IF_ERROR(
101 InlineFunctionBody(flib_def, g, n, fbody.get(), inline_options));
102 } else {
103 VLOG(2) << "Failed to inline function call node: "
104 << can_inline_function_call.error_message();
105 }
106
107 return Status::OK();
108 }
109
110 } // namespace tensorflow
111