• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/client/client_session.h"
17 #include "tensorflow/cc/framework/ops.h"
18 #include "tensorflow/cc/ops/array_ops.h"
19 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/resource_variable_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/common_runtime/graph_constructor.h"
24 #include "tensorflow/core/common_runtime/graph_runner.h"
25 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
26 #include "tensorflow/core/framework/function_testlib.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/test.h"
34 
35 namespace tensorflow {
36 namespace {
37 
FuncAttr(const string & name)38 AttrValue FuncAttr(const string& name) {
39   AttrValue attr;
40   attr.mutable_func()->set_name(name);
41   return attr;
42 }
43 
FuncAttr(const string & name,const DataType type)44 AttrValue FuncAttr(const string& name, const DataType type) {
45   AttrValue attr;
46   attr.mutable_func()->set_name(name);
47   (*attr.mutable_func()->mutable_attr())["T"].set_type(type);
48   return attr;
49 }
50 
SessionOptionsWithInlining()51 SessionOptions SessionOptionsWithInlining() {
52   SessionOptions session_options;
53   session_options.config.mutable_graph_options()
54       ->mutable_optimizer_options()
55       ->set_do_function_inlining(true);
56   return session_options;
57 }
58 
Rewrite(std::unique_ptr<Graph> * graph)59 Status Rewrite(std::unique_ptr<Graph>* graph) {
60   FunctionLibraryDefinition flib_def((*graph)->flib_def());
61   GraphOptimizationPassOptions opt_options;
62   SessionOptions session_options = SessionOptionsWithInlining();
63   opt_options.session_options = &session_options;
64   opt_options.graph = graph;
65   opt_options.flib_def = &flib_def;
66   LowerFunctionalOpsPass pass;
67   return pass.Run(opt_options);
68 }
69 
TEST(LowerFunctionCallTest,InlineFunctionCall)70 TEST(LowerFunctionCallTest, InlineFunctionCall) {
71   using FDH = FunctionDefHelper;
72 
73   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
74 
75   FunctionDefLibrary f_lib_proto;
76 
77   // `add` node is not required to compute regular output `o`, but it must
78   // execute because it is in `control_ret`.
79   *(f_lib_proto.add_function()) =
80       FDH::Create("AddAndMul", {"i: int32"}, {"o: int32"}, {},
81                   {{{"add"}, "Add", {"i", "i"}, {{"T", DT_INT32}}},
82                    {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_INT32}}}},
83                   /*ret_def=*/{{"o", "ret:z:0"}},
84                   /*control_ret_def=*/{{"must_execute", "add"}});
85 
86   // Construct a graph:
87   //   A = Placeholder[dtype=int32]
88   //   F = PartitionedCall[f=AddAndMul](a)
89   //   B = Identity(func, ^func)
90   Scope root = Scope::NewRootScope().ExitOnError();
91   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
92   auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
93   Node* function_call;
94   std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
95   TF_ASSERT_OK(NodeBuilder("F", "PartitionedCall", &root.graph()->flib_def())
96                    .Input(inputs)
97                    .Attr("Tin", {DT_INT32})
98                    .Attr("Tout", {DT_INT32})
99                    .Attr("f", FuncAttr("AddAndMul"))
100                    .Finalize(root.graph(), &function_call));
101   TF_ASSERT_OK(root.DoShapeInference(function_call));
102 
103   auto b = ops::Identity(root.WithOpName("B"), Output(function_call, 0));
104   root.graph()->AddControlEdge(function_call, b.node());
105 
106   TF_ASSERT_OK(root.ToGraph(graph.get()));
107   TF_ASSERT_OK(Rewrite(&graph));
108 
109   // Verify the resultant graph has no PartitionedCall ops and function body was
110   // inlined into the main graph.
111   int partitioned_call_count = 0;
112   int add_count = 0;
113   int mul_count = 0;
114   for (const auto* op : graph->op_nodes()) {
115     if (op->IsPartitionedCall()) partitioned_call_count++;
116     if (op->type_string() == "Add") add_count++;
117     if (op->type_string() == "Mul") mul_count++;
118   }
119 
120   ASSERT_EQ(partitioned_call_count, 0);
121   ASSERT_EQ(add_count, 1);
122   ASSERT_EQ(mul_count, 1);
123 
124   // Verify execution.
125   ClientSession session(root, SessionOptionsWithInlining());
126   {
127     ClientSession::FeedType feeds;
128     feeds.emplace(Output(a.node()), Input::Initializer(10));
129     std::vector<Tensor> out_tensors;
130     TF_ASSERT_OK(session.Run(feeds, {Output(b)}, &out_tensors));
131     EXPECT_EQ(out_tensors.size(), 1);
132     EXPECT_EQ(out_tensors[0].scalar<int>()(), 100);
133   }
134 }
135 
TEST(LowerFunctionCallTest,DoNotInlineTpuOrXlaFunctions)136 TEST(LowerFunctionCallTest, DoNotInlineTpuOrXlaFunctions) {
137   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
138 
139   FunctionDef tpu_func = test::function::XTimesTwo();
140   tpu_func.mutable_signature()->set_name("TpuXTimesTwo");
141   (*tpu_func.mutable_attr())["_tpu_replicate"].set_b(true);
142 
143   FunctionDef xla_func = test::function::XTimesTwo();
144   xla_func.mutable_signature()->set_name("XlaXTimesTwo");
145   (*xla_func.mutable_attr())["_xla_compile_id"].set_s("cluster_0");
146 
147   FunctionDefLibrary f_lib_proto;
148   *(f_lib_proto.add_function()) = test::function::XTimesTwo();
149 
150   // Construct a graph:
151   //   A = Placeholder[dtype=int32]
152   //   B = XTimesTwo[_tpu_replicate="cluster"](A)
153   //   C = XTimesTwo[_xla_compile_id="cluster"](A)
154   Scope root = Scope::NewRootScope().ExitOnError();
155   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
156   auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
157   std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
158 
159   Node* tpu_call;
160   TF_ASSERT_OK(NodeBuilder("B", "PartitionedCall", &root.graph()->flib_def())
161                    .Input(inputs)
162                    .Attr("Tin", {DT_INT32})
163                    .Attr("Tout", {DT_INT32})
164                    .Attr("f", FuncAttr("XTimesTwo", DT_INT32))
165                    .Attr("_tpu_replicate", "cluster")
166                    .Finalize(root.graph(), &tpu_call));
167 
168   Node* xla_call;
169   TF_ASSERT_OK(NodeBuilder("C", "PartitionedCall", &root.graph()->flib_def())
170                    .Input(inputs)
171                    .Attr("Tin", {DT_INT32})
172                    .Attr("Tout", {DT_INT32})
173                    .Attr("f", FuncAttr("XTimesTwo", DT_INT32))
174                    .Attr("_xla_compile_id", "cluster")
175                    .Finalize(root.graph(), &xla_call));
176 
177   TF_ASSERT_OK(root.DoShapeInference(tpu_call));
178   TF_ASSERT_OK(root.DoShapeInference(xla_call));
179   TF_ASSERT_OK(root.ToGraph(graph.get()));
180   TF_ASSERT_OK(Rewrite(&graph));
181 
182   // Verify that we do not inline any of the special function call nodes.
183   int partitioned_call_count = 0;
184   for (const auto* op : graph->op_nodes()) {
185     if (op->IsPartitionedCall()) partitioned_call_count++;
186   }
187   ASSERT_EQ(partitioned_call_count, 2);
188 
189   // Verify execution.
190   ClientSession session(root, SessionOptionsWithInlining());
191   {
192     ClientSession::FeedType feeds;
193     feeds.emplace(Output(a.node()), Input::Initializer(10));
194     std::vector<Tensor> out_tensors;
195     TF_ASSERT_OK(
196         session.Run(feeds, {Output(tpu_call), Output(xla_call)}, &out_tensors));
197     EXPECT_EQ(out_tensors.size(), 2);
198     EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
199     EXPECT_EQ(out_tensors[1].scalar<int>()(), 20);
200   }
201 }
202 
203 }  // namespace
204 }  // namespace tensorflow
205