1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28 namespace tensorflow {
29 namespace graph_transforms {
30
31 // Declare here, so we don't need a public header.
32 Status ObfuscateNames(const GraphDef& input_graph_def,
33 const TransformFuncContext& context,
34 GraphDef* output_graph_def);
35
36 class ObfuscateNamesTest : public ::testing::Test {
37 protected:
TestSimpleTree()38 void TestSimpleTree() {
39 GraphDef graph_def;
40
41 NodeDef* add_node1 = graph_def.add_node();
42 add_node1->set_name("add_node1");
43 add_node1->set_op("Add");
44 add_node1->add_input("add_node2");
45 add_node1->add_input("add_node3");
46
47 NodeDef* add_node2 = graph_def.add_node();
48 add_node2->set_name("add_node2");
49 add_node2->set_op("Add");
50 add_node2->add_input("const_node1");
51 add_node2->add_input("const_node2");
52
53 NodeDef* add_node3 = graph_def.add_node();
54 add_node3->set_name("add_node3");
55 add_node3->set_op("Add");
56 add_node3->add_input("const_node3");
57 add_node3->add_input("const_node4");
58
59 NodeDef* const_node1 = graph_def.add_node();
60 const_node1->set_name("const_node1");
61 const_node1->set_op("Const");
62
63 NodeDef* const_node2 = graph_def.add_node();
64 const_node2->set_name("const_node2");
65 const_node2->set_op("Const");
66
67 NodeDef* const_node3 = graph_def.add_node();
68 const_node3->set_name("const_node3");
69 const_node3->set_op("Const");
70
71 NodeDef* const_node4 = graph_def.add_node();
72 const_node4->set_name("const_node4");
73 const_node4->set_op("Const");
74
75 GraphDef result;
76 TF_ASSERT_OK(
77 ObfuscateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result));
78
79 std::map<string, const NodeDef*> node_lookup;
80 MapNamesToNodes(result, &node_lookup);
81
82 EXPECT_EQ(1, node_lookup.count("add_node1"));
83 EXPECT_EQ(0, node_lookup.count("add_node2"));
84 EXPECT_EQ(0, node_lookup.count("add_node3"));
85 EXPECT_EQ(1, node_lookup.count("const_node1"));
86 EXPECT_EQ(0, node_lookup.count("const_node2"));
87 EXPECT_EQ(0, node_lookup.count("const_node3"));
88 EXPECT_EQ(0, node_lookup.count("const_node4"));
89 }
90
TestManyNodes()91 void TestManyNodes() {
92 GraphDef graph_def;
93 for (int i = 0; i < 1000; ++i) {
94 NodeDef* const_node = graph_def.add_node();
95 const_node->set_name(strings::StrCat("const_node", i));
96 const_node->set_op("Const");
97 }
98
99 GraphDef result;
100 TF_ASSERT_OK(ObfuscateNames(graph_def, {{"const_node0"}, {"const_node999"}},
101 &result));
102
103 std::map<string, const NodeDef*> node_lookup;
104 MapNamesToNodes(result, &node_lookup);
105 EXPECT_EQ(1, node_lookup.count("const_node0"));
106 EXPECT_EQ(0, node_lookup.count("const_node500"));
107 EXPECT_EQ(1, node_lookup.count("const_node999"));
108 }
109
TestNameClashes()110 void TestNameClashes() {
111 GraphDef graph_def;
112 for (int i = 0; i < 1000; ++i) {
113 NodeDef* const_node = graph_def.add_node();
114 const_node->set_name(strings::StrCat("1", i));
115 const_node->set_op("Const");
116 }
117
118 GraphDef result;
119 TF_ASSERT_OK(ObfuscateNames(graph_def, {{"10"}, {"19"}}, &result));
120
121 std::map<string, const NodeDef*> node_lookup;
122 MapNamesToNodes(result, &node_lookup);
123 EXPECT_EQ(1, node_lookup.count("10"));
124 EXPECT_EQ(1, node_lookup.count("19"));
125
126 std::unordered_set<string> names;
127 for (const NodeDef& node : result.node()) {
128 EXPECT_EQ(0, names.count(node.name()))
129 << "Found multiple nodes with name '" << node.name() << "'";
130 names.insert(node.name());
131 }
132 }
133 };
134
TEST_F(ObfuscateNamesTest,TestSimpleTree)135 TEST_F(ObfuscateNamesTest, TestSimpleTree) { TestSimpleTree(); }
136
TEST_F(ObfuscateNamesTest,TestManyNodes)137 TEST_F(ObfuscateNamesTest, TestManyNodes) { TestManyNodes(); }
138
TEST_F(ObfuscateNamesTest,TestNameClashes)139 TEST_F(ObfuscateNamesTest, TestNameClashes) { TestNameClashes(); }
140
141 } // namespace graph_transforms
142 } // namespace tensorflow
143