1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/strings/match.h"
17 #include "tensorflow/cc/framework/scope.h"
18 #include "tensorflow/cc/ops/array_ops.h"
19 #include "tensorflow/cc/ops/function_ops.h"
20 #include "tensorflow/cc/ops/functional_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/compiler/jit/encapsulate_util.h"
23 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/core/common_runtime/device_factory.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/framework/common_shape_fns.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/graph_to_functiondef.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/protobuf/error_codes.pb.h"
34 #include "tensorflow/core/public/session_options.h"
35 #include "tensorflow/core/public/version.h"
36
37 namespace tensorflow {
38
TEST(RearrangeFunctionArgumentForFunctionTest,Basic)39 TEST(RearrangeFunctionArgumentForFunctionTest, Basic) {
40 FunctionDefLibrary fdl;
41 {
42 // Function for StatefulPartitionedCall's "f", If's
43 // "then_branch"/"else_branch".
44 // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
45 // "ret0" = "arg1"
46 // "ret1" = "arg0"
47 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
48 Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
49 Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
50 auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
51 auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1);
52 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
53 TF_CHECK_OK(s.ToGraph(g.get()));
54 FunctionDef *xla_fdef = fdl.add_function();
55 TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
56 }
57 {
58 // Function for While's "body".
59 // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
60 // "ret0" = "arg0"
61 // "ret1" = "arg1"
62 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
63 Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
64 Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
65 auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg0, 0);
66 auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg1, 1);
67 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
68 TF_CHECK_OK(s.ToGraph(g.get()));
69 FunctionDef *xla_fdef = fdl.add_function();
70 TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef));
71 }
72 {
73 // Function for While's "cond".
74 // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
75 // "ret0" = "arg1"
76 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
77 Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
78 Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
79 auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
80 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
81 TF_CHECK_OK(s.ToGraph(g.get()));
82 FunctionDef *xla_fdef = fdl.add_function();
83 TF_CHECK_OK(GraphToFunctionDef(*g, "f3", xla_fdef));
84 }
85 FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
86
87 // Build the XLA computation graph.
88 // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32)
89 // "arg0", "arg1" -> "if" (If) -> "ret0", "ret1"
90 // "arg0", "arg1" -> "while" (While) -> "ret2", "ret3"
91 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
92 Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
93 Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
94 NameAttrList f;
95 f.set_name("f1");
96 auto if_op = ops::If(s.WithOpName("if"), arg1,
97 std::initializer_list<Input>{arg0, arg1},
98 {DT_BOOL, DT_RESOURCE}, f, f);
99 auto ret0 = ops::_Retval(s.WithOpName("ret0"), if_op.output[0], 0);
100 auto ret1 = ops::_Retval(s.WithOpName("ret1"), if_op.output[1], 1);
101 NameAttrList cond_fn, body_fn;
102 cond_fn.set_name("f3");
103 body_fn.set_name("f2");
104 auto while_op =
105 ops::While(s.WithOpName("while"),
106 std::initializer_list<Input>{arg0, arg1}, cond_fn, body_fn);
107 auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2);
108 auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3);
109 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
110 TF_CHECK_OK(s.ToGraph(g.get()));
111
112 std::vector<std::unique_ptr<FunctionBody>> fbodies;
113 TF_CHECK_OK(RearrangeFunctionArguments(
114 [&](const NameAttrList &function, const FunctionBody **fbody) {
115 std::unique_ptr<FunctionBody> new_fbody;
116 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
117 AttrSlice(&function.attr()),
118 &fld, &new_fbody));
119 *fbody = new_fbody.get();
120 fbodies.push_back(std::move(new_fbody));
121 return OkStatus();
122 },
123 g.get(), &fld));
124
125 // Check function f1_rearrange_0, input types should be {DT_BOOL, DT_RESOURCE}
126 // and output types should be {DT_BOOL}.
127 const FunctionDef *f1_rewritten = fld.Find("f1_rearrange_0");
128 CHECK_NE(f1_rewritten, nullptr);
129 ASSERT_EQ(f1_rewritten->signature().input_arg_size(), 2);
130 EXPECT_EQ(f1_rewritten->signature().input_arg(0).type(), DT_BOOL);
131 EXPECT_EQ(f1_rewritten->signature().input_arg(1).type(), DT_RESOURCE);
132 ASSERT_EQ(f1_rewritten->signature().output_arg_size(), 1);
133 EXPECT_EQ(f1_rewritten->signature().output_arg(0).type(), DT_BOOL);
134
135 // Check node "if" input and output edges.
136 auto node_name_index = g->BuildNodeNameIndex();
137 const Node *if_node = node_name_index.at("if");
138 ASSERT_NE(if_node, nullptr);
139 const Node *input_node;
140 TF_CHECK_OK(if_node->input_node(1, &input_node));
141 EXPECT_EQ(input_node->name(), "arg1");
142 TF_CHECK_OK(if_node->input_node(2, &input_node));
143 EXPECT_EQ(input_node->name(), "arg0");
144 const Node *ret0_node = node_name_index.at("ret0");
145 ASSERT_NE(ret0_node, nullptr);
146 TF_CHECK_OK(ret0_node->input_node(0, &input_node));
147 EXPECT_EQ(input_node->name(), "if");
148 const Node *ret1_node = node_name_index.at("ret1");
149 ASSERT_NE(ret1_node, nullptr);
150 TF_CHECK_OK(ret1_node->input_node(0, &input_node));
151 EXPECT_EQ(input_node->name(), "arg0");
152
153 // Check node "while" input and output edges.
154 const Node *while_node = node_name_index.at("while");
155 ASSERT_NE(while_node, nullptr);
156 TF_CHECK_OK(while_node->input_node(0, &input_node));
157 EXPECT_EQ(input_node->name(), "arg1");
158 TF_CHECK_OK(while_node->input_node(1, &input_node));
159 EXPECT_EQ(input_node->name(), "arg0");
160 const Node *ret2_node = node_name_index.at("ret2");
161 ASSERT_NE(ret2_node, nullptr);
162 TF_CHECK_OK(ret2_node->input_node(0, &input_node));
163 EXPECT_EQ(input_node->name(), "arg0");
164 const Node *ret3_node = node_name_index.at("ret3");
165 ASSERT_NE(ret3_node, nullptr);
166 TF_CHECK_OK(ret3_node->input_node(0, &input_node));
167 EXPECT_EQ(input_node->name(), "while");
168 }
169
TEST(RearrangeFunctionArgumentForFunctionTest,WhileResourceRetvalFromDifferentArgUnimplemented)170 TEST(RearrangeFunctionArgumentForFunctionTest,
171 WhileResourceRetvalFromDifferentArgUnimplemented) {
172 FunctionDefLibrary fdl;
173 {
174 // Function for While's "body".
175 // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
176 // "ret0" = "arg1"
177 // "ret1" = "arg0"
178 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
179 Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
180 Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
181 Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
182 auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
183 auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1);
184 auto ret2 = ops::_Retval(s.WithOpName("ret2"), arg2, 2);
185 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
186 TF_CHECK_OK(s.ToGraph(g.get()));
187 FunctionDef *xla_fdef = fdl.add_function();
188 TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef));
189 }
190 {
191 // Function for While's "cond".
192 // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
193 // "ret0" = true
194 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
195 Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
196 Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
197 Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
198 Output cond = ops::Const(s.WithOpName("const"), true, TensorShape({}));
199 auto ret0 = ops::_Retval(s.WithOpName("ret0"), cond, 0);
200 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
201 TF_CHECK_OK(s.ToGraph(g.get()));
202 FunctionDef *xla_fdef = fdl.add_function();
203 TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
204 }
205 FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
206
207 // Build the XLA computation graph.
208 // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
209 // "arg0", "arg1" -> "while" (While)
210 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
211 Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
212 Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
213 Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
214 NameAttrList cond_fn, body_fn;
215 cond_fn.set_name("f1");
216 body_fn.set_name("f2");
217 auto while_op = ops::While(s.WithOpName("while"),
218 std::initializer_list<Input>{arg0, arg1, arg2},
219 cond_fn, body_fn);
220 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
221 TF_CHECK_OK(s.ToGraph(g.get()));
222
223 std::vector<std::unique_ptr<FunctionBody>> fbodies;
224 Status status = RearrangeFunctionArguments(
225 [&](const NameAttrList &function, const FunctionBody **fbody) {
226 std::unique_ptr<FunctionBody> new_fbody;
227 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
228 AttrSlice(&function.attr()),
229 &fld, &new_fbody));
230 *fbody = new_fbody.get();
231 fbodies.push_back(std::move(new_fbody));
232 return OkStatus();
233 },
234 g.get(), &fld);
235 EXPECT_EQ(status.code(), error::UNIMPLEMENTED);
236 }
237
238 } // namespace tensorflow
239