• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/graph/default_device.h"
25 #include "tensorflow/core/graph/node_builder.h"
26 #include "tensorflow/core/graph/testlib.h"
27 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
28 #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
29 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/public/session.h"
34 #include "tensorflow/tools/graph_transforms/transform_utils.h"
35 
36 namespace tensorflow {
37 namespace graph_transforms {
38 
39 // Declared here so we don't have to put it in a public header.
40 Status FuseRemoteGraph(const GraphDef& input_graph_def,
41                        const TransformFuncContext& context,
42                        GraphDef* output_graph_def);
43 
44 Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
45                                  const TransformFuncContext& context,
46                                  GraphDef* output_graph_def);
47 
48 namespace {
49 constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
50     "remote_fused_graph_executor_name";
51 constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME =
52     "remote_fused_graph_node_name";
53 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
54     "fuse_test_remote_fused_graph_executor0";
55 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
56     "fuse_test_remote_fused_graph_executor1";
57 
BuildRemoteFusedGraphExecutor0(std::unique_ptr<IRemoteFusedGraphExecutor> * executor)58 Status BuildRemoteFusedGraphExecutor0(
59     std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
60   executor->reset(
61       new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
62   return Status::OK();
63 }
64 
BuildRemoteFusedGraphExecutor1(std::unique_ptr<IRemoteFusedGraphExecutor> * executor)65 Status BuildRemoteFusedGraphExecutor1(
66     std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
67   executor->reset(new TestRemoteFusedGraphExecutor(
68       {"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
69   return Status::OK();
70 }
71 
72 class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
73  protected:
SetUp()74   void SetUp() final {
75     TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(
76         &input_graph_def_));
77     RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
78         hexagon_remote_fused_graph_executor_build(
79             REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
80             [](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
81               return Status::OK();
82             });
83     RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
84         test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
85                                                 BuildRemoteFusedGraphExecutor0);
86 
87     RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
88         test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
89                                                 BuildRemoteFusedGraphExecutor1);
90   }
91 
TearDown()92   void TearDown() final {}
93 
Fuse()94   Status Fuse() { return FuseInternal(/*only_place_args=*/false); }
95 
PlaceFuseArgs()96   Status PlaceFuseArgs() { return FuseInternal(/*only_place_args*/ true); }
97 
FuseWithPlacedArgs()98   Status FuseWithPlacedArgs() {
99     const std::vector<std::pair<string, Tensor>> input_tensors{
100         {"A", {DT_FLOAT, {1, 1, 1, 1}}}};
101     return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
102         input_graph_def_with_fuse_args_, input_tensors, &output_graph_def_);
103   }
104 
FuseInternal(bool only_place_args)105   Status FuseInternal(bool only_place_args) {
106     TransformFuncContext context;
107     context.input_names = inputs_;
108     context.output_names = outputs_;
109 
110     if (!input_types_.empty()) {
111       context.params.insert(std::pair<string, std::vector<string>>(
112           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES,
113            {input_types_}}));
114     }
115     if (!input_shapes_.empty()) {
116       context.params.insert(std::pair<string, std::vector<string>>(
117           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES,
118            {input_shapes_}}));
119     }
120     if (!fused_node_names_str_.empty()) {
121       context.params.insert(std::pair<string, std::vector<string>>(
122           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES,
123            {fused_node_names_str_}}));
124     }
125 
126     if (!border_inputs_str_.empty()) {
127       context.params.insert(std::pair<string, std::vector<string>>(
128           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS,
129            {border_inputs_str_}}));
130     }
131     if (!border_outputs_str_.empty()) {
132       context.params.insert(std::pair<string, std::vector<string>>(
133           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS,
134            {border_outputs_str_}}));
135     }
136 
137     if (!fused_op_types_str_.empty()) {
138       context.params.insert(std::pair<string, std::vector<string>>(
139           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES,
140            {fused_op_types_str_}}));
141     }
142 
143     if (fuse_by_executor_) {
144       context.params.insert(std::pair<string, std::vector<string>>(
145           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR,
146            {"true"}}));
147     }
148 
149     context.params.insert(std::pair<string, std::vector<string>>(
150         {RemoteFusedGraphExecuteUtils::
151              TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
152          {remote_fused_graph_executor_name_}}));
153     context.params.insert(std::pair<string, std::vector<string>>(
154         {RemoteFusedGraphExecuteUtils::
155              TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
156          {REMOTE_FUSED_GRAPH_NODE_NAME}}));
157 
158     if (only_place_args) {
159       return PlaceRemoteGraphArguments(input_graph_def_, context,
160                                        &input_graph_def_with_fuse_args_);
161     } else {
162       return FuseRemoteGraph(input_graph_def_, context, &output_graph_def_);
163     }
164   }
165 
SetInputShapeType()166   void SetInputShapeType() {
167     input_types_ = "float";
168     input_shapes_ = "1,1,1,1";
169   }
170 
ReplaceOpType(const std::unordered_set<string> & op_name,const string & new_op_type)171   void ReplaceOpType(const std::unordered_set<string>& op_name,
172                      const string& new_op_type) {
173     for (NodeDef& node_def : *input_graph_def_.mutable_node()) {
174       if (op_name.count(node_def.name()) > 0) {
175         node_def.set_op(new_op_type);
176       }
177     }
178   }
179 
CheckGraph(int expected_node_count,int expected_cluster_count)180   void CheckGraph(int expected_node_count, int expected_cluster_count) {
181     EXPECT_EQ(expected_node_count, output_graph_def_.node_size());
182 
183     int cluster_count = 0;
184     for (const NodeDef& node_def : output_graph_def_.node()) {
185       const string& name = node_def.name();
186       if (absl::StartsWith(name, REMOTE_FUSED_GRAPH_NODE_NAME)) {
187         ++cluster_count;
188         RemoteFusedGraphExecuteInfo info;
189         string serialized_proto;
190         TF_ASSERT_OK(
191             GetNodeAttr(node_def,
192                         RemoteFusedGraphExecuteUtils::
193                             ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
194                         &serialized_proto));
195         info.ParseFromString(serialized_proto);
196         CHECK_EQ(remote_fused_graph_executor_name_, info.executor_name());
197       }
198     }
199     EXPECT_EQ(expected_cluster_count, cluster_count);
200   }
201 
202  public:
203   const std::vector<string> inputs_{"A"};
204   const std::vector<string> outputs_{"K"};
205   GraphDef input_graph_def_;
206   string input_types_;
207   string input_shapes_;
208   GraphDef input_graph_def_with_fuse_args_;
209   GraphDef output_graph_def_;
210   string fused_node_names_str_;
211   string border_inputs_str_;
212   string border_outputs_str_;
213   string fused_op_types_str_;
214   string remote_fused_graph_executor_name_{REMOTE_FUSED_GRAPH_EXECUTOR_NAME};
215   bool fuse_by_executor_{false};
216 };
217 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByNodesWithShapeType_HIJ)218 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
219        FuseRemoteGraphByNodesWithShapeType_HIJ) {
220   SetInputShapeType();
221   fused_node_names_str_ = "H,I,J";
222   TF_ASSERT_OK(Fuse());
223   CheckGraph(9, 1);
224 }
225 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByNodesWithoutShapeType_HIJ)226 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
227        FuseRemoteGraphByNodesWithoutShapeType_HIJ) {
228   fused_node_names_str_ = "H,I,J";
229   TF_ASSERT_OK(Fuse());
230   CheckGraph(9, 1);
231 }
232 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByNodesWithShapeType_ABCDEFGHIJK)233 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
234        FuseRemoteGraphByNodesWithShapeType_ABCDEFGHIJK) {
235   SetInputShapeType();
236   fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
237   TF_ASSERT_OK(Fuse());
238   CheckGraph(3, 1);
239 }
240 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByNodesWithoutShapeType_ABCDEFGHIJK)241 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
242        FuseRemoteGraphByNodesWithoutShapeType_ABCDEFGHIJK) {
243   fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
244   TF_ASSERT_OK(Fuse());
245   CheckGraph(3, 1);
246 }
247 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByBorderWithShapeType_FCG_J)248 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
249        FuseRemoteGraphByBorderWithShapeType_FCG_J) {
250   SetInputShapeType();
251   border_inputs_str_ = "F:0,C:0,G";
252   border_outputs_str_ = "J:0";
253   TF_ASSERT_OK(Fuse());
254   CheckGraph(9, 1);
255 }
256 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByBorderWithoutShapeType_FCG_J)257 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
258        FuseRemoteGraphByBorderWithoutShapeType_FCG_J) {
259   border_inputs_str_ = "F:0,C:0,G";
260   border_outputs_str_ = "J:0";
261   TF_ASSERT_OK(Fuse());
262   CheckGraph(9, 1);
263 }
264 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByBorderWithShapeType_ABCDE_K)265 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
266        FuseRemoteGraphByBorderWithShapeType_ABCDE_K) {
267   SetInputShapeType();
268   border_inputs_str_ = "A,B,C,D,E";
269   border_outputs_str_ = "K";
270   TF_ASSERT_OK(Fuse());
271   CheckGraph(7, 1);
272 }
273 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByBorderWithoutShapeType_ABCDE_K)274 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
275        FuseRemoteGraphByBorderWithoutShapeType_ABCDE_K) {
276   border_inputs_str_ = "A,B,C,D,E";
277   border_outputs_str_ = "K";
278   TF_ASSERT_OK(Fuse());
279   CheckGraph(7, 1);
280 }
281 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByOpTypes_HIJ)282 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
283        FuseRemoteGraphByOpTypes_HIJ) {
284   ReplaceOpType({"H", "I", "J"}, "Mul");
285   fused_op_types_str_ = "Mul";
286   TF_ASSERT_OK(Fuse());
287   CheckGraph(9, 1);
288 }
289 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByOpTypes_FGHIJ)290 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
291        FuseRemoteGraphByOpTypes_FGHIJ) {
292   ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
293   fused_op_types_str_ = "Const,Mul";
294   TF_ASSERT_OK(Fuse());
295   CheckGraph(3, 1);
296 }
297 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByExecutor_HIJ)298 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
299        FuseRemoteGraphByExecutor_HIJ) {
300   ReplaceOpType({"H", "I", "J"}, "Mul");
301   remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME0;
302   fuse_by_executor_ = true;
303   TF_ASSERT_OK(Fuse());
304   CheckGraph(9, 1);
305 }
306 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByExecutor_FGHIJ)307 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
308        FuseRemoteGraphByExecutor_FGHIJ) {
309   ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
310   remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME1;
311   fuse_by_executor_ = true;
312   TF_ASSERT_OK(Fuse());
313   CheckGraph(3, 1);
314 }
315 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_HIJ)316 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_HIJ) {
317   fused_node_names_str_ = "H,I,J";
318   TF_ASSERT_OK(PlaceFuseArgs());
319   TF_ASSERT_OK(FuseWithPlacedArgs());
320   CheckGraph(9, 1);
321 }
322 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_ABCDEFGHIJK)323 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_ABCDEFGHIJK) {
324   fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
325   TF_ASSERT_OK(PlaceFuseArgs());
326   TF_ASSERT_OK(FuseWithPlacedArgs());
327   CheckGraph(3, 1);
328 }
329 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_FCG_J)330 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_FCG_J) {
331   border_inputs_str_ = "F:0,C:0,G";
332   border_outputs_str_ = "J:0";
333   TF_ASSERT_OK(PlaceFuseArgs());
334   TF_ASSERT_OK(FuseWithPlacedArgs());
335   CheckGraph(9, 1);
336 }
337 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_ABCDE_K)338 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_ABCDE_K) {
339   SetInputShapeType();
340   border_inputs_str_ = "A,B,C,D,E";
341   border_outputs_str_ = "K";
342   TF_ASSERT_OK(PlaceFuseArgs());
343   TF_ASSERT_OK(FuseWithPlacedArgs());
344   CheckGraph(7, 1);
345 }
346 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_MUL_HIJ)347 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_MUL_HIJ) {
348   SetInputShapeType();
349   ReplaceOpType({"H", "I", "J"}, "Mul");
350   fused_op_types_str_ = "Mul";
351 
352   TF_ASSERT_OK(PlaceFuseArgs());
353   TF_ASSERT_OK(FuseWithPlacedArgs());
354   CheckGraph(9, 1);
355 }
356 
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_CONST_MUL_FGHIJ)357 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
358        PlaceAndFuse_CONST_MUL_FGHIJ) {
359   SetInputShapeType();
360   ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
361   fused_op_types_str_ = "Const,Mul";
362 
363   TF_ASSERT_OK(PlaceFuseArgs());
364   TF_ASSERT_OK(FuseWithPlacedArgs());
365   CheckGraph(3, 1);
366 }
367 
368 }  // namespace
369 }  // namespace graph_transforms
370 }  // namespace tensorflow
371