1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
17
18 #include "tensorflow/core/framework/attr_value_util.h"
19 #include "tensorflow/core/framework/function_testlib.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
23 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
24
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27
28 namespace tensorflow {
29 namespace grappler {
30 namespace fusion_utils {
31 namespace {
32
ParseNodeConnection(const string & name)33 string ParseNodeConnection(const string &name) {
34 return name.substr(0, name.find(':'));
35 }
36
CheckUniqueNames(const FunctionDef & function)37 void CheckUniqueNames(const FunctionDef &function) {
38 std::unordered_set<string> inputs;
39 for (const auto &input_arg : function.signature().input_arg())
40 inputs.insert(input_arg.name());
41 EXPECT_EQ(inputs.size(), function.signature().input_arg_size());
42
43 std::unordered_set<string> outputs;
44 for (const auto &output_arg : function.signature().output_arg())
45 outputs.insert(output_arg.name());
46 EXPECT_EQ(outputs.size(), function.signature().output_arg_size());
47
48 std::unordered_set<string> nodes;
49 for (const auto &node : function.node_def()) nodes.insert(node.name());
50
51 EXPECT_EQ(nodes.size(), function.node_def_size());
52 }
53
TEST(FusionUtilsTest,FuseFunctionsByComposition)54 TEST(FusionUtilsTest, FuseFunctionsByComposition) {
55 GraphDef graph;
56 auto *parent_function = graph.mutable_library()->add_function();
57 *parent_function = test::function::XTimesTwo();
58 auto *function = graph.mutable_library()->add_function();
59 *function = test::function::XTimesTwo();
60
61 auto *fused_function = FuseFunctions(
62 *parent_function, *function, "fused_maps", fusion_utils::ComposeSignature,
63 fusion_utils::ComposeInput, fusion_utils::ComposeOutput,
64 fusion_utils::MergeNodes, graph.mutable_library());
65
66 EXPECT_EQ(fused_function->signature().name(), "fused_maps");
67 EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
68 EXPECT_EQ(fused_function->signature().output_arg_size(), 1);
69 EXPECT_EQ(fused_function->ret_size(), 1);
70 std::cerr << fused_function->DebugString();
71 CheckUniqueNames(*fused_function);
72
73 const NodeDef *parent_mul = nullptr, *output_mul = nullptr;
74 for (const auto &fused_node : fused_function->node_def()) {
75 if (fused_node.op() == "Mul") {
76 if (fused_node.name() == "y")
77 parent_mul = &fused_node;
78 else
79 output_mul = &fused_node;
80 }
81 }
82 ASSERT_NE(parent_mul, nullptr);
83 ASSERT_NE(output_mul, nullptr);
84 EXPECT_EQ(ParseNodeConnection(output_mul->input(0)), parent_mul->name());
85
86 auto output_value = fused_function->ret().at(
87 fused_function->signature().output_arg(0).name());
88
89 EXPECT_EQ(ParseNodeConnection(output_value), output_mul->name());
90 }
91
TEST(FusionUtilsTest,FuseFunctionWithPredicate)92 TEST(FusionUtilsTest, FuseFunctionWithPredicate) {
93 GraphDef graph;
94 auto *xtimes_two = graph.mutable_library()->add_function();
95 *xtimes_two = test::function::XTimesTwo();
96 auto *is_zero = graph.mutable_library()->add_function();
97 *is_zero = test::function::IsZero();
98
99 auto *fused_function =
100 FuseFunctions(*xtimes_two, *is_zero, "fused_map_and_filter_function",
101 fusion_utils::CombineSignature, fusion_utils::ComposeInput,
102 fusion_utils::CombineOutput, fusion_utils::MergeNodes,
103 graph.mutable_library());
104
105 EXPECT_EQ(fused_function->signature().name(),
106 "fused_map_and_filter_function");
107
108 EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
109 EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
110 EXPECT_EQ(fused_function->ret_size(), 2);
111 CheckUniqueNames(*fused_function);
112
113 ASSERT_TRUE(
114 function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
115 const auto &equal_node = fused_function->node_def(
116 function_utils::FindFunctionNodeWithOp("Equal", *fused_function));
117
118 EXPECT_EQ(xtimes_two->signature().output_arg(0).name(),
119 fused_function->signature().output_arg(0).name());
120
121 EXPECT_EQ(fused_function->signature().output_arg(1).name(),
122 equal_node.name());
123
124 EXPECT_EQ(ParseNodeConnection(equal_node.input(0)),
125 fused_function->signature().output_arg(0).name());
126
127 auto output_value = fused_function->ret().at(
128 fused_function->signature().output_arg(1).name());
129 EXPECT_EQ(ParseNodeConnection(output_value), equal_node.name());
130 }
131
TEST(FusionUtilsTest,FuseSameFunctionWithExtraOutput)132 TEST(FusionUtilsTest, FuseSameFunctionWithExtraOutput) {
133 GraphDef graph;
134 auto *parent_function = graph.mutable_library()->add_function();
135 *parent_function = test::function::XTimesTwo();
136 auto *function = graph.mutable_library()->add_function();
137 *function = test::function::XTimesTwo();
138
139 auto *fused_function = FuseFunctions(
140 *parent_function, *function, "fused_maps", fusion_utils::CombineSignature,
141 fusion_utils::ComposeInput, fusion_utils::CombineOutput,
142 fusion_utils::MergeNodes, graph.mutable_library());
143
144 EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
145 EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
146 EXPECT_EQ(fused_function->ret_size(), 2);
147 CheckUniqueNames(*fused_function);
148 }
149
TEST(FusionUtilsTest,ZipFusion)150 TEST(FusionUtilsTest, ZipFusion) {
151 GraphDef graph;
152 auto *function = graph.mutable_library()->add_function();
153 *function = test::function::XTimesTwo();
154
155 auto zip_signature = [](const OpDef &parent_function_signature,
156 const OpDef &function_signature,
157 OpDef *fused_function_signature) {
158 *fused_function_signature = parent_function_signature;
159 fused_function_signature->mutable_input_arg()->MergeFrom(
160 function_signature.input_arg());
161 fused_function_signature->mutable_output_arg()->MergeFrom(
162 function_signature.output_arg());
163 };
164
165 auto zip_input = [](const StringCollection &parent_inputs,
166 const StringCollection &function_inputs,
167 const StringCollection &parent_outputs, int arg_num) {
168 // Take corresponding parent output.
169 return function_inputs.at(arg_num);
170 };
171
172 auto *fused_function =
173 FuseFunctions(*function, *function, "zip_maps", zip_signature, zip_input,
174 fusion_utils::CombineOutput, fusion_utils::MergeNodes,
175 graph.mutable_library());
176
177 EXPECT_EQ(fused_function->signature().input_arg_size(), 2);
178 EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
179 EXPECT_EQ(fused_function->ret_size(), 2);
180 CheckUniqueNames(*fused_function);
181 }
182
183 } // namespace
184 } // namespace fusion_utils
185 } // namespace grappler
186 } // namespace tensorflow
187