1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/client/client_session.h"
17 #include "tensorflow/cc/framework/ops.h"
18 #include "tensorflow/cc/ops/array_ops.h"
19 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/resource_variable_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/common_runtime/graph_constructor.h"
24 #include "tensorflow/core/common_runtime/graph_runner.h"
25 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
26 #include "tensorflow/core/framework/function_testlib.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/test.h"
34
35 namespace tensorflow {
36 namespace {
37
FuncAttr(const string & name)38 AttrValue FuncAttr(const string& name) {
39 AttrValue attr;
40 attr.mutable_func()->set_name(name);
41 return attr;
42 }
43
FuncAttr(const string & name,const DataType type)44 AttrValue FuncAttr(const string& name, const DataType type) {
45 AttrValue attr;
46 attr.mutable_func()->set_name(name);
47 (*attr.mutable_func()->mutable_attr())["T"].set_type(type);
48 return attr;
49 }
50
SessionOptionsWithInlining()51 SessionOptions SessionOptionsWithInlining() {
52 SessionOptions session_options;
53 session_options.config.mutable_graph_options()
54 ->mutable_optimizer_options()
55 ->set_do_function_inlining(true);
56 return session_options;
57 }
58
Rewrite(std::unique_ptr<Graph> * graph)59 Status Rewrite(std::unique_ptr<Graph>* graph) {
60 FunctionLibraryDefinition flib_def((*graph)->flib_def());
61 GraphOptimizationPassOptions opt_options;
62 SessionOptions session_options = SessionOptionsWithInlining();
63 opt_options.session_options = &session_options;
64 opt_options.graph = graph;
65 opt_options.flib_def = &flib_def;
66 LowerFunctionalOpsPass pass;
67 return pass.Run(opt_options);
68 }
69
TEST(LowerFunctionCallTest,InlineFunctionCall)70 TEST(LowerFunctionCallTest, InlineFunctionCall) {
71 using FDH = FunctionDefHelper;
72
73 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
74
75 FunctionDefLibrary f_lib_proto;
76
77 // `add` node is not required to compute regular output `o`, but it must
78 // execute because it is in `control_ret`.
79 *(f_lib_proto.add_function()) =
80 FDH::Create("AddAndMul", {"i: int32"}, {"o: int32"}, {},
81 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_INT32}}},
82 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_INT32}}}},
83 /*ret_def=*/{{"o", "ret:z:0"}},
84 /*control_ret_def=*/{{"must_execute", "add"}});
85
86 // Construct a graph:
87 // A = Placeholder[dtype=int32]
88 // F = PartitionedCall[f=AddAndMul](a)
89 // B = Identity(func, ^func)
90 Scope root = Scope::NewRootScope().ExitOnError();
91 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
92 auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
93 Node* function_call;
94 std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
95 TF_ASSERT_OK(NodeBuilder("F", "PartitionedCall", &root.graph()->flib_def())
96 .Input(inputs)
97 .Attr("Tin", {DT_INT32})
98 .Attr("Tout", {DT_INT32})
99 .Attr("f", FuncAttr("AddAndMul"))
100 .Finalize(root.graph(), &function_call));
101 TF_ASSERT_OK(root.DoShapeInference(function_call));
102
103 auto b = ops::Identity(root.WithOpName("B"), Output(function_call, 0));
104 root.graph()->AddControlEdge(function_call, b.node());
105
106 TF_ASSERT_OK(root.ToGraph(graph.get()));
107 TF_ASSERT_OK(Rewrite(&graph));
108
109 // Verify the resultant graph has no PartitionedCall ops and function body was
110 // inlined into the main graph.
111 int partitioned_call_count = 0;
112 int add_count = 0;
113 int mul_count = 0;
114 for (const auto* op : graph->op_nodes()) {
115 if (op->IsPartitionedCall()) partitioned_call_count++;
116 if (op->type_string() == "Add") add_count++;
117 if (op->type_string() == "Mul") mul_count++;
118 }
119
120 ASSERT_EQ(partitioned_call_count, 0);
121 ASSERT_EQ(add_count, 1);
122 ASSERT_EQ(mul_count, 1);
123
124 // Verify execution.
125 ClientSession session(root, SessionOptionsWithInlining());
126 {
127 ClientSession::FeedType feeds;
128 feeds.emplace(Output(a.node()), Input::Initializer(10));
129 std::vector<Tensor> out_tensors;
130 TF_ASSERT_OK(session.Run(feeds, {Output(b)}, &out_tensors));
131 EXPECT_EQ(out_tensors.size(), 1);
132 EXPECT_EQ(out_tensors[0].scalar<int>()(), 100);
133 }
134 }
135
TEST(LowerFunctionCallTest,DoNotInlineTpuOrXlaFunctions)136 TEST(LowerFunctionCallTest, DoNotInlineTpuOrXlaFunctions) {
137 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
138
139 FunctionDef tpu_func = test::function::XTimesTwo();
140 tpu_func.mutable_signature()->set_name("TpuXTimesTwo");
141 (*tpu_func.mutable_attr())["_tpu_replicate"].set_b(true);
142
143 FunctionDef xla_func = test::function::XTimesTwo();
144 xla_func.mutable_signature()->set_name("XlaXTimesTwo");
145 (*xla_func.mutable_attr())["_xla_compile_id"].set_s("cluster_0");
146
147 FunctionDefLibrary f_lib_proto;
148 *(f_lib_proto.add_function()) = test::function::XTimesTwo();
149
150 // Construct a graph:
151 // A = Placeholder[dtype=int32]
152 // B = XTimesTwo[_tpu_replicate="cluster"](A)
153 // C = XTimesTwo[_xla_compile_id="cluster"](A)
154 Scope root = Scope::NewRootScope().ExitOnError();
155 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
156 auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
157 std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
158
159 Node* tpu_call;
160 TF_ASSERT_OK(NodeBuilder("B", "PartitionedCall", &root.graph()->flib_def())
161 .Input(inputs)
162 .Attr("Tin", {DT_INT32})
163 .Attr("Tout", {DT_INT32})
164 .Attr("f", FuncAttr("XTimesTwo", DT_INT32))
165 .Attr("_tpu_replicate", "cluster")
166 .Finalize(root.graph(), &tpu_call));
167
168 Node* xla_call;
169 TF_ASSERT_OK(NodeBuilder("C", "PartitionedCall", &root.graph()->flib_def())
170 .Input(inputs)
171 .Attr("Tin", {DT_INT32})
172 .Attr("Tout", {DT_INT32})
173 .Attr("f", FuncAttr("XTimesTwo", DT_INT32))
174 .Attr("_xla_compile_id", "cluster")
175 .Finalize(root.graph(), &xla_call));
176
177 TF_ASSERT_OK(root.DoShapeInference(tpu_call));
178 TF_ASSERT_OK(root.DoShapeInference(xla_call));
179 TF_ASSERT_OK(root.ToGraph(graph.get()));
180 TF_ASSERT_OK(Rewrite(&graph));
181
182 // Verify that we do not inline any of the special function call nodes.
183 int partitioned_call_count = 0;
184 for (const auto* op : graph->op_nodes()) {
185 if (op->IsPartitionedCall()) partitioned_call_count++;
186 }
187 ASSERT_EQ(partitioned_call_count, 2);
188
189 // Verify execution.
190 ClientSession session(root, SessionOptionsWithInlining());
191 {
192 ClientSession::FeedType feeds;
193 feeds.emplace(Output(a.node()), Input::Initializer(10));
194 std::vector<Tensor> out_tensors;
195 TF_ASSERT_OK(
196 session.Run(feeds, {Output(tpu_call), Output(xla_call)}, &out_tensors));
197 EXPECT_EQ(out_tensors.size(), 2);
198 EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
199 EXPECT_EQ(out_tensors[1].scalar<int>()(), 20);
200 }
201 }
202
203 } // namespace
204 } // namespace tensorflow
205