• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/tools/graph_transforms/transform_graph.h"
17 #include "tensorflow/cc/ops/const_op.h"
18 #include "tensorflow/cc/ops/image_ops.h"
19 #include "tensorflow/cc/ops/nn_ops.h"
20 #include "tensorflow/cc/ops/sendrecv_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/lib/io/path.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/platform/test_benchmark.h"
28 #include "tensorflow/core/public/session.h"
29 #include "tensorflow/tools/graph_transforms/transform_utils.h"
30 
31 namespace tensorflow {
32 namespace graph_transforms {
33 
34 // Declared here so we don't have to expose it in the public header.
35 Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params,
36                           bool* ignore_errors);
37 
38 namespace {
test_empty_graph_transform(const GraphDef & graph_def,const TransformFuncContext & context,GraphDef * result)39 Status test_empty_graph_transform(const GraphDef& graph_def,
40                                   const TransformFuncContext& context,
41                                   GraphDef* result) {
42   result->Clear();
43   return Status::OK();
44 }
45 }  // namespace
46 
47 REGISTER_GRAPH_TRANSFORM("test_empty_graph_transform",
48                          test_empty_graph_transform);
49 
50 class TransformGraphTest : public ::testing::Test {
51  protected:
TestConstantFolding()52   void TestConstantFolding() {
53     auto root = tensorflow::Scope::NewRootScope();
54     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
55 
56     const int width = 100;
57 
58     Tensor a_data(DT_FLOAT, TensorShape({width}));
59     test::FillIota<float>(&a_data, 1.0f);
60     Output a_const =
61         Const(root.WithOpName("a_expect_removed"), Input::Initializer(a_data));
62 
63     Tensor b_data(DT_FLOAT, TensorShape({width}));
64     test::FillIota<float>(&b_data, 1.0f);
65     Output b_const =
66         Const(root.WithOpName("b_expect_removed"), Input::Initializer(b_data));
67 
68     Output add = Add(root.WithOpName("add_expect_removed"), a_const, b_const);
69 
70     Output placeholder =
71         Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
72 
73     Output mul =
74         Mul(root.WithOpName("output_expect_remains"), add, placeholder);
75 
76     GraphDef graph_def;
77     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
78     string graph_def_serialized;
79     graph_def.SerializeToString(&graph_def_serialized);
80     const string dir = testing::TmpDir();
81     const string in_filename_pb = io::JoinPath(dir, "in_graphdef.pb");
82     const string out_filename_pb = io::JoinPath(dir, "out_graphdef.pb");
83     TF_ASSERT_OK(WriteStringToFile(Env::Default(), in_filename_pb,
84                                    graph_def_serialized));
85 
86     std::vector<string> args = {"some_binary",
87                                 "--in_graph=" + in_filename_pb,
88                                 "--out_graph=" + out_filename_pb,
89                                 "--inputs=placeholder_expect_remains",
90                                 "--outputs=output_expect_remains",
91                                 "--transforms=fold_constants"};
92     const int argc = 6;
93     EXPECT_EQ(argc, args.size());
94     char* argv[argc];
95     std::vector<char*> char_strings;
96     for (int i = 0; i < argc; ++i) {
97       string arg = args[i];
98       char* char_string = new char[arg.size() + 1];
99       std::copy_n(arg.c_str(), arg.size() + 1, char_string);
100       argv[i] = char_string;
101       char_strings.push_back(char_string);
102     }
103     ParseFlagsAndTransformGraph(argc, argv, false);
104     for (char* char_string : char_strings) {
105       delete[] char_string;
106     }
107 
108     GraphDef out_graph_def;
109     TF_EXPECT_OK(
110         ReadBinaryProto(Env::Default(), out_filename_pb, &out_graph_def));
111 
112     std::map<string, const NodeDef*> out_node_map;
113     graph_transforms::MapNamesToNodes(out_graph_def, &out_node_map);
114 
115     for (const NodeDef& node : out_graph_def.node()) {
116       const int occurrence_count = out_node_map.count(node.name());
117       if (str_util::EndsWith(node.name(), "expect_removed")) {
118         EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name();
119       }
120       if (str_util::EndsWith(node.name(), "expect_remains")) {
121         EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name();
122       }
123     }
124   }
125 
TestTransformRegistration()126   void TestTransformRegistration() {
127     auto root = tensorflow::Scope::NewRootScope();
128     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
129     Output placeholder =
130         Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
131     GraphDef graph_def;
132     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
133     EXPECT_EQ(1, graph_def.node().size());
134     TF_ASSERT_OK(TransformGraph({}, {}, {{"test_empty_graph_transform", {}}},
135                                 &graph_def));
136     EXPECT_EQ(0, graph_def.node().size());
137 
138     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
139     Status no_such_status =
140         TransformGraph({}, {}, {{"test_no_such_transform", {}}}, &graph_def);
141     EXPECT_TRUE(absl::StrContains(no_such_status.ToString(), "not recognized"));
142   }
143 
TestParseTransformParameters()144   void TestParseTransformParameters() {
145     TransformParameters params_list;
146 
147     TF_EXPECT_OK(ParseTransformParameters("foo", &params_list));
148     EXPECT_EQ(1, params_list.size());
149     EXPECT_EQ("foo", params_list[0].first);
150     EXPECT_TRUE(params_list[0].second.empty());
151 
152     TF_EXPECT_OK(ParseTransformParameters("foo bar", &params_list));
153     EXPECT_EQ(2, params_list.size());
154     EXPECT_EQ("foo", params_list[0].first);
155     EXPECT_TRUE(params_list[0].second.empty());
156     EXPECT_EQ("bar", params_list[1].first);
157     EXPECT_TRUE(params_list[1].second.empty());
158 
159     TF_EXPECT_OK(ParseTransformParameters("foo() bar()", &params_list));
160     EXPECT_EQ(2, params_list.size());
161     EXPECT_EQ("foo", params_list[0].first);
162     EXPECT_TRUE(params_list[0].second.empty());
163     EXPECT_EQ("bar", params_list[1].first);
164     EXPECT_TRUE(params_list[1].second.empty());
165 
166     TF_EXPECT_OK(
167         ParseTransformParameters("foo(bob_something=sue)", &params_list));
168     EXPECT_EQ(1, params_list.size());
169     EXPECT_EQ("foo", params_list[0].first);
170     EXPECT_EQ(1, params_list[0].second.count("bob_something"));
171     EXPECT_EQ(1, params_list[0].second["bob_something"].size());
172     EXPECT_EQ("sue", params_list[0].second["bob_something"][0]);
173 
174     TF_EXPECT_OK(ParseTransformParameters("bar(a=1, b=2, a=3)", &params_list));
175     EXPECT_EQ(1, params_list.size());
176     EXPECT_EQ("bar", params_list[0].first);
177     EXPECT_EQ(1, params_list[0].second.count("a"));
178     EXPECT_EQ(2, params_list[0].second["a"].size());
179     EXPECT_EQ("1", params_list[0].second["a"][0]);
180     EXPECT_EQ("3", params_list[0].second["a"][1]);
181     EXPECT_EQ(1, params_list[0].second.count("b"));
182     EXPECT_EQ(1, params_list[0].second["b"].size());
183     EXPECT_EQ("2", params_list[0].second["b"][0]);
184 
185     TF_EXPECT_OK(ParseTransformParameters("bar(a=\"1\", b=\"1,2,3\", a=3)",
186                                           &params_list));
187     EXPECT_EQ(1, params_list.size());
188     EXPECT_EQ("bar", params_list[0].first);
189     EXPECT_EQ(1, params_list[0].second.count("a"));
190     EXPECT_EQ(2, params_list[0].second["a"].size());
191     EXPECT_EQ("1", params_list[0].second["a"][0]);
192     EXPECT_EQ("3", params_list[0].second["a"][1]);
193     EXPECT_EQ(1, params_list[0].second.count("b"));
194     EXPECT_EQ(1, params_list[0].second["b"].size());
195     EXPECT_EQ("1,2,3", params_list[0].second["b"][0]);
196   }
197 
TestParseEscapedNewline()198   void TestParseEscapedNewline() {
199     // This sequence of characters caused an infinite loop in the parser, which
200     // is responsible for the hang mentioned in
201     // https://github.com/tensorflow/tensorflow/issues/7150
202     TransformParameters params_list;
203     ParseTransformParameters("\\\n", &params_list).IgnoreError();
204     EXPECT_EQ(0, params_list.size());
205   }
206 
TestParseExtraSpaces()207   void TestParseExtraSpaces() {
208     TransformParameters params_list;
209     ParseTransformParameters(" ", &params_list).IgnoreError();
210     EXPECT_EQ(0, params_list.size());
211 
212     TF_EXPECT_OK(ParseTransformParameters("  foo bar \\\n", &params_list));
213     EXPECT_EQ(2, params_list.size());
214     EXPECT_EQ("foo", params_list[0].first);
215     EXPECT_TRUE(params_list[0].second.empty());
216     EXPECT_EQ("bar", params_list[1].first);
217     EXPECT_TRUE(params_list[1].second.empty());
218   }
219 
TestShouldIgnoreErrors()220   void TestShouldIgnoreErrors() {
221     bool ignore_errors;
222     TF_EXPECT_OK(
223         ShouldIgnoreErrors({{"ignore_errors", {"true"}}}, &ignore_errors));
224     EXPECT_TRUE(ignore_errors);
225 
226     TF_EXPECT_OK(
227         ShouldIgnoreErrors({{"ignore_errors", {"false"}}}, &ignore_errors));
228     EXPECT_FALSE(ignore_errors);
229 
230     TF_EXPECT_OK(ShouldIgnoreErrors({}, &ignore_errors));
231     EXPECT_FALSE(ignore_errors);
232 
233     EXPECT_FALSE(
234         ShouldIgnoreErrors({{"ignore_errors", {"foo"}}}, &ignore_errors).ok());
235   }
236 };
237 
TEST_F(TransformGraphTest,TestConstantFolding)238 TEST_F(TransformGraphTest, TestConstantFolding) { TestConstantFolding(); }
239 
TEST_F(TransformGraphTest,TestTransformRegistration)240 TEST_F(TransformGraphTest, TestTransformRegistration) {
241   TestTransformRegistration();
242 }
243 
TEST_F(TransformGraphTest,TestParseTransformParameters)244 TEST_F(TransformGraphTest, TestParseTransformParameters) {
245   TestParseTransformParameters();
246 }
247 
TEST_F(TransformGraphTest,TestParseEscapedNewline)248 TEST_F(TransformGraphTest, TestParseEscapedNewline) {
249   TestParseEscapedNewline();
250 }
251 
TEST_F(TransformGraphTest,TestShouldIgnoreErrors)252 TEST_F(TransformGraphTest, TestShouldIgnoreErrors) { TestShouldIgnoreErrors(); }
253 
254 }  // namespace graph_transforms
255 }  // namespace tensorflow
256