1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28 namespace tensorflow {
29 namespace graph_transforms {
30
31 // Declare here, so we don't need a public header.
32 Status RenameAttribute(const GraphDef& input_graph_def,
33 const TransformFuncContext& context,
34 GraphDef* output_graph_def);
35
36 class RenameAttributeTest : public ::testing::Test {
37 protected:
TestRenameAttribute()38 void TestRenameAttribute() {
39 GraphDef graph_def;
40
41 NodeDef* mul_node1 = graph_def.add_node();
42 mul_node1->set_name("mul_node1");
43 mul_node1->set_op("Mul");
44 mul_node1->add_input("add_node2");
45 mul_node1->add_input("add_node3");
46 AddNodeAttr("foo", 23, mul_node1);
47 AddNodeAttr("bar", "something", mul_node1);
48
49 NodeDef* add_node2 = graph_def.add_node();
50 add_node2->set_name("add_node2");
51 add_node2->set_op("Add");
52 add_node2->add_input("const_node1");
53 add_node2->add_input("const_node2");
54 AddNodeAttr("foo", 46, add_node2);
55 AddNodeAttr("bob", 23, add_node2);
56 AddNodeAttr("bar", "something else", add_node2);
57
58 NodeDef* add_node3 = graph_def.add_node();
59 add_node3->set_name("add_node3");
60 add_node3->set_op("Add");
61 add_node3->add_input("const_node1");
62 add_node3->add_input("const_node3");
63
64 NodeDef* const_node1 = graph_def.add_node();
65 const_node1->set_name("const_node1");
66 const_node1->set_op("Const");
67
68 NodeDef* const_node2 = graph_def.add_node();
69 const_node2->set_name("const_node2");
70 const_node2->set_op("Const");
71
72 NodeDef* const_node3 = graph_def.add_node();
73 const_node3->set_name("const_node3");
74 const_node3->set_op("Const");
75
76 NodeDef* add_node4 = graph_def.add_node();
77 add_node4->set_name("add_node4");
78 add_node4->set_op("Add");
79 add_node4->add_input("add_node2");
80 add_node4->add_input("add_node3");
81
82 GraphDef wildcard_result;
83 TransformFuncContext context;
84 context.input_names = {};
85 context.output_names = {"mul_node1"};
86 context.params.insert(
87 std::pair<string, std::vector<string>>({"op_name", {string("*")}}));
88 context.params.insert(std::pair<string, std::vector<string>>(
89 {"old_attribute_name", {string("foo")}}));
90 context.params.insert(std::pair<string, std::vector<string>>(
91 {"new_attribute_name", {string("baz")}}));
92 TF_ASSERT_OK(RenameAttribute(graph_def, context, &wildcard_result));
93
94 std::map<string, const NodeDef*> node_lookup;
95 MapNamesToNodes(wildcard_result, &node_lookup);
96 EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
97 EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("baz"));
98 EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
99 EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("foo"));
100 EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("baz"));
101 EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
102 EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
103
104 GraphDef targeted_result;
105 TransformFuncContext targeted_context;
106 targeted_context.input_names = {};
107 targeted_context.output_names = {"mul_node1"};
108 targeted_context.params.insert(
109 std::pair<string, std::vector<string>>({"op_name", {string("Mul")}}));
110 targeted_context.params.insert(std::pair<string, std::vector<string>>(
111 {"old_attribute_name", {string("foo")}}));
112 targeted_context.params.insert(std::pair<string, std::vector<string>>(
113 {"new_attribute_name", {string("baz")}}));
114 TF_ASSERT_OK(
115 RenameAttribute(graph_def, targeted_context, &targeted_result));
116
117 MapNamesToNodes(targeted_result, &node_lookup);
118 EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
119 EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("baz"));
120 EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
121 EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("foo"));
122 EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("baz"));
123 EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
124 EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
125 }
126 };
127
TEST_F(RenameAttributeTest,TestRenameAttribute)128 TEST_F(RenameAttributeTest, TestRenameAttribute) { TestRenameAttribute(); }
129
130 } // namespace graph_transforms
131 } // namespace tensorflow
132