• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
17 #include "tensorflow/cc/ops/array_ops.h"
18 #include "tensorflow/cc/ops/array_ops_internal.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/utils.h"
25 #include "tensorflow/core/grappler/utils/grappler_test.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/tensor_coding.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 namespace {
34 
35 class ConstantFoldingTest : public GrapplerTest {
36  protected:
37   template <DataType DTYPE>
SimpleNeutralElementTest()38   void SimpleNeutralElementTest() {
39     for (bool use_snapshot : {false, true}) {
40       typedef typename EnumToDataType<DTYPE>::Type T;
41       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
42       Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
43                                   ops::Placeholder::Shape(TensorShape({2, 2})));
44       Output v = ops::Variable(s.WithOpName("v"), {2, 2}, DTYPE);
45       Tensor zeros_t(DTYPE, TensorShape({2, 2}));
46       Tensor ones_t(DTYPE, TensorShape({2, 2}));
47       Tensor x_t(DTYPE, TensorShape({2, 2}));
48       for (int i = 0; i < 4; ++i) {
49         zeros_t.flat<T>()(i) = T(0);
50         ones_t.flat<T>()(i) = T(1);
51         x_t.flat<T>()(i) = T(i + 1);
52       }
53       Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
54       Output ones = ops::Const(s.WithOpName("ones"), ones_t);
55       Output mul1;
56       Output mul2;
57       Output add1;
58       Output add2;
59       if (DTYPE == DT_BOOL) {
60         mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
61         mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
62         add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
63         add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
64       } else {
65         mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
66         mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
67         add1 = ops::Add(s.WithOpName("add1"), x, zeros);
68         add1 = ops::Add(s.WithOpName("add2"), x, ones);
69       }
70       if (use_snapshot) {
71         // Add an op with ref input to prevent Snapshot from being
72         // turned into Identity.
73         ops::Assign(s.WithOpName("assign"), v, ones);
74       }
75       GrapplerItem item;
76       TF_CHECK_OK(s.ToGraphDef(&item.graph));
77       item.fetch = {"mul1", "mul2", "add1", "add2"};
78       ConstantFolding optimizer(/*cpu_device=*/nullptr);
79       GraphDef output;
80       Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
81       TF_EXPECT_OK(status);
82 
83       EXPECT_EQ(7, output.node_size());
84       const string snapshot_or_identity =
85           use_snapshot ? "Snapshot" : "Identity";
86       for (int i = 0; i < output.node_size(); ++i) {
87         const NodeDef& node = output.node(i);
88         const string& name = node.name();
89         if (name == "mul1") {
90           EXPECT_EQ("Const", node.op());
91           EXPECT_EQ("^x", node.input(0));
92           EXPECT_EQ("^zeros", node.input(1));
93         } else if (name == "mul2") {
94           EXPECT_EQ(snapshot_or_identity, node.op());
95           EXPECT_EQ("x", node.input(0));
96           EXPECT_EQ("^ones", node.input(1));
97         } else if (name == "add1") {
98           EXPECT_EQ(snapshot_or_identity, node.op());
99           EXPECT_EQ("x", node.input(0));
100           EXPECT_EQ("^zeros", node.input(1));
101         } else if (name == "add2") {
102           if (DTYPE == DT_BOOL) {
103             EXPECT_EQ("Const", node.op());
104             EXPECT_EQ("^x", node.input(0));
105             EXPECT_EQ("^ones", node.input(1));
106           } else {
107             EXPECT_EQ("Add", node.op());
108             EXPECT_EQ("x", node.input(0));
109             EXPECT_EQ("ones", node.input(1));
110           }
111         }
112       }
113       auto tensors_expected =
114           EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
115       auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
116       EXPECT_EQ(4, tensors_expected.size());
117       EXPECT_EQ(4, tensors.size());
118       for (int i = 0; i < item.fetch.size(); ++i) {
119         test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
120       }
121     }
122   }
123 
MulConvPushDownTest(const TensorShape & input_shape,const TensorShape & filter_shape,const TensorShape & mul_const_input_shape,const bool use_3d_conv,const char * padding,const char * data_format,const bool expect_folded)124   void MulConvPushDownTest(const TensorShape& input_shape,
125                            const TensorShape& filter_shape,
126                            const TensorShape& mul_const_input_shape,
127                            const bool use_3d_conv, const char* padding,
128                            const char* data_format, const bool expect_folded) {
129     // Tests if the following rewrite is performed:
130     //
131     //         *                       Conv2D
132     //        / \                       / \
133     //       c  Conv2D        -->      x  (c * filter)
134     //           / \
135     //          x  filter
136     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
137 
138     Tensor filter_values(DT_FLOAT, filter_shape);
139     for (int i = 0; i < filter_values.NumElements(); ++i) {
140       filter_values.flat<float>()(i) = std::sqrt(static_cast<float>(i));
141     }
142     Output filter =
143         ops::Const(s.WithOpName("filter"), Input::Initializer(filter_values));
144 
145     Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
146                                     ops::Placeholder::Shape(input_shape));
147 
148     Output conv;
149     if (use_3d_conv) {
150       conv = ops::Conv3D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1, 1},
151                          padding, ops::Conv3D::DataFormat(data_format));
152     } else {
153       conv = ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1},
154                          padding, ops::Conv2D::DataFormat(data_format));
155     }
156     Tensor mul_const_input(DT_FLOAT, mul_const_input_shape);
157     for (int i = 0; i < mul_const_input.NumElements(); ++i) {
158       mul_const_input.flat<float>()(i) = static_cast<float>(i + 3);
159     }
160     Output c =
161         ops::Const(s.WithOpName("c"), Input::Initializer(mul_const_input));
162     Output mul = ops::Mul(s.WithOpName("mul"), c, conv);
163 
164     GrapplerItem item;
165     TF_CHECK_OK(s.ToGraphDef(&item.graph));
166 
167     ConstantFolding optimizer(/*cpu_device=*/nullptr);
168     GraphDef output;
169     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
170     TF_EXPECT_OK(status);
171 
172     EXPECT_EQ(5, output.node_size());
173     int found = 0;
174     if (expect_folded) {
175       for (const auto& node : output.node()) {
176         if (node.name() == "mul") {
177           found++;
178           EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
179           EXPECT_EQ(2, node.input_size());
180           EXPECT_EQ("x", node.input(0));
181           EXPECT_EQ("conv/merged_input", node.input(1));
182         } else if (node.name() == "conv/merged_input") {
183           found++;
184           EXPECT_EQ("Const", node.op());
185           EXPECT_EQ(0, node.input_size());
186         }
187       }
188     } else {
189       for (const auto& node : output.node()) {
190         if (node.name() == "mul") {
191           found++;
192           EXPECT_EQ("Mul", node.op());
193           EXPECT_EQ(2, node.input_size());
194           EXPECT_EQ("c", node.input(0));
195           EXPECT_EQ("conv", node.input(1));
196         } else if (node.name() == "conv") {
197           found++;
198           EXPECT_EQ(use_3d_conv ? "Conv3D" : "Conv2D", node.op());
199           EXPECT_EQ(2, node.input_size());
200           EXPECT_EQ("x", node.input(0));
201           EXPECT_EQ("filter", node.input(1));
202         }
203       }
204     }
205     EXPECT_EQ(2, found);
206 
207     // Check that const folded multiplication node has the expected value.
208     std::vector<string> fetch = {"mul"};
209     Tensor value(DT_FLOAT, input_shape);
210     for (int i = 0; i < value.NumElements(); ++i) {
211       value.flat<float>()(i) = i;
212     }
213     auto actual = EvaluateNodes(output, fetch, {{"x", value}});
214     auto expected = EvaluateNodes(item.graph, fetch, {{"x", value}});
215     test::ExpectTensorEqual<float>(expected[0], actual[0]);
216   }
217 };
218 
TEST_F(ConstantFoldingTest,SimpleFolding)219 TEST_F(ConstantFoldingTest, SimpleFolding) {
220   // Build a simple graph with a few trivially prunable ops.
221   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
222 
223   Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
224   Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
225   Output c = ops::AddN(s.WithOpName("c").WithDevice("/CPU:0"), {a, b});
226   Output d = ops::AddN(s.WithOpName("d"), {b, c});
227 
228   GrapplerItem item;
229   item.fetch.push_back("d");
230   TF_CHECK_OK(s.ToGraphDef(&item.graph));
231 
232   ConstantFolding optimizer(/*cpu_device=*/nullptr);
233   GraphDef output;
234   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
235   TF_EXPECT_OK(status);
236 
237   EXPECT_EQ(1, output.node_size());
238 
239   const NodeDef& node_d = output.node(0);
240   EXPECT_EQ("d", node_d.name());
241   EXPECT_EQ("Const", node_d.op());
242 
243   std::vector<string> fetch = {"d"};
244   auto tensors_expected = EvaluateNodes(item.graph, fetch);
245   auto tensors = EvaluateNodes(output, fetch);
246   EXPECT_EQ(1, tensors_expected.size());
247   EXPECT_EQ(1, tensors.size());
248   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
249 }
250 
TEST_F(ConstantFoldingTest,AddTree)251 TEST_F(ConstantFoldingTest, AddTree) {
252   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
253 
254   Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
255   Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
256   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
257                               ops::Placeholder::Shape(TensorShape({2, 2})));
258   Output add_child = ops::Add(s.WithOpName("add_child"), c2, x);
259   Output c1 = ops::Const(s.WithOpName("c1").WithControlDependencies(add_child),
260                          1.0f, {1});
261   Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child);
262 
263   Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
264                               ops::Placeholder::Shape(TensorShape({2, 2})));
265   Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
266   Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2});
267   Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2});
268   Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y);
269   Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child);
270   Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x);
271   Output addmul_parent =
272       ops::Mul(s.WithOpName("addmul_parent"), c5, addmul_child);
273 
274   GrapplerItem item;
275   item.fetch = {"add_parent", "mul_parent", "addmul_parent"};
276   TF_CHECK_OK(s.ToGraphDef(&item.graph));
277 
278   ConstantFolding optimizer(/*cpu_device=*/nullptr);
279   GraphDef output;
280   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
281   TF_EXPECT_OK(status);
282 
283   // We expect the following rewrite(s) to occur:
284   //
285   //    +                +             +
286   //   / \              / \           / \
287   // 1.0  +     -->    x   +    -->  x  3.0
288   //     / \              / \
289   //   2.0  x           1.0 2.0
290   //
291   //    *                *             *
292   //   / \              / \           / \
293   // 4.0  *     -->    y   *    -->  y  20.0
294   //     / \              / \
295   //   5.0  y           4.0 5.0
296 
297   EXPECT_EQ(11, output.node_size());
298   for (const auto& node : output.node()) {
299     if (node.name() == "add_child") {
300       EXPECT_EQ("Const", node.op());
301       TensorProto t = node.attr().at("value").tensor();
302       EXPECT_EQ(1, t.tensor_shape().dim_size());
303       EXPECT_EQ(2, t.tensor_shape().dim(0).size());
304     } else if (node.name() == "add_parent") {
305       EXPECT_EQ("Add", node.op());
306       EXPECT_EQ(2, node.input_size());
307       EXPECT_EQ("x", node.input(0));
308       EXPECT_EQ("add_child", node.input(1));
309     } else if (node.name() == "mul_child") {
310       EXPECT_EQ("Const", node.op());
311       TensorProto t = node.attr().at("value").tensor();
312       EXPECT_EQ(1, t.tensor_shape().dim_size());
313       EXPECT_EQ(2, t.tensor_shape().dim(0).size());
314     } else if (node.name() == "mul_parent") {
315       EXPECT_EQ("Mul", node.op());
316       EXPECT_EQ(2, node.input_size());
317       EXPECT_EQ("y", node.input(0));
318       EXPECT_EQ("mul_child", node.input(1));
319     } else if (node.name() == "addmul_child") {
320       // Unchanged.
321       EXPECT_EQ("Add", node.op());
322       EXPECT_EQ(2, node.input_size());
323       EXPECT_EQ("c4", node.input(0));
324       EXPECT_EQ("x", node.input(1));
325     }
326   }
327 
328   // Check that the result nodes have the expected value.
329   std::vector<string> fetch = {"c3", "c20"};
330   auto tensor_expected = EvaluateNodes(item.graph, fetch);
331   EXPECT_EQ(fetch.size(), tensor_expected.size());
332   fetch = {"add_child", "mul_child"};
333   auto tensors = EvaluateNodes(output, fetch);
334   EXPECT_EQ(fetch.size(), tensors.size());
335   for (int i = 0; i < fetch.size(); i++) {
336     test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
337   }
338 }
339 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_ScalarConst)340 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) {
341   for (string data_format : {
342          "NHWC",
343 #if GOOGLE_CUDA
344              "NCHW"
345 #endif  // GOOGLE_CUDA
346        }) {
347     MulConvPushDownTest(
348         /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
349                                               : TensorShape{4, 3, 10, 10},
350         /*filter_shape=*/{2, 2, 3, 5},
351         /*mul_const_input_shape=*/{},
352         /*use_3d_conv=*/false,
353         /*padding=*/"VALID", data_format.c_str(),
354         /*expect_folded=*/true);
355   }
356 }
357 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst)358 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_SingletonConst) {
359   for (string data_format : {
360          "NHWC",
361 #if GOOGLE_CUDA
362              "NCHW"
363 #endif  // GOOGLE_CUDA
364        }) {
365     for (auto mul_const_input_shape :
366          {TensorShape{1}, TensorShape{1, 1, 1, 1}}) {
367       MulConvPushDownTest(
368           /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
369                                                 : TensorShape{4, 3, 10, 10},
370           /*filter_shape=*/{2, 2, 3, 5}, mul_const_input_shape,
371           /*use_3d_conv=*/false,
372           /*padding=*/"VALID", data_format.c_str(),
373           /*expect_folded=*/true);
374     }
375   }
376 }
377 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch)378 TEST_F(ConstantFoldingTest,
379        MulConvPushDownTest_Conv2D_SingletonConst_ShapeMismatch) {
380   for (string data_format : {
381          "NHWC",
382 #if GOOGLE_CUDA
383              "NCHW"
384 #endif  // GOOGLE_CUDA
385        }) {
386     MulConvPushDownTest(
387         /*input_shape=*/data_format == "NHWC" ? TensorShape{4, 10, 10, 3}
388                                               : TensorShape{4, 3, 10, 10},
389         /*filter_shape=*/{2, 2, 3, 5},
390         /*mul_const_input_shape=*/{1, 1, 1, 1, 1},
391         /*use_3d_conv=*/false,
392         /*padding=*/"VALID", data_format.c_str(),
393         /*expect_folded=*/false);
394   }
395 }
396 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1x3Const)397 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1x3Const) {
398   for (auto data_format : {
399          "NHWC",
400 #if GOOGLE_CUDA
401              "NCHW"
402 #endif  // GOOGLE_CUDA
403        }) {
404     MulConvPushDownTest(
405         /*input_shape=*/{3, 3, 3, 3},
406         /*filter_shape=*/{3, 3, 3, 3},
407         /*mul_const_input_shape=*/{3, 1, 3},
408         /*use_3d_conv=*/false,
409         /*padding=*/"SAME", data_format,
410         /*expect_folded=*/false);
411   }
412 }
413 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst)414 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NHWC_VectorLikeConst) {
415   for (auto mul_const_input_shape :
416        {TensorShape{3}, TensorShape{1, 3}, TensorShape{1, 1, 1, 3}}) {
417     MulConvPushDownTest(
418         /*input_shape=*/{3, 3, 3, 3},
419         /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
420         /*use_3d_conv=*/false,
421         /*padding=*/"SAME",
422         /*data_format=*/"NHWC",
423         /*expect_folded=*/true);
424   }
425 }
426 
427 #if GOOGLE_CUDA
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst)428 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_NCHW_VectorLikeConst) {
429   for (auto mul_const_input_shape :
430        {TensorShape{3}, TensorShape{3, 1, 1}, TensorShape{1, 3, 1, 1}}) {
431     MulConvPushDownTest(
432         /*input_shape=*/{3, 3, 3, 3},
433         /*filter_shape=*/{3, 3, 3, 3}, mul_const_input_shape,
434         /*use_3d_conv=*/false,
435         /*padding=*/"SAME",
436         /*data_format=*/"NCHW",
437         // TODO(laigd): optimization should happen in this case.
438         /*expect_folded=*/false);
439   }
440 }
441 #endif  // GOOGLE_CUDA
442 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv2D_3x1Const)443 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_3x1Const) {
444   for (auto data_format : {
445          "NHWC",
446 #if GOOGLE_CUDA
447              "NCHW"
448 #endif  // GOOGLE_CUDA
449        }) {
450     MulConvPushDownTest(
451         /*input_shape=*/{3, 3, 3, 3},
452         /*filter_shape=*/{3, 3, 3, 3},
453         /*mul_const_input_shape=*/{3, 1},
454         /*use_3d_conv=*/false,
455         /*padding=*/"SAME", data_format,
456         /*expect_folded=*/false);
457   }
458 }
459 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const)460 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NDHWC_1x1x3Const) {
461   MulConvPushDownTest(
462       /*input_shape=*/{3, 3, 3, 3, 3},
463       /*filter_shape=*/{3, 3, 3, 3, 3},
464       /*mul_const_input_shape=*/{1, 1, 3},
465       /*use_3d_conv=*/true,
466       /*padding=*/"SAME",
467       /*data_format=*/"NDHWC",
468       /*expect_folded=*/true);
469 }
470 
TEST_F(ConstantFoldingTest,MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const)471 TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv3D_NCDHW_3x1x1x1Const) {
472   MulConvPushDownTest(
473       /*input_shape=*/{3, 3, 3, 3, 3},
474       /*filter_shape=*/{3, 3, 3, 3, 3},
475       /*mul_const_input_shape=*/{3, 1, 1, 1},
476       /*use_3d_conv=*/true,
477       /*padding=*/"SAME",
478       /*data_format=*/"NDHWC",
479       // TODO(laigd): optimization should happen in this case.
480       /*expect_folded=*/false);
481 }
482 
TEST_F(ConstantFoldingTest,NeutralElement)483 TEST_F(ConstantFoldingTest, NeutralElement) {
484   int kConst = 0;
485   int kLike = 1;
486   int kFill = 2;
487   for (int const_type : {kConst, kLike, kFill}) {
488     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
489     Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
490                                 ops::Placeholder::Shape(TensorShape({2, 2})));
491     Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
492                                 ops::Placeholder::Shape(TensorShape({2, 2})));
493     Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
494                                 ops::Placeholder::Shape(TensorShape({3, 2})));
495     Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
496                                 ops::Placeholder::Shape(TensorShape({2, 3})));
497     Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT,
498                                    ops::Placeholder::Shape(TensorShape({2})));
499     Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
500     Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
501     Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
502     Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
503     Output zeros = const_type == kConst
504                        ? zeros_const
505                        : (const_type == kLike ? zeros_like : zeros_fill);
506     Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
507     Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
508     Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
509     Output ones = const_type == kConst
510                       ? ones_const
511                       : (const_type == kLike ? ones_like : ones_fill);
512     Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
513     Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
514     Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
515     Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
516     Output mul5 = ops::Mul(s.WithOpName("mul5"), x, zeros_1d);
517     Output mul6 = ops::Mul(s.WithOpName("mul6"), zeros_1d, y);
518     Output div1 = ops::Div(s.WithOpName("div1"), x, ones);
519     Output div2 = ops::Div(s.WithOpName("div2"), ones, y);
520     Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros);
521     Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y);
522     Output matmul3 = ops::MatMul(s.WithOpName("matmul3"), a, zeros);
523     Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
524     Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
525     Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
526     Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
527     Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
528     Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
529     Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y);
530     Output concat =
531         ops::Stack(s.WithOpName("stack"),
532                    {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
533                     matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
534     GrapplerItem item;
535     TF_CHECK_OK(s.ToGraphDef(&item.graph));
536     item.fetch = {"stack", "matmul3", "matmul4"};
537 
538     ConstantFolding optimizer(/*cpu_device=*/nullptr);
539     GraphDef output;
540     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
541     TF_EXPECT_OK(status);
542 
543     const string suffix =
544         (const_type == kConst ? "_const"
545                               : (const_type == kLike ? "_like" : "_fill"));
546     const string zeros_name = strings::StrCat("zeros", suffix);
547     const string ones_name = strings::StrCat("ones", suffix);
548     const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
549     const string ctrl_ones_name = strings::StrCat("^ones", suffix);
550     EXPECT_EQ(const_type == kFill ? 31 : 27, output.node_size());
551     for (int i = 0; i < output.node_size(); ++i) {
552       const NodeDef& node = output.node(i);
553       const string& name = node.name();
554       if (name == "mul1") {
555         EXPECT_EQ("Const", node.op());
556         EXPECT_EQ("^x", node.input(0));
557         EXPECT_EQ(ctrl_zeros_name, node.input(1));
558       } else if (name == "mul2") {
559         EXPECT_EQ("Const", node.op());
560         EXPECT_EQ(ctrl_zeros_name, node.input(0));
561         EXPECT_EQ("^y", node.input(1));
562       } else if (name == "mul3") {
563         EXPECT_EQ("Identity", node.op());
564         EXPECT_EQ("x", node.input(0));
565         EXPECT_EQ(ctrl_ones_name, node.input(1));
566       } else if (name == "mul4") {
567         EXPECT_EQ("Identity", node.op());
568         EXPECT_EQ("y", node.input(0));
569         EXPECT_EQ(ctrl_ones_name, node.input(1));
570       } else if (name == "mul5") {
571         EXPECT_EQ("Const", node.op());
572         EXPECT_EQ("^x", node.input(0));
573         EXPECT_EQ("^zeros_1d", node.input(1));
574       } else if (name == "mul6") {
575         EXPECT_EQ("Const", node.op());
576         EXPECT_EQ("^zeros_1d", node.input(0));
577         EXPECT_EQ("^y", node.input(1));
578       } else if (name == "div1") {
579         EXPECT_EQ("Identity", node.op());
580         EXPECT_EQ("x", node.input(0));
581         EXPECT_EQ(ctrl_ones_name, node.input(1));
582       } else if (name == "div2") {
583         EXPECT_EQ("Reciprocal", node.op());
584         EXPECT_EQ("y", node.input(0));
585         EXPECT_EQ(ctrl_ones_name, node.input(1));
586       } else if (name == "matmul1") {
587         EXPECT_EQ("Const", node.op());
588         EXPECT_EQ("^x", node.input(0));
589         EXPECT_EQ(ctrl_zeros_name, node.input(1));
590       } else if (name == "matmul2") {
591         EXPECT_EQ("Const", node.op());
592         EXPECT_EQ(ctrl_zeros_name, node.input(0));
593         EXPECT_EQ("^y", node.input(1));
594       } else if (name == "matmul3") {
595         EXPECT_EQ("Const", node.op());
596         EXPECT_EQ("^a", node.input(0));
597         EXPECT_EQ(ctrl_zeros_name, node.input(1));
598         TensorProto t = node.attr().at("value").tensor();
599         EXPECT_EQ(1, t.float_val_size());
600         EXPECT_EQ(0, t.float_val(0));
601         EXPECT_EQ(2, t.tensor_shape().dim_size());
602         EXPECT_EQ(3, t.tensor_shape().dim(0).size());
603         EXPECT_EQ(2, t.tensor_shape().dim(1).size());
604       } else if (name == "matmul4") {
605         EXPECT_EQ("Const", node.op());
606         EXPECT_EQ(ctrl_zeros_name, node.input(0));
607         EXPECT_EQ("^b", node.input(1));
608         TensorProto t = node.attr().at("value").tensor();
609         EXPECT_EQ(1, t.float_val_size());
610         EXPECT_EQ(0, t.float_val(0));
611         EXPECT_EQ(2, t.tensor_shape().dim_size());
612         EXPECT_EQ(2, t.tensor_shape().dim(0).size());
613         EXPECT_EQ(3, t.tensor_shape().dim(1).size());
614       } else if (name == "add1") {
615         EXPECT_EQ("Identity", node.op());
616         EXPECT_EQ("x", node.input(0));
617         EXPECT_EQ(ctrl_zeros_name, node.input(1));
618       } else if (name == "add2") {
619         EXPECT_EQ("Identity", node.op());
620         EXPECT_EQ("y", node.input(0));
621         EXPECT_EQ(ctrl_zeros_name, node.input(1));
622       } else if (name == "bias_add1") {
623         EXPECT_EQ("Identity", node.op());
624         EXPECT_EQ("x", node.input(0));
625         EXPECT_EQ("^zeros_1d", node.input(1));
626       } else if (name == "bias_add2") {
627         // We don't eliminate this one, because it requires broadcasting.
628         EXPECT_EQ("BiasAdd", node.op());
629         EXPECT_EQ(zeros_name, node.input(0));
630         EXPECT_EQ("bias", node.input(1));
631       } else if (name == "sub1") {
632         EXPECT_EQ("Identity", node.op());
633         EXPECT_EQ("x", node.input(0));
634         EXPECT_EQ(ctrl_zeros_name, node.input(1));
635       } else if (name == "sub2") {
636         EXPECT_EQ("Neg", node.op());
637         EXPECT_EQ("y", node.input(0));
638         EXPECT_EQ(ctrl_zeros_name, node.input(1));
639       }
640       const std::set<string> square_zero_const{"mul1", "mul2",    "mul5",
641                                                "mul6", "matmul1", "matmul2"};
642       if (square_zero_const.count(name) > 0) {
643         TensorProto t = node.attr().at("value").tensor();
644         EXPECT_EQ(1, t.float_val_size());
645         EXPECT_EQ(0, t.float_val(0));
646         EXPECT_EQ(2, t.tensor_shape().dim_size());
647         EXPECT_EQ(2, t.tensor_shape().dim(0).size());
648         EXPECT_EQ(2, t.tensor_shape().dim(1).size());
649       }
650     }
651     auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
652     auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
653     auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
654     auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
655     auto bias_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
656 
657     auto tensors_expected = EvaluateNodes(
658         item.graph, item.fetch,
659         {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
660     EXPECT_EQ(item.fetch.size(), tensors_expected.size());
661     auto tensors = EvaluateNodes(
662         output, item.fetch,
663         {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
664     EXPECT_EQ(item.fetch.size(), tensors.size());
665     for (int i = 0; i < item.fetch.size(); ++i) {
666       test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
667     }
668   }
669 }
670 
TEST_F(ConstantFoldingTest,NeutralElement_ShortFloats)671 TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
672   SimpleNeutralElementTest<DT_BOOL>();
673   SimpleNeutralElementTest<DT_HALF>();
674   SimpleNeutralElementTest<DT_BFLOAT16>();
675 }
676 
TEST_F(ConstantFoldingTest,StrengthReduce_Reciprocal)677 TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
678   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
679   Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
680   Output xf = ops::Placeholder(s.WithOpName("xf"), DT_FLOAT,
681                                ops::Placeholder::Shape(TensorShape({2, 2})));
682   Output xi = ops::Placeholder(s.WithOpName("xi"), DT_INT32,
683                                ops::Placeholder::Shape(TensorShape({2, 2})));
684   Output ci = ops::Const(s.WithOpName("ci"), 2, {1});
685   Output cf = ops::Const(s.WithOpName("cf"), 2.0f, {1});
686   Output div_i = ops::Div(s.WithOpName("div_i"), xi, ci);
687   Output div_f = ops::Div(s.WithOpName("div_f"), xf, cf);
688   Output realdiv = ops::RealDiv(s.WithOpName("realdiv"), xf, cf);
689 
690   GrapplerItem item;
691   TF_CHECK_OK(s.ToGraphDef(&item.graph));
692   item.fetch = {"div_f", "div_i", "realdiv"};
693   ConstantFolding optimizer(/*cpu_device=*/nullptr);
694   GraphDef output;
695   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
696   TF_EXPECT_OK(status);
697 
698   EXPECT_EQ(8, output.node_size());
699   for (int i = 0; i < output.node_size(); ++i) {
700     const NodeDef& node = output.node(i);
701     const string& name = node.name();
702     if (name == "div_i") {
703       // Integer division is unchanged.
704       EXPECT_EQ("Div", node.op());
705       EXPECT_EQ("xi", node.input(0));
706       EXPECT_EQ("ci", node.input(1));
707     } else if (name == "div_f") {
708       EXPECT_EQ("Mul", node.op());
709       EXPECT_EQ("xf", node.input(0));
710       EXPECT_EQ("ConstantFolding/div_f_recip", node.input(1));
711     } else if (name == "realdiv") {
712       EXPECT_EQ("Mul", node.op());
713       EXPECT_EQ("xf", node.input(0));
714       EXPECT_EQ("ConstantFolding/realdiv_recip", node.input(1));
715     } else if (name == "ConstantFolding/div_f_recip") {
716       EXPECT_EQ("Const", node.op());
717       EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
718       TensorProto t = node.attr().at("value").tensor();
719       EXPECT_EQ(DT_FLOAT, t.dtype());
720       EXPECT_EQ(1, t.tensor_shape().dim_size());
721       EXPECT_EQ(1, t.tensor_shape().dim(0).size());
722     } else if (name == "ConstantFolding/realdiv_recip") {
723       EXPECT_EQ("Const", node.op());
724       EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
725       TensorProto t = node.attr().at("value").tensor();
726       EXPECT_EQ(DT_FLOAT, t.dtype());
727       EXPECT_EQ(1, t.tensor_shape().dim_size());
728       EXPECT_EQ(1, t.tensor_shape().dim(0).size());
729     }
730   }
731 
732   // Check that the reciprocals have the expected value.
733   std::vector<string> fetch = {"cf_half"};
734   auto tensor_expected = EvaluateNodes(item.graph, fetch);
735   EXPECT_EQ(fetch.size(), tensor_expected.size());
736   fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
737   auto tensors = EvaluateNodes(output, fetch);
738   EXPECT_EQ(fetch.size(), tensors.size());
739   for (int i = 0; i < fetch.size(); i++) {
740     test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
741   }
742 }
743 
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_UnknownOutputShape)744 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
745   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
746   Output x_known =
747       ops::Placeholder(s.WithOpName("x_known"), DT_FLOAT,
748                        ops::Placeholder::Shape(TensorShape({2, 2})));
749   Output x_partially_known =
750       ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
751                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
752   Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
753   Output zeros_known = ops::ZerosLike(s.WithOpName("zeros_known"), x_known);
754   Output zeros_partially_known =
755       ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
756   Output zeros_unknown =
757       ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
758 
759   // Multiplies without any additional ops to supply the output shape.
760   int count = 0;
761   std::vector<Output> muls;
762   std::unordered_set<string> not_converted;
763   std::unordered_set<string> to_const;
764   std::unordered_set<string> to_identity;
765   for (const auto* x : {&x_known, &x_partially_known, &x_unknown}) {
766     for (const auto* zeros :
767          {&zeros_known, &zeros_partially_known, &zeros_unknown}) {
768       const string name = strings::StrCat("mul_", count++);
769       muls.push_back(ops::Mul(s.WithOpName(name), *x, *zeros));
770       if (x == &x_partially_known && zeros == &zeros_partially_known) {
771         to_identity.insert(name);
772       } else if (x == &x_unknown || zeros == &zeros_unknown) {
773         not_converted.insert(name);
774       } else {
775         to_const.insert(name);
776       }
777     }
778   }
779 
780   GrapplerItem item;
781   TF_CHECK_OK(s.ToGraphDef(&item.graph));
782 
783   ConstantFolding optimizer(/*cpu_device=*/nullptr);
784   GraphDef output;
785   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
786   TF_EXPECT_OK(status);
787 
788   EXPECT_EQ(15, output.node_size());
789   for (int i = 0; i < output.node_size(); ++i) {
790     const NodeDef& node = output.node(i);
791     const string& name = node.name();
792     if (to_const.count(name) > 0) {
793       EXPECT_EQ("Const", node.op()) << node.name();
794     } else if (to_identity.count(name) > 0) {
795       EXPECT_EQ("Identity", node.op()) << node.name();
796     } else if (not_converted.count(name) > 0) {
797       EXPECT_EQ("Mul", node.op()) << node.name();
798     }
799   }
800 
801   const std::vector<string> fetch = {"mul_0", "mul_4", "mul_8"};
802   auto x_known_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
803   auto x_partially_unknown_t =
804       GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
805   auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
806   auto expected_tensors =
807       EvaluateNodes(item.graph, fetch,
808                     {{"x_known", x_known_t},
809                      {"x_partially_unknown", x_partially_unknown_t},
810                      {"x_unknown", x_unknown_t}});
811   EXPECT_EQ(fetch.size(), expected_tensors.size());
812   auto tensors = EvaluateNodes(output, fetch,
813                                {{"x_known", x_known_t},
814                                 {"x_partially_unknown", x_partially_unknown_t},
815                                 {"x_unknown", x_unknown_t}});
816   EXPECT_EQ(fetch.size(), tensors.size());
817   for (int i = 0; i < tensors.size(); i++)
818     test::ExpectTensorNear<float>(expected_tensors[i], tensors[i], 1e-5);
819 }
820 
TEST_F(ConstantFoldingTest,NeutralElement_PartialShape_KnownOutputShape)821 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
822   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
823   Output known_shape = ops::Const(s.WithOpName("known_shape"), 0.0f, {2, 2});
824   Output x_partially_known =
825       ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
826                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
827   Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
828   Output zeros_partially_known =
829       ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
830   Output zeros_unknown =
831       ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
832 
833   // If at least one of the inputs to AddN has a known shape, shape inference
834   // will propagate the shape back to the inputs of AddN, making the
835   // output shapes of all its inputs known
836   std::vector<Output> muls_deduced_output_shape;
837   std::unordered_set<string> to_const;
838   int count = 0;
839   for (const auto& x : {x_partially_known, x_unknown}) {
840     for (const auto& zeros : {zeros_partially_known, zeros_unknown}) {
841       const string name = strings::StrCat("mul_", count++);
842       muls_deduced_output_shape.push_back(
843           ops::Mul(s.WithOpName(name), x, zeros));
844       to_const.insert(name);
845     }
846   }
847   // We add a known shape as input to AddN to propagate it back to the
848   // multiplies above, which means they can all be turned into Const nodes.
849   muls_deduced_output_shape.push_back(known_shape);
850   Output addn1 = ops::AddN(s.WithOpName("addn1"), muls_deduced_output_shape);
851 
852   GrapplerItem item;
853   TF_CHECK_OK(s.ToGraphDef(&item.graph));
854 
855   ConstantFolding optimizer(/*cpu_device=*/nullptr);
856   GraphDef output;
857   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
858   TF_EXPECT_OK(status);
859 
860   EXPECT_EQ(10, output.node_size());
861   for (int i = 0; i < output.node_size(); ++i) {
862     const NodeDef& node = output.node(i);
863     const string& name = node.name();
864     if (to_const.count(name) > 0) {
865       EXPECT_EQ("Const", node.op()) << node.name();
866       EXPECT_EQ(2, node.input_size());
867       EXPECT_TRUE(IsControlInput(node.input(0)));
868       EXPECT_TRUE(IsControlInput(node.input(1)));
869     }
870   }
871   const std::vector<string> fetch = {"addn1"};
872   auto x_partially_unknown_t =
873       GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
874   auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
875   auto expected_tensors =
876       EvaluateNodes(item.graph, fetch,
877                     {{"x_partially_unknown", x_partially_unknown_t},
878                      {"x_unknown", x_unknown_t}});
879   EXPECT_EQ(1, expected_tensors.size());
880   auto tensors = EvaluateNodes(output, fetch,
881                                {{"x_partially_unknown", x_partially_unknown_t},
882                                 {"x_unknown", x_unknown_t}});
883   EXPECT_EQ(1, tensors.size());
884   test::ExpectTensorNear<float>(expected_tensors[0], tensors[0], 1e-5);
885 }
886 
TEST_F(ConstantFoldingTest,CreateConstNodes)887 TEST_F(ConstantFoldingTest, CreateConstNodes) {
888   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
889 
890 #define MAKE_TEST_GRAPH(TYPE)                                               \
891   Output TYPE##_const =                                                     \
892       ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \
893   Output TYPE##_mul =                                                       \
894       ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const);     \
895   Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul)
896 
897   MAKE_TEST_GRAPH(float);
898   MAKE_TEST_GRAPH(double);
899   MAKE_TEST_GRAPH(int64);
900   MAKE_TEST_GRAPH(int32);
901   MAKE_TEST_GRAPH(int16);
902   MAKE_TEST_GRAPH(int8);
903   MAKE_TEST_GRAPH(uint8);
904 #undef MAKE_TEST_GRAPH
905 
906   Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5});
907   Output bool_and =
908       ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const);
909   Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and);
910 
911   GrapplerItem item;
912   TF_CHECK_OK(s.ToGraphDef(&item.graph));
913   ConstantFolding optimizer(/*cpu_device=*/nullptr);
914   GraphDef output;
915   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
916   TF_EXPECT_OK(status);
917 
918   EXPECT_EQ(24, output.node_size());
919   for (const NodeDef& node : output.node()) {
920 #define CHECK_RESULT(TYPE, FIELD)                                             \
921   if (node.name() == #TYPE "_mul") {                                          \
922     EXPECT_EQ(5,                                                              \
923               node.attr().at("value").tensor().tensor_shape().dim(0).size()); \
924     EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size());        \
925     EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0));      \
926   }
927 
928     CHECK_RESULT(float, float);
929     CHECK_RESULT(double, double);
930     CHECK_RESULT(int64, int64);
931     CHECK_RESULT(int32, int);
932     CHECK_RESULT(int16, int);
933     CHECK_RESULT(int8, int);
934     CHECK_RESULT(uint8, int);
935 #undef CHECK_RESULT
936 
937     if (node.name() == "bool_and") {
938       EXPECT_EQ(5,
939                 node.attr().at("value").tensor().tensor_shape().dim(0).size());
940       EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size());
941       EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0));
942     }
943   }
944 }
945 
TEST_F(ConstantFoldingTest,FoldingNodeWithTwoOutputs)946 TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
947   // Build a simple graph with a few trivially prunable ops.
948   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
949 
950   Output a = ops::Const(s.WithOpName("a"), 10, {5});
951   auto b = ops::Unique(s.WithOpName("b"), {a});
952   Output c = ops::Identity(s.WithOpName("c"), {b.y});
953   Output d = ops::Identity(s.WithOpName("d"), {b.idx});
954   Output e = ops::Identity(s.WithOpName("e"), {c});
955   Output f = ops::Identity(s.WithOpName("f"), {d});
956 
957   GrapplerItem item;
958   item.fetch.push_back("e");
959   item.fetch.push_back("f");
960   TF_CHECK_OK(s.ToGraphDef(&item.graph));
961 
962   ConstantFolding optimizer(/*cpu_device=*/nullptr);
963   GraphDef output;
964   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
965   TF_EXPECT_OK(status);
966 
967   EXPECT_EQ(2, output.node_size());
968 
969   const NodeDef& new_c = output.node(0);
970   EXPECT_EQ("e", new_c.name());
971   EXPECT_EQ("Const", new_c.op());
972 
973   const NodeDef& new_d = output.node(1);
974   EXPECT_EQ("f", new_d.name());
975   EXPECT_EQ("Const", new_d.op());
976 
977   std::vector<string> fetch = {"e", "f"};
978   auto tensors_expected = EvaluateNodes(item.graph, fetch);
979   auto tensors = EvaluateNodes(output, fetch);
980   EXPECT_EQ(fetch.size(), tensors_expected.size());
981   EXPECT_EQ(fetch.size(), tensors.size());
982   for (int i = 0; i < fetch.size(); i++) {
983     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
984   }
985 }
986 
TEST_F(ConstantFoldingTest,ControlDependencies)987 TEST_F(ConstantFoldingTest, ControlDependencies) {
988   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
989   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
990   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
991   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
992   Output c =
993       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
994   Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
995   Output i2 =
996       ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
997   Output i3 = ops::Identity(scope.WithOpName("e"), {i2});
998 
999   GrapplerItem item;
1000   item.fetch.push_back("e");
1001   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1002 
1003   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1004   GraphDef output;
1005   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1006   TF_EXPECT_OK(status);
1007 
1008   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "e"};
1009   EXPECT_EQ(output.node_size(), expected_nodes.size());
1010   int i = 0;
1011   int found = 0;
1012   for (const auto& node : output.node()) {
1013     EXPECT_EQ(expected_nodes[i], output.node(i).name());
1014     i++;
1015     if (node.name() == "e") {
1016       EXPECT_EQ("Const", node.op());
1017       ++found;
1018       auto folded = EvaluateNodes(output, {"e"});
1019       auto expected = EvaluateNodes(item.graph, {"e"});
1020       EXPECT_EQ(1, expected.size());
1021       EXPECT_EQ(1, folded.size());
1022       test::ExpectTensorEqual<int>(folded[0], expected[0]);
1023       EXPECT_EQ(2, node.input_size());
1024       EXPECT_EQ("^p1", node.input(0));
1025       EXPECT_EQ("^p2", node.input(1));
1026     }
1027   }
1028   EXPECT_EQ(1, found);
1029 }
1030 
TEST_F(ConstantFoldingTest,ControlDependenciesEmptyFetch)1031 TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
1032   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1033   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1034   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1035   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1036   Output c =
1037       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1038   Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
1039   Output i2 =
1040       ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
1041   Output i3 = ops::Identity(scope.WithOpName("e"), {i2});
1042 
1043   GrapplerItem item;
1044   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1045 
1046   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1047   GraphDef output;
1048   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1049   TF_EXPECT_OK(status);
1050 
1051   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "c",
1052                                         "i1",   "i2", "e"};
1053   EXPECT_EQ(output.node_size(), expected_nodes.size());
1054   int i = 0;
1055   int found = 0;
1056   for (const auto& node : output.node()) {
1057     EXPECT_EQ(expected_nodes[i], output.node(i).name());
1058     i++;
1059     if (node.name() == "i1") {
1060       EXPECT_EQ("Const", node.op());
1061       ++found;
1062       auto folded = EvaluateNodes(output, {"i1"});
1063       auto expected = EvaluateNodes(item.graph, {"i1"});
1064       EXPECT_EQ(1, expected.size());
1065       EXPECT_EQ(1, folded.size());
1066       test::ExpectTensorEqual<int>(folded[0], expected[0]);
1067       EXPECT_EQ(1, node.input_size());
1068       EXPECT_EQ("^p1", node.input(0));
1069     }
1070     if (node.name() == "i2") {
1071       EXPECT_EQ("Const", node.op());
1072       ++found;
1073       auto folded = EvaluateNodes(output, {"i2"});
1074       auto expected = EvaluateNodes(item.graph, {"i2"});
1075       EXPECT_EQ(1, expected.size());
1076       EXPECT_EQ(1, folded.size());
1077       test::ExpectTensorEqual<int>(folded[0], expected[0]);
1078       EXPECT_EQ(2, node.input_size());
1079       EXPECT_EQ("^p1", node.input(0));
1080       EXPECT_EQ("^p2", node.input(1));
1081     }
1082   }
1083   EXPECT_EQ(2, found);
1084 }
1085 
TEST_F(ConstantFoldingTest,ControlDependenciesDeduplicate)1086 TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
1087   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1088   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
1089   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
1090   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
1091   Output c =
1092       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
1093   Output i1 = ops::Identity(scope.WithOpName("i1")
1094                                 .WithControlDependencies(p2)
1095                                 .WithControlDependencies(p1),
1096                             {c});
1097   Output i2 = ops::Identity(scope.WithOpName("i2"), {i1});
1098 
1099   GrapplerItem item;
1100   item.fetch.push_back("i2");
1101   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1102   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1103   EXPECT_EQ(1, tensors_expected.size());
1104   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1105   GraphDef output;
1106   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1107   TF_EXPECT_OK(status);
1108 
1109   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i2"};
1110   EXPECT_EQ(output.node_size(), expected_nodes.size());
1111   int i = 0;
1112   for (const auto& node : output.node()) {
1113     EXPECT_EQ(expected_nodes[i], output.node(i).name());
1114     i++;
1115     if (node.name() == "i2") {
1116       EXPECT_EQ("Const", node.op());
1117       EXPECT_EQ(2, node.input_size());
1118       EXPECT_EQ("^p1", node.input(0));
1119       EXPECT_EQ("^p2", node.input(1));
1120     }
1121   }
1122   auto tensors = EvaluateNodes(output, item.fetch);
1123   EXPECT_EQ(1, tensors.size());
1124   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1125 }
1126 
TEST_F(ConstantFoldingTest,VariableNumberOfOutputs)1127 TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
1128   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1129   // Add a DynamicPartition node to the graph
1130   Output input = ops::Const(scope.WithOpName("in0"), 314, {3, 4, 5});
1131   Output indices = ops::Const(scope.WithOpName("indices"), 1, {3, 4});
1132   int num_partitions = 4;
1133   ops::DynamicPartition part(scope.WithOpName("partition"), input, indices,
1134                              num_partitions);
1135 
1136   std::vector<string> outputs;
1137   for (int i = 0; i < num_partitions; ++i) {
1138     string part_out_name = strings::StrCat("part_out", i);
1139     ops::Identity partition_out(scope.WithOpName(part_out_name),
1140                                 {part.outputs[i]});
1141     outputs.push_back(part_out_name);
1142   }
1143 
1144   GrapplerItem item;
1145   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1146 
1147   // Add a ConcatOffset node to the graph
1148   Tensor initial_val(DT_INT32, TensorShape({3}));
1149   test::FillIota<int>(&initial_val, 7);
1150   for (int i = 1; i < 5; ++i) {
1151     TF_CHECK_OK(NodeDefBuilder(strings::StrCat("in", i), "Const")
1152                     .Attr("dtype", DT_INT32)
1153                     .Attr("value", initial_val)
1154                     .Finalize(item.graph.add_node()));
1155   }
1156   Tensor concat_dim(DT_INT32, TensorShape({}));
1157   test::FillIota<int>(&concat_dim, 0);
1158   TF_CHECK_OK(NodeDefBuilder("concat_dim", "Const")
1159                   .Attr("dtype", DT_INT32)
1160                   .Attr("value", concat_dim)
1161                   .Finalize(item.graph.add_node()));
1162 
1163   TF_CHECK_OK(NodeDefBuilder("concat_offsets", "ConcatOffset")
1164                   .Input("concat_dim", 0, DT_INT32)
1165                   .Input({NodeDefBuilder::NodeOut("in1", 0, DT_INT32),
1166                           NodeDefBuilder::NodeOut("in2", 0, DT_INT32),
1167                           NodeDefBuilder::NodeOut("in3", 0, DT_INT32),
1168                           NodeDefBuilder::NodeOut("in4", 0, DT_INT32)})
1169                   .Finalize(item.graph.add_node()));
1170 
1171   for (int i = 0; i < 4; ++i) {
1172     string concat_offset_out_name = strings::StrCat("concat_offset_out", i);
1173     TF_CHECK_OK(NodeDefBuilder(concat_offset_out_name, "Identity")
1174                     .Attr("T", DT_INT32)
1175                     .Input("concat_offsets", i, DT_INT32)
1176                     .Finalize(item.graph.add_node()));
1177     outputs.push_back(concat_offset_out_name);
1178   }
1179 
1180   item.fetch = outputs;
1181   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1182   GraphDef output;
1183   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1184   TF_EXPECT_OK(status);
1185 
1186   int constant_folded = 0;
1187   for (const auto& node : output.node()) {
1188     if (node.name().find("part_out") != string::npos ||
1189         node.name().find("concat_offset_out") != string::npos) {
1190       ++constant_folded;
1191       EXPECT_EQ("Const", node.op());
1192     }
1193   }
1194   EXPECT_EQ(8, constant_folded);
1195 
1196   auto expected = EvaluateNodes(item.graph, outputs);
1197   auto optimized = EvaluateNodes(output, outputs);
1198   ASSERT_EQ(expected.size(), optimized.size());
1199   for (int i = 0; i < expected.size(); ++i) {
1200     test::ExpectTensorEqual<int>(expected[i], optimized[i]);
1201   }
1202 }
1203 
TEST_F(ConstantFoldingTest,ShapeMaterialization)1204 TEST_F(ConstantFoldingTest, ShapeMaterialization) {
1205   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1206   Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1207   Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1208   Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1209   Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1210   Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1211   Output size = ops::Size(scope.WithOpName("size"), v3);
1212   Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1213   Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1214 
1215   GrapplerItem item;
1216   item.fetch.push_back("p2");
1217   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1218 
1219   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1220   GraphDef output;
1221   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1222   TF_EXPECT_OK(status);
1223 
1224   int found = 0;
1225   for (const auto& node : output.node()) {
1226     if (node.name() == "p2") {
1227       ++found;
1228       EXPECT_EQ("Const", node.op());
1229       EXPECT_EQ(3, node.input_size());
1230       EXPECT_EQ("^v3", node.input(0));
1231       EXPECT_EQ("^v1", node.input(1));
1232       EXPECT_EQ("^v2", node.input(2));
1233       Tensor value;
1234       CHECK(value.FromProto(node.attr().at("value").tensor()));
1235       // rank = 1, shape = (5, 7), size = 143 = 11*13
1236       // p2 = (715, 1001) = (5*143, 7*143)
1237       EXPECT_EQ(715, value.flat<int>()(0));
1238       EXPECT_EQ(1001, value.flat<int>()(1));
1239     }
1240   }
1241   EXPECT_EQ(1, found);
1242   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1243   auto v2_t = GenerateRandomTensor<DT_FLOAT>({5, 7});
1244   auto v3_t = GenerateRandomTensor<DT_FLOAT>({11, 13});
1245 
1246   auto tensors_expected = EvaluateNodes(
1247       item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1248   EXPECT_EQ(1, item.fetch.size());
1249   auto tensors = EvaluateNodes(output, item.fetch,
1250                                {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1251   EXPECT_EQ(1, item.fetch.size());
1252   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1253 }
1254 
TEST_F(ConstantFoldingTest,ShapeMaterializationEmptyFetch)1255 TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) {
1256   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1257   Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
1258   Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
1259   Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
1260   Output rank = ops::Rank(scope.WithOpName("rank"), v1);
1261   Output shape = ops::Shape(scope.WithOpName("shape"), v2);
1262   Output size = ops::Size(scope.WithOpName("size"), v3);
1263   Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
1264   Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
1265 
1266   GrapplerItem item;
1267   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1268 
1269   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1270   GraphDef output;
1271   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1272   TF_EXPECT_OK(status);
1273 
1274   int found = 0;
1275   for (const auto& node : output.node()) {
1276     if (node.name() == "size") {
1277       ++found;
1278       EXPECT_EQ("Const", node.op());
1279       EXPECT_EQ(1, node.input_size());
1280       EXPECT_EQ("^v3", node.input(0));
1281       Tensor value;
1282       CHECK(value.FromProto(node.attr().at("value").tensor()));
1283       EXPECT_EQ(11 * 13, value.flat<int>()(0));
1284     } else if (node.name() == "rank") {
1285       ++found;
1286       EXPECT_EQ("Const", node.op());
1287       EXPECT_EQ(1, node.input_size());
1288       EXPECT_EQ("^v1", node.input(0));
1289       Tensor value;
1290       CHECK(value.FromProto(node.attr().at("value").tensor()));
1291       EXPECT_EQ(1, value.flat<int>()(0));
1292     } else if (node.name() == "shape") {
1293       ++found;
1294       EXPECT_EQ("Const", node.op());
1295       EXPECT_EQ(1, node.input_size());
1296       EXPECT_EQ("^v2", node.input(0));
1297       Tensor value;
1298       CHECK(value.FromProto(node.attr().at("value").tensor()));
1299       EXPECT_EQ(5, value.flat<int>()(0));
1300       EXPECT_EQ(7, value.flat<int>()(1));
1301     }
1302   }
1303   EXPECT_EQ(3, found);
1304 
1305   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1306   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
1307   auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({11, 13}));
1308   std::vector<string> fetch_nodes = {"p2"};
1309   auto tensors_expected = EvaluateNodes(
1310       item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1311   EXPECT_EQ(1, tensors_expected.size());
1312   auto tensors = EvaluateNodes(output, fetch_nodes,
1313                                {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1314   EXPECT_EQ(1, tensors.size());
1315   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1316 }
1317 
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN)1318 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) {
1319   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1320   Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1321   Output v2 = ops::Variable(scope.WithOpName("v2"), {}, DT_FLOAT);
1322   Output v3 = ops::Variable(scope.WithOpName("v3"), {4, 6}, DT_FLOAT);
1323   auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2, v3});
1324   Output i1a = ops::Identity(scope.WithOpName("i1a"), s[0]);
1325   Output i1b = ops::Identity(scope.WithOpName("i1b"), s[0]);
1326   Output i2a = ops::Identity(scope.WithOpName("i2a"), s[1]);
1327   Output i2b = ops::Identity(scope.WithOpName("i2b"), s[1]);
1328   Output i2c = ops::Identity(scope.WithOpName("i2c"), s[1]);
1329   Output i3a = ops::Identity(scope.WithOpName("i3a"), s[2]);
1330   Output i3b = ops::Identity(scope.WithOpName("i3b"), s[2]);
1331 
1332   GrapplerItem item;
1333   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1334 
1335   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1336   GraphDef output;
1337   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1338   TF_EXPECT_OK(status);
1339   int found = 0;
1340   for (const auto& node : output.node()) {
1341     EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1342               node.name());
1343     EXPECT_NE(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1344               node.name());
1345     if (node.name() == "i1a" || node.name() == "i1b") {
1346       ++found;
1347       EXPECT_EQ("s", node.input(0));
1348     }
1349     if (node.name() == "i2a" || node.name() == "i2b" || node.name() == "i2c") {
1350       ++found;
1351       EXPECT_EQ("s:1", node.input(0));
1352     }
1353     if (node.name() == "i3a" || node.name() == "i3b") {
1354       ++found;
1355       EXPECT_EQ(AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst),
1356                 node.input(0));
1357     }
1358     if (node.name() == "s") {
1359       ++found;
1360       EXPECT_EQ("ShapeN", node.op());
1361       EXPECT_EQ("v1", node.input(0));
1362       EXPECT_EQ("v2", node.input(1));
1363       EXPECT_EQ("v3", node.input(2));
1364     }
1365     if (node.name() ==
1366         AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst)) {
1367       ++found;
1368       EXPECT_EQ("Const", node.op());
1369       EXPECT_EQ("^s", node.input(0));
1370       Tensor value;
1371       CHECK(value.FromProto(node.attr().at("value").tensor()));
1372       EXPECT_EQ(4, value.flat<int>()(0));
1373       EXPECT_EQ(6, value.flat<int>()(1));
1374     }
1375   }
1376   EXPECT_EQ(9, found);
1377 
1378   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1379   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 6}));
1380   auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1381   const std::vector<string> fetch_nodes = {"i1a", "i1b", "i2a", "i2b",
1382                                            "i2c", "i3a", "i3b"};
1383   auto tensors_expected = EvaluateNodes(
1384       item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1385   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
1386   auto tensors = EvaluateNodes(output, fetch_nodes,
1387                                {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
1388   EXPECT_EQ(fetch_nodes.size(), tensors.size());
1389   for (int i = 0; i < fetch_nodes.size(); i++)
1390     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1391 }
1392 
TEST_F(ConstantFoldingTest,ShapeMaterializationShapeN_MultipleOutputs)1393 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN_MultipleOutputs) {
1394   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1395   Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
1396   Output v2 = ops::Variable(scope.WithOpName("v2"), {4, 6}, DT_FLOAT);
1397   auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2});
1398   auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {s[0], s[1]});
1399   Output ia = ops::Identity(scope.WithOpName("ia"), id_n[0]);
1400   Output ib = ops::Identity(scope.WithOpName("ib"), id_n[1]);
1401 
1402   GrapplerItem item;
1403   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1404   item.fetch.push_back("ia");
1405   item.fetch.push_back("ib");
1406 
1407   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1408   GraphDef output;
1409   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1410   TF_EXPECT_OK(status);
1411 
1412   int found = 0;
1413   for (const auto& node : output.node()) {
1414     EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
1415               node.name());
1416     if (node.name() == "s") {
1417       ++found;
1418       EXPECT_EQ("ShapeN", node.op());
1419       EXPECT_EQ("v1", node.input(0));
1420       EXPECT_EQ("v2", node.input(1));
1421     }
1422     if (node.name() == "id_n") {
1423       ++found;
1424       EXPECT_EQ("IdentityN", node.op());
1425       EXPECT_EQ("s", node.input(0));
1426       EXPECT_EQ(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
1427                 node.input(1));
1428     }
1429     if (node.name() == "ia") {
1430       ++found;
1431       EXPECT_EQ("id_n", node.input(0));
1432     }
1433     if (node.name() == "ib") {
1434       ++found;
1435       EXPECT_EQ("Const", node.op());
1436       EXPECT_EQ("^s", node.input(0));
1437       EXPECT_EQ("^id_n", node.input(1));
1438     }
1439   }
1440   EXPECT_EQ(4, found);
1441 
1442   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
1443   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1444   auto tensors_expected =
1445       EvaluateNodes(item.graph, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1446   EXPECT_EQ(2, tensors_expected.size());
1447   auto tensors =
1448       EvaluateNodes(output, item.fetch, {{"v1", v1_t}, {"v2", v2_t}});
1449   EXPECT_EQ(2, tensors.size());
1450   for (int i = 0; i < tensors.size(); i++)
1451     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
1452 }
1453 
TEST_F(ConstantFoldingTest,SwitchNodesEmptyFetch)1454 TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
1455   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1456   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1457   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1458   ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1459   ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1460   ops::Identity i(scope.WithOpName("i"), s1.output_true);
1461   ops::Size size(scope.WithOpName("size"), i);
1462   ops::Square p1(scope.WithOpName("p1"), rank);
1463   ops::Square p2(scope.WithOpName("p2"), size);
1464   ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1465 
1466   Output predicate =
1467       ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1468   Output constant =
1469       ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1470   ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1471   ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1472   ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1473   ops::Merge m2(scope.WithOpName("m2"),
1474                 {statically_known.output, never_generated.output});
1475 
1476   GrapplerItem item;
1477   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1478 
1479   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1480   GraphDef output;
1481   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1482   TF_EXPECT_OK(status);
1483 
1484   std::set<string> present_nodes = {"v_in",     "v_ctrl",
1485                                     "switch",   "i",
1486                                     "p1",       "p2",
1487                                     "m",        "false",
1488                                     "constant", "switch2",
1489                                     "i2",       "i3",
1490                                     "m2",       "ConstantFoldingCtrl/switch_0",
1491                                     "rank",     "size"};
1492   std::set<string> not_present_nodes = {"ConstantFolding/switch2-0"};
1493   EXPECT_EQ(present_nodes.size(), output.node_size());
1494   int found = 0;
1495   for (const auto& node : output.node()) {
1496     EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end())
1497         << node.name();
1498     EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end())
1499         << node.name();
1500     present_nodes.erase(node.name());
1501     not_present_nodes.erase(node.name());
1502     if (node.name() == "rank") {
1503       ++found;
1504       EXPECT_EQ("Const", node.op());
1505       EXPECT_EQ(1, node.input_size());
1506       EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
1507     }
1508     if (node.name() == "size") {
1509       ++found;
1510       EXPECT_EQ("Const", node.op());
1511       EXPECT_EQ(1, node.input_size());
1512       EXPECT_EQ("^i", node.input(0));
1513     }
1514     if (node.name() == "i2") {
1515       ++found;
1516       EXPECT_EQ("Const", node.op());
1517       EXPECT_EQ(0, node.input_size());
1518     }
1519     if (node.name() == "i3") {
1520       ++found;
1521       EXPECT_EQ("Identity", node.op());
1522       EXPECT_EQ(1, node.input_size());
1523       EXPECT_EQ("switch2:1", node.input(0));
1524     }
1525   }
1526   EXPECT_EQ(4, found);
1527 
1528   auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1529   Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1530 
1531   v_ctrl_t.flat<bool>()(0) = true;
1532   std::vector<string> fetch_nodes = {"m", "m2"};
1533   auto tensors_expected = EvaluateNodes(
1534       item.graph, fetch_nodes, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1535   EXPECT_EQ(2, tensors_expected.size());
1536   auto tensors = EvaluateNodes(output, fetch_nodes,
1537                                {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1538   EXPECT_EQ(2, tensors.size());
1539   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1540   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1541 
1542   v_ctrl_t.flat<bool>()(0) = false;
1543   tensors_expected = EvaluateNodes(item.graph, fetch_nodes,
1544                                    {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1545   EXPECT_EQ(2, tensors_expected.size());
1546   tensors = EvaluateNodes(output, fetch_nodes,
1547                           {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1548   EXPECT_EQ(2, tensors.size());
1549   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1550   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1551 }
1552 
TEST_F(ConstantFoldingTest,SwitchNodes)1553 TEST_F(ConstantFoldingTest, SwitchNodes) {
1554   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1555   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
1556   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
1557   ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
1558   ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
1559   ops::Identity i(scope.WithOpName("i"), s1.output_true);
1560   ops::Size size(scope.WithOpName("size"), i);
1561   ops::Square p1(scope.WithOpName("p1"), rank);
1562   ops::Square p2(scope.WithOpName("p2"), size);
1563   ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
1564 
1565   Output predicate =
1566       ops::Const(scope.WithOpName("false"), false, TensorShape({}));
1567   Output constant =
1568       ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
1569   ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
1570   ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
1571   ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
1572   ops::Merge m2(scope.WithOpName("m2"),
1573                 {statically_known.output, never_generated.output});
1574 
1575   GrapplerItem item;
1576   item.fetch.push_back("m");
1577   item.fetch.push_back("m2");
1578 
1579   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1580 
1581   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1582   GraphDef output;
1583   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1584   TF_EXPECT_OK(status);
1585   std::set<string> present_nodes = {"v_in",     "v_ctrl",
1586                                     "switch",   "i",
1587                                     "p1",       "p2",
1588                                     "m",        "false",
1589                                     "constant", "switch2",
1590                                     "i2",       "i3",
1591                                     "m2",       "ConstantFoldingCtrl/switch_0"};
1592   std::set<string> not_present_nodes = {"rank", "size",
1593                                         "ConstantFolding/switch2-0"};
1594   EXPECT_EQ(present_nodes.size(), output.node_size());
1595 
1596   int found = 0;
1597   for (const auto& node : output.node()) {
1598     EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end());
1599     EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end());
1600     present_nodes.erase(node.name());
1601     not_present_nodes.erase(node.name());
1602     if (node.name() == "i2") {
1603       ++found;
1604       EXPECT_EQ("Const", node.op());
1605       EXPECT_EQ(0, node.input_size());
1606     }
1607     if (node.name() == "i3") {
1608       ++found;
1609       EXPECT_EQ("Identity", node.op());
1610       EXPECT_EQ(1, node.input_size());
1611       EXPECT_EQ("switch2:1", node.input(0));
1612     }
1613   }
1614   EXPECT_EQ(2, found);
1615 
1616   auto v_in_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3}));
1617   Tensor v_ctrl_t(DT_BOOL, TensorShape({}));
1618   v_ctrl_t.flat<bool>()(0) = true;
1619   auto tensors_expected = EvaluateNodes(
1620       item.graph, item.fetch, {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1621   EXPECT_EQ(2, tensors_expected.size());
1622   auto tensors = EvaluateNodes(output, item.fetch,
1623                                {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1624   EXPECT_EQ(2, tensors.size());
1625   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1626   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1627 
1628   v_ctrl_t.flat<bool>()(0) = false;
1629   tensors_expected = EvaluateNodes(item.graph, item.fetch,
1630                                    {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1631   EXPECT_EQ(2, tensors_expected.size());
1632   tensors = EvaluateNodes(output, item.fetch,
1633                           {{"v_in", v_in_t}, {"v_ctrl", v_ctrl_t}});
1634   EXPECT_EQ(2, tensors.size());
1635   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
1636   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
1637 }
1638 
TEST_F(ConstantFoldingTest,MergeNodes)1639 TEST_F(ConstantFoldingTest, MergeNodes) {
1640   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1641 
1642   Output x =
1643       ops::RandomNormal(scope.WithOpName("x"), {3, 5}, DataType::DT_FLOAT);
1644   Output y =
1645       ops::RandomNormal(scope.WithOpName("y"), {3, 5}, DataType::DT_FLOAT);
1646   Output const1 =
1647       ops::Const(scope.WithOpName("const1").WithControlDependencies(x), 2.7f,
1648                  TensorShape({3, 5}));
1649   Output const2 =
1650       ops::Const(scope.WithOpName("const2"), 3.14f, TensorShape({3, 5}));
1651   Output const3 =
1652       ops::Const(scope.WithOpName("const3").WithControlDependencies(x), 3.14f,
1653                  TensorShape({3, 5}));
1654 
1655   // Create 3 merge nodes: m1 is foldable, m2 and m3 aren't.
1656   ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2});
1657   ops::Merge m2(scope.WithOpName("m2"), {const1, const3});
1658   ops::Merge m3(scope.WithOpName("m3"), {x, y});
1659   // m4 is not foldable because the only constant input
1660   // has a control input, so we cannot know if it will be
1661   // triggered.
1662   ops::Merge m4(scope.WithOpName("m4"), {x, const1});
1663 
1664   ops::Identity out1(scope.WithOpName("out1"), m1.output);
1665   ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index);
1666   ops::Identity out2(scope.WithOpName("out2"), m2.output);
1667   ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index);
1668   ops::Identity out3(scope.WithOpName("out3"), m3.output);
1669   ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index);
1670   ops::Identity out4(scope.WithOpName("out4"), m4.output);
1671   ops::Identity idx4(scope.WithOpName("idx4"), m4.value_index);
1672 
1673   GrapplerItem item;
1674   item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3", "out4", "idx4"};
1675   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1676 
1677   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1678   GraphDef output;
1679   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
1680   TF_EXPECT_OK(status);
1681 
1682   EXPECT_EQ(19, output.node_size());
1683   int found_nodes = 0;
1684   for (const auto& node : output.node()) {
1685     if (node.name() == "out1") {
1686       EXPECT_EQ(1, node.input_size());
1687       EXPECT_EQ("^m1", node.input(0));
1688       ++found_nodes;
1689     } else if (node.name() == "idx1") {
1690       EXPECT_EQ(1, node.input_size());
1691       EXPECT_EQ("^m1", node.input(0));
1692       ++found_nodes;
1693     } else if (node.name() == "ConstantFolding/m1") {
1694       EXPECT_EQ("Const", node.op());
1695       EXPECT_EQ(1, node.input_size());
1696       EXPECT_EQ("^m1", node.input(0));
1697       ++found_nodes;
1698     } else if (node.name() == "ConstantFolding/m1_index") {
1699       EXPECT_EQ("Const", node.op());
1700       EXPECT_EQ(1, node.input_size());
1701       EXPECT_EQ("^m1", node.input(0));
1702       ++found_nodes;
1703     } else if (node.name() == "out2") {
1704       EXPECT_EQ(1, node.input_size());
1705       EXPECT_EQ("m2", node.input(0));
1706       ++found_nodes;
1707     } else if (node.name() == "idx2") {
1708       EXPECT_EQ(1, node.input_size());
1709       EXPECT_EQ("m2:1", node.input(0));
1710       ++found_nodes;
1711     } else if (node.name() == "out3") {
1712       EXPECT_EQ(1, node.input_size());
1713       EXPECT_EQ("m3", node.input(0));
1714       ++found_nodes;
1715     } else if (node.name() == "idx3") {
1716       EXPECT_EQ(1, node.input_size());
1717       EXPECT_EQ("m3:1", node.input(0));
1718       ++found_nodes;
1719     } else if (node.name() == "out4") {
1720       EXPECT_EQ(1, node.input_size());
1721       EXPECT_EQ("m4", node.input(0));
1722       ++found_nodes;
1723     } else if (node.name() == "idx4") {
1724       EXPECT_EQ(1, node.input_size());
1725       EXPECT_EQ("m4:1", node.input(0));
1726       ++found_nodes;
1727     }
1728   }
1729   // Make sure the graph contains all the nodes we're expecting.
1730   EXPECT_EQ(8, found_nodes);
1731 
1732   std::vector<string> fetch = {"out1", "idx1"};
1733   auto tensors = EvaluateNodes(output, fetch);
1734   EXPECT_EQ(2, tensors.size());
1735   const Tensor& out_value = tensors[0];
1736   EXPECT_EQ(3 * 5, out_value.NumElements());
1737   for (int i = 0; i < 3 * 5; ++i) {
1738     EXPECT_EQ(3.14f, out_value.flat<float>()(i));
1739   }
1740   const Tensor& out_idx = tensors[1];
1741   EXPECT_EQ(1, out_idx.NumElements());
1742   EXPECT_EQ(2, out_idx.flat<int32>()(0));
1743 }
1744 
TEST_F(ConstantFoldingTest,SplitRemoval)1745 TEST_F(ConstantFoldingTest, SplitRemoval) {
1746   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1747 
1748   Output in1 =
1749       ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
1750   Output in2 =
1751       ops::Variable(scope.WithOpName("in2"), TensorShape({4}), DT_FLOAT);
1752   auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
1753   ops::Split s1(scope.WithOpName("s1"), split_dim, in1, 1);
1754   ops::Split s2(scope.WithOpName("s2"), split_dim, in2, 2);
1755 
1756   ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
1757 
1758   GrapplerItem item;
1759   item.fetch = {"out"};
1760   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1761 
1762   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1763   GraphDef got;
1764   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1765   TF_EXPECT_OK(status);
1766 
1767   GraphDef want;
1768   AddNode("in1", "VariableV2", {}, {}, &want);
1769   AddNode("in2", "VariableV2", {}, {}, &want);
1770   AddNode("split_dim", "Const", {}, {}, &want);
1771   AddNode("s1", "Identity", {"in1", AsControlDependency("split_dim")}, {},
1772           &want);
1773   AddNode("s2", "Split", {"split_dim", "in2"}, {}, &want);
1774   AddNode("out", "Add", {"s1", "s2"}, {}, &want);
1775 
1776   CompareGraphs(want, got);
1777 
1778   auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
1779   auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4}));
1780   auto tensors_expected =
1781       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1782   EXPECT_EQ(1, tensors_expected.size());
1783   auto tensors =
1784       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1785   EXPECT_EQ(1, tensors.size());
1786   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
1787 }
1788 
TEST_F(ConstantFoldingTest,SplitVRemoval)1789 TEST_F(ConstantFoldingTest, SplitVRemoval) {
1790   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1791 
1792   Output in1 =
1793       ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
1794   Output in2 =
1795       ops::Variable(scope.WithOpName("in2"), TensorShape({5}), DT_FLOAT);
1796   auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
1797   auto size_splits1 = ops::Const(scope.WithOpName("size_splits1"), {2}, {1});
1798   auto size_splits2 = ops::Const(scope.WithOpName("size_splits2"), {2, 3}, {2});
1799   ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
1800   ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
1801 
1802   ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
1803 
1804   GrapplerItem item;
1805   item.fetch = {"out"};
1806   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1807 
1808   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1809   GraphDef got;
1810   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1811   TF_EXPECT_OK(status);
1812 
1813   GraphDef want;
1814   AddNode("in1", "VariableV2", {}, {}, &want);
1815   AddNode("in2", "VariableV2", {}, {}, &want);
1816   AddNode("split_dim", "Const", {}, {}, &want);
1817   AddNode("size_splits1", "Const", {}, {}, &want);
1818   AddNode("size_splits2", "Const", {}, {}, &want);
1819   AddNode("s1", "Identity",
1820           {"in1", AsControlDependency("size_splits1"),
1821            AsControlDependency("split_dim")},
1822           {}, &want);
1823   AddNode("s2", "SplitV", {"in2", "size_splits2", "split_dim"}, {}, &want);
1824   AddNode("out", "Add", {"s1", "s2"}, {}, &want);
1825 
1826   CompareGraphs(want, got);
1827 
1828   auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
1829   auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5}));
1830   auto tensors_expected =
1831       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1832   EXPECT_EQ(1, tensors_expected.size());
1833   auto tensors =
1834       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1835   EXPECT_EQ(1, tensors.size());
1836   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
1837 }
1838 
TEST_F(ConstantFoldingTest,TransposeOnSize1DimsRemoval)1839 TEST_F(ConstantFoldingTest, TransposeOnSize1DimsRemoval) {
1840   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1841 
1842   Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
1843                              DT_FLOAT);
1844   Output p1 = ops::Const(scope.WithOpName("p1"), {3, 2, 1, 0}, {4});
1845   Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 4, 2, 1}),
1846                              DT_FLOAT);
1847   Output p2 = ops::Const(scope.WithOpName("p2"), {3, 1, 2, 0}, {4});
1848   ops::Transpose t1(scope.WithOpName("t1"), in1, p1);
1849   ops::Transpose t2(scope.WithOpName("t2").WithControlDependencies({in1}), in2,
1850                     p2);
1851 
1852   ops::Add out1(scope.WithOpName("out1"), t1, t2);
1853 
1854   GrapplerItem item;
1855   item.fetch = {"out1"};
1856   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1857 
1858   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1859   GraphDef got;
1860   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1861   TF_EXPECT_OK(status);
1862 
1863   GraphDef want;
1864   AddNode("in1", "VariableV2", {}, {}, &want);
1865   AddNode("in2", "VariableV2", {}, {}, &want);
1866   AddNode("p1", "Const", {}, {}, &want);
1867   AddNode("p2", "Const", {}, {}, &want);
1868   AddNode("t1", "Transpose", {"in1", "p1"}, {}, &want);
1869   AddNode("t2", "Identity",
1870           {"in2", AsControlDependency("in1"), AsControlDependency("p2")}, {},
1871           &want);
1872   AddNode("out1", "Add", {"t1", "t2"}, {}, &want);
1873 
1874   CompareGraphs(want, got);
1875 }
1876 
TEST_F(ConstantFoldingTest,RandomShuffleOnScalarRemoval)1877 TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
1878   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1879 
1880   Output in1 =
1881       ops::Variable(scope.WithOpName("in1"), TensorShape({}), DT_FLOAT);
1882   Output in2 =
1883       ops::Variable(scope.WithOpName("in2"), TensorShape({}), DT_FLOAT);
1884   ops::RandomShuffle s1(scope.WithOpName("s1"), in1);
1885   ops::RandomShuffle s2(scope.WithOpName("s2").WithControlDependencies({in1}),
1886                         in2);
1887 
1888   ops::Add out1(scope.WithOpName("out1"), s1, s2);
1889   ops::Identity out2(scope.WithOpName("out2"), s2);
1890 
1891   GrapplerItem item;
1892   item.fetch = {"out1", "out2"};
1893   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1894 
1895   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1896   GraphDef got;
1897   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1898   TF_EXPECT_OK(status);
1899 
1900   GraphDef want;
1901   AddNode("in1", "VariableV2", {}, {}, &want);
1902   AddNode("in2", "VariableV2", {}, {}, &want);
1903   AddNode("s1", "Identity", {"in1"}, {}, &want);
1904   AddNode("s2", "Identity", {"in2", AsControlDependency("in1")}, {}, &want);
1905   AddNode("out1", "Add", {"s1", "s2"}, {}, &want);
1906   AddNode("out2", "Identity", {"s2"}, {}, &want);
1907 
1908   CompareGraphs(want, got);
1909 
1910   auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
1911   auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
1912   auto tensors_expected =
1913       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1914   EXPECT_EQ(2, tensors_expected.size());
1915   auto tensors =
1916       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1917   EXPECT_EQ(2, tensors.size());
1918   for (int i = 0; i < tensors.size(); i++)
1919     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
1920 }
1921 
TEST_F(ConstantFoldingTest,ReverseOnSize1DimsRemoval)1922 TEST_F(ConstantFoldingTest, ReverseOnSize1DimsRemoval) {
1923   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1924 
1925   Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
1926                              DT_FLOAT);
1927   Output a1 = ops::Const(scope.WithOpName("a1"), {3, 2, 1, 0}, {4});
1928   Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 2, 4, 1}),
1929                              DT_FLOAT);
1930   Output a2 = ops::Const(scope.WithOpName("a2"), {0, 3}, {2});
1931   ops::Reverse r1(scope.WithOpName("r1"), in1, a1);
1932   ops::Reverse r2(scope.WithOpName("r2").WithControlDependencies({in1}), in2,
1933                   a2);
1934 
1935   ops::Add out1(scope.WithOpName("out1"), r1, r2);
1936 
1937   GrapplerItem item;
1938   item.fetch = {"out1"};
1939   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1940 
1941   ConstantFolding optimizer(/*cpu_device=*/nullptr);
1942   GraphDef got;
1943   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1944   TF_EXPECT_OK(status);
1945 
1946   GraphDef want;
1947   AddNode("in1", "VariableV2", {}, {}, &want);
1948   AddNode("in2", "VariableV2", {}, {}, &want);
1949   AddNode("a1", "Const", {}, {}, &want);
1950   AddNode("a2", "Const", {}, {}, &want);
1951   AddNode("r1", "ReverseV2", {"in1", "a1"}, {}, &want);
1952   AddNode("r2", "Identity",
1953           {"in2", AsControlDependency("in1"), AsControlDependency("a2")}, {},
1954           &want);
1955   AddNode("out1", "Add", {"r1", "r2"}, {}, &want);
1956 
1957   CompareGraphs(want, got);
1958 }
1959 
TEST_F(ConstantFoldingTest,SliceWithSameDimensionRemoval)1960 TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) {
1961   {  // size = {3, 5}
1962     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
1963 
1964     auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5}, DT_FLOAT);
1965     auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
1966     auto size = ops::Const(scope.WithOpName("size"), {3, 5}, {2});
1967     Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
1968     ops::Slice s1(scope.WithOpName("s1"), in1, begin, size);
1969     ops::Slice s2(scope.WithOpName("s2"), in2, begin, size);
1970 
1971     ops::Add out(scope.WithOpName("out"), s1, s2);
1972 
1973     GrapplerItem item;
1974     item.fetch = {"out"};
1975     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
1976 
1977     ConstantFolding optimizer(/*cpu_device=*/nullptr);
1978     GraphDef got;
1979     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
1980     TF_EXPECT_OK(status);
1981 
1982     GraphDef want;
1983     AddNode("in1", "VariableV2", {}, {}, &want);
1984     AddNode("in2", "VariableV2", {}, {}, &want);
1985     AddNode("begin", "Const", {}, {}, &want);
1986     AddNode("size", "Const", {}, {}, &want);
1987     AddNode("s1", "Identity",
1988             {"in1", AsControlDependency("begin"), AsControlDependency("size")},
1989             {}, &want);
1990     AddNode("s2", "Slice", {"in2", "begin", "size"}, {}, &want);
1991     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
1992 
1993     CompareGraphs(want, got);
1994 
1995     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
1996     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
1997     auto tensors_expected =
1998         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
1999     EXPECT_EQ(1, tensors_expected.size());
2000     auto tensors =
2001         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2002     EXPECT_EQ(1, tensors.size());
2003     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2004   }
2005   {  // size = {-1, -1}
2006     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2007 
2008     auto in1 =
2009         ops::Variable(scope.WithOpName("in1"), {3, 5}, DataType::DT_FLOAT);
2010     auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0}, {2});
2011     auto begin2 = ops::Const(scope.WithOpName("begin2"), {1, 1}, {2});
2012     auto size = ops::Const(scope.WithOpName("size"), {-1, -1}, {2});
2013     Output in2 =
2014         ops::Variable(scope.WithOpName("in2"), {4, 6}, DataType::DT_FLOAT);
2015     ops::Slice s1(scope.WithOpName("s1"), in1, begin1, size);
2016     ops::Slice s2(scope.WithOpName("s2"), in2, begin2, size);
2017 
2018     ops::Add out(scope.WithOpName("out"), s1, s2);
2019 
2020     GrapplerItem item;
2021     item.fetch = {"out"};
2022     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2023 
2024     ConstantFolding optimizer(/*cpu_device=*/nullptr);
2025     GraphDef got;
2026     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2027     TF_EXPECT_OK(status);
2028 
2029     GraphDef want;
2030     AddNode("in1", "VariableV2", {}, {}, &want);
2031     AddNode("in2", "VariableV2", {}, {}, &want);
2032     AddNode("begin1", "Const", {}, {}, &want);
2033     AddNode("begin2", "Const", {}, {}, &want);
2034     AddNode("size", "Const", {}, {}, &want);
2035     AddNode("s1", "Identity",
2036             {"in1", AsControlDependency("begin1"), AsControlDependency("size")},
2037             {}, &want);
2038     AddNode("s2", "Slice", {"in2", "begin2", "size"}, {}, &want);
2039     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2040 
2041     CompareGraphs(want, got);
2042 
2043     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5}));
2044     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
2045     auto tensors_expected =
2046         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2047     EXPECT_EQ(1, tensors_expected.size());
2048     auto tensors =
2049         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2050     EXPECT_EQ(1, tensors.size());
2051     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2052   }
2053 }
2054 
TEST_F(ConstantFoldingTest,StridedSliceWithSameDimensionRemoval)2055 TEST_F(ConstantFoldingTest, StridedSliceWithSameDimensionRemoval) {
2056   {  // no mask
2057     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2058 
2059     auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5, 2}, DT_FLOAT);
2060     auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
2061     auto end = ops::Const(scope.WithOpName("end"), {3, 5}, {2});
2062     auto strides = ops::Const(scope.WithOpName("strides"), {1, 1}, {2});
2063     Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6, 2}, DT_FLOAT);
2064     ops::StridedSlice s1(scope.WithOpName("s1"), in1, begin, end, strides);
2065     ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin, end, strides);
2066 
2067     ops::Add out(scope.WithOpName("out"), s1, s2);
2068 
2069     GrapplerItem item;
2070     item.fetch = {"out"};
2071     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2072 
2073     ConstantFolding optimizer(/*cpu_device=*/nullptr);
2074     GraphDef got;
2075     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2076     TF_EXPECT_OK(status);
2077 
2078     GraphDef want;
2079     AddNode("in1", "VariableV2", {}, {}, &want);
2080     AddNode("in2", "VariableV2", {}, {}, &want);
2081     AddNode("begin", "Const", {}, {}, &want);
2082     AddNode("end", "Const", {}, {}, &want);
2083     AddNode("strides", "Const", {}, {}, &want);
2084     AddNode("s1", "Identity",
2085             {"in1", AsControlDependency("begin"), AsControlDependency("end"),
2086              AsControlDependency("strides")},
2087             {}, &want);
2088     AddNode("s2", "StridedSlice", {"in2", "begin", "end", "strides"}, {},
2089             &want);
2090     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2091 
2092     CompareGraphs(want, got);
2093 
2094     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 2}));
2095     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6, 2}));
2096     auto tensors_expected =
2097         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2098     EXPECT_EQ(1, tensors_expected.size());
2099     auto tensors =
2100         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2101     EXPECT_EQ(1, tensors.size());
2102     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2103   }
2104   {  // with begin/end/ellipsis mask
2105     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2106 
2107     // s1 = in1[:, ..., 0:5, 0:6]
2108     auto in1 =
2109         ops::Variable(scope.WithOpName("in1"), {2, 3, 4, 5, 6}, DT_FLOAT);
2110     auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0, 0}, {3});
2111     auto end1 = ops::Const(scope.WithOpName("end1"), {0, 5, 6}, {3});
2112     auto strides1 = ops::Const(scope.WithOpName("strides1"), {1, 1, 1}, {3});
2113     ops::StridedSlice s1(
2114         scope.WithOpName("s1"), in1, begin1, end1, strides1,
2115         ops::StridedSlice::Attrs().BeginMask(1).EndMask(1).EllipsisMask(2));
2116 
2117     Output in2 =
2118         ops::Variable(scope.WithOpName("in2"), {5, 8, 5, 6, 9}, DT_FLOAT);
2119     auto begin2 = ops::Const(scope.WithOpName("begin2"), {0, 0, 0, 0, 0}, {5});
2120     auto end2 = ops::Const(scope.WithOpName("end2"), {2, 3, 4, 5, 6}, {5});
2121     auto strides2 =
2122         ops::Const(scope.WithOpName("strides2"), {1, 1, 1, 1, 1}, {5});
2123     ops::StridedSlice s2(scope.WithOpName("s2"), in2, begin2, end2, strides2);
2124 
2125     ops::Add out(scope.WithOpName("out"), s1, s2);
2126 
2127     GrapplerItem item;
2128     item.fetch = {"out"};
2129     TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2130 
2131     ConstantFolding optimizer(/*cpu_device=*/nullptr);
2132     GraphDef got;
2133     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2134     TF_EXPECT_OK(status);
2135 
2136     GraphDef want;
2137     AddNode("in1", "VariableV2", {}, {}, &want);
2138     AddNode("in2", "VariableV2", {}, {}, &want);
2139     AddNode("begin1", "Const", {}, {}, &want);
2140     AddNode("end1", "Const", {}, {}, &want);
2141     AddNode("strides1", "Const", {}, {}, &want);
2142     AddNode("s1", "Identity",
2143             {"in1", AsControlDependency("begin1"), AsControlDependency("end1"),
2144              AsControlDependency("strides1")},
2145             {}, &want);
2146     AddNode("begin2", "Const", {}, {}, &want);
2147     AddNode("end2", "Const", {}, {}, &want);
2148     AddNode("strides2", "Const", {}, {}, &want);
2149     AddNode("s2", "StridedSlice", {"in2", "begin2", "end2", "strides2"}, {},
2150             &want);
2151     AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2152 
2153     CompareGraphs(want, got);
2154 
2155     auto in1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3, 4, 5, 6}));
2156     auto in2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 8, 5, 6, 9}));
2157     auto tensors_expected =
2158         EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2159     EXPECT_EQ(1, tensors_expected.size());
2160     auto tensors =
2161         EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2162     EXPECT_EQ(1, tensors.size());
2163     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2164   }
2165 }
2166 
TEST_F(ConstantFoldingTest,TileWithMultipliesBeingOne)2167 TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) {
2168   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2169 
2170   auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2171   auto in2 = ops::Variable(scope.WithOpName("in2"), {4, 3}, DT_FLOAT);
2172   auto multiplies1 = ops::Const(scope.WithOpName("multiplies1"), {1, 1}, {2});
2173   auto multiplies2 = ops::Const(scope.WithOpName("multiplies2"), {1, 2}, {2});
2174 
2175   ops::Tile t1(scope.WithOpName("t1"), in1, multiplies1);
2176   ops::Tile t2(scope.WithOpName("t2"), in2, multiplies2);
2177 
2178   ops::Add out(scope.WithOpName("out"), t1, t2);
2179 
2180   GrapplerItem item;
2181   item.fetch = {"out"};
2182   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2183 
2184   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2185   GraphDef got;
2186   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2187   TF_EXPECT_OK(status);
2188 
2189   GraphDef want;
2190   AddNode("in1", "VariableV2", {}, {}, &want);
2191   AddNode("in2", "VariableV2", {}, {}, &want);
2192   AddNode("multiplies1", "Const", {}, {}, &want);
2193   AddNode("multiplies2", "Const", {}, {}, &want);
2194   AddNode("t1", "Identity", {"in1", AsControlDependency("multiplies1")}, {},
2195           &want);
2196   AddNode("t2", "Tile", {"in2", "multiplies2"}, {}, &want);
2197   AddNode("out", "Add", {"t1", "t2"}, {}, &want);
2198 
2199   CompareGraphs(want, got);
2200 }
2201 
TEST_F(ConstantFoldingTest,MergeConcat)2202 TEST_F(ConstantFoldingTest, MergeConcat) {
2203   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2204 
2205   Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2206   Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2207   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2208   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2209 
2210   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2211   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2212 
2213   GrapplerItem item;
2214   item.fetch = {"c2"};
2215   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2216 
2217   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2218   GraphDef got;
2219   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2220   TF_EXPECT_OK(status);
2221 
2222   GraphDef want;
2223   AddNode("in1", "VariableV2", {}, {}, &want);
2224   AddNode("in2", "VariableV2", {}, {}, &want);
2225   AddNode("in3", "VariableV2", {}, {}, &want);
2226   AddNode("axis", "Const", {}, {}, &want);
2227   AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2228 
2229   CompareGraphs(want, got);
2230 }
2231 
TEST_F(ConstantFoldingTest,MergeConcat_SameInput)2232 TEST_F(ConstantFoldingTest, MergeConcat_SameInput) {
2233   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2234 
2235   Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
2236   Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
2237   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2238   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2239 
2240   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2241   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3, Output(c1)}, axis);
2242 
2243   GrapplerItem item;
2244   item.fetch = {"c2"};
2245   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2246 
2247   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2248   GraphDef got;
2249   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2250   TF_EXPECT_OK(status);
2251 
2252   GraphDef want;
2253   AddNode("in1", "VariableV2", {}, {}, &want);
2254   AddNode("in2", "VariableV2", {}, {}, &want);
2255   AddNode("in3", "VariableV2", {}, {}, &want);
2256   AddNode("axis", "Const", {}, {}, &want);
2257   AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "in1", "in2", "axis"}, {},
2258           &want);
2259 
2260   CompareGraphs(want, got);
2261 }
2262 
TEST_F(ConstantFoldingTest,MergeConcat_ConcatWithConst)2263 TEST_F(ConstantFoldingTest, MergeConcat_ConcatWithConst) {
2264   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2265 
2266   Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 6}, DT_FLOAT);
2267   Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2268   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2269   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
2270 
2271   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
2272   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
2273 
2274   GrapplerItem item;
2275   item.fetch = {"c2"};
2276   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2277 
2278   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2279   GraphDef got;
2280   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2281   TF_EXPECT_OK(status);
2282 
2283   GraphDef want;
2284   AddNode("in1", "VariableV2", {}, {}, &want);
2285   AddNode("in2", "VariableV2", {}, {}, &want);
2286   AddNode("in3", "VariableV2", {}, {}, &want);
2287   AddNode("axis", "Const", {}, {}, &want);
2288   AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
2289 
2290   CompareGraphs(want, got);
2291 }
2292 
TEST_F(ConstantFoldingTest,MergeConcat_AxisMismatch)2293 TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) {
2294   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2295 
2296   Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 5}, DT_FLOAT);
2297   Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
2298   Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
2299   Output axis1 = ops::Const(scope.WithOpName("axis1"), 0, {});
2300   Output axis2 = ops::Const(scope.WithOpName("axis2"), 1, {});
2301 
2302   ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis2);
2303   ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis1);
2304 
2305   GrapplerItem item;
2306   item.fetch = {"c2"};
2307   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2308 
2309   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2310   GraphDef got;
2311   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2312   TF_EXPECT_OK(status);
2313 
2314   GraphDef want;
2315   AddNode("in1", "VariableV2", {}, {}, &want);
2316   AddNode("in2", "VariableV2", {}, {}, &want);
2317   AddNode("in3", "VariableV2", {}, {}, &want);
2318   AddNode("axis1", "Const", {}, {}, &want);
2319   AddNode("axis2", "Const", {}, {}, &want);
2320   AddNode("c1", "ConcatV2", {"in1", "in2", "axis2"}, {}, &want);
2321   AddNode("c2", "ConcatV2", {"c1", "in3", "axis1"}, {}, &want);
2322 
2323   CompareGraphs(want, got);
2324 }
2325 
TEST_F(ConstantFoldingTest,PaddingWithZeroSize)2326 TEST_F(ConstantFoldingTest, PaddingWithZeroSize) {
2327   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2328 
2329   auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_INT32);
2330   auto in2 = ops::Variable(scope.WithOpName("in2"), {2, 2}, DT_INT32);
2331   auto paddings1 =
2332       ops::Const(scope.WithOpName("paddings1"), {0, 0, 0, 0}, {2, 2});
2333   auto paddings2 =
2334       ops::Const(scope.WithOpName("paddings2"), {1, 1, 2, 2}, {2, 2});
2335   auto c1 = ops::Const(scope.WithOpName("c1"), 1);
2336   auto c2 = ops::Const(scope.WithOpName("c2"), 1);
2337 
2338   ops::PadV2 p1(scope.WithOpName("p1"), in1, paddings1, c1);
2339   ops::PadV2 p2(scope.WithOpName("p2"), in2, paddings2, c2);
2340 
2341   ops::Add out(scope.WithOpName("out"), p1, p2);
2342 
2343   GrapplerItem item;
2344   item.fetch = {"out"};
2345   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2346 
2347   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2348   GraphDef got;
2349   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2350   TF_EXPECT_OK(status);
2351 
2352   GraphDef want;
2353   AddNode("in1", "VariableV2", {}, {}, &want);
2354   AddNode("in2", "VariableV2", {}, {}, &want);
2355   AddNode("paddings1", "Const", {}, {}, &want);
2356   AddNode("paddings2", "Const", {}, {}, &want);
2357   AddNode("c1", "Const", {}, {}, &want);
2358   AddNode("c2", "Const", {}, {}, &want);
2359   AddNode("p1", "Identity",
2360           {"in1", AsControlDependency("paddings1"), AsControlDependency("c1")},
2361           {}, &want);
2362   AddNode("p2", "PadV2", {"in2", "paddings2", "c2"}, {}, &want);
2363   AddNode("out", "Add", {"p1", "p2"}, {}, &want);
2364 
2365   CompareGraphs(want, got);
2366 
2367   auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({4, 6}));
2368   auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 2}));
2369   auto tensors_expected =
2370       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2371   EXPECT_EQ(1, tensors_expected.size());
2372   auto tensors =
2373       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2374   EXPECT_EQ(1, tensors.size());
2375   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
2376 }
2377 
TEST_F(ConstantFoldingTest,SqueezeWithAllDimesionsGreaterThanOne)2378 TEST_F(ConstantFoldingTest, SqueezeWithAllDimesionsGreaterThanOne) {
2379   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2380 
2381   auto in1 = ops::Variable(scope.WithOpName("in1"), {2, 3}, DT_INT32);
2382   auto in2 = ops::Variable(scope.WithOpName("in2"), {1, 2, 3, 1}, DT_INT32);
2383 
2384   ops::Squeeze s1(scope.WithOpName("s1"), in1);
2385   ops::Squeeze s2(scope.WithOpName("s2"), in2);
2386 
2387   ops::Add out(scope.WithOpName("out"), s1, s2);
2388 
2389   GrapplerItem item;
2390   item.fetch = {"out"};
2391   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2392 
2393   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2394   GraphDef got;
2395   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &got);
2396   TF_EXPECT_OK(status);
2397 
2398   GraphDef want;
2399   AddNode("in1", "VariableV2", {}, {}, &want);
2400   AddNode("in2", "VariableV2", {}, {}, &want);
2401   AddNode("s1", "Identity", {"in1"}, {}, &want);
2402   AddNode("s2", "Squeeze", {"in2"}, {}, &want);
2403   AddNode("out", "Add", {"s1", "s2"}, {}, &want);
2404 
2405   CompareGraphs(want, got);
2406 
2407   auto in1_t = GenerateRandomTensor<DT_INT32>(TensorShape({2, 3}));
2408   auto in2_t = GenerateRandomTensor<DT_INT32>(TensorShape({1, 2, 3, 1}));
2409   auto tensors_expected =
2410       EvaluateNodes(item.graph, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2411   EXPECT_EQ(1, tensors_expected.size());
2412   auto tensors =
2413       EvaluateNodes(got, item.fetch, {{"in1", in1_t}, {"in2", in2_t}});
2414   EXPECT_EQ(1, tensors.size());
2415   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
2416 }
2417 
TEST_F(ConstantFoldingTest,NoOpReduction)2418 TEST_F(ConstantFoldingTest, NoOpReduction) {
2419   // Build a simple graph with reductions that can be reduced to the
2420   // identity.
2421   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2422 
2423   Output v = ops::Variable(scope.WithOpName("v"), {3, 5, 7}, DT_FLOAT);
2424   Output c =
2425       ops::Const(scope.WithOpName("c").WithControlDependencies(v), 0, {0});
2426   Output i = ops::Identity(scope.WithOpName("i"), c);
2427   Output p = ops::Prod(scope.WithOpName("p"), v, i);
2428   Output s = ops::Square(scope.WithOpName("s"), p);
2429 
2430   Output v2 = ops::Variable(scope.WithOpName("v2"), {3, 5, 1}, DT_FLOAT);
2431   Output c2 =
2432       ops::Const(scope.WithOpName("c2").WithControlDependencies(v), 2, {1});
2433   ops::Prod::Attrs attr;
2434   attr = attr.KeepDims(true);
2435   Output p2 = ops::Prod(scope.WithOpName("p2"), v2, c2, attr);
2436 
2437   GrapplerItem item;
2438   item.fetch = {"s", "p2"};
2439   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2440 
2441   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2442   GraphDef output;
2443   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2444   TF_EXPECT_OK(status);
2445 
2446   int found = 0;
2447   for (const auto& node : output.node()) {
2448     if (node.name() == "p") {
2449       found++;
2450       EXPECT_EQ("Identity", node.op());
2451       EXPECT_EQ(2, node.input_size());
2452       EXPECT_EQ("v", node.input(0));
2453       EXPECT_EQ("^i", node.input(1));
2454     } else if (node.name() == "p2") {
2455       found++;
2456       EXPECT_EQ("Identity", node.op());
2457       EXPECT_EQ(2, node.input_size());
2458       EXPECT_EQ("v2", node.input(0));
2459       EXPECT_EQ("^c2", node.input(1));
2460     }
2461   }
2462   EXPECT_EQ(2, found);
2463 
2464   auto v_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
2465   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 1}));
2466   auto tensors_expected =
2467       EvaluateNodes(item.graph, item.fetch, {{"v", v_t}, {"v2", v2_t}});
2468   EXPECT_EQ(2, tensors_expected.size());
2469   auto tensors = EvaluateNodes(output, item.fetch, {{"v", v_t}, {"v2", v2_t}});
2470   EXPECT_EQ(2, tensors.size());
2471   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2472   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
2473 }
2474 
TEST_F(ConstantFoldingTest,SingleElementEmptyAxisReduction)2475 TEST_F(ConstantFoldingTest, SingleElementEmptyAxisReduction) {
2476   // Build a simple graph with reductions that involve single-element input and
2477   // no axes to reduce along.
2478   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2479 
2480   Output input_var_three_dim = ops::Variable(
2481       scope.WithOpName("input_var_three_dim"), {1, 1, 1}, DT_FLOAT);
2482   Output input_var_one_dim =
2483       ops::Variable(scope.WithOpName("input_var_one_dim"), {1}, DT_FLOAT);
2484   Output one_axis = ops::Const(scope.WithOpName("one_axis"), {0}, {1});
2485   Output multiple_axes =
2486       ops::Const(scope.WithOpName("multiple_axes"), {1, 0}, {2});
2487   Output variable_axis =
2488       ops::Variable(scope.WithOpName("input_var_axis"), {1}, DT_INT32);
2489   ops::Mean::Attrs attr;
2490   attr = attr.KeepDims(false);
2491   // Should be optimized to Reshape.
2492   Output mean_1 = ops::Mean(scope.WithOpName("mean_1"), input_var_three_dim,
2493                             one_axis, attr.KeepDims(false));
2494   Output mean_2 = ops::Mean(scope.WithOpName("mean_2"), input_var_three_dim,
2495                             multiple_axes, attr.KeepDims(false));
2496   // Should remain as-is, since OutputProperties will not be known this node.
2497   Output mean_3 = ops::Mean(scope.WithOpName("mean_3"), input_var_one_dim,
2498                             one_axis, attr.KeepDims(false));
2499   // Should remain as-is.
2500   Output mean_4 = ops::Mean(scope.WithOpName("mean_4"), input_var_three_dim,
2501                             variable_axis, attr.KeepDims(false));
2502   // Should be optimized to Identity, since KeepDims=true.
2503   Output mean_5 = ops::Mean(scope.WithOpName("mean_5"), input_var_three_dim,
2504                             multiple_axes, attr.KeepDims(true));
2505 
2506   GrapplerItem item;
2507   item.fetch = {"mean_1", "mean_2", "mean_3", "mean_4", "mean_5"};
2508   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2509 
2510   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2511   GraphDef output;
2512   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2513   TF_EXPECT_OK(status);
2514 
2515   // Ensure Mean node is optimized to Reshape.
2516   int found = 0;
2517   for (const auto& node : output.node()) {
2518     if (node.name() == "mean_1" || node.name() == "mean_2") {
2519       found++;
2520       EXPECT_EQ("Reshape", node.op());
2521       EXPECT_EQ(2, node.input_size());
2522       EXPECT_EQ("input_var_three_dim", node.input(0));
2523     } else if (node.name() == "mean_3") {
2524       found++;
2525       EXPECT_EQ("Mean", node.op());
2526       EXPECT_EQ(2, node.input_size());
2527       EXPECT_EQ("input_var_one_dim", node.input(0));
2528     } else if (node.name() == "mean_4") {
2529       found++;
2530       EXPECT_EQ("Mean", node.op());
2531       EXPECT_EQ(2, node.input_size());
2532       EXPECT_EQ("input_var_three_dim", node.input(0));
2533     } else if (node.name() == "mean_5") {
2534       found++;
2535       EXPECT_EQ("Identity", node.op());
2536       EXPECT_EQ(2, node.input_size());
2537       EXPECT_EQ("^multiple_axes", node.input(1));
2538     }
2539   }
2540   EXPECT_EQ(5, found);
2541 
2542   // Ensure resultant values from Mean and Reshape are the same.
2543   auto input_var_three_dim_t =
2544       GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
2545   auto input_var_one_dim_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
2546   Tensor input_var_axis_t(DT_INT32, TensorShape({1}));
2547   input_var_axis_t.flat<int32>()(0) = 0;
2548   auto tensors_expected =
2549       EvaluateNodes(item.graph, item.fetch,
2550                     {{"input_var_three_dim", input_var_three_dim_t},
2551                      {"input_var_one_dim", input_var_one_dim_t},
2552                      {"input_var_axis", input_var_axis_t}});
2553   EXPECT_EQ(5, tensors_expected.size());
2554   auto tensors = EvaluateNodes(output, item.fetch,
2555                                {{"input_var_three_dim", input_var_three_dim_t},
2556                                 {"input_var_one_dim", input_var_one_dim_t},
2557                                 {"input_var_axis", input_var_axis_t}});
2558   EXPECT_EQ(5, tensors.size());
2559   for (int i = 0; i < 5; ++i) {
2560     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2561   }
2562 }
2563 
TEST_F(ConstantFoldingTest,NoOpReshape)2564 TEST_F(ConstantFoldingTest, NoOpReshape) {
2565   // Build a simple graph with a reshape that can be reduced to the identity.
2566   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2567 
2568   // A reshape than can be optimized
2569   Output d1 = ops::Const(scope.WithOpName("d1"), 3.14f, {17});
2570   Output v1 = ops::Variable(scope.WithOpName("v1"), {17}, DT_FLOAT);
2571   Output c1 =
2572       ops::Const(scope.WithOpName("c1").WithControlDependencies(v1), 17, {1});
2573   Output i1 = ops::Identity(scope.WithOpName("i1"), c1);
2574   Output r1 =
2575       ops::Reshape(scope.WithOpName("r1").WithControlDependencies(d1), v1, i1);
2576   Output s1 = ops::Square(scope.WithOpName("s1"), r1);
2577 
2578   // A multi dimensional reshape than can be optimized
2579   Output v3 = ops::Variable(scope.WithOpName("v3"), {5, 5, 5}, DT_FLOAT);
2580   Output c3 =
2581       ops::Const(scope.WithOpName("c3").WithControlDependencies(v3), 5, {3});
2582   Output i3 = ops::Identity(scope.WithOpName("i3"), c3);
2583   Output r3 = ops::Reshape(scope.WithOpName("r3"), v3, i3);
2584   Output s3 = ops::Square(scope.WithOpName("s3"), r3);
2585 
2586   // A multi dimensional partially defined reshape than can be optimized
2587   Output v4 = ops::Variable(scope.WithOpName("v4"), {5, 5, 5}, DT_FLOAT);
2588   Output c4 = ops::Const(scope.WithOpName("c4").WithControlDependencies(v4),
2589                          {5, -1, 5}, {3});
2590   Output i4 = ops::Identity(scope.WithOpName("i4"), c4);
2591   Output r4 = ops::Reshape(scope.WithOpName("r4"), v4, i4);
2592   Output s4 = ops::Square(scope.WithOpName("s4"), r4);
2593 
2594   // A reshape that can't be optimized
2595   Output v2 = ops::Variable(scope.WithOpName("v2"), {17, 1}, DT_FLOAT);
2596   Output c2 =
2597       ops::Const(scope.WithOpName("c2").WithControlDependencies(v2), 17, {1});
2598   Output r2 = ops::Reshape(scope.WithOpName("r2"), v2, c2);
2599   Output s2 = ops::Square(scope.WithOpName("s2"), r2);
2600 
2601   GrapplerItem item;
2602   item.fetch = {"s1", "s2", "s3", "s4"};
2603   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2604 
2605   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2606   GraphDef output;
2607   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2608   TF_EXPECT_OK(status);
2609 
2610   int found = 0;
2611   for (const auto& node : output.node()) {
2612     if (node.name() == "r1") {
2613       ++found;
2614       EXPECT_EQ("Identity", node.op());
2615       ASSERT_EQ(3, node.input_size());
2616       EXPECT_EQ("v1", node.input(0));
2617       EXPECT_EQ("^i1", node.input(1));
2618       EXPECT_EQ("^d1", node.input(2));
2619     } else if (node.name() == "r3") {
2620       ++found;
2621       EXPECT_EQ("Identity", node.op());
2622       ASSERT_EQ(2, node.input_size());
2623       EXPECT_EQ("v3", node.input(0));
2624       EXPECT_EQ("^i3", node.input(1));
2625     } else if (node.name() == "r4") {
2626       ++found;
2627       EXPECT_EQ("Identity", node.op());
2628       ASSERT_EQ(2, node.input_size());
2629       EXPECT_EQ("v4", node.input(0));
2630       EXPECT_EQ("^i4", node.input(1));
2631     } else if (node.name() == "r2") {
2632       ++found;
2633       EXPECT_EQ("Reshape", node.op());
2634       ASSERT_EQ(2, node.input_size());
2635       EXPECT_EQ("v2", node.input(0));
2636       EXPECT_EQ("c2", node.input(1));
2637     }
2638   }
2639   EXPECT_EQ(4, found);
2640 
2641   auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17}));
2642   auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({17, 1}));
2643   auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2644   auto v4_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 5, 5}));
2645   auto tensors_expected =
2646       EvaluateNodes(item.graph, item.fetch,
2647                     {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2648   EXPECT_EQ(4, tensors_expected.size());
2649   auto tensors =
2650       EvaluateNodes(output, item.fetch,
2651                     {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}, {"v4", v4_t}});
2652   EXPECT_EQ(4, tensors.size());
2653   for (int i = 0; i < tensors.size(); i++)
2654     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2655 }
2656 
TEST_F(ConstantFoldingTest,Packing)2657 TEST_F(ConstantFoldingTest, Packing) {
2658   // Build a simple graph with a large constant that can be folded.
2659   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2660   Output c = ops::Const(scope.WithOpName("c"), 3.14f, {1000});
2661   Output i1 = ops::Identity(scope.WithOpName("i1"), c);
2662   Output i2 = ops::Identity(scope.WithOpName("i2"), c);
2663 
2664   GrapplerItem item;
2665   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2666 
2667   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2668   GraphDef output;
2669   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2670   TF_EXPECT_OK(status);
2671 
2672   const std::vector<string> fetch_nodes = {"i1", "i2"};
2673   auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes);
2674   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
2675   auto tensors = EvaluateNodes(output, fetch_nodes);
2676   EXPECT_EQ(fetch_nodes.size(), tensors.size());
2677   for (int i = 0; i < fetch_nodes.size(); i++)
2678     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
2679 
2680   // Make sure that the representation of the folded constant is space
2681   // efficient: in particular, the whole message should be smaller than 8k
2682   // (the size needed to naively encode 1000 floats folded twice).
2683   EXPECT_GT(8000, output.ByteSizeLong());
2684 }
2685 
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs)2686 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
2687   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2688   Output a =
2689       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2690                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2691   Output b = ops::Square(s.WithOpName("b"), a);
2692   Output c = ops::Mul(s.WithOpName("c"), a, b);
2693   Output d = ops::Shape(s.WithOpName("d"), a);
2694   Output e = ops::Shape(s.WithOpName("e"), b);
2695 
2696   auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
2697   Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
2698   Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
2699 
2700   Output g = ops::Placeholder(s.WithOpName("g"), DT_FLOAT,
2701                               ops::Placeholder::Shape(PartialTensorShape({1})));
2702   Output h = ops::Shape(s.WithOpName("h"), g);
2703   auto i = ops::internal::BroadcastGradientArgs(s.WithOpName("i"), d, h);
2704   Output p1 = ops::Identity(s.WithOpName("p1"), i.r0);
2705   Output p2 = ops::Identity(s.WithOpName("p2"), i.r1);
2706 
2707   GrapplerItem item;
2708   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2709 
2710   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2711   GraphDef output;
2712   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2713   TF_EXPECT_OK(status);
2714 
2715   std::vector<string> fetch_nodes = {"o1", "o2", "p1", "p2"};
2716   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 5}));
2717   auto g_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
2718   auto tensors_expected =
2719       EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}, {"g", g_t}});
2720   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
2721 
2722   // Run a second time to make sure the optimization is idempotent.
2723   item.graph.Swap(&output);
2724   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2725   TF_EXPECT_OK(status);
2726 
2727   int found = 0;
2728   for (const auto& node : output.node()) {
2729     if (node.name() == "o1") {
2730       ++found;
2731       EXPECT_EQ(1, node.input_size());
2732       EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
2733     } else if (node.name() == "o2") {
2734       ++found;
2735       EXPECT_EQ(1, node.input_size());
2736       EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
2737     } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
2738       ++found;
2739       EXPECT_EQ("Const", node.op());
2740       EXPECT_EQ(1, node.input_size());
2741       EXPECT_EQ("^f", node.input(0));
2742       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
2743                        .num_elements());
2744     } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
2745       ++found;
2746       EXPECT_EQ("Const", node.op());
2747       EXPECT_EQ(1, node.input_size());
2748       EXPECT_EQ("^f", node.input(0));
2749       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
2750                        .num_elements());
2751     } else if (node.name() == "p1") {
2752       ++found;
2753       EXPECT_EQ(1, node.input_size());
2754       EXPECT_EQ("i", node.input(0));
2755     } else if (node.name() == "p2") {
2756       ++found;
2757       EXPECT_EQ(1, node.input_size());
2758       EXPECT_EQ("i:1", node.input(0));
2759     }
2760   }
2761   EXPECT_EQ(6, found);
2762 
2763   auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}, {"g", g_t}});
2764   EXPECT_EQ(fetch_nodes.size(), tensors.size());
2765   for (int i = 0; i < fetch_nodes.size(); i++)
2766     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
2767 }
2768 
TEST_F(ConstantFoldingTest,MaterializeBroadcastGradientArgs_InfiniteLoop)2769 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
2770   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2771   Output a =
2772       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
2773                        ops::Placeholder::Shape(PartialTensorShape({2, 2})));
2774   Output b = ops::Square(s.WithOpName("b"), a);
2775   Output c = ops::Mul(s.WithOpName("c"), a, b);
2776   Output d = ops::Shape(s.WithOpName("d"), a);
2777   Output e = ops::Shape(s.WithOpName("e"), b);
2778 
2779   auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
2780   Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
2781   Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
2782 
2783   GrapplerItem item;
2784   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2785 
2786   std::vector<string> fetch_nodes = {"o1", "o2"};
2787   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2788   auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}});
2789   EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
2790 
2791   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2792   GraphDef output;
2793   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2794   TF_EXPECT_OK(status);
2795 
2796   // Run a second time to make sure the optimization is idempotent.
2797   item.graph.Swap(&output);
2798   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2799   TF_EXPECT_OK(status);
2800 
2801   EXPECT_EQ(11, output.node_size());
2802   int found = 0;
2803   for (const auto& node : output.node()) {
2804     if (node.name() == "ConstantFolding/f-folded-1") {
2805       ++found;
2806       EXPECT_EQ("Const", node.op());
2807       EXPECT_EQ(2, node.input_size());
2808       EXPECT_EQ("^a", node.input(0));
2809       EXPECT_EQ("^b", node.input(1));
2810     } else if (node.name() == "d") {
2811       ++found;
2812       EXPECT_EQ("Const", node.op());
2813       EXPECT_EQ(1, node.input_size());
2814       EXPECT_EQ("^a", node.input(0));
2815     } else if (node.name() == "e") {
2816       ++found;
2817       EXPECT_EQ("Const", node.op());
2818       EXPECT_EQ(1, node.input_size());
2819       EXPECT_EQ("^b", node.input(0));
2820     } else if (node.name() == "o1") {
2821       ++found;
2822       EXPECT_EQ(1, node.input_size());
2823       EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
2824     } else if (node.name() == "o2") {
2825       ++found;
2826       EXPECT_EQ(1, node.input_size());
2827       EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
2828     } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
2829       ++found;
2830       EXPECT_EQ("Const", node.op());
2831       EXPECT_EQ(1, node.input_size());
2832       EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
2833       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
2834                        .num_elements());
2835     } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
2836       ++found;
2837       EXPECT_EQ("Const", node.op());
2838       EXPECT_EQ(1, node.input_size());
2839       EXPECT_EQ("^ConstantFolding/f-folded-1", node.input(0));
2840       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
2841                        .num_elements());
2842     }
2843   }
2844   EXPECT_EQ(7, found);
2845   auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}});
2846   EXPECT_EQ(fetch_nodes.size(), tensors.size());
2847   for (int i = 0; i < fetch_nodes.size(); i++)
2848     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
2849 }
2850 
TEST_F(ConstantFoldingTest,MaterializeReductionIndices)2851 TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
2852   for (bool use_reshape : {true, false}) {
2853     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2854     Output input =
2855         ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
2856                          ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
2857     // If use_reshape is false, we need to now the number of indices to apply
2858     // the rewrite.
2859     Output indices = ops::Placeholder(
2860         s.WithOpName("indices"), DT_INT32,
2861         ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
2862     Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
2863     if (use_reshape) {
2864       Output size = ops::Const(s.WithOpName("size"), 1, {1});
2865       Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
2866     }
2867 
2868     GrapplerItem item;
2869     TF_CHECK_OK(s.ToGraphDef(&item.graph));
2870     item.fetch.push_back(use_reshape ? "reshape" : "sum");
2871 
2872     auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
2873     Tensor indices_t(DT_INT32, TensorShape({2}));
2874     indices_t.flat<int>()(0) = 0;
2875     indices_t.flat<int>()(1) = 1;
2876     auto tensors_expected = EvaluateNodes(
2877         item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
2878     EXPECT_EQ(1, tensors_expected.size());
2879 
2880     // Use aggressive mode to force the shape inference to propagate placeholder
2881     // shapes.
2882     ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
2883                               /*cpu_device=*/nullptr);
2884     GraphDef output;
2885     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2886     TF_EXPECT_OK(status);
2887 
2888     // Run a second time to make sure the optimization is idempotent.
2889     item.graph.Swap(&output);
2890     status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2891     TF_EXPECT_OK(status);
2892 
2893     int found = 0;
2894     for (const auto& node : output.node()) {
2895       if (node.name() == "ConstantFolding/sum-reduction_indices") {
2896         ++found;
2897         EXPECT_EQ("Const", node.op());
2898         EXPECT_EQ("^indices", node.input(0));
2899         EXPECT_EQ(2,
2900                   TensorShape(node.attr().at("value").tensor().tensor_shape())
2901                       .num_elements());
2902       } else if (node.name() == "sum") {
2903         ++found;
2904         EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
2905       } else if (node.name() == "indices") {
2906         ++found;
2907       }
2908     }
2909     EXPECT_EQ(3, found);
2910 
2911     auto tensors = EvaluateNodes(output, item.fetch,
2912                                  {{"input", input_t}, {"indices", indices_t}});
2913     EXPECT_EQ(1, tensors.size());
2914     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
2915   }
2916 }
2917 
TEST_F(ConstantFoldingTest,MaterializeReductionIndices_NotFullReduction)2918 TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
2919   for (bool input_rank_known : {true, false}) {
2920     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2921     Output input =
2922         (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
2923                                              ops::Placeholder::Shape(
2924                                                  PartialTensorShape({-1, -1})))
2925                           : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
2926     Output indices =
2927         ops::Placeholder(s.WithOpName("indices"), DT_INT32,
2928                          ops::Placeholder::Shape(
2929                              PartialTensorShape({input_rank_known ? 1 : 2})));
2930     Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
2931 
2932     GrapplerItem item;
2933     TF_CHECK_OK(s.ToGraphDef(&item.graph));
2934     item.fetch.push_back("sum");
2935 
2936     // Use aggressive mode to force the shape inference to propagate placeholder
2937     // shapes.
2938     ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
2939                               /*cpu_device=*/nullptr);
2940     GraphDef output;
2941     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2942     TF_EXPECT_OK(status);
2943 
2944     CompareGraphs(item.graph, output);
2945   }
2946 }
2947 
TEST_F(ConstantFoldingTest,LargeConstant)2948 TEST_F(ConstantFoldingTest, LargeConstant) {
2949   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
2950   // Generate a 4k by 4k constant matrix.
2951   Output mat_diag =
2952       ops::Const(scope.WithOpName("mat_diag"), 3.14f, TensorShape({1024 * 4}));
2953   Output mat = ops::Diag(scope.WithOpName("mat"), mat_diag);
2954   Output out = ops::Identity(scope.WithOpName("out"), mat);
2955 
2956   GrapplerItem item;
2957   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
2958   item.fetch.push_back("out");
2959 
2960   ConstantFolding optimizer(/*cpu_device=*/nullptr);
2961   GraphDef output;
2962   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
2963   TF_EXPECT_OK(status);
2964 
2965   // Make sure the diag node hasn't been folded, since it would use too much
2966   // memory to encode the corresponding constant.
2967   int found = 0;
2968   for (const NodeDef& node : output.node()) {
2969     if (node.name() == "out") {
2970       EXPECT_EQ("Identity", node.op());
2971       EXPECT_EQ(1, node.input_size());
2972       EXPECT_EQ("mat", node.input(0));
2973       ++found;
2974     } else if (node.name() == "mat") {
2975       EXPECT_EQ("Diag", node.op());
2976       EXPECT_EQ(1, node.input_size());
2977       EXPECT_EQ("mat_diag", node.input(0));
2978       ++found;
2979     }
2980   }
2981   EXPECT_EQ(2, found);
2982 
2983   EXPECT_GT(1024 * 1024, output.ByteSizeLong());
2984 
2985   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2986   EXPECT_EQ(1, tensors_expected.size());
2987   auto tensors = EvaluateNodes(output, item.fetch);
2988   EXPECT_EQ(1, tensors.size());
2989   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
2990 }
2991 
TEST_F(ConstantFoldingTest,SwitchIdenticalInputs)2992 TEST_F(ConstantFoldingTest, SwitchIdenticalInputs) {
2993   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2994   Output x = ops::Placeholder(s.WithOpName("x"), DT_BOOL,
2995                               ops::Placeholder::Shape(TensorShape({})));
2996   ops::Switch sw = ops::Switch(s.WithOpName("switch"), x, x);
2997   Output id_false = ops::LogicalNot(s.WithOpName("id_false"), sw.output_false);
2998   Output id_true = ops::LogicalNot(s.WithOpName("id_true"), sw.output_true);
2999 
3000   GrapplerItem item;
3001   item.fetch.push_back("id_false");
3002   item.fetch.push_back("id_true");
3003   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3004 
3005   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3006   GraphDef output;
3007   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3008   TF_EXPECT_OK(status);
3009 
3010   EXPECT_EQ(6, output.node_size());
3011   int found = 0;
3012   for (const auto& node : output.node()) {
3013     if (node.name() == "switch" || node.name() == "x") {
3014       ++found;
3015     }
3016     if (node.name() == "id_false") {
3017       EXPECT_EQ("Const", node.op());
3018       EXPECT_EQ(1, node.input_size());
3019       EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
3020       ++found;
3021     }
3022     if (node.name() == "id_true") {
3023       EXPECT_EQ("Const", node.op());
3024       EXPECT_EQ(1, node.input_size());
3025       EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
3026       ++found;
3027     }
3028     if (node.name() == "ConstantFoldingCtrl/switch_0") {
3029       EXPECT_EQ("Identity", node.op());
3030       EXPECT_EQ(1, node.input_size());
3031       EXPECT_EQ("switch", node.input(0));
3032       ++found;
3033     }
3034     if (node.name() == "ConstantFoldingCtrl/switch_1") {
3035       EXPECT_EQ("Identity", node.op());
3036       EXPECT_EQ(1, node.input_size());
3037       EXPECT_EQ("switch:1", node.input(0));
3038       ++found;
3039     }
3040   }
3041   EXPECT_EQ(6, found);
3042 
3043   // Evaluate id_true when input tensor x is true.
3044   Tensor x_t(DT_BOOL, TensorShape({}));
3045   x_t.flat<bool>()(0) = true;
3046   auto tensors_expected = EvaluateNodes(item.graph, {"id_true"}, {{"x", x_t}});
3047   EXPECT_EQ(1, tensors_expected.size());
3048   auto tensors = EvaluateNodes(output, {"id_true"}, {{"x", x_t}});
3049   EXPECT_EQ(1, tensors.size());
3050   test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3051 
3052   // Evalute id_false when input tensor is false.
3053   x_t.flat<bool>()(0) = false;
3054   tensors_expected = EvaluateNodes(item.graph, {"id_false"}, {{"x", x_t}});
3055   EXPECT_EQ(1, tensors_expected.size());
3056   tensors = EvaluateNodes(output, {"id_false"}, {{"x", x_t}});
3057   EXPECT_EQ(1, tensors.size());
3058   test::ExpectTensorEqual<bool>(tensors_expected[0], tensors[0]);
3059 }
3060 
TEST_F(ConstantFoldingTest,PartialFolding_AssociativeAndCommutative)3061 TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
3062   std::function<Output(const Scope&, InputList)> addn_fun =
3063       [](const Scope& scope, InputList inputs) {
3064         return ops::AddN(scope, inputs);
3065       };
3066   std::function<Output(const Scope&, InputList)> accumulate_fun =
3067       [](const Scope& scope, InputList inputs) {
3068         return ops::AccumulateNV2(scope, inputs, TensorShape({2, 2}));
3069       };
3070   for (bool use_add_n : {true, false}) {
3071     auto fun = use_add_n ? addn_fun : accumulate_fun;
3072     const string op_name = use_add_n ? "AddN" : "AccumulateNV2";
3073     Scope s = Scope::NewRootScope();
3074     Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3075                                 ops::Placeholder::Shape(TensorShape({2, 2})));
3076     Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3077                                 ops::Placeholder::Shape(TensorShape({2, 2})));
3078     Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3079                                 ops::Placeholder::Shape(TensorShape({2, 2})));
3080     Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3081     Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3082     Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2, 2});
3083     Output acc0 = fun(s.WithOpName("acc0"), {c1, c2, c3});
3084     Output acc1 = fun(s.WithOpName("acc1"), {x, y, z});
3085     Output acc2 = fun(s.WithOpName("acc2"), {c1, x, y});
3086     Output acc3 = fun(s.WithOpName("acc3"), {c1, c2, z});
3087     Output acc4 = fun(s.WithOpName("acc4"), {c1, y, c2});
3088     Output acc5 = fun(s.WithOpName("acc5"), {x, c1, c2});
3089     Output acc6 = fun(s.WithOpName("acc6"), {x, c1, y, c2});
3090     Output stack = ops::Stack(s.WithOpName("stack"),
3091                               {acc0, acc1, acc2, acc3, acc4, acc5, acc6});
3092 
3093     GrapplerItem item;
3094     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3095     item.fetch = {"stack"};
3096 
3097     ConstantFolding optimizer(/*cpu_device=*/nullptr);
3098     GraphDef output;
3099     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3100     TF_EXPECT_OK(status);
3101 
3102     EXPECT_EQ(16, output.node_size());
3103     for (const NodeDef& node : output.node()) {
3104       if (node.name() == "acc0") {
3105         EXPECT_EQ("Const", node.op());
3106       }
3107       if (node.name() == "acc1") {
3108         EXPECT_EQ(op_name, node.op());
3109         EXPECT_EQ(3, node.input_size());
3110         EXPECT_EQ("x", node.input(0));
3111         EXPECT_EQ("y", node.input(1));
3112         EXPECT_EQ("z", node.input(2));
3113       }
3114       if (node.name() == "acc2") {
3115         EXPECT_EQ(op_name, node.op());
3116         EXPECT_EQ(3, node.input_size());
3117         EXPECT_EQ("c1", node.input(0));
3118         EXPECT_EQ("x", node.input(1));
3119         EXPECT_EQ("y", node.input(2));
3120       }
3121       if (node.name() == "acc3") {
3122         EXPECT_EQ(op_name, node.op());
3123         EXPECT_EQ(2, node.input_size());
3124         EXPECT_EQ("ConstantFolding/acc3_partial_split_2", node.input(0));
3125         EXPECT_EQ("z", node.input(1));
3126       }
3127       if (node.name() == "acc4") {
3128         EXPECT_EQ(op_name, node.op());
3129         EXPECT_EQ(2, node.input_size());
3130         EXPECT_EQ("ConstantFolding/acc4_partial_split_2", node.input(0));
3131         EXPECT_EQ("y", node.input(1));
3132       }
3133       if (node.name() == "acc5") {
3134         EXPECT_EQ(op_name, node.op());
3135         EXPECT_EQ(2, node.input_size());
3136         EXPECT_EQ("x", node.input(0));
3137         EXPECT_EQ("ConstantFolding/acc5_partial_split_2", node.input(1));
3138       }
3139       if (node.name() == "acc6") {
3140         EXPECT_EQ(op_name, node.op());
3141         EXPECT_EQ(3, node.input_size());
3142         EXPECT_EQ("x", node.input(0));
3143         EXPECT_EQ("ConstantFolding/acc6_partial_split_2", node.input(1));
3144         EXPECT_EQ("y", node.input(2));
3145       }
3146       if (str_util::StartsWith(node.name(), "ConstantFolding/")) {
3147         EXPECT_EQ("Const", node.op());
3148       }
3149     }
3150 
3151     std::vector<string> fetch = {"acc0"};
3152     auto tensors_expected = EvaluateNodes(item.graph, fetch);
3153     auto tensors = EvaluateNodes(output, fetch);
3154     EXPECT_EQ(1, tensors_expected.size());
3155     EXPECT_EQ(1, tensors.size());
3156     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3157   }
3158 }
3159 
TEST_F(ConstantFoldingTest,PartialFolding_Concat)3160 TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
3161   Scope s = Scope::NewRootScope();
3162   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
3163                               ops::Placeholder::Shape(TensorShape({2, 2})));
3164   Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
3165                               ops::Placeholder::Shape(TensorShape({2, 2})));
3166   Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
3167                               ops::Placeholder::Shape(TensorShape({2, 2})));
3168   Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3169   Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
3170   Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
3171   Output concat0 = ops::Concat(s.WithOpName("concat0"), {c1, c2, c1}, axis);
3172   Output concat1 = ops::Concat(s.WithOpName("concat1"), {x, y, z}, axis);
3173   Output concat2 = ops::Concat(s.WithOpName("concat2"), {c1, x, y}, axis);
3174   Output concat3 = ops::Concat(s.WithOpName("concat3"), {c1, c2, z}, axis);
3175   Output concat4 = ops::Concat(s.WithOpName("concat4"), {c1, y, c2}, axis);
3176   Output concat5 = ops::Concat(s.WithOpName("concat5"), {x, c1, c2}, axis);
3177   Output concat6 = ops::Concat(s.WithOpName("concat6"), {x, c1, y, c2}, axis);
3178   Output concat7 = ops::Concat(s.WithOpName("concat7"), {x, y, c1, c2}, axis);
3179   Output concat8 = ops::Concat(s.WithOpName("concat8"), {x, c1, c2, y}, axis);
3180   Output concat9 = ops::Concat(s.WithOpName("concat9"), {c1, c2, x, y}, axis);
3181 
3182   GrapplerItem item;
3183   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3184   item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
3185                 "concat5", "concat6", "concat7", "concat8", "concat9"};
3186 
3187   auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
3188   EXPECT_EQ(1, tensors_expected.size());
3189   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3190   GraphDef output;
3191   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3192   TF_EXPECT_OK(status);
3193   // Run the optimizer twice to make sure the rewrite is idempotent.
3194   item.graph.Swap(&output);
3195   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3196   TF_EXPECT_OK(status);
3197 
3198   EXPECT_EQ(21, output.node_size());
3199   for (int i = 0; i < output.node_size(); ++i) {
3200     const NodeDef& node = output.node(i);
3201     if (node.name() == "concat0") {
3202       EXPECT_EQ("Const", node.op());
3203     } else if (node.name() == "concat3") {
3204       EXPECT_EQ(3, node.input_size());
3205       EXPECT_EQ("ConstantFolding/concat3_partial_split_0", node.input(0));
3206       EXPECT_EQ("z", node.input(1));
3207       EXPECT_EQ("axis", node.input(2));
3208     } else if (node.name() == "concat5") {
3209       EXPECT_EQ(3, node.input_size());
3210       EXPECT_EQ("x", node.input(0));
3211       EXPECT_EQ("ConstantFolding/concat5_partial_split_1", node.input(1));
3212       EXPECT_EQ("axis", node.input(2));
3213     } else if (node.name() == "concat7") {
3214       EXPECT_EQ(4, node.input_size());
3215       EXPECT_EQ("x", node.input(0));
3216       EXPECT_EQ("y", node.input(1));
3217       EXPECT_EQ("ConstantFolding/concat7_partial_split_2", node.input(2));
3218       EXPECT_EQ("axis", node.input(3));
3219     } else if (node.name() == "concat8") {
3220       EXPECT_EQ(4, node.input_size());
3221       EXPECT_EQ("x", node.input(0));
3222       EXPECT_EQ("ConstantFolding/concat8_partial_split_1", node.input(1));
3223       EXPECT_EQ("y", node.input(2));
3224       EXPECT_EQ("axis", node.input(3));
3225     } else if (node.name() == "concat9") {
3226       EXPECT_EQ(4, node.input_size());
3227       EXPECT_EQ("ConstantFolding/concat9_partial_split_0", node.input(0));
3228       EXPECT_EQ("x", node.input(1));
3229       EXPECT_EQ("y", node.input(2));
3230       EXPECT_EQ("axis", node.input(3));
3231     } else if (str_util::StartsWith(node.name(), "ConstantFolding/")) {
3232       EXPECT_EQ("Const", node.op());
3233     } else {
3234       EXPECT_EQ(item.graph.node(i).DebugString(), node.DebugString());
3235     }
3236   }
3237 
3238   auto tensors = EvaluateNodes(output, {"concat0"});
3239   EXPECT_EQ(1, tensors.size());
3240   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
3241 }
3242 
TEST_F(ConstantFoldingTest,PartialFolding_IdentityN)3243 TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) {
3244   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3245   Output x = ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3246                               ops::Placeholder::Shape(TensorShape({})));
3247   Output c1 = ops::Const(scope.WithOpName("c1"), 1.0f, {2, 2});
3248   Output c2 = ops::Const(scope.WithOpName("c2"), 2.0f, {2, 2});
3249   auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {c1, x, c2});
3250   auto id0 = ops::Identity(scope.WithOpName("id0"), id_n[0]);
3251   auto id1 = ops::Identity(scope.WithOpName("id1"), id_n[1]);
3252   auto add0 = ops::Add(scope.WithOpName("add0"), id_n[0], id_n[1]);
3253   auto add1 = ops::Add(scope.WithOpName("add1"), id_n[0], id_n[2]);
3254 
3255   GrapplerItem item;
3256   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3257   item.fetch.push_back("id0");
3258   item.fetch.push_back("id1");
3259   item.fetch.push_back("add0");
3260   item.fetch.push_back("add1");
3261 
3262   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3263   GraphDef output;
3264   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3265   TF_EXPECT_OK(status);
3266   EXPECT_EQ(8, output.node_size());
3267   for (const auto& node : output.node()) {
3268     // id_n should remain unchanged.
3269     if (node.name() == "id_n") {
3270       EXPECT_EQ(3, node.input_size());
3271       EXPECT_EQ("c1", node.input(0));
3272       EXPECT_EQ("x", node.input(1));
3273       EXPECT_EQ("c2", node.input(2));
3274     }
3275     // id0 should be constant folded, and a control dependency from id_n.
3276     if (node.name() == "id0") {
3277       EXPECT_EQ("Const", node.op());
3278       EXPECT_EQ(1, node.input_size());
3279       EXPECT_EQ("^id_n", node.input(0));
3280     }
3281     // id1 is unchanged.
3282     if ("id1" == node.name()) {
3283       EXPECT_EQ(1, node.input_size());
3284       EXPECT_EQ("id_n:1", node.input(0));
3285     }
3286 
3287     if ("add0" == node.name()) {
3288       EXPECT_EQ(2, node.input_size());
3289       EXPECT_EQ("c1", node.input(0));
3290       EXPECT_EQ("id_n:1", node.input(1));
3291     }
3292     // add1 should bo constant folded and have a control dependency from id_n.
3293     if ("add1" == node.name()) {
3294       EXPECT_EQ("Const", node.op());
3295       EXPECT_EQ(1, node.input_size());
3296       EXPECT_EQ("^id_n", node.input(0));
3297     }
3298   }
3299 
3300   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({}));
3301   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
3302   EXPECT_EQ(4, tensors_expected.size());
3303   auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
3304   EXPECT_EQ(4, tensors.size());
3305   for (int i = 0; i < tensors.size(); i++) {
3306     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
3307   }
3308 }
3309 
TEST_F(ConstantFoldingTest,TrivialPack)3310 TEST_F(ConstantFoldingTest, TrivialPack) {
3311   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3312   Output x =
3313       ops::RandomNormal(scope.WithOpName("x"), {2, 2}, DataType::DT_FLOAT);
3314   Output y = ops::Const(scope.WithOpName("y"), {2.0f}, {});
3315   auto stack =
3316       ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
3317                  ops::Stack::Axis(1));
3318   auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x});
3319 
3320   GrapplerItem item;
3321   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3322   item.fetch = {"stack", "stack_no_axis"};
3323 
3324   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3325   GraphDef output;
3326   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3327   TF_EXPECT_OK(status);
3328   EXPECT_EQ(7, output.node_size());
3329   int found = 0;
3330   for (const auto& node : output.node()) {
3331     if (node.name() == "stack") {
3332       EXPECT_EQ("ExpandDims", node.op());
3333       EXPECT_EQ(3, node.input_size());
3334       EXPECT_EQ("x", node.input(0));
3335       EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
3336       EXPECT_EQ("^y", node.input(2));
3337       ++found;
3338     } else if (node.name() == "stack_no_axis") {
3339       EXPECT_EQ("ExpandDims", node.op());
3340       EXPECT_EQ(2, node.input_size());
3341       EXPECT_EQ("x", node.input(0));
3342       EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1));
3343       ++found;
3344     } else if (node.name() == "ConstantFolding/stack_const_axis") {
3345       EXPECT_EQ("Const", node.op());
3346       EXPECT_EQ(1, node.input_size());
3347       EXPECT_EQ("^x", node.input(0));
3348       ++found;
3349     }
3350   }
3351   EXPECT_EQ(found, 3);
3352 
3353   std::vector<string> fetch = {"stack", "stack_no_axis"};
3354   auto tensors_expected = EvaluateNodes(item.graph, fetch);
3355   auto tensors = EvaluateNodes(output, fetch);
3356   EXPECT_EQ(2, tensors_expected.size());
3357   EXPECT_EQ(2, tensors.size());
3358   EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3359   EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape());
3360 }
3361 
3362 // The test does not evalute the optimized and original graphs to check if their
3363 // outputs are the same. See b/78233179.
TEST_F(ConstantFoldingTest,Enter)3364 TEST_F(ConstantFoldingTest, Enter) {
3365   GrapplerItem item;
3366   AttrValue frame_name;
3367   frame_name.set_s("foo");
3368   AttrValue is_constant_true;
3369   is_constant_true.set_b(true);
3370   AttrValue is_constant_false;
3371   is_constant_false.set_b(false);
3372   AttrValue type;
3373   type.set_type(DT_FLOAT);
3374   AttrValue value;
3375   Tensor value_tensor(DT_FLOAT, TensorShape({}));
3376   value_tensor.flat<float>()(0) = 1;
3377   value_tensor.AsProtoTensorContent(value.mutable_tensor());
3378 
3379   GraphDef& graph = item.graph;
3380   AddNode("x", "Placeholder", {}, {{"dtype", type}}, &graph);
3381   AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph);
3382   AddNode("enter1", "Enter", {"x"},
3383           {{"T", type},
3384            {"frame_name", frame_name},
3385            {"is_constant", is_constant_true}},
3386           &graph);
3387   AddNode("enter2", "Enter", {"c1"},
3388           {{"T", type},
3389            {"frame_name", frame_name},
3390            {"is_constant", is_constant_true}},
3391           &graph);
3392   AddNode("enter3", "Enter", {"c1"},
3393           {{"T", type},
3394            {"frame_name", frame_name},
3395            {"is_constant", is_constant_false}},
3396           &graph);
3397   AddNode("id1", "Identity", {"enter1"}, {{"T", type}}, &graph);
3398   AddNode("id2", "Identity", {"enter2"}, {{"T", type}}, &graph);
3399   AddNode("id3", "Identity", {"enter2"}, {{"T", type}}, &graph);
3400   AddNode("id4", "Identity", {"enter3"}, {{"T", type}}, &graph);
3401   item.fetch.push_back("id1");
3402   item.fetch.push_back("id2");
3403   item.fetch.push_back("id3");
3404   item.fetch.push_back("id4");
3405 
3406   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3407   GraphDef output;
3408   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3409   TF_EXPECT_OK(status);
3410   // Run the optimizer twice to make sure the rewrite is idempotent.
3411   item.graph.Swap(&output);
3412   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3413   TF_EXPECT_OK(status);
3414 
3415   EXPECT_EQ(9, output.node_size());
3416   for (const NodeDef& node : output.node()) {
3417     if (node.name() == "id1") {
3418       EXPECT_EQ("Identity", node.op());
3419       EXPECT_EQ(1, node.input_size());
3420       EXPECT_EQ("enter1", node.input(0));
3421     }
3422     if (node.name() == "id2" || node.name() == "id3") {
3423       EXPECT_EQ("Const", node.op());
3424       EXPECT_EQ(1, node.input_size());
3425       EXPECT_EQ("^enter2", node.input(0));
3426     }
3427     if (node.name() == "id4") {
3428       EXPECT_EQ("Identity", node.op());
3429       EXPECT_EQ(1, node.input_size());
3430       EXPECT_EQ("enter3", node.input(0));
3431     }
3432   }
3433 }
3434 
TEST_F(ConstantFoldingTest,TensorArraySize)3435 TEST_F(ConstantFoldingTest, TensorArraySize) {
3436   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3437   Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
3438   Output placeholder =
3439       ops::Placeholder(scope.WithOpName("placeholder"), DT_RESOURCE,
3440                        ops::Placeholder::Shape(TensorShape({2})));
3441   Output foo = ops::Const(scope.WithOpName("foo"), 5.0f, TensorShape({}));
3442   auto dynamic_array =
3443       ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
3444                        ops::TensorArray::DynamicSize(true));
3445   auto static_array =
3446       ops::TensorArray(scope.WithOpName("static"), size, DT_FLOAT,
3447                        ops::TensorArray::DynamicSize(false));
3448   auto dynamic_sz = ops::TensorArraySize(
3449       scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
3450   auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
3451                                         static_array.handle, static_array.flow);
3452   auto placeholder_sz = ops::TensorArraySize(scope.WithOpName("placeholder_sz"),
3453                                              placeholder, foo);
3454 
3455   GrapplerItem item;
3456   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3457 
3458   auto tensors_expected =
3459       EvaluateNodes(item.graph, {"dynamic_sz", "static_sz"});
3460 
3461   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3462   GraphDef output;
3463   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3464   TF_EXPECT_OK(status);
3465   // Run the optimizer twice to make sure the rewrite is idempotent.
3466   item.graph.Swap(&output);
3467   status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3468   TF_EXPECT_OK(status);
3469 
3470   EXPECT_EQ(8, output.node_size());
3471   EXPECT_EQ("dynamic_sz", output.node(5).name());
3472   EXPECT_EQ("TensorArraySizeV3", output.node(5).op());
3473   EXPECT_EQ("static_sz", output.node(6).name());
3474   EXPECT_EQ("Const", output.node(6).op());
3475   EXPECT_EQ("placeholder_sz", output.node(7).name());
3476   EXPECT_EQ("TensorArraySizeV3", output.node(7).op());
3477 
3478   auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
3479   EXPECT_EQ(2, tensors_expected.size());
3480   EXPECT_EQ(2, tensors_actual.size());
3481   test::ExpectTensorEqual<int32>(tensors_expected[0], tensors_actual[0]);
3482   test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
3483 }
3484 
TEST_F(ConstantFoldingTest,FoldingPreservesDenormalFlushing)3485 TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) {
3486   // Multiplying min() with 0.1 gives a denormal without FTZ and zero with FTZ.
3487   // Make sure constant folding behaves the same way as TensorFlow.
3488   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3489 
3490   Output a =
3491       ops::Const(s.WithOpName("a"), std::numeric_limits<float>::min(), {1});
3492   Output b = ops::Const(s.WithOpName("b"), 0.1f, {1});
3493   Output c = ops::Mul(s.WithOpName("c"), a, b);
3494 
3495   GrapplerItem item;
3496   item.fetch.push_back("c");
3497   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3498 
3499   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3500   GraphDef output;
3501   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3502   TF_EXPECT_OK(status);
3503 
3504   EXPECT_EQ(1, output.node_size());
3505 
3506   const NodeDef& node_d = output.node(0);
3507   EXPECT_EQ("c", node_d.name());
3508   EXPECT_EQ("Const", node_d.op());
3509 
3510   std::vector<string> fetch = {"c"};
3511   auto tensors_expected = EvaluateNodes(item.graph, fetch);
3512   auto tensors = EvaluateNodes(output, fetch);
3513   EXPECT_EQ(1, tensors_expected.size());
3514   EXPECT_EQ(1, tensors.size());
3515   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
3516 }
3517 
TEST_F(ConstantFoldingTest,EvaluatingLargeConstantNoFoldingMergingLoop)3518 TEST_F(ConstantFoldingTest, EvaluatingLargeConstantNoFoldingMergingLoop) {
3519   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3520 
3521   int size = 10 * 1024 * 1024 / 4 / 2;
3522   Output nonconst =
3523       ops::RandomUniform(s.WithOpName("nonconst"), {size, 1}, DT_FLOAT);
3524   Output const1 = ops::Const(s.WithOpName("const1"), 0.0f, {size, 1});
3525   Output const2 = ops::Const(s.WithOpName("const2"), 1.0f, {size, 1});
3526   Output axis = ops::Const(s.WithOpName("axis"), -1, {});
3527   Output concat1 =
3528       ops::Concat(s.WithOpName("concat1"), {nonconst, const1}, axis);
3529   Output result = ops::Concat(s.WithOpName("result"), {concat1, const2}, axis);
3530 
3531   GrapplerItem item;
3532   item.fetch.push_back("result");
3533   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3534 
3535   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3536   GraphDef output;
3537   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3538   TF_EXPECT_OK(status);
3539 
3540   std::vector<string> fetch = {"result"};
3541   auto tensors_expected = EvaluateNodes(item.graph, fetch);
3542   auto tensors = EvaluateNodes(output, fetch);
3543   EXPECT_EQ(1, tensors_expected.size());
3544   EXPECT_EQ(1, tensors.size());
3545   EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
3546 }
3547 
3548 class ConstantFoldingCastConstTest : public GrapplerTest {
3549  protected:
ConstantFoldingCastConst(bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3550   void ConstantFoldingCastConst(bool fetch_const, bool fetch_cast,
3551                                 bool fetch_const_child, bool fetch_cast_child) {
3552     if (!fetch_const && !fetch_cast && !fetch_const_child &&
3553         !fetch_cast_child) {
3554       return;
3555     }
3556 
3557     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3558     CreateCastConstGraph(s);
3559     GrapplerItem item;
3560     int expected_output_size = SetFetch(&item, fetch_const, fetch_cast,
3561                                         fetch_const_child, fetch_cast_child);
3562     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3563 
3564     GraphDef output = ConstantFoldingOptimize(item);
3565     EXPECT_EQ(expected_output_size, output.node_size());
3566 
3567     EvaluateAndCompareUnoptimized(item.graph, output, item.fetch);
3568   }
3569 
3570  private:
CreateCastConstGraph(const tensorflow::Scope & s)3571   void CreateCastConstGraph(const tensorflow::Scope& s) {
3572     Output const1 = ops::Const(s.WithOpName("const1"), 2, {5, 5});
3573     Output cast = ops::Cast(s.WithOpName("cast"), const1, DT_FLOAT);
3574     Output const1_child = ops::Identity(s.WithOpName("const1_child"), const1);
3575     Output cast_child = ops::Identity(s.WithOpName("cast_child"), cast);
3576   }
3577 
SetFetch(GrapplerItem * item,bool fetch_const,bool fetch_cast,bool fetch_const_child,bool fetch_cast_child)3578   int SetFetch(GrapplerItem* item, bool fetch_const, bool fetch_cast,
3579                bool fetch_const_child, bool fetch_cast_child) {
3580     int expected_output_size = 0;
3581     if (fetch_const) {
3582       item->fetch.push_back("const1");
3583       expected_output_size++;
3584     }
3585     if (fetch_cast) {
3586       item->fetch.push_back("cast");
3587       expected_output_size++;
3588     }
3589     if (fetch_const_child) {
3590       item->fetch.push_back("const1_child");
3591       expected_output_size++;
3592     }
3593     if (fetch_cast_child) {
3594       item->fetch.push_back("cast_child");
3595       expected_output_size++;
3596     }
3597     return expected_output_size;
3598   }
3599 
ConstantFoldingOptimize(const GrapplerItem & item)3600   GraphDef ConstantFoldingOptimize(const GrapplerItem& item) {
3601     ConstantFolding optimizer(/*cpu_device=*/nullptr);
3602     GraphDef output;
3603     Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3604     TF_EXPECT_OK(status);
3605     return output;
3606   }
3607 
EvaluateAndCompareUnoptimized(const GraphDef & unoptimized_graph,const GraphDef & optimized_graph,const std::vector<string> & fetch_nodes)3608   void EvaluateAndCompareUnoptimized(const GraphDef& unoptimized_graph,
3609                                      const GraphDef& optimized_graph,
3610                                      const std::vector<string>& fetch_nodes) {
3611     auto tensors_expected = EvaluateNodes(unoptimized_graph, fetch_nodes);
3612     auto tensors = EvaluateNodes(optimized_graph, fetch_nodes);
3613     ASSERT_EQ(fetch_nodes.size(), tensors_expected.size());
3614     ASSERT_EQ(fetch_nodes.size(), tensors.size());
3615     for (int i = 0; i < fetch_nodes.size(); i++) {
3616       if (fetch_nodes[i] == "const1" || fetch_nodes[i] == "const1_child") {
3617         test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3618       } else {
3619         test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
3620       }
3621     }
3622   }
3623 };
3624 
TEST_F(ConstantFoldingCastConstTest,CastConstFolding)3625 TEST_F(ConstantFoldingCastConstTest, CastConstFolding) {
3626   for (bool fetch_const : {false, true}) {
3627     for (bool fetch_cast : {false, true}) {
3628       for (bool fetch_const_child : {false, true}) {
3629         for (bool fetch_cast_child : {false, true}) {
3630           ConstantFoldingCastConst(fetch_const, fetch_cast, fetch_const_child,
3631                                    fetch_cast_child);
3632         }
3633       }
3634     }
3635   }
3636 }
3637 
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNode)3638 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNode) {
3639   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3640 
3641   Output x =
3642       ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
3643                        ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
3644   Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
3645   Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
3646   Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
3647 
3648   GrapplerItem item;
3649   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3650   item.fetch = {"ones_like", "zeros_like", "fill"};
3651   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
3652   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
3653 
3654   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3655   GraphDef output;
3656   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3657   TF_EXPECT_OK(status);
3658 
3659   EXPECT_EQ(output.node_size(), 6);
3660   for (const auto& node : output.node()) {
3661     if (node.name() != "x") {
3662       EXPECT_EQ(node.op(), "Const");
3663     }
3664     if (node.name() == "ones_like" || node.name() == "zeros_like") {
3665       ASSERT_EQ(node.input_size(), 1);
3666       EXPECT_EQ(node.input(0), "^x");
3667     }
3668     if (node.name() == "fill") {
3669       ASSERT_EQ(node.input_size(), 2);
3670       EXPECT_EQ(node.input(0)[0], '^');
3671       EXPECT_EQ(node.input(1)[0], '^');
3672     }
3673   }
3674   auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
3675   ASSERT_EQ(item.fetch.size(), tensors.size());
3676   ASSERT_EQ(tensors_expected.size(), tensors.size());
3677   for (int i = 0; i < tensors.size(); i++) {
3678     if (item.fetch[i] == "fill") {
3679       test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
3680     } else {
3681       test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
3682     }
3683   }
3684 }
3685 
TEST_F(ConstantFoldingTest,MaterializeConstantValuedNodeHugeFill)3686 TEST_F(ConstantFoldingTest, MaterializeConstantValuedNodeHugeFill) {
3687   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3688   Output value = ops::Const(scope.WithOpName("value"), 42, {});
3689   Output fill_huge = ops::Fill(scope.WithOpName("fill_huge"),
3690                                {1024, 1024, 1024, 1024, 1024}, value);
3691 
3692   GrapplerItem item;
3693   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3694   // Manually convert the input value format to tensor_content to test this
3695   // case.
3696   NodeDef* node = item.graph.mutable_node(0);
3697   ASSERT_EQ(node->name(), "value");
3698   TensorProto* t = (*node->mutable_attr())["value"].mutable_tensor();
3699   t->clear_int_val();
3700   int val = 42;
3701   port::CopyFromArray(t->mutable_tensor_content(),
3702                       reinterpret_cast<const char*>(&val), sizeof(int));
3703   item.fetch = {"fill_huge"};
3704   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3705   GraphDef output;
3706   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3707   TF_EXPECT_OK(status);
3708 
3709   EXPECT_EQ(output.node_size(), 3);
3710   for (const auto& node : output.node()) {
3711     EXPECT_EQ(node.op(), "Const");
3712     if (node.name() == "fill_huge") {
3713       ASSERT_EQ(node.input_size(), 2);
3714       EXPECT_EQ(node.input(0)[0], '^');
3715       EXPECT_EQ(node.input(1)[0], '^');
3716     }
3717   }
3718 }
3719 
TEST_F(ConstantFoldingTest,BitcastDenormalFloats)3720 TEST_F(ConstantFoldingTest, BitcastDenormalFloats) {
3721   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3722 
3723   Tensor x_t(DT_INT64, TensorShape({2, 2}));
3724   x_t.flat<int64>()(0) = 9223372036854775807L;
3725   x_t.flat<int64>()(1) = 1L;
3726   x_t.flat<int64>()(2) = 9223372036854775807L;
3727   x_t.flat<int64>()(3) = 1L;
3728   Output x = ops::Const(scope.WithOpName("x"), x_t);
3729   Output y = ops::Bitcast(scope.WithOpName("y"), x, DT_FLOAT);
3730   Output z = ops::Bitcast(scope.WithOpName("z"), y, DT_INT64);
3731 
3732   GrapplerItem item;
3733   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3734   item.fetch = {"z"};
3735   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
3736 
3737   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3738   GraphDef output;
3739   Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
3740   TF_EXPECT_OK(status);
3741 
3742   ASSERT_EQ(output.node_size(), 1);
3743   const NodeDef& node = output.node(0);
3744   EXPECT_EQ(node.name(), "z");
3745   EXPECT_EQ(node.op(), "Const");
3746 
3747   auto tensors = EvaluateNodes(output, item.fetch, {});
3748   ASSERT_EQ(tensors.size(), 1);
3749   ASSERT_EQ(tensors_expected.size(), 1);
3750   test::ExpectTensorEqual<int64>(tensors[0], tensors_expected[0]);
3751 }
3752 
TEST_F(ConstantFoldingTest,CompressConstants)3753 TEST_F(ConstantFoldingTest, CompressConstants) {
3754   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
3755   Tensor zeros_t(DT_FLOAT, TensorShape({64}));
3756   Tensor ones_t(DT_FLOAT, TensorShape({64}));
3757   for (int i = 0; i < 64; ++i) {
3758     zeros_t.flat<float>()(i) = 0.0f;
3759     ones_t.flat<float>()(i) = 1.0f;
3760   }
3761   Output zeros = ops::Const(scope.WithOpName("zeros"), zeros_t);
3762   Output host_ones = ops::Const(scope.WithOpName("host_ones"), ones_t);
3763   GrapplerItem item;
3764   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
3765   ASSERT_EQ(item.graph.node(1).name(), "host_ones");
3766   // There is not C++ api for HostConst, so we manually change the node type
3767   // here.
3768   item.graph.mutable_node(1)->set_op("HostConst");
3769   item.fetch = {"zeros", "host_ones"};
3770   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
3771 
3772   ConstantFolding optimizer(/*cpu_device=*/nullptr);
3773   GraphDef output;
3774   TF_EXPECT_OK(optimizer.Optimize(/*cluster=*/nullptr, item, &output));
3775 
3776   {
3777     ASSERT_EQ(output.node_size(), 2);
3778     const NodeDef& node = output.node(0);
3779     EXPECT_EQ(node.name(), "zeros");
3780     EXPECT_EQ(node.op(), "Const");
3781     const TensorProto& zeroes_t = node.attr().at("value").tensor();
3782     EXPECT_EQ(zeroes_t.float_val_size(), 1);
3783     EXPECT_EQ(zeroes_t.float_val(0), 0.0f);
3784   }
3785   {
3786     const NodeDef& node = output.node(1);
3787     EXPECT_EQ(node.name(), "host_ones");
3788     EXPECT_EQ(node.op(), "HostConst");
3789     const TensorProto& ones_t = node.attr().at("value").tensor();
3790     EXPECT_EQ(ones_t.float_val_size(), 1);
3791     EXPECT_EQ(ones_t.float_val(0), 1.0f);
3792   }
3793 
3794   auto tensors = EvaluateNodes(output, item.fetch, {});
3795   ASSERT_EQ(tensors.size(), 2);
3796   ASSERT_EQ(tensors_expected.size(), 2);
3797   for (int i = 0; i < 2; ++i) {
3798     test::ExpectTensorEqual<float>(tensors[i], tensors_expected[i]);
3799   }
3800 }
3801 
3802 }  // namespace
3803 }  // namespace grappler
3804 }  // namespace tensorflow
3805