• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/strings/match.h"
17 #include "tensorflow/cc/framework/scope.h"
18 #include "tensorflow/cc/ops/array_ops.h"
19 #include "tensorflow/cc/ops/function_ops.h"
20 #include "tensorflow/cc/ops/functional_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/compiler/jit/encapsulate_util.h"
23 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
24 #include "tensorflow/compiler/xla/test.h"
25 #include "tensorflow/core/common_runtime/device_factory.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/framework/common_shape_fns.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/graph_to_functiondef.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/protobuf/error_codes.pb.h"
34 #include "tensorflow/core/public/session_options.h"
35 #include "tensorflow/core/public/version.h"
36 
37 namespace tensorflow {
38 
TEST(RearrangeFunctionArgumentForFunctionTest,Basic)39 TEST(RearrangeFunctionArgumentForFunctionTest, Basic) {
40   FunctionDefLibrary fdl;
41   {
42     // Function for StatefulPartitionedCall's "f", If's
43     // "then_branch"/"else_branch".
44     // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
45     // "ret0" = "arg1"
46     // "ret1" = "arg0"
47     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
48     Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
49     Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
50     auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
51     auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1);
52     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
53     TF_CHECK_OK(s.ToGraph(g.get()));
54     FunctionDef *xla_fdef = fdl.add_function();
55     TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
56   }
57   {
58     // Function for While's "body".
59     // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
60     // "ret0" = "arg0"
61     // "ret1" = "arg1"
62     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
63     Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
64     Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
65     auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg0, 0);
66     auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg1, 1);
67     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
68     TF_CHECK_OK(s.ToGraph(g.get()));
69     FunctionDef *xla_fdef = fdl.add_function();
70     TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef));
71   }
72   {
73     // Function for While's "cond".
74     // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL)
75     // "ret0" = "arg1"
76     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
77     Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
78     Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
79     auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
80     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
81     TF_CHECK_OK(s.ToGraph(g.get()));
82     FunctionDef *xla_fdef = fdl.add_function();
83     TF_CHECK_OK(GraphToFunctionDef(*g, "f3", xla_fdef));
84   }
85   FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
86 
87   // Build the XLA computation graph.
88   // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32)
89   // "arg0", "arg1" -> "if" (If) -> "ret0", "ret1"
90   // "arg0", "arg1" -> "while" (While) -> "ret2", "ret3"
91   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
92   Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
93   Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1);
94   NameAttrList f;
95   f.set_name("f1");
96   auto if_op = ops::If(s.WithOpName("if"), arg1,
97                        std::initializer_list<Input>{arg0, arg1},
98                        {DT_BOOL, DT_RESOURCE}, f, f);
99   auto ret0 = ops::_Retval(s.WithOpName("ret0"), if_op.output[0], 0);
100   auto ret1 = ops::_Retval(s.WithOpName("ret1"), if_op.output[1], 1);
101   NameAttrList cond_fn, body_fn;
102   cond_fn.set_name("f3");
103   body_fn.set_name("f2");
104   auto while_op =
105       ops::While(s.WithOpName("while"),
106                  std::initializer_list<Input>{arg0, arg1}, cond_fn, body_fn);
107   auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2);
108   auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3);
109   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
110   TF_CHECK_OK(s.ToGraph(g.get()));
111 
112   std::vector<std::unique_ptr<FunctionBody>> fbodies;
113   TF_CHECK_OK(RearrangeFunctionArguments(
114       [&](const NameAttrList &function, const FunctionBody **fbody) {
115         std::unique_ptr<FunctionBody> new_fbody;
116         TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
117                                                    AttrSlice(&function.attr()),
118                                                    &fld, &new_fbody));
119         *fbody = new_fbody.get();
120         fbodies.push_back(std::move(new_fbody));
121         return OkStatus();
122       },
123       g.get(), &fld));
124 
125   // Check function f1_rearrange_0, input types should be {DT_BOOL, DT_RESOURCE}
126   // and output types should be {DT_BOOL}.
127   const FunctionDef *f1_rewritten = fld.Find("f1_rearrange_0");
128   CHECK_NE(f1_rewritten, nullptr);
129   ASSERT_EQ(f1_rewritten->signature().input_arg_size(), 2);
130   EXPECT_EQ(f1_rewritten->signature().input_arg(0).type(), DT_BOOL);
131   EXPECT_EQ(f1_rewritten->signature().input_arg(1).type(), DT_RESOURCE);
132   ASSERT_EQ(f1_rewritten->signature().output_arg_size(), 1);
133   EXPECT_EQ(f1_rewritten->signature().output_arg(0).type(), DT_BOOL);
134 
135   // Check node "if" input and output edges.
136   auto node_name_index = g->BuildNodeNameIndex();
137   const Node *if_node = node_name_index.at("if");
138   ASSERT_NE(if_node, nullptr);
139   const Node *input_node;
140   TF_CHECK_OK(if_node->input_node(1, &input_node));
141   EXPECT_EQ(input_node->name(), "arg1");
142   TF_CHECK_OK(if_node->input_node(2, &input_node));
143   EXPECT_EQ(input_node->name(), "arg0");
144   const Node *ret0_node = node_name_index.at("ret0");
145   ASSERT_NE(ret0_node, nullptr);
146   TF_CHECK_OK(ret0_node->input_node(0, &input_node));
147   EXPECT_EQ(input_node->name(), "if");
148   const Node *ret1_node = node_name_index.at("ret1");
149   ASSERT_NE(ret1_node, nullptr);
150   TF_CHECK_OK(ret1_node->input_node(0, &input_node));
151   EXPECT_EQ(input_node->name(), "arg0");
152 
153   // Check node "while" input and output edges.
154   const Node *while_node = node_name_index.at("while");
155   ASSERT_NE(while_node, nullptr);
156   TF_CHECK_OK(while_node->input_node(0, &input_node));
157   EXPECT_EQ(input_node->name(), "arg1");
158   TF_CHECK_OK(while_node->input_node(1, &input_node));
159   EXPECT_EQ(input_node->name(), "arg0");
160   const Node *ret2_node = node_name_index.at("ret2");
161   ASSERT_NE(ret2_node, nullptr);
162   TF_CHECK_OK(ret2_node->input_node(0, &input_node));
163   EXPECT_EQ(input_node->name(), "arg0");
164   const Node *ret3_node = node_name_index.at("ret3");
165   ASSERT_NE(ret3_node, nullptr);
166   TF_CHECK_OK(ret3_node->input_node(0, &input_node));
167   EXPECT_EQ(input_node->name(), "while");
168 }
169 
TEST(RearrangeFunctionArgumentForFunctionTest,WhileResourceRetvalFromDifferentArgUnimplemented)170 TEST(RearrangeFunctionArgumentForFunctionTest,
171      WhileResourceRetvalFromDifferentArgUnimplemented) {
172   FunctionDefLibrary fdl;
173   {
174     // Function for While's "body".
175     // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
176     // "ret0" = "arg1"
177     // "ret1" = "arg0"
178     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
179     Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
180     Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
181     Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
182     auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0);
183     auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1);
184     auto ret2 = ops::_Retval(s.WithOpName("ret2"), arg2, 2);
185     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
186     TF_CHECK_OK(s.ToGraph(g.get()));
187     FunctionDef *xla_fdef = fdl.add_function();
188     TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef));
189   }
190   {
191     // Function for While's "cond".
192     // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
193     // "ret0" = true
194     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
195     Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
196     Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
197     Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
198     Output cond = ops::Const(s.WithOpName("const"), true, TensorShape({}));
199     auto ret0 = ops::_Retval(s.WithOpName("ret0"), cond, 0);
200     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
201     TF_CHECK_OK(s.ToGraph(g.get()));
202     FunctionDef *xla_fdef = fdl.add_function();
203     TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef));
204   }
205   FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
206 
207   // Build the XLA computation graph.
208   // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32)
209   // "arg0", "arg1" -> "while" (While)
210   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
211   Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0);
212   Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1);
213   Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2);
214   NameAttrList cond_fn, body_fn;
215   cond_fn.set_name("f1");
216   body_fn.set_name("f2");
217   auto while_op = ops::While(s.WithOpName("while"),
218                              std::initializer_list<Input>{arg0, arg1, arg2},
219                              cond_fn, body_fn);
220   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
221   TF_CHECK_OK(s.ToGraph(g.get()));
222 
223   std::vector<std::unique_ptr<FunctionBody>> fbodies;
224   Status status = RearrangeFunctionArguments(
225       [&](const NameAttrList &function, const FunctionBody **fbody) {
226         std::unique_ptr<FunctionBody> new_fbody;
227         TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()),
228                                                    AttrSlice(&function.attr()),
229                                                    &fld, &new_fbody));
230         *fbody = new_fbody.get();
231         fbodies.push_back(std::move(new_fbody));
232         return OkStatus();
233       },
234       g.get(), &fld);
235   EXPECT_EQ(status.code(), error::UNIMPLEMENTED);
236 }
237 
238 }  // namespace tensorflow
239