• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17 
18 #include <complex>
19 
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/cc/ops/array_ops.h"
23 #include "tensorflow/cc/ops/math_ops.h"
24 #include "tensorflow/cc/ops/resource_variable_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/tensor_testutil.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
30 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
31 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
32 #include "tensorflow/core/grappler/utils.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 
39 namespace {
40 
41 constexpr char kHoistFactorOptimizerDiv[] =
42     "ArithmeticOptimizer/HoistCommonFactor_Div_";
43 
44 constexpr char kHoistFactorOptimizerMul[] =
45     "ArithmeticOptimizer/HoistCommonFactor_Mul_";
46 
47 constexpr char kHoistFactorOptimizerAdd[] =
48     "ArithmeticOptimizer/HoistCommonFactor_AddV2_";
49 
50 constexpr char kSimplifyAggregationConst[] =
51     "ArithmeticOptimizer/SimplifyAggregation_Const_";
52 
53 constexpr char kSimplifyAggregationMul[] =
54     "ArithmeticOptimizer/SimplifyAggregation_Mul_";
55 
56 // Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation.
HoistMulName(const string & name)57 string HoistMulName(const string& name) {
58   return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
59 }
60 
61 // Optimized name of outer Div node by HoistCommonFactorOutOfAggregation.
HoistDivName(const string & name)62 string HoistDivName(const string& name) {
63   return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
64 }
65 
66 // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation.
HoistAddName(const string & name)67 string HoistAddName(const string& name) {
68   return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
69 }
70 
71 // Optimized name of Const node by SimplifyAggregation.
AggregationConstName(const string & name)72 string AggregationConstName(const string& name) {
73   return AddPrefixToNodeName(name, kSimplifyAggregationConst, "");
74 }
75 
76 // Optimized name of Mul node by SimplifyAggregation.
AggregationMulName(const string & name)77 string AggregationMulName(const string& name) {
78   return AddPrefixToNodeName(name, kSimplifyAggregationMul, "");
79 }
80 
VerifyGraphsMatch(const GraphDef & original_graph,const GraphDef & optimized_graph,int line)81 void VerifyGraphsMatch(const GraphDef& original_graph,
82                        const GraphDef& optimized_graph, int line) {
83   EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
84   for (int i = 0; i < original_graph.node_size(); ++i) {
85     const NodeDef& original = original_graph.node(i);
86     const NodeDef& optimized = optimized_graph.node(i);
87     EXPECT_EQ(original.name(), optimized.name()) << line;
88     EXPECT_EQ(original.op(), optimized.op()) << line;
89     EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
90     for (int j = 0; j < original.input_size(); ++j) {
91       EXPECT_EQ(original.input(j), optimized.input(j)) << line;
92     }
93   }
94 }
95 }  // namespace
96 
TEST_F(ArithmeticOptimizerTest,NoOp)97 TEST_F(ArithmeticOptimizerTest, NoOp) {
98   // This trivial graph is so basic there's nothing to optimize.
99   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
100   GrapplerItem item;
101   CHECK(fake_input.NextItem(&item));
102 
103   ArithmeticOptimizer optimizer;
104   GraphDef output;
105   Status status = optimizer.Optimize(nullptr, item, &output);
106   TF_EXPECT_OK(status);
107   VerifyGraphsMatch(item.graph, output, __LINE__);
108 }
109 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTile)110 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTile) {
111   // Graph from b/176172427
112   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
113   Output input =
114       ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
115                        ops::Placeholder::Shape({1, 44, 1, 96, 1, 64}));
116   Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 1, 2, 1, 2, 1});
117   Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
118   Output output = ops::Identity(s.WithOpName("output"), multiply);
119 
120   GrapplerItem item;
121   item.fetch = {"output"};
122   TF_CHECK_OK(s.ToGraphDef(&item.graph));
123   auto tensor =
124       GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 44, 1, 96, 1, 64}));
125   auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
126   ASSERT_EQ(expected.size(), 1);
127 
128   GraphDef g;
129   ArithmeticOptimizer optimizer;
130   EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
131   OptimizeTwiceAndPrune(&optimizer, &item, &g);
132   EXPECT_EQ(g.node_size(), 4);
133 
134   ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
135   ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
136 
137   NodeMap node_map(&g);
138   const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
139   const NodeDef* t = node_map.GetNode(absl::StrCat(p, "_", "Tile_mul"));
140   const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
141   ASSERT_NE(t, nullptr);
142   ASSERT_NE(c, nullptr);
143   EXPECT_EQ(t->op(), "Tile");
144   ASSERT_EQ(t->input_size(), 2);
145   EXPECT_EQ(t->input(0), "input");
146   EXPECT_EQ(t->input(1), c->name());
147   EXPECT_EQ(t->attr().at("T").type(), DT_FLOAT);
148   EXPECT_EQ(t->attr().at("Tmultiples").type(), c->attr().at("dtype").type());
149 
150   auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
151   ASSERT_EQ(result.size(), 1);
152   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
153 }
154 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTilePreserveControl)155 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTilePreserveControl) {
156   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
157   Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
158                                   ops::Placeholder::Shape({1, 1, 1}));
159   Output ones = ops::Const(s.WithOpName("ones").WithControlDependencies(input),
160                            1.0f, {1, 2, 1});
161   Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
162   Output output = ops::Identity(s.WithOpName("output"), multiply);
163 
164   GrapplerItem item;
165   item.fetch = {"output"};
166   TF_CHECK_OK(s.ToGraphDef(&item.graph));
167   auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
168   auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
169   ASSERT_EQ(expected.size(), 1);
170 
171   GraphDef g;
172   ArithmeticOptimizer optimizer;
173   EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
174   OptimizeTwiceAndPrune(&optimizer, &item, &g);
175   EXPECT_EQ(g.node_size(), 4);
176 
177   ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
178   ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
179 
180   NodeMap node_map(&g);
181   const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
182   const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
183   ASSERT_NE(c, nullptr);
184   ASSERT_EQ(c->input_size(), 1);
185   EXPECT_TRUE(IsControlInput(c->input(0)));
186   EXPECT_EQ(c->input(0), "^input");
187 }
188 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNoBroadcast)189 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNoBroadcast) {
190   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
191   Output input =
192       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 2, 1}));
193   Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 2, 1});
194   Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
195   Output output = ops::Identity(s.WithOpName("output"), multiply);
196 
197   GrapplerItem item;
198   item.fetch = {"output"};
199   TF_CHECK_OK(s.ToGraphDef(&item.graph));
200   auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
201   auto expected =
202       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
203   ASSERT_EQ(expected.size(), 1);
204 
205   GraphDef g;
206   ArithmeticOptimizer optimizer;
207   EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
208   OptimizeTwiceAndPrune(&optimizer, &item, &g);
209   EXPECT_EQ(g.node_size(), 4);
210 
211   VerifyGraphsMatch(item.graph, g, __LINE__);
212 
213   auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
214   ASSERT_EQ(result.size(), 1);
215   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
216 }
217 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNotConst)218 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotConst) {
219   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
220   Output input1 = ops::Placeholder(s.WithOpName("input1"), DT_FLOAT,
221                                    ops::Placeholder::Shape({1, 1, 1}));
222   Output input2 = ops::Placeholder(s.WithOpName("input2"), DT_FLOAT,
223                                    ops::Placeholder::Shape({1, 2, 1}));
224   Output multiply = ops::Mul(s.WithOpName("multiply"), input1, input2);
225   Output output = ops::Identity(s.WithOpName("output"), multiply);
226 
227   GrapplerItem item;
228   item.fetch = {"output"};
229   TF_CHECK_OK(s.ToGraphDef(&item.graph));
230   auto tensor1 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
231   auto tensor2 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
232   auto expected = EvaluateNodes(item.graph, item.fetch,
233                                 {{"input1", tensor1}, {"input2", tensor2}});
234   ASSERT_EQ(expected.size(), 1);
235 
236   GraphDef g;
237   ArithmeticOptimizer optimizer;
238   EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
239   OptimizeTwiceAndPrune(&optimizer, &item, &g);
240   EXPECT_EQ(g.node_size(), 4);
241 
242   VerifyGraphsMatch(item.graph, g, __LINE__);
243 
244   auto result = EvaluateNodes(item.graph, item.fetch,
245                               {{"input1", tensor1}, {"input2", tensor2}});
246   ASSERT_EQ(result.size(), 1);
247   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
248 }
249 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNotOnes)250 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotOnes) {
251   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
252   Output input =
253       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 1, 1}));
254   Output ones = ops::Const(s.WithOpName("ones"), 2.0f, {1, 2, 1});
255   Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
256   Output output = ops::Identity(s.WithOpName("output"), multiply);
257 
258   GrapplerItem item;
259   item.fetch = {"output"};
260   TF_CHECK_OK(s.ToGraphDef(&item.graph));
261   auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
262   auto expected =
263       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
264   ASSERT_EQ(expected.size(), 1);
265 
266   GraphDef g;
267   ArithmeticOptimizer optimizer;
268   EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
269   OptimizeTwiceAndPrune(&optimizer, &item, &g);
270   EXPECT_EQ(g.node_size(), 4);
271 
272   VerifyGraphsMatch(item.graph, g, __LINE__);
273 
274   auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
275   ASSERT_EQ(result.size(), 1);
276   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
277 }
278 
TEST_F(ArithmeticOptimizerTest,ReduceUpsamplingDims)279 TEST_F(ArithmeticOptimizerTest, ReduceUpsamplingDims) {
280   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
281   Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
282                                   ops::Placeholder::Shape({1, 22, 48, 64}));
283   Output reshape_a = ops::Reshape(
284       s.WithOpName("reshape_a"), input,
285       ops::Const(s.WithOpName("shape_a"), {1, 22, 1, 48, 1, 64}, {6}));
286   Output tile =
287       ops::Tile(s.WithOpName("tile"), reshape_a,
288                 ops::Const(s.WithOpName("multiples"), {1, 1, 2, 1, 2, 1}, {6}));
289   Output reshape_b =
290       ops::Reshape(s.WithOpName("reshape_b"), tile,
291                    ops::Const(s.WithOpName("shape_b"), {1, 44, 96, 64}));
292   Output output = ops::Identity(s.WithOpName("output"), reshape_b);
293 
294   GrapplerItem item;
295   item.fetch = {"output"};
296   TF_CHECK_OK(s.ToGraphDef(&item.graph));
297   auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 22, 48, 64}));
298   auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
299   ASSERT_EQ(expected.size(), 1);
300 
301   GraphDef g;
302   ArithmeticOptimizer optimizer;
303   EnableOnlyReduceUpsamplingDims(&optimizer);
304   OptimizeTwiceAndPrune(&optimizer, &item, &g);
305   EXPECT_EQ(g.node_size(), 8);
306 
307   ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
308   ASSERT_EQ(CountOpNodes(g, "Reshape"), 2);
309   ASSERT_EQ(CountOpNodes(g, "Const"), 3);
310 
311   NodeMap node_map(&g);
312   const string p = "ArithmeticOptimizer/ReduceUpsamplingDims";
313   const NodeDef* ra =
314       node_map.GetNode(absl::StrCat(p, "_", "Reshape_reshape_b"));
315   const NodeDef* rb = node_map.GetNode("reshape_b");
316   const NodeDef* t = node_map.GetNode(absl::StrCat(p, "_", "Tile_reshape_b"));
317   ASSERT_NE(ra, nullptr);
318   ASSERT_NE(rb, nullptr);
319   ASSERT_NE(t, nullptr);
320 
321   ASSERT_EQ(rb->input_size(), 2);
322   EXPECT_EQ(rb->input(0), t->name());
323   ASSERT_EQ(t->input_size(), 2);
324   EXPECT_EQ(t->input(0), ra->name());
325   ASSERT_EQ(ra->input_size(), 2);
326   EXPECT_EQ(ra->input(0), "input");
327 
328   {
329     auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
330     ASSERT_EQ(result.size(), 1);
331     test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
332   }
333 
334   // Check to make sure the first reshape is removed
335   EnableOnlyRemoveRedundantReshape(&optimizer);
336   OptimizeTwiceAndPrune(&optimizer, &item, &g);
337   EXPECT_EQ(g.node_size(), 6);
338 
339   {
340     auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
341     ASSERT_EQ(result.size(), 1);
342     test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
343   }
344 }
345 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithSquare)346 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
347   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
348   Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
349   Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
350   Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
351   Output mul_no_nan = ops::MulNoNan(s.WithOpName("mul_no_nan"), d, d);
352   Output id = ops::Identity(s.WithOpName("id"), mul);
353   Output id2 = ops::Identity(s.WithOpName("id2"), mul_no_nan);
354 
355   GrapplerItem item;
356   item.fetch = {"id", "id2"};
357   TF_CHECK_OK(s.ToGraphDef(&item.graph));
358   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
359   ASSERT_EQ(tensors_expected.size(), 2);
360 
361   GraphDef output;
362   ArithmeticOptimizer optimizer;
363   EnableOnlyReplaceMulWithSquare(&optimizer);
364   OptimizeAndPrune(&optimizer, &item, &output);
365 
366   EXPECT_EQ(output.node_size(), 6);
367 
368   NodeMap node_map(&output);
369   const string p = "ArithmeticOptimizer/ReplaceMulWithSquare";
370   const NodeDef* square_node = node_map.GetNode(absl::StrCat(p, "_", "mul"));
371 
372   ASSERT_NE(square_node, nullptr);
373   EXPECT_EQ(square_node->op(), "Square");
374   ASSERT_EQ(square_node->input_size(), 2);
375   EXPECT_EQ(square_node->input(0), "c");
376   EXPECT_EQ(square_node->input(1), "^d");
377 
378   const NodeDef* square_node2 =
379       node_map.GetNode(absl::StrCat(p, "_", "mul_no_nan"));
380   ASSERT_NE(square_node2, nullptr);
381   EXPECT_EQ(square_node2->op(), "Square");
382   ASSERT_EQ(square_node2->input_size(), 1);
383   EXPECT_EQ(square_node2->input(0), "d");
384 
385   auto tensors = EvaluateNodes(output, item.fetch);
386   ASSERT_EQ(tensors.size(), 2);
387   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
388 }
389 
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshape)390 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshape) {
391   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
392   Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
393                               ops::Placeholder::Shape({3, 5, 7, 11}));
394   // Stack creates Pack nodes
395   Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(3));
396   Output c = ops::Stack(s.WithOpName("c"), {b, b}, ops::Stack::Axis(2));
397   Output o = ops::Identity(s.WithOpName("output"), c);
398 
399   GrapplerItem item;
400   item.fetch = {"output"};
401   TF_CHECK_OK(s.ToGraphDef(&item.graph));
402   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
403   auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
404   ASSERT_EQ(expected.size(), 1);
405 
406   GraphDef g;
407   ArithmeticOptimizer optimizer;
408   EnableOnlyReplacePackWithTileReshape(&optimizer);
409   OptimizeAndPrune(&optimizer, &item, &g);
410 
411   EXPECT_EQ(g.node_size(), 6);
412   EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
413   EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
414   EXPECT_EQ(CountOpNodes(g, "Const"), 2);
415   EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
416 
417   NodeMap node_map(&g);
418   const string p = "ArithmeticOptimizer/ReplacePackWithTileReshape";
419   const NodeDef* t_node = node_map.GetNode(absl::StrCat(p, "_", "Tile_c"));
420   const NodeDef* c_node = node_map.GetNode(absl::StrCat(p, "_", "Multiples_c"));
421   const NodeDef* s_node = node_map.GetNode(absl::StrCat(p, "_", "Shape_c"));
422   const NodeDef* r_node = node_map.GetNode(absl::StrCat(p, "_", "Reshape_c"));
423   const NodeDef* a_node = node_map.GetNode("a");
424   ASSERT_NE(t_node, nullptr);
425   ASSERT_NE(c_node, nullptr);
426   ASSERT_NE(s_node, nullptr);
427   ASSERT_NE(r_node, nullptr);
428   ASSERT_NE(a_node, nullptr);
429 
430   EXPECT_EQ(c_node->op(), "Const");
431   EXPECT_EQ(s_node->op(), "Const");
432 
433   // Check Reshape properties
434   ASSERT_EQ(r_node->input_size(), 2);
435   EXPECT_EQ(r_node->op(), "Reshape");
436   EXPECT_EQ(r_node->input(0), t_node->name());
437   EXPECT_EQ(r_node->input(1), s_node->name());
438 
439   // Check Tile properties
440   ASSERT_EQ(t_node->input_size(), 2);
441   EXPECT_EQ(t_node->op(), "Tile");
442   EXPECT_EQ(t_node->input(0), a_node->name());
443   EXPECT_EQ(t_node->input(1), c_node->name());
444   EXPECT_EQ(t_node->attr().at("T").type(), DT_FLOAT);
445   EXPECT_EQ(t_node->attr().at("Tmultiples").type(),
446             c_node->attr().at("dtype").type());
447 
448   auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
449   ASSERT_EQ(result.size(), 1);
450   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
451 }
452 
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshapeControlDeps)453 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshapeControlDeps) {
454   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
455   Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
456                               ops::Placeholder::Shape({3, 5, 7, 11}));
457 
458   Output x = ops::Identity(s.WithOpName("x"), a);
459   Output y = ops::Identity(s.WithOpName("y"), a);
460 
461   Output b = ops::Stack(s.WithOpName("b").WithControlDependencies(x), {a, a},
462                         ops::Stack::Axis(3));
463   Output c = ops::Stack(s.WithOpName("c").WithControlDependencies(y), {b, b},
464                         ops::Stack::Axis(2));
465   Output o = ops::Identity(s.WithOpName("output"), c);
466 
467   GrapplerItem item;
468   item.fetch = {"output"};
469   item.keep_ops = {"x", "y"};
470   TF_CHECK_OK(s.ToGraphDef(&item.graph));
471   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
472   auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
473   ASSERT_EQ(expected.size(), 1);
474 
475   GraphDef g;
476   ArithmeticOptimizer optimizer;
477   EnableOnlyReplacePackWithTileReshape(&optimizer);
478   OptimizeAndPrune(&optimizer, &item, &g);
479 
480   EXPECT_EQ(g.node_size(), 8);
481   EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
482   EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
483   EXPECT_EQ(CountOpNodes(g, "Const"), 2);
484   EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
485   EXPECT_EQ(CountOpNodes(g, "Identity"), 3);
486 
487   NodeMap node_map(&g);
488   const string p = "ArithmeticOptimizer/ReplacePackWithTileReshape";
489   const NodeDef* t_node = node_map.GetNode(absl::StrCat(p, "_", "Tile_c"));
490   const NodeDef* c_node = node_map.GetNode(absl::StrCat(p, "_", "Multiples_c"));
491   const NodeDef* s_node = node_map.GetNode(absl::StrCat(p, "_", "Shape_c"));
492   const NodeDef* a_node = node_map.GetNode("a");
493   ASSERT_NE(t_node, nullptr);
494   ASSERT_NE(c_node, nullptr);
495   ASSERT_NE(s_node, nullptr);
496   ASSERT_NE(a_node, nullptr);
497 
498   ASSERT_EQ(t_node->input_size(), 4);
499   EXPECT_EQ(t_node->op(), "Tile");
500   EXPECT_EQ(t_node->input(0), a_node->name());
501   EXPECT_EQ(t_node->input(1), c_node->name());
502   EXPECT_EQ(t_node->input(2), "^y");
503   EXPECT_EQ(t_node->input(3), "^x");
504 
505   ASSERT_EQ(c_node->input_size(), 1);
506   EXPECT_EQ(c_node->input(0), "^a");
507 
508   ASSERT_EQ(s_node->input_size(), 1);
509   ASSERT_EQ(s_node->input(0), "^a");
510 
511   auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
512   ASSERT_EQ(result.size(), 1);
513   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
514 }
515 
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileRemoveReshape)516 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileRemoveReshape) {
517   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
518   Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
519                               ops::Placeholder::Shape({3, 5, 7, 11}));
520   // Stack creates Pack nodes
521   Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(3));
522   Output c = ops::Stack(s.WithOpName("c"), {b, b}, ops::Stack::Axis(2));
523   Output r =
524       ops::Reshape(s.WithOpName("r"), c, ops::Const(s, {3, 10, 14, 11}, {4}));
525   Output o = ops::Identity(s.WithOpName("output"), r);
526 
527   GrapplerItem item;
528   item.fetch = {"output"};
529   TF_CHECK_OK(s.ToGraphDef(&item.graph));
530   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
531   auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
532   ASSERT_EQ(expected.size(), 1);
533 
534   GraphDef g;
535   ArithmeticOptimizer optimizer;
536   EnableOnlyReplacePackWithTileReshape(&optimizer);
537   OptimizeAndPrune(&optimizer, &item, &g);
538 
539   EXPECT_EQ(g.node_size(), 8);
540   EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
541   EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
542   EXPECT_EQ(CountOpNodes(g, "Const"), 3);
543   EXPECT_EQ(CountOpNodes(g, "Reshape"), 2);
544 
545   EnableOnlyRemoveRedundantReshape(&optimizer);
546   OptimizeAndPrune(&optimizer, &item, &g);
547 
548   EXPECT_EQ(g.node_size(), 6);
549   EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
550   EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
551   EXPECT_EQ(CountOpNodes(g, "Const"), 2);
552   EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
553 
554   auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
555   ASSERT_EQ(result.size(), 1);
556   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
557 }
558 
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshapeOutOfRange)559 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshapeOutOfRange) {
560   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
561   Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
562                               ops::Placeholder::Shape({3, 5, 7, 11}));
563   // Stack creates Pack nodes
564   Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(4));
565   Output o = ops::Identity(s.WithOpName("output"), b);
566 
567   GrapplerItem item;
568   item.fetch = {"output"};
569   TF_CHECK_OK(s.ToGraphDef(&item.graph));
570 
571   GraphDef g;
572   ArithmeticOptimizer optimizer;
573   EnableOnlyReplacePackWithTileReshape(&optimizer);
574   OptimizeAndPrune(&optimizer, &item, &g);
575 
576   VerifyGraphsMatch(item.graph, g, __LINE__);
577 }
578 
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAdjacentNodes)579 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAdjacentNodes) {
580   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
581 
582   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
583   auto neg1 = ops::Neg(s.WithOpName("neg1"), c);
584   auto neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
585   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
586   auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
587   auto id = ops::Identity(s.WithOpName("id"), recip2);
588 
589   GrapplerItem item;
590   item.fetch = {"id"};
591   TF_CHECK_OK(s.ToGraphDef(&item.graph));
592   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
593   ASSERT_EQ(tensors_expected.size(), 1);
594 
595   GraphDef output;
596   ArithmeticOptimizer optimizer;
597   EnableOnlyRemoveInvolution(&optimizer);
598   OptimizeAndPrune(&optimizer, &item, &output);
599 
600   // Negation and Reciprocal nodes cancelled each other.
601   ASSERT_EQ(output.node_size(), 2);
602   EXPECT_EQ(output.node(1).name(), "id");
603   ASSERT_EQ(output.node(1).input_size(), 1);
604   EXPECT_EQ(output.node(1).input(0), "c");
605 
606   auto tensors = EvaluateNodes(output, item.fetch);
607   ASSERT_EQ(tensors.size(), 1);
608   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
609 }
610 
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAroundValuePreservingChain)611 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAroundValuePreservingChain) {
612   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
613 
614   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
615   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
616   auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
617   auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
618   auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze);
619   auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
620 
621   std::vector<string> fetch = {"id2"};
622 
623   GrapplerItem item;
624   item.fetch = fetch;
625   TF_CHECK_OK(s.ToGraphDef(&item.graph));
626   auto tensors_expected = EvaluateNodes(item.graph, fetch);
627   ASSERT_EQ(tensors_expected.size(), 1);
628 
629   GraphDef output;
630   ArithmeticOptimizer optimizer;
631   EnableOnlyRemoveInvolution(&optimizer);
632   OptimizeTwiceAndPrune(&optimizer, &item, &output);
633 
634   // Check that Reciprocal nodes were removed from the graph.
635   EXPECT_EQ(output.node_size(), 3);
636 
637   // And const directly flows into squeeze.
638   int found = 0;
639   for (const NodeDef& node : output.node()) {
640     if (node.name() == "squeeze") {
641       ASSERT_EQ(node.input_size(), 1);
642       EXPECT_EQ(node.input(0), "c");
643       found++;
644     } else if (node.name() == "id2") {
645       ASSERT_EQ(node.input_size(), 1);
646       EXPECT_EQ(node.input(0), "squeeze");
647       found++;
648     }
649   }
650   EXPECT_EQ(found, 2);
651 
652   auto tensors = EvaluateNodes(output, fetch);
653   ASSERT_EQ(tensors.size(), 1);
654   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
655 }
656 
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionSkipControlDependencies)657 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionSkipControlDependencies) {
658   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
659 
660   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
661   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
662   auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
663   auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
664   auto recip2 = ops::Reciprocal(
665       s.WithOpName("recip2").WithControlDependencies(squeeze), c);
666   auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
667 
668   std::vector<string> fetch = {"id2"};
669 
670   GrapplerItem item;
671   item.fetch = fetch;
672   TF_CHECK_OK(s.ToGraphDef(&item.graph));
673 
674   auto tensors_expected = EvaluateNodes(item.graph, fetch);
675   ASSERT_EQ(tensors_expected.size(), 1);
676 
677   GraphDef output;
678   ArithmeticOptimizer optimizer;
679   EnableOnlyRemoveInvolution(&optimizer);
680   OptimizeTwice(&optimizer, &item, &output);  // do not prune in this test
681 
682   // The optimizer should be a noop.
683   VerifyGraphsMatch(item.graph, output, __LINE__);
684 
685   auto tensors = EvaluateNodes(output, fetch);
686   ASSERT_EQ(tensors.size(), 1);
687   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
688 }
689 
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimple)690 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
691   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
692   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
693   Output add = ops::Add(s.WithOpName("add"), x, x);
694   Output id = ops::Identity(s.WithOpName("id"), add);
695 
696   GrapplerItem item;
697   item.fetch = {"id"};
698   TF_CHECK_OK(s.ToGraphDef(&item.graph));
699 
700   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
701   ASSERT_EQ(tensors_expected.size(), 1);
702 
703   ArithmeticOptimizer optimizer;
704   GraphDef output;
705   OptimizeTwice(&optimizer, &item, &output);
706   NodeMap node_map(&output);
707 
708   EXPECT_EQ(output.node_size(), 5);
709 
710   const string optimized_const_name = AggregationConstName("add");
711   const string optimized_mul_name = AggregationMulName("add");
712 
713   const NodeDef* new_const = node_map.GetNode(optimized_const_name);
714   ASSERT_NE(new_const, nullptr);
715   ASSERT_EQ(new_const->input_size(), 1);
716   EXPECT_EQ(new_const->input(0), "^x");
717   EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
718             string("\0\0\0@", 4));
719 
720   const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
721   ASSERT_NE(new_mul, nullptr);
722   ASSERT_EQ(new_mul->input_size(), 2);
723   EXPECT_EQ(new_mul->input(0), optimized_const_name);
724   EXPECT_EQ(new_mul->input(1), "x");
725 
726   const NodeDef* new_id = node_map.GetNode("id");
727   ASSERT_NE(new_id, nullptr);
728   ASSERT_EQ(new_id->input_size(), 1);
729   EXPECT_EQ(new_id->input(0), optimized_mul_name);
730 
731   auto tensors = EvaluateNodes(output, item.fetch);
732   ASSERT_EQ(tensors.size(), 1);
733   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
734 }
735 
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimpleWithControlDep)736 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
737   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
738   Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
739   Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2});
740   Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x);
741   Output id = ops::Identity(s.WithOpName("id"), add);
742 
743   GrapplerItem item;
744   TF_CHECK_OK(s.ToGraphDef(&item.graph));
745 
746   std::vector<string> fetch = {"id"};
747   auto tensors_expected = EvaluateNodes(item.graph, fetch);
748   ASSERT_EQ(tensors_expected.size(), 1);
749 
750   ArithmeticOptimizer optimizer;
751   GraphDef output;
752   OptimizeTwice(&optimizer, &item, &output);
753   NodeMap node_map(&output);
754 
755   EXPECT_EQ(output.node_size(), 6);
756 
757   const string optimized_const_name = AggregationConstName("add");
758   const string optimized_mul_name = AggregationMulName("add");
759 
760   const NodeDef* new_const = node_map.GetNode(optimized_const_name);
761   ASSERT_NE(new_const, nullptr);
762   ASSERT_EQ(new_const->input_size(), 1);
763   EXPECT_EQ(new_const->input(0), "^x");
764   EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
765             string("\0\0\0@", 4));
766 
767   const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
768   ASSERT_NE(new_mul, nullptr);
769   ASSERT_EQ(new_mul->input_size(), 3);
770   EXPECT_EQ(new_mul->input(0), optimized_const_name);
771   EXPECT_EQ(new_mul->input(1), "x");
772   EXPECT_EQ(new_mul->input(2), "^y");
773 
774   const NodeDef* new_id = node_map.GetNode("id");
775   ASSERT_NE(new_id, nullptr);
776   ASSERT_EQ(new_id->input_size(), 1);
777   EXPECT_EQ(new_id->input(0), optimized_mul_name);
778 
779   auto tensors = EvaluateNodes(output, fetch);
780   ASSERT_EQ(tensors.size(), 1);
781   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
782 }
783 
TEST_F(ArithmeticOptimizerTest,TrivialSumsRepeatedAdd)784 TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
785   // Test case from b/69059093.
786   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
787   Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10}));
788   Output add = ops::Add(s.WithOpName("Add"), p, p);
789   Output add1 = ops::Add(s.WithOpName("Add_1"), p, p);
790   Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1);
791   Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1);
792   Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5);
793   Output id = ops::Identity(s.WithOpName("id"), add6);
794 
795   GrapplerItem item;
796   item.fetch = {"id"};
797   TF_CHECK_OK(s.ToGraphDef(&item.graph));
798 
799   const std::vector<string> devices{
800       "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1",
801       "/device:CPU:0", "/device:CPU:0", "/device:CPU:0",
802   };
803   for (int i = 0; i < item.graph.node_size(); ++i) {
804     item.graph.mutable_node(i)->set_device(devices[i]);
805   }
806 
807   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
808   DisableAddToAddNCombining(&optimizer);
809 
810   GraphDef output;
811   DedupAndOptimizeTwiceAndPrune(&optimizer, &item, &output);
812 
813   // We expect the following rewrite(s) to occur:
814   //
815   // Mul(p,
816   //     Add_6(Add_4(Const(2), Const(2)),
817   //           Add_5(Const(2), Const(2)))
818   NodeMap node_map(&output);
819 
820   EXPECT_EQ(output.node_size(), 8);
821 
822   const NodeDef* id_node = node_map.GetNode("id");
823   ASSERT_NE(id_node, nullptr);
824   ASSERT_EQ(id_node->input_size(), 1);
825   EXPECT_EQ(id_node->input(0), HoistMulName("Add_6"));
826 
827   const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
828   ASSERT_NE(mul_node, nullptr);
829   ASSERT_EQ(mul_node->input_size(), 2);
830   EXPECT_EQ(mul_node->input(0), "Placeholder");
831   EXPECT_EQ(mul_node->input(1), HoistAddName("Add_6"));
832 
833   const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
834   ASSERT_NE(add_6_node, nullptr);
835   ASSERT_EQ(add_6_node->input_size(), 2);
836   EXPECT_EQ(add_6_node->input(0), HoistAddName("Add_4"));
837   EXPECT_EQ(add_6_node->input(1), HoistAddName("Add_5"));
838 
839   const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
840   ASSERT_NE(add_4_node, nullptr);
841   EXPECT_EQ(add_4_node->op(), "Add");
842   ASSERT_EQ(2, add_4_node->input_size());
843   EXPECT_EQ(add_4_node->input(0), AggregationConstName("Add"));
844   EXPECT_EQ(add_4_node->input(1), AggregationConstName("Add_1"));
845 
846   const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
847   ASSERT_NE(add_5_node, nullptr);
848   EXPECT_EQ(add_5_node->op(), "Add");
849   ASSERT_EQ(add_5_node->input_size(), 2);
850   EXPECT_EQ(add_5_node->input(0), AggregationConstName("Add"));
851   EXPECT_EQ(add_5_node->input(1), AggregationConstName("Add_1"));
852 
853   const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add"));
854   ASSERT_NE(add_const_node, nullptr);
855   EXPECT_EQ(add_const_node->op(), "Const");
856   ASSERT_EQ(add_const_node->input_size(), 1);
857   EXPECT_EQ(add_const_node->input(0), "^Placeholder");
858 
859   const NodeDef* add_1_const_node =
860       node_map.GetNode(AggregationConstName("Add_1"));
861   ASSERT_NE(add_1_const_node, nullptr);
862   EXPECT_EQ(add_1_const_node->op(), "Const");
863   ASSERT_EQ(add_1_const_node->input_size(), 1);
864   EXPECT_EQ(add_1_const_node->input(0), "^Placeholder");
865 }
866 
TEST_F(ArithmeticOptimizerTest,HoistFactorMul)867 TEST_F(ArithmeticOptimizerTest, HoistFactorMul) {
868   for (bool matching_shapes : {true, false}) {
869     for (bool use_addn : {true, false}) {
870       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
871       Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
872       Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
873       Output y2 = matching_shapes
874                       ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2})
875                       : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
876       Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
877       Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
878       Output id =
879           use_addn ? ops::Identity(s.WithOpName("id"),
880                                    ops::AddN(s.WithOpName("add"), {mul1, mul2}))
881                    : ops::Identity(s.WithOpName("id"),
882                                    ops::Add(s.WithOpName("add"), mul1, mul2));
883 
884       GrapplerItem item;
885       item.fetch = {"id"};
886       TF_CHECK_OK(s.ToGraphDef(&item.graph));
887       auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
888       ASSERT_EQ(tensors_expected.size(), 1);
889       ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
890       EnableOnlyHoistCommonFactor(&optimizer);
891 
892       GraphDef output;
893       OptimizeTwice(&optimizer, &item, &output);
894 
895       // We expect the following rewrite(s) to occur:
896       //
897       //        Add                 Mul
898       //      /    \               /   \
899       //    Mul    Mul       ->   x    Add
900       //    / \    / \                 / \
901       //   x  y1  y2  x              y1   y2
902       //
903       // If "root" op is AddN and shapes does not match, this rewrite is not
904       // possible and graph should stay intact.
905       NodeMap node_map(&output);
906 
907       if (use_addn && !matching_shapes) {
908         VerifyGraphsMatch(item.graph, output, __LINE__);
909       } else {
910         EXPECT_EQ(output.node_size(), 9);
911 
912         const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
913         ASSERT_NE(new_add_node, nullptr) << "Hoisted Add node not found";
914         ASSERT_EQ(new_add_node->input_size(), 2);
915         EXPECT_EQ(new_add_node->input(0), "y1");
916         EXPECT_EQ(new_add_node->input(1), "y2");
917 
918         const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
919         ASSERT_NE(new_mul_node, nullptr) << "Hoisted Mul node not found";
920         ASSERT_EQ(new_mul_node->input_size(), 2);
921         EXPECT_EQ(new_mul_node->input(0), "x");
922         EXPECT_EQ(new_mul_node->input(1), new_add_node->name());
923 
924         const NodeDef* id_node = node_map.GetNode("id");
925         ASSERT_NE(id_node, nullptr) << "Id node not found";
926         EXPECT_EQ(id_node->name(), "id");
927         ASSERT_EQ(id_node->input_size(), 1);
928         EXPECT_EQ(id_node->input(0), HoistMulName("add"));
929       }
930       auto tensors = EvaluateNodes(output, item.fetch);
931       ASSERT_EQ(tensors.size(), 1);
932       test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
933     }
934   }
935 }
936 
TEST_F(ArithmeticOptimizerTest,HoistFactorDiv)937 TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
938   for (bool matching_shapes : {true, false}) {
939     for (bool use_addn : {true, false}) {
940       for (bool use_ints : {true, false}) {
941         tensorflow::Scope s = tensorflow::Scope::NewRootScope();
942         Output x = use_ints
943                        ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2})
944                        : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
945         Output y1 = use_ints
946                         ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2})
947                         : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
948         Output y2;
949         if (matching_shapes) {
950           y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2})
951                         : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
952         } else {
953           y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1})
954                         : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
955         }
956         Output div1 = ops::Div(s.WithOpName("div1"), y1, x);
957         Output div2 = ops::Div(s.WithOpName("div2"), y2, x);
958         Output id =
959             use_addn
960                 ? ops::Identity(s.WithOpName("id"),
961                                 ops::AddN(s.WithOpName("add"), {div1, div2}))
962                 : ops::Identity(s.WithOpName("id"),
963                                 ops::Add(s.WithOpName("add"), div1, div2));
964 
965         GrapplerItem item;
966         item.fetch = {"id"};
967         TF_CHECK_OK(s.ToGraphDef(&item.graph));
968 
969         auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
970         ASSERT_EQ(tensors_expected.size(), 1);
971 
972         ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
973         EnableOnlyHoistCommonFactor(&optimizer);
974 
975         GraphDef output;
976         OptimizeTwice(&optimizer, &item, &output);
977 
978         // We expect the following rewrite(s) to occur:
979         //
980         //        Add                 Div
981         //      /    \               /   \
982         //    Div    Div       ->  Add    x
983         //    / \    / \           / \
984         //   y1  x  y2  x         y1  y2
985         //
986         // If "root" op is AddN and shapes does not match, this rewrite is not
987         // possible and graph should stay intact.
988         NodeMap node_map(&output);
989 
990         if ((use_addn && !matching_shapes) || use_ints) {
991           VerifyGraphsMatch(item.graph, output, __LINE__);
992         } else {
993           EXPECT_EQ(output.node_size(), 9);
994 
995           const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
996           ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
997           ASSERT_EQ(new_add_node->input_size(), 2);
998           EXPECT_EQ(new_add_node->input(0), "y1");
999           EXPECT_EQ(new_add_node->input(1), "y2");
1000 
1001           const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add"));
1002           ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found";
1003           ASSERT_EQ(new_div_node->input_size(), 2);
1004           EXPECT_EQ(new_div_node->input(0), new_add_node->name());
1005           EXPECT_EQ(new_div_node->input(1), "x");
1006 
1007           const NodeDef* id_node = node_map.GetNode("id");
1008           ASSERT_TRUE(id_node != nullptr) << "Id node not found";
1009           EXPECT_EQ("id", id_node->name());
1010           ASSERT_EQ(id_node->input_size(), 1);
1011           EXPECT_EQ(id_node->input(0), HoistDivName("add"));
1012         }
1013         auto tensors = EvaluateNodes(output, item.fetch);
1014         ASSERT_EQ(tensors.size(), 1);
1015         if (use_ints) {
1016           test::ExpectTensorEqual<int32>(tensors[0], tensors_expected[0]);
1017         } else {
1018           test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1019         }
1020       }
1021     }
1022   }
1023 }
1024 
TEST_F(ArithmeticOptimizerTest,FuseConjAndTranspose)1025 TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
1026   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1027   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1028   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1029   Output z = ops::Complex(s.WithOpName("z"), re, im);
1030   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1031   Output conj = ops::Conj(s.WithOpName("conj"), z);
1032   Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
1033 
1034   GrapplerItem item;
1035   item.fetch = {"trans"};
1036   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1037 
1038   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1039   ASSERT_EQ(tensors_expected.size(), 1);
1040 
1041   ArithmeticOptimizer optimizer;
1042   GraphDef output;
1043   OptimizeTwice(&optimizer, &item, &output);
1044   NodeMap node_map(&output);
1045 
1046   EXPECT_EQ(output.node_size(), 7);
1047 
1048   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1049   const string optimized_name = absl::StrCat(p, "_", "trans");
1050 
1051   const NodeDef* trans_fused_node = node_map.GetNode(optimized_name);
1052   ASSERT_NE(trans_fused_node, nullptr);
1053   EXPECT_EQ(trans_fused_node->op(), "ConjugateTranspose");
1054   ASSERT_EQ(trans_fused_node->input_size(), 2);
1055   EXPECT_EQ(trans_fused_node->input(0), "z");
1056   EXPECT_EQ(trans_fused_node->input(1), "perm");
1057 
1058   auto tensors = EvaluateNodes(output, item.fetch);
1059   ASSERT_EQ(tensors.size(), 1);
1060   test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1061 }
1062 
TEST_F(ArithmeticOptimizerTest,FuseConjAndConjugateTranspose)1063 TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
1064   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1065 
1066   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1067   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1068   Output z = ops::Complex(s.WithOpName("z"), re, im);
1069   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1070   Output conj = ops::Conj(s.WithOpName("conj"), z);
1071   Output transp =
1072       ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
1073 
1074   GrapplerItem item;
1075   item.fetch = {"conjugate_trans"};
1076   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1077 
1078   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1079   ASSERT_EQ(tensors_expected.size(), 1);
1080 
1081   ArithmeticOptimizer optimizer;
1082   GraphDef output;
1083   OptimizeTwice(&optimizer, &item, &output);
1084   NodeMap node_map(&output);
1085 
1086   EXPECT_EQ(output.node_size(), 7);
1087 
1088   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1089   const string optimized_name = absl::StrCat(p, "_", "conjugate_trans");
1090 
1091   const NodeDef* conjugate_trans_fused_node = node_map.GetNode(optimized_name);
1092   ASSERT_NE(conjugate_trans_fused_node, nullptr);
1093   EXPECT_EQ(conjugate_trans_fused_node->op(), "Transpose");
1094   ASSERT_EQ(conjugate_trans_fused_node->input_size(), 2);
1095   EXPECT_EQ(conjugate_trans_fused_node->input(0), "z");
1096   EXPECT_EQ(conjugate_trans_fused_node->input(1), "perm");
1097 
1098   auto tensors = EvaluateNodes(output, item.fetch);
1099   ASSERT_EQ(tensors.size(), 1);
1100   test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1101 }
1102 
TEST_F(ArithmeticOptimizerTest,FuseTransposeAndConj)1103 TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
1104   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1105   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1106   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1107   Output z = ops::Complex(s.WithOpName("z"), re, im);
1108   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1109   Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
1110   Output conj = ops::Conj(s.WithOpName("conj"), trans);
1111 
1112   GrapplerItem item;
1113   item.fetch = {"conj"};
1114   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1115 
1116   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1117   ASSERT_EQ(tensors_expected.size(), 1);
1118 
1119   ArithmeticOptimizer optimizer;
1120   GraphDef output;
1121   OptimizeTwice(&optimizer, &item, &output);
1122   NodeMap node_map(&output);
1123 
1124   EXPECT_EQ(output.node_size(), 7);
1125 
1126   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1127   const string optimized_name = absl::StrCat(p, "_", "conj");
1128 
1129   const NodeDef* conj_fused_node = node_map.GetNode(optimized_name);
1130   ASSERT_NE(conj_fused_node, nullptr);
1131   EXPECT_EQ(conj_fused_node->op(), "ConjugateTranspose");
1132   ASSERT_EQ(conj_fused_node->input_size(), 2);
1133   EXPECT_EQ(conj_fused_node->input(0), "z");
1134   EXPECT_EQ(conj_fused_node->input(1), "perm");
1135 
1136   auto tensors = EvaluateNodes(output, item.fetch);
1137   ASSERT_EQ(tensors.size(), 1);
1138   test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1139 }
1140 
TEST_F(ArithmeticOptimizerTest,FoldTransposeIntoMatMul)1141 TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
1142   for (const string matmul_type :
1143        {"MatMul", "SparseMatMul", "BatchMatMul", "BatchMatMulV2"}) {
1144     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1145 
1146     Output a = ops::Const(s.WithOpName("a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1147     Output b = ops::Const(s.WithOpName("b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1148     Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1149     Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
1150     Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
1151 
1152     Output matmul;
1153     auto matmul_op = s.WithOpName("matmul");
1154     if (matmul_type == "MatMul") {
1155       matmul = ops::MatMul(matmul_op, trans_a, trans_b);
1156     } else if (matmul_type == "SparseMatMul") {
1157       matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
1158     } else if (matmul_type == "BatchMatMul") {
1159       matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
1160     } else if (matmul_type == "BatchMatMulV2") {
1161       matmul = ops::BatchMatMulV2(matmul_op, trans_a, trans_b);
1162     }
1163 
1164     auto identity = ops::Identity(s.WithOpName("identity"), matmul);
1165 
1166     GrapplerItem item;
1167     item.fetch = {"identity"};
1168     TF_CHECK_OK(s.ToGraphDef(&item.graph));
1169 
1170     auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1171     ASSERT_EQ(tensors_expected.size(), 1);
1172 
1173     ArithmeticOptimizer optimizer;
1174     EnableOnlyFoldTransposeIntoMatMul(&optimizer);
1175     GraphDef output;
1176     OptimizeTwice(&optimizer, &item, &output);
1177     NodeMap node_map(&output);
1178 
1179     EXPECT_EQ(output.node_size(), 8);
1180 
1181     const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
1182     const string optimized_name = absl::StrCat(p, "_", "matmul");
1183 
1184     const NodeDef* matmul_fused_node = node_map.GetNode(optimized_name);
1185     ASSERT_NE(matmul_fused_node, nullptr);
1186     ASSERT_EQ(matmul_fused_node->input_size(), 2);
1187     EXPECT_EQ(matmul_fused_node->input(0), "a");
1188     EXPECT_EQ(matmul_fused_node->input(1), "b");
1189 
1190     if (matmul_type == "BatchMatMul" || matmul_type == "BatchMatMulV2") {
1191       EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b());
1192       EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b());
1193     } else {
1194       EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
1195       EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
1196     }
1197 
1198     const NodeDef* identity_node = node_map.GetNode("identity");
1199     ASSERT_NE(identity_node, nullptr);
1200     ASSERT_EQ(identity_node->input_size(), 1);
1201     EXPECT_EQ(identity_node->input(0), optimized_name);
1202 
1203     auto tensors = EvaluateNodes(output, item.fetch);
1204     ASSERT_EQ(tensors.size(), 1);
1205     test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1206   }
1207 }
1208 
TEST_F(ArithmeticOptimizerTest,FoldConjugateTransposeIntoBatchMatMul)1209 TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
1210   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1211 
1212   Output re_a =
1213       ops::Const(s.WithOpName("re_a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1214   Output im_a =
1215       ops::Const(s.WithOpName("im_a"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
1216   Output re_b =
1217       ops::Const(s.WithOpName("re_b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1218   Output im_b =
1219       ops::Const(s.WithOpName("im_b"), {-5.0f, -6.0f, -7.0f, -8.0f}, {2, 2});
1220   Output a = ops::Complex(s.WithOpName("a"), re_a, im_a);
1221   Output b = ops::Complex(s.WithOpName("b"), re_b, im_b);
1222   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1223   Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
1224   Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
1225   Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
1226   Output identity = ops::Identity(s.WithOpName("identity"), matmul);
1227 
1228   GrapplerItem item;
1229   item.fetch = {"identity"};
1230   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1231 
1232   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1233   ASSERT_EQ(tensors_expected.size(), 1);
1234 
1235   ArithmeticOptimizer optimizer;
1236   GraphDef output;
1237   OptimizeTwice(&optimizer, &item, &output);
1238 
1239   NodeMap node_map(&output);
1240   EXPECT_EQ(output.node_size(), 12);
1241 
1242   const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
1243   const string optimized_name = absl::StrCat(p, "_", "matmul");
1244 
1245   const NodeDef* optimized_matmul = node_map.GetNode(optimized_name);
1246   ASSERT_NE(optimized_matmul, nullptr);
1247   ASSERT_EQ(optimized_matmul->input_size(), 2);
1248   EXPECT_EQ(optimized_matmul->input(0), "a");
1249   EXPECT_EQ(optimized_matmul->input(1), "b");
1250   EXPECT_TRUE(optimized_matmul->attr().at("adj_x").b());
1251   EXPECT_TRUE(optimized_matmul->attr().at("adj_y").b());
1252 
1253   auto tensors = EvaluateNodes(output, item.fetch);
1254   ASSERT_EQ(tensors.size(), 1);
1255   test::ExpectTensorNear<complex64>(tensors[0], tensors_expected[0], 1e-6);
1256 }
1257 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshape)1258 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeIdentityReshape) {
1259   for (bool is_broadcastto : {false, true}) {
1260     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1261     Output inputs =
1262         ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
1263     Output inputs_shape = ops::Shape(s, inputs);
1264     // The target shape of the reshape is the concatenation of `batch_size` and
1265     // [3,28,28].
1266     Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
1267                                    ops::Const(s, {1}, {1}));
1268     Output target_shape = ops::Concat(
1269         s.WithOpName("target_shape"),
1270         {batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {}));
1271     if (is_broadcastto) {
1272       Output outputs = ops::Identity(s.WithOpName("outputs"),
1273                                      ops::BroadcastTo(s, inputs, target_shape));
1274     } else {
1275       Output outputs = ops::Identity(s.WithOpName("outputs"),
1276                                      ops::Reshape(s, inputs, target_shape));
1277     }
1278 
1279     GrapplerItem item;
1280     item.fetch = {"outputs"};
1281     TF_CHECK_OK(s.ToGraphDef(&item.graph));
1282     auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
1283     auto tensors_expected =
1284         EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
1285     ASSERT_EQ(tensors_expected.size(), 1);
1286 
1287     GraphDef output;
1288     ArithmeticOptimizer optimizer;
1289     EnableOnlyRemoveRedundantReshape(&optimizer);
1290     OptimizeTwiceAndPrune(&optimizer, &item, &output);
1291 
1292     EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1293     EXPECT_EQ(CountOpNodes(output, "BroadcastTo"), 0);
1294     auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
1295     ASSERT_EQ(tensors.size(), 1);
1296     test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1297   }
1298 }
1299 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes)1300 TEST_F(ArithmeticOptimizerTest,
1301        RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes) {
1302   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1303   Output inputs =
1304       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
1305   Output inputs_shape = ops::Shape(s, inputs);
1306   // The target shape of the reshape is the concatenation of `batch_size`, 3,
1307   // `height, and `width`.
1308   Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
1309                                  ops::Const(s, {1}, {1}));
1310   Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
1311                              ops::Const(s, {1}, {1}));
1312   Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
1313                             ops::Const(s, {1}, {1}));
1314   Output target_shape =
1315       ops::Concat(s.WithOpName("target_shape"),
1316                   {batch_size, ops::Const(s, {3}, {1}), height, width},
1317                   ops::Const(s, {0}, {}));
1318   Output reshape = ops::Reshape(s, inputs, target_shape);
1319   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1320 
1321   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
1322   GrapplerItem item;
1323   item.fetch = {"outputs"};
1324   item.feed = {{"Placeholder", x_t}};
1325   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1326 
1327   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1328   ASSERT_EQ(tensors_expected.size(), 1);
1329 
1330   GraphDef output;
1331   // Assume valid feed shape in aggressive mode.
1332   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1333   EnableOnlyRemoveRedundantReshape(&optimizer);
1334   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1335 
1336   EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1337   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1338   ASSERT_EQ(tensors.size(), 1);
1339   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1340 }
1341 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotAssumeValidFeeds)1342 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotAssumeValidFeeds) {
1343   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1344   Output inputs =
1345       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
1346   Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
1347   Output reshape = ops::Reshape(s, inputs, target_shape);
1348   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1349 
1350   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
1351   GrapplerItem item;
1352   item.fetch = {"outputs"};
1353   item.feed = {{"Placeholder", x_t}};
1354   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1355 
1356   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1357   ASSERT_EQ(tensors_expected.size(), 1);
1358 
1359   GraphDef output;
1360   ArithmeticOptimizer optimizer;
1361   EnableOnlyRemoveRedundantReshape(&optimizer);
1362   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1363 
1364   // The reshape is preserved because the shape of the placeholder can be
1365   // different from the shape of the actual feed.
1366   EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1367 
1368   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1369   ASSERT_EQ(tensors.size(), 1);
1370   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1371 }
1372 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode)1373 TEST_F(ArithmeticOptimizerTest,
1374        RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode) {
1375   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1376   Output inputs =
1377       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
1378   Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
1379   Output reshape = ops::Reshape(s, inputs, target_shape);
1380   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1381 
1382   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
1383   GrapplerItem item;
1384   item.fetch = {"outputs"};
1385   item.feed = {{"Placeholder", x_t}};
1386   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1387 
1388   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1389   ASSERT_EQ(tensors_expected.size(), 1);
1390 
1391   GraphDef output;
1392   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1393   EnableOnlyRemoveRedundantReshape(&optimizer);
1394   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1395 
1396   EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1397   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1398   ASSERT_EQ(tensors.size(), 1);
1399   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1400 }
1401 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshape)1402 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotIdentityReshape) {
1403   // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
1404   // be from [4,3,28,28] to [8,6,28,28].
1405   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1406   Output inputs =
1407       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
1408   Output reshape = ops::Reshape(s, inputs, ops::Const(s, {8, -1, 28, 28}, {4}));
1409   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1410 
1411   GrapplerItem item;
1412   item.fetch = {"outputs"};
1413   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1414   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28}));
1415   item.feed = {{"Placeholder", x_t}};
1416   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1417   ASSERT_EQ(tensors_expected.size(), 1);
1418 
1419   GraphDef output;
1420   ArithmeticOptimizer optimizer;
1421   EnableOnlyRemoveRedundantReshape(&optimizer);
1422   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1423 
1424   EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1425   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1426   ASSERT_EQ(tensors.size(), 1);
1427   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1428 }
1429 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes)1430 TEST_F(ArithmeticOptimizerTest,
1431        RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes) {
1432   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1433   Output inputs =
1434       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
1435   Output reshape = ops::Reshape(s, inputs, ops::Const(s, {-1, -1}, {2}));
1436   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1437 
1438   GrapplerItem item;
1439   item.fetch = {"outputs"};
1440   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1441 
1442   GraphDef output;
1443   ArithmeticOptimizer optimizer;
1444   EnableOnlyRemoveRedundantReshape(&optimizer);
1445   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1446 
1447   EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1448 }
1449 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeCombineReshapes)1450 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeCombineReshapes) {
1451   for (bool include_unary_chain : {false, true}) {
1452     // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The
1453     // two reshapes should be combined.
1454     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1455     Output nchw_vect_c =
1456         ops::Placeholder(s.WithOpName("nchw_vect_c"), DT_FLOAT,
1457                          ops::Placeholder::Shape({8, 3, 28, 28, 4}));
1458     Output transpose =
1459         ops::Transpose(s.WithOpName("transpose"), nchw_vect_c,
1460                        ops::Const(s.WithOpName("perm"), {0, 2, 3, 1, 4}, {5}));
1461     Output nhwc = ops::Reshape(
1462         s.WithOpName("nhwc"), transpose,
1463         ops::Const(
1464             s.WithControlDependencies(nchw_vect_c).WithOpName("nhwc_shape"),
1465             {8, 28, 28, 12}, {4}));
1466     Output flatten = ops::Reshape(
1467         s.WithOpName("flatten"),
1468         (include_unary_chain ? ops::Cos(s.WithOpName("Cos"), nhwc) : nhwc),
1469         ops::Const(s.WithOpName("flatten_shape"), {8, 28 * 28 * 12}, {2}));
1470     Output output0 = ops::Identity(s.WithOpName("output0"), flatten);
1471     Output output1 = ops::Identity(s.WithOpName("output1"), flatten);
1472 
1473     GraphDef graph;
1474     TF_CHECK_OK(s.ToGraphDef(&graph));
1475     auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28, 4}));
1476     auto eval =
1477         EvaluateNodes(graph, {"output0", "nhwc"}, {{"nchw_vect_c", x_t}});
1478 
1479     ASSERT_EQ(eval.size(), 2);
1480     auto expected_output_t = eval[0];
1481     auto nhwc_t = eval[1];
1482 
1483     {
1484       GrapplerItem item;
1485       item.graph = graph;
1486       item.fetch = {"output0", "output1"};
1487       item.feed = {{"nchw_vect_c", x_t}};
1488 
1489       GraphDef output;
1490       ArithmeticOptimizer optimizer;
1491       EnableOnlyRemoveRedundantReshape(&optimizer);
1492       OptimizeTwiceAndPrune(&optimizer, &item, &output);
1493 
1494       EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1495       auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1496       ASSERT_EQ(tensors.size(), 2);
1497       test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1498       test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1499     }
1500 
1501     // Test when the first reshape node output is the feed tensor.
1502     // (Expected no reshape removal to happen.)
1503     {
1504       GrapplerItem item;
1505       item.graph = graph;
1506       item.fetch = {"output0", "output1"};
1507       item.feed = {{"nhwc", nhwc_t}};
1508 
1509       GraphDef output;
1510       ArithmeticOptimizer optimizer;
1511       EnableOnlyRemoveRedundantReshape(&optimizer);
1512       OptimizeTwiceAndPrune(&optimizer, &item, &output);
1513 
1514       EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
1515       auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1516       ASSERT_EQ(tensors.size(), 2);
1517       test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1518       test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1519     }
1520 
1521     // Test when the first reshape node output is consumed by multiple nodes
1522     // (Expected no reshape removal to happen.)
1523     {
1524       Output output2 = ops::Identity(s.WithOpName("output2"), nhwc);
1525       GraphDef graph;
1526       TF_CHECK_OK(s.ToGraphDef(&graph));
1527       GrapplerItem item;
1528       item.graph = graph;
1529       item.fetch = {"output0", "output1", "output2"};
1530       item.feed = {{"nchw_vect_c", x_t}};
1531 
1532       GraphDef output;
1533       ArithmeticOptimizer optimizer;
1534       EnableOnlyRemoveRedundantReshape(&optimizer);
1535       OptimizeTwiceAndPrune(&optimizer, &item, &output);
1536 
1537       EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
1538       auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1539       ASSERT_EQ(tensors.size(), 3);
1540       test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1541       test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1542       test::ExpectTensorEqual<float>(tensors[2], nhwc_t);
1543     }
1544   }
1545 }
1546 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsCast)1547 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsCast) {
1548   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1549   Output nhwc_uint8 =
1550       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1551   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1552   Output nchw_fp32 =
1553       ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1554   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1555 
1556   GrapplerItem item;
1557   item.fetch = {"outputs"};
1558   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1559 
1560   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1561   auto tensors_expected =
1562       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1563   ASSERT_EQ(tensors_expected.size(), 1);
1564 
1565   GraphDef output;
1566   ArithmeticOptimizer optimizer;
1567   OptimizeAndPrune(&optimizer, &item, &output);
1568 
1569   const NodeDef* transpose_node = nullptr;
1570   for (const NodeDef& node : output.node()) {
1571     if (node.op() == "Transpose") {
1572       EXPECT_EQ(transpose_node, nullptr);
1573       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1574       transpose_node = &node;
1575     }
1576   }
1577   ASSERT_NE(transpose_node, nullptr);
1578 
1579   for (const NodeDef& node : output.node()) {
1580     if (node.op() == "Cast") {
1581       ASSERT_EQ(node.input_size(), 1);
1582       EXPECT_EQ(transpose_node->name(), NodeName(node.input(0)));
1583     }
1584   }
1585 
1586   auto tensors =
1587       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1588   ASSERT_EQ(tensors.size(), 1);
1589   test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1590 }
1591 
TEST_F(ArithmeticOptimizerTest,ReorderS2DCastProducerIsCast)1592 TEST_F(ArithmeticOptimizerTest, ReorderS2DCastProducerIsCast) {
1593   // TODO(jingyue): Evaluate S2D+Cast on GPU as well. We can't simply put nodes
1594   // under a /GPU:0 scope, because this test would fail if the testing machine
1595   // doesn't have a GPU. Maybe EvaluateNodes should allow soft placement?
1596   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1597   Output outputs =
1598       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1599   outputs = ops::Cast(s, outputs, DT_FLOAT);
1600   outputs = ops::SpaceToDepth(s, outputs, 2);
1601   outputs = ops::Identity(s.WithOpName("outputs"), outputs);
1602 
1603   GrapplerItem item;
1604   item.fetch = {"outputs"};
1605   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1606 
1607   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1608   auto tensors_expected =
1609       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1610   ASSERT_EQ(tensors_expected.size(), 1);
1611 
1612   GraphDef output;
1613   ArithmeticOptimizer optimizer;
1614   OptimizeAndPrune(&optimizer, &item, &output);
1615 
1616   const NodeDef* s2d_node = nullptr;
1617   for (const NodeDef& node : output.node()) {
1618     if (node.op() == "SpaceToDepth") {
1619       EXPECT_EQ(s2d_node, nullptr);
1620       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1621       s2d_node = &node;
1622     }
1623   }
1624   ASSERT_NE(s2d_node, nullptr);
1625 
1626   for (const NodeDef& node : output.node()) {
1627     if (node.op() == "Cast") {
1628       ASSERT_EQ(node.input_size(), 1);
1629       EXPECT_EQ(s2d_node->name(), NodeName(node.input(0)));
1630     }
1631   }
1632 
1633   auto tensors =
1634       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1635   ASSERT_EQ(tensors.size(), 1);
1636   test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1637 }
1638 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsTranspose)1639 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsTranspose) {
1640   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1641   Output nhwc_fp32 =
1642       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1643   Output nchw_fp32 =
1644       ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1645   Output nchw_uint8 = ops::Cast(s, nchw_fp32, DT_UINT8);
1646   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1647 
1648   GrapplerItem item;
1649   item.fetch = {"outputs"};
1650   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1651 
1652   auto input_t =
1653       GenerateConstantTensor<DT_FLOAT>(TensorShape({8, 28, 28, 3}), 42.0f);
1654   auto tensors_expected =
1655       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1656   ASSERT_EQ(tensors_expected.size(), 1);
1657 
1658   GraphDef output;
1659   ArithmeticOptimizer optimizer;
1660   OptimizeAndPrune(&optimizer, &item, &output);
1661 
1662   const NodeDef* cast_node = nullptr;
1663   for (const NodeDef& node : output.node()) {
1664     if (node.op() == "Cast") {
1665       EXPECT_EQ(cast_node, nullptr);
1666       cast_node = &node;
1667       ASSERT_EQ(node.input_size(), 1);
1668       EXPECT_EQ(NodeName(node.input(0)), "Placeholder");
1669     }
1670   }
1671   ASSERT_NE(cast_node, nullptr);
1672 
1673   for (const NodeDef& node : output.node()) {
1674     if (node.op() == "Transpose") {
1675       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1676       ASSERT_EQ(node.input_size(), 2);
1677       EXPECT_EQ(cast_node->name(), NodeName(node.input(0)));
1678     }
1679   }
1680 
1681   auto tensors =
1682       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1683   ASSERT_EQ(tensors.size(), 1);
1684   test::ExpectTensorEqual<uint8>(tensors[0], tensors_expected[0]);
1685 }
1686 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeReverseCast)1687 TEST_F(ArithmeticOptimizerTest, ReorderTransposeReverseCast) {
1688   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1689   Output nhwc_uint8 =
1690       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1691   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1692   Output nhwc_fp32_reversed =
1693       ops::Reverse(s, nhwc_fp32, ops::Const(s, {0}, {1}));
1694   Output nchw_fp32_reversed =
1695       ops::Transpose(s, nhwc_fp32_reversed, ops::Const(s, {0, 3, 1, 2}, {4}));
1696 
1697   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32_reversed);
1698 
1699   GrapplerItem item;
1700   item.fetch = {"outputs"};
1701   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1702 
1703   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1704   auto tensors_expected =
1705       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1706   ASSERT_EQ(tensors_expected.size(), 1);
1707 
1708   GraphDef output;
1709   ArithmeticOptimizer optimizer;
1710   OptimizeAndPrune(&optimizer, &item, &output);
1711 
1712   const NodeDef* reverse_node = nullptr;
1713   const NodeDef* transpose_node = nullptr;
1714   const NodeDef* cast_node = nullptr;
1715   for (const NodeDef& node : output.node()) {
1716     if (node.op() == "Transpose") {
1717       EXPECT_EQ(transpose_node, nullptr);
1718       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1719       transpose_node = &node;
1720     } else if (node.op() == "ReverseV2") {
1721       EXPECT_EQ(reverse_node, nullptr);
1722       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1723       reverse_node = &node;
1724     } else if (node.op() == "Cast") {
1725       cast_node = &node;
1726     }
1727   }
1728   ASSERT_NE(cast_node, nullptr);
1729   ASSERT_NE(reverse_node, nullptr);
1730   ASSERT_NE(transpose_node, nullptr);
1731   ASSERT_EQ(reverse_node->input_size(), 2);
1732   EXPECT_EQ(NodeName(reverse_node->input(0)), "Placeholder");
1733   ASSERT_EQ(transpose_node->input_size(), 2);
1734   EXPECT_EQ(NodeName(transpose_node->input(0)), reverse_node->name());
1735   ASSERT_EQ(cast_node->input_size(), 1);
1736   EXPECT_EQ(NodeName(cast_node->input(0)), transpose_node->name());
1737 
1738   auto tensors =
1739       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1740   ASSERT_EQ(tensors.size(), 1);
1741   test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1742 }
1743 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastCheckNumericsToIdentity)1744 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastCheckNumericsToIdentity) {
1745   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1746   Output nhwc_uint8 =
1747       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1748   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1749   Output nchw_fp32 = ops::CheckNumerics(s, nhwc_fp32, "foo");
1750   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1751 
1752   GrapplerItem item;
1753   item.fetch = {"outputs"};
1754   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1755 
1756   GraphDef output;
1757   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1758   CompareGraphs(item.graph, output);
1759 }
1760 
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsCast)1761 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsCast) {
1762   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1763   Output nhwc_fp32 =
1764       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1765   Output nhwc_uint8 = ops::Cast(s, nhwc_fp32, DT_UINT8);
1766   Output nchw_uint8 =
1767       ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1768   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1769 
1770   GrapplerItem item;
1771   item.fetch = {"outputs"};
1772   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1773 
1774   GraphDef output;
1775   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1776   CompareGraphs(item.graph, output);
1777 }
1778 
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsTranspose)1779 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsTranspose) {
1780   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1781   Output nhwc_uint8 =
1782       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1783   Output nchw_uint8 =
1784       ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1785   Output nchw_fp32 = ops::Cast(s, nchw_uint8, DT_FLOAT);
1786   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1787 
1788   GrapplerItem item;
1789   item.fetch = {"outputs"};
1790   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1791 
1792   GraphDef output;
1793   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1794   CompareGraphs(item.graph, output);
1795 }
1796 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposes)1797 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
1798   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1799   Output inputs_shape =
1800       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1801   Output inputs =
1802       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1803   Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1804   Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1805   Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
1806   Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
1807   Output transpose2 =
1808       ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
1809   Output transpose3 = ops::Transpose(s.WithOpName("transpose3"), inputs, perm3);
1810   Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1811   Output id2 = ops::Identity(s.WithOpName("id2"), transpose3);
1812 
1813   GrapplerItem item;
1814   item.fetch = {"id1", "id2"};
1815   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1816 
1817   GraphDef output;
1818   ArithmeticOptimizer optimizer;
1819   EnableOnlyRemoveIdentityTranspose(&optimizer);
1820   OptimizeAndPrune(&optimizer, &item, &output);
1821 
1822   std::set<string> nodes_after_optimization;
1823   for (const NodeDef& node : output.node()) {
1824     nodes_after_optimization.insert(node.name());
1825   }
1826   EXPECT_EQ(nodes_after_optimization,
1827             std::set<string>({"id1", "id2", "inputs_shape", "inputs"}));
1828 }
1829 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityConjugateTransposes)1830 TEST_F(ArithmeticOptimizerTest, RemoveIdentityConjugateTransposes) {
1831   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1832   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1833   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1834   Output z = ops::Complex(s.WithOpName("z"), re, im);
1835   Output perm = ops::Const(s.WithOpName("perm"), {0, 1}, {2});
1836   Output transpose = ops::ConjugateTranspose(s.WithOpName("trans"), z, perm);
1837   Output id = ops::Identity(s.WithOpName("id"), transpose);
1838 
1839   GrapplerItem item;
1840   item.fetch = {"id"};
1841   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1842 
1843   GraphDef output;
1844   ArithmeticOptimizer optimizer;
1845   EnableOnlyRemoveIdentityTranspose(&optimizer);
1846   OptimizeAndPrune(&optimizer, &item, &output);
1847   NodeMap node_map(&output);
1848 
1849   EXPECT_EQ(output.node_size(), 5);
1850 
1851   const string p = "ArithmeticOptimizer/RemoveIdentityTranspose";
1852   const string optimized_name = absl::StrCat(p, "_", "trans");
1853 
1854   const NodeDef* conj = node_map.GetNode(optimized_name);
1855   ASSERT_NE(conj, nullptr);
1856   EXPECT_EQ(conj->op(), "Conj");
1857   ASSERT_EQ(conj->input_size(), 1);
1858   EXPECT_EQ(conj->input(0), "z");
1859 }
1860 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesMultipleOutputs)1861 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
1862   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1863   Output inputs_shape =
1864       ops::Const(s.WithOpName("inputs_shape"), {8, 9, 28, 28}, {4});
1865   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1866                                    ops::Placeholder::Shape({8, 12, 28, 28}));
1867   OutputList split = ops::Split(s, ops::Const(s, 1), inputs, 3).output;
1868   Output perm1 = ops::Const(s, {0, 2, 3, 1}, {4});
1869   Output perm2 = ops::Const(s, {0, 3, 1, 2}, {4});
1870   Output branch0 = split[0];
1871   Output branch1 = ops::Transpose(s, ops::Transpose(s, split[1], perm1), perm2);
1872   Output branch2 = split[2];
1873   Output concat = ops::Concat(s, {branch0, branch1, branch2}, ops::Const(s, 1));
1874   Output outputs = ops::Identity(s.WithOpName("outputs"), concat);
1875 
1876   GrapplerItem item;
1877   item.fetch = {"outputs"};
1878   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1879 
1880   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28}));
1881   item.feed = {{"inputs", x_t}};
1882   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1883   ASSERT_EQ(tensors_expected.size(), 1);
1884 
1885   GraphDef output;
1886   ArithmeticOptimizer optimizer;
1887   EnableOnlyRemoveIdentityTranspose(&optimizer);
1888   OptimizeAndPrune(&optimizer, &item, &output);
1889 
1890   for (const NodeDef& node : output.node()) {
1891     if (node.op() == "Concat") {
1892       ASSERT_EQ(node.input_size(), 3);
1893       EXPECT_EQ(node.input(0), "Split");
1894       EXPECT_EQ(node.input(1), "Split:1");
1895       EXPECT_EQ(node.input(2), "Split:2");
1896     }
1897   }
1898 
1899   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1900   ASSERT_EQ(tensors.size(), 1);
1901   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1902 }
1903 
TEST_F(ArithmeticOptimizerTest,RemoveTransposesWithControlDependency)1904 TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
1905   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1906   Output inputs =
1907       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3}));
1908   Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0}));
1909   Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0}));
1910   Output outputs =
1911       ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2),
1912                     ops::Const(s.WithOpName("outputs_const"), 1.0f));
1913 
1914   GrapplerItem item;
1915   item.fetch = {"outputs"};
1916   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1917 
1918   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
1919   item.feed = {{"Placeholder", x_t}};
1920   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1921   ASSERT_EQ(tensors_expected.size(), 1);
1922 
1923   GraphDef output;
1924   ArithmeticOptimizer optimizer;
1925   EnableOnlyRemoveIdentityTranspose(&optimizer);
1926   OptimizeAndPrune(&optimizer, &item, &output);
1927 
1928   NodeMap node_map(&output);
1929   const NodeDef* outputs_node = node_map.GetNode("outputs");
1930   ASSERT_EQ(outputs_node->input_size(), 2);
1931   EXPECT_EQ(outputs_node->input(0), "outputs_const");
1932   EXPECT_EQ(outputs_node->input(1), "^Placeholder");
1933 
1934   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1935   ASSERT_EQ(tensors.size(), 1);
1936   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1937 }
1938 
TEST_F(ArithmeticOptimizerTest,NotRemoveTransposes)1939 TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
1940   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1941   Output inputs_shape =
1942       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1943   Output inputs =
1944       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1945   Output perm = ops::Const(s.WithOpName("perm"), {1, 2, 3, 0}, {4});
1946   Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm);
1947   Output transpose2 =
1948       ops::Transpose(s.WithOpName("transpose2"), transpose1, perm);
1949   Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2);
1950 
1951   GrapplerItem item;
1952   item.fetch = {"outputs"};
1953   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1954 
1955   GraphDef output;
1956   ArithmeticOptimizer optimizer;
1957   EnableOnlyRemoveIdentityTranspose(&optimizer);
1958   OptimizeAndPrune(&optimizer, &item, &output);
1959 
1960   EXPECT_EQ(output.node_size(), 6);
1961 }
1962 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesThroughChain)1963 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
1964   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1965   Output inputs_shape =
1966       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1967   Output inputs =
1968       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1969   Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1970   Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1971   Output transpose1 = ops::Transpose(
1972       s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
1973   Output identity = ops::Identity(s.WithOpName("id"), transpose1);
1974   Output transpose2 =
1975       ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
1976   Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1977 
1978   GrapplerItem item;
1979   item.fetch = {"id1"};
1980   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1981 
1982   GraphDef output;
1983   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1984   EnableOnlyRemoveIdentityTranspose(&optimizer);
1985   OptimizeAndPrune(&optimizer, &item, &output);
1986 
1987   std::set<string> nodes_after_optimization;
1988   for (const NodeDef& node : output.node()) {
1989     nodes_after_optimization.insert(node.name());
1990     if (node.name() == "id") {
1991       ASSERT_EQ(node.input_size(), 1);
1992       EXPECT_EQ(node.input(0), "inputs");
1993     }
1994     if (node.name() == "id1") {
1995       ASSERT_EQ(node.input_size(), 1);
1996       EXPECT_EQ(node.input(0), "id");
1997     }
1998   }
1999   EXPECT_EQ(nodes_after_optimization,
2000             std::set<string>({"id", "id1", "inputs_shape", "inputs"}));
2001 }
2002 
TEST_F(ArithmeticOptimizerTest,FoldMulToTransposeConv)2003 TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
2004   for (bool swap_inputs : {false, true}) {
2005     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2006     Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2007                                      ops::Placeholder::Shape({1, 28, 28, 3}));
2008     Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2009     Output scaled_inputs = ops::Multiply(s.WithOpName("scaled_inputs"),
2010                                          swap_inputs ? scale : inputs,
2011                                          swap_inputs ? inputs : scale);
2012     Output perm_nhwc_to_nchw =
2013         ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
2014     Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
2015                                         scaled_inputs, perm_nhwc_to_nchw);
2016     Output weights = ops::Const(s.WithOpName("weights"),
2017                                 Input::Initializer(127.0f, {5, 5, 3, 4}));
2018     Output conv =
2019         ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
2020                     "VALID", ops::Conv2D::DataFormat("NCHW"));
2021     Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2022 
2023     GrapplerItem item;
2024     item.fetch = {"outputs"};
2025     TF_CHECK_OK(s.ToGraphDef(&item.graph));
2026 
2027     //    LOG(INFO) << "Before:\n" << item.graph.DebugString();
2028     GraphDef output;
2029     ArithmeticOptimizer optimizer;
2030     EnableOnlyFoldMultipleIntoConv(&optimizer);
2031     OptimizeTwiceAndPrune(&optimizer, &item, &output);
2032 
2033     //    LOG(INFO) << "After:\n"  << output.DebugString();
2034     NodeMap node_map(&output);
2035     // `conv` is now a folded convolution with scaled weights.
2036     const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
2037     ASSERT_NE(folded_conv, nullptr);
2038 
2039     const NodeDef* folded_conv_weights =
2040         node_map.GetNode(folded_conv->input(1));
2041     ASSERT_NE(folded_conv_weights, nullptr);
2042     EXPECT_EQ(folded_conv_weights->op(), "Mul");
2043 
2044     // Its input should be a transpose of `inputs`.
2045     const NodeDef* transpose =
2046         node_map.GetNode(NodeName(folded_conv->input(0)));
2047     ASSERT_NE(transpose, nullptr);
2048     ASSERT_EQ(transpose->input_size(), 2);
2049     EXPECT_EQ(transpose->input(0), "inputs");
2050   }
2051 }
2052 
TEST_F(ArithmeticOptimizerTest,NotFoldMulAcrossPreservedTranspose)2053 TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
2054   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2055   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2056                                    ops::Placeholder::Shape({8, 28, 28, 3}));
2057   Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2058   Output scaled_inputs =
2059       ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
2060   Output perm_nhwc_to_nchw =
2061       ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
2062   Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
2063                                       scaled_inputs, perm_nhwc_to_nchw);
2064   Output weights = ops::Const(s.WithOpName("weights"),
2065                               Input::Initializer(127.0f, {5, 5, 3, 16}));
2066   Output conv =
2067       ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
2068                   "VALID", ops::Conv2D::DataFormat("NCHW"));
2069   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2070 
2071   Tensor inputs_nchw_tensor(DT_FLOAT, {8, 3, 28, 28});
2072   memset(const_cast<char*>(inputs_nchw_tensor.tensor_data().data()), 0,
2073          inputs_nchw_tensor.tensor_data().size());
2074 
2075   GrapplerItem item;
2076   item.fetch = {"outputs"};
2077   item.feed = {{"inputs_nchw", inputs_nchw_tensor}};
2078   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2079 
2080   GraphDef output;
2081   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
2082 
2083   item.graph.Swap(&output);
2084   TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
2085 
2086   NodeMap node_map(&output);
2087   const NodeDef* inputs_nchw_node_def =
2088       node_map.GetNode(inputs_nchw.node()->name());
2089   ASSERT_NE(inputs_nchw_node_def, nullptr);
2090   ASSERT_EQ(inputs_nchw_node_def->input_size(), 2);
2091   EXPECT_EQ(NodeName(inputs_nchw_node_def->input(0)),
2092             scaled_inputs.node()->name());
2093 }
2094 
TEST_F(ArithmeticOptimizerTest,FoldMulToConv)2095 TEST_F(ArithmeticOptimizerTest, FoldMulToConv) {
2096   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2097   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2098                                    ops::Placeholder::Shape({8, 28, 28, 28, 3}));
2099   Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2100   Output scaled_inputs =
2101       ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
2102   Output weights = ops::Const(s.WithOpName("weights"),
2103                               Input::Initializer(127.0f, {5, 5, 5, 3, 16}));
2104   Output conv = ops::Conv3D(s.WithOpName("conv"), scaled_inputs, weights,
2105                             {1, 1, 1, 1, 1}, "VALID");
2106   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2107 
2108   GrapplerItem item;
2109   item.fetch = {"outputs"};
2110   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2111 
2112   GraphDef output;
2113   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
2114 
2115   item.graph.Swap(&output);
2116   TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
2117 
2118   NodeMap node_map(&output);
2119   // `conv` is now a folded convolution on `inputs` and scaled weights.
2120   const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
2121   ASSERT_NE(folded_conv, nullptr);
2122   ASSERT_EQ(folded_conv->input_size(), 2);
2123   CHECK_EQ(NodeName(folded_conv->input(0)), inputs.node()->name());
2124   const NodeDef* folded_conv_input_1 =
2125       node_map.GetNode(NodeName(folded_conv->input(1)));
2126   ASSERT_NE(folded_conv_input_1, nullptr);
2127   CHECK_EQ(folded_conv_input_1->op(), "Mul");
2128 }
2129 
TEST_F(ArithmeticOptimizerTest,OptimizeCastMulTransposeConv)2130 TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
2131   // This unit test exercises two optimizations, folding mul into conv, and
2132   // reordering cast and transpose.
2133   //
2134   //   Conv2D(Transpose(Mul(Cast(I), S)), W)
2135   //     =>
2136   //   Conv2D(Transpose(Cast(I)), W*S)
2137   //     =>
2138   //   Conv2D(Cast(Transpose(I)), W*S)
2139   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
2140 
2141   Output inputs =
2142       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
2143   Output cast = ops::Cast(s, inputs, DT_FLOAT);
2144   Output mul = ops::Mul(s, cast, ops::Const(s, 1.0f / 255.0f));
2145   Output transpose =
2146       ops::Transpose(s, mul, ops::Const(s.WithOpName("perm"), {0, 3, 1, 2}));
2147   Output weights = ops::Const(s.WithOpName("weights"),
2148                               Input::Initializer(127.0f, {5, 5, 3, 16}));
2149   Output conv = ops::Conv2D(s, transpose, weights, {1, 1, 1, 1}, "VALID",
2150                             ops::Conv2D::DataFormat("NCHW"));
2151   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2152 
2153   GrapplerItem item;
2154   item.fetch = {"outputs"};
2155   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2156 
2157   GraphDef output;
2158   ArithmeticOptimizer optimizer;  // all optimization stages are on
2159   OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
2160   NodeMap node_map(&output);
2161 
2162   // Expected names for reordered cast and transpose.
2163   const string p = "ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_";
2164   const string optimized_cast_name = absl::StrCat(p, "float_Cast");
2165   const string optimized_transpose_name = absl::StrCat(p, "uint8_Transpose");
2166 
2167   // Expected names for folded multiply and conv.
2168   const string optimized_weights =
2169       "ArithmeticOptimizer/FoldMultiplyIntoConv_scaled_Conv2D_weights";
2170 
2171   const NodeDef* inputs_node = node_map.GetNode("Placeholder");
2172   const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name);
2173   const NodeDef* cast_node = node_map.GetNode(optimized_cast_name);
2174 
2175   const NodeDef* weights_node = node_map.GetNode(optimized_weights);
2176   const NodeDef* conv_node = node_map.GetNode("Conv2D");
2177 
2178   ASSERT_NE(inputs_node, nullptr);
2179   ASSERT_NE(transpose_node, nullptr);
2180   ASSERT_NE(cast_node, nullptr);
2181   ASSERT_NE(weights_node, nullptr);
2182   ASSERT_NE(conv_node, nullptr);
2183 
2184   EXPECT_EQ(output.node_size(), 7);
2185   ASSERT_EQ(transpose_node->input_size(), 2);
2186   EXPECT_EQ(transpose_node->input(0), inputs_node->name());
2187   ASSERT_EQ(cast_node->input_size(), 1);
2188   EXPECT_EQ(cast_node->input(0), transpose_node->name());
2189   ASSERT_EQ(conv_node->input_size(), 2);
2190   EXPECT_EQ(conv_node->input(0), cast_node->name());
2191   EXPECT_EQ(conv_node->input(1), weights_node->name());
2192 }
2193 
TEST_F(ArithmeticOptimizerTest,OptimizeMultipleMulTransposeConv)2194 TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
2195   // This unit test exercises optimization of folding mul into conv for
2196   // multiple nodes in the graph.
2197   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
2198 
2199   GrapplerItem item;
2200   Output conv[2];
2201 
2202   for (int i = 0; i < 2; ++i) {
2203     Output inputs =
2204         ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28}));
2205     Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f));
2206     Output weights = ops::Const(s.WithOpName("weights"),
2207                                 Input::Initializer(127.0f, {5, 5, 3, 16}));
2208     conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID",
2209                           ops::Conv2D::DataFormat("NCHW"));
2210   }
2211   Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]);
2212 
2213   item.fetch = {"outputs"};
2214   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2215 
2216   GraphDef output;
2217   ArithmeticOptimizer optimizer;
2218   EnableOnlyFoldMultipleIntoConv(&optimizer);
2219   OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
2220 
2221   NodeMap node_map(&output);
2222 
2223   using absl::StrCat;
2224   const string p = "ArithmeticOptimizer/FoldMultiplyIntoConv_";
2225   const string optimized_weights = StrCat(p, "scaled_Conv2D_weights");
2226   const string optimized_weights_1 = StrCat(p, "scaled_Conv2D_1_weights_1");
2227 
2228   const NodeDef* weights_node = node_map.GetNode(optimized_weights);
2229   const NodeDef* weights_node_1 = node_map.GetNode(optimized_weights_1);
2230   const NodeDef* conv_node = node_map.GetNode("Conv2D");
2231   const NodeDef* conv_node_1 = node_map.GetNode("Conv2D_1");
2232 
2233   ASSERT_NE(weights_node, nullptr);
2234   ASSERT_NE(weights_node_1, nullptr);
2235   ASSERT_NE(conv_node, nullptr);
2236   ASSERT_NE(conv_node_1, nullptr);
2237 
2238   ASSERT_EQ(conv_node->input_size(), 2);
2239   ASSERT_EQ(conv_node_1->input_size(), 2);
2240   EXPECT_EQ(conv_node->input(1), weights_node->name());
2241   EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
2242 }
2243 
TEST_F(ArithmeticOptimizerTest,CombineBitcasts)2244 TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
2245   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2246   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8,
2247                                    ops::Placeholder::Shape({2, 3}));
2248   Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8);
2249   Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8);
2250   Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
2251 
2252   GrapplerItem item;
2253   item.fetch = {"outputs"};
2254   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2255 
2256   auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3}));
2257   item.feed = {{"inputs", x_t}};
2258   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2259   ASSERT_EQ(tensors_expected.size(), 1);
2260 
2261   GraphDef output;
2262   ArithmeticOptimizer optimizer;
2263   EnableOnlyRemoveRedundantBitcast(&optimizer);
2264 
2265   OptimizeAndPrune(&optimizer, &item, &output);
2266   NodeMap node_map(&output);
2267 
2268   // Bitcasts combined into a single op and inputs redirected to updated Bitcast
2269   EXPECT_EQ(output.node_size(), 3);
2270   EXPECT_EQ(CountOpNodes(output, "Bitcast"), 1);
2271   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
2272 
2273   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2274   ASSERT_EQ(tensors.size(), 1);
2275   test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2276 }
2277 
TEST_F(ArithmeticOptimizerTest,CombineAndRemoveBitcasts)2278 TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
2279   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2280   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
2281                                    ops::Placeholder::Shape({2, 3}));
2282   Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
2283   Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
2284   Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
2285 
2286   GrapplerItem item;
2287   item.fetch = {"outputs"};
2288   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2289 
2290   auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
2291   item.feed = {{"inputs", x_t}};
2292   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2293   ASSERT_EQ(tensors_expected.size(), 1);
2294 
2295   GraphDef output;
2296   ArithmeticOptimizer optimizer;
2297   EnableOnlyRemoveRedundantBitcast(&optimizer);
2298 
2299   OptimizeAndPrune(&optimizer, &item, &output);
2300   NodeMap node_map(&output);
2301 
2302   // Bitcasts removed and inputs redirected to outputs
2303   EXPECT_EQ(output.node_size(), 2);
2304   EXPECT_EQ(CountOpNodes(output, "Bitcast"), 0);
2305   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
2306 
2307   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2308   ASSERT_EQ(tensors.size(), 1);
2309   test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2310 }
2311 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantCast)2312 TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
2313   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2314   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
2315                                    ops::Placeholder::Shape({2, 3}));
2316   Output cast = ops::Cast(s, inputs, DT_INT8);
2317   Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
2318 
2319   GrapplerItem item;
2320   item.fetch = {"outputs"};
2321   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2322 
2323   auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
2324   item.feed = {{"inputs", x_t}};
2325   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2326   ASSERT_EQ(tensors_expected.size(), 1);
2327 
2328   GraphDef output;
2329   ArithmeticOptimizer optimizer;
2330   EnableOnlyRemoveRedundantCast(&optimizer);
2331 
2332   OptimizeAndPrune(&optimizer, &item, &output);
2333   NodeMap node_map(&output);
2334 
2335   // Cast removed and inputs redirected to outputs
2336   EXPECT_EQ(output.node_size(), 2);
2337   EXPECT_EQ(CountOpNodes(output, "Cast"), 0);
2338   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
2339 
2340   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2341   ASSERT_EQ(tensors.size(), 1);
2342   test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2343 }
2344 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfIdenticalShape)2345 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfIdenticalShape) {
2346   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2347   tensorflow::Scope sx = s.NewSubScope("x");
2348   tensorflow::Scope sy = s.NewSubScope("y");
2349 
2350   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2351   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2352   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2353   auto add_bc = ops::Add(sx.WithOpName("Add_bc"), b, c);
2354   auto add_abc = ops::Add(sy.WithOpName("Add_abc"), a, add_bc);
2355 
2356   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2357 
2358   GrapplerItem item;
2359   item.fetch = {"outputs"};
2360   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2361 
2362   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2363   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2364   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2365   std::vector<std::pair<string, Tensor>> feed = {
2366       {"a", a_t}, {"b", b_t}, {"c", c_t}};
2367   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2368   ASSERT_EQ(tensors_expected.size(), 1);
2369 
2370   GraphDef output;
2371   ArithmeticOptimizer optimizer;
2372   EnableOnlyAddToAddNCombining(&optimizer);
2373 
2374   OptimizeAndPrune(&optimizer, &item, &output);
2375 
2376   // We expect the following rewrite(s) to occur:
2377   //
2378   //     +
2379   //    / \
2380   //   a   +         -->    AddN(a, b, c)
2381   //      / \
2382   //     b   c
2383   EXPECT_EQ(output.node_size(), 5);
2384 
2385   NodeMap node_map(&output);
2386 
2387   // check add tree was replaced with AddN
2388   const NodeDef* collapsed_add =
2389       node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2390   ASSERT_NE(collapsed_add, nullptr);
2391 
2392   EXPECT_EQ(collapsed_add->op(), "AddN");
2393   ASSERT_EQ(collapsed_add->input_size(), 3);
2394   EXPECT_EQ(collapsed_add->input(0), "a");
2395   EXPECT_EQ(collapsed_add->input(1), "b");
2396   EXPECT_EQ(collapsed_add->input(2), "c");
2397 
2398   // check output was re-wired to new node
2399   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2400   ASSERT_NE(updated_outputs, nullptr);
2401   ASSERT_EQ(updated_outputs->input_size(), 1);
2402   EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
2403 
2404   auto tensors = EvaluateNodes(output, item.fetch, feed);
2405   ASSERT_EQ(tensors.size(), 1);
2406   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2407 }
2408 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMultiplePasses)2409 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
2410   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2411 
2412   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2413   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2414   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2415   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2416   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2417 
2418   auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2419   auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2420   auto z = ops::Variable(s.WithOpName("z"), {2, 2}, DT_FLOAT);
2421   auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2422   auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2423 
2424   auto mul = ops::Multiply(s.WithOpName("Mul"), add_abc, add_xyz);
2425   auto outputs = ops::Identity(s.WithOpName("outputs"), mul);
2426 
2427   GrapplerItem item;
2428   item.fetch = {"outputs"};
2429   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2430 
2431   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2432   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2433   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2434   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2435   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2436   auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2437   std::vector<std::pair<string, Tensor>> feed = {
2438       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2439   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2440   ASSERT_EQ(tensors_expected.size(), 1);
2441 
2442   GraphDef output;
2443   ArithmeticOptimizer optimizer;
2444   EnableOnlyAddToAddNCombining(&optimizer);
2445 
2446   OptimizeAndPrune(&optimizer, &item, &output);
2447 
2448   // We expect the following rewrite(s) to occur:
2449   //
2450   //         *
2451   //      /     \
2452   //     +       +                        *
2453   //    / \     / \                    /     \
2454   //   +   c   x   + -->    AddN(a, b, c)  AddN(x, y, z))
2455   //  / \         / \
2456   // a   b       y   z
2457   EXPECT_EQ(output.node_size(), 10);
2458 
2459   NodeMap node_map(&output);
2460 
2461   // check left Add subtree replaced with AddN
2462   const NodeDef* collapsed_left =
2463       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2464   ASSERT_NE(collapsed_left, nullptr);
2465 
2466   EXPECT_EQ(collapsed_left->op(), "AddN");
2467   ASSERT_EQ(collapsed_left->input_size(), 3);
2468   EXPECT_EQ(collapsed_left->input(0), "a");
2469   EXPECT_EQ(collapsed_left->input(1), "b");
2470   EXPECT_EQ(collapsed_left->input(2), "c");
2471 
2472   // check right Add subtree replaced with AddN
2473   const NodeDef* collapsed_right =
2474       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz");
2475   ASSERT_NE(collapsed_right, nullptr);
2476 
2477   EXPECT_EQ(collapsed_right->op(), "AddN");
2478   ASSERT_EQ(collapsed_right->input_size(), 3);
2479   EXPECT_EQ(collapsed_right->input(0), "x");
2480   EXPECT_EQ(collapsed_right->input(1), "y");
2481   EXPECT_EQ(collapsed_right->input(2), "z");
2482 
2483   // check that Mul inputs re-wired to new Nodes
2484   const NodeDef* updated_mul = node_map.GetNode("Mul");
2485   ASSERT_NE(updated_mul, nullptr);
2486 
2487   EXPECT_EQ(updated_mul->op(), "Mul");
2488   ASSERT_EQ(updated_mul->input_size(), 2);
2489   EXPECT_EQ(updated_mul->input(0), collapsed_left->name());
2490   EXPECT_EQ(updated_mul->input(1), collapsed_right->name());
2491 
2492   auto tensors = EvaluateNodes(output, item.fetch, feed);
2493   ASSERT_EQ(tensors.size(), 1);
2494   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2495 }
2496 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddInputMultipleTimes)2497 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputMultipleTimes) {
2498   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2499 
2500   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2501   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2502   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2503   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2504   auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
2505   auto add_all = ops::Add(s.WithOpName("Add_all"), add_ab, add_bc);
2506   auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2507 
2508   GrapplerItem item;
2509   item.fetch = {"outputs"};
2510   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2511 
2512   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2513   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2514   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2515   std::vector<std::pair<string, Tensor>> feed = {
2516       {"a", a_t}, {"b", b_t}, {"c", c_t}};
2517   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2518   ASSERT_EQ(tensors_expected.size(), 1);
2519 
2520   GraphDef output;
2521   ArithmeticOptimizer optimizer;
2522   EnableOnlyAddToAddNCombining(&optimizer);
2523 
2524   OptimizeAndPrune(&optimizer, &item, &output);
2525 
2526   // We expect the following rewrite(s) to occur:
2527   //
2528   //     +
2529   //    / \
2530   //   +   +     -->    AddN(a, b, b, c)
2531   //  / \ / \                   ^
2532   // a   b   c                  b added twice!
2533   EXPECT_EQ(output.node_size(), 5);
2534 
2535   NodeMap node_map(&output);
2536 
2537   // check Add tree replaced with AddN
2538   const NodeDef* collapsed_add =
2539       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_all");
2540   ASSERT_NE(collapsed_add, nullptr);
2541 
2542   EXPECT_EQ(collapsed_add->op(), "AddN");
2543   ASSERT_EQ(collapsed_add->input_size(), 4);
2544   EXPECT_EQ(collapsed_add->input(0), "a");
2545   EXPECT_EQ(collapsed_add->input(1), "b");
2546   EXPECT_EQ(collapsed_add->input(2), "b");
2547   EXPECT_EQ(collapsed_add->input(3), "c");
2548 
2549   auto tensors = EvaluateNodes(output, item.fetch, feed);
2550   ASSERT_EQ(tensors.size(), 1);
2551   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2552 }
2553 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfSymbolicallyEqualShape)2554 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfSymbolicallyEqualShape) {
2555   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2556 
2557   // unknown input shape propagated symbolically through the graph
2558   auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
2559 
2560   // [a, b, c] have symbolically equal shapes
2561   auto a = ops::Sqrt(s.WithOpName("a"), input);
2562   auto b = ops::Square(s.WithOpName("b"), input);
2563   auto c = ops::Round(s.WithOpName("c"), input);
2564 
2565   // [add_ab, add_abc] shape must be inferred from inputs
2566   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2567   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2568 
2569   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2570 
2571   GrapplerItem item;
2572   item.fetch = {"outputs"};
2573   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2574 
2575   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2576   std::vector<std::pair<string, Tensor>> feed = {{"input", x_t}};
2577   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2578   ASSERT_EQ(tensors_expected.size(), 1);
2579 
2580   GraphDef output;
2581   ArithmeticOptimizer optimizer;
2582   EnableOnlyAddToAddNCombining(&optimizer);
2583 
2584   OptimizeAndPrune(&optimizer, &item, &output);
2585 
2586   // We expect the following rewrite(s) to occur:
2587   //
2588   //     +
2589   //    / \
2590   //   +   c      -->    AddN(a, b, c)
2591   //  / \
2592   // a   b
2593   EXPECT_EQ(output.node_size(), 6);
2594 
2595   NodeMap node_map(&output);
2596 
2597   // check add tree was replaced with AddN
2598   const NodeDef* collapsed_add =
2599       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2600   ASSERT_NE(collapsed_add, nullptr);
2601   EXPECT_EQ(collapsed_add->op(), "AddN");
2602   ASSERT_EQ(collapsed_add->input_size(), 3);
2603   EXPECT_EQ(collapsed_add->input(0), "a");
2604   EXPECT_EQ(collapsed_add->input(1), "b");
2605   EXPECT_EQ(collapsed_add->input(2), "c");
2606 
2607   // check output was re-wired to new node
2608   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2609   ASSERT_NE(updated_outputs, nullptr);
2610   ASSERT_EQ(updated_outputs->input_size(), 1);
2611   EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
2612 
2613   auto tensors = EvaluateNodes(output, item.fetch, feed);
2614   ASSERT_EQ(tensors.size(), 1);
2615   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2616 }
2617 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCast)2618 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCast) {
2619   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2620 
2621   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2622   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2623   auto c = ops::Variable(s.WithOpName("c"), {32, 32, 32}, DT_FLOAT);
2624   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2625   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2626 
2627   auto x = ops::Variable(s.WithOpName("x"), {32}, DT_FLOAT);
2628   auto y = ops::Variable(s.WithOpName("y"), {32, 32}, DT_FLOAT);
2629   auto z = ops::Variable(s.WithOpName("z"), {32, 32, 32}, DT_FLOAT);
2630   auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2631   auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2632 
2633   auto add_all = ops::Add(s.WithOpName("AddAll"), add_abc, add_xyz);
2634   auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2635 
2636   GrapplerItem item;
2637   item.fetch = {"outputs"};
2638   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2639 
2640   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2641   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2642   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2643   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2644   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2645   auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2646   std::vector<std::pair<string, Tensor>> feed = {
2647       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2648   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2649   ASSERT_EQ(tensors_expected.size(), 1);
2650 
2651   GraphDef output;
2652   ArithmeticOptimizer optimizer;
2653   EnableOnlyAddToAddNCombining(&optimizer);
2654 
2655   OptimizeAndPrune(&optimizer, &item, &output);
2656 
2657   // We expect the following rewrite(s) to occur:
2658   //  1) [a, x], [b, y], [c, z] - aggregate same shapes first
2659   //  2) Build an aggregation tree minimizing cost of broadcast
2660   //
2661   //         +                              +
2662   //      /     \                       /       \
2663   //     +       +                     +       AddN(c, z)
2664   //    / \     / \                 /     \
2665   //   +   c   x   + -->    AddN(a, x)  AddN(b, y)
2666   //  / \         / \
2667   // a   b       y   z
2668   EXPECT_EQ(output.node_size(), 12);
2669   NodeMap node_map(&output);
2670 
2671   // expected names of outer and inner nodes
2672   string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_AddAll";
2673   string outer_0_add_name =
2674       "ArithmeticOptimizer/AddOpsRewrite_Internal_0_AddAll";
2675   string inner_0_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_AddAll";
2676   string inner_1_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_1_AddAll";
2677   string inner_2_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_2_AddAll";
2678 
2679   // Add [a, x] first
2680   const NodeDef* add_ax_node = node_map.GetNode(inner_0_add_name);
2681   ASSERT_NE(add_ax_node, nullptr);
2682   EXPECT_EQ(add_ax_node->op(), "AddN");
2683   ASSERT_EQ(add_ax_node->input_size(), 2);
2684   EXPECT_EQ(add_ax_node->input(0), "a");
2685   EXPECT_EQ(add_ax_node->input(1), "x");
2686 
2687   // Then add [b, y]
2688   const NodeDef* add_by_node = node_map.GetNode(inner_1_add_name);
2689   ASSERT_NE(add_by_node, nullptr);
2690   EXPECT_EQ(add_by_node->op(), "AddN");
2691   ASSERT_EQ(2, add_by_node->input_size());
2692   EXPECT_EQ(add_by_node->input(0), "b");
2693   EXPECT_EQ(add_by_node->input(1), "y");
2694 
2695   // Then add [c, z]
2696   const NodeDef* add_cz_node = node_map.GetNode(inner_2_add_name);
2697   ASSERT_NE(add_cz_node, nullptr);
2698   EXPECT_EQ(add_cz_node->op(), "AddN");
2699   ASSERT_EQ(add_cz_node->input_size(), 2);
2700   EXPECT_EQ(add_cz_node->input(0), "c");
2701   EXPECT_EQ(add_cz_node->input(1), "z");
2702 
2703   // Then add results together starting from smaller shapes [a, x] + [b, y]
2704   const NodeDef* outer_0_node = node_map.GetNode(outer_0_add_name);
2705   ASSERT_NE(outer_0_node, nullptr);
2706   EXPECT_EQ(outer_0_node->op(), "AddV2");
2707   ASSERT_EQ(outer_0_node->input_size(), 2);
2708   EXPECT_EQ(outer_0_node->input(0), inner_0_add_name);
2709   EXPECT_EQ(outer_0_node->input(1), inner_1_add_name);
2710 
2711   // And finally top level Add node
2712   const NodeDef* outer_node = node_map.GetNode(outer_add_name);
2713   ASSERT_NE(outer_node, nullptr);
2714   EXPECT_EQ(outer_node->op(), "AddV2");
2715   ASSERT_EQ(outer_node->input_size(), 2);
2716   EXPECT_EQ(outer_node->input(0), outer_0_add_name);
2717   EXPECT_EQ(outer_node->input(1), inner_2_add_name);
2718 
2719   // And outputs reading new top level Add node
2720   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2721   ASSERT_NE(updated_outputs, nullptr);
2722   ASSERT_EQ(updated_outputs->input_size(), 1);
2723   EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2724 
2725   auto tensors = EvaluateNodes(output, item.fetch, feed);
2726   ASSERT_EQ(tensors.size(), 1);
2727   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2728 }
2729 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCastWithSymbolicShapes)2730 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCastWithSymbolicShapes) {
2731   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2732 
2733   // We have a small input with one unknown dimension
2734   auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_DOUBLE);
2735 
2736   // And second input which is larger, but has the same unknown dimension
2737   // device spec prevents this node from rewriting
2738   auto d = "/device:CPU:0";
2739   auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_DOUBLE);
2740   auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v);
2741 
2742   // [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32}
2743   auto a = ops::Sqrt(s.WithOpName("a"), small);
2744   auto b = ops::Square(s.WithOpName("b"), large);
2745   auto c = ops::Round(s.WithOpName("c"), small);
2746 
2747   // [add_ab, add_abc] shape must be inferred from inputs
2748   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2749   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2750 
2751   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2752 
2753   GrapplerItem item;
2754   item.fetch = {"outputs"};
2755   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2756 
2757   auto s_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({8, 1, 1}));
2758   auto v_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({1, 32, 32}));
2759   std::vector<std::pair<string, Tensor>> feed = {{"small", s_t}, {"v", v_t}};
2760   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2761   ASSERT_EQ(tensors_expected.size(), 1);
2762 
2763   GraphDef output;
2764   ArithmeticOptimizer optimizer;
2765   EnableOnlyAddToAddNCombining(&optimizer);
2766   OptimizeAndPrune(&optimizer, &item, &output);
2767 
2768   // We expect the following rewrite(s) to occur: it's much cheaper to add small
2769   // tensors, and do the broadcast just once
2770   //
2771   //     +                  +
2772   //    / \                / \
2773   //   +   c      -->     +   b
2774   //  / \                / \
2775   // a   b              a   c
2776   EXPECT_EQ(output.node_size(), 9);
2777   NodeMap node_map(&output);
2778 
2779   // expected names of outer and inner nodes
2780   string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_Add_abc";
2781   string inner_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_Add_abc";
2782 
2783   // outer Add node
2784   const NodeDef* outer_add = node_map.GetNode(outer_add_name);
2785   ASSERT_NE(outer_add, nullptr);
2786   EXPECT_EQ(outer_add->op(), "AddV2");
2787   ASSERT_EQ(outer_add->input_size(), 2);
2788   EXPECT_EQ(outer_add->input(0), inner_add_name);
2789   EXPECT_EQ(outer_add->input(1), "b");
2790 
2791   // inner AddN node
2792   const NodeDef* inner_add = node_map.GetNode(inner_add_name);
2793   ASSERT_NE(inner_add, nullptr);
2794   ASSERT_EQ(inner_add->input_size(), 2);
2795   EXPECT_EQ(inner_add->input(0), "a");
2796   EXPECT_EQ(inner_add->input(1), "c");
2797 
2798   // check output was re-wired to new node
2799   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2800   ASSERT_NE(updated_outputs, nullptr);
2801   ASSERT_EQ(updated_outputs->input_size(), 1);
2802   EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2803 
2804   auto tensors = EvaluateNodes(output, item.fetch, feed);
2805   ASSERT_EQ(tensors.size(), 1);
2806   test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
2807 }
2808 
TEST_F(ArithmeticOptimizerTest,RemoveNegation)2809 TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
2810   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2811   auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2812   auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2813   Output neg_x = ops::Neg(s.WithOpName("Neg_x"), x);
2814   Output neg_y = ops::Neg(s.WithOpName("Neg_y"), y);
2815   Output add_x_y = ops::Add(s.WithOpName("Add_x_y"), x, y);
2816   Output add_negx_y = ops::Add(s.WithOpName("Add_negx_y"), neg_x, y);
2817   Output add_x_negy = ops::Add(s.WithOpName("Add_x_negy"), x, neg_y);
2818   Output add_negx_negy = ops::Add(s.WithOpName("Add_negx_negy"), neg_x, neg_y);
2819   Output sub_x_y = ops::Sub(s.WithOpName("Sub_x_y"), x, y);
2820   Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
2821   Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
2822   Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
2823   Output neg_x_with_dep = ops::Neg(
2824       s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
2825   Output add_negx_with_dep_y =
2826       ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
2827   auto add_all =
2828       ops::AddN(s.WithOpName("add_all"),
2829                 {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
2830                  sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
2831 
2832   GrapplerItem item;
2833   item.fetch = {"add_all"};
2834   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2835 
2836   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2837   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2838   std::vector<std::pair<string, Tensor>> feed = {{"x", x_t}, {"y", y_t}};
2839   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2840   ASSERT_EQ(tensors_expected.size(), 1);
2841 
2842   GraphDef output;
2843   ArithmeticOptimizer optimizer;
2844   EnableOnlyRemoveNegation(&optimizer);
2845   OptimizeTwice(&optimizer, &item, &output);
2846 
2847   EXPECT_EQ(output.node_size(), item.graph.node_size());
2848   int found = 0;
2849   for (int i = 0; i < output.node_size(); ++i) {
2850     const NodeDef& node = output.node(i);
2851     if (node.name() == "Add_negx_y") {
2852       ++found;
2853       EXPECT_EQ(node.op(), "Sub");
2854       ASSERT_EQ(node.input_size(), 2);
2855       EXPECT_EQ(node.input(0), "y");
2856       EXPECT_EQ(node.input(1), "x");
2857     } else if (node.name() == "Add_x_negy") {
2858       ++found;
2859       EXPECT_EQ(node.op(), "Sub");
2860       ASSERT_EQ(node.input_size(), 2);
2861       EXPECT_EQ(node.input(0), "x");
2862       EXPECT_EQ(node.input(1), "y");
2863     } else if (node.name() == "Add_negx_negy") {
2864       ++found;
2865       EXPECT_EQ(node.op(), "Sub");
2866       ASSERT_EQ(node.input_size(), 2);
2867       EXPECT_EQ(node.input(0), "Neg_x");
2868       EXPECT_EQ(node.input(1), "y");
2869     } else if (node.name() == "Sub_x_negy") {
2870       ++found;
2871       EXPECT_EQ(node.op(), "AddV2");
2872       ASSERT_EQ(node.input_size(), 2);
2873       EXPECT_EQ(node.input(0), "x");
2874       EXPECT_EQ(node.input(1), "y");
2875     } else if (node.name() == "Sub_negx_negy") {
2876       ++found;
2877       EXPECT_EQ(node.op(), "Sub");
2878       ASSERT_EQ(node.input_size(), 2);
2879       EXPECT_EQ(node.input(0), "y");
2880       EXPECT_EQ(node.input(1), "x");
2881     } else if (node.name() == "Add_negx_with_dep_y") {
2882       ++found;
2883       EXPECT_EQ(node.op(), "Sub");
2884       ASSERT_EQ(node.input_size(), 3);
2885       EXPECT_EQ(node.input(0), "y");
2886       EXPECT_EQ(node.input(1), "x");
2887       EXPECT_EQ(node.input(2), "^Add_x_y");
2888     }
2889   }
2890   EXPECT_EQ(found, 6);
2891 
2892   auto tensors = EvaluateNodes(output, item.fetch, feed);
2893   ASSERT_EQ(tensors.size(), 1);
2894   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2895 }
2896 
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMul)2897 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
2898   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2899   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2900   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2901   Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2902   Output div_x_sqrt_y = ops::Div(s.WithOpName("output"), x, sqrt_y);
2903 
2904   GrapplerItem item;
2905   item.fetch = {"output"};
2906   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2907   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2908   ASSERT_EQ(tensors_expected.size(), 1);
2909 
2910   GraphDef output;
2911   ArithmeticOptimizer optimizer;
2912   EnableOnlySqrtDivToRsqrtMul(&optimizer);
2913   OptimizeAndPrune(&optimizer, &item, &output);
2914   auto tensors = EvaluateNodes(output, item.fetch);
2915   ASSERT_EQ(tensors.size(), 1);
2916 
2917   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2918   EXPECT_EQ(output.node_size(), item.graph.node_size());
2919   for (int i = 0; i < output.node_size(); ++i) {
2920     const NodeDef& node = output.node(i);
2921     if (node.name() == "output") {
2922       EXPECT_EQ(node.op(), "Mul");
2923       ASSERT_EQ(node.input_size(), 2);
2924       EXPECT_EQ(node.input(0), "x");
2925       EXPECT_EQ(node.input(1), "sqrt_y");
2926     } else if (node.name() == "sqrt_y") {
2927       EXPECT_EQ(node.op(), "Rsqrt");
2928       ASSERT_EQ(node.input_size(), 1);
2929       EXPECT_EQ(node.input(0), "y");
2930     }
2931   }
2932 }
2933 
TEST_F(ArithmeticOptimizerTest,DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode)2934 TEST_F(ArithmeticOptimizerTest, DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode) {
2935   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2936   Output floats = ops::Const(s.WithOpName("floats"),
2937                              {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2938   Output output0 = ops::Sqrt(s.WithOpName("output0"), floats);
2939   Output const1 = ops::Const(s.WithOpName("const1"), 1.0f, {3});
2940   Output mul1 = ops::Multiply(s.WithOpName("mul1"), const1, 0.5f);
2941   Output grad = ops::Div(s.WithOpName("grad"), mul1, output0);
2942 
2943   GrapplerItem item;
2944   item.fetch = {"grad", "output0"};
2945   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2946   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2947   ASSERT_EQ(tensors_expected.size(), 2);
2948 
2949   GraphDef output;
2950   ArithmeticOptimizer optimizer;
2951   EnableOnlySqrtDivToRsqrtMul(&optimizer);
2952   OptimizeAndPrune(&optimizer, &item, &output);
2953   auto tensors = EvaluateNodes(output, item.fetch);
2954   ASSERT_EQ(tensors.size(), 2);
2955 
2956   for (int i = 0; i < tensors.size(); i++) {
2957     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2958     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2959   }
2960   EXPECT_EQ(output.node_size(), item.graph.node_size());
2961   for (int i = 0; i < output.node_size(); ++i) {
2962     const NodeDef& node = output.node(i);
2963     if (node.name() == "grad") {
2964       EXPECT_EQ(node.op(), "Div");
2965       ASSERT_EQ(node.input_size(), 2);
2966       EXPECT_EQ(node.input(0), "mul1");
2967       EXPECT_EQ(node.input(1), "output0");
2968     } else if (node.name() == "output0") {
2969       EXPECT_EQ(node.op(), "Sqrt");
2970       ASSERT_EQ(node.input_size(), 1);
2971       EXPECT_EQ(node.input(0), "floats");
2972     }
2973   }
2974 }
2975 
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMulExcludeFloorDiv)2976 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMulExcludeFloorDiv) {
2977   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2978   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2979   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2980   Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2981   Output div_x_sqrt_y = ops::FloorDiv(s.WithOpName("output"), x, sqrt_y);
2982 
2983   GrapplerItem item;
2984   item.fetch = {"output"};
2985   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2986   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2987   ASSERT_EQ(tensors_expected.size(), 1);
2988 
2989   GraphDef output;
2990   ArithmeticOptimizer optimizer;
2991   EnableOnlySqrtDivToRsqrtMul(&optimizer);
2992   OptimizeAndPrune(&optimizer, &item, &output);
2993   auto tensors = EvaluateNodes(output, item.fetch);
2994   ASSERT_EQ(tensors.size(), 1);
2995 
2996   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2997   EXPECT_EQ(output.node_size(), item.graph.node_size());
2998   for (int i = 0; i < output.node_size(); ++i) {
2999     const NodeDef& node = output.node(i);
3000     if (node.name() == "output") {
3001       EXPECT_EQ(node.op(), "FloorDiv");
3002       ASSERT_EQ(node.input_size(), 2);
3003       EXPECT_EQ(node.input(0), "x");
3004       EXPECT_EQ(node.input(1), "sqrt_y");
3005     } else if (node.name() == "sqrt_y") {
3006       EXPECT_EQ(node.op(), "Sqrt");
3007       ASSERT_EQ(node.input_size(), 1);
3008       EXPECT_EQ(node.input(0), "y");
3009     }
3010   }
3011 }
3012 
TEST_F(ArithmeticOptimizerTest,FuseSquaredDiff)3013 TEST_F(ArithmeticOptimizerTest, FuseSquaredDiff) {
3014   for (bool is_complex : {false, true}) {
3015     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3016     Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3017     Output y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3018     Output complex_x = ops::Complex(s.WithOpName("complex_x"), x, x);
3019     Output complex_y = ops::Complex(s.WithOpName("complex_y"), y, y);
3020     Output sub_x_y =
3021         is_complex ? ops::Sub(s.WithOpName("sub_x_y"), complex_x, complex_y)
3022                    : ops::Sub(s.WithOpName("sub_x_y"), x, y);
3023     Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
3024 
3025     GrapplerItem item;
3026     item.fetch = {"output"};
3027     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3028     const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3029     ASSERT_EQ(tensors_expected.size(), 1);
3030 
3031     GraphDef output;
3032     ArithmeticOptimizer optimizer;
3033     EnableOnlyFuseSquaredDiff(&optimizer);
3034     OptimizeAndPrune(&optimizer, &item, &output);
3035     const auto tensors = EvaluateNodes(output, item.fetch);
3036     ASSERT_EQ(tensors.size(), 1);
3037 
3038     if (is_complex) {
3039       test::ExpectTensorNear<std::complex<float>>(tensors[0],
3040                                                   tensors_expected[0], 1e-6);
3041       EXPECT_EQ(output.node_size(), item.graph.node_size());
3042     } else {
3043       test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3044       // The two unused Complex nodes should get pruned.
3045       EXPECT_EQ(output.node_size(), item.graph.node_size() - 2);
3046     }
3047     for (int i = 0; i < output.node_size(); ++i) {
3048       const NodeDef& node = output.node(i);
3049       if (node.name() == "output") {
3050         EXPECT_EQ(node.op(), is_complex ? "Square" : "Identity");
3051         ASSERT_EQ(node.input_size(), 1);
3052         EXPECT_EQ(node.input(0), "sub_x_y");
3053       } else if (node.name() == "sub_x_y") {
3054         EXPECT_EQ(node.op(), is_complex ? "Sub" : "SquaredDifference");
3055         ASSERT_EQ(node.input_size(), 2);
3056         EXPECT_EQ(node.input(0), is_complex ? "complex_x" : "x");
3057         EXPECT_EQ(node.input(1), is_complex ? "complex_y" : "y");
3058       }
3059     }
3060   }
3061 }
3062 
TEST_F(ArithmeticOptimizerTest,DoNotFuseSquaredDiffFetchNode)3063 TEST_F(ArithmeticOptimizerTest, DoNotFuseSquaredDiffFetchNode) {
3064   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3065   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3066   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3067   Output sub_x_y = ops::Sub(s.WithOpName("sub_x_y"), x, y);
3068   Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
3069 
3070   GrapplerItem item;
3071   item.fetch = {"output", "sub_x_y"};
3072   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3073   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3074   ASSERT_EQ(tensors_expected.size(), 2);
3075 
3076   GraphDef output;
3077   ArithmeticOptimizer optimizer;
3078   EnableOnlyFuseSquaredDiff(&optimizer);
3079   OptimizeAndPrune(&optimizer, &item, &output);
3080   const auto tensors = EvaluateNodes(output, item.fetch);
3081   ASSERT_EQ(tensors.size(), 2);
3082 
3083   for (int i = 0; i < tensors.size(); i++) {
3084     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3085     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3086   }
3087   EXPECT_EQ(output.node_size(), item.graph.node_size());
3088   for (int i = 0; i < output.node_size(); ++i) {
3089     const NodeDef& node = output.node(i);
3090     if (node.name() == "output") {
3091       EXPECT_EQ(node.op(), "Square");
3092       ASSERT_EQ(node.input_size(), 1);
3093       EXPECT_EQ(node.input(0), "sub_x_y");
3094     } else if (node.name() == "sub_x_y") {
3095       EXPECT_EQ(node.op(), "Sub");
3096       ASSERT_EQ(node.input_size(), 2);
3097       EXPECT_EQ(node.input(0), "x");
3098       EXPECT_EQ(node.input(1), "y");
3099     }
3100   }
3101 }
3102 
TEST_F(ArithmeticOptimizerTest,ConvertLogSoftmax)3103 TEST_F(ArithmeticOptimizerTest, ConvertLogSoftmax) {
3104   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3105   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3106   Output softmax = ops::Softmax(s.WithOpName("softmax"), x);
3107   Output logsoftmax = ops::Log(s.WithOpName("output"), softmax);
3108 
3109   GrapplerItem item;
3110   item.fetch = {"output"};
3111   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3112   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3113   ASSERT_EQ(tensors_expected.size(), 1);
3114 
3115   GraphDef output;
3116   ArithmeticOptimizer optimizer;
3117   EnableOnlyLogSoftmax(&optimizer);
3118   OptimizeAndPrune(&optimizer, &item, &output);
3119   const auto tensors = EvaluateNodes(output, item.fetch);
3120   ASSERT_EQ(tensors.size(), 1);
3121 
3122   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3123   EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
3124   for (int i = 0; i < output.node_size(); ++i) {
3125     const NodeDef& node = output.node(i);
3126     if (node.name() == "output") {
3127       EXPECT_EQ(node.op(), "LogSoftmax");
3128       ASSERT_EQ(node.input_size(), 1);
3129       EXPECT_EQ(node.input(0), "x");
3130     }
3131   }
3132 }
3133 
TEST_F(ArithmeticOptimizerTest,DoNotConvertLogSoftmaxArgFetchNode)3134 TEST_F(ArithmeticOptimizerTest, DoNotConvertLogSoftmaxArgFetchNode) {
3135   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3136   Output floats = ops::Const(s.WithOpName("floats"),
3137                              {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
3138   Output softmax = ops::Softmax(s.WithOpName("softmax"), floats);
3139   Output final_output = ops::Log(s.WithOpName("final_output"), softmax);
3140 
3141   GrapplerItem item;
3142   item.fetch = {"softmax", "final_output"};
3143   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3144   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3145   ASSERT_EQ(tensors_expected.size(), 2);
3146 
3147   GraphDef output;
3148   ArithmeticOptimizer optimizer;
3149   EnableOnlyLogSoftmax(&optimizer);
3150   OptimizeTwice(&optimizer, &item, &output);
3151   const auto tensors = EvaluateNodes(output, item.fetch);
3152   ASSERT_EQ(tensors.size(), 2);
3153 
3154   // Should be a NoOp since we are not allowed to change the output of fetch
3155   // nodes.
3156   VerifyGraphsMatch(item.graph, output, __LINE__);
3157 
3158   for (int i = 0; i < tensors.size(); i++) {
3159     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3160     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3161   }
3162 }
3163 
TEST_F(ArithmeticOptimizerTest,ConvertPow)3164 TEST_F(ArithmeticOptimizerTest, ConvertPow) {
3165   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3166   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3167   auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
3168   auto y3 = ops::Const(s.WithOpName("y3"), {3.0f, 3.0f}, {1, 2});
3169   auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
3170   auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
3171   auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
3172   auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
3173   auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
3174   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3175   auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
3176   auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
3177   auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
3178   Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
3179   Output out3 =
3180       ops::Pow(s.WithOpName("out3").WithDevice("/device:CPU:0"), x, y3);
3181   Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
3182   Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
3183   Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
3184   Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
3185   Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
3186   Output out = ops::Pow(s.WithOpName("out"), x, y);
3187   Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
3188   Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
3189 
3190   GrapplerItem item;
3191   item.fetch = {"out2",   "out3",  "out1", "out.5",      "out0",
3192                 "out_.5", "out_1", "out",  "out_bcast1", "out_bcast2"};
3193   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3194   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3195   ASSERT_EQ(tensors_expected.size(), 10);
3196 
3197   GraphDef got;
3198   ArithmeticOptimizer optimizer;
3199   EnableOnlyConvertPow(&optimizer);
3200   OptimizeAndPrune(&optimizer, &item, &got);
3201   auto tensors = EvaluateNodes(got, item.fetch);
3202   ASSERT_EQ(tensors.size(), 10);
3203 
3204   for (int i = 0; i < tensors.size(); ++i) {
3205     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3206     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3207   }
3208 
3209   GraphDef want;
3210   AddNode("x", "Const", {}, {}, &want);
3211   AddNode("y", "Const", {}, {}, &want);
3212   AddNode("z", "Const", {}, {}, &want);
3213   AddNode("ones", "Const", {}, {}, &want);
3214   AddNode("zeros", "Const", {}, {}, &want);
3215   AddNode("out2", "Square", {"x"}, {}, &want);
3216   AddNode("ArithmeticOptimizer/ConvertPow__inner_out3", "Square", {"x"}, {},
3217           &want)
3218       ->set_device("/device:CPU:0");
3219   AddNode("out3", "Mul", {"x", "ArithmeticOptimizer/ConvertPow__inner_out3"},
3220           {}, &want)
3221       ->set_device("/device:CPU:0");
3222   AddNode("out1", "Identity", {"x"}, {}, &want);
3223   AddNode("out.5", "Sqrt", {"x"}, {}, &want);
3224   AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want);
3225   AddNode("out_.5", "Rsqrt", {"x"}, {}, &want);
3226   AddNode("out_1", "Reciprocal", {"x"}, {}, &want);
3227   AddNode("out", "Pow", {"x", "y"}, {}, &want);
3228   AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
3229   AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
3230 
3231   CompareGraphs(want, got);
3232 }
3233 
TEST_F(ArithmeticOptimizerTest,Log1p)3234 TEST_F(ArithmeticOptimizerTest, Log1p) {
3235   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3236 
3237   auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2});
3238   auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2});
3239   auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
3240   auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2);
3241   auto a23 = ops::Add(s.WithOpName("a23"), x2, x3);
3242   Output out1 = ops::Log(s.WithOpName("out1"), a12);
3243   Output out2 = ops::Log(s.WithOpName("out2"), a23);
3244 
3245   GrapplerItem item;
3246   item.fetch = {"out1", "out2"};
3247   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3248   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3249   ASSERT_EQ(tensors_expected.size(), 2);
3250 
3251   GraphDef got;
3252   ArithmeticOptimizer optimizer;
3253   EnableOnlyLog1p(&optimizer);
3254   OptimizeAndPrune(&optimizer, &item, &got);
3255   auto tensors = EvaluateNodes(got, item.fetch);
3256   ASSERT_EQ(tensors.size(), 2);
3257 
3258   for (int i = 0; i < 2; ++i) {
3259     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3260     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3261   }
3262 
3263   GraphDef want;
3264   AddNode("x2", "Const", {}, {}, &want);
3265   AddNode("x3", "Const", {}, {}, &want);
3266   AddNode("a23", "Add", {"x2", "x3"}, {}, &want);
3267   AddNode("out1", "Log1p", {"x2", AsControlDependency("x3")}, {}, &want);
3268   AddNode("out2", "Log", {"a23"}, {}, &want);
3269 
3270   CompareGraphs(want, got);
3271 }
3272 
TEST_F(ArithmeticOptimizerTest,Expm1)3273 TEST_F(ArithmeticOptimizerTest, Expm1) {
3274   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3275 
3276   auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
3277   auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
3278   auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
3279   auto exp1 = ops::Exp(s.WithOpName("exp1").WithControlDependencies(x3), x1);
3280   Output out1 = ops::Sub(s.WithOpName("out1"), exp1, x2);
3281   Output out2 = ops::Sub(s.WithOpName("out2"), exp1, x3);
3282 
3283   GrapplerItem item;
3284   item.fetch = {"out1", "out2"};
3285   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3286   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3287   ASSERT_EQ(tensors_expected.size(), 2);
3288 
3289   GraphDef got;
3290   ArithmeticOptimizer optimizer;
3291   EnableOnlyExpm1(&optimizer);
3292   OptimizeAndPrune(&optimizer, &item, &got);
3293   auto tensors = EvaluateNodes(got, item.fetch);
3294   ASSERT_EQ(tensors.size(), 2);
3295 
3296   for (int i = 0; i < 2; ++i) {
3297     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3298     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3299   }
3300 
3301   GraphDef want;
3302   AddNode("x1", "Const", {}, {}, &want);
3303   AddNode("x3", "Const", {}, {}, &want);
3304   AddNode("exp1", "Exp", {"x1", AsControlDependency("x3")}, {}, &want);
3305   AddNode("out1", "Expm1", {"x1", AsControlDependency("x3")}, {}, &want);
3306   AddNode("out2", "Sub", {"exp1", "x3"}, {}, &want);
3307 
3308   CompareGraphs(want, got);
3309 }
3310 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_SimpleSwap)3311 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
3312   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3313 
3314   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
3315   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
3316   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
3317 
3318   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3319   auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
3320 
3321   auto outputs = ops::Identity(s.WithOpName("outputs"), mul2);
3322 
3323   GrapplerItem item;
3324   item.fetch = {"outputs"};
3325   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3326 
3327   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3328   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
3329   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3330   std::vector<std::pair<string, Tensor>> feed = {
3331       {"a", a_t}, {"b", b_t}, {"c", c_t}};
3332   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3333   ASSERT_EQ(tensors_expected.size(), 1);
3334 
3335   GraphDef output;
3336   ArithmeticOptimizer optimizer;
3337   EnableOnlyMinimizeBroadcasts(&optimizer);
3338 
3339   OptimizeAndPrune(&optimizer, &item, &output);
3340 
3341   // We expect the following rewrite(s) to occur:
3342   //
3343   //     *                  *
3344   //    / \                / \
3345   //   *   c      -->     *   b
3346   //  / \                / \
3347   // a   b              a   c
3348   NodeMap node_map(&output);
3349 
3350   const NodeDef* mul1_node = node_map.GetNode("mul1");
3351   ASSERT_NE(mul1_node, nullptr);
3352   ASSERT_EQ(mul1_node->input_size(), 2);
3353   EXPECT_EQ(mul1_node->input(0), "a");
3354   EXPECT_EQ(mul1_node->input(1), "c");
3355 
3356   const NodeDef* mul2_node = node_map.GetNode("mul2");
3357   ASSERT_NE(mul2_node, nullptr);
3358   ASSERT_EQ(mul2_node->input_size(), 2);
3359   EXPECT_EQ(mul2_node->input(0), "mul1");
3360   EXPECT_EQ(mul2_node->input(1), "b");
3361 
3362   auto tensors = EvaluateNodes(output, item.fetch, feed);
3363   ASSERT_EQ(tensors.size(), 1);
3364   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3365 }
3366 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_FlattenTallGraph)3367 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
3368   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3369 
3370   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_DOUBLE);
3371   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_DOUBLE);
3372   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_DOUBLE);
3373   auto d = ops::Variable(s.WithOpName("d"), {32}, DT_DOUBLE);
3374   auto e = ops::Variable(s.WithOpName("e"), {32}, DT_DOUBLE);
3375 
3376   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3377   auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
3378   auto mul3 = ops::Mul(s.WithOpName("mul3"), mul2, d);
3379   auto mul4 = ops::Mul(s.WithOpName("mul4"), mul3, e);
3380 
3381   auto outputs = ops::Identity(s.WithOpName("outputs"), mul4);
3382 
3383   GrapplerItem item;
3384   item.fetch = {"outputs"};
3385   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3386 
3387   auto a_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3388   auto b_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32, 32}));
3389   auto c_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3390   auto d_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3391   auto e_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3392   std::vector<std::pair<string, Tensor>> feed = {
3393       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"d", d_t}, {"e", e_t}};
3394   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3395   ASSERT_EQ(tensors_expected.size(), 1);
3396 
3397   GraphDef output;
3398   ArithmeticOptimizer optimizer;
3399   EnableOnlyMinimizeBroadcasts(&optimizer);
3400 
3401   OptimizeAndPrune(&optimizer, &item, &output);
3402 
3403   // We expect the following rewrite(s) to occur: Graph is "flattened" and
3404   // largest shape pushed to the top.
3405   //
3406   //          *
3407   //        /   \
3408   //       *     e                *
3409   //      /  \                  /   \
3410   //     *    d               *      b
3411   //    / \                 /  \
3412   //   *   c      -->     *      *
3413   //  / \                / \    / \
3414   // a   b              a   c  d   e
3415   NodeMap node_map(&output);
3416 
3417   const NodeDef* mul1_node = node_map.GetNode("mul1");
3418   ASSERT_NE(mul1_node, nullptr);
3419   ASSERT_EQ(mul1_node->input_size(), 2);
3420   EXPECT_EQ(mul1_node->input(0), "a");
3421   EXPECT_EQ(mul1_node->input(1), "c");
3422 
3423   const NodeDef* mul2_node = node_map.GetNode("mul2");
3424   ASSERT_NE(mul2_node, nullptr);
3425   ASSERT_EQ(mul2_node->input_size(), 2);
3426   EXPECT_EQ(mul2_node->input(0), "d");
3427   EXPECT_EQ(mul2_node->input(1), "e");
3428 
3429   const NodeDef* mul3_node = node_map.GetNode("mul3");
3430   ASSERT_NE(mul3_node, nullptr);
3431   ASSERT_EQ(mul3_node->input_size(), 2);
3432   EXPECT_EQ(mul3_node->input(0), "mul1");
3433   EXPECT_EQ(mul3_node->input(1), "mul2");
3434 
3435   const NodeDef* mul4_node = node_map.GetNode("mul4");
3436   ASSERT_NE(mul4_node, nullptr);
3437   ASSERT_EQ(mul4_node->input_size(), 2);
3438   EXPECT_EQ(mul4_node->input(0), "mul3");
3439   EXPECT_EQ(mul4_node->input(1), "b");
3440 
3441   auto tensors = EvaluateNodes(output, item.fetch, feed);
3442   ASSERT_EQ(tensors.size(), 1);
3443   test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
3444 }
3445 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_BuildTreeUp)3446 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
3447   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3448 
3449   // [a, b, c] - scalars, [d] - matrix
3450   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
3451   auto b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
3452   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
3453   auto d = ops::Variable(s.WithOpName("D"), {32, 32}, DT_FLOAT);
3454 
3455   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3456   auto mul2 = ops::Mul(s.WithOpName("mul2"), c, d);
3457   auto mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2);
3458 
3459   auto outputs = ops::Identity(s.WithOpName("outputs"), mul3);
3460 
3461   GrapplerItem item;
3462   item.fetch = {"outputs"};
3463   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3464 
3465   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3466   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3467   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3468   auto d_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
3469   std::vector<std::pair<string, Tensor>> feed = {
3470       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"D", d_t}};
3471   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3472   ASSERT_EQ(tensors_expected.size(), 1);
3473 
3474   GraphDef output;
3475   ArithmeticOptimizer optimizer;
3476   EnableOnlyMinimizeBroadcasts(&optimizer);
3477 
3478   OptimizeAndPrune(&optimizer, &item, &output);
3479 
3480   // We expect the following rewrite(s) to occur:
3481   //
3482   //                              *
3483   //                            /  \
3484   //       *                   *    D
3485   //     /   \                / \
3486   //    *     *      ->      *   c
3487   //   / \   / \            / \
3488   //  a   b c   D          a   b
3489   NodeMap node_map(&output);
3490 
3491   const NodeDef* mul1_node = node_map.GetNode("mul2");
3492   ASSERT_NE(mul1_node, nullptr);
3493   ASSERT_EQ(mul1_node->input_size(), 2);
3494   EXPECT_EQ(mul1_node->input(0), "a");
3495   EXPECT_EQ(mul1_node->input(1), "b");
3496 
3497   const NodeDef* mul2_node = node_map.GetNode("mul1");
3498   ASSERT_NE(mul2_node, nullptr);
3499   ASSERT_EQ(mul2_node->input_size(), 2);
3500   EXPECT_EQ(mul2_node->input(0), "mul2");
3501   EXPECT_EQ(mul2_node->input(1), "c");
3502 
3503   const NodeDef* mul3_node = node_map.GetNode("mul3");
3504   ASSERT_NE(mul3_node, nullptr);
3505   ASSERT_EQ(mul3_node->input_size(), 2);
3506   EXPECT_EQ(mul3_node->input(0), "D");
3507   EXPECT_EQ(mul3_node->input(1), "mul1");
3508 
3509   auto tensors = EvaluateNodes(output, item.fetch, feed);
3510   ASSERT_EQ(tensors.size(), 1);
3511   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3512 }
3513 
TEST_F(ArithmeticOptimizerTest,DoNotHoistReluFromConcat)3514 TEST_F(ArithmeticOptimizerTest, DoNotHoistReluFromConcat) {
3515   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3516   Output weights1 = ops::Const(s.WithOpName("weights1"),
3517                                Input::Initializer(1.0f, {5, 5, 3, 4}));
3518   Output weights2 = ops::Const(s.WithOpName("weights2"),
3519                                Input::Initializer(2.0f, {5, 5, 3, 4}));
3520   Output biases =
3521       ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
3522   Output axis = ops::Const(s.WithOpName("axis"), 3, {});
3523   Output input = ops::Const(s.WithOpName("input"),
3524                             Input::Initializer(1.0f, {1, 28, 28, 3}));
3525   Output branch1 =
3526       ops::Conv2D(s.WithOpName("conv1"), input, weights1, {1, 1, 1, 1}, "SAME");
3527   branch1 = ops::BiasAdd(s.WithOpName("biasadd1"), branch1, biases);
3528   branch1 = ops::Relu(s.WithOpName("relu1"), branch1);
3529   Output branch2 =
3530       ops::Conv2D(s.WithOpName("conv2"), input, weights2, {1, 1, 1, 1}, "SAME");
3531   branch2 = ops::BiasAdd(s.WithOpName("biasadd2"), branch2, biases);
3532   branch2 = ops::Relu(s.WithOpName("relu2"), branch2);
3533   Output concat = ops::Concat(s.WithOpName("concat"), {branch1, branch2}, axis);
3534   Output output = ops::Identity(s.WithOpName("output"), concat);
3535 
3536   GrapplerItem item;
3537   item.fetch = {"output"};
3538   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3539 
3540   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3541 
3542   GraphDef new_graph;
3543   ArithmeticOptimizer optimizer;
3544   OptimizeAndPrune(&optimizer, &item, &new_graph);
3545 
3546   // Verify that the two Relus are not hoisted.
3547   EXPECT_EQ(CountOpNodes(new_graph, "Relu"), 2);
3548 
3549   auto tensors = EvaluateNodes(new_graph, item.fetch);
3550   for (int i = 0; i < item.fetch.size(); ++i) {
3551     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3552   }
3553 }
3554 
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryFromConcat)3555 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
3556   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3557   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3558   Output b = ops::Const(s.WithOpName("b"), 1.0f, {32});
3559   Output c = ops::Const(s.WithOpName("c"), 42.0f, {32});
3560   Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3561   Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3562   Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3563   Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3564   // Test case with chains of length 1.
3565   // Rewrites
3566   //       Concat({Exp(a), Exp(b), Exp(c)})
3567   // into
3568   //       Exp(Concat({a, b, c})).
3569   Output sin_a =
3570       ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a);
3571   Output exp_a =
3572       ops::Exp(s.WithOpName("exp_a").WithControlDependencies(ctrl1), sin_a);
3573   Output exp_b = ops::Exp(s.WithOpName("exp_b"), b);
3574   Output exp_c =
3575       ops::Exp(s.WithOpName("exp_c").WithControlDependencies(ctrl2), c);
3576   Output concat =
3577       ops::Concat(s.WithOpName("concat"), {exp_a, exp_b, exp_c}, axis);
3578   Output id = ops::Identity(s.WithOpName("id"), concat);
3579 
3580   // Test case with chains of length 2.
3581   // Rewrites
3582   //       Concat({Cos(Exp(a)), Cos(Exp(b)), Cos(Exp(c))})
3583   // into
3584   //       Cos(Exp(Concat({a, b, c}))).
3585   Output exp_a2 =
3586       ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a);
3587   Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b);
3588   Output exp_c2 =
3589       ops::Exp(s.WithOpName("exp_c2").WithControlDependencies(ctrl2), c);
3590   Output cos_exp_a2 = ops::Cos(
3591       s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl1), exp_a2);
3592   Output cos_exp_b2 = ops::Cos(
3593       s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3594   Output cos_exp_c2 = ops::Cos(s.WithOpName("cos_exp_c2"), exp_c2);
3595   Output concat2 = ops::Concat(s.WithOpName("concat2"),
3596                                {cos_exp_a2, cos_exp_b2, cos_exp_c2}, axis);
3597   Output id2 = ops::Identity(s.WithOpName("id2"), concat2);
3598   GrapplerItem item;
3599   item.fetch = {"id", "id2"};
3600   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3601 
3602   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3603 
3604   GraphDef output;
3605   ArithmeticOptimizer optimizer;
3606   EnableOnlyHoistCWiseUnaryChains(&optimizer);
3607   OptimizeTwiceAndPrune(&optimizer, &item, &output);
3608 
3609   int found = 0;
3610   for (const NodeDef& node : output.node()) {
3611     if (node.name() == "concat") {
3612       ASSERT_EQ(node.input_size(), 4);
3613       EXPECT_EQ(node.input(0), "sin_a");
3614       EXPECT_EQ(node.input(1), "b");
3615       EXPECT_EQ(node.input(2), "c");
3616       EXPECT_EQ(node.input(3), "axis");
3617       found++;
3618     }
3619     if (node.name() == "exp_a") {
3620       ASSERT_EQ(node.input_size(), 1);
3621       EXPECT_EQ(node.input(0), "concat");
3622       found++;
3623     }
3624     if (node.name() == "id") {
3625       ASSERT_EQ(node.input_size(), 1);
3626       EXPECT_EQ(node.input(0), "exp_a");
3627       found++;
3628     }
3629 
3630     if (node.name() == "concat2") {
3631       ASSERT_EQ(node.input_size(), 4);
3632       EXPECT_EQ(node.input(0), "sin_a");
3633       EXPECT_EQ(node.input(1), "b");
3634       EXPECT_EQ(node.input(2), "c");
3635       EXPECT_EQ(node.input(3), "axis");
3636       found++;
3637     }
3638     if (node.name() == "exp_a2") {
3639       ASSERT_EQ(node.input_size(), 1);
3640       EXPECT_EQ(node.input(0), "concat2");
3641       found++;
3642     }
3643     if (node.name() == "cos_exp_a2") {
3644       ASSERT_EQ(node.input_size(), 1);
3645       EXPECT_EQ(node.input(0), "exp_a2");
3646       found++;
3647     }
3648     if (node.name() == "id2") {
3649       ASSERT_EQ(node.input_size(), 1);
3650       EXPECT_EQ(node.input(0), "cos_exp_a2");
3651       found++;
3652     }
3653   }
3654   EXPECT_EQ(found, 7);
3655 
3656   auto tensors = EvaluateNodes(output, item.fetch);
3657   ASSERT_EQ(tensors.size(), tensors_expected.size());
3658   EXPECT_EQ(tensors.size(), item.fetch.size());
3659   for (int i = 0; i < item.fetch.size(); ++i) {
3660     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3661   }
3662 }
3663 
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryIntoSplit)3664 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
3665   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3666   Output x = ops::Const(s.WithOpName("x"), 3.1415f, {32});
3667   Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3668   Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3669   Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3670   Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3671   // Test case with chains of length 1.
3672   // Rewrites
3673   //          [Sin(y) for y in Split(x)]
3674   // into
3675   //          [y for y in Split(Sin(x))].
3676   ops::Split split1(s.WithOpName("split1"), axis, x, 2);
3677   Output sin_a =
3678       ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl1), split1[0]);
3679   Output id_a = ops::Identity(s.WithOpName("id_a"), sin_a);
3680   Output sin_b = ops::Sin(s.WithOpName("sin_b"), split1[1]);
3681   Output exp_b = ops::Exp(s.WithOpName("exp_b"), sin_b);
3682   Output id_b = ops::Identity(s.WithOpName("id_b"), exp_b);
3683 
3684   // Test case with SplitV and chains of length 2.
3685   // Rewrites
3686   //          [Cos(Exp(y)) for y in Split(x)]
3687   // into
3688   //          [y for y in Split(Cos(Exp(x)))].
3689   Output size_splits2 = ops::Const(s.WithOpName("size_splits2"), {20, 12}, {2});
3690   ops::SplitV split2(s.WithOpName("split2"), x, size_splits2, axis, 2);
3691   Output exp_a2 = ops::Exp(
3692       s.WithOpName("exp_a2").WithControlDependencies(ctrl1), split2[0]);
3693   Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), split2[1]);
3694   Output cos_exp_a2 = ops::Cos(
3695       s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl2), exp_a2);
3696   Output cos_exp_b2 = ops::Cos(
3697       s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3698   Output id_a2 = ops::Identity(s.WithOpName("id_a2"), cos_exp_a2);
3699   Output id_b2 = ops::Identity(s.WithOpName("id_b2"), cos_exp_b2);
3700 
3701   GrapplerItem item;
3702   item.fetch = {"id_a", "id_b", "id_a2", "id_b2"};
3703   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3704 
3705   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3706 
3707   GraphDef output;
3708   ArithmeticOptimizer optimizer;
3709   EnableOnlyHoistCWiseUnaryChains(&optimizer);
3710   OptimizeTwiceAndPrune(&optimizer, &item, &output);
3711 
3712   int found = 0;
3713   for (const NodeDef& node : output.node()) {
3714     // The following 6 nodes should be pruned.
3715     EXPECT_NE(node.name(), "sin_a");
3716     EXPECT_NE(node.name(), "sin_b");
3717     EXPECT_NE(node.name(), "exp_a2");
3718     EXPECT_NE(node.name(), "exp_b2");
3719     EXPECT_NE(node.name(), "cos_exp_a2");
3720     EXPECT_NE(node.name(), "cos_exp_b2");
3721 
3722     if (node.name() == "split1") {
3723       ASSERT_EQ(node.input_size(), 2);
3724       EXPECT_EQ(node.input(0), "axis");
3725       EXPECT_EQ(node.input(1), "ArithmeticOptimizer/_sin_a_split1");
3726       found++;
3727     }
3728     if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
3729       EXPECT_EQ(node.op(), "Sin");
3730       ASSERT_EQ(node.input_size(), 1);
3731       EXPECT_EQ(node.input(0), "x");
3732       found++;
3733     }
3734     if (node.name() == "id_a") {
3735       ASSERT_EQ(node.input_size(), 1);
3736       EXPECT_EQ(node.input(0), "split1");
3737       found++;
3738     }
3739     if (node.name() == "exp_b") {
3740       ASSERT_EQ(node.input_size(), 1);
3741       EXPECT_EQ(node.input(0), "split1:1");
3742       found++;
3743     }
3744     if (node.name() == "id_b") {
3745       ASSERT_EQ(node.input_size(), 1);
3746       EXPECT_EQ(node.input(0), "exp_b");
3747       found++;
3748     }
3749     if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
3750       EXPECT_EQ(node.op(), "Exp");
3751       ASSERT_EQ(node.input_size(), 1);
3752       EXPECT_EQ(node.input(0), "x");
3753       found++;
3754     }
3755     if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
3756       EXPECT_EQ(node.op(), "Cos");
3757       ASSERT_EQ(node.input_size(), 1);
3758       EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_exp_a2_split2");
3759       found++;
3760     }
3761     if (node.name() == "split2") {
3762       ASSERT_EQ(node.input_size(), 3);
3763       EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_cos_exp_a2_split2");
3764       EXPECT_EQ(node.input(1), "size_splits2");
3765       EXPECT_EQ(node.input(2), "axis");
3766       found++;
3767     }
3768     if (node.name() == "id_a2") {
3769       ASSERT_EQ(node.input_size(), 1);
3770       EXPECT_EQ(node.input(0), "split2");
3771       found++;
3772     }
3773     if (node.name() == "id_b2") {
3774       ASSERT_EQ(node.input_size(), 1);
3775       EXPECT_EQ(node.input(0), "split2:1");
3776       found++;
3777     }
3778   }
3779   EXPECT_EQ(found, 10);
3780 
3781   auto tensors = EvaluateNodes(output, item.fetch);
3782   ASSERT_EQ(tensors.size(), tensors_expected.size());
3783   EXPECT_EQ(tensors.size(), item.fetch.size());
3784   for (int i = 0; i < item.fetch.size(); ++i) {
3785     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3786   }
3787 }
3788 
TEST_F(ArithmeticOptimizerTest,RemoveIdempotent)3789 TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
3790   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3791   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3792   Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
3793   Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
3794   Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
3795   Output id1 = ops::Identity(s.WithOpName("id1"), a);
3796   Output id2 = ops::Identity(s.WithOpName("id2"), id1);
3797   Output out2 = ops::Identity(s.WithOpName("out2"), id2);
3798   GrapplerItem item;
3799   item.fetch = {"out1", "out2"};
3800   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3801 
3802   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3803 
3804   GraphDef output;
3805   ArithmeticOptimizer optimizer;
3806   EnableOnlyRemoveIdempotent(&optimizer);
3807   OptimizeTwice(&optimizer, &item, &output);
3808 
3809   EXPECT_EQ(7, output.node_size());
3810   int found = 0;
3811   for (const NodeDef& node : output.node()) {
3812     if (node.name() == "out1") {
3813       ASSERT_EQ(node.input_size(), 1);
3814       EXPECT_EQ(node.input(0), "sn1");
3815       found++;
3816     } else if (node.name() == "out2") {
3817       ASSERT_EQ(node.input_size(), 1);
3818       EXPECT_EQ(node.input(0), "id1");
3819       found++;
3820     } else if (node.name() == "sn1") {
3821       ASSERT_EQ(node.input_size(), 1);
3822       EXPECT_EQ(node.input(0), "a");
3823       found++;
3824     }
3825   }
3826   EXPECT_EQ(found, 3);
3827 
3828   auto tensors = EvaluateNodes(output, item.fetch);
3829   ASSERT_EQ(tensors.size(), tensors_expected.size());
3830   EXPECT_EQ(tensors.size(), item.fetch.size());
3831   for (int i = 0; i < item.fetch.size(); ++i) {
3832     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3833   }
3834 }
3835 
TEST_F(ArithmeticOptimizerTest,RemoveLogicalNot)3836 TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
3837   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3838   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3839   Output b = ops::Const(s.WithOpName("b"), -3.14f, {32});
3840   Output eq = ops::Equal(s.WithOpName("eq"), a, b);
3841   Output neq = ops::NotEqual(s.WithOpName("neq"), a, b);
3842   Output lt = ops::Less(s.WithOpName("lt"), a, b);
3843   Output le = ops::LessEqual(s.WithOpName("le"), a, b);
3844   Output gt = ops::Greater(s.WithOpName("gt"), a, b);
3845   Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b);
3846   // not_eq is reserved
3847   Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq);
3848   Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq);
3849   Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt);
3850   Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le);
3851   Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt);
3852   Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge);
3853   Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1);
3854   Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq);
3855   Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt);
3856   Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le);
3857   Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt);
3858   Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge);
3859 
3860   GrapplerItem item;
3861   item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt",
3862                 "id_not_le", "id_not_gt",  "id_not_ge"};
3863   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3864 
3865   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3866 
3867   GraphDef output;
3868   ArithmeticOptimizer optimizer;
3869   EnableOnlyRemoveLogicalNot(&optimizer);
3870   OptimizeTwice(&optimizer, &item, &output);
3871 
3872   int found = 0;
3873   for (const NodeDef& node : output.node()) {
3874     if (node.name() == "id_not_eq") {
3875       ASSERT_EQ(node.input_size(), 1);
3876       EXPECT_EQ(node.input(0), "eq");
3877       ++found;
3878     }
3879     if (node.name() == "id_not_neq") {
3880       ASSERT_EQ(node.input_size(), 1);
3881       EXPECT_EQ(node.input(0), "neq");
3882       ++found;
3883     }
3884     if (node.name() == "id_not_lt") {
3885       ASSERT_EQ(node.input_size(), 1);
3886       EXPECT_EQ(node.input(0), "lt");
3887       ++found;
3888     }
3889     if (node.name() == "id_not_le") {
3890       ASSERT_EQ(node.input_size(), 1);
3891       EXPECT_EQ(node.input(0), "le");
3892       ++found;
3893     }
3894     if (node.name() == "id_not_gt") {
3895       ASSERT_EQ(node.input_size(), 1);
3896       EXPECT_EQ(node.input(0), "gt");
3897       ++found;
3898     }
3899     if (node.name() == "id_not_ge") {
3900       ASSERT_EQ(node.input_size(), 1);
3901       EXPECT_EQ(node.input(0), "ge");
3902       ++found;
3903     }
3904 
3905     if (node.name() == "eq") {
3906       EXPECT_EQ(node.op(), "NotEqual");
3907       ++found;
3908     }
3909     if (node.name() == "neq") {
3910       EXPECT_EQ(node.op(), "Equal");
3911       ++found;
3912     }
3913     if (node.name() == "lt") {
3914       EXPECT_EQ(node.op(), "GreaterEqual");
3915       ++found;
3916     }
3917     if (node.name() == "le") {
3918       EXPECT_EQ(node.op(), "Greater");
3919       ++found;
3920     }
3921     if (node.name() == "gt") {
3922       EXPECT_EQ(node.op(), "LessEqual");
3923       ++found;
3924     }
3925     if (node.name() == "ge") {
3926       EXPECT_EQ(node.op(), "Less");
3927       ++found;
3928     }
3929   }
3930   EXPECT_EQ(found, 12);
3931 
3932   auto tensors = EvaluateNodes(output, item.fetch);
3933   ASSERT_EQ(tensors.size(), tensors_expected.size());
3934   EXPECT_EQ(tensors.size(), item.fetch.size());
3935   for (int i = 0; i < item.fetch.size(); ++i) {
3936     test::ExpectTensorEqual<bool>(tensors[i], tensors_expected[i]);
3937   }
3938 }
3939 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWise)3940 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
3941   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3942   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3943   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3944   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3945   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3946 
3947   GrapplerItem item;
3948   item.fetch = {"final_out"};
3949   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3950   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3951   ASSERT_EQ(tensors_expected.size(), 1);
3952 
3953   GraphDef output;
3954   ArithmeticOptimizer optimizer;
3955   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3956   OptimizeAndPrune(&optimizer, &item, &output);
3957   auto tensors = EvaluateNodes(output, item.fetch);
3958   ASSERT_EQ(tensors.size(), 1);
3959 
3960   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3961   EXPECT_EQ(output.node_size(), item.graph.node_size());
3962   // Check if the inputs are switched
3963   int required_node_count = 0;
3964   for (int i = 0; i < output.node_size(); ++i) {
3965     const NodeDef& node = output.node(i);
3966     if (node.name() == "sqrt") {
3967       EXPECT_EQ(node.op(), "Sqrt");
3968       ASSERT_EQ(node.input_size(), 1);
3969       EXPECT_EQ(node.input(0), "reduce_max");
3970       ++required_node_count;
3971     } else if (node.name() == "reduce_max") {
3972       EXPECT_EQ(node.op(), "Max");
3973       ASSERT_EQ(node.input_size(), 2);
3974       EXPECT_EQ(node.input(0), "x");
3975       ++required_node_count;
3976     }
3977   }
3978   EXPECT_EQ(required_node_count, 2);
3979 }
3980 
TEST_F(ArithmeticOptimizerTest,OptimizeArgMaxOrArgMinOfMonotonicElementWise)3981 TEST_F(ArithmeticOptimizerTest, OptimizeArgMaxOrArgMinOfMonotonicElementWise) {
3982   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3983   const auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3984   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3985   Output arg_max = ops::ArgMax(s.WithOpName("arg_max"), sqrt, 1);
3986   Output final_out = ops::Identity(s.WithOpName("final_out"), arg_max);
3987 
3988   GrapplerItem item;
3989   item.fetch = {"final_out"};
3990   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3991   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3992   ASSERT_EQ(tensors_expected.size(), 1);
3993 
3994   GraphDef output;
3995   ArithmeticOptimizer optimizer;
3996   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3997   OptimizeAndPrune(&optimizer, &item, &output);
3998   const auto tensors = EvaluateNodes(output, item.fetch);
3999   ASSERT_EQ(tensors.size(), 1);
4000 
4001   test::ExpectTensorEqual<int64>(tensors[0], tensors_expected[0]);
4002   EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
4003   // Check if the inputs are switched
4004   int required_node_count = 0;
4005   for (int i = 0; i < output.node_size(); ++i) {
4006     const NodeDef& node = output.node(i);
4007     if (node.name() == "final_out") {
4008       EXPECT_EQ(node.op(), "Identity");
4009       ASSERT_EQ(node.input_size(), 1);
4010       EXPECT_EQ(node.input(0), "arg_max");
4011       ++required_node_count;
4012     } else if (node.name() == "arg_max") {
4013       EXPECT_EQ(node.op(), "ArgMax");
4014       ASSERT_EQ(node.input_size(), 2);
4015       EXPECT_EQ(node.input(0), "x");
4016       ++required_node_count;
4017     }
4018   }
4019   EXPECT_EQ(required_node_count, 2);
4020 }
4021 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode)4022 TEST_F(ArithmeticOptimizerTest,
4023        OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode) {
4024   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4025   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4026   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4027   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
4028   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
4029 
4030   GrapplerItem item;
4031   item.fetch = {"sqrt", "final_out"};
4032   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4033   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4034   EXPECT_EQ(tensors_expected.size(), 2);
4035 
4036   GraphDef output;
4037   ArithmeticOptimizer optimizer;
4038   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4039   OptimizeTwice(&optimizer, &item, &output);
4040 
4041   // Should be a NoOp since we are not allowed to change the output of fetch
4042   // nodes.
4043   VerifyGraphsMatch(item.graph, output, __LINE__);
4044 }
4045 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction)4046 TEST_F(ArithmeticOptimizerTest,
4047        OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction) {
4048   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4049   auto x = ops::Const(s.WithOpName("x"), {2, 3}, {1, 2});
4050   Output reshape = ops::Reshape(s.WithOpName("reshape"), x, {-1});
4051   Output y = ops::Neg(s.WithOpName("y"), reshape);
4052   Output z = ops::Max(s.WithOpName("z"), y, {0});
4053 
4054   GrapplerItem item;
4055   item.fetch = {"z"};
4056   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4057   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4058   ASSERT_EQ(tensors_expected.size(), 1);
4059 
4060   GraphDef output;
4061   ArithmeticOptimizer optimizer;
4062   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4063   OptimizeTwice(&optimizer, &item, &output);
4064 
4065   // Should be a NoOp since we are not allowed to change the output of fetch
4066   // nodes.
4067   VerifyGraphsMatch(item.graph, output, __LINE__);
4068 
4069   auto tensors = EvaluateNodes(output, item.fetch);
4070   ASSERT_EQ(tensors.size(), 1);
4071   test::ExpectTensorEqual<int>(tensors[0], tensors_expected[0]);
4072   test::ExpectTensorEqual<int>(tensors[0], Tensor(-2));
4073 }
4074 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing)4075 TEST_F(ArithmeticOptimizerTest,
4076        OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
4077   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4078   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4079   Output neg = ops::Neg(s.WithOpName("neg"), x);
4080   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
4081   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
4082 
4083   GrapplerItem item;
4084   item.fetch = {"final_out"};
4085   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4086   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4087   ASSERT_EQ(tensors_expected.size(), 1);
4088 
4089   GraphDef output;
4090   ArithmeticOptimizer optimizer;
4091   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4092   OptimizeAndPrune(&optimizer, &item, &output);
4093   auto tensors = EvaluateNodes(output, item.fetch);
4094   ASSERT_EQ(tensors.size(), 1);
4095 
4096   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4097   EXPECT_EQ(output.node_size(), item.graph.node_size());
4098   // Check if the inputs are switched
4099   int required_node_count = 0;
4100   for (int i = 0; i < output.node_size(); ++i) {
4101     const NodeDef& node = output.node(i);
4102     if (node.name() == "neg") {
4103       EXPECT_EQ(node.op(), "Neg");
4104       ASSERT_EQ(node.input_size(), 1);
4105       EXPECT_EQ(node.input(0), "reduce_max");
4106       ++required_node_count;
4107     } else if (node.name() == "reduce_max") {
4108       EXPECT_EQ(node.op(), "Min");
4109       ASSERT_EQ(node.input_size(), 2);
4110       EXPECT_EQ(node.input(0), "x");
4111       ++required_node_count;
4112     }
4113   }
4114   EXPECT_EQ(2, required_node_count);
4115 }
4116 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool)4117 TEST_F(ArithmeticOptimizerTest,
4118        OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool) {
4119   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4120   auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
4121   Output neg = ops::Neg(s.WithOpName("neg"), x);
4122   Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), neg, {1, 2, 2, 1},
4123                                  {1, 2, 2, 1}, "VALID");
4124 
4125   GrapplerItem item;
4126   item.fetch = {"max_pool"};
4127   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4128   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4129   ASSERT_EQ(tensors_expected.size(), 1);
4130 
4131   GraphDef output;
4132   ArithmeticOptimizer optimizer;
4133   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4134   OptimizeTwice(&optimizer, &item, &output);
4135 
4136   // Should be a NoOp
4137   VerifyGraphsMatch(item.graph, output, __LINE__);
4138 
4139   auto tensors = EvaluateNodes(output, item.fetch);
4140   ASSERT_EQ(tensors.size(), 1);
4141   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4142 }
4143 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool)4144 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool) {
4145   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4146   Output weights = ops::Const(s.WithOpName("weights"),
4147                               Input::Initializer(1.0f, {5, 5, 3, 4}));
4148   Output biases =
4149       ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
4150   Output input = ops::Const(s.WithOpName("input"),
4151                             Input::Initializer(1.0f, {1, 28, 28, 3}));
4152   Output output =
4153       ops::Conv2D(s.WithOpName("conv"), input, weights, {1, 1, 1, 1}, "SAME");
4154   output = ops::BiasAdd(s.WithOpName("biasadd"), output, biases);
4155   output = ops::Relu(s.WithOpName("relu"), output);
4156   output = ops::MaxPool(s.WithOpName("max_pool"), output, {1, 2, 2, 1},
4157                         {1, 2, 2, 1}, "VALID");
4158   output = ops::Identity(s.WithOpName("output"), output);
4159 
4160   GrapplerItem item;
4161   item.fetch = {"output"};
4162   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4163   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4164   ASSERT_EQ(tensors_expected.size(), 1);
4165 
4166   GraphDef new_graph;
4167   ArithmeticOptimizer optimizer;
4168   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4169   OptimizeTwice(&optimizer, &item, &new_graph);
4170 
4171   // Should be a NoOp
4172   VerifyGraphsMatch(item.graph, new_graph, __LINE__);
4173 
4174   auto tensors = EvaluateNodes(new_graph, item.fetch);
4175   ASSERT_EQ(tensors.size(), 1);
4176   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4177 }
4178 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseMaxPool)4179 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWiseMaxPool) {
4180   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4181   auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
4182   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4183   Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), sqrt, {1, 2, 2, 1},
4184                                  {1, 2, 2, 1}, "VALID");
4185   Output final_out = ops::Identity(s.WithOpName("final_out"), max_pool);
4186 
4187   GrapplerItem item;
4188   item.fetch = {"final_out"};
4189   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4190   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4191   ASSERT_EQ(tensors_expected.size(), 1);
4192 
4193   GraphDef output;
4194   ArithmeticOptimizer optimizer;
4195   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4196   OptimizeAndPrune(&optimizer, &item, &output);
4197   auto tensors = EvaluateNodes(output, item.fetch);
4198   ASSERT_EQ(tensors.size(), 1);
4199 
4200   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4201   EXPECT_EQ(output.node_size(), item.graph.node_size());
4202   // Check if the inputs are switched
4203   int required_node_count = 0;
4204   for (int i = 0; i < output.node_size(); ++i) {
4205     const NodeDef& node = output.node(i);
4206     if (node.name() == "sqrt") {
4207       EXPECT_EQ(node.op(), "Sqrt");
4208       ASSERT_EQ(node.input_size(), 1);
4209       EXPECT_EQ(node.input(0), "max_pool");
4210       ++required_node_count;
4211     } else if (node.name() == "max_pool") {
4212       EXPECT_EQ(node.op(), "MaxPool");
4213       ASSERT_EQ(node.input_size(), 1);
4214       EXPECT_EQ(node.input(0), "x");
4215       ++required_node_count;
4216     }
4217   }
4218   EXPECT_EQ(required_node_count, 2);
4219 }
4220 
TEST_F(ArithmeticOptimizerTest,UnaryOpsComposition)4221 TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
4222   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4223 
4224   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4225   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4226   Output log = ops::Log(s.WithOpName("log"), sqrt);
4227   Output relu = ops::Relu(s.WithOpName("relu"), log);
4228   Output final_out = ops::Identity(s.WithOpName("final_out"), relu);
4229 
4230   GrapplerItem item;
4231   item.fetch = {"final_out"};
4232   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4233 
4234   // Place all nodes on CPU.
4235   for (int i = 0; i < item.graph.node_size(); ++i) {
4236     item.graph.mutable_node(i)->set_device("/device:CPU:0");
4237   }
4238 
4239   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4240   ASSERT_EQ(tensors_expected.size(), 1);
4241 
4242   GraphDef output;
4243   ArithmeticOptimizer optimizer;
4244   EnableOnlyUnaryOpsComposition(&optimizer);
4245   OptimizeAndPrune(&optimizer, &item, &output);
4246 
4247   EXPECT_EQ(output.node_size(), 3);
4248 
4249   // Check that Sqrt/Log/Relu were replaced with a single op.
4250   int required_node_count = 0;
4251   for (int i = 0; i < output.node_size(); ++i) {
4252     const NodeDef& node = output.node(i);
4253     if (node.name() == "final_out") {
4254       EXPECT_EQ(node.op(), "Identity");
4255       ASSERT_EQ(node.input_size(), 1);
4256       EXPECT_EQ(node.input(0), "relu/unary_ops_composition");
4257       ++required_node_count;
4258     } else if (node.name() == "relu/unary_ops_composition") {
4259       EXPECT_EQ(node.op(), "_UnaryOpsComposition");
4260       ASSERT_EQ(node.input_size(), 1);
4261       EXPECT_EQ(node.input(0), "x");
4262 
4263       auto op_names = node.attr().at("op_names").list().s();
4264       ASSERT_EQ(op_names.size(), 3);
4265       EXPECT_EQ(op_names[0], "Sqrt");
4266       EXPECT_EQ(op_names[1], "Log");
4267       EXPECT_EQ(op_names[2], "Relu");
4268       ++required_node_count;
4269     }
4270   }
4271   EXPECT_EQ(required_node_count, 2);
4272 
4273   auto tensors = EvaluateNodes(output, item.fetch);
4274   ASSERT_EQ(tensors.size(), 1);
4275   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4276 }
4277 
TEST_F(ArithmeticOptimizerTest,RemoveStackStridedSliceSameAxis)4278 TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) {
4279   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4280   auto a_in =
4281       ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4282   auto b_in =
4283       ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
4284   auto c_in =
4285       ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
4286   auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
4287                                        PartialTensorShape({-1, -1}));
4288   auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
4289                                        PartialTensorShape({-1, -1}));
4290   auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
4291                                        PartialTensorShape({-1, -1}));
4292   // stacked = tf.stack((a, b, c), axis=1).
4293   // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
4294   auto stacked =
4295       ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
4296                  ops::Stack::Axis(1));
4297   auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
4298   auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
4299   auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
4300   auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
4301   auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3});
4302   auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
4303   auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3});
4304   auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
4305   auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3});
4306   auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3});
4307   auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3});
4308 
4309   // stacked[:, 0]
4310   using SS = ops::StridedSlice;
4311   auto pa_slice = ops::Identity(
4312       s.WithOpName("pa_slice_out"),
4313       SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides,
4314          SS::BeginMask(0b0101)  // 5
4315              .EllipsisMask(0)
4316              .EndMask(0b0101)  // 5
4317              .NewAxisMask(0)
4318              .ShrinkAxisMask(0b0010)));  // 2
4319 
4320   // stacked[:, 1]
4321   auto pb_slice = ops::Identity(
4322       s.WithOpName("pb_slice_out"),
4323       SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides,
4324          SS::BeginMask(0b0101)  // 5
4325              .EllipsisMask(0)
4326              .EndMask(0b0101)  // 5
4327              .NewAxisMask(0)
4328              .ShrinkAxisMask(0b0010)));  // 2
4329 
4330   // stacked[:, 2]
4331   auto pc_slice = ops::Identity(
4332       s.WithOpName("pc_slice_out"),
4333       SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides,
4334          SS::BeginMask(0b0101)  // 5
4335              .EllipsisMask(0)
4336              .EndMask(0b0101)  // 5
4337              .NewAxisMask(0)
4338              .ShrinkAxisMask(0b0010)));  // 2
4339 
4340   // stacked[:, 0:1, :]
4341   auto pa_slice_01 = ops::Identity(
4342       s.WithOpName("pa_slice_01_out"),
4343       SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides,
4344          SS::BeginMask(0b0101)  // 5
4345              .EllipsisMask(0)
4346              .EndMask(0b0101)  // 5
4347              .NewAxisMask(0)
4348              .ShrinkAxisMask(0)));
4349 
4350   // stacked[:, :1, :]
4351   auto pa_slice_to1 = ops::Identity(
4352       s.WithOpName("pa_slice_to1_out"),
4353       SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides,
4354          SS::BeginMask(0b0111)  // 7
4355              .EllipsisMask(0)
4356              .EndMask(0b0101)  // 5
4357              .NewAxisMask(0)
4358              .ShrinkAxisMask(0)));
4359 
4360   // stacked[:, 1:2, :]
4361   auto pb_slice_12 = ops::Identity(
4362       s.WithOpName("pb_slice_12_out"),
4363       SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides,
4364          SS::BeginMask(0b0101)  // 5
4365              .EllipsisMask(0)
4366              .EndMask(0b0101)  // 5
4367              .NewAxisMask(0)
4368              .ShrinkAxisMask(0)));
4369 
4370   // stacked[:, 2:, :].
4371   auto pc_slice_2to = ops::Identity(
4372       s.WithOpName("pc_slice_2to_out"),
4373       SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides,
4374          SS::BeginMask(0b0101)  // 5
4375              .EllipsisMask(0)
4376              .EndMask(0b0111)  // 7
4377              .NewAxisMask(0)
4378              .ShrinkAxisMask(0)));
4379 
4380   GrapplerItem item;
4381   item.fetch = {"a",
4382                 "b",
4383                 "c",
4384                 "pa_slice_out",
4385                 "pb_slice_out",
4386                 "pc_slice_out",
4387                 "expanded_a",
4388                 "expanded_b",
4389                 "expanded_c",
4390                 "pa_slice_01_out",
4391                 "pa_slice_to1_out",
4392                 "pb_slice_12_out",
4393                 "pc_slice_2to_out"};
4394   enum FetchItem {
4395     fA,
4396     fB,
4397     fC,
4398     fASliceOut,
4399     fBSliceOut,
4400     fCSliceOut,
4401     fExpandedA,
4402     fExpandedB,
4403     fExpandedC,
4404     fASlice01Out,
4405     fASliceTo1Out,
4406     fBSlice12Out,
4407     fCSlice2ToOut,
4408   };
4409   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4410   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4411 
4412   // stacked[:, 0, :] == a.
4413   test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
4414                                  tensors_expected[fA]);
4415   // stacked[:, 1, :] == b.
4416   test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
4417                                  tensors_expected[fB]);
4418   // stacked[:, 2, :] == c.
4419   test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
4420                                  tensors_expected[fC]);
4421 
4422   // stacked[:, 0:1, :] == expand_dims(a, 1).
4423   test::ExpectTensorEqual<float>(tensors_expected[fASlice01Out],
4424                                  tensors_expected[fExpandedA]);
4425 
4426   // stacked[:, :1, :] == expand_dims(a, 1).
4427   test::ExpectTensorEqual<float>(tensors_expected[fASliceTo1Out],
4428                                  tensors_expected[fExpandedA]);
4429 
4430   // stacked[:, 1:2, :] == expand_dims(b, 1).
4431   test::ExpectTensorEqual<float>(tensors_expected[fBSlice12Out],
4432                                  tensors_expected[fExpandedB]);
4433   // stacked[:, 2:, :] == expand_dims(c, 1).
4434   test::ExpectTensorEqual<float>(tensors_expected[fCSlice2ToOut],
4435                                  tensors_expected[fExpandedC]);
4436 
4437   GraphDef output;
4438   ArithmeticOptimizer optimizer;
4439   EnableOnlyRemoveStackSliceSameAxis(&optimizer);
4440   OptimizeAndPrune(&optimizer, &item, &output);
4441 
4442   for (const auto& node : output.node()) {
4443     if (node.name() == "pa_slice_out") {
4444       ASSERT_EQ(node.input_size(), 1);
4445       EXPECT_EQ(node.input(0), "a");
4446     } else if (node.name() == "pb_slice_out") {
4447       ASSERT_EQ(node.input_size(), 1);
4448       EXPECT_EQ(node.input(0), "b");
4449     } else if (node.name() == "pc_slice_out") {
4450       ASSERT_EQ(node.input_size(), 1);
4451       EXPECT_EQ(node.input(0), "c");
4452     } else if (str_util::EndsWith(node.name(), "_out")) {
4453       ASSERT_EQ(node.input_size(), 1);
4454       EXPECT_EQ(
4455           absl::StrCat(node.input(0), "_out"),
4456           absl::StrCat("ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_",
4457                        node.name()));
4458     }
4459   }
4460 
4461   auto tensors = EvaluateNodes(output, item.fetch);
4462 
4463   // stacked[:, 0, :] == a.
4464   test::ExpectTensorEqual<float>(tensors[fASliceOut], tensors_expected[fA]);
4465 
4466   // stacked[:, 1, :] == b.
4467   test::ExpectTensorEqual<float>(tensors[fBSliceOut], tensors_expected[fB]);
4468   // stacked[:, 2, :] == c.
4469   test::ExpectTensorEqual<float>(tensors[fCSliceOut], tensors_expected[fC]);
4470 
4471   // stacked[:, 0:1, :] == expand_dims(a, 1).
4472   test::ExpectTensorEqual<float>(tensors[fASlice01Out],
4473                                  tensors_expected[fExpandedA]);
4474 
4475   // stacked[:, :1, :] == expand_dims(a, 1).
4476   test::ExpectTensorEqual<float>(tensors[fASliceTo1Out],
4477                                  tensors_expected[fExpandedA]);
4478 
4479   // stacked[:, 1:2, :] == expand_dims(b, 1).
4480   test::ExpectTensorEqual<float>(tensors[fBSlice12Out],
4481                                  tensors_expected[fExpandedB]);
4482   // stacked[:, 2:, :] == expand_dims(c, 1).
4483   test::ExpectTensorEqual<float>(tensors[fCSlice2ToOut],
4484                                  tensors_expected[fExpandedC]);
4485 }
4486 
TEST_F(ArithmeticOptimizerTest,RemoveStackSimpleSliceSameAxis)4487 TEST_F(ArithmeticOptimizerTest, RemoveStackSimpleSliceSameAxis) {
4488   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4489   auto a_in =
4490       ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4491   auto b_in =
4492       ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
4493   auto c_in =
4494       ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
4495   auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
4496                                        PartialTensorShape({-1, -1}));
4497   auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
4498                                        PartialTensorShape({-1, -1}));
4499   auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
4500                                        PartialTensorShape({-1, -1}));
4501   // stacked = tf.stack((a, b, c), axis=1).
4502   // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
4503   auto stacked =
4504       ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
4505                  ops::Stack::Axis(1));
4506   auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
4507   auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
4508   auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
4509   auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
4510   auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
4511   auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
4512   auto sizes_to_end = ops::Const(s.WithOpName("size"), {-1, 1, -1}, {3});
4513 
4514   // stacked[:, 0:1, :]
4515   auto pa_slice = ops::Identity(
4516       s.WithOpName("pa_slice_out"),
4517       ops::Slice(s.WithOpName("pa_slice"), stacked, begin_a, sizes_to_end));
4518 
4519   // stacked[:, 1:2, :]
4520   auto pb_slice = ops::Identity(
4521       s.WithOpName("pb_slice_out"),
4522       ops::Slice(s.WithOpName("pb_slice"), stacked, begin_b, sizes_to_end));
4523 
4524   // stacked[:, 2:3, :]
4525   auto pc_slice = ops::Identity(
4526       s.WithOpName("pc_slice_out"),
4527       ops::Slice(s.WithOpName("pc_slice"), stacked, begin_c, sizes_to_end));
4528 
4529   GrapplerItem item;
4530   item.fetch = {"a",
4531                 "b",
4532                 "c",
4533                 "pa_slice_out",
4534                 "pb_slice_out",
4535                 "pc_slice_out",
4536                 "expanded_a",
4537                 "expanded_b",
4538                 "expanded_c"};
4539   enum FetchItem {
4540     fA,
4541     fB,
4542     fC,
4543     fASliceOut,
4544     fBSliceOut,
4545     fCSliceOut,
4546     fExpandedA,
4547     fExpandedB,
4548     fExpandedC,
4549   };
4550   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4551   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4552 
4553   // stacked[:, 0:1, :] == a.
4554   test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
4555                                  tensors_expected[fExpandedA]);
4556   // stacked[:, 1:2, :] == b.
4557   test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
4558                                  tensors_expected[fExpandedB]);
4559   // stacked[:, 2:3, :] == c.
4560   test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
4561                                  tensors_expected[fExpandedC]);
4562 
4563   GraphDef output;
4564   ArithmeticOptimizer optimizer;
4565   EnableOnlyRemoveStackSliceSameAxis(&optimizer);
4566   OptimizeAndPrune(&optimizer, &item, &output);
4567 
4568   const string kExpandDimsNamePrefix(
4569       "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_p");
4570 
4571   for (const auto& node : output.node()) {
4572     if (node.name() == "pa_slice_out") {
4573       ASSERT_EQ(node.input_size(), 1);
4574       EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "a_slice"));
4575     } else if (node.name() == "pb_slice_out") {
4576       ASSERT_EQ(node.input_size(), 1);
4577       EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "b_slice"));
4578     } else if (node.name() == "pc_slice_out") {
4579       ASSERT_EQ(node.input_size(), 1);
4580       EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "c_slice"));
4581     } else if (absl::StartsWith(node.name(), kExpandDimsNamePrefix)) {
4582       EXPECT_EQ(node.op(), "ExpandDims");
4583       // The input is "a", "b", or "c", as appropriate.
4584       EXPECT_EQ(node.input(0),
4585                 node.name().substr(kExpandDimsNamePrefix.size(), 1));
4586     }
4587   }
4588 
4589   auto tensors = EvaluateNodes(output, item.fetch);
4590 
4591   // stacked[:, 0:1, :] == a.
4592   test::ExpectTensorEqual<float>(tensors[fASliceOut],
4593                                  tensors_expected[fExpandedA]);
4594 
4595   // stacked[:, 1:2, :] == b.
4596   test::ExpectTensorEqual<float>(tensors[fBSliceOut],
4597                                  tensors_expected[fExpandedB]);
4598   // stacked[:, 2:3, :] == c.
4599   test::ExpectTensorEqual<float>(tensors[fCSliceOut],
4600                                  tensors_expected[fExpandedC]);
4601 }
4602 
TEST_F(ArithmeticOptimizerTest,SimplifyAggregationBFloat16)4603 TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
4604   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4605   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4606   Output cast = ops::Cast(s.WithOpName("cast"), x, DT_BFLOAT16);
4607   Output add = ops::AddN(s.WithOpName("add"), {cast, cast});
4608   Output id = ops::Identity(s.WithOpName("id"), add);
4609 
4610   GrapplerItem item;
4611   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4612   item.fetch = {"id"};
4613   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4614   ASSERT_EQ(tensors_expected.size(), 1);
4615 
4616   GraphDef output;
4617   ArithmeticOptimizer optimizer;
4618   EnableOnlySimplifyAggregation(&optimizer);
4619   OptimizeAndPrune(&optimizer, &item, &output);
4620 
4621   // Extra node created for multiplier.
4622   EXPECT_EQ(output.node_size(), 5);
4623 
4624   auto tensors = EvaluateNodes(output, item.fetch);
4625   ASSERT_EQ(tensors.size(), 1);
4626   test::ExpectTensorEqual<bfloat16>(tensors[0], tensors_expected[0]);
4627 }
4628 
TEST_F(ArithmeticOptimizerTest,SimplifyEmbeddingLookup)4629 TEST_F(ArithmeticOptimizerTest, SimplifyEmbeddingLookup) {
4630   for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
4631     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4632     Output embeddings = ops::Const(s.WithOpName("embeddings"),
4633                                    {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4634     Output segment_ids =
4635         ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
4636     Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
4637     auto unique = ops::Unique(s.WithOpName("unique"), indices,
4638                               /*attrs=*/{unique_idx_type});
4639     Output ids = unique.y;
4640     Output idx = unique.idx;
4641     Output gathered_rows =
4642         ops::Gather(s.WithOpName("gathered_rows"), embeddings, ids);
4643     Output result = ops::SparseSegmentSum(s.WithOpName("result"), gathered_rows,
4644                                           idx, segment_ids);
4645     Output id = ops::Identity(s.WithOpName("id"), result);
4646 
4647     GrapplerItem item;
4648     TF_CHECK_OK(s.ToGraphDef(&item.graph));
4649     item.fetch = {"id"};
4650     auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4651     ASSERT_EQ(tensors_expected.size(), 1);
4652 
4653     GraphDef output;
4654     ArithmeticOptimizer optimizer;
4655     EnableOnlySimplifyEmbeddingLookup(&optimizer);
4656     OptimizeAndPrune(&optimizer, &item, &output);
4657 
4658     for (const auto& node : output.node()) {
4659       if (node.name() == "result") {
4660         EXPECT_EQ(node.input(0), "embeddings");
4661         EXPECT_EQ(node.input(1), "indices");
4662       }
4663       EXPECT_NE(node.op(), "Unique");
4664       EXPECT_NE(node.op(), "Gather");
4665     }
4666 
4667     auto tensors = EvaluateNodes(output, item.fetch);
4668     ASSERT_EQ(tensors.size(), 1);
4669     test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4670   }
4671 }
4672 
TEST_F(ArithmeticOptimizerTest,SimplifyResourceEmbeddingLookup)4673 TEST_F(ArithmeticOptimizerTest, SimplifyResourceEmbeddingLookup) {
4674   for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
4675     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4676     Output embeddings = ops::Const(s.WithOpName("embeddings"),
4677                                    {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4678     Output segment_ids =
4679         ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
4680     Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
4681     auto unique = ops::Unique(s.WithOpName("unique"), indices,
4682                               /*attrs=*/{unique_idx_type});
4683     Output ids = unique.y;
4684     Output idx = unique.idx;
4685 
4686     auto var =
4687         ops::VarHandleOp(s.WithOpName("var"), DT_FLOAT, TensorShape({2, 2}));
4688     ops::AssignVariableOp assign_op(s.WithOpName("assign_var_handle"), var,
4689                                     embeddings);
4690 
4691     Output gathered_rows = ops::ResourceGather(
4692         s.WithOpName("gathered_rows")
4693             .WithControlDependencies(std::vector<Operation>{assign_op}),
4694         var, ids, DT_FLOAT);
4695     gathered_rows.node()->AddAttr("_class", {"test_class"});
4696     Output result =
4697         ops::SparseSegmentSum(s.WithOpName("result").WithControlDependencies(
4698                                   std::vector<Operation>{assign_op}),
4699                               gathered_rows, idx, segment_ids);
4700     Output id = ops::Identity(s.WithOpName("id"), result);
4701 
4702     GrapplerItem item;
4703     item.init_ops.push_back("assign_var_handle");
4704     TF_CHECK_OK(s.ToGraphDef(&item.graph));
4705     item.fetch = {"id"};
4706     auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4707     ASSERT_EQ(tensors_expected.size(), 1);
4708 
4709     GraphDef output;
4710     ArithmeticOptimizer optimizer;
4711     EnableOnlySimplifyEmbeddingLookup(&optimizer);
4712     OptimizeAndPrune(&optimizer, &item, &output);
4713     bool read_var_node_found = false;
4714     for (const auto& node : output.node()) {
4715       if (node.name() == "result") {
4716         EXPECT_EQ(
4717             node.input(0),
4718             "ArithmeticOptimizer/SimplifyEmbeddingLookupStage_ReadVar_result");
4719         EXPECT_EQ(node.input(1), "indices");
4720       }
4721       if (node.op() == "ReadVariableOp") {
4722         read_var_node_found = true;
4723         EXPECT_EQ(node.attr().at("_class").list().s(0), "test_class");
4724       }
4725       EXPECT_NE(node.op(), "Unique");
4726       EXPECT_NE(node.op(), "Gather");
4727     }
4728     EXPECT_TRUE(read_var_node_found);
4729     // Add a control dependency to the ReadVar to do the AssignVar first. This
4730     // shouldn't be an issue in actual use as variables are assumed initialized
4731     // during setup.
4732     for (int i = 0; i < output.node_size(); ++i) {
4733       if (output.node(i).name() ==
4734           "ArithmeticOptimizer/SimplifyEmbeddingLookupStage_ReadVar_result") {
4735         output.mutable_node(i)->add_input("^assign_var_handle");
4736       }
4737     }
4738 
4739     auto tensors = EvaluateNodes(output, item.fetch);
4740     ASSERT_EQ(tensors.size(), 1);
4741     test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4742   }
4743 }
4744 
TEST_F(ArithmeticOptimizerTest,RemoveCastIntoSegmentReduction)4745 TEST_F(ArithmeticOptimizerTest, RemoveCastIntoSegmentReduction) {
4746   for (DataType indices_type : {DT_INT32, DT_INT64}) {
4747     for (DataType segment_ids_type : {DT_INT32, DT_INT64}) {
4748       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4749       Output embeddings = ops::Const(s.WithOpName("embeddings"),
4750                                      {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4751       Output indices =
4752           ops::Cast(s.WithOpName("cast_indices"),
4753                     ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1}),
4754                     indices_type);
4755       Output segment_ids = ops::Cast(
4756           s.WithOpName("cast_segment_ids"),
4757           ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2}),
4758           segment_ids_type);
4759       Output result = ops::SparseSegmentSum(s.WithOpName("result"), embeddings,
4760                                             indices, segment_ids);
4761       Output id = ops::Identity(s.WithOpName("id"), result);
4762 
4763       GrapplerItem item;
4764       TF_CHECK_OK(s.ToGraphDef(&item.graph));
4765       item.fetch = {"id"};
4766       auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4767       ASSERT_EQ(tensors_expected.size(), 1);
4768 
4769       GraphDef output;
4770       ArithmeticOptimizer optimizer;
4771       EnableOnlyRemoveCastIntoSegmentReduction(&optimizer);
4772       OptimizeAndPrune(&optimizer, &item, &output);
4773 
4774       for (const auto& node : output.node()) {
4775         if (node.name() == "result") {
4776           EXPECT_EQ(node.input(1), "indices");
4777           EXPECT_EQ(node.input(2), "segment_ids");
4778         }
4779         EXPECT_NE(node.op(), "Cast");
4780       }
4781 
4782       auto tensors = EvaluateNodes(output, item.fetch);
4783       ASSERT_EQ(tensors.size(), 1);
4784       test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4785     }
4786   }
4787 }
4788 
4789 }  // namespace grappler
4790 }  // namespace tensorflow
4791