• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/remapper.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/core/framework/versions.pb.h"
20 #include "tensorflow/core/grappler/costs/graph_properties.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/grappler/utils/graph_view.h"
27 #include "tensorflow/core/grappler/utils/pattern_utils.h"
28 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
29 #include "tensorflow/core/grappler/utils/topological_sort.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
33 #include "tensorflow/core/util/env_var.h"
34 #include "tensorflow/core/util/use_cudnn.h"
35 #include "tensorflow/core/util/util.h"
36 
37 #if GOOGLE_CUDA
38 #include "third_party/gpus/cudnn/cudnn.h"
39 #endif  // GOOGLE_CUDA
40 
41 namespace tensorflow {
42 namespace grappler {
43 
44 // Supported patterns:
45 //
46 // Conv2D + ... -> _FusedConv2D
47 //   (1) Conv2D + BiasAdd + <Activation>
48 //   (2) Conv2D + FusedBatchNorm + <Activation>
49 //   (3) Conv2D + Squeeze + BiasAdd
50 //
51 // MatMul + ... -> _FusedMatMul:
52 //   (1) MatMul + BiasAdd + <Activation>
53 //
54 // DepthwiseConv2dNative + ... -> _FusedDepthwiseConv2dNative:
55 //   (1) DepthwiseConv2dNative + BiasAdd + <Activation>
56 //
57 // FusedBatchNorm[$is_training] + ... -> _FusedBatchNormEx[$is_training]
58 //   (1) FusedBatchNorm + <Activation>
59 //   (2) FusedBatchNorm + SideInput + <Activation>
60 //
61 // Sigmoid + Mul -> _MklSwish  // This fusion only works on Intel CPU.
62 //
63 //
64 // In all cases, the supported activation functions are Relu, Relu6, and Elu.
65 //
66 // Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the
67 // patterns are "ContractionWith...".
68 namespace {
69 
70 constexpr char kFusedConv2D[] = "_FusedConv2D";
71 constexpr char kFusedConv3D[] = "_FusedConv3D";
72 constexpr char kFusedMatMul[] = "_FusedMatMul";
73 constexpr char kFusedDepthwiseConv2dNative[] = "_FusedDepthwiseConv2dNative";
74 constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
75 constexpr char kFusedBatchNormGradEx[] = "_FusedBatchNormGradEx";
76 constexpr char kTensorToHashBucket[] = "_TensorToHashBucketFast";
77 
78 constexpr char kDataFormat[] = "data_format";
79 constexpr char kIsTraining[] = "is_training";
80 
81 constexpr char kWidth[] = "width";
82 constexpr char kFill[] = "fill";
83 
84 constexpr int kMissingIndex = -1;
85 
86 struct RemapperContext {
RemapperContexttensorflow::grappler::__anon2f31cdb30111::RemapperContext87   explicit RemapperContext(GrapplerItem* item, Status* status,
88                            RewriterConfig::CpuLayout cpu_layout_conversion,
89                            bool xla_auto_clustering_on)
90       : nodes_to_preserve(item->NodesToPreserve()),
91         graph_view(&item->graph, status),
92         graph_properties(*item),
93         inferred_graph_properties(false),
94         cpu_layout_conversion(cpu_layout_conversion),
95         xla_auto_clustering_on(xla_auto_clustering_on) {}
96 
97   std::unordered_set<string> nodes_to_preserve;
98   utils::MutableGraphView graph_view;
99   GraphProperties graph_properties;
100   bool inferred_graph_properties;
101   RewriterConfig::CpuLayout cpu_layout_conversion;
102   bool xla_auto_clustering_on;
103 };
104 
105 // FusedBatchNorm that can be replaced with a cheaper set of primitives.
106 struct FusedBatchNorm {
107   FusedBatchNorm() = default;
FusedBatchNormtensorflow::grappler::__anon2f31cdb30111::FusedBatchNorm108   explicit FusedBatchNorm(int fused_batch_norm)
109       : fused_batch_norm(fused_batch_norm) {}
110 
111   int fused_batch_norm = kMissingIndex;
112 };
113 
114 // FusedBatchNorm[$is_training] with fused side input and/or activation.
115 struct FusedBatchNormEx {
116   FusedBatchNormEx() = default;
117 
118   int fused_batch_norm = kMissingIndex;
119   int side_input = kMissingIndex;
120   int activation = kMissingIndex;
121   // Add node that will be invalidated by fusing side input and fused batch norm
122   int invalidated = kMissingIndex;
123 };
124 
125 // FusedBatchNormGrad with fused side output and/or activation.
126 struct FusedBatchNormGradEx {
127   int fused_batch_norm_grad = kMissingIndex;
128   int activation_grad = kMissingIndex;
129   int side_input_grad = kMissingIndex;
130   // Add node of the forward pass to access its "offset" input.
131   int fwd_fused_batch_norm = kMissingIndex;
132 };
133 
134 // TensorToHashBucket that can be replaced with AsString + StringToHashBucket.
135 // We also include the fanin node of AsString ("pre_as_string") to determine the
136 // device.
137 struct TensorToHashBucket {
138   TensorToHashBucket() = default;
TensorToHashBuckettensorflow::grappler::__anon2f31cdb30111::TensorToHashBucket139   explicit TensorToHashBucket(int op1, int op2, int op3)
140       : pre_as_string(op1), as_string(op2), string_to_hash_bucket(op3) {}
141 
142   int pre_as_string = kMissingIndex;
143   int as_string = kMissingIndex;
144   int string_to_hash_bucket = kMissingIndex;
145 };
146 
147 // Pad followed by Conv3D/FusedConv3D
148 struct PadWithConv3D {
149   PadWithConv3D() = default;
PadWithConv3Dtensorflow::grappler::__anon2f31cdb30111::PadWithConv3D150   PadWithConv3D(int contraction_idx, int pad_idx, int padding_const_idx)
151       : contraction_idx(contraction_idx),
152         pad_idx(pad_idx),
153         padding_const_idx(padding_const_idx) {}
154 
155   int contraction_idx = kMissingIndex;
156   int pad_idx = kMissingIndex;
157   int padding_const_idx = kMissingIndex;
158 };
159 
160 // Contraction node followed by a BiasAdd.
161 struct ContractionWithBiasAdd {
162   ContractionWithBiasAdd() = default;
ContractionWithBiasAddtensorflow::grappler::__anon2f31cdb30111::ContractionWithBiasAdd163   ContractionWithBiasAdd(int contraction, int bias_add, int bias_port)
164       : contraction(contraction), bias_add(bias_add), bias_port(bias_port) {}
165 
166   int contraction = kMissingIndex;
167   int bias_add = kMissingIndex;
168   int bias_port = 1;
169 };
170 
171 // Contraction node followed by a BiasAdd and Activation.
172 struct ContractionWithBiasAddAndActivation {
173   ContractionWithBiasAddAndActivation() = default;
ContractionWithBiasAddAndActivationtensorflow::grappler::__anon2f31cdb30111::ContractionWithBiasAddAndActivation174   ContractionWithBiasAddAndActivation(int contraction, int bias_add,
175                                       int activation, int bias_port)
176       : contraction(contraction),
177         bias_add(bias_add),
178         activation(activation),
179         bias_port(bias_port) {}
180 
181   int contraction = kMissingIndex;
182   int bias_add = kMissingIndex;
183   int activation = kMissingIndex;
184   int bias_port = 1;
185 };
186 
187 // Contraction node followed by a Squeeze and BiasAdd.
188 struct ContractionWithSqueezeAndBiasAdd {
189   ContractionWithSqueezeAndBiasAdd() = default;
ContractionWithSqueezeAndBiasAddtensorflow::grappler::__anon2f31cdb30111::ContractionWithSqueezeAndBiasAdd190   ContractionWithSqueezeAndBiasAdd(int contraction, int squeeze, int bias_add)
191       : contraction(contraction), squeeze(squeeze), bias_add(bias_add) {}
192 
193   int contraction = kMissingIndex;
194   int squeeze = kMissingIndex;
195   int bias_add = kMissingIndex;
196 };
197 
198 // Contraction node followed by a FusedBatchNorm.
199 struct ContractionWithBatchNorm {
200   ContractionWithBatchNorm() = default;
ContractionWithBatchNormtensorflow::grappler::__anon2f31cdb30111::ContractionWithBatchNorm201   ContractionWithBatchNorm(int contraction, int fused_batch_norm,
202                            float epsilon = 0.0)
203       : contraction(contraction),
204         fused_batch_norm(fused_batch_norm),
205         epsilon(epsilon) {}
206 
207   int contraction = kMissingIndex;
208   int fused_batch_norm = kMissingIndex;
209   float epsilon = 0.0;
210 };
211 
212 // Contraction node followed by a FusedBatchNorm and Activation.
213 struct ContractionWithBatchNormAndActivation {
214   ContractionWithBatchNormAndActivation() = default;
ContractionWithBatchNormAndActivationtensorflow::grappler::__anon2f31cdb30111::ContractionWithBatchNormAndActivation215   ContractionWithBatchNormAndActivation(int contraction, int fused_batch_norm,
216                                         int activation, float epsilon = 0.0)
217       : contraction(contraction),
218         fused_batch_norm(fused_batch_norm),
219         activation(activation),
220         epsilon(epsilon) {}
221 
222   int contraction = kMissingIndex;
223   int fused_batch_norm = kMissingIndex;
224   int activation = kMissingIndex;
225   float epsilon = 0.0;
226 };
227 
228 // Contraction node followed by a BiasAdd and Add.
229 struct ContractionWithBiasAddAndAdd {
230   ContractionWithBiasAddAndAdd() = default;
ContractionWithBiasAddAndAddtensorflow::grappler::__anon2f31cdb30111::ContractionWithBiasAddAndAdd231   ContractionWithBiasAddAndAdd(int contraction, int bias_add, int add,
232                                int port_id, int bias_port)
233       : contraction(contraction),
234         bias_add(bias_add),
235         add(add),
236         port_id(port_id),
237         bias_port(bias_port) {}
238 
239   int contraction = kMissingIndex;
240   int bias_add = kMissingIndex;
241   int add = kMissingIndex;
242   int port_id = 0;
243   int bias_port = 1;
244 };
245 
246 // Contraction node followed by a BiasAdd, Add and Relu.
247 // Plus Tanh and Sigmoid for MatMul in MKL
248 struct ContractionWithBiasAndAddActivation {
249   ContractionWithBiasAndAddActivation() = default;
ContractionWithBiasAndAddActivationtensorflow::grappler::__anon2f31cdb30111::ContractionWithBiasAndAddActivation250   ContractionWithBiasAndAddActivation(int contraction, int bias_add, int add,
251                                       int port_id, int activation,
252                                       int bias_port)
253       : contraction(contraction),
254         bias_add(bias_add),
255         add(add),
256         port_id(port_id),
257         activation(activation),
258         bias_port(bias_port) {}
259 
260   int contraction = kMissingIndex;
261   int bias_add = kMissingIndex;
262   int add = kMissingIndex;
263   int port_id = 0;
264   int activation = kMissingIndex;
265   int bias_port = 1;
266 };
267 
IsInPreserveSet(const RemapperContext & ctx,const NodeDef * node)268 bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) {
269   return ctx.nodes_to_preserve.count(node->name()) > 0;
270 }
271 
HaveSameDataType(const NodeDef * lhs,const NodeDef * rhs,const string & type_attr="T")272 bool HaveSameDataType(const NodeDef* lhs, const NodeDef* rhs,
273                       const string& type_attr = "T") {
274   DataType lhs_attr = GetDataTypeFromAttr(*lhs, type_attr);
275   DataType rhs_attr = GetDataTypeFromAttr(*rhs, type_attr);
276 
277   return lhs_attr != DT_INVALID && rhs_attr != DT_INVALID &&
278          lhs_attr == rhs_attr;
279 }
280 
HasDataType(const NodeDef * node,const DataType & expected,const string & type_attr="T")281 bool HasDataType(const NodeDef* node, const DataType& expected,
282                  const string& type_attr = "T") {
283   DataType dtype = GetDataTypeFromAttr(*node, type_attr);
284   return dtype == expected;
285 }
286 
IsCpuCompatibleDataType(const NodeDef * contraction,const string & type_attr="T")287 bool IsCpuCompatibleDataType(const NodeDef* contraction,
288                              const string& type_attr = "T") {
289   DataType dtype = GetDataTypeFromAttr(*contraction, type_attr);
290   // Stock TF without oneDNN build will always be `false`.
291   bool is_one_dnn_enabled = IsMKLEnabled();
292 
293   if (is_one_dnn_enabled) {
294     return (IsConv2D(*contraction) || IsDepthwiseConv2dNative(*contraction) ||
295             IsMatMul(*contraction) || IsConv3D(*contraction) ||
296             IsAnyBatchMatMul(*contraction)) &&
297            (dtype == DT_FLOAT || dtype == DT_BFLOAT16);
298   }
299   if (IsConv2D(*contraction)) {
300     return dtype == DT_FLOAT || dtype == DT_DOUBLE;
301   } else if (IsMatMul(*contraction)) {
302     return dtype == DT_FLOAT;
303   } else {
304     return false;
305   }
306 }
307 
IsGpuCompatibleDataType(const NodeDef * contraction,const string & type_attr="T")308 bool IsGpuCompatibleDataType(const NodeDef* contraction,
309                              const string& type_attr = "T") {
310   DataType dtype = GetDataTypeFromAttr(*contraction, type_attr);
311   if (IsConv2D(*contraction) || IsMatMul(*contraction)) {
312     return dtype == DT_FLOAT || dtype == DT_HALF;
313   } else {
314     return false;
315   }
316 }
317 
IsCpuCompatibleDataFormat(const RemapperContext & ctx,const NodeDef * conv_node)318 bool IsCpuCompatibleDataFormat(const RemapperContext& ctx,
319                                const NodeDef* conv_node) {
320   const string& data_format = conv_node->attr().at(kDataFormat).s();
321   if (IsConv2D(*conv_node)) {
322     return data_format == "NHWC" || (IsMKLEnabled() && data_format == "NCHW") ||
323            (ctx.cpu_layout_conversion == RewriterConfig::NHWC_TO_NCHW &&
324             data_format == "NCHW");
325   } else if (IsConv3D(*conv_node)) {
326     return data_format == "NDHWC" || (IsMKLEnabled() && data_format == "NCDHW");
327   } else {
328     return false;
329   }
330 }
331 
BlasLtMatmulEnabled()332 bool BlasLtMatmulEnabled() {
333   static bool is_enabled = [] {
334     bool is_enabled = false;
335     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
336         "TF_USE_CUBLASLT", /*default_val=*/false, &is_enabled));
337     return is_enabled;
338   }();
339   return is_enabled;
340 }
341 
IsGpuCompatibleDataFormat(const RemapperContext & ctx,const NodeDef * conv2d)342 bool IsGpuCompatibleDataFormat(const RemapperContext& ctx,
343                                const NodeDef* conv2d) {
344   DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
345   const string& data_format = conv2d->attr().at(kDataFormat).s();
346   return data_format == "NHWC" || data_format == "NCHW";
347 }
348 
IsCpuCompatibleConv2D(const RemapperContext & ctx,const NodeDef * conv2d)349 bool IsCpuCompatibleConv2D(const RemapperContext& ctx, const NodeDef* conv2d) {
350   DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
351   return NodeIsOnCpu(conv2d) && IsCpuCompatibleDataType(conv2d) &&
352          IsCpuCompatibleDataFormat(ctx, conv2d);
353 }
354 
IsCpuCompatibleConv3D(const RemapperContext & ctx,const NodeDef * conv3d)355 bool IsCpuCompatibleConv3D(const RemapperContext& ctx, const NodeDef* conv3d) {
356   DCHECK(IsConv3D(*conv3d)) << "Expected Conv3D op";
357   return NodeIsOnCpu(conv3d) && IsCpuCompatibleDataType(conv3d) &&
358          IsCpuCompatibleDataFormat(ctx, conv3d);
359 }
360 
IsGpuCompatibleConv2D(const RemapperContext & ctx,const NodeDef * conv2d)361 bool IsGpuCompatibleConv2D(const RemapperContext& ctx, const NodeDef* conv2d) {
362   DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
363   return NodeIsOnGpu(conv2d) && IsGpuCompatibleDataType(conv2d) &&
364          IsGpuCompatibleDataFormat(ctx, conv2d);
365 }
366 
IsGpuCompatibleMatMul(const RemapperContext & ctx,const NodeDef * matmul)367 bool IsGpuCompatibleMatMul(const RemapperContext& ctx, const NodeDef* matmul) {
368   DCHECK(IsMatMul(*matmul)) << "Expected MatMul op";
369   return BlasLtMatmulEnabled() && NodeIsOnGpu(matmul) &&
370          IsGpuCompatibleDataType(matmul);
371 }
372 
IsCpuCompatibleMatMul(const RemapperContext & ctx,const NodeDef * matmul)373 bool IsCpuCompatibleMatMul(const RemapperContext& ctx, const NodeDef* matmul) {
374   DCHECK(IsMatMul(*matmul)) << "Expected MatMul op";
375   return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul);
376 }
377 
IsCpuCompatibleDepthwiseConv2dNative(const NodeDef * dw_conv2d)378 bool IsCpuCompatibleDepthwiseConv2dNative(const NodeDef* dw_conv2d) {
379   DCHECK(IsDepthwiseConv2dNative(*dw_conv2d))
380       << "Expected DepthwiseConv2dNative op";
381   return NodeIsOnCpu(dw_conv2d) && IsCpuCompatibleDataType(dw_conv2d);
382 }
383 
384 // Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU.
385 template <typename Pattern>
IsCpuCompatible(const RemapperContext & ctx,const Pattern & matched)386 bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) {
387   const NodeDef& node = ctx.graph_view.graph()->node(matched.contraction);
388   if (IsConv2D(node)) {
389     return IsCpuCompatibleConv2D(ctx, &node);
390   } else if (IsDepthwiseConv2dNative(node)) {
391     return (IsMKLEnabled() && IsCpuCompatibleDepthwiseConv2dNative(&node));
392   } else if (IsMatMul(node)) {
393     return IsCpuCompatibleMatMul(ctx, &node);
394   } else if (IsConv3D(node)) {
395     return (IsMKLEnabled() && IsCpuCompatibleConv3D(ctx, &node));
396   } else {
397     return false;
398   }
399 }
400 
IsSupportedActivation(const NodeDef & node)401 bool IsSupportedActivation(const NodeDef& node) {
402   bool is_default_supported =
403       IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
404   bool is_mkl_specific = IsMKLEnabled() && (IsTanh(node) || IsSigmoid(node));
405   return (is_default_supported || is_mkl_specific);
406 }
407 
RuntimeFusionEnabled(const Cluster * cluster)408 bool RuntimeFusionEnabled(const Cluster* cluster) {
409   static bool is_enabled = [&] {
410 #if CUDNN_VERSION >= 8400
411     // Cudnn runtime fusion feature is recommended for Ampere GPUs or later.
412     // For pre-Ampere GPUs, the overhead of runtime compilation would be very
413     // large and there are more limitations of supported cases.
414     if (!cluster) return false;
415     auto devices = cluster->GetDevices();
416     int num_gpus = 0;
417     int num_ampere = 0;
418     for (const auto& d : devices) {
419       if (d.second.type() == "GPU") {
420         num_gpus++;
421         auto cc_it = d.second.environment().find("architecture");
422         if (cc_it != d.second.environment().end()) {
423           double compute_capability = 0.0;
424           if (absl::SimpleAtod(cc_it->second, &compute_capability) &&
425               compute_capability >= 8.0) {
426             num_ampere++;
427           }
428         }
429       }
430     }
431     bool runtime_fusion_enabled = CudnnUseRuntimeFusion() &&
432                                   CudnnUseFrontend() && num_gpus > 0 &&
433                                   num_gpus == num_ampere;
434 
435     if (CudnnUseRuntimeFusion() && !runtime_fusion_enabled) {
436       VLOG(1) << "Enabling Cudnn with runtime compilation requires the "
437               << "Cudnn frontend and Ampere GPUs or later, but we got "
438               << "Cudnn frontend is "
439               << (CudnnUseFrontend() ? "enabled" : "disabled") << " and "
440               << num_ampere << " Ampere GPU(s) out of total " << num_gpus
441               << " GPU(s)";
442     }
443 
444     return runtime_fusion_enabled;
445 #else
446     return false;
447 #endif
448   }();
449   return is_enabled;
450 }
451 
452 // Checks if we can rewrite a pattern to the `_FusedConv2D` on GPU device.
IsGpuCompatible(const RemapperContext & ctx,const ContractionWithBiasAddAndActivation & matched,const Cluster * cluster)453 bool IsGpuCompatible(const RemapperContext& ctx,
454                      const ContractionWithBiasAddAndActivation& matched,
455                      const Cluster* cluster) {
456 #if TENSORFLOW_USE_ROCM
457   // ROCm does not support _FusedConv2D
458   return false;
459 #endif
460   // The TF->XLA bridge does not support `_Fused[Conv2D|MatMul]` so we avoid
461   // creating this op. Furthermore, XLA already does this fusion internally so
462   // there is no true benefit from doing this optimization if XLA is going to
463   // compile the unfused operations anyway.
464   if (ctx.xla_auto_clustering_on) return false;
465 
466   const GraphDef* graph = ctx.graph_view.graph();
467 
468   // We rely on cuDNN for fused convolution and cublasLt for fused matmul.
469   const NodeDef& activation_node = graph->node(matched.activation);
470   if (!IsSupportedActivation(activation_node)) return false;
471 
472   const NodeDef& contraction_node = graph->node(matched.contraction);
473   if (IsConv2D(contraction_node)) {
474     const std::vector<OpInfo::TensorProperties>& input_props =
475         ctx.graph_properties.GetInputProperties(contraction_node.name());
476     const TensorShapeProto& filter_shape =
477         input_props.size() >= 2 ? input_props[1].shape() : TensorShapeProto();
478 
479     // FusedConv2D on GPU with 1x1 convolution is marginally faster than
480     // in-graph computation in micro benchmarks (see kernels/conv_ops_test.cc),
481     // and significantly slower in large scale benchmarks.
482     bool is_spatial_conv = Rank(filter_shape) == 4 &&          //
483                            IsKnown(filter_shape.dim(0)) &&     //
484                            IsKnown(filter_shape.dim(1)) &&     //
485                            filter_shape.dim(0).size() != 1 &&  //
486                            filter_shape.dim(1).size() != 1;
487 
488     // The CuDNN runtime compiled kernels support the activations of relu6,
489     // elu, leakrelu but require the in_channels and out_channels to be even and
490     // fp16 dtype.
491     bool act_requires_fp16 = IsRelu6(activation_node) ||
492                              IsElu(activation_node) ||
493                              IsLeakyRelu(activation_node);
494     DataType dtype = GetDataTypeFromAttr(activation_node, "T");
495     bool is_fp16 = dtype == DT_HALF;
496     bool valid_channels = Rank(filter_shape) == 4 &&              //
497                           IsKnown(filter_shape.dim(2)) &&         //
498                           IsKnown(filter_shape.dim(3)) &&         //
499                           filter_shape.dim(2).size() % 2 == 0 &&  //
500                           filter_shape.dim(3).size() % 2 == 0;
501     bool is_supported_conv =
502         is_spatial_conv &&
503         (!act_requires_fp16 ||
504          (is_fp16 && valid_channels && RuntimeFusionEnabled(cluster)));
505 
506     return is_supported_conv && IsGpuCompatibleConv2D(ctx, &contraction_node);
507   } else if (IsMatMul(contraction_node)) {
508     return IsGpuCompatibleMatMul(ctx, &contraction_node);
509   }
510 
511   return false;
512 }
513 
514 // Checks if we can rewrite a pattern to the `_FusedMatMul` on GPU device.
IsGpuCompatible(const RemapperContext & ctx,const ContractionWithBiasAdd & matched,const Cluster * cluster)515 bool IsGpuCompatible(const RemapperContext& ctx,
516                      const ContractionWithBiasAdd& matched,
517                      const Cluster* cluster) {
518 #if TENSORFLOW_USE_ROCM
519   // ROCm does not support _FusedMatMul
520   return false;
521 #endif
522   // The TF->XLA bridge does not support `_FusedMatMul` so we avoid creating
523   // this op. Furthermore, XLA already does this fusion internally so there
524   // is no true benefit from doing this optimization if XLA is going to compile
525   // the unfused operations anyway.
526   if (ctx.xla_auto_clustering_on) return false;
527 
528   const GraphDef* graph = ctx.graph_view.graph();
529   const NodeDef& contraction_node = graph->node(matched.contraction);
530   if (!IsMatMul(contraction_node)) return false;
531 
532   return IsGpuCompatibleMatMul(ctx, &contraction_node);
533 }
534 
IsGpuCompatible(const RemapperContext & ctx,const ContractionWithSqueezeAndBiasAdd & matched,const Cluster * cluster)535 bool IsGpuCompatible(const RemapperContext& ctx,
536                      const ContractionWithSqueezeAndBiasAdd& matched,
537                      const Cluster* cluster) {
538   return false;
539 }
540 
541 // Returns true if the given pattern is supported on the assigned device.
542 template <typename Pattern>
IsDeviceCompatible(const RemapperContext & ctx,Pattern & matched,Cluster * cluster=nullptr)543 bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched,
544                         Cluster* cluster = nullptr) {
545   return IsCpuCompatible(ctx, matched) ||
546          IsGpuCompatible(ctx, matched, cluster);
547 }
548 
HasControlFaninOrFanout(const utils::MutableNodeView & node_view)549 inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) {
550   return node_view.NumControllingFanins() > 0 ||
551          node_view.NumControlledFanouts() > 0;
552 }
553 
554 // Returns true if at most one fanout reads output at port 0 (output used once).
HasAtMostOneFanoutAtPort0(const utils::MutableNodeView & node_view)555 inline bool HasAtMostOneFanoutAtPort0(const utils::MutableNodeView& node_view) {
556   return node_view.GetRegularFanout(0).size() <= 1;
557 }
558 
559 // Returns true if at most one fanout reads actual tensor data at output port 0
560 // (output used once for any data computation).
HasAtMostOneDataFanoutAtPort0(const utils::MutableNodeView & node_view)561 inline bool HasAtMostOneDataFanoutAtPort0(
562     const utils::MutableNodeView& node_view) {
563   const auto predicate = [](const auto& fanout) -> bool {
564     const NodeDef* node = fanout.node_view()->node();
565     return !IsShape(*node) && !IsRank(*node);
566   };
567   return absl::c_count_if(node_view.GetRegularFanout(0), predicate) <= 1;
568 }
569 
IsConvOrMatMul(const NodeDef & node)570 bool IsConvOrMatMul(const NodeDef& node) {
571   return IsConv2D(node) || IsDepthwiseConv2dNative(node) || IsMatMul(node) ||
572          IsConv3D(node);
573 }
574 
575 // Returns true if one input to Add is Conv2D/3D or DepthwiseConv2dNative or
576 // MatMul, and the other input is semantically equivalent to BiasAdd.
IsBiasSemanticAdd(const RemapperContext & ctx,const utils::MutableNodeView & node_view,int & bias_port)577 bool IsBiasSemanticAdd(const RemapperContext& ctx,
578                        const utils::MutableNodeView& node_view,
579                        int& bias_port) {
580   if (!IsMKLEnabled()) return false;
581 
582   const auto* node_def = node_view.node();
583   if (!NodeIsOnCpu(node_def)) return false;
584   if (!IsAdd(*node_def) || node_view.NumRegularFanins() != 2) return false;
585 
586   const auto& props = ctx.graph_properties.GetInputProperties(node_def->name());
587   if (props.size() < 2) return false;
588 
589   const auto& regular_fanin_0 = node_view.GetRegularFanin(0);
590   const auto* node_view_0 = regular_fanin_0.node_view();
591   const auto* node_def_0 = node_view_0->node();
592   const auto& regular_fanin_1 = node_view.GetRegularFanin(1);
593   const auto* node_view_1 = regular_fanin_1.node_view();
594   const auto* node_def_1 = node_view_1->node();
595 
596   if (!IsConvOrMatMul(*node_def_0) && !IsConvOrMatMul(*node_def_1))
597     return false;
598 
599   auto is_channel_last_format = [](const NodeDef& node) -> bool {
600     if (node.attr().contains("data_format")) {
601       const string data_format = node.attr().at("data_format").s();
602       return (data_format == "NHWC" || data_format == "NDHWC");
603     }
604     return true;
605   };
606 
607   // Currently supported data formats are NHWC and NDHWC.
608   if (!is_channel_last_format(*node_def_0) ||
609       !is_channel_last_format(*node_def_1))
610     return false;
611 
612   const TensorShapeProto& prot0_shape = props[0].shape();
613   const TensorShapeProto& prot1_shape = props[1].shape();
614 
615   if (prot0_shape.unknown_rank() || prot1_shape.unknown_rank() ||
616       prot0_shape.dim_size() < 1 || prot1_shape.dim_size() < 1 ||
617       !IsKnown(prot0_shape.dim(prot0_shape.dim_size() - 1)) ||
618       !IsKnown(prot1_shape.dim(prot1_shape.dim_size() - 1)))
619     return false;
620 
621   // Helper function to check Add/AddV2 could be replaced with BiasAdd.
622   const auto is_supported_shape =
623       [&](const TensorShapeProto& shape,
624           const TensorShapeProto& bcast_shape) -> bool {
625     int conv_channel_dim;
626     conv_channel_dim = shape.dim(shape.dim_size() - 1).size();
627 
628     if (shape.dim_size() == 4 && bcast_shape.dim_size() > 4) return false;
629     if (shape.dim_size() == 5 && bcast_shape.dim_size() > 5) return false;
630 
631     if (shape.dim_size() < 2) return false;
632     // Check that the conv node's channel dim is equal to the 1-dim add node's
633     // dim
634     if (conv_channel_dim != bcast_shape.dim(bcast_shape.dim_size() - 1).size())
635       return false;
636 
637     // Check that add nodes dims are all 1's except the channel dim
638     for (int i = 0; i < bcast_shape.dim_size() - 1; i++) {
639       if (1 != bcast_shape.dim(i).size()) return false;
640     }
641     return true;
642   };
643 
644   if (ShapesSymbolicallyEqual(prot0_shape, prot1_shape) ||
645       !ShapesBroadcastable(prot0_shape, prot1_shape))
646     return false;
647 
648   if (IsConvOrMatMul(*node_def_0)) {
649     bias_port = 1;
650     return (is_supported_shape(prot0_shape, prot1_shape));
651   } else if (IsConvOrMatMul(*node_def_1)) {
652     bias_port = 0;
653     return (is_supported_shape(prot1_shape, prot0_shape));
654   }
655   return false;
656 }
657 
FindContractionWithBias(const RemapperContext & ctx,int node_index,ContractionWithBiasAdd * matched,bool check_device_compatible=true)658 bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
659                              ContractionWithBiasAdd* matched,
660                              bool check_device_compatible = true) {
661   const auto* node_view = ctx.graph_view.GetNode(node_index);
662   // Root of the pattern must be a BiasAdd.
663   // TODO(lyandy): Forward controls for patterns with control dependencies.
664   if (HasControlFaninOrFanout(*node_view)) return false;
665 
666   const auto* node_def = node_view->node();
667   int bias_port = 1;
668   if (!IsBiasAdd(*node_def) && !IsBiasSemanticAdd(ctx, *node_view, bias_port))
669     return false;
670 
671   // Input to the BiasAdd must be a Conv2D/3D or a MatMul.
672   if (node_view->NumRegularFanins() < 1) return false;
673   const auto& regular_fanin_0 = node_view->GetRegularFanin(1 - bias_port);
674   const auto* contraction_node_view = regular_fanin_0.node_view();
675   const auto* contraction_node_def = contraction_node_view->node();
676 
677   // Conv2D/3D, MatMul or DepthwiseConv2D
678   bool is_contraction = IsConv2D(*contraction_node_def) ||
679                         (IsConv3D(*contraction_node_def) && IsMKLEnabled()) ||
680                         IsMatMul(*contraction_node_def) ||
681                         IsDepthwiseConv2dNative(*contraction_node_def);
682 
683   if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) ||
684       HasControlFaninOrFanout(*contraction_node_view) ||
685       !HasAtMostOneFanoutAtPort0(*contraction_node_view) ||
686       IsInPreserveSet(ctx, contraction_node_def))
687     return false;
688 
689   // Check that data type and data format are supported on assigned device.
690   const ContractionWithBiasAdd pattern{contraction_node_view->node_index(),
691                                        node_index, bias_port};
692   if (check_device_compatible && !IsDeviceCompatible(ctx, pattern))
693     return false;
694 
695   // We successfully found a {Conv2D, MatMul}+BiasAdd pattern.
696   *matched = pattern;
697 
698   return true;
699 }
700 
FindContractionWithBiasAndActivation(const RemapperContext & ctx,Cluster * cluster,int node_index,ContractionWithBiasAddAndActivation * matched)701 bool FindContractionWithBiasAndActivation(
702     const RemapperContext& ctx, Cluster* cluster, int node_index,
703     ContractionWithBiasAddAndActivation* matched) {
704   const auto* node_view = ctx.graph_view.GetNode(node_index);
705   // Root of the pattern must be an activation node.
706   // TODO(lyandy): Forward controls for patterns with control dependencies.
707   if (HasControlFaninOrFanout(*node_view)) return false;
708 
709   const auto* node_def = node_view->node();
710   if (!IsSupportedActivation(*node_def)) return false;
711 
712   // And input to the activation node must match ContractionWithBiasAdd pattern.
713   if (node_view->NumRegularFanins() < 1) return false;
714   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
715   const auto* bias_add_node_view = regular_fanin_0.node_view();
716   const auto* bias_add_node_def = bias_add_node_view->node();
717 
718   ContractionWithBiasAdd base;
719   if (!FindContractionWithBias(ctx, bias_add_node_view->node_index(), &base,
720                                /*check_device_compatible=*/false) ||
721       !HasAtMostOneFanoutAtPort0(*bias_add_node_view) ||
722       !HaveSameDataType(node_def, bias_add_node_def) ||
723       IsInPreserveSet(ctx, bias_add_node_def))
724     return false;
725 
726   // Get the contraction node
727   const auto* contraction_node_view =
728       bias_add_node_view->GetRegularFanin(1 - base.bias_port).node_view();
729   const auto* contraction_node_def = contraction_node_view->node();
730 
731   // Currently, only matmul + bias + (tanh or Sigmoid) is enabled
732   if (!IsMatMul(*contraction_node_def) &&
733       (IsTanh(*node_def) || IsSigmoid(*node_def)))
734     return false;
735 
736   // Currently, only (conv | matmul) + bias + leakyrelu is enabled
737   if (!(IsConv2D(*contraction_node_def) || IsMatMul(*contraction_node_def) ||
738         (IsConv3D(*contraction_node_def) && IsMKLEnabled())) &&
739       IsLeakyRelu(*node_def))
740     return false;
741 
742   // Check that data type and data format are supported on assigned device.
743   const ContractionWithBiasAddAndActivation pattern{
744       base.contraction, base.bias_add, node_index, base.bias_port};
745   if (!IsDeviceCompatible(ctx, pattern, cluster)) return false;
746 
747   // We successfully found a {Conv2D, MatMul}+BiasAdd+Activation pattern.
748   *matched = pattern;
749 
750   return true;
751 }
752 
FindConvWithSqueezeAndBias(const RemapperContext & ctx,int node_index,ContractionWithSqueezeAndBiasAdd * matched)753 bool FindConvWithSqueezeAndBias(const RemapperContext& ctx, int node_index,
754                                 ContractionWithSqueezeAndBiasAdd* matched) {
755   const auto* node_view = ctx.graph_view.GetNode(node_index);
756   // TODO(lyandy): Forward controls for patterns with control dependencies.
757   if (HasControlFaninOrFanout(*node_view)) return false;
758 
759   // Root of the pattern must be a BiasAdd.
760   const auto* node_def = node_view->node();
761   if (!IsBiasAdd(*node_def)) return false;
762 
763   // Input to the BiasAdd must be a Squeeze.
764   if (node_view->NumRegularFanins() < 1) return false;
765   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
766   const auto* squeeze_node_view = regular_fanin_0.node_view();
767   const auto* squeeze_node_def = squeeze_node_view->node();
768 
769   if (!IsSqueeze(*squeeze_node_def) ||
770       !HaveSameDataType(node_def, squeeze_node_def, "T") ||
771       HasControlFaninOrFanout(*squeeze_node_view) ||
772       !HasAtMostOneFanoutAtPort0(*squeeze_node_view) ||
773       IsInPreserveSet(ctx, squeeze_node_def))
774     return false;
775 
776   // Input to the Squeeze must be a Conv2D/3D.
777   if (squeeze_node_view->NumRegularFanins() < 1) return false;
778   const auto& squeeze_regular_fanin_0 = squeeze_node_view->GetRegularFanin(0);
779   const auto* conv_node_view = squeeze_regular_fanin_0.node_view();
780   const auto* conv_node_def = conv_node_view->node();
781 
782   if (!(IsConv2D(*conv_node_def) ||
783         (IsConv3D(*conv_node_def) && IsMKLEnabled())) ||
784       !HaveSameDataType(node_def, conv_node_def, "T") ||
785       HasControlFaninOrFanout(*conv_node_view) ||
786       !HasAtMostOneFanoutAtPort0(*conv_node_view) ||
787       IsInPreserveSet(ctx, conv_node_def))
788     return false;
789 
790   // Squeeze must not squeeze output channel dimension.
791   std::vector<int32> dims;
792   if (!TryGetNodeAttr(*squeeze_node_def, "squeeze_dims", &dims)) return false;
793   for (auto dim : dims) {
794     if ((dim == 3 && IsConv2D(*conv_node_def)) ||
795         (dim == 4 && IsConv3D(*conv_node_def)))
796       return false;
797   }
798 
799   // Check that data type and data format are supported on assigned device.
800   const ContractionWithSqueezeAndBiasAdd pattern{
801       conv_node_view->node_index(), squeeze_node_view->node_index(),
802       node_index};
803   if (!IsDeviceCompatible(ctx, pattern)) return false;
804 
805   // We successfully found a Conv2D+Squeeze+BiasAdd pattern.
806   *matched = pattern;
807 
808   return true;
809 }
810 
FindConv2DWithBatchNorm(const RemapperContext & ctx,int node_index,ContractionWithBatchNorm * matched)811 bool FindConv2DWithBatchNorm(const RemapperContext& ctx, int node_index,
812                              ContractionWithBatchNorm* matched) {
813   const auto* node_view = ctx.graph_view.GetNode(node_index);
814   const auto* node_def = node_view->node();
815   // Root of the pattern must be a FusedBatchNorm.
816   if (!IsFusedBatchNorm(*node_def)) return false;
817 
818   // FusedBatchNormV2 and V3 have an extra type parameter.
819   // Conv2D + FusedBatchNormV2/V3 fusion is currently not supported for bf16.
820   // TODO(intel-tf): enable the fusion for bf16
821   bool dtypeU_is_float = HasDataType(node_def, DT_FLOAT, "U");
822   bool dtypeT_is_bf16 = HasDataType(node_def, DT_BFLOAT16, "T");
823   if (node_view->GetOp() != "FusedBatchNorm" &&
824       (!dtypeU_is_float || dtypeT_is_bf16)) {
825     return false;
826   }
827 
828   // Check that batch normalization is in inference mode.
829   const auto* training_attr = node_view->GetAttr(kIsTraining);
830   if (training_attr != nullptr && training_attr->b()) return false;
831 
832   // Check that only 0th output is consumed by other nodes.
833   // TODO(lyandy): Forward controls for patterns with control dependencies.
834   if (HasControlFaninOrFanout(*node_view) ||
835       !node_view->GetRegularFanout(1).empty() ||  // batch_mean
836       !node_view->GetRegularFanout(2).empty() ||  // batch_variance
837       !node_view->GetRegularFanout(3).empty() ||  // reserve_space_1
838       !node_view->GetRegularFanout(4).empty())    // reserve_space_2
839     return false;
840 
841   // Input to the FusedBatchNorm must be a Conv2D.
842   if (node_view->NumRegularFanins() < 1) return false;
843   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
844   const auto* conv2d_node_view = regular_fanin_0.node_view();
845   const auto* conv2d_node_def = conv2d_node_view->node();
846 
847   if (!IsConv2D(*conv2d_node_def) || !NodeIsOnCpu(conv2d_node_def) ||
848       !HaveSameDataType(node_def, conv2d_node_def) ||
849       !IsCpuCompatibleDataType(conv2d_node_def) ||
850       !IsCpuCompatibleDataFormat(ctx, conv2d_node_def) ||
851       HasControlFaninOrFanout(*conv2d_node_view) ||
852       !HasAtMostOneFanoutAtPort0(*conv2d_node_view) ||
853       IsInPreserveSet(ctx, conv2d_node_def))
854     return false;
855 
856   // We successfully found a Conv2D+FusedBatchNorm pattern.
857   matched->contraction = conv2d_node_view->node_index();
858   matched->fused_batch_norm = node_index;
859   if (!TryGetNodeAttr(*node_def, "epsilon", &matched->epsilon)) return false;
860 
861   return true;
862 }
863 
FindConv2DWithBatchNormAndActivation(const RemapperContext & ctx,int node_index,ContractionWithBatchNormAndActivation * matched)864 bool FindConv2DWithBatchNormAndActivation(
865     const RemapperContext& ctx, int node_index,
866     ContractionWithBatchNormAndActivation* matched) {
867   const auto* node_view = ctx.graph_view.GetNode(node_index);
868   // TODO(lyandy): Forward controls for patterns with control dependencies.
869   if (HasControlFaninOrFanout(*node_view)) return false;
870 
871   // Root of the pattern must be an activation node.
872   const auto* node_def = node_view->node();
873   if (!IsSupportedActivation(*node_def)) return false;
874 
875   // Need to test and enable in Kernel Op before enabling
876   // this activation TODO(intel-tf)
877   if (IsSigmoid(*node_def)) return false;
878 
879   // And input to the activation node must match Conv2DWithBatchNorm pattern.
880   if (node_view->NumRegularFanins() < 1) return false;
881   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
882   const auto* batch_norm_node_view = regular_fanin_0.node_view();
883 
884   ContractionWithBatchNorm base;
885   if (!FindConv2DWithBatchNorm(ctx, batch_norm_node_view->node_index(), &base))
886     return false;
887 
888   const auto* fused_batch_norm_node_view =
889       ctx.graph_view.GetNode(base.fused_batch_norm);
890   const auto* fused_batch_norm_node_def = fused_batch_norm_node_view->node();
891   if (!HasAtMostOneFanoutAtPort0(*fused_batch_norm_node_view) ||
892       !HaveSameDataType(node_def, fused_batch_norm_node_def) ||
893       IsInPreserveSet(ctx, fused_batch_norm_node_def))
894     return false;
895 
896   // We successfully found a Conv2D+FusedBatchNorm+Activation pattern.
897   matched->contraction = base.contraction;
898   matched->fused_batch_norm = base.fused_batch_norm;
899   matched->activation = node_index;
900   matched->epsilon = base.epsilon;
901 
902   return true;
903 }
904 
905 // As AddN has multiple inputs, this function tries to find Conv2D + Bias
906 // pattern in specific input port.
FindContractionWithBiasInPort(const RemapperContext & ctx,const utils::MutableNodeView & add_node_view,const NodeDef & add_node_def,int port_id,ContractionWithBiasAdd * base)907 bool FindContractionWithBiasInPort(const RemapperContext& ctx,
908                                    const utils::MutableNodeView& add_node_view,
909                                    const NodeDef& add_node_def, int port_id,
910                                    ContractionWithBiasAdd* base) {
911   // Input to AddN must match ContractionWithBiasAdd pattern.
912   if (add_node_view.NumRegularFanins() < port_id + 1) return false;
913   const auto& bias_add_node_view =
914       add_node_view.GetRegularFanin(port_id).node_view();
915   if (bias_add_node_view == nullptr) return false;
916   const auto* bias_add_node_def = bias_add_node_view->node();
917 
918   if (!FindContractionWithBias(ctx, bias_add_node_view->node_index(), base,
919                                /*check_device_compatible=*/false))
920     return false;
921   if (!HasAtMostOneFanoutAtPort0(*bias_add_node_view) ||
922       !HaveSameDataType(&add_node_def, bias_add_node_def) ||
923       IsInPreserveSet(ctx, bias_add_node_def))
924     return false;
925   return true;
926 }
927 
IsAddWithNoBroadcast(const RemapperContext & ctx,const NodeDef & node)928 bool IsAddWithNoBroadcast(const RemapperContext& ctx, const NodeDef& node) {
929   if (!IsAdd(node)) return false;
930 
931   // Check if this is case of broadcasting - Add node supports broadcasting.
932   const auto& props = ctx.graph_properties.GetInputProperties(node.name());
933   if (props.size() == 2 &&
934       ShapesSymbolicallyEqual(props[0].shape(), props[1].shape())) {
935     return true;
936   }
937   return false;
938 }
939 
FindPadWithConv3D(const RemapperContext & ctx,int node_index,PadWithConv3D * matched)940 bool FindPadWithConv3D(const RemapperContext& ctx, int node_index,
941                        PadWithConv3D* matched) {
942   if (!IsMKLEnabled()) return false;
943   const auto* node_view = ctx.graph_view.GetNode(node_index);
944   const auto* node_def = node_view->node();
945   // The optimization is only for CPU
946   if (!NodeIsOnCpu(node_def)) return false;
947   // Root of the pattern must be a Conv3D or _FusedConv3D
948   if (!(IsConv3D(*node_def) || node_def->op() == kFusedConv3D)) return false;
949   if (!(HasDataType(node_def, DT_FLOAT) || HasDataType(node_def, DT_BFLOAT16)))
950     return false;
951 
952   // Input to Conv3D/_FusedConv3D must be a Pad
953   if (node_view->NumRegularFanins() < 1) return false;
954   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
955   const auto* pad_node_view = regular_fanin_0.node_view();
956   const auto* pad_node_def = pad_node_view->node();
957   const auto& padding_const = pad_node_view->GetRegularFanin(1);
958   const auto* padding_const_node_view = padding_const.node_view();
959 
960   if (!(pad_node_def->op() == "Pad") ||
961       !HaveSameDataType(node_def, pad_node_def))
962     return false;
963   const PadWithConv3D pattern{node_view->node_index(),
964                               pad_node_view->node_index(),
965                               padding_const_node_view->node_index()};
966 
967   // Successfully found a Pad+{Conv3D, _FusedConv3D} pattern.
968   *matched = pattern;
969   return true;
970 }
971 
FindContractionWithBiasAddAndAdd(const RemapperContext & ctx,const utils::MutableNodeView & node_view,ContractionWithBiasAddAndAdd * matched)972 bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx,
973                                       const utils::MutableNodeView& node_view,
974                                       ContractionWithBiasAddAndAdd* matched) {
975   // Fusion with AddN is supported only when it has two inputs.
976   // TODO(lyandy): Forward controls for patterns with control dependencies.
977   if (HasControlFaninOrFanout(node_view) || node_view.NumRegularFanins() != 2)
978     return false;
979 
980   // Root of the pattern must be a AddN or Add with same input shapes
981   // (no broadcasting).
982   const auto* node_def = node_view.node();
983   if (!IsAddN(*node_def) && !IsAddWithNoBroadcast(ctx, *node_def)) return false;
984 
985   if (!NodeIsOnCpu(node_def)) return false;
986 
987   // MKL AddN ops only support float and bfloat16 data types.
988   if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16))
989     return false;
990 
991   ContractionWithBiasAdd base;
992   matched->port_id = 0;
993 
994   // Find the conv+bias pattern in specific port.
995   if (!FindContractionWithBiasInPort(ctx, node_view, *node_def,
996                                      matched->port_id, &base)) {
997     matched->port_id = 1;
998     if (!FindContractionWithBiasInPort(ctx, node_view, *node_def,
999                                        matched->port_id, &base)) {
1000       return false;
1001     }
1002   }
1003 
1004   // We successfully found a {Conv2D,Conv3D}+BiasAdd+{AddN,Add} pattern.
1005   matched->contraction = base.contraction;
1006   matched->bias_add = base.bias_add;
1007   matched->add = node_view.node_index();
1008   matched->bias_port = base.bias_port;
1009 
1010   return true;
1011 }
1012 
FindContractionWithBiasAddAndAdd(const RemapperContext & ctx,int node_index,ContractionWithBiasAddAndAdd * matched)1013 bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx,
1014                                       int node_index,
1015                                       ContractionWithBiasAddAndAdd* matched) {
1016   const auto* node_view = ctx.graph_view.GetNode(node_index);
1017   return FindContractionWithBiasAddAndAdd(ctx, *node_view, matched);
1018 }
1019 
FindContractionWithBiasAndAddActivation(const RemapperContext & ctx,int node_index,ContractionWithBiasAndAddActivation * matched)1020 bool FindContractionWithBiasAndAddActivation(
1021     const RemapperContext& ctx, int node_index,
1022     ContractionWithBiasAndAddActivation* matched) {
1023   const auto* node_view = ctx.graph_view.GetNode(node_index);
1024   // TODO(lyandy): Forward controls for patterns with control dependencies.
1025   if (HasControlFaninOrFanout(*node_view)) return false;
1026 
1027   // Root of the pattern must be an activation node.
1028   const auto* node_def = node_view->node();
1029   if (node_def == nullptr) return false;
1030   if (!IsSupportedActivation(*node_def)) return false;
1031 
1032   if (!NodeIsOnCpu(node_def)) return false;
1033 
1034   // Currently, Contraction + Bias + Add + Tanh pattern is not supported
1035   if (IsTanh(*node_def)) return false;
1036 
1037   // Need to test and enable in Kernel Op before enabling
1038   // this activation. TODO(intel-tf)
1039   if (IsSigmoid(*node_def)) return false;
1040 
1041   // MKL activation op only supports float and bfloat16 data types.
1042   if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16))
1043     return false;
1044 
1045   // And input to activation must match ContractionWithBiasAddAndAdd pattern.
1046   if (node_view->NumRegularFanins() < 1) return false;
1047   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
1048   const auto* add_node_view = regular_fanin_0.node_view();
1049 
1050   ContractionWithBiasAddAndAdd base;
1051 
1052   if (!FindContractionWithBiasAddAndAdd(ctx, *add_node_view, &base)) {
1053     return false;
1054   }
1055 
1056   // Get the contraction node
1057   const auto* bias_add_node_view =
1058       add_node_view->GetRegularFanin(base.port_id).node_view();
1059   const auto* contraction_node_view =
1060       bias_add_node_view->GetRegularFanin(0).node_view();
1061   const auto* contraction_node_def = contraction_node_view->node();
1062 
1063   // Currently, only conv + bias + add + leakyrelu is enabled
1064   if (!(IsConv2D(*contraction_node_def) || IsConv3D(*contraction_node_def)) &&
1065       IsLeakyRelu(*node_def))
1066     return false;
1067   // Conv3D fusion is available with oneDNN enabled
1068   if (IsConv3D(*contraction_node_def) && !IsMKLEnabled()) return false;
1069 
1070   // We successfully found a Conv2D+BiasAdd+AddN+activation pattern
1071   // or Conv3D+BiasAdd+AddN+activation pattern
1072   const ContractionWithBiasAndAddActivation pattern{
1073       base.contraction, base.bias_add, base.add,
1074       base.port_id,     node_index,    base.bias_port};
1075   *matched = pattern;
1076 
1077   return true;
1078 }
1079 
VerifyConstants(RemapperContext * ctx,std::map<string,int> * nodes_map,std::map<string,float> * values_map)1080 inline bool VerifyConstants(RemapperContext* ctx,
1081                             std::map<string, int>* nodes_map,
1082                             std::map<string, float>* values_map) {
1083   using utils::MutableNodeView;
1084   for (auto it = values_map->begin(); it != values_map->end(); ++it) {
1085     int node_idx = nodes_map->at(it->first);
1086     MutableNodeView* node_view = ctx->graph_view.GetNode(node_idx);
1087     NodeDef* node_def = node_view->node();
1088     Tensor const_tensor;
1089     if (node_def != nullptr && node_def->op() == "Const" &&
1090         const_tensor.FromProto(node_def->attr().at("value").tensor())) {
1091       if (const_tensor.NumElements() == 1) {
1092         DataType dtype = const_tensor.dtype();
1093         float const_value;
1094         if (dtype == DT_FLOAT) {
1095           const_value = const_tensor.flat<float>()(0);
1096         } else if (dtype == DT_BFLOAT16) {
1097           const_value = const_tensor.flat<bfloat16>()(0);
1098         } else if (dtype == DT_HALF) {
1099           const_value = const_tensor.flat<Eigen::half>()(0);
1100         } else {
1101           return false;
1102         }
1103         if (std::abs(const_value - it->second) > 1e-2) return false;
1104       } else {
1105         return false;
1106       }
1107     } else {
1108       return false;
1109     }
1110   }
1111   return true;
1112 }
1113 
1114 // Gelu in python api generates a number of nodes in the graph. Depending on the
1115 // parmeter `approximate={True/False}` different types of ops are generated. We
1116 // distinguish them as `GeluExact` that uses Erf and `GeluApproximate` that
1117 // uses Tanh.
FindMatMulBiasAddAndGelu(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices,bool * is_gelu_approximate)1118 bool FindMatMulBiasAddAndGelu(RemapperContext* ctx, int node_index,
1119                               std::map<string, int>* matched_nodes_map,
1120                               std::set<int>* remove_node_indices,
1121                               bool* is_gelu_approximate) {
1122   // Gelu fusion is enabled with oneDNN or cublasLt library.
1123   if (!IsMKLEnabled() && !BlasLtMatmulEnabled()) return false;
1124 
1125   using utils::MatchingDirection;
1126   using utils::NodeStatus;
1127   // clang-format off
1128   utils::OpTypePattern gelu_exact_pattern =
1129     {"Mul", "output", NodeStatus::kReplace,
1130       {
1131         {"Mul", "erf_plus_one_times_one_half", NodeStatus::kRemove,
1132           {
1133             {"AddV2", "erf_plus_one", NodeStatus::kRemove,
1134               {
1135                 {"Erf", "erf", NodeStatus::kRemove,
1136                   {
1137                     {"Mul", "bias_add_times_square_root_one_half", NodeStatus::kRemove,
1138                       {
1139                         {"BiasAdd", "bias_add", NodeStatus::kRemove},
1140                         {"Const", "square_root_one_half", NodeStatus::kRemain}
1141                       }
1142                     }
1143                   }
1144                 },
1145                 {"Const", "one", NodeStatus::kRemain}
1146               }
1147             },
1148             {"Const", "one_half", NodeStatus::kRemain}
1149           }
1150         },
1151         {"BiasAdd", "bias_add", NodeStatus::kRemove,
1152           {
1153             {"MatMul", "matmul", NodeStatus::kRemove},
1154             {"*", "bias", NodeStatus::kRemain}
1155           }
1156         }
1157       }
1158     };
1159   // clang-format on
1160 
1161   // Gelu approximate uses Pow(x, 3). On GPU, it is a single Pow() node, but on
1162   // CPU, it is optimized by arithmetic optimizer as Mul(x, Square(x)) with an
1163   // arifact of control dependency. So we try to match pattern at second pass of
1164   // remapper which reccieves _FusedMatMul (MatMul + BiasAdd) with control
1165   // dependency removed.
1166   // clang-format off
1167   utils::OpTypePattern subgraph_gpu =
1168     {"Mul", "mul", NodeStatus::kRemove,
1169       {
1170         {"Pow", "pow", NodeStatus::kRemove,
1171           {
1172             {"_FusedMatMul", "matmul", NodeStatus::kRemove},
1173             {"Const", "three", NodeStatus::kRemain}
1174           }
1175         },
1176         {"Const", "empirical_const", NodeStatus::kRemain}
1177       }
1178     };
1179   utils::OpTypePattern subgraph_cpu =
1180     {"Mul", "mul", NodeStatus::kRemove,
1181       {
1182         {"Mul", "empirical_const_times_matmul", NodeStatus::kRemove,
1183           {
1184             {"Const", "empirical_const", NodeStatus::kRemain},
1185             {"_FusedMatMul", "matmul", NodeStatus::kRemove}
1186           }
1187         },
1188         {"Square", "square", NodeStatus::kRemove,
1189           {
1190             {"_FusedMatMul", "matmul", NodeStatus::kRemove}
1191           }
1192         }
1193       }
1194     };
1195   // clang-format on
1196 
1197   utils::MutableNodeView* node_view = ctx->graph_view.GetNode(node_index);
1198   const NodeDef* node_def = node_view->node();
1199   bool root_on_gpu = NodeIsOnGpu(node_def);
1200   utils::OpTypePattern* subgraph_pattern =
1201       root_on_gpu ? &subgraph_gpu : &subgraph_cpu;
1202 
1203   // clang-format off
1204   utils::OpTypePattern gelu_approximate_pattern =
1205     {"Mul", "output", NodeStatus::kReplace,
1206       {
1207         {"Mul", "tanh_plus_one_times_one_half", NodeStatus::kRemove,
1208           {
1209             {"AddV2", "tanh_plus_one", NodeStatus::kRemove,
1210               {
1211                 {"Tanh", "tanh", NodeStatus::kRemove,
1212                   {
1213                     {"Mul", "matmul_plus_mul_times_square_root_two_over_pi", NodeStatus::kRemove,
1214                       {
1215                         {"AddV2", "matmul_plus_mul", NodeStatus::kRemove,
1216                           {
1217                             {"_FusedMatMul", "matmul", NodeStatus::kRemove},
1218                             *subgraph_pattern
1219                           }
1220                         },
1221                         {"Const", "square_root_two_over_pi", NodeStatus::kRemain}
1222                       }
1223                     }
1224                   }
1225                 },
1226                 {"Const", "one", NodeStatus::kRemain}
1227               }
1228             },
1229             {"Const", "one_half", NodeStatus::kRemain}
1230           }
1231         },
1232         {"_FusedMatMul", "matmul", NodeStatus::kRemove}
1233       }
1234     };
1235   // clang-format on
1236 
1237   bool found_gelu_exact = false;
1238   bool found_gelu_approximate = false;
1239   utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
1240       &(ctx->graph_view));
1241   // Find GeluExact
1242   matched_nodes_map->clear();
1243   remove_node_indices->clear();
1244   found_gelu_exact = graph_matcher.GetMatchedNodes(
1245       gelu_exact_pattern, ctx->nodes_to_preserve, node_view, matched_nodes_map,
1246       remove_node_indices);
1247   // Find GeluApproximate
1248   if (!found_gelu_exact) {
1249     matched_nodes_map->clear();
1250     remove_node_indices->clear();
1251     found_gelu_approximate = graph_matcher.GetMatchedNodes(
1252         gelu_approximate_pattern, ctx->nodes_to_preserve, node_view,
1253         matched_nodes_map, remove_node_indices);
1254   }
1255 
1256   // Pattern matcher does subgraph matching based on op types only. The matcher
1257   // also does a sanity check on nodes tagged as `kRemove`, i.e., they do not
1258   // have any consumer outside the matched nodes. In order to replace the
1259   // subgraph, we need additional checks, for example, if the key ops have been
1260   // placed on CPU or GPU, desired data type, const has desired value etc. For
1261   // the following fusion: MatMul + BiasAdd + Gelu (disintegrated into smaller
1262   // ops), we check if (i) MatMul op is CpuCompatible or GpuComptible, (ii)
1263   // const nodes have desired values.
1264   if (found_gelu_exact) {
1265     if (!IsMKLEnabled()) return false;
1266     // Check if the MatMul to be fused is CPU compatible
1267     // TODO(kaixih@nvidia): Add GPU support when cublastLt supports the exact
1268     // form.
1269     NodeDef* matmul_node =
1270         ctx->graph_view.GetNode(matched_nodes_map->at("matmul"))->node();
1271     if (!IsCpuCompatibleMatMul(*ctx, matmul_node)) {
1272       matched_nodes_map->clear();
1273       remove_node_indices->clear();
1274       return false;
1275     }
1276     // Check if the matched constants have desired values.
1277     if (found_gelu_exact) {
1278       std::map<string, float> values_map = {
1279           {"square_root_one_half", 0.707106}, {"one", 1.0}, {"one_half", 0.5}};
1280       if (!VerifyConstants(ctx, matched_nodes_map, &values_map)) return false;
1281     }
1282   } else if (found_gelu_approximate) {
1283     NodeDef* matmul_node =
1284         ctx->graph_view.GetNode(matched_nodes_map->at("matmul"))->node();
1285 
1286     // matmul_node is already the _FusedMatMul and we don't need to check its
1287     // data type again.
1288     if (!IsMKLEnabled() && !NodeIsOnGpu(matmul_node)) return false;
1289 
1290     // Check if _FusedMatMul contains only BiasAdd
1291     auto fused_ops = matmul_node->attr().at("fused_ops").list().s();
1292     if (fused_ops.size() == 1) {
1293       if (fused_ops.at(0) != "BiasAdd") return false;
1294     } else {
1295       return false;
1296     }
1297     // Check if the matched constants have desired values.
1298     std::map<string, float> values_map = {{"square_root_two_over_pi", 0.797884},
1299                                           {"one", 1.0},
1300                                           {"one_half", 0.5},
1301                                           {"empirical_const", 0.044715}};
1302     if (NodeIsOnGpu(matmul_node)) {
1303       values_map["three"] = 3.0;
1304     }
1305 
1306     if (!VerifyConstants(ctx, matched_nodes_map, &values_map)) return false;
1307   } else {
1308     return false;
1309   }
1310   *is_gelu_approximate = found_gelu_approximate ? true : false;
1311   return (found_gelu_exact || found_gelu_approximate);
1312 }
1313 
FindMulAndMaximum(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)1314 bool FindMulAndMaximum(RemapperContext* ctx, int node_index,
1315                        std::map<string, int>* matched_nodes_map,
1316                        std::set<int>* remove_node_indices) {
1317   using utils::MatchingDirection;
1318   using utils::NodeStatus;
1319 
1320   // clang-format off
1321   // Convert Mul+Maximum to LeakyRelu
1322   // maximum(x, alpha * x) = LeakyRelu(x)
1323   utils::OpTypePattern mulmax_pattern{
1324     "Maximum", "max_to_leakyrelu", NodeStatus::kReplace,
1325     {
1326       { "Mul", "mul", NodeStatus::kRemove,
1327         {
1328           { "*", "input", NodeStatus::kRemain},
1329           { "Const|Cast", "alpha", NodeStatus::kRemain}
1330         }
1331       },
1332       { "*", "input", NodeStatus::kRemain}
1333     }
1334   };
1335   // clang-format on
1336   // Check for allowed datatypes
1337   auto* max_node_def = ctx->graph_view.GetNode(node_index)->node();
1338   if (!HasDataType(max_node_def, DT_HALF) &&
1339       !HasDataType(max_node_def, DT_BFLOAT16) &&
1340       !HasDataType(max_node_def, DT_FLOAT) &&
1341       !HasDataType(max_node_def, DT_DOUBLE))
1342     return false;
1343 
1344   // Current implementation has support only
1345   // for CPU when oneDNN is enabled.
1346   // TODO(intel-tf): This will be removed when fully tested with GPU
1347   if (!NodeIsOnCpu(max_node_def) && !IsMKLEnabled()) return false;
1348 
1349   bool found_op_type_match = false;
1350   utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
1351       &(ctx->graph_view));
1352   matched_nodes_map->clear();
1353   remove_node_indices->clear();
1354 
1355   found_op_type_match = graph_matcher.GetMatchedNodes(
1356       mulmax_pattern, {}, ctx->graph_view.GetNode(node_index),
1357       matched_nodes_map, remove_node_indices);
1358 
1359   // Check if the value of alpha >= 0 as required for LeakyRelu
1360   if (found_op_type_match) {
1361     const auto* alpha_node_view =
1362         ctx->graph_view.GetNode(matched_nodes_map->at("alpha"));
1363     const auto* alpha_node_def = alpha_node_view->node();
1364 
1365     float alpha_val;
1366     Tensor alpha_tensor;
1367     if (alpha_node_def->op() == "Cast") {
1368       const auto& regular_fanin_0 = alpha_node_view->GetRegularFanin(0);
1369       const auto* regular_node_view = regular_fanin_0.node_view();
1370       const auto* const_node = regular_node_view->node();
1371       if (const_node != nullptr && const_node->op() == "Const" &&
1372           alpha_tensor.FromProto(const_node->attr().at("value").tensor())) {
1373         // Only fusing if the const is a scalar value
1374         if (alpha_tensor.shape().dims() > 0) {
1375           return false;
1376         }
1377         alpha_val = alpha_tensor.flat<float>()(0);
1378       } else {
1379         return false;
1380       }
1381     } else if (alpha_node_def->op() == "Const" &&
1382                alpha_tensor.FromProto(
1383                    alpha_node_def->attr().at("value").tensor())) {
1384       // Only fusing if the const is a scalar value
1385       if (alpha_tensor.shape().dims() > 0) {
1386         return false;
1387       }
1388       alpha_val = alpha_tensor.flat<float>()(0);
1389     } else {
1390       return false;
1391     }
1392 
1393     if (alpha_val < 0) {
1394       return false;
1395     }
1396   }
1397   return found_op_type_match;
1398 }
1399 
FindSigmoidAndMul(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)1400 bool FindSigmoidAndMul(RemapperContext* ctx, int node_index,
1401                        std::map<string, int>* matched_nodes_map,
1402                        std::set<int>* remove_node_indices) {
1403   // Gelu fusion is enabled only with oneDNN library.
1404   if (!IsMKLEnabled()) return false;
1405 
1406   using utils::MatchingDirection;
1407   using utils::NodeStatus;
1408   // clang-format off
1409   // Convert Sigmoid+Mul to Swish
1410   // Mul(x, Sigmoid(x)) --> _MklSwish(x)
1411 
1412   utils::OpTypePattern sigmoidmul_pattern{
1413     "Mul", "mul_to_swish", NodeStatus::kReplace,
1414     {
1415       { "Sigmoid", "sigmoid", NodeStatus::kRemove,
1416         {
1417           { "*", "input", NodeStatus::kRemain}
1418         }
1419       },
1420       { "*", "input", NodeStatus::kRemain}
1421     }
1422   };
1423   // clang-format on
1424   // check for data types
1425   auto* mul_node_def = ctx->graph_view.GetNode(node_index)->node();
1426   if (!HasDataType(mul_node_def, DT_FLOAT) &&
1427       !HasDataType(mul_node_def, DT_BFLOAT16))
1428     return false;
1429 
1430   if (!NodeIsOnCpu(mul_node_def)) return false;
1431 
1432   bool found_op_type_match = false;
1433   utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
1434       &(ctx->graph_view));
1435   matched_nodes_map->clear();
1436   remove_node_indices->clear();
1437   found_op_type_match = graph_matcher.GetMatchedNodes(
1438       sigmoidmul_pattern, {}, ctx->graph_view.GetNode(node_index),
1439       matched_nodes_map, remove_node_indices);
1440 
1441   return found_op_type_match;
1442 }
1443 
1444 // Keras LayerNormalization api uses multiple TensorFlow ops. Current fusion
1445 // pattern is only for the case, when LayerNormalization uses FusedBatcNormV3.
1446 // We further restrict it to only 2D or 3D tensor inputs to keras
1447 // LayerNormalization api.
FindMklLayerNorm(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)1448 bool FindMklLayerNorm(RemapperContext* ctx, int node_index,
1449                       std::map<string, int>* matched_nodes_map,
1450                       std::set<int>* remove_node_indices) {
1451   if (!IsMKLEnabled()) return false;
1452 
1453   // The following pattern will be searched in the graph with additional
1454   // contraints. Here * means any type of op.
1455   // clang-format off
1456   //              Subgraph for fusion
1457   //              -------------------
1458   //
1459   //     *(input)  *  * Const  *  Const                       FusedOp
1460   //          \    |   \  |    |  /        Const              -------
1461   //           \   |    \ |    | /  Const   /
1462   //           Reshape  Fill   Fill  /     /         *(input) *(gamma)  *(beta)
1463   //              \      /      /   /     /                \     |      /
1464   //               \    /      /   /     /                  \    |     /
1465   //          F u s e d B a t c h N o r m V 3              _MklLayerNorm
1466   //                 \
1467   //                  \   *
1468   //                   \ /
1469   //                 Reshape
1470   //                    \   *(gamma)
1471   //                     \ /
1472   //                     Mul
1473   //             *(beta) /
1474   //                \   /
1475   //                AddV2(output)
1476   // clang-format on
1477   using utils::MatchingDirection;
1478   using utils::NodeStatus;
1479   // clang-format off
1480   utils::OpTypePattern layer_norm_pattern =
1481     {"AddV2", "output", NodeStatus::kReplace,
1482       {
1483         {"*", "beta", NodeStatus::kRemain},
1484         {"Mul", "scale", NodeStatus::kRemove,
1485           {
1486             {"Reshape", "post_reshape", NodeStatus::kRemove,
1487               {
1488                 {"FusedBatchNormV3", "fused_batch_norm", NodeStatus::kRemove,
1489                   {
1490                     {"Reshape", "pre_reshape", NodeStatus::kRemove,
1491                       {
1492                         {"*", "input", NodeStatus::kRemain},
1493                         {"*", "pre_shape", NodeStatus::kRemain}
1494                       }
1495                     },
1496                     {"Fill", "fill_scale", NodeStatus::kRemove,
1497                       {
1498                         {"*", "dims_fill_scale", NodeStatus::kRemain},
1499                         {"Const", "unit_gamma", NodeStatus::kRemain}
1500                       }
1501                     },
1502                     {"Fill", "fill_offset", NodeStatus::kRemove,
1503                       {
1504                         {"*", "dims_fill_offset", NodeStatus::kRemain},
1505                         {"Const", "zero_beta", NodeStatus::kRemain}
1506                       }
1507                     },
1508                     {"Const", "empty", NodeStatus::kRemain},
1509                     {"Const", "empty", NodeStatus::kRemain}
1510                   }
1511                 },
1512                 {"*", "post_shape", NodeStatus::kRemain}
1513               }
1514             },
1515             {"*", "gamma", NodeStatus::kRemain}
1516           }
1517         }
1518       }
1519     };  // clang-format on
1520 
1521   utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
1522       &(ctx->graph_view));
1523   bool found_op_type_match = false;
1524   matched_nodes_map->clear();
1525   remove_node_indices->clear();
1526   found_op_type_match =
1527       graph_matcher.GetMatchedNodes(layer_norm_pattern, ctx->nodes_to_preserve,
1528                                     ctx->graph_view.GetNode(node_index),
1529                                     matched_nodes_map, remove_node_indices);
1530 
1531   // Additional check for LayerNorm
1532   if (found_op_type_match) {
1533     // LayerNorm uses FusedBatchNorm in training mode.
1534     NodeDef* fused_batch_norm_node =
1535         ctx->graph_view.GetNode(matched_nodes_map->at("fused_batch_norm"))
1536             ->node();
1537     bool is_training = false;
1538     if (!TryGetNodeAttr(*fused_batch_norm_node, kIsTraining, &is_training) ||
1539         !is_training)
1540       return false;
1541 
1542     if (!NodeIsOnCpu(fused_batch_norm_node)) return false;
1543 
1544     // FusedBatchNorm node should have mean/variance as empty constant
1545     NodeDef* empty_const_node =
1546         ctx->graph_view.GetNode(matched_nodes_map->at("empty"))->node();
1547     Tensor const_tensor;
1548     if (empty_const_node != nullptr && empty_const_node->op() == "Const" &&
1549         const_tensor.FromProto(empty_const_node->attr().at("value").tensor())) {
1550       if (const_tensor.NumElements() != 0) return false;
1551     } else {
1552       return false;
1553     }
1554 
1555     // TODO(intel-tf): Relax the restriction of 2D/3D tensor once kernel
1556     // supports that.
1557     if (!ctx->inferred_graph_properties) {
1558       Status s = ctx->graph_properties.InferStatically(
1559           /*assume_valid_feeds=*/true,
1560           /*aggressive_shape_inference=*/false,
1561           /*include_input_tensor_values=*/true,
1562           /*include_output_tensor_values=*/false);
1563       if (!s.ok()) return false;
1564       ctx->inferred_graph_properties = true;
1565     }
1566     NodeDef* input_node_def =
1567         ctx->graph_view.GetNode(matched_nodes_map->at("input"))->node();
1568     auto input_props =
1569         ctx->graph_properties.GetOutputProperties(input_node_def->name());
1570     NodeDef* output_node_def =
1571         ctx->graph_view.GetNode(matched_nodes_map->at("output"))->node();
1572     auto output_props =
1573         ctx->graph_properties.GetOutputProperties(output_node_def->name());
1574     if (ShapesSymbolicallyEqual(input_props[0].shape(),
1575                                 output_props[0].shape())) {
1576       int rank = Rank(input_props[0].shape());
1577       if (rank < 2 || rank > 3) return false;
1578     } else {
1579       return false;
1580     }
1581   }
1582   return found_op_type_match;
1583 }
1584 
FindFusedBatchNorm(const RemapperContext & ctx,int node_index,FusedBatchNorm * matched)1585 bool FindFusedBatchNorm(const RemapperContext& ctx, int node_index,
1586                         FusedBatchNorm* matched) {
1587   const auto* node_view = ctx.graph_view.GetNode(node_index);
1588   const auto* node_def = node_view->node();
1589   if (!IsFusedBatchNorm(*node_def)) return false;
1590   if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
1591 
1592   // Check that the node is in inference mode.
1593   bool is_training = true;
1594   if (!TryGetNodeAttr(*node_def, kIsTraining, &is_training)) return false;
1595   if (is_training) return false;
1596 
1597   const auto& props = ctx.graph_properties.GetInputProperties(node_def->name());
1598 
1599   // a. Scaling factor can be const folded:
1600   //      scaling_factor = (variance + epsilon).rsqrt() * scale
1601   bool const_scaling_factor =
1602       props.size() == 5 &&     // [x, scale, offset, mean, variance]
1603       props[1].has_value() &&  // scale
1604       props[4].has_value();    // variance aka estimated variance
1605 
1606   // b. Or input can be const folded into some other expression.
1607   auto const_inputs = std::count_if(
1608       props.begin(), props.end(),
1609       [](const OpInfo::TensorProperties& props) { return props.has_value(); });
1610 
1611   // TODO(bsteiner): use the cost model to compare the cost of fused batch
1612   // norm against that of the optimized form.
1613   bool can_remap = const_scaling_factor || const_inputs >= 4;
1614   if (!can_remap) return false;
1615 
1616   // The optimized version only generates the first output.
1617   if (node_view->GetRegularFanouts().size() > 1) {
1618     return false;
1619   }
1620 
1621   // We found a fused batch norm node that can be replaced with primitive ops.
1622   matched->fused_batch_norm = node_index;
1623 
1624   return true;
1625 }
1626 
1627 // NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
1628 // `tensorflow/stream_executor/cuda/cuda_dnn.cc` for details.
BatchnormSpatialPersistentEnabled()1629 bool BatchnormSpatialPersistentEnabled() {
1630 #if CUDNN_VERSION >= 7402
1631   static bool is_enabled = [] {
1632     bool is_enabled = false;
1633     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
1634         "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
1635         /*default_val=*/false, &is_enabled));
1636     return is_enabled;
1637   }();
1638   return is_enabled;
1639 #else
1640   return false;
1641 #endif
1642 }
1643 
FindFusedBatchNormEx(const RemapperContext & ctx,int node_index,FusedBatchNormEx * matched)1644 bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
1645                           FusedBatchNormEx* matched) {
1646   // Root of the pattern must be a Relu.
1647   // TODO(ezhulenev): Forward control dependencies.
1648   const auto* node_view = ctx.graph_view.GetNode(node_index);
1649   const auto* node_def = node_view->node();
1650   // TODO(lyandy): Forward controls for patterns with control dependencies.
1651   if (!IsRelu(*node_def) || HasControlFaninOrFanout(*node_view)) return false;
1652 
1653   // Returns true iff the node is a compatible FusedBatchNorm node.
1654   const auto valid_batch_norm =
1655       [&](const utils::MutableNodeView& fused_batch_norm) -> bool {
1656     const auto* fused_batch_norm_node_def = fused_batch_norm.node();
1657     if (!IsFusedBatchNorm(*fused_batch_norm_node_def)) return false;
1658 
1659     // We fuse FusedBatchNorm on GPU or oneDNN CPU.
1660     if (!IsMKLEnabled() && !NodeIsOnGpu(fused_batch_norm_node_def))
1661       return false;
1662 
1663     DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T");
1664 
1665     if (NodeIsOnGpu(fused_batch_norm_node_def)) {
1666       // GPU supports float and half.
1667       // Put this condition before check `IsMKLEnabled()` because this node
1668       // should be processed when it's on GPU and oneDNN CPU is enabled.
1669       if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
1670     } else {
1671       // Bfloat16 is available only with oneDNN.
1672       // Half is not available with oneDNN.
1673       if (IsMKLEnabled() && t_dtype != DT_FLOAT && t_dtype != DT_BFLOAT16)
1674         return false;
1675     }
1676 
1677     // Get the FusedBatchNorm training mode.
1678     bool is_training;
1679     if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
1680              .ok())
1681       return false;
1682     string data_format;
1683     if (!GetNodeAttr(*fused_batch_norm_node_def, kDataFormat, &data_format)
1684              .ok())
1685       return false;
1686     if (data_format != "NHWC" && data_format != "NCHW") return false;
1687 
1688     // In training mode we rely on cuDNN for computing FusedBatchNorm with side
1689     // inputs and activation, and it has its own limitations. In inference mode
1690     // we have a custom CUDA kernel that doesn't not have these constraints.
1691     if (is_training && NodeIsOnGpu(fused_batch_norm_node_def)) {
1692       // cuDNN only supports NHWC data layout.
1693       if (data_format != "NHWC") return false;
1694 
1695       // Data type must be DT_HALF.
1696       if (t_dtype != DT_HALF) return false;
1697 
1698       // Channel dimension must be a multiple of 4.
1699       const auto& props = ctx.graph_properties.GetInputProperties(
1700           fused_batch_norm_node_def->name());
1701 
1702       const bool valid_channel_dim = !props.empty() &&
1703                                      props[0].shape().dim_size() == 4 &&
1704                                      props[0].shape().dim(3).size() % 4 == 0;
1705       if (!valid_channel_dim) return false;
1706 
1707       // cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode.
1708       if (!BatchnormSpatialPersistentEnabled()) return false;
1709     }
1710 
1711     // FusedBatchNormV2 and V3 have an extra type parameter.
1712     if ((fused_batch_norm_node_def->op() != "FusedBatchNorm") &&
1713         !HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U"))
1714       return false;
1715 
1716     // Check that only one node consumes the 0-th output of a FusedBatchNorm.
1717     if (HasControlFaninOrFanout(fused_batch_norm) ||
1718         !HasAtMostOneDataFanoutAtPort0(fused_batch_norm) ||
1719         IsInPreserveSet(ctx, fused_batch_norm_node_def))
1720       return false;
1721 
1722     return true;
1723   };
1724 
1725   if (node_view->NumRegularFanins() < 1) return false;
1726   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
1727   const auto* relu_fanin_0_node_view = regular_fanin_0.node_view();
1728   const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
1729 
1730   // Input to a Relu can be a FusedBatchNorm.
1731   if (valid_batch_norm(*relu_fanin_0_node_view)) {
1732     matched->activation = node_index;
1733     matched->fused_batch_norm = regular_fanin_0.node_index();
1734     return true;
1735   }
1736 
1737   // Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
1738   if (IsAdd(*relu_fanin_0_node_def)) {
1739     // Currently no CPU implementation for "FusedBatchNorm + SideInput +
1740     // <Activation>""
1741     if (IsMKLEnabled() && !NodeIsOnGpu(node_def)) return false;
1742 
1743     // Check that only Relu node consumes the output of an Add node.
1744     if (HasControlFaninOrFanout(*relu_fanin_0_node_view) ||
1745         !HasAtMostOneFanoutAtPort0(*relu_fanin_0_node_view) ||
1746         IsInPreserveSet(ctx, relu_fanin_0_node_def))
1747       return false;
1748 
1749     // Add node supports broadcasting, FusedBatchNormEx does not.
1750     const auto& props =
1751         ctx.graph_properties.GetInputProperties(relu_fanin_0_node_def->name());
1752     if (props.size() < 2 ||
1753         !ShapesSymbolicallyEqual(props[0].shape(), props[1].shape()))
1754       return false;
1755 
1756     if (relu_fanin_0_node_view->NumRegularFanins() < 2) return false;
1757     const auto& add_regular_fanin_0 =
1758         relu_fanin_0_node_view->GetRegularFanin(0);
1759     const auto& add_regular_fanin_1 =
1760         relu_fanin_0_node_view->GetRegularFanin(1);
1761 
1762     if (valid_batch_norm(*add_regular_fanin_0.node_view())) {
1763       matched->activation = node_index;
1764       matched->side_input = add_regular_fanin_1.node_index();
1765       matched->fused_batch_norm = add_regular_fanin_0.node_index();
1766       matched->invalidated = regular_fanin_0.node_index();
1767       return true;
1768     }
1769 
1770     if (valid_batch_norm(*add_regular_fanin_1.node_view())) {
1771       matched->activation = node_index;
1772       matched->side_input = add_regular_fanin_0.node_index();
1773       matched->fused_batch_norm = add_regular_fanin_1.node_index();
1774       matched->invalidated = regular_fanin_0.node_index();
1775       return true;
1776     }
1777   }
1778 
1779   return false;
1780 }
1781 
FindFusedBatchNormGradEx(const RemapperContext & ctx,int node_index,FusedBatchNormGradEx * matched)1782 bool FindFusedBatchNormGradEx(const RemapperContext& ctx, int node_index,
1783                               FusedBatchNormGradEx* matched) {
1784   // Root of the pattern must be a FusedBatchNormGrad.
1785   const utils::MutableNodeView* node_view = ctx.graph_view.GetNode(node_index);
1786 
1787   // Returns true iff the node is a compatible FusedBatchNormGrad node.
1788   const auto valid_batch_norm_grad =
1789       [&](const utils::MutableNodeView& fused_batch_norm_grad) -> bool {
1790     const NodeDef* node_def = fused_batch_norm_grad.node();
1791     if (!IsFusedBatchNormGrad(*node_def) ||
1792         HasControlFaninOrFanout(fused_batch_norm_grad))
1793       return false;
1794 
1795     // We fuse FusedBatchNormGrad on GPU.
1796     if (!NodeIsOnGpu(node_def)) return false;
1797 
1798     // We fuse FusedBatchNormGrad only for the training mode.
1799     bool is_training;
1800     if (!GetNodeAttr(*node_def, kIsTraining, &is_training).ok() || !is_training)
1801       return false;
1802 
1803     // Data type must be DT_HALF.
1804     DataType t_dtype = GetDataTypeFromAttr(*node_def, "T");
1805     if (t_dtype != DT_HALF) return false;
1806 
1807     // We rely on cuDNN for computing FusedBatchNormGrad with side
1808     // outputs and activation. cuDNN only supports NHWC data layout.
1809     string data_format;
1810     if (!GetNodeAttr(*node_def, kDataFormat, &data_format).ok()) return false;
1811     if (data_format != "NHWC") return false;
1812 
1813     // Channel dimension must be a multiple of 4.
1814     const auto& props =
1815         ctx.graph_properties.GetInputProperties(node_def->name());
1816     const bool valid_channel_dim = !props.empty() &&
1817                                    props[0].shape().dim_size() == 4 &&
1818                                    props[0].shape().dim(3).size() % 4 == 0;
1819     if (!valid_channel_dim) return false;
1820 
1821     // cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode.
1822     if (!BatchnormSpatialPersistentEnabled()) return false;
1823 
1824     // FusedBatchNormV2 and V3 have an extra type parameter.
1825     if (node_def->op() != "FusedBatchNorm" &&
1826         !HasDataType(node_def, DT_FLOAT, "U"))
1827       return false;
1828 
1829     return true;
1830   };
1831 
1832   if (ctx.xla_auto_clustering_on) return false;
1833 
1834   if (!valid_batch_norm_grad(*node_view)) return false;
1835 
1836   if (node_view->NumRegularFanins() < 1) return false;
1837 
1838   const utils::MutableFanoutView& regular_fanin_0 =
1839       node_view->GetRegularFanin(0);
1840   const utils::MutableNodeView* relugrad_node_view =
1841       regular_fanin_0.node_view();
1842   const NodeDef* relugrad_node_def = relugrad_node_view->node();
1843   bool is_relugrad = IsReluGrad(*relugrad_node_def);
1844 
1845   if (!is_relugrad || HasControlFaninOrFanout(*relugrad_node_view) ||
1846       IsInPreserveSet(ctx, relugrad_node_def))
1847     return false;
1848 
1849   if (relugrad_node_view->NumRegularFanins() < 1) return false;
1850   // Find its corresponding forward node. We need the node to determine if the
1851   // type is bn+add+act or bn+act. Also, we need to access its "offset" input.
1852   const utils::MutableFanoutView& fanin_1 =
1853       relugrad_node_view->GetRegularFanin(1);
1854   const utils::MutableNodeView* fwd_node_view = fanin_1.node_view();
1855   FusedBatchNormEx fwd_matched;
1856   FindFusedBatchNormEx(ctx, fwd_node_view->node_index(), &fwd_matched);
1857   bool fwd_bn_act_used = fwd_matched.activation != kMissingIndex &&
1858                          fwd_matched.side_input == kMissingIndex;
1859   bool fwd_bn_add_act_used = fwd_matched.activation != kMissingIndex &&
1860                              fwd_matched.side_input != kMissingIndex;
1861 
1862   // Check that only 1 node consumes the output of the ReluGrad node.
1863   if (fwd_bn_act_used && relugrad_node_view->GetRegularFanout(0).size() == 1) {
1864     matched->activation_grad = regular_fanin_0.node_index();
1865     matched->fused_batch_norm_grad = node_index;
1866     matched->fwd_fused_batch_norm = fwd_matched.fused_batch_norm;
1867     return true;
1868   }
1869 
1870   // Check that only 2 nodes consume the output of the ReluGrad node.
1871   if (fwd_bn_add_act_used &&
1872       relugrad_node_view->GetRegularFanout(0).size() == 2) {
1873     // In a graph with the Add node having two BatchNorm nodes as the inputs, we
1874     // need to make sure only the one backward BatchNorm that correponds to the
1875     // to-be-fused forward BatchNorm should be fused. We use the edge for the
1876     // reserve space to get the directly corresponded forward BatchNorm node.
1877     const utils::MutableFanoutView& fwd_batch_norm_node =
1878         node_view->GetRegularFanin(5);
1879     if (fwd_matched.fused_batch_norm != fwd_batch_norm_node.node_index()) {
1880       return false;
1881     }
1882 
1883     const std::vector<utils::MutableFaninView>& fanouts_at_port_0 =
1884         relugrad_node_view->GetRegularFanouts()[0];
1885     const utils::MutableNodeView* fanout_0_node_view =
1886         ctx.graph_view.GetNode(fanouts_at_port_0[0].node_view()->GetName());
1887     const utils::MutableNodeView* fanout_1_node_view =
1888         ctx.graph_view.GetNode(fanouts_at_port_0[1].node_view()->GetName());
1889     const NodeDef* fanout_0_node_def = fanout_0_node_view->node();
1890     const NodeDef* fanout_1_node_def = fanout_1_node_view->node();
1891     const NodeDef* node_def = node_view->node();
1892 
1893     matched->activation_grad = regular_fanin_0.node_index();
1894     matched->fused_batch_norm_grad = node_index;
1895     matched->fwd_fused_batch_norm = fwd_matched.fused_batch_norm;
1896 
1897     if (fanout_0_node_def == node_def) {
1898       matched->side_input_grad = fanout_1_node_view->node_index();
1899       return true;
1900     }
1901 
1902     if (fanout_1_node_def == node_def) {
1903       matched->side_input_grad = fanout_0_node_view->node_index();
1904       return true;
1905     }
1906   }
1907 
1908   return false;
1909 }
1910 
FindTensorToHashBucket(const RemapperContext & ctx,int node_index,TensorToHashBucket * matched)1911 bool FindTensorToHashBucket(const RemapperContext& ctx, int node_index,
1912                             TensorToHashBucket* matched) {
1913   // Root of the pattern must be a StringToHashBucketFast.
1914   const auto* node_view = ctx.graph_view.GetNode(node_index);
1915   const auto* node_def = node_view->node();
1916 
1917   if (!IsStringToHashBucketFast(*node_def) ||
1918       HasControlFaninOrFanout(*node_view)) {
1919     return false;
1920   }
1921 
1922   // Input to the StringToHashBucketFast must be AsString.
1923   if (node_view->NumRegularFanins() < 1) return false;
1924 
1925   const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
1926   const auto* as_string_node_view = regular_fanin_0.node_view();
1927   const auto* as_string_node_def = as_string_node_view->node();
1928   bool is_as_string = IsAsString(*as_string_node_def);
1929 
1930   if (!is_as_string || HasControlFaninOrFanout(*as_string_node_view) ||
1931       !HasAtMostOneFanoutAtPort0(*as_string_node_view) ||
1932       IsInPreserveSet(ctx, as_string_node_def))
1933     return false;
1934 
1935   // DataType of AsString must be int8/16/32/64 and width/fill attrs must be
1936   // default values.
1937   if (!HasDataType(as_string_node_def, DT_INT8) &&
1938       !HasDataType(as_string_node_def, DT_INT16) &&
1939       !HasDataType(as_string_node_def, DT_INT32) &&
1940       !HasDataType(as_string_node_def, DT_INT64)) {
1941     return false;
1942   }
1943 
1944   int width;
1945   if (!GetNodeAttr(*as_string_node_def, kWidth, &width).ok() || width != -1) {
1946     return false;
1947   }
1948 
1949   string fill;
1950   if (!GetNodeAttr(*as_string_node_def, kFill, &fill).ok() || !fill.empty()) {
1951     return false;
1952   }
1953 
1954   // An input to the AsString must exist to determine the device.
1955   if (as_string_node_view->NumRegularFanins() < 1) return false;
1956 
1957   const auto& fanin_0 = as_string_node_view->GetRegularFanin(0);
1958   const auto* pre_node_view = fanin_0.node_view();
1959 
1960   // We successfully found a AsString + StringToHashBucketFast pattern.
1961   const TensorToHashBucket pattern{pre_node_view->node_index(),
1962                                    as_string_node_view->node_index(),
1963                                    node_index};
1964 
1965   *matched = pattern;
1966 
1967   return true;
1968 }
1969 
FindFusedBatchMatMul(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)1970 bool FindFusedBatchMatMul(RemapperContext* ctx, int node_index,
1971                           std::map<string, int>* matched_nodes_map,
1972                           std::set<int>* remove_node_indices) {
1973   if (!IsMKLEnabled()) return false;
1974 
1975   using utils::MatchingDirection;
1976   using utils::NodeStatus;
1977   // clang-format off
1978   utils::OpTypePattern fusion_pattern1 =
1979     {"AddV2", "output", NodeStatus::kReplace,
1980       {
1981         {"Mul", "mul", NodeStatus::kRemove,
1982           {
1983             {"BatchMatMulV2", "batch_matmul", NodeStatus::kRemove},
1984             {"*", "multiplicand", NodeStatus::kRemain}
1985           }
1986         },
1987         {"*", "addend", NodeStatus::kRemain}
1988       }
1989     };
1990 
1991   utils::OpTypePattern fusion_pattern2 =
1992     {"AddV2", "output", NodeStatus::kReplace,
1993       {
1994         {"*", "addend", NodeStatus::kRemain},
1995         {"Mul", "mul", NodeStatus::kRemove,
1996           {
1997             {"BatchMatMulV2", "batch_matmul", NodeStatus::kRemove},
1998             {"*", "multiplicand", NodeStatus::kRemain}
1999           }
2000         }
2001       }
2002     };
2003   // clang-format on
2004 
2005   utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
2006       &(ctx->graph_view));
2007   bool found_op_type_match = false;
2008   matched_nodes_map->clear();
2009   remove_node_indices->clear();
2010   found_op_type_match =
2011       graph_matcher.GetMatchedNodes(fusion_pattern1, ctx->nodes_to_preserve,
2012                                     ctx->graph_view.GetNode(node_index),
2013                                     matched_nodes_map, remove_node_indices);
2014 
2015   if (!found_op_type_match) {
2016     matched_nodes_map->clear();
2017     remove_node_indices->clear();
2018     found_op_type_match =
2019         graph_matcher.GetMatchedNodes(fusion_pattern2, ctx->nodes_to_preserve,
2020                                       ctx->graph_view.GetNode(node_index),
2021                                       matched_nodes_map, remove_node_indices);
2022   }
2023 
2024   // OneDNN is not optimized for all shapes with regard to binary-post ops
2025   // fusion. Allow limited cases only for now that are optimized, (i)
2026   // multiplicand is scalar, (ii) BatchMatmulV2 output is 4D tensor, and (iii)
2027   // addend is 4D tensor with second dim_size = 1.
2028   if (!found_op_type_match) return false;
2029   if (!ctx->inferred_graph_properties) {
2030     Status s = ctx->graph_properties.InferStatically(
2031         /*assume_valid_feeds=*/true,
2032         /*aggressive_shape_inference=*/false,
2033         /*include_input_tensor_values=*/false,
2034         /*include_output_tensor_values=*/true);
2035     if (!s.ok()) return false;
2036     ctx->inferred_graph_properties = true;
2037   }
2038   NodeDef* multiplicand_node_def =
2039       ctx->graph_view.GetNode(matched_nodes_map->at("multiplicand"))->node();
2040   auto multiplicand_props =
2041       ctx->graph_properties.GetOutputProperties(multiplicand_node_def->name());
2042   if (NumCoefficients(multiplicand_props[0].shape()) != 1) return false;
2043 
2044   NodeDef* batch_matmul_node_def =
2045       ctx->graph_view.GetNode(matched_nodes_map->at("batch_matmul"))->node();
2046   if (!IsCpuCompatibleMatMul(*ctx, batch_matmul_node_def)) return false;
2047 
2048   auto batch_matmul_props =
2049       ctx->graph_properties.GetOutputProperties(batch_matmul_node_def->name());
2050   if (Rank(batch_matmul_props[0].shape()) != 4) return false;
2051 
2052   NodeDef* addend_node_def =
2053       ctx->graph_view.GetNode(matched_nodes_map->at("addend"))->node();
2054   auto addend_props =
2055       ctx->graph_properties.GetOutputProperties(addend_node_def->name());
2056   auto addend_shape = addend_props[0].shape();
2057   if (!(Rank(addend_shape) == 4 && addend_shape.dim(1).size() == 1))
2058     return false;
2059   return found_op_type_match;
2060 }
2061 
CopyConv2DAttributes(const NodeDef & conv2d,NodeDef * fused_conv2d,const NodeDef * activation=nullptr)2062 void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d,
2063                           const NodeDef* activation = nullptr) {
2064   DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D";
2065 
2066   auto* attr = fused_conv2d->mutable_attr();
2067   auto& src_attr = conv2d.attr();
2068 
2069   (*attr)["T"] = src_attr.at("T");
2070   (*attr)["strides"] = src_attr.at("strides");
2071   (*attr)["padding"] = src_attr.at("padding");
2072   (*attr)["explicit_paddings"] = src_attr.at("explicit_paddings");
2073   (*attr)["dilations"] = src_attr.at("dilations");
2074   (*attr)["data_format"] = src_attr.at("data_format");
2075   (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
2076   // Copy LeakyRelu's attr alpha to FusedConv2D's attr leakyrelu_alpha
2077   if (activation != nullptr && IsLeakyRelu(*activation)) {
2078     auto& activation_attr = activation->attr();
2079     (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
2080   }
2081 }
2082 
CopyConv3DAttributes(const NodeDef & conv3d,NodeDef * fused_conv3d,const NodeDef * activation=nullptr)2083 void CopyConv3DAttributes(const NodeDef& conv3d, NodeDef* fused_conv3d,
2084                           const NodeDef* activation = nullptr) {
2085   DCHECK(IsConv3D(conv3d)) << "Input node must be a Conv3D";
2086 
2087   auto* attr = fused_conv3d->mutable_attr();
2088   auto& src_attr = conv3d.attr();
2089 
2090   (*attr)["T"] = src_attr.at("T");
2091   (*attr)["strides"] = src_attr.at("strides");
2092   (*attr)["padding"] = src_attr.at("padding");
2093   (*attr)["dilations"] = src_attr.at("dilations");
2094   (*attr)["data_format"] = src_attr.at("data_format");
2095   // Copy LeakyRelu's attr alpha to FusedConv3D's attr leakyrelu_alpha
2096   if (activation != nullptr && IsLeakyRelu(*activation)) {
2097     auto& activation_attr = activation->attr();
2098     (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
2099   }
2100 }
2101 
CopyDepthwiseConv2dNativeAttributes(const NodeDef & dw_conv2d,NodeDef * fused_dw_conv2d,const NodeDef * activation=nullptr)2102 void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
2103                                          NodeDef* fused_dw_conv2d,
2104                                          const NodeDef* activation = nullptr) {
2105   DCHECK(IsDepthwiseConv2dNative(dw_conv2d))
2106       << "Input node must be a DepthwiseConv2dNative";
2107 
2108   auto* attr = fused_dw_conv2d->mutable_attr();
2109   auto& src_attr = dw_conv2d.attr();
2110 
2111   (*attr)["T"] = src_attr.at("T");
2112   (*attr)["strides"] = src_attr.at("strides");
2113   (*attr)["padding"] = src_attr.at("padding");
2114   (*attr)["dilations"] = src_attr.at("dilations");
2115   (*attr)["data_format"] = src_attr.at("data_format");
2116   // Copy LeakyRelu's attr alpha to FusedDepthwiseConv2d's attr leakyrelu_alpha
2117   if (activation != nullptr && IsLeakyRelu(*activation)) {
2118     auto& activation_attr = activation->attr();
2119     (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
2120   }
2121 }
2122 
CopyFusedBatchNormAttributes(const NodeDef & fused_batch_norm,NodeDef * fused_batch_norm_ex)2123 void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
2124                                   NodeDef* fused_batch_norm_ex) {
2125   DCHECK(IsFusedBatchNorm(fused_batch_norm))
2126       << "Input node must be a FusedBatchNorm";
2127 
2128   auto* attr = fused_batch_norm_ex->mutable_attr();
2129   auto src_attr = fused_batch_norm.attr();
2130 
2131   (*attr)["T"] = src_attr.at("T");
2132   (*attr)["is_training"] = src_attr.at("is_training");
2133   (*attr)["data_format"] = src_attr.at("data_format");
2134   (*attr)["epsilon"] = src_attr.at("epsilon");
2135   (*attr)["exponential_avg_factor"] = src_attr.at("exponential_avg_factor");
2136 
2137   // FusedBatchNormV2 and V3 have an extra type parameter.
2138   if (fused_batch_norm.op() != "FusedBatchNorm") {
2139     SetAttrValue(src_attr.at("U"), &(*attr)["U"]);
2140   } else {
2141     if (!IsMKLEnabled())
2142       SetAttrValue(src_attr.at("T"), &(*attr)["U"]);
2143     else
2144       SetAttrValue(DT_FLOAT, &(*attr)["U"]);
2145   }
2146 }
2147 
CopyFusedBatchNormGradAttributes(const NodeDef & fused_batch_norm_grad,NodeDef * fused_batch_norm_grad_ex)2148 void CopyFusedBatchNormGradAttributes(const NodeDef& fused_batch_norm_grad,
2149                                       NodeDef* fused_batch_norm_grad_ex) {
2150   DCHECK(IsFusedBatchNormGrad(fused_batch_norm_grad))
2151       << "Input node must be a FusedBatchNormGrad";
2152 
2153   auto* attr = fused_batch_norm_grad_ex->mutable_attr();
2154   auto src_attr = fused_batch_norm_grad.attr();
2155 
2156   (*attr)["T"] = src_attr.at("T");
2157   (*attr)["is_training"] = src_attr.at("is_training");
2158   (*attr)["data_format"] = src_attr.at("data_format");
2159   (*attr)["epsilon"] = src_attr.at("epsilon");
2160 
2161   // FusedBatchNormV2 and V3 have an extra type parameter.
2162   if (fused_batch_norm_grad.op() != "FusedBatchNormGrad") {
2163     SetAttrValue(src_attr.at("U"), &(*attr)["U"]);
2164   } else {
2165     SetAttrValue(DT_FLOAT, &(*attr)["U"]);
2166   }
2167 }
2168 
CopyMatMulAttributes(const NodeDef & matmul,NodeDef * fused_matmul,const NodeDef * activation=nullptr)2169 void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul,
2170                           const NodeDef* activation = nullptr) {
2171   DCHECK(IsMatMul(matmul)) << "Input node must be a MatMul";
2172 
2173   auto* attr = fused_matmul->mutable_attr();
2174   auto& src_attr = matmul.attr();
2175 
2176   (*attr)["T"] = src_attr.at("T");
2177   (*attr)["transpose_a"] = src_attr.at("transpose_a");
2178   (*attr)["transpose_b"] = src_attr.at("transpose_b");
2179   // Copy LeakyRelu's attr alpha to _FusedMatMul's attr leakyrelu_alpha
2180   if (activation != nullptr && IsLeakyRelu(*activation)) {
2181     auto& activation_attr = activation->attr();
2182     (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
2183   }
2184 }
2185 
CopyBatchMatMulAttributes(const NodeDef & batchmatmul,NodeDef * fused_batch_matmul)2186 void CopyBatchMatMulAttributes(const NodeDef& batchmatmul,
2187                                NodeDef* fused_batch_matmul) {
2188   DCHECK(IsAnyBatchMatMul(batchmatmul)) << "Input node must be a BatchMatMul";
2189 
2190   auto* attr = fused_batch_matmul->mutable_attr();
2191   auto& src_attr = batchmatmul.attr();
2192 
2193   (*attr)["T"] = src_attr.at("T");
2194   (*attr)["adj_x"] = src_attr.at("adj_x");
2195   (*attr)["adj_y"] = src_attr.at("adj_y");
2196 }
2197 
SetFusedOpAttributes(NodeDef * fused,const absl::Span<const absl::string_view> fused_ops,int num_args=1,float epsilon=0.0)2198 void SetFusedOpAttributes(NodeDef* fused,
2199                           const absl::Span<const absl::string_view> fused_ops,
2200                           int num_args = 1, float epsilon = 0.0) {
2201   auto* attr = fused->mutable_attr();
2202   SetAttrValue(fused_ops, &(*attr)["fused_ops"]);
2203   SetAttrValue(num_args, &(*attr)["num_args"]);
2204   SetAttrValue(epsilon, &(*attr)["epsilon"]);  // required only for BatchNorm
2205 }
2206 
AddFusedContractionNode(RemapperContext * ctx,const ContractionWithBiasAdd & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2207 Status AddFusedContractionNode(RemapperContext* ctx,
2208                                const ContractionWithBiasAdd& matched,
2209                                std::vector<bool>* invalidated_nodes,
2210                                std::vector<bool>* nodes_to_delete) {
2211   DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
2212 
2213   const GraphDef* graph = ctx->graph_view.graph();
2214   const NodeDef& contraction = graph->node(matched.contraction);
2215   const NodeDef& bias_add = graph->node(matched.bias_add);
2216   VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd: "
2217           << " bias_add=" << bias_add.name()
2218           << " contraction=" << contraction.name();
2219 
2220   NodeDef fused_op;
2221   fused_op.set_name(bias_add.name());
2222   fused_op.set_device(contraction.device());
2223   fused_op.add_input(contraction.input(0));               // 0: input
2224   fused_op.add_input(contraction.input(1));               // 1: filter
2225   fused_op.add_input(bias_add.input(matched.bias_port));  // 2: bias
2226   if (IsConv2D(contraction)) {
2227     fused_op.set_op(kFusedConv2D);
2228     CopyConv2DAttributes(contraction, &fused_op);
2229   } else if (IsDepthwiseConv2dNative(contraction)) {
2230     fused_op.set_op(kFusedDepthwiseConv2dNative);
2231     CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
2232   } else if (IsMatMul(contraction)) {
2233     fused_op.set_op(kFusedMatMul);
2234     CopyMatMulAttributes(contraction, &fused_op);
2235   } else if (IsConv3D(contraction)) {
2236     fused_op.set_op(kFusedConv3D);
2237     CopyConv3DAttributes(contraction, &fused_op);
2238   }
2239 
2240   SetFusedOpAttributes(&fused_op, {"BiasAdd"});
2241   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2242   Status status;
2243   mutation->AddNode(std::move(fused_op), &status);
2244   TF_RETURN_IF_ERROR(status);
2245   TF_RETURN_IF_ERROR(mutation->Apply());
2246 
2247   (*invalidated_nodes)[matched.bias_add] = true;
2248   (*nodes_to_delete)[matched.contraction] = true;
2249 
2250   return OkStatus();
2251 }
2252 
AddFusedContractionNode(RemapperContext * ctx,const ContractionWithBiasAddAndActivation & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2253 Status AddFusedContractionNode(
2254     RemapperContext* ctx, const ContractionWithBiasAddAndActivation& matched,
2255     std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
2256   DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
2257 
2258   const GraphDef* graph = ctx->graph_view.graph();
2259   const NodeDef& contraction = graph->node(matched.contraction);
2260   const NodeDef& bias_add = graph->node(matched.bias_add);
2261   const NodeDef& activation = graph->node(matched.activation);
2262 
2263   VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd and "
2264           << activation.op() << ":"
2265           << " activation=" << activation.name()
2266           << " bias_add=" << bias_add.name()
2267           << " contraction=" << contraction.name();
2268 
2269   NodeDef fused_op;
2270   fused_op.set_name(activation.name());
2271   fused_op.set_device(contraction.device());
2272   fused_op.add_input(contraction.input(0));               // 0: input
2273   fused_op.add_input(contraction.input(1));               // 1: filter
2274   fused_op.add_input(bias_add.input(matched.bias_port));  // 2: bias
2275 
2276   if (IsConv2D(contraction)) {
2277     fused_op.set_op(kFusedConv2D);
2278     // leaky relu has a special attribute alpha
2279     CopyConv2DAttributes(contraction, &fused_op, &activation);
2280   } else if (IsDepthwiseConv2dNative(contraction)) {
2281     fused_op.set_op(kFusedDepthwiseConv2dNative);
2282     CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
2283   } else if (IsMatMul(contraction)) {
2284     fused_op.set_op(kFusedMatMul);
2285     CopyMatMulAttributes(contraction, &fused_op, &activation);
2286   } else if (IsConv3D(contraction)) {
2287     fused_op.set_op(kFusedConv3D);
2288     CopyConv3DAttributes(contraction, &fused_op, &activation);
2289   }
2290 
2291   SetFusedOpAttributes(&fused_op, {"BiasAdd", activation.op()});
2292 
2293   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2294   Status status;
2295   mutation->AddNode(std::move(fused_op), &status);
2296   TF_RETURN_IF_ERROR(status);
2297   TF_RETURN_IF_ERROR(mutation->Apply());
2298 
2299   (*nodes_to_delete)[matched.contraction] = true;
2300   (*nodes_to_delete)[matched.bias_add] = true;
2301   (*invalidated_nodes)[matched.activation] = true;
2302 
2303   return OkStatus();
2304 }
2305 
AddFusedConvNode(RemapperContext * ctx,const ContractionWithSqueezeAndBiasAdd & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2306 Status AddFusedConvNode(RemapperContext* ctx,
2307                         const ContractionWithSqueezeAndBiasAdd& matched,
2308                         std::vector<bool>* invalidated_nodes,
2309                         std::vector<bool>* nodes_to_delete) {
2310   DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
2311 
2312   const GraphDef* graph = ctx->graph_view.graph();
2313   const NodeDef& contraction = graph->node(matched.contraction);
2314 
2315   const NodeDef& bias_add = graph->node(matched.bias_add);
2316   const NodeDef& squeeze = graph->node(matched.squeeze);
2317   VLOG(2) << "Fuse Conv2D/3D with Squeeze and BiasAdd: "
2318           << " bias_add=" << bias_add.name() << " squeeze=" << squeeze.name()
2319           << " conv=" << contraction.name();
2320 
2321   // Replace Conv2D/3D node with a fused Conv2D/3D. Matched pattern guarantees
2322   // that it has single consumer (only the squeeze node).
2323   NodeDef fused_conv;
2324   fused_conv.set_name(contraction.name());
2325   fused_conv.set_device(contraction.device());
2326   fused_conv.add_input(contraction.input(0));  // 0: input
2327   fused_conv.add_input(contraction.input(1));  // 1: filter
2328   fused_conv.add_input(bias_add.input(1));     // 2: bias
2329 
2330   if (IsConv2D(contraction)) {
2331     fused_conv.set_op(kFusedConv2D);
2332     CopyConv2DAttributes(contraction, &fused_conv);
2333   } else if (IsConv3D(contraction)) {
2334     fused_conv.set_op(kFusedConv3D);
2335     CopyConv3DAttributes(contraction, &fused_conv);
2336   }
2337 
2338   SetFusedOpAttributes(&fused_conv, {"BiasAdd"});
2339 
2340   // Replace BiasAdd node with a Squeeze.
2341   NodeDef remapped_squeeze = squeeze;
2342   remapped_squeeze.set_name(bias_add.name());
2343   remapped_squeeze.set_input(0, contraction.name());
2344 
2345   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2346   Status status;
2347   mutation->AddNode(std::move(fused_conv), &status);
2348   TF_RETURN_IF_ERROR(status);
2349   mutation->AddNode(std::move(remapped_squeeze), &status);
2350   TF_RETURN_IF_ERROR(status);
2351   TF_RETURN_IF_ERROR(mutation->Apply());
2352 
2353   (*invalidated_nodes)[matched.contraction] = true;
2354   (*invalidated_nodes)[matched.bias_add] = true;
2355   (*nodes_to_delete)[matched.squeeze] = true;
2356 
2357   return OkStatus();
2358 }
2359 
AddFusedConv2DNode(RemapperContext * ctx,const ContractionWithBatchNorm & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2360 Status AddFusedConv2DNode(RemapperContext* ctx,
2361                           const ContractionWithBatchNorm& matched,
2362                           std::vector<bool>* invalidated_nodes,
2363                           std::vector<bool>* nodes_to_delete) {
2364   const GraphDef* graph = ctx->graph_view.graph();
2365   const NodeDef& contraction = graph->node(matched.contraction);
2366   DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";
2367   const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
2368   VLOG(2) << "Fuse Conv2D with BatchNorm: batch_norm="
2369           << fused_batch_norm.name() << " conv2d=" << contraction.name();
2370 
2371   NodeDef fused_conv2d;
2372   fused_conv2d.set_name(fused_batch_norm.name());
2373   fused_conv2d.set_op(kFusedConv2D);
2374   fused_conv2d.set_device(contraction.device());
2375   fused_conv2d.add_input(contraction.input(0));       // 0: input
2376   fused_conv2d.add_input(contraction.input(1));       // 1: filter
2377   fused_conv2d.add_input(fused_batch_norm.input(1));  // 2: scale
2378   fused_conv2d.add_input(fused_batch_norm.input(2));  // 3: offset
2379   fused_conv2d.add_input(fused_batch_norm.input(3));  // 4: mean
2380   fused_conv2d.add_input(fused_batch_norm.input(4));  // 5: variance
2381 
2382   CopyConv2DAttributes(contraction, &fused_conv2d);
2383   SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm"},
2384                        /*num_args=*/4, /*epsilon=*/matched.epsilon);
2385 
2386   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2387   Status status;
2388   mutation->AddNode(std::move(fused_conv2d), &status);
2389   TF_RETURN_IF_ERROR(status);
2390   TF_RETURN_IF_ERROR(mutation->Apply());
2391 
2392   (*invalidated_nodes)[matched.fused_batch_norm] = true;
2393   (*nodes_to_delete)[matched.contraction] = true;
2394 
2395   return OkStatus();
2396 }
2397 
AddFusedConv2DNode(RemapperContext * ctx,const ContractionWithBatchNormAndActivation & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2398 Status AddFusedConv2DNode(RemapperContext* ctx,
2399                           const ContractionWithBatchNormAndActivation& matched,
2400                           std::vector<bool>* invalidated_nodes,
2401                           std::vector<bool>* nodes_to_delete) {
2402   const GraphDef* graph = ctx->graph_view.graph();
2403   const NodeDef& contraction = graph->node(matched.contraction);
2404 
2405   DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";
2406 
2407   const NodeDef& activation = graph->node(matched.activation);
2408   const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
2409   VLOG(2) << "Fuse Conv2D with BatchNorm and " << activation.op()
2410           << ": activation=" << activation.name()
2411           << " batch_norm=" << fused_batch_norm.name()
2412           << " conv2d=" << contraction.name();
2413 
2414   NodeDef fused_conv2d;
2415   fused_conv2d.set_name(activation.name());
2416   fused_conv2d.set_op(kFusedConv2D);
2417   fused_conv2d.set_device(contraction.device());
2418   fused_conv2d.add_input(contraction.input(0));       // 0: input
2419   fused_conv2d.add_input(contraction.input(1));       // 1: filter
2420   fused_conv2d.add_input(fused_batch_norm.input(1));  // 2: scale
2421   fused_conv2d.add_input(fused_batch_norm.input(2));  // 3: offset
2422   fused_conv2d.add_input(fused_batch_norm.input(3));  // 4: mean
2423   fused_conv2d.add_input(fused_batch_norm.input(4));  // 5: variance
2424 
2425   CopyConv2DAttributes(contraction, &fused_conv2d, &activation);
2426   SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm", activation.op()},
2427                        /*num_args=*/4, /*epsilon=*/matched.epsilon);
2428 
2429   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2430   Status status;
2431   mutation->AddNode(std::move(fused_conv2d), &status);
2432   TF_RETURN_IF_ERROR(status);
2433   TF_RETURN_IF_ERROR(mutation->Apply());
2434 
2435   (*invalidated_nodes)[matched.activation] = true;
2436   (*nodes_to_delete)[matched.contraction] = true;
2437   (*nodes_to_delete)[matched.fused_batch_norm] = true;
2438 
2439   return OkStatus();
2440 }
2441 
AddFusedContractionNode(RemapperContext * ctx,const ContractionWithBiasAddAndAdd & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2442 Status AddFusedContractionNode(RemapperContext* ctx,
2443                                const ContractionWithBiasAddAndAdd& matched,
2444                                std::vector<bool>* invalidated_nodes,
2445                                std::vector<bool>* nodes_to_delete) {
2446   const GraphDef* graph = ctx->graph_view.graph();
2447   const NodeDef& contraction = graph->node(matched.contraction);
2448   const NodeDef& bias_add = graph->node(matched.bias_add);
2449 
2450   // oneDNN version only supports fusion for Conv2D/3D and MatMul
2451   DCHECK(IsConv2D(contraction) || IsMatMul(contraction) ||
2452          IsConv3D(contraction));
2453 
2454   NodeDef contraction_node;
2455   const NodeDef& add = graph->node(matched.add);
2456   contraction_node.set_name(add.name());
2457   contraction_node.set_device(contraction.device());
2458   contraction_node.add_input(
2459       contraction.input(0));  // 0: input(conv) / a (matmul)
2460   contraction_node.add_input(
2461       contraction.input(1));  // 1: filter(conv) / b (matmul)
2462   contraction_node.add_input(bias_add.input(matched.bias_port));  // 2: bias
2463 
2464   // Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
2465   // previously, the other input to add is fused here.
2466   contraction_node.add_input(add.input(1 - matched.port_id));
2467 
2468   if (IsConv2D(contraction)) {
2469     contraction_node.set_op(kFusedConv2D);
2470     CopyConv2DAttributes(contraction, &contraction_node);
2471   } else if (IsMatMul(contraction)) {
2472     contraction_node.set_op(kFusedMatMul);
2473     CopyMatMulAttributes(contraction, &contraction_node);
2474   } else if (IsConv3D(contraction)) {
2475     contraction_node.set_op(kFusedConv3D);
2476     CopyConv3DAttributes(contraction, &contraction_node);
2477   }
2478 
2479   SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2);
2480 
2481   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2482   Status status;
2483   mutation->AddNode(std::move(contraction_node), &status);
2484   TF_RETURN_IF_ERROR(status);
2485   TF_RETURN_IF_ERROR(mutation->Apply());
2486 
2487   (*invalidated_nodes)[matched.add] = true;
2488   (*nodes_to_delete)[matched.contraction] = true;
2489   (*nodes_to_delete)[matched.bias_add] = true;
2490 
2491   return OkStatus();
2492 }
2493 
AddFusedConv3DNode(RemapperContext * ctx,const PadWithConv3D & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2494 Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched,
2495                           std::vector<bool>* invalidated_nodes,
2496                           std::vector<bool>* nodes_to_delete) {
2497   const GraphDef* graph = ctx->graph_view.graph();
2498   const NodeDef& contraction = graph->node(matched.contraction_idx);
2499   const NodeDef& pad_node_def = graph->node(matched.pad_idx);
2500   const NodeDef& padding_const_node_def =
2501       graph->node(matched.padding_const_idx);
2502   VLOG(2) << "Fuse " << pad_node_def.op() << " with contraction: "
2503           << " contraction=" << contraction.name();
2504 
2505   NodeDef fused_node;
2506   fused_node.set_name(contraction.name());
2507   fused_node.set_device(contraction.device());
2508   fused_node.add_input(pad_node_def.input(0));  // 0: input
2509   fused_node.add_input(contraction.input(1));   // 1: filter
2510   fused_node.set_op(kFusedConv3D);
2511 
2512   auto* attr = fused_node.mutable_attr();
2513   auto& src_attr = contraction.attr();
2514   (*attr)["T"] = src_attr.at("T");
2515   (*attr)["strides"] = src_attr.at("strides");
2516   (*attr)["data_format"] = src_attr.at("data_format");
2517   (*attr)["padding"] = src_attr.at("padding");
2518   (*attr)["dilations"] = src_attr.at("dilations");
2519 
2520   if (contraction.op() == kFusedConv3D) {
2521     fused_node.add_input(contraction.input(2));  // 2: bias
2522     (*attr)["fused_ops"] = src_attr.at("fused_ops");
2523     (*attr)["num_args"] = src_attr.at("num_args");
2524   } else {
2525     SetAttrValue(0, &(*attr)["num_args"]);
2526   }
2527 
2528   Tensor const_tensor;
2529   if (padding_const_node_def.op() == "Const" &&
2530       const_tensor.FromProto(
2531           padding_const_node_def.attr().at("value").tensor())) {
2532     auto const_value = const_tensor.flat<int32>();
2533     std::vector<int32> paddings;
2534     for (int i = 0; i < const_value.size(); ++i) {
2535       paddings.push_back(const_value(i));
2536       SetAttrValue(paddings, &(*attr)["padding_list"]);
2537     }
2538   }
2539 
2540   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2541   Status status;
2542   mutation->AddNode(std::move(fused_node), &status);
2543   TF_RETURN_IF_ERROR(status);
2544   TF_RETURN_IF_ERROR(mutation->Apply());
2545 
2546   (*invalidated_nodes)[matched.contraction_idx] = true;
2547   (*nodes_to_delete)[matched.pad_idx] = true;
2548   return OkStatus();
2549 }
2550 
AddFusedContractionNode(RemapperContext * ctx,const ContractionWithBiasAndAddActivation & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2551 Status AddFusedContractionNode(
2552     RemapperContext* ctx, const ContractionWithBiasAndAddActivation& matched,
2553     std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
2554   const GraphDef* graph = ctx->graph_view.graph();
2555   // MKL version only support fusion for Conv2D
2556   const NodeDef& contraction = graph->node(matched.contraction);
2557   DCHECK(IsConv2D(contraction) || IsConv3D(contraction));
2558   const NodeDef& activation = graph->node(matched.activation);
2559 
2560   NodeDef fused_conv;
2561   fused_conv.set_name(activation.name());
2562   fused_conv.set_device(contraction.device());
2563   fused_conv.add_input(contraction.input(0));  // 0: input
2564   fused_conv.add_input(contraction.input(1));  // 1: filter
2565   const NodeDef& bias_add = graph->node(matched.bias_add);
2566   fused_conv.add_input(bias_add.input(matched.bias_port));  // 2: bias
2567 
2568   if (IsConv2D(contraction)) {
2569     fused_conv.set_op(kFusedConv2D);
2570     CopyConv2DAttributes(contraction, &fused_conv);
2571   } else if (IsConv3D(contraction)) {
2572     fused_conv.set_op(kFusedConv3D);
2573     CopyConv3DAttributes(contraction, &fused_conv);
2574   }
2575 
2576   // Add OP has two inputs, one is conv+bias pattern matched previously,
2577   // the other input to add is fused here.
2578   const NodeDef& add = graph->node(matched.add);
2579   fused_conv.add_input(add.input(1 - matched.port_id));
2580 
2581   SetFusedOpAttributes(&fused_conv, {"BiasAdd", "Add", activation.op()}, 2);
2582 
2583   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2584   Status status;
2585   mutation->AddNode(std::move(fused_conv), &status);
2586   TF_RETURN_IF_ERROR(status);
2587   TF_RETURN_IF_ERROR(mutation->Apply());
2588 
2589   (*invalidated_nodes)[matched.activation] = true;
2590   (*nodes_to_delete)[matched.add] = true;
2591   (*nodes_to_delete)[matched.bias_add] = true;
2592   (*nodes_to_delete)[matched.contraction] = true;
2593 
2594   return OkStatus();
2595 }
2596 
AddFusedMatMulBiasAddAndGelu(RemapperContext * ctx,const std::map<string,int> & matched_nodes_map,const std::set<int> & remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete,bool is_gelu_approximate)2597 Status AddFusedMatMulBiasAddAndGelu(
2598     RemapperContext* ctx, const std::map<string, int>& matched_nodes_map,
2599     const std::set<int>& remove_node_indices,
2600     std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete,
2601     bool is_gelu_approximate) {
2602   auto* output_node =
2603       ctx->graph_view.GetNode(matched_nodes_map.at("output"))->node();
2604   auto* matmul_node =
2605       ctx->graph_view.GetNode(matched_nodes_map.at("matmul"))->node();
2606 
2607   NodeDef fused_node;
2608   // Fused node should have the name of terminal node of the fusion.
2609   fused_node.set_name(output_node->name());
2610   fused_node.set_op("_FusedMatMul");
2611   fused_node.set_device(matmul_node->device());
2612   fused_node.add_input(matmul_node->input(0));
2613   fused_node.add_input(matmul_node->input(1));
2614   if (is_gelu_approximate) {
2615     fused_node.add_input(matmul_node->input(2));
2616   } else {
2617     auto* bias_add_node =
2618         ctx->graph_view.GetNode(matched_nodes_map.at("bias_add"))->node();
2619     fused_node.add_input(bias_add_node->input(1));
2620   }
2621   CopyMatMulAttributes(*matmul_node, &fused_node);
2622   if (is_gelu_approximate)
2623     SetFusedOpAttributes(&fused_node, {"BiasAdd", "GeluApproximate"});
2624   else
2625     SetFusedOpAttributes(&fused_node, {"BiasAdd", "GeluExact"});
2626 
2627   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2628   Status status;
2629   mutation->AddNode(std::move(fused_node), &status);
2630   TF_RETURN_IF_ERROR(status);
2631   TF_RETURN_IF_ERROR(mutation->Apply());
2632   (*invalidated_nodes)[matched_nodes_map.at("output")] = true;
2633 
2634   for (const auto& node_idx : remove_node_indices) {
2635     (*nodes_to_delete)[node_idx] = true;
2636   }
2637   return OkStatus();
2638 }
2639 
AddMklLayerNorm(RemapperContext * ctx,const std::map<string,int> & matched_nodes_map,const std::set<int> & remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2640 Status AddMklLayerNorm(RemapperContext* ctx,
2641                        const std::map<string, int>& matched_nodes_map,
2642                        const std::set<int>& remove_node_indices,
2643                        std::vector<bool>* invalidated_nodes,
2644                        std::vector<bool>* nodes_to_delete) {
2645   auto* pre_reshape_node =
2646       ctx->graph_view.GetNode(matched_nodes_map.at("pre_reshape"))->node();
2647   auto* scale_node =
2648       ctx->graph_view.GetNode(matched_nodes_map.at("gamma"))->node();
2649   auto* output_node =
2650       ctx->graph_view.GetNode(matched_nodes_map.at("output"))->node();
2651 
2652   NodeDef fused_node;
2653   fused_node.set_name(output_node->name());
2654   fused_node.set_op("_MklLayerNorm");
2655   fused_node.set_device(output_node->device());
2656   fused_node.add_input(pre_reshape_node->input(0));
2657   fused_node.add_input(scale_node->name());
2658   fused_node.add_input(output_node->input(0));
2659   auto* attr = fused_node.mutable_attr();
2660   auto& src_attr = output_node->attr();
2661   (*attr)["T"] = src_attr.at("T");
2662 
2663   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2664   Status status;
2665   mutation->AddNode(std::move(fused_node), &status);
2666   TF_RETURN_IF_ERROR(status);
2667   TF_RETURN_IF_ERROR(mutation->Apply());
2668   (*invalidated_nodes)[matched_nodes_map.at("output")] = true;
2669 
2670   for (const auto& node_idx : remove_node_indices) {
2671     (*nodes_to_delete)[node_idx] = true;
2672   }
2673   return OkStatus();
2674 }
2675 
ReplaceMulMaximumWithLeakyRelu(RemapperContext * ctx,const std::map<string,int> & matched_nodes_map,const std::set<int> & remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2676 Status ReplaceMulMaximumWithLeakyRelu(
2677     RemapperContext* ctx, const std::map<string, int>& matched_nodes_map,
2678     const std::set<int>& remove_node_indices,
2679     std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
2680   const NodeDef* maximum =
2681       ctx->graph_view.GetNode(matched_nodes_map.at("max_to_leakyrelu"))->node();
2682   const NodeDef* input =
2683       ctx->graph_view.GetNode(matched_nodes_map.at("input"))->node();
2684   const auto* alpha_node_view =
2685       ctx->graph_view.GetNode(matched_nodes_map.at("alpha"));
2686   const auto* alpha_node_def = alpha_node_view->node();
2687 
2688   NodeDef fused_op;
2689   fused_op.set_name(maximum->name());
2690   fused_op.set_op("LeakyRelu");
2691   fused_op.set_device(maximum->device());
2692   fused_op.add_input(input->name());
2693 
2694   auto* attr = fused_op.mutable_attr();
2695   (*attr)["T"] = maximum->attr().at("T");
2696 
2697   // BF16 adds a cast before the const alpha, so accessing the const node
2698   // using the cast node to retrieve the value of alpha.
2699   float alpha_val;
2700   Tensor alpha_tensor;
2701   if (alpha_node_def->op() == "Cast") {
2702     const auto& regular_fanin_0 = alpha_node_view->GetRegularFanin(0);
2703     const auto* regular_node_view = regular_fanin_0.node_view();
2704     const auto* const_node = regular_node_view->node();
2705     if (const_node != nullptr && const_node->op() == "Const" &&
2706         alpha_tensor.FromProto(const_node->attr().at("value").tensor())) {
2707       alpha_val = alpha_tensor.flat<float>()(0);
2708       SetAttrValue(alpha_val, &(*attr)["alpha"]);
2709     }
2710   } else if (alpha_node_def->op() == "Const" &&
2711              alpha_tensor.FromProto(
2712                  alpha_node_def->attr().at("value").tensor())) {
2713     alpha_val = alpha_tensor.flat<float>()(0);
2714     SetAttrValue(alpha_val, &(*attr)["alpha"]);
2715   }
2716 
2717   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2718   Status status;
2719   mutation->AddNode(std::move(fused_op), &status);
2720   TF_RETURN_IF_ERROR(status);
2721   TF_RETURN_IF_ERROR(mutation->Apply());
2722 
2723   (*invalidated_nodes)[matched_nodes_map.at("max_to_leakyrelu")] = true;
2724 
2725   for (const auto& node_index : remove_node_indices) {
2726     (*nodes_to_delete)[node_index] = true;
2727   }
2728 
2729   return Status::OK();
2730 }
2731 
ReplaceSigmoidMulWithSwish(RemapperContext * ctx,const std::map<string,int> & matched_nodes_map,const std::set<int> & remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2732 Status ReplaceSigmoidMulWithSwish(
2733     RemapperContext* ctx, const std::map<string, int>& matched_nodes_map,
2734     const std::set<int>& remove_node_indices,
2735     std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
2736   const NodeDef* mul =
2737       ctx->graph_view.GetNode(matched_nodes_map.at("mul_to_swish"))->node();
2738   const NodeDef* sigmoid =
2739       ctx->graph_view.GetNode(matched_nodes_map.at("sigmoid"))->node();
2740 
2741   NodeDef fused_op;
2742   fused_op.set_name(mul->name());
2743   fused_op.set_op("_MklSwish");
2744   fused_op.set_device(mul->device());
2745   fused_op.add_input(sigmoid->input(0));
2746 
2747   auto* attr = fused_op.mutable_attr();
2748   (*attr)["T"] = mul->attr().at("T");
2749 
2750   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2751   Status status;
2752   mutation->AddNode(std::move(fused_op), &status);
2753   TF_RETURN_IF_ERROR(status);
2754   TF_RETURN_IF_ERROR(mutation->Apply());
2755 
2756   (*invalidated_nodes)[matched_nodes_map.at("mul_to_swish")] = true;
2757 
2758   for (const auto& node_index : remove_node_indices) {
2759     (*nodes_to_delete)[node_index] = true;
2760   }
2761   return OkStatus();
2762 }
2763 
AddFusedBatchNormExNode(RemapperContext * ctx,const FusedBatchNormEx & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2764 Status AddFusedBatchNormExNode(RemapperContext* ctx,
2765                                const FusedBatchNormEx& matched,
2766                                std::vector<bool>* invalidated_nodes,
2767                                std::vector<bool>* nodes_to_delete) {
2768   const GraphDef* graph = ctx->graph_view.graph();
2769   const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
2770   const NodeDef& activation = graph->node(matched.activation);
2771 
2772   VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:"
2773           << " activation=" << activation.name() << " side_input="
2774           << (matched.side_input != kMissingIndex
2775                   ? graph->node(matched.side_input).name()
2776                   : "<none>")
2777           << " invalidated="
2778           << (matched.invalidated != kMissingIndex
2779                   ? graph->node(matched.invalidated).name()
2780                   : "<none>")
2781           << " fused_batch_norm=" << fused_batch_norm.name();
2782 
2783   // Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
2784   NodeDef fused_op;
2785   fused_op.set_op(kFusedBatchNormEx);
2786   fused_op.set_name(fused_batch_norm.name());
2787   fused_op.set_device(fused_batch_norm.device());
2788 
2789   fused_op.add_input(fused_batch_norm.input(0));  // 0: input
2790   fused_op.add_input(fused_batch_norm.input(1));  // 1: scale
2791   fused_op.add_input(fused_batch_norm.input(2));  // 2: offset
2792   fused_op.add_input(fused_batch_norm.input(3));  // 3: estimated_mean
2793   fused_op.add_input(fused_batch_norm.input(4));  // 4: estimated_var
2794 
2795   CopyFusedBatchNormAttributes(fused_batch_norm, &fused_op);
2796 
2797   auto* attrs = fused_op.mutable_attr();
2798   SetAttrValue(activation.op(), &(*attrs)["activation_mode"]);
2799 
2800   if (matched.side_input != kMissingIndex) {
2801     SetAttrValue(1, &(*attrs)["num_side_inputs"]);
2802     const NodeDef& side_input = graph->node(matched.side_input);
2803     fused_op.add_input(side_input.name());  // 5: side_input
2804   } else {
2805     SetAttrValue(0, &(*attrs)["num_side_inputs"]);
2806   }
2807 
2808   // Turn activation node into Identity node.
2809   NodeDef identity_op;
2810   identity_op.set_op("Identity");
2811   identity_op.set_name(activation.name());
2812   identity_op.set_device(fused_batch_norm.device());
2813   identity_op.add_input(fused_batch_norm.name());
2814   (*identity_op.mutable_attr())["T"] = attrs->at("T");
2815 
2816   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2817   Status status;
2818   mutation->AddNode(std::move(fused_op), &status);
2819   TF_RETURN_IF_ERROR(status);
2820   mutation->AddNode(std::move(identity_op), &status);
2821   TF_RETURN_IF_ERROR(status);
2822   TF_RETURN_IF_ERROR(mutation->Apply());
2823 
2824   (*invalidated_nodes)[matched.fused_batch_norm] = true;
2825   (*invalidated_nodes)[matched.activation] = true;
2826   if (matched.side_input != kMissingIndex) {
2827     (*nodes_to_delete)[matched.invalidated] = true;
2828   }
2829 
2830   return OkStatus();
2831 }
2832 
AddFusedBatchNormGradExNode(RemapperContext * ctx,const FusedBatchNormGradEx & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2833 Status AddFusedBatchNormGradExNode(RemapperContext* ctx,
2834                                    const FusedBatchNormGradEx& matched,
2835                                    std::vector<bool>* invalidated_nodes,
2836                                    std::vector<bool>* nodes_to_delete) {
2837   const GraphDef* graph = ctx->graph_view.graph();
2838   const NodeDef& fused_batch_norm_grad =
2839       graph->node(matched.fused_batch_norm_grad);
2840   const NodeDef& activation_grad = graph->node(matched.activation_grad);
2841   const NodeDef& fwd_fused_batch_norm =
2842       graph->node(matched.fwd_fused_batch_norm);
2843 
2844   VLOG(2) << "Fuse FusedBatchNormGrad with " << activation_grad.op() << ": "
2845           << " fused_batch_norm_grad=" << fused_batch_norm_grad.name()
2846           << " side_input="
2847           << (matched.side_input_grad != kMissingIndex
2848                   ? graph->node(matched.side_input_grad).name()
2849                   : "<none>")
2850           << " activation=" << activation_grad.name()
2851           << " corresponding FusedBatchNorm=" << fwd_fused_batch_norm.name();
2852 
2853   NodeDef fused_op;
2854   fused_op.set_op(kFusedBatchNormGradEx);
2855   fused_op.set_name(fused_batch_norm_grad.name());
2856   fused_op.set_device(fused_batch_norm_grad.device());
2857 
2858   fused_op.add_input(activation_grad.input(0));        // 0: y_backprop
2859   fused_op.add_input(fused_batch_norm_grad.input(1));  // 1: x
2860   fused_op.add_input(fused_batch_norm_grad.input(2));  // 2: scale
2861   fused_op.add_input(fused_batch_norm_grad.input(3));  // 3: reserve_space_1
2862   fused_op.add_input(fused_batch_norm_grad.input(4));  // 4: reserve_space_2
2863   fused_op.add_input(fused_batch_norm_grad.input(5));  // 5: reserve_space_3
2864   fused_op.add_input(fwd_fused_batch_norm.input(2));   // 6: offset
2865   fused_op.add_input(activation_grad.input(1));        // 7: y
2866 
2867   CopyFusedBatchNormGradAttributes(fused_batch_norm_grad, &fused_op);
2868 
2869   auto* attrs = fused_op.mutable_attr();
2870   // Only support Relu mode.
2871   SetAttrValue("Relu", &(*attrs)["activation_mode"]);
2872 
2873   if (matched.side_input_grad != kMissingIndex) {
2874     SetAttrValue(1, &(*attrs)["num_side_inputs"]);
2875   } else {
2876     SetAttrValue(0, &(*attrs)["num_side_inputs"]);
2877   }
2878 
2879   NodeDef identity_op;
2880   identity_op.set_op("Identity");
2881   identity_op.set_name(activation_grad.name());
2882   identity_op.set_device(fused_batch_norm_grad.device());
2883   identity_op.add_input(absl::StrCat(fused_batch_norm_grad.name(), ":5"));
2884   (*identity_op.mutable_attr())["T"] = attrs->at("T");
2885 
2886   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2887   Status status;
2888   mutation->AddNode(std::move(fused_op), &status);
2889   TF_RETURN_IF_ERROR(status);
2890   if (matched.side_input_grad != kMissingIndex) {
2891     mutation->AddNode(std::move(identity_op), &status);
2892     TF_RETURN_IF_ERROR(status);
2893   }
2894   TF_RETURN_IF_ERROR(mutation->Apply());
2895 
2896   (*invalidated_nodes)[matched.fused_batch_norm_grad] = true;
2897   if (matched.side_input_grad != kMissingIndex) {
2898     (*invalidated_nodes)[matched.activation_grad] = true;
2899   } else {
2900     (*nodes_to_delete)[matched.activation_grad] = true;
2901   }
2902 
2903   return OkStatus();
2904 }
2905 
AddBatchNormNodes(RemapperContext * ctx,const FusedBatchNorm & matched)2906 Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
2907   const GraphDef* graph = ctx->graph_view.graph();
2908   const NodeDef& fused_node = graph->node(matched.fused_batch_norm);
2909   VLOG(2) << "Optimizing fused batch norm node "
2910           << SummarizeNodeDef(fused_node);
2911 
2912   const string& x = fused_node.input(0);
2913   string scale = fused_node.input(1);
2914   string offset = fused_node.input(2);
2915   string mean = fused_node.input(3);
2916   string variance = fused_node.input(4);
2917 
2918   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2919   Status status;
2920 
2921   string x_format = fused_node.attr().at(kDataFormat).s();
2922   if (x_format == "NCHW" || x_format == "NCDHW") {
2923     // Need to reshape the last 4 inputs
2924     NodeDef new_shape;
2925     const string new_shape_name =
2926         AddPrefixToNodeName(x_format + "Shape", fused_node.name());
2927     new_shape.set_name(new_shape_name);
2928     new_shape.set_op("Const");
2929     new_shape.set_device(fused_node.device());
2930     *new_shape.add_input() = AsControlDependency(scale);
2931     (*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
2932     if (x_format == "NCHW") {
2933       Tensor t(DT_INT32, {4});
2934       t.flat<int32>()(0) = 1;
2935       t.flat<int32>()(1) = -1;
2936       t.flat<int32>()(2) = 1;
2937       t.flat<int32>()(3) = 1;
2938       t.AsProtoTensorContent(
2939           (*new_shape.mutable_attr())["value"].mutable_tensor());
2940     } else {
2941       Tensor t(DT_INT32, {5});
2942       t.flat<int32>()(0) = 1;
2943       t.flat<int32>()(1) = -1;
2944       t.flat<int32>()(2) = 1;
2945       t.flat<int32>()(3) = 1;
2946       t.flat<int32>()(4) = 1;
2947       t.AsProtoTensorContent(
2948           (*new_shape.mutable_attr())["value"].mutable_tensor());
2949     }
2950     mutation->AddNode(std::move(new_shape), &status);
2951     TF_RETURN_IF_ERROR(status);
2952 
2953     NodeDef reshaped_scale;
2954     reshaped_scale.set_name(
2955         AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name()));
2956     reshaped_scale.set_op("Reshape");
2957     reshaped_scale.set_device(fused_node.device());
2958     *reshaped_scale.add_input() = scale;
2959     *reshaped_scale.add_input() = new_shape_name;
2960     (*reshaped_scale.mutable_attr())["T"] = fused_node.attr().at("T");
2961     (*reshaped_scale.mutable_attr())["Tshape"].set_type(DT_INT32);
2962     scale = reshaped_scale.name();
2963     mutation->AddNode(std::move(reshaped_scale), &status);
2964     TF_RETURN_IF_ERROR(status);
2965 
2966     NodeDef reshaped_offset;
2967     reshaped_offset.set_name(
2968         AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name()));
2969     reshaped_offset.set_op("Reshape");
2970     reshaped_offset.set_device(fused_node.device());
2971     *reshaped_offset.add_input() = offset;
2972     *reshaped_offset.add_input() = new_shape_name;
2973     (*reshaped_offset.mutable_attr())["T"] = fused_node.attr().at("T");
2974     (*reshaped_offset.mutable_attr())["Tshape"].set_type(DT_INT32);
2975     offset = reshaped_offset.name();
2976     mutation->AddNode(std::move(reshaped_offset), &status);
2977     TF_RETURN_IF_ERROR(status);
2978 
2979     NodeDef reshaped_mean;
2980     reshaped_mean.set_name(
2981         AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name()));
2982     reshaped_mean.set_op("Reshape");
2983     reshaped_mean.set_device(fused_node.device());
2984     *reshaped_mean.add_input() = mean;
2985     *reshaped_mean.add_input() = new_shape_name;
2986     (*reshaped_mean.mutable_attr())["T"] = fused_node.attr().at("T");
2987     (*reshaped_mean.mutable_attr())["Tshape"].set_type(DT_INT32);
2988     mean = reshaped_mean.name();
2989     mutation->AddNode(std::move(reshaped_mean), &status);
2990     TF_RETURN_IF_ERROR(status);
2991 
2992     NodeDef reshaped_variance;
2993     reshaped_variance.set_name(
2994         AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name()));
2995     reshaped_variance.set_op("Reshape");
2996     reshaped_variance.set_device(fused_node.device());
2997     *reshaped_variance.add_input() = variance;
2998     *reshaped_variance.add_input() = new_shape_name;
2999     (*reshaped_variance.mutable_attr())["T"] = fused_node.attr().at("T");
3000     (*reshaped_variance.mutable_attr())["Tshape"].set_type(DT_INT32);
3001     variance = reshaped_variance.name();
3002     mutation->AddNode(std::move(reshaped_variance), &status);
3003     TF_RETURN_IF_ERROR(status);
3004   }
3005 
3006   float epsilon = 0.0f;
3007   if (fused_node.attr().count("epsilon")) {
3008     epsilon = fused_node.attr().at("epsilon").f();
3009   }
3010   DataType dtype = fused_node.attr().at("T").type();
3011   Tensor value(dtype, TensorShape());
3012   value.scalar<float>()() = epsilon;
3013   NodeDef variance_epsilon;
3014   const string variance_epsilon_name =
3015       AddPrefixToNodeName("Const", fused_node.name());
3016   TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
3017       variance_epsilon_name, TensorValue(&value), &variance_epsilon));
3018   variance_epsilon.set_device(fused_node.device());
3019   mutation->AddNode(std::move(variance_epsilon), &status);
3020   TF_RETURN_IF_ERROR(status);
3021 
3022   NodeDef variance_plus_epsilon;
3023   const string variance_plus_epsilon_name =
3024       AddPrefixToNodeName("VarPlusEpsilon", fused_node.name());
3025   variance_plus_epsilon.set_name(variance_plus_epsilon_name);
3026   variance_plus_epsilon.set_op("Add");
3027   (*variance_plus_epsilon.mutable_attr())["T"].set_type(dtype);
3028   variance_plus_epsilon.set_device(fused_node.device());
3029   *variance_plus_epsilon.add_input() = variance;
3030   *variance_plus_epsilon.add_input() = variance_epsilon_name;
3031   mutation->AddNode(std::move(variance_plus_epsilon), &status);
3032   TF_RETURN_IF_ERROR(status);
3033 
3034   NodeDef inv;
3035   const string inv_name = AddPrefixToNodeName("Inv", fused_node.name());
3036   inv.set_name(inv_name);
3037   inv.set_op("Rsqrt");
3038   inv.set_device(fused_node.device());
3039   (*inv.mutable_attr())["T"].set_type(dtype);
3040   *inv.add_input() = variance_plus_epsilon_name;
3041   mutation->AddNode(std::move(inv), &status);
3042   TF_RETURN_IF_ERROR(status);
3043 
3044   NodeDef scaled;
3045   const string scaled_name = AddPrefixToNodeName("Scaled", fused_node.name());
3046   scaled.set_name(scaled_name);
3047   scaled.set_op("Mul");
3048   scaled.set_device(fused_node.device());
3049   (*scaled.mutable_attr())["T"].set_type(dtype);
3050   *scaled.add_input() = inv_name;
3051   *scaled.add_input() = scale;
3052   mutation->AddNode(std::move(scaled), &status);
3053   TF_RETURN_IF_ERROR(status);
3054 
3055   NodeDef a;
3056   const string a_name = AddPrefixToNodeName("Mul", fused_node.name());
3057   a.set_name(a_name);
3058   a.set_op("Mul");
3059   a.set_device(fused_node.device());
3060   (*a.mutable_attr())["T"].set_type(dtype);
3061   *a.add_input() = x;
3062   *a.add_input() = scaled_name;
3063   mutation->AddNode(std::move(a), &status);
3064   TF_RETURN_IF_ERROR(status);
3065 
3066   NodeDef b;
3067   const string b_name = AddPrefixToNodeName("Mul2", fused_node.name());
3068   b.set_name(b_name);
3069   b.set_op("Mul");
3070   b.set_device(fused_node.device());
3071   (*b.mutable_attr())["T"].set_type(dtype);
3072   *b.add_input() = mean;
3073   *b.add_input() = scaled_name;
3074   mutation->AddNode(std::move(b), &status);
3075   TF_RETURN_IF_ERROR(status);
3076 
3077   NodeDef c;
3078   const string c_name = AddPrefixToNodeName("Offset", fused_node.name());
3079   c.set_name(c_name);
3080   c.set_op("Sub");
3081   c.set_device(fused_node.device());
3082   (*c.mutable_attr())["T"].set_type(dtype);
3083   *c.add_input() = offset;
3084   *c.add_input() = b_name;
3085   mutation->AddNode(std::move(c), &status);
3086   TF_RETURN_IF_ERROR(status);
3087 
3088   NodeDef r;
3089   r.set_name(fused_node.name());
3090   r.set_op("Add");
3091   r.set_device(fused_node.device());
3092   (*r.mutable_attr())["T"].set_type(dtype);
3093   *r.add_input() = a_name;
3094   *r.add_input() = c_name;
3095   mutation->AddNode(std::move(r), &status);
3096   TF_RETURN_IF_ERROR(status);
3097 
3098   return mutation->Apply();
3099 }
3100 
AddTensorToHashBucketNode(RemapperContext * ctx,const TensorToHashBucket & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)3101 Status AddTensorToHashBucketNode(RemapperContext* ctx,
3102                                  const TensorToHashBucket& matched,
3103                                  std::vector<bool>* invalidated_nodes,
3104                                  std::vector<bool>* nodes_to_delete) {
3105   const GraphDef* graph = ctx->graph_view.graph();
3106   const NodeDef& pre_as_string = graph->node(matched.pre_as_string);
3107   const NodeDef& as_string = graph->node(matched.as_string);
3108   const NodeDef& string_to_hash_bucket =
3109       graph->node(matched.string_to_hash_bucket);
3110   VLOG(2) << "Fuse AsString with StringToHashBucketFast:"
3111           << " as_string=" << as_string.name()
3112           << " string_to_hash_bucket=" << string_to_hash_bucket.name()
3113           << " on device=" << pre_as_string.device();
3114 
3115   NodeDef fused_op;
3116   fused_op.set_name(string_to_hash_bucket.name());
3117   fused_op.set_device(pre_as_string.device());
3118   fused_op.add_input(as_string.input(0));  // 0: input
3119   fused_op.set_op(kTensorToHashBucket);
3120 
3121   auto* attr = fused_op.mutable_attr();
3122   auto& src_attr0 = as_string.attr();
3123   auto& src_attr1 = string_to_hash_bucket.attr();
3124   (*attr)["T"] = src_attr0.at("T");
3125   (*attr)["num_buckets"] = src_attr1.at("num_buckets");
3126 
3127   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
3128   Status status;
3129   mutation->AddNode(std::move(fused_op), &status);
3130   TF_RETURN_IF_ERROR(status);
3131   TF_RETURN_IF_ERROR(mutation->Apply());
3132 
3133   (*invalidated_nodes)[matched.string_to_hash_bucket] = true;
3134   (*nodes_to_delete)[matched.as_string] = true;
3135 
3136   return OkStatus();
3137 }
3138 
AddFusedBatchMatMul(RemapperContext * ctx,const std::map<string,int> & matched_nodes_map,const std::set<int> & remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)3139 Status AddFusedBatchMatMul(RemapperContext* ctx,
3140                            const std::map<string, int>& matched_nodes_map,
3141                            const std::set<int>& remove_node_indices,
3142                            std::vector<bool>* invalidated_nodes,
3143                            std::vector<bool>* nodes_to_delete) {
3144   auto* output_node =
3145       ctx->graph_view.GetNode(matched_nodes_map.at("output"))->node();
3146   auto* batch_matmul_node =
3147       ctx->graph_view.GetNode(matched_nodes_map.at("batch_matmul"))->node();
3148   auto* multiplicand_node =
3149       ctx->graph_view.GetNode(matched_nodes_map.at("multiplicand"))->node();
3150   auto* addend_node =
3151       ctx->graph_view.GetNode(matched_nodes_map.at("addend"))->node();
3152 
3153   NodeDef fused_node;
3154   fused_node.set_name(output_node->name());
3155   fused_node.set_op("_MklFusedBatchMatMulV2");
3156   fused_node.set_device(batch_matmul_node->device());
3157   fused_node.add_input(batch_matmul_node->input(0));
3158   fused_node.add_input(batch_matmul_node->input(1));
3159   fused_node.add_input(multiplicand_node->name());
3160   fused_node.add_input(addend_node->name());
3161 
3162   CopyBatchMatMulAttributes(*batch_matmul_node, &fused_node);
3163   SetFusedOpAttributes(&fused_node, {"Mul", "Add"}, /*num_args=*/2);
3164 
3165   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
3166   Status status;
3167   mutation->AddNode(std::move(fused_node), &status);
3168   TF_RETURN_IF_ERROR(status);
3169   TF_RETURN_IF_ERROR(mutation->Apply());
3170   (*invalidated_nodes)[matched_nodes_map.at("output")] = true;
3171 
3172   for (const auto& node_idx : remove_node_indices) {
3173     (*nodes_to_delete)[node_idx] = true;
3174   }
3175   return OkStatus();
3176 }
3177 
3178 // This function supports below patterns that require inferred
3179 // shapes:
3180 // 1. Contraction + Add.
3181 // 2. Contraction + Add + Activation.
3182 // 3. Contraction + BiasAdd/BiasSemanticAdd + Add.
3183 // 4. Contraction + BiasAdd/BiasSemanticAdd + Add + Activation.
3184 // Contraction candidate: MatMul, Conv2D, Conv3D, DepthwiseConv2dNative.
IsContractionWithAdd(const RemapperContext & ctx,int node_index)3185 bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
3186   const auto* node_view = ctx.graph_view.GetNode(node_index);
3187 
3188   auto is_supported_add_input = [](const auto* node_view) -> bool {
3189     if (IsConvOrMatMul(*node_view->node())) return true;
3190     // IsAdd will verify BiasSemanticAdd.
3191     if (IsBiasAdd(*node_view->node()) || IsAdd(*node_view->node())) {
3192       if (node_view->NumRegularFanins() < 2) return false;
3193       const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
3194       const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
3195       return IsConvOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
3196              IsConvOrMatMul(*bias_add_fanin_1.node_view()->node());
3197     }
3198     return false;
3199   };
3200 
3201   auto is_supported_add = [&](const auto* node_view) -> bool {
3202     const auto* node_def = node_view->node();
3203     if (IsAdd(*node_def)) {
3204       if (node_view->NumRegularFanins() < 2) return false;
3205       const auto& add_fanin_0 = node_view->GetRegularFanin(0);
3206       const auto& add_fanin_1 = node_view->GetRegularFanin(1);
3207       return is_supported_add_input(add_fanin_0.node_view()) ||
3208              is_supported_add_input(add_fanin_1.node_view());
3209     }
3210     return false;
3211   };
3212 
3213   // Dealing with the Contraction + Add or Contraction + BiasAdd/BiasSemanticAdd
3214   // + Add patterns.
3215   if (is_supported_add(node_view)) {
3216     return true;
3217   }
3218   // Dealing with the Contraction + Add + Activation  or Contraction +
3219   // BiasAdd/BiasSemanticAdd + Add + Activation pattern.
3220   if (IsSupportedActivation(*node_view->node())) {
3221     for (int i = 0; i < node_view->NumRegularFanins(); i++) {
3222       const auto& fanin_i = node_view->GetRegularFanin(i);
3223       if (is_supported_add(fanin_i.node_view())) return true;
3224     }
3225   }
3226 
3227   return false;
3228 }
3229 
FindSoftplusAndTanhAndMul(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)3230 bool FindSoftplusAndTanhAndMul(RemapperContext* ctx, int node_index,
3231                                std::map<string, int>* matched_nodes_map,
3232                                std::set<int>* remove_node_indices) {
3233   // Mish fusion is enabled only with oneDNN library.
3234   if (!IsMKLEnabled()) return false;
3235 
3236   using utils::MatchingDirection;
3237   using utils::NodeStatus;
3238   // clang-format off
3239   //                Convert Softplus+Tanh+Mul to Mish
3240   //          From Graph                          To Graph
3241   //          -----------                         ---------
3242   //    Conv2D  <-  Filter(const)           Conv2D  <-  Filter(const)
3243   //      !                                   !
3244   //      V                                   V
3245   //    BiasAdd <-  bias(const)             BiasAdd <-  bias(const)
3246   //      !                                   !
3247   //      V                                   !
3248   //  ---- ----                               !
3249   //  !       !                               !
3250   //  !       V                               !
3251   //  !    Softplus                           !
3252   //  !       !                               !
3253   //  !       V                               !
3254   //  !     Tanh                              !
3255   //  !       !                               !
3256   //  !       V                               !
3257   //  ---   ---                               !
3258   //     !  !                                 !
3259   //     !  !                                 !
3260   //     V  V                                 V
3261   //      Mul                           _MklFusedMish
3262   //      !                                   !
3263   //      V                                   V
3264 
3265   utils::OpTypePattern softplustanhmul_pattern {
3266     "Mul", "mul_to_mish", NodeStatus::kReplace,
3267     {
3268       {
3269         "Tanh", "tanh", NodeStatus::kRemove,
3270         {
3271           {
3272             "Softplus", "softplus", NodeStatus::kRemove,
3273             {
3274               {"*", "input", NodeStatus::kRemain}
3275             }
3276           }
3277         }
3278       },
3279       {"*", "input", NodeStatus::kRemain}
3280     }
3281   };
3282   // clang-format on
3283 
3284   // check for data types
3285   auto* mul_node_def = ctx->graph_view.GetNode(node_index)->node();
3286   if (!HasDataType(mul_node_def, DT_FLOAT) &&
3287       !HasDataType(mul_node_def, DT_BFLOAT16))
3288     return false;
3289 
3290   if (!NodeIsOnCpu(mul_node_def)) return false;
3291 
3292   bool found_op_type_match = false;
3293   utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
3294       &(ctx->graph_view));
3295   matched_nodes_map->clear();
3296   remove_node_indices->clear();
3297   found_op_type_match = graph_matcher.GetMatchedNodes(
3298       softplustanhmul_pattern, {}, ctx->graph_view.GetNode(node_index),
3299       matched_nodes_map, remove_node_indices);
3300 
3301   return found_op_type_match;
3302 }
3303 
ReplaceSoftplusTanhAndMulWithMish(RemapperContext * ctx,const std::map<string,int> * matched_nodes_map,const std::set<int> * remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)3304 Status ReplaceSoftplusTanhAndMulWithMish(
3305     RemapperContext* ctx, const std::map<string, int>* matched_nodes_map,
3306     const std::set<int>* remove_node_indices,
3307     std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
3308   // Fuse Softplus + Tanh + Mul to Mish
3309   auto* old_mul_node =
3310       ctx->graph_view.GetNode(matched_nodes_map->at("mul_to_mish"))->node();
3311   auto* softplus_node =
3312       ctx->graph_view.GetNode(matched_nodes_map->at("softplus"))->node();
3313 
3314   NodeDef fused_node;
3315   fused_node.set_name(old_mul_node->name());
3316   fused_node.set_op("_MklFusedMish");
3317   fused_node.set_device(old_mul_node->device());
3318   fused_node.add_input(softplus_node->input(0));
3319 
3320   auto* fused_node_attr = fused_node.mutable_attr();
3321   (*fused_node_attr)["T"] = old_mul_node->attr().at("T");
3322 
3323   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
3324   Status status;
3325   mutation->AddNode(std::move(fused_node), &status);
3326   TF_RETURN_IF_ERROR(status);
3327   TF_RETURN_IF_ERROR(mutation->Apply());
3328   (*invalidated_nodes)[matched_nodes_map->at("mul_to_mish")] = true;
3329 
3330   for (const auto& node_index : *remove_node_indices) {
3331     (*nodes_to_delete)[node_index] = true;
3332   }
3333 
3334   return OkStatus();
3335 }
3336 
3337 // Check if a node is a candidate to one of the patterns that require inferred
3338 // shapes:
3339 //   (1) Splitting FusedBatchNorm into primitives.
3340 //   (2) Fusing side input and/or activation into FusedBatchNorm.
3341 //   (3) Fusing Conv2D biasadd and relu on GPU
3342 //   (4) INTEL_MKL specific: Conv2D -> Add or Conv2D -> BiasAdd -> Add.
3343 //   (5) Fusing side output and/or activation into FusedBatchNormGrad.
RequiresInferredShapes(const RemapperContext & ctx,int node_index,const Cluster * cluster)3344 bool RequiresInferredShapes(const RemapperContext& ctx, int node_index,
3345                             const Cluster* cluster) {
3346   // Candidate for a FusedBatchNorm splitting.
3347   const auto* node_view = ctx.graph_view.GetNode(node_index);
3348   const auto* node_def = node_view->node();
3349   const auto is_batch_norm_candidate = [&]() -> bool {
3350     if (!IsFusedBatchNorm(*node_def)) return false;
3351     if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
3352 
3353     bool is_training = true;
3354     if (!TryGetNodeAttr(*node_def, kIsTraining, &is_training)) return false;
3355     if (is_training) return false;
3356 
3357     return true;
3358   };
3359 
3360   const auto is_act_biasadd_conv_candidate = [&]() -> bool {
3361     if (!IsSupportedActivation(*node_def)) return false;
3362 
3363     if (!RuntimeFusionEnabled(cluster) && !IsRelu(*node_def)) return false;
3364 
3365     const auto is_compatible_dtype = [&](const NodeDef& node) -> bool {
3366       // The cuDNN fusion with relu6, elu, leakyrelu is realized by the runtime
3367       // compiled kernels which only support fp16.
3368       bool fp16_only =
3369           IsRelu6(*node_def) || IsElu(*node_def) || IsLeakyRelu(*node_def);
3370       DataType dtype = GetDataTypeFromAttr(node, "T");
3371       return dtype == DT_HALF || (!fp16_only && dtype == DT_FLOAT);
3372     };
3373     if (!is_compatible_dtype(*node_def)) return false;
3374 
3375     if (node_view->NumRegularFanins() < 1) return false;
3376     const auto& relu_fanin_0 = node_view->GetRegularFanin(0);
3377     const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
3378     const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
3379 
3380     if (!IsBiasAdd(*relu_fanin_0_node_def) && !IsAdd(*relu_fanin_0_node_def))
3381       return false;
3382     if (!is_compatible_dtype(*relu_fanin_0_node_def)) return false;
3383 
3384     if (relu_fanin_0_node_view->NumRegularFanins() < 1) return false;
3385 
3386     const auto& biasadd_fanin_0 = relu_fanin_0_node_view->GetRegularFanin(0);
3387     const auto* biasadd_fanin_0_node_def = biasadd_fanin_0.node_view()->node();
3388 
3389     if (!IsConv2D(*biasadd_fanin_0_node_def) &&
3390         !IsConv3D(*biasadd_fanin_0_node_def))
3391       return false;
3392     if (!is_compatible_dtype(*biasadd_fanin_0_node_def)) return false;
3393     return true;
3394   };
3395 
3396   // Candidate for a FusedBatchNorm fusion.
3397   const auto is_batch_norm_fusion_candidate = [&]() -> bool {
3398     if (!IsRelu(*node_def)) return false;
3399 
3400     if (node_view->NumRegularFanins() < 1) return false;
3401     const auto& relu_fanin_0 = node_view->GetRegularFanin(0);
3402     const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
3403     const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
3404 
3405     if (IsFusedBatchNorm(*relu_fanin_0_node_def)) {
3406       // FusedBatchNorm + Relu.
3407       return true;
3408 
3409     } else if (IsAdd(*relu_fanin_0_node_def)) {
3410       // FusedBatchNorm + Add + Relu.
3411 
3412       if (relu_fanin_0_node_view->NumRegularFanins() < 2) return false;
3413       const auto& add_regular_fanin_0 =
3414           relu_fanin_0_node_view->GetRegularFanin(0);
3415       if (IsFusedBatchNorm(*add_regular_fanin_0.node_view()->node()))
3416         return true;
3417       const auto& add_regular_fanin_1 =
3418           relu_fanin_0_node_view->GetRegularFanin(1);
3419       if (IsFusedBatchNorm(*add_regular_fanin_1.node_view()->node()))
3420         return true;
3421     }
3422 
3423     return false;
3424   };
3425 
3426   // Candidate for a FusedBatchNormGrad fusion.
3427   const auto is_batch_norm_grad_fusion_candidate = [&]() -> bool {
3428     if (!IsFusedBatchNormGrad(*node_def)) return false;
3429 
3430     if (node_view->NumRegularFanins() < 1) return false;
3431     const auto& bn_fanin_0 = node_view->GetRegularFanin(0);
3432     const auto* bn_fanin_0_node_view = bn_fanin_0.node_view();
3433     const auto* bn_fanin_0_node_def = bn_fanin_0_node_view->node();
3434 
3435     if (IsReluGrad(*bn_fanin_0_node_def)) {
3436       // ReluGrad + FusedBatchNormGrad.
3437       return true;
3438     }
3439 
3440     return false;
3441   };
3442 
3443   if (IsMKLEnabled())
3444     return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
3445            IsContractionWithAdd(ctx, node_index) ||
3446            is_act_biasadd_conv_candidate();
3447 
3448   return is_act_biasadd_conv_candidate() || is_batch_norm_candidate() ||
3449          is_batch_norm_fusion_candidate() ||
3450          is_batch_norm_grad_fusion_candidate();
3451 }
3452 }  // namespace
3453 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)3454 Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
3455                           GraphDef* optimized_graph) {
3456   GrapplerItem mutable_item = item;
3457   Status status;
3458   RemapperContext ctx(&mutable_item, &status, cpu_layout_conversion_,
3459                       xla_auto_clustering_on_);
3460   TF_RETURN_IF_ERROR(status);
3461   // Processing graph in reverse-topological sorted order allows to remap
3462   // longer chains of dependent ops in one pass.
3463   TF_RETURN_IF_ERROR(
3464       ctx.graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
3465 
3466   const int num_nodes = item.graph.node_size();
3467   // Skip nodes that were invalidated by a remapper, e.g. do not process BiasAdd
3468   // and Activation nodes that were fused into a Conv2D node.
3469   std::vector<bool> invalidated_nodes(num_nodes);
3470   std::vector<bool> nodes_to_delete(num_nodes);
3471 
3472   // _Fused{...} kernels do not have registered gradient function, so we must
3473   // not perform rewrite if the graph will be differentiated later.
3474   bool allow_non_differentiable_rewrites =
3475       item.optimization_options().allow_non_differentiable_rewrites;
3476 
3477   for (int i = num_nodes - 1; i >= 0; --i) {
3478     // Check if node was invalidated by one of the previous remaps.
3479     if (invalidated_nodes[i] || nodes_to_delete[i]) {
3480       continue;
3481     }
3482 
3483     // Infer properties lazily in case they are not needed.
3484     if (!ctx.inferred_graph_properties &&
3485         RequiresInferredShapes(ctx, i, cluster)) {
3486       const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3487       TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
3488           assume_valid_feeds,
3489           /*aggressive_shape_inference=*/false,
3490           /*include_input_tensor_values=*/true,
3491           /*include_output_tensor_values=*/false));
3492       ctx.inferred_graph_properties = true;
3493     }
3494 
3495     ContractionWithBiasAddAndAdd contract_with_bias_and_add;
3496     ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
3497 
3498     if (IsMKLEnabled()) {
3499       // Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D.
3500       // or Remap Conv3D+BiasAdd+Add+relu into _FusedConv3D
3501       if (FindContractionWithBiasAndAddActivation(
3502               ctx, i, &contract_with_bias_and_add_activation)) {
3503         TF_RETURN_IF_ERROR(
3504             AddFusedContractionNode(&ctx, contract_with_bias_and_add_activation,
3505                                     &invalidated_nodes, &nodes_to_delete));
3506         continue;
3507       }
3508 
3509       // Remap {Conv2D,Conv3D}+BiasAdd+Add into the _FusedConv2D/3D.
3510       if (FindContractionWithBiasAddAndAdd(ctx, i,
3511                                            &contract_with_bias_and_add)) {
3512         TF_RETURN_IF_ERROR(
3513             AddFusedContractionNode(&ctx, contract_with_bias_and_add,
3514                                     &invalidated_nodes, &nodes_to_delete));
3515         continue;
3516       }
3517 
3518       PadWithConv3D pad_with_conv3d;
3519       // Remap Pad+{Conv3D,_FusedConv3D} into the _FusedConv3D.
3520       if (FindPadWithConv3D(ctx, i, &pad_with_conv3d)) {
3521         TF_RETURN_IF_ERROR(AddFusedConv3DNode(
3522             &ctx, pad_with_conv3d, &invalidated_nodes, &nodes_to_delete));
3523         continue;
3524       }
3525 
3526       std::map<string, int> matched_nodes_map;
3527       std::set<int> remove_node_indices;
3528 
3529       // Softplus + Tanh + Mul to Mish conversion
3530       matched_nodes_map.clear();
3531       remove_node_indices.clear();
3532       if (FindSoftplusAndTanhAndMul(&ctx, i, &matched_nodes_map,
3533                                     &remove_node_indices)) {
3534         TF_RETURN_IF_ERROR(ReplaceSoftplusTanhAndMulWithMish(
3535             &ctx, &matched_nodes_map, &remove_node_indices, &invalidated_nodes,
3536             &nodes_to_delete));
3537         continue;
3538       }
3539 
3540       // Remap BatchMatMul+Mul+AddV2 into the _FusedBatchMatMul.
3541       matched_nodes_map.clear();
3542       remove_node_indices.clear();
3543       if (FindFusedBatchMatMul(&ctx, i, &matched_nodes_map,
3544                                &remove_node_indices)) {
3545         TF_RETURN_IF_ERROR(
3546             AddFusedBatchMatMul(&ctx, matched_nodes_map, remove_node_indices,
3547                                 &invalidated_nodes, &nodes_to_delete));
3548         continue;
3549       }
3550 
3551       // Remap Maximum(x, alpha * x) pattern, fuse them into the LeakyRelu(x).
3552       std::map<string, int> mulmax_matched_nodes_map;
3553       std::set<int> mulmax_remove_node_indices;
3554       if (FindMulAndMaximum(&ctx, i, &mulmax_matched_nodes_map,
3555                             &mulmax_remove_node_indices)) {
3556         TF_RETURN_IF_ERROR(ReplaceMulMaximumWithLeakyRelu(
3557             &ctx, mulmax_matched_nodes_map, mulmax_remove_node_indices,
3558             &invalidated_nodes, &nodes_to_delete));
3559         continue;
3560       }
3561 
3562       // Remap Mul(x, Sigmoid(x)) pattern, fuse them into the Swish(x).
3563       std::map<string, int> sigmoidmul_matched_nodes_map;
3564       std::set<int> sigmoidmul_remove_node_indices;
3565       if (FindSigmoidAndMul(&ctx, i, &sigmoidmul_matched_nodes_map,
3566                             &sigmoidmul_remove_node_indices)) {
3567         TF_RETURN_IF_ERROR(ReplaceSigmoidMulWithSwish(
3568             &ctx, sigmoidmul_matched_nodes_map, sigmoidmul_remove_node_indices,
3569             &invalidated_nodes, &nodes_to_delete));
3570         continue;
3571       }
3572 
3573       // Remap smaller ops from layernorm python api into _MklLayerNorm
3574       matched_nodes_map.clear();
3575       remove_node_indices.clear();
3576       if (FindMklLayerNorm(&ctx, i, &matched_nodes_map, &remove_node_indices)) {
3577         TF_RETURN_IF_ERROR(
3578             AddMklLayerNorm(&ctx, matched_nodes_map, remove_node_indices,
3579                             &invalidated_nodes, &nodes_to_delete));
3580         continue;
3581       }
3582     }
3583 
3584     // Remap MatMul + BiasAdd + gelu-subgraph
3585     std::map<string, int> matched_nodes_map;
3586     std::set<int> remove_node_indices;
3587     bool is_gelu_approximate = false;
3588     if (FindMatMulBiasAddAndGelu(&ctx, i, &matched_nodes_map,
3589                                  &remove_node_indices, &is_gelu_approximate)) {
3590       TF_RETURN_IF_ERROR(AddFusedMatMulBiasAddAndGelu(
3591           &ctx, matched_nodes_map, remove_node_indices, &invalidated_nodes,
3592           &nodes_to_delete, is_gelu_approximate));
3593       continue;
3594     }
3595 
3596     // Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd into the
3597     // _Fused{Conv2D,DepthwiseConv2dNative,MatMul}
3598     ContractionWithBiasAdd contract_with_bias;
3599     if (allow_non_differentiable_rewrites &&
3600         FindContractionWithBias(ctx, i, &contract_with_bias)) {
3601       TF_RETURN_IF_ERROR(AddFusedContractionNode(
3602           &ctx, contract_with_bias, &invalidated_nodes, &nodes_to_delete));
3603       continue;
3604     }
3605 
3606     // Remap {Conv2D,DepthwiseConv2D,MatMul,Conv3D}+BiasAdd+Activation into the
3607     // _Fused{Conv2D,DepthwiseConv2dNative,MatMul,Conv3D}.
3608     ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
3609     if (allow_non_differentiable_rewrites &&
3610         FindContractionWithBiasAndActivation(
3611             ctx, cluster, i, &contract_with_bias_and_activation)) {
3612       TF_RETURN_IF_ERROR(
3613           AddFusedContractionNode(&ctx, contract_with_bias_and_activation,
3614                                   &invalidated_nodes, &nodes_to_delete));
3615       continue;
3616     }
3617 
3618     // NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do
3619     // it for MatMul as well, but in practice this pattern does not appear in
3620     // real Tensorflow graphs.
3621 
3622     // Remap {Conv2D, Conv3D}+Squeeze+BiasAdd into the {_FusedConv2D,
3623     // _FusedConv3D}+Squeeze.
3624     ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
3625     if (allow_non_differentiable_rewrites &&
3626         FindConvWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) {
3627       TF_RETURN_IF_ERROR(AddFusedConvNode(&ctx, contract_with_squeeze_and_bias,
3628                                           &invalidated_nodes,
3629                                           &nodes_to_delete));
3630       continue;
3631     }
3632 
3633 #ifndef DNNL_AARCH64_USE_ACL
3634     // Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
3635     ContractionWithBatchNorm contract_with_batch_norm;
3636     if (allow_non_differentiable_rewrites &&
3637         FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
3638       TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
3639                                             &invalidated_nodes,
3640                                             &nodes_to_delete));
3641       continue;
3642     }
3643 
3644     // Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
3645     ContractionWithBatchNormAndActivation
3646         contract_with_batch_norm_and_activation;
3647     if (allow_non_differentiable_rewrites &&
3648         FindConv2DWithBatchNormAndActivation(
3649             ctx, i, &contract_with_batch_norm_and_activation)) {
3650       TF_RETURN_IF_ERROR(
3651           AddFusedConv2DNode(&ctx, contract_with_batch_norm_and_activation,
3652                              &invalidated_nodes, &nodes_to_delete));
3653       continue;
3654     }
3655 #endif  // !DNNL_AARCH64_USE_ACL
3656 
3657     // Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
3658     FusedBatchNormEx fused_batch_norm_ex;
3659     if (allow_non_differentiable_rewrites &&
3660         FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) {
3661       TF_RETURN_IF_ERROR(AddFusedBatchNormExNode(
3662           &ctx, fused_batch_norm_ex, &invalidated_nodes, &nodes_to_delete));
3663       continue;
3664     }
3665 
3666     FusedBatchNormGradEx fused_batch_norm_grad_ex;
3667     if (allow_non_differentiable_rewrites &&
3668         FindFusedBatchNormGradEx(ctx, i, &fused_batch_norm_grad_ex)) {
3669       TF_RETURN_IF_ERROR(
3670           AddFusedBatchNormGradExNode(&ctx, fused_batch_norm_grad_ex,
3671                                       &invalidated_nodes, &nodes_to_delete));
3672       continue;
3673     }
3674 
3675     TensorToHashBucket tensor_to_hash_bucket;
3676     if (allow_non_differentiable_rewrites &&
3677         FindTensorToHashBucket(ctx, i, &tensor_to_hash_bucket)) {
3678       TF_RETURN_IF_ERROR(AddTensorToHashBucketNode(
3679           &ctx, tensor_to_hash_bucket, &invalidated_nodes, &nodes_to_delete));
3680       continue;
3681     }
3682 
3683     // During inference, most of the inputs to FusedBatchNorm are constant, and
3684     // we can therefore replace the op with a much cheaper set of primitives.
3685     FusedBatchNorm fused_batch_norm;
3686     if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) {
3687       TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
3688       continue;
3689     }
3690   }
3691 
3692   // Remove invalidated nodes.
3693   utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder();
3694   for (int i = 0; i < num_nodes; ++i) {
3695     if (nodes_to_delete[i]) {
3696       mutation->RemoveNode(ctx.graph_view.GetNode(i));
3697     }
3698   }
3699   TF_RETURN_IF_ERROR(mutation->Apply());
3700 
3701   *optimized_graph = std::move(mutable_item.graph);
3702 
3703   return OkStatus();
3704 }
3705 
3706 }  // namespace grappler
3707 }  // namespace tensorflow
3708