1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/tools/graph_transforms/transform_graph.h"
17 #include "tensorflow/cc/ops/const_op.h"
18 #include "tensorflow/cc/ops/image_ops.h"
19 #include "tensorflow/cc/ops/nn_ops.h"
20 #include "tensorflow/cc/ops/sendrecv_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/lib/io/path.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/platform/test_benchmark.h"
28 #include "tensorflow/core/public/session.h"
29 #include "tensorflow/tools/graph_transforms/transform_utils.h"
30
31 namespace tensorflow {
32 namespace graph_transforms {
33
34 // Declared here so we don't have to expose it in the public header.
35 Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params,
36 bool* ignore_errors);
37
38 namespace {
test_empty_graph_transform(const GraphDef & graph_def,const TransformFuncContext & context,GraphDef * result)39 Status test_empty_graph_transform(const GraphDef& graph_def,
40 const TransformFuncContext& context,
41 GraphDef* result) {
42 result->Clear();
43 return Status::OK();
44 }
45 } // namespace
46
47 REGISTER_GRAPH_TRANSFORM("test_empty_graph_transform",
48 test_empty_graph_transform);
49
50 class TransformGraphTest : public ::testing::Test {
51 protected:
TestConstantFolding()52 void TestConstantFolding() {
53 auto root = tensorflow::Scope::NewRootScope();
54 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
55
56 const int width = 100;
57
58 Tensor a_data(DT_FLOAT, TensorShape({width}));
59 test::FillIota<float>(&a_data, 1.0f);
60 Output a_const =
61 Const(root.WithOpName("a_expect_removed"), Input::Initializer(a_data));
62
63 Tensor b_data(DT_FLOAT, TensorShape({width}));
64 test::FillIota<float>(&b_data, 1.0f);
65 Output b_const =
66 Const(root.WithOpName("b_expect_removed"), Input::Initializer(b_data));
67
68 Output add = Add(root.WithOpName("add_expect_removed"), a_const, b_const);
69
70 Output placeholder =
71 Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
72
73 Output mul =
74 Mul(root.WithOpName("output_expect_remains"), add, placeholder);
75
76 GraphDef graph_def;
77 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
78 string graph_def_serialized;
79 graph_def.SerializeToString(&graph_def_serialized);
80 const string dir = testing::TmpDir();
81 const string in_filename_pb = io::JoinPath(dir, "in_graphdef.pb");
82 const string out_filename_pb = io::JoinPath(dir, "out_graphdef.pb");
83 TF_ASSERT_OK(WriteStringToFile(Env::Default(), in_filename_pb,
84 graph_def_serialized));
85
86 std::vector<string> args = {"some_binary",
87 "--in_graph=" + in_filename_pb,
88 "--out_graph=" + out_filename_pb,
89 "--inputs=placeholder_expect_remains",
90 "--outputs=output_expect_remains",
91 "--transforms=fold_constants"};
92 const int argc = 6;
93 EXPECT_EQ(argc, args.size());
94 char* argv[argc];
95 std::vector<char*> char_strings;
96 for (int i = 0; i < argc; ++i) {
97 string arg = args[i];
98 char* char_string = new char[arg.size() + 1];
99 std::copy_n(arg.c_str(), arg.size() + 1, char_string);
100 argv[i] = char_string;
101 char_strings.push_back(char_string);
102 }
103 ParseFlagsAndTransformGraph(argc, argv, false);
104 for (char* char_string : char_strings) {
105 delete[] char_string;
106 }
107
108 GraphDef out_graph_def;
109 TF_EXPECT_OK(
110 ReadBinaryProto(Env::Default(), out_filename_pb, &out_graph_def));
111
112 std::map<string, const NodeDef*> out_node_map;
113 graph_transforms::MapNamesToNodes(out_graph_def, &out_node_map);
114
115 for (const NodeDef& node : out_graph_def.node()) {
116 const int occurrence_count = out_node_map.count(node.name());
117 if (str_util::EndsWith(node.name(), "expect_removed")) {
118 EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name();
119 }
120 if (str_util::EndsWith(node.name(), "expect_remains")) {
121 EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name();
122 }
123 }
124 }
125
TestTransformRegistration()126 void TestTransformRegistration() {
127 auto root = tensorflow::Scope::NewRootScope();
128 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
129 Output placeholder =
130 Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
131 GraphDef graph_def;
132 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
133 EXPECT_EQ(1, graph_def.node().size());
134 TF_ASSERT_OK(TransformGraph({}, {}, {{"test_empty_graph_transform", {}}},
135 &graph_def));
136 EXPECT_EQ(0, graph_def.node().size());
137
138 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
139 Status no_such_status =
140 TransformGraph({}, {}, {{"test_no_such_transform", {}}}, &graph_def);
141 EXPECT_TRUE(absl::StrContains(no_such_status.ToString(), "not recognized"));
142 }
143
TestParseTransformParameters()144 void TestParseTransformParameters() {
145 TransformParameters params_list;
146
147 TF_EXPECT_OK(ParseTransformParameters("foo", ¶ms_list));
148 EXPECT_EQ(1, params_list.size());
149 EXPECT_EQ("foo", params_list[0].first);
150 EXPECT_TRUE(params_list[0].second.empty());
151
152 TF_EXPECT_OK(ParseTransformParameters("foo bar", ¶ms_list));
153 EXPECT_EQ(2, params_list.size());
154 EXPECT_EQ("foo", params_list[0].first);
155 EXPECT_TRUE(params_list[0].second.empty());
156 EXPECT_EQ("bar", params_list[1].first);
157 EXPECT_TRUE(params_list[1].second.empty());
158
159 TF_EXPECT_OK(ParseTransformParameters("foo() bar()", ¶ms_list));
160 EXPECT_EQ(2, params_list.size());
161 EXPECT_EQ("foo", params_list[0].first);
162 EXPECT_TRUE(params_list[0].second.empty());
163 EXPECT_EQ("bar", params_list[1].first);
164 EXPECT_TRUE(params_list[1].second.empty());
165
166 TF_EXPECT_OK(
167 ParseTransformParameters("foo(bob_something=sue)", ¶ms_list));
168 EXPECT_EQ(1, params_list.size());
169 EXPECT_EQ("foo", params_list[0].first);
170 EXPECT_EQ(1, params_list[0].second.count("bob_something"));
171 EXPECT_EQ(1, params_list[0].second["bob_something"].size());
172 EXPECT_EQ("sue", params_list[0].second["bob_something"][0]);
173
174 TF_EXPECT_OK(ParseTransformParameters("bar(a=1, b=2, a=3)", ¶ms_list));
175 EXPECT_EQ(1, params_list.size());
176 EXPECT_EQ("bar", params_list[0].first);
177 EXPECT_EQ(1, params_list[0].second.count("a"));
178 EXPECT_EQ(2, params_list[0].second["a"].size());
179 EXPECT_EQ("1", params_list[0].second["a"][0]);
180 EXPECT_EQ("3", params_list[0].second["a"][1]);
181 EXPECT_EQ(1, params_list[0].second.count("b"));
182 EXPECT_EQ(1, params_list[0].second["b"].size());
183 EXPECT_EQ("2", params_list[0].second["b"][0]);
184
185 TF_EXPECT_OK(ParseTransformParameters("bar(a=\"1\", b=\"1,2,3\", a=3)",
186 ¶ms_list));
187 EXPECT_EQ(1, params_list.size());
188 EXPECT_EQ("bar", params_list[0].first);
189 EXPECT_EQ(1, params_list[0].second.count("a"));
190 EXPECT_EQ(2, params_list[0].second["a"].size());
191 EXPECT_EQ("1", params_list[0].second["a"][0]);
192 EXPECT_EQ("3", params_list[0].second["a"][1]);
193 EXPECT_EQ(1, params_list[0].second.count("b"));
194 EXPECT_EQ(1, params_list[0].second["b"].size());
195 EXPECT_EQ("1,2,3", params_list[0].second["b"][0]);
196 }
197
TestParseEscapedNewline()198 void TestParseEscapedNewline() {
199 // This sequence of characters caused an infinite loop in the parser, which
200 // is responsible for the hang mentioned in
201 // https://github.com/tensorflow/tensorflow/issues/7150
202 TransformParameters params_list;
203 ParseTransformParameters("\\\n", ¶ms_list).IgnoreError();
204 EXPECT_EQ(0, params_list.size());
205 }
206
TestParseExtraSpaces()207 void TestParseExtraSpaces() {
208 TransformParameters params_list;
209 ParseTransformParameters(" ", ¶ms_list).IgnoreError();
210 EXPECT_EQ(0, params_list.size());
211
212 TF_EXPECT_OK(ParseTransformParameters(" foo bar \\\n", ¶ms_list));
213 EXPECT_EQ(2, params_list.size());
214 EXPECT_EQ("foo", params_list[0].first);
215 EXPECT_TRUE(params_list[0].second.empty());
216 EXPECT_EQ("bar", params_list[1].first);
217 EXPECT_TRUE(params_list[1].second.empty());
218 }
219
TestShouldIgnoreErrors()220 void TestShouldIgnoreErrors() {
221 bool ignore_errors;
222 TF_EXPECT_OK(
223 ShouldIgnoreErrors({{"ignore_errors", {"true"}}}, &ignore_errors));
224 EXPECT_TRUE(ignore_errors);
225
226 TF_EXPECT_OK(
227 ShouldIgnoreErrors({{"ignore_errors", {"false"}}}, &ignore_errors));
228 EXPECT_FALSE(ignore_errors);
229
230 TF_EXPECT_OK(ShouldIgnoreErrors({}, &ignore_errors));
231 EXPECT_FALSE(ignore_errors);
232
233 EXPECT_FALSE(
234 ShouldIgnoreErrors({{"ignore_errors", {"foo"}}}, &ignore_errors).ok());
235 }
236 };
237
TEST_F(TransformGraphTest,TestConstantFolding)238 TEST_F(TransformGraphTest, TestConstantFolding) { TestConstantFolding(); }
239
TEST_F(TransformGraphTest,TestTransformRegistration)240 TEST_F(TransformGraphTest, TestTransformRegistration) {
241 TestTransformRegistration();
242 }
243
TEST_F(TransformGraphTest,TestParseTransformParameters)244 TEST_F(TransformGraphTest, TestParseTransformParameters) {
245 TestParseTransformParameters();
246 }
247
TEST_F(TransformGraphTest,TestParseEscapedNewline)248 TEST_F(TransformGraphTest, TestParseEscapedNewline) {
249 TestParseEscapedNewline();
250 }
251
TEST_F(TransformGraphTest,TestShouldIgnoreErrors)252 TEST_F(TransformGraphTest, TestShouldIgnoreErrors) { TestShouldIgnoreErrors(); }
253
254 } // namespace graph_transforms
255 } // namespace tensorflow
256