• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27 
28 namespace tensorflow {
29 namespace graph_transforms {
30 
31 // Declare here, so we don't need a public header.
32 Status InsertLogging(const GraphDef& input_graph_def,
33                      const TransformFuncContext& context,
34                      GraphDef* output_graph_def);
35 
36 class InsertLoggingTest : public ::testing::Test {
37  protected:
CheckGraphCanRun(const GraphDef & graph_def,const std::vector<string> & output_names)38   void CheckGraphCanRun(const GraphDef& graph_def,
39                         const std::vector<string>& output_names) {
40     std::unique_ptr<Session> session(NewSession(SessionOptions()));
41     TF_ASSERT_OK(session->Create(graph_def));
42     std::vector<Tensor> outputs;
43     TF_ASSERT_OK(session->Run({}, output_names, {}, &outputs));
44   }
45 
TestInsertLogging()46   void TestInsertLogging() {
47     auto root = tensorflow::Scope::NewRootScope();
48     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
49     Tensor const_tensor(DT_FLOAT, TensorShape({10}));
50     test::FillIota<float>(&const_tensor, 1.0f);
51     Output const_node1 =
52         Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
53     Output const_node2 =
54         Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
55     Output const_node3 =
56         Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
57     Output add_node2 =
58         Add(root.WithOpName("add_node2"), const_node1, const_node2);
59     Output add_node3 =
60         Add(root.WithOpName("add_node3"), const_node1, const_node3);
61     Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
62     Output add_node4 =
63         Add(root.WithOpName("add_node4"), mul_node1, const_node3);
64     GraphDef graph_def;
65     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
66     CheckGraphCanRun(graph_def, {"add_node4"});
67 
68     GraphDef result;
69     TransformFuncContext context;
70     context.input_names = {};
71     context.output_names = {"add_node4"};
72     TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
73 
74     CheckGraphCanRun(result, {"add_node4"});
75 
76     std::unordered_set<string> print_inputs;
77     for (const NodeDef& node : result.node()) {
78       if (node.op() == "Print") {
79         print_inputs.insert(node.input(0));
80       }
81     }
82 
83     EXPECT_EQ(6, print_inputs.size());
84     EXPECT_EQ(1, print_inputs.count("mul_node1:0"));
85     EXPECT_EQ(1, print_inputs.count("add_node2:0"));
86     EXPECT_EQ(1, print_inputs.count("add_node3:0"));
87     EXPECT_EQ(0, print_inputs.count("add_node4:0"));
88     EXPECT_EQ(1, print_inputs.count("const_node1:0"));
89     EXPECT_EQ(1, print_inputs.count("const_node2:0"));
90     EXPECT_EQ(1, print_inputs.count("const_node3:0"));
91   }
92 
TestInsertLoggingByOpType()93   void TestInsertLoggingByOpType() {
94     auto root = tensorflow::Scope::NewRootScope();
95     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
96     Tensor const_tensor(DT_FLOAT, TensorShape({10}));
97     test::FillIota<float>(&const_tensor, 1.0f);
98     Output const_node1 =
99         Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
100     Output const_node2 =
101         Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
102     Output const_node3 =
103         Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
104     Output add_node2 =
105         Add(root.WithOpName("add_node2"), const_node1, const_node2);
106     Output add_node3 =
107         Add(root.WithOpName("add_node3"), const_node1, const_node3);
108     Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
109     Output add_node4 =
110         Add(root.WithOpName("add_node4"), mul_node1, const_node3);
111     GraphDef graph_def;
112     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
113     CheckGraphCanRun(graph_def, {"add_node4"});
114 
115     GraphDef result;
116     TransformFuncContext context;
117     context.input_names = {};
118     context.output_names = {"add_node4"};
119     context.params.insert(
120         std::pair<string, std::vector<string>>({"op", {"Mul", "Add"}}));
121     TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
122 
123     CheckGraphCanRun(result, {"add_node4"});
124 
125     std::unordered_set<string> print_inputs;
126     for (const NodeDef& node : result.node()) {
127       if (node.op() == "Print") {
128         print_inputs.insert(node.input(0));
129       }
130     }
131 
132     EXPECT_EQ(3, print_inputs.size());
133     EXPECT_EQ(1, print_inputs.count("mul_node1:0"));
134     EXPECT_EQ(1, print_inputs.count("add_node2:0"));
135     EXPECT_EQ(1, print_inputs.count("add_node3:0"));
136     EXPECT_EQ(0, print_inputs.count("add_node4:0"));
137     EXPECT_EQ(0, print_inputs.count("const_node1:0"));
138     EXPECT_EQ(0, print_inputs.count("const_node2:0"));
139     EXPECT_EQ(0, print_inputs.count("const_node3:0"));
140   }
141 
TestInsertLoggingByPrefix()142   void TestInsertLoggingByPrefix() {
143     auto root = tensorflow::Scope::NewRootScope();
144     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
145     Tensor const_tensor(DT_FLOAT, TensorShape({10}));
146     test::FillIota<float>(&const_tensor, 1.0f);
147     Output const_node1 =
148         Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
149     Output const_node2 =
150         Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
151     Output const_node3 =
152         Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
153     Output add_node2 =
154         Add(root.WithOpName("add_node2"), const_node1, const_node2);
155     Output add_node3 =
156         Add(root.WithOpName("add_node3"), const_node1, const_node3);
157     Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
158     Output add_node4 =
159         Add(root.WithOpName("add_node4"), mul_node1, const_node3);
160     GraphDef graph_def;
161     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
162     CheckGraphCanRun(graph_def, {"add_node4"});
163 
164     GraphDef result;
165     TransformFuncContext context;
166     context.input_names = {};
167     context.output_names = {"add_node4"};
168     context.params.insert(
169         std::pair<string, std::vector<string>>({"prefix", {"add_node"}}));
170     TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
171 
172     CheckGraphCanRun(result, {"add_node4"});
173 
174     std::unordered_set<string> print_inputs;
175     for (const NodeDef& node : result.node()) {
176       if (node.op() == "Print") {
177         print_inputs.insert(node.input(0));
178       }
179     }
180 
181     EXPECT_EQ(2, print_inputs.size());
182     EXPECT_EQ(0, print_inputs.count("mul_node1:0"));
183     EXPECT_EQ(1, print_inputs.count("add_node2:0"));
184     EXPECT_EQ(1, print_inputs.count("add_node3:0"));
185     EXPECT_EQ(0, print_inputs.count("add_node4:0"));
186     EXPECT_EQ(0, print_inputs.count("const_node1:0"));
187     EXPECT_EQ(0, print_inputs.count("const_node2:0"));
188     EXPECT_EQ(0, print_inputs.count("const_node3:0"));
189   }
190 };
191 
TEST_F(InsertLoggingTest,TestInsertLogging)192 TEST_F(InsertLoggingTest, TestInsertLogging) { TestInsertLogging(); }
193 
TEST_F(InsertLoggingTest,TestInsertLoggingByOpType)194 TEST_F(InsertLoggingTest, TestInsertLoggingByOpType) {
195   TestInsertLoggingByOpType();
196 }
197 
TEST_F(InsertLoggingTest,TestInsertLoggingByPrefix)198 TEST_F(InsertLoggingTest, TestInsertLoggingByPrefix) {
199   TestInsertLoggingByPrefix();
200 }
201 
202 }  // namespace graph_transforms
203 }  // namespace tensorflow
204