• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/remapper.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/core/framework/versions.pb.h"
20 #include "tensorflow/core/grappler/costs/graph_properties.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/grappler/utils/graph_view.h"
27 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
28 #include "tensorflow/core/grappler/utils/topological_sort.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/core/util/util.h"
33 
34 #if GOOGLE_CUDA
35 #include "third_party/gpus/cudnn/cudnn.h"
36 #endif  // GOOGLE_CUDA
37 
38 #ifdef INTEL_MKL
39 #include "tensorflow/core/graph/mkl_graph_util.h"
40 #endif  // INTEL_MKL
41 
42 namespace tensorflow {
43 namespace grappler {
44 
45 // Supported patterns:
46 //
47 // Conv2D + ... -> _FusedConv2D
48 //   (1) Conv2D + BiasAdd + <Activation>
49 //   (2) Conv2D + FusedBatchNorm + <Activation>
50 //   (3) Conv2D + Squeeze + BiasAdd
51 //
52 // MatMul + ... -> _FusedMatMul:
53 //   (1) MatMul + BiasAdd + <Activation>
54 //
55 // DepthwiseConv2dNative + ... -> _FusedDepthwiseConv2dNative:
56 //   (1) DepthwiseConv2dNative + BiasAdd + <Activation>
57 //
58 // FusedBatchNorm[$is_training] + ... -> _FusedBatchNormEx[$is_training]
59 //   (1) FusedBatchNorm + <Activation>
60 //   (2) FusedBatchNorm + SideInput + <Activation>
61 //
62 // In all cases, the supported activation functions are Relu, Relu6, and Elu.
63 //
64 // Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the
65 // patterns are "ContractionWith...".
66 namespace {
67 
68 constexpr char kFusedConv2D[] = "_FusedConv2D";
69 constexpr char kFusedMatMul[] = "_FusedMatMul";
70 constexpr char kFusedDepthwiseConv2dNative[] = "_FusedDepthwiseConv2dNative";
71 constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
72 
73 constexpr char kDataFormat[] = "data_format";
74 constexpr char kIsTraining[] = "is_training";
75 
76 constexpr int kMissingIndex = -1;
77 
78 struct RemapperContext {
RemapperContexttensorflow::grappler::__anonbf673cd50111::RemapperContext79   explicit RemapperContext(GrapplerItem* item, Status* status)
80       : nodes_to_preserve(item->NodesToPreserve()),
81         graph_view(&item->graph, status),
82         graph_properties(*item),
83         inferred_graph_properties(false) {}
84 
85   std::unordered_set<string> nodes_to_preserve;
86   utils::MutableGraphView graph_view;
87   GraphProperties graph_properties;
88   bool inferred_graph_properties;
89 };
90 
91 // FusedBatchNorm that can be replaced with a cheaper set of primitives.
92 struct FusedBatchNorm {
93   FusedBatchNorm() = default;
FusedBatchNormtensorflow::grappler::__anonbf673cd50111::FusedBatchNorm94   explicit FusedBatchNorm(int fused_batch_norm)
95       : fused_batch_norm(fused_batch_norm) {}
96 
97   int fused_batch_norm = kMissingIndex;
98 };
99 
100 // FusedBatchNorm[$is_training] with fused side input and/or activation.
101 struct FusedBatchNormEx {
102   FusedBatchNormEx() = default;
103 
104   int fused_batch_norm = kMissingIndex;
105   int side_input = kMissingIndex;
106   int activation = kMissingIndex;
107   // Add node that will be invalidated by fusing side input and fused batch norm
108   int invalidated = kMissingIndex;
109 };
110 
111 // Contraction node followed by a BiasAdd.
112 struct ContractionWithBiasAdd {
113   ContractionWithBiasAdd() = default;
ContractionWithBiasAddtensorflow::grappler::__anonbf673cd50111::ContractionWithBiasAdd114   ContractionWithBiasAdd(int contraction, int bias_add)
115       : contraction(contraction), bias_add(bias_add) {}
116 
117   int contraction = kMissingIndex;
118   int bias_add = kMissingIndex;
119 };
120 
121 // Contraction node followed by a BiasAdd and Activation.
122 struct ContractionWithBiasAddAndActivation {
123   ContractionWithBiasAddAndActivation() = default;
ContractionWithBiasAddAndActivationtensorflow::grappler::__anonbf673cd50111::ContractionWithBiasAddAndActivation124   ContractionWithBiasAddAndActivation(int contraction, int bias_add,
125                                       int activation)
126       : contraction(contraction), bias_add(bias_add), activation(activation) {}
127 
128   int contraction = kMissingIndex;
129   int bias_add = kMissingIndex;
130   int activation = kMissingIndex;
131 };
132 
133 // Contraction node followed by a Squeeze and BiasAdd.
134 struct ContractionWithSqueezeAndBiasAdd {
135   ContractionWithSqueezeAndBiasAdd() = default;
ContractionWithSqueezeAndBiasAddtensorflow::grappler::__anonbf673cd50111::ContractionWithSqueezeAndBiasAdd136   ContractionWithSqueezeAndBiasAdd(int contraction, int squeeze, int bias_add)
137       : contraction(contraction), squeeze(squeeze), bias_add(bias_add) {}
138 
139   int contraction = kMissingIndex;
140   int squeeze = kMissingIndex;
141   int bias_add = kMissingIndex;
142 };
143 
144 // Contraction node followed by a FusedBatchNorm.
145 struct ContractionWithBatchNorm {
146   ContractionWithBatchNorm() = default;
ContractionWithBatchNormtensorflow::grappler::__anonbf673cd50111::ContractionWithBatchNorm147   ContractionWithBatchNorm(int contraction, int fused_batch_norm,
148                            float epsilon = 0.0)
149       : contraction(contraction),
150         fused_batch_norm(fused_batch_norm),
151         epsilon(epsilon) {}
152 
153   int contraction = kMissingIndex;
154   int fused_batch_norm = kMissingIndex;
155   float epsilon = 0.0;
156 };
157 
158 // Contraction node followed by a FusedBatchNorm and Activation.
159 struct ContractionWithBatchNormAndActivation {
160   ContractionWithBatchNormAndActivation() = default;
ContractionWithBatchNormAndActivationtensorflow::grappler::__anonbf673cd50111::ContractionWithBatchNormAndActivation161   ContractionWithBatchNormAndActivation(int contraction, int fused_batch_norm,
162                                         int activation, float epsilon = 0.0)
163       : contraction(contraction),
164         fused_batch_norm(fused_batch_norm),
165         activation(activation),
166         epsilon(epsilon) {}
167 
168   int contraction = kMissingIndex;
169   int fused_batch_norm = kMissingIndex;
170   int activation = kMissingIndex;
171   float epsilon = 0.0;
172 };
173 
174 #ifdef INTEL_MKL
175 // Contraction node followed by a BiasAdd and Add.
176 struct ContractionWithBiasAddAndAdd {
177   ContractionWithBiasAddAndAdd() = default;
ContractionWithBiasAddAndAddtensorflow::grappler::__anonbf673cd50111::ContractionWithBiasAddAndAdd178   ContractionWithBiasAddAndAdd(int contraction, int bias_add, int add,
179                                int port_id)
180       : contraction(contraction),
181         bias_add(bias_add),
182         add(add),
183         port_id(port_id) {}
184 
185   int contraction = kMissingIndex;
186   int bias_add = kMissingIndex;
187   int add = kMissingIndex;
188   int port_id = 0;
189 };
190 
191 // Contraction node followed by a BiasAdd, Add and Relu.
192 struct ContractionWithBiasAndAddActivation {
193   ContractionWithBiasAndAddActivation() = default;
ContractionWithBiasAndAddActivationtensorflow::grappler::__anonbf673cd50111::ContractionWithBiasAndAddActivation194   ContractionWithBiasAndAddActivation(int contraction, int bias_add, int add,
195                                       int port_id, int activation)
196       : contraction(contraction),
197         bias_add(bias_add),
198         add(add),
199         port_id(port_id),
200         activation(activation) {}
201 
202   int contraction = kMissingIndex;
203   int bias_add = kMissingIndex;
204   int add = kMissingIndex;
205   int port_id = 0;
206   int activation = kMissingIndex;
207 };
208 #endif  // INTEL_MKL
209 
IsInPreserveSet(const RemapperContext & ctx,const NodeDef * node)210 bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) {
211   return ctx.nodes_to_preserve.count(node->name()) > 0;
212 }
213 
HaveSameDataType(const NodeDef * lhs,const NodeDef * rhs,const string & type_attr="T")214 bool HaveSameDataType(const NodeDef* lhs, const NodeDef* rhs,
215                       const string& type_attr = "T") {
216   DataType lhs_attr = GetDataTypeFromAttr(*lhs, type_attr);
217   DataType rhs_attr = GetDataTypeFromAttr(*rhs, type_attr);
218 
219   return lhs_attr != DT_INVALID && rhs_attr != DT_INVALID &&
220          lhs_attr == rhs_attr;
221 }
222 
HasDataType(const NodeDef * node,const DataType & expected,const string & type_attr="T")223 bool HasDataType(const NodeDef* node, const DataType& expected,
224                  const string& type_attr = "T") {
225   DataType dtype = GetDataTypeFromAttr(*node, type_attr);
226   return dtype == expected;
227 }
228 
IsCpuCompatibleDataType(const NodeDef * contraction,const string & type_attr="T")229 bool IsCpuCompatibleDataType(const NodeDef* contraction,
230                              const string& type_attr = "T") {
231   DataType dtype = GetDataTypeFromAttr(*contraction, type_attr);
232 #if defined(INTEL_MKL)
233 #if defined(ENABLE_INTEL_MKL_BFLOAT16)
234   if (IsConv2D(*contraction) || IsDepthwiseConv2dNative(*contraction) ||
235       IsMatMul(*contraction)) {
236     return dtype == DT_FLOAT || dtype == DT_BFLOAT16;
237 #else
238   if (IsConv2D(*contraction) || IsDepthwiseConv2dNative(*contraction) ||
239       IsMatMul(*contraction)) {
240     return dtype == DT_FLOAT;
241 #endif  // ENABLE_INTEL_MKL_BFLOAT16
242 #else
243   if (IsConv2D(*contraction)) {
244     return dtype == DT_FLOAT || dtype == DT_DOUBLE;
245   } else if (IsMatMul(*contraction)) {
246     return dtype == DT_FLOAT;
247 #endif  // INTEL_MKL
248   } else {
249     return false;
250   }
251 }
252 
253 bool IsGpuCompatibleDataType(const NodeDef* contraction,
254                              const string& type_attr = "T") {
255   DataType dtype = GetDataTypeFromAttr(*contraction, type_attr);
256   if (IsConv2D(*contraction)) {
257     return dtype == DT_FLOAT;
258   } else {
259     return false;
260   }
261 }
262 
263 bool IsCpuCompatibleDataFormat(const NodeDef* conv2d) {
264   DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
265   const string& data_format = conv2d->attr().at(kDataFormat).s();
266 #ifndef INTEL_MKL
267   return data_format == "NHWC";
268 #else
269   return data_format == "NHWC" || data_format == "NCHW";
270 #endif  // !INTEL_MKL
271 }
272 
273 bool IsGpuCompatibleDataFormat(const NodeDef* conv2d) {
274   DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
275   const string& data_format = conv2d->attr().at(kDataFormat).s();
276   return data_format == "NHWC" || data_format == "NCHW";
277 }
278 
279 bool IsCpuCompatibleConv2D(const NodeDef* conv2d) {
280   DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
281   return NodeIsOnCpu(conv2d) && IsCpuCompatibleDataType(conv2d) &&
282          IsCpuCompatibleDataFormat(conv2d);
283 }
284 
285 bool IsGpuCompatibleConv2D(const NodeDef* conv2d) {
286   DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
287   return NodeIsOnGpu(conv2d) && IsGpuCompatibleDataType(conv2d) &&
288          IsGpuCompatibleDataFormat(conv2d);
289 }
290 
291 bool IsCpuCompatibleMatMul(const NodeDef* matmul) {
292   DCHECK(IsMatMul(*matmul)) << "Expected MatMul op";
293   return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul);
294 }
295 
296 bool IsCpuCompatibleDepthwiseConv2dNative(const NodeDef* dw_conv2d) {
297   DCHECK(IsDepthwiseConv2dNative(*dw_conv2d))
298       << "Expected DepthwiseConv2dNative op";
299   return NodeIsOnCpu(dw_conv2d) && IsCpuCompatibleDataType(dw_conv2d);
300 }
301 
302 // Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU.
303 template <typename Pattern>
304 bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) {
305   const NodeDef& node = ctx.graph_view.graph()->node(matched.contraction);
306   if (IsConv2D(node)) {
307     return IsCpuCompatibleConv2D(&node);
308   } else if (IsDepthwiseConv2dNative(node)) {
309 #ifdef INTEL_MKL
310     if (DisableMKL()) {
311       return false;
312     }
313     return IsCpuCompatibleDepthwiseConv2dNative(&node);
314 #else
315     return false;
316 #endif  // INTEL_MKL
317   } else if (IsMatMul(node)) {
318     return IsCpuCompatibleMatMul(&node);
319   } else {
320     return false;
321   }
322 }
323 
324 // Checks if we can rewrite a pattern to the `_FusedConv2D` on GPU device.
325 bool IsGpuCompatible(const RemapperContext& ctx,
326                      const ContractionWithBiasAddAndActivation& matched) {
327 #if TENSORFLOW_USE_ROCM
328   // ROCm does not support _FusedConv2D
329   return false;
330 #endif
331   const GraphDef* graph = ctx.graph_view.graph();
332   const NodeDef& contraction_node = graph->node(matched.contraction);
333   if (!IsConv2D(contraction_node)) return false;
334 
335   const std::vector<OpInfo::TensorProperties>& input_props =
336       ctx.graph_properties.GetInputProperties(contraction_node.name());
337   const TensorShapeProto& filter_shape =
338       input_props.size() >= 2 ? input_props[1].shape() : TensorShapeProto();
339 
340   // FusedConv2D on GPU with 1x1 convolution is marginally faster than
341   // in-graph computation in micro benchmarks (see kernels/conv_ops_test.cc),
342   // and significantly slower in large scale benchmarks.
343   bool is_spatial_conv = Rank(filter_shape) == 4 &&          //
344                          IsKnown(filter_shape.dim(1)) &&     //
345                          IsKnown(filter_shape.dim(2)) &&     //
346                          filter_shape.dim(1).size() != 1 &&  //
347                          filter_shape.dim(2).size() != 1;
348 
349   // We rely on cuDNN for fused convolution, and it currently supports only Relu
350   // activation.
351   const NodeDef& activation_node = graph->node(matched.activation);
352   bool is_relu = IsRelu(activation_node);
353 
354   return is_relu && is_spatial_conv && IsGpuCompatibleConv2D(&contraction_node);
355 }
356 bool IsGpuCompatible(const RemapperContext& ctx,
357                      const ContractionWithBiasAdd& matched) {
358   return false;
359 }
360 bool IsGpuCompatible(const RemapperContext& ctx,
361                      const ContractionWithSqueezeAndBiasAdd& matched) {
362   return false;
363 }
364 
365 // Returns true if the given pattern is supported on the assigned device.
366 template <typename Pattern>
367 bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) {
368   return IsCpuCompatible(ctx, matched) || IsGpuCompatible(ctx, matched);
369 }
370 
371 bool IsSupportedActivation(const NodeDef& node) {
372 #ifdef INTEL_MKL
373   return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node) ||
374          IsTanh(node);
375 #else
376   return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
377 #endif
378 }
379 
380 inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) {
381   return node_view.NumControllingFanins() > 0 ||
382          node_view.NumControlledFanouts() > 0;
383 }
384 
385 // Returns true if at most one fanout reads output at port 0 (output used once).
386 inline bool HasAtMostOneFanoutAtPort0(const utils::MutableNodeView& node_view) {
387   return node_view.GetRegularFanout(0).size() <= 1;
388 }
389 
390 // Returns true if at most one fanout reads actual tensor data at output port 0
391 // (output used once for any data computation).
392 inline bool HasAtMostOneDataFanoutAtPort0(
393     const utils::MutableNodeView& node_view) {
394   const auto predicate = [](const auto& fanout) -> bool {
395     const NodeDef* node = fanout.node_view()->node();
396     return !IsShape(*node) && !IsRank(*node);
397   };
398   return absl::c_count_if(node_view.GetRegularFanout(0), predicate) <= 1;
399 }
400 
401 bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
402                              ContractionWithBiasAdd* matched,
403                              bool check_device_compatible = true) {
404   const auto* node_view = ctx.graph_view.GetNode(node_index);
405   // Root of the pattern must be a BiasAdd.
406   // TODO(lyandy): Forward controls for patterns with control dependencies.
407   if (HasControlFaninOrFanout(*node_view)) return false;
408 
409   const auto* node_def = node_view->node();
410   if (!IsBiasAdd(*node_def)) return false;
411 
412   // Input to the BiasAdd must be a Conv2D or a MatMul.
413   if (node_view->NumRegularFanins() < 1) return false;
414   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
415   const auto* contraction_node_view = regular_fanin_0.node_view();
416   const auto* contraction_node_def = contraction_node_view->node();
417 
418   // Conv2D, MatMul or DepthwiseConv2D
419   bool is_contraction = IsConv2D(*contraction_node_def) ||
420                         IsMatMul(*contraction_node_def) ||
421                         IsDepthwiseConv2dNative(*contraction_node_def);
422 
423   if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) ||
424       HasControlFaninOrFanout(*contraction_node_view) ||
425       !HasAtMostOneFanoutAtPort0(*contraction_node_view) ||
426       IsInPreserveSet(ctx, contraction_node_def))
427     return false;
428 
429   // Check that data type and data format are supported on assigned device.
430   const ContractionWithBiasAdd pattern{contraction_node_view->node_index(),
431                                        node_index};
432   if (check_device_compatible && !IsDeviceCompatible(ctx, pattern))
433     return false;
434 
435   // We successfully found a {Conv2D, MatMul}+BiasAdd pattern.
436   *matched = pattern;
437 
438   return true;
439 }
440 
441 bool FindContractionWithBiasAndActivation(
442     const RemapperContext& ctx, int node_index,
443     ContractionWithBiasAddAndActivation* matched) {
444   const auto* node_view = ctx.graph_view.GetNode(node_index);
445   // Root of the pattern must be an activation node.
446   // TODO(lyandy): Forward controls for patterns with control dependencies.
447   if (HasControlFaninOrFanout(*node_view)) return false;
448 
449   const auto* node_def = node_view->node();
450   if (!IsSupportedActivation(*node_def)) return false;
451 
452   // And input to the activation node must match ContractionWithBiasAdd pattern.
453   if (node_view->NumRegularFanins() < 1) return false;
454   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
455   const auto* bias_add_node_view = regular_fanin_0.node_view();
456   const auto* bias_add_node_def = bias_add_node_view->node();
457 
458   ContractionWithBiasAdd base;
459   if (!FindContractionWithBias(ctx, bias_add_node_view->node_index(), &base,
460                                /*check_device_compatible=*/false) ||
461       !HasAtMostOneFanoutAtPort0(*bias_add_node_view) ||
462       !HaveSameDataType(node_def, bias_add_node_def) ||
463       IsInPreserveSet(ctx, bias_add_node_def))
464     return false;
465 
466   // Get the contraction node
467   const auto* contraction_node_view =
468       bias_add_node_view->GetRegularFanin(0).node_view();
469   const auto* contraction_node_def = contraction_node_view->node();
470 
471   // Currently, only matmul + bias + tanh is enable
472   if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false;
473 
474   // Currently, only (conv | matmul) + bias + leakyrelu is enabled
475   if (!(IsConv2D(*contraction_node_def) || IsMatMul(*contraction_node_def)) &&
476       IsLeakyRelu(*node_def))
477     return false;
478 
479   // Check that data type and data format are supported on assigned device.
480   const ContractionWithBiasAddAndActivation pattern{base.contraction,
481                                                     base.bias_add, node_index};
482   if (!IsDeviceCompatible(ctx, pattern)) return false;
483 
484   // We successfully found a {Conv2D, MatMul}+BiasAdd+Activation pattern.
485   *matched = pattern;
486 
487   return true;
488 }
489 
490 bool FindConv2DWithSqueezeAndBias(const RemapperContext& ctx, int node_index,
491                                   ContractionWithSqueezeAndBiasAdd* matched) {
492   const auto* node_view = ctx.graph_view.GetNode(node_index);
493   // TODO(lyandy): Forward controls for patterns with control dependencies.
494   if (HasControlFaninOrFanout(*node_view)) return false;
495 
496   // Root of the pattern must be a BiasAdd.
497   const auto* node_def = node_view->node();
498   if (!IsBiasAdd(*node_def)) return false;
499 
500   // Input to the BiasAdd must be a Squeeze.
501   if (node_view->NumRegularFanins() < 1) return false;
502   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
503   const auto* squeeze_node_view = regular_fanin_0.node_view();
504   const auto* squeeze_node_def = squeeze_node_view->node();
505 
506   if (!IsSqueeze(*squeeze_node_def) ||
507       !HaveSameDataType(node_def, squeeze_node_def, "T") ||
508       HasControlFaninOrFanout(*squeeze_node_view) ||
509       !HasAtMostOneFanoutAtPort0(*squeeze_node_view) ||
510       IsInPreserveSet(ctx, squeeze_node_def))
511     return false;
512 
513   // Squeeze must not squeeze output channel dimension.
514   std::vector<int32> dims;
515   if (!TryGetNodeAttr(*squeeze_node_def, "squeeze_dims", &dims)) return false;
516   for (auto dim : dims) {
517     if (dim == 3) return false;
518   }
519 
520   // Input to the Squeeze must be a Conv2D.
521   if (squeeze_node_view->NumRegularFanins() < 1) return false;
522   const auto& squeeze_regular_fanin_0 = squeeze_node_view->GetRegularFanin(0);
523   const auto* conv2d_node_view = squeeze_regular_fanin_0.node_view();
524   const auto* conv2d_node_def = conv2d_node_view->node();
525 
526   if (!IsConv2D(*conv2d_node_def) ||
527       !HaveSameDataType(node_def, conv2d_node_def, "T") ||
528       HasControlFaninOrFanout(*conv2d_node_view) ||
529       !HasAtMostOneFanoutAtPort0(*conv2d_node_view) ||
530       IsInPreserveSet(ctx, conv2d_node_def))
531     return false;
532 
533   // Check that data type and data format are supported on assigned device.
534   const ContractionWithSqueezeAndBiasAdd pattern{
535       conv2d_node_view->node_index(), squeeze_node_view->node_index(),
536       node_index};
537   if (!IsDeviceCompatible(ctx, pattern)) return false;
538 
539   // We successfully found a Conv2D+Squeeze+BiasAdd pattern.
540   *matched = pattern;
541 
542   return true;
543 }
544 
545 bool FindConv2DWithBatchNorm(const RemapperContext& ctx, int node_index,
546                              ContractionWithBatchNorm* matched) {
547   const auto* node_view = ctx.graph_view.GetNode(node_index);
548   const auto* node_def = node_view->node();
549   // Root of the pattern must be a FusedBatchNorm.
550   if (!IsFusedBatchNorm(*node_def)) return false;
551 
552   // FusedBatchNormV2 and V3 have an extra type parameter.
553   if (node_view->GetOp() != "FusedBatchNorm" &&
554       !HasDataType(node_def, DT_FLOAT, "U"))
555     return false;
556 
557   // Check that batch normalization is in inference mode.
558   const auto* training_attr = node_view->GetAttr(kIsTraining);
559   if (training_attr != nullptr && training_attr->b()) return false;
560 
561   // Check that only 0th output is consumed by other nodes.
562   // TODO(lyandy): Forward controls for patterns with control dependencies.
563   if (HasControlFaninOrFanout(*node_view) ||
564       !node_view->GetRegularFanout(1).empty() ||  // batch_mean
565       !node_view->GetRegularFanout(2).empty() ||  // batch_variance
566       !node_view->GetRegularFanout(3).empty() ||  // reserve_space_1
567       !node_view->GetRegularFanout(4).empty())    // reserve_space_2
568     return false;
569 
570   // Input to the FusedBatchNorm must be a Conv2D.
571   if (node_view->NumRegularFanins() < 1) return false;
572   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
573   const auto* conv2d_node_view = regular_fanin_0.node_view();
574   const auto* conv2d_node_def = conv2d_node_view->node();
575 
576   if (!IsConv2D(*conv2d_node_def) || !NodeIsOnCpu(conv2d_node_def) ||
577       !HaveSameDataType(node_def, conv2d_node_def) ||
578       !IsCpuCompatibleDataType(conv2d_node_def) ||
579       !IsCpuCompatibleDataFormat(conv2d_node_def) ||
580       HasControlFaninOrFanout(*conv2d_node_view) ||
581       !HasAtMostOneFanoutAtPort0(*conv2d_node_view) ||
582       IsInPreserveSet(ctx, conv2d_node_def))
583     return false;
584 
585   // We successfully found a Conv2D+FusedBatchNorm pattern.
586   matched->contraction = conv2d_node_view->node_index();
587   matched->fused_batch_norm = node_index;
588   if (!TryGetNodeAttr(*node_def, "epsilon", &matched->epsilon)) return false;
589 
590   return true;
591 }
592 
593 bool FindConv2DWithBatchNormAndActivation(
594     const RemapperContext& ctx, int node_index,
595     ContractionWithBatchNormAndActivation* matched) {
596   const auto* node_view = ctx.graph_view.GetNode(node_index);
597   // TODO(lyandy): Forward controls for patterns with control dependencies.
598   if (HasControlFaninOrFanout(*node_view)) return false;
599 
600   // Root of the pattern must be an activation node.
601   const auto* node_def = node_view->node();
602   if (!IsSupportedActivation(*node_def)) return false;
603 
604   // And input to the activation node must match Conv2DWithBatchNorm pattern.
605   if (node_view->NumRegularFanins() < 1) return false;
606   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
607   const auto* batch_norm_node_view = regular_fanin_0.node_view();
608 
609   ContractionWithBatchNorm base;
610   if (!FindConv2DWithBatchNorm(ctx, batch_norm_node_view->node_index(), &base))
611     return false;
612 
613   const auto* fused_batch_norm_node_view =
614       ctx.graph_view.GetNode(base.fused_batch_norm);
615   const auto* fused_batch_norm_node_def = fused_batch_norm_node_view->node();
616   if (!HasAtMostOneFanoutAtPort0(*fused_batch_norm_node_view) ||
617       !HaveSameDataType(node_def, fused_batch_norm_node_def) ||
618       IsInPreserveSet(ctx, fused_batch_norm_node_def))
619     return false;
620 
621   // We successfully found a Conv2D+FusedBatchNorm+Activation pattern.
622   matched->contraction = base.contraction;
623   matched->fused_batch_norm = base.fused_batch_norm;
624   matched->activation = node_index;
625   matched->epsilon = base.epsilon;
626 
627   return true;
628 }
629 
630 #ifdef INTEL_MKL
631 // As AddN has multiple inputs, this function tries to find Conv2D + Bias
632 // pattern in specific input port.
633 bool FindContractionWithBiasInPort(const RemapperContext& ctx,
634                                    const utils::MutableNodeView& add_node_view,
635                                    const NodeDef& add_node_def, int port_id,
636                                    ContractionWithBiasAdd* base) {
637   // Input to AddN must match ContractionWithBiasAdd pattern.
638   if (add_node_view.NumRegularFanins() < port_id + 1) return false;
639   const auto& bias_add_node_view =
640       add_node_view.GetRegularFanin(port_id).node_view();
641   if (bias_add_node_view == nullptr) return false;
642   const auto* bias_add_node_def = bias_add_node_view->node();
643 
644   if (!FindContractionWithBias(ctx, bias_add_node_view->node_index(), base,
645                                /*check_device_compatible=*/false))
646     return false;
647   if (!HasAtMostOneFanoutAtPort0(*bias_add_node_view) ||
648       !HaveSameDataType(&add_node_def, bias_add_node_def) ||
649       IsInPreserveSet(ctx, bias_add_node_def))
650     return false;
651   return true;
652 }
653 
654 bool IsAddWithNoBroadcast(const RemapperContext& ctx, const NodeDef& node) {
655   if (!IsAdd(node)) return false;
656 
657   // Check if this is case of broadcasting - Add node supports broadcasting.
658   const auto& props = ctx.graph_properties.GetInputProperties(node.name());
659   if (props.size() == 2 &&
660       ShapesSymbolicallyEqual(props[0].shape(), props[1].shape())) {
661     return true;
662   }
663   return false;
664 }
665 
666 bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx,
667                                       const utils::MutableNodeView& node_view,
668                                       ContractionWithBiasAddAndAdd* matched) {
669   if (DisableMKL()) return false;
670   // Fusion with AddN is supported only when it has two inputs.
671   // TODO(lyandy): Forward controls for patterns with control dependencies.
672   if (HasControlFaninOrFanout(node_view) || node_view.NumRegularFanins() != 2)
673     return false;
674 
675   // Root of the pattern must be a AddN or Add with same input shapes
676   // (no broadcasting).
677   const auto* node_def = node_view.node();
678   if (!IsAddN(*node_def) && !IsAddWithNoBroadcast(ctx, *node_def)) return false;
679 
680 #ifdef ENABLE_INTEL_MKL_BFLOAT16
681   // MKL AddN ops only support float and bfloat16 data types.
682   if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16))
683     return false;
684 #else
685   // MKL AddN ops only support float data type.
686   if (!HasDataType(node_def, DT_FLOAT)) return false;
687 #endif  // ENABLE_INTEL_MKL_BFLOAT16
688 
689   ContractionWithBiasAdd base;
690   matched->port_id = 0;
691 
692   // Find the conv+bias pattern in specific port.
693   if (!FindContractionWithBiasInPort(ctx, node_view, *node_def,
694                                      matched->port_id, &base)) {
695     matched->port_id = 1;
696     if (!FindContractionWithBiasInPort(ctx, node_view, *node_def,
697                                        matched->port_id, &base)) {
698       return false;
699     }
700   }
701 
702   // We successfully found a Conv2D+BiasAdd+{AddN,Add} pattern.
703   matched->contraction = base.contraction;
704   matched->bias_add = base.bias_add;
705   matched->add = node_view.node_index();
706 
707   return true;
708 }
709 
710 bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx,
711                                       int node_index,
712                                       ContractionWithBiasAddAndAdd* matched) {
713   const auto* node_view = ctx.graph_view.GetNode(node_index);
714   return FindContractionWithBiasAddAndAdd(ctx, *node_view, matched);
715 }
716 
717 bool FindContractionWithBiasAndAddActivation(
718     const RemapperContext& ctx, int node_index,
719     ContractionWithBiasAndAddActivation* matched) {
720   if (DisableMKL()) return false;
721   const auto* node_view = ctx.graph_view.GetNode(node_index);
722   // TODO(lyandy): Forward controls for patterns with control dependencies.
723   if (HasControlFaninOrFanout(*node_view)) return false;
724 
725   // Root of the pattern must be an activation node.
726   const auto* node_def = node_view->node();
727   if (node_def == nullptr) return false;
728   if (!IsSupportedActivation(*node_def)) return false;
729 
730   // Currently, Contraction + Bias + Add + Tanh pattern is not supported
731   if (IsTanh(*node_def)) return false;
732 
733 #ifdef ENABLE_INTEL_MKL_BFLOAT16
734   // MKL activation op only supports float and bfloat16 data types.
735   if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16))
736     return false;
737 #else
738   // MKL activation op only supports float data type.
739   if (!HasDataType(node_def, DT_FLOAT)) return false;
740 #endif  // ENABLE_INTEL_MKL_BFLOAT16
741 
742   // And input to activation must match ContractionWithBiasAddAndAdd pattern.
743   if (node_view->NumRegularFanins() < 1) return false;
744   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
745   const auto* add_node_view = regular_fanin_0.node_view();
746 
747   ContractionWithBiasAddAndAdd base;
748 
749   if (!FindContractionWithBiasAddAndAdd(ctx, *add_node_view, &base)) {
750     return false;
751   }
752 
753   // Get the contraction node
754   const auto* bias_add_node_view =
755       add_node_view->GetRegularFanin(base.port_id).node_view();
756   const auto* contraction_node_view =
757       bias_add_node_view->GetRegularFanin(0).node_view();
758   const auto* contraction_node_def = contraction_node_view->node();
759 
760   // Currently, only conv + bias + add + leakyrelu is enabled
761   if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
762 
763   // We successfully found a Conv2D+BiasAdd+AddN+activation pattern.
764   const ContractionWithBiasAndAddActivation pattern{
765       base.contraction, base.bias_add, base.add, base.port_id, node_index};
766   *matched = pattern;
767 
768   return true;
769 }
770 #endif
771 
772 bool FindFusedBatchNorm(const RemapperContext& ctx, int node_index,
773                         FusedBatchNorm* matched) {
774   const auto* node_view = ctx.graph_view.GetNode(node_index);
775   const auto* node_def = node_view->node();
776   if (!IsFusedBatchNorm(*node_def)) return false;
777   if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
778 
779   // Check that the node is in inference mode.
780   bool is_training = true;
781   if (!TryGetNodeAttr(*node_def, kIsTraining, &is_training)) return false;
782   if (is_training) return false;
783 
784   const auto& props = ctx.graph_properties.GetInputProperties(node_def->name());
785 
786   // a. Scaling factor can be const folded:
787   //      scaling_factor = (variance + epsilon).rsqrt() * scale
788   bool const_scaling_factor =
789       props.size() == 5 &&     // [x, scale, offset, mean, variance]
790       props[1].has_value() &&  // scale
791       props[4].has_value();    // variance aka estimated variance
792 
793   // b. Or input can be const folded into some other expression.
794   auto const_inputs = std::count_if(
795       props.begin(), props.end(),
796       [](const OpInfo::TensorProperties& props) { return props.has_value(); });
797 
798   // TODO(bsteiner): use the cost model to compare the cost of fused batch
799   // norm against that of the optimized form.
800   bool can_remap = const_scaling_factor || const_inputs >= 4;
801   if (!can_remap) return false;
802 
803   // The optimized version only generates the first output.
804   if (node_view->GetRegularFanouts().size() > 1) {
805     return false;
806   }
807 
808   // We found a fused batch norm node that can be replaced with primitive ops.
809   matched->fused_batch_norm = node_index;
810 
811   return true;
812 }
813 
814 // NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
815 // `tensorflow/stream_executor/cuda/cuda_dnn.cc` for details.
816 bool BatchnormSpatialPersistentEnabled() {
817 #if CUDNN_VERSION >= 7402
818   static bool is_enabled = [] {
819     bool is_enabled = false;
820     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
821         "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
822         /*default_val=*/false, &is_enabled));
823     return is_enabled;
824   }();
825   return is_enabled;
826 #else
827   return false;
828 #endif
829 }
830 
831 bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
832                           FusedBatchNormEx* matched) {
833   // Root of the pattern must be a Relu.
834   // TODO(ezhulenev): Forward control dependencies.
835   const auto* node_view = ctx.graph_view.GetNode(node_index);
836   const auto* node_def = node_view->node();
837   // TODO(lyandy): Forward controls for patterns with control dependencies.
838   if (!IsRelu(*node_def) || HasControlFaninOrFanout(*node_view)) return false;
839 
840   // Returns true iff the node is a compatible FusedBatchNorm node.
841   const auto valid_batch_norm =
842       [&](const utils::MutableNodeView& fused_batch_norm) -> bool {
843     const auto* fused_batch_norm_node_def = fused_batch_norm.node();
844     if (!IsFusedBatchNorm(*fused_batch_norm_node_def)) return false;
845 
846 #ifndef ENABLE_MKLDNN_V1
847     // We fuse FusedBatchNorm on GPU or MKL CPU.
848     if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false;
849 #else
850     if (DisableMKL()) return false;
851 #endif
852 
853     DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T");
854 #ifndef ENABLE_MKLDNN_V1
855     if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
856 #else
857     if (t_dtype != DT_FLOAT && t_dtype != DT_BFLOAT16) return false;
858 #endif
859 
860     // Get the FusedBatchNorm training mode.
861     bool is_training;
862     if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
863              .ok())
864       return false;
865     // In training mode we rely on cuDNN for computing FusedBatchNorm with side
866     // inputs and activation, and it has its own limitations. In inference mode
867     // we have a custom CUDA kernel that doesn't not have these constraints.
868     if (is_training && NodeIsOnGpu(fused_batch_norm_node_def)) {
869       // cuDNN only supports NHWC data layout.
870       string data_format;
871       if (!GetNodeAttr(*fused_batch_norm_node_def, kDataFormat, &data_format)
872                .ok())
873         return false;
874       if (data_format != "NHWC") return false;
875 
876       // Data type must be DT_HALF.
877       if (t_dtype != DT_HALF) return false;
878 
879       // Channel dimension must be a multiple of 4.
880       const auto& props = ctx.graph_properties.GetInputProperties(
881           fused_batch_norm_node_def->name());
882 
883       const bool valid_channel_dim = !props.empty() &&
884                                      props[0].shape().dim_size() == 4 &&
885                                      props[0].shape().dim(3).size() % 4 == 0;
886       if (!valid_channel_dim) return false;
887 
888       // cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode.
889       if (!BatchnormSpatialPersistentEnabled()) return false;
890     }
891 
892     // FusedBatchNormV2 and V3 have an extra type parameter.
893     if ((fused_batch_norm_node_def->op() != "FusedBatchNorm") &&
894         !HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U"))
895       return false;
896 
897     // Check that only one node consumes the 0-th output of a FusedBatchNorm.
898     if (HasControlFaninOrFanout(fused_batch_norm) ||
899         !HasAtMostOneDataFanoutAtPort0(fused_batch_norm) ||
900         IsInPreserveSet(ctx, fused_batch_norm_node_def))
901       return false;
902 
903     return true;
904   };
905 
906   if (node_view->NumRegularFanins() < 1) return false;
907   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
908   const auto* relu_fanin_0_node_view = regular_fanin_0.node_view();
909   const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
910 
911   // Input to a Relu can be a FusedBatchNorm.
912   if (valid_batch_norm(*relu_fanin_0_node_view)) {
913     matched->activation = node_index;
914     matched->fused_batch_norm = regular_fanin_0.node_index();
915     return true;
916   }
917 
918   // Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
919   if (IsAdd(*relu_fanin_0_node_def)) {
920     // Currently no CPU implementation for "FusedBatchNorm + SideInput +
921     // <Activation>""
922 #ifdef ENABLE_MKLDNN_V1
923     return false;
924 #endif
925 
926     // Check that only Relu node consumes the output of an Add node.
927     if (HasControlFaninOrFanout(*relu_fanin_0_node_view) ||
928         !HasAtMostOneFanoutAtPort0(*relu_fanin_0_node_view) ||
929         IsInPreserveSet(ctx, relu_fanin_0_node_def))
930       return false;
931 
932     // Add node supports broadcasting, FusedBatchNormEx does not.
933     const auto& props =
934         ctx.graph_properties.GetInputProperties(relu_fanin_0_node_def->name());
935     if (props.size() < 2 ||
936         !ShapesSymbolicallyEqual(props[0].shape(), props[1].shape()))
937       return false;
938 
939     if (relu_fanin_0_node_view->NumRegularFanins() < 2) return false;
940     const auto& add_regular_fanin_0 =
941         relu_fanin_0_node_view->GetRegularFanin(0);
942     const auto& add_regular_fanin_1 =
943         relu_fanin_0_node_view->GetRegularFanin(1);
944 
945     if (valid_batch_norm(*add_regular_fanin_0.node_view())) {
946       matched->activation = node_index;
947       matched->side_input = add_regular_fanin_1.node_index();
948       matched->fused_batch_norm = add_regular_fanin_0.node_index();
949       matched->invalidated = regular_fanin_0.node_index();
950       return true;
951     }
952 
953     if (valid_batch_norm(*add_regular_fanin_1.node_view())) {
954       matched->activation = node_index;
955       matched->side_input = add_regular_fanin_0.node_index();
956       matched->fused_batch_norm = add_regular_fanin_1.node_index();
957       matched->invalidated = regular_fanin_0.node_index();
958       return true;
959     }
960   }
961 
962   return false;
963 }
964 
965 void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d,
966                           const NodeDef* activation = nullptr) {
967   DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D";
968 
969   auto* attr = fused_conv2d->mutable_attr();
970   auto& src_attr = conv2d.attr();
971 
972   (*attr)["T"] = src_attr.at("T");
973   (*attr)["strides"] = src_attr.at("strides");
974   (*attr)["padding"] = src_attr.at("padding");
975   (*attr)["explicit_paddings"] = src_attr.at("explicit_paddings");
976   (*attr)["dilations"] = src_attr.at("dilations");
977   (*attr)["data_format"] = src_attr.at("data_format");
978   (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
979   // Copy LeakyRelu's attr alpha to FusedConv2D's attr leakyrelu_alpha
980   if (activation != nullptr && IsLeakyRelu(*activation)) {
981     auto& activation_attr = activation->attr();
982     (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
983   }
984 }
985 
986 void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
987                                          NodeDef* fused_dw_conv2d,
988                                          const NodeDef* activation = nullptr) {
989   DCHECK(IsDepthwiseConv2dNative(dw_conv2d))
990       << "Input node must be a DepthwiseConv2dNative";
991 
992   auto* attr = fused_dw_conv2d->mutable_attr();
993   auto& src_attr = dw_conv2d.attr();
994 
995   (*attr)["T"] = src_attr.at("T");
996   (*attr)["strides"] = src_attr.at("strides");
997   (*attr)["padding"] = src_attr.at("padding");
998   (*attr)["dilations"] = src_attr.at("dilations");
999   (*attr)["data_format"] = src_attr.at("data_format");
1000   // Copy LeakyRelu's attr alpha to FusedDepthwiseConv2d's attr leakyrelu_alpha
1001   if (activation != nullptr && IsLeakyRelu(*activation)) {
1002     auto& activation_attr = activation->attr();
1003     (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
1004   }
1005 }
1006 
1007 void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
1008                                   NodeDef* fused_batch_norm_ex) {
1009   DCHECK(IsFusedBatchNorm(fused_batch_norm))
1010       << "Input node must be a FusedBatchNorm";
1011 
1012   auto* attr = fused_batch_norm_ex->mutable_attr();
1013   auto src_attr = fused_batch_norm.attr();
1014 
1015   (*attr)["T"] = src_attr.at("T");
1016   (*attr)["is_training"] = src_attr.at("is_training");
1017   (*attr)["data_format"] = src_attr.at("data_format");
1018   (*attr)["epsilon"] = src_attr.at("epsilon");
1019   (*attr)["exponential_avg_factor"] = src_attr.at("exponential_avg_factor");
1020 
1021   // FusedBatchNormV2 and V3 have an extra type parameter.
1022   if (fused_batch_norm.op() != "FusedBatchNorm") {
1023     SetAttrValue(src_attr.at("U"), &(*attr)["U"]);
1024   } else {
1025 #ifndef ENABLE_MKLDNN_V1
1026     SetAttrValue(src_attr.at("T"), &(*attr)["U"]);
1027 #else
1028     SetAttrValue(DT_FLOAT, &(*attr)["U"]);
1029 #endif
1030   }
1031 }
1032 
1033 void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul,
1034                           const NodeDef* activation = nullptr) {
1035   DCHECK(IsMatMul(matmul)) << "Input node must be a MatMul";
1036 
1037   auto* attr = fused_matmul->mutable_attr();
1038   auto& src_attr = matmul.attr();
1039 
1040   (*attr)["T"] = src_attr.at("T");
1041   (*attr)["transpose_a"] = src_attr.at("transpose_a");
1042   (*attr)["transpose_b"] = src_attr.at("transpose_b");
1043   // Copy LeakyRelu's attr alpha to _FusedMatMul's attr leakyrelu_alpha
1044   if (activation != nullptr && IsLeakyRelu(*activation)) {
1045     auto& activation_attr = activation->attr();
1046     (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
1047   }
1048 }
1049 
1050 void SetFusedOpAttributes(NodeDef* fused,
1051                           const absl::Span<const absl::string_view> fused_ops,
1052                           int num_args = 1, float epsilon = 0.0) {
1053   auto* attr = fused->mutable_attr();
1054   SetAttrValue(fused_ops, &(*attr)["fused_ops"]);
1055   SetAttrValue(num_args, &(*attr)["num_args"]);
1056   SetAttrValue(epsilon, &(*attr)["epsilon"]);  // required only for BatchNorm
1057 }
1058 
1059 Status AddFusedContractionNode(RemapperContext* ctx,
1060                                const ContractionWithBiasAdd& matched,
1061                                std::vector<bool>* invalidated_nodes,
1062                                std::vector<bool>* nodes_to_delete) {
1063   DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
1064 
1065   const GraphDef* graph = ctx->graph_view.graph();
1066   const NodeDef& contraction = graph->node(matched.contraction);
1067   const NodeDef& bias_add = graph->node(matched.bias_add);
1068   VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd: "
1069           << " bias_add=" << bias_add.name()
1070           << " contraction=" << contraction.name();
1071 
1072   NodeDef fused_op;
1073   fused_op.set_name(bias_add.name());
1074   fused_op.set_device(contraction.device());
1075   fused_op.add_input(contraction.input(0));  // 0: input
1076   fused_op.add_input(contraction.input(1));  // 1: filter
1077   fused_op.add_input(bias_add.input(1));     // 2: bias
1078 
1079   if (IsConv2D(contraction)) {
1080     fused_op.set_op(kFusedConv2D);
1081     CopyConv2DAttributes(contraction, &fused_op);
1082   } else if (IsDepthwiseConv2dNative(contraction)) {
1083     fused_op.set_op(kFusedDepthwiseConv2dNative);
1084     CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
1085   } else if (IsMatMul(contraction)) {
1086     fused_op.set_op(kFusedMatMul);
1087     CopyMatMulAttributes(contraction, &fused_op);
1088   }
1089 
1090   SetFusedOpAttributes(&fused_op, {"BiasAdd"});
1091 
1092   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1093   Status status;
1094   mutation->AddNode(std::move(fused_op), &status);
1095   TF_RETURN_IF_ERROR(status);
1096   TF_RETURN_IF_ERROR(mutation->Apply());
1097 
1098   (*invalidated_nodes)[matched.bias_add] = true;
1099   (*nodes_to_delete)[matched.contraction] = true;
1100 
1101   return Status::OK();
1102 }
1103 
1104 Status AddFusedContractionNode(
1105     RemapperContext* ctx, const ContractionWithBiasAddAndActivation& matched,
1106     std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
1107   DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
1108 
1109   const GraphDef* graph = ctx->graph_view.graph();
1110   const NodeDef& contraction = graph->node(matched.contraction);
1111   const NodeDef& bias_add = graph->node(matched.bias_add);
1112   const NodeDef& activation = graph->node(matched.activation);
1113 
1114   VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd and "
1115           << activation.op() << ":"
1116           << " activation=" << activation.name()
1117           << " bias_add=" << bias_add.name()
1118           << " contraction=" << contraction.name();
1119 
1120   NodeDef fused_op;
1121   fused_op.set_name(activation.name());
1122   fused_op.set_device(contraction.device());
1123   fused_op.add_input(contraction.input(0));  // 0: input
1124   fused_op.add_input(contraction.input(1));  // 1: filter
1125   fused_op.add_input(bias_add.input(1));     // 2: bias
1126 
1127   if (IsConv2D(contraction)) {
1128     fused_op.set_op(kFusedConv2D);
1129     // leaky relu has a special attribute alpha
1130     CopyConv2DAttributes(contraction, &fused_op, &activation);
1131   } else if (IsDepthwiseConv2dNative(contraction)) {
1132     fused_op.set_op(kFusedDepthwiseConv2dNative);
1133     CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
1134   } else if (IsMatMul(contraction)) {
1135     fused_op.set_op(kFusedMatMul);
1136     CopyMatMulAttributes(contraction, &fused_op, &activation);
1137   }
1138 
1139   SetFusedOpAttributes(&fused_op, {"BiasAdd", activation.op()});
1140 
1141   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1142   Status status;
1143   mutation->AddNode(std::move(fused_op), &status);
1144   TF_RETURN_IF_ERROR(status);
1145   TF_RETURN_IF_ERROR(mutation->Apply());
1146 
1147   (*nodes_to_delete)[matched.contraction] = true;
1148   (*nodes_to_delete)[matched.bias_add] = true;
1149   (*invalidated_nodes)[matched.activation] = true;
1150 
1151   return Status::OK();
1152 }
1153 
1154 Status AddFusedConv2DNode(RemapperContext* ctx,
1155                           const ContractionWithSqueezeAndBiasAdd& matched,
1156                           std::vector<bool>* invalidated_nodes,
1157                           std::vector<bool>* nodes_to_delete) {
1158   DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
1159 
1160   const GraphDef* graph = ctx->graph_view.graph();
1161   const NodeDef& contraction = graph->node(matched.contraction);
1162   DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";
1163 
1164   const NodeDef& bias_add = graph->node(matched.bias_add);
1165   const NodeDef& squeeze = graph->node(matched.squeeze);
1166   VLOG(2) << "Fuse Conv2D with Squeeze and BiasAdd: "
1167           << " bias_add=" << bias_add.name() << " squeeze=" << squeeze.name()
1168           << " conv2d=" << contraction.name();
1169 
1170   // Replace Conv2D node with a fused Conv2D. Matched pattern guarantees that it
1171   // has single consumer (only the squeeze node).
1172   NodeDef fused_conv2d;
1173   fused_conv2d.set_name(contraction.name());
1174   fused_conv2d.set_op(kFusedConv2D);
1175   fused_conv2d.set_device(contraction.device());
1176   fused_conv2d.add_input(contraction.input(0));  // 0: input
1177   fused_conv2d.add_input(contraction.input(1));  // 1: filter
1178   fused_conv2d.add_input(bias_add.input(1));     // 2: bias
1179 
1180   CopyConv2DAttributes(contraction, &fused_conv2d);
1181   SetFusedOpAttributes(&fused_conv2d, {"BiasAdd"});
1182 
1183   // Replace BiasAdd node with a Squeeze.
1184   NodeDef remapped_squeeze = squeeze;
1185   remapped_squeeze.set_name(bias_add.name());
1186   remapped_squeeze.set_input(0, contraction.name());
1187 
1188   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1189   Status status;
1190   mutation->AddNode(std::move(fused_conv2d), &status);
1191   TF_RETURN_IF_ERROR(status);
1192   mutation->AddNode(std::move(remapped_squeeze), &status);
1193   TF_RETURN_IF_ERROR(status);
1194   TF_RETURN_IF_ERROR(mutation->Apply());
1195 
1196   (*invalidated_nodes)[matched.contraction] = true;
1197   (*invalidated_nodes)[matched.bias_add] = true;
1198   (*nodes_to_delete)[matched.squeeze] = true;
1199 
1200   return Status::OK();
1201 }
1202 
1203 Status AddFusedConv2DNode(RemapperContext* ctx,
1204                           const ContractionWithBatchNorm& matched,
1205                           std::vector<bool>* invalidated_nodes,
1206                           std::vector<bool>* nodes_to_delete) {
1207   const GraphDef* graph = ctx->graph_view.graph();
1208   const NodeDef& contraction = graph->node(matched.contraction);
1209   DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";
1210   const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
1211   VLOG(2) << "Fuse Conv2D with BatchNorm: batch_norm="
1212           << fused_batch_norm.name() << " conv2d=" << contraction.name();
1213 
1214   NodeDef fused_conv2d;
1215   fused_conv2d.set_name(fused_batch_norm.name());
1216   fused_conv2d.set_op(kFusedConv2D);
1217   fused_conv2d.set_device(contraction.device());
1218   fused_conv2d.add_input(contraction.input(0));       // 0: input
1219   fused_conv2d.add_input(contraction.input(1));       // 1: filter
1220   fused_conv2d.add_input(fused_batch_norm.input(1));  // 2: scale
1221   fused_conv2d.add_input(fused_batch_norm.input(2));  // 3: offset
1222   fused_conv2d.add_input(fused_batch_norm.input(3));  // 4: mean
1223   fused_conv2d.add_input(fused_batch_norm.input(4));  // 5: variance
1224 
1225   CopyConv2DAttributes(contraction, &fused_conv2d);
1226   SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm"},
1227                        /*num_args=*/4, /*epsilon=*/matched.epsilon);
1228 
1229   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1230   Status status;
1231   mutation->AddNode(std::move(fused_conv2d), &status);
1232   TF_RETURN_IF_ERROR(status);
1233   TF_RETURN_IF_ERROR(mutation->Apply());
1234 
1235   (*invalidated_nodes)[matched.fused_batch_norm] = true;
1236   (*nodes_to_delete)[matched.contraction] = true;
1237 
1238   return Status::OK();
1239 }
1240 
1241 Status AddFusedConv2DNode(RemapperContext* ctx,
1242                           const ContractionWithBatchNormAndActivation& matched,
1243                           std::vector<bool>* invalidated_nodes,
1244                           std::vector<bool>* nodes_to_delete) {
1245   const GraphDef* graph = ctx->graph_view.graph();
1246   const NodeDef& contraction = graph->node(matched.contraction);
1247 
1248   DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";
1249 
1250   const NodeDef& activation = graph->node(matched.activation);
1251   const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
1252   VLOG(2) << "Fuse Conv2D with BatchNorm and " << activation.op()
1253           << ": activation=" << activation.name()
1254           << " batch_norm=" << fused_batch_norm.name()
1255           << " conv2d=" << contraction.name();
1256 
1257   NodeDef fused_conv2d;
1258   fused_conv2d.set_name(activation.name());
1259   fused_conv2d.set_op(kFusedConv2D);
1260   fused_conv2d.set_device(contraction.device());
1261   fused_conv2d.add_input(contraction.input(0));       // 0: input
1262   fused_conv2d.add_input(contraction.input(1));       // 1: filter
1263   fused_conv2d.add_input(fused_batch_norm.input(1));  // 2: scale
1264   fused_conv2d.add_input(fused_batch_norm.input(2));  // 3: offset
1265   fused_conv2d.add_input(fused_batch_norm.input(3));  // 4: mean
1266   fused_conv2d.add_input(fused_batch_norm.input(4));  // 5: variance
1267 
1268   CopyConv2DAttributes(contraction, &fused_conv2d, &activation);
1269   SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm", activation.op()},
1270                        /*num_args=*/4, /*epsilon=*/matched.epsilon);
1271 
1272   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1273   Status status;
1274   mutation->AddNode(std::move(fused_conv2d), &status);
1275   TF_RETURN_IF_ERROR(status);
1276   TF_RETURN_IF_ERROR(mutation->Apply());
1277 
1278   (*invalidated_nodes)[matched.activation] = true;
1279   (*nodes_to_delete)[matched.contraction] = true;
1280   (*nodes_to_delete)[matched.fused_batch_norm] = true;
1281 
1282   return Status::OK();
1283 }
1284 
1285 #ifdef INTEL_MKL
1286 Status AddFusedContractionNode(RemapperContext* ctx,
1287                                const ContractionWithBiasAddAndAdd& matched,
1288                                std::vector<bool>* invalidated_nodes,
1289                                std::vector<bool>* nodes_to_delete) {
1290   const GraphDef* graph = ctx->graph_view.graph();
1291   const NodeDef& contraction = graph->node(matched.contraction);
1292   const NodeDef& bias_add = graph->node(matched.bias_add);
1293 
1294   // MKL version only support fusion for Conv2D and MatMul
1295   DCHECK(IsConv2D(contraction) || IsMatMul(contraction));
1296 
1297   NodeDef contraction_node;
1298   const NodeDef& add = graph->node(matched.add);
1299   contraction_node.set_name(add.name());
1300   contraction_node.set_device(contraction.device());
1301   contraction_node.add_input(
1302       contraction.input(0));  // 0: input(conv) / a (matmul)
1303   contraction_node.add_input(
1304       contraction.input(1));  // 1: filter(conv) / b (matmul)
1305   contraction_node.add_input(bias_add.input(1));  // 2: bias
1306 
1307   // Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
1308   // previously, the other input to add is fused here.
1309   contraction_node.add_input(add.input(1 - matched.port_id));
1310 
1311   if (IsConv2D(contraction)) {
1312     contraction_node.set_op(kFusedConv2D);
1313     CopyConv2DAttributes(contraction, &contraction_node);
1314   } else if (IsMatMul(contraction)) {
1315     contraction_node.set_op(kFusedMatMul);
1316     CopyMatMulAttributes(contraction, &contraction_node);
1317   }
1318 
1319   SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2);
1320 
1321   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1322   Status status;
1323   mutation->AddNode(std::move(contraction_node), &status);
1324   TF_RETURN_IF_ERROR(status);
1325   TF_RETURN_IF_ERROR(mutation->Apply());
1326 
1327   (*invalidated_nodes)[matched.add] = true;
1328   (*nodes_to_delete)[matched.contraction] = true;
1329   (*nodes_to_delete)[matched.bias_add] = true;
1330 
1331   return Status::OK();
1332 }
1333 
1334 Status AddFusedContractionNode(
1335     RemapperContext* ctx, const ContractionWithBiasAndAddActivation& matched,
1336     std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
1337   const GraphDef* graph = ctx->graph_view.graph();
1338   // MKL version only support fusion for Conv2D
1339   const NodeDef& contraction = graph->node(matched.contraction);
1340   DCHECK(IsConv2D(contraction));
1341   const NodeDef& activation = graph->node(matched.activation);
1342 
1343   NodeDef fused_conv2d;
1344   fused_conv2d.set_name(activation.name());
1345   fused_conv2d.set_op(kFusedConv2D);
1346   fused_conv2d.set_device(contraction.device());
1347   fused_conv2d.add_input(contraction.input(0));  // 0: input
1348   fused_conv2d.add_input(contraction.input(1));  // 1: filter
1349   const NodeDef& bias_add = graph->node(matched.bias_add);
1350   fused_conv2d.add_input(bias_add.input(1));  // 2: bias
1351 
1352   // Add OP has two inputs, one is conv+bias pattern matched previously,
1353   // the other input to add is fused here.
1354   const NodeDef& add = graph->node(matched.add);
1355   fused_conv2d.add_input(add.input(1 - matched.port_id));
1356 
1357   CopyConv2DAttributes(contraction, &fused_conv2d);
1358   SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", activation.op()}, 2);
1359 
1360   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1361   Status status;
1362   mutation->AddNode(std::move(fused_conv2d), &status);
1363   TF_RETURN_IF_ERROR(status);
1364   TF_RETURN_IF_ERROR(mutation->Apply());
1365 
1366   (*invalidated_nodes)[matched.activation] = true;
1367   (*nodes_to_delete)[matched.add] = true;
1368   (*nodes_to_delete)[matched.bias_add] = true;
1369   (*nodes_to_delete)[matched.contraction] = true;
1370 
1371   return Status::OK();
1372 }
1373 #endif
1374 
1375 Status AddFusedBatchNormExNode(RemapperContext* ctx,
1376                                const FusedBatchNormEx& matched,
1377                                std::vector<bool>* invalidated_nodes,
1378                                std::vector<bool>* nodes_to_delete) {
1379   const GraphDef* graph = ctx->graph_view.graph();
1380   const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
1381   const NodeDef& activation = graph->node(matched.activation);
1382 
1383   VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:"
1384           << " activation=" << activation.name() << " side_input="
1385           << (matched.side_input != kMissingIndex
1386                   ? graph->node(matched.side_input).name()
1387                   : "<none>")
1388           << " invalidated="
1389           << (matched.invalidated != kMissingIndex
1390                   ? graph->node(matched.invalidated).name()
1391                   : "<none>")
1392           << " fused_batch_norm=" << fused_batch_norm.name();
1393 
1394   // Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
1395   NodeDef fused_op;
1396   fused_op.set_op(kFusedBatchNormEx);
1397   fused_op.set_name(fused_batch_norm.name());
1398   fused_op.set_device(fused_batch_norm.device());
1399 
1400   fused_op.add_input(fused_batch_norm.input(0));  // 0: input
1401   fused_op.add_input(fused_batch_norm.input(1));  // 1: scale
1402   fused_op.add_input(fused_batch_norm.input(2));  // 2: offset
1403   fused_op.add_input(fused_batch_norm.input(3));  // 3: estimated_mean
1404   fused_op.add_input(fused_batch_norm.input(4));  // 4: estimated_var
1405 
1406   CopyFusedBatchNormAttributes(fused_batch_norm, &fused_op);
1407 
1408   auto* attrs = fused_op.mutable_attr();
1409   SetAttrValue(activation.op(), &(*attrs)["activation_mode"]);
1410 
1411   if (matched.side_input != kMissingIndex) {
1412     SetAttrValue(1, &(*attrs)["num_side_inputs"]);
1413     const NodeDef& side_input = graph->node(matched.side_input);
1414     fused_op.add_input(side_input.name());  // 5: side_input
1415   } else {
1416     SetAttrValue(0, &(*attrs)["num_side_inputs"]);
1417   }
1418 
1419   // Turn activation node into Identity node.
1420   NodeDef identity_op;
1421   identity_op.set_op("Identity");
1422   identity_op.set_name(activation.name());
1423   identity_op.set_device(fused_batch_norm.device());
1424   identity_op.add_input(fused_batch_norm.name());
1425   (*identity_op.mutable_attr())["T"] = attrs->at("T");
1426 
1427   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1428   Status status;
1429   mutation->AddNode(std::move(fused_op), &status);
1430   TF_RETURN_IF_ERROR(status);
1431   mutation->AddNode(std::move(identity_op), &status);
1432   TF_RETURN_IF_ERROR(status);
1433   TF_RETURN_IF_ERROR(mutation->Apply());
1434 
1435   (*invalidated_nodes)[matched.fused_batch_norm] = true;
1436   (*invalidated_nodes)[matched.activation] = true;
1437   if (matched.side_input != kMissingIndex) {
1438     (*nodes_to_delete)[matched.invalidated] = true;
1439   }
1440 
1441   return Status::OK();
1442 }
1443 
1444 Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
1445   const GraphDef* graph = ctx->graph_view.graph();
1446   const NodeDef& fused_node = graph->node(matched.fused_batch_norm);
1447   VLOG(2) << "Optimizing fused batch norm node "
1448           << SummarizeNodeDef(fused_node);
1449 
1450   const string& x = fused_node.input(0);
1451   string scale = fused_node.input(1);
1452   string offset = fused_node.input(2);
1453   string mean = fused_node.input(3);
1454   string variance = fused_node.input(4);
1455 
1456   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1457   Status status;
1458 
1459   string x_format = fused_node.attr().at(kDataFormat).s();
1460   if (x_format == "NCHW" || x_format == "NCDHW") {
1461     // Need to reshape the last 4 inputs
1462     NodeDef new_shape;
1463     const string new_shape_name =
1464         AddPrefixToNodeName(x_format + "Shape", fused_node.name());
1465     new_shape.set_name(new_shape_name);
1466     new_shape.set_op("Const");
1467     new_shape.set_device(fused_node.device());
1468     *new_shape.add_input() = AsControlDependency(scale);
1469     (*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
1470     if (x_format == "NCHW") {
1471       Tensor t(DT_INT32, {4});
1472       t.flat<int32>()(0) = 1;
1473       t.flat<int32>()(1) = -1;
1474       t.flat<int32>()(2) = 1;
1475       t.flat<int32>()(3) = 1;
1476       t.AsProtoTensorContent(
1477           (*new_shape.mutable_attr())["value"].mutable_tensor());
1478     } else {
1479       Tensor t(DT_INT32, {5});
1480       t.flat<int32>()(0) = 1;
1481       t.flat<int32>()(1) = -1;
1482       t.flat<int32>()(2) = 1;
1483       t.flat<int32>()(3) = 1;
1484       t.flat<int32>()(4) = 1;
1485       t.AsProtoTensorContent(
1486           (*new_shape.mutable_attr())["value"].mutable_tensor());
1487     }
1488     mutation->AddNode(std::move(new_shape), &status);
1489     TF_RETURN_IF_ERROR(status);
1490 
1491     NodeDef reshaped_scale;
1492     reshaped_scale.set_name(
1493         AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name()));
1494     reshaped_scale.set_op("Reshape");
1495     reshaped_scale.set_device(fused_node.device());
1496     *reshaped_scale.add_input() = scale;
1497     *reshaped_scale.add_input() = new_shape_name;
1498     (*reshaped_scale.mutable_attr())["T"] = fused_node.attr().at("T");
1499     (*reshaped_scale.mutable_attr())["Tshape"].set_type(DT_INT32);
1500     scale = reshaped_scale.name();
1501     mutation->AddNode(std::move(reshaped_scale), &status);
1502     TF_RETURN_IF_ERROR(status);
1503 
1504     NodeDef reshaped_offset;
1505     reshaped_offset.set_name(
1506         AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name()));
1507     reshaped_offset.set_op("Reshape");
1508     reshaped_offset.set_device(fused_node.device());
1509     *reshaped_offset.add_input() = offset;
1510     *reshaped_offset.add_input() = new_shape_name;
1511     (*reshaped_offset.mutable_attr())["T"] = fused_node.attr().at("T");
1512     (*reshaped_offset.mutable_attr())["Tshape"].set_type(DT_INT32);
1513     offset = reshaped_offset.name();
1514     mutation->AddNode(std::move(reshaped_offset), &status);
1515     TF_RETURN_IF_ERROR(status);
1516 
1517     NodeDef reshaped_mean;
1518     reshaped_mean.set_name(
1519         AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name()));
1520     reshaped_mean.set_op("Reshape");
1521     reshaped_mean.set_device(fused_node.device());
1522     *reshaped_mean.add_input() = mean;
1523     *reshaped_mean.add_input() = new_shape_name;
1524     (*reshaped_mean.mutable_attr())["T"] = fused_node.attr().at("T");
1525     (*reshaped_mean.mutable_attr())["Tshape"].set_type(DT_INT32);
1526     mean = reshaped_mean.name();
1527     mutation->AddNode(std::move(reshaped_mean), &status);
1528     TF_RETURN_IF_ERROR(status);
1529 
1530     NodeDef reshaped_variance;
1531     reshaped_variance.set_name(
1532         AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name()));
1533     reshaped_variance.set_op("Reshape");
1534     reshaped_variance.set_device(fused_node.device());
1535     *reshaped_variance.add_input() = variance;
1536     *reshaped_variance.add_input() = new_shape_name;
1537     (*reshaped_variance.mutable_attr())["T"] = fused_node.attr().at("T");
1538     (*reshaped_variance.mutable_attr())["Tshape"].set_type(DT_INT32);
1539     variance = reshaped_variance.name();
1540     mutation->AddNode(std::move(reshaped_variance), &status);
1541     TF_RETURN_IF_ERROR(status);
1542   }
1543 
1544   float epsilon = 0.0f;
1545   if (fused_node.attr().count("epsilon")) {
1546     epsilon = fused_node.attr().at("epsilon").f();
1547   }
1548   DataType dtype = fused_node.attr().at("T").type();
1549   Tensor value(dtype, TensorShape());
1550   value.scalar<float>()() = epsilon;
1551   NodeDef variance_epsilon;
1552   const string variance_epsilon_name =
1553       AddPrefixToNodeName("Const", fused_node.name());
1554   TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
1555       variance_epsilon_name, TensorValue(&value), &variance_epsilon));
1556   variance_epsilon.set_device(fused_node.device());
1557   mutation->AddNode(std::move(variance_epsilon), &status);
1558   TF_RETURN_IF_ERROR(status);
1559 
1560   NodeDef variance_plus_epsilon;
1561   const string variance_plus_epsilon_name =
1562       AddPrefixToNodeName("VarPlusEpsilon", fused_node.name());
1563   variance_plus_epsilon.set_name(variance_plus_epsilon_name);
1564   variance_plus_epsilon.set_op("Add");
1565   (*variance_plus_epsilon.mutable_attr())["T"].set_type(dtype);
1566   variance_plus_epsilon.set_device(fused_node.device());
1567   *variance_plus_epsilon.add_input() = variance;
1568   *variance_plus_epsilon.add_input() = variance_epsilon_name;
1569   mutation->AddNode(std::move(variance_plus_epsilon), &status);
1570   TF_RETURN_IF_ERROR(status);
1571 
1572   NodeDef inv;
1573   const string inv_name = AddPrefixToNodeName("Inv", fused_node.name());
1574   inv.set_name(inv_name);
1575   inv.set_op("Rsqrt");
1576   inv.set_device(fused_node.device());
1577   (*inv.mutable_attr())["T"].set_type(dtype);
1578   *inv.add_input() = variance_plus_epsilon_name;
1579   mutation->AddNode(std::move(inv), &status);
1580   TF_RETURN_IF_ERROR(status);
1581 
1582   NodeDef scaled;
1583   const string scaled_name = AddPrefixToNodeName("Scaled", fused_node.name());
1584   scaled.set_name(scaled_name);
1585   scaled.set_op("Mul");
1586   scaled.set_device(fused_node.device());
1587   (*scaled.mutable_attr())["T"].set_type(dtype);
1588   *scaled.add_input() = inv_name;
1589   *scaled.add_input() = scale;
1590   mutation->AddNode(std::move(scaled), &status);
1591   TF_RETURN_IF_ERROR(status);
1592 
1593   NodeDef a;
1594   const string a_name = AddPrefixToNodeName("Mul", fused_node.name());
1595   a.set_name(a_name);
1596   a.set_op("Mul");
1597   a.set_device(fused_node.device());
1598   (*a.mutable_attr())["T"].set_type(dtype);
1599   *a.add_input() = x;
1600   *a.add_input() = scaled_name;
1601   mutation->AddNode(std::move(a), &status);
1602   TF_RETURN_IF_ERROR(status);
1603 
1604   NodeDef b;
1605   const string b_name = AddPrefixToNodeName("Mul2", fused_node.name());
1606   b.set_name(b_name);
1607   b.set_op("Mul");
1608   b.set_device(fused_node.device());
1609   (*b.mutable_attr())["T"].set_type(dtype);
1610   *b.add_input() = mean;
1611   *b.add_input() = scaled_name;
1612   mutation->AddNode(std::move(b), &status);
1613   TF_RETURN_IF_ERROR(status);
1614 
1615   NodeDef c;
1616   const string c_name = AddPrefixToNodeName("Offset", fused_node.name());
1617   c.set_name(c_name);
1618   c.set_op("Sub");
1619   c.set_device(fused_node.device());
1620   (*c.mutable_attr())["T"].set_type(dtype);
1621   *c.add_input() = offset;
1622   *c.add_input() = b_name;
1623   mutation->AddNode(std::move(c), &status);
1624   TF_RETURN_IF_ERROR(status);
1625 
1626   NodeDef r;
1627   r.set_name(fused_node.name());
1628   r.set_op("Add");
1629   r.set_device(fused_node.device());
1630   (*r.mutable_attr())["T"].set_type(dtype);
1631   *r.add_input() = a_name;
1632   *r.add_input() = c_name;
1633   mutation->AddNode(std::move(r), &status);
1634   TF_RETURN_IF_ERROR(status);
1635 
1636   return mutation->Apply();
1637 }
1638 
1639 #ifdef INTEL_MKL
1640 bool IsConv2DOrMatMul(const NodeDef& node) {
1641   return IsConv2D(node) || IsMatMul(node);
1642 }
1643 
1644 bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
1645   const auto* node_view = ctx.graph_view.GetNode(node_index);
1646 
1647   // Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion.
1648   //               MatMul + Add or MatMul + BiasAdd + Add fusion.
1649   auto is_supported_add_input = [](const auto* node_view) -> bool {
1650     // Currently only support Conv2D and MatMul
1651     if (IsConv2DOrMatMul(*node_view->node())) return true;
1652     if (IsBiasAdd(*node_view->node())) {
1653       if (node_view->NumRegularFanins() < 2) return false;
1654       const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
1655       const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
1656       return IsConv2DOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
1657              IsConv2DOrMatMul(*bias_add_fanin_1.node_view()->node());
1658     }
1659     return false;
1660   };
1661 
1662   auto is_supported_add = [&](const auto* node_view) -> bool {
1663     const auto* node_def = node_view->node();
1664     if (IsAdd(*node_def)) {
1665       if (node_view->NumRegularFanins() < 2) return false;
1666       const auto& add_fanin_0 = node_view->GetRegularFanin(0);
1667       const auto& add_fanin_1 = node_view->GetRegularFanin(1);
1668       return is_supported_add_input(add_fanin_0.node_view()) ||
1669              is_supported_add_input(add_fanin_1.node_view());
1670     }
1671     return false;
1672   };
1673 
1674   bool ret = false;
1675   for (int i = 0; i < node_view->NumRegularFanins(); i++) {
1676     const auto& fanin_i = node_view->GetRegularFanin(i);
1677     ret = is_supported_add(fanin_i.node_view());
1678     if (ret) break;
1679   }
1680 
1681   return ret;
1682 }
1683 #endif
1684 
1685 // Check if a node is a candidate to one of the patterns that require inferred
1686 // shapes:
1687 //   (1) Splitting FusedBatchNorm into primitives.
1688 //   (2) Fusing side input and/or activation into FusedBatchNorm.
1689 //   (3) Fusing Conv2D biasadd and relu on GPU
1690 //   (4) INTEL_MKL specific: Conv2D -> Add or Conv2D -> BiasAdd -> Add.
1691 bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
1692   // Candidate for a FusedBatchNorm splitting.
1693   const auto* node_view = ctx.graph_view.GetNode(node_index);
1694   const auto* node_def = node_view->node();
1695   const auto is_batch_norm_candidate = [&]() -> bool {
1696     if (!IsFusedBatchNorm(*node_def)) return false;
1697     if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
1698 
1699     bool is_training = true;
1700     if (!TryGetNodeAttr(*node_def, kIsTraining, &is_training)) return false;
1701     if (is_training) return false;
1702 
1703     return true;
1704   };
1705 
1706   const auto is_relu_biasadd_conv2d_candidate = [&]() -> bool {
1707     if (!IsRelu(*node_def)) return false;
1708     if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
1709 
1710     if (node_view->NumRegularFanins() < 1) return false;
1711     const auto& relu_fanin_0 = node_view->GetRegularFanin(0);
1712     const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
1713     const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
1714 
1715     if (!IsBiasAdd(*relu_fanin_0_node_def)) return false;
1716     if (GetDataTypeFromAttr(*relu_fanin_0_node_def, "T") != DT_FLOAT)
1717       return false;
1718 
1719     if (relu_fanin_0_node_view->NumRegularFanins() < 1) return false;
1720 
1721     const auto& biasadd_fanin_0 = relu_fanin_0_node_view->GetRegularFanin(0);
1722     const auto* biasadd_fanin_0_node_def = biasadd_fanin_0.node_view()->node();
1723 
1724     if (!IsConv2D(*biasadd_fanin_0_node_def)) return false;
1725     if (GetDataTypeFromAttr(*biasadd_fanin_0_node_def, "T") != DT_FLOAT)
1726       return false;
1727 
1728     return true;
1729   };
1730 
1731   // Candidate for a FusedBatchNorm fusion.
1732   const auto is_batch_norm_fusion_candidate = [&]() -> bool {
1733     if (!IsRelu(*node_def)) return false;
1734 
1735     if (node_view->NumRegularFanins() < 1) return false;
1736     const auto& relu_fanin_0 = node_view->GetRegularFanin(0);
1737     const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
1738     const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
1739 
1740     if (IsFusedBatchNorm(*relu_fanin_0_node_def)) {
1741       // FusedBatchNorm + Relu.
1742       return true;
1743 
1744     } else if (IsAdd(*relu_fanin_0_node_def)) {
1745       // FusedBatchNorm + Add + Relu.
1746 
1747       if (relu_fanin_0_node_view->NumRegularFanins() < 2) return false;
1748       const auto& add_regular_fanin_0 =
1749           relu_fanin_0_node_view->GetRegularFanin(0);
1750       if (IsFusedBatchNorm(*add_regular_fanin_0.node_view()->node()))
1751         return true;
1752       const auto& add_regular_fanin_1 =
1753           relu_fanin_0_node_view->GetRegularFanin(1);
1754       if (IsFusedBatchNorm(*add_regular_fanin_1.node_view()->node()))
1755         return true;
1756     }
1757 
1758     return false;
1759   };
1760 
1761 #ifdef INTEL_MKL
1762   (void)is_relu_biasadd_conv2d_candidate;  // To fix unused variable error.
1763   return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
1764          IsContractionWithAdd(ctx, node_index);
1765 #else
1766   return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
1767          is_batch_norm_fusion_candidate();
1768 #endif  // INTEL_MKL
1769 }
1770 
1771 }  // namespace
1772 
1773 Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
1774                           GraphDef* optimized_graph) {
1775   GrapplerItem mutable_item = item;
1776   Status status;
1777   RemapperContext ctx(&mutable_item, &status);
1778   TF_RETURN_IF_ERROR(status);
1779   // Processing graph in reverse-topological sorted order allows to remap
1780   // longer chains of dependent ops in one pass.
1781   TF_RETURN_IF_ERROR(
1782       ctx.graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
1783 
1784   const int num_nodes = item.graph.node_size();
1785   // Skip nodes that were invalidated by a remapper, e.g. do not process BiasAdd
1786   // and Activation nodes that were fused into a Conv2D node.
1787   std::vector<bool> invalidated_nodes(num_nodes);
1788   std::vector<bool> nodes_to_delete(num_nodes);
1789 
1790   // _Fused{...} kernels do not have registered gradient function, so we must
1791   // not perform rewrite if the graph will be differentiated later.
1792   bool allow_non_differentiable_rewrites =
1793       item.optimization_options().allow_non_differentiable_rewrites;
1794 
1795   for (int i = num_nodes - 1; i >= 0; --i) {
1796     // Check if node was invalidated by one of the previous remaps.
1797     if (invalidated_nodes[i] || nodes_to_delete[i]) {
1798       continue;
1799     }
1800 
1801     // Infer properties lazily in case they are not needed.
1802     if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, i)) {
1803       const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
1804       TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
1805           assume_valid_feeds,
1806           /*aggressive_shape_inference=*/false,
1807           /*include_input_tensor_values=*/true,
1808           /*include_output_tensor_values=*/false));
1809       ctx.inferred_graph_properties = true;
1810     }
1811 
1812 #ifdef INTEL_MKL
1813     ContractionWithBiasAddAndAdd contract_with_bias_and_add;
1814     ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
1815 
1816     if (!item.optimization_options().is_eager_mode) {
1817       // Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D.
1818       if (FindContractionWithBiasAndAddActivation(
1819               ctx, i, &contract_with_bias_and_add_activation)) {
1820         TF_RETURN_IF_ERROR(
1821             AddFusedContractionNode(&ctx, contract_with_bias_and_add_activation,
1822                                     &invalidated_nodes, &nodes_to_delete));
1823         continue;
1824       }
1825 
1826       // Remap Conv2D+BiasAdd+Add into the _FusedConv2D.
1827       if (FindContractionWithBiasAddAndAdd(ctx, i,
1828                                            &contract_with_bias_and_add)) {
1829         TF_RETURN_IF_ERROR(
1830             AddFusedContractionNode(&ctx, contract_with_bias_and_add,
1831                                     &invalidated_nodes, &nodes_to_delete));
1832         continue;
1833       }
1834     }
1835 #endif  //! INTEL_MKL
1836 
1837     // Infer properties lazily in case they are not needed.
1838     if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, i)) {
1839       const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
1840       TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
1841           assume_valid_feeds,
1842           /*aggressive_shape_inference=*/false,
1843           /*include_input_tensor_values=*/true,
1844           /*include_output_tensor_values=*/false));
1845       ctx.inferred_graph_properties = true;
1846     }
1847 
1848     // Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd into the
1849     // _Fused{Conv2D,DepthwiseConv2dNative,MatMul}
1850     ContractionWithBiasAdd contract_with_bias;
1851     if (allow_non_differentiable_rewrites &&
1852         FindContractionWithBias(ctx, i, &contract_with_bias)) {
1853       TF_RETURN_IF_ERROR(AddFusedContractionNode(
1854           &ctx, contract_with_bias, &invalidated_nodes, &nodes_to_delete));
1855       continue;
1856     }
1857 
1858     // Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd+Activation into the
1859     // _Fused{Conv2D,DepthwiseConv2dNative,MatMul}.
1860     ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
1861     if (allow_non_differentiable_rewrites &&
1862         FindContractionWithBiasAndActivation(
1863             ctx, i, &contract_with_bias_and_activation)) {
1864       TF_RETURN_IF_ERROR(
1865           AddFusedContractionNode(&ctx, contract_with_bias_and_activation,
1866                                   &invalidated_nodes, &nodes_to_delete));
1867       continue;
1868     }
1869 
1870     // NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do
1871     // it for MatMul as well, but in practice this pattern does not appear in
1872     // real Tensorflow graphs.
1873 
1874     // Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
1875     ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
1876     if (allow_non_differentiable_rewrites &&
1877         FindConv2DWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) {
1878       TF_RETURN_IF_ERROR(
1879           AddFusedConv2DNode(&ctx, contract_with_squeeze_and_bias,
1880                              &invalidated_nodes, &nodes_to_delete));
1881       continue;
1882     }
1883 
1884 // TODO(intel-tf):
1885 // Remove this once TF-MKL supports _FusedConv2D with these operations.
1886 #ifndef INTEL_MKL
1887     // Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
1888     ContractionWithBatchNorm contract_with_batch_norm;
1889     if (allow_non_differentiable_rewrites &&
1890         FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
1891       TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
1892                                             &invalidated_nodes,
1893                                             &nodes_to_delete));
1894       continue;
1895     }
1896 
1897     // Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
1898     ContractionWithBatchNormAndActivation
1899         contract_with_batch_norm_and_activation;
1900     if (allow_non_differentiable_rewrites &&
1901         FindConv2DWithBatchNormAndActivation(
1902             ctx, i, &contract_with_batch_norm_and_activation)) {
1903       TF_RETURN_IF_ERROR(
1904           AddFusedConv2DNode(&ctx, contract_with_batch_norm_and_activation,
1905                              &invalidated_nodes, &nodes_to_delete));
1906       continue;
1907     }
1908 #endif  // !INTEL_MKL
1909 
1910     // Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
1911     FusedBatchNormEx fused_batch_norm_ex;
1912     if (allow_non_differentiable_rewrites &&
1913         FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) {
1914       TF_RETURN_IF_ERROR(AddFusedBatchNormExNode(
1915           &ctx, fused_batch_norm_ex, &invalidated_nodes, &nodes_to_delete));
1916       continue;
1917     }
1918 
1919     // During inference, most of the inputs to FusedBatchNorm are constant, and
1920     // we can therefore replace the op with a much cheaper set of primitives.
1921     FusedBatchNorm fused_batch_norm;
1922     if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) {
1923       TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
1924       continue;
1925     }
1926   }
1927 
1928   // Remove invalidated nodes.
1929   utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder();
1930   for (int i = 0; i < num_nodes; ++i) {
1931     if (nodes_to_delete[i]) {
1932       mutation->RemoveNode(ctx.graph_view.GetNode(i));
1933     }
1934   }
1935   TF_RETURN_IF_ERROR(mutation->Apply());
1936 
1937   *optimized_graph = std::move(mutable_item.graph);
1938 
1939   return Status::OK();
1940 }
1941 
1942 void Remapper::Feedback(Cluster* cluster, const GrapplerItem& item,
1943                         const GraphDef& optimized_graph, double result) {
1944   // Nothing to do for RemapperOptimizer.
1945 }
1946 
1947 }  // namespace grappler
1948 }  // namespace tensorflow
1949