1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28 namespace tensorflow {
29 namespace graph_transforms {
30
31 // Declare here, so we don't need a public header.
32 Status InsertLogging(const GraphDef& input_graph_def,
33 const TransformFuncContext& context,
34 GraphDef* output_graph_def);
35
36 class InsertLoggingTest : public ::testing::Test {
37 protected:
CheckGraphCanRun(const GraphDef & graph_def,const std::vector<string> & output_names)38 void CheckGraphCanRun(const GraphDef& graph_def,
39 const std::vector<string>& output_names) {
40 std::unique_ptr<Session> session(NewSession(SessionOptions()));
41 TF_ASSERT_OK(session->Create(graph_def));
42 std::vector<Tensor> outputs;
43 TF_ASSERT_OK(session->Run({}, output_names, {}, &outputs));
44 }
45
TestInsertLogging()46 void TestInsertLogging() {
47 auto root = tensorflow::Scope::NewRootScope();
48 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
49 Tensor const_tensor(DT_FLOAT, TensorShape({10}));
50 test::FillIota<float>(&const_tensor, 1.0f);
51 Output const_node1 =
52 Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
53 Output const_node2 =
54 Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
55 Output const_node3 =
56 Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
57 Output add_node2 =
58 Add(root.WithOpName("add_node2"), const_node1, const_node2);
59 Output add_node3 =
60 Add(root.WithOpName("add_node3"), const_node1, const_node3);
61 Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
62 Output add_node4 =
63 Add(root.WithOpName("add_node4"), mul_node1, const_node3);
64 GraphDef graph_def;
65 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
66 CheckGraphCanRun(graph_def, {"add_node4"});
67
68 GraphDef result;
69 TransformFuncContext context;
70 context.input_names = {};
71 context.output_names = {"add_node4"};
72 TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
73
74 CheckGraphCanRun(result, {"add_node4"});
75
76 std::unordered_set<string> print_inputs;
77 for (const NodeDef& node : result.node()) {
78 if (node.op() == "Print") {
79 print_inputs.insert(node.input(0));
80 }
81 }
82
83 EXPECT_EQ(6, print_inputs.size());
84 EXPECT_EQ(1, print_inputs.count("mul_node1:0"));
85 EXPECT_EQ(1, print_inputs.count("add_node2:0"));
86 EXPECT_EQ(1, print_inputs.count("add_node3:0"));
87 EXPECT_EQ(0, print_inputs.count("add_node4:0"));
88 EXPECT_EQ(1, print_inputs.count("const_node1:0"));
89 EXPECT_EQ(1, print_inputs.count("const_node2:0"));
90 EXPECT_EQ(1, print_inputs.count("const_node3:0"));
91 }
92
TestInsertLoggingByOpType()93 void TestInsertLoggingByOpType() {
94 auto root = tensorflow::Scope::NewRootScope();
95 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
96 Tensor const_tensor(DT_FLOAT, TensorShape({10}));
97 test::FillIota<float>(&const_tensor, 1.0f);
98 Output const_node1 =
99 Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
100 Output const_node2 =
101 Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
102 Output const_node3 =
103 Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
104 Output add_node2 =
105 Add(root.WithOpName("add_node2"), const_node1, const_node2);
106 Output add_node3 =
107 Add(root.WithOpName("add_node3"), const_node1, const_node3);
108 Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
109 Output add_node4 =
110 Add(root.WithOpName("add_node4"), mul_node1, const_node3);
111 GraphDef graph_def;
112 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
113 CheckGraphCanRun(graph_def, {"add_node4"});
114
115 GraphDef result;
116 TransformFuncContext context;
117 context.input_names = {};
118 context.output_names = {"add_node4"};
119 context.params.insert(
120 std::pair<string, std::vector<string>>({"op", {"Mul", "Add"}}));
121 TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
122
123 CheckGraphCanRun(result, {"add_node4"});
124
125 std::unordered_set<string> print_inputs;
126 for (const NodeDef& node : result.node()) {
127 if (node.op() == "Print") {
128 print_inputs.insert(node.input(0));
129 }
130 }
131
132 EXPECT_EQ(3, print_inputs.size());
133 EXPECT_EQ(1, print_inputs.count("mul_node1:0"));
134 EXPECT_EQ(1, print_inputs.count("add_node2:0"));
135 EXPECT_EQ(1, print_inputs.count("add_node3:0"));
136 EXPECT_EQ(0, print_inputs.count("add_node4:0"));
137 EXPECT_EQ(0, print_inputs.count("const_node1:0"));
138 EXPECT_EQ(0, print_inputs.count("const_node2:0"));
139 EXPECT_EQ(0, print_inputs.count("const_node3:0"));
140 }
141
TestInsertLoggingByPrefix()142 void TestInsertLoggingByPrefix() {
143 auto root = tensorflow::Scope::NewRootScope();
144 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
145 Tensor const_tensor(DT_FLOAT, TensorShape({10}));
146 test::FillIota<float>(&const_tensor, 1.0f);
147 Output const_node1 =
148 Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
149 Output const_node2 =
150 Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
151 Output const_node3 =
152 Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
153 Output add_node2 =
154 Add(root.WithOpName("add_node2"), const_node1, const_node2);
155 Output add_node3 =
156 Add(root.WithOpName("add_node3"), const_node1, const_node3);
157 Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
158 Output add_node4 =
159 Add(root.WithOpName("add_node4"), mul_node1, const_node3);
160 GraphDef graph_def;
161 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
162 CheckGraphCanRun(graph_def, {"add_node4"});
163
164 GraphDef result;
165 TransformFuncContext context;
166 context.input_names = {};
167 context.output_names = {"add_node4"};
168 context.params.insert(
169 std::pair<string, std::vector<string>>({"prefix", {"add_node"}}));
170 TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
171
172 CheckGraphCanRun(result, {"add_node4"});
173
174 std::unordered_set<string> print_inputs;
175 for (const NodeDef& node : result.node()) {
176 if (node.op() == "Print") {
177 print_inputs.insert(node.input(0));
178 }
179 }
180
181 EXPECT_EQ(2, print_inputs.size());
182 EXPECT_EQ(0, print_inputs.count("mul_node1:0"));
183 EXPECT_EQ(1, print_inputs.count("add_node2:0"));
184 EXPECT_EQ(1, print_inputs.count("add_node3:0"));
185 EXPECT_EQ(0, print_inputs.count("add_node4:0"));
186 EXPECT_EQ(0, print_inputs.count("const_node1:0"));
187 EXPECT_EQ(0, print_inputs.count("const_node2:0"));
188 EXPECT_EQ(0, print_inputs.count("const_node3:0"));
189 }
190 };
191
TEST_F(InsertLoggingTest,TestInsertLogging)192 TEST_F(InsertLoggingTest, TestInsertLogging) { TestInsertLogging(); }
193
TEST_F(InsertLoggingTest,TestInsertLoggingByOpType)194 TEST_F(InsertLoggingTest, TestInsertLoggingByOpType) {
195 TestInsertLoggingByOpType();
196 }
197
TEST_F(InsertLoggingTest,TestInsertLoggingByPrefix)198 TEST_F(InsertLoggingTest, TestInsertLoggingByPrefix) {
199 TestInsertLoggingByPrefix();
200 }
201
202 } // namespace graph_transforms
203 } // namespace tensorflow
204