1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17
18 #include <complex>
19
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/cc/ops/array_ops.h"
23 #include "tensorflow/cc/ops/math_ops.h"
24 #include "tensorflow/cc/ops/resource_variable_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/tensor_testutil.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
30 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
31 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
32 #include "tensorflow/core/grappler/utils.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace tensorflow {
37 namespace grappler {
38
39 namespace {
40
41 constexpr char kHoistFactorOptimizerDiv[] =
42 "ArithmeticOptimizer/HoistCommonFactor_Div_";
43
44 constexpr char kHoistFactorOptimizerMul[] =
45 "ArithmeticOptimizer/HoistCommonFactor_Mul_";
46
47 constexpr char kHoistFactorOptimizerAdd[] =
48 "ArithmeticOptimizer/HoistCommonFactor_AddV2_";
49
50 constexpr char kSimplifyAggregationConst[] =
51 "ArithmeticOptimizer/SimplifyAggregation_Const_";
52
53 constexpr char kSimplifyAggregationMul[] =
54 "ArithmeticOptimizer/SimplifyAggregation_Mul_";
55
56 // Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation.
HoistMulName(const string & name)57 string HoistMulName(const string& name) {
58 return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
59 }
60
61 // Optimized name of outer Div node by HoistCommonFactorOutOfAggregation.
HoistDivName(const string & name)62 string HoistDivName(const string& name) {
63 return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
64 }
65
66 // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation.
HoistAddName(const string & name)67 string HoistAddName(const string& name) {
68 return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
69 }
70
71 // Optimized name of Const node by SimplifyAggregation.
AggregationConstName(const string & name)72 string AggregationConstName(const string& name) {
73 return AddPrefixToNodeName(name, kSimplifyAggregationConst, "");
74 }
75
76 // Optimized name of Mul node by SimplifyAggregation.
AggregationMulName(const string & name)77 string AggregationMulName(const string& name) {
78 return AddPrefixToNodeName(name, kSimplifyAggregationMul, "");
79 }
80
VerifyGraphsMatch(const GraphDef & original_graph,const GraphDef & optimized_graph,int line)81 void VerifyGraphsMatch(const GraphDef& original_graph,
82 const GraphDef& optimized_graph, int line) {
83 EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
84 for (int i = 0; i < original_graph.node_size(); ++i) {
85 const NodeDef& original = original_graph.node(i);
86 const NodeDef& optimized = optimized_graph.node(i);
87 EXPECT_EQ(original.name(), optimized.name()) << line;
88 EXPECT_EQ(original.op(), optimized.op()) << line;
89 EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
90 for (int j = 0; j < original.input_size(); ++j) {
91 EXPECT_EQ(original.input(j), optimized.input(j)) << line;
92 }
93 }
94 }
95 } // namespace
96
TEST_F(ArithmeticOptimizerTest,NoOp)97 TEST_F(ArithmeticOptimizerTest, NoOp) {
98 // This trivial graph is so basic there's nothing to optimize.
99 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
100 GrapplerItem item;
101 CHECK(fake_input.NextItem(&item));
102
103 ArithmeticOptimizer optimizer;
104 GraphDef output;
105 Status status = optimizer.Optimize(nullptr, item, &output);
106 TF_EXPECT_OK(status);
107 VerifyGraphsMatch(item.graph, output, __LINE__);
108 }
109
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTile)110 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTile) {
111 // Graph from b/176172427
112 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
113 Output input =
114 ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
115 ops::Placeholder::Shape({1, 44, 1, 96, 1, 64}));
116 Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 1, 2, 1, 2, 1});
117 Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
118 Output output = ops::Identity(s.WithOpName("output"), multiply);
119
120 GrapplerItem item;
121 item.fetch = {"output"};
122 TF_CHECK_OK(s.ToGraphDef(&item.graph));
123 auto tensor =
124 GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 44, 1, 96, 1, 64}));
125 auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
126 ASSERT_EQ(expected.size(), 1);
127
128 GraphDef g;
129 ArithmeticOptimizer optimizer;
130 EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
131 OptimizeTwiceAndPrune(&optimizer, &item, &g);
132 EXPECT_EQ(g.node_size(), 4);
133
134 ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
135 ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
136
137 NodeMap node_map(&g);
138 const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
139 const NodeDef* t = node_map.GetNode(absl::StrCat(p, "_", "Tile_mul"));
140 const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
141 ASSERT_NE(t, nullptr);
142 ASSERT_NE(c, nullptr);
143 EXPECT_EQ(t->op(), "Tile");
144 ASSERT_EQ(t->input_size(), 2);
145 EXPECT_EQ(t->input(0), "input");
146 EXPECT_EQ(t->input(1), c->name());
147 EXPECT_EQ(t->attr().at("T").type(), DT_FLOAT);
148 EXPECT_EQ(t->attr().at("Tmultiples").type(), c->attr().at("dtype").type());
149
150 auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
151 ASSERT_EQ(result.size(), 1);
152 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
153 }
154
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTilePreserveControl)155 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTilePreserveControl) {
156 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
157 Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
158 ops::Placeholder::Shape({1, 1, 1}));
159 Output ones = ops::Const(s.WithOpName("ones").WithControlDependencies(input),
160 1.0f, {1, 2, 1});
161 Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
162 Output output = ops::Identity(s.WithOpName("output"), multiply);
163
164 GrapplerItem item;
165 item.fetch = {"output"};
166 TF_CHECK_OK(s.ToGraphDef(&item.graph));
167 auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
168 auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
169 ASSERT_EQ(expected.size(), 1);
170
171 GraphDef g;
172 ArithmeticOptimizer optimizer;
173 EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
174 OptimizeTwiceAndPrune(&optimizer, &item, &g);
175 EXPECT_EQ(g.node_size(), 4);
176
177 ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
178 ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
179
180 NodeMap node_map(&g);
181 const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
182 const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
183 ASSERT_NE(c, nullptr);
184 ASSERT_EQ(c->input_size(), 1);
185 EXPECT_TRUE(IsControlInput(c->input(0)));
186 EXPECT_EQ(c->input(0), "^input");
187 }
188
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNoBroadcast)189 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNoBroadcast) {
190 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
191 Output input =
192 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 2, 1}));
193 Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 2, 1});
194 Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
195 Output output = ops::Identity(s.WithOpName("output"), multiply);
196
197 GrapplerItem item;
198 item.fetch = {"output"};
199 TF_CHECK_OK(s.ToGraphDef(&item.graph));
200 auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
201 auto expected =
202 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
203 ASSERT_EQ(expected.size(), 1);
204
205 GraphDef g;
206 ArithmeticOptimizer optimizer;
207 EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
208 OptimizeTwiceAndPrune(&optimizer, &item, &g);
209 EXPECT_EQ(g.node_size(), 4);
210
211 VerifyGraphsMatch(item.graph, g, __LINE__);
212
213 auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
214 ASSERT_EQ(result.size(), 1);
215 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
216 }
217
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNotConst)218 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotConst) {
219 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
220 Output input1 = ops::Placeholder(s.WithOpName("input1"), DT_FLOAT,
221 ops::Placeholder::Shape({1, 1, 1}));
222 Output input2 = ops::Placeholder(s.WithOpName("input2"), DT_FLOAT,
223 ops::Placeholder::Shape({1, 2, 1}));
224 Output multiply = ops::Mul(s.WithOpName("multiply"), input1, input2);
225 Output output = ops::Identity(s.WithOpName("output"), multiply);
226
227 GrapplerItem item;
228 item.fetch = {"output"};
229 TF_CHECK_OK(s.ToGraphDef(&item.graph));
230 auto tensor1 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
231 auto tensor2 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
232 auto expected = EvaluateNodes(item.graph, item.fetch,
233 {{"input1", tensor1}, {"input2", tensor2}});
234 ASSERT_EQ(expected.size(), 1);
235
236 GraphDef g;
237 ArithmeticOptimizer optimizer;
238 EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
239 OptimizeTwiceAndPrune(&optimizer, &item, &g);
240 EXPECT_EQ(g.node_size(), 4);
241
242 VerifyGraphsMatch(item.graph, g, __LINE__);
243
244 auto result = EvaluateNodes(item.graph, item.fetch,
245 {{"input1", tensor1}, {"input2", tensor2}});
246 ASSERT_EQ(result.size(), 1);
247 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
248 }
249
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNotOnes)250 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotOnes) {
251 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
252 Output input =
253 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 1, 1}));
254 Output ones = ops::Const(s.WithOpName("ones"), 2.0f, {1, 2, 1});
255 Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
256 Output output = ops::Identity(s.WithOpName("output"), multiply);
257
258 GrapplerItem item;
259 item.fetch = {"output"};
260 TF_CHECK_OK(s.ToGraphDef(&item.graph));
261 auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
262 auto expected =
263 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
264 ASSERT_EQ(expected.size(), 1);
265
266 GraphDef g;
267 ArithmeticOptimizer optimizer;
268 EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
269 OptimizeTwiceAndPrune(&optimizer, &item, &g);
270 EXPECT_EQ(g.node_size(), 4);
271
272 VerifyGraphsMatch(item.graph, g, __LINE__);
273
274 auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
275 ASSERT_EQ(result.size(), 1);
276 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
277 }
278
TEST_F(ArithmeticOptimizerTest,ReduceUpsamplingDims)279 TEST_F(ArithmeticOptimizerTest, ReduceUpsamplingDims) {
280 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
281 Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
282 ops::Placeholder::Shape({1, 22, 48, 64}));
283 Output reshape_a = ops::Reshape(
284 s.WithOpName("reshape_a"), input,
285 ops::Const(s.WithOpName("shape_a"), {1, 22, 1, 48, 1, 64}, {6}));
286 Output tile =
287 ops::Tile(s.WithOpName("tile"), reshape_a,
288 ops::Const(s.WithOpName("multiples"), {1, 1, 2, 1, 2, 1}, {6}));
289 Output reshape_b =
290 ops::Reshape(s.WithOpName("reshape_b"), tile,
291 ops::Const(s.WithOpName("shape_b"), {1, 44, 96, 64}));
292 Output output = ops::Identity(s.WithOpName("output"), reshape_b);
293
294 GrapplerItem item;
295 item.fetch = {"output"};
296 TF_CHECK_OK(s.ToGraphDef(&item.graph));
297 auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 22, 48, 64}));
298 auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
299 ASSERT_EQ(expected.size(), 1);
300
301 GraphDef g;
302 ArithmeticOptimizer optimizer;
303 EnableOnlyReduceUpsamplingDims(&optimizer);
304 OptimizeTwiceAndPrune(&optimizer, &item, &g);
305 EXPECT_EQ(g.node_size(), 8);
306
307 ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
308 ASSERT_EQ(CountOpNodes(g, "Reshape"), 2);
309 ASSERT_EQ(CountOpNodes(g, "Const"), 3);
310
311 NodeMap node_map(&g);
312 const string p = "ArithmeticOptimizer/ReduceUpsamplingDims";
313 const NodeDef* ra =
314 node_map.GetNode(absl::StrCat(p, "_", "Reshape_reshape_b"));
315 const NodeDef* rb = node_map.GetNode("reshape_b");
316 const NodeDef* t = node_map.GetNode(absl::StrCat(p, "_", "Tile_reshape_b"));
317 ASSERT_NE(ra, nullptr);
318 ASSERT_NE(rb, nullptr);
319 ASSERT_NE(t, nullptr);
320
321 ASSERT_EQ(rb->input_size(), 2);
322 EXPECT_EQ(rb->input(0), t->name());
323 ASSERT_EQ(t->input_size(), 2);
324 EXPECT_EQ(t->input(0), ra->name());
325 ASSERT_EQ(ra->input_size(), 2);
326 EXPECT_EQ(ra->input(0), "input");
327
328 {
329 auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
330 ASSERT_EQ(result.size(), 1);
331 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
332 }
333
334 // Check to make sure the first reshape is removed
335 EnableOnlyRemoveRedundantReshape(&optimizer);
336 OptimizeTwiceAndPrune(&optimizer, &item, &g);
337 EXPECT_EQ(g.node_size(), 6);
338
339 {
340 auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
341 ASSERT_EQ(result.size(), 1);
342 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
343 }
344 }
345
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithSquare)346 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
347 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
348 Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
349 Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
350 Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
351 Output mul_no_nan = ops::MulNoNan(s.WithOpName("mul_no_nan"), d, d);
352 Output id = ops::Identity(s.WithOpName("id"), mul);
353 Output id2 = ops::Identity(s.WithOpName("id2"), mul_no_nan);
354
355 GrapplerItem item;
356 item.fetch = {"id", "id2"};
357 TF_CHECK_OK(s.ToGraphDef(&item.graph));
358 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
359 ASSERT_EQ(tensors_expected.size(), 2);
360
361 GraphDef output;
362 ArithmeticOptimizer optimizer;
363 EnableOnlyReplaceMulWithSquare(&optimizer);
364 OptimizeAndPrune(&optimizer, &item, &output);
365
366 EXPECT_EQ(output.node_size(), 6);
367
368 NodeMap node_map(&output);
369 const string p = "ArithmeticOptimizer/ReplaceMulWithSquare";
370 const NodeDef* square_node = node_map.GetNode(absl::StrCat(p, "_", "mul"));
371
372 ASSERT_NE(square_node, nullptr);
373 EXPECT_EQ(square_node->op(), "Square");
374 ASSERT_EQ(square_node->input_size(), 2);
375 EXPECT_EQ(square_node->input(0), "c");
376 EXPECT_EQ(square_node->input(1), "^d");
377
378 const NodeDef* square_node2 =
379 node_map.GetNode(absl::StrCat(p, "_", "mul_no_nan"));
380 ASSERT_NE(square_node2, nullptr);
381 EXPECT_EQ(square_node2->op(), "Square");
382 ASSERT_EQ(square_node2->input_size(), 1);
383 EXPECT_EQ(square_node2->input(0), "d");
384
385 auto tensors = EvaluateNodes(output, item.fetch);
386 ASSERT_EQ(tensors.size(), 2);
387 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
388 }
389
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshape)390 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshape) {
391 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
392 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
393 ops::Placeholder::Shape({3, 5, 7, 11}));
394 // Stack creates Pack nodes
395 Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(3));
396 Output c = ops::Stack(s.WithOpName("c"), {b, b}, ops::Stack::Axis(2));
397 Output o = ops::Identity(s.WithOpName("output"), c);
398
399 GrapplerItem item;
400 item.fetch = {"output"};
401 TF_CHECK_OK(s.ToGraphDef(&item.graph));
402 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
403 auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
404 ASSERT_EQ(expected.size(), 1);
405
406 GraphDef g;
407 ArithmeticOptimizer optimizer;
408 EnableOnlyReplacePackWithTileReshape(&optimizer);
409 OptimizeAndPrune(&optimizer, &item, &g);
410
411 EXPECT_EQ(g.node_size(), 6);
412 EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
413 EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
414 EXPECT_EQ(CountOpNodes(g, "Const"), 2);
415 EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
416
417 NodeMap node_map(&g);
418 const string p = "ArithmeticOptimizer/ReplacePackWithTileReshape";
419 const NodeDef* t_node = node_map.GetNode(absl::StrCat(p, "_", "Tile_c"));
420 const NodeDef* c_node = node_map.GetNode(absl::StrCat(p, "_", "Multiples_c"));
421 const NodeDef* s_node = node_map.GetNode(absl::StrCat(p, "_", "Shape_c"));
422 const NodeDef* r_node = node_map.GetNode(absl::StrCat(p, "_", "Reshape_c"));
423 const NodeDef* a_node = node_map.GetNode("a");
424 ASSERT_NE(t_node, nullptr);
425 ASSERT_NE(c_node, nullptr);
426 ASSERT_NE(s_node, nullptr);
427 ASSERT_NE(r_node, nullptr);
428 ASSERT_NE(a_node, nullptr);
429
430 EXPECT_EQ(c_node->op(), "Const");
431 EXPECT_EQ(s_node->op(), "Const");
432
433 // Check Reshape properties
434 ASSERT_EQ(r_node->input_size(), 2);
435 EXPECT_EQ(r_node->op(), "Reshape");
436 EXPECT_EQ(r_node->input(0), t_node->name());
437 EXPECT_EQ(r_node->input(1), s_node->name());
438
439 // Check Tile properties
440 ASSERT_EQ(t_node->input_size(), 2);
441 EXPECT_EQ(t_node->op(), "Tile");
442 EXPECT_EQ(t_node->input(0), a_node->name());
443 EXPECT_EQ(t_node->input(1), c_node->name());
444 EXPECT_EQ(t_node->attr().at("T").type(), DT_FLOAT);
445 EXPECT_EQ(t_node->attr().at("Tmultiples").type(),
446 c_node->attr().at("dtype").type());
447
448 auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
449 ASSERT_EQ(result.size(), 1);
450 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
451 }
452
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshapeControlDeps)453 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshapeControlDeps) {
454 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
455 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
456 ops::Placeholder::Shape({3, 5, 7, 11}));
457
458 Output x = ops::Identity(s.WithOpName("x"), a);
459 Output y = ops::Identity(s.WithOpName("y"), a);
460
461 Output b = ops::Stack(s.WithOpName("b").WithControlDependencies(x), {a, a},
462 ops::Stack::Axis(3));
463 Output c = ops::Stack(s.WithOpName("c").WithControlDependencies(y), {b, b},
464 ops::Stack::Axis(2));
465 Output o = ops::Identity(s.WithOpName("output"), c);
466
467 GrapplerItem item;
468 item.fetch = {"output"};
469 item.keep_ops = {"x", "y"};
470 TF_CHECK_OK(s.ToGraphDef(&item.graph));
471 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
472 auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
473 ASSERT_EQ(expected.size(), 1);
474
475 GraphDef g;
476 ArithmeticOptimizer optimizer;
477 EnableOnlyReplacePackWithTileReshape(&optimizer);
478 OptimizeAndPrune(&optimizer, &item, &g);
479
480 EXPECT_EQ(g.node_size(), 8);
481 EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
482 EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
483 EXPECT_EQ(CountOpNodes(g, "Const"), 2);
484 EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
485 EXPECT_EQ(CountOpNodes(g, "Identity"), 3);
486
487 NodeMap node_map(&g);
488 const string p = "ArithmeticOptimizer/ReplacePackWithTileReshape";
489 const NodeDef* t_node = node_map.GetNode(absl::StrCat(p, "_", "Tile_c"));
490 const NodeDef* c_node = node_map.GetNode(absl::StrCat(p, "_", "Multiples_c"));
491 const NodeDef* s_node = node_map.GetNode(absl::StrCat(p, "_", "Shape_c"));
492 const NodeDef* a_node = node_map.GetNode("a");
493 ASSERT_NE(t_node, nullptr);
494 ASSERT_NE(c_node, nullptr);
495 ASSERT_NE(s_node, nullptr);
496 ASSERT_NE(a_node, nullptr);
497
498 ASSERT_EQ(t_node->input_size(), 4);
499 EXPECT_EQ(t_node->op(), "Tile");
500 EXPECT_EQ(t_node->input(0), a_node->name());
501 EXPECT_EQ(t_node->input(1), c_node->name());
502 EXPECT_EQ(t_node->input(2), "^y");
503 EXPECT_EQ(t_node->input(3), "^x");
504
505 ASSERT_EQ(c_node->input_size(), 1);
506 EXPECT_EQ(c_node->input(0), "^a");
507
508 ASSERT_EQ(s_node->input_size(), 1);
509 ASSERT_EQ(s_node->input(0), "^a");
510
511 auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
512 ASSERT_EQ(result.size(), 1);
513 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
514 }
515
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileRemoveReshape)516 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileRemoveReshape) {
517 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
518 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
519 ops::Placeholder::Shape({3, 5, 7, 11}));
520 // Stack creates Pack nodes
521 Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(3));
522 Output c = ops::Stack(s.WithOpName("c"), {b, b}, ops::Stack::Axis(2));
523 Output r =
524 ops::Reshape(s.WithOpName("r"), c, ops::Const(s, {3, 10, 14, 11}, {4}));
525 Output o = ops::Identity(s.WithOpName("output"), r);
526
527 GrapplerItem item;
528 item.fetch = {"output"};
529 TF_CHECK_OK(s.ToGraphDef(&item.graph));
530 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
531 auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
532 ASSERT_EQ(expected.size(), 1);
533
534 GraphDef g;
535 ArithmeticOptimizer optimizer;
536 EnableOnlyReplacePackWithTileReshape(&optimizer);
537 OptimizeAndPrune(&optimizer, &item, &g);
538
539 EXPECT_EQ(g.node_size(), 8);
540 EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
541 EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
542 EXPECT_EQ(CountOpNodes(g, "Const"), 3);
543 EXPECT_EQ(CountOpNodes(g, "Reshape"), 2);
544
545 EnableOnlyRemoveRedundantReshape(&optimizer);
546 OptimizeAndPrune(&optimizer, &item, &g);
547
548 EXPECT_EQ(g.node_size(), 6);
549 EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
550 EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
551 EXPECT_EQ(CountOpNodes(g, "Const"), 2);
552 EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
553
554 auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
555 ASSERT_EQ(result.size(), 1);
556 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
557 }
558
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshapeOutOfRange)559 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshapeOutOfRange) {
560 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
561 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
562 ops::Placeholder::Shape({3, 5, 7, 11}));
563 // Stack creates Pack nodes
564 Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(4));
565 Output o = ops::Identity(s.WithOpName("output"), b);
566
567 GrapplerItem item;
568 item.fetch = {"output"};
569 TF_CHECK_OK(s.ToGraphDef(&item.graph));
570
571 GraphDef g;
572 ArithmeticOptimizer optimizer;
573 EnableOnlyReplacePackWithTileReshape(&optimizer);
574 OptimizeAndPrune(&optimizer, &item, &g);
575
576 VerifyGraphsMatch(item.graph, g, __LINE__);
577 }
578
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAdjacentNodes)579 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAdjacentNodes) {
580 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
581
582 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
583 auto neg1 = ops::Neg(s.WithOpName("neg1"), c);
584 auto neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
585 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
586 auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
587 auto id = ops::Identity(s.WithOpName("id"), recip2);
588
589 GrapplerItem item;
590 item.fetch = {"id"};
591 TF_CHECK_OK(s.ToGraphDef(&item.graph));
592 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
593 ASSERT_EQ(tensors_expected.size(), 1);
594
595 GraphDef output;
596 ArithmeticOptimizer optimizer;
597 EnableOnlyRemoveInvolution(&optimizer);
598 OptimizeAndPrune(&optimizer, &item, &output);
599
600 // Negation and Reciprocal nodes cancelled each other.
601 ASSERT_EQ(output.node_size(), 2);
602 EXPECT_EQ(output.node(1).name(), "id");
603 ASSERT_EQ(output.node(1).input_size(), 1);
604 EXPECT_EQ(output.node(1).input(0), "c");
605
606 auto tensors = EvaluateNodes(output, item.fetch);
607 ASSERT_EQ(tensors.size(), 1);
608 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
609 }
610
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAroundValuePreservingChain)611 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAroundValuePreservingChain) {
612 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
613
614 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
615 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
616 auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
617 auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
618 auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze);
619 auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
620
621 std::vector<string> fetch = {"id2"};
622
623 GrapplerItem item;
624 item.fetch = fetch;
625 TF_CHECK_OK(s.ToGraphDef(&item.graph));
626 auto tensors_expected = EvaluateNodes(item.graph, fetch);
627 ASSERT_EQ(tensors_expected.size(), 1);
628
629 GraphDef output;
630 ArithmeticOptimizer optimizer;
631 EnableOnlyRemoveInvolution(&optimizer);
632 OptimizeTwiceAndPrune(&optimizer, &item, &output);
633
634 // Check that Reciprocal nodes were removed from the graph.
635 EXPECT_EQ(output.node_size(), 3);
636
637 // And const directly flows into squeeze.
638 int found = 0;
639 for (const NodeDef& node : output.node()) {
640 if (node.name() == "squeeze") {
641 ASSERT_EQ(node.input_size(), 1);
642 EXPECT_EQ(node.input(0), "c");
643 found++;
644 } else if (node.name() == "id2") {
645 ASSERT_EQ(node.input_size(), 1);
646 EXPECT_EQ(node.input(0), "squeeze");
647 found++;
648 }
649 }
650 EXPECT_EQ(found, 2);
651
652 auto tensors = EvaluateNodes(output, fetch);
653 ASSERT_EQ(tensors.size(), 1);
654 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
655 }
656
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionSkipControlDependencies)657 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionSkipControlDependencies) {
658 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
659
660 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
661 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
662 auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
663 auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
664 auto recip2 = ops::Reciprocal(
665 s.WithOpName("recip2").WithControlDependencies(squeeze), c);
666 auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
667
668 std::vector<string> fetch = {"id2"};
669
670 GrapplerItem item;
671 item.fetch = fetch;
672 TF_CHECK_OK(s.ToGraphDef(&item.graph));
673
674 auto tensors_expected = EvaluateNodes(item.graph, fetch);
675 ASSERT_EQ(tensors_expected.size(), 1);
676
677 GraphDef output;
678 ArithmeticOptimizer optimizer;
679 EnableOnlyRemoveInvolution(&optimizer);
680 OptimizeTwice(&optimizer, &item, &output); // do not prune in this test
681
682 // The optimizer should be a noop.
683 VerifyGraphsMatch(item.graph, output, __LINE__);
684
685 auto tensors = EvaluateNodes(output, fetch);
686 ASSERT_EQ(tensors.size(), 1);
687 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
688 }
689
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimple)690 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
691 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
692 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
693 Output add = ops::Add(s.WithOpName("add"), x, x);
694 Output id = ops::Identity(s.WithOpName("id"), add);
695
696 GrapplerItem item;
697 item.fetch = {"id"};
698 TF_CHECK_OK(s.ToGraphDef(&item.graph));
699
700 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
701 ASSERT_EQ(tensors_expected.size(), 1);
702
703 ArithmeticOptimizer optimizer;
704 GraphDef output;
705 OptimizeTwice(&optimizer, &item, &output);
706 NodeMap node_map(&output);
707
708 EXPECT_EQ(output.node_size(), 5);
709
710 const string optimized_const_name = AggregationConstName("add");
711 const string optimized_mul_name = AggregationMulName("add");
712
713 const NodeDef* new_const = node_map.GetNode(optimized_const_name);
714 ASSERT_NE(new_const, nullptr);
715 ASSERT_EQ(new_const->input_size(), 1);
716 EXPECT_EQ(new_const->input(0), "^x");
717 EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
718 string("\0\0\0@", 4));
719
720 const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
721 ASSERT_NE(new_mul, nullptr);
722 ASSERT_EQ(new_mul->input_size(), 2);
723 EXPECT_EQ(new_mul->input(0), optimized_const_name);
724 EXPECT_EQ(new_mul->input(1), "x");
725
726 const NodeDef* new_id = node_map.GetNode("id");
727 ASSERT_NE(new_id, nullptr);
728 ASSERT_EQ(new_id->input_size(), 1);
729 EXPECT_EQ(new_id->input(0), optimized_mul_name);
730
731 auto tensors = EvaluateNodes(output, item.fetch);
732 ASSERT_EQ(tensors.size(), 1);
733 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
734 }
735
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimpleWithControlDep)736 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
737 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
738 Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
739 Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2});
740 Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x);
741 Output id = ops::Identity(s.WithOpName("id"), add);
742
743 GrapplerItem item;
744 TF_CHECK_OK(s.ToGraphDef(&item.graph));
745
746 std::vector<string> fetch = {"id"};
747 auto tensors_expected = EvaluateNodes(item.graph, fetch);
748 ASSERT_EQ(tensors_expected.size(), 1);
749
750 ArithmeticOptimizer optimizer;
751 GraphDef output;
752 OptimizeTwice(&optimizer, &item, &output);
753 NodeMap node_map(&output);
754
755 EXPECT_EQ(output.node_size(), 6);
756
757 const string optimized_const_name = AggregationConstName("add");
758 const string optimized_mul_name = AggregationMulName("add");
759
760 const NodeDef* new_const = node_map.GetNode(optimized_const_name);
761 ASSERT_NE(new_const, nullptr);
762 ASSERT_EQ(new_const->input_size(), 1);
763 EXPECT_EQ(new_const->input(0), "^x");
764 EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
765 string("\0\0\0@", 4));
766
767 const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
768 ASSERT_NE(new_mul, nullptr);
769 ASSERT_EQ(new_mul->input_size(), 3);
770 EXPECT_EQ(new_mul->input(0), optimized_const_name);
771 EXPECT_EQ(new_mul->input(1), "x");
772 EXPECT_EQ(new_mul->input(2), "^y");
773
774 const NodeDef* new_id = node_map.GetNode("id");
775 ASSERT_NE(new_id, nullptr);
776 ASSERT_EQ(new_id->input_size(), 1);
777 EXPECT_EQ(new_id->input(0), optimized_mul_name);
778
779 auto tensors = EvaluateNodes(output, fetch);
780 ASSERT_EQ(tensors.size(), 1);
781 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
782 }
783
TEST_F(ArithmeticOptimizerTest,TrivialSumsRepeatedAdd)784 TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
785 // Test case from b/69059093.
786 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
787 Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10}));
788 Output add = ops::Add(s.WithOpName("Add"), p, p);
789 Output add1 = ops::Add(s.WithOpName("Add_1"), p, p);
790 Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1);
791 Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1);
792 Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5);
793 Output id = ops::Identity(s.WithOpName("id"), add6);
794
795 GrapplerItem item;
796 item.fetch = {"id"};
797 TF_CHECK_OK(s.ToGraphDef(&item.graph));
798
799 const std::vector<string> devices{
800 "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1",
801 "/device:CPU:0", "/device:CPU:0", "/device:CPU:0",
802 };
803 for (int i = 0; i < item.graph.node_size(); ++i) {
804 item.graph.mutable_node(i)->set_device(devices[i]);
805 }
806
807 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
808 DisableAddToAddNCombining(&optimizer);
809
810 GraphDef output;
811 DedupAndOptimizeTwiceAndPrune(&optimizer, &item, &output);
812
813 // We expect the following rewrite(s) to occur:
814 //
815 // Mul(p,
816 // Add_6(Add_4(Const(2), Const(2)),
817 // Add_5(Const(2), Const(2)))
818 NodeMap node_map(&output);
819
820 EXPECT_EQ(output.node_size(), 8);
821
822 const NodeDef* id_node = node_map.GetNode("id");
823 ASSERT_NE(id_node, nullptr);
824 ASSERT_EQ(id_node->input_size(), 1);
825 EXPECT_EQ(id_node->input(0), HoistMulName("Add_6"));
826
827 const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
828 ASSERT_NE(mul_node, nullptr);
829 ASSERT_EQ(mul_node->input_size(), 2);
830 EXPECT_EQ(mul_node->input(0), "Placeholder");
831 EXPECT_EQ(mul_node->input(1), HoistAddName("Add_6"));
832
833 const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
834 ASSERT_NE(add_6_node, nullptr);
835 ASSERT_EQ(add_6_node->input_size(), 2);
836 EXPECT_EQ(add_6_node->input(0), HoistAddName("Add_4"));
837 EXPECT_EQ(add_6_node->input(1), HoistAddName("Add_5"));
838
839 const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
840 ASSERT_NE(add_4_node, nullptr);
841 EXPECT_EQ(add_4_node->op(), "Add");
842 ASSERT_EQ(2, add_4_node->input_size());
843 EXPECT_EQ(add_4_node->input(0), AggregationConstName("Add"));
844 EXPECT_EQ(add_4_node->input(1), AggregationConstName("Add_1"));
845
846 const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
847 ASSERT_NE(add_5_node, nullptr);
848 EXPECT_EQ(add_5_node->op(), "Add");
849 ASSERT_EQ(add_5_node->input_size(), 2);
850 EXPECT_EQ(add_5_node->input(0), AggregationConstName("Add"));
851 EXPECT_EQ(add_5_node->input(1), AggregationConstName("Add_1"));
852
853 const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add"));
854 ASSERT_NE(add_const_node, nullptr);
855 EXPECT_EQ(add_const_node->op(), "Const");
856 ASSERT_EQ(add_const_node->input_size(), 1);
857 EXPECT_EQ(add_const_node->input(0), "^Placeholder");
858
859 const NodeDef* add_1_const_node =
860 node_map.GetNode(AggregationConstName("Add_1"));
861 ASSERT_NE(add_1_const_node, nullptr);
862 EXPECT_EQ(add_1_const_node->op(), "Const");
863 ASSERT_EQ(add_1_const_node->input_size(), 1);
864 EXPECT_EQ(add_1_const_node->input(0), "^Placeholder");
865 }
866
TEST_F(ArithmeticOptimizerTest,HoistFactorMul)867 TEST_F(ArithmeticOptimizerTest, HoistFactorMul) {
868 for (bool matching_shapes : {true, false}) {
869 for (bool use_addn : {true, false}) {
870 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
871 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
872 Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
873 Output y2 = matching_shapes
874 ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2})
875 : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
876 Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
877 Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
878 Output id =
879 use_addn ? ops::Identity(s.WithOpName("id"),
880 ops::AddN(s.WithOpName("add"), {mul1, mul2}))
881 : ops::Identity(s.WithOpName("id"),
882 ops::Add(s.WithOpName("add"), mul1, mul2));
883
884 GrapplerItem item;
885 item.fetch = {"id"};
886 TF_CHECK_OK(s.ToGraphDef(&item.graph));
887 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
888 ASSERT_EQ(tensors_expected.size(), 1);
889 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
890 EnableOnlyHoistCommonFactor(&optimizer);
891
892 GraphDef output;
893 OptimizeTwice(&optimizer, &item, &output);
894
895 // We expect the following rewrite(s) to occur:
896 //
897 // Add Mul
898 // / \ / \
899 // Mul Mul -> x Add
900 // / \ / \ / \
901 // x y1 y2 x y1 y2
902 //
903 // If "root" op is AddN and shapes does not match, this rewrite is not
904 // possible and graph should stay intact.
905 NodeMap node_map(&output);
906
907 if (use_addn && !matching_shapes) {
908 VerifyGraphsMatch(item.graph, output, __LINE__);
909 } else {
910 EXPECT_EQ(output.node_size(), 9);
911
912 const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
913 ASSERT_NE(new_add_node, nullptr) << "Hoisted Add node not found";
914 ASSERT_EQ(new_add_node->input_size(), 2);
915 EXPECT_EQ(new_add_node->input(0), "y1");
916 EXPECT_EQ(new_add_node->input(1), "y2");
917
918 const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
919 ASSERT_NE(new_mul_node, nullptr) << "Hoisted Mul node not found";
920 ASSERT_EQ(new_mul_node->input_size(), 2);
921 EXPECT_EQ(new_mul_node->input(0), "x");
922 EXPECT_EQ(new_mul_node->input(1), new_add_node->name());
923
924 const NodeDef* id_node = node_map.GetNode("id");
925 ASSERT_NE(id_node, nullptr) << "Id node not found";
926 EXPECT_EQ(id_node->name(), "id");
927 ASSERT_EQ(id_node->input_size(), 1);
928 EXPECT_EQ(id_node->input(0), HoistMulName("add"));
929 }
930 auto tensors = EvaluateNodes(output, item.fetch);
931 ASSERT_EQ(tensors.size(), 1);
932 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
933 }
934 }
935 }
936
TEST_F(ArithmeticOptimizerTest,HoistFactorDiv)937 TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
938 for (bool matching_shapes : {true, false}) {
939 for (bool use_addn : {true, false}) {
940 for (bool use_ints : {true, false}) {
941 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
942 Output x = use_ints
943 ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2})
944 : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
945 Output y1 = use_ints
946 ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2})
947 : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
948 Output y2;
949 if (matching_shapes) {
950 y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2})
951 : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
952 } else {
953 y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1})
954 : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
955 }
956 Output div1 = ops::Div(s.WithOpName("div1"), y1, x);
957 Output div2 = ops::Div(s.WithOpName("div2"), y2, x);
958 Output id =
959 use_addn
960 ? ops::Identity(s.WithOpName("id"),
961 ops::AddN(s.WithOpName("add"), {div1, div2}))
962 : ops::Identity(s.WithOpName("id"),
963 ops::Add(s.WithOpName("add"), div1, div2));
964
965 GrapplerItem item;
966 item.fetch = {"id"};
967 TF_CHECK_OK(s.ToGraphDef(&item.graph));
968
969 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
970 ASSERT_EQ(tensors_expected.size(), 1);
971
972 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
973 EnableOnlyHoistCommonFactor(&optimizer);
974
975 GraphDef output;
976 OptimizeTwice(&optimizer, &item, &output);
977
978 // We expect the following rewrite(s) to occur:
979 //
980 // Add Div
981 // / \ / \
982 // Div Div -> Add x
983 // / \ / \ / \
984 // y1 x y2 x y1 y2
985 //
986 // If "root" op is AddN and shapes does not match, this rewrite is not
987 // possible and graph should stay intact.
988 NodeMap node_map(&output);
989
990 if ((use_addn && !matching_shapes) || use_ints) {
991 VerifyGraphsMatch(item.graph, output, __LINE__);
992 } else {
993 EXPECT_EQ(output.node_size(), 9);
994
995 const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
996 ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
997 ASSERT_EQ(new_add_node->input_size(), 2);
998 EXPECT_EQ(new_add_node->input(0), "y1");
999 EXPECT_EQ(new_add_node->input(1), "y2");
1000
1001 const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add"));
1002 ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found";
1003 ASSERT_EQ(new_div_node->input_size(), 2);
1004 EXPECT_EQ(new_div_node->input(0), new_add_node->name());
1005 EXPECT_EQ(new_div_node->input(1), "x");
1006
1007 const NodeDef* id_node = node_map.GetNode("id");
1008 ASSERT_TRUE(id_node != nullptr) << "Id node not found";
1009 EXPECT_EQ("id", id_node->name());
1010 ASSERT_EQ(id_node->input_size(), 1);
1011 EXPECT_EQ(id_node->input(0), HoistDivName("add"));
1012 }
1013 auto tensors = EvaluateNodes(output, item.fetch);
1014 ASSERT_EQ(tensors.size(), 1);
1015 if (use_ints) {
1016 test::ExpectTensorEqual<int32>(tensors[0], tensors_expected[0]);
1017 } else {
1018 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1019 }
1020 }
1021 }
1022 }
1023 }
1024
TEST_F(ArithmeticOptimizerTest,FuseConjAndTranspose)1025 TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
1026 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1027 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1028 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1029 Output z = ops::Complex(s.WithOpName("z"), re, im);
1030 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1031 Output conj = ops::Conj(s.WithOpName("conj"), z);
1032 Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
1033
1034 GrapplerItem item;
1035 item.fetch = {"trans"};
1036 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1037
1038 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1039 ASSERT_EQ(tensors_expected.size(), 1);
1040
1041 ArithmeticOptimizer optimizer;
1042 GraphDef output;
1043 OptimizeTwice(&optimizer, &item, &output);
1044 NodeMap node_map(&output);
1045
1046 EXPECT_EQ(output.node_size(), 7);
1047
1048 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1049 const string optimized_name = absl::StrCat(p, "_", "trans");
1050
1051 const NodeDef* trans_fused_node = node_map.GetNode(optimized_name);
1052 ASSERT_NE(trans_fused_node, nullptr);
1053 EXPECT_EQ(trans_fused_node->op(), "ConjugateTranspose");
1054 ASSERT_EQ(trans_fused_node->input_size(), 2);
1055 EXPECT_EQ(trans_fused_node->input(0), "z");
1056 EXPECT_EQ(trans_fused_node->input(1), "perm");
1057
1058 auto tensors = EvaluateNodes(output, item.fetch);
1059 ASSERT_EQ(tensors.size(), 1);
1060 test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1061 }
1062
TEST_F(ArithmeticOptimizerTest,FuseConjAndConjugateTranspose)1063 TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
1064 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1065
1066 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1067 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1068 Output z = ops::Complex(s.WithOpName("z"), re, im);
1069 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1070 Output conj = ops::Conj(s.WithOpName("conj"), z);
1071 Output transp =
1072 ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
1073
1074 GrapplerItem item;
1075 item.fetch = {"conjugate_trans"};
1076 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1077
1078 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1079 ASSERT_EQ(tensors_expected.size(), 1);
1080
1081 ArithmeticOptimizer optimizer;
1082 GraphDef output;
1083 OptimizeTwice(&optimizer, &item, &output);
1084 NodeMap node_map(&output);
1085
1086 EXPECT_EQ(output.node_size(), 7);
1087
1088 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1089 const string optimized_name = absl::StrCat(p, "_", "conjugate_trans");
1090
1091 const NodeDef* conjugate_trans_fused_node = node_map.GetNode(optimized_name);
1092 ASSERT_NE(conjugate_trans_fused_node, nullptr);
1093 EXPECT_EQ(conjugate_trans_fused_node->op(), "Transpose");
1094 ASSERT_EQ(conjugate_trans_fused_node->input_size(), 2);
1095 EXPECT_EQ(conjugate_trans_fused_node->input(0), "z");
1096 EXPECT_EQ(conjugate_trans_fused_node->input(1), "perm");
1097
1098 auto tensors = EvaluateNodes(output, item.fetch);
1099 ASSERT_EQ(tensors.size(), 1);
1100 test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1101 }
1102
TEST_F(ArithmeticOptimizerTest,FuseTransposeAndConj)1103 TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
1104 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1105 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1106 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1107 Output z = ops::Complex(s.WithOpName("z"), re, im);
1108 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1109 Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
1110 Output conj = ops::Conj(s.WithOpName("conj"), trans);
1111
1112 GrapplerItem item;
1113 item.fetch = {"conj"};
1114 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1115
1116 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1117 ASSERT_EQ(tensors_expected.size(), 1);
1118
1119 ArithmeticOptimizer optimizer;
1120 GraphDef output;
1121 OptimizeTwice(&optimizer, &item, &output);
1122 NodeMap node_map(&output);
1123
1124 EXPECT_EQ(output.node_size(), 7);
1125
1126 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1127 const string optimized_name = absl::StrCat(p, "_", "conj");
1128
1129 const NodeDef* conj_fused_node = node_map.GetNode(optimized_name);
1130 ASSERT_NE(conj_fused_node, nullptr);
1131 EXPECT_EQ(conj_fused_node->op(), "ConjugateTranspose");
1132 ASSERT_EQ(conj_fused_node->input_size(), 2);
1133 EXPECT_EQ(conj_fused_node->input(0), "z");
1134 EXPECT_EQ(conj_fused_node->input(1), "perm");
1135
1136 auto tensors = EvaluateNodes(output, item.fetch);
1137 ASSERT_EQ(tensors.size(), 1);
1138 test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1139 }
1140
TEST_F(ArithmeticOptimizerTest,FoldTransposeIntoMatMul)1141 TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
1142 for (const string matmul_type :
1143 {"MatMul", "SparseMatMul", "BatchMatMul", "BatchMatMulV2"}) {
1144 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1145
1146 Output a = ops::Const(s.WithOpName("a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1147 Output b = ops::Const(s.WithOpName("b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1148 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1149 Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
1150 Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
1151
1152 Output matmul;
1153 auto matmul_op = s.WithOpName("matmul");
1154 if (matmul_type == "MatMul") {
1155 matmul = ops::MatMul(matmul_op, trans_a, trans_b);
1156 } else if (matmul_type == "SparseMatMul") {
1157 matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
1158 } else if (matmul_type == "BatchMatMul") {
1159 matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
1160 } else if (matmul_type == "BatchMatMulV2") {
1161 matmul = ops::BatchMatMulV2(matmul_op, trans_a, trans_b);
1162 }
1163
1164 auto identity = ops::Identity(s.WithOpName("identity"), matmul);
1165
1166 GrapplerItem item;
1167 item.fetch = {"identity"};
1168 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1169
1170 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1171 ASSERT_EQ(tensors_expected.size(), 1);
1172
1173 ArithmeticOptimizer optimizer;
1174 EnableOnlyFoldTransposeIntoMatMul(&optimizer);
1175 GraphDef output;
1176 OptimizeTwice(&optimizer, &item, &output);
1177 NodeMap node_map(&output);
1178
1179 EXPECT_EQ(output.node_size(), 8);
1180
1181 const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
1182 const string optimized_name = absl::StrCat(p, "_", "matmul");
1183
1184 const NodeDef* matmul_fused_node = node_map.GetNode(optimized_name);
1185 ASSERT_NE(matmul_fused_node, nullptr);
1186 ASSERT_EQ(matmul_fused_node->input_size(), 2);
1187 EXPECT_EQ(matmul_fused_node->input(0), "a");
1188 EXPECT_EQ(matmul_fused_node->input(1), "b");
1189
1190 if (matmul_type == "BatchMatMul" || matmul_type == "BatchMatMulV2") {
1191 EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b());
1192 EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b());
1193 } else {
1194 EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
1195 EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
1196 }
1197
1198 const NodeDef* identity_node = node_map.GetNode("identity");
1199 ASSERT_NE(identity_node, nullptr);
1200 ASSERT_EQ(identity_node->input_size(), 1);
1201 EXPECT_EQ(identity_node->input(0), optimized_name);
1202
1203 auto tensors = EvaluateNodes(output, item.fetch);
1204 ASSERT_EQ(tensors.size(), 1);
1205 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1206 }
1207 }
1208
TEST_F(ArithmeticOptimizerTest,FoldConjugateTransposeIntoBatchMatMul)1209 TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
1210 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1211
1212 Output re_a =
1213 ops::Const(s.WithOpName("re_a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1214 Output im_a =
1215 ops::Const(s.WithOpName("im_a"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
1216 Output re_b =
1217 ops::Const(s.WithOpName("re_b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1218 Output im_b =
1219 ops::Const(s.WithOpName("im_b"), {-5.0f, -6.0f, -7.0f, -8.0f}, {2, 2});
1220 Output a = ops::Complex(s.WithOpName("a"), re_a, im_a);
1221 Output b = ops::Complex(s.WithOpName("b"), re_b, im_b);
1222 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1223 Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
1224 Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
1225 Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
1226 Output identity = ops::Identity(s.WithOpName("identity"), matmul);
1227
1228 GrapplerItem item;
1229 item.fetch = {"identity"};
1230 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1231
1232 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1233 ASSERT_EQ(tensors_expected.size(), 1);
1234
1235 ArithmeticOptimizer optimizer;
1236 GraphDef output;
1237 OptimizeTwice(&optimizer, &item, &output);
1238
1239 NodeMap node_map(&output);
1240 EXPECT_EQ(output.node_size(), 12);
1241
1242 const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
1243 const string optimized_name = absl::StrCat(p, "_", "matmul");
1244
1245 const NodeDef* optimized_matmul = node_map.GetNode(optimized_name);
1246 ASSERT_NE(optimized_matmul, nullptr);
1247 ASSERT_EQ(optimized_matmul->input_size(), 2);
1248 EXPECT_EQ(optimized_matmul->input(0), "a");
1249 EXPECT_EQ(optimized_matmul->input(1), "b");
1250 EXPECT_TRUE(optimized_matmul->attr().at("adj_x").b());
1251 EXPECT_TRUE(optimized_matmul->attr().at("adj_y").b());
1252
1253 auto tensors = EvaluateNodes(output, item.fetch);
1254 ASSERT_EQ(tensors.size(), 1);
1255 test::ExpectTensorNear<complex64>(tensors[0], tensors_expected[0], 1e-6);
1256 }
1257
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshape)1258 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeIdentityReshape) {
1259 for (bool is_broadcastto : {false, true}) {
1260 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1261 Output inputs =
1262 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
1263 Output inputs_shape = ops::Shape(s, inputs);
1264 // The target shape of the reshape is the concatenation of `batch_size` and
1265 // [3,28,28].
1266 Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
1267 ops::Const(s, {1}, {1}));
1268 Output target_shape = ops::Concat(
1269 s.WithOpName("target_shape"),
1270 {batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {}));
1271 if (is_broadcastto) {
1272 Output outputs = ops::Identity(s.WithOpName("outputs"),
1273 ops::BroadcastTo(s, inputs, target_shape));
1274 } else {
1275 Output outputs = ops::Identity(s.WithOpName("outputs"),
1276 ops::Reshape(s, inputs, target_shape));
1277 }
1278
1279 GrapplerItem item;
1280 item.fetch = {"outputs"};
1281 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1282 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
1283 auto tensors_expected =
1284 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
1285 ASSERT_EQ(tensors_expected.size(), 1);
1286
1287 GraphDef output;
1288 ArithmeticOptimizer optimizer;
1289 EnableOnlyRemoveRedundantReshape(&optimizer);
1290 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1291
1292 EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1293 EXPECT_EQ(CountOpNodes(output, "BroadcastTo"), 0);
1294 auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
1295 ASSERT_EQ(tensors.size(), 1);
1296 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1297 }
1298 }
1299
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes)1300 TEST_F(ArithmeticOptimizerTest,
1301 RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes) {
1302 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1303 Output inputs =
1304 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
1305 Output inputs_shape = ops::Shape(s, inputs);
1306 // The target shape of the reshape is the concatenation of `batch_size`, 3,
1307 // `height, and `width`.
1308 Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
1309 ops::Const(s, {1}, {1}));
1310 Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
1311 ops::Const(s, {1}, {1}));
1312 Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
1313 ops::Const(s, {1}, {1}));
1314 Output target_shape =
1315 ops::Concat(s.WithOpName("target_shape"),
1316 {batch_size, ops::Const(s, {3}, {1}), height, width},
1317 ops::Const(s, {0}, {}));
1318 Output reshape = ops::Reshape(s, inputs, target_shape);
1319 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1320
1321 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
1322 GrapplerItem item;
1323 item.fetch = {"outputs"};
1324 item.feed = {{"Placeholder", x_t}};
1325 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1326
1327 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1328 ASSERT_EQ(tensors_expected.size(), 1);
1329
1330 GraphDef output;
1331 // Assume valid feed shape in aggressive mode.
1332 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1333 EnableOnlyRemoveRedundantReshape(&optimizer);
1334 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1335
1336 EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1337 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1338 ASSERT_EQ(tensors.size(), 1);
1339 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1340 }
1341
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotAssumeValidFeeds)1342 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotAssumeValidFeeds) {
1343 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1344 Output inputs =
1345 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
1346 Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
1347 Output reshape = ops::Reshape(s, inputs, target_shape);
1348 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1349
1350 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
1351 GrapplerItem item;
1352 item.fetch = {"outputs"};
1353 item.feed = {{"Placeholder", x_t}};
1354 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1355
1356 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1357 ASSERT_EQ(tensors_expected.size(), 1);
1358
1359 GraphDef output;
1360 ArithmeticOptimizer optimizer;
1361 EnableOnlyRemoveRedundantReshape(&optimizer);
1362 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1363
1364 // The reshape is preserved because the shape of the placeholder can be
1365 // different from the shape of the actual feed.
1366 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1367
1368 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1369 ASSERT_EQ(tensors.size(), 1);
1370 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1371 }
1372
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode)1373 TEST_F(ArithmeticOptimizerTest,
1374 RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode) {
1375 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1376 Output inputs =
1377 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
1378 Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
1379 Output reshape = ops::Reshape(s, inputs, target_shape);
1380 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1381
1382 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
1383 GrapplerItem item;
1384 item.fetch = {"outputs"};
1385 item.feed = {{"Placeholder", x_t}};
1386 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1387
1388 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1389 ASSERT_EQ(tensors_expected.size(), 1);
1390
1391 GraphDef output;
1392 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1393 EnableOnlyRemoveRedundantReshape(&optimizer);
1394 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1395
1396 EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1397 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1398 ASSERT_EQ(tensors.size(), 1);
1399 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1400 }
1401
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshape)1402 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotIdentityReshape) {
1403 // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
1404 // be from [4,3,28,28] to [8,6,28,28].
1405 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1406 Output inputs =
1407 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
1408 Output reshape = ops::Reshape(s, inputs, ops::Const(s, {8, -1, 28, 28}, {4}));
1409 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1410
1411 GrapplerItem item;
1412 item.fetch = {"outputs"};
1413 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1414 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28}));
1415 item.feed = {{"Placeholder", x_t}};
1416 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1417 ASSERT_EQ(tensors_expected.size(), 1);
1418
1419 GraphDef output;
1420 ArithmeticOptimizer optimizer;
1421 EnableOnlyRemoveRedundantReshape(&optimizer);
1422 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1423
1424 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1425 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1426 ASSERT_EQ(tensors.size(), 1);
1427 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1428 }
1429
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes)1430 TEST_F(ArithmeticOptimizerTest,
1431 RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes) {
1432 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1433 Output inputs =
1434 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
1435 Output reshape = ops::Reshape(s, inputs, ops::Const(s, {-1, -1}, {2}));
1436 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1437
1438 GrapplerItem item;
1439 item.fetch = {"outputs"};
1440 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1441
1442 GraphDef output;
1443 ArithmeticOptimizer optimizer;
1444 EnableOnlyRemoveRedundantReshape(&optimizer);
1445 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1446
1447 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1448 }
1449
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeCombineReshapes)1450 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeCombineReshapes) {
1451 for (bool include_unary_chain : {false, true}) {
1452 // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The
1453 // two reshapes should be combined.
1454 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1455 Output nchw_vect_c =
1456 ops::Placeholder(s.WithOpName("nchw_vect_c"), DT_FLOAT,
1457 ops::Placeholder::Shape({8, 3, 28, 28, 4}));
1458 Output transpose =
1459 ops::Transpose(s.WithOpName("transpose"), nchw_vect_c,
1460 ops::Const(s.WithOpName("perm"), {0, 2, 3, 1, 4}, {5}));
1461 Output nhwc = ops::Reshape(
1462 s.WithOpName("nhwc"), transpose,
1463 ops::Const(
1464 s.WithControlDependencies(nchw_vect_c).WithOpName("nhwc_shape"),
1465 {8, 28, 28, 12}, {4}));
1466 Output flatten = ops::Reshape(
1467 s.WithOpName("flatten"),
1468 (include_unary_chain ? ops::Cos(s.WithOpName("Cos"), nhwc) : nhwc),
1469 ops::Const(s.WithOpName("flatten_shape"), {8, 28 * 28 * 12}, {2}));
1470 Output output0 = ops::Identity(s.WithOpName("output0"), flatten);
1471 Output output1 = ops::Identity(s.WithOpName("output1"), flatten);
1472
1473 GraphDef graph;
1474 TF_CHECK_OK(s.ToGraphDef(&graph));
1475 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28, 4}));
1476 auto eval =
1477 EvaluateNodes(graph, {"output0", "nhwc"}, {{"nchw_vect_c", x_t}});
1478
1479 ASSERT_EQ(eval.size(), 2);
1480 auto expected_output_t = eval[0];
1481 auto nhwc_t = eval[1];
1482
1483 {
1484 GrapplerItem item;
1485 item.graph = graph;
1486 item.fetch = {"output0", "output1"};
1487 item.feed = {{"nchw_vect_c", x_t}};
1488
1489 GraphDef output;
1490 ArithmeticOptimizer optimizer;
1491 EnableOnlyRemoveRedundantReshape(&optimizer);
1492 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1493
1494 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1495 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1496 ASSERT_EQ(tensors.size(), 2);
1497 test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1498 test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1499 }
1500
1501 // Test when the first reshape node output is the feed tensor.
1502 // (Expected no reshape removal to happen.)
1503 {
1504 GrapplerItem item;
1505 item.graph = graph;
1506 item.fetch = {"output0", "output1"};
1507 item.feed = {{"nhwc", nhwc_t}};
1508
1509 GraphDef output;
1510 ArithmeticOptimizer optimizer;
1511 EnableOnlyRemoveRedundantReshape(&optimizer);
1512 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1513
1514 EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
1515 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1516 ASSERT_EQ(tensors.size(), 2);
1517 test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1518 test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1519 }
1520
1521 // Test when the first reshape node output is consumed by multiple nodes
1522 // (Expected no reshape removal to happen.)
1523 {
1524 Output output2 = ops::Identity(s.WithOpName("output2"), nhwc);
1525 GraphDef graph;
1526 TF_CHECK_OK(s.ToGraphDef(&graph));
1527 GrapplerItem item;
1528 item.graph = graph;
1529 item.fetch = {"output0", "output1", "output2"};
1530 item.feed = {{"nchw_vect_c", x_t}};
1531
1532 GraphDef output;
1533 ArithmeticOptimizer optimizer;
1534 EnableOnlyRemoveRedundantReshape(&optimizer);
1535 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1536
1537 EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
1538 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1539 ASSERT_EQ(tensors.size(), 3);
1540 test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1541 test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1542 test::ExpectTensorEqual<float>(tensors[2], nhwc_t);
1543 }
1544 }
1545 }
1546
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsCast)1547 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsCast) {
1548 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1549 Output nhwc_uint8 =
1550 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1551 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1552 Output nchw_fp32 =
1553 ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1554 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1555
1556 GrapplerItem item;
1557 item.fetch = {"outputs"};
1558 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1559
1560 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1561 auto tensors_expected =
1562 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1563 ASSERT_EQ(tensors_expected.size(), 1);
1564
1565 GraphDef output;
1566 ArithmeticOptimizer optimizer;
1567 OptimizeAndPrune(&optimizer, &item, &output);
1568
1569 const NodeDef* transpose_node = nullptr;
1570 for (const NodeDef& node : output.node()) {
1571 if (node.op() == "Transpose") {
1572 EXPECT_EQ(transpose_node, nullptr);
1573 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1574 transpose_node = &node;
1575 }
1576 }
1577 ASSERT_NE(transpose_node, nullptr);
1578
1579 for (const NodeDef& node : output.node()) {
1580 if (node.op() == "Cast") {
1581 ASSERT_EQ(node.input_size(), 1);
1582 EXPECT_EQ(transpose_node->name(), NodeName(node.input(0)));
1583 }
1584 }
1585
1586 auto tensors =
1587 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1588 ASSERT_EQ(tensors.size(), 1);
1589 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1590 }
1591
TEST_F(ArithmeticOptimizerTest,ReorderS2DCastProducerIsCast)1592 TEST_F(ArithmeticOptimizerTest, ReorderS2DCastProducerIsCast) {
1593 // TODO(jingyue): Evaluate S2D+Cast on GPU as well. We can't simply put nodes
1594 // under a /GPU:0 scope, because this test would fail if the testing machine
1595 // doesn't have a GPU. Maybe EvaluateNodes should allow soft placement?
1596 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1597 Output outputs =
1598 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1599 outputs = ops::Cast(s, outputs, DT_FLOAT);
1600 outputs = ops::SpaceToDepth(s, outputs, 2);
1601 outputs = ops::Identity(s.WithOpName("outputs"), outputs);
1602
1603 GrapplerItem item;
1604 item.fetch = {"outputs"};
1605 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1606
1607 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1608 auto tensors_expected =
1609 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1610 ASSERT_EQ(tensors_expected.size(), 1);
1611
1612 GraphDef output;
1613 ArithmeticOptimizer optimizer;
1614 OptimizeAndPrune(&optimizer, &item, &output);
1615
1616 const NodeDef* s2d_node = nullptr;
1617 for (const NodeDef& node : output.node()) {
1618 if (node.op() == "SpaceToDepth") {
1619 EXPECT_EQ(s2d_node, nullptr);
1620 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1621 s2d_node = &node;
1622 }
1623 }
1624 ASSERT_NE(s2d_node, nullptr);
1625
1626 for (const NodeDef& node : output.node()) {
1627 if (node.op() == "Cast") {
1628 ASSERT_EQ(node.input_size(), 1);
1629 EXPECT_EQ(s2d_node->name(), NodeName(node.input(0)));
1630 }
1631 }
1632
1633 auto tensors =
1634 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1635 ASSERT_EQ(tensors.size(), 1);
1636 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1637 }
1638
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsTranspose)1639 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsTranspose) {
1640 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1641 Output nhwc_fp32 =
1642 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1643 Output nchw_fp32 =
1644 ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1645 Output nchw_uint8 = ops::Cast(s, nchw_fp32, DT_UINT8);
1646 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1647
1648 GrapplerItem item;
1649 item.fetch = {"outputs"};
1650 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1651
1652 auto input_t =
1653 GenerateConstantTensor<DT_FLOAT>(TensorShape({8, 28, 28, 3}), 42.0f);
1654 auto tensors_expected =
1655 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1656 ASSERT_EQ(tensors_expected.size(), 1);
1657
1658 GraphDef output;
1659 ArithmeticOptimizer optimizer;
1660 OptimizeAndPrune(&optimizer, &item, &output);
1661
1662 const NodeDef* cast_node = nullptr;
1663 for (const NodeDef& node : output.node()) {
1664 if (node.op() == "Cast") {
1665 EXPECT_EQ(cast_node, nullptr);
1666 cast_node = &node;
1667 ASSERT_EQ(node.input_size(), 1);
1668 EXPECT_EQ(NodeName(node.input(0)), "Placeholder");
1669 }
1670 }
1671 ASSERT_NE(cast_node, nullptr);
1672
1673 for (const NodeDef& node : output.node()) {
1674 if (node.op() == "Transpose") {
1675 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1676 ASSERT_EQ(node.input_size(), 2);
1677 EXPECT_EQ(cast_node->name(), NodeName(node.input(0)));
1678 }
1679 }
1680
1681 auto tensors =
1682 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1683 ASSERT_EQ(tensors.size(), 1);
1684 test::ExpectTensorEqual<uint8>(tensors[0], tensors_expected[0]);
1685 }
1686
TEST_F(ArithmeticOptimizerTest,ReorderTransposeReverseCast)1687 TEST_F(ArithmeticOptimizerTest, ReorderTransposeReverseCast) {
1688 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1689 Output nhwc_uint8 =
1690 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1691 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1692 Output nhwc_fp32_reversed =
1693 ops::Reverse(s, nhwc_fp32, ops::Const(s, {0}, {1}));
1694 Output nchw_fp32_reversed =
1695 ops::Transpose(s, nhwc_fp32_reversed, ops::Const(s, {0, 3, 1, 2}, {4}));
1696
1697 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32_reversed);
1698
1699 GrapplerItem item;
1700 item.fetch = {"outputs"};
1701 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1702
1703 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1704 auto tensors_expected =
1705 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1706 ASSERT_EQ(tensors_expected.size(), 1);
1707
1708 GraphDef output;
1709 ArithmeticOptimizer optimizer;
1710 OptimizeAndPrune(&optimizer, &item, &output);
1711
1712 const NodeDef* reverse_node = nullptr;
1713 const NodeDef* transpose_node = nullptr;
1714 const NodeDef* cast_node = nullptr;
1715 for (const NodeDef& node : output.node()) {
1716 if (node.op() == "Transpose") {
1717 EXPECT_EQ(transpose_node, nullptr);
1718 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1719 transpose_node = &node;
1720 } else if (node.op() == "ReverseV2") {
1721 EXPECT_EQ(reverse_node, nullptr);
1722 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1723 reverse_node = &node;
1724 } else if (node.op() == "Cast") {
1725 cast_node = &node;
1726 }
1727 }
1728 ASSERT_NE(cast_node, nullptr);
1729 ASSERT_NE(reverse_node, nullptr);
1730 ASSERT_NE(transpose_node, nullptr);
1731 ASSERT_EQ(reverse_node->input_size(), 2);
1732 EXPECT_EQ(NodeName(reverse_node->input(0)), "Placeholder");
1733 ASSERT_EQ(transpose_node->input_size(), 2);
1734 EXPECT_EQ(NodeName(transpose_node->input(0)), reverse_node->name());
1735 ASSERT_EQ(cast_node->input_size(), 1);
1736 EXPECT_EQ(NodeName(cast_node->input(0)), transpose_node->name());
1737
1738 auto tensors =
1739 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1740 ASSERT_EQ(tensors.size(), 1);
1741 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1742 }
1743
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastCheckNumericsToIdentity)1744 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastCheckNumericsToIdentity) {
1745 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1746 Output nhwc_uint8 =
1747 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1748 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1749 Output nchw_fp32 = ops::CheckNumerics(s, nhwc_fp32, "foo");
1750 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1751
1752 GrapplerItem item;
1753 item.fetch = {"outputs"};
1754 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1755
1756 GraphDef output;
1757 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1758 CompareGraphs(item.graph, output);
1759 }
1760
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsCast)1761 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsCast) {
1762 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1763 Output nhwc_fp32 =
1764 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1765 Output nhwc_uint8 = ops::Cast(s, nhwc_fp32, DT_UINT8);
1766 Output nchw_uint8 =
1767 ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1768 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1769
1770 GrapplerItem item;
1771 item.fetch = {"outputs"};
1772 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1773
1774 GraphDef output;
1775 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1776 CompareGraphs(item.graph, output);
1777 }
1778
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsTranspose)1779 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsTranspose) {
1780 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1781 Output nhwc_uint8 =
1782 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1783 Output nchw_uint8 =
1784 ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1785 Output nchw_fp32 = ops::Cast(s, nchw_uint8, DT_FLOAT);
1786 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1787
1788 GrapplerItem item;
1789 item.fetch = {"outputs"};
1790 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1791
1792 GraphDef output;
1793 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1794 CompareGraphs(item.graph, output);
1795 }
1796
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposes)1797 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
1798 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1799 Output inputs_shape =
1800 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1801 Output inputs =
1802 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1803 Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1804 Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1805 Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
1806 Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
1807 Output transpose2 =
1808 ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
1809 Output transpose3 = ops::Transpose(s.WithOpName("transpose3"), inputs, perm3);
1810 Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1811 Output id2 = ops::Identity(s.WithOpName("id2"), transpose3);
1812
1813 GrapplerItem item;
1814 item.fetch = {"id1", "id2"};
1815 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1816
1817 GraphDef output;
1818 ArithmeticOptimizer optimizer;
1819 EnableOnlyRemoveIdentityTranspose(&optimizer);
1820 OptimizeAndPrune(&optimizer, &item, &output);
1821
1822 std::set<string> nodes_after_optimization;
1823 for (const NodeDef& node : output.node()) {
1824 nodes_after_optimization.insert(node.name());
1825 }
1826 EXPECT_EQ(nodes_after_optimization,
1827 std::set<string>({"id1", "id2", "inputs_shape", "inputs"}));
1828 }
1829
TEST_F(ArithmeticOptimizerTest,RemoveIdentityConjugateTransposes)1830 TEST_F(ArithmeticOptimizerTest, RemoveIdentityConjugateTransposes) {
1831 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1832 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1833 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1834 Output z = ops::Complex(s.WithOpName("z"), re, im);
1835 Output perm = ops::Const(s.WithOpName("perm"), {0, 1}, {2});
1836 Output transpose = ops::ConjugateTranspose(s.WithOpName("trans"), z, perm);
1837 Output id = ops::Identity(s.WithOpName("id"), transpose);
1838
1839 GrapplerItem item;
1840 item.fetch = {"id"};
1841 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1842
1843 GraphDef output;
1844 ArithmeticOptimizer optimizer;
1845 EnableOnlyRemoveIdentityTranspose(&optimizer);
1846 OptimizeAndPrune(&optimizer, &item, &output);
1847 NodeMap node_map(&output);
1848
1849 EXPECT_EQ(output.node_size(), 5);
1850
1851 const string p = "ArithmeticOptimizer/RemoveIdentityTranspose";
1852 const string optimized_name = absl::StrCat(p, "_", "trans");
1853
1854 const NodeDef* conj = node_map.GetNode(optimized_name);
1855 ASSERT_NE(conj, nullptr);
1856 EXPECT_EQ(conj->op(), "Conj");
1857 ASSERT_EQ(conj->input_size(), 1);
1858 EXPECT_EQ(conj->input(0), "z");
1859 }
1860
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesMultipleOutputs)1861 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
1862 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1863 Output inputs_shape =
1864 ops::Const(s.WithOpName("inputs_shape"), {8, 9, 28, 28}, {4});
1865 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1866 ops::Placeholder::Shape({8, 12, 28, 28}));
1867 OutputList split = ops::Split(s, ops::Const(s, 1), inputs, 3).output;
1868 Output perm1 = ops::Const(s, {0, 2, 3, 1}, {4});
1869 Output perm2 = ops::Const(s, {0, 3, 1, 2}, {4});
1870 Output branch0 = split[0];
1871 Output branch1 = ops::Transpose(s, ops::Transpose(s, split[1], perm1), perm2);
1872 Output branch2 = split[2];
1873 Output concat = ops::Concat(s, {branch0, branch1, branch2}, ops::Const(s, 1));
1874 Output outputs = ops::Identity(s.WithOpName("outputs"), concat);
1875
1876 GrapplerItem item;
1877 item.fetch = {"outputs"};
1878 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1879
1880 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28}));
1881 item.feed = {{"inputs", x_t}};
1882 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1883 ASSERT_EQ(tensors_expected.size(), 1);
1884
1885 GraphDef output;
1886 ArithmeticOptimizer optimizer;
1887 EnableOnlyRemoveIdentityTranspose(&optimizer);
1888 OptimizeAndPrune(&optimizer, &item, &output);
1889
1890 for (const NodeDef& node : output.node()) {
1891 if (node.op() == "Concat") {
1892 ASSERT_EQ(node.input_size(), 3);
1893 EXPECT_EQ(node.input(0), "Split");
1894 EXPECT_EQ(node.input(1), "Split:1");
1895 EXPECT_EQ(node.input(2), "Split:2");
1896 }
1897 }
1898
1899 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1900 ASSERT_EQ(tensors.size(), 1);
1901 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1902 }
1903
TEST_F(ArithmeticOptimizerTest,RemoveTransposesWithControlDependency)1904 TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
1905 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1906 Output inputs =
1907 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3}));
1908 Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0}));
1909 Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0}));
1910 Output outputs =
1911 ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2),
1912 ops::Const(s.WithOpName("outputs_const"), 1.0f));
1913
1914 GrapplerItem item;
1915 item.fetch = {"outputs"};
1916 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1917
1918 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
1919 item.feed = {{"Placeholder", x_t}};
1920 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1921 ASSERT_EQ(tensors_expected.size(), 1);
1922
1923 GraphDef output;
1924 ArithmeticOptimizer optimizer;
1925 EnableOnlyRemoveIdentityTranspose(&optimizer);
1926 OptimizeAndPrune(&optimizer, &item, &output);
1927
1928 NodeMap node_map(&output);
1929 const NodeDef* outputs_node = node_map.GetNode("outputs");
1930 ASSERT_EQ(outputs_node->input_size(), 2);
1931 EXPECT_EQ(outputs_node->input(0), "outputs_const");
1932 EXPECT_EQ(outputs_node->input(1), "^Placeholder");
1933
1934 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1935 ASSERT_EQ(tensors.size(), 1);
1936 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1937 }
1938
TEST_F(ArithmeticOptimizerTest,NotRemoveTransposes)1939 TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
1940 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1941 Output inputs_shape =
1942 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1943 Output inputs =
1944 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1945 Output perm = ops::Const(s.WithOpName("perm"), {1, 2, 3, 0}, {4});
1946 Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm);
1947 Output transpose2 =
1948 ops::Transpose(s.WithOpName("transpose2"), transpose1, perm);
1949 Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2);
1950
1951 GrapplerItem item;
1952 item.fetch = {"outputs"};
1953 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1954
1955 GraphDef output;
1956 ArithmeticOptimizer optimizer;
1957 EnableOnlyRemoveIdentityTranspose(&optimizer);
1958 OptimizeAndPrune(&optimizer, &item, &output);
1959
1960 EXPECT_EQ(output.node_size(), 6);
1961 }
1962
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesThroughChain)1963 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
1964 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1965 Output inputs_shape =
1966 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1967 Output inputs =
1968 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1969 Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1970 Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1971 Output transpose1 = ops::Transpose(
1972 s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
1973 Output identity = ops::Identity(s.WithOpName("id"), transpose1);
1974 Output transpose2 =
1975 ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
1976 Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1977
1978 GrapplerItem item;
1979 item.fetch = {"id1"};
1980 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1981
1982 GraphDef output;
1983 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1984 EnableOnlyRemoveIdentityTranspose(&optimizer);
1985 OptimizeAndPrune(&optimizer, &item, &output);
1986
1987 std::set<string> nodes_after_optimization;
1988 for (const NodeDef& node : output.node()) {
1989 nodes_after_optimization.insert(node.name());
1990 if (node.name() == "id") {
1991 ASSERT_EQ(node.input_size(), 1);
1992 EXPECT_EQ(node.input(0), "inputs");
1993 }
1994 if (node.name() == "id1") {
1995 ASSERT_EQ(node.input_size(), 1);
1996 EXPECT_EQ(node.input(0), "id");
1997 }
1998 }
1999 EXPECT_EQ(nodes_after_optimization,
2000 std::set<string>({"id", "id1", "inputs_shape", "inputs"}));
2001 }
2002
TEST_F(ArithmeticOptimizerTest,FoldMulToTransposeConv)2003 TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
2004 for (bool swap_inputs : {false, true}) {
2005 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2006 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2007 ops::Placeholder::Shape({1, 28, 28, 3}));
2008 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2009 Output scaled_inputs = ops::Multiply(s.WithOpName("scaled_inputs"),
2010 swap_inputs ? scale : inputs,
2011 swap_inputs ? inputs : scale);
2012 Output perm_nhwc_to_nchw =
2013 ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
2014 Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
2015 scaled_inputs, perm_nhwc_to_nchw);
2016 Output weights = ops::Const(s.WithOpName("weights"),
2017 Input::Initializer(127.0f, {5, 5, 3, 4}));
2018 Output conv =
2019 ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
2020 "VALID", ops::Conv2D::DataFormat("NCHW"));
2021 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2022
2023 GrapplerItem item;
2024 item.fetch = {"outputs"};
2025 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2026
2027 // LOG(INFO) << "Before:\n" << item.graph.DebugString();
2028 GraphDef output;
2029 ArithmeticOptimizer optimizer;
2030 EnableOnlyFoldMultipleIntoConv(&optimizer);
2031 OptimizeTwiceAndPrune(&optimizer, &item, &output);
2032
2033 // LOG(INFO) << "After:\n" << output.DebugString();
2034 NodeMap node_map(&output);
2035 // `conv` is now a folded convolution with scaled weights.
2036 const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
2037 ASSERT_NE(folded_conv, nullptr);
2038
2039 const NodeDef* folded_conv_weights =
2040 node_map.GetNode(folded_conv->input(1));
2041 ASSERT_NE(folded_conv_weights, nullptr);
2042 EXPECT_EQ(folded_conv_weights->op(), "Mul");
2043
2044 // Its input should be a transpose of `inputs`.
2045 const NodeDef* transpose =
2046 node_map.GetNode(NodeName(folded_conv->input(0)));
2047 ASSERT_NE(transpose, nullptr);
2048 ASSERT_EQ(transpose->input_size(), 2);
2049 EXPECT_EQ(transpose->input(0), "inputs");
2050 }
2051 }
2052
TEST_F(ArithmeticOptimizerTest,NotFoldMulAcrossPreservedTranspose)2053 TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
2054 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2055 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2056 ops::Placeholder::Shape({8, 28, 28, 3}));
2057 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2058 Output scaled_inputs =
2059 ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
2060 Output perm_nhwc_to_nchw =
2061 ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
2062 Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
2063 scaled_inputs, perm_nhwc_to_nchw);
2064 Output weights = ops::Const(s.WithOpName("weights"),
2065 Input::Initializer(127.0f, {5, 5, 3, 16}));
2066 Output conv =
2067 ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
2068 "VALID", ops::Conv2D::DataFormat("NCHW"));
2069 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2070
2071 Tensor inputs_nchw_tensor(DT_FLOAT, {8, 3, 28, 28});
2072 memset(const_cast<char*>(inputs_nchw_tensor.tensor_data().data()), 0,
2073 inputs_nchw_tensor.tensor_data().size());
2074
2075 GrapplerItem item;
2076 item.fetch = {"outputs"};
2077 item.feed = {{"inputs_nchw", inputs_nchw_tensor}};
2078 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2079
2080 GraphDef output;
2081 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
2082
2083 item.graph.Swap(&output);
2084 TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
2085
2086 NodeMap node_map(&output);
2087 const NodeDef* inputs_nchw_node_def =
2088 node_map.GetNode(inputs_nchw.node()->name());
2089 ASSERT_NE(inputs_nchw_node_def, nullptr);
2090 ASSERT_EQ(inputs_nchw_node_def->input_size(), 2);
2091 EXPECT_EQ(NodeName(inputs_nchw_node_def->input(0)),
2092 scaled_inputs.node()->name());
2093 }
2094
TEST_F(ArithmeticOptimizerTest,FoldMulToConv)2095 TEST_F(ArithmeticOptimizerTest, FoldMulToConv) {
2096 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2097 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2098 ops::Placeholder::Shape({8, 28, 28, 28, 3}));
2099 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2100 Output scaled_inputs =
2101 ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
2102 Output weights = ops::Const(s.WithOpName("weights"),
2103 Input::Initializer(127.0f, {5, 5, 5, 3, 16}));
2104 Output conv = ops::Conv3D(s.WithOpName("conv"), scaled_inputs, weights,
2105 {1, 1, 1, 1, 1}, "VALID");
2106 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2107
2108 GrapplerItem item;
2109 item.fetch = {"outputs"};
2110 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2111
2112 GraphDef output;
2113 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
2114
2115 item.graph.Swap(&output);
2116 TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
2117
2118 NodeMap node_map(&output);
2119 // `conv` is now a folded convolution on `inputs` and scaled weights.
2120 const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
2121 ASSERT_NE(folded_conv, nullptr);
2122 ASSERT_EQ(folded_conv->input_size(), 2);
2123 CHECK_EQ(NodeName(folded_conv->input(0)), inputs.node()->name());
2124 const NodeDef* folded_conv_input_1 =
2125 node_map.GetNode(NodeName(folded_conv->input(1)));
2126 ASSERT_NE(folded_conv_input_1, nullptr);
2127 CHECK_EQ(folded_conv_input_1->op(), "Mul");
2128 }
2129
TEST_F(ArithmeticOptimizerTest,OptimizeCastMulTransposeConv)2130 TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
2131 // This unit test exercises two optimizations, folding mul into conv, and
2132 // reordering cast and transpose.
2133 //
2134 // Conv2D(Transpose(Mul(Cast(I), S)), W)
2135 // =>
2136 // Conv2D(Transpose(Cast(I)), W*S)
2137 // =>
2138 // Conv2D(Cast(Transpose(I)), W*S)
2139 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
2140
2141 Output inputs =
2142 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
2143 Output cast = ops::Cast(s, inputs, DT_FLOAT);
2144 Output mul = ops::Mul(s, cast, ops::Const(s, 1.0f / 255.0f));
2145 Output transpose =
2146 ops::Transpose(s, mul, ops::Const(s.WithOpName("perm"), {0, 3, 1, 2}));
2147 Output weights = ops::Const(s.WithOpName("weights"),
2148 Input::Initializer(127.0f, {5, 5, 3, 16}));
2149 Output conv = ops::Conv2D(s, transpose, weights, {1, 1, 1, 1}, "VALID",
2150 ops::Conv2D::DataFormat("NCHW"));
2151 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2152
2153 GrapplerItem item;
2154 item.fetch = {"outputs"};
2155 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2156
2157 GraphDef output;
2158 ArithmeticOptimizer optimizer; // all optimization stages are on
2159 OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
2160 NodeMap node_map(&output);
2161
2162 // Expected names for reordered cast and transpose.
2163 const string p = "ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_";
2164 const string optimized_cast_name = absl::StrCat(p, "float_Cast");
2165 const string optimized_transpose_name = absl::StrCat(p, "uint8_Transpose");
2166
2167 // Expected names for folded multiply and conv.
2168 const string optimized_weights =
2169 "ArithmeticOptimizer/FoldMultiplyIntoConv_scaled_Conv2D_weights";
2170
2171 const NodeDef* inputs_node = node_map.GetNode("Placeholder");
2172 const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name);
2173 const NodeDef* cast_node = node_map.GetNode(optimized_cast_name);
2174
2175 const NodeDef* weights_node = node_map.GetNode(optimized_weights);
2176 const NodeDef* conv_node = node_map.GetNode("Conv2D");
2177
2178 ASSERT_NE(inputs_node, nullptr);
2179 ASSERT_NE(transpose_node, nullptr);
2180 ASSERT_NE(cast_node, nullptr);
2181 ASSERT_NE(weights_node, nullptr);
2182 ASSERT_NE(conv_node, nullptr);
2183
2184 EXPECT_EQ(output.node_size(), 7);
2185 ASSERT_EQ(transpose_node->input_size(), 2);
2186 EXPECT_EQ(transpose_node->input(0), inputs_node->name());
2187 ASSERT_EQ(cast_node->input_size(), 1);
2188 EXPECT_EQ(cast_node->input(0), transpose_node->name());
2189 ASSERT_EQ(conv_node->input_size(), 2);
2190 EXPECT_EQ(conv_node->input(0), cast_node->name());
2191 EXPECT_EQ(conv_node->input(1), weights_node->name());
2192 }
2193
TEST_F(ArithmeticOptimizerTest,OptimizeMultipleMulTransposeConv)2194 TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
2195 // This unit test exercises optimization of folding mul into conv for
2196 // multiple nodes in the graph.
2197 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
2198
2199 GrapplerItem item;
2200 Output conv[2];
2201
2202 for (int i = 0; i < 2; ++i) {
2203 Output inputs =
2204 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28}));
2205 Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f));
2206 Output weights = ops::Const(s.WithOpName("weights"),
2207 Input::Initializer(127.0f, {5, 5, 3, 16}));
2208 conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID",
2209 ops::Conv2D::DataFormat("NCHW"));
2210 }
2211 Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]);
2212
2213 item.fetch = {"outputs"};
2214 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2215
2216 GraphDef output;
2217 ArithmeticOptimizer optimizer;
2218 EnableOnlyFoldMultipleIntoConv(&optimizer);
2219 OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
2220
2221 NodeMap node_map(&output);
2222
2223 using absl::StrCat;
2224 const string p = "ArithmeticOptimizer/FoldMultiplyIntoConv_";
2225 const string optimized_weights = StrCat(p, "scaled_Conv2D_weights");
2226 const string optimized_weights_1 = StrCat(p, "scaled_Conv2D_1_weights_1");
2227
2228 const NodeDef* weights_node = node_map.GetNode(optimized_weights);
2229 const NodeDef* weights_node_1 = node_map.GetNode(optimized_weights_1);
2230 const NodeDef* conv_node = node_map.GetNode("Conv2D");
2231 const NodeDef* conv_node_1 = node_map.GetNode("Conv2D_1");
2232
2233 ASSERT_NE(weights_node, nullptr);
2234 ASSERT_NE(weights_node_1, nullptr);
2235 ASSERT_NE(conv_node, nullptr);
2236 ASSERT_NE(conv_node_1, nullptr);
2237
2238 ASSERT_EQ(conv_node->input_size(), 2);
2239 ASSERT_EQ(conv_node_1->input_size(), 2);
2240 EXPECT_EQ(conv_node->input(1), weights_node->name());
2241 EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
2242 }
2243
TEST_F(ArithmeticOptimizerTest,CombineBitcasts)2244 TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
2245 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2246 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8,
2247 ops::Placeholder::Shape({2, 3}));
2248 Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8);
2249 Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8);
2250 Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
2251
2252 GrapplerItem item;
2253 item.fetch = {"outputs"};
2254 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2255
2256 auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3}));
2257 item.feed = {{"inputs", x_t}};
2258 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2259 ASSERT_EQ(tensors_expected.size(), 1);
2260
2261 GraphDef output;
2262 ArithmeticOptimizer optimizer;
2263 EnableOnlyRemoveRedundantBitcast(&optimizer);
2264
2265 OptimizeAndPrune(&optimizer, &item, &output);
2266 NodeMap node_map(&output);
2267
2268 // Bitcasts combined into a single op and inputs redirected to updated Bitcast
2269 EXPECT_EQ(output.node_size(), 3);
2270 EXPECT_EQ(CountOpNodes(output, "Bitcast"), 1);
2271 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
2272
2273 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2274 ASSERT_EQ(tensors.size(), 1);
2275 test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2276 }
2277
TEST_F(ArithmeticOptimizerTest,CombineAndRemoveBitcasts)2278 TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
2279 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2280 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
2281 ops::Placeholder::Shape({2, 3}));
2282 Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
2283 Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
2284 Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
2285
2286 GrapplerItem item;
2287 item.fetch = {"outputs"};
2288 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2289
2290 auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
2291 item.feed = {{"inputs", x_t}};
2292 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2293 ASSERT_EQ(tensors_expected.size(), 1);
2294
2295 GraphDef output;
2296 ArithmeticOptimizer optimizer;
2297 EnableOnlyRemoveRedundantBitcast(&optimizer);
2298
2299 OptimizeAndPrune(&optimizer, &item, &output);
2300 NodeMap node_map(&output);
2301
2302 // Bitcasts removed and inputs redirected to outputs
2303 EXPECT_EQ(output.node_size(), 2);
2304 EXPECT_EQ(CountOpNodes(output, "Bitcast"), 0);
2305 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
2306
2307 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2308 ASSERT_EQ(tensors.size(), 1);
2309 test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2310 }
2311
TEST_F(ArithmeticOptimizerTest,RemoveRedundantCast)2312 TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
2313 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2314 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
2315 ops::Placeholder::Shape({2, 3}));
2316 Output cast = ops::Cast(s, inputs, DT_INT8);
2317 Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
2318
2319 GrapplerItem item;
2320 item.fetch = {"outputs"};
2321 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2322
2323 auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
2324 item.feed = {{"inputs", x_t}};
2325 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2326 ASSERT_EQ(tensors_expected.size(), 1);
2327
2328 GraphDef output;
2329 ArithmeticOptimizer optimizer;
2330 EnableOnlyRemoveRedundantCast(&optimizer);
2331
2332 OptimizeAndPrune(&optimizer, &item, &output);
2333 NodeMap node_map(&output);
2334
2335 // Cast removed and inputs redirected to outputs
2336 EXPECT_EQ(output.node_size(), 2);
2337 EXPECT_EQ(CountOpNodes(output, "Cast"), 0);
2338 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
2339
2340 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2341 ASSERT_EQ(tensors.size(), 1);
2342 test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2343 }
2344
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfIdenticalShape)2345 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfIdenticalShape) {
2346 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2347 tensorflow::Scope sx = s.NewSubScope("x");
2348 tensorflow::Scope sy = s.NewSubScope("y");
2349
2350 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2351 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2352 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2353 auto add_bc = ops::Add(sx.WithOpName("Add_bc"), b, c);
2354 auto add_abc = ops::Add(sy.WithOpName("Add_abc"), a, add_bc);
2355
2356 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2357
2358 GrapplerItem item;
2359 item.fetch = {"outputs"};
2360 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2361
2362 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2363 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2364 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2365 std::vector<std::pair<string, Tensor>> feed = {
2366 {"a", a_t}, {"b", b_t}, {"c", c_t}};
2367 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2368 ASSERT_EQ(tensors_expected.size(), 1);
2369
2370 GraphDef output;
2371 ArithmeticOptimizer optimizer;
2372 EnableOnlyAddToAddNCombining(&optimizer);
2373
2374 OptimizeAndPrune(&optimizer, &item, &output);
2375
2376 // We expect the following rewrite(s) to occur:
2377 //
2378 // +
2379 // / \
2380 // a + --> AddN(a, b, c)
2381 // / \
2382 // b c
2383 EXPECT_EQ(output.node_size(), 5);
2384
2385 NodeMap node_map(&output);
2386
2387 // check add tree was replaced with AddN
2388 const NodeDef* collapsed_add =
2389 node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2390 ASSERT_NE(collapsed_add, nullptr);
2391
2392 EXPECT_EQ(collapsed_add->op(), "AddN");
2393 ASSERT_EQ(collapsed_add->input_size(), 3);
2394 EXPECT_EQ(collapsed_add->input(0), "a");
2395 EXPECT_EQ(collapsed_add->input(1), "b");
2396 EXPECT_EQ(collapsed_add->input(2), "c");
2397
2398 // check output was re-wired to new node
2399 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2400 ASSERT_NE(updated_outputs, nullptr);
2401 ASSERT_EQ(updated_outputs->input_size(), 1);
2402 EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
2403
2404 auto tensors = EvaluateNodes(output, item.fetch, feed);
2405 ASSERT_EQ(tensors.size(), 1);
2406 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2407 }
2408
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMultiplePasses)2409 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
2410 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2411
2412 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2413 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2414 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2415 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2416 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2417
2418 auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2419 auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2420 auto z = ops::Variable(s.WithOpName("z"), {2, 2}, DT_FLOAT);
2421 auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2422 auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2423
2424 auto mul = ops::Multiply(s.WithOpName("Mul"), add_abc, add_xyz);
2425 auto outputs = ops::Identity(s.WithOpName("outputs"), mul);
2426
2427 GrapplerItem item;
2428 item.fetch = {"outputs"};
2429 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2430
2431 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2432 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2433 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2434 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2435 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2436 auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2437 std::vector<std::pair<string, Tensor>> feed = {
2438 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2439 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2440 ASSERT_EQ(tensors_expected.size(), 1);
2441
2442 GraphDef output;
2443 ArithmeticOptimizer optimizer;
2444 EnableOnlyAddToAddNCombining(&optimizer);
2445
2446 OptimizeAndPrune(&optimizer, &item, &output);
2447
2448 // We expect the following rewrite(s) to occur:
2449 //
2450 // *
2451 // / \
2452 // + + *
2453 // / \ / \ / \
2454 // + c x + --> AddN(a, b, c) AddN(x, y, z))
2455 // / \ / \
2456 // a b y z
2457 EXPECT_EQ(output.node_size(), 10);
2458
2459 NodeMap node_map(&output);
2460
2461 // check left Add subtree replaced with AddN
2462 const NodeDef* collapsed_left =
2463 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2464 ASSERT_NE(collapsed_left, nullptr);
2465
2466 EXPECT_EQ(collapsed_left->op(), "AddN");
2467 ASSERT_EQ(collapsed_left->input_size(), 3);
2468 EXPECT_EQ(collapsed_left->input(0), "a");
2469 EXPECT_EQ(collapsed_left->input(1), "b");
2470 EXPECT_EQ(collapsed_left->input(2), "c");
2471
2472 // check right Add subtree replaced with AddN
2473 const NodeDef* collapsed_right =
2474 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz");
2475 ASSERT_NE(collapsed_right, nullptr);
2476
2477 EXPECT_EQ(collapsed_right->op(), "AddN");
2478 ASSERT_EQ(collapsed_right->input_size(), 3);
2479 EXPECT_EQ(collapsed_right->input(0), "x");
2480 EXPECT_EQ(collapsed_right->input(1), "y");
2481 EXPECT_EQ(collapsed_right->input(2), "z");
2482
2483 // check that Mul inputs re-wired to new Nodes
2484 const NodeDef* updated_mul = node_map.GetNode("Mul");
2485 ASSERT_NE(updated_mul, nullptr);
2486
2487 EXPECT_EQ(updated_mul->op(), "Mul");
2488 ASSERT_EQ(updated_mul->input_size(), 2);
2489 EXPECT_EQ(updated_mul->input(0), collapsed_left->name());
2490 EXPECT_EQ(updated_mul->input(1), collapsed_right->name());
2491
2492 auto tensors = EvaluateNodes(output, item.fetch, feed);
2493 ASSERT_EQ(tensors.size(), 1);
2494 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2495 }
2496
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddInputMultipleTimes)2497 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputMultipleTimes) {
2498 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2499
2500 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2501 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2502 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2503 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2504 auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
2505 auto add_all = ops::Add(s.WithOpName("Add_all"), add_ab, add_bc);
2506 auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2507
2508 GrapplerItem item;
2509 item.fetch = {"outputs"};
2510 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2511
2512 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2513 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2514 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2515 std::vector<std::pair<string, Tensor>> feed = {
2516 {"a", a_t}, {"b", b_t}, {"c", c_t}};
2517 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2518 ASSERT_EQ(tensors_expected.size(), 1);
2519
2520 GraphDef output;
2521 ArithmeticOptimizer optimizer;
2522 EnableOnlyAddToAddNCombining(&optimizer);
2523
2524 OptimizeAndPrune(&optimizer, &item, &output);
2525
2526 // We expect the following rewrite(s) to occur:
2527 //
2528 // +
2529 // / \
2530 // + + --> AddN(a, b, b, c)
2531 // / \ / \ ^
2532 // a b c b added twice!
2533 EXPECT_EQ(output.node_size(), 5);
2534
2535 NodeMap node_map(&output);
2536
2537 // check Add tree replaced with AddN
2538 const NodeDef* collapsed_add =
2539 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_all");
2540 ASSERT_NE(collapsed_add, nullptr);
2541
2542 EXPECT_EQ(collapsed_add->op(), "AddN");
2543 ASSERT_EQ(collapsed_add->input_size(), 4);
2544 EXPECT_EQ(collapsed_add->input(0), "a");
2545 EXPECT_EQ(collapsed_add->input(1), "b");
2546 EXPECT_EQ(collapsed_add->input(2), "b");
2547 EXPECT_EQ(collapsed_add->input(3), "c");
2548
2549 auto tensors = EvaluateNodes(output, item.fetch, feed);
2550 ASSERT_EQ(tensors.size(), 1);
2551 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2552 }
2553
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfSymbolicallyEqualShape)2554 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfSymbolicallyEqualShape) {
2555 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2556
2557 // unknown input shape propagated symbolically through the graph
2558 auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
2559
2560 // [a, b, c] have symbolically equal shapes
2561 auto a = ops::Sqrt(s.WithOpName("a"), input);
2562 auto b = ops::Square(s.WithOpName("b"), input);
2563 auto c = ops::Round(s.WithOpName("c"), input);
2564
2565 // [add_ab, add_abc] shape must be inferred from inputs
2566 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2567 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2568
2569 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2570
2571 GrapplerItem item;
2572 item.fetch = {"outputs"};
2573 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2574
2575 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2576 std::vector<std::pair<string, Tensor>> feed = {{"input", x_t}};
2577 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2578 ASSERT_EQ(tensors_expected.size(), 1);
2579
2580 GraphDef output;
2581 ArithmeticOptimizer optimizer;
2582 EnableOnlyAddToAddNCombining(&optimizer);
2583
2584 OptimizeAndPrune(&optimizer, &item, &output);
2585
2586 // We expect the following rewrite(s) to occur:
2587 //
2588 // +
2589 // / \
2590 // + c --> AddN(a, b, c)
2591 // / \
2592 // a b
2593 EXPECT_EQ(output.node_size(), 6);
2594
2595 NodeMap node_map(&output);
2596
2597 // check add tree was replaced with AddN
2598 const NodeDef* collapsed_add =
2599 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2600 ASSERT_NE(collapsed_add, nullptr);
2601 EXPECT_EQ(collapsed_add->op(), "AddN");
2602 ASSERT_EQ(collapsed_add->input_size(), 3);
2603 EXPECT_EQ(collapsed_add->input(0), "a");
2604 EXPECT_EQ(collapsed_add->input(1), "b");
2605 EXPECT_EQ(collapsed_add->input(2), "c");
2606
2607 // check output was re-wired to new node
2608 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2609 ASSERT_NE(updated_outputs, nullptr);
2610 ASSERT_EQ(updated_outputs->input_size(), 1);
2611 EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
2612
2613 auto tensors = EvaluateNodes(output, item.fetch, feed);
2614 ASSERT_EQ(tensors.size(), 1);
2615 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2616 }
2617
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCast)2618 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCast) {
2619 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2620
2621 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2622 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2623 auto c = ops::Variable(s.WithOpName("c"), {32, 32, 32}, DT_FLOAT);
2624 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2625 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2626
2627 auto x = ops::Variable(s.WithOpName("x"), {32}, DT_FLOAT);
2628 auto y = ops::Variable(s.WithOpName("y"), {32, 32}, DT_FLOAT);
2629 auto z = ops::Variable(s.WithOpName("z"), {32, 32, 32}, DT_FLOAT);
2630 auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2631 auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2632
2633 auto add_all = ops::Add(s.WithOpName("AddAll"), add_abc, add_xyz);
2634 auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2635
2636 GrapplerItem item;
2637 item.fetch = {"outputs"};
2638 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2639
2640 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2641 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2642 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2643 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2644 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2645 auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2646 std::vector<std::pair<string, Tensor>> feed = {
2647 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2648 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2649 ASSERT_EQ(tensors_expected.size(), 1);
2650
2651 GraphDef output;
2652 ArithmeticOptimizer optimizer;
2653 EnableOnlyAddToAddNCombining(&optimizer);
2654
2655 OptimizeAndPrune(&optimizer, &item, &output);
2656
2657 // We expect the following rewrite(s) to occur:
2658 // 1) [a, x], [b, y], [c, z] - aggregate same shapes first
2659 // 2) Build an aggregation tree minimizing cost of broadcast
2660 //
2661 // + +
2662 // / \ / \
2663 // + + + AddN(c, z)
2664 // / \ / \ / \
2665 // + c x + --> AddN(a, x) AddN(b, y)
2666 // / \ / \
2667 // a b y z
2668 EXPECT_EQ(output.node_size(), 12);
2669 NodeMap node_map(&output);
2670
2671 // expected names of outer and inner nodes
2672 string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_AddAll";
2673 string outer_0_add_name =
2674 "ArithmeticOptimizer/AddOpsRewrite_Internal_0_AddAll";
2675 string inner_0_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_AddAll";
2676 string inner_1_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_1_AddAll";
2677 string inner_2_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_2_AddAll";
2678
2679 // Add [a, x] first
2680 const NodeDef* add_ax_node = node_map.GetNode(inner_0_add_name);
2681 ASSERT_NE(add_ax_node, nullptr);
2682 EXPECT_EQ(add_ax_node->op(), "AddN");
2683 ASSERT_EQ(add_ax_node->input_size(), 2);
2684 EXPECT_EQ(add_ax_node->input(0), "a");
2685 EXPECT_EQ(add_ax_node->input(1), "x");
2686
2687 // Then add [b, y]
2688 const NodeDef* add_by_node = node_map.GetNode(inner_1_add_name);
2689 ASSERT_NE(add_by_node, nullptr);
2690 EXPECT_EQ(add_by_node->op(), "AddN");
2691 ASSERT_EQ(2, add_by_node->input_size());
2692 EXPECT_EQ(add_by_node->input(0), "b");
2693 EXPECT_EQ(add_by_node->input(1), "y");
2694
2695 // Then add [c, z]
2696 const NodeDef* add_cz_node = node_map.GetNode(inner_2_add_name);
2697 ASSERT_NE(add_cz_node, nullptr);
2698 EXPECT_EQ(add_cz_node->op(), "AddN");
2699 ASSERT_EQ(add_cz_node->input_size(), 2);
2700 EXPECT_EQ(add_cz_node->input(0), "c");
2701 EXPECT_EQ(add_cz_node->input(1), "z");
2702
2703 // Then add results together starting from smaller shapes [a, x] + [b, y]
2704 const NodeDef* outer_0_node = node_map.GetNode(outer_0_add_name);
2705 ASSERT_NE(outer_0_node, nullptr);
2706 EXPECT_EQ(outer_0_node->op(), "AddV2");
2707 ASSERT_EQ(outer_0_node->input_size(), 2);
2708 EXPECT_EQ(outer_0_node->input(0), inner_0_add_name);
2709 EXPECT_EQ(outer_0_node->input(1), inner_1_add_name);
2710
2711 // And finally top level Add node
2712 const NodeDef* outer_node = node_map.GetNode(outer_add_name);
2713 ASSERT_NE(outer_node, nullptr);
2714 EXPECT_EQ(outer_node->op(), "AddV2");
2715 ASSERT_EQ(outer_node->input_size(), 2);
2716 EXPECT_EQ(outer_node->input(0), outer_0_add_name);
2717 EXPECT_EQ(outer_node->input(1), inner_2_add_name);
2718
2719 // And outputs reading new top level Add node
2720 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2721 ASSERT_NE(updated_outputs, nullptr);
2722 ASSERT_EQ(updated_outputs->input_size(), 1);
2723 EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2724
2725 auto tensors = EvaluateNodes(output, item.fetch, feed);
2726 ASSERT_EQ(tensors.size(), 1);
2727 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2728 }
2729
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCastWithSymbolicShapes)2730 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCastWithSymbolicShapes) {
2731 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2732
2733 // We have a small input with one unknown dimension
2734 auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_DOUBLE);
2735
2736 // And second input which is larger, but has the same unknown dimension
2737 // device spec prevents this node from rewriting
2738 auto d = "/device:CPU:0";
2739 auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_DOUBLE);
2740 auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v);
2741
2742 // [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32}
2743 auto a = ops::Sqrt(s.WithOpName("a"), small);
2744 auto b = ops::Square(s.WithOpName("b"), large);
2745 auto c = ops::Round(s.WithOpName("c"), small);
2746
2747 // [add_ab, add_abc] shape must be inferred from inputs
2748 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2749 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2750
2751 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2752
2753 GrapplerItem item;
2754 item.fetch = {"outputs"};
2755 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2756
2757 auto s_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({8, 1, 1}));
2758 auto v_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({1, 32, 32}));
2759 std::vector<std::pair<string, Tensor>> feed = {{"small", s_t}, {"v", v_t}};
2760 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2761 ASSERT_EQ(tensors_expected.size(), 1);
2762
2763 GraphDef output;
2764 ArithmeticOptimizer optimizer;
2765 EnableOnlyAddToAddNCombining(&optimizer);
2766 OptimizeAndPrune(&optimizer, &item, &output);
2767
2768 // We expect the following rewrite(s) to occur: it's much cheaper to add small
2769 // tensors, and do the broadcast just once
2770 //
2771 // + +
2772 // / \ / \
2773 // + c --> + b
2774 // / \ / \
2775 // a b a c
2776 EXPECT_EQ(output.node_size(), 9);
2777 NodeMap node_map(&output);
2778
2779 // expected names of outer and inner nodes
2780 string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_Add_abc";
2781 string inner_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_Add_abc";
2782
2783 // outer Add node
2784 const NodeDef* outer_add = node_map.GetNode(outer_add_name);
2785 ASSERT_NE(outer_add, nullptr);
2786 EXPECT_EQ(outer_add->op(), "AddV2");
2787 ASSERT_EQ(outer_add->input_size(), 2);
2788 EXPECT_EQ(outer_add->input(0), inner_add_name);
2789 EXPECT_EQ(outer_add->input(1), "b");
2790
2791 // inner AddN node
2792 const NodeDef* inner_add = node_map.GetNode(inner_add_name);
2793 ASSERT_NE(inner_add, nullptr);
2794 ASSERT_EQ(inner_add->input_size(), 2);
2795 EXPECT_EQ(inner_add->input(0), "a");
2796 EXPECT_EQ(inner_add->input(1), "c");
2797
2798 // check output was re-wired to new node
2799 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2800 ASSERT_NE(updated_outputs, nullptr);
2801 ASSERT_EQ(updated_outputs->input_size(), 1);
2802 EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2803
2804 auto tensors = EvaluateNodes(output, item.fetch, feed);
2805 ASSERT_EQ(tensors.size(), 1);
2806 test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
2807 }
2808
TEST_F(ArithmeticOptimizerTest,RemoveNegation)2809 TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
2810 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2811 auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2812 auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2813 Output neg_x = ops::Neg(s.WithOpName("Neg_x"), x);
2814 Output neg_y = ops::Neg(s.WithOpName("Neg_y"), y);
2815 Output add_x_y = ops::Add(s.WithOpName("Add_x_y"), x, y);
2816 Output add_negx_y = ops::Add(s.WithOpName("Add_negx_y"), neg_x, y);
2817 Output add_x_negy = ops::Add(s.WithOpName("Add_x_negy"), x, neg_y);
2818 Output add_negx_negy = ops::Add(s.WithOpName("Add_negx_negy"), neg_x, neg_y);
2819 Output sub_x_y = ops::Sub(s.WithOpName("Sub_x_y"), x, y);
2820 Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
2821 Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
2822 Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
2823 Output neg_x_with_dep = ops::Neg(
2824 s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
2825 Output add_negx_with_dep_y =
2826 ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
2827 auto add_all =
2828 ops::AddN(s.WithOpName("add_all"),
2829 {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
2830 sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
2831
2832 GrapplerItem item;
2833 item.fetch = {"add_all"};
2834 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2835
2836 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2837 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2838 std::vector<std::pair<string, Tensor>> feed = {{"x", x_t}, {"y", y_t}};
2839 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2840 ASSERT_EQ(tensors_expected.size(), 1);
2841
2842 GraphDef output;
2843 ArithmeticOptimizer optimizer;
2844 EnableOnlyRemoveNegation(&optimizer);
2845 OptimizeTwice(&optimizer, &item, &output);
2846
2847 EXPECT_EQ(output.node_size(), item.graph.node_size());
2848 int found = 0;
2849 for (int i = 0; i < output.node_size(); ++i) {
2850 const NodeDef& node = output.node(i);
2851 if (node.name() == "Add_negx_y") {
2852 ++found;
2853 EXPECT_EQ(node.op(), "Sub");
2854 ASSERT_EQ(node.input_size(), 2);
2855 EXPECT_EQ(node.input(0), "y");
2856 EXPECT_EQ(node.input(1), "x");
2857 } else if (node.name() == "Add_x_negy") {
2858 ++found;
2859 EXPECT_EQ(node.op(), "Sub");
2860 ASSERT_EQ(node.input_size(), 2);
2861 EXPECT_EQ(node.input(0), "x");
2862 EXPECT_EQ(node.input(1), "y");
2863 } else if (node.name() == "Add_negx_negy") {
2864 ++found;
2865 EXPECT_EQ(node.op(), "Sub");
2866 ASSERT_EQ(node.input_size(), 2);
2867 EXPECT_EQ(node.input(0), "Neg_x");
2868 EXPECT_EQ(node.input(1), "y");
2869 } else if (node.name() == "Sub_x_negy") {
2870 ++found;
2871 EXPECT_EQ(node.op(), "AddV2");
2872 ASSERT_EQ(node.input_size(), 2);
2873 EXPECT_EQ(node.input(0), "x");
2874 EXPECT_EQ(node.input(1), "y");
2875 } else if (node.name() == "Sub_negx_negy") {
2876 ++found;
2877 EXPECT_EQ(node.op(), "Sub");
2878 ASSERT_EQ(node.input_size(), 2);
2879 EXPECT_EQ(node.input(0), "y");
2880 EXPECT_EQ(node.input(1), "x");
2881 } else if (node.name() == "Add_negx_with_dep_y") {
2882 ++found;
2883 EXPECT_EQ(node.op(), "Sub");
2884 ASSERT_EQ(node.input_size(), 3);
2885 EXPECT_EQ(node.input(0), "y");
2886 EXPECT_EQ(node.input(1), "x");
2887 EXPECT_EQ(node.input(2), "^Add_x_y");
2888 }
2889 }
2890 EXPECT_EQ(found, 6);
2891
2892 auto tensors = EvaluateNodes(output, item.fetch, feed);
2893 ASSERT_EQ(tensors.size(), 1);
2894 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2895 }
2896
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMul)2897 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
2898 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2899 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2900 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2901 Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2902 Output div_x_sqrt_y = ops::Div(s.WithOpName("output"), x, sqrt_y);
2903
2904 GrapplerItem item;
2905 item.fetch = {"output"};
2906 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2907 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2908 ASSERT_EQ(tensors_expected.size(), 1);
2909
2910 GraphDef output;
2911 ArithmeticOptimizer optimizer;
2912 EnableOnlySqrtDivToRsqrtMul(&optimizer);
2913 OptimizeAndPrune(&optimizer, &item, &output);
2914 auto tensors = EvaluateNodes(output, item.fetch);
2915 ASSERT_EQ(tensors.size(), 1);
2916
2917 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2918 EXPECT_EQ(output.node_size(), item.graph.node_size());
2919 for (int i = 0; i < output.node_size(); ++i) {
2920 const NodeDef& node = output.node(i);
2921 if (node.name() == "output") {
2922 EXPECT_EQ(node.op(), "Mul");
2923 ASSERT_EQ(node.input_size(), 2);
2924 EXPECT_EQ(node.input(0), "x");
2925 EXPECT_EQ(node.input(1), "sqrt_y");
2926 } else if (node.name() == "sqrt_y") {
2927 EXPECT_EQ(node.op(), "Rsqrt");
2928 ASSERT_EQ(node.input_size(), 1);
2929 EXPECT_EQ(node.input(0), "y");
2930 }
2931 }
2932 }
2933
TEST_F(ArithmeticOptimizerTest,DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode)2934 TEST_F(ArithmeticOptimizerTest, DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode) {
2935 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2936 Output floats = ops::Const(s.WithOpName("floats"),
2937 {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2938 Output output0 = ops::Sqrt(s.WithOpName("output0"), floats);
2939 Output const1 = ops::Const(s.WithOpName("const1"), 1.0f, {3});
2940 Output mul1 = ops::Multiply(s.WithOpName("mul1"), const1, 0.5f);
2941 Output grad = ops::Div(s.WithOpName("grad"), mul1, output0);
2942
2943 GrapplerItem item;
2944 item.fetch = {"grad", "output0"};
2945 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2946 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2947 ASSERT_EQ(tensors_expected.size(), 2);
2948
2949 GraphDef output;
2950 ArithmeticOptimizer optimizer;
2951 EnableOnlySqrtDivToRsqrtMul(&optimizer);
2952 OptimizeAndPrune(&optimizer, &item, &output);
2953 auto tensors = EvaluateNodes(output, item.fetch);
2954 ASSERT_EQ(tensors.size(), 2);
2955
2956 for (int i = 0; i < tensors.size(); i++) {
2957 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2958 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2959 }
2960 EXPECT_EQ(output.node_size(), item.graph.node_size());
2961 for (int i = 0; i < output.node_size(); ++i) {
2962 const NodeDef& node = output.node(i);
2963 if (node.name() == "grad") {
2964 EXPECT_EQ(node.op(), "Div");
2965 ASSERT_EQ(node.input_size(), 2);
2966 EXPECT_EQ(node.input(0), "mul1");
2967 EXPECT_EQ(node.input(1), "output0");
2968 } else if (node.name() == "output0") {
2969 EXPECT_EQ(node.op(), "Sqrt");
2970 ASSERT_EQ(node.input_size(), 1);
2971 EXPECT_EQ(node.input(0), "floats");
2972 }
2973 }
2974 }
2975
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMulExcludeFloorDiv)2976 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMulExcludeFloorDiv) {
2977 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2978 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2979 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2980 Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2981 Output div_x_sqrt_y = ops::FloorDiv(s.WithOpName("output"), x, sqrt_y);
2982
2983 GrapplerItem item;
2984 item.fetch = {"output"};
2985 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2986 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2987 ASSERT_EQ(tensors_expected.size(), 1);
2988
2989 GraphDef output;
2990 ArithmeticOptimizer optimizer;
2991 EnableOnlySqrtDivToRsqrtMul(&optimizer);
2992 OptimizeAndPrune(&optimizer, &item, &output);
2993 auto tensors = EvaluateNodes(output, item.fetch);
2994 ASSERT_EQ(tensors.size(), 1);
2995
2996 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2997 EXPECT_EQ(output.node_size(), item.graph.node_size());
2998 for (int i = 0; i < output.node_size(); ++i) {
2999 const NodeDef& node = output.node(i);
3000 if (node.name() == "output") {
3001 EXPECT_EQ(node.op(), "FloorDiv");
3002 ASSERT_EQ(node.input_size(), 2);
3003 EXPECT_EQ(node.input(0), "x");
3004 EXPECT_EQ(node.input(1), "sqrt_y");
3005 } else if (node.name() == "sqrt_y") {
3006 EXPECT_EQ(node.op(), "Sqrt");
3007 ASSERT_EQ(node.input_size(), 1);
3008 EXPECT_EQ(node.input(0), "y");
3009 }
3010 }
3011 }
3012
TEST_F(ArithmeticOptimizerTest,FuseSquaredDiff)3013 TEST_F(ArithmeticOptimizerTest, FuseSquaredDiff) {
3014 for (bool is_complex : {false, true}) {
3015 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3016 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3017 Output y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3018 Output complex_x = ops::Complex(s.WithOpName("complex_x"), x, x);
3019 Output complex_y = ops::Complex(s.WithOpName("complex_y"), y, y);
3020 Output sub_x_y =
3021 is_complex ? ops::Sub(s.WithOpName("sub_x_y"), complex_x, complex_y)
3022 : ops::Sub(s.WithOpName("sub_x_y"), x, y);
3023 Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
3024
3025 GrapplerItem item;
3026 item.fetch = {"output"};
3027 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3028 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3029 ASSERT_EQ(tensors_expected.size(), 1);
3030
3031 GraphDef output;
3032 ArithmeticOptimizer optimizer;
3033 EnableOnlyFuseSquaredDiff(&optimizer);
3034 OptimizeAndPrune(&optimizer, &item, &output);
3035 const auto tensors = EvaluateNodes(output, item.fetch);
3036 ASSERT_EQ(tensors.size(), 1);
3037
3038 if (is_complex) {
3039 test::ExpectTensorNear<std::complex<float>>(tensors[0],
3040 tensors_expected[0], 1e-6);
3041 EXPECT_EQ(output.node_size(), item.graph.node_size());
3042 } else {
3043 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3044 // The two unused Complex nodes should get pruned.
3045 EXPECT_EQ(output.node_size(), item.graph.node_size() - 2);
3046 }
3047 for (int i = 0; i < output.node_size(); ++i) {
3048 const NodeDef& node = output.node(i);
3049 if (node.name() == "output") {
3050 EXPECT_EQ(node.op(), is_complex ? "Square" : "Identity");
3051 ASSERT_EQ(node.input_size(), 1);
3052 EXPECT_EQ(node.input(0), "sub_x_y");
3053 } else if (node.name() == "sub_x_y") {
3054 EXPECT_EQ(node.op(), is_complex ? "Sub" : "SquaredDifference");
3055 ASSERT_EQ(node.input_size(), 2);
3056 EXPECT_EQ(node.input(0), is_complex ? "complex_x" : "x");
3057 EXPECT_EQ(node.input(1), is_complex ? "complex_y" : "y");
3058 }
3059 }
3060 }
3061 }
3062
TEST_F(ArithmeticOptimizerTest,DoNotFuseSquaredDiffFetchNode)3063 TEST_F(ArithmeticOptimizerTest, DoNotFuseSquaredDiffFetchNode) {
3064 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3065 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3066 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3067 Output sub_x_y = ops::Sub(s.WithOpName("sub_x_y"), x, y);
3068 Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
3069
3070 GrapplerItem item;
3071 item.fetch = {"output", "sub_x_y"};
3072 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3073 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3074 ASSERT_EQ(tensors_expected.size(), 2);
3075
3076 GraphDef output;
3077 ArithmeticOptimizer optimizer;
3078 EnableOnlyFuseSquaredDiff(&optimizer);
3079 OptimizeAndPrune(&optimizer, &item, &output);
3080 const auto tensors = EvaluateNodes(output, item.fetch);
3081 ASSERT_EQ(tensors.size(), 2);
3082
3083 for (int i = 0; i < tensors.size(); i++) {
3084 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3085 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3086 }
3087 EXPECT_EQ(output.node_size(), item.graph.node_size());
3088 for (int i = 0; i < output.node_size(); ++i) {
3089 const NodeDef& node = output.node(i);
3090 if (node.name() == "output") {
3091 EXPECT_EQ(node.op(), "Square");
3092 ASSERT_EQ(node.input_size(), 1);
3093 EXPECT_EQ(node.input(0), "sub_x_y");
3094 } else if (node.name() == "sub_x_y") {
3095 EXPECT_EQ(node.op(), "Sub");
3096 ASSERT_EQ(node.input_size(), 2);
3097 EXPECT_EQ(node.input(0), "x");
3098 EXPECT_EQ(node.input(1), "y");
3099 }
3100 }
3101 }
3102
TEST_F(ArithmeticOptimizerTest,ConvertLogSoftmax)3103 TEST_F(ArithmeticOptimizerTest, ConvertLogSoftmax) {
3104 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3105 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3106 Output softmax = ops::Softmax(s.WithOpName("softmax"), x);
3107 Output logsoftmax = ops::Log(s.WithOpName("output"), softmax);
3108
3109 GrapplerItem item;
3110 item.fetch = {"output"};
3111 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3112 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3113 ASSERT_EQ(tensors_expected.size(), 1);
3114
3115 GraphDef output;
3116 ArithmeticOptimizer optimizer;
3117 EnableOnlyLogSoftmax(&optimizer);
3118 OptimizeAndPrune(&optimizer, &item, &output);
3119 const auto tensors = EvaluateNodes(output, item.fetch);
3120 ASSERT_EQ(tensors.size(), 1);
3121
3122 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3123 EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
3124 for (int i = 0; i < output.node_size(); ++i) {
3125 const NodeDef& node = output.node(i);
3126 if (node.name() == "output") {
3127 EXPECT_EQ(node.op(), "LogSoftmax");
3128 ASSERT_EQ(node.input_size(), 1);
3129 EXPECT_EQ(node.input(0), "x");
3130 }
3131 }
3132 }
3133
TEST_F(ArithmeticOptimizerTest,DoNotConvertLogSoftmaxArgFetchNode)3134 TEST_F(ArithmeticOptimizerTest, DoNotConvertLogSoftmaxArgFetchNode) {
3135 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3136 Output floats = ops::Const(s.WithOpName("floats"),
3137 {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
3138 Output softmax = ops::Softmax(s.WithOpName("softmax"), floats);
3139 Output final_output = ops::Log(s.WithOpName("final_output"), softmax);
3140
3141 GrapplerItem item;
3142 item.fetch = {"softmax", "final_output"};
3143 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3144 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3145 ASSERT_EQ(tensors_expected.size(), 2);
3146
3147 GraphDef output;
3148 ArithmeticOptimizer optimizer;
3149 EnableOnlyLogSoftmax(&optimizer);
3150 OptimizeTwice(&optimizer, &item, &output);
3151 const auto tensors = EvaluateNodes(output, item.fetch);
3152 ASSERT_EQ(tensors.size(), 2);
3153
3154 // Should be a NoOp since we are not allowed to change the output of fetch
3155 // nodes.
3156 VerifyGraphsMatch(item.graph, output, __LINE__);
3157
3158 for (int i = 0; i < tensors.size(); i++) {
3159 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3160 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3161 }
3162 }
3163
TEST_F(ArithmeticOptimizerTest,ConvertPow)3164 TEST_F(ArithmeticOptimizerTest, ConvertPow) {
3165 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3166 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3167 auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
3168 auto y3 = ops::Const(s.WithOpName("y3"), {3.0f, 3.0f}, {1, 2});
3169 auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
3170 auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
3171 auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
3172 auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
3173 auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
3174 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3175 auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
3176 auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
3177 auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
3178 Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
3179 Output out3 =
3180 ops::Pow(s.WithOpName("out3").WithDevice("/device:CPU:0"), x, y3);
3181 Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
3182 Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
3183 Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
3184 Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
3185 Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
3186 Output out = ops::Pow(s.WithOpName("out"), x, y);
3187 Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
3188 Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
3189
3190 GrapplerItem item;
3191 item.fetch = {"out2", "out3", "out1", "out.5", "out0",
3192 "out_.5", "out_1", "out", "out_bcast1", "out_bcast2"};
3193 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3194 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3195 ASSERT_EQ(tensors_expected.size(), 10);
3196
3197 GraphDef got;
3198 ArithmeticOptimizer optimizer;
3199 EnableOnlyConvertPow(&optimizer);
3200 OptimizeAndPrune(&optimizer, &item, &got);
3201 auto tensors = EvaluateNodes(got, item.fetch);
3202 ASSERT_EQ(tensors.size(), 10);
3203
3204 for (int i = 0; i < tensors.size(); ++i) {
3205 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3206 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3207 }
3208
3209 GraphDef want;
3210 AddNode("x", "Const", {}, {}, &want);
3211 AddNode("y", "Const", {}, {}, &want);
3212 AddNode("z", "Const", {}, {}, &want);
3213 AddNode("ones", "Const", {}, {}, &want);
3214 AddNode("zeros", "Const", {}, {}, &want);
3215 AddNode("out2", "Square", {"x"}, {}, &want);
3216 AddNode("ArithmeticOptimizer/ConvertPow__inner_out3", "Square", {"x"}, {},
3217 &want)
3218 ->set_device("/device:CPU:0");
3219 AddNode("out3", "Mul", {"x", "ArithmeticOptimizer/ConvertPow__inner_out3"},
3220 {}, &want)
3221 ->set_device("/device:CPU:0");
3222 AddNode("out1", "Identity", {"x"}, {}, &want);
3223 AddNode("out.5", "Sqrt", {"x"}, {}, &want);
3224 AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want);
3225 AddNode("out_.5", "Rsqrt", {"x"}, {}, &want);
3226 AddNode("out_1", "Reciprocal", {"x"}, {}, &want);
3227 AddNode("out", "Pow", {"x", "y"}, {}, &want);
3228 AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
3229 AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
3230
3231 CompareGraphs(want, got);
3232 }
3233
TEST_F(ArithmeticOptimizerTest,Log1p)3234 TEST_F(ArithmeticOptimizerTest, Log1p) {
3235 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3236
3237 auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2});
3238 auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2});
3239 auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
3240 auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2);
3241 auto a23 = ops::Add(s.WithOpName("a23"), x2, x3);
3242 Output out1 = ops::Log(s.WithOpName("out1"), a12);
3243 Output out2 = ops::Log(s.WithOpName("out2"), a23);
3244
3245 GrapplerItem item;
3246 item.fetch = {"out1", "out2"};
3247 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3248 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3249 ASSERT_EQ(tensors_expected.size(), 2);
3250
3251 GraphDef got;
3252 ArithmeticOptimizer optimizer;
3253 EnableOnlyLog1p(&optimizer);
3254 OptimizeAndPrune(&optimizer, &item, &got);
3255 auto tensors = EvaluateNodes(got, item.fetch);
3256 ASSERT_EQ(tensors.size(), 2);
3257
3258 for (int i = 0; i < 2; ++i) {
3259 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3260 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3261 }
3262
3263 GraphDef want;
3264 AddNode("x2", "Const", {}, {}, &want);
3265 AddNode("x3", "Const", {}, {}, &want);
3266 AddNode("a23", "Add", {"x2", "x3"}, {}, &want);
3267 AddNode("out1", "Log1p", {"x2", AsControlDependency("x3")}, {}, &want);
3268 AddNode("out2", "Log", {"a23"}, {}, &want);
3269
3270 CompareGraphs(want, got);
3271 }
3272
TEST_F(ArithmeticOptimizerTest,Expm1)3273 TEST_F(ArithmeticOptimizerTest, Expm1) {
3274 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3275
3276 auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
3277 auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
3278 auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
3279 auto exp1 = ops::Exp(s.WithOpName("exp1").WithControlDependencies(x3), x1);
3280 Output out1 = ops::Sub(s.WithOpName("out1"), exp1, x2);
3281 Output out2 = ops::Sub(s.WithOpName("out2"), exp1, x3);
3282
3283 GrapplerItem item;
3284 item.fetch = {"out1", "out2"};
3285 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3286 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3287 ASSERT_EQ(tensors_expected.size(), 2);
3288
3289 GraphDef got;
3290 ArithmeticOptimizer optimizer;
3291 EnableOnlyExpm1(&optimizer);
3292 OptimizeAndPrune(&optimizer, &item, &got);
3293 auto tensors = EvaluateNodes(got, item.fetch);
3294 ASSERT_EQ(tensors.size(), 2);
3295
3296 for (int i = 0; i < 2; ++i) {
3297 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3298 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3299 }
3300
3301 GraphDef want;
3302 AddNode("x1", "Const", {}, {}, &want);
3303 AddNode("x3", "Const", {}, {}, &want);
3304 AddNode("exp1", "Exp", {"x1", AsControlDependency("x3")}, {}, &want);
3305 AddNode("out1", "Expm1", {"x1", AsControlDependency("x3")}, {}, &want);
3306 AddNode("out2", "Sub", {"exp1", "x3"}, {}, &want);
3307
3308 CompareGraphs(want, got);
3309 }
3310
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_SimpleSwap)3311 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
3312 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3313
3314 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
3315 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
3316 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
3317
3318 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3319 auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
3320
3321 auto outputs = ops::Identity(s.WithOpName("outputs"), mul2);
3322
3323 GrapplerItem item;
3324 item.fetch = {"outputs"};
3325 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3326
3327 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3328 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
3329 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3330 std::vector<std::pair<string, Tensor>> feed = {
3331 {"a", a_t}, {"b", b_t}, {"c", c_t}};
3332 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3333 ASSERT_EQ(tensors_expected.size(), 1);
3334
3335 GraphDef output;
3336 ArithmeticOptimizer optimizer;
3337 EnableOnlyMinimizeBroadcasts(&optimizer);
3338
3339 OptimizeAndPrune(&optimizer, &item, &output);
3340
3341 // We expect the following rewrite(s) to occur:
3342 //
3343 // * *
3344 // / \ / \
3345 // * c --> * b
3346 // / \ / \
3347 // a b a c
3348 NodeMap node_map(&output);
3349
3350 const NodeDef* mul1_node = node_map.GetNode("mul1");
3351 ASSERT_NE(mul1_node, nullptr);
3352 ASSERT_EQ(mul1_node->input_size(), 2);
3353 EXPECT_EQ(mul1_node->input(0), "a");
3354 EXPECT_EQ(mul1_node->input(1), "c");
3355
3356 const NodeDef* mul2_node = node_map.GetNode("mul2");
3357 ASSERT_NE(mul2_node, nullptr);
3358 ASSERT_EQ(mul2_node->input_size(), 2);
3359 EXPECT_EQ(mul2_node->input(0), "mul1");
3360 EXPECT_EQ(mul2_node->input(1), "b");
3361
3362 auto tensors = EvaluateNodes(output, item.fetch, feed);
3363 ASSERT_EQ(tensors.size(), 1);
3364 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3365 }
3366
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_FlattenTallGraph)3367 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
3368 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3369
3370 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_DOUBLE);
3371 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_DOUBLE);
3372 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_DOUBLE);
3373 auto d = ops::Variable(s.WithOpName("d"), {32}, DT_DOUBLE);
3374 auto e = ops::Variable(s.WithOpName("e"), {32}, DT_DOUBLE);
3375
3376 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3377 auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
3378 auto mul3 = ops::Mul(s.WithOpName("mul3"), mul2, d);
3379 auto mul4 = ops::Mul(s.WithOpName("mul4"), mul3, e);
3380
3381 auto outputs = ops::Identity(s.WithOpName("outputs"), mul4);
3382
3383 GrapplerItem item;
3384 item.fetch = {"outputs"};
3385 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3386
3387 auto a_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3388 auto b_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32, 32}));
3389 auto c_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3390 auto d_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3391 auto e_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3392 std::vector<std::pair<string, Tensor>> feed = {
3393 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"d", d_t}, {"e", e_t}};
3394 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3395 ASSERT_EQ(tensors_expected.size(), 1);
3396
3397 GraphDef output;
3398 ArithmeticOptimizer optimizer;
3399 EnableOnlyMinimizeBroadcasts(&optimizer);
3400
3401 OptimizeAndPrune(&optimizer, &item, &output);
3402
3403 // We expect the following rewrite(s) to occur: Graph is "flattened" and
3404 // largest shape pushed to the top.
3405 //
3406 // *
3407 // / \
3408 // * e *
3409 // / \ / \
3410 // * d * b
3411 // / \ / \
3412 // * c --> * *
3413 // / \ / \ / \
3414 // a b a c d e
3415 NodeMap node_map(&output);
3416
3417 const NodeDef* mul1_node = node_map.GetNode("mul1");
3418 ASSERT_NE(mul1_node, nullptr);
3419 ASSERT_EQ(mul1_node->input_size(), 2);
3420 EXPECT_EQ(mul1_node->input(0), "a");
3421 EXPECT_EQ(mul1_node->input(1), "c");
3422
3423 const NodeDef* mul2_node = node_map.GetNode("mul2");
3424 ASSERT_NE(mul2_node, nullptr);
3425 ASSERT_EQ(mul2_node->input_size(), 2);
3426 EXPECT_EQ(mul2_node->input(0), "d");
3427 EXPECT_EQ(mul2_node->input(1), "e");
3428
3429 const NodeDef* mul3_node = node_map.GetNode("mul3");
3430 ASSERT_NE(mul3_node, nullptr);
3431 ASSERT_EQ(mul3_node->input_size(), 2);
3432 EXPECT_EQ(mul3_node->input(0), "mul1");
3433 EXPECT_EQ(mul3_node->input(1), "mul2");
3434
3435 const NodeDef* mul4_node = node_map.GetNode("mul4");
3436 ASSERT_NE(mul4_node, nullptr);
3437 ASSERT_EQ(mul4_node->input_size(), 2);
3438 EXPECT_EQ(mul4_node->input(0), "mul3");
3439 EXPECT_EQ(mul4_node->input(1), "b");
3440
3441 auto tensors = EvaluateNodes(output, item.fetch, feed);
3442 ASSERT_EQ(tensors.size(), 1);
3443 test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
3444 }
3445
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_BuildTreeUp)3446 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
3447 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3448
3449 // [a, b, c] - scalars, [d] - matrix
3450 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
3451 auto b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
3452 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
3453 auto d = ops::Variable(s.WithOpName("D"), {32, 32}, DT_FLOAT);
3454
3455 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3456 auto mul2 = ops::Mul(s.WithOpName("mul2"), c, d);
3457 auto mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2);
3458
3459 auto outputs = ops::Identity(s.WithOpName("outputs"), mul3);
3460
3461 GrapplerItem item;
3462 item.fetch = {"outputs"};
3463 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3464
3465 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3466 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3467 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3468 auto d_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
3469 std::vector<std::pair<string, Tensor>> feed = {
3470 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"D", d_t}};
3471 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3472 ASSERT_EQ(tensors_expected.size(), 1);
3473
3474 GraphDef output;
3475 ArithmeticOptimizer optimizer;
3476 EnableOnlyMinimizeBroadcasts(&optimizer);
3477
3478 OptimizeAndPrune(&optimizer, &item, &output);
3479
3480 // We expect the following rewrite(s) to occur:
3481 //
3482 // *
3483 // / \
3484 // * * D
3485 // / \ / \
3486 // * * -> * c
3487 // / \ / \ / \
3488 // a b c D a b
3489 NodeMap node_map(&output);
3490
3491 const NodeDef* mul1_node = node_map.GetNode("mul2");
3492 ASSERT_NE(mul1_node, nullptr);
3493 ASSERT_EQ(mul1_node->input_size(), 2);
3494 EXPECT_EQ(mul1_node->input(0), "a");
3495 EXPECT_EQ(mul1_node->input(1), "b");
3496
3497 const NodeDef* mul2_node = node_map.GetNode("mul1");
3498 ASSERT_NE(mul2_node, nullptr);
3499 ASSERT_EQ(mul2_node->input_size(), 2);
3500 EXPECT_EQ(mul2_node->input(0), "mul2");
3501 EXPECT_EQ(mul2_node->input(1), "c");
3502
3503 const NodeDef* mul3_node = node_map.GetNode("mul3");
3504 ASSERT_NE(mul3_node, nullptr);
3505 ASSERT_EQ(mul3_node->input_size(), 2);
3506 EXPECT_EQ(mul3_node->input(0), "D");
3507 EXPECT_EQ(mul3_node->input(1), "mul1");
3508
3509 auto tensors = EvaluateNodes(output, item.fetch, feed);
3510 ASSERT_EQ(tensors.size(), 1);
3511 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3512 }
3513
TEST_F(ArithmeticOptimizerTest,DoNotHoistReluFromConcat)3514 TEST_F(ArithmeticOptimizerTest, DoNotHoistReluFromConcat) {
3515 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3516 Output weights1 = ops::Const(s.WithOpName("weights1"),
3517 Input::Initializer(1.0f, {5, 5, 3, 4}));
3518 Output weights2 = ops::Const(s.WithOpName("weights2"),
3519 Input::Initializer(2.0f, {5, 5, 3, 4}));
3520 Output biases =
3521 ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
3522 Output axis = ops::Const(s.WithOpName("axis"), 3, {});
3523 Output input = ops::Const(s.WithOpName("input"),
3524 Input::Initializer(1.0f, {1, 28, 28, 3}));
3525 Output branch1 =
3526 ops::Conv2D(s.WithOpName("conv1"), input, weights1, {1, 1, 1, 1}, "SAME");
3527 branch1 = ops::BiasAdd(s.WithOpName("biasadd1"), branch1, biases);
3528 branch1 = ops::Relu(s.WithOpName("relu1"), branch1);
3529 Output branch2 =
3530 ops::Conv2D(s.WithOpName("conv2"), input, weights2, {1, 1, 1, 1}, "SAME");
3531 branch2 = ops::BiasAdd(s.WithOpName("biasadd2"), branch2, biases);
3532 branch2 = ops::Relu(s.WithOpName("relu2"), branch2);
3533 Output concat = ops::Concat(s.WithOpName("concat"), {branch1, branch2}, axis);
3534 Output output = ops::Identity(s.WithOpName("output"), concat);
3535
3536 GrapplerItem item;
3537 item.fetch = {"output"};
3538 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3539
3540 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3541
3542 GraphDef new_graph;
3543 ArithmeticOptimizer optimizer;
3544 OptimizeAndPrune(&optimizer, &item, &new_graph);
3545
3546 // Verify that the two Relus are not hoisted.
3547 EXPECT_EQ(CountOpNodes(new_graph, "Relu"), 2);
3548
3549 auto tensors = EvaluateNodes(new_graph, item.fetch);
3550 for (int i = 0; i < item.fetch.size(); ++i) {
3551 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3552 }
3553 }
3554
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryFromConcat)3555 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
3556 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3557 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3558 Output b = ops::Const(s.WithOpName("b"), 1.0f, {32});
3559 Output c = ops::Const(s.WithOpName("c"), 42.0f, {32});
3560 Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3561 Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3562 Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3563 Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3564 // Test case with chains of length 1.
3565 // Rewrites
3566 // Concat({Exp(a), Exp(b), Exp(c)})
3567 // into
3568 // Exp(Concat({a, b, c})).
3569 Output sin_a =
3570 ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a);
3571 Output exp_a =
3572 ops::Exp(s.WithOpName("exp_a").WithControlDependencies(ctrl1), sin_a);
3573 Output exp_b = ops::Exp(s.WithOpName("exp_b"), b);
3574 Output exp_c =
3575 ops::Exp(s.WithOpName("exp_c").WithControlDependencies(ctrl2), c);
3576 Output concat =
3577 ops::Concat(s.WithOpName("concat"), {exp_a, exp_b, exp_c}, axis);
3578 Output id = ops::Identity(s.WithOpName("id"), concat);
3579
3580 // Test case with chains of length 2.
3581 // Rewrites
3582 // Concat({Cos(Exp(a)), Cos(Exp(b)), Cos(Exp(c))})
3583 // into
3584 // Cos(Exp(Concat({a, b, c}))).
3585 Output exp_a2 =
3586 ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a);
3587 Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b);
3588 Output exp_c2 =
3589 ops::Exp(s.WithOpName("exp_c2").WithControlDependencies(ctrl2), c);
3590 Output cos_exp_a2 = ops::Cos(
3591 s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl1), exp_a2);
3592 Output cos_exp_b2 = ops::Cos(
3593 s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3594 Output cos_exp_c2 = ops::Cos(s.WithOpName("cos_exp_c2"), exp_c2);
3595 Output concat2 = ops::Concat(s.WithOpName("concat2"),
3596 {cos_exp_a2, cos_exp_b2, cos_exp_c2}, axis);
3597 Output id2 = ops::Identity(s.WithOpName("id2"), concat2);
3598 GrapplerItem item;
3599 item.fetch = {"id", "id2"};
3600 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3601
3602 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3603
3604 GraphDef output;
3605 ArithmeticOptimizer optimizer;
3606 EnableOnlyHoistCWiseUnaryChains(&optimizer);
3607 OptimizeTwiceAndPrune(&optimizer, &item, &output);
3608
3609 int found = 0;
3610 for (const NodeDef& node : output.node()) {
3611 if (node.name() == "concat") {
3612 ASSERT_EQ(node.input_size(), 4);
3613 EXPECT_EQ(node.input(0), "sin_a");
3614 EXPECT_EQ(node.input(1), "b");
3615 EXPECT_EQ(node.input(2), "c");
3616 EXPECT_EQ(node.input(3), "axis");
3617 found++;
3618 }
3619 if (node.name() == "exp_a") {
3620 ASSERT_EQ(node.input_size(), 1);
3621 EXPECT_EQ(node.input(0), "concat");
3622 found++;
3623 }
3624 if (node.name() == "id") {
3625 ASSERT_EQ(node.input_size(), 1);
3626 EXPECT_EQ(node.input(0), "exp_a");
3627 found++;
3628 }
3629
3630 if (node.name() == "concat2") {
3631 ASSERT_EQ(node.input_size(), 4);
3632 EXPECT_EQ(node.input(0), "sin_a");
3633 EXPECT_EQ(node.input(1), "b");
3634 EXPECT_EQ(node.input(2), "c");
3635 EXPECT_EQ(node.input(3), "axis");
3636 found++;
3637 }
3638 if (node.name() == "exp_a2") {
3639 ASSERT_EQ(node.input_size(), 1);
3640 EXPECT_EQ(node.input(0), "concat2");
3641 found++;
3642 }
3643 if (node.name() == "cos_exp_a2") {
3644 ASSERT_EQ(node.input_size(), 1);
3645 EXPECT_EQ(node.input(0), "exp_a2");
3646 found++;
3647 }
3648 if (node.name() == "id2") {
3649 ASSERT_EQ(node.input_size(), 1);
3650 EXPECT_EQ(node.input(0), "cos_exp_a2");
3651 found++;
3652 }
3653 }
3654 EXPECT_EQ(found, 7);
3655
3656 auto tensors = EvaluateNodes(output, item.fetch);
3657 ASSERT_EQ(tensors.size(), tensors_expected.size());
3658 EXPECT_EQ(tensors.size(), item.fetch.size());
3659 for (int i = 0; i < item.fetch.size(); ++i) {
3660 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3661 }
3662 }
3663
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryIntoSplit)3664 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
3665 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3666 Output x = ops::Const(s.WithOpName("x"), 3.1415f, {32});
3667 Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3668 Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3669 Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3670 Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3671 // Test case with chains of length 1.
3672 // Rewrites
3673 // [Sin(y) for y in Split(x)]
3674 // into
3675 // [y for y in Split(Sin(x))].
3676 ops::Split split1(s.WithOpName("split1"), axis, x, 2);
3677 Output sin_a =
3678 ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl1), split1[0]);
3679 Output id_a = ops::Identity(s.WithOpName("id_a"), sin_a);
3680 Output sin_b = ops::Sin(s.WithOpName("sin_b"), split1[1]);
3681 Output exp_b = ops::Exp(s.WithOpName("exp_b"), sin_b);
3682 Output id_b = ops::Identity(s.WithOpName("id_b"), exp_b);
3683
3684 // Test case with SplitV and chains of length 2.
3685 // Rewrites
3686 // [Cos(Exp(y)) for y in Split(x)]
3687 // into
3688 // [y for y in Split(Cos(Exp(x)))].
3689 Output size_splits2 = ops::Const(s.WithOpName("size_splits2"), {20, 12}, {2});
3690 ops::SplitV split2(s.WithOpName("split2"), x, size_splits2, axis, 2);
3691 Output exp_a2 = ops::Exp(
3692 s.WithOpName("exp_a2").WithControlDependencies(ctrl1), split2[0]);
3693 Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), split2[1]);
3694 Output cos_exp_a2 = ops::Cos(
3695 s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl2), exp_a2);
3696 Output cos_exp_b2 = ops::Cos(
3697 s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3698 Output id_a2 = ops::Identity(s.WithOpName("id_a2"), cos_exp_a2);
3699 Output id_b2 = ops::Identity(s.WithOpName("id_b2"), cos_exp_b2);
3700
3701 GrapplerItem item;
3702 item.fetch = {"id_a", "id_b", "id_a2", "id_b2"};
3703 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3704
3705 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3706
3707 GraphDef output;
3708 ArithmeticOptimizer optimizer;
3709 EnableOnlyHoistCWiseUnaryChains(&optimizer);
3710 OptimizeTwiceAndPrune(&optimizer, &item, &output);
3711
3712 int found = 0;
3713 for (const NodeDef& node : output.node()) {
3714 // The following 6 nodes should be pruned.
3715 EXPECT_NE(node.name(), "sin_a");
3716 EXPECT_NE(node.name(), "sin_b");
3717 EXPECT_NE(node.name(), "exp_a2");
3718 EXPECT_NE(node.name(), "exp_b2");
3719 EXPECT_NE(node.name(), "cos_exp_a2");
3720 EXPECT_NE(node.name(), "cos_exp_b2");
3721
3722 if (node.name() == "split1") {
3723 ASSERT_EQ(node.input_size(), 2);
3724 EXPECT_EQ(node.input(0), "axis");
3725 EXPECT_EQ(node.input(1), "ArithmeticOptimizer/_sin_a_split1");
3726 found++;
3727 }
3728 if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
3729 EXPECT_EQ(node.op(), "Sin");
3730 ASSERT_EQ(node.input_size(), 1);
3731 EXPECT_EQ(node.input(0), "x");
3732 found++;
3733 }
3734 if (node.name() == "id_a") {
3735 ASSERT_EQ(node.input_size(), 1);
3736 EXPECT_EQ(node.input(0), "split1");
3737 found++;
3738 }
3739 if (node.name() == "exp_b") {
3740 ASSERT_EQ(node.input_size(), 1);
3741 EXPECT_EQ(node.input(0), "split1:1");
3742 found++;
3743 }
3744 if (node.name() == "id_b") {
3745 ASSERT_EQ(node.input_size(), 1);
3746 EXPECT_EQ(node.input(0), "exp_b");
3747 found++;
3748 }
3749 if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
3750 EXPECT_EQ(node.op(), "Exp");
3751 ASSERT_EQ(node.input_size(), 1);
3752 EXPECT_EQ(node.input(0), "x");
3753 found++;
3754 }
3755 if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
3756 EXPECT_EQ(node.op(), "Cos");
3757 ASSERT_EQ(node.input_size(), 1);
3758 EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_exp_a2_split2");
3759 found++;
3760 }
3761 if (node.name() == "split2") {
3762 ASSERT_EQ(node.input_size(), 3);
3763 EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_cos_exp_a2_split2");
3764 EXPECT_EQ(node.input(1), "size_splits2");
3765 EXPECT_EQ(node.input(2), "axis");
3766 found++;
3767 }
3768 if (node.name() == "id_a2") {
3769 ASSERT_EQ(node.input_size(), 1);
3770 EXPECT_EQ(node.input(0), "split2");
3771 found++;
3772 }
3773 if (node.name() == "id_b2") {
3774 ASSERT_EQ(node.input_size(), 1);
3775 EXPECT_EQ(node.input(0), "split2:1");
3776 found++;
3777 }
3778 }
3779 EXPECT_EQ(found, 10);
3780
3781 auto tensors = EvaluateNodes(output, item.fetch);
3782 ASSERT_EQ(tensors.size(), tensors_expected.size());
3783 EXPECT_EQ(tensors.size(), item.fetch.size());
3784 for (int i = 0; i < item.fetch.size(); ++i) {
3785 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3786 }
3787 }
3788
TEST_F(ArithmeticOptimizerTest,RemoveIdempotent)3789 TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
3790 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3791 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3792 Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
3793 Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
3794 Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
3795 Output id1 = ops::Identity(s.WithOpName("id1"), a);
3796 Output id2 = ops::Identity(s.WithOpName("id2"), id1);
3797 Output out2 = ops::Identity(s.WithOpName("out2"), id2);
3798 GrapplerItem item;
3799 item.fetch = {"out1", "out2"};
3800 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3801
3802 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3803
3804 GraphDef output;
3805 ArithmeticOptimizer optimizer;
3806 EnableOnlyRemoveIdempotent(&optimizer);
3807 OptimizeTwice(&optimizer, &item, &output);
3808
3809 EXPECT_EQ(7, output.node_size());
3810 int found = 0;
3811 for (const NodeDef& node : output.node()) {
3812 if (node.name() == "out1") {
3813 ASSERT_EQ(node.input_size(), 1);
3814 EXPECT_EQ(node.input(0), "sn1");
3815 found++;
3816 } else if (node.name() == "out2") {
3817 ASSERT_EQ(node.input_size(), 1);
3818 EXPECT_EQ(node.input(0), "id1");
3819 found++;
3820 } else if (node.name() == "sn1") {
3821 ASSERT_EQ(node.input_size(), 1);
3822 EXPECT_EQ(node.input(0), "a");
3823 found++;
3824 }
3825 }
3826 EXPECT_EQ(found, 3);
3827
3828 auto tensors = EvaluateNodes(output, item.fetch);
3829 ASSERT_EQ(tensors.size(), tensors_expected.size());
3830 EXPECT_EQ(tensors.size(), item.fetch.size());
3831 for (int i = 0; i < item.fetch.size(); ++i) {
3832 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3833 }
3834 }
3835
TEST_F(ArithmeticOptimizerTest,RemoveLogicalNot)3836 TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
3837 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3838 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3839 Output b = ops::Const(s.WithOpName("b"), -3.14f, {32});
3840 Output eq = ops::Equal(s.WithOpName("eq"), a, b);
3841 Output neq = ops::NotEqual(s.WithOpName("neq"), a, b);
3842 Output lt = ops::Less(s.WithOpName("lt"), a, b);
3843 Output le = ops::LessEqual(s.WithOpName("le"), a, b);
3844 Output gt = ops::Greater(s.WithOpName("gt"), a, b);
3845 Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b);
3846 // not_eq is reserved
3847 Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq);
3848 Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq);
3849 Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt);
3850 Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le);
3851 Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt);
3852 Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge);
3853 Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1);
3854 Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq);
3855 Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt);
3856 Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le);
3857 Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt);
3858 Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge);
3859
3860 GrapplerItem item;
3861 item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt",
3862 "id_not_le", "id_not_gt", "id_not_ge"};
3863 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3864
3865 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3866
3867 GraphDef output;
3868 ArithmeticOptimizer optimizer;
3869 EnableOnlyRemoveLogicalNot(&optimizer);
3870 OptimizeTwice(&optimizer, &item, &output);
3871
3872 int found = 0;
3873 for (const NodeDef& node : output.node()) {
3874 if (node.name() == "id_not_eq") {
3875 ASSERT_EQ(node.input_size(), 1);
3876 EXPECT_EQ(node.input(0), "eq");
3877 ++found;
3878 }
3879 if (node.name() == "id_not_neq") {
3880 ASSERT_EQ(node.input_size(), 1);
3881 EXPECT_EQ(node.input(0), "neq");
3882 ++found;
3883 }
3884 if (node.name() == "id_not_lt") {
3885 ASSERT_EQ(node.input_size(), 1);
3886 EXPECT_EQ(node.input(0), "lt");
3887 ++found;
3888 }
3889 if (node.name() == "id_not_le") {
3890 ASSERT_EQ(node.input_size(), 1);
3891 EXPECT_EQ(node.input(0), "le");
3892 ++found;
3893 }
3894 if (node.name() == "id_not_gt") {
3895 ASSERT_EQ(node.input_size(), 1);
3896 EXPECT_EQ(node.input(0), "gt");
3897 ++found;
3898 }
3899 if (node.name() == "id_not_ge") {
3900 ASSERT_EQ(node.input_size(), 1);
3901 EXPECT_EQ(node.input(0), "ge");
3902 ++found;
3903 }
3904
3905 if (node.name() == "eq") {
3906 EXPECT_EQ(node.op(), "NotEqual");
3907 ++found;
3908 }
3909 if (node.name() == "neq") {
3910 EXPECT_EQ(node.op(), "Equal");
3911 ++found;
3912 }
3913 if (node.name() == "lt") {
3914 EXPECT_EQ(node.op(), "GreaterEqual");
3915 ++found;
3916 }
3917 if (node.name() == "le") {
3918 EXPECT_EQ(node.op(), "Greater");
3919 ++found;
3920 }
3921 if (node.name() == "gt") {
3922 EXPECT_EQ(node.op(), "LessEqual");
3923 ++found;
3924 }
3925 if (node.name() == "ge") {
3926 EXPECT_EQ(node.op(), "Less");
3927 ++found;
3928 }
3929 }
3930 EXPECT_EQ(found, 12);
3931
3932 auto tensors = EvaluateNodes(output, item.fetch);
3933 ASSERT_EQ(tensors.size(), tensors_expected.size());
3934 EXPECT_EQ(tensors.size(), item.fetch.size());
3935 for (int i = 0; i < item.fetch.size(); ++i) {
3936 test::ExpectTensorEqual<bool>(tensors[i], tensors_expected[i]);
3937 }
3938 }
3939
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWise)3940 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
3941 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3942 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3943 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3944 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3945 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3946
3947 GrapplerItem item;
3948 item.fetch = {"final_out"};
3949 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3950 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3951 ASSERT_EQ(tensors_expected.size(), 1);
3952
3953 GraphDef output;
3954 ArithmeticOptimizer optimizer;
3955 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3956 OptimizeAndPrune(&optimizer, &item, &output);
3957 auto tensors = EvaluateNodes(output, item.fetch);
3958 ASSERT_EQ(tensors.size(), 1);
3959
3960 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3961 EXPECT_EQ(output.node_size(), item.graph.node_size());
3962 // Check if the inputs are switched
3963 int required_node_count = 0;
3964 for (int i = 0; i < output.node_size(); ++i) {
3965 const NodeDef& node = output.node(i);
3966 if (node.name() == "sqrt") {
3967 EXPECT_EQ(node.op(), "Sqrt");
3968 ASSERT_EQ(node.input_size(), 1);
3969 EXPECT_EQ(node.input(0), "reduce_max");
3970 ++required_node_count;
3971 } else if (node.name() == "reduce_max") {
3972 EXPECT_EQ(node.op(), "Max");
3973 ASSERT_EQ(node.input_size(), 2);
3974 EXPECT_EQ(node.input(0), "x");
3975 ++required_node_count;
3976 }
3977 }
3978 EXPECT_EQ(required_node_count, 2);
3979 }
3980
TEST_F(ArithmeticOptimizerTest,OptimizeArgMaxOrArgMinOfMonotonicElementWise)3981 TEST_F(ArithmeticOptimizerTest, OptimizeArgMaxOrArgMinOfMonotonicElementWise) {
3982 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3983 const auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3984 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3985 Output arg_max = ops::ArgMax(s.WithOpName("arg_max"), sqrt, 1);
3986 Output final_out = ops::Identity(s.WithOpName("final_out"), arg_max);
3987
3988 GrapplerItem item;
3989 item.fetch = {"final_out"};
3990 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3991 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3992 ASSERT_EQ(tensors_expected.size(), 1);
3993
3994 GraphDef output;
3995 ArithmeticOptimizer optimizer;
3996 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3997 OptimizeAndPrune(&optimizer, &item, &output);
3998 const auto tensors = EvaluateNodes(output, item.fetch);
3999 ASSERT_EQ(tensors.size(), 1);
4000
4001 test::ExpectTensorEqual<int64>(tensors[0], tensors_expected[0]);
4002 EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
4003 // Check if the inputs are switched
4004 int required_node_count = 0;
4005 for (int i = 0; i < output.node_size(); ++i) {
4006 const NodeDef& node = output.node(i);
4007 if (node.name() == "final_out") {
4008 EXPECT_EQ(node.op(), "Identity");
4009 ASSERT_EQ(node.input_size(), 1);
4010 EXPECT_EQ(node.input(0), "arg_max");
4011 ++required_node_count;
4012 } else if (node.name() == "arg_max") {
4013 EXPECT_EQ(node.op(), "ArgMax");
4014 ASSERT_EQ(node.input_size(), 2);
4015 EXPECT_EQ(node.input(0), "x");
4016 ++required_node_count;
4017 }
4018 }
4019 EXPECT_EQ(required_node_count, 2);
4020 }
4021
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode)4022 TEST_F(ArithmeticOptimizerTest,
4023 OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode) {
4024 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4025 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4026 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4027 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
4028 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
4029
4030 GrapplerItem item;
4031 item.fetch = {"sqrt", "final_out"};
4032 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4033 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4034 EXPECT_EQ(tensors_expected.size(), 2);
4035
4036 GraphDef output;
4037 ArithmeticOptimizer optimizer;
4038 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4039 OptimizeTwice(&optimizer, &item, &output);
4040
4041 // Should be a NoOp since we are not allowed to change the output of fetch
4042 // nodes.
4043 VerifyGraphsMatch(item.graph, output, __LINE__);
4044 }
4045
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction)4046 TEST_F(ArithmeticOptimizerTest,
4047 OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction) {
4048 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4049 auto x = ops::Const(s.WithOpName("x"), {2, 3}, {1, 2});
4050 Output reshape = ops::Reshape(s.WithOpName("reshape"), x, {-1});
4051 Output y = ops::Neg(s.WithOpName("y"), reshape);
4052 Output z = ops::Max(s.WithOpName("z"), y, {0});
4053
4054 GrapplerItem item;
4055 item.fetch = {"z"};
4056 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4057 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4058 ASSERT_EQ(tensors_expected.size(), 1);
4059
4060 GraphDef output;
4061 ArithmeticOptimizer optimizer;
4062 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4063 OptimizeTwice(&optimizer, &item, &output);
4064
4065 // Should be a NoOp since we are not allowed to change the output of fetch
4066 // nodes.
4067 VerifyGraphsMatch(item.graph, output, __LINE__);
4068
4069 auto tensors = EvaluateNodes(output, item.fetch);
4070 ASSERT_EQ(tensors.size(), 1);
4071 test::ExpectTensorEqual<int>(tensors[0], tensors_expected[0]);
4072 test::ExpectTensorEqual<int>(tensors[0], Tensor(-2));
4073 }
4074
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing)4075 TEST_F(ArithmeticOptimizerTest,
4076 OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
4077 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4078 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4079 Output neg = ops::Neg(s.WithOpName("neg"), x);
4080 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
4081 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
4082
4083 GrapplerItem item;
4084 item.fetch = {"final_out"};
4085 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4086 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4087 ASSERT_EQ(tensors_expected.size(), 1);
4088
4089 GraphDef output;
4090 ArithmeticOptimizer optimizer;
4091 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4092 OptimizeAndPrune(&optimizer, &item, &output);
4093 auto tensors = EvaluateNodes(output, item.fetch);
4094 ASSERT_EQ(tensors.size(), 1);
4095
4096 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4097 EXPECT_EQ(output.node_size(), item.graph.node_size());
4098 // Check if the inputs are switched
4099 int required_node_count = 0;
4100 for (int i = 0; i < output.node_size(); ++i) {
4101 const NodeDef& node = output.node(i);
4102 if (node.name() == "neg") {
4103 EXPECT_EQ(node.op(), "Neg");
4104 ASSERT_EQ(node.input_size(), 1);
4105 EXPECT_EQ(node.input(0), "reduce_max");
4106 ++required_node_count;
4107 } else if (node.name() == "reduce_max") {
4108 EXPECT_EQ(node.op(), "Min");
4109 ASSERT_EQ(node.input_size(), 2);
4110 EXPECT_EQ(node.input(0), "x");
4111 ++required_node_count;
4112 }
4113 }
4114 EXPECT_EQ(2, required_node_count);
4115 }
4116
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool)4117 TEST_F(ArithmeticOptimizerTest,
4118 OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool) {
4119 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4120 auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
4121 Output neg = ops::Neg(s.WithOpName("neg"), x);
4122 Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), neg, {1, 2, 2, 1},
4123 {1, 2, 2, 1}, "VALID");
4124
4125 GrapplerItem item;
4126 item.fetch = {"max_pool"};
4127 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4128 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4129 ASSERT_EQ(tensors_expected.size(), 1);
4130
4131 GraphDef output;
4132 ArithmeticOptimizer optimizer;
4133 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4134 OptimizeTwice(&optimizer, &item, &output);
4135
4136 // Should be a NoOp
4137 VerifyGraphsMatch(item.graph, output, __LINE__);
4138
4139 auto tensors = EvaluateNodes(output, item.fetch);
4140 ASSERT_EQ(tensors.size(), 1);
4141 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4142 }
4143
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool)4144 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool) {
4145 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4146 Output weights = ops::Const(s.WithOpName("weights"),
4147 Input::Initializer(1.0f, {5, 5, 3, 4}));
4148 Output biases =
4149 ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
4150 Output input = ops::Const(s.WithOpName("input"),
4151 Input::Initializer(1.0f, {1, 28, 28, 3}));
4152 Output output =
4153 ops::Conv2D(s.WithOpName("conv"), input, weights, {1, 1, 1, 1}, "SAME");
4154 output = ops::BiasAdd(s.WithOpName("biasadd"), output, biases);
4155 output = ops::Relu(s.WithOpName("relu"), output);
4156 output = ops::MaxPool(s.WithOpName("max_pool"), output, {1, 2, 2, 1},
4157 {1, 2, 2, 1}, "VALID");
4158 output = ops::Identity(s.WithOpName("output"), output);
4159
4160 GrapplerItem item;
4161 item.fetch = {"output"};
4162 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4163 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4164 ASSERT_EQ(tensors_expected.size(), 1);
4165
4166 GraphDef new_graph;
4167 ArithmeticOptimizer optimizer;
4168 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4169 OptimizeTwice(&optimizer, &item, &new_graph);
4170
4171 // Should be a NoOp
4172 VerifyGraphsMatch(item.graph, new_graph, __LINE__);
4173
4174 auto tensors = EvaluateNodes(new_graph, item.fetch);
4175 ASSERT_EQ(tensors.size(), 1);
4176 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4177 }
4178
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseMaxPool)4179 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWiseMaxPool) {
4180 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4181 auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
4182 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4183 Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), sqrt, {1, 2, 2, 1},
4184 {1, 2, 2, 1}, "VALID");
4185 Output final_out = ops::Identity(s.WithOpName("final_out"), max_pool);
4186
4187 GrapplerItem item;
4188 item.fetch = {"final_out"};
4189 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4190 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4191 ASSERT_EQ(tensors_expected.size(), 1);
4192
4193 GraphDef output;
4194 ArithmeticOptimizer optimizer;
4195 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4196 OptimizeAndPrune(&optimizer, &item, &output);
4197 auto tensors = EvaluateNodes(output, item.fetch);
4198 ASSERT_EQ(tensors.size(), 1);
4199
4200 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4201 EXPECT_EQ(output.node_size(), item.graph.node_size());
4202 // Check if the inputs are switched
4203 int required_node_count = 0;
4204 for (int i = 0; i < output.node_size(); ++i) {
4205 const NodeDef& node = output.node(i);
4206 if (node.name() == "sqrt") {
4207 EXPECT_EQ(node.op(), "Sqrt");
4208 ASSERT_EQ(node.input_size(), 1);
4209 EXPECT_EQ(node.input(0), "max_pool");
4210 ++required_node_count;
4211 } else if (node.name() == "max_pool") {
4212 EXPECT_EQ(node.op(), "MaxPool");
4213 ASSERT_EQ(node.input_size(), 1);
4214 EXPECT_EQ(node.input(0), "x");
4215 ++required_node_count;
4216 }
4217 }
4218 EXPECT_EQ(required_node_count, 2);
4219 }
4220
TEST_F(ArithmeticOptimizerTest,UnaryOpsComposition)4221 TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
4222 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4223
4224 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4225 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4226 Output log = ops::Log(s.WithOpName("log"), sqrt);
4227 Output relu = ops::Relu(s.WithOpName("relu"), log);
4228 Output final_out = ops::Identity(s.WithOpName("final_out"), relu);
4229
4230 GrapplerItem item;
4231 item.fetch = {"final_out"};
4232 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4233
4234 // Place all nodes on CPU.
4235 for (int i = 0; i < item.graph.node_size(); ++i) {
4236 item.graph.mutable_node(i)->set_device("/device:CPU:0");
4237 }
4238
4239 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4240 ASSERT_EQ(tensors_expected.size(), 1);
4241
4242 GraphDef output;
4243 ArithmeticOptimizer optimizer;
4244 EnableOnlyUnaryOpsComposition(&optimizer);
4245 OptimizeAndPrune(&optimizer, &item, &output);
4246
4247 EXPECT_EQ(output.node_size(), 3);
4248
4249 // Check that Sqrt/Log/Relu were replaced with a single op.
4250 int required_node_count = 0;
4251 for (int i = 0; i < output.node_size(); ++i) {
4252 const NodeDef& node = output.node(i);
4253 if (node.name() == "final_out") {
4254 EXPECT_EQ(node.op(), "Identity");
4255 ASSERT_EQ(node.input_size(), 1);
4256 EXPECT_EQ(node.input(0), "relu/unary_ops_composition");
4257 ++required_node_count;
4258 } else if (node.name() == "relu/unary_ops_composition") {
4259 EXPECT_EQ(node.op(), "_UnaryOpsComposition");
4260 ASSERT_EQ(node.input_size(), 1);
4261 EXPECT_EQ(node.input(0), "x");
4262
4263 auto op_names = node.attr().at("op_names").list().s();
4264 ASSERT_EQ(op_names.size(), 3);
4265 EXPECT_EQ(op_names[0], "Sqrt");
4266 EXPECT_EQ(op_names[1], "Log");
4267 EXPECT_EQ(op_names[2], "Relu");
4268 ++required_node_count;
4269 }
4270 }
4271 EXPECT_EQ(required_node_count, 2);
4272
4273 auto tensors = EvaluateNodes(output, item.fetch);
4274 ASSERT_EQ(tensors.size(), 1);
4275 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4276 }
4277
TEST_F(ArithmeticOptimizerTest,RemoveStackStridedSliceSameAxis)4278 TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) {
4279 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4280 auto a_in =
4281 ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4282 auto b_in =
4283 ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
4284 auto c_in =
4285 ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
4286 auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
4287 PartialTensorShape({-1, -1}));
4288 auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
4289 PartialTensorShape({-1, -1}));
4290 auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
4291 PartialTensorShape({-1, -1}));
4292 // stacked = tf.stack((a, b, c), axis=1).
4293 // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
4294 auto stacked =
4295 ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
4296 ops::Stack::Axis(1));
4297 auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
4298 auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
4299 auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
4300 auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
4301 auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3});
4302 auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
4303 auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3});
4304 auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
4305 auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3});
4306 auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3});
4307 auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3});
4308
4309 // stacked[:, 0]
4310 using SS = ops::StridedSlice;
4311 auto pa_slice = ops::Identity(
4312 s.WithOpName("pa_slice_out"),
4313 SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides,
4314 SS::BeginMask(0b0101) // 5
4315 .EllipsisMask(0)
4316 .EndMask(0b0101) // 5
4317 .NewAxisMask(0)
4318 .ShrinkAxisMask(0b0010))); // 2
4319
4320 // stacked[:, 1]
4321 auto pb_slice = ops::Identity(
4322 s.WithOpName("pb_slice_out"),
4323 SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides,
4324 SS::BeginMask(0b0101) // 5
4325 .EllipsisMask(0)
4326 .EndMask(0b0101) // 5
4327 .NewAxisMask(0)
4328 .ShrinkAxisMask(0b0010))); // 2
4329
4330 // stacked[:, 2]
4331 auto pc_slice = ops::Identity(
4332 s.WithOpName("pc_slice_out"),
4333 SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides,
4334 SS::BeginMask(0b0101) // 5
4335 .EllipsisMask(0)
4336 .EndMask(0b0101) // 5
4337 .NewAxisMask(0)
4338 .ShrinkAxisMask(0b0010))); // 2
4339
4340 // stacked[:, 0:1, :]
4341 auto pa_slice_01 = ops::Identity(
4342 s.WithOpName("pa_slice_01_out"),
4343 SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides,
4344 SS::BeginMask(0b0101) // 5
4345 .EllipsisMask(0)
4346 .EndMask(0b0101) // 5
4347 .NewAxisMask(0)
4348 .ShrinkAxisMask(0)));
4349
4350 // stacked[:, :1, :]
4351 auto pa_slice_to1 = ops::Identity(
4352 s.WithOpName("pa_slice_to1_out"),
4353 SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides,
4354 SS::BeginMask(0b0111) // 7
4355 .EllipsisMask(0)
4356 .EndMask(0b0101) // 5
4357 .NewAxisMask(0)
4358 .ShrinkAxisMask(0)));
4359
4360 // stacked[:, 1:2, :]
4361 auto pb_slice_12 = ops::Identity(
4362 s.WithOpName("pb_slice_12_out"),
4363 SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides,
4364 SS::BeginMask(0b0101) // 5
4365 .EllipsisMask(0)
4366 .EndMask(0b0101) // 5
4367 .NewAxisMask(0)
4368 .ShrinkAxisMask(0)));
4369
4370 // stacked[:, 2:, :].
4371 auto pc_slice_2to = ops::Identity(
4372 s.WithOpName("pc_slice_2to_out"),
4373 SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides,
4374 SS::BeginMask(0b0101) // 5
4375 .EllipsisMask(0)
4376 .EndMask(0b0111) // 7
4377 .NewAxisMask(0)
4378 .ShrinkAxisMask(0)));
4379
4380 GrapplerItem item;
4381 item.fetch = {"a",
4382 "b",
4383 "c",
4384 "pa_slice_out",
4385 "pb_slice_out",
4386 "pc_slice_out",
4387 "expanded_a",
4388 "expanded_b",
4389 "expanded_c",
4390 "pa_slice_01_out",
4391 "pa_slice_to1_out",
4392 "pb_slice_12_out",
4393 "pc_slice_2to_out"};
4394 enum FetchItem {
4395 fA,
4396 fB,
4397 fC,
4398 fASliceOut,
4399 fBSliceOut,
4400 fCSliceOut,
4401 fExpandedA,
4402 fExpandedB,
4403 fExpandedC,
4404 fASlice01Out,
4405 fASliceTo1Out,
4406 fBSlice12Out,
4407 fCSlice2ToOut,
4408 };
4409 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4410 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4411
4412 // stacked[:, 0, :] == a.
4413 test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
4414 tensors_expected[fA]);
4415 // stacked[:, 1, :] == b.
4416 test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
4417 tensors_expected[fB]);
4418 // stacked[:, 2, :] == c.
4419 test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
4420 tensors_expected[fC]);
4421
4422 // stacked[:, 0:1, :] == expand_dims(a, 1).
4423 test::ExpectTensorEqual<float>(tensors_expected[fASlice01Out],
4424 tensors_expected[fExpandedA]);
4425
4426 // stacked[:, :1, :] == expand_dims(a, 1).
4427 test::ExpectTensorEqual<float>(tensors_expected[fASliceTo1Out],
4428 tensors_expected[fExpandedA]);
4429
4430 // stacked[:, 1:2, :] == expand_dims(b, 1).
4431 test::ExpectTensorEqual<float>(tensors_expected[fBSlice12Out],
4432 tensors_expected[fExpandedB]);
4433 // stacked[:, 2:, :] == expand_dims(c, 1).
4434 test::ExpectTensorEqual<float>(tensors_expected[fCSlice2ToOut],
4435 tensors_expected[fExpandedC]);
4436
4437 GraphDef output;
4438 ArithmeticOptimizer optimizer;
4439 EnableOnlyRemoveStackSliceSameAxis(&optimizer);
4440 OptimizeAndPrune(&optimizer, &item, &output);
4441
4442 for (const auto& node : output.node()) {
4443 if (node.name() == "pa_slice_out") {
4444 ASSERT_EQ(node.input_size(), 1);
4445 EXPECT_EQ(node.input(0), "a");
4446 } else if (node.name() == "pb_slice_out") {
4447 ASSERT_EQ(node.input_size(), 1);
4448 EXPECT_EQ(node.input(0), "b");
4449 } else if (node.name() == "pc_slice_out") {
4450 ASSERT_EQ(node.input_size(), 1);
4451 EXPECT_EQ(node.input(0), "c");
4452 } else if (str_util::EndsWith(node.name(), "_out")) {
4453 ASSERT_EQ(node.input_size(), 1);
4454 EXPECT_EQ(
4455 absl::StrCat(node.input(0), "_out"),
4456 absl::StrCat("ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_",
4457 node.name()));
4458 }
4459 }
4460
4461 auto tensors = EvaluateNodes(output, item.fetch);
4462
4463 // stacked[:, 0, :] == a.
4464 test::ExpectTensorEqual<float>(tensors[fASliceOut], tensors_expected[fA]);
4465
4466 // stacked[:, 1, :] == b.
4467 test::ExpectTensorEqual<float>(tensors[fBSliceOut], tensors_expected[fB]);
4468 // stacked[:, 2, :] == c.
4469 test::ExpectTensorEqual<float>(tensors[fCSliceOut], tensors_expected[fC]);
4470
4471 // stacked[:, 0:1, :] == expand_dims(a, 1).
4472 test::ExpectTensorEqual<float>(tensors[fASlice01Out],
4473 tensors_expected[fExpandedA]);
4474
4475 // stacked[:, :1, :] == expand_dims(a, 1).
4476 test::ExpectTensorEqual<float>(tensors[fASliceTo1Out],
4477 tensors_expected[fExpandedA]);
4478
4479 // stacked[:, 1:2, :] == expand_dims(b, 1).
4480 test::ExpectTensorEqual<float>(tensors[fBSlice12Out],
4481 tensors_expected[fExpandedB]);
4482 // stacked[:, 2:, :] == expand_dims(c, 1).
4483 test::ExpectTensorEqual<float>(tensors[fCSlice2ToOut],
4484 tensors_expected[fExpandedC]);
4485 }
4486
TEST_F(ArithmeticOptimizerTest,RemoveStackSimpleSliceSameAxis)4487 TEST_F(ArithmeticOptimizerTest, RemoveStackSimpleSliceSameAxis) {
4488 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4489 auto a_in =
4490 ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4491 auto b_in =
4492 ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
4493 auto c_in =
4494 ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
4495 auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
4496 PartialTensorShape({-1, -1}));
4497 auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
4498 PartialTensorShape({-1, -1}));
4499 auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
4500 PartialTensorShape({-1, -1}));
4501 // stacked = tf.stack((a, b, c), axis=1).
4502 // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
4503 auto stacked =
4504 ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
4505 ops::Stack::Axis(1));
4506 auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
4507 auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
4508 auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
4509 auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
4510 auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
4511 auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
4512 auto sizes_to_end = ops::Const(s.WithOpName("size"), {-1, 1, -1}, {3});
4513
4514 // stacked[:, 0:1, :]
4515 auto pa_slice = ops::Identity(
4516 s.WithOpName("pa_slice_out"),
4517 ops::Slice(s.WithOpName("pa_slice"), stacked, begin_a, sizes_to_end));
4518
4519 // stacked[:, 1:2, :]
4520 auto pb_slice = ops::Identity(
4521 s.WithOpName("pb_slice_out"),
4522 ops::Slice(s.WithOpName("pb_slice"), stacked, begin_b, sizes_to_end));
4523
4524 // stacked[:, 2:3, :]
4525 auto pc_slice = ops::Identity(
4526 s.WithOpName("pc_slice_out"),
4527 ops::Slice(s.WithOpName("pc_slice"), stacked, begin_c, sizes_to_end));
4528
4529 GrapplerItem item;
4530 item.fetch = {"a",
4531 "b",
4532 "c",
4533 "pa_slice_out",
4534 "pb_slice_out",
4535 "pc_slice_out",
4536 "expanded_a",
4537 "expanded_b",
4538 "expanded_c"};
4539 enum FetchItem {
4540 fA,
4541 fB,
4542 fC,
4543 fASliceOut,
4544 fBSliceOut,
4545 fCSliceOut,
4546 fExpandedA,
4547 fExpandedB,
4548 fExpandedC,
4549 };
4550 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4551 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4552
4553 // stacked[:, 0:1, :] == a.
4554 test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
4555 tensors_expected[fExpandedA]);
4556 // stacked[:, 1:2, :] == b.
4557 test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
4558 tensors_expected[fExpandedB]);
4559 // stacked[:, 2:3, :] == c.
4560 test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
4561 tensors_expected[fExpandedC]);
4562
4563 GraphDef output;
4564 ArithmeticOptimizer optimizer;
4565 EnableOnlyRemoveStackSliceSameAxis(&optimizer);
4566 OptimizeAndPrune(&optimizer, &item, &output);
4567
4568 const string kExpandDimsNamePrefix(
4569 "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_p");
4570
4571 for (const auto& node : output.node()) {
4572 if (node.name() == "pa_slice_out") {
4573 ASSERT_EQ(node.input_size(), 1);
4574 EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "a_slice"));
4575 } else if (node.name() == "pb_slice_out") {
4576 ASSERT_EQ(node.input_size(), 1);
4577 EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "b_slice"));
4578 } else if (node.name() == "pc_slice_out") {
4579 ASSERT_EQ(node.input_size(), 1);
4580 EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "c_slice"));
4581 } else if (absl::StartsWith(node.name(), kExpandDimsNamePrefix)) {
4582 EXPECT_EQ(node.op(), "ExpandDims");
4583 // The input is "a", "b", or "c", as appropriate.
4584 EXPECT_EQ(node.input(0),
4585 node.name().substr(kExpandDimsNamePrefix.size(), 1));
4586 }
4587 }
4588
4589 auto tensors = EvaluateNodes(output, item.fetch);
4590
4591 // stacked[:, 0:1, :] == a.
4592 test::ExpectTensorEqual<float>(tensors[fASliceOut],
4593 tensors_expected[fExpandedA]);
4594
4595 // stacked[:, 1:2, :] == b.
4596 test::ExpectTensorEqual<float>(tensors[fBSliceOut],
4597 tensors_expected[fExpandedB]);
4598 // stacked[:, 2:3, :] == c.
4599 test::ExpectTensorEqual<float>(tensors[fCSliceOut],
4600 tensors_expected[fExpandedC]);
4601 }
4602
TEST_F(ArithmeticOptimizerTest,SimplifyAggregationBFloat16)4603 TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
4604 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4605 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4606 Output cast = ops::Cast(s.WithOpName("cast"), x, DT_BFLOAT16);
4607 Output add = ops::AddN(s.WithOpName("add"), {cast, cast});
4608 Output id = ops::Identity(s.WithOpName("id"), add);
4609
4610 GrapplerItem item;
4611 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4612 item.fetch = {"id"};
4613 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4614 ASSERT_EQ(tensors_expected.size(), 1);
4615
4616 GraphDef output;
4617 ArithmeticOptimizer optimizer;
4618 EnableOnlySimplifyAggregation(&optimizer);
4619 OptimizeAndPrune(&optimizer, &item, &output);
4620
4621 // Extra node created for multiplier.
4622 EXPECT_EQ(output.node_size(), 5);
4623
4624 auto tensors = EvaluateNodes(output, item.fetch);
4625 ASSERT_EQ(tensors.size(), 1);
4626 test::ExpectTensorEqual<bfloat16>(tensors[0], tensors_expected[0]);
4627 }
4628
TEST_F(ArithmeticOptimizerTest,SimplifyEmbeddingLookup)4629 TEST_F(ArithmeticOptimizerTest, SimplifyEmbeddingLookup) {
4630 for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
4631 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4632 Output embeddings = ops::Const(s.WithOpName("embeddings"),
4633 {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4634 Output segment_ids =
4635 ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
4636 Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
4637 auto unique = ops::Unique(s.WithOpName("unique"), indices,
4638 /*attrs=*/{unique_idx_type});
4639 Output ids = unique.y;
4640 Output idx = unique.idx;
4641 Output gathered_rows =
4642 ops::Gather(s.WithOpName("gathered_rows"), embeddings, ids);
4643 Output result = ops::SparseSegmentSum(s.WithOpName("result"), gathered_rows,
4644 idx, segment_ids);
4645 Output id = ops::Identity(s.WithOpName("id"), result);
4646
4647 GrapplerItem item;
4648 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4649 item.fetch = {"id"};
4650 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4651 ASSERT_EQ(tensors_expected.size(), 1);
4652
4653 GraphDef output;
4654 ArithmeticOptimizer optimizer;
4655 EnableOnlySimplifyEmbeddingLookup(&optimizer);
4656 OptimizeAndPrune(&optimizer, &item, &output);
4657
4658 for (const auto& node : output.node()) {
4659 if (node.name() == "result") {
4660 EXPECT_EQ(node.input(0), "embeddings");
4661 EXPECT_EQ(node.input(1), "indices");
4662 }
4663 EXPECT_NE(node.op(), "Unique");
4664 EXPECT_NE(node.op(), "Gather");
4665 }
4666
4667 auto tensors = EvaluateNodes(output, item.fetch);
4668 ASSERT_EQ(tensors.size(), 1);
4669 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4670 }
4671 }
4672
TEST_F(ArithmeticOptimizerTest,SimplifyResourceEmbeddingLookup)4673 TEST_F(ArithmeticOptimizerTest, SimplifyResourceEmbeddingLookup) {
4674 for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
4675 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4676 Output embeddings = ops::Const(s.WithOpName("embeddings"),
4677 {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4678 Output segment_ids =
4679 ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
4680 Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
4681 auto unique = ops::Unique(s.WithOpName("unique"), indices,
4682 /*attrs=*/{unique_idx_type});
4683 Output ids = unique.y;
4684 Output idx = unique.idx;
4685
4686 auto var =
4687 ops::VarHandleOp(s.WithOpName("var"), DT_FLOAT, TensorShape({2, 2}));
4688 ops::AssignVariableOp assign_op(s.WithOpName("assign_var_handle"), var,
4689 embeddings);
4690
4691 Output gathered_rows = ops::ResourceGather(
4692 s.WithOpName("gathered_rows")
4693 .WithControlDependencies(std::vector<Operation>{assign_op}),
4694 var, ids, DT_FLOAT);
4695 gathered_rows.node()->AddAttr("_class", {"test_class"});
4696 Output result =
4697 ops::SparseSegmentSum(s.WithOpName("result").WithControlDependencies(
4698 std::vector<Operation>{assign_op}),
4699 gathered_rows, idx, segment_ids);
4700 Output id = ops::Identity(s.WithOpName("id"), result);
4701
4702 GrapplerItem item;
4703 item.init_ops.push_back("assign_var_handle");
4704 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4705 item.fetch = {"id"};
4706 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4707 ASSERT_EQ(tensors_expected.size(), 1);
4708
4709 GraphDef output;
4710 ArithmeticOptimizer optimizer;
4711 EnableOnlySimplifyEmbeddingLookup(&optimizer);
4712 OptimizeAndPrune(&optimizer, &item, &output);
4713 bool read_var_node_found = false;
4714 for (const auto& node : output.node()) {
4715 if (node.name() == "result") {
4716 EXPECT_EQ(
4717 node.input(0),
4718 "ArithmeticOptimizer/SimplifyEmbeddingLookupStage_ReadVar_result");
4719 EXPECT_EQ(node.input(1), "indices");
4720 }
4721 if (node.op() == "ReadVariableOp") {
4722 read_var_node_found = true;
4723 EXPECT_EQ(node.attr().at("_class").list().s(0), "test_class");
4724 }
4725 EXPECT_NE(node.op(), "Unique");
4726 EXPECT_NE(node.op(), "Gather");
4727 }
4728 EXPECT_TRUE(read_var_node_found);
4729 // Add a control dependency to the ReadVar to do the AssignVar first. This
4730 // shouldn't be an issue in actual use as variables are assumed initialized
4731 // during setup.
4732 for (int i = 0; i < output.node_size(); ++i) {
4733 if (output.node(i).name() ==
4734 "ArithmeticOptimizer/SimplifyEmbeddingLookupStage_ReadVar_result") {
4735 output.mutable_node(i)->add_input("^assign_var_handle");
4736 }
4737 }
4738
4739 auto tensors = EvaluateNodes(output, item.fetch);
4740 ASSERT_EQ(tensors.size(), 1);
4741 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4742 }
4743 }
4744
TEST_F(ArithmeticOptimizerTest,RemoveCastIntoSegmentReduction)4745 TEST_F(ArithmeticOptimizerTest, RemoveCastIntoSegmentReduction) {
4746 for (DataType indices_type : {DT_INT32, DT_INT64}) {
4747 for (DataType segment_ids_type : {DT_INT32, DT_INT64}) {
4748 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4749 Output embeddings = ops::Const(s.WithOpName("embeddings"),
4750 {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4751 Output indices =
4752 ops::Cast(s.WithOpName("cast_indices"),
4753 ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1}),
4754 indices_type);
4755 Output segment_ids = ops::Cast(
4756 s.WithOpName("cast_segment_ids"),
4757 ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2}),
4758 segment_ids_type);
4759 Output result = ops::SparseSegmentSum(s.WithOpName("result"), embeddings,
4760 indices, segment_ids);
4761 Output id = ops::Identity(s.WithOpName("id"), result);
4762
4763 GrapplerItem item;
4764 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4765 item.fetch = {"id"};
4766 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4767 ASSERT_EQ(tensors_expected.size(), 1);
4768
4769 GraphDef output;
4770 ArithmeticOptimizer optimizer;
4771 EnableOnlyRemoveCastIntoSegmentReduction(&optimizer);
4772 OptimizeAndPrune(&optimizer, &item, &output);
4773
4774 for (const auto& node : output.node()) {
4775 if (node.name() == "result") {
4776 EXPECT_EQ(node.input(1), "indices");
4777 EXPECT_EQ(node.input(2), "segment_ids");
4778 }
4779 EXPECT_NE(node.op(), "Cast");
4780 }
4781
4782 auto tensors = EvaluateNodes(output, item.fetch);
4783 ASSERT_EQ(tensors.size(), 1);
4784 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4785 }
4786 }
4787 }
4788
4789 } // namespace grappler
4790 } // namespace tensorflow
4791