1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17
18 #include <complex>
19
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/cc/ops/array_ops.h"
23 #include "tensorflow/cc/ops/math_ops.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
29 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
30 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
31 #include "tensorflow/core/grappler/utils.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/test.h"
34
35 namespace tensorflow {
36 namespace grappler {
37
38 namespace {
39
40 constexpr char kHoistFactorOptimizerDiv[] =
41 "ArithmeticOptimizer/HoistCommonFactor_Div_";
42
43 constexpr char kHoistFactorOptimizerMul[] =
44 "ArithmeticOptimizer/HoistCommonFactor_Mul_";
45
46 constexpr char kHoistFactorOptimizerAdd[] =
47 "ArithmeticOptimizer/HoistCommonFactor_AddV2_";
48
49 constexpr char kSimplifyAggregationConst[] =
50 "ArithmeticOptimizer/SimplifyAggregation_Const_";
51
52 constexpr char kSimplifyAggregationMul[] =
53 "ArithmeticOptimizer/SimplifyAggregation_Mul_";
54
55 // Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation.
HoistMulName(const string & name)56 string HoistMulName(const string& name) {
57 return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
58 }
59
60 // Optimized name of outer Div node by HoistCommonFactorOutOfAggregation.
HoistDivName(const string & name)61 string HoistDivName(const string& name) {
62 return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
63 }
64
65 // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation.
HoistAddName(const string & name)66 string HoistAddName(const string& name) {
67 return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
68 }
69
70 // Optimized name of Const node by SimplifyAggregation.
AggregationConstName(const string & name)71 string AggregationConstName(const string& name) {
72 return AddPrefixToNodeName(name, kSimplifyAggregationConst, "");
73 }
74
75 // Optimized name of Mul node by SimplifyAggregation.
AggregationMulName(const string & name)76 string AggregationMulName(const string& name) {
77 return AddPrefixToNodeName(name, kSimplifyAggregationMul, "");
78 }
79
VerifyGraphsMatch(const GraphDef & original_graph,const GraphDef & optimized_graph,int line)80 void VerifyGraphsMatch(const GraphDef& original_graph,
81 const GraphDef& optimized_graph, int line) {
82 EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
83 for (int i = 0; i < original_graph.node_size(); ++i) {
84 const NodeDef& original = original_graph.node(i);
85 const NodeDef& optimized = optimized_graph.node(i);
86 EXPECT_EQ(original.name(), optimized.name()) << line;
87 EXPECT_EQ(original.op(), optimized.op()) << line;
88 EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
89 for (int j = 0; j < original.input_size(); ++j) {
90 EXPECT_EQ(original.input(j), optimized.input(j)) << line;
91 }
92 }
93 }
94 } // namespace
95
TEST_F(ArithmeticOptimizerTest,NoOp)96 TEST_F(ArithmeticOptimizerTest, NoOp) {
97 // This trivial graph is so basic there's nothing to optimize.
98 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
99 GrapplerItem item;
100 CHECK(fake_input.NextItem(&item));
101
102 ArithmeticOptimizer optimizer;
103 GraphDef output;
104 Status status = optimizer.Optimize(nullptr, item, &output);
105 TF_EXPECT_OK(status);
106 VerifyGraphsMatch(item.graph, output, __LINE__);
107 }
108
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithSquare)109 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
110 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
111 Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
112 Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
113 Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
114 Output mul_no_nan = ops::MulNoNan(s.WithOpName("mul_no_nan"), d, d);
115 Output id = ops::Identity(s.WithOpName("id"), mul);
116 Output id2 = ops::Identity(s.WithOpName("id2"), mul_no_nan);
117
118 GrapplerItem item;
119 item.fetch = {"id", "id2"};
120 TF_CHECK_OK(s.ToGraphDef(&item.graph));
121 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
122 ASSERT_EQ(tensors_expected.size(), 2);
123
124 GraphDef output;
125 ArithmeticOptimizer optimizer;
126 EnableOnlyReplaceMulWithSquare(&optimizer);
127 OptimizeAndPrune(&optimizer, &item, &output);
128
129 EXPECT_EQ(output.node_size(), 6);
130
131 NodeMap node_map(&output);
132 const string p = "ArithmeticOptimizer/ReplaceMulWithSquare";
133 const NodeDef* square_node = node_map.GetNode(absl::StrCat(p, "_", "mul"));
134
135 ASSERT_NE(square_node, nullptr);
136 EXPECT_EQ(square_node->op(), "Square");
137 ASSERT_EQ(square_node->input_size(), 2);
138 EXPECT_EQ(square_node->input(0), "c");
139 EXPECT_EQ(square_node->input(1), "^d");
140
141 const NodeDef* square_node2 =
142 node_map.GetNode(absl::StrCat(p, "_", "mul_no_nan"));
143 ASSERT_NE(square_node2, nullptr);
144 EXPECT_EQ(square_node2->op(), "Square");
145 ASSERT_EQ(square_node2->input_size(), 1);
146 EXPECT_EQ(square_node2->input(0), "d");
147
148 auto tensors = EvaluateNodes(output, item.fetch);
149 ASSERT_EQ(tensors.size(), 2);
150 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
151 }
152
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAdjacentNodes)153 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAdjacentNodes) {
154 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
155
156 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
157 auto neg1 = ops::Neg(s.WithOpName("neg1"), c);
158 auto neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
159 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
160 auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
161 auto id = ops::Identity(s.WithOpName("id"), recip2);
162
163 GrapplerItem item;
164 item.fetch = {"id"};
165 TF_CHECK_OK(s.ToGraphDef(&item.graph));
166 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
167 ASSERT_EQ(tensors_expected.size(), 1);
168
169 GraphDef output;
170 ArithmeticOptimizer optimizer;
171 EnableOnlyRemoveInvolution(&optimizer);
172 OptimizeAndPrune(&optimizer, &item, &output);
173
174 // Negation and Reciprocal nodes cancelled each other.
175 ASSERT_EQ(output.node_size(), 2);
176 EXPECT_EQ(output.node(1).name(), "id");
177 ASSERT_EQ(output.node(1).input_size(), 1);
178 EXPECT_EQ(output.node(1).input(0), "c");
179
180 auto tensors = EvaluateNodes(output, item.fetch);
181 ASSERT_EQ(tensors.size(), 1);
182 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
183 }
184
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAroundValuePreservingChain)185 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAroundValuePreservingChain) {
186 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
187
188 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
189 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
190 auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
191 auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
192 auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze);
193 auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
194
195 std::vector<string> fetch = {"id2"};
196
197 GrapplerItem item;
198 item.fetch = fetch;
199 TF_CHECK_OK(s.ToGraphDef(&item.graph));
200 auto tensors_expected = EvaluateNodes(item.graph, fetch);
201 ASSERT_EQ(tensors_expected.size(), 1);
202
203 GraphDef output;
204 ArithmeticOptimizer optimizer;
205 EnableOnlyRemoveInvolution(&optimizer);
206 OptimizeTwiceAndPrune(&optimizer, &item, &output);
207
208 // Check that Reciprocal nodes were removed from the graph.
209 EXPECT_EQ(output.node_size(), 3);
210
211 // And const directly flows into squeeze.
212 int found = 0;
213 for (const NodeDef& node : output.node()) {
214 if (node.name() == "squeeze") {
215 ASSERT_EQ(node.input_size(), 1);
216 EXPECT_EQ(node.input(0), "c");
217 found++;
218 } else if (node.name() == "id2") {
219 ASSERT_EQ(node.input_size(), 1);
220 EXPECT_EQ(node.input(0), "squeeze");
221 found++;
222 }
223 }
224 EXPECT_EQ(found, 2);
225
226 auto tensors = EvaluateNodes(output, fetch);
227 ASSERT_EQ(tensors.size(), 1);
228 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
229 }
230
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionSkipControlDependencies)231 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionSkipControlDependencies) {
232 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
233
234 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
235 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
236 auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
237 auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
238 auto recip2 = ops::Reciprocal(
239 s.WithOpName("recip2").WithControlDependencies(squeeze), c);
240 auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
241
242 std::vector<string> fetch = {"id2"};
243
244 GrapplerItem item;
245 item.fetch = fetch;
246 TF_CHECK_OK(s.ToGraphDef(&item.graph));
247
248 auto tensors_expected = EvaluateNodes(item.graph, fetch);
249 ASSERT_EQ(tensors_expected.size(), 1);
250
251 GraphDef output;
252 ArithmeticOptimizer optimizer;
253 EnableOnlyRemoveInvolution(&optimizer);
254 OptimizeTwice(&optimizer, &item, &output); // do not prune in this test
255
256 // The optimizer should be a noop.
257 VerifyGraphsMatch(item.graph, output, __LINE__);
258
259 auto tensors = EvaluateNodes(output, fetch);
260 ASSERT_EQ(tensors.size(), 1);
261 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
262 }
263
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimple)264 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
265 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
266 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
267 Output add = ops::Add(s.WithOpName("add"), x, x);
268 Output id = ops::Identity(s.WithOpName("id"), add);
269
270 GrapplerItem item;
271 item.fetch = {"id"};
272 TF_CHECK_OK(s.ToGraphDef(&item.graph));
273
274 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
275 ASSERT_EQ(tensors_expected.size(), 1);
276
277 ArithmeticOptimizer optimizer;
278 GraphDef output;
279 OptimizeTwice(&optimizer, &item, &output);
280 NodeMap node_map(&output);
281
282 EXPECT_EQ(output.node_size(), 5);
283
284 const string optimized_const_name = AggregationConstName("add");
285 const string optimized_mul_name = AggregationMulName("add");
286
287 const NodeDef* new_const = node_map.GetNode(optimized_const_name);
288 ASSERT_NE(new_const, nullptr);
289 ASSERT_EQ(new_const->input_size(), 1);
290 EXPECT_EQ(new_const->input(0), "^x");
291 EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
292 string("\0\0\0@", 4));
293
294 const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
295 ASSERT_NE(new_mul, nullptr);
296 ASSERT_EQ(new_mul->input_size(), 2);
297 EXPECT_EQ(new_mul->input(0), optimized_const_name);
298 EXPECT_EQ(new_mul->input(1), "x");
299
300 const NodeDef* new_id = node_map.GetNode("id");
301 ASSERT_NE(new_id, nullptr);
302 ASSERT_EQ(new_id->input_size(), 1);
303 EXPECT_EQ(new_id->input(0), optimized_mul_name);
304
305 auto tensors = EvaluateNodes(output, item.fetch);
306 ASSERT_EQ(tensors.size(), 1);
307 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
308 }
309
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimpleWithControlDep)310 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
311 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
312 Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
313 Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2});
314 Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x);
315 Output id = ops::Identity(s.WithOpName("id"), add);
316
317 GrapplerItem item;
318 TF_CHECK_OK(s.ToGraphDef(&item.graph));
319
320 std::vector<string> fetch = {"id"};
321 auto tensors_expected = EvaluateNodes(item.graph, fetch);
322 ASSERT_EQ(tensors_expected.size(), 1);
323
324 ArithmeticOptimizer optimizer;
325 GraphDef output;
326 OptimizeTwice(&optimizer, &item, &output);
327 NodeMap node_map(&output);
328
329 EXPECT_EQ(output.node_size(), 6);
330
331 const string optimized_const_name = AggregationConstName("add");
332 const string optimized_mul_name = AggregationMulName("add");
333
334 const NodeDef* new_const = node_map.GetNode(optimized_const_name);
335 ASSERT_NE(new_const, nullptr);
336 ASSERT_EQ(new_const->input_size(), 1);
337 EXPECT_EQ(new_const->input(0), "^x");
338 EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
339 string("\0\0\0@", 4));
340
341 const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
342 ASSERT_NE(new_mul, nullptr);
343 ASSERT_EQ(new_mul->input_size(), 3);
344 EXPECT_EQ(new_mul->input(0), optimized_const_name);
345 EXPECT_EQ(new_mul->input(1), "x");
346 EXPECT_EQ(new_mul->input(2), "^y");
347
348 const NodeDef* new_id = node_map.GetNode("id");
349 ASSERT_NE(new_id, nullptr);
350 ASSERT_EQ(new_id->input_size(), 1);
351 EXPECT_EQ(new_id->input(0), optimized_mul_name);
352
353 auto tensors = EvaluateNodes(output, fetch);
354 ASSERT_EQ(tensors.size(), 1);
355 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
356 }
357
TEST_F(ArithmeticOptimizerTest,TrivialSumsRepeatedAdd)358 TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
359 // Test case from b/69059093.
360 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
361 Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10}));
362 Output add = ops::Add(s.WithOpName("Add"), p, p);
363 Output add1 = ops::Add(s.WithOpName("Add_1"), p, p);
364 Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1);
365 Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1);
366 Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5);
367 Output id = ops::Identity(s.WithOpName("id"), add6);
368
369 GrapplerItem item;
370 item.fetch = {"id"};
371 TF_CHECK_OK(s.ToGraphDef(&item.graph));
372
373 const std::vector<string> devices{
374 "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1",
375 "/device:CPU:0", "/device:CPU:0", "/device:CPU:0",
376 };
377 for (int i = 0; i < item.graph.node_size(); ++i) {
378 item.graph.mutable_node(i)->set_device(devices[i]);
379 }
380
381 ArithmeticOptimizer optimizer;
382 DisableAddToAddNCombining(&optimizer);
383
384 GraphDef output;
385 DedupAndOptimizeTwiceAndPrune(&optimizer, &item, &output);
386
387 // We expect the following rewrite(s) to occur:
388 //
389 // Mul(p,
390 // Add_6(Add_4(Const(2), Const(2)),
391 // Add_5(Const(2), Const(2)))
392 NodeMap node_map(&output);
393
394 EXPECT_EQ(output.node_size(), 8);
395
396 const NodeDef* id_node = node_map.GetNode("id");
397 ASSERT_NE(id_node, nullptr);
398 ASSERT_EQ(id_node->input_size(), 1);
399 EXPECT_EQ(id_node->input(0), HoistMulName("Add_6"));
400
401 const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
402 ASSERT_NE(mul_node, nullptr);
403 ASSERT_EQ(mul_node->input_size(), 2);
404 EXPECT_EQ(mul_node->input(0), "Placeholder");
405 EXPECT_EQ(mul_node->input(1), HoistAddName("Add_6"));
406
407 const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
408 ASSERT_NE(add_6_node, nullptr);
409 ASSERT_EQ(add_6_node->input_size(), 2);
410 EXPECT_EQ(add_6_node->input(0), HoistAddName("Add_4"));
411 EXPECT_EQ(add_6_node->input(1), HoistAddName("Add_5"));
412
413 const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
414 ASSERT_NE(add_4_node, nullptr);
415 EXPECT_EQ(add_4_node->op(), "Add");
416 ASSERT_EQ(2, add_4_node->input_size());
417 EXPECT_EQ(add_4_node->input(0), AggregationConstName("Add"));
418 EXPECT_EQ(add_4_node->input(1), AggregationConstName("Add_1"));
419
420 const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
421 ASSERT_NE(add_5_node, nullptr);
422 EXPECT_EQ(add_5_node->op(), "Add");
423 ASSERT_EQ(add_5_node->input_size(), 2);
424 EXPECT_EQ(add_5_node->input(0), AggregationConstName("Add"));
425 EXPECT_EQ(add_5_node->input(1), AggregationConstName("Add_1"));
426
427 const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add"));
428 ASSERT_NE(add_const_node, nullptr);
429 EXPECT_EQ(add_const_node->op(), "Const");
430 ASSERT_EQ(add_const_node->input_size(), 1);
431 EXPECT_EQ(add_const_node->input(0), "^Placeholder");
432
433 const NodeDef* add_1_const_node =
434 node_map.GetNode(AggregationConstName("Add_1"));
435 ASSERT_NE(add_1_const_node, nullptr);
436 EXPECT_EQ(add_1_const_node->op(), "Const");
437 ASSERT_EQ(add_1_const_node->input_size(), 1);
438 EXPECT_EQ(add_1_const_node->input(0), "^Placeholder");
439 }
440
TEST_F(ArithmeticOptimizerTest,HoistFactorMul)441 TEST_F(ArithmeticOptimizerTest, HoistFactorMul) {
442 for (bool matching_shapes : {true, false}) {
443 for (bool use_addn : {true, false}) {
444 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
445 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
446 Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
447 Output y2 = matching_shapes
448 ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2})
449 : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
450 Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
451 Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
452 Output id =
453 use_addn ? ops::Identity(s.WithOpName("id"),
454 ops::AddN(s.WithOpName("add"), {mul1, mul2}))
455 : ops::Identity(s.WithOpName("id"),
456 ops::Add(s.WithOpName("add"), mul1, mul2));
457
458 GrapplerItem item;
459 item.fetch = {"id"};
460 TF_CHECK_OK(s.ToGraphDef(&item.graph));
461 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
462 ASSERT_EQ(tensors_expected.size(), 1);
463 ArithmeticOptimizer optimizer;
464 EnableOnlyHoistCommonFactor(&optimizer);
465
466 GraphDef output;
467 OptimizeTwice(&optimizer, &item, &output);
468
469 // We expect the following rewrite(s) to occur:
470 //
471 // Add Mul
472 // / \ / \
473 // Mul Mul -> x Add
474 // / \ / \ / \
475 // x y1 y2 x y1 y2
476 //
477 // If "root" op is AddN and shapes does not match, this rewrite is not
478 // possible and graph should stay intact.
479 NodeMap node_map(&output);
480
481 if (use_addn && !matching_shapes) {
482 VerifyGraphsMatch(item.graph, output, __LINE__);
483 } else {
484 EXPECT_EQ(output.node_size(), 9);
485
486 const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
487 ASSERT_NE(new_add_node, nullptr) << "Hoisted Add node not found";
488 ASSERT_EQ(new_add_node->input_size(), 2);
489 EXPECT_EQ(new_add_node->input(0), "y1");
490 EXPECT_EQ(new_add_node->input(1), "y2");
491
492 const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
493 ASSERT_NE(new_mul_node, nullptr) << "Hoisted Mul node not found";
494 ASSERT_EQ(new_mul_node->input_size(), 2);
495 EXPECT_EQ(new_mul_node->input(0), "x");
496 EXPECT_EQ(new_mul_node->input(1), new_add_node->name());
497
498 const NodeDef* id_node = node_map.GetNode("id");
499 ASSERT_NE(id_node, nullptr) << "Id node not found";
500 EXPECT_EQ(id_node->name(), "id");
501 ASSERT_EQ(id_node->input_size(), 1);
502 EXPECT_EQ(id_node->input(0), HoistMulName("add"));
503 }
504 auto tensors = EvaluateNodes(output, item.fetch);
505 ASSERT_EQ(tensors.size(), 1);
506 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
507 }
508 }
509 }
510
TEST_F(ArithmeticOptimizerTest,HoistFactorDiv)511 TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
512 for (bool matching_shapes : {true, false}) {
513 for (bool use_addn : {true, false}) {
514 for (bool use_ints : {true, false}) {
515 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
516 Output x = use_ints
517 ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2})
518 : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
519 Output y1 = use_ints
520 ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2})
521 : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
522 Output y2;
523 if (matching_shapes) {
524 y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2})
525 : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
526 } else {
527 y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1})
528 : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
529 }
530 Output div1 = ops::Div(s.WithOpName("div1"), y1, x);
531 Output div2 = ops::Div(s.WithOpName("div2"), y2, x);
532 Output id =
533 use_addn
534 ? ops::Identity(s.WithOpName("id"),
535 ops::AddN(s.WithOpName("add"), {div1, div2}))
536 : ops::Identity(s.WithOpName("id"),
537 ops::Add(s.WithOpName("add"), div1, div2));
538
539 GrapplerItem item;
540 item.fetch = {"id"};
541 TF_CHECK_OK(s.ToGraphDef(&item.graph));
542
543 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
544 ASSERT_EQ(tensors_expected.size(), 1);
545
546 ArithmeticOptimizer optimizer;
547 EnableOnlyHoistCommonFactor(&optimizer);
548
549 GraphDef output;
550 OptimizeTwice(&optimizer, &item, &output);
551
552 // We expect the following rewrite(s) to occur:
553 //
554 // Add Div
555 // / \ / \
556 // Div Div -> Add x
557 // / \ / \ / \
558 // y1 x y2 x y1 y2
559 //
560 // If "root" op is AddN and shapes does not match, this rewrite is not
561 // possible and graph should stay intact.
562 NodeMap node_map(&output);
563
564 if ((use_addn && !matching_shapes) || use_ints) {
565 VerifyGraphsMatch(item.graph, output, __LINE__);
566 } else {
567 EXPECT_EQ(output.node_size(), 9);
568
569 const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
570 ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
571 ASSERT_EQ(new_add_node->input_size(), 2);
572 EXPECT_EQ(new_add_node->input(0), "y1");
573 EXPECT_EQ(new_add_node->input(1), "y2");
574
575 const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add"));
576 ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found";
577 ASSERT_EQ(new_div_node->input_size(), 2);
578 EXPECT_EQ(new_div_node->input(0), new_add_node->name());
579 EXPECT_EQ(new_div_node->input(1), "x");
580
581 const NodeDef* id_node = node_map.GetNode("id");
582 ASSERT_TRUE(id_node != nullptr) << "Id node not found";
583 EXPECT_EQ("id", id_node->name());
584 ASSERT_EQ(id_node->input_size(), 1);
585 EXPECT_EQ(id_node->input(0), HoistDivName("add"));
586 }
587 auto tensors = EvaluateNodes(output, item.fetch);
588 ASSERT_EQ(tensors.size(), 1);
589 if (use_ints) {
590 test::ExpectTensorEqual<int32>(tensors[0], tensors_expected[0]);
591 } else {
592 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
593 }
594 }
595 }
596 }
597 }
598
TEST_F(ArithmeticOptimizerTest,FuseConjAndTranspose)599 TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
600 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
601 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
602 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
603 Output z = ops::Complex(s.WithOpName("z"), re, im);
604 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
605 Output conj = ops::Conj(s.WithOpName("conj"), z);
606 Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
607
608 GrapplerItem item;
609 item.fetch = {"trans"};
610 TF_CHECK_OK(s.ToGraphDef(&item.graph));
611
612 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
613 ASSERT_EQ(tensors_expected.size(), 1);
614
615 ArithmeticOptimizer optimizer;
616 GraphDef output;
617 OptimizeTwice(&optimizer, &item, &output);
618 NodeMap node_map(&output);
619
620 EXPECT_EQ(output.node_size(), 7);
621
622 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
623 const string optimized_name = absl::StrCat(p, "_", "trans");
624
625 const NodeDef* trans_fused_node = node_map.GetNode(optimized_name);
626 ASSERT_NE(trans_fused_node, nullptr);
627 EXPECT_EQ(trans_fused_node->op(), "ConjugateTranspose");
628 ASSERT_EQ(trans_fused_node->input_size(), 2);
629 EXPECT_EQ(trans_fused_node->input(0), "z");
630 EXPECT_EQ(trans_fused_node->input(1), "perm");
631
632 auto tensors = EvaluateNodes(output, item.fetch);
633 ASSERT_EQ(tensors.size(), 1);
634 test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
635 }
636
TEST_F(ArithmeticOptimizerTest,FuseConjAndConjugateTranspose)637 TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
638 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
639
640 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
641 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
642 Output z = ops::Complex(s.WithOpName("z"), re, im);
643 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
644 Output conj = ops::Conj(s.WithOpName("conj"), z);
645 Output transp =
646 ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
647
648 GrapplerItem item;
649 item.fetch = {"conjugate_trans"};
650 TF_CHECK_OK(s.ToGraphDef(&item.graph));
651
652 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
653 ASSERT_EQ(tensors_expected.size(), 1);
654
655 ArithmeticOptimizer optimizer;
656 GraphDef output;
657 OptimizeTwice(&optimizer, &item, &output);
658 NodeMap node_map(&output);
659
660 EXPECT_EQ(output.node_size(), 7);
661
662 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
663 const string optimized_name = absl::StrCat(p, "_", "conjugate_trans");
664
665 const NodeDef* conjugate_trans_fused_node = node_map.GetNode(optimized_name);
666 ASSERT_NE(conjugate_trans_fused_node, nullptr);
667 EXPECT_EQ(conjugate_trans_fused_node->op(), "Transpose");
668 ASSERT_EQ(conjugate_trans_fused_node->input_size(), 2);
669 EXPECT_EQ(conjugate_trans_fused_node->input(0), "z");
670 EXPECT_EQ(conjugate_trans_fused_node->input(1), "perm");
671
672 auto tensors = EvaluateNodes(output, item.fetch);
673 ASSERT_EQ(tensors.size(), 1);
674 test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
675 }
676
TEST_F(ArithmeticOptimizerTest,FuseTransposeAndConj)677 TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
678 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
679 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
680 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
681 Output z = ops::Complex(s.WithOpName("z"), re, im);
682 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
683 Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
684 Output conj = ops::Conj(s.WithOpName("conj"), trans);
685
686 GrapplerItem item;
687 item.fetch = {"conj"};
688 TF_CHECK_OK(s.ToGraphDef(&item.graph));
689
690 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
691 ASSERT_EQ(tensors_expected.size(), 1);
692
693 ArithmeticOptimizer optimizer;
694 GraphDef output;
695 OptimizeTwice(&optimizer, &item, &output);
696 NodeMap node_map(&output);
697
698 EXPECT_EQ(output.node_size(), 7);
699
700 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
701 const string optimized_name = absl::StrCat(p, "_", "conj");
702
703 const NodeDef* conj_fused_node = node_map.GetNode(optimized_name);
704 ASSERT_NE(conj_fused_node, nullptr);
705 EXPECT_EQ(conj_fused_node->op(), "ConjugateTranspose");
706 ASSERT_EQ(conj_fused_node->input_size(), 2);
707 EXPECT_EQ(conj_fused_node->input(0), "z");
708 EXPECT_EQ(conj_fused_node->input(1), "perm");
709
710 auto tensors = EvaluateNodes(output, item.fetch);
711 ASSERT_EQ(tensors.size(), 1);
712 test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
713 }
714
TEST_F(ArithmeticOptimizerTest,FoldTransposeIntoMatMul)715 TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
716 for (const string matmul_type :
717 {"MatMul", "SparseMatMul", "BatchMatMul", "BatchMatMulV2"}) {
718 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
719
720 Output a = ops::Const(s.WithOpName("a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
721 Output b = ops::Const(s.WithOpName("b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
722 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
723 Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
724 Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
725
726 Output matmul;
727 auto matmul_op = s.WithOpName("matmul");
728 if (matmul_type == "MatMul") {
729 matmul = ops::MatMul(matmul_op, trans_a, trans_b);
730 } else if (matmul_type == "SparseMatMul") {
731 matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
732 } else if (matmul_type == "BatchMatMul") {
733 matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
734 } else if (matmul_type == "BatchMatMulV2") {
735 matmul = ops::BatchMatMulV2(matmul_op, trans_a, trans_b);
736 }
737
738 auto identity = ops::Identity(s.WithOpName("identity"), matmul);
739
740 GrapplerItem item;
741 item.fetch = {"identity"};
742 TF_CHECK_OK(s.ToGraphDef(&item.graph));
743
744 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
745 ASSERT_EQ(tensors_expected.size(), 1);
746
747 ArithmeticOptimizer optimizer;
748 EnableOnlyFoldTransposeIntoMatMul(&optimizer);
749 GraphDef output;
750 OptimizeTwice(&optimizer, &item, &output);
751 NodeMap node_map(&output);
752
753 EXPECT_EQ(output.node_size(), 8);
754
755 const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
756 const string optimized_name = absl::StrCat(p, "_", "matmul");
757
758 const NodeDef* matmul_fused_node = node_map.GetNode(optimized_name);
759 ASSERT_NE(matmul_fused_node, nullptr);
760 ASSERT_EQ(matmul_fused_node->input_size(), 2);
761 EXPECT_EQ(matmul_fused_node->input(0), "a");
762 EXPECT_EQ(matmul_fused_node->input(1), "b");
763
764 if (matmul_type == "BatchMatMul" || matmul_type == "BatchMatMulV2") {
765 EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b());
766 EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b());
767 } else {
768 EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
769 EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
770 }
771
772 const NodeDef* identity_node = node_map.GetNode("identity");
773 ASSERT_NE(identity_node, nullptr);
774 ASSERT_EQ(identity_node->input_size(), 1);
775 EXPECT_EQ(identity_node->input(0), optimized_name);
776
777 auto tensors = EvaluateNodes(output, item.fetch);
778 ASSERT_EQ(tensors.size(), 1);
779 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
780 }
781 }
782
TEST_F(ArithmeticOptimizerTest,FoldConjugateTransposeIntoBatchMatMul)783 TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
784 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
785
786 Output re_a =
787 ops::Const(s.WithOpName("re_a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
788 Output im_a =
789 ops::Const(s.WithOpName("im_a"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
790 Output re_b =
791 ops::Const(s.WithOpName("re_b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
792 Output im_b =
793 ops::Const(s.WithOpName("im_b"), {-5.0f, -6.0f, -7.0f, -8.0f}, {2, 2});
794 Output a = ops::Complex(s.WithOpName("a"), re_a, im_a);
795 Output b = ops::Complex(s.WithOpName("b"), re_b, im_b);
796 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
797 Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
798 Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
799 Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
800 Output identity = ops::Identity(s.WithOpName("identity"), matmul);
801
802 GrapplerItem item;
803 item.fetch = {"identity"};
804 TF_CHECK_OK(s.ToGraphDef(&item.graph));
805
806 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
807 ASSERT_EQ(tensors_expected.size(), 1);
808
809 ArithmeticOptimizer optimizer;
810 GraphDef output;
811 OptimizeTwice(&optimizer, &item, &output);
812
813 NodeMap node_map(&output);
814 EXPECT_EQ(output.node_size(), 12);
815
816 const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
817 const string optimized_name = absl::StrCat(p, "_", "matmul");
818
819 const NodeDef* optimized_matmul = node_map.GetNode(optimized_name);
820 ASSERT_NE(optimized_matmul, nullptr);
821 ASSERT_EQ(optimized_matmul->input_size(), 2);
822 EXPECT_EQ(optimized_matmul->input(0), "a");
823 EXPECT_EQ(optimized_matmul->input(1), "b");
824 EXPECT_TRUE(optimized_matmul->attr().at("adj_x").b());
825 EXPECT_TRUE(optimized_matmul->attr().at("adj_y").b());
826
827 auto tensors = EvaluateNodes(output, item.fetch);
828 ASSERT_EQ(tensors.size(), 1);
829 test::ExpectTensorNear<complex64>(tensors[0], tensors_expected[0], 1e-6);
830 }
831
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshape)832 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeIdentityReshape) {
833 for (bool is_broadcastto : {false, true}) {
834 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
835 Output inputs =
836 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
837 Output inputs_shape = ops::Shape(s, inputs);
838 // The target shape of the reshape is the concatenation of `batch_size` and
839 // [3,28,28].
840 Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
841 ops::Const(s, {1}, {1}));
842 Output target_shape = ops::Concat(
843 s.WithOpName("target_shape"),
844 {batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {}));
845 if (is_broadcastto) {
846 Output outputs = ops::Identity(s.WithOpName("outputs"),
847 ops::BroadcastTo(s, inputs, target_shape));
848 } else {
849 Output outputs = ops::Identity(s.WithOpName("outputs"),
850 ops::Reshape(s, inputs, target_shape));
851 }
852
853 GrapplerItem item;
854 item.fetch = {"outputs"};
855 TF_CHECK_OK(s.ToGraphDef(&item.graph));
856 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
857 auto tensors_expected =
858 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
859 ASSERT_EQ(tensors_expected.size(), 1);
860
861 GraphDef output;
862 ArithmeticOptimizer optimizer;
863 EnableOnlyRemoveRedundantReshape(&optimizer);
864 OptimizeTwiceAndPrune(&optimizer, &item, &output);
865
866 EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
867 EXPECT_EQ(CountOpNodes(output, "BroadcastTo"), 0);
868 auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
869 ASSERT_EQ(tensors.size(), 1);
870 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
871 }
872 }
873
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes)874 TEST_F(ArithmeticOptimizerTest,
875 RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes) {
876 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
877 Output inputs =
878 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
879 Output inputs_shape = ops::Shape(s, inputs);
880 // The target shape of the reshape is the concatenation of `batch_size`, 3,
881 // `height, and `width`.
882 Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
883 ops::Const(s, {1}, {1}));
884 Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
885 ops::Const(s, {1}, {1}));
886 Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
887 ops::Const(s, {1}, {1}));
888 Output target_shape =
889 ops::Concat(s.WithOpName("target_shape"),
890 {batch_size, ops::Const(s, {3}, {1}), height, width},
891 ops::Const(s, {0}, {}));
892 Output reshape = ops::Reshape(s, inputs, target_shape);
893 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
894
895 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
896 GrapplerItem item;
897 item.fetch = {"outputs"};
898 item.feed = {{"Placeholder", x_t}};
899 TF_CHECK_OK(s.ToGraphDef(&item.graph));
900
901 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
902 ASSERT_EQ(tensors_expected.size(), 1);
903
904 GraphDef output;
905 // Assume valid feed shape in aggressive mode.
906 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
907 EnableOnlyRemoveRedundantReshape(&optimizer);
908 OptimizeTwiceAndPrune(&optimizer, &item, &output);
909
910 EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
911 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
912 ASSERT_EQ(tensors.size(), 1);
913 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
914 }
915
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotAssumeValidFeeds)916 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotAssumeValidFeeds) {
917 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
918 Output inputs =
919 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
920 Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
921 Output reshape = ops::Reshape(s, inputs, target_shape);
922 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
923
924 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
925 GrapplerItem item;
926 item.fetch = {"outputs"};
927 item.feed = {{"Placeholder", x_t}};
928 TF_CHECK_OK(s.ToGraphDef(&item.graph));
929
930 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
931 ASSERT_EQ(tensors_expected.size(), 1);
932
933 GraphDef output;
934 ArithmeticOptimizer optimizer;
935 EnableOnlyRemoveRedundantReshape(&optimizer);
936 OptimizeTwiceAndPrune(&optimizer, &item, &output);
937
938 // The reshape is preserved because the shape of the placeholder can be
939 // different from the shape of the actual feed.
940 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
941
942 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
943 ASSERT_EQ(tensors.size(), 1);
944 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
945 }
946
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode)947 TEST_F(ArithmeticOptimizerTest,
948 RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode) {
949 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
950 Output inputs =
951 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
952 Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
953 Output reshape = ops::Reshape(s, inputs, target_shape);
954 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
955
956 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
957 GrapplerItem item;
958 item.fetch = {"outputs"};
959 item.feed = {{"Placeholder", x_t}};
960 TF_CHECK_OK(s.ToGraphDef(&item.graph));
961
962 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
963 ASSERT_EQ(tensors_expected.size(), 1);
964
965 GraphDef output;
966 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
967 EnableOnlyRemoveRedundantReshape(&optimizer);
968 OptimizeTwiceAndPrune(&optimizer, &item, &output);
969
970 EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
971 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
972 ASSERT_EQ(tensors.size(), 1);
973 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
974 }
975
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshape)976 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotIdentityReshape) {
977 // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
978 // be from [4,3,28,28] to [8,6,28,28].
979 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
980 Output inputs =
981 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
982 Output reshape = ops::Reshape(s, inputs, ops::Const(s, {8, -1, 28, 28}, {4}));
983 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
984
985 GrapplerItem item;
986 item.fetch = {"outputs"};
987 TF_CHECK_OK(s.ToGraphDef(&item.graph));
988 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28}));
989 item.feed = {{"Placeholder", x_t}};
990 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
991 ASSERT_EQ(tensors_expected.size(), 1);
992
993 GraphDef output;
994 ArithmeticOptimizer optimizer;
995 EnableOnlyRemoveRedundantReshape(&optimizer);
996 OptimizeTwiceAndPrune(&optimizer, &item, &output);
997
998 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
999 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1000 ASSERT_EQ(tensors.size(), 1);
1001 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1002 }
1003
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes)1004 TEST_F(ArithmeticOptimizerTest,
1005 RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes) {
1006 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1007 Output inputs =
1008 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
1009 Output reshape = ops::Reshape(s, inputs, ops::Const(s, {-1, -1}, {2}));
1010 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1011
1012 GrapplerItem item;
1013 item.fetch = {"outputs"};
1014 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1015
1016 GraphDef output;
1017 ArithmeticOptimizer optimizer;
1018 EnableOnlyRemoveRedundantReshape(&optimizer);
1019 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1020
1021 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1022 }
1023
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeCombineReshapes)1024 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeCombineReshapes) {
1025 // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two
1026 // reshapes should be combined.
1027 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1028 Output nchw_vect_c =
1029 ops::Placeholder(s.WithOpName("nchw_vect_c"), DT_INT8,
1030 ops::Placeholder::Shape({8, 3, 28, 28, 4}));
1031 Output transpose =
1032 ops::Transpose(s.WithOpName("transpose"), nchw_vect_c,
1033 ops::Const(s.WithOpName("perm"), {0, 2, 3, 1, 4}, {5}));
1034 Output nhwc = ops::Reshape(
1035 s.WithOpName("nhwc"), transpose,
1036 ops::Const(
1037 s.WithControlDependencies(nchw_vect_c).WithOpName("nhwc_shape"),
1038 {8, 28, 28, 12}, {4}));
1039 Output flatten = ops::Reshape(
1040 s.WithOpName("flatten"), nhwc,
1041 ops::Const(s.WithOpName("flatten_shape"), {8, 28 * 28 * 12}, {2}));
1042 Output outputs = ops::Identity(s.WithOpName("outputs"), flatten);
1043
1044 GraphDef graph;
1045 TF_CHECK_OK(s.ToGraphDef(&graph));
1046 auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({8, 3, 28, 28, 4}));
1047 auto eval = EvaluateNodes(graph, {"outputs", "nhwc"}, {{"nchw_vect_c", x_t}});
1048
1049 ASSERT_EQ(eval.size(), 2);
1050 auto expected_output_t = eval[0];
1051 auto nhwc_t = eval[1];
1052
1053 {
1054 GrapplerItem item;
1055 item.graph = graph;
1056 item.fetch = {"outputs"};
1057 item.feed = {{"nchw_vect_c", x_t}};
1058
1059 GraphDef output;
1060 ArithmeticOptimizer optimizer;
1061 EnableOnlyRemoveRedundantReshape(&optimizer);
1062 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1063
1064 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1065 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1066 ASSERT_EQ(tensors.size(), 1);
1067 test::ExpectTensorEqual<int8>(tensors[0], expected_output_t);
1068 }
1069
1070 // Test when the first reshape node output is the feed tensor.
1071 // (Expected no reshape removal to happen.)
1072 {
1073 GrapplerItem item;
1074 item.graph = graph;
1075 item.fetch = {"outputs"};
1076 item.feed = {{"nhwc", nhwc_t}};
1077
1078 GraphDef output;
1079 ArithmeticOptimizer optimizer;
1080 EnableOnlyRemoveRedundantReshape(&optimizer);
1081 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1082
1083 EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
1084 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1085 ASSERT_EQ(tensors.size(), 1);
1086 test::ExpectTensorEqual<int8>(tensors[0], expected_output_t);
1087 }
1088 }
1089
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsCast)1090 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsCast) {
1091 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1092 Output nhwc_uint8 =
1093 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1094 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1095 Output nchw_fp32 =
1096 ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1097 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1098
1099 GrapplerItem item;
1100 item.fetch = {"outputs"};
1101 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1102
1103 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1104 auto tensors_expected =
1105 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1106 ASSERT_EQ(tensors_expected.size(), 1);
1107
1108 GraphDef output;
1109 ArithmeticOptimizer optimizer;
1110 OptimizeAndPrune(&optimizer, &item, &output);
1111
1112 const NodeDef* transpose_node = nullptr;
1113 for (const NodeDef& node : output.node()) {
1114 if (node.op() == "Transpose") {
1115 EXPECT_EQ(transpose_node, nullptr);
1116 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1117 transpose_node = &node;
1118 }
1119 }
1120 ASSERT_NE(transpose_node, nullptr);
1121
1122 for (const NodeDef& node : output.node()) {
1123 if (node.op() == "Cast") {
1124 ASSERT_EQ(node.input_size(), 1);
1125 EXPECT_EQ(transpose_node->name(), NodeName(node.input(0)));
1126 }
1127 }
1128
1129 auto tensors =
1130 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1131 ASSERT_EQ(tensors.size(), 1);
1132 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1133 }
1134
TEST_F(ArithmeticOptimizerTest,ReorderS2DCastProducerIsCast)1135 TEST_F(ArithmeticOptimizerTest, ReorderS2DCastProducerIsCast) {
1136 // TODO(jingyue): Evaluate S2D+Cast on GPU as well. We can't simply put nodes
1137 // under a /GPU:0 scope, because this test would fail if the testing machine
1138 // doesn't have a GPU. Maybe EvaluateNodes should allow soft placement?
1139 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1140 Output outputs =
1141 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1142 outputs = ops::Cast(s, outputs, DT_FLOAT);
1143 outputs = ops::SpaceToDepth(s, outputs, 2);
1144 outputs = ops::Identity(s.WithOpName("outputs"), outputs);
1145
1146 GrapplerItem item;
1147 item.fetch = {"outputs"};
1148 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1149
1150 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1151 auto tensors_expected =
1152 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1153 ASSERT_EQ(tensors_expected.size(), 1);
1154
1155 GraphDef output;
1156 ArithmeticOptimizer optimizer;
1157 OptimizeAndPrune(&optimizer, &item, &output);
1158
1159 const NodeDef* s2d_node = nullptr;
1160 for (const NodeDef& node : output.node()) {
1161 if (node.op() == "SpaceToDepth") {
1162 EXPECT_EQ(s2d_node, nullptr);
1163 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1164 s2d_node = &node;
1165 }
1166 }
1167 ASSERT_NE(s2d_node, nullptr);
1168
1169 for (const NodeDef& node : output.node()) {
1170 if (node.op() == "Cast") {
1171 ASSERT_EQ(node.input_size(), 1);
1172 EXPECT_EQ(s2d_node->name(), NodeName(node.input(0)));
1173 }
1174 }
1175
1176 auto tensors =
1177 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1178 ASSERT_EQ(tensors.size(), 1);
1179 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1180 }
1181
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsTranspose)1182 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsTranspose) {
1183 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1184 Output nhwc_fp32 =
1185 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1186 Output nchw_fp32 =
1187 ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1188 Output nchw_uint8 = ops::Cast(s, nchw_fp32, DT_UINT8);
1189 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1190
1191 GrapplerItem item;
1192 item.fetch = {"outputs"};
1193 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1194
1195 auto input_t =
1196 GenerateConstantTensor<DT_FLOAT>(TensorShape({8, 28, 28, 3}), 42.0f);
1197 auto tensors_expected =
1198 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1199 ASSERT_EQ(tensors_expected.size(), 1);
1200
1201 GraphDef output;
1202 ArithmeticOptimizer optimizer;
1203 OptimizeAndPrune(&optimizer, &item, &output);
1204
1205 const NodeDef* cast_node = nullptr;
1206 for (const NodeDef& node : output.node()) {
1207 if (node.op() == "Cast") {
1208 EXPECT_EQ(cast_node, nullptr);
1209 cast_node = &node;
1210 ASSERT_EQ(node.input_size(), 1);
1211 EXPECT_EQ(NodeName(node.input(0)), "Placeholder");
1212 }
1213 }
1214 ASSERT_NE(cast_node, nullptr);
1215
1216 for (const NodeDef& node : output.node()) {
1217 if (node.op() == "Transpose") {
1218 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1219 ASSERT_EQ(node.input_size(), 2);
1220 EXPECT_EQ(cast_node->name(), NodeName(node.input(0)));
1221 }
1222 }
1223
1224 auto tensors =
1225 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1226 ASSERT_EQ(tensors.size(), 1);
1227 test::ExpectTensorEqual<uint8>(tensors[0], tensors_expected[0]);
1228 }
1229
TEST_F(ArithmeticOptimizerTest,ReorderTransposeReverseCast)1230 TEST_F(ArithmeticOptimizerTest, ReorderTransposeReverseCast) {
1231 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1232 Output nhwc_uint8 =
1233 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1234 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1235 Output nhwc_fp32_reversed =
1236 ops::Reverse(s, nhwc_fp32, ops::Const(s, {0}, {1}));
1237 Output nchw_fp32_reversed =
1238 ops::Transpose(s, nhwc_fp32_reversed, ops::Const(s, {0, 3, 1, 2}, {4}));
1239
1240 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32_reversed);
1241
1242 GrapplerItem item;
1243 item.fetch = {"outputs"};
1244 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1245
1246 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1247 auto tensors_expected =
1248 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1249 ASSERT_EQ(tensors_expected.size(), 1);
1250
1251 GraphDef output;
1252 ArithmeticOptimizer optimizer;
1253 OptimizeAndPrune(&optimizer, &item, &output);
1254
1255 const NodeDef* reverse_node = nullptr;
1256 const NodeDef* transpose_node = nullptr;
1257 const NodeDef* cast_node = nullptr;
1258 for (const NodeDef& node : output.node()) {
1259 if (node.op() == "Transpose") {
1260 EXPECT_EQ(transpose_node, nullptr);
1261 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1262 transpose_node = &node;
1263 } else if (node.op() == "ReverseV2") {
1264 EXPECT_EQ(reverse_node, nullptr);
1265 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1266 reverse_node = &node;
1267 } else if (node.op() == "Cast") {
1268 cast_node = &node;
1269 }
1270 }
1271 ASSERT_NE(cast_node, nullptr);
1272 ASSERT_NE(reverse_node, nullptr);
1273 ASSERT_NE(transpose_node, nullptr);
1274 ASSERT_EQ(reverse_node->input_size(), 2);
1275 EXPECT_EQ(NodeName(reverse_node->input(0)), "Placeholder");
1276 ASSERT_EQ(transpose_node->input_size(), 2);
1277 EXPECT_EQ(NodeName(transpose_node->input(0)), reverse_node->name());
1278 ASSERT_EQ(cast_node->input_size(), 1);
1279 EXPECT_EQ(NodeName(cast_node->input(0)), transpose_node->name());
1280
1281 auto tensors =
1282 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1283 ASSERT_EQ(tensors.size(), 1);
1284 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1285 }
1286
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastCheckNumericsToIdentity)1287 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastCheckNumericsToIdentity) {
1288 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1289 Output nhwc_uint8 =
1290 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1291 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1292 Output nchw_fp32 = ops::CheckNumerics(s, nhwc_fp32, "foo");
1293 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1294
1295 GrapplerItem item;
1296 item.fetch = {"outputs"};
1297 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1298
1299 GraphDef output;
1300 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1301 CompareGraphs(item.graph, output);
1302 }
1303
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsCast)1304 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsCast) {
1305 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1306 Output nhwc_fp32 =
1307 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1308 Output nhwc_uint8 = ops::Cast(s, nhwc_fp32, DT_UINT8);
1309 Output nchw_uint8 =
1310 ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1311 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1312
1313 GrapplerItem item;
1314 item.fetch = {"outputs"};
1315 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1316
1317 GraphDef output;
1318 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1319 CompareGraphs(item.graph, output);
1320 }
1321
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsTranspose)1322 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsTranspose) {
1323 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1324 Output nhwc_uint8 =
1325 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1326 Output nchw_uint8 =
1327 ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1328 Output nchw_fp32 = ops::Cast(s, nchw_uint8, DT_FLOAT);
1329 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1330
1331 GrapplerItem item;
1332 item.fetch = {"outputs"};
1333 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1334
1335 GraphDef output;
1336 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1337 CompareGraphs(item.graph, output);
1338 }
1339
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposes)1340 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
1341 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1342 Output inputs_shape =
1343 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1344 Output inputs =
1345 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1346 Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1347 Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1348 Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
1349 Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
1350 Output transpose2 =
1351 ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
1352 Output transpose3 = ops::Transpose(s.WithOpName("transpose3"), inputs, perm3);
1353 Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1354 Output id2 = ops::Identity(s.WithOpName("id2"), transpose3);
1355
1356 GrapplerItem item;
1357 item.fetch = {"id1", "id2"};
1358 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1359
1360 GraphDef output;
1361 ArithmeticOptimizer optimizer;
1362 EnableOnlyRemoveIdentityTranspose(&optimizer);
1363 OptimizeAndPrune(&optimizer, &item, &output);
1364
1365 std::set<string> nodes_after_optimization;
1366 for (const NodeDef& node : output.node()) {
1367 nodes_after_optimization.insert(node.name());
1368 }
1369 EXPECT_EQ(nodes_after_optimization,
1370 std::set<string>({"id1", "id2", "inputs_shape", "inputs"}));
1371 }
1372
TEST_F(ArithmeticOptimizerTest,RemoveIdentityConjugateTransposes)1373 TEST_F(ArithmeticOptimizerTest, RemoveIdentityConjugateTransposes) {
1374 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1375 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1376 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1377 Output z = ops::Complex(s.WithOpName("z"), re, im);
1378 Output perm = ops::Const(s.WithOpName("perm"), {0, 1}, {2});
1379 Output transpose = ops::ConjugateTranspose(s.WithOpName("trans"), z, perm);
1380 Output id = ops::Identity(s.WithOpName("id"), transpose);
1381
1382 GrapplerItem item;
1383 item.fetch = {"id"};
1384 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1385
1386 GraphDef output;
1387 ArithmeticOptimizer optimizer;
1388 EnableOnlyRemoveIdentityTranspose(&optimizer);
1389 OptimizeAndPrune(&optimizer, &item, &output);
1390 NodeMap node_map(&output);
1391
1392 EXPECT_EQ(output.node_size(), 5);
1393
1394 const string p = "ArithmeticOptimizer/RemoveIdentityTranspose";
1395 const string optimized_name = absl::StrCat(p, "_", "trans");
1396
1397 const NodeDef* conj = node_map.GetNode(optimized_name);
1398 ASSERT_NE(conj, nullptr);
1399 EXPECT_EQ(conj->op(), "Conj");
1400 ASSERT_EQ(conj->input_size(), 1);
1401 EXPECT_EQ(conj->input(0), "z");
1402 }
1403
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesMultipleOutputs)1404 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
1405 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1406 Output inputs_shape =
1407 ops::Const(s.WithOpName("inputs_shape"), {8, 9, 28, 28}, {4});
1408 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1409 ops::Placeholder::Shape({8, 12, 28, 28}));
1410 OutputList split = ops::Split(s, ops::Const(s, 1), inputs, 3).output;
1411 Output perm1 = ops::Const(s, {0, 2, 3, 1}, {4});
1412 Output perm2 = ops::Const(s, {0, 3, 1, 2}, {4});
1413 Output branch0 = split[0];
1414 Output branch1 = ops::Transpose(s, ops::Transpose(s, split[1], perm1), perm2);
1415 Output branch2 = split[2];
1416 Output concat = ops::Concat(s, {branch0, branch1, branch2}, ops::Const(s, 1));
1417 Output outputs = ops::Identity(s.WithOpName("outputs"), concat);
1418
1419 GrapplerItem item;
1420 item.fetch = {"outputs"};
1421 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1422
1423 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28}));
1424 item.feed = {{"inputs", x_t}};
1425 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1426 ASSERT_EQ(tensors_expected.size(), 1);
1427
1428 GraphDef output;
1429 ArithmeticOptimizer optimizer;
1430 EnableOnlyRemoveIdentityTranspose(&optimizer);
1431 OptimizeAndPrune(&optimizer, &item, &output);
1432
1433 for (const NodeDef& node : output.node()) {
1434 if (node.op() == "Concat") {
1435 ASSERT_EQ(node.input_size(), 3);
1436 EXPECT_EQ(node.input(0), "Split");
1437 EXPECT_EQ(node.input(1), "Split:1");
1438 EXPECT_EQ(node.input(2), "Split:2");
1439 }
1440 }
1441
1442 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1443 ASSERT_EQ(tensors.size(), 1);
1444 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1445 }
1446
TEST_F(ArithmeticOptimizerTest,RemoveTransposesWithControlDependency)1447 TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
1448 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1449 Output inputs =
1450 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3}));
1451 Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0}));
1452 Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0}));
1453 Output outputs =
1454 ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2),
1455 ops::Const(s.WithOpName("outputs_const"), 1.0f));
1456
1457 GrapplerItem item;
1458 item.fetch = {"outputs"};
1459 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1460
1461 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
1462 item.feed = {{"Placeholder", x_t}};
1463 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1464 ASSERT_EQ(tensors_expected.size(), 1);
1465
1466 GraphDef output;
1467 ArithmeticOptimizer optimizer;
1468 EnableOnlyRemoveIdentityTranspose(&optimizer);
1469 OptimizeAndPrune(&optimizer, &item, &output);
1470
1471 NodeMap node_map(&output);
1472 const NodeDef* outputs_node = node_map.GetNode("outputs");
1473 ASSERT_EQ(outputs_node->input_size(), 2);
1474 EXPECT_EQ(outputs_node->input(0), "outputs_const");
1475 EXPECT_EQ(outputs_node->input(1), "^Placeholder");
1476
1477 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1478 ASSERT_EQ(tensors.size(), 1);
1479 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1480 }
1481
TEST_F(ArithmeticOptimizerTest,NotRemoveTransposes)1482 TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
1483 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1484 Output inputs_shape =
1485 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1486 Output inputs =
1487 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1488 Output perm = ops::Const(s.WithOpName("perm"), {1, 2, 3, 0}, {4});
1489 Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm);
1490 Output transpose2 =
1491 ops::Transpose(s.WithOpName("transpose2"), transpose1, perm);
1492 Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2);
1493
1494 GrapplerItem item;
1495 item.fetch = {"outputs"};
1496 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1497
1498 GraphDef output;
1499 ArithmeticOptimizer optimizer;
1500 EnableOnlyRemoveIdentityTranspose(&optimizer);
1501 OptimizeAndPrune(&optimizer, &item, &output);
1502
1503 EXPECT_EQ(output.node_size(), 6);
1504 }
1505
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesThroughChain)1506 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
1507 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1508 Output inputs_shape =
1509 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1510 Output inputs =
1511 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1512 Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1513 Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1514 Output transpose1 = ops::Transpose(
1515 s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
1516 Output identity = ops::Identity(s.WithOpName("id"), transpose1);
1517 Output transpose2 =
1518 ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
1519 Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1520
1521 GrapplerItem item;
1522 item.fetch = {"id1"};
1523 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1524
1525 GraphDef output;
1526 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1527 EnableOnlyRemoveIdentityTranspose(&optimizer);
1528 OptimizeAndPrune(&optimizer, &item, &output);
1529
1530 std::set<string> nodes_after_optimization;
1531 for (const NodeDef& node : output.node()) {
1532 nodes_after_optimization.insert(node.name());
1533 if (node.name() == "id") {
1534 ASSERT_EQ(node.input_size(), 1);
1535 EXPECT_EQ(node.input(0), "inputs");
1536 }
1537 if (node.name() == "id1") {
1538 ASSERT_EQ(node.input_size(), 1);
1539 EXPECT_EQ(node.input(0), "id");
1540 }
1541 }
1542 EXPECT_EQ(nodes_after_optimization,
1543 std::set<string>({"id", "id1", "inputs_shape", "inputs"}));
1544 }
1545
TEST_F(ArithmeticOptimizerTest,FoldMulToTransposeConv)1546 TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
1547 for (bool swap_inputs : {false, true}) {
1548 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1549 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1550 ops::Placeholder::Shape({1, 28, 28, 3}));
1551 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1552 Output scaled_inputs = ops::Multiply(s.WithOpName("scaled_inputs"),
1553 swap_inputs ? scale : inputs,
1554 swap_inputs ? inputs : scale);
1555 Output perm_nhwc_to_nchw =
1556 ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
1557 Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
1558 scaled_inputs, perm_nhwc_to_nchw);
1559 Output weights = ops::Const(s.WithOpName("weights"),
1560 Input::Initializer(127.0f, {5, 5, 3, 4}));
1561 Output conv =
1562 ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
1563 "VALID", ops::Conv2D::DataFormat("NCHW"));
1564 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1565
1566 GrapplerItem item;
1567 item.fetch = {"outputs"};
1568 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1569
1570 // LOG(INFO) << "Before:\n" << item.graph.DebugString();
1571 GraphDef output;
1572 ArithmeticOptimizer optimizer;
1573 EnableOnlyFoldMultipleIntoConv(&optimizer);
1574 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1575
1576 // LOG(INFO) << "After:\n" << output.DebugString();
1577 NodeMap node_map(&output);
1578 // `conv` is now a folded convolution with scaled weights.
1579 const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
1580 ASSERT_NE(folded_conv, nullptr);
1581
1582 const NodeDef* folded_conv_weights =
1583 node_map.GetNode(folded_conv->input(1));
1584 ASSERT_NE(folded_conv_weights, nullptr);
1585 EXPECT_EQ(folded_conv_weights->op(), "Mul");
1586
1587 // Its input should be a transpose of `inputs`.
1588 const NodeDef* transpose =
1589 node_map.GetNode(NodeName(folded_conv->input(0)));
1590 ASSERT_NE(transpose, nullptr);
1591 ASSERT_EQ(transpose->input_size(), 2);
1592 EXPECT_EQ(transpose->input(0), "inputs");
1593 }
1594 }
1595
TEST_F(ArithmeticOptimizerTest,NotFoldMulAcrossPreservedTranspose)1596 TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
1597 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1598 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1599 ops::Placeholder::Shape({8, 28, 28, 3}));
1600 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1601 Output scaled_inputs =
1602 ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
1603 Output perm_nhwc_to_nchw =
1604 ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
1605 Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
1606 scaled_inputs, perm_nhwc_to_nchw);
1607 Output weights = ops::Const(s.WithOpName("weights"),
1608 Input::Initializer(127.0f, {5, 5, 3, 16}));
1609 Output conv =
1610 ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
1611 "VALID", ops::Conv2D::DataFormat("NCHW"));
1612 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1613
1614 Tensor inputs_nchw_tensor(DT_FLOAT, {8, 3, 28, 28});
1615 memset(const_cast<char*>(inputs_nchw_tensor.tensor_data().data()), 0,
1616 inputs_nchw_tensor.tensor_data().size());
1617
1618 GrapplerItem item;
1619 item.fetch = {"outputs"};
1620 item.feed = {{"inputs_nchw", inputs_nchw_tensor}};
1621 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1622
1623 GraphDef output;
1624 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1625
1626 item.graph.Swap(&output);
1627 TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
1628
1629 NodeMap node_map(&output);
1630 const NodeDef* inputs_nchw_node_def =
1631 node_map.GetNode(inputs_nchw.node()->name());
1632 ASSERT_NE(inputs_nchw_node_def, nullptr);
1633 ASSERT_EQ(inputs_nchw_node_def->input_size(), 2);
1634 EXPECT_EQ(NodeName(inputs_nchw_node_def->input(0)),
1635 scaled_inputs.node()->name());
1636 }
1637
TEST_F(ArithmeticOptimizerTest,FoldMulToConv)1638 TEST_F(ArithmeticOptimizerTest, FoldMulToConv) {
1639 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1640 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1641 ops::Placeholder::Shape({8, 28, 28, 28, 3}));
1642 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1643 Output scaled_inputs =
1644 ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
1645 Output weights = ops::Const(s.WithOpName("weights"),
1646 Input::Initializer(127.0f, {5, 5, 5, 3, 16}));
1647 Output conv = ops::Conv3D(s.WithOpName("conv"), scaled_inputs, weights,
1648 {1, 1, 1, 1, 1}, "VALID");
1649 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1650
1651 GrapplerItem item;
1652 item.fetch = {"outputs"};
1653 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1654
1655 GraphDef output;
1656 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1657
1658 item.graph.Swap(&output);
1659 TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
1660
1661 NodeMap node_map(&output);
1662 // `conv` is now a folded convolution on `inputs` and scaled weights.
1663 const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
1664 ASSERT_NE(folded_conv, nullptr);
1665 ASSERT_EQ(folded_conv->input_size(), 2);
1666 CHECK_EQ(NodeName(folded_conv->input(0)), inputs.node()->name());
1667 const NodeDef* folded_conv_input_1 =
1668 node_map.GetNode(NodeName(folded_conv->input(1)));
1669 ASSERT_NE(folded_conv_input_1, nullptr);
1670 CHECK_EQ(folded_conv_input_1->op(), "Mul");
1671 }
1672
TEST_F(ArithmeticOptimizerTest,OptimizeCastMulTransposeConv)1673 TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
1674 // This unit test exercises two optimizations, folding mul into conv, and
1675 // reordering cast and transpose.
1676 //
1677 // Conv2D(Transpose(Mul(Cast(I), S)), W)
1678 // =>
1679 // Conv2D(Transpose(Cast(I)), W*S)
1680 // =>
1681 // Conv2D(Cast(Transpose(I)), W*S)
1682 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
1683
1684 Output inputs =
1685 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1686 Output cast = ops::Cast(s, inputs, DT_FLOAT);
1687 Output mul = ops::Mul(s, cast, ops::Const(s, 1.0f / 255.0f));
1688 Output transpose =
1689 ops::Transpose(s, mul, ops::Const(s.WithOpName("perm"), {0, 3, 1, 2}));
1690 Output weights = ops::Const(s.WithOpName("weights"),
1691 Input::Initializer(127.0f, {5, 5, 3, 16}));
1692 Output conv = ops::Conv2D(s, transpose, weights, {1, 1, 1, 1}, "VALID",
1693 ops::Conv2D::DataFormat("NCHW"));
1694 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1695
1696 GrapplerItem item;
1697 item.fetch = {"outputs"};
1698 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1699
1700 GraphDef output;
1701 ArithmeticOptimizer optimizer; // all optimization stages are on
1702 OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
1703 NodeMap node_map(&output);
1704
1705 // Expected names for reordered cast and transpose.
1706 const string p = "ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_";
1707 const string optimized_cast_name = absl::StrCat(p, "float_Cast");
1708 const string optimized_transpose_name = absl::StrCat(p, "uint8_Transpose");
1709
1710 // Expected names for folded multiply and conv.
1711 const string optimized_weights =
1712 "ArithmeticOptimizer/FoldMultiplyIntoConv_scaled_Conv2D_weights";
1713
1714 const NodeDef* inputs_node = node_map.GetNode("Placeholder");
1715 const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name);
1716 const NodeDef* cast_node = node_map.GetNode(optimized_cast_name);
1717
1718 const NodeDef* weights_node = node_map.GetNode(optimized_weights);
1719 const NodeDef* conv_node = node_map.GetNode("Conv2D");
1720
1721 ASSERT_NE(inputs_node, nullptr);
1722 ASSERT_NE(transpose_node, nullptr);
1723 ASSERT_NE(cast_node, nullptr);
1724 ASSERT_NE(weights_node, nullptr);
1725 ASSERT_NE(conv_node, nullptr);
1726
1727 EXPECT_EQ(output.node_size(), 7);
1728 ASSERT_EQ(transpose_node->input_size(), 2);
1729 EXPECT_EQ(transpose_node->input(0), inputs_node->name());
1730 ASSERT_EQ(cast_node->input_size(), 1);
1731 EXPECT_EQ(cast_node->input(0), transpose_node->name());
1732 ASSERT_EQ(conv_node->input_size(), 2);
1733 EXPECT_EQ(conv_node->input(0), cast_node->name());
1734 EXPECT_EQ(conv_node->input(1), weights_node->name());
1735 }
1736
TEST_F(ArithmeticOptimizerTest,OptimizeMultipleMulTransposeConv)1737 TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
1738 // This unit test exercises optimization of folding mul into conv for
1739 // multiple nodes in the graph.
1740 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
1741
1742 GrapplerItem item;
1743 Output conv[2];
1744
1745 for (int i = 0; i < 2; ++i) {
1746 Output inputs =
1747 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28}));
1748 Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f));
1749 Output weights = ops::Const(s.WithOpName("weights"),
1750 Input::Initializer(127.0f, {5, 5, 3, 16}));
1751 conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID",
1752 ops::Conv2D::DataFormat("NCHW"));
1753 }
1754 Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]);
1755
1756 item.fetch = {"outputs"};
1757 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1758
1759 GraphDef output;
1760 ArithmeticOptimizer optimizer;
1761 EnableOnlyFoldMultipleIntoConv(&optimizer);
1762 OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
1763
1764 NodeMap node_map(&output);
1765
1766 using absl::StrCat;
1767 const string p = "ArithmeticOptimizer/FoldMultiplyIntoConv_";
1768 const string optimized_weights = StrCat(p, "scaled_Conv2D_weights");
1769 const string optimized_weights_1 = StrCat(p, "scaled_Conv2D_1_weights_1");
1770
1771 const NodeDef* weights_node = node_map.GetNode(optimized_weights);
1772 const NodeDef* weights_node_1 = node_map.GetNode(optimized_weights_1);
1773 const NodeDef* conv_node = node_map.GetNode("Conv2D");
1774 const NodeDef* conv_node_1 = node_map.GetNode("Conv2D_1");
1775
1776 ASSERT_NE(weights_node, nullptr);
1777 ASSERT_NE(weights_node_1, nullptr);
1778 ASSERT_NE(conv_node, nullptr);
1779 ASSERT_NE(conv_node_1, nullptr);
1780
1781 ASSERT_EQ(conv_node->input_size(), 2);
1782 ASSERT_EQ(conv_node_1->input_size(), 2);
1783 EXPECT_EQ(conv_node->input(1), weights_node->name());
1784 EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
1785 }
1786
TEST_F(ArithmeticOptimizerTest,CombineBitcasts)1787 TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
1788 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1789 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8,
1790 ops::Placeholder::Shape({2, 3}));
1791 Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8);
1792 Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8);
1793 Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
1794
1795 GrapplerItem item;
1796 item.fetch = {"outputs"};
1797 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1798
1799 auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3}));
1800 item.feed = {{"inputs", x_t}};
1801 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1802 ASSERT_EQ(tensors_expected.size(), 1);
1803
1804 GraphDef output;
1805 ArithmeticOptimizer optimizer;
1806 EnableOnlyRemoveRedundantBitcast(&optimizer);
1807
1808 OptimizeAndPrune(&optimizer, &item, &output);
1809 NodeMap node_map(&output);
1810
1811 // Bitcasts combined into a single op and inputs redirected to updated Bitcast
1812 EXPECT_EQ(output.node_size(), 3);
1813 EXPECT_EQ(CountOpNodes(output, "Bitcast"), 1);
1814 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
1815
1816 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1817 ASSERT_EQ(tensors.size(), 1);
1818 test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
1819 }
1820
TEST_F(ArithmeticOptimizerTest,CombineAndRemoveBitcasts)1821 TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
1822 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1823 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
1824 ops::Placeholder::Shape({2, 3}));
1825 Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
1826 Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
1827 Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
1828
1829 GrapplerItem item;
1830 item.fetch = {"outputs"};
1831 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1832
1833 auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
1834 item.feed = {{"inputs", x_t}};
1835 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1836 ASSERT_EQ(tensors_expected.size(), 1);
1837
1838 GraphDef output;
1839 ArithmeticOptimizer optimizer;
1840 EnableOnlyRemoveRedundantBitcast(&optimizer);
1841
1842 OptimizeAndPrune(&optimizer, &item, &output);
1843 NodeMap node_map(&output);
1844
1845 // Bitcasts removed and inputs redirected to outputs
1846 EXPECT_EQ(output.node_size(), 2);
1847 EXPECT_EQ(CountOpNodes(output, "Bitcast"), 0);
1848 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
1849
1850 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1851 ASSERT_EQ(tensors.size(), 1);
1852 test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
1853 }
1854
TEST_F(ArithmeticOptimizerTest,RemoveRedundantCast)1855 TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
1856 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1857 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
1858 ops::Placeholder::Shape({2, 3}));
1859 Output cast = ops::Cast(s, inputs, DT_INT8);
1860 Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
1861
1862 GrapplerItem item;
1863 item.fetch = {"outputs"};
1864 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1865
1866 auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
1867 item.feed = {{"inputs", x_t}};
1868 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1869 ASSERT_EQ(tensors_expected.size(), 1);
1870
1871 GraphDef output;
1872 ArithmeticOptimizer optimizer;
1873 EnableOnlyRemoveRedundantCast(&optimizer);
1874
1875 OptimizeAndPrune(&optimizer, &item, &output);
1876 NodeMap node_map(&output);
1877
1878 // Cast removed and inputs redirected to outputs
1879 EXPECT_EQ(output.node_size(), 2);
1880 EXPECT_EQ(CountOpNodes(output, "Cast"), 0);
1881 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
1882
1883 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1884 ASSERT_EQ(tensors.size(), 1);
1885 test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
1886 }
1887
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfIdenticalShape)1888 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfIdenticalShape) {
1889 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1890 tensorflow::Scope sx = s.NewSubScope("x");
1891 tensorflow::Scope sy = s.NewSubScope("y");
1892
1893 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
1894 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
1895 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
1896 auto add_bc = ops::Add(sx.WithOpName("Add_bc"), b, c);
1897 auto add_abc = ops::Add(sy.WithOpName("Add_abc"), a, add_bc);
1898
1899 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
1900
1901 GrapplerItem item;
1902 item.fetch = {"outputs"};
1903 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1904
1905 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1906 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1907 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1908 std::vector<std::pair<string, Tensor>> feed = {
1909 {"a", a_t}, {"b", b_t}, {"c", c_t}};
1910 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
1911 ASSERT_EQ(tensors_expected.size(), 1);
1912
1913 GraphDef output;
1914 ArithmeticOptimizer optimizer;
1915 EnableOnlyAddToAddNCombining(&optimizer);
1916
1917 OptimizeAndPrune(&optimizer, &item, &output);
1918
1919 // We expect the following rewrite(s) to occur:
1920 //
1921 // +
1922 // / \
1923 // a + --> AddN(a, b, c)
1924 // / \
1925 // b c
1926 EXPECT_EQ(output.node_size(), 5);
1927
1928 NodeMap node_map(&output);
1929
1930 // check add tree was replaced with AddN
1931 const NodeDef* collapsed_add =
1932 node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc");
1933 ASSERT_NE(collapsed_add, nullptr);
1934
1935 EXPECT_EQ(collapsed_add->op(), "AddN");
1936 ASSERT_EQ(collapsed_add->input_size(), 3);
1937 EXPECT_EQ(collapsed_add->input(0), "a");
1938 EXPECT_EQ(collapsed_add->input(1), "b");
1939 EXPECT_EQ(collapsed_add->input(2), "c");
1940
1941 // check output was re-wired to new node
1942 const NodeDef* updated_outputs = node_map.GetNode("outputs");
1943 ASSERT_NE(updated_outputs, nullptr);
1944 ASSERT_EQ(updated_outputs->input_size(), 1);
1945 EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
1946
1947 auto tensors = EvaluateNodes(output, item.fetch, feed);
1948 ASSERT_EQ(tensors.size(), 1);
1949 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1950 }
1951
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMultiplePasses)1952 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
1953 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1954
1955 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
1956 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
1957 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
1958 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
1959 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
1960
1961 auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
1962 auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
1963 auto z = ops::Variable(s.WithOpName("z"), {2, 2}, DT_FLOAT);
1964 auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
1965 auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
1966
1967 auto mul = ops::Multiply(s.WithOpName("Mul"), add_abc, add_xyz);
1968 auto outputs = ops::Identity(s.WithOpName("outputs"), mul);
1969
1970 GrapplerItem item;
1971 item.fetch = {"outputs"};
1972 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1973
1974 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1975 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1976 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1977 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1978 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1979 auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1980 std::vector<std::pair<string, Tensor>> feed = {
1981 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
1982 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
1983 ASSERT_EQ(tensors_expected.size(), 1);
1984
1985 GraphDef output;
1986 ArithmeticOptimizer optimizer;
1987 EnableOnlyAddToAddNCombining(&optimizer);
1988
1989 OptimizeAndPrune(&optimizer, &item, &output);
1990
1991 // We expect the following rewrite(s) to occur:
1992 //
1993 // *
1994 // / \
1995 // + + *
1996 // / \ / \ / \
1997 // + c x + --> AddN(a, b, c) AddN(x, y, z))
1998 // / \ / \
1999 // a b y z
2000 EXPECT_EQ(output.node_size(), 10);
2001
2002 NodeMap node_map(&output);
2003
2004 // check left Add subtree replaced with AddN
2005 const NodeDef* collapsed_left =
2006 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2007 ASSERT_NE(collapsed_left, nullptr);
2008
2009 EXPECT_EQ(collapsed_left->op(), "AddN");
2010 ASSERT_EQ(collapsed_left->input_size(), 3);
2011 EXPECT_EQ(collapsed_left->input(0), "a");
2012 EXPECT_EQ(collapsed_left->input(1), "b");
2013 EXPECT_EQ(collapsed_left->input(2), "c");
2014
2015 // check right Add subtree replaced with AddN
2016 const NodeDef* collapsed_right =
2017 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz");
2018 ASSERT_NE(collapsed_right, nullptr);
2019
2020 EXPECT_EQ(collapsed_right->op(), "AddN");
2021 ASSERT_EQ(collapsed_right->input_size(), 3);
2022 EXPECT_EQ(collapsed_right->input(0), "x");
2023 EXPECT_EQ(collapsed_right->input(1), "y");
2024 EXPECT_EQ(collapsed_right->input(2), "z");
2025
2026 // check that Mul inputs re-wired to new Nodes
2027 const NodeDef* updated_mul = node_map.GetNode("Mul");
2028 ASSERT_NE(updated_mul, nullptr);
2029
2030 EXPECT_EQ(updated_mul->op(), "Mul");
2031 ASSERT_EQ(updated_mul->input_size(), 2);
2032 EXPECT_EQ(updated_mul->input(0), collapsed_left->name());
2033 EXPECT_EQ(updated_mul->input(1), collapsed_right->name());
2034
2035 auto tensors = EvaluateNodes(output, item.fetch, feed);
2036 ASSERT_EQ(tensors.size(), 1);
2037 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2038 }
2039
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddInputMultipleTimes)2040 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputMultipleTimes) {
2041 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2042
2043 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2044 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2045 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2046 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2047 auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
2048 auto add_all = ops::Add(s.WithOpName("Add_all"), add_ab, add_bc);
2049 auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2050
2051 GrapplerItem item;
2052 item.fetch = {"outputs"};
2053 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2054
2055 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2056 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2057 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2058 std::vector<std::pair<string, Tensor>> feed = {
2059 {"a", a_t}, {"b", b_t}, {"c", c_t}};
2060 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2061 ASSERT_EQ(tensors_expected.size(), 1);
2062
2063 GraphDef output;
2064 ArithmeticOptimizer optimizer;
2065 EnableOnlyAddToAddNCombining(&optimizer);
2066
2067 OptimizeAndPrune(&optimizer, &item, &output);
2068
2069 // We expect the following rewrite(s) to occur:
2070 //
2071 // +
2072 // / \
2073 // + + --> AddN(a, b, b, c)
2074 // / \ / \ ^
2075 // a b c b added twice!
2076 EXPECT_EQ(output.node_size(), 5);
2077
2078 NodeMap node_map(&output);
2079
2080 // check Add tree replaced with AddN
2081 const NodeDef* collapsed_add =
2082 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_all");
2083 ASSERT_NE(collapsed_add, nullptr);
2084
2085 EXPECT_EQ(collapsed_add->op(), "AddN");
2086 ASSERT_EQ(collapsed_add->input_size(), 4);
2087 EXPECT_EQ(collapsed_add->input(0), "a");
2088 EXPECT_EQ(collapsed_add->input(1), "b");
2089 EXPECT_EQ(collapsed_add->input(2), "b");
2090 EXPECT_EQ(collapsed_add->input(3), "c");
2091
2092 auto tensors = EvaluateNodes(output, item.fetch, feed);
2093 ASSERT_EQ(tensors.size(), 1);
2094 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2095 }
2096
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfSymbolicallyEqualShape)2097 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfSymbolicallyEqualShape) {
2098 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2099
2100 // unknown input shape propagated symbolically through the graph
2101 auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
2102
2103 // [a, b, c] have symbolically equal shapes
2104 auto a = ops::Sqrt(s.WithOpName("a"), input);
2105 auto b = ops::Square(s.WithOpName("b"), input);
2106 auto c = ops::Round(s.WithOpName("c"), input);
2107
2108 // [add_ab, add_abc] shape must be inferred from inputs
2109 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2110 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2111
2112 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2113
2114 GrapplerItem item;
2115 item.fetch = {"outputs"};
2116 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2117
2118 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2119 std::vector<std::pair<string, Tensor>> feed = {{"input", x_t}};
2120 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2121 ASSERT_EQ(tensors_expected.size(), 1);
2122
2123 GraphDef output;
2124 ArithmeticOptimizer optimizer;
2125 EnableOnlyAddToAddNCombining(&optimizer);
2126
2127 OptimizeAndPrune(&optimizer, &item, &output);
2128
2129 // We expect the following rewrite(s) to occur:
2130 //
2131 // +
2132 // / \
2133 // + c --> AddN(a, b, c)
2134 // / \
2135 // a b
2136 EXPECT_EQ(output.node_size(), 6);
2137
2138 NodeMap node_map(&output);
2139
2140 // check add tree was replaced with AddN
2141 const NodeDef* collapsed_add =
2142 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2143 ASSERT_NE(collapsed_add, nullptr);
2144 EXPECT_EQ(collapsed_add->op(), "AddN");
2145 ASSERT_EQ(collapsed_add->input_size(), 3);
2146 EXPECT_EQ(collapsed_add->input(0), "a");
2147 EXPECT_EQ(collapsed_add->input(1), "b");
2148 EXPECT_EQ(collapsed_add->input(2), "c");
2149
2150 // check output was re-wired to new node
2151 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2152 ASSERT_NE(updated_outputs, nullptr);
2153 ASSERT_EQ(updated_outputs->input_size(), 1);
2154 EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
2155
2156 auto tensors = EvaluateNodes(output, item.fetch, feed);
2157 ASSERT_EQ(tensors.size(), 1);
2158 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2159 }
2160
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCast)2161 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCast) {
2162 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2163
2164 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2165 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2166 auto c = ops::Variable(s.WithOpName("c"), {32, 32, 32}, DT_FLOAT);
2167 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2168 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2169
2170 auto x = ops::Variable(s.WithOpName("x"), {32}, DT_FLOAT);
2171 auto y = ops::Variable(s.WithOpName("y"), {32, 32}, DT_FLOAT);
2172 auto z = ops::Variable(s.WithOpName("z"), {32, 32, 32}, DT_FLOAT);
2173 auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2174 auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2175
2176 auto add_all = ops::Add(s.WithOpName("AddAll"), add_abc, add_xyz);
2177 auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2178
2179 GrapplerItem item;
2180 item.fetch = {"outputs"};
2181 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2182
2183 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2184 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2185 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2186 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2187 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2188 auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2189 std::vector<std::pair<string, Tensor>> feed = {
2190 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2191 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2192 ASSERT_EQ(tensors_expected.size(), 1);
2193
2194 GraphDef output;
2195 ArithmeticOptimizer optimizer;
2196 EnableOnlyAddToAddNCombining(&optimizer);
2197
2198 OptimizeAndPrune(&optimizer, &item, &output);
2199
2200 // We expect the following rewrite(s) to occur:
2201 // 1) [a, x], [b, y], [c, z] - aggregate same shapes first
2202 // 2) Build an aggregation tree minimizing cost of broadcast
2203 //
2204 // + +
2205 // / \ / \
2206 // + + + AddN(c, z)
2207 // / \ / \ / \
2208 // + c x + --> AddN(a, x) AddN(b, y)
2209 // / \ / \
2210 // a b y z
2211 EXPECT_EQ(output.node_size(), 12);
2212 NodeMap node_map(&output);
2213
2214 // expected names of outer and inner nodes
2215 string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_AddAll";
2216 string outer_0_add_name =
2217 "ArithmeticOptimizer/AddOpsRewrite_Internal_0_AddAll";
2218 string inner_0_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_AddAll";
2219 string inner_1_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_1_AddAll";
2220 string inner_2_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_2_AddAll";
2221
2222 // Add [a, x] first
2223 const NodeDef* add_ax_node = node_map.GetNode(inner_0_add_name);
2224 ASSERT_NE(add_ax_node, nullptr);
2225 EXPECT_EQ(add_ax_node->op(), "AddN");
2226 ASSERT_EQ(add_ax_node->input_size(), 2);
2227 EXPECT_EQ(add_ax_node->input(0), "a");
2228 EXPECT_EQ(add_ax_node->input(1), "x");
2229
2230 // Then add [b, y]
2231 const NodeDef* add_by_node = node_map.GetNode(inner_1_add_name);
2232 ASSERT_NE(add_by_node, nullptr);
2233 EXPECT_EQ(add_by_node->op(), "AddN");
2234 ASSERT_EQ(2, add_by_node->input_size());
2235 EXPECT_EQ(add_by_node->input(0), "b");
2236 EXPECT_EQ(add_by_node->input(1), "y");
2237
2238 // Then add [c, z]
2239 const NodeDef* add_cz_node = node_map.GetNode(inner_2_add_name);
2240 ASSERT_NE(add_cz_node, nullptr);
2241 EXPECT_EQ(add_cz_node->op(), "AddN");
2242 ASSERT_EQ(add_cz_node->input_size(), 2);
2243 EXPECT_EQ(add_cz_node->input(0), "c");
2244 EXPECT_EQ(add_cz_node->input(1), "z");
2245
2246 // Then add results together starting from smaller shapes [a, x] + [b, y]
2247 const NodeDef* outer_0_node = node_map.GetNode(outer_0_add_name);
2248 ASSERT_NE(outer_0_node, nullptr);
2249 EXPECT_EQ(outer_0_node->op(), "AddV2");
2250 ASSERT_EQ(outer_0_node->input_size(), 2);
2251 EXPECT_EQ(outer_0_node->input(0), inner_0_add_name);
2252 EXPECT_EQ(outer_0_node->input(1), inner_1_add_name);
2253
2254 // And finally top level Add node
2255 const NodeDef* outer_node = node_map.GetNode(outer_add_name);
2256 ASSERT_NE(outer_node, nullptr);
2257 EXPECT_EQ(outer_node->op(), "AddV2");
2258 ASSERT_EQ(outer_node->input_size(), 2);
2259 EXPECT_EQ(outer_node->input(0), outer_0_add_name);
2260 EXPECT_EQ(outer_node->input(1), inner_2_add_name);
2261
2262 // And outputs reading new top level Add node
2263 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2264 ASSERT_NE(updated_outputs, nullptr);
2265 ASSERT_EQ(updated_outputs->input_size(), 1);
2266 EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2267
2268 auto tensors = EvaluateNodes(output, item.fetch, feed);
2269 ASSERT_EQ(tensors.size(), 1);
2270 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2271 }
2272
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCastWithSymbolicShapes)2273 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCastWithSymbolicShapes) {
2274 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2275
2276 // We have a small input with one unknown dimension
2277 auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_DOUBLE);
2278
2279 // And second input which is larger, but has the same unknown dimension
2280 // device spec prevents this node from rewriting
2281 auto d = "/device:CPU:0";
2282 auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_DOUBLE);
2283 auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v);
2284
2285 // [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32}
2286 auto a = ops::Sqrt(s.WithOpName("a"), small);
2287 auto b = ops::Square(s.WithOpName("b"), large);
2288 auto c = ops::Round(s.WithOpName("c"), small);
2289
2290 // [add_ab, add_abc] shape must be inferred from inputs
2291 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2292 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2293
2294 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2295
2296 GrapplerItem item;
2297 item.fetch = {"outputs"};
2298 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2299
2300 auto s_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({8, 1, 1}));
2301 auto v_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({1, 32, 32}));
2302 std::vector<std::pair<string, Tensor>> feed = {{"small", s_t}, {"v", v_t}};
2303 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2304 ASSERT_EQ(tensors_expected.size(), 1);
2305
2306 GraphDef output;
2307 ArithmeticOptimizer optimizer;
2308 EnableOnlyAddToAddNCombining(&optimizer);
2309 OptimizeAndPrune(&optimizer, &item, &output);
2310
2311 // We expect the following rewrite(s) to occur: it's much cheaper to add small
2312 // tensors, and do the broadcast just once
2313 //
2314 // + +
2315 // / \ / \
2316 // + c --> + b
2317 // / \ / \
2318 // a b a c
2319 EXPECT_EQ(output.node_size(), 9);
2320 NodeMap node_map(&output);
2321
2322 // expected names of outer and inner nodes
2323 string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_Add_abc";
2324 string inner_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_Add_abc";
2325
2326 // outer Add node
2327 const NodeDef* outer_add = node_map.GetNode(outer_add_name);
2328 ASSERT_NE(outer_add, nullptr);
2329 EXPECT_EQ(outer_add->op(), "AddV2");
2330 ASSERT_EQ(outer_add->input_size(), 2);
2331 EXPECT_EQ(outer_add->input(0), inner_add_name);
2332 EXPECT_EQ(outer_add->input(1), "b");
2333
2334 // inner AddN node
2335 const NodeDef* inner_add = node_map.GetNode(inner_add_name);
2336 ASSERT_NE(inner_add, nullptr);
2337 ASSERT_EQ(inner_add->input_size(), 2);
2338 EXPECT_EQ(inner_add->input(0), "a");
2339 EXPECT_EQ(inner_add->input(1), "c");
2340
2341 // check output was re-wired to new node
2342 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2343 ASSERT_NE(updated_outputs, nullptr);
2344 ASSERT_EQ(updated_outputs->input_size(), 1);
2345 EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2346
2347 auto tensors = EvaluateNodes(output, item.fetch, feed);
2348 ASSERT_EQ(tensors.size(), 1);
2349 test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
2350 }
2351
TEST_F(ArithmeticOptimizerTest,RemoveNegation)2352 TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
2353 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2354 auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2355 auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2356 Output neg_x = ops::Neg(s.WithOpName("Neg_x"), x);
2357 Output neg_y = ops::Neg(s.WithOpName("Neg_y"), y);
2358 Output add_x_y = ops::Add(s.WithOpName("Add_x_y"), x, y);
2359 Output add_negx_y = ops::Add(s.WithOpName("Add_negx_y"), neg_x, y);
2360 Output add_x_negy = ops::Add(s.WithOpName("Add_x_negy"), x, neg_y);
2361 Output add_negx_negy = ops::Add(s.WithOpName("Add_negx_negy"), neg_x, neg_y);
2362 Output sub_x_y = ops::Sub(s.WithOpName("Sub_x_y"), x, y);
2363 Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
2364 Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
2365 Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
2366 Output neg_x_with_dep = ops::Neg(
2367 s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
2368 Output add_negx_with_dep_y =
2369 ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
2370 auto add_all =
2371 ops::AddN(s.WithOpName("add_all"),
2372 {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
2373 sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
2374
2375 GrapplerItem item;
2376 item.fetch = {"add_all"};
2377 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2378
2379 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2380 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2381 std::vector<std::pair<string, Tensor>> feed = {{"x", x_t}, {"y", y_t}};
2382 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2383 ASSERT_EQ(tensors_expected.size(), 1);
2384
2385 GraphDef output;
2386 ArithmeticOptimizer optimizer;
2387 EnableOnlyRemoveNegation(&optimizer);
2388 OptimizeTwice(&optimizer, &item, &output);
2389
2390 EXPECT_EQ(output.node_size(), item.graph.node_size());
2391 int found = 0;
2392 for (int i = 0; i < output.node_size(); ++i) {
2393 const NodeDef& node = output.node(i);
2394 if (node.name() == "Add_negx_y") {
2395 ++found;
2396 EXPECT_EQ(node.op(), "Sub");
2397 ASSERT_EQ(node.input_size(), 2);
2398 EXPECT_EQ(node.input(0), "y");
2399 EXPECT_EQ(node.input(1), "x");
2400 } else if (node.name() == "Add_x_negy") {
2401 ++found;
2402 EXPECT_EQ(node.op(), "Sub");
2403 ASSERT_EQ(node.input_size(), 2);
2404 EXPECT_EQ(node.input(0), "x");
2405 EXPECT_EQ(node.input(1), "y");
2406 } else if (node.name() == "Add_negx_negy") {
2407 ++found;
2408 EXPECT_EQ(node.op(), "Sub");
2409 ASSERT_EQ(node.input_size(), 2);
2410 EXPECT_EQ(node.input(0), "Neg_x");
2411 EXPECT_EQ(node.input(1), "y");
2412 } else if (node.name() == "Sub_x_negy") {
2413 ++found;
2414 EXPECT_EQ(node.op(), "AddV2");
2415 ASSERT_EQ(node.input_size(), 2);
2416 EXPECT_EQ(node.input(0), "x");
2417 EXPECT_EQ(node.input(1), "y");
2418 } else if (node.name() == "Sub_negx_negy") {
2419 ++found;
2420 EXPECT_EQ(node.op(), "Sub");
2421 ASSERT_EQ(node.input_size(), 2);
2422 EXPECT_EQ(node.input(0), "y");
2423 EXPECT_EQ(node.input(1), "x");
2424 } else if (node.name() == "Add_negx_with_dep_y") {
2425 ++found;
2426 EXPECT_EQ(node.op(), "Sub");
2427 ASSERT_EQ(node.input_size(), 3);
2428 EXPECT_EQ(node.input(0), "y");
2429 EXPECT_EQ(node.input(1), "x");
2430 EXPECT_EQ(node.input(2), "^Add_x_y");
2431 }
2432 }
2433 EXPECT_EQ(found, 6);
2434
2435 auto tensors = EvaluateNodes(output, item.fetch, feed);
2436 ASSERT_EQ(tensors.size(), 1);
2437 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2438 }
2439
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMul)2440 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
2441 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2442 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2443 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2444 Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2445 Output div_x_sqrt_y = ops::Div(s.WithOpName("output"), x, sqrt_y);
2446
2447 GrapplerItem item;
2448 item.fetch = {"output"};
2449 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2450 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2451 ASSERT_EQ(tensors_expected.size(), 1);
2452
2453 GraphDef output;
2454 ArithmeticOptimizer optimizer;
2455 EnableOnlySqrtDivToRsqrtMul(&optimizer);
2456 OptimizeAndPrune(&optimizer, &item, &output);
2457 auto tensors = EvaluateNodes(output, item.fetch);
2458 ASSERT_EQ(tensors.size(), 1);
2459
2460 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2461 EXPECT_EQ(output.node_size(), item.graph.node_size());
2462 for (int i = 0; i < output.node_size(); ++i) {
2463 const NodeDef& node = output.node(i);
2464 if (node.name() == "output") {
2465 EXPECT_EQ(node.op(), "Mul");
2466 ASSERT_EQ(node.input_size(), 2);
2467 EXPECT_EQ(node.input(0), "x");
2468 EXPECT_EQ(node.input(1), "sqrt_y");
2469 } else if (node.name() == "sqrt_y") {
2470 EXPECT_EQ(node.op(), "Rsqrt");
2471 ASSERT_EQ(node.input_size(), 1);
2472 EXPECT_EQ(node.input(0), "y");
2473 }
2474 }
2475 }
2476
TEST_F(ArithmeticOptimizerTest,DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode)2477 TEST_F(ArithmeticOptimizerTest, DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode) {
2478 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2479 Output floats = ops::Const(s.WithOpName("floats"),
2480 {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2481 Output output0 = ops::Sqrt(s.WithOpName("output0"), floats);
2482 Output const1 = ops::Const(s.WithOpName("const1"), 1.0f, {3});
2483 Output mul1 = ops::Multiply(s.WithOpName("mul1"), const1, 0.5f);
2484 Output grad = ops::Div(s.WithOpName("grad"), mul1, output0);
2485
2486 GrapplerItem item;
2487 item.fetch = {"grad", "output0"};
2488 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2489 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2490 ASSERT_EQ(tensors_expected.size(), 2);
2491
2492 GraphDef output;
2493 ArithmeticOptimizer optimizer;
2494 EnableOnlySqrtDivToRsqrtMul(&optimizer);
2495 OptimizeAndPrune(&optimizer, &item, &output);
2496 auto tensors = EvaluateNodes(output, item.fetch);
2497 ASSERT_EQ(tensors.size(), 2);
2498
2499 for (int i = 0; i < tensors.size(); i++) {
2500 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2501 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2502 }
2503 EXPECT_EQ(output.node_size(), item.graph.node_size());
2504 for (int i = 0; i < output.node_size(); ++i) {
2505 const NodeDef& node = output.node(i);
2506 if (node.name() == "grad") {
2507 EXPECT_EQ(node.op(), "Div");
2508 ASSERT_EQ(node.input_size(), 2);
2509 EXPECT_EQ(node.input(0), "mul1");
2510 EXPECT_EQ(node.input(1), "output0");
2511 } else if (node.name() == "output0") {
2512 EXPECT_EQ(node.op(), "Sqrt");
2513 ASSERT_EQ(node.input_size(), 1);
2514 EXPECT_EQ(node.input(0), "floats");
2515 }
2516 }
2517 }
2518
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMulExcludeFloorDiv)2519 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMulExcludeFloorDiv) {
2520 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2521 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2522 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2523 Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2524 Output div_x_sqrt_y = ops::FloorDiv(s.WithOpName("output"), x, sqrt_y);
2525
2526 GrapplerItem item;
2527 item.fetch = {"output"};
2528 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2529 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2530 ASSERT_EQ(tensors_expected.size(), 1);
2531
2532 GraphDef output;
2533 ArithmeticOptimizer optimizer;
2534 EnableOnlySqrtDivToRsqrtMul(&optimizer);
2535 OptimizeAndPrune(&optimizer, &item, &output);
2536 auto tensors = EvaluateNodes(output, item.fetch);
2537 ASSERT_EQ(tensors.size(), 1);
2538
2539 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2540 EXPECT_EQ(output.node_size(), item.graph.node_size());
2541 for (int i = 0; i < output.node_size(); ++i) {
2542 const NodeDef& node = output.node(i);
2543 if (node.name() == "output") {
2544 EXPECT_EQ(node.op(), "FloorDiv");
2545 ASSERT_EQ(node.input_size(), 2);
2546 EXPECT_EQ(node.input(0), "x");
2547 EXPECT_EQ(node.input(1), "sqrt_y");
2548 } else if (node.name() == "sqrt_y") {
2549 EXPECT_EQ(node.op(), "Sqrt");
2550 ASSERT_EQ(node.input_size(), 1);
2551 EXPECT_EQ(node.input(0), "y");
2552 }
2553 }
2554 }
2555
TEST_F(ArithmeticOptimizerTest,FuseSquaredDiff)2556 TEST_F(ArithmeticOptimizerTest, FuseSquaredDiff) {
2557 for (bool is_complex : {false, true}) {
2558 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2559 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2560 Output y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2561 Output complex_x = ops::Complex(s.WithOpName("complex_x"), x, x);
2562 Output complex_y = ops::Complex(s.WithOpName("complex_y"), y, y);
2563 Output sub_x_y =
2564 is_complex ? ops::Sub(s.WithOpName("sub_x_y"), complex_x, complex_y)
2565 : ops::Sub(s.WithOpName("sub_x_y"), x, y);
2566 Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
2567
2568 GrapplerItem item;
2569 item.fetch = {"output"};
2570 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2571 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2572 ASSERT_EQ(tensors_expected.size(), 1);
2573
2574 GraphDef output;
2575 ArithmeticOptimizer optimizer;
2576 EnableOnlyFuseSquaredDiff(&optimizer);
2577 OptimizeAndPrune(&optimizer, &item, &output);
2578 const auto tensors = EvaluateNodes(output, item.fetch);
2579 ASSERT_EQ(tensors.size(), 1);
2580
2581 if (is_complex) {
2582 test::ExpectTensorNear<std::complex<float>>(tensors[0],
2583 tensors_expected[0], 1e-6);
2584 EXPECT_EQ(output.node_size(), item.graph.node_size());
2585 } else {
2586 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2587 // The two unused Complex nodes should get pruned.
2588 EXPECT_EQ(output.node_size(), item.graph.node_size() - 2);
2589 }
2590 for (int i = 0; i < output.node_size(); ++i) {
2591 const NodeDef& node = output.node(i);
2592 if (node.name() == "output") {
2593 EXPECT_EQ(node.op(), is_complex ? "Square" : "Identity");
2594 ASSERT_EQ(node.input_size(), 1);
2595 EXPECT_EQ(node.input(0), "sub_x_y");
2596 } else if (node.name() == "sub_x_y") {
2597 EXPECT_EQ(node.op(), is_complex ? "Sub" : "SquaredDifference");
2598 ASSERT_EQ(node.input_size(), 2);
2599 EXPECT_EQ(node.input(0), is_complex ? "complex_x" : "x");
2600 EXPECT_EQ(node.input(1), is_complex ? "complex_y" : "y");
2601 }
2602 }
2603 }
2604 }
2605
TEST_F(ArithmeticOptimizerTest,DoNotFuseSquaredDiffFetchNode)2606 TEST_F(ArithmeticOptimizerTest, DoNotFuseSquaredDiffFetchNode) {
2607 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2608 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2609 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2610 Output sub_x_y = ops::Sub(s.WithOpName("sub_x_y"), x, y);
2611 Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
2612
2613 GrapplerItem item;
2614 item.fetch = {"output", "sub_x_y"};
2615 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2616 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2617 ASSERT_EQ(tensors_expected.size(), 2);
2618
2619 GraphDef output;
2620 ArithmeticOptimizer optimizer;
2621 EnableOnlyFuseSquaredDiff(&optimizer);
2622 OptimizeAndPrune(&optimizer, &item, &output);
2623 const auto tensors = EvaluateNodes(output, item.fetch);
2624 ASSERT_EQ(tensors.size(), 2);
2625
2626 for (int i = 0; i < tensors.size(); i++) {
2627 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2628 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2629 }
2630 EXPECT_EQ(output.node_size(), item.graph.node_size());
2631 for (int i = 0; i < output.node_size(); ++i) {
2632 const NodeDef& node = output.node(i);
2633 if (node.name() == "output") {
2634 EXPECT_EQ(node.op(), "Square");
2635 ASSERT_EQ(node.input_size(), 1);
2636 EXPECT_EQ(node.input(0), "sub_x_y");
2637 } else if (node.name() == "sub_x_y") {
2638 EXPECT_EQ(node.op(), "Sub");
2639 ASSERT_EQ(node.input_size(), 2);
2640 EXPECT_EQ(node.input(0), "x");
2641 EXPECT_EQ(node.input(1), "y");
2642 }
2643 }
2644 }
2645
TEST_F(ArithmeticOptimizerTest,ConvertLogSoftmax)2646 TEST_F(ArithmeticOptimizerTest, ConvertLogSoftmax) {
2647 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2648 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2649 Output softmax = ops::Softmax(s.WithOpName("softmax"), x);
2650 Output logsoftmax = ops::Log(s.WithOpName("output"), softmax);
2651
2652 GrapplerItem item;
2653 item.fetch = {"output"};
2654 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2655 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2656 ASSERT_EQ(tensors_expected.size(), 1);
2657
2658 GraphDef output;
2659 ArithmeticOptimizer optimizer;
2660 EnableOnlyLogSoftmax(&optimizer);
2661 OptimizeAndPrune(&optimizer, &item, &output);
2662 const auto tensors = EvaluateNodes(output, item.fetch);
2663 ASSERT_EQ(tensors.size(), 1);
2664
2665 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2666 EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
2667 for (int i = 0; i < output.node_size(); ++i) {
2668 const NodeDef& node = output.node(i);
2669 if (node.name() == "output") {
2670 EXPECT_EQ(node.op(), "LogSoftmax");
2671 ASSERT_EQ(node.input_size(), 1);
2672 EXPECT_EQ(node.input(0), "x");
2673 }
2674 }
2675 }
2676
TEST_F(ArithmeticOptimizerTest,DoNotConvertLogSoftmaxArgFetchNode)2677 TEST_F(ArithmeticOptimizerTest, DoNotConvertLogSoftmaxArgFetchNode) {
2678 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2679 Output floats = ops::Const(s.WithOpName("floats"),
2680 {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2681 Output softmax = ops::Softmax(s.WithOpName("softmax"), floats);
2682 Output final_output = ops::Log(s.WithOpName("final_output"), softmax);
2683
2684 GrapplerItem item;
2685 item.fetch = {"softmax", "final_output"};
2686 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2687 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2688 ASSERT_EQ(tensors_expected.size(), 2);
2689
2690 GraphDef output;
2691 ArithmeticOptimizer optimizer;
2692 EnableOnlyLogSoftmax(&optimizer);
2693 OptimizeTwice(&optimizer, &item, &output);
2694 const auto tensors = EvaluateNodes(output, item.fetch);
2695 ASSERT_EQ(tensors.size(), 2);
2696
2697 // Should be a NoOp since we are not allowed to change the output of fetch
2698 // nodes.
2699 VerifyGraphsMatch(item.graph, output, __LINE__);
2700
2701 for (int i = 0; i < tensors.size(); i++) {
2702 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2703 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2704 }
2705 }
2706
TEST_F(ArithmeticOptimizerTest,ConvertPow)2707 TEST_F(ArithmeticOptimizerTest, ConvertPow) {
2708 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2709 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2710 auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
2711 auto y3 = ops::Const(s.WithOpName("y3"), {3.0f, 3.0f}, {1, 2});
2712 auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
2713 auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
2714 auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
2715 auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
2716 auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
2717 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2718 auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
2719 auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
2720 auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
2721 Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
2722 Output out3 =
2723 ops::Pow(s.WithOpName("out3").WithDevice("/device:CPU:0"), x, y3);
2724 Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
2725 Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
2726 Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
2727 Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
2728 Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
2729 Output out = ops::Pow(s.WithOpName("out"), x, y);
2730 Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
2731 Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
2732
2733 GrapplerItem item;
2734 item.fetch = {"out2", "out3", "out1", "out.5", "out0",
2735 "out_.5", "out_1", "out", "out_bcast1", "out_bcast2"};
2736 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2737 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2738 ASSERT_EQ(tensors_expected.size(), 10);
2739
2740 GraphDef got;
2741 ArithmeticOptimizer optimizer;
2742 EnableOnlyConvertPow(&optimizer);
2743 OptimizeAndPrune(&optimizer, &item, &got);
2744 auto tensors = EvaluateNodes(got, item.fetch);
2745 ASSERT_EQ(tensors.size(), 10);
2746
2747 for (int i = 0; i < tensors.size(); ++i) {
2748 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2749 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2750 }
2751
2752 GraphDef want;
2753 AddNode("x", "Const", {}, {}, &want);
2754 AddNode("y", "Const", {}, {}, &want);
2755 AddNode("z", "Const", {}, {}, &want);
2756 AddNode("ones", "Const", {}, {}, &want);
2757 AddNode("zeros", "Const", {}, {}, &want);
2758 AddNode("out2", "Square", {"x"}, {}, &want);
2759 AddNode("ArithmeticOptimizer/ConvertPow__inner_out3", "Square", {"x"}, {},
2760 &want)
2761 ->set_device("/device:CPU:0");
2762 AddNode("out3", "Mul", {"x", "ArithmeticOptimizer/ConvertPow__inner_out3"},
2763 {}, &want)
2764 ->set_device("/device:CPU:0");
2765 AddNode("out1", "Identity", {"x"}, {}, &want);
2766 AddNode("out.5", "Sqrt", {"x"}, {}, &want);
2767 AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want);
2768 AddNode("out_.5", "Rsqrt", {"x"}, {}, &want);
2769 AddNode("out_1", "Reciprocal", {"x"}, {}, &want);
2770 AddNode("out", "Pow", {"x", "y"}, {}, &want);
2771 AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
2772 AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
2773
2774 CompareGraphs(want, got);
2775 }
2776
TEST_F(ArithmeticOptimizerTest,Log1p)2777 TEST_F(ArithmeticOptimizerTest, Log1p) {
2778 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2779
2780 auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2});
2781 auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2});
2782 auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
2783 auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2);
2784 auto a23 = ops::Add(s.WithOpName("a23"), x2, x3);
2785 Output out1 = ops::Log(s.WithOpName("out1"), a12);
2786 Output out2 = ops::Log(s.WithOpName("out2"), a23);
2787
2788 GrapplerItem item;
2789 item.fetch = {"out1", "out2"};
2790 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2791 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2792 ASSERT_EQ(tensors_expected.size(), 2);
2793
2794 GraphDef got;
2795 ArithmeticOptimizer optimizer;
2796 EnableOnlyLog1p(&optimizer);
2797 OptimizeAndPrune(&optimizer, &item, &got);
2798 auto tensors = EvaluateNodes(got, item.fetch);
2799 ASSERT_EQ(tensors.size(), 2);
2800
2801 for (int i = 0; i < 2; ++i) {
2802 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2803 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2804 }
2805
2806 GraphDef want;
2807 AddNode("x2", "Const", {}, {}, &want);
2808 AddNode("x3", "Const", {}, {}, &want);
2809 AddNode("a23", "Add", {"x2", "x3"}, {}, &want);
2810 AddNode("out1", "Log1p", {"x2", AsControlDependency("x3")}, {}, &want);
2811 AddNode("out2", "Log", {"a23"}, {}, &want);
2812
2813 CompareGraphs(want, got);
2814 }
2815
TEST_F(ArithmeticOptimizerTest,Expm1)2816 TEST_F(ArithmeticOptimizerTest, Expm1) {
2817 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2818
2819 auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
2820 auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
2821 auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
2822 auto exp1 = ops::Exp(s.WithOpName("exp1").WithControlDependencies(x3), x1);
2823 Output out1 = ops::Sub(s.WithOpName("out1"), exp1, x2);
2824 Output out2 = ops::Sub(s.WithOpName("out2"), exp1, x3);
2825
2826 GrapplerItem item;
2827 item.fetch = {"out1", "out2"};
2828 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2829 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2830 ASSERT_EQ(tensors_expected.size(), 2);
2831
2832 GraphDef got;
2833 ArithmeticOptimizer optimizer;
2834 EnableOnlyExpm1(&optimizer);
2835 OptimizeAndPrune(&optimizer, &item, &got);
2836 auto tensors = EvaluateNodes(got, item.fetch);
2837 ASSERT_EQ(tensors.size(), 2);
2838
2839 for (int i = 0; i < 2; ++i) {
2840 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2841 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2842 }
2843
2844 GraphDef want;
2845 AddNode("x1", "Const", {}, {}, &want);
2846 AddNode("x3", "Const", {}, {}, &want);
2847 AddNode("exp1", "Exp", {"x1", AsControlDependency("x3")}, {}, &want);
2848 AddNode("out1", "Expm1", {"x1", AsControlDependency("x3")}, {}, &want);
2849 AddNode("out2", "Sub", {"exp1", "x3"}, {}, &want);
2850
2851 CompareGraphs(want, got);
2852 }
2853
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_SimpleSwap)2854 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
2855 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2856
2857 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2858 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2859 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
2860
2861 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2862 auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
2863
2864 auto outputs = ops::Identity(s.WithOpName("outputs"), mul2);
2865
2866 GrapplerItem item;
2867 item.fetch = {"outputs"};
2868 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2869
2870 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2871 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2872 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2873 std::vector<std::pair<string, Tensor>> feed = {
2874 {"a", a_t}, {"b", b_t}, {"c", c_t}};
2875 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2876 ASSERT_EQ(tensors_expected.size(), 1);
2877
2878 GraphDef output;
2879 ArithmeticOptimizer optimizer;
2880 EnableOnlyMinimizeBroadcasts(&optimizer);
2881
2882 OptimizeAndPrune(&optimizer, &item, &output);
2883
2884 // We expect the following rewrite(s) to occur:
2885 //
2886 // * *
2887 // / \ / \
2888 // * c --> * b
2889 // / \ / \
2890 // a b a c
2891 NodeMap node_map(&output);
2892
2893 const NodeDef* mul1_node = node_map.GetNode("mul1");
2894 ASSERT_NE(mul1_node, nullptr);
2895 ASSERT_EQ(mul1_node->input_size(), 2);
2896 EXPECT_EQ(mul1_node->input(0), "a");
2897 EXPECT_EQ(mul1_node->input(1), "c");
2898
2899 const NodeDef* mul2_node = node_map.GetNode("mul2");
2900 ASSERT_NE(mul2_node, nullptr);
2901 ASSERT_EQ(mul2_node->input_size(), 2);
2902 EXPECT_EQ(mul2_node->input(0), "mul1");
2903 EXPECT_EQ(mul2_node->input(1), "b");
2904
2905 auto tensors = EvaluateNodes(output, item.fetch, feed);
2906 ASSERT_EQ(tensors.size(), 1);
2907 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2908 }
2909
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_FlattenTallGraph)2910 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
2911 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2912
2913 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_DOUBLE);
2914 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_DOUBLE);
2915 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_DOUBLE);
2916 auto d = ops::Variable(s.WithOpName("d"), {32}, DT_DOUBLE);
2917 auto e = ops::Variable(s.WithOpName("e"), {32}, DT_DOUBLE);
2918
2919 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2920 auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
2921 auto mul3 = ops::Mul(s.WithOpName("mul3"), mul2, d);
2922 auto mul4 = ops::Mul(s.WithOpName("mul4"), mul3, e);
2923
2924 auto outputs = ops::Identity(s.WithOpName("outputs"), mul4);
2925
2926 GrapplerItem item;
2927 item.fetch = {"outputs"};
2928 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2929
2930 auto a_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2931 auto b_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32, 32}));
2932 auto c_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2933 auto d_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2934 auto e_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2935 std::vector<std::pair<string, Tensor>> feed = {
2936 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"d", d_t}, {"e", e_t}};
2937 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2938 ASSERT_EQ(tensors_expected.size(), 1);
2939
2940 GraphDef output;
2941 ArithmeticOptimizer optimizer;
2942 EnableOnlyMinimizeBroadcasts(&optimizer);
2943
2944 OptimizeAndPrune(&optimizer, &item, &output);
2945
2946 // We expect the following rewrite(s) to occur: Graph is "flattened" and
2947 // largest shape pushed to the top.
2948 //
2949 // *
2950 // / \
2951 // * e *
2952 // / \ / \
2953 // * d * b
2954 // / \ / \
2955 // * c --> * *
2956 // / \ / \ / \
2957 // a b a c d e
2958 NodeMap node_map(&output);
2959
2960 const NodeDef* mul1_node = node_map.GetNode("mul1");
2961 ASSERT_NE(mul1_node, nullptr);
2962 ASSERT_EQ(mul1_node->input_size(), 2);
2963 EXPECT_EQ(mul1_node->input(0), "a");
2964 EXPECT_EQ(mul1_node->input(1), "c");
2965
2966 const NodeDef* mul2_node = node_map.GetNode("mul2");
2967 ASSERT_NE(mul2_node, nullptr);
2968 ASSERT_EQ(mul2_node->input_size(), 2);
2969 EXPECT_EQ(mul2_node->input(0), "d");
2970 EXPECT_EQ(mul2_node->input(1), "e");
2971
2972 const NodeDef* mul3_node = node_map.GetNode("mul3");
2973 ASSERT_NE(mul3_node, nullptr);
2974 ASSERT_EQ(mul3_node->input_size(), 2);
2975 EXPECT_EQ(mul3_node->input(0), "mul1");
2976 EXPECT_EQ(mul3_node->input(1), "mul2");
2977
2978 const NodeDef* mul4_node = node_map.GetNode("mul4");
2979 ASSERT_NE(mul4_node, nullptr);
2980 ASSERT_EQ(mul4_node->input_size(), 2);
2981 EXPECT_EQ(mul4_node->input(0), "mul3");
2982 EXPECT_EQ(mul4_node->input(1), "b");
2983
2984 auto tensors = EvaluateNodes(output, item.fetch, feed);
2985 ASSERT_EQ(tensors.size(), 1);
2986 test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
2987 }
2988
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_BuildTreeUp)2989 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
2990 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2991
2992 // [a, b, c] - scalars, [d] - matrix
2993 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2994 auto b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
2995 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
2996 auto d = ops::Variable(s.WithOpName("D"), {32, 32}, DT_FLOAT);
2997
2998 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2999 auto mul2 = ops::Mul(s.WithOpName("mul2"), c, d);
3000 auto mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2);
3001
3002 auto outputs = ops::Identity(s.WithOpName("outputs"), mul3);
3003
3004 GrapplerItem item;
3005 item.fetch = {"outputs"};
3006 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3007
3008 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3009 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3010 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3011 auto d_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
3012 std::vector<std::pair<string, Tensor>> feed = {
3013 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"D", d_t}};
3014 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3015 ASSERT_EQ(tensors_expected.size(), 1);
3016
3017 GraphDef output;
3018 ArithmeticOptimizer optimizer;
3019 EnableOnlyMinimizeBroadcasts(&optimizer);
3020
3021 OptimizeAndPrune(&optimizer, &item, &output);
3022
3023 // We expect the following rewrite(s) to occur:
3024 //
3025 // *
3026 // / \
3027 // * * D
3028 // / \ / \
3029 // * * -> * c
3030 // / \ / \ / \
3031 // a b c D a b
3032 NodeMap node_map(&output);
3033
3034 const NodeDef* mul1_node = node_map.GetNode("mul2");
3035 ASSERT_NE(mul1_node, nullptr);
3036 ASSERT_EQ(mul1_node->input_size(), 2);
3037 EXPECT_EQ(mul1_node->input(0), "a");
3038 EXPECT_EQ(mul1_node->input(1), "b");
3039
3040 const NodeDef* mul2_node = node_map.GetNode("mul1");
3041 ASSERT_NE(mul2_node, nullptr);
3042 ASSERT_EQ(mul2_node->input_size(), 2);
3043 EXPECT_EQ(mul2_node->input(0), "mul2");
3044 EXPECT_EQ(mul2_node->input(1), "c");
3045
3046 const NodeDef* mul3_node = node_map.GetNode("mul3");
3047 ASSERT_NE(mul3_node, nullptr);
3048 ASSERT_EQ(mul3_node->input_size(), 2);
3049 EXPECT_EQ(mul3_node->input(0), "D");
3050 EXPECT_EQ(mul3_node->input(1), "mul1");
3051
3052 auto tensors = EvaluateNodes(output, item.fetch, feed);
3053 ASSERT_EQ(tensors.size(), 1);
3054 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3055 }
3056
TEST_F(ArithmeticOptimizerTest,DoNotHoistReluFromConcat)3057 TEST_F(ArithmeticOptimizerTest, DoNotHoistReluFromConcat) {
3058 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3059 Output weights1 = ops::Const(s.WithOpName("weights1"),
3060 Input::Initializer(1.0f, {5, 5, 3, 4}));
3061 Output weights2 = ops::Const(s.WithOpName("weights2"),
3062 Input::Initializer(2.0f, {5, 5, 3, 4}));
3063 Output biases =
3064 ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
3065 Output axis = ops::Const(s.WithOpName("axis"), 3, {});
3066 Output input = ops::Const(s.WithOpName("input"),
3067 Input::Initializer(1.0f, {1, 28, 28, 3}));
3068 Output branch1 =
3069 ops::Conv2D(s.WithOpName("conv1"), input, weights1, {1, 1, 1, 1}, "SAME");
3070 branch1 = ops::BiasAdd(s.WithOpName("biasadd1"), branch1, biases);
3071 branch1 = ops::Relu(s.WithOpName("relu1"), branch1);
3072 Output branch2 =
3073 ops::Conv2D(s.WithOpName("conv2"), input, weights2, {1, 1, 1, 1}, "SAME");
3074 branch2 = ops::BiasAdd(s.WithOpName("biasadd2"), branch2, biases);
3075 branch2 = ops::Relu(s.WithOpName("relu2"), branch2);
3076 Output concat = ops::Concat(s.WithOpName("concat"), {branch1, branch2}, axis);
3077 Output output = ops::Identity(s.WithOpName("output"), concat);
3078
3079 GrapplerItem item;
3080 item.fetch = {"output"};
3081 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3082
3083 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3084
3085 GraphDef new_graph;
3086 ArithmeticOptimizer optimizer;
3087 OptimizeAndPrune(&optimizer, &item, &new_graph);
3088
3089 // Verify that the two Relus are not hoisted.
3090 EXPECT_EQ(CountOpNodes(new_graph, "Relu"), 2);
3091
3092 auto tensors = EvaluateNodes(new_graph, item.fetch);
3093 for (int i = 0; i < item.fetch.size(); ++i) {
3094 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3095 }
3096 }
3097
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryFromConcat)3098 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
3099 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3100 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3101 Output b = ops::Const(s.WithOpName("b"), 1.0f, {32});
3102 Output c = ops::Const(s.WithOpName("c"), 42.0f, {32});
3103 Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3104 Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3105 Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3106 Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3107 // Test case with chains of length 1.
3108 // Rewrites
3109 // Concat({Exp(a), Exp(b), Exp(c)})
3110 // into
3111 // Exp(Concat({a, b, c})).
3112 Output sin_a =
3113 ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a);
3114 Output exp_a =
3115 ops::Exp(s.WithOpName("exp_a").WithControlDependencies(ctrl1), sin_a);
3116 Output exp_b = ops::Exp(s.WithOpName("exp_b"), b);
3117 Output exp_c =
3118 ops::Exp(s.WithOpName("exp_c").WithControlDependencies(ctrl2), c);
3119 Output concat =
3120 ops::Concat(s.WithOpName("concat"), {exp_a, exp_b, exp_c}, axis);
3121 Output id = ops::Identity(s.WithOpName("id"), concat);
3122
3123 // Test case with chains of length 2.
3124 // Rewrites
3125 // Concat({Cos(Exp(a)), Cos(Exp(b)), Cos(Exp(c))})
3126 // into
3127 // Cos(Exp(Concat({a, b, c}))).
3128 Output exp_a2 =
3129 ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a);
3130 Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b);
3131 Output exp_c2 =
3132 ops::Exp(s.WithOpName("exp_c2").WithControlDependencies(ctrl2), c);
3133 Output cos_exp_a2 = ops::Cos(
3134 s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl1), exp_a2);
3135 Output cos_exp_b2 = ops::Cos(
3136 s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3137 Output cos_exp_c2 = ops::Cos(s.WithOpName("cos_exp_c2"), exp_c2);
3138 Output concat2 = ops::Concat(s.WithOpName("concat2"),
3139 {cos_exp_a2, cos_exp_b2, cos_exp_c2}, axis);
3140 Output id2 = ops::Identity(s.WithOpName("id2"), concat2);
3141 GrapplerItem item;
3142 item.fetch = {"id", "id2"};
3143 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3144
3145 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3146
3147 GraphDef output;
3148 ArithmeticOptimizer optimizer;
3149 EnableOnlyHoistCWiseUnaryChains(&optimizer);
3150 OptimizeTwiceAndPrune(&optimizer, &item, &output);
3151
3152 int found = 0;
3153 for (const NodeDef& node : output.node()) {
3154 if (node.name() == "concat") {
3155 ASSERT_EQ(node.input_size(), 4);
3156 EXPECT_EQ(node.input(0), "sin_a");
3157 EXPECT_EQ(node.input(1), "b");
3158 EXPECT_EQ(node.input(2), "c");
3159 EXPECT_EQ(node.input(3), "axis");
3160 found++;
3161 }
3162 if (node.name() == "exp_a") {
3163 ASSERT_EQ(node.input_size(), 1);
3164 EXPECT_EQ(node.input(0), "concat");
3165 found++;
3166 }
3167 if (node.name() == "id") {
3168 ASSERT_EQ(node.input_size(), 1);
3169 EXPECT_EQ(node.input(0), "exp_a");
3170 found++;
3171 }
3172
3173 if (node.name() == "concat2") {
3174 ASSERT_EQ(node.input_size(), 4);
3175 EXPECT_EQ(node.input(0), "sin_a");
3176 EXPECT_EQ(node.input(1), "b");
3177 EXPECT_EQ(node.input(2), "c");
3178 EXPECT_EQ(node.input(3), "axis");
3179 found++;
3180 }
3181 if (node.name() == "exp_a2") {
3182 ASSERT_EQ(node.input_size(), 1);
3183 EXPECT_EQ(node.input(0), "concat2");
3184 found++;
3185 }
3186 if (node.name() == "cos_exp_a2") {
3187 ASSERT_EQ(node.input_size(), 1);
3188 EXPECT_EQ(node.input(0), "exp_a2");
3189 found++;
3190 }
3191 if (node.name() == "id2") {
3192 ASSERT_EQ(node.input_size(), 1);
3193 EXPECT_EQ(node.input(0), "cos_exp_a2");
3194 found++;
3195 }
3196 }
3197 EXPECT_EQ(found, 7);
3198
3199 auto tensors = EvaluateNodes(output, item.fetch);
3200 ASSERT_EQ(tensors.size(), tensors_expected.size());
3201 EXPECT_EQ(tensors.size(), item.fetch.size());
3202 for (int i = 0; i < item.fetch.size(); ++i) {
3203 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3204 }
3205 }
3206
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryIntoSplit)3207 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
3208 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3209 Output x = ops::Const(s.WithOpName("x"), 3.1415f, {32});
3210 Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3211 Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3212 Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3213 Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3214 // Test case with chains of length 1.
3215 // Rewrites
3216 // [Sin(y) for y in Split(x)]
3217 // into
3218 // [y for y in Split(Sin(x))].
3219 ops::Split split1(s.WithOpName("split1"), axis, x, 2);
3220 Output sin_a =
3221 ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl1), split1[0]);
3222 Output id_a = ops::Identity(s.WithOpName("id_a"), sin_a);
3223 Output sin_b = ops::Sin(s.WithOpName("sin_b"), split1[1]);
3224 Output exp_b = ops::Exp(s.WithOpName("exp_b"), sin_b);
3225 Output id_b = ops::Identity(s.WithOpName("id_b"), exp_b);
3226
3227 // Test case with SplitV and chains of length 2.
3228 // Rewrites
3229 // [Cos(Exp(y)) for y in Split(x)]
3230 // into
3231 // [y for y in Split(Cos(Exp(x)))].
3232 Output size_splits2 = ops::Const(s.WithOpName("size_splits2"), {20, 12}, {2});
3233 ops::SplitV split2(s.WithOpName("split2"), x, size_splits2, axis, 2);
3234 Output exp_a2 = ops::Exp(
3235 s.WithOpName("exp_a2").WithControlDependencies(ctrl1), split2[0]);
3236 Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), split2[1]);
3237 Output cos_exp_a2 = ops::Cos(
3238 s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl2), exp_a2);
3239 Output cos_exp_b2 = ops::Cos(
3240 s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3241 Output id_a2 = ops::Identity(s.WithOpName("id_a2"), cos_exp_a2);
3242 Output id_b2 = ops::Identity(s.WithOpName("id_b2"), cos_exp_b2);
3243
3244 GrapplerItem item;
3245 item.fetch = {"id_a", "id_b", "id_a2", "id_b2"};
3246 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3247
3248 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3249
3250 GraphDef output;
3251 ArithmeticOptimizer optimizer;
3252 EnableOnlyHoistCWiseUnaryChains(&optimizer);
3253 OptimizeTwiceAndPrune(&optimizer, &item, &output);
3254
3255 int found = 0;
3256 for (const NodeDef& node : output.node()) {
3257 // The following 6 nodes should be pruned.
3258 EXPECT_NE(node.name(), "sin_a");
3259 EXPECT_NE(node.name(), "sin_b");
3260 EXPECT_NE(node.name(), "exp_a2");
3261 EXPECT_NE(node.name(), "exp_b2");
3262 EXPECT_NE(node.name(), "cos_exp_a2");
3263 EXPECT_NE(node.name(), "cos_exp_b2");
3264
3265 if (node.name() == "split1") {
3266 ASSERT_EQ(node.input_size(), 2);
3267 EXPECT_EQ(node.input(0), "axis");
3268 EXPECT_EQ(node.input(1), "ArithmeticOptimizer/_sin_a_split1");
3269 found++;
3270 }
3271 if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
3272 EXPECT_EQ(node.op(), "Sin");
3273 ASSERT_EQ(node.input_size(), 1);
3274 EXPECT_EQ(node.input(0), "x");
3275 found++;
3276 }
3277 if (node.name() == "id_a") {
3278 ASSERT_EQ(node.input_size(), 1);
3279 EXPECT_EQ(node.input(0), "split1");
3280 found++;
3281 }
3282 if (node.name() == "exp_b") {
3283 ASSERT_EQ(node.input_size(), 1);
3284 EXPECT_EQ(node.input(0), "split1:1");
3285 found++;
3286 }
3287 if (node.name() == "id_b") {
3288 ASSERT_EQ(node.input_size(), 1);
3289 EXPECT_EQ(node.input(0), "exp_b");
3290 found++;
3291 }
3292 if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
3293 EXPECT_EQ(node.op(), "Exp");
3294 ASSERT_EQ(node.input_size(), 1);
3295 EXPECT_EQ(node.input(0), "x");
3296 found++;
3297 }
3298 if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
3299 EXPECT_EQ(node.op(), "Cos");
3300 ASSERT_EQ(node.input_size(), 1);
3301 EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_exp_a2_split2");
3302 found++;
3303 }
3304 if (node.name() == "split2") {
3305 ASSERT_EQ(node.input_size(), 3);
3306 EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_cos_exp_a2_split2");
3307 EXPECT_EQ(node.input(1), "size_splits2");
3308 EXPECT_EQ(node.input(2), "axis");
3309 found++;
3310 }
3311 if (node.name() == "id_a2") {
3312 ASSERT_EQ(node.input_size(), 1);
3313 EXPECT_EQ(node.input(0), "split2");
3314 found++;
3315 }
3316 if (node.name() == "id_b2") {
3317 ASSERT_EQ(node.input_size(), 1);
3318 EXPECT_EQ(node.input(0), "split2:1");
3319 found++;
3320 }
3321 }
3322 EXPECT_EQ(found, 10);
3323
3324 auto tensors = EvaluateNodes(output, item.fetch);
3325 ASSERT_EQ(tensors.size(), tensors_expected.size());
3326 EXPECT_EQ(tensors.size(), item.fetch.size());
3327 for (int i = 0; i < item.fetch.size(); ++i) {
3328 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3329 }
3330 }
3331
TEST_F(ArithmeticOptimizerTest,RemoveIdempotent)3332 TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
3333 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3334 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3335 Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
3336 Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
3337 Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
3338 Output id1 = ops::Identity(s.WithOpName("id1"), a);
3339 Output id2 = ops::Identity(s.WithOpName("id2"), id1);
3340 Output out2 = ops::Identity(s.WithOpName("out2"), id2);
3341 GrapplerItem item;
3342 item.fetch = {"out1", "out2"};
3343 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3344
3345 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3346
3347 GraphDef output;
3348 ArithmeticOptimizer optimizer;
3349 EnableOnlyRemoveIdempotent(&optimizer);
3350 OptimizeTwice(&optimizer, &item, &output);
3351
3352 EXPECT_EQ(7, output.node_size());
3353 int found = 0;
3354 for (const NodeDef& node : output.node()) {
3355 if (node.name() == "out1") {
3356 ASSERT_EQ(node.input_size(), 1);
3357 EXPECT_EQ(node.input(0), "sn1");
3358 found++;
3359 } else if (node.name() == "out2") {
3360 ASSERT_EQ(node.input_size(), 1);
3361 EXPECT_EQ(node.input(0), "id1");
3362 found++;
3363 } else if (node.name() == "sn1") {
3364 ASSERT_EQ(node.input_size(), 1);
3365 EXPECT_EQ(node.input(0), "a");
3366 found++;
3367 }
3368 }
3369 EXPECT_EQ(found, 3);
3370
3371 auto tensors = EvaluateNodes(output, item.fetch);
3372 ASSERT_EQ(tensors.size(), tensors_expected.size());
3373 EXPECT_EQ(tensors.size(), item.fetch.size());
3374 for (int i = 0; i < item.fetch.size(); ++i) {
3375 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3376 }
3377 }
3378
TEST_F(ArithmeticOptimizerTest,RemoveLogicalNot)3379 TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
3380 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3381 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3382 Output b = ops::Const(s.WithOpName("b"), -3.14f, {32});
3383 Output eq = ops::Equal(s.WithOpName("eq"), a, b);
3384 Output neq = ops::NotEqual(s.WithOpName("neq"), a, b);
3385 Output lt = ops::Less(s.WithOpName("lt"), a, b);
3386 Output le = ops::LessEqual(s.WithOpName("le"), a, b);
3387 Output gt = ops::Greater(s.WithOpName("gt"), a, b);
3388 Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b);
3389 // not_eq is reserved
3390 Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq);
3391 Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq);
3392 Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt);
3393 Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le);
3394 Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt);
3395 Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge);
3396 Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1);
3397 Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq);
3398 Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt);
3399 Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le);
3400 Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt);
3401 Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge);
3402
3403 GrapplerItem item;
3404 item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt",
3405 "id_not_le", "id_not_gt", "id_not_ge"};
3406 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3407
3408 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3409
3410 GraphDef output;
3411 ArithmeticOptimizer optimizer;
3412 EnableOnlyRemoveLogicalNot(&optimizer);
3413 OptimizeTwice(&optimizer, &item, &output);
3414
3415 int found = 0;
3416 for (const NodeDef& node : output.node()) {
3417 if (node.name() == "id_not_eq") {
3418 ASSERT_EQ(node.input_size(), 1);
3419 EXPECT_EQ(node.input(0), "eq");
3420 ++found;
3421 }
3422 if (node.name() == "id_not_neq") {
3423 ASSERT_EQ(node.input_size(), 1);
3424 EXPECT_EQ(node.input(0), "neq");
3425 ++found;
3426 }
3427 if (node.name() == "id_not_lt") {
3428 ASSERT_EQ(node.input_size(), 1);
3429 EXPECT_EQ(node.input(0), "lt");
3430 ++found;
3431 }
3432 if (node.name() == "id_not_le") {
3433 ASSERT_EQ(node.input_size(), 1);
3434 EXPECT_EQ(node.input(0), "le");
3435 ++found;
3436 }
3437 if (node.name() == "id_not_gt") {
3438 ASSERT_EQ(node.input_size(), 1);
3439 EXPECT_EQ(node.input(0), "gt");
3440 ++found;
3441 }
3442 if (node.name() == "id_not_ge") {
3443 ASSERT_EQ(node.input_size(), 1);
3444 EXPECT_EQ(node.input(0), "ge");
3445 ++found;
3446 }
3447
3448 if (node.name() == "eq") {
3449 EXPECT_EQ(node.op(), "NotEqual");
3450 ++found;
3451 }
3452 if (node.name() == "neq") {
3453 EXPECT_EQ(node.op(), "Equal");
3454 ++found;
3455 }
3456 if (node.name() == "lt") {
3457 EXPECT_EQ(node.op(), "GreaterEqual");
3458 ++found;
3459 }
3460 if (node.name() == "le") {
3461 EXPECT_EQ(node.op(), "Greater");
3462 ++found;
3463 }
3464 if (node.name() == "gt") {
3465 EXPECT_EQ(node.op(), "LessEqual");
3466 ++found;
3467 }
3468 if (node.name() == "ge") {
3469 EXPECT_EQ(node.op(), "Less");
3470 ++found;
3471 }
3472 }
3473 EXPECT_EQ(found, 12);
3474
3475 auto tensors = EvaluateNodes(output, item.fetch);
3476 ASSERT_EQ(tensors.size(), tensors_expected.size());
3477 EXPECT_EQ(tensors.size(), item.fetch.size());
3478 for (int i = 0; i < item.fetch.size(); ++i) {
3479 test::ExpectTensorEqual<bool>(tensors[i], tensors_expected[i]);
3480 }
3481 }
3482
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWise)3483 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
3484 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3485 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3486 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3487 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3488 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3489
3490 GrapplerItem item;
3491 item.fetch = {"final_out"};
3492 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3493 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3494 ASSERT_EQ(tensors_expected.size(), 1);
3495
3496 GraphDef output;
3497 ArithmeticOptimizer optimizer;
3498 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3499 OptimizeAndPrune(&optimizer, &item, &output);
3500 auto tensors = EvaluateNodes(output, item.fetch);
3501 ASSERT_EQ(tensors.size(), 1);
3502
3503 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3504 EXPECT_EQ(output.node_size(), item.graph.node_size());
3505 // Check if the inputs are switched
3506 int required_node_count = 0;
3507 for (int i = 0; i < output.node_size(); ++i) {
3508 const NodeDef& node = output.node(i);
3509 if (node.name() == "sqrt") {
3510 EXPECT_EQ(node.op(), "Sqrt");
3511 ASSERT_EQ(node.input_size(), 1);
3512 EXPECT_EQ(node.input(0), "reduce_max");
3513 ++required_node_count;
3514 } else if (node.name() == "reduce_max") {
3515 EXPECT_EQ(node.op(), "Max");
3516 ASSERT_EQ(node.input_size(), 2);
3517 EXPECT_EQ(node.input(0), "x");
3518 ++required_node_count;
3519 }
3520 }
3521 EXPECT_EQ(required_node_count, 2);
3522 }
3523
TEST_F(ArithmeticOptimizerTest,OptimizeArgMaxOrArgMinOfMonotonicElementWise)3524 TEST_F(ArithmeticOptimizerTest, OptimizeArgMaxOrArgMinOfMonotonicElementWise) {
3525 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3526 const auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3527 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3528 Output arg_max = ops::ArgMax(s.WithOpName("arg_max"), sqrt, 1);
3529 Output final_out = ops::Identity(s.WithOpName("final_out"), arg_max);
3530
3531 GrapplerItem item;
3532 item.fetch = {"final_out"};
3533 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3534 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3535 ASSERT_EQ(tensors_expected.size(), 1);
3536
3537 GraphDef output;
3538 ArithmeticOptimizer optimizer;
3539 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3540 OptimizeAndPrune(&optimizer, &item, &output);
3541 const auto tensors = EvaluateNodes(output, item.fetch);
3542 ASSERT_EQ(tensors.size(), 1);
3543
3544 test::ExpectTensorEqual<int64>(tensors[0], tensors_expected[0]);
3545 EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
3546 // Check if the inputs are switched
3547 int required_node_count = 0;
3548 for (int i = 0; i < output.node_size(); ++i) {
3549 const NodeDef& node = output.node(i);
3550 if (node.name() == "final_out") {
3551 EXPECT_EQ(node.op(), "Identity");
3552 ASSERT_EQ(node.input_size(), 1);
3553 EXPECT_EQ(node.input(0), "arg_max");
3554 ++required_node_count;
3555 } else if (node.name() == "arg_max") {
3556 EXPECT_EQ(node.op(), "ArgMax");
3557 ASSERT_EQ(node.input_size(), 2);
3558 EXPECT_EQ(node.input(0), "x");
3559 ++required_node_count;
3560 }
3561 }
3562 EXPECT_EQ(required_node_count, 2);
3563 }
3564
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode)3565 TEST_F(ArithmeticOptimizerTest,
3566 OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode) {
3567 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3568 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3569 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3570 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3571 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3572
3573 GrapplerItem item;
3574 item.fetch = {"sqrt", "final_out"};
3575 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3576 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3577 EXPECT_EQ(tensors_expected.size(), 2);
3578
3579 GraphDef output;
3580 ArithmeticOptimizer optimizer;
3581 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3582 OptimizeTwice(&optimizer, &item, &output);
3583
3584 // Should be a NoOp since we are not allowed to change the output of fetch
3585 // nodes.
3586 VerifyGraphsMatch(item.graph, output, __LINE__);
3587 }
3588
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction)3589 TEST_F(ArithmeticOptimizerTest,
3590 OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction) {
3591 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3592 auto x = ops::Const(s.WithOpName("x"), {2, 3}, {1, 2});
3593 Output reshape = ops::Reshape(s.WithOpName("reshape"), x, {-1});
3594 Output y = ops::Neg(s.WithOpName("y"), reshape);
3595 Output z = ops::Max(s.WithOpName("z"), y, {0});
3596
3597 GrapplerItem item;
3598 item.fetch = {"z"};
3599 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3600 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3601 ASSERT_EQ(tensors_expected.size(), 1);
3602
3603 GraphDef output;
3604 ArithmeticOptimizer optimizer;
3605 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3606 OptimizeTwice(&optimizer, &item, &output);
3607
3608 // Should be a NoOp since we are not allowed to change the output of fetch
3609 // nodes.
3610 VerifyGraphsMatch(item.graph, output, __LINE__);
3611
3612 auto tensors = EvaluateNodes(output, item.fetch);
3613 ASSERT_EQ(tensors.size(), 1);
3614 test::ExpectTensorEqual<int>(tensors[0], tensors_expected[0]);
3615 test::ExpectTensorEqual<int>(tensors[0], Tensor(-2));
3616 }
3617
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing)3618 TEST_F(ArithmeticOptimizerTest,
3619 OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
3620 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3621 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3622 Output neg = ops::Neg(s.WithOpName("neg"), x);
3623 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
3624 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3625
3626 GrapplerItem item;
3627 item.fetch = {"final_out"};
3628 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3629 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3630 ASSERT_EQ(tensors_expected.size(), 1);
3631
3632 GraphDef output;
3633 ArithmeticOptimizer optimizer;
3634 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3635 OptimizeAndPrune(&optimizer, &item, &output);
3636 auto tensors = EvaluateNodes(output, item.fetch);
3637 ASSERT_EQ(tensors.size(), 1);
3638
3639 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3640 EXPECT_EQ(output.node_size(), item.graph.node_size());
3641 // Check if the inputs are switched
3642 int required_node_count = 0;
3643 for (int i = 0; i < output.node_size(); ++i) {
3644 const NodeDef& node = output.node(i);
3645 if (node.name() == "neg") {
3646 EXPECT_EQ(node.op(), "Neg");
3647 ASSERT_EQ(node.input_size(), 1);
3648 EXPECT_EQ(node.input(0), "reduce_max");
3649 ++required_node_count;
3650 } else if (node.name() == "reduce_max") {
3651 EXPECT_EQ(node.op(), "Min");
3652 ASSERT_EQ(node.input_size(), 2);
3653 EXPECT_EQ(node.input(0), "x");
3654 ++required_node_count;
3655 }
3656 }
3657 EXPECT_EQ(2, required_node_count);
3658 }
3659
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool)3660 TEST_F(ArithmeticOptimizerTest,
3661 OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool) {
3662 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3663 auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
3664 Output neg = ops::Neg(s.WithOpName("neg"), x);
3665 Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), neg, {1, 2, 2, 1},
3666 {1, 2, 2, 1}, "VALID");
3667
3668 GrapplerItem item;
3669 item.fetch = {"max_pool"};
3670 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3671 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3672 ASSERT_EQ(tensors_expected.size(), 1);
3673
3674 GraphDef output;
3675 ArithmeticOptimizer optimizer;
3676 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3677 OptimizeTwice(&optimizer, &item, &output);
3678
3679 // Should be a NoOp
3680 VerifyGraphsMatch(item.graph, output, __LINE__);
3681
3682 auto tensors = EvaluateNodes(output, item.fetch);
3683 ASSERT_EQ(tensors.size(), 1);
3684 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3685 }
3686
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool)3687 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool) {
3688 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3689 Output weights = ops::Const(s.WithOpName("weights"),
3690 Input::Initializer(1.0f, {5, 5, 3, 4}));
3691 Output biases =
3692 ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
3693 Output input = ops::Const(s.WithOpName("input"),
3694 Input::Initializer(1.0f, {1, 28, 28, 3}));
3695 Output output =
3696 ops::Conv2D(s.WithOpName("conv"), input, weights, {1, 1, 1, 1}, "SAME");
3697 output = ops::BiasAdd(s.WithOpName("biasadd"), output, biases);
3698 output = ops::Relu(s.WithOpName("relu"), output);
3699 output = ops::MaxPool(s.WithOpName("max_pool"), output, {1, 2, 2, 1},
3700 {1, 2, 2, 1}, "VALID");
3701 output = ops::Identity(s.WithOpName("output"), output);
3702
3703 GrapplerItem item;
3704 item.fetch = {"output"};
3705 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3706 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3707 ASSERT_EQ(tensors_expected.size(), 1);
3708
3709 GraphDef new_graph;
3710 ArithmeticOptimizer optimizer;
3711 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3712 OptimizeTwice(&optimizer, &item, &new_graph);
3713
3714 // Should be a NoOp
3715 VerifyGraphsMatch(item.graph, new_graph, __LINE__);
3716
3717 auto tensors = EvaluateNodes(new_graph, item.fetch);
3718 ASSERT_EQ(tensors.size(), 1);
3719 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3720 }
3721
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseMaxPool)3722 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWiseMaxPool) {
3723 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3724 auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
3725 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3726 Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), sqrt, {1, 2, 2, 1},
3727 {1, 2, 2, 1}, "VALID");
3728 Output final_out = ops::Identity(s.WithOpName("final_out"), max_pool);
3729
3730 GrapplerItem item;
3731 item.fetch = {"final_out"};
3732 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3733 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3734 ASSERT_EQ(tensors_expected.size(), 1);
3735
3736 GraphDef output;
3737 ArithmeticOptimizer optimizer;
3738 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3739 OptimizeAndPrune(&optimizer, &item, &output);
3740 auto tensors = EvaluateNodes(output, item.fetch);
3741 ASSERT_EQ(tensors.size(), 1);
3742
3743 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3744 EXPECT_EQ(output.node_size(), item.graph.node_size());
3745 // Check if the inputs are switched
3746 int required_node_count = 0;
3747 for (int i = 0; i < output.node_size(); ++i) {
3748 const NodeDef& node = output.node(i);
3749 if (node.name() == "sqrt") {
3750 EXPECT_EQ(node.op(), "Sqrt");
3751 ASSERT_EQ(node.input_size(), 1);
3752 EXPECT_EQ(node.input(0), "max_pool");
3753 ++required_node_count;
3754 } else if (node.name() == "max_pool") {
3755 EXPECT_EQ(node.op(), "MaxPool");
3756 ASSERT_EQ(node.input_size(), 1);
3757 EXPECT_EQ(node.input(0), "x");
3758 ++required_node_count;
3759 }
3760 }
3761 EXPECT_EQ(required_node_count, 2);
3762 }
3763
TEST_F(ArithmeticOptimizerTest,UnaryOpsComposition)3764 TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
3765 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3766
3767 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3768 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3769 Output log = ops::Log(s.WithOpName("log"), sqrt);
3770 Output relu = ops::Relu(s.WithOpName("relu"), log);
3771 Output final_out = ops::Identity(s.WithOpName("final_out"), relu);
3772
3773 GrapplerItem item;
3774 item.fetch = {"final_out"};
3775 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3776
3777 // Place all nodes on CPU.
3778 for (int i = 0; i < item.graph.node_size(); ++i) {
3779 item.graph.mutable_node(i)->set_device("/device:CPU:0");
3780 }
3781
3782 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3783 ASSERT_EQ(tensors_expected.size(), 1);
3784
3785 GraphDef output;
3786 ArithmeticOptimizer optimizer;
3787 EnableOnlyUnaryOpsComposition(&optimizer);
3788 OptimizeAndPrune(&optimizer, &item, &output);
3789
3790 EXPECT_EQ(output.node_size(), 3);
3791
3792 // Check that Sqrt/Log/Relu were replaced with a single op.
3793 int required_node_count = 0;
3794 for (int i = 0; i < output.node_size(); ++i) {
3795 const NodeDef& node = output.node(i);
3796 if (node.name() == "final_out") {
3797 EXPECT_EQ(node.op(), "Identity");
3798 ASSERT_EQ(node.input_size(), 1);
3799 EXPECT_EQ(node.input(0), "relu/unary_ops_composition");
3800 ++required_node_count;
3801 } else if (node.name() == "relu/unary_ops_composition") {
3802 EXPECT_EQ(node.op(), "_UnaryOpsComposition");
3803 ASSERT_EQ(node.input_size(), 1);
3804 EXPECT_EQ(node.input(0), "x");
3805
3806 auto op_names = node.attr().at("op_names").list().s();
3807 ASSERT_EQ(op_names.size(), 3);
3808 EXPECT_EQ(op_names[0], "Sqrt");
3809 EXPECT_EQ(op_names[1], "Log");
3810 EXPECT_EQ(op_names[2], "Relu");
3811 ++required_node_count;
3812 }
3813 }
3814 EXPECT_EQ(required_node_count, 2);
3815
3816 auto tensors = EvaluateNodes(output, item.fetch);
3817 ASSERT_EQ(tensors.size(), 1);
3818 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3819 }
3820
TEST_F(ArithmeticOptimizerTest,RemoveStackStridedSliceSameAxis)3821 TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) {
3822 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3823 auto a_in =
3824 ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
3825 auto b_in =
3826 ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
3827 auto c_in =
3828 ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
3829 auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
3830 PartialTensorShape({-1, -1}));
3831 auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
3832 PartialTensorShape({-1, -1}));
3833 auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
3834 PartialTensorShape({-1, -1}));
3835 // stacked = tf.stack((a, b, c), axis=1).
3836 // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
3837 auto stacked =
3838 ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
3839 ops::Stack::Axis(1));
3840 auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
3841 auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
3842 auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
3843 auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
3844 auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3});
3845 auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
3846 auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3});
3847 auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
3848 auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3});
3849 auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3});
3850 auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3});
3851
3852 // stacked[:, 0]
3853 using SS = ops::StridedSlice;
3854 auto pa_slice = ops::Identity(
3855 s.WithOpName("pa_slice_out"),
3856 SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides,
3857 SS::BeginMask(0b0101) // 5
3858 .EllipsisMask(0)
3859 .EndMask(0b0101) // 5
3860 .NewAxisMask(0)
3861 .ShrinkAxisMask(0b0010))); // 2
3862
3863 // stacked[:, 1]
3864 auto pb_slice = ops::Identity(
3865 s.WithOpName("pb_slice_out"),
3866 SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides,
3867 SS::BeginMask(0b0101) // 5
3868 .EllipsisMask(0)
3869 .EndMask(0b0101) // 5
3870 .NewAxisMask(0)
3871 .ShrinkAxisMask(0b0010))); // 2
3872
3873 // stacked[:, 2]
3874 auto pc_slice = ops::Identity(
3875 s.WithOpName("pc_slice_out"),
3876 SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides,
3877 SS::BeginMask(0b0101) // 5
3878 .EllipsisMask(0)
3879 .EndMask(0b0101) // 5
3880 .NewAxisMask(0)
3881 .ShrinkAxisMask(0b0010))); // 2
3882
3883 // stacked[:, 0:1, :]
3884 auto pa_slice_01 = ops::Identity(
3885 s.WithOpName("pa_slice_01_out"),
3886 SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides,
3887 SS::BeginMask(0b0101) // 5
3888 .EllipsisMask(0)
3889 .EndMask(0b0101) // 5
3890 .NewAxisMask(0)
3891 .ShrinkAxisMask(0)));
3892
3893 // stacked[:, :1, :]
3894 auto pa_slice_to1 = ops::Identity(
3895 s.WithOpName("pa_slice_to1_out"),
3896 SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides,
3897 SS::BeginMask(0b0111) // 7
3898 .EllipsisMask(0)
3899 .EndMask(0b0101) // 5
3900 .NewAxisMask(0)
3901 .ShrinkAxisMask(0)));
3902
3903 // stacked[:, 1:2, :]
3904 auto pb_slice_12 = ops::Identity(
3905 s.WithOpName("pb_slice_12_out"),
3906 SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides,
3907 SS::BeginMask(0b0101) // 5
3908 .EllipsisMask(0)
3909 .EndMask(0b0101) // 5
3910 .NewAxisMask(0)
3911 .ShrinkAxisMask(0)));
3912
3913 // stacked[:, 2:, :].
3914 auto pc_slice_2to = ops::Identity(
3915 s.WithOpName("pc_slice_2to_out"),
3916 SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides,
3917 SS::BeginMask(0b0101) // 5
3918 .EllipsisMask(0)
3919 .EndMask(0b0111) // 7
3920 .NewAxisMask(0)
3921 .ShrinkAxisMask(0)));
3922
3923 GrapplerItem item;
3924 item.fetch = {"a",
3925 "b",
3926 "c",
3927 "pa_slice_out",
3928 "pb_slice_out",
3929 "pc_slice_out",
3930 "expanded_a",
3931 "expanded_b",
3932 "expanded_c",
3933 "pa_slice_01_out",
3934 "pa_slice_to1_out",
3935 "pb_slice_12_out",
3936 "pc_slice_2to_out"};
3937 enum FetchItem {
3938 fA,
3939 fB,
3940 fC,
3941 fASliceOut,
3942 fBSliceOut,
3943 fCSliceOut,
3944 fExpandedA,
3945 fExpandedB,
3946 fExpandedC,
3947 fASlice01Out,
3948 fASliceTo1Out,
3949 fBSlice12Out,
3950 fCSlice2ToOut,
3951 };
3952 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3953 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3954
3955 // stacked[:, 0, :] == a.
3956 test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
3957 tensors_expected[fA]);
3958 // stacked[:, 1, :] == b.
3959 test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
3960 tensors_expected[fB]);
3961 // stacked[:, 2, :] == c.
3962 test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
3963 tensors_expected[fC]);
3964
3965 // stacked[:, 0:1, :] == expand_dims(a, 1).
3966 test::ExpectTensorEqual<float>(tensors_expected[fASlice01Out],
3967 tensors_expected[fExpandedA]);
3968
3969 // stacked[:, :1, :] == expand_dims(a, 1).
3970 test::ExpectTensorEqual<float>(tensors_expected[fASliceTo1Out],
3971 tensors_expected[fExpandedA]);
3972
3973 // stacked[:, 1:2, :] == expand_dims(b, 1).
3974 test::ExpectTensorEqual<float>(tensors_expected[fBSlice12Out],
3975 tensors_expected[fExpandedB]);
3976 // stacked[:, 2:, :] == expand_dims(c, 1).
3977 test::ExpectTensorEqual<float>(tensors_expected[fCSlice2ToOut],
3978 tensors_expected[fExpandedC]);
3979
3980 GraphDef output;
3981 ArithmeticOptimizer optimizer;
3982 EnableOnlyRemoveStackSliceSameAxis(&optimizer);
3983 OptimizeAndPrune(&optimizer, &item, &output);
3984
3985 for (const auto& node : output.node()) {
3986 if (node.name() == "pa_slice_out") {
3987 ASSERT_EQ(node.input_size(), 1);
3988 EXPECT_EQ(node.input(0), "a");
3989 } else if (node.name() == "pb_slice_out") {
3990 ASSERT_EQ(node.input_size(), 1);
3991 EXPECT_EQ(node.input(0), "b");
3992 } else if (node.name() == "pc_slice_out") {
3993 ASSERT_EQ(node.input_size(), 1);
3994 EXPECT_EQ(node.input(0), "c");
3995 } else if (str_util::EndsWith(node.name(), "_out")) {
3996 ASSERT_EQ(node.input_size(), 1);
3997 EXPECT_EQ(
3998 absl::StrCat(node.input(0), "_out"),
3999 absl::StrCat("ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_",
4000 node.name()));
4001 }
4002 }
4003
4004 auto tensors = EvaluateNodes(output, item.fetch);
4005
4006 // stacked[:, 0, :] == a.
4007 test::ExpectTensorEqual<float>(tensors[fASliceOut], tensors_expected[fA]);
4008
4009 // stacked[:, 1, :] == b.
4010 test::ExpectTensorEqual<float>(tensors[fBSliceOut], tensors_expected[fB]);
4011 // stacked[:, 2, :] == c.
4012 test::ExpectTensorEqual<float>(tensors[fCSliceOut], tensors_expected[fC]);
4013
4014 // stacked[:, 0:1, :] == expand_dims(a, 1).
4015 test::ExpectTensorEqual<float>(tensors[fASlice01Out],
4016 tensors_expected[fExpandedA]);
4017
4018 // stacked[:, :1, :] == expand_dims(a, 1).
4019 test::ExpectTensorEqual<float>(tensors[fASliceTo1Out],
4020 tensors_expected[fExpandedA]);
4021
4022 // stacked[:, 1:2, :] == expand_dims(b, 1).
4023 test::ExpectTensorEqual<float>(tensors[fBSlice12Out],
4024 tensors_expected[fExpandedB]);
4025 // stacked[:, 2:, :] == expand_dims(c, 1).
4026 test::ExpectTensorEqual<float>(tensors[fCSlice2ToOut],
4027 tensors_expected[fExpandedC]);
4028 }
4029
TEST_F(ArithmeticOptimizerTest,RemoveStackSimpleSliceSameAxis)4030 TEST_F(ArithmeticOptimizerTest, RemoveStackSimpleSliceSameAxis) {
4031 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4032 auto a_in =
4033 ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4034 auto b_in =
4035 ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
4036 auto c_in =
4037 ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
4038 auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
4039 PartialTensorShape({-1, -1}));
4040 auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
4041 PartialTensorShape({-1, -1}));
4042 auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
4043 PartialTensorShape({-1, -1}));
4044 // stacked = tf.stack((a, b, c), axis=1).
4045 // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
4046 auto stacked =
4047 ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
4048 ops::Stack::Axis(1));
4049 auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
4050 auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
4051 auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
4052 auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
4053 auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
4054 auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
4055 auto sizes_to_end = ops::Const(s.WithOpName("size"), {-1, 1, -1}, {3});
4056
4057 // stacked[:, 0:1, :]
4058 auto pa_slice = ops::Identity(
4059 s.WithOpName("pa_slice_out"),
4060 ops::Slice(s.WithOpName("pa_slice"), stacked, begin_a, sizes_to_end));
4061
4062 // stacked[:, 1:2, :]
4063 auto pb_slice = ops::Identity(
4064 s.WithOpName("pb_slice_out"),
4065 ops::Slice(s.WithOpName("pb_slice"), stacked, begin_b, sizes_to_end));
4066
4067 // stacked[:, 2:3, :]
4068 auto pc_slice = ops::Identity(
4069 s.WithOpName("pc_slice_out"),
4070 ops::Slice(s.WithOpName("pc_slice"), stacked, begin_c, sizes_to_end));
4071
4072 GrapplerItem item;
4073 item.fetch = {"a",
4074 "b",
4075 "c",
4076 "pa_slice_out",
4077 "pb_slice_out",
4078 "pc_slice_out",
4079 "expanded_a",
4080 "expanded_b",
4081 "expanded_c"};
4082 enum FetchItem {
4083 fA,
4084 fB,
4085 fC,
4086 fASliceOut,
4087 fBSliceOut,
4088 fCSliceOut,
4089 fExpandedA,
4090 fExpandedB,
4091 fExpandedC,
4092 };
4093 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4094 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4095
4096 // stacked[:, 0:1, :] == a.
4097 test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
4098 tensors_expected[fExpandedA]);
4099 // stacked[:, 1:2, :] == b.
4100 test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
4101 tensors_expected[fExpandedB]);
4102 // stacked[:, 2:3, :] == c.
4103 test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
4104 tensors_expected[fExpandedC]);
4105
4106 GraphDef output;
4107 ArithmeticOptimizer optimizer;
4108 EnableOnlyRemoveStackSliceSameAxis(&optimizer);
4109 OptimizeAndPrune(&optimizer, &item, &output);
4110
4111 const string kExpandDimsNamePrefix(
4112 "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_p");
4113
4114 for (const auto& node : output.node()) {
4115 if (node.name() == "pa_slice_out") {
4116 ASSERT_EQ(node.input_size(), 1);
4117 EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "a_slice"));
4118 } else if (node.name() == "pb_slice_out") {
4119 ASSERT_EQ(node.input_size(), 1);
4120 EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "b_slice"));
4121 } else if (node.name() == "pc_slice_out") {
4122 ASSERT_EQ(node.input_size(), 1);
4123 EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "c_slice"));
4124 } else if (absl::StartsWith(node.name(), kExpandDimsNamePrefix)) {
4125 EXPECT_EQ(node.op(), "ExpandDims");
4126 // The input is "a", "b", or "c", as appropriate.
4127 EXPECT_EQ(node.input(0),
4128 node.name().substr(kExpandDimsNamePrefix.size(), 1));
4129 }
4130 }
4131
4132 auto tensors = EvaluateNodes(output, item.fetch);
4133
4134 // stacked[:, 0:1, :] == a.
4135 test::ExpectTensorEqual<float>(tensors[fASliceOut],
4136 tensors_expected[fExpandedA]);
4137
4138 // stacked[:, 1:2, :] == b.
4139 test::ExpectTensorEqual<float>(tensors[fBSliceOut],
4140 tensors_expected[fExpandedB]);
4141 // stacked[:, 2:3, :] == c.
4142 test::ExpectTensorEqual<float>(tensors[fCSliceOut],
4143 tensors_expected[fExpandedC]);
4144 }
4145
TEST_F(ArithmeticOptimizerTest,SimplifyAggregationBFloat16)4146 TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
4147 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4148 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4149 Output cast = ops::Cast(s.WithOpName("cast"), x, DT_BFLOAT16);
4150 Output add = ops::AddN(s.WithOpName("add"), {cast, cast});
4151 Output id = ops::Identity(s.WithOpName("id"), add);
4152
4153 GrapplerItem item;
4154 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4155 item.fetch = {"id"};
4156 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4157 ASSERT_EQ(tensors_expected.size(), 1);
4158
4159 GraphDef output;
4160 ArithmeticOptimizer optimizer;
4161 EnableOnlySimplifyAggregation(&optimizer);
4162 OptimizeAndPrune(&optimizer, &item, &output);
4163
4164 // Extra node created for multiplier.
4165 EXPECT_EQ(output.node_size(), 5);
4166
4167 auto tensors = EvaluateNodes(output, item.fetch);
4168 ASSERT_EQ(tensors.size(), 1);
4169 test::ExpectTensorEqual<bfloat16>(tensors[0], tensors_expected[0]);
4170 }
4171
TEST_F(ArithmeticOptimizerTest,SimplifyEmbeddingLookup)4172 TEST_F(ArithmeticOptimizerTest, SimplifyEmbeddingLookup) {
4173 for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
4174 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4175 Output embeddings = ops::Const(s.WithOpName("embeddings"),
4176 {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4177 Output segment_ids =
4178 ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
4179 Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
4180 auto unique = ops::Unique(s.WithOpName("unique"), indices,
4181 /*attrs=*/{unique_idx_type});
4182 Output ids = unique.y;
4183 Output idx = unique.idx;
4184 Output gathered_rows =
4185 ops::Gather(s.WithOpName("gathered_rows"), embeddings, ids);
4186 Output result = ops::SparseSegmentSum(s.WithOpName("result"), gathered_rows,
4187 idx, segment_ids);
4188 Output id = ops::Identity(s.WithOpName("id"), result);
4189
4190 GrapplerItem item;
4191 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4192 item.fetch = {"id"};
4193 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4194 ASSERT_EQ(tensors_expected.size(), 1);
4195
4196 GraphDef output;
4197 ArithmeticOptimizer optimizer;
4198 EnableOnlySimplifyEmbeddingLookup(&optimizer);
4199 OptimizeAndPrune(&optimizer, &item, &output);
4200
4201 for (const auto& node : output.node()) {
4202 if (node.name() == "result") {
4203 EXPECT_EQ(node.input(0), "embeddings");
4204 EXPECT_EQ(node.input(1), "indices");
4205 }
4206 EXPECT_NE(node.op(), "Unique");
4207 EXPECT_NE(node.op(), "Gather");
4208 }
4209
4210 auto tensors = EvaluateNodes(output, item.fetch);
4211 ASSERT_EQ(tensors.size(), 1);
4212 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4213 }
4214 }
4215
TEST_F(ArithmeticOptimizerTest,RemoveCastIntoSegmentReduction)4216 TEST_F(ArithmeticOptimizerTest, RemoveCastIntoSegmentReduction) {
4217 for (DataType indices_type : {DT_INT32, DT_INT64}) {
4218 for (DataType segment_ids_type : {DT_INT32, DT_INT64}) {
4219 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4220 Output embeddings = ops::Const(s.WithOpName("embeddings"),
4221 {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4222 Output indices =
4223 ops::Cast(s.WithOpName("cast_indices"),
4224 ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1}),
4225 indices_type);
4226 Output segment_ids = ops::Cast(
4227 s.WithOpName("cast_segment_ids"),
4228 ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2}),
4229 segment_ids_type);
4230 Output result = ops::SparseSegmentSum(s.WithOpName("result"), embeddings,
4231 indices, segment_ids);
4232 Output id = ops::Identity(s.WithOpName("id"), result);
4233
4234 GrapplerItem item;
4235 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4236 item.fetch = {"id"};
4237 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4238 ASSERT_EQ(tensors_expected.size(), 1);
4239
4240 GraphDef output;
4241 ArithmeticOptimizer optimizer;
4242 EnableOnlyRemoveCastIntoSegmentReduction(&optimizer);
4243 OptimizeAndPrune(&optimizer, &item, &output);
4244
4245 for (const auto& node : output.node()) {
4246 if (node.name() == "result") {
4247 EXPECT_EQ(node.input(1), "indices");
4248 EXPECT_EQ(node.input(2), "segment_ids");
4249 }
4250 EXPECT_NE(node.op(), "Cast");
4251 }
4252
4253 auto tensors = EvaluateNodes(output, item.fetch);
4254 ASSERT_EQ(tensors.size(), 1);
4255 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4256 }
4257 }
4258 }
4259
4260 } // namespace grappler
4261 } // namespace tensorflow
4262