• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
17 
18 #include "tensorflow/core/framework/attr_value_util.h"
19 #include "tensorflow/core/framework/function_testlib.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
23 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
24 
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27 
28 namespace tensorflow {
29 namespace grappler {
30 namespace fusion_utils {
31 namespace {
32 
ParseNodeConnection(const string & name)33 string ParseNodeConnection(const string &name) {
34   return name.substr(0, name.find(':'));
35 }
36 
CheckUniqueNames(const FunctionDef & function)37 void CheckUniqueNames(const FunctionDef &function) {
38   std::unordered_set<string> inputs;
39   for (const auto &input_arg : function.signature().input_arg())
40     inputs.insert(input_arg.name());
41   EXPECT_EQ(inputs.size(), function.signature().input_arg_size());
42 
43   std::unordered_set<string> outputs;
44   for (const auto &output_arg : function.signature().output_arg())
45     outputs.insert(output_arg.name());
46   EXPECT_EQ(outputs.size(), function.signature().output_arg_size());
47 
48   std::unordered_set<string> nodes;
49   for (const auto &node : function.node_def()) nodes.insert(node.name());
50 
51   EXPECT_EQ(nodes.size(), function.node_def_size());
52 }
53 
TEST(FusionUtilsTest,FuseFunctionsByComposition)54 TEST(FusionUtilsTest, FuseFunctionsByComposition) {
55   GraphDef graph;
56   auto *parent_function = graph.mutable_library()->add_function();
57   *parent_function = test::function::XTimesTwo();
58   auto *function = graph.mutable_library()->add_function();
59   *function = test::function::XTimesTwo();
60 
61   auto *fused_function = FuseFunctions(
62       *parent_function, *function, "fused_maps", fusion_utils::ComposeSignature,
63       fusion_utils::ComposeInput, fusion_utils::ComposeOutput,
64       fusion_utils::MergeNodes, graph.mutable_library());
65 
66   EXPECT_EQ(fused_function->signature().name(), "fused_maps");
67   EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
68   EXPECT_EQ(fused_function->signature().output_arg_size(), 1);
69   EXPECT_EQ(fused_function->ret_size(), 1);
70   std::cerr << fused_function->DebugString();
71   CheckUniqueNames(*fused_function);
72 
73   const NodeDef *parent_mul = nullptr, *output_mul = nullptr;
74   for (const auto &fused_node : fused_function->node_def()) {
75     if (fused_node.op() == "Mul") {
76       if (fused_node.name() == "y")
77         parent_mul = &fused_node;
78       else
79         output_mul = &fused_node;
80     }
81   }
82   ASSERT_NE(parent_mul, nullptr);
83   ASSERT_NE(output_mul, nullptr);
84   EXPECT_EQ(ParseNodeConnection(output_mul->input(0)), parent_mul->name());
85 
86   auto output_value = fused_function->ret().at(
87       fused_function->signature().output_arg(0).name());
88 
89   EXPECT_EQ(ParseNodeConnection(output_value), output_mul->name());
90 }
91 
TEST(FusionUtilsTest,FuseFunctionWithPredicate)92 TEST(FusionUtilsTest, FuseFunctionWithPredicate) {
93   GraphDef graph;
94   auto *xtimes_two = graph.mutable_library()->add_function();
95   *xtimes_two = test::function::XTimesTwo();
96   auto *is_zero = graph.mutable_library()->add_function();
97   *is_zero = test::function::IsZero();
98 
99   auto *fused_function =
100       FuseFunctions(*xtimes_two, *is_zero, "fused_map_and_filter_function",
101                     fusion_utils::CombineSignature, fusion_utils::ComposeInput,
102                     fusion_utils::CombineOutput, fusion_utils::MergeNodes,
103                     graph.mutable_library());
104 
105   EXPECT_EQ(fused_function->signature().name(),
106             "fused_map_and_filter_function");
107 
108   EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
109   EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
110   EXPECT_EQ(fused_function->ret_size(), 2);
111   CheckUniqueNames(*fused_function);
112 
113   ASSERT_TRUE(
114       function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
115   const auto &equal_node = fused_function->node_def(
116       function_utils::FindFunctionNodeWithOp("Equal", *fused_function));
117 
118   EXPECT_EQ(xtimes_two->signature().output_arg(0).name(),
119             fused_function->signature().output_arg(0).name());
120 
121   EXPECT_EQ(fused_function->signature().output_arg(1).name(),
122             equal_node.name());
123 
124   EXPECT_EQ(ParseNodeConnection(equal_node.input(0)),
125             fused_function->signature().output_arg(0).name());
126 
127   auto output_value = fused_function->ret().at(
128       fused_function->signature().output_arg(1).name());
129   EXPECT_EQ(ParseNodeConnection(output_value), equal_node.name());
130 }
131 
TEST(FusionUtilsTest,FuseSameFunctionWithExtraOutput)132 TEST(FusionUtilsTest, FuseSameFunctionWithExtraOutput) {
133   GraphDef graph;
134   auto *parent_function = graph.mutable_library()->add_function();
135   *parent_function = test::function::XTimesTwo();
136   auto *function = graph.mutable_library()->add_function();
137   *function = test::function::XTimesTwo();
138 
139   auto *fused_function = FuseFunctions(
140       *parent_function, *function, "fused_maps", fusion_utils::CombineSignature,
141       fusion_utils::ComposeInput, fusion_utils::CombineOutput,
142       fusion_utils::MergeNodes, graph.mutable_library());
143 
144   EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
145   EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
146   EXPECT_EQ(fused_function->ret_size(), 2);
147   CheckUniqueNames(*fused_function);
148 }
149 
TEST(FusionUtilsTest,ZipFusion)150 TEST(FusionUtilsTest, ZipFusion) {
151   GraphDef graph;
152   auto *function = graph.mutable_library()->add_function();
153   *function = test::function::XTimesTwo();
154 
155   auto zip_signature = [](const OpDef &parent_function_signature,
156                           const OpDef &function_signature,
157                           OpDef *fused_function_signature) {
158     *fused_function_signature = parent_function_signature;
159     fused_function_signature->mutable_input_arg()->MergeFrom(
160         function_signature.input_arg());
161     fused_function_signature->mutable_output_arg()->MergeFrom(
162         function_signature.output_arg());
163   };
164 
165   auto zip_input = [](const StringCollection &parent_inputs,
166                       const StringCollection &function_inputs,
167                       const StringCollection &parent_outputs, int arg_num) {
168     // Take corresponding parent output.
169     return function_inputs.at(arg_num);
170   };
171 
172   auto *fused_function =
173       FuseFunctions(*function, *function, "zip_maps", zip_signature, zip_input,
174                     fusion_utils::CombineOutput, fusion_utils::MergeNodes,
175                     graph.mutable_library());
176 
177   EXPECT_EQ(fused_function->signature().input_arg_size(), 2);
178   EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
179   EXPECT_EQ(fused_function->ret_size(), 2);
180   CheckUniqueNames(*fused_function);
181 }
182 
183 }  // namespace
184 }  // namespace fusion_utils
185 }  // namespace grappler
186 }  // namespace tensorflow
187