1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17 #include "tensorflow/cc/ops/math_ops.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
23 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
24 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28
29 namespace tensorflow {
30 namespace grappler {
31
32 namespace {
33
34 constexpr char kHoistFactorOptimizerDiv[] =
35 "ArithmeticOptimizer/HoistCommonFactor_Div_";
36
37 constexpr char kHoistFactorOptimizerMul[] =
38 "ArithmeticOptimizer/HoistCommonFactor_Mul_";
39
40 constexpr char kHoistFactorOptimizerAdd[] =
41 "ArithmeticOptimizer/HoistCommonFactor_Add_";
42
43 constexpr char kSimplifyAggregationConst[] =
44 "ArithmeticOptimizer/SimplifyAggregation_Const_";
45
46 constexpr char kSimplifyAggregationMul[] =
47 "ArithmeticOptimizer/SimplifyAggregation_Mul_";
48
49 // Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation.
HoistMulName(const string & name)50 string HoistMulName(const string& name) {
51 return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
52 }
53
54 // Optimized name of outer Div node by HoistCommonFactorOutOfAggregation.
HoistDivName(const string & name)55 string HoistDivName(const string& name) {
56 return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
57 }
58
59 // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation.
HoistAddName(const string & name)60 string HoistAddName(const string& name) {
61 return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
62 }
63
64 // Optimized name of Const node by SimplifyAggregation.
AggregationConstName(const string & name)65 string AggregationConstName(const string& name) {
66 return AddPrefixToNodeName(name, kSimplifyAggregationConst, "");
67 }
68
69 // Optimized name of Mul node by SimplifyAggregation.
AggregationMulName(const string & name)70 string AggregationMulName(const string& name) {
71 return AddPrefixToNodeName(name, kSimplifyAggregationMul, "");
72 }
73
VerifyGraphsMatch(const GraphDef & original_graph,const GraphDef & optimized_graph,int line)74 void VerifyGraphsMatch(const GraphDef& original_graph,
75 const GraphDef& optimized_graph, int line) {
76 EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
77 for (int i = 0; i < original_graph.node_size(); ++i) {
78 const NodeDef& original = original_graph.node(i);
79 const NodeDef& optimized = optimized_graph.node(i);
80 EXPECT_EQ(original.name(), optimized.name()) << line;
81 EXPECT_EQ(original.op(), optimized.op()) << line;
82 EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
83 for (int j = 0; j < original.input_size(); ++j) {
84 EXPECT_EQ(original.input(j), optimized.input(j)) << line;
85 }
86 }
87 }
88 } // namespace
89
TEST_F(ArithmeticOptimizerTest,NoOp)90 TEST_F(ArithmeticOptimizerTest, NoOp) {
91 // This trivial graph is so basic there's nothing to optimize.
92 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
93 GrapplerItem item;
94 CHECK(fake_input.NextItem(&item));
95
96 ArithmeticOptimizer optimizer;
97 GraphDef output;
98 Status status = optimizer.Optimize(nullptr, item, &output);
99 TF_EXPECT_OK(status);
100 VerifyGraphsMatch(item.graph, output, __LINE__);
101 }
102
TEST_F(ArithmeticOptimizerTest,OpDedupping)103 TEST_F(ArithmeticOptimizerTest, OpDedupping) {
104 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
105 Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2});
106 Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2});
107 Output div = ops::Div(s.WithOpName("div"), c1, c2);
108 GrapplerItem item;
109 TF_CHECK_OK(s.ToGraphDef(&item.graph));
110 item.fetch = {"div"};
111
112 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
113 EXPECT_EQ(1, tensors_expected.size());
114
115 ArithmeticOptimizer optimizer;
116 GraphDef output;
117 OptimizeTwice(&optimizer, &item, &output);
118 NodeMap node_map(&output);
119 EXPECT_EQ(2, output.node_size());
120 const NodeDef* new_c1 = node_map.GetNode("c1");
121 ASSERT_NE(new_c1, nullptr);
122
123 const NodeDef* new_div = node_map.GetNode("div");
124 ASSERT_NE(new_div, nullptr);
125 EXPECT_EQ(2, new_div->input_size());
126 EXPECT_EQ("c1", new_div->input(0));
127 EXPECT_EQ("c1", new_div->input(1));
128
129 auto tensors = EvaluateNodes(output, item.fetch);
130 EXPECT_EQ(1, tensors.size());
131 test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
132 }
133
TEST_F(ArithmeticOptimizerTest,OpDeduppingAssertAndCheckNumerics)134 TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
135 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
136 Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({}));
137 Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2});
138 auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo");
139 auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo");
140 auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c});
141 auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c});
142 Output div = ops::Div(s.WithOpName("div").WithControlDependencies(
143 {assert1.operation, assert2.operation}),
144 check1, check2);
145 GrapplerItem item;
146 TF_CHECK_OK(s.ToGraphDef(&item.graph));
147 item.fetch = {"div"};
148 Tensor bool_t(DT_BOOL, TensorShape({}));
149 bool_t.scalar<bool>().setConstant(true);
150 auto tensors_expected =
151 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", bool_t}});
152 EXPECT_EQ(1, tensors_expected.size());
153
154 ArithmeticOptimizer optimizer;
155 GraphDef output;
156
157 OptimizeTwice(&optimizer, &item, &output);
158 NodeMap node_map(&output);
159
160 EXPECT_EQ(5, output.node_size());
161 const NodeDef* new_div = node_map.GetNode("div");
162 ASSERT_NE(new_div, nullptr);
163 EXPECT_EQ(4, new_div->input_size());
164 EXPECT_EQ("check1", new_div->input(0));
165 EXPECT_EQ("check1", new_div->input(1));
166 EXPECT_EQ("^assert1", new_div->input(2));
167 EXPECT_EQ("^assert1", new_div->input(3));
168
169 auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}});
170 EXPECT_EQ(1, tensors.size());
171 test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
172 }
173
TEST_F(ArithmeticOptimizerTest,OpDedupCommutative)174 TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
175 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
176 Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2});
177 Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2});
178 Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2);
179 Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1);
180 Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2);
181 GrapplerItem item;
182 TF_CHECK_OK(s.ToGraphDef(&item.graph));
183 item.fetch = {"div1"};
184 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
185 EXPECT_EQ(1, tensors_expected.size());
186
187 ArithmeticOptimizer optimizer;
188 GraphDef output;
189 OptimizeTwice(&optimizer, &item, &output);
190 NodeMap node_map(&output);
191
192 EXPECT_EQ(4, output.node_size());
193 const NodeDef* new_c1 = node_map.GetNode("c1");
194 ASSERT_NE(new_c1, nullptr);
195 const NodeDef* new_c2 = node_map.GetNode("c2");
196 ASSERT_NE(new_c2, nullptr);
197 const NodeDef* new_mul1 = node_map.GetNode("mul1");
198 ASSERT_NE(new_mul1, nullptr);
199 EXPECT_EQ(2, new_mul1->input_size());
200 EXPECT_EQ("c1", new_mul1->input(0));
201 EXPECT_EQ("c2", new_mul1->input(1));
202 const NodeDef* new_div1 = node_map.GetNode("div1");
203 ASSERT_NE(new_div1, nullptr);
204 EXPECT_EQ(2, new_div1->input_size());
205 EXPECT_EQ("mul1", new_div1->input(0));
206 EXPECT_EQ("mul1", new_div1->input(1));
207
208 auto tensors = EvaluateNodes(output, item.fetch);
209 EXPECT_EQ(1, tensors.size());
210 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
211 }
212
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithSquare)213 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
214 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
215 Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
216 Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
217 Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
218 Output id = ops::Identity(s.WithOpName("id"), mul);
219
220 GrapplerItem item;
221 item.fetch = {"id"};
222 TF_CHECK_OK(s.ToGraphDef(&item.graph));
223 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
224 EXPECT_EQ(1, tensors_expected.size());
225
226 GraphDef output;
227 ArithmeticOptimizer optimizer;
228 EnableOnlyReplaceMulWithSquare(&optimizer);
229 OptimizeAndPrune(&optimizer, &item, &output);
230
231 EXPECT_EQ(4, output.node_size());
232
233 NodeMap node_map(&output);
234 const string p = "ArithmeticOptimizer/ReplaceMulWithSquare";
235 const NodeDef* square_node = node_map.GetNode(strings::StrCat(p, "_", "mul"));
236
237 ASSERT_NE(square_node, nullptr);
238 EXPECT_EQ("Square", square_node->op());
239 EXPECT_EQ("c", square_node->input(0));
240 EXPECT_EQ("^d", square_node->input(1));
241
242 auto tensors = EvaluateNodes(output, item.fetch);
243 EXPECT_EQ(1, tensors.size());
244 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
245 }
246
TEST_F(ArithmeticOptimizerTest,RemoveInvolution_AdjacentNodes)247 TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) {
248 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
249
250 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
251 auto neg1 = ops::Neg(s.WithOpName("neg1"), c);
252 auto neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
253 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
254 auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
255 auto id = ops::Identity(s.WithOpName("id"), recip2);
256
257 GrapplerItem item;
258 item.fetch = {"id"};
259 TF_CHECK_OK(s.ToGraphDef(&item.graph));
260 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
261 EXPECT_EQ(1, tensors_expected.size());
262
263 GraphDef output;
264 ArithmeticOptimizer optimizer;
265 EnableOnlyRemoveInvolution(&optimizer);
266 OptimizeAndPrune(&optimizer, &item, &output);
267
268 // Negation and Reciprocal nodes cancelled each other.
269 EXPECT_EQ(2, output.node_size());
270 EXPECT_EQ("id", output.node(1).name());
271 EXPECT_EQ("c", output.node(1).input(0));
272
273 auto tensors = EvaluateNodes(output, item.fetch);
274 EXPECT_EQ(1, tensors.size());
275 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
276 }
277
TEST_F(ArithmeticOptimizerTest,RemoveInvolution_AroundValuePreservingChain)278 TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AroundValuePreservingChain) {
279 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
280
281 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
282 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
283 auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
284 auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
285 auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze);
286 auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
287
288 std::vector<string> fetch = {"id2"};
289
290 GrapplerItem item;
291 item.fetch = fetch;
292 TF_CHECK_OK(s.ToGraphDef(&item.graph));
293 auto tensors_expected = EvaluateNodes(item.graph, fetch);
294 EXPECT_EQ(1, tensors_expected.size());
295
296 GraphDef output;
297 ArithmeticOptimizer optimizer;
298 EnableOnlyRemoveInvolution(&optimizer);
299 OptimizeTwiceAndPrune(&optimizer, &item, &output);
300
301 // Check that Reciprocal nodes were removed from the graph.
302 EXPECT_EQ(3, output.node_size());
303
304 // And const directly flows into squeeze.
305 int found = 0;
306 for (const NodeDef& node : output.node()) {
307 if (node.name() == "squeeze") {
308 EXPECT_EQ("c", node.input(0));
309 found++;
310 } else if (node.name() == "id2") {
311 EXPECT_EQ("squeeze", node.input(0));
312 found++;
313 }
314 }
315 EXPECT_EQ(2, found);
316
317 auto tensors = EvaluateNodes(output, fetch);
318 EXPECT_EQ(1, tensors.size());
319 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
320 }
321
TEST_F(ArithmeticOptimizerTest,RemoveInvolution_SkipControlDependencies)322 TEST_F(ArithmeticOptimizerTest, RemoveInvolution_SkipControlDependencies) {
323 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
324
325 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
326 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
327 auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
328 auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
329 auto recip2 = ops::Reciprocal(
330 s.WithOpName("recip2").WithControlDependencies(squeeze), c);
331 auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
332
333 std::vector<string> fetch = {"id2"};
334
335 GrapplerItem item;
336 item.fetch = fetch;
337 TF_CHECK_OK(s.ToGraphDef(&item.graph));
338
339 auto tensors_expected = EvaluateNodes(item.graph, fetch);
340 EXPECT_EQ(1, tensors_expected.size());
341
342 GraphDef output;
343 ArithmeticOptimizer optimizer;
344 EnableOnlyRemoveInvolution(&optimizer);
345 OptimizeTwice(&optimizer, &item, &output); // do not prune in this test
346
347 // The optimizer should be a noop.
348 VerifyGraphsMatch(item.graph, output, __LINE__);
349
350 auto tensors = EvaluateNodes(output, fetch);
351 EXPECT_EQ(1, tensors.size());
352 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
353 }
354
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimple)355 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
356 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
357 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
358 Output add = ops::Add(s.WithOpName("add"), x, x);
359 Output id = ops::Identity(s.WithOpName("id"), add);
360
361 GrapplerItem item;
362 item.fetch = {"id"};
363 TF_CHECK_OK(s.ToGraphDef(&item.graph));
364
365 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
366 EXPECT_EQ(1, tensors_expected.size());
367
368 ArithmeticOptimizer optimizer;
369 GraphDef output;
370 OptimizeTwice(&optimizer, &item, &output);
371 NodeMap node_map(&output);
372
373 EXPECT_EQ(5, output.node_size());
374
375 const string optimized_const_name = AggregationConstName("add");
376 const string optimized_mul_name = AggregationMulName("add");
377
378 const NodeDef* new_const = node_map.GetNode(optimized_const_name);
379 ASSERT_NE(new_const, nullptr);
380 EXPECT_EQ("^x", new_const->input(0));
381 EXPECT_EQ(string("\0\0\0@", 4),
382 new_const->attr().at("value").tensor().tensor_content());
383
384 const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
385 ASSERT_NE(new_mul, nullptr);
386 EXPECT_EQ(optimized_const_name, new_mul->input(0));
387 EXPECT_EQ("x", new_mul->input(1));
388
389 const NodeDef* new_id = node_map.GetNode("id");
390 ASSERT_NE(new_id, nullptr);
391 EXPECT_EQ(optimized_mul_name, new_id->input(0));
392
393 auto tensors = EvaluateNodes(output, item.fetch);
394 EXPECT_EQ(1, tensors.size());
395 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
396 }
397
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimpleWithControlDep)398 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
399 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
400 Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
401 Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2});
402 Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x);
403 Output id = ops::Identity(s.WithOpName("id"), add);
404
405 GrapplerItem item;
406 TF_CHECK_OK(s.ToGraphDef(&item.graph));
407
408 std::vector<string> fetch = {"id"};
409 auto tensors_expected = EvaluateNodes(item.graph, fetch);
410 EXPECT_EQ(1, tensors_expected.size());
411
412 ArithmeticOptimizer optimizer;
413 GraphDef output;
414 OptimizeTwice(&optimizer, &item, &output);
415 NodeMap node_map(&output);
416
417 EXPECT_EQ(6, output.node_size());
418
419 const string optimized_const_name = AggregationConstName("add");
420 const string optimized_mul_name = AggregationMulName("add");
421
422 const NodeDef* new_const = node_map.GetNode(optimized_const_name);
423 ASSERT_NE(new_const, nullptr);
424 EXPECT_EQ("^x", new_const->input(0));
425 EXPECT_EQ(string("\0\0\0@", 4),
426 new_const->attr().at("value").tensor().tensor_content());
427
428 const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
429 ASSERT_NE(new_mul, nullptr);
430 EXPECT_EQ(optimized_const_name, new_mul->input(0));
431 EXPECT_EQ("x", new_mul->input(1));
432 EXPECT_EQ("^y", new_mul->input(2));
433
434 const NodeDef* new_id = node_map.GetNode("id");
435 ASSERT_NE(new_id, nullptr);
436 EXPECT_EQ(optimized_mul_name, new_id->input(0));
437
438 auto tensors = EvaluateNodes(output, fetch);
439 EXPECT_EQ(1, tensors.size());
440 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
441 }
442
TEST_F(ArithmeticOptimizerTest,TrivialSumsRepeatedAdd)443 TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
444 // Test case from b/69059093.
445 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
446 Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10}));
447 Output add = ops::Add(s.WithOpName("Add"), p, p);
448 Output add1 = ops::Add(s.WithOpName("Add_1"), p, p);
449 Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1);
450 Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1);
451 Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5);
452 Output id = ops::Identity(s.WithOpName("id"), add6);
453
454 GrapplerItem item;
455 TF_CHECK_OK(s.ToGraphDef(&item.graph));
456
457 const std::vector<string> devices{
458 "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1",
459 "/device:CPU:0", "/device:CPU:0", "/device:CPU:0",
460 };
461 for (int i = 0; i < item.graph.node_size(); ++i) {
462 item.graph.mutable_node(i)->set_device(devices[i]);
463 }
464
465 ArithmeticOptimizer optimizer;
466 DisableAddToAddNCombining(&optimizer);
467
468 GraphDef output;
469 OptimizeTwice(&optimizer, &item, &output);
470
471 // We expect the following rewrite(s) to occur:
472 //
473 // Mul(p,
474 // Add_6(Add_4(Const(2), Const(2)),
475 // Add_5(Const(2), Const(2))))
476 NodeMap node_map(&output);
477
478 EXPECT_EQ(17, output.node_size());
479
480 const NodeDef* id_node = node_map.GetNode("id");
481 ASSERT_NE(id_node, nullptr);
482 EXPECT_EQ(1, id_node->input_size());
483 EXPECT_EQ(HoistMulName("Add_6"), id_node->input(0));
484
485 const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
486 ASSERT_NE(mul_node, nullptr);
487 EXPECT_EQ(2, mul_node->input_size());
488 EXPECT_EQ("Placeholder", mul_node->input(0));
489 EXPECT_EQ(HoistAddName("Add_6"), mul_node->input(1));
490
491 const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
492 ASSERT_NE(add_6_node, nullptr);
493 EXPECT_EQ(2, add_6_node->input_size());
494 EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0));
495 EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1));
496
497 const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
498 ASSERT_NE(add_4_node, nullptr);
499 EXPECT_EQ("Add", add_4_node->op());
500 EXPECT_EQ(2, add_4_node->input_size());
501 EXPECT_EQ(AggregationConstName("Add"), add_4_node->input(0));
502 EXPECT_EQ(AggregationConstName("Add_1"), add_4_node->input(1));
503
504 const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
505 ASSERT_NE(add_5_node, nullptr);
506 EXPECT_EQ("Add", add_5_node->op());
507 EXPECT_EQ(2, add_5_node->input_size());
508 EXPECT_EQ(AggregationConstName("Add"), add_5_node->input(0));
509 EXPECT_EQ(AggregationConstName("Add_1"), add_5_node->input(1));
510
511 const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add"));
512 ASSERT_NE(add_const_node, nullptr);
513 EXPECT_EQ("Const", add_const_node->op());
514 EXPECT_EQ(1, add_const_node->input_size());
515 EXPECT_EQ("^Placeholder", add_const_node->input(0));
516
517 const NodeDef* add_1_const_node =
518 node_map.GetNode(AggregationConstName("Add_1"));
519 ASSERT_NE(add_1_const_node, nullptr);
520 EXPECT_EQ("Const", add_1_const_node->op());
521 EXPECT_EQ(1, add_1_const_node->input_size());
522 EXPECT_EQ("^Placeholder", add_1_const_node->input(0));
523 }
524
TEST_F(ArithmeticOptimizerTest,HoistFactorMul)525 TEST_F(ArithmeticOptimizerTest, HoistFactorMul) {
526 for (bool matching_shapes : {true, false}) {
527 for (bool use_addn : {true, false}) {
528 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
529 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
530 Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
531 Output y2 = matching_shapes
532 ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2})
533 : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
534 Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
535 Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
536 Output id =
537 use_addn ? ops::Identity(s.WithOpName("id"),
538 ops::AddN(s.WithOpName("add"), {mul1, mul2}))
539 : ops::Identity(s.WithOpName("id"),
540 ops::Add(s.WithOpName("add"), mul1, mul2));
541
542 GrapplerItem item;
543 item.fetch = {"id"};
544 TF_CHECK_OK(s.ToGraphDef(&item.graph));
545 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
546 EXPECT_EQ(1, tensors_expected.size());
547 ArithmeticOptimizer optimizer;
548 EnableOnlyHoistCommonFactor(&optimizer);
549
550 GraphDef output;
551 OptimizeTwice(&optimizer, &item, &output);
552
553 // We expect the following rewrite(s) to occur:
554 //
555 // Add Mul
556 // / \ / \
557 // Mul Mul -> x Add
558 // / \ / \ / \
559 // x y1 y2 x y1 y2
560 //
561 // If "root" op is AddN and shapes does not match, this rewrite is not
562 // possible and graph should stay intact.
563 NodeMap node_map(&output);
564
565 if (use_addn && !matching_shapes) {
566 VerifyGraphsMatch(item.graph, output, __LINE__);
567 } else {
568 EXPECT_EQ(9, output.node_size());
569
570 const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
571 ASSERT_NE(new_add_node, nullptr) << "Hoisted Add node not found";
572 EXPECT_EQ("y1", new_add_node->input(0));
573 EXPECT_EQ("y2", new_add_node->input(1));
574
575 const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
576 ASSERT_NE(new_mul_node, nullptr) << "Hoisted Mul node not found";
577 EXPECT_EQ("x", new_mul_node->input(0));
578 EXPECT_EQ(new_add_node->name(), new_mul_node->input(1));
579
580 const NodeDef* id_node = node_map.GetNode("id");
581 ASSERT_NE(id_node, nullptr) << "Id node not found";
582 EXPECT_EQ("id", id_node->name());
583 EXPECT_EQ(HoistMulName("add"), id_node->input(0));
584 }
585 auto tensors = EvaluateNodes(output, item.fetch);
586 EXPECT_EQ(1, tensors.size());
587 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
588 }
589 }
590 }
591
TEST_F(ArithmeticOptimizerTest,HoistFactorDiv)592 TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
593 for (bool matching_shapes : {true, false}) {
594 for (bool use_addn : {true, false}) {
595 for (bool use_ints : {true, false}) {
596 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
597 Output x = use_ints
598 ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2})
599 : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
600 Output y1 = use_ints
601 ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2})
602 : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
603 Output y2;
604 if (matching_shapes) {
605 y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2})
606 : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
607 } else {
608 y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1})
609 : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
610 }
611 Output div1 = ops::Div(s.WithOpName("div1"), y1, x);
612 Output div2 = ops::Div(s.WithOpName("div2"), y2, x);
613 Output id =
614 use_addn
615 ? ops::Identity(s.WithOpName("id"),
616 ops::AddN(s.WithOpName("add"), {div1, div2}))
617 : ops::Identity(s.WithOpName("id"),
618 ops::Add(s.WithOpName("add"), div1, div2));
619
620 GrapplerItem item;
621 item.fetch = {"id"};
622 TF_CHECK_OK(s.ToGraphDef(&item.graph));
623
624 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
625 EXPECT_EQ(1, tensors_expected.size());
626
627 ArithmeticOptimizer optimizer;
628 EnableOnlyHoistCommonFactor(&optimizer);
629
630 GraphDef output;
631 OptimizeTwice(&optimizer, &item, &output);
632
633 // We expect the following rewrite(s) to occur:
634 //
635 // Add Div
636 // / \ / \
637 // Div Div -> Add x
638 // / \ / \ / \
639 // y1 x y2 x y1 y2
640 //
641 // If "root" op is AddN and shapes does not match, this rewrite is not
642 // possible and graph should stay intact.
643 NodeMap node_map(&output);
644
645 if ((use_addn && !matching_shapes) || use_ints) {
646 VerifyGraphsMatch(item.graph, output, __LINE__);
647 } else {
648 EXPECT_EQ(9, output.node_size());
649
650 const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
651 ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
652 EXPECT_EQ("y1", new_add_node->input(0));
653 EXPECT_EQ("y2", new_add_node->input(1));
654
655 const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add"));
656 ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found";
657 EXPECT_EQ(new_add_node->name(), new_div_node->input(0));
658 EXPECT_EQ("x", new_div_node->input(1));
659
660 const NodeDef* id_node = node_map.GetNode("id");
661 ASSERT_TRUE(id_node != nullptr) << "Id node not found";
662 EXPECT_EQ("id", id_node->name());
663 EXPECT_EQ(HoistDivName("add"), id_node->input(0));
664 }
665 auto tensors = EvaluateNodes(output, item.fetch);
666 EXPECT_EQ(1, tensors.size());
667 if (use_ints) {
668 test::ExpectTensorEqual<int32>(tensors_expected[0], tensors[0]);
669 } else {
670 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
671 }
672 }
673 }
674 }
675 }
676
TEST_F(ArithmeticOptimizerTest,FuseConjAndTranspose)677 TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
678 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
679 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
680 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
681 Output z = ops::Complex(s.WithOpName("z"), re, im);
682 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
683 Output conj = ops::Conj(s.WithOpName("conj"), z);
684 Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
685
686 GrapplerItem item;
687 item.fetch = {"trans"};
688 TF_CHECK_OK(s.ToGraphDef(&item.graph));
689
690 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
691 EXPECT_EQ(1, tensors_expected.size());
692
693 ArithmeticOptimizer optimizer;
694 GraphDef output;
695 OptimizeTwice(&optimizer, &item, &output);
696 NodeMap node_map(&output);
697
698 EXPECT_EQ(7, output.node_size());
699
700 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
701 const string optimized_name = strings::StrCat(p, "_", "trans");
702
703 const NodeDef* trans_fused_node = node_map.GetNode(optimized_name);
704 ASSERT_NE(trans_fused_node, nullptr);
705 EXPECT_EQ("ConjugateTranspose", trans_fused_node->op());
706 EXPECT_EQ("z", trans_fused_node->input(0));
707 EXPECT_EQ("perm", trans_fused_node->input(1));
708
709 auto tensors = EvaluateNodes(output, item.fetch);
710 EXPECT_EQ(1, tensors.size());
711 test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
712 }
713
TEST_F(ArithmeticOptimizerTest,FuseConjAndConjugateTranspose)714 TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
715 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
716
717 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
718 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
719 Output z = ops::Complex(s.WithOpName("z"), re, im);
720 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
721 Output conj = ops::Conj(s.WithOpName("conj"), z);
722 Output transp =
723 ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
724
725 GrapplerItem item;
726 item.fetch = {"conjugate_trans"};
727 TF_CHECK_OK(s.ToGraphDef(&item.graph));
728
729 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
730 EXPECT_EQ(1, tensors_expected.size());
731
732 ArithmeticOptimizer optimizer;
733 GraphDef output;
734 OptimizeTwice(&optimizer, &item, &output);
735 NodeMap node_map(&output);
736
737 EXPECT_EQ(7, output.node_size());
738
739 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
740 const string optimized_name = strings::StrCat(p, "_", "conjugate_trans");
741
742 const NodeDef* conjugate_trans_fused_node = node_map.GetNode(optimized_name);
743 ASSERT_NE(conjugate_trans_fused_node, nullptr);
744 EXPECT_EQ("Transpose", conjugate_trans_fused_node->op());
745 EXPECT_EQ("z", conjugate_trans_fused_node->input(0));
746 EXPECT_EQ("perm", conjugate_trans_fused_node->input(1));
747
748 auto tensors = EvaluateNodes(output, item.fetch);
749 EXPECT_EQ(1, tensors.size());
750 test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
751 }
752
TEST_F(ArithmeticOptimizerTest,FuseTransposeAndConj)753 TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
754 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
755 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
756 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
757 Output z = ops::Complex(s.WithOpName("z"), re, im);
758 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
759 Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
760 Output conj = ops::Conj(s.WithOpName("conj"), trans);
761
762 GrapplerItem item;
763 item.fetch = {"conj"};
764 TF_CHECK_OK(s.ToGraphDef(&item.graph));
765
766 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
767 EXPECT_EQ(1, tensors_expected.size());
768
769 ArithmeticOptimizer optimizer;
770 GraphDef output;
771 OptimizeTwice(&optimizer, &item, &output);
772 NodeMap node_map(&output);
773
774 EXPECT_EQ(7, output.node_size());
775
776 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
777 const string optimized_name = strings::StrCat(p, "_", "conj");
778
779 const NodeDef* conj_fused_node = node_map.GetNode(optimized_name);
780 ASSERT_NE(conj_fused_node, nullptr);
781 EXPECT_EQ("ConjugateTranspose", conj_fused_node->op());
782 EXPECT_EQ("z", conj_fused_node->input(0));
783 EXPECT_EQ("perm", conj_fused_node->input(1));
784
785 auto tensors = EvaluateNodes(output, item.fetch);
786 EXPECT_EQ(1, tensors.size());
787 test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
788 }
789
TEST_F(ArithmeticOptimizerTest,FoldTransposeIntoMatMul)790 TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
791 for (const string matmul_type : {"MatMul", "SparseMatMul", "BatchMatMul"}) {
792 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
793
794 Output a = ops::Const(s.WithOpName("a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
795 Output b = ops::Const(s.WithOpName("b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
796 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
797 Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
798 Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
799
800 auto matmul_op = s.WithOpName("matmul");
801 if (matmul_type == "MatMul") {
802 Output matmul = ops::MatMul(matmul_op, trans_a, trans_b);
803 } else if (matmul_type == "SparseMatMul") {
804 Output matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
805 } else if (matmul_type == "BatchMatMul") {
806 Output matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
807 }
808
809 GrapplerItem item;
810 item.fetch = {"matmul"};
811 TF_CHECK_OK(s.ToGraphDef(&item.graph));
812
813 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
814 EXPECT_EQ(1, tensors_expected.size());
815
816 ArithmeticOptimizer optimizer;
817 EnableOnlyFoldTransposeIntoMatMul(&optimizer);
818 GraphDef output;
819 OptimizeTwice(&optimizer, &item, &output);
820 NodeMap node_map(&output);
821
822 EXPECT_EQ(7, output.node_size());
823
824 const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
825 const string optimized_name = strings::StrCat(p, "_", "matmul");
826
827 const NodeDef* matmul_fused_node = node_map.GetNode(optimized_name);
828 ASSERT_NE(matmul_fused_node, nullptr);
829 EXPECT_EQ("a", matmul_fused_node->input(0));
830 EXPECT_EQ("b", matmul_fused_node->input(1));
831
832 if (matmul_type == "BatchMatMul") {
833 EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b());
834 EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b());
835 } else {
836 EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
837 EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
838 }
839
840 auto tensors = EvaluateNodes(output, item.fetch);
841 EXPECT_EQ(1, tensors.size());
842 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
843 }
844 }
845
TEST_F(ArithmeticOptimizerTest,FoldConjugateTransposeIntoBatchMatMul)846 TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
847 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
848
849 Output re_a =
850 ops::Const(s.WithOpName("re_a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
851 Output im_a =
852 ops::Const(s.WithOpName("im_a"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
853 Output re_b =
854 ops::Const(s.WithOpName("re_b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
855 Output im_b =
856 ops::Const(s.WithOpName("im_b"), {-5.0f, -6.0f, -7.0f, -8.0f}, {2, 2});
857 Output a = ops::Complex(s.WithOpName("a"), re_a, im_a);
858 Output b = ops::Complex(s.WithOpName("b"), re_b, im_b);
859 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
860 Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
861 Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
862 Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
863
864 GrapplerItem item;
865 item.fetch = {"matmul"};
866 TF_CHECK_OK(s.ToGraphDef(&item.graph));
867
868 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
869 EXPECT_EQ(1, tensors_expected.size());
870
871 ArithmeticOptimizer optimizer;
872 GraphDef output;
873 OptimizeTwice(&optimizer, &item, &output);
874
875 NodeMap node_map(&output);
876 ASSERT_EQ(11, output.node_size());
877
878 const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
879 const string optimized_name = strings::StrCat(p, "_", "matmul");
880
881 const NodeDef* optimized_matmul = node_map.GetNode(optimized_name);
882 ASSERT_NE(optimized_matmul, nullptr);
883 EXPECT_EQ("a", optimized_matmul->input(0));
884 EXPECT_EQ("b", optimized_matmul->input(1));
885 EXPECT_TRUE(optimized_matmul->attr().at("adj_x").b());
886 EXPECT_TRUE(optimized_matmul->attr().at("adj_y").b());
887
888 auto tensors = EvaluateNodes(output, item.fetch);
889 EXPECT_EQ(1, tensors.size());
890 test::ExpectTensorNear<complex64>(tensors_expected[0], tensors[0], 1e-6);
891 }
892
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_IdentityReshape)893 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_IdentityReshape) {
894 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
895 Output inputs =
896 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
897 Output inputs_shape = ops::Shape(s, inputs);
898 // The target shape of the reshape is the concatenation of `batch_size` and
899 // [3,28,28].
900 Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
901 ops::Const(s, {1}, {1}));
902 Output target_shape = ops::Concat(
903 s.WithOpName("target_shape"),
904 {batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {}));
905 Output reshape = ops::Reshape(s, inputs, target_shape);
906 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
907
908 GrapplerItem item;
909 item.fetch = {"outputs"};
910 TF_CHECK_OK(s.ToGraphDef(&item.graph));
911 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
912 auto tensors_expected =
913 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
914 EXPECT_EQ(1, tensors_expected.size());
915
916 GraphDef output;
917 ArithmeticOptimizer optimizer;
918 EnableOnlyRemoveRedundantReshape(&optimizer);
919 OptimizeTwiceAndPrune(&optimizer, &item, &output);
920
921 EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
922 auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
923 EXPECT_EQ(1, tensors.size());
924 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
925 }
926
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_IdentityReshapeBetweenSymbolicShapes)927 TEST_F(ArithmeticOptimizerTest,
928 RemoveRedundantReshape_IdentityReshapeBetweenSymbolicShapes) {
929 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
930 Output inputs =
931 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
932 Output inputs_shape = ops::Shape(s, inputs);
933 // The target shape of the reshape is the concatenation of `batch_size`, 3,
934 // `height, and `width`.
935 Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
936 ops::Const(s, {1}, {1}));
937 Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
938 ops::Const(s, {1}, {1}));
939 Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
940 ops::Const(s, {1}, {1}));
941 Output target_shape =
942 ops::Concat(s.WithOpName("target_shape"),
943 {batch_size, ops::Const(s, {3}, {1}), height, width},
944 ops::Const(s, {0}, {}));
945 Output reshape = ops::Reshape(s, inputs, target_shape);
946 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
947
948 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
949 GrapplerItem item;
950 item.fetch = {"outputs"};
951 item.feed = {{"Placeholder", x_t}};
952 TF_CHECK_OK(s.ToGraphDef(&item.graph));
953
954 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
955 EXPECT_EQ(1, tensors_expected.size());
956
957 GraphDef output;
958 // Assume valid feed shape in aggressive mode.
959 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
960 EnableOnlyRemoveRedundantReshape(&optimizer);
961 OptimizeTwiceAndPrune(&optimizer, &item, &output);
962
963 EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
964 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
965 EXPECT_EQ(1, tensors.size());
966 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
967 }
968
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_NotAssumeValidFeeds)969 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotAssumeValidFeeds) {
970 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
971 Output inputs =
972 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
973 Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
974 Output reshape = ops::Reshape(s, inputs, target_shape);
975 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
976
977 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
978 GrapplerItem item;
979 item.fetch = {"outputs"};
980 item.feed = {{"Placeholder", x_t}};
981 TF_CHECK_OK(s.ToGraphDef(&item.graph));
982
983 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
984 EXPECT_EQ(1, tensors_expected.size());
985
986 GraphDef output;
987 ArithmeticOptimizer optimizer;
988 EnableOnlyRemoveRedundantReshape(&optimizer);
989 OptimizeTwiceAndPrune(&optimizer, &item, &output);
990
991 // The reshape is preserved because the shape of the placeholder can be
992 // different from the shape of the actual feed.
993 EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
994
995 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
996 EXPECT_EQ(1, tensors.size());
997 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
998 }
999
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_AssumeValidFeedsInAggressiveMode)1000 TEST_F(ArithmeticOptimizerTest,
1001 RemoveRedundantReshape_AssumeValidFeedsInAggressiveMode) {
1002 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1003 Output inputs =
1004 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
1005 Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
1006 Output reshape = ops::Reshape(s, inputs, target_shape);
1007 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1008
1009 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
1010 GrapplerItem item;
1011 item.fetch = {"outputs"};
1012 item.feed = {{"Placeholder", x_t}};
1013 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1014
1015 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1016 EXPECT_EQ(1, tensors_expected.size());
1017
1018 GraphDef output;
1019 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1020 EnableOnlyRemoveRedundantReshape(&optimizer);
1021 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1022
1023 EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
1024 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1025 EXPECT_EQ(1, tensors.size());
1026 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
1027 }
1028
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_NotIdentityReshape)1029 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotIdentityReshape) {
1030 // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
1031 // be from [4,3,28,28] to [8,6,28,28].
1032 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1033 Output inputs =
1034 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
1035 Output reshape = ops::Reshape(s, inputs, ops::Const(s, {8, -1, 28, 28}, {4}));
1036 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1037
1038 GrapplerItem item;
1039 item.fetch = {"outputs"};
1040 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1041 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28}));
1042 item.feed = {{"Placeholder", x_t}};
1043 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1044 EXPECT_EQ(1, tensors_expected.size());
1045
1046 GraphDef output;
1047 ArithmeticOptimizer optimizer;
1048 EnableOnlyRemoveRedundantReshape(&optimizer);
1049 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1050
1051 EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
1052 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1053 EXPECT_EQ(1, tensors.size());
1054 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
1055 }
1056
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_NotIdentityReshapeTooManyUnknownDimSizes)1057 TEST_F(ArithmeticOptimizerTest,
1058 RemoveRedundantReshape_NotIdentityReshapeTooManyUnknownDimSizes) {
1059 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1060 Output inputs =
1061 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
1062 Output reshape = ops::Reshape(s, inputs, ops::Const(s, {-1, -1}, {2}));
1063 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1064
1065 GrapplerItem item;
1066 item.fetch = {"outputs"};
1067 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1068
1069 GraphDef output;
1070 ArithmeticOptimizer optimizer;
1071 EnableOnlyRemoveRedundantReshape(&optimizer);
1072 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1073
1074 EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
1075 }
1076
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshape_CombineReshapes)1077 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_CombineReshapes) {
1078 // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two
1079 // reshapes should be combined.
1080 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1081 Output nchw_vect_c =
1082 ops::Placeholder(s.WithOpName("nchw_vect_c"), DT_INT8,
1083 ops::Placeholder::Shape({8, 3, 28, 28, 4}));
1084 Output transpose =
1085 ops::Transpose(s.WithOpName("transpose"), nchw_vect_c,
1086 ops::Const(s.WithOpName("perm"), {0, 2, 3, 1, 4}, {5}));
1087 Output nhwc = ops::Reshape(
1088 s.WithOpName("nhwc"), transpose,
1089 ops::Const(s.WithOpName("nhwc_shape"), {8, 28, 28, 12}, {4}));
1090 Output flatten = ops::Reshape(
1091 s.WithOpName("flatten"), nhwc,
1092 ops::Const(s.WithOpName("flatten_shape"), {8, 28 * 28 * 12}, {2}));
1093 Output outputs = ops::Identity(s.WithOpName("outputs"), flatten);
1094
1095 GrapplerItem item;
1096 item.fetch = {"outputs"};
1097 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1098 auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({8, 3, 28, 28, 4}));
1099 item.feed = {{"nchw_vect_c", x_t}};
1100 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1101 EXPECT_EQ(1, tensors_expected.size());
1102
1103 GraphDef output;
1104 ArithmeticOptimizer optimizer;
1105 EnableOnlyRemoveRedundantReshape(&optimizer);
1106 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1107
1108 EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
1109 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1110 EXPECT_EQ(1, tensors.size());
1111 test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
1112 }
1113
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCast_ProducerIsCast)1114 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast_ProducerIsCast) {
1115 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1116 Output nhwc_uint8 =
1117 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1118 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1119 Output nchw_fp32 =
1120 ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1121 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1122
1123 GrapplerItem item;
1124 item.fetch = {"outputs"};
1125 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1126
1127 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1128 auto tensors_expected =
1129 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1130 EXPECT_EQ(1, tensors_expected.size());
1131
1132 GraphDef output;
1133 ArithmeticOptimizer optimizer;
1134 OptimizeAndPrune(&optimizer, &item, &output);
1135
1136 const NodeDef* transpose_node = nullptr;
1137 for (const NodeDef& node : output.node()) {
1138 if (node.op() == "Transpose") {
1139 EXPECT_EQ(transpose_node, nullptr);
1140 EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
1141 transpose_node = &node;
1142 }
1143 }
1144 EXPECT_NE(transpose_node, nullptr);
1145
1146 for (const NodeDef& node : output.node()) {
1147 if (node.op() == "Cast") {
1148 EXPECT_EQ(NodeName(node.input(0)), transpose_node->name());
1149 }
1150 }
1151
1152 auto tensors =
1153 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1154 EXPECT_EQ(1, tensors.size());
1155 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1156 }
1157
TEST_F(ArithmeticOptimizerTest,ReorderS2DCast_ProducerIsCast)1158 TEST_F(ArithmeticOptimizerTest, ReorderS2DCast_ProducerIsCast) {
1159 // TODO(jingyue): Evaluate S2D+Cast on GPU as well. We can't simply put nodes
1160 // under a /GPU:0 scope, because this test would fail if the testing machine
1161 // doesn't have a GPU. Maybe EvaluateNodes should allow soft placement?
1162 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1163 Output outputs =
1164 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1165 outputs = ops::Cast(s, outputs, DT_FLOAT);
1166 outputs = ops::SpaceToDepth(s, outputs, 2);
1167 outputs = ops::Identity(s.WithOpName("outputs"), outputs);
1168
1169 GrapplerItem item;
1170 item.fetch = {"outputs"};
1171 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1172
1173 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1174 auto tensors_expected =
1175 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1176 EXPECT_EQ(1, tensors_expected.size());
1177
1178 GraphDef output;
1179 ArithmeticOptimizer optimizer;
1180 OptimizeAndPrune(&optimizer, &item, &output);
1181
1182 const NodeDef* s2d_node = nullptr;
1183 for (const NodeDef& node : output.node()) {
1184 if (node.op() == "SpaceToDepth") {
1185 EXPECT_EQ(s2d_node, nullptr);
1186 EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
1187 s2d_node = &node;
1188 }
1189 }
1190 EXPECT_NE(s2d_node, nullptr);
1191
1192 for (const NodeDef& node : output.node()) {
1193 if (node.op() == "Cast") {
1194 EXPECT_EQ(NodeName(node.input(0)), s2d_node->name());
1195 }
1196 }
1197
1198 auto tensors =
1199 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1200 EXPECT_EQ(1, tensors.size());
1201 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1202 }
1203
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCast_ProducerIsTranspose)1204 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast_ProducerIsTranspose) {
1205 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1206 Output nhwc_fp32 =
1207 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1208 Output nchw_fp32 =
1209 ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1210 Output nchw_uint8 = ops::Cast(s, nchw_fp32, DT_UINT8);
1211 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1212
1213 GrapplerItem item;
1214 item.fetch = {"outputs"};
1215 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1216
1217 auto input_t =
1218 GenerateConstantTensor<DT_FLOAT>(TensorShape({8, 28, 28, 3}), 42.0f);
1219 auto tensors_expected =
1220 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1221 EXPECT_EQ(1, tensors_expected.size());
1222
1223 GraphDef output;
1224 ArithmeticOptimizer optimizer;
1225 OptimizeAndPrune(&optimizer, &item, &output);
1226
1227 const NodeDef* cast_node = nullptr;
1228 for (const NodeDef& node : output.node()) {
1229 if (node.op() == "Cast") {
1230 EXPECT_EQ(cast_node, nullptr);
1231 cast_node = &node;
1232 EXPECT_EQ(NodeName(node.input(0)), "Placeholder");
1233 }
1234 }
1235 EXPECT_NE(cast_node, nullptr);
1236
1237 for (const NodeDef& node : output.node()) {
1238 if (node.op() == "Transpose") {
1239 EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
1240 EXPECT_EQ(NodeName(node.input(0)), cast_node->name());
1241 }
1242 }
1243
1244 auto tensors =
1245 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1246 EXPECT_EQ(1, tensors.size());
1247 test::ExpectTensorEqual<uint8>(tensors_expected[0], tensors[0]);
1248 }
1249
TEST_F(ArithmeticOptimizerTest,ReorderTransposeReverseCast)1250 TEST_F(ArithmeticOptimizerTest, ReorderTransposeReverseCast) {
1251 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1252 Output nhwc_uint8 =
1253 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1254 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1255 Output nhwc_fp32_reversed =
1256 ops::Reverse(s, nhwc_fp32, ops::Const(s, {0}, {1}));
1257 Output nchw_fp32_reversed =
1258 ops::Transpose(s, nhwc_fp32_reversed, ops::Const(s, {0, 3, 1, 2}, {4}));
1259
1260 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32_reversed);
1261
1262 GrapplerItem item;
1263 item.fetch = {"outputs"};
1264 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1265
1266 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1267 auto tensors_expected =
1268 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1269 EXPECT_EQ(1, tensors_expected.size());
1270
1271 GraphDef output;
1272 ArithmeticOptimizer optimizer;
1273 OptimizeAndPrune(&optimizer, &item, &output);
1274
1275 const NodeDef* reverse_node = nullptr;
1276 const NodeDef* transpose_node = nullptr;
1277 const NodeDef* cast_node = nullptr;
1278 for (const NodeDef& node : output.node()) {
1279 if (node.op() == "Transpose") {
1280 EXPECT_EQ(transpose_node, nullptr);
1281 EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
1282 transpose_node = &node;
1283 } else if (node.op() == "ReverseV2") {
1284 EXPECT_EQ(reverse_node, nullptr);
1285 EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
1286 reverse_node = &node;
1287 } else if (node.op() == "Cast") {
1288 cast_node = &node;
1289 }
1290 }
1291 EXPECT_NE(cast_node, nullptr);
1292 EXPECT_NE(reverse_node, nullptr);
1293 EXPECT_NE(transpose_node, nullptr);
1294 EXPECT_EQ(NodeName(reverse_node->input(0)), "Placeholder");
1295 EXPECT_EQ(NodeName(transpose_node->input(0)), reverse_node->name());
1296 EXPECT_EQ(NodeName(cast_node->input(0)), transpose_node->name());
1297
1298 auto tensors =
1299 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1300 EXPECT_EQ(1, tensors.size());
1301 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
1302 }
1303
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCast_CheckNumericsToIdentity)1304 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast_CheckNumericsToIdentity) {
1305 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1306 Output nhwc_uint8 =
1307 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1308 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1309 Output nchw_fp32 = ops::CheckNumerics(s, nhwc_fp32, "foo");
1310 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1311
1312 GrapplerItem item;
1313 item.fetch = {"outputs"};
1314 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1315
1316 GraphDef output;
1317 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1318 CompareGraphs(item.graph, output);
1319 }
1320
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCast_ProducerIsCast)1321 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCast_ProducerIsCast) {
1322 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1323 Output nhwc_fp32 =
1324 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1325 Output nhwc_uint8 = ops::Cast(s, nhwc_fp32, DT_UINT8);
1326 Output nchw_uint8 =
1327 ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1328 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1329
1330 GrapplerItem item;
1331 item.fetch = {"outputs"};
1332 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1333
1334 GraphDef output;
1335 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1336 CompareGraphs(item.graph, output);
1337 }
1338
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCast_ProducerIsTranspose)1339 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCast_ProducerIsTranspose) {
1340 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1341 Output nhwc_uint8 =
1342 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1343 Output nchw_uint8 =
1344 ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1345 Output nchw_fp32 = ops::Cast(s, nchw_uint8, DT_FLOAT);
1346 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1347
1348 GrapplerItem item;
1349 item.fetch = {"outputs"};
1350 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1351
1352 GraphDef output;
1353 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1354 CompareGraphs(item.graph, output);
1355 }
1356
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposes)1357 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
1358 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1359 Output inputs_shape =
1360 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1361 Output inputs =
1362 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1363 Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1364 Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1365 Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
1366 Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
1367 Output transpose2 =
1368 ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
1369 Output transpose3 = ops::Transpose(s.WithOpName("transpose3"), inputs, perm3);
1370 Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1371 Output id2 = ops::Identity(s.WithOpName("id2"), transpose3);
1372
1373 GrapplerItem item;
1374 item.fetch = {"id1", "id2"};
1375 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1376
1377 GraphDef output;
1378 ArithmeticOptimizer optimizer;
1379 EnableOnlyRemoveIdentityTranspose(&optimizer);
1380 OptimizeAndPrune(&optimizer, &item, &output);
1381
1382 std::set<string> nodes_after_optimization;
1383 for (const NodeDef& node : output.node()) {
1384 nodes_after_optimization.insert(node.name());
1385 }
1386 EXPECT_EQ(nodes_after_optimization,
1387 std::set<string>({"id1", "id2", "inputs_shape", "inputs"}));
1388 }
1389
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesMultipleOutputs)1390 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
1391 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1392 Output inputs_shape =
1393 ops::Const(s.WithOpName("inputs_shape"), {8, 9, 28, 28}, {4});
1394 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1395 ops::Placeholder::Shape({8, 12, 28, 28}));
1396 OutputList split = ops::Split(s, ops::Const(s, 1), inputs, 3).output;
1397 Output perm1 = ops::Const(s, {0, 2, 3, 1}, {4});
1398 Output perm2 = ops::Const(s, {0, 3, 1, 2}, {4});
1399 Output branch0 = split[0];
1400 Output branch1 = ops::Transpose(s, ops::Transpose(s, split[1], perm1), perm2);
1401 Output branch2 = split[2];
1402 Output concat = ops::Concat(s, {branch0, branch1, branch2}, ops::Const(s, 1));
1403 Output outputs = ops::Identity(s.WithOpName("outputs"), concat);
1404
1405 GrapplerItem item;
1406 item.fetch = {"outputs"};
1407 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1408
1409 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28}));
1410 item.feed = {{"inputs", x_t}};
1411 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1412 EXPECT_EQ(1, tensors_expected.size());
1413
1414 GraphDef output;
1415 ArithmeticOptimizer optimizer;
1416 EnableOnlyRemoveIdentityTranspose(&optimizer);
1417 OptimizeAndPrune(&optimizer, &item, &output);
1418
1419 for (const NodeDef& node : output.node()) {
1420 if (node.op() == "Concat") {
1421 EXPECT_EQ(node.input(0), "Split");
1422 EXPECT_EQ(node.input(1), "Split:1");
1423 EXPECT_EQ(node.input(2), "Split:2");
1424 }
1425 }
1426
1427 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1428 EXPECT_EQ(1, tensors.size());
1429 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
1430 }
1431
TEST_F(ArithmeticOptimizerTest,RemoveTransposesWithControlDependency)1432 TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
1433 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1434 Output inputs =
1435 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3}));
1436 Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0}));
1437 Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0}));
1438 Output outputs =
1439 ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2),
1440 ops::Const(s.WithOpName("outputs_const"), 1.0f));
1441
1442 GrapplerItem item;
1443 item.fetch = {"outputs"};
1444 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1445
1446 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
1447 item.feed = {{"Placeholder", x_t}};
1448 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1449 EXPECT_EQ(1, tensors_expected.size());
1450
1451 GraphDef output;
1452 ArithmeticOptimizer optimizer;
1453 EnableOnlyRemoveIdentityTranspose(&optimizer);
1454 OptimizeAndPrune(&optimizer, &item, &output);
1455
1456 NodeMap node_map(&output);
1457 const NodeDef* outputs_node = node_map.GetNode("outputs");
1458 EXPECT_EQ(2, outputs_node->input_size());
1459 EXPECT_EQ(outputs_node->input(0), "outputs_const");
1460 EXPECT_EQ(outputs_node->input(1), "^Placeholder");
1461
1462 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1463 EXPECT_EQ(1, tensors.size());
1464 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
1465 }
1466
TEST_F(ArithmeticOptimizerTest,NotRemoveTransposes)1467 TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
1468 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1469 Output inputs_shape =
1470 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1471 Output inputs =
1472 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1473 Output perm = ops::Const(s.WithOpName("perm"), {1, 2, 3, 0}, {4});
1474 Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm);
1475 Output transpose2 =
1476 ops::Transpose(s.WithOpName("transpose2"), transpose1, perm);
1477 Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2);
1478
1479 GrapplerItem item;
1480 item.fetch = {"outputs"};
1481 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1482
1483 GraphDef output;
1484 ArithmeticOptimizer optimizer;
1485 EnableOnlyRemoveIdentityTranspose(&optimizer);
1486 OptimizeAndPrune(&optimizer, &item, &output);
1487
1488 EXPECT_EQ(6, output.node_size());
1489 }
1490
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesThroughChain)1491 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
1492 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1493 Output inputs_shape =
1494 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1495 Output inputs =
1496 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1497 Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1498 Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1499 Output transpose1 = ops::Transpose(
1500 s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
1501 Output identity = ops::Identity(s.WithOpName("id"), transpose1);
1502 Output transpose2 =
1503 ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
1504 Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1505
1506 GrapplerItem item;
1507 item.fetch = {"id1"};
1508 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1509
1510 GraphDef output;
1511 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1512 EnableOnlyRemoveIdentityTranspose(&optimizer);
1513 OptimizeAndPrune(&optimizer, &item, &output);
1514
1515 std::set<string> nodes_after_optimization;
1516 for (const NodeDef& node : output.node()) {
1517 nodes_after_optimization.insert(node.name());
1518 if (node.name() == "id") {
1519 EXPECT_EQ(2, node.input_size());
1520 EXPECT_EQ("inputs", node.input(0));
1521 EXPECT_EQ("^perm2", node.input(1));
1522 }
1523 if (node.name() == "id1") {
1524 EXPECT_EQ(1, node.input_size());
1525 EXPECT_EQ("id", node.input(0));
1526 }
1527 }
1528 EXPECT_EQ(nodes_after_optimization,
1529 std::set<string>({"id", "id1", "inputs_shape", "inputs", "perm2"}));
1530 }
1531
TEST_F(ArithmeticOptimizerTest,FoldMulToTransposeConv)1532 TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
1533 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1534 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1535 ops::Placeholder::Shape({8, 28, 28, 3}));
1536 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1537 Output scaled_inputs =
1538 ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
1539 Output perm_nhwc_to_nchw =
1540 ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
1541 Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
1542 scaled_inputs, perm_nhwc_to_nchw);
1543 Output weights = ops::Const(s.WithOpName("weights"),
1544 Input::Initializer(127.0f, {5, 5, 3, 16}));
1545 Output conv =
1546 ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
1547 "VALID", ops::Conv2D::DataFormat("NCHW"));
1548 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1549
1550 GrapplerItem item;
1551 item.fetch = {"outputs"};
1552 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1553
1554 GraphDef output;
1555 ArithmeticOptimizer optimizer;
1556 EnableOnlyFoldMultipleIntoConv(&optimizer);
1557 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1558
1559 NodeMap node_map(&output);
1560
1561 // `conv` is now a folded convolution with scaled weights.
1562 const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
1563 ASSERT_NE(folded_conv, nullptr);
1564
1565 const NodeDef* folded_conv_weights = node_map.GetNode(folded_conv->input(1));
1566 ASSERT_NE(folded_conv_weights, nullptr);
1567 EXPECT_EQ("Mul", folded_conv_weights->op());
1568
1569 // Its input should be a transpose of `inputs`.
1570 const NodeDef* transpose = node_map.GetNode(NodeName(folded_conv->input(0)));
1571 ASSERT_NE(transpose, nullptr);
1572 EXPECT_EQ("inputs", transpose->input(0));
1573 }
1574
TEST_F(ArithmeticOptimizerTest,NotFoldMulAcrossPreservedTranspose)1575 TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
1576 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1577 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1578 ops::Placeholder::Shape({8, 28, 28, 3}));
1579 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1580 Output scaled_inputs =
1581 ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
1582 Output perm_nhwc_to_nchw =
1583 ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
1584 Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
1585 scaled_inputs, perm_nhwc_to_nchw);
1586 Output weights = ops::Const(s.WithOpName("weights"),
1587 Input::Initializer(127.0f, {5, 5, 3, 16}));
1588 Output conv =
1589 ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
1590 "VALID", ops::Conv2D::DataFormat("NCHW"));
1591 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1592
1593 Tensor inputs_nchw_tensor(DT_FLOAT, {8, 3, 28, 28});
1594 memset(const_cast<char*>(inputs_nchw_tensor.tensor_data().data()), 0,
1595 inputs_nchw_tensor.tensor_data().size());
1596
1597 GrapplerItem item;
1598 item.fetch = {"outputs"};
1599 item.feed = {{"inputs_nchw", inputs_nchw_tensor}};
1600 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1601
1602 GraphDef output;
1603 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1604
1605 item.graph.Swap(&output);
1606 TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
1607
1608 NodeMap node_map(&output);
1609 const NodeDef* inputs_nchw_node_def =
1610 node_map.GetNode(inputs_nchw.node()->name());
1611 EXPECT_EQ(NodeName(inputs_nchw_node_def->input(0)),
1612 scaled_inputs.node()->name());
1613 }
1614
TEST_F(ArithmeticOptimizerTest,FoldMulToConv)1615 TEST_F(ArithmeticOptimizerTest, FoldMulToConv) {
1616 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1617 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1618 ops::Placeholder::Shape({8, 28, 28, 28, 3}));
1619 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
1620 Output scaled_inputs =
1621 ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
1622 Output weights = ops::Const(s.WithOpName("weights"),
1623 Input::Initializer(127.0f, {5, 5, 5, 3, 16}));
1624 Output conv = ops::Conv3D(s.WithOpName("conv"), scaled_inputs, weights,
1625 {1, 1, 1, 1, 1}, "VALID");
1626 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1627
1628 GrapplerItem item;
1629 item.fetch = {"outputs"};
1630 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1631
1632 GraphDef output;
1633 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1634
1635 item.graph.Swap(&output);
1636 TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
1637
1638 NodeMap node_map(&output);
1639 // `conv` is now a folded convolution on `inputs` and scaled weights.
1640 const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
1641 CHECK_EQ(inputs.node()->name(), NodeName(folded_conv->input(0)));
1642 CHECK_EQ(node_map.GetNode(NodeName(folded_conv->input(1)))->op(), "Mul");
1643 }
1644
TEST_F(ArithmeticOptimizerTest,OptimizeCastMulTransposeConv)1645 TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
1646 // This unit test exercises two optimizations, folding mul into conv, and
1647 // reordering cast and transpose.
1648 //
1649 // Conv2D(Transpose(Mul(Cast(I), S)), W)
1650 // =>
1651 // Conv2D(Transpose(Cast(I)), W*S)
1652 // =>
1653 // Conv2D(Cast(Transpose(I)), W*S)
1654 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
1655
1656 Output inputs =
1657 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1658 Output cast = ops::Cast(s, inputs, DT_FLOAT);
1659 Output mul = ops::Mul(s, cast, ops::Const(s, 1.0f / 255.0f));
1660 Output transpose =
1661 ops::Transpose(s, mul, ops::Const(s.WithOpName("perm"), {0, 3, 1, 2}));
1662 Output weights = ops::Const(s.WithOpName("weights"),
1663 Input::Initializer(127.0f, {5, 5, 3, 16}));
1664 Output conv = ops::Conv2D(s, transpose, weights, {1, 1, 1, 1}, "VALID",
1665 ops::Conv2D::DataFormat("NCHW"));
1666 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
1667
1668 GrapplerItem item;
1669 item.fetch = {"outputs"};
1670 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1671
1672 GraphDef output;
1673 ArithmeticOptimizer optimizer; // all optimization stages are on
1674 OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
1675 NodeMap node_map(&output);
1676
1677 // Expected names for reordered cast and transpose.
1678 const string p = "ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_";
1679 const string optimized_cast_name = strings::StrCat(p, "float_Cast");
1680 const string optimized_transpose_name = strings::StrCat(p, "uint8_Transpose");
1681
1682 // Expected names for folded multiply and conv.
1683 const string optimized_weights =
1684 "ArithmeticOptimizer/FoldMultiplyIntoConv_scaled_Conv2D_weights";
1685
1686 const NodeDef* inputs_node = node_map.GetNode("Placeholder");
1687 const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name);
1688 const NodeDef* cast_node = node_map.GetNode(optimized_cast_name);
1689
1690 const NodeDef* weights_node = node_map.GetNode(optimized_weights);
1691 const NodeDef* conv_node = node_map.GetNode("Conv2D");
1692
1693 ASSERT_NE(inputs_node, nullptr);
1694 ASSERT_NE(transpose_node, nullptr);
1695 ASSERT_NE(cast_node, nullptr);
1696 ASSERT_NE(weights_node, nullptr);
1697 ASSERT_NE(conv_node, nullptr);
1698
1699 EXPECT_EQ(output.node_size(), 7);
1700 EXPECT_EQ(transpose_node->input(0), inputs_node->name());
1701 EXPECT_EQ(cast_node->input(0), transpose_node->name());
1702 EXPECT_EQ(conv_node->input(0), cast_node->name());
1703 EXPECT_EQ(conv_node->input(1), weights_node->name());
1704 }
1705
TEST_F(ArithmeticOptimizerTest,OptimizeMultipleMulTransposeConv)1706 TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
1707 // This unit test exercises optimization of folding mul into conv for
1708 // multiple nodes in the graph.
1709 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
1710
1711 GrapplerItem item;
1712 Output conv[2];
1713
1714 for (int i = 0; i < 2; ++i) {
1715 Output inputs =
1716 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28}));
1717 Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f));
1718 Output weights = ops::Const(s.WithOpName("weights"),
1719 Input::Initializer(127.0f, {5, 5, 3, 16}));
1720 conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID",
1721 ops::Conv2D::DataFormat("NCHW"));
1722 }
1723 Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]);
1724
1725 item.fetch = {"outputs"};
1726 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1727
1728 GraphDef output;
1729 ArithmeticOptimizer optimizer;
1730 EnableOnlyFoldMultipleIntoConv(&optimizer);
1731 OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
1732
1733 NodeMap node_map(&output);
1734
1735 using strings::StrCat;
1736 const string p = "ArithmeticOptimizer/FoldMultiplyIntoConv_";
1737 const string optimized_weights = StrCat(p, "scaled_Conv2D_weights");
1738 const string optimized_weights_1 = StrCat(p, "scaled_Conv2D_1_weights_1");
1739
1740 const NodeDef* weights_node = node_map.GetNode(optimized_weights);
1741 const NodeDef* weights_node_1 = node_map.GetNode(optimized_weights_1);
1742 const NodeDef* conv_node = node_map.GetNode("Conv2D");
1743 const NodeDef* conv_node_1 = node_map.GetNode("Conv2D_1");
1744
1745 ASSERT_NE(weights_node, nullptr);
1746 ASSERT_NE(weights_node_1, nullptr);
1747 ASSERT_NE(conv_node, nullptr);
1748 ASSERT_NE(conv_node_1, nullptr);
1749
1750 EXPECT_EQ(conv_node->input(1), weights_node->name());
1751 EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
1752 }
1753
TEST_F(ArithmeticOptimizerTest,CombineBitcasts)1754 TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
1755 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1756 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8,
1757 ops::Placeholder::Shape({2, 3}));
1758 Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8);
1759 Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8);
1760 Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
1761
1762 GrapplerItem item;
1763 item.fetch = {"outputs"};
1764 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1765
1766 auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3}));
1767 item.feed = {{"inputs", x_t}};
1768 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1769 EXPECT_EQ(1, tensors_expected.size());
1770
1771 GraphDef output;
1772 ArithmeticOptimizer optimizer;
1773 EnableOnlyRemoveRedundantBitcast(&optimizer);
1774
1775 OptimizeAndPrune(&optimizer, &item, &output);
1776 NodeMap node_map(&output);
1777
1778 // Bitcasts combined into a single op and inputs redirected to updated Bitcast
1779 EXPECT_EQ(3, output.node_size());
1780 EXPECT_EQ(1, CountOpNodes(output, "Bitcast"));
1781 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
1782
1783 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1784 EXPECT_EQ(1, tensors.size());
1785 test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
1786 }
1787
TEST_F(ArithmeticOptimizerTest,CombineAndRemoveBitcasts)1788 TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
1789 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1790 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
1791 ops::Placeholder::Shape({2, 3}));
1792 Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
1793 Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
1794 Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
1795
1796 GrapplerItem item;
1797 item.fetch = {"outputs"};
1798 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1799
1800 auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
1801 item.feed = {{"inputs", x_t}};
1802 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1803 EXPECT_EQ(1, tensors_expected.size());
1804
1805 GraphDef output;
1806 ArithmeticOptimizer optimizer;
1807 EnableOnlyRemoveRedundantBitcast(&optimizer);
1808
1809 OptimizeAndPrune(&optimizer, &item, &output);
1810 NodeMap node_map(&output);
1811
1812 // Bitcasts removed and inputs redirected to outputs
1813 EXPECT_EQ(2, output.node_size());
1814 EXPECT_EQ(0, CountOpNodes(output, "Bitcast"));
1815 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
1816
1817 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1818 EXPECT_EQ(1, tensors.size());
1819 test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
1820 }
1821
TEST_F(ArithmeticOptimizerTest,RemoveRedundantCast)1822 TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
1823 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1824 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
1825 ops::Placeholder::Shape({2, 3}));
1826 Output cast = ops::Cast(s, inputs, DT_INT8);
1827 Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
1828
1829 GrapplerItem item;
1830 item.fetch = {"outputs"};
1831 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1832
1833 auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
1834 item.feed = {{"inputs", x_t}};
1835 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1836 EXPECT_EQ(1, tensors_expected.size());
1837
1838 GraphDef output;
1839 ArithmeticOptimizer optimizer;
1840 EnableOnlyRemoveRedundantCast(&optimizer);
1841
1842 OptimizeAndPrune(&optimizer, &item, &output);
1843 NodeMap node_map(&output);
1844
1845 // Cast removed and inputs redirected to outputs
1846 EXPECT_EQ(2, output.node_size());
1847 EXPECT_EQ(0, CountOpNodes(output, "Cast"));
1848 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
1849
1850 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1851 EXPECT_EQ(1, tensors.size());
1852 test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
1853 }
1854
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_AddOpsOfIdenticalShape)1855 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
1856 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1857 tensorflow::Scope sx = s.NewSubScope("x");
1858 tensorflow::Scope sy = s.NewSubScope("y");
1859
1860 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
1861 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
1862 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
1863 auto add_ab = ops::Add(sx.WithOpName("Add_ab"), a, b);
1864 auto add_abc = ops::Add(sy.WithOpName("Add_abc"), add_ab, c);
1865
1866 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
1867
1868 GrapplerItem item;
1869 item.fetch = {"outputs"};
1870 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1871
1872 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1873 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1874 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1875 std::vector<std::pair<string, Tensor>> feed = {
1876 {"a", a_t}, {"b", b_t}, {"c", c_t}};
1877 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
1878 EXPECT_EQ(1, tensors_expected.size());
1879
1880 GraphDef output;
1881 ArithmeticOptimizer optimizer;
1882 EnableOnlyAddToAddNCombining(&optimizer);
1883
1884 OptimizeAndPrune(&optimizer, &item, &output);
1885
1886 // We expect the following rewrite(s) to occur:
1887 //
1888 // +
1889 // / \
1890 // + c --> AddN(a, b, c)
1891 // / \
1892 // a b
1893 EXPECT_EQ(5, output.node_size());
1894
1895 NodeMap node_map(&output);
1896
1897 // check add tree was replaced with AddN
1898 const NodeDef* collapsed_add =
1899 node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc");
1900 ASSERT_NE(collapsed_add, nullptr);
1901
1902 EXPECT_EQ("AddN", collapsed_add->op());
1903 EXPECT_EQ(3, collapsed_add->input_size());
1904 EXPECT_EQ("a", collapsed_add->input(0));
1905 EXPECT_EQ("b", collapsed_add->input(1));
1906 EXPECT_EQ("c", collapsed_add->input(2));
1907
1908 // check output was re-wired to new node
1909 const NodeDef* updated_outputs = node_map.GetNode("outputs");
1910 ASSERT_NE(updated_outputs, nullptr);
1911
1912 EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
1913
1914 auto tensors = EvaluateNodes(output, item.fetch, feed);
1915 EXPECT_EQ(1, tensors.size());
1916 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
1917 }
1918
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_MultiplePasses)1919 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
1920 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1921
1922 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
1923 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
1924 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
1925 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
1926 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
1927
1928 auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
1929 auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
1930 auto z = ops::Variable(s.WithOpName("z"), {2, 2}, DT_FLOAT);
1931 auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
1932 auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
1933
1934 auto mul = ops::Multiply(s.WithOpName("Mul"), add_abc, add_xyz);
1935 auto outputs = ops::Identity(s.WithOpName("outputs"), mul);
1936
1937 GrapplerItem item;
1938 item.fetch = {"outputs"};
1939 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1940
1941 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1942 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1943 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1944 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1945 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1946 auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
1947 std::vector<std::pair<string, Tensor>> feed = {
1948 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
1949 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
1950 EXPECT_EQ(1, tensors_expected.size());
1951
1952 GraphDef output;
1953 ArithmeticOptimizer optimizer;
1954 EnableOnlyAddToAddNCombining(&optimizer);
1955
1956 OptimizeAndPrune(&optimizer, &item, &output);
1957
1958 // We expect the following rewrite(s) to occur:
1959 //
1960 // *
1961 // / \
1962 // + + *
1963 // / \ / \ / \
1964 // + c x + --> AddN(a, b, c) AddN(x, y, z))
1965 // / \ / \
1966 // a b y z
1967 EXPECT_EQ(10, output.node_size());
1968
1969 NodeMap node_map(&output);
1970
1971 // check left Add subtree replaced with AddN
1972 const NodeDef* collapsed_left =
1973 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
1974 ASSERT_NE(collapsed_left, nullptr);
1975
1976 EXPECT_EQ("AddN", collapsed_left->op());
1977 EXPECT_EQ(3, collapsed_left->input_size());
1978 EXPECT_EQ("a", collapsed_left->input(0));
1979 EXPECT_EQ("b", collapsed_left->input(1));
1980 EXPECT_EQ("c", collapsed_left->input(2));
1981
1982 // check right Add subtree replaced with AddN
1983 const NodeDef* collapsed_right =
1984 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz");
1985 ASSERT_NE(collapsed_right, nullptr);
1986
1987 EXPECT_EQ("AddN", collapsed_right->op());
1988 EXPECT_EQ(3, collapsed_right->input_size());
1989 EXPECT_EQ("x", collapsed_right->input(0));
1990 EXPECT_EQ("y", collapsed_right->input(1));
1991 EXPECT_EQ("z", collapsed_right->input(2));
1992
1993 // check that Mul inputs re-wired to new Nodes
1994 const NodeDef* updated_mul = node_map.GetNode("Mul");
1995 ASSERT_NE(updated_mul, nullptr);
1996
1997 EXPECT_EQ("Mul", updated_mul->op());
1998 EXPECT_EQ(2, updated_mul->input_size());
1999 EXPECT_EQ(collapsed_left->name(), updated_mul->input(0));
2000 EXPECT_EQ(collapsed_right->name(), updated_mul->input(1));
2001
2002 auto tensors = EvaluateNodes(output, item.fetch, feed);
2003 EXPECT_EQ(1, tensors.size());
2004 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2005 }
2006
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_AddInputMultipleTimes)2007 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) {
2008 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2009
2010 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2011 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2012 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2013 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2014 auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
2015 auto add_all = ops::Add(s.WithOpName("Add_all"), add_ab, add_bc);
2016 auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2017
2018 GrapplerItem item;
2019 item.fetch = {"outputs"};
2020 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2021
2022 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2023 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2024 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2025 std::vector<std::pair<string, Tensor>> feed = {
2026 {"a", a_t}, {"b", b_t}, {"c", c_t}};
2027 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2028 EXPECT_EQ(1, tensors_expected.size());
2029
2030 GraphDef output;
2031 ArithmeticOptimizer optimizer;
2032 EnableOnlyAddToAddNCombining(&optimizer);
2033
2034 OptimizeAndPrune(&optimizer, &item, &output);
2035
2036 // We expect the following rewrite(s) to occur:
2037 //
2038 // +
2039 // / \
2040 // + + --> AddN(a, b, b, c)
2041 // / \ / \ ^
2042 // a b c b added twice!
2043 EXPECT_EQ(5, output.node_size());
2044
2045 NodeMap node_map(&output);
2046
2047 // check Add tree replaced with AddN
2048 const NodeDef* collapsed_add =
2049 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_all");
2050 ASSERT_NE(collapsed_add, nullptr);
2051
2052 EXPECT_EQ("AddN", collapsed_add->op());
2053 EXPECT_EQ(4, collapsed_add->input_size());
2054 EXPECT_EQ("a", collapsed_add->input(0));
2055 EXPECT_EQ("b", collapsed_add->input(1));
2056 EXPECT_EQ("b", collapsed_add->input(2));
2057 EXPECT_EQ("c", collapsed_add->input(3));
2058
2059 auto tensors = EvaluateNodes(output, item.fetch, feed);
2060 EXPECT_EQ(1, tensors.size());
2061 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2062 }
2063
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_AddOpsOfSymbolicallyEqualShape)2064 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
2065 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2066
2067 // unknown input shape propagated symbolically through the graph
2068 auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
2069
2070 // [a, b, c] have symbolically equal shapes
2071 auto a = ops::Sqrt(s.WithOpName("a"), input);
2072 auto b = ops::Square(s.WithOpName("b"), input);
2073 auto c = ops::Round(s.WithOpName("c"), input);
2074
2075 // [add_ab, add_abc] shape must be inferred from inputs
2076 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2077 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2078
2079 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2080
2081 GrapplerItem item;
2082 item.fetch = {"outputs"};
2083 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2084
2085 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2086 std::vector<std::pair<string, Tensor>> feed = {{"input", x_t}};
2087 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2088 EXPECT_EQ(1, tensors_expected.size());
2089
2090 GraphDef output;
2091 ArithmeticOptimizer optimizer;
2092 EnableOnlyAddToAddNCombining(&optimizer);
2093
2094 OptimizeAndPrune(&optimizer, &item, &output);
2095
2096 // We expect the following rewrite(s) to occur:
2097 //
2098 // +
2099 // / \
2100 // + c --> AddN(a, b, c)
2101 // / \
2102 // a b
2103 EXPECT_EQ(6, output.node_size());
2104
2105 NodeMap node_map(&output);
2106
2107 // check add tree was replaced with AddN
2108 const NodeDef* collapsed_add =
2109 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2110 ASSERT_NE(collapsed_add, nullptr);
2111 EXPECT_EQ("AddN", collapsed_add->op());
2112 EXPECT_EQ(3, collapsed_add->input_size());
2113 EXPECT_EQ("a", collapsed_add->input(0));
2114 EXPECT_EQ("b", collapsed_add->input(1));
2115 EXPECT_EQ("c", collapsed_add->input(2));
2116
2117 // check output was re-wired to new node
2118 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2119 ASSERT_NE(updated_outputs, nullptr);
2120 EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
2121
2122 auto tensors = EvaluateNodes(output, item.fetch, feed);
2123 EXPECT_EQ(1, tensors.size());
2124 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2125 }
2126
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_MinimizeBCast)2127 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCast) {
2128 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2129
2130 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2131 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2132 auto c = ops::Variable(s.WithOpName("c"), {32, 32, 32}, DT_FLOAT);
2133 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2134 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2135
2136 auto x = ops::Variable(s.WithOpName("x"), {32}, DT_FLOAT);
2137 auto y = ops::Variable(s.WithOpName("y"), {32, 32}, DT_FLOAT);
2138 auto z = ops::Variable(s.WithOpName("z"), {32, 32, 32}, DT_FLOAT);
2139 auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2140 auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2141
2142 auto add_all = ops::Add(s.WithOpName("AddAll"), add_abc, add_xyz);
2143 auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2144
2145 GrapplerItem item;
2146 item.fetch = {"outputs"};
2147 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2148
2149 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2150 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2151 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2152 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2153 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2154 auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2155 std::vector<std::pair<string, Tensor>> feed = {
2156 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2157 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2158 EXPECT_EQ(1, tensors_expected.size());
2159
2160 GraphDef output;
2161 ArithmeticOptimizer optimizer;
2162 EnableOnlyAddToAddNCombining(&optimizer);
2163
2164 OptimizeAndPrune(&optimizer, &item, &output);
2165
2166 // We expect the following rewrite(s) to occur:
2167 // 1) [a, x], [b, y], [c, z] - aggregate same shapes first
2168 // 2) Build an aggregation tree minimizing cost of broadcast
2169 //
2170 // + +
2171 // / \ / \
2172 // + + + AddN(c, z)
2173 // / \ / \ / \
2174 // + c x + --> AddN(a, x) AddN(b, y)
2175 // / \ / \
2176 // a b y z
2177 EXPECT_EQ(12, output.node_size());
2178 NodeMap node_map(&output);
2179
2180 // expected names of outer and inner nodes
2181 string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_AddAll";
2182 string outer_0_add_name =
2183 "ArithmeticOptimizer/AddOpsRewrite_Internal_0_AddAll";
2184 string inner_0_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_AddAll";
2185 string inner_1_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_1_AddAll";
2186 string inner_2_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_2_AddAll";
2187
2188 // Add [a, x] first
2189 const NodeDef* add_ax_node = node_map.GetNode(inner_0_add_name);
2190 ASSERT_NE(add_ax_node, nullptr);
2191 EXPECT_EQ("AddN", add_ax_node->op());
2192 EXPECT_EQ(2, add_ax_node->input_size());
2193 EXPECT_EQ("a", add_ax_node->input(0));
2194 EXPECT_EQ("x", add_ax_node->input(1));
2195
2196 // Then add [b, y]
2197 const NodeDef* add_by_node = node_map.GetNode(inner_1_add_name);
2198 ASSERT_NE(add_by_node, nullptr);
2199 EXPECT_EQ("AddN", add_by_node->op());
2200 EXPECT_EQ(2, add_by_node->input_size());
2201 EXPECT_EQ("b", add_by_node->input(0));
2202 EXPECT_EQ("y", add_by_node->input(1));
2203
2204 // Then add [c, z]
2205 const NodeDef* add_cz_node = node_map.GetNode(inner_2_add_name);
2206 ASSERT_NE(add_cz_node, nullptr);
2207 EXPECT_EQ("AddN", add_cz_node->op());
2208 EXPECT_EQ(2, add_cz_node->input_size());
2209 EXPECT_EQ("c", add_cz_node->input(0));
2210 EXPECT_EQ("z", add_cz_node->input(1));
2211
2212 // Then add results together starting from smaller shapes [a, x] + [b, y]
2213 const NodeDef* outer_0_node = node_map.GetNode(outer_0_add_name);
2214 ASSERT_NE(outer_0_node, nullptr);
2215 EXPECT_EQ("Add", outer_0_node->op());
2216 EXPECT_EQ(2, outer_0_node->input_size());
2217 EXPECT_EQ(inner_0_add_name, outer_0_node->input(0));
2218 EXPECT_EQ(inner_1_add_name, outer_0_node->input(1));
2219
2220 // And finally top level Add node
2221 const NodeDef* outer_node = node_map.GetNode(outer_add_name);
2222 ASSERT_NE(outer_node, nullptr);
2223 EXPECT_EQ("Add", outer_node->op());
2224 EXPECT_EQ(2, outer_node->input_size());
2225 EXPECT_EQ(outer_0_add_name, outer_node->input(0));
2226 EXPECT_EQ(inner_2_add_name, outer_node->input(1));
2227
2228 // And outputs reading new top level Add node
2229 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2230 ASSERT_NE(updated_outputs, nullptr);
2231 EXPECT_EQ(outer_add_name, updated_outputs->input(0));
2232
2233 auto tensors = EvaluateNodes(output, item.fetch, feed);
2234 EXPECT_EQ(1, tensors.size());
2235 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2236 }
2237
TEST_F(ArithmeticOptimizerTest,AddOpsRewrite_MinimizeBCastWithSymbolicShapes)2238 TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCastWithSymbolicShapes) {
2239 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2240
2241 // We have a small input with one unknown dimension
2242 auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_DOUBLE);
2243
2244 // And second input which is larger, but has the same unknown dimension
2245 // device spec prevents this node from rewriting
2246 auto d = "/device:CPU:0";
2247 auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_DOUBLE);
2248 auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v);
2249
2250 // [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32}
2251 auto a = ops::Sqrt(s.WithOpName("a"), small);
2252 auto b = ops::Square(s.WithOpName("b"), large);
2253 auto c = ops::Round(s.WithOpName("c"), small);
2254
2255 // [add_ab, add_abc] shape must be inferred from inputs
2256 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2257 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2258
2259 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2260
2261 GrapplerItem item;
2262 item.fetch = {"outputs"};
2263 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2264
2265 auto s_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({8, 1, 1}));
2266 auto v_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({1, 32, 32}));
2267 std::vector<std::pair<string, Tensor>> feed = {{"small", s_t}, {"v", v_t}};
2268 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2269 EXPECT_EQ(1, tensors_expected.size());
2270
2271 GraphDef output;
2272 ArithmeticOptimizer optimizer;
2273 EnableOnlyAddToAddNCombining(&optimizer);
2274 OptimizeAndPrune(&optimizer, &item, &output);
2275
2276 // We expect the following rewrite(s) to occur: it's much cheaper to add small
2277 // tensors, and do the broadcast just once
2278 //
2279 // + +
2280 // / \ / \
2281 // + c --> + b
2282 // / \ / \
2283 // a b a c
2284 EXPECT_EQ(9, output.node_size());
2285 NodeMap node_map(&output);
2286
2287 // expected names of outer and inner nodes
2288 string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_Add_abc";
2289 string inner_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_Add_abc";
2290
2291 // outer Add node
2292 const NodeDef* outer_add = node_map.GetNode(outer_add_name);
2293 ASSERT_NE(outer_add, nullptr);
2294 EXPECT_EQ("Add", outer_add->op());
2295 EXPECT_EQ(inner_add_name, outer_add->input(0));
2296 EXPECT_EQ("b", outer_add->input(1));
2297
2298 // inner AddN node
2299 const NodeDef* inner_add = node_map.GetNode(inner_add_name);
2300 ASSERT_NE(inner_add, nullptr);
2301 EXPECT_EQ(2, inner_add->input_size());
2302 EXPECT_EQ("a", inner_add->input(0));
2303 EXPECT_EQ("c", inner_add->input(1));
2304
2305 // check output was re-wired to new node
2306 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2307 ASSERT_NE(updated_outputs, nullptr);
2308 EXPECT_EQ(outer_add_name, updated_outputs->input(0));
2309
2310 auto tensors = EvaluateNodes(output, item.fetch, feed);
2311 EXPECT_EQ(1, tensors.size());
2312 test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
2313 }
2314
TEST_F(ArithmeticOptimizerTest,RemoveNegation)2315 TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
2316 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2317 auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2318 auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2319 Output neg_x = ops::Neg(s.WithOpName("Neg_x"), x);
2320 Output neg_y = ops::Neg(s.WithOpName("Neg_y"), y);
2321 Output add_x_y = ops::Add(s.WithOpName("Add_x_y"), x, y);
2322 Output add_negx_y = ops::Add(s.WithOpName("Add_negx_y"), neg_x, y);
2323 Output add_x_negy = ops::Add(s.WithOpName("Add_x_negy"), x, neg_y);
2324 Output add_negx_negy = ops::Add(s.WithOpName("Add_negx_negy"), neg_x, neg_y);
2325 Output sub_x_y = ops::Sub(s.WithOpName("Sub_x_y"), x, y);
2326 Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
2327 Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
2328 Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
2329 Output neg_x_with_dep = ops::Neg(
2330 s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
2331 Output add_negx_with_dep_y =
2332 ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
2333 auto add_all =
2334 ops::AddN(s.WithOpName("add_all"),
2335 {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
2336 sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
2337
2338 GrapplerItem item;
2339 item.fetch = {"add_all"};
2340 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2341
2342 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2343 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2344 std::vector<std::pair<string, Tensor>> feed = {{"x", x_t}, {"y", y_t}};
2345 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2346 EXPECT_EQ(1, tensors_expected.size());
2347
2348 GraphDef output;
2349 ArithmeticOptimizer optimizer;
2350 EnableOnlyRemoveNegation(&optimizer);
2351 OptimizeTwice(&optimizer, &item, &output);
2352
2353 EXPECT_EQ(item.graph.node_size(), output.node_size());
2354 int found = 0;
2355 for (int i = 0; i < output.node_size(); ++i) {
2356 const NodeDef& node = output.node(i);
2357 if (node.name() == "Add_negx_y") {
2358 ++found;
2359 EXPECT_EQ("Sub", node.op());
2360 EXPECT_EQ(2, node.input_size());
2361 EXPECT_EQ("y", node.input(0));
2362 EXPECT_EQ("x", node.input(1));
2363 } else if (node.name() == "Add_x_negy") {
2364 ++found;
2365 EXPECT_EQ("Sub", node.op());
2366 EXPECT_EQ(2, node.input_size());
2367 EXPECT_EQ("x", node.input(0));
2368 EXPECT_EQ("y", node.input(1));
2369 } else if (node.name() == "Add_negx_negy") {
2370 ++found;
2371 EXPECT_EQ("Sub", node.op());
2372 EXPECT_EQ(2, node.input_size());
2373 EXPECT_EQ("Neg_x", node.input(0));
2374 EXPECT_EQ("y", node.input(1));
2375 } else if (node.name() == "Sub_x_negy") {
2376 ++found;
2377 EXPECT_EQ("Add", node.op());
2378 EXPECT_EQ(2, node.input_size());
2379 EXPECT_EQ("x", node.input(0));
2380 EXPECT_EQ("y", node.input(1));
2381 } else if (node.name() == "Sub_negx_negy") {
2382 ++found;
2383 EXPECT_EQ("Sub", node.op());
2384 EXPECT_EQ(2, node.input_size());
2385 EXPECT_EQ("y", node.input(0));
2386 EXPECT_EQ("x", node.input(1));
2387 } else if (node.name() == "Add_negx_with_dep_y") {
2388 ++found;
2389 EXPECT_EQ("Sub", node.op());
2390 EXPECT_EQ(3, node.input_size());
2391 EXPECT_EQ("y", node.input(0));
2392 EXPECT_EQ("x", node.input(1));
2393 EXPECT_EQ("^Add_x_y", node.input(2));
2394 }
2395 }
2396 EXPECT_EQ(6, found);
2397
2398 auto tensors = EvaluateNodes(output, item.fetch, feed);
2399 EXPECT_EQ(1, tensors.size());
2400 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2401 }
2402
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMul)2403 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
2404 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2405 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2406 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2407 Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2408 Output div_x_sqrt_y = ops::Div(s.WithOpName("output"), x, sqrt_y);
2409
2410 GrapplerItem item;
2411 item.fetch = {"output"};
2412 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2413 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2414 EXPECT_EQ(1, tensors_expected.size());
2415
2416 GraphDef output;
2417 ArithmeticOptimizer optimizer;
2418 EnableOnlySqrtDivToRsqrtMul(&optimizer);
2419 OptimizeAndPrune(&optimizer, &item, &output);
2420 auto tensors = EvaluateNodes(output, item.fetch);
2421 EXPECT_EQ(1, tensors.size());
2422
2423 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2424 EXPECT_EQ(item.graph.node_size(), output.node_size());
2425 for (int i = 0; i < output.node_size(); ++i) {
2426 const NodeDef& node = output.node(i);
2427 if (node.name() == "output") {
2428 EXPECT_EQ("Mul", node.op());
2429 EXPECT_EQ(2, node.input_size());
2430 EXPECT_EQ("x", node.input(0));
2431 EXPECT_EQ("sqrt_y", node.input(1));
2432 } else if (node.name() == "sqrt_y") {
2433 EXPECT_EQ("Rsqrt", node.op());
2434 EXPECT_EQ(1, node.input_size());
2435 EXPECT_EQ("y", node.input(0));
2436 }
2437 }
2438 }
2439
TEST_F(ArithmeticOptimizerTest,DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode)2440 TEST_F(ArithmeticOptimizerTest, DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode) {
2441 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2442 Output floats = ops::Const(s.WithOpName("floats"),
2443 {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2444 Output output0 = ops::Sqrt(s.WithOpName("output0"), floats);
2445 Output const1 = ops::Const(s.WithOpName("const1"), 1.0f, {3});
2446 Output mul1 = ops::Multiply(s.WithOpName("mul1"), const1, 0.5f);
2447 Output grad = ops::Div(s.WithOpName("grad"), mul1, output0);
2448
2449 GrapplerItem item;
2450 item.fetch = {"grad", "output0"};
2451 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2452 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2453 ASSERT_EQ(2, tensors_expected.size());
2454
2455 GraphDef output;
2456 ArithmeticOptimizer optimizer;
2457 EnableOnlySqrtDivToRsqrtMul(&optimizer);
2458 OptimizeAndPrune(&optimizer, &item, &output);
2459 auto tensors = EvaluateNodes(output, item.fetch);
2460 ASSERT_EQ(2, tensors.size());
2461
2462 for (int i = 0; i < tensors.size(); i++) {
2463 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2464 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
2465 }
2466 EXPECT_EQ(item.graph.node_size(), output.node_size());
2467 for (int i = 0; i < output.node_size(); ++i) {
2468 const NodeDef& node = output.node(i);
2469 if (node.name() == "grad") {
2470 EXPECT_EQ("Div", node.op());
2471 EXPECT_EQ(2, node.input_size());
2472 EXPECT_EQ("mul1", node.input(0));
2473 EXPECT_EQ("output0", node.input(1));
2474 } else if (node.name() == "output0") {
2475 EXPECT_EQ("Sqrt", node.op());
2476 EXPECT_EQ(1, node.input_size());
2477 EXPECT_EQ("floats", node.input(0));
2478 }
2479 }
2480 }
2481
TEST_F(ArithmeticOptimizerTest,FuseSquaredDiff)2482 TEST_F(ArithmeticOptimizerTest, FuseSquaredDiff) {
2483 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2484 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2485 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2486 Output sub_x_y = ops::Sub(s.WithOpName("sub_x_y"), x, y);
2487 Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
2488
2489 GrapplerItem item;
2490 item.fetch = {"output"};
2491 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2492 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2493 EXPECT_EQ(1, tensors_expected.size());
2494
2495 GraphDef output;
2496 ArithmeticOptimizer optimizer;
2497 EnableOnlyFuseSquaredDiff(&optimizer);
2498 OptimizeAndPrune(&optimizer, &item, &output);
2499 const auto tensors = EvaluateNodes(output, item.fetch);
2500 EXPECT_EQ(1, tensors.size());
2501
2502 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2503 EXPECT_EQ(item.graph.node_size(), output.node_size());
2504 for (int i = 0; i < output.node_size(); ++i) {
2505 const NodeDef& node = output.node(i);
2506 if (node.name() == "output") {
2507 EXPECT_EQ("Identity", node.op());
2508 EXPECT_EQ(1, node.input_size());
2509 EXPECT_EQ("sub_x_y", node.input(0));
2510 } else if (node.name() == "sub_x_y") {
2511 EXPECT_EQ("SquaredDifference", node.op());
2512 EXPECT_EQ(2, node.input_size());
2513 EXPECT_EQ("x", node.input(0));
2514 EXPECT_EQ("y", node.input(1));
2515 }
2516 }
2517 }
2518
TEST_F(ArithmeticOptimizerTest,DoNotFuseSquaredDiffFetchNode)2519 TEST_F(ArithmeticOptimizerTest, DoNotFuseSquaredDiffFetchNode) {
2520 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2521 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2522 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2523 Output sub_x_y = ops::Sub(s.WithOpName("sub_x_y"), x, y);
2524 Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
2525
2526 GrapplerItem item;
2527 item.fetch = {"output", "sub_x_y"};
2528 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2529 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2530 ASSERT_EQ(2, tensors_expected.size());
2531
2532 GraphDef output;
2533 ArithmeticOptimizer optimizer;
2534 EnableOnlyFuseSquaredDiff(&optimizer);
2535 OptimizeAndPrune(&optimizer, &item, &output);
2536 const auto tensors = EvaluateNodes(output, item.fetch);
2537 ASSERT_EQ(2, tensors.size());
2538
2539 for (int i = 0; i < tensors.size(); i++) {
2540 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2541 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
2542 }
2543 EXPECT_EQ(item.graph.node_size(), output.node_size());
2544 for (int i = 0; i < output.node_size(); ++i) {
2545 const NodeDef& node = output.node(i);
2546 if (node.name() == "output") {
2547 EXPECT_EQ("Square", node.op());
2548 EXPECT_EQ(1, node.input_size());
2549 EXPECT_EQ("sub_x_y", node.input(0));
2550 } else if (node.name() == "sub_x_y") {
2551 EXPECT_EQ("Sub", node.op());
2552 EXPECT_EQ(2, node.input_size());
2553 EXPECT_EQ("x", node.input(0));
2554 EXPECT_EQ("y", node.input(1));
2555 }
2556 }
2557 }
2558
TEST_F(ArithmeticOptimizerTest,ConvertLogSoftmax)2559 TEST_F(ArithmeticOptimizerTest, ConvertLogSoftmax) {
2560 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2561 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2562 Output softmax = ops::Softmax(s.WithOpName("softmax"), x);
2563 Output logsoftmax = ops::Log(s.WithOpName("output"), softmax);
2564
2565 GrapplerItem item;
2566 item.fetch = {"output"};
2567 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2568 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2569 EXPECT_EQ(1, tensors_expected.size());
2570
2571 GraphDef output;
2572 ArithmeticOptimizer optimizer;
2573 EnableOnlyLogSoftmax(&optimizer);
2574 OptimizeAndPrune(&optimizer, &item, &output);
2575 const auto tensors = EvaluateNodes(output, item.fetch);
2576 EXPECT_EQ(1, tensors.size());
2577
2578 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2579 EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
2580 for (int i = 0; i < output.node_size(); ++i) {
2581 const NodeDef& node = output.node(i);
2582 if (node.name() == "output") {
2583 EXPECT_EQ("LogSoftmax", node.op());
2584 EXPECT_EQ(1, node.input_size());
2585 EXPECT_EQ("x", node.input(0));
2586 }
2587 }
2588 }
2589
TEST_F(ArithmeticOptimizerTest,DoNotConvertLogSoftmaxArgFetchNode)2590 TEST_F(ArithmeticOptimizerTest, DoNotConvertLogSoftmaxArgFetchNode) {
2591 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2592 Output floats = ops::Const(s.WithOpName("floats"),
2593 {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2594 Output softmax = ops::Softmax(s.WithOpName("softmax"), floats);
2595 Output final_output = ops::Log(s.WithOpName("final_output"), softmax);
2596
2597 GrapplerItem item;
2598 item.fetch = {"softmax", "final_output"};
2599 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2600 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2601 ASSERT_EQ(2, tensors_expected.size());
2602
2603 GraphDef output;
2604 ArithmeticOptimizer optimizer;
2605 EnableOnlyLogSoftmax(&optimizer);
2606 OptimizeTwice(&optimizer, &item, &output);
2607 const auto tensors = EvaluateNodes(output, item.fetch);
2608 ASSERT_EQ(2, tensors.size());
2609
2610 // Should be a NoOp since we are not allowed to change the output of fetch
2611 // nodes.
2612 VerifyGraphsMatch(item.graph, output, __LINE__);
2613
2614 for (int i = 0; i < tensors.size(); i++) {
2615 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2616 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
2617 }
2618 }
2619
TEST_F(ArithmeticOptimizerTest,ConvertPow)2620 TEST_F(ArithmeticOptimizerTest, ConvertPow) {
2621 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2622 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2623 auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
2624 auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
2625 auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
2626 auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
2627 auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
2628 auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
2629 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2630 auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
2631 auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
2632 auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
2633 Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
2634 Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
2635 Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
2636 Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
2637 Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
2638 Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
2639 Output out = ops::Pow(s.WithOpName("out"), x, y);
2640 Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
2641 Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
2642
2643 GrapplerItem item;
2644 item.fetch = {"out2", "out1", "out.5", "out0", "out_.5",
2645 "out_1", "out", "out_bcast1", "out_bcast2"};
2646 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2647 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2648 EXPECT_EQ(9, tensors_expected.size());
2649
2650 GraphDef got;
2651 ArithmeticOptimizer optimizer;
2652 EnableOnlyConvertPow(&optimizer);
2653 OptimizeAndPrune(&optimizer, &item, &got);
2654 auto tensors = EvaluateNodes(got, item.fetch);
2655 EXPECT_EQ(9, tensors.size());
2656
2657 for (int i = 0; i < tensors.size(); ++i) {
2658 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2659 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2660 }
2661
2662 GraphDef want;
2663 AddNode("x", "Const", {}, {}, &want);
2664 AddNode("y2", "Const", {}, {}, &want);
2665 AddNode("y1", "Const", {}, {}, &want);
2666 AddNode("y.5", "Const", {}, {}, &want);
2667 AddNode("y0", "Const", {}, {}, &want);
2668 AddNode("y_.5", "Const", {}, {}, &want);
2669 AddNode("y_1", "Const", {}, {}, &want);
2670 AddNode("y", "Const", {}, {}, &want);
2671 AddNode("z", "Const", {}, {}, &want);
2672 AddNode("ones", "Const", {}, {}, &want);
2673 AddNode("zeros", "Const", {}, {}, &want);
2674 AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
2675 AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
2676 AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
2677 AddNode("out0", "Const",
2678 {AsControlDependency("x"), AsControlDependency("y0")}, {}, &want);
2679 AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
2680 AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
2681 AddNode("out", "Pow", {"x", "y"}, {}, &want);
2682 AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
2683 AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
2684
2685 CompareGraphs(want, got);
2686 }
2687
TEST_F(ArithmeticOptimizerTest,Log1p)2688 TEST_F(ArithmeticOptimizerTest, Log1p) {
2689 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2690
2691 auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2});
2692 auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2});
2693 auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
2694 auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2);
2695 auto a23 = ops::Add(s.WithOpName("a23"), x2, x3);
2696 Output out1 = ops::Log(s.WithOpName("out1"), a12);
2697 Output out2 = ops::Log(s.WithOpName("out2"), a23);
2698
2699 GrapplerItem item;
2700 item.fetch = {"out1", "out2"};
2701 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2702 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2703 EXPECT_EQ(2, tensors_expected.size());
2704
2705 GraphDef got;
2706 ArithmeticOptimizer optimizer;
2707 EnableOnlyLog1p(&optimizer);
2708 OptimizeAndPrune(&optimizer, &item, &got);
2709 auto tensors = EvaluateNodes(got, item.fetch);
2710 EXPECT_EQ(2, tensors.size());
2711
2712 for (int i = 0; i < 2; ++i) {
2713 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2714 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2715 }
2716
2717 GraphDef want;
2718 AddNode("x1", "Const", {}, {}, &want);
2719 AddNode("x2", "Const", {}, {}, &want);
2720 AddNode("x3", "Const", {}, {}, &want);
2721 AddNode("a23", "Add", {"x2", "x3"}, {}, &want);
2722 AddNode("out1", "Log1p",
2723 {"x2", AsControlDependency("x1"), AsControlDependency("x3")}, {},
2724 &want);
2725 AddNode("out2", "Log", {"a23"}, {}, &want);
2726
2727 CompareGraphs(want, got);
2728 }
2729
TEST_F(ArithmeticOptimizerTest,Expm1)2730 TEST_F(ArithmeticOptimizerTest, Expm1) {
2731 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2732
2733 auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
2734 auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
2735 auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
2736 auto exp1 = ops::Exp(s.WithOpName("exp1").WithControlDependencies(x3), x1);
2737 Output out1 = ops::Sub(s.WithOpName("out1"), exp1, x2);
2738 Output out2 = ops::Sub(s.WithOpName("out2"), exp1, x3);
2739
2740 GrapplerItem item;
2741 item.fetch = {"out1", "out2"};
2742 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2743 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2744 EXPECT_EQ(2, tensors_expected.size());
2745
2746 GraphDef got;
2747 ArithmeticOptimizer optimizer;
2748 EnableOnlyExpm1(&optimizer);
2749 OptimizeAndPrune(&optimizer, &item, &got);
2750 auto tensors = EvaluateNodes(got, item.fetch);
2751 EXPECT_EQ(2, tensors.size());
2752
2753 for (int i = 0; i < 2; ++i) {
2754 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2755 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2756 }
2757
2758 GraphDef want;
2759 AddNode("x1", "Const", {}, {}, &want);
2760 AddNode("x2", "Const", {}, {}, &want);
2761 AddNode("x3", "Const", {}, {}, &want);
2762 AddNode("exp1", "Exp", {"x1", AsControlDependency("x3")}, {}, &want);
2763 AddNode("out1", "Expm1",
2764 {"x1", AsControlDependency("x2"), AsControlDependency("x3")}, {},
2765 &want);
2766 AddNode("out2", "Sub", {"exp1", "x3"}, {}, &want);
2767
2768 CompareGraphs(want, got);
2769 }
2770
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_SimpleSwap)2771 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
2772 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2773
2774 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2775 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2776 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
2777
2778 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2779 auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
2780
2781 auto outputs = ops::Identity(s.WithOpName("outputs"), mul2);
2782
2783 GrapplerItem item;
2784 item.fetch = {"outputs"};
2785 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2786
2787 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2788 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2789 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2790 std::vector<std::pair<string, Tensor>> feed = {
2791 {"a", a_t}, {"b", b_t}, {"c", c_t}};
2792 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2793 EXPECT_EQ(1, tensors_expected.size());
2794
2795 GraphDef output;
2796 ArithmeticOptimizer optimizer;
2797 EnableOnlyMinimizeBroadcasts(&optimizer);
2798
2799 OptimizeAndPrune(&optimizer, &item, &output);
2800
2801 // We expect the following rewrite(s) to occur:
2802 //
2803 // * *
2804 // / \ / \
2805 // * c --> * b
2806 // / \ / \
2807 // a b a c
2808 NodeMap node_map(&output);
2809
2810 const NodeDef* mul1_node = node_map.GetNode("mul1");
2811 ASSERT_NE(mul1_node, nullptr);
2812 EXPECT_EQ("a", mul1_node->input(0));
2813 EXPECT_EQ("c", mul1_node->input(1));
2814
2815 const NodeDef* mul2_node = node_map.GetNode("mul2");
2816 ASSERT_NE(mul2_node, nullptr);
2817 EXPECT_EQ("mul1", mul2_node->input(0));
2818 EXPECT_EQ("b", mul2_node->input(1));
2819
2820 auto tensors = EvaluateNodes(output, item.fetch, feed);
2821 EXPECT_EQ(1, tensors.size());
2822 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2823 }
2824
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_FlattenTallGraph)2825 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
2826 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2827
2828 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_DOUBLE);
2829 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_DOUBLE);
2830 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_DOUBLE);
2831 auto d = ops::Variable(s.WithOpName("d"), {32}, DT_DOUBLE);
2832 auto e = ops::Variable(s.WithOpName("e"), {32}, DT_DOUBLE);
2833
2834 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2835 auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
2836 auto mul3 = ops::Mul(s.WithOpName("mul3"), mul2, d);
2837 auto mul4 = ops::Mul(s.WithOpName("mul4"), mul3, e);
2838
2839 auto outputs = ops::Identity(s.WithOpName("outputs"), mul4);
2840
2841 GrapplerItem item;
2842 item.fetch = {"outputs"};
2843 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2844
2845 auto a_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2846 auto b_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32, 32}));
2847 auto c_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2848 auto d_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2849 auto e_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
2850 std::vector<std::pair<string, Tensor>> feed = {
2851 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"d", d_t}, {"e", e_t}};
2852 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2853 EXPECT_EQ(1, tensors_expected.size());
2854
2855 GraphDef output;
2856 ArithmeticOptimizer optimizer;
2857 EnableOnlyMinimizeBroadcasts(&optimizer);
2858
2859 OptimizeAndPrune(&optimizer, &item, &output);
2860
2861 // We expect the following rewrite(s) to occur: Graph is "flattened" and
2862 // largest shape pushed to the top.
2863 //
2864 // *
2865 // / \
2866 // * e *
2867 // / \ / \
2868 // * d * b
2869 // / \ / \
2870 // * c --> * *
2871 // / \ / \ / \
2872 // a b a c d e
2873 NodeMap node_map(&output);
2874
2875 const NodeDef* mul1_node = node_map.GetNode("mul1");
2876 ASSERT_NE(mul1_node, nullptr);
2877 EXPECT_EQ("a", mul1_node->input(0));
2878 EXPECT_EQ("c", mul1_node->input(1));
2879
2880 const NodeDef* mul2_node = node_map.GetNode("mul2");
2881 ASSERT_NE(mul2_node, nullptr);
2882 EXPECT_EQ("d", mul2_node->input(0));
2883 EXPECT_EQ("e", mul2_node->input(1));
2884
2885 const NodeDef* mul3_node = node_map.GetNode("mul3");
2886 ASSERT_NE(mul3_node, nullptr);
2887 EXPECT_EQ("mul1", mul3_node->input(0));
2888 EXPECT_EQ("mul2", mul3_node->input(1));
2889
2890 const NodeDef* mul4_node = node_map.GetNode("mul4");
2891 ASSERT_NE(mul4_node, nullptr);
2892 EXPECT_EQ("mul3", mul4_node->input(0));
2893 EXPECT_EQ("b", mul4_node->input(1));
2894
2895 auto tensors = EvaluateNodes(output, item.fetch, feed);
2896 EXPECT_EQ(1, tensors.size());
2897 test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
2898 }
2899
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_BuildTreeUp)2900 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
2901 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2902
2903 // [a, b, c] - scalars, [d] - matrix
2904 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2905 auto b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
2906 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
2907 auto d = ops::Variable(s.WithOpName("D"), {32, 32}, DT_FLOAT);
2908
2909 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
2910 auto mul2 = ops::Mul(s.WithOpName("mul2"), c, d);
2911 auto mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2);
2912
2913 auto outputs = ops::Identity(s.WithOpName("outputs"), mul3);
2914
2915 GrapplerItem item;
2916 item.fetch = {"outputs"};
2917 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2918
2919 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2920 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2921 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2922 auto d_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2923 std::vector<std::pair<string, Tensor>> feed = {
2924 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"D", d_t}};
2925 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2926 EXPECT_EQ(1, tensors_expected.size());
2927
2928 GraphDef output;
2929 ArithmeticOptimizer optimizer;
2930 EnableOnlyMinimizeBroadcasts(&optimizer);
2931
2932 OptimizeAndPrune(&optimizer, &item, &output);
2933
2934 // We expect the following rewrite(s) to occur:
2935 //
2936 // *
2937 // / \
2938 // * * D
2939 // / \ / \
2940 // * * -> * c
2941 // / \ / \ / \
2942 // a b c D a b
2943 NodeMap node_map(&output);
2944
2945 const NodeDef* mul1_node = node_map.GetNode("mul2");
2946 ASSERT_NE(mul1_node, nullptr);
2947 EXPECT_EQ("a", mul1_node->input(0));
2948 EXPECT_EQ("b", mul1_node->input(1));
2949
2950 const NodeDef* mul2_node = node_map.GetNode("mul1");
2951 ASSERT_NE(mul2_node, nullptr);
2952 EXPECT_EQ("mul2", mul2_node->input(0));
2953 EXPECT_EQ("c", mul2_node->input(1));
2954
2955 const NodeDef* mul3_node = node_map.GetNode("mul3");
2956 ASSERT_NE(mul3_node, nullptr);
2957 EXPECT_EQ("D", mul3_node->input(0));
2958 EXPECT_EQ("mul1", mul3_node->input(1));
2959
2960 auto tensors = EvaluateNodes(output, item.fetch, feed);
2961 EXPECT_EQ(1, tensors.size());
2962 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
2963 }
2964
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryFromConcat)2965 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
2966 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2967 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
2968 Output b = ops::Const(s.WithOpName("b"), 1.0f, {32});
2969 Output c = ops::Const(s.WithOpName("c"), 42.0f, {32});
2970 Output axis = ops::Const(s.WithOpName("axis"), 0, {});
2971 Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
2972 Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
2973 Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
2974 // Test case with chains of length 1.
2975 // Rewrites
2976 // Concat({Exp(a), Exp(b), Exp(c)})
2977 // into
2978 // Exp(Concat({a, b, c})).
2979 Output sin_a =
2980 ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a);
2981 Output exp_a =
2982 ops::Exp(s.WithOpName("exp_a").WithControlDependencies(ctrl1), sin_a);
2983 Output exp_b = ops::Exp(s.WithOpName("exp_b"), b);
2984 Output exp_c =
2985 ops::Exp(s.WithOpName("exp_c").WithControlDependencies(ctrl2), c);
2986 Output concat =
2987 ops::Concat(s.WithOpName("concat"), {exp_a, exp_b, exp_c}, axis);
2988 Output id = ops::Identity(s.WithOpName("id"), concat);
2989
2990 // Test case with chains of length 2.
2991 // Rewrites
2992 // Concat({Cos(Exp(a)), Cos(Exp(b)), Cos(Exp(c))})
2993 // into
2994 // Cos(Exp(Concat({a, b, c}))).
2995 Output exp_a2 =
2996 ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a);
2997 Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b);
2998 Output exp_c2 =
2999 ops::Exp(s.WithOpName("exp_c2").WithControlDependencies(ctrl2), c);
3000 Output cos_exp_a2 = ops::Cos(
3001 s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl1), exp_a2);
3002 Output cos_exp_b2 = ops::Cos(
3003 s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3004 Output cos_exp_c2 = ops::Cos(s.WithOpName("cos_exp_c2"), exp_c2);
3005 Output concat2 = ops::Concat(s.WithOpName("concat2"),
3006 {cos_exp_a2, cos_exp_b2, cos_exp_c2}, axis);
3007 Output id2 = ops::Identity(s.WithOpName("id2"), concat2);
3008 GrapplerItem item;
3009 item.fetch = {"id", "id2"};
3010 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3011
3012 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3013
3014 GraphDef output;
3015 ArithmeticOptimizer optimizer;
3016 EnableOnlyHoistCWiseUnaryChains(&optimizer);
3017 OptimizeTwiceAndPrune(&optimizer, &item, &output);
3018
3019 int found = 0;
3020 for (const NodeDef& node : output.node()) {
3021 if (node.name() == "concat") {
3022 EXPECT_EQ(6, node.input_size());
3023 EXPECT_EQ("sin_a", node.input(0));
3024 EXPECT_EQ("b", node.input(1));
3025 EXPECT_EQ("c", node.input(2));
3026 EXPECT_EQ("axis", node.input(3));
3027 EXPECT_EQ("^ctrl1", node.input(4));
3028 EXPECT_EQ("^ctrl2", node.input(5));
3029 found++;
3030 }
3031 if (node.name() == "exp_a") {
3032 EXPECT_EQ(2, node.input_size());
3033 EXPECT_EQ("concat", node.input(0));
3034 EXPECT_EQ("^ctrl1", node.input(1));
3035 found++;
3036 }
3037 if (node.name() == "id") {
3038 EXPECT_EQ(1, node.input_size());
3039 EXPECT_EQ("exp_a", node.input(0));
3040 found++;
3041 }
3042
3043 if (node.name() == "concat2") {
3044 EXPECT_EQ(7, node.input_size());
3045 EXPECT_EQ("sin_a", node.input(0));
3046 EXPECT_EQ("b", node.input(1));
3047 EXPECT_EQ("c", node.input(2));
3048 EXPECT_EQ("axis", node.input(3));
3049 EXPECT_EQ("^ctrl1", node.input(4));
3050 EXPECT_EQ("^ctrl2", node.input(5));
3051 EXPECT_EQ("^ctrl3", node.input(6));
3052 found++;
3053 }
3054 if (node.name() == "exp_a2") {
3055 EXPECT_EQ(2, node.input_size());
3056 EXPECT_EQ("concat2", node.input(0));
3057 EXPECT_EQ("^ctrl1", node.input(1));
3058 found++;
3059 }
3060 if (node.name() == "cos_exp_a2") {
3061 EXPECT_EQ(2, node.input_size());
3062 EXPECT_EQ("exp_a2", node.input(0));
3063 EXPECT_EQ("^ctrl1", node.input(1));
3064 found++;
3065 }
3066 if (node.name() == "id2") {
3067 EXPECT_EQ(1, node.input_size());
3068 EXPECT_EQ("cos_exp_a2", node.input(0));
3069 found++;
3070 }
3071 }
3072 EXPECT_EQ(7, found);
3073
3074 auto tensors = EvaluateNodes(output, item.fetch);
3075 EXPECT_EQ(tensors.size(), tensors_expected.size());
3076 EXPECT_EQ(tensors.size(), item.fetch.size());
3077 for (int i = 0; i < item.fetch.size(); ++i) {
3078 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
3079 }
3080 }
3081
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryIntoSplit)3082 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
3083 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3084 Output x = ops::Const(s.WithOpName("x"), 3.1415f, {32});
3085 Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3086 Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3087 Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3088 Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3089 // Test case with chains of length 1.
3090 // Rewrites
3091 // [Sin(y) for y in Split(x)]
3092 // into
3093 // [y for y in Split(Sin(x))].
3094 ops::Split split1(s.WithOpName("split1"), axis, x, 2);
3095 Output sin_a =
3096 ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl1), split1[0]);
3097 Output id_a = ops::Identity(s.WithOpName("id_a"), sin_a);
3098 Output sin_b = ops::Sin(s.WithOpName("sin_b"), split1[1]);
3099 Output exp_b = ops::Exp(s.WithOpName("exp_b"), sin_b);
3100 Output id_b = ops::Identity(s.WithOpName("id_b"), exp_b);
3101
3102 // Test case with SplitV and chains of length 2.
3103 // Rewrites
3104 // [Cos(Exp(y)) for y in Split(x)]
3105 // into
3106 // [y for y in Split(Cos(Exp(x)))].
3107 Output size_splits2 = ops::Const(s.WithOpName("size_splits2"), {20, 12}, {2});
3108 ops::SplitV split2(s.WithOpName("split2"), x, size_splits2, axis, 2);
3109 Output exp_a2 = ops::Exp(
3110 s.WithOpName("exp_a2").WithControlDependencies(ctrl1), split2[0]);
3111 Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), split2[1]);
3112 Output cos_exp_a2 = ops::Cos(
3113 s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl2), exp_a2);
3114 Output cos_exp_b2 = ops::Cos(
3115 s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3116 Output id_a2 = ops::Identity(s.WithOpName("id_a2"), cos_exp_a2);
3117 Output id_b2 = ops::Identity(s.WithOpName("id_b2"), cos_exp_b2);
3118
3119 GrapplerItem item;
3120 item.fetch = {"id_a", "id_b", "id_a2", "id_b2"};
3121 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3122
3123 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3124
3125 GraphDef output;
3126 ArithmeticOptimizer optimizer;
3127 EnableOnlyHoistCWiseUnaryChains(&optimizer);
3128 OptimizeTwiceAndPrune(&optimizer, &item, &output);
3129
3130 int found = 0;
3131 for (const NodeDef& node : output.node()) {
3132 // The following 6 nodes should be pruned.
3133 EXPECT_NE(node.name(), "sin_a");
3134 EXPECT_NE(node.name(), "sin_b");
3135 EXPECT_NE(node.name(), "exp_a2");
3136 EXPECT_NE(node.name(), "exp_b2");
3137 EXPECT_NE(node.name(), "cos_exp_a2");
3138 EXPECT_NE(node.name(), "cos_exp_b2");
3139
3140 if (node.name() == "split1") {
3141 EXPECT_EQ(2, node.input_size());
3142 EXPECT_EQ("axis", node.input(0));
3143 EXPECT_EQ("ArithmeticOptimizer/_sin_a_split1", node.input(1));
3144 found++;
3145 }
3146 if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
3147 EXPECT_EQ("Sin", node.op());
3148 EXPECT_EQ(2, node.input_size());
3149 EXPECT_EQ("x", node.input(0));
3150 EXPECT_EQ("^ctrl1", node.input(1));
3151 found++;
3152 }
3153 if (node.name() == "id_a") {
3154 EXPECT_EQ(1, node.input_size());
3155 EXPECT_EQ("split1", node.input(0));
3156 found++;
3157 }
3158 if (node.name() == "exp_b") {
3159 EXPECT_EQ(1, node.input_size());
3160 EXPECT_EQ("split1:1", node.input(0));
3161 found++;
3162 }
3163 if (node.name() == "id_b") {
3164 EXPECT_EQ(1, node.input_size());
3165 EXPECT_EQ("exp_b", node.input(0));
3166 found++;
3167 }
3168 if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
3169 EXPECT_EQ("Exp", node.op());
3170 EXPECT_EQ(4, node.input_size());
3171 EXPECT_EQ("x", node.input(0));
3172 EXPECT_EQ("^ctrl1", node.input(1));
3173 EXPECT_EQ("^ctrl2", node.input(2));
3174 EXPECT_EQ("^ctrl3", node.input(3));
3175 found++;
3176 }
3177 if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
3178 EXPECT_EQ("Cos", node.op());
3179 EXPECT_EQ(1, node.input_size());
3180 EXPECT_EQ("ArithmeticOptimizer/_exp_a2_split2", node.input(0));
3181 found++;
3182 }
3183 if (node.name() == "split2") {
3184 EXPECT_EQ(3, node.input_size());
3185 EXPECT_EQ("ArithmeticOptimizer/_cos_exp_a2_split2", node.input(0));
3186 EXPECT_EQ("size_splits2", node.input(1));
3187 EXPECT_EQ("axis", node.input(2));
3188 found++;
3189 }
3190 if (node.name() == "id_a2") {
3191 EXPECT_EQ(1, node.input_size());
3192 EXPECT_EQ("split2", node.input(0));
3193 found++;
3194 }
3195 if (node.name() == "id_b2") {
3196 EXPECT_EQ(1, node.input_size());
3197 EXPECT_EQ("split2:1", node.input(0));
3198 found++;
3199 }
3200 }
3201 EXPECT_EQ(10, found);
3202
3203 auto tensors = EvaluateNodes(output, item.fetch);
3204 EXPECT_EQ(tensors.size(), tensors_expected.size());
3205 EXPECT_EQ(tensors.size(), item.fetch.size());
3206 for (int i = 0; i < item.fetch.size(); ++i) {
3207 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
3208 }
3209 }
3210
TEST_F(ArithmeticOptimizerTest,RemoveIdempotent)3211 TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
3212 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3213 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3214 Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
3215 Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
3216 Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
3217 Output id1 = ops::Identity(s.WithOpName("id1"), a);
3218 Output id2 = ops::Identity(s.WithOpName("id2"), id1);
3219 Output out2 = ops::Identity(s.WithOpName("out2"), id2);
3220 GrapplerItem item;
3221 item.fetch = {"out1", "out2"};
3222 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3223
3224 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3225
3226 GraphDef output;
3227 ArithmeticOptimizer optimizer;
3228 EnableOnlyRemoveIdempotent(&optimizer);
3229 OptimizeTwice(&optimizer, &item, &output);
3230
3231 EXPECT_EQ(7, output.node_size());
3232 int found = 0;
3233 for (const NodeDef& node : output.node()) {
3234 if (node.name() == "out1") {
3235 EXPECT_EQ(1, node.input_size());
3236 EXPECT_EQ("sn1", node.input(0));
3237 found++;
3238 } else if (node.name() == "out2") {
3239 EXPECT_EQ(1, node.input_size());
3240 EXPECT_EQ("id1", node.input(0));
3241 found++;
3242 } else if (node.name() == "sn1") {
3243 EXPECT_EQ(1, node.input_size());
3244 EXPECT_EQ("a", node.input(0));
3245 found++;
3246 }
3247 }
3248 EXPECT_EQ(3, found);
3249
3250 auto tensors = EvaluateNodes(output, item.fetch);
3251 EXPECT_EQ(tensors.size(), tensors_expected.size());
3252 EXPECT_EQ(tensors.size(), item.fetch.size());
3253 for (int i = 0; i < item.fetch.size(); ++i) {
3254 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
3255 }
3256 }
3257
TEST_F(ArithmeticOptimizerTest,RemoveLogicalNot)3258 TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
3259 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3260 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3261 Output b = ops::Const(s.WithOpName("b"), -3.14f, {32});
3262 Output eq = ops::Equal(s.WithOpName("eq"), a, b);
3263 Output neq = ops::NotEqual(s.WithOpName("neq"), a, b);
3264 Output lt = ops::Less(s.WithOpName("lt"), a, b);
3265 Output le = ops::LessEqual(s.WithOpName("le"), a, b);
3266 Output gt = ops::Greater(s.WithOpName("gt"), a, b);
3267 Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b);
3268 // not_eq is reserved
3269 Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq);
3270 Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq);
3271 Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt);
3272 Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le);
3273 Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt);
3274 Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge);
3275 Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1);
3276 Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq);
3277 Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt);
3278 Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le);
3279 Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt);
3280 Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge);
3281
3282 GrapplerItem item;
3283 item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt",
3284 "id_not_le", "id_not_gt", "id_not_ge"};
3285 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3286
3287 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3288
3289 GraphDef output;
3290 ArithmeticOptimizer optimizer;
3291 EnableOnlyRemoveLogicalNot(&optimizer);
3292 OptimizeTwice(&optimizer, &item, &output);
3293
3294 int found = 0;
3295 for (const NodeDef& node : output.node()) {
3296 if (node.name() == "id_not_eq") {
3297 EXPECT_EQ("eq", node.input(0));
3298 ++found;
3299 }
3300 if (node.name() == "id_not_neq") {
3301 EXPECT_EQ("neq", node.input(0));
3302 ++found;
3303 }
3304 if (node.name() == "id_not_lt") {
3305 EXPECT_EQ("lt", node.input(0));
3306 ++found;
3307 }
3308 if (node.name() == "id_not_le") {
3309 EXPECT_EQ("le", node.input(0));
3310 ++found;
3311 }
3312 if (node.name() == "id_not_gt") {
3313 EXPECT_EQ("gt", node.input(0));
3314 ++found;
3315 }
3316 if (node.name() == "id_not_ge") {
3317 EXPECT_EQ("ge", node.input(0));
3318 ++found;
3319 }
3320
3321 if (node.name() == "eq") {
3322 EXPECT_EQ("NotEqual", node.op());
3323 ++found;
3324 }
3325 if (node.name() == "neq") {
3326 EXPECT_EQ("Equal", node.op());
3327 ++found;
3328 }
3329 if (node.name() == "lt") {
3330 EXPECT_EQ("GreaterEqual", node.op());
3331 ++found;
3332 }
3333 if (node.name() == "le") {
3334 EXPECT_EQ("Greater", node.op());
3335 ++found;
3336 }
3337 if (node.name() == "gt") {
3338 EXPECT_EQ("LessEqual", node.op());
3339 ++found;
3340 }
3341 if (node.name() == "ge") {
3342 EXPECT_EQ("Less", node.op());
3343 ++found;
3344 }
3345 }
3346 EXPECT_EQ(12, found);
3347
3348 auto tensors = EvaluateNodes(output, item.fetch);
3349 EXPECT_EQ(tensors.size(), tensors_expected.size());
3350 EXPECT_EQ(tensors.size(), item.fetch.size());
3351 for (int i = 0; i < item.fetch.size(); ++i) {
3352 test::ExpectTensorEqual<bool>(tensors_expected[i], tensors[i]);
3353 }
3354 }
3355
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWise)3356 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
3357 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3358 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3359 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3360 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3361 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3362
3363 GrapplerItem item;
3364 item.fetch = {"final_out"};
3365 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3366 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3367 EXPECT_EQ(1, tensors_expected.size());
3368
3369 GraphDef output;
3370 ArithmeticOptimizer optimizer;
3371 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3372 OptimizeAndPrune(&optimizer, &item, &output);
3373 auto tensors = EvaluateNodes(output, item.fetch);
3374 EXPECT_EQ(1, tensors.size());
3375
3376 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3377 EXPECT_EQ(item.graph.node_size(), output.node_size());
3378 // Check if the inputs are switched
3379 int required_node_count = 0;
3380 for (int i = 0; i < output.node_size(); ++i) {
3381 const NodeDef& node = output.node(i);
3382 if (node.name() == "sqrt") {
3383 EXPECT_EQ("Sqrt", node.op());
3384 EXPECT_EQ(1, node.input_size());
3385 EXPECT_EQ("reduce_max", node.input(0));
3386 ++required_node_count;
3387 } else if (node.name() == "reduce_max") {
3388 EXPECT_EQ("Max", node.op());
3389 EXPECT_EQ(2, node.input_size());
3390 EXPECT_EQ("x", node.input(0));
3391 ++required_node_count;
3392 }
3393 }
3394 EXPECT_EQ(2, required_node_count);
3395 }
3396
TEST_F(ArithmeticOptimizerTest,OptimizeArgMaxOrArgMinOfMonotonicElementWise)3397 TEST_F(ArithmeticOptimizerTest, OptimizeArgMaxOrArgMinOfMonotonicElementWise) {
3398 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3399 const auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3400 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3401 Output arg_max = ops::ArgMax(s.WithOpName("arg_max"), sqrt, 1);
3402 Output final_out = ops::Identity(s.WithOpName("final_out"), arg_max);
3403
3404 GrapplerItem item;
3405 item.fetch = {"final_out"};
3406 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3407 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3408 EXPECT_EQ(1, tensors_expected.size());
3409
3410 GraphDef output;
3411 ArithmeticOptimizer optimizer;
3412 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3413 OptimizeAndPrune(&optimizer, &item, &output);
3414 const auto tensors = EvaluateNodes(output, item.fetch);
3415 EXPECT_EQ(1, tensors.size());
3416
3417 test::ExpectTensorEqual<int64>(tensors_expected[0], tensors[0]);
3418 EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
3419 // Check if the inputs are switched
3420 int required_node_count = 0;
3421 for (int i = 0; i < output.node_size(); ++i) {
3422 const NodeDef& node = output.node(i);
3423 if (node.name() == "final_out") {
3424 EXPECT_EQ("Identity", node.op());
3425 EXPECT_EQ(1, node.input_size());
3426 EXPECT_EQ("arg_max", node.input(0));
3427 ++required_node_count;
3428 } else if (node.name() == "arg_max") {
3429 EXPECT_EQ("ArgMax", node.op());
3430 EXPECT_EQ(2, node.input_size());
3431 EXPECT_EQ("x", node.input(0));
3432 ++required_node_count;
3433 }
3434 }
3435 EXPECT_EQ(2, required_node_count);
3436 }
3437
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWise_DoNotChangeFetchNode)3438 TEST_F(ArithmeticOptimizerTest,
3439 OptimizeMaxOrMinOfMonotonicElementWise_DoNotChangeFetchNode) {
3440 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3441 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3442 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3443 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3444 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3445
3446 GrapplerItem item;
3447 item.fetch = {"sqrt", "final_out"};
3448 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3449 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3450 EXPECT_EQ(2, tensors_expected.size());
3451
3452 GraphDef output;
3453 ArithmeticOptimizer optimizer;
3454 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3455 OptimizeTwice(&optimizer, &item, &output);
3456
3457 // Should be a NoOp since we are not allowed to change the output of fetch
3458 // nodes.
3459 VerifyGraphsMatch(item.graph, output, __LINE__);
3460 }
3461
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction)3462 TEST_F(ArithmeticOptimizerTest,
3463 OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction) {
3464 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3465 auto x = ops::Const(s.WithOpName("x"), {2, 3}, {1, 2});
3466 Output reshape = ops::Reshape(s.WithOpName("reshape"), x, {-1});
3467 Output y = ops::Neg(s.WithOpName("y"), reshape);
3468 Output z = ops::Max(s.WithOpName("z"), y, {0});
3469
3470 GrapplerItem item;
3471 item.fetch = {"z"};
3472 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3473 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3474 ASSERT_EQ(1, tensors_expected.size());
3475
3476 GraphDef output;
3477 ArithmeticOptimizer optimizer;
3478 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3479 OptimizeTwice(&optimizer, &item, &output);
3480
3481 // Should be a NoOp since we are not allowed to change the output of fetch
3482 // nodes.
3483 VerifyGraphsMatch(item.graph, output, __LINE__);
3484
3485 auto tensors = EvaluateNodes(output, item.fetch);
3486 ASSERT_EQ(1, tensors.size());
3487 test::ExpectTensorEqual<int>(tensors[0], tensors_expected[0]);
3488 test::ExpectTensorEqual<int>(tensors[0], Tensor(-2));
3489 }
3490
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing)3491 TEST_F(ArithmeticOptimizerTest,
3492 OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
3493 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3494 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3495 Output neg = ops::Neg(s.WithOpName("neg"), x);
3496 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
3497 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3498
3499 GrapplerItem item;
3500 item.fetch = {"final_out"};
3501 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3502 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3503 EXPECT_EQ(1, tensors_expected.size());
3504
3505 GraphDef output;
3506 ArithmeticOptimizer optimizer;
3507 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3508 OptimizeAndPrune(&optimizer, &item, &output);
3509 auto tensors = EvaluateNodes(output, item.fetch);
3510 EXPECT_EQ(1, tensors.size());
3511
3512 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3513 EXPECT_EQ(item.graph.node_size(), output.node_size());
3514 // Check if the inputs are switched
3515 int required_node_count = 0;
3516 for (int i = 0; i < output.node_size(); ++i) {
3517 const NodeDef& node = output.node(i);
3518 if (node.name() == "neg") {
3519 EXPECT_EQ("Neg", node.op());
3520 EXPECT_EQ(1, node.input_size());
3521 EXPECT_EQ("reduce_max", node.input(0));
3522 ++required_node_count;
3523 } else if (node.name() == "reduce_max") {
3524 EXPECT_EQ("Min", node.op());
3525 EXPECT_EQ(2, node.input_size());
3526 EXPECT_EQ("x", node.input(0));
3527 ++required_node_count;
3528 }
3529 }
3530 EXPECT_EQ(2, required_node_count);
3531 }
3532
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool)3533 TEST_F(ArithmeticOptimizerTest,
3534 OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool) {
3535 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3536 auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
3537 Output neg = ops::Neg(s.WithOpName("neg"), x);
3538 Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), neg, {1, 2, 2, 1},
3539 {1, 2, 2, 1}, "VALID");
3540
3541 GrapplerItem item;
3542 item.fetch = {"max_pool"};
3543 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3544 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3545 ASSERT_EQ(1, tensors_expected.size());
3546
3547 GraphDef output;
3548 ArithmeticOptimizer optimizer;
3549 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3550 OptimizeTwice(&optimizer, &item, &output);
3551
3552 // Should be a NoOp
3553 VerifyGraphsMatch(item.graph, output, __LINE__);
3554
3555 auto tensors = EvaluateNodes(output, item.fetch);
3556 ASSERT_EQ(1, tensors.size());
3557 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3558 }
3559
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseMaxPool)3560 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWiseMaxPool) {
3561 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3562 auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
3563 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3564 Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), sqrt, {1, 2, 2, 1},
3565 {1, 2, 2, 1}, "VALID");
3566 Output final_out = ops::Identity(s.WithOpName("final_out"), max_pool);
3567
3568 GrapplerItem item;
3569 item.fetch = {"final_out"};
3570 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3571 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3572 EXPECT_EQ(1, tensors_expected.size());
3573
3574 GraphDef output;
3575 ArithmeticOptimizer optimizer;
3576 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3577 OptimizeAndPrune(&optimizer, &item, &output);
3578 auto tensors = EvaluateNodes(output, item.fetch);
3579 EXPECT_EQ(1, tensors.size());
3580
3581 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3582 EXPECT_EQ(item.graph.node_size(), output.node_size());
3583 // Check if the inputs are switched
3584 int required_node_count = 0;
3585 for (int i = 0; i < output.node_size(); ++i) {
3586 const NodeDef& node = output.node(i);
3587 if (node.name() == "sqrt") {
3588 EXPECT_EQ("Sqrt", node.op());
3589 EXPECT_EQ(1, node.input_size());
3590 EXPECT_EQ("max_pool", node.input(0));
3591 ++required_node_count;
3592 } else if (node.name() == "max_pool") {
3593 EXPECT_EQ("MaxPool", node.op());
3594 EXPECT_EQ(1, node.input_size());
3595 EXPECT_EQ("x", node.input(0));
3596 ++required_node_count;
3597 }
3598 }
3599 EXPECT_EQ(2, required_node_count);
3600 }
3601
TEST_F(ArithmeticOptimizerTest,UnaryOpsComposition)3602 TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
3603 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3604
3605 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3606 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3607 Output log = ops::Log(s.WithOpName("log"), sqrt);
3608 Output relu = ops::Relu(s.WithOpName("relu"), log);
3609 Output final_out = ops::Identity(s.WithOpName("final_out"), relu);
3610
3611 GrapplerItem item;
3612 item.fetch = {"final_out"};
3613 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3614
3615 // Place all nodes on CPU.
3616 for (int i = 0; i < item.graph.node_size(); ++i) {
3617 item.graph.mutable_node(i)->set_device("/device:CPU:0");
3618 }
3619
3620 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3621 EXPECT_EQ(1, tensors_expected.size());
3622
3623 GraphDef output;
3624 ArithmeticOptimizer optimizer;
3625 EnableOnlyUnaryOpsComposition(&optimizer);
3626 OptimizeAndPrune(&optimizer, &item, &output);
3627
3628 EXPECT_EQ(3, output.node_size());
3629
3630 // Check that Sqrt/Log/Relu were replaced with a single op.
3631 int required_node_count = 0;
3632 for (int i = 0; i < output.node_size(); ++i) {
3633 const NodeDef& node = output.node(i);
3634 if (node.name() == "final_out") {
3635 EXPECT_EQ("Identity", node.op());
3636 EXPECT_EQ(1, node.input_size());
3637 EXPECT_EQ("relu/unary_ops_composition", node.input(0));
3638 ++required_node_count;
3639 } else if (node.name() == "relu/unary_ops_composition") {
3640 EXPECT_EQ("_UnaryOpsComposition", node.op());
3641 EXPECT_EQ(1, node.input_size());
3642 EXPECT_EQ("x", node.input(0));
3643
3644 auto op_names = node.attr().at("op_names").list().s();
3645 EXPECT_EQ(3, op_names.size());
3646 EXPECT_EQ("Sqrt", op_names[0]);
3647 EXPECT_EQ("Log", op_names[1]);
3648 EXPECT_EQ("Relu", op_names[2]);
3649 ++required_node_count;
3650 }
3651 }
3652 EXPECT_EQ(2, required_node_count);
3653
3654 auto tensors = EvaluateNodes(output, item.fetch);
3655 EXPECT_EQ(1, tensors.size());
3656 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3657 }
3658
TEST_F(ArithmeticOptimizerTest,RemoveStackStridedSliceSameAxis)3659 TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) {
3660 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3661 auto a_in =
3662 ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
3663 auto b_in =
3664 ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
3665 auto c_in =
3666 ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
3667 auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
3668 PartialTensorShape({-1, -1}));
3669 auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
3670 PartialTensorShape({-1, -1}));
3671 auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
3672 PartialTensorShape({-1, -1}));
3673 // stacked = tf.stack((a, b, c), axis=1).
3674 // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
3675 auto stacked =
3676 ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
3677 ops::Stack::Axis(1));
3678 auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
3679 auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
3680 auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
3681 auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
3682 auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3});
3683 auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
3684 auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3});
3685 auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
3686 auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3});
3687 auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3});
3688 auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3});
3689
3690 // stacked[:, 0]
3691 using SS = ops::StridedSlice;
3692 auto pa_slice = ops::Identity(
3693 s.WithOpName("pa_slice_out"),
3694 SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides,
3695 SS::BeginMask(0b0101) // 5
3696 .EllipsisMask(0)
3697 .EndMask(0b0101) // 5
3698 .NewAxisMask(0)
3699 .ShrinkAxisMask(0b0010))); // 2
3700
3701 // stacked[:, 1]
3702 auto pb_slice = ops::Identity(
3703 s.WithOpName("pb_slice_out"),
3704 SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides,
3705 SS::BeginMask(0b0101) // 5
3706 .EllipsisMask(0)
3707 .EndMask(0b0101) // 5
3708 .NewAxisMask(0)
3709 .ShrinkAxisMask(0b0010))); // 2
3710
3711 // stacked[:, 2]
3712 auto pc_slice = ops::Identity(
3713 s.WithOpName("pc_slice_out"),
3714 SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides,
3715 SS::BeginMask(0b0101) // 5
3716 .EllipsisMask(0)
3717 .EndMask(0b0101) // 5
3718 .NewAxisMask(0)
3719 .ShrinkAxisMask(0b0010))); // 2
3720
3721 // stacked[:, 0:1, :]
3722 auto pa_slice_01 = ops::Identity(
3723 s.WithOpName("pa_slice_01_out"),
3724 SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides,
3725 SS::BeginMask(0b0101) // 5
3726 .EllipsisMask(0)
3727 .EndMask(0b0101) // 5
3728 .NewAxisMask(0)
3729 .ShrinkAxisMask(0)));
3730
3731 // stacked[:, :1, :]
3732 auto pa_slice_to1 = ops::Identity(
3733 s.WithOpName("pa_slice_to1_out"),
3734 SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides,
3735 SS::BeginMask(0b0111) // 7
3736 .EllipsisMask(0)
3737 .EndMask(0b0101) // 5
3738 .NewAxisMask(0)
3739 .ShrinkAxisMask(0)));
3740
3741 // stacked[:, 1:2, :]
3742 auto pb_slice_12 = ops::Identity(
3743 s.WithOpName("pb_slice_12_out"),
3744 SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides,
3745 SS::BeginMask(0b0101) // 5
3746 .EllipsisMask(0)
3747 .EndMask(0b0101) // 5
3748 .NewAxisMask(0)
3749 .ShrinkAxisMask(0)));
3750
3751 // stacked[:, 2:, :].
3752 auto pc_slice_2to = ops::Identity(
3753 s.WithOpName("pc_slice_2to_out"),
3754 SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides,
3755 SS::BeginMask(0b0101) // 5
3756 .EllipsisMask(0)
3757 .EndMask(0b0111) // 7
3758 .NewAxisMask(0)
3759 .ShrinkAxisMask(0)));
3760
3761 GrapplerItem item;
3762 item.fetch = {"a",
3763 "b",
3764 "c",
3765 "pa_slice_out",
3766 "pb_slice_out",
3767 "pc_slice_out",
3768 "expanded_a",
3769 "expanded_b",
3770 "expanded_c",
3771 "pa_slice_01_out",
3772 "pa_slice_to1_out",
3773 "pb_slice_12_out",
3774 "pc_slice_2to_out"};
3775 enum FetchItem {
3776 fA,
3777 fB,
3778 fC,
3779 fASliceOut,
3780 fBSliceOut,
3781 fCSliceOut,
3782 fExpandedA,
3783 fExpandedB,
3784 fExpandedC,
3785 fASlice01Out,
3786 fASliceTo1Out,
3787 fBSlice12Out,
3788 fCSlice2ToOut,
3789 };
3790 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3791 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3792
3793 // stacked[:, 0, :] == a.
3794 test::ExpectTensorEqual<float>(tensors_expected[fA],
3795 tensors_expected[fASliceOut]);
3796 // stacked[:, 1, :] == b.
3797 test::ExpectTensorEqual<float>(tensors_expected[fB],
3798 tensors_expected[fBSliceOut]);
3799 // stacked[:, 2, :] == c.
3800 test::ExpectTensorEqual<float>(tensors_expected[fC],
3801 tensors_expected[fCSliceOut]);
3802
3803 // stacked[:, 0:1, :] == expand_dims(a, 1).
3804 test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
3805 tensors_expected[fASlice01Out]);
3806
3807 // stacked[:, :1, :] == expand_dims(a, 1).
3808 test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
3809 tensors_expected[fASliceTo1Out]);
3810
3811 // stacked[:, 1:2, :] == expand_dims(b, 1).
3812 test::ExpectTensorEqual<float>(tensors_expected[fExpandedB],
3813 tensors_expected[fBSlice12Out]);
3814 // stacked[:, 2:, :] == expand_dims(c, 1).
3815 test::ExpectTensorEqual<float>(tensors_expected[fExpandedC],
3816 tensors_expected[fCSlice2ToOut]);
3817
3818 GraphDef output;
3819 ArithmeticOptimizer optimizer;
3820 EnableOnlyRemoveStackStridedSliceSameAxis(&optimizer);
3821 OptimizeAndPrune(&optimizer, &item, &output);
3822
3823 for (const auto& node : output.node()) {
3824 if (node.name() == "pa_slice_out") {
3825 EXPECT_EQ(node.input(0), "a");
3826 } else if (node.name() == "pb_slice_out") {
3827 EXPECT_EQ(node.input(0), "b");
3828 } else if (node.name() == "pc_slice_out") {
3829 EXPECT_EQ(node.input(0), "c");
3830 } else if (str_util::EndsWith(node.name(), "_out")) {
3831 EXPECT_EQ(strings::StrCat(node.input(0), "_out"),
3832 strings::StrCat(
3833 "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_",
3834 node.name()));
3835 }
3836 }
3837
3838 auto tensors = EvaluateNodes(output, item.fetch);
3839
3840 // stacked[:, 0, :] == a.
3841 test::ExpectTensorEqual<float>(tensors_expected[fA], tensors[fASliceOut]);
3842
3843 // stacked[:, 1, :] == b.
3844 test::ExpectTensorEqual<float>(tensors_expected[fB], tensors[fBSliceOut]);
3845 // stacked[:, 2, :] == c.
3846 test::ExpectTensorEqual<float>(tensors_expected[fC], tensors[fCSliceOut]);
3847
3848 // stacked[:, 0:1, :] == expand_dims(a, 1).
3849 test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
3850 tensors[fASlice01Out]);
3851
3852 // stacked[:, :1, :] == expand_dims(a, 1).
3853 test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
3854 tensors[fASliceTo1Out]);
3855
3856 // stacked[:, 1:2, :] == expand_dims(b, 1).
3857 test::ExpectTensorEqual<float>(tensors_expected[fExpandedB],
3858 tensors[fBSlice12Out]);
3859 // stacked[:, 2:, :] == expand_dims(c, 1).
3860 test::ExpectTensorEqual<float>(tensors_expected[fExpandedC],
3861 tensors[fCSlice2ToOut]);
3862 }
3863
TEST_F(ArithmeticOptimizerTest,SimplifyAggregationBFloat16)3864 TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
3865 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3866 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3867 Output cast = ops::Cast(s.WithOpName("cast"), x, DT_BFLOAT16);
3868 Output add = ops::AddN(s.WithOpName("add"), {cast, cast});
3869 Output id = ops::Identity(s.WithOpName("id"), add);
3870
3871 GrapplerItem item;
3872 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3873 item.fetch = {"id"};
3874 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3875 EXPECT_EQ(1, tensors_expected.size());
3876
3877 GraphDef output;
3878 ArithmeticOptimizer optimizer;
3879 EnableOnlySimplifyAggregation(&optimizer);
3880 OptimizeAndPrune(&optimizer, &item, &output);
3881
3882 // Extra node created for multiplier.
3883 EXPECT_EQ(5, output.node_size());
3884
3885 auto tensors = EvaluateNodes(output, item.fetch);
3886 EXPECT_EQ(1, tensors.size());
3887 test::ExpectTensorEqual<bfloat16>(tensors_expected[0], tensors[0]);
3888 }
3889
3890 } // namespace grappler
3891 } // namespace tensorflow
3892