1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/graph/default_device.h"
25 #include "tensorflow/core/graph/node_builder.h"
26 #include "tensorflow/core/graph/testlib.h"
27 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
28 #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
29 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/public/session.h"
34 #include "tensorflow/tools/graph_transforms/transform_utils.h"
35
36 namespace tensorflow {
37 namespace graph_transforms {
38
39 // Declared here so we don't have to put it in a public header.
40 Status FuseRemoteGraph(const GraphDef& input_graph_def,
41 const TransformFuncContext& context,
42 GraphDef* output_graph_def);
43
44 Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
45 const TransformFuncContext& context,
46 GraphDef* output_graph_def);
47
48 namespace {
49 constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
50 "remote_fused_graph_executor_name";
51 constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME =
52 "remote_fused_graph_node_name";
53 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
54 "fuse_test_remote_fused_graph_executor0";
55 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
56 "fuse_test_remote_fused_graph_executor1";
57
BuildRemoteFusedGraphExecutor0(std::unique_ptr<IRemoteFusedGraphExecutor> * executor)58 Status BuildRemoteFusedGraphExecutor0(
59 std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
60 executor->reset(
61 new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
62 return Status::OK();
63 }
64
BuildRemoteFusedGraphExecutor1(std::unique_ptr<IRemoteFusedGraphExecutor> * executor)65 Status BuildRemoteFusedGraphExecutor1(
66 std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
67 executor->reset(new TestRemoteFusedGraphExecutor(
68 {"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
69 return Status::OK();
70 }
71
72 class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
73 protected:
SetUp()74 void SetUp() final {
75 TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(
76 &input_graph_def_));
77 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
78 hexagon_remote_fused_graph_executor_build(
79 REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
80 [](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
81 return Status::OK();
82 });
83 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
84 test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
85 BuildRemoteFusedGraphExecutor0);
86
87 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
88 test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
89 BuildRemoteFusedGraphExecutor1);
90 }
91
TearDown()92 void TearDown() final {}
93
Fuse()94 Status Fuse() { return FuseInternal(/*only_place_args=*/false); }
95
PlaceFuseArgs()96 Status PlaceFuseArgs() { return FuseInternal(/*only_place_args*/ true); }
97
FuseWithPlacedArgs()98 Status FuseWithPlacedArgs() {
99 const std::vector<std::pair<string, Tensor>> input_tensors{
100 {"A", {DT_FLOAT, {1, 1, 1, 1}}}};
101 return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
102 input_graph_def_with_fuse_args_, input_tensors, &output_graph_def_);
103 }
104
FuseInternal(bool only_place_args)105 Status FuseInternal(bool only_place_args) {
106 TransformFuncContext context;
107 context.input_names = inputs_;
108 context.output_names = outputs_;
109
110 if (!input_types_.empty()) {
111 context.params.insert(std::pair<string, std::vector<string>>(
112 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES,
113 {input_types_}}));
114 }
115 if (!input_shapes_.empty()) {
116 context.params.insert(std::pair<string, std::vector<string>>(
117 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES,
118 {input_shapes_}}));
119 }
120 if (!fused_node_names_str_.empty()) {
121 context.params.insert(std::pair<string, std::vector<string>>(
122 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES,
123 {fused_node_names_str_}}));
124 }
125
126 if (!border_inputs_str_.empty()) {
127 context.params.insert(std::pair<string, std::vector<string>>(
128 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS,
129 {border_inputs_str_}}));
130 }
131 if (!border_outputs_str_.empty()) {
132 context.params.insert(std::pair<string, std::vector<string>>(
133 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS,
134 {border_outputs_str_}}));
135 }
136
137 if (!fused_op_types_str_.empty()) {
138 context.params.insert(std::pair<string, std::vector<string>>(
139 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES,
140 {fused_op_types_str_}}));
141 }
142
143 if (fuse_by_executor_) {
144 context.params.insert(std::pair<string, std::vector<string>>(
145 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR,
146 {"true"}}));
147 }
148
149 context.params.insert(std::pair<string, std::vector<string>>(
150 {RemoteFusedGraphExecuteUtils::
151 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
152 {remote_fused_graph_executor_name_}}));
153 context.params.insert(std::pair<string, std::vector<string>>(
154 {RemoteFusedGraphExecuteUtils::
155 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
156 {REMOTE_FUSED_GRAPH_NODE_NAME}}));
157
158 if (only_place_args) {
159 return PlaceRemoteGraphArguments(input_graph_def_, context,
160 &input_graph_def_with_fuse_args_);
161 } else {
162 return FuseRemoteGraph(input_graph_def_, context, &output_graph_def_);
163 }
164 }
165
SetInputShapeType()166 void SetInputShapeType() {
167 input_types_ = "float";
168 input_shapes_ = "1,1,1,1";
169 }
170
ReplaceOpType(const std::unordered_set<string> & op_name,const string & new_op_type)171 void ReplaceOpType(const std::unordered_set<string>& op_name,
172 const string& new_op_type) {
173 for (NodeDef& node_def : *input_graph_def_.mutable_node()) {
174 if (op_name.count(node_def.name()) > 0) {
175 node_def.set_op(new_op_type);
176 }
177 }
178 }
179
CheckGraph(int expected_node_count,int expected_cluster_count)180 void CheckGraph(int expected_node_count, int expected_cluster_count) {
181 EXPECT_EQ(expected_node_count, output_graph_def_.node_size());
182
183 int cluster_count = 0;
184 for (const NodeDef& node_def : output_graph_def_.node()) {
185 const string& name = node_def.name();
186 if (absl::StartsWith(name, REMOTE_FUSED_GRAPH_NODE_NAME)) {
187 ++cluster_count;
188 RemoteFusedGraphExecuteInfo info;
189 string serialized_proto;
190 TF_ASSERT_OK(
191 GetNodeAttr(node_def,
192 RemoteFusedGraphExecuteUtils::
193 ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
194 &serialized_proto));
195 info.ParseFromString(serialized_proto);
196 CHECK_EQ(remote_fused_graph_executor_name_, info.executor_name());
197 }
198 }
199 EXPECT_EQ(expected_cluster_count, cluster_count);
200 }
201
202 public:
203 const std::vector<string> inputs_{"A"};
204 const std::vector<string> outputs_{"K"};
205 GraphDef input_graph_def_;
206 string input_types_;
207 string input_shapes_;
208 GraphDef input_graph_def_with_fuse_args_;
209 GraphDef output_graph_def_;
210 string fused_node_names_str_;
211 string border_inputs_str_;
212 string border_outputs_str_;
213 string fused_op_types_str_;
214 string remote_fused_graph_executor_name_{REMOTE_FUSED_GRAPH_EXECUTOR_NAME};
215 bool fuse_by_executor_{false};
216 };
217
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByNodesWithShapeType_HIJ)218 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
219 FuseRemoteGraphByNodesWithShapeType_HIJ) {
220 SetInputShapeType();
221 fused_node_names_str_ = "H,I,J";
222 TF_ASSERT_OK(Fuse());
223 CheckGraph(9, 1);
224 }
225
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByNodesWithoutShapeType_HIJ)226 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
227 FuseRemoteGraphByNodesWithoutShapeType_HIJ) {
228 fused_node_names_str_ = "H,I,J";
229 TF_ASSERT_OK(Fuse());
230 CheckGraph(9, 1);
231 }
232
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByNodesWithShapeType_ABCDEFGHIJK)233 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
234 FuseRemoteGraphByNodesWithShapeType_ABCDEFGHIJK) {
235 SetInputShapeType();
236 fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
237 TF_ASSERT_OK(Fuse());
238 CheckGraph(3, 1);
239 }
240
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByNodesWithoutShapeType_ABCDEFGHIJK)241 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
242 FuseRemoteGraphByNodesWithoutShapeType_ABCDEFGHIJK) {
243 fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
244 TF_ASSERT_OK(Fuse());
245 CheckGraph(3, 1);
246 }
247
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByBorderWithShapeType_FCG_J)248 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
249 FuseRemoteGraphByBorderWithShapeType_FCG_J) {
250 SetInputShapeType();
251 border_inputs_str_ = "F:0,C:0,G";
252 border_outputs_str_ = "J:0";
253 TF_ASSERT_OK(Fuse());
254 CheckGraph(9, 1);
255 }
256
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByBorderWithoutShapeType_FCG_J)257 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
258 FuseRemoteGraphByBorderWithoutShapeType_FCG_J) {
259 border_inputs_str_ = "F:0,C:0,G";
260 border_outputs_str_ = "J:0";
261 TF_ASSERT_OK(Fuse());
262 CheckGraph(9, 1);
263 }
264
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByBorderWithShapeType_ABCDE_K)265 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
266 FuseRemoteGraphByBorderWithShapeType_ABCDE_K) {
267 SetInputShapeType();
268 border_inputs_str_ = "A,B,C,D,E";
269 border_outputs_str_ = "K";
270 TF_ASSERT_OK(Fuse());
271 CheckGraph(7, 1);
272 }
273
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByBorderWithoutShapeType_ABCDE_K)274 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
275 FuseRemoteGraphByBorderWithoutShapeType_ABCDE_K) {
276 border_inputs_str_ = "A,B,C,D,E";
277 border_outputs_str_ = "K";
278 TF_ASSERT_OK(Fuse());
279 CheckGraph(7, 1);
280 }
281
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByOpTypes_HIJ)282 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
283 FuseRemoteGraphByOpTypes_HIJ) {
284 ReplaceOpType({"H", "I", "J"}, "Mul");
285 fused_op_types_str_ = "Mul";
286 TF_ASSERT_OK(Fuse());
287 CheckGraph(9, 1);
288 }
289
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByOpTypes_FGHIJ)290 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
291 FuseRemoteGraphByOpTypes_FGHIJ) {
292 ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
293 fused_op_types_str_ = "Const,Mul";
294 TF_ASSERT_OK(Fuse());
295 CheckGraph(3, 1);
296 }
297
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByExecutor_HIJ)298 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
299 FuseRemoteGraphByExecutor_HIJ) {
300 ReplaceOpType({"H", "I", "J"}, "Mul");
301 remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME0;
302 fuse_by_executor_ = true;
303 TF_ASSERT_OK(Fuse());
304 CheckGraph(9, 1);
305 }
306
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,FuseRemoteGraphByExecutor_FGHIJ)307 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
308 FuseRemoteGraphByExecutor_FGHIJ) {
309 ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
310 remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME1;
311 fuse_by_executor_ = true;
312 TF_ASSERT_OK(Fuse());
313 CheckGraph(3, 1);
314 }
315
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_HIJ)316 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_HIJ) {
317 fused_node_names_str_ = "H,I,J";
318 TF_ASSERT_OK(PlaceFuseArgs());
319 TF_ASSERT_OK(FuseWithPlacedArgs());
320 CheckGraph(9, 1);
321 }
322
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_ABCDEFGHIJK)323 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_ABCDEFGHIJK) {
324 fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
325 TF_ASSERT_OK(PlaceFuseArgs());
326 TF_ASSERT_OK(FuseWithPlacedArgs());
327 CheckGraph(3, 1);
328 }
329
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_FCG_J)330 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_FCG_J) {
331 border_inputs_str_ = "F:0,C:0,G";
332 border_outputs_str_ = "J:0";
333 TF_ASSERT_OK(PlaceFuseArgs());
334 TF_ASSERT_OK(FuseWithPlacedArgs());
335 CheckGraph(9, 1);
336 }
337
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_ABCDE_K)338 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_ABCDE_K) {
339 SetInputShapeType();
340 border_inputs_str_ = "A,B,C,D,E";
341 border_outputs_str_ = "K";
342 TF_ASSERT_OK(PlaceFuseArgs());
343 TF_ASSERT_OK(FuseWithPlacedArgs());
344 CheckGraph(7, 1);
345 }
346
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_MUL_HIJ)347 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_MUL_HIJ) {
348 SetInputShapeType();
349 ReplaceOpType({"H", "I", "J"}, "Mul");
350 fused_op_types_str_ = "Mul";
351
352 TF_ASSERT_OK(PlaceFuseArgs());
353 TF_ASSERT_OK(FuseWithPlacedArgs());
354 CheckGraph(9, 1);
355 }
356
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,PlaceAndFuse_CONST_MUL_FGHIJ)357 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
358 PlaceAndFuse_CONST_MUL_FGHIJ) {
359 SetInputShapeType();
360 ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
361 fused_op_types_str_ = "Const,Mul";
362
363 TF_ASSERT_OK(PlaceFuseArgs());
364 TF_ASSERT_OK(FuseWithPlacedArgs());
365 CheckGraph(3, 1);
366 }
367
368 } // namespace
369 } // namespace graph_transforms
370 } // namespace tensorflow
371