• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27 
28 namespace tensorflow {
29 namespace graph_transforms {
30 
31 // Declare here, so we don't need a public header.
32 Status RenameAttribute(const GraphDef& input_graph_def,
33                        const TransformFuncContext& context,
34                        GraphDef* output_graph_def);
35 
36 class RenameAttributeTest : public ::testing::Test {
37  protected:
TestRenameAttribute()38   void TestRenameAttribute() {
39     GraphDef graph_def;
40 
41     NodeDef* mul_node1 = graph_def.add_node();
42     mul_node1->set_name("mul_node1");
43     mul_node1->set_op("Mul");
44     mul_node1->add_input("add_node2");
45     mul_node1->add_input("add_node3");
46     AddNodeAttr("foo", 23, mul_node1);
47     AddNodeAttr("bar", "something", mul_node1);
48 
49     NodeDef* add_node2 = graph_def.add_node();
50     add_node2->set_name("add_node2");
51     add_node2->set_op("Add");
52     add_node2->add_input("const_node1");
53     add_node2->add_input("const_node2");
54     AddNodeAttr("foo", 46, add_node2);
55     AddNodeAttr("bob", 23, add_node2);
56     AddNodeAttr("bar", "something else", add_node2);
57 
58     NodeDef* add_node3 = graph_def.add_node();
59     add_node3->set_name("add_node3");
60     add_node3->set_op("Add");
61     add_node3->add_input("const_node1");
62     add_node3->add_input("const_node3");
63 
64     NodeDef* const_node1 = graph_def.add_node();
65     const_node1->set_name("const_node1");
66     const_node1->set_op("Const");
67 
68     NodeDef* const_node2 = graph_def.add_node();
69     const_node2->set_name("const_node2");
70     const_node2->set_op("Const");
71 
72     NodeDef* const_node3 = graph_def.add_node();
73     const_node3->set_name("const_node3");
74     const_node3->set_op("Const");
75 
76     NodeDef* add_node4 = graph_def.add_node();
77     add_node4->set_name("add_node4");
78     add_node4->set_op("Add");
79     add_node4->add_input("add_node2");
80     add_node4->add_input("add_node3");
81 
82     GraphDef wildcard_result;
83     TransformFuncContext context;
84     context.input_names = {};
85     context.output_names = {"mul_node1"};
86     context.params.insert(
87         std::pair<string, std::vector<string>>({"op_name", {string("*")}}));
88     context.params.insert(std::pair<string, std::vector<string>>(
89         {"old_attribute_name", {string("foo")}}));
90     context.params.insert(std::pair<string, std::vector<string>>(
91         {"new_attribute_name", {string("baz")}}));
92     TF_ASSERT_OK(RenameAttribute(graph_def, context, &wildcard_result));
93 
94     std::map<string, const NodeDef*> node_lookup;
95     MapNamesToNodes(wildcard_result, &node_lookup);
96     EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
97     EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("baz"));
98     EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
99     EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("foo"));
100     EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("baz"));
101     EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
102     EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
103 
104     GraphDef targeted_result;
105     TransformFuncContext targeted_context;
106     targeted_context.input_names = {};
107     targeted_context.output_names = {"mul_node1"};
108     targeted_context.params.insert(
109         std::pair<string, std::vector<string>>({"op_name", {string("Mul")}}));
110     targeted_context.params.insert(std::pair<string, std::vector<string>>(
111         {"old_attribute_name", {string("foo")}}));
112     targeted_context.params.insert(std::pair<string, std::vector<string>>(
113         {"new_attribute_name", {string("baz")}}));
114     TF_ASSERT_OK(
115         RenameAttribute(graph_def, targeted_context, &targeted_result));
116 
117     MapNamesToNodes(targeted_result, &node_lookup);
118     EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
119     EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("baz"));
120     EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
121     EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("foo"));
122     EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("baz"));
123     EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
124     EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
125   }
126 };
127 
TEST_F(RenameAttributeTest,TestRenameAttribute)128 TEST_F(RenameAttributeTest, TestRenameAttribute) { TestRenameAttribute(); }
129 
130 }  // namespace graph_transforms
131 }  // namespace tensorflow
132