• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27 
28 namespace tensorflow {
29 namespace graph_transforms {
30 
31 // Declare here, so we don't need a public header.
32 Status ObfuscateNames(const GraphDef& input_graph_def,
33                       const TransformFuncContext& context,
34                       GraphDef* output_graph_def);
35 
36 class ObfuscateNamesTest : public ::testing::Test {
37  protected:
TestSimpleTree()38   void TestSimpleTree() {
39     GraphDef graph_def;
40 
41     NodeDef* add_node1 = graph_def.add_node();
42     add_node1->set_name("add_node1");
43     add_node1->set_op("Add");
44     add_node1->add_input("add_node2");
45     add_node1->add_input("add_node3");
46 
47     NodeDef* add_node2 = graph_def.add_node();
48     add_node2->set_name("add_node2");
49     add_node2->set_op("Add");
50     add_node2->add_input("const_node1");
51     add_node2->add_input("const_node2");
52 
53     NodeDef* add_node3 = graph_def.add_node();
54     add_node3->set_name("add_node3");
55     add_node3->set_op("Add");
56     add_node3->add_input("const_node3");
57     add_node3->add_input("const_node4");
58 
59     NodeDef* const_node1 = graph_def.add_node();
60     const_node1->set_name("const_node1");
61     const_node1->set_op("Const");
62 
63     NodeDef* const_node2 = graph_def.add_node();
64     const_node2->set_name("const_node2");
65     const_node2->set_op("Const");
66 
67     NodeDef* const_node3 = graph_def.add_node();
68     const_node3->set_name("const_node3");
69     const_node3->set_op("Const");
70 
71     NodeDef* const_node4 = graph_def.add_node();
72     const_node4->set_name("const_node4");
73     const_node4->set_op("Const");
74 
75     GraphDef result;
76     TF_ASSERT_OK(
77         ObfuscateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result));
78 
79     std::map<string, const NodeDef*> node_lookup;
80     MapNamesToNodes(result, &node_lookup);
81 
82     EXPECT_EQ(1, node_lookup.count("add_node1"));
83     EXPECT_EQ(0, node_lookup.count("add_node2"));
84     EXPECT_EQ(0, node_lookup.count("add_node3"));
85     EXPECT_EQ(1, node_lookup.count("const_node1"));
86     EXPECT_EQ(0, node_lookup.count("const_node2"));
87     EXPECT_EQ(0, node_lookup.count("const_node3"));
88     EXPECT_EQ(0, node_lookup.count("const_node4"));
89   }
90 
TestManyNodes()91   void TestManyNodes() {
92     GraphDef graph_def;
93     for (int i = 0; i < 1000; ++i) {
94       NodeDef* const_node = graph_def.add_node();
95       const_node->set_name(strings::StrCat("const_node", i));
96       const_node->set_op("Const");
97     }
98 
99     GraphDef result;
100     TF_ASSERT_OK(ObfuscateNames(graph_def, {{"const_node0"}, {"const_node999"}},
101                                 &result));
102 
103     std::map<string, const NodeDef*> node_lookup;
104     MapNamesToNodes(result, &node_lookup);
105     EXPECT_EQ(1, node_lookup.count("const_node0"));
106     EXPECT_EQ(0, node_lookup.count("const_node500"));
107     EXPECT_EQ(1, node_lookup.count("const_node999"));
108   }
109 
TestNameClashes()110   void TestNameClashes() {
111     GraphDef graph_def;
112     for (int i = 0; i < 1000; ++i) {
113       NodeDef* const_node = graph_def.add_node();
114       const_node->set_name(strings::StrCat("1", i));
115       const_node->set_op("Const");
116     }
117 
118     GraphDef result;
119     TF_ASSERT_OK(ObfuscateNames(graph_def, {{"10"}, {"19"}}, &result));
120 
121     std::map<string, const NodeDef*> node_lookup;
122     MapNamesToNodes(result, &node_lookup);
123     EXPECT_EQ(1, node_lookup.count("10"));
124     EXPECT_EQ(1, node_lookup.count("19"));
125 
126     std::unordered_set<string> names;
127     for (const NodeDef& node : result.node()) {
128       EXPECT_EQ(0, names.count(node.name()))
129           << "Found multiple nodes with name '" << node.name() << "'";
130       names.insert(node.name());
131     }
132   }
133 };
134 
TEST_F(ObfuscateNamesTest,TestSimpleTree)135 TEST_F(ObfuscateNamesTest, TestSimpleTree) { TestSimpleTree(); }
136 
TEST_F(ObfuscateNamesTest,TestManyNodes)137 TEST_F(ObfuscateNamesTest, TestManyNodes) { TestManyNodes(); }
138 
TEST_F(ObfuscateNamesTest,TestNameClashes)139 TEST_F(ObfuscateNamesTest, TestNameClashes) { TestNameClashes(); }
140 
141 }  // namespace graph_transforms
142 }  // namespace tensorflow
143