• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef INTEL_MKL
17 #include "tensorflow/cc/ops/nn_ops_internal.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/tensor_testutil.h"
20 #include "tensorflow/core/grappler/devices.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/optimizers/remapper.h"
23 #include "tensorflow/core/grappler/utils/grappler_test.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/util/mkl_util.h"
27 
28 namespace tensorflow {
29 namespace grappler {
30 
31 class MklRemapperTest : public GrapplerTest {
32  public:
33   const string kAddNOp = "AddN";
34   const string kAddOp = "Add";
35   const string kAddV2Op = "AddV2";
36 
37  protected:
FuseConv2DWithBiasAndAddNOrAdd(const string & data_format,const string & activation,string add_op,bool add_with_bcast)38   void FuseConv2DWithBiasAndAddNOrAdd(const string& data_format,
39                                       const string& activation, string add_op,
40                                       bool add_with_bcast) {
41     using ::tensorflow::ops::Placeholder;
42 
43     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
44 
45     auto input_shape = (data_format == "NHWC")
46                            ? ops::Placeholder::Shape({8, 32, 32, 3})
47                            : ops::Placeholder::Shape({8, 3, 32, 32});
48     auto input_shape_addn = ops::Placeholder::Shape({});
49     if (data_format == "NHWC") {
50       if (add_with_bcast)
51         input_shape_addn = ops::Placeholder::Shape({128});
52       else
53         input_shape_addn = ops::Placeholder::Shape({8, 32, 32, 128});
54     } else {
55       if (add_with_bcast)
56         input_shape_addn = ops::Placeholder::Shape({32});
57       else
58         input_shape_addn = ops::Placeholder::Shape({8, 128, 32, 32});
59     }
60     auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128});
61     auto bias_shape = ops::Placeholder::Shape({128});
62 
63     auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
64     auto input_addn =
65         Placeholder(s.WithOpName("input_addn"), DT_FLOAT, input_shape_addn);
66     auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
67     auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
68 
69     std::vector<int> strides = {1, 1, 1, 1};
70     auto conv =
71         ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME",
72                     ops::Conv2D::Attrs().DataFormat(data_format));
73     auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias,
74                                  ops::BiasAdd::Attrs().DataFormat(data_format));
75 
76     auto addfetch = [&](::tensorflow::Input addop) {
77       auto activate = s.WithOpName("activation");
78       auto fetch = s.WithOpName("fetch");
79       if (activation == "Relu") {
80         ops::Identity(fetch, ops::Relu(activate, addop));
81       } else if (activation == "Relu6") {
82         ops::Identity(fetch, ops::Relu6(activate, addop));
83       } else if (activation == "Elu") {
84         ops::Identity(fetch, ops::Elu(activate, addop));
85       } else if (activation == "LeakyRelu") {
86         ops::Identity(fetch, ops::internal::LeakyRelu(activate, addop));
87       } else {
88         DCHECK(activation == "None");
89         ops::Identity(fetch, addop);
90       }
91     };
92 
93     if (add_op == kAddNOp) {
94       auto addn = ops::AddN(s.WithOpName(add_op),
95                             std::initializer_list<Input>{input_addn, bias_add});
96       addfetch(addn);
97     } else if (add_op == kAddV2Op) {
98       auto add = ops::AddV2(s.WithOpName(add_op), input_addn, bias_add);
99       addfetch(add);
100     } else {
101       auto add = ops::Add(s.WithOpName(add_op), input_addn, bias_add);
102       addfetch(add);
103     }
104     auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
105         TensorShape(input_shape.shape_.dim_sizes()));
106     auto input_addn_tensor = GenerateRandomTensor<DT_FLOAT>(
107         TensorShape(input_shape_addn.shape_.dim_sizes()));
108     auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
109         TensorShape(filter_shape.shape_.dim_sizes()));
110     auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
111         TensorShape(bias_shape.shape_.dim_sizes()));
112 
113     GrapplerItem item;
114     item.fetch = {"fetch"};
115     item.feed = {{"input", input_tensor},
116                  {"filter", filter_tensor},
117                  {"bias", bias_tensor},
118                  {"input_addn", input_addn_tensor}};
119     TF_CHECK_OK(s.ToGraphDef(&item.graph));
120 
121     // Place all nodes on CPU.
122     for (int i = 0; i < item.graph.node_size(); ++i) {
123       item.graph.mutable_node(i)->set_device("/device:CPU:0");
124     }
125 
126     // Set Rewriter config to AGGRESSIVE so that we can use Placeholder shape
127     // to test that Add with both inputs having same shape get fused with
128     // Conv2D. Setting this config to AGGRESSIVE is not required for the feature
129     // though.
130     Remapper optimizer(RewriterConfig::AGGRESSIVE);
131     GraphDef output;
132     TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
133 
134     bool check_fusion = !add_with_bcast;
135     int found = 0;
136     for (const NodeDef& node : output.node()) {
137       auto fetch_node_name = activation != "None" ? "activation" : add_op;
138       if (node.name() == fetch_node_name) {
139         if (check_fusion) {
140           EXPECT_EQ("_FusedConv2D", node.op());
141           EXPECT_EQ("input", node.input(0));
142           EXPECT_EQ("filter", node.input(1));
143 
144           EXPECT_EQ(2, node.attr().at("num_args").i());
145           EXPECT_EQ("bias", node.input(2));
146           EXPECT_EQ("input_addn", node.input(3));
147 
148           const auto fused_ops = node.attr().at("fused_ops").list().s();
149           if (activation != "None") {
150             EXPECT_EQ(3, fused_ops.size());
151             EXPECT_EQ("BiasAdd", fused_ops[0]);
152             EXPECT_EQ("Add", fused_ops[1]);
153             EXPECT_EQ(activation, fused_ops[2]);
154           } else {
155             EXPECT_EQ(2, fused_ops.size());
156             EXPECT_EQ("BiasAdd", fused_ops[0]);
157             EXPECT_EQ("Add", fused_ops[1]);
158           }
159         } else {
160           if (activation != "None") {
161             EXPECT_EQ(node.op(), activation);
162             ASSERT_EQ(node.input_size(), 1);
163             EXPECT_EQ(node.input(0), add_op);
164           } else {
165             EXPECT_EQ(node.op(), add_op);
166             ASSERT_EQ(node.input_size(), 2);
167           }
168         }
169         found++;
170       }
171     }
172     EXPECT_EQ(1, found);
173 
174     auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
175     auto tensors = EvaluateNodes(output, item.fetch, item.feed);
176     EXPECT_EQ(1, tensors_expected.size());
177     EXPECT_EQ(1, tensors.size());
178     // Using relative tolerance since oneDNN could produce different results
179     // when float32 numbers need to be rounded during accumulation.
180     test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
181   }
182 };
183 
184 #define CREATE_CONV2DFUSION_TEST(data_format, addop, activation, bcast)                          \
185   TEST_F(                                                                                        \
186       MklRemapperTest,                                                                           \
187       FuseConv2DWithBiasAnd##addop##_##data_format##_activation##activation##_addbcast##bcast) { \
188     FuseConv2DWithBiasAndAddNOrAdd(#data_format, #activation, #addop, bcast);                    \
189   }
190 
191 #define CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(data_format, addop, bcast) \
192   CREATE_CONV2DFUSION_TEST(data_format, addop, Relu, bcast);               \
193   CREATE_CONV2DFUSION_TEST(data_format, addop, Relu6, bcast);              \
194   CREATE_CONV2DFUSION_TEST(data_format, addop, Elu, bcast);                \
195   CREATE_CONV2DFUSION_TEST(data_format, addop, LeakyRelu, bcast);          \
196   CREATE_CONV2DFUSION_TEST(data_format, addop, None, bcast);
197 
198 #define CREATE_CONV2DFUSION_ADD_NOBCAST_TEST(addop)            \
199   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, false); \
200   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, false);
201 
202 CREATE_CONV2DFUSION_ADD_NOBCAST_TEST(AddN);
203 
204 #define CREATE_CONV2DFUSION_ADD_BCAST_TEST(addop)              \
205   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, false); \
206   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, false); \
207   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, true);  \
208   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, true);
209 
210 CREATE_CONV2DFUSION_ADD_BCAST_TEST(Add);
211 CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2);
212 
213 #undef CREATE_CONV2DFUSION_ADD_NOBCAST_TEST
214 #undef CREATE_CONV2DFUSION_ADD_BCAST_TEST
215 #undef CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST
216 #undef CREATE_CONV2DFUSION_TEST
217 
218 #define REGISTER_TEST(NAME, T, INPUT)                                         \
219   TEST_F(MklRemapperTest, NAME##_##T) {                                       \
220     using ::tensorflow::ops::Placeholder;                                     \
221                                                                               \
222     for (const string& activation : {"Relu", "Relu6", "Elu", "None"}) {       \
223       tensorflow::Scope s = tensorflow::Scope::NewRootScope();                \
224                                                                               \
225       auto input_shape = Placeholder::Shape({8, 32, 32, 3});                  \
226       auto filter_shape = Placeholder::Shape({1, 1, 3, 1});                   \
227       auto bias_shape = Placeholder::Shape({3});                              \
228                                                                               \
229       auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); \
230       auto filter =                                                           \
231           Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);        \
232       auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);    \
233                                                                               \
234       std::vector<int> strides = {1, 1, 1, 1};                                \
235       auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),  \
236                                              input, filter, strides, "SAME"); \
237       auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);     \
238                                                                               \
239       ops::Identity fetch = [&]() -> ops::Identity {                          \
240         auto activate = s.WithOpName("activation");                           \
241         auto fetch = s.WithOpName("fetch");                                   \
242                                                                               \
243         if (activation == "Relu") {                                           \
244           return ops::Identity(fetch, ops::Relu(activate, bias_add));         \
245         } else if (activation == "Relu6") {                                   \
246           return ops::Identity(fetch, ops::Relu6(activate, bias_add));        \
247         } else if (activation == "Elu") {                                     \
248           return ops::Identity(fetch, ops::Elu(activate, bias_add));          \
249         }                                                                     \
250                                                                               \
251         DCHECK(activation == "None");                                         \
252         return ops::Identity(fetch, bias_add);                                \
253       }();                                                                    \
254                                                                               \
255       auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});          \
256       auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 1});           \
257       auto bias_t = GenerateRandomTensor<DT_FLOAT>({3});                      \
258                                                                               \
259       GrapplerItem item;                                                      \
260       item.fetch = {"fetch"};                                                 \
261       item.feed = {                                                           \
262           {"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};        \
263       TF_CHECK_OK(s.ToGraphDef(&item.graph));                                 \
264                                                                               \
265       for (int i = 0; i < item.graph.node_size(); ++i) {                      \
266         item.graph.mutable_node(i)->set_device("/device:CPU:0");              \
267       }                                                                       \
268                                                                               \
269       Remapper optimizer(RewriterConfig::ON);                                 \
270       GraphDef output;                                                        \
271       TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));                \
272                                                                               \
273       int found = 0;                                                          \
274       for (const NodeDef& node : output.node()) {                             \
275         if (node.name() != "bias_add" && node.name() != "activation")         \
276           continue;                                                           \
277                                                                               \
278         EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");                  \
279         ASSERT_EQ(node.input_size(), 3);                                      \
280         EXPECT_EQ(node.input(0), "input");                                    \
281         EXPECT_EQ(node.input(1), "filter");                                   \
282                                                                               \
283         EXPECT_EQ(node.attr().at("num_args").i(), 1);                         \
284         EXPECT_EQ(node.input(2), "bias");                                     \
285                                                                               \
286         const auto fused_ops = node.attr().at("fused_ops").list().s();        \
287         if (node.name() == "bias_add") {                                      \
288           ASSERT_EQ(fused_ops.size(), 1);                                     \
289           EXPECT_EQ(fused_ops[0], "BiasAdd");                                 \
290           found++;                                                            \
291         }                                                                     \
292         if (node.name() == "activation") {                                    \
293           ASSERT_EQ(fused_ops.size(), 2);                                     \
294           EXPECT_EQ(fused_ops[0], "BiasAdd");                                 \
295           EXPECT_EQ(fused_ops[1], activation);                                \
296           found++;                                                            \
297         }                                                                     \
298       }                                                                       \
299       EXPECT_EQ(found, 1);                                                    \
300                                                                               \
301       auto tensors_expected =                                                 \
302           EvaluateNodes(item.graph, item.fetch, item.feed);                   \
303       ASSERT_EQ(tensors_expected.size(), 1);                                  \
304       auto tensors = EvaluateNodes(output, item.fetch, item.feed);            \
305       ASSERT_EQ(tensors.size(), 1);                                           \
306       test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);   \
307     }                                                                         \
308   }
309 REGISTER_TEST_ALL_TYPES(FuseDepthwiseConv2DWithBiasAndActivation);
310 #undef REGISTER_TEST
311 
312 #ifdef ENABLE_MKLDNN_V1
TEST_F(MklRemapperTest,FuseBatchNormWithRelu)313 TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
314   using ::tensorflow::ops::Placeholder;
315 
316   for (bool is_training : {true, false}) {
317     for (bool has_side_input : {true, false}) {
318       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
319 
320       const int num_channels = 24;
321 
322       TensorShape channel_shape({num_channels});
323       TensorShape empty_shape({0});
324 
325       auto input =
326           Placeholder(s.WithOpName("input"), DT_FLOAT,
327                       ops::Placeholder::Shape({2, 8, 8, num_channels}));
328       auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_FLOAT);
329       auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
330       auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
331       auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
332       auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
333 
334       float epsilon = 0.1f;
335       auto fbn =
336           ops::FusedBatchNormV3(s.WithOpName("fused_batch_norm"), input_cast,
337                                 scale, offset, mean, var,
338                                 ops::FusedBatchNormV3::IsTraining(is_training)
339                                     .Epsilon(epsilon)
340                                     .DataFormat("NHWC"));
341 
342       if (has_side_input) {
343         auto side_input =
344             Placeholder(s.WithOpName("side_input"), DT_FLOAT,
345                         ops::Placeholder::Shape({2, 8, 8, num_channels}));
346         auto side_input_cast =
347             ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_FLOAT);
348         auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
349         auto relu = ops::Relu(s.WithOpName("relu"), add);
350       } else {
351         auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
352       }
353 
354       auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
355       auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
356       auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
357       auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
358                                                                : channel_shape);
359       auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
360                                                               : channel_shape);
361       auto side_input_t =
362           GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
363 
364       GrapplerItem item;
365       item.fetch = {"relu"};
366       if (has_side_input)
367         item.feed = {{"input", input_t},   {"scale", scale_t},
368                      {"offset", offset_t}, {"mean", mean_t},
369                      {"var", var_t},       {"side_input", side_input_t}};
370       else
371         item.feed = {{"input", input_t},
372                      {"scale", scale_t},
373                      {"offset", offset_t},
374                      {"mean", mean_t},
375                      {"var", var_t}};
376       TF_ASSERT_OK(s.ToGraphDef(&item.graph));
377 
378       // Place all nodes on CPU.
379       for (int i = 0; i < item.graph.node_size(); ++i) {
380         item.graph.mutable_node(i)->set_device("/device:CPU:0");
381       }
382 
383       Remapper optimizer(RewriterConfig::AGGRESSIVE);
384       GraphDef output;
385       TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
386 
387       int found = 0;
388       if (has_side_input) {
389         for (const NodeDef& node : output.node()) {
390           if (node.name() == "add") {
391             EXPECT_EQ(node.op(), "Add");
392             ASSERT_EQ(node.input_size(), 2);
393             EXPECT_EQ(node.input(0), "fused_batch_norm");
394             EXPECT_EQ(node.input(1), "side_input_cast");
395             found++;
396           }
397           if (node.name() == "relu") {
398             EXPECT_EQ(node.op(), "Relu");
399             ASSERT_EQ(node.input_size(), 1);
400             EXPECT_EQ(node.input(0), "add");
401             found++;
402           }
403           if (node.name() == "fused_batch_norm") {
404             EXPECT_EQ(node.op(), "FusedBatchNormV3");
405             ASSERT_EQ(node.input_size(), 5);
406             EXPECT_EQ(node.input(0), "input_cast");
407             EXPECT_EQ(node.input(1), "scale");
408             EXPECT_EQ(node.input(2), "offset");
409             EXPECT_EQ(node.input(3), "mean");
410             EXPECT_EQ(node.input(4), "var");
411             found++;
412           }
413         }
414         EXPECT_EQ(found, 3);
415       } else {
416         for (const NodeDef& node : output.node()) {
417           if (node.name() == "relu") {
418             EXPECT_EQ(node.op(), "Identity");
419             ASSERT_EQ(node.input_size(), 1);
420             EXPECT_EQ(node.input(0), "fused_batch_norm");
421             found++;
422           }
423           if (node.name() == "fused_batch_norm") {
424             EXPECT_EQ(node.op(), "_FusedBatchNormEx");
425             ASSERT_EQ(node.input_size(), 5);
426             EXPECT_EQ(node.input(0), "input_cast");
427             EXPECT_EQ(node.input(1), "scale");
428             EXPECT_EQ(node.input(2), "offset");
429             EXPECT_EQ(node.input(3), "mean");
430             EXPECT_EQ(node.input(4), "var");
431 
432             auto attr = node.attr();
433             EXPECT_EQ(attr["num_side_inputs"].i(), 0);
434             EXPECT_EQ(attr["activation_mode"].s(), "Relu");
435             found++;
436           }
437         }
438         EXPECT_EQ(found, 2);
439       }
440 
441       auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
442       ASSERT_EQ(tensors_expected.size(), 1);
443       auto tensors = EvaluateNodes(output, item.fetch, item.feed);
444       ASSERT_EQ(tensors.size(), 1);
445       test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
446     }
447   }
448 }
449 
TEST_F(MklRemapperTest,FuseMatMulWithBiasAddAndAdd)450 TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
451   using ::tensorflow::ops::Placeholder;
452 
453   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
454 
455   auto input_shape = ops::Placeholder::Shape({4, 32});
456   auto input_shape_add = ops::Placeholder::Shape({4, 8});
457   auto filter_shape = ops::Placeholder::Shape({32, 8});
458   auto bias_shape = ops::Placeholder::Shape({8});
459 
460   auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
461   auto input_add =
462       Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
463   auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
464   auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
465 
466   auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
467   auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
468 
469   auto fetch = s.WithOpName("fetch");
470   auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
471 
472   ops::Identity(fetch, add);
473 
474   auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
475       TensorShape(input_shape.shape_.dim_sizes()));
476   auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
477       TensorShape(input_shape_add.shape_.dim_sizes()));
478   auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
479       TensorShape(filter_shape.shape_.dim_sizes()));
480   auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
481       TensorShape(bias_shape.shape_.dim_sizes()));
482 
483   GrapplerItem item;
484   item.fetch = {"fetch"};
485   item.feed = {{"input", input_tensor},
486                {"filter", filter_tensor},
487                {"bias", bias_tensor},
488                {"input_add", input_add_tensor}};
489   TF_CHECK_OK(s.ToGraphDef(&item.graph));
490 
491   // Place all nodes on CPU.
492   for (int i = 0; i < item.graph.node_size(); ++i) {
493     item.graph.mutable_node(i)->set_device("/device:CPU:0");
494   }
495 
496   Remapper optimizer(RewriterConfig::AGGRESSIVE);
497   GraphDef output;
498   TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
499 
500   int found = 0;
501   for (const NodeDef& node : output.node()) {
502     auto fetch_node_name = "add";
503     if (node.name() == fetch_node_name) {
504       EXPECT_EQ("_FusedMatMul", node.op());
505       EXPECT_EQ("input", node.input(0));
506       EXPECT_EQ("filter", node.input(1));
507 
508       EXPECT_EQ(2, node.attr().at("num_args").i());
509       EXPECT_EQ("bias", node.input(2));
510       EXPECT_EQ("input_add", node.input(3));
511 
512       const auto fused_ops = node.attr().at("fused_ops").list().s();
513       EXPECT_EQ(2, fused_ops.size());
514       EXPECT_EQ("BiasAdd", fused_ops[0]);
515       EXPECT_EQ("Add", fused_ops[1]);
516       found++;
517     }
518   }
519   EXPECT_EQ(1, found);
520 
521   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
522   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
523   EXPECT_EQ(1, tensors_expected.size());
524   EXPECT_EQ(1, tensors.size());
525   test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
526 }
527 #endif  // ENABLE_MKLDNN_V1
528 
529 }  // namespace grappler
530 }  // namespace tensorflow
531 #endif  // INTEL_MKL
532