1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/remapper.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/core/framework/versions.pb.h"
20 #include "tensorflow/core/grappler/costs/graph_properties.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/grappler/utils/graph_view.h"
27 #include "tensorflow/core/grappler/utils/pattern_utils.h"
28 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
29 #include "tensorflow/core/grappler/utils/topological_sort.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
33 #include "tensorflow/core/util/env_var.h"
34 #include "tensorflow/core/util/use_cudnn.h"
35 #include "tensorflow/core/util/util.h"
36
37 #if GOOGLE_CUDA
38 #include "third_party/gpus/cudnn/cudnn.h"
39 #endif // GOOGLE_CUDA
40
41 namespace tensorflow {
42 namespace grappler {
43
44 // Supported patterns:
45 //
46 // Conv2D + ... -> _FusedConv2D
47 // (1) Conv2D + BiasAdd + <Activation>
48 // (2) Conv2D + FusedBatchNorm + <Activation>
49 // (3) Conv2D + Squeeze + BiasAdd
50 //
51 // MatMul + ... -> _FusedMatMul:
52 // (1) MatMul + BiasAdd + <Activation>
53 //
54 // DepthwiseConv2dNative + ... -> _FusedDepthwiseConv2dNative:
55 // (1) DepthwiseConv2dNative + BiasAdd + <Activation>
56 //
57 // FusedBatchNorm[$is_training] + ... -> _FusedBatchNormEx[$is_training]
58 // (1) FusedBatchNorm + <Activation>
59 // (2) FusedBatchNorm + SideInput + <Activation>
60 //
61 // Sigmoid + Mul -> _MklSwish // This fusion only works on Intel CPU.
62 //
63 //
64 // In all cases, the supported activation functions are Relu, Relu6, and Elu.
65 //
66 // Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the
67 // patterns are "ContractionWith...".
68 namespace {
69
70 constexpr char kFusedConv2D[] = "_FusedConv2D";
71 constexpr char kFusedConv3D[] = "_FusedConv3D";
72 constexpr char kFusedMatMul[] = "_FusedMatMul";
73 constexpr char kFusedDepthwiseConv2dNative[] = "_FusedDepthwiseConv2dNative";
74 constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
75 constexpr char kFusedBatchNormGradEx[] = "_FusedBatchNormGradEx";
76 constexpr char kTensorToHashBucket[] = "_TensorToHashBucketFast";
77
78 constexpr char kDataFormat[] = "data_format";
79 constexpr char kIsTraining[] = "is_training";
80
81 constexpr char kWidth[] = "width";
82 constexpr char kFill[] = "fill";
83
84 constexpr int kMissingIndex = -1;
85
86 struct RemapperContext {
RemapperContexttensorflow::grappler::__anon2f31cdb30111::RemapperContext87 explicit RemapperContext(GrapplerItem* item, Status* status,
88 RewriterConfig::CpuLayout cpu_layout_conversion,
89 bool xla_auto_clustering_on)
90 : nodes_to_preserve(item->NodesToPreserve()),
91 graph_view(&item->graph, status),
92 graph_properties(*item),
93 inferred_graph_properties(false),
94 cpu_layout_conversion(cpu_layout_conversion),
95 xla_auto_clustering_on(xla_auto_clustering_on) {}
96
97 std::unordered_set<string> nodes_to_preserve;
98 utils::MutableGraphView graph_view;
99 GraphProperties graph_properties;
100 bool inferred_graph_properties;
101 RewriterConfig::CpuLayout cpu_layout_conversion;
102 bool xla_auto_clustering_on;
103 };
104
105 // FusedBatchNorm that can be replaced with a cheaper set of primitives.
106 struct FusedBatchNorm {
107 FusedBatchNorm() = default;
FusedBatchNormtensorflow::grappler::__anon2f31cdb30111::FusedBatchNorm108 explicit FusedBatchNorm(int fused_batch_norm)
109 : fused_batch_norm(fused_batch_norm) {}
110
111 int fused_batch_norm = kMissingIndex;
112 };
113
114 // FusedBatchNorm[$is_training] with fused side input and/or activation.
115 struct FusedBatchNormEx {
116 FusedBatchNormEx() = default;
117
118 int fused_batch_norm = kMissingIndex;
119 int side_input = kMissingIndex;
120 int activation = kMissingIndex;
121 // Add node that will be invalidated by fusing side input and fused batch norm
122 int invalidated = kMissingIndex;
123 };
124
125 // FusedBatchNormGrad with fused side output and/or activation.
126 struct FusedBatchNormGradEx {
127 int fused_batch_norm_grad = kMissingIndex;
128 int activation_grad = kMissingIndex;
129 int side_input_grad = kMissingIndex;
130 // Add node of the forward pass to access its "offset" input.
131 int fwd_fused_batch_norm = kMissingIndex;
132 };
133
134 // TensorToHashBucket that can be replaced with AsString + StringToHashBucket.
135 // We also include the fanin node of AsString ("pre_as_string") to determine the
136 // device.
137 struct TensorToHashBucket {
138 TensorToHashBucket() = default;
TensorToHashBuckettensorflow::grappler::__anon2f31cdb30111::TensorToHashBucket139 explicit TensorToHashBucket(int op1, int op2, int op3)
140 : pre_as_string(op1), as_string(op2), string_to_hash_bucket(op3) {}
141
142 int pre_as_string = kMissingIndex;
143 int as_string = kMissingIndex;
144 int string_to_hash_bucket = kMissingIndex;
145 };
146
147 // Pad followed by Conv3D/FusedConv3D
148 struct PadWithConv3D {
149 PadWithConv3D() = default;
PadWithConv3Dtensorflow::grappler::__anon2f31cdb30111::PadWithConv3D150 PadWithConv3D(int contraction_idx, int pad_idx, int padding_const_idx)
151 : contraction_idx(contraction_idx),
152 pad_idx(pad_idx),
153 padding_const_idx(padding_const_idx) {}
154
155 int contraction_idx = kMissingIndex;
156 int pad_idx = kMissingIndex;
157 int padding_const_idx = kMissingIndex;
158 };
159
160 // Contraction node followed by a BiasAdd.
161 struct ContractionWithBiasAdd {
162 ContractionWithBiasAdd() = default;
ContractionWithBiasAddtensorflow::grappler::__anon2f31cdb30111::ContractionWithBiasAdd163 ContractionWithBiasAdd(int contraction, int bias_add, int bias_port)
164 : contraction(contraction), bias_add(bias_add), bias_port(bias_port) {}
165
166 int contraction = kMissingIndex;
167 int bias_add = kMissingIndex;
168 int bias_port = 1;
169 };
170
171 // Contraction node followed by a BiasAdd and Activation.
172 struct ContractionWithBiasAddAndActivation {
173 ContractionWithBiasAddAndActivation() = default;
ContractionWithBiasAddAndActivationtensorflow::grappler::__anon2f31cdb30111::ContractionWithBiasAddAndActivation174 ContractionWithBiasAddAndActivation(int contraction, int bias_add,
175 int activation, int bias_port)
176 : contraction(contraction),
177 bias_add(bias_add),
178 activation(activation),
179 bias_port(bias_port) {}
180
181 int contraction = kMissingIndex;
182 int bias_add = kMissingIndex;
183 int activation = kMissingIndex;
184 int bias_port = 1;
185 };
186
187 // Contraction node followed by a Squeeze and BiasAdd.
188 struct ContractionWithSqueezeAndBiasAdd {
189 ContractionWithSqueezeAndBiasAdd() = default;
ContractionWithSqueezeAndBiasAddtensorflow::grappler::__anon2f31cdb30111::ContractionWithSqueezeAndBiasAdd190 ContractionWithSqueezeAndBiasAdd(int contraction, int squeeze, int bias_add)
191 : contraction(contraction), squeeze(squeeze), bias_add(bias_add) {}
192
193 int contraction = kMissingIndex;
194 int squeeze = kMissingIndex;
195 int bias_add = kMissingIndex;
196 };
197
198 // Contraction node followed by a FusedBatchNorm.
199 struct ContractionWithBatchNorm {
200 ContractionWithBatchNorm() = default;
ContractionWithBatchNormtensorflow::grappler::__anon2f31cdb30111::ContractionWithBatchNorm201 ContractionWithBatchNorm(int contraction, int fused_batch_norm,
202 float epsilon = 0.0)
203 : contraction(contraction),
204 fused_batch_norm(fused_batch_norm),
205 epsilon(epsilon) {}
206
207 int contraction = kMissingIndex;
208 int fused_batch_norm = kMissingIndex;
209 float epsilon = 0.0;
210 };
211
212 // Contraction node followed by a FusedBatchNorm and Activation.
213 struct ContractionWithBatchNormAndActivation {
214 ContractionWithBatchNormAndActivation() = default;
ContractionWithBatchNormAndActivationtensorflow::grappler::__anon2f31cdb30111::ContractionWithBatchNormAndActivation215 ContractionWithBatchNormAndActivation(int contraction, int fused_batch_norm,
216 int activation, float epsilon = 0.0)
217 : contraction(contraction),
218 fused_batch_norm(fused_batch_norm),
219 activation(activation),
220 epsilon(epsilon) {}
221
222 int contraction = kMissingIndex;
223 int fused_batch_norm = kMissingIndex;
224 int activation = kMissingIndex;
225 float epsilon = 0.0;
226 };
227
228 // Contraction node followed by a BiasAdd and Add.
229 struct ContractionWithBiasAddAndAdd {
230 ContractionWithBiasAddAndAdd() = default;
ContractionWithBiasAddAndAddtensorflow::grappler::__anon2f31cdb30111::ContractionWithBiasAddAndAdd231 ContractionWithBiasAddAndAdd(int contraction, int bias_add, int add,
232 int port_id, int bias_port)
233 : contraction(contraction),
234 bias_add(bias_add),
235 add(add),
236 port_id(port_id),
237 bias_port(bias_port) {}
238
239 int contraction = kMissingIndex;
240 int bias_add = kMissingIndex;
241 int add = kMissingIndex;
242 int port_id = 0;
243 int bias_port = 1;
244 };
245
246 // Contraction node followed by a BiasAdd, Add and Relu.
247 // Plus Tanh and Sigmoid for MatMul in MKL
248 struct ContractionWithBiasAndAddActivation {
249 ContractionWithBiasAndAddActivation() = default;
ContractionWithBiasAndAddActivationtensorflow::grappler::__anon2f31cdb30111::ContractionWithBiasAndAddActivation250 ContractionWithBiasAndAddActivation(int contraction, int bias_add, int add,
251 int port_id, int activation,
252 int bias_port)
253 : contraction(contraction),
254 bias_add(bias_add),
255 add(add),
256 port_id(port_id),
257 activation(activation),
258 bias_port(bias_port) {}
259
260 int contraction = kMissingIndex;
261 int bias_add = kMissingIndex;
262 int add = kMissingIndex;
263 int port_id = 0;
264 int activation = kMissingIndex;
265 int bias_port = 1;
266 };
267
IsInPreserveSet(const RemapperContext & ctx,const NodeDef * node)268 bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) {
269 return ctx.nodes_to_preserve.count(node->name()) > 0;
270 }
271
HaveSameDataType(const NodeDef * lhs,const NodeDef * rhs,const string & type_attr="T")272 bool HaveSameDataType(const NodeDef* lhs, const NodeDef* rhs,
273 const string& type_attr = "T") {
274 DataType lhs_attr = GetDataTypeFromAttr(*lhs, type_attr);
275 DataType rhs_attr = GetDataTypeFromAttr(*rhs, type_attr);
276
277 return lhs_attr != DT_INVALID && rhs_attr != DT_INVALID &&
278 lhs_attr == rhs_attr;
279 }
280
HasDataType(const NodeDef * node,const DataType & expected,const string & type_attr="T")281 bool HasDataType(const NodeDef* node, const DataType& expected,
282 const string& type_attr = "T") {
283 DataType dtype = GetDataTypeFromAttr(*node, type_attr);
284 return dtype == expected;
285 }
286
IsCpuCompatibleDataType(const NodeDef * contraction,const string & type_attr="T")287 bool IsCpuCompatibleDataType(const NodeDef* contraction,
288 const string& type_attr = "T") {
289 DataType dtype = GetDataTypeFromAttr(*contraction, type_attr);
290 // Stock TF without oneDNN build will always be `false`.
291 bool is_one_dnn_enabled = IsMKLEnabled();
292
293 if (is_one_dnn_enabled) {
294 return (IsConv2D(*contraction) || IsDepthwiseConv2dNative(*contraction) ||
295 IsMatMul(*contraction) || IsConv3D(*contraction) ||
296 IsAnyBatchMatMul(*contraction)) &&
297 (dtype == DT_FLOAT || dtype == DT_BFLOAT16);
298 }
299 if (IsConv2D(*contraction)) {
300 return dtype == DT_FLOAT || dtype == DT_DOUBLE;
301 } else if (IsMatMul(*contraction)) {
302 return dtype == DT_FLOAT;
303 } else {
304 return false;
305 }
306 }
307
IsGpuCompatibleDataType(const NodeDef * contraction,const string & type_attr="T")308 bool IsGpuCompatibleDataType(const NodeDef* contraction,
309 const string& type_attr = "T") {
310 DataType dtype = GetDataTypeFromAttr(*contraction, type_attr);
311 if (IsConv2D(*contraction) || IsMatMul(*contraction)) {
312 return dtype == DT_FLOAT || dtype == DT_HALF;
313 } else {
314 return false;
315 }
316 }
317
IsCpuCompatibleDataFormat(const RemapperContext & ctx,const NodeDef * conv_node)318 bool IsCpuCompatibleDataFormat(const RemapperContext& ctx,
319 const NodeDef* conv_node) {
320 const string& data_format = conv_node->attr().at(kDataFormat).s();
321 if (IsConv2D(*conv_node)) {
322 return data_format == "NHWC" || (IsMKLEnabled() && data_format == "NCHW") ||
323 (ctx.cpu_layout_conversion == RewriterConfig::NHWC_TO_NCHW &&
324 data_format == "NCHW");
325 } else if (IsConv3D(*conv_node)) {
326 return data_format == "NDHWC" || (IsMKLEnabled() && data_format == "NCDHW");
327 } else {
328 return false;
329 }
330 }
331
BlasLtMatmulEnabled()332 bool BlasLtMatmulEnabled() {
333 static bool is_enabled = [] {
334 bool is_enabled = false;
335 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
336 "TF_USE_CUBLASLT", /*default_val=*/false, &is_enabled));
337 return is_enabled;
338 }();
339 return is_enabled;
340 }
341
IsGpuCompatibleDataFormat(const RemapperContext & ctx,const NodeDef * conv2d)342 bool IsGpuCompatibleDataFormat(const RemapperContext& ctx,
343 const NodeDef* conv2d) {
344 DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
345 const string& data_format = conv2d->attr().at(kDataFormat).s();
346 return data_format == "NHWC" || data_format == "NCHW";
347 }
348
IsCpuCompatibleConv2D(const RemapperContext & ctx,const NodeDef * conv2d)349 bool IsCpuCompatibleConv2D(const RemapperContext& ctx, const NodeDef* conv2d) {
350 DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
351 return NodeIsOnCpu(conv2d) && IsCpuCompatibleDataType(conv2d) &&
352 IsCpuCompatibleDataFormat(ctx, conv2d);
353 }
354
IsCpuCompatibleConv3D(const RemapperContext & ctx,const NodeDef * conv3d)355 bool IsCpuCompatibleConv3D(const RemapperContext& ctx, const NodeDef* conv3d) {
356 DCHECK(IsConv3D(*conv3d)) << "Expected Conv3D op";
357 return NodeIsOnCpu(conv3d) && IsCpuCompatibleDataType(conv3d) &&
358 IsCpuCompatibleDataFormat(ctx, conv3d);
359 }
360
IsGpuCompatibleConv2D(const RemapperContext & ctx,const NodeDef * conv2d)361 bool IsGpuCompatibleConv2D(const RemapperContext& ctx, const NodeDef* conv2d) {
362 DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
363 return NodeIsOnGpu(conv2d) && IsGpuCompatibleDataType(conv2d) &&
364 IsGpuCompatibleDataFormat(ctx, conv2d);
365 }
366
IsGpuCompatibleMatMul(const RemapperContext & ctx,const NodeDef * matmul)367 bool IsGpuCompatibleMatMul(const RemapperContext& ctx, const NodeDef* matmul) {
368 DCHECK(IsMatMul(*matmul)) << "Expected MatMul op";
369 return BlasLtMatmulEnabled() && NodeIsOnGpu(matmul) &&
370 IsGpuCompatibleDataType(matmul);
371 }
372
IsCpuCompatibleMatMul(const RemapperContext & ctx,const NodeDef * matmul)373 bool IsCpuCompatibleMatMul(const RemapperContext& ctx, const NodeDef* matmul) {
374 DCHECK(IsMatMul(*matmul)) << "Expected MatMul op";
375 return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul);
376 }
377
IsCpuCompatibleDepthwiseConv2dNative(const NodeDef * dw_conv2d)378 bool IsCpuCompatibleDepthwiseConv2dNative(const NodeDef* dw_conv2d) {
379 DCHECK(IsDepthwiseConv2dNative(*dw_conv2d))
380 << "Expected DepthwiseConv2dNative op";
381 return NodeIsOnCpu(dw_conv2d) && IsCpuCompatibleDataType(dw_conv2d);
382 }
383
384 // Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU.
385 template <typename Pattern>
IsCpuCompatible(const RemapperContext & ctx,const Pattern & matched)386 bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) {
387 const NodeDef& node = ctx.graph_view.graph()->node(matched.contraction);
388 if (IsConv2D(node)) {
389 return IsCpuCompatibleConv2D(ctx, &node);
390 } else if (IsDepthwiseConv2dNative(node)) {
391 return (IsMKLEnabled() && IsCpuCompatibleDepthwiseConv2dNative(&node));
392 } else if (IsMatMul(node)) {
393 return IsCpuCompatibleMatMul(ctx, &node);
394 } else if (IsConv3D(node)) {
395 return (IsMKLEnabled() && IsCpuCompatibleConv3D(ctx, &node));
396 } else {
397 return false;
398 }
399 }
400
IsSupportedActivation(const NodeDef & node)401 bool IsSupportedActivation(const NodeDef& node) {
402 bool is_default_supported =
403 IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
404 bool is_mkl_specific = IsMKLEnabled() && (IsTanh(node) || IsSigmoid(node));
405 return (is_default_supported || is_mkl_specific);
406 }
407
RuntimeFusionEnabled(const Cluster * cluster)408 bool RuntimeFusionEnabled(const Cluster* cluster) {
409 static bool is_enabled = [&] {
410 #if CUDNN_VERSION >= 8400
411 // Cudnn runtime fusion feature is recommended for Ampere GPUs or later.
412 // For pre-Ampere GPUs, the overhead of runtime compilation would be very
413 // large and there are more limitations of supported cases.
414 if (!cluster) return false;
415 auto devices = cluster->GetDevices();
416 int num_gpus = 0;
417 int num_ampere = 0;
418 for (const auto& d : devices) {
419 if (d.second.type() == "GPU") {
420 num_gpus++;
421 auto cc_it = d.second.environment().find("architecture");
422 if (cc_it != d.second.environment().end()) {
423 double compute_capability = 0.0;
424 if (absl::SimpleAtod(cc_it->second, &compute_capability) &&
425 compute_capability >= 8.0) {
426 num_ampere++;
427 }
428 }
429 }
430 }
431 bool runtime_fusion_enabled = CudnnUseRuntimeFusion() &&
432 CudnnUseFrontend() && num_gpus > 0 &&
433 num_gpus == num_ampere;
434
435 if (CudnnUseRuntimeFusion() && !runtime_fusion_enabled) {
436 VLOG(1) << "Enabling Cudnn with runtime compilation requires the "
437 << "Cudnn frontend and Ampere GPUs or later, but we got "
438 << "Cudnn frontend is "
439 << (CudnnUseFrontend() ? "enabled" : "disabled") << " and "
440 << num_ampere << " Ampere GPU(s) out of total " << num_gpus
441 << " GPU(s)";
442 }
443
444 return runtime_fusion_enabled;
445 #else
446 return false;
447 #endif
448 }();
449 return is_enabled;
450 }
451
452 // Checks if we can rewrite a pattern to the `_FusedConv2D` on GPU device.
IsGpuCompatible(const RemapperContext & ctx,const ContractionWithBiasAddAndActivation & matched,const Cluster * cluster)453 bool IsGpuCompatible(const RemapperContext& ctx,
454 const ContractionWithBiasAddAndActivation& matched,
455 const Cluster* cluster) {
456 #if TENSORFLOW_USE_ROCM
457 // ROCm does not support _FusedConv2D
458 return false;
459 #endif
460 // The TF->XLA bridge does not support `_Fused[Conv2D|MatMul]` so we avoid
461 // creating this op. Furthermore, XLA already does this fusion internally so
462 // there is no true benefit from doing this optimization if XLA is going to
463 // compile the unfused operations anyway.
464 if (ctx.xla_auto_clustering_on) return false;
465
466 const GraphDef* graph = ctx.graph_view.graph();
467
468 // We rely on cuDNN for fused convolution and cublasLt for fused matmul.
469 const NodeDef& activation_node = graph->node(matched.activation);
470 if (!IsSupportedActivation(activation_node)) return false;
471
472 const NodeDef& contraction_node = graph->node(matched.contraction);
473 if (IsConv2D(contraction_node)) {
474 const std::vector<OpInfo::TensorProperties>& input_props =
475 ctx.graph_properties.GetInputProperties(contraction_node.name());
476 const TensorShapeProto& filter_shape =
477 input_props.size() >= 2 ? input_props[1].shape() : TensorShapeProto();
478
479 // FusedConv2D on GPU with 1x1 convolution is marginally faster than
480 // in-graph computation in micro benchmarks (see kernels/conv_ops_test.cc),
481 // and significantly slower in large scale benchmarks.
482 bool is_spatial_conv = Rank(filter_shape) == 4 && //
483 IsKnown(filter_shape.dim(0)) && //
484 IsKnown(filter_shape.dim(1)) && //
485 filter_shape.dim(0).size() != 1 && //
486 filter_shape.dim(1).size() != 1;
487
488 // The CuDNN runtime compiled kernels support the activations of relu6,
489 // elu, leakrelu but require the in_channels and out_channels to be even and
490 // fp16 dtype.
491 bool act_requires_fp16 = IsRelu6(activation_node) ||
492 IsElu(activation_node) ||
493 IsLeakyRelu(activation_node);
494 DataType dtype = GetDataTypeFromAttr(activation_node, "T");
495 bool is_fp16 = dtype == DT_HALF;
496 bool valid_channels = Rank(filter_shape) == 4 && //
497 IsKnown(filter_shape.dim(2)) && //
498 IsKnown(filter_shape.dim(3)) && //
499 filter_shape.dim(2).size() % 2 == 0 && //
500 filter_shape.dim(3).size() % 2 == 0;
501 bool is_supported_conv =
502 is_spatial_conv &&
503 (!act_requires_fp16 ||
504 (is_fp16 && valid_channels && RuntimeFusionEnabled(cluster)));
505
506 return is_supported_conv && IsGpuCompatibleConv2D(ctx, &contraction_node);
507 } else if (IsMatMul(contraction_node)) {
508 return IsGpuCompatibleMatMul(ctx, &contraction_node);
509 }
510
511 return false;
512 }
513
514 // Checks if we can rewrite a pattern to the `_FusedMatMul` on GPU device.
IsGpuCompatible(const RemapperContext & ctx,const ContractionWithBiasAdd & matched,const Cluster * cluster)515 bool IsGpuCompatible(const RemapperContext& ctx,
516 const ContractionWithBiasAdd& matched,
517 const Cluster* cluster) {
518 #if TENSORFLOW_USE_ROCM
519 // ROCm does not support _FusedMatMul
520 return false;
521 #endif
522 // The TF->XLA bridge does not support `_FusedMatMul` so we avoid creating
523 // this op. Furthermore, XLA already does this fusion internally so there
524 // is no true benefit from doing this optimization if XLA is going to compile
525 // the unfused operations anyway.
526 if (ctx.xla_auto_clustering_on) return false;
527
528 const GraphDef* graph = ctx.graph_view.graph();
529 const NodeDef& contraction_node = graph->node(matched.contraction);
530 if (!IsMatMul(contraction_node)) return false;
531
532 return IsGpuCompatibleMatMul(ctx, &contraction_node);
533 }
534
IsGpuCompatible(const RemapperContext & ctx,const ContractionWithSqueezeAndBiasAdd & matched,const Cluster * cluster)535 bool IsGpuCompatible(const RemapperContext& ctx,
536 const ContractionWithSqueezeAndBiasAdd& matched,
537 const Cluster* cluster) {
538 return false;
539 }
540
541 // Returns true if the given pattern is supported on the assigned device.
542 template <typename Pattern>
IsDeviceCompatible(const RemapperContext & ctx,Pattern & matched,Cluster * cluster=nullptr)543 bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched,
544 Cluster* cluster = nullptr) {
545 return IsCpuCompatible(ctx, matched) ||
546 IsGpuCompatible(ctx, matched, cluster);
547 }
548
HasControlFaninOrFanout(const utils::MutableNodeView & node_view)549 inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) {
550 return node_view.NumControllingFanins() > 0 ||
551 node_view.NumControlledFanouts() > 0;
552 }
553
554 // Returns true if at most one fanout reads output at port 0 (output used once).
HasAtMostOneFanoutAtPort0(const utils::MutableNodeView & node_view)555 inline bool HasAtMostOneFanoutAtPort0(const utils::MutableNodeView& node_view) {
556 return node_view.GetRegularFanout(0).size() <= 1;
557 }
558
559 // Returns true if at most one fanout reads actual tensor data at output port 0
560 // (output used once for any data computation).
HasAtMostOneDataFanoutAtPort0(const utils::MutableNodeView & node_view)561 inline bool HasAtMostOneDataFanoutAtPort0(
562 const utils::MutableNodeView& node_view) {
563 const auto predicate = [](const auto& fanout) -> bool {
564 const NodeDef* node = fanout.node_view()->node();
565 return !IsShape(*node) && !IsRank(*node);
566 };
567 return absl::c_count_if(node_view.GetRegularFanout(0), predicate) <= 1;
568 }
569
IsConvOrMatMul(const NodeDef & node)570 bool IsConvOrMatMul(const NodeDef& node) {
571 return IsConv2D(node) || IsDepthwiseConv2dNative(node) || IsMatMul(node) ||
572 IsConv3D(node);
573 }
574
575 // Returns true if one input to Add is Conv2D/3D or DepthwiseConv2dNative or
576 // MatMul, and the other input is semantically equivalent to BiasAdd.
IsBiasSemanticAdd(const RemapperContext & ctx,const utils::MutableNodeView & node_view,int & bias_port)577 bool IsBiasSemanticAdd(const RemapperContext& ctx,
578 const utils::MutableNodeView& node_view,
579 int& bias_port) {
580 if (!IsMKLEnabled()) return false;
581
582 const auto* node_def = node_view.node();
583 if (!NodeIsOnCpu(node_def)) return false;
584 if (!IsAdd(*node_def) || node_view.NumRegularFanins() != 2) return false;
585
586 const auto& props = ctx.graph_properties.GetInputProperties(node_def->name());
587 if (props.size() < 2) return false;
588
589 const auto& regular_fanin_0 = node_view.GetRegularFanin(0);
590 const auto* node_view_0 = regular_fanin_0.node_view();
591 const auto* node_def_0 = node_view_0->node();
592 const auto& regular_fanin_1 = node_view.GetRegularFanin(1);
593 const auto* node_view_1 = regular_fanin_1.node_view();
594 const auto* node_def_1 = node_view_1->node();
595
596 if (!IsConvOrMatMul(*node_def_0) && !IsConvOrMatMul(*node_def_1))
597 return false;
598
599 auto is_channel_last_format = [](const NodeDef& node) -> bool {
600 if (node.attr().contains("data_format")) {
601 const string data_format = node.attr().at("data_format").s();
602 return (data_format == "NHWC" || data_format == "NDHWC");
603 }
604 return true;
605 };
606
607 // Currently supported data formats are NHWC and NDHWC.
608 if (!is_channel_last_format(*node_def_0) ||
609 !is_channel_last_format(*node_def_1))
610 return false;
611
612 const TensorShapeProto& prot0_shape = props[0].shape();
613 const TensorShapeProto& prot1_shape = props[1].shape();
614
615 if (prot0_shape.unknown_rank() || prot1_shape.unknown_rank() ||
616 prot0_shape.dim_size() < 1 || prot1_shape.dim_size() < 1 ||
617 !IsKnown(prot0_shape.dim(prot0_shape.dim_size() - 1)) ||
618 !IsKnown(prot1_shape.dim(prot1_shape.dim_size() - 1)))
619 return false;
620
621 // Helper function to check Add/AddV2 could be replaced with BiasAdd.
622 const auto is_supported_shape =
623 [&](const TensorShapeProto& shape,
624 const TensorShapeProto& bcast_shape) -> bool {
625 int conv_channel_dim;
626 conv_channel_dim = shape.dim(shape.dim_size() - 1).size();
627
628 if (shape.dim_size() == 4 && bcast_shape.dim_size() > 4) return false;
629 if (shape.dim_size() == 5 && bcast_shape.dim_size() > 5) return false;
630
631 if (shape.dim_size() < 2) return false;
632 // Check that the conv node's channel dim is equal to the 1-dim add node's
633 // dim
634 if (conv_channel_dim != bcast_shape.dim(bcast_shape.dim_size() - 1).size())
635 return false;
636
637 // Check that add nodes dims are all 1's except the channel dim
638 for (int i = 0; i < bcast_shape.dim_size() - 1; i++) {
639 if (1 != bcast_shape.dim(i).size()) return false;
640 }
641 return true;
642 };
643
644 if (ShapesSymbolicallyEqual(prot0_shape, prot1_shape) ||
645 !ShapesBroadcastable(prot0_shape, prot1_shape))
646 return false;
647
648 if (IsConvOrMatMul(*node_def_0)) {
649 bias_port = 1;
650 return (is_supported_shape(prot0_shape, prot1_shape));
651 } else if (IsConvOrMatMul(*node_def_1)) {
652 bias_port = 0;
653 return (is_supported_shape(prot1_shape, prot0_shape));
654 }
655 return false;
656 }
657
FindContractionWithBias(const RemapperContext & ctx,int node_index,ContractionWithBiasAdd * matched,bool check_device_compatible=true)658 bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
659 ContractionWithBiasAdd* matched,
660 bool check_device_compatible = true) {
661 const auto* node_view = ctx.graph_view.GetNode(node_index);
662 // Root of the pattern must be a BiasAdd.
663 // TODO(lyandy): Forward controls for patterns with control dependencies.
664 if (HasControlFaninOrFanout(*node_view)) return false;
665
666 const auto* node_def = node_view->node();
667 int bias_port = 1;
668 if (!IsBiasAdd(*node_def) && !IsBiasSemanticAdd(ctx, *node_view, bias_port))
669 return false;
670
671 // Input to the BiasAdd must be a Conv2D/3D or a MatMul.
672 if (node_view->NumRegularFanins() < 1) return false;
673 const auto& regular_fanin_0 = node_view->GetRegularFanin(1 - bias_port);
674 const auto* contraction_node_view = regular_fanin_0.node_view();
675 const auto* contraction_node_def = contraction_node_view->node();
676
677 // Conv2D/3D, MatMul or DepthwiseConv2D
678 bool is_contraction = IsConv2D(*contraction_node_def) ||
679 (IsConv3D(*contraction_node_def) && IsMKLEnabled()) ||
680 IsMatMul(*contraction_node_def) ||
681 IsDepthwiseConv2dNative(*contraction_node_def);
682
683 if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) ||
684 HasControlFaninOrFanout(*contraction_node_view) ||
685 !HasAtMostOneFanoutAtPort0(*contraction_node_view) ||
686 IsInPreserveSet(ctx, contraction_node_def))
687 return false;
688
689 // Check that data type and data format are supported on assigned device.
690 const ContractionWithBiasAdd pattern{contraction_node_view->node_index(),
691 node_index, bias_port};
692 if (check_device_compatible && !IsDeviceCompatible(ctx, pattern))
693 return false;
694
695 // We successfully found a {Conv2D, MatMul}+BiasAdd pattern.
696 *matched = pattern;
697
698 return true;
699 }
700
FindContractionWithBiasAndActivation(const RemapperContext & ctx,Cluster * cluster,int node_index,ContractionWithBiasAddAndActivation * matched)701 bool FindContractionWithBiasAndActivation(
702 const RemapperContext& ctx, Cluster* cluster, int node_index,
703 ContractionWithBiasAddAndActivation* matched) {
704 const auto* node_view = ctx.graph_view.GetNode(node_index);
705 // Root of the pattern must be an activation node.
706 // TODO(lyandy): Forward controls for patterns with control dependencies.
707 if (HasControlFaninOrFanout(*node_view)) return false;
708
709 const auto* node_def = node_view->node();
710 if (!IsSupportedActivation(*node_def)) return false;
711
712 // And input to the activation node must match ContractionWithBiasAdd pattern.
713 if (node_view->NumRegularFanins() < 1) return false;
714 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
715 const auto* bias_add_node_view = regular_fanin_0.node_view();
716 const auto* bias_add_node_def = bias_add_node_view->node();
717
718 ContractionWithBiasAdd base;
719 if (!FindContractionWithBias(ctx, bias_add_node_view->node_index(), &base,
720 /*check_device_compatible=*/false) ||
721 !HasAtMostOneFanoutAtPort0(*bias_add_node_view) ||
722 !HaveSameDataType(node_def, bias_add_node_def) ||
723 IsInPreserveSet(ctx, bias_add_node_def))
724 return false;
725
726 // Get the contraction node
727 const auto* contraction_node_view =
728 bias_add_node_view->GetRegularFanin(1 - base.bias_port).node_view();
729 const auto* contraction_node_def = contraction_node_view->node();
730
731 // Currently, only matmul + bias + (tanh or Sigmoid) is enabled
732 if (!IsMatMul(*contraction_node_def) &&
733 (IsTanh(*node_def) || IsSigmoid(*node_def)))
734 return false;
735
736 // Currently, only (conv | matmul) + bias + leakyrelu is enabled
737 if (!(IsConv2D(*contraction_node_def) || IsMatMul(*contraction_node_def) ||
738 (IsConv3D(*contraction_node_def) && IsMKLEnabled())) &&
739 IsLeakyRelu(*node_def))
740 return false;
741
742 // Check that data type and data format are supported on assigned device.
743 const ContractionWithBiasAddAndActivation pattern{
744 base.contraction, base.bias_add, node_index, base.bias_port};
745 if (!IsDeviceCompatible(ctx, pattern, cluster)) return false;
746
747 // We successfully found a {Conv2D, MatMul}+BiasAdd+Activation pattern.
748 *matched = pattern;
749
750 return true;
751 }
752
FindConvWithSqueezeAndBias(const RemapperContext & ctx,int node_index,ContractionWithSqueezeAndBiasAdd * matched)753 bool FindConvWithSqueezeAndBias(const RemapperContext& ctx, int node_index,
754 ContractionWithSqueezeAndBiasAdd* matched) {
755 const auto* node_view = ctx.graph_view.GetNode(node_index);
756 // TODO(lyandy): Forward controls for patterns with control dependencies.
757 if (HasControlFaninOrFanout(*node_view)) return false;
758
759 // Root of the pattern must be a BiasAdd.
760 const auto* node_def = node_view->node();
761 if (!IsBiasAdd(*node_def)) return false;
762
763 // Input to the BiasAdd must be a Squeeze.
764 if (node_view->NumRegularFanins() < 1) return false;
765 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
766 const auto* squeeze_node_view = regular_fanin_0.node_view();
767 const auto* squeeze_node_def = squeeze_node_view->node();
768
769 if (!IsSqueeze(*squeeze_node_def) ||
770 !HaveSameDataType(node_def, squeeze_node_def, "T") ||
771 HasControlFaninOrFanout(*squeeze_node_view) ||
772 !HasAtMostOneFanoutAtPort0(*squeeze_node_view) ||
773 IsInPreserveSet(ctx, squeeze_node_def))
774 return false;
775
776 // Input to the Squeeze must be a Conv2D/3D.
777 if (squeeze_node_view->NumRegularFanins() < 1) return false;
778 const auto& squeeze_regular_fanin_0 = squeeze_node_view->GetRegularFanin(0);
779 const auto* conv_node_view = squeeze_regular_fanin_0.node_view();
780 const auto* conv_node_def = conv_node_view->node();
781
782 if (!(IsConv2D(*conv_node_def) ||
783 (IsConv3D(*conv_node_def) && IsMKLEnabled())) ||
784 !HaveSameDataType(node_def, conv_node_def, "T") ||
785 HasControlFaninOrFanout(*conv_node_view) ||
786 !HasAtMostOneFanoutAtPort0(*conv_node_view) ||
787 IsInPreserveSet(ctx, conv_node_def))
788 return false;
789
790 // Squeeze must not squeeze output channel dimension.
791 std::vector<int32> dims;
792 if (!TryGetNodeAttr(*squeeze_node_def, "squeeze_dims", &dims)) return false;
793 for (auto dim : dims) {
794 if ((dim == 3 && IsConv2D(*conv_node_def)) ||
795 (dim == 4 && IsConv3D(*conv_node_def)))
796 return false;
797 }
798
799 // Check that data type and data format are supported on assigned device.
800 const ContractionWithSqueezeAndBiasAdd pattern{
801 conv_node_view->node_index(), squeeze_node_view->node_index(),
802 node_index};
803 if (!IsDeviceCompatible(ctx, pattern)) return false;
804
805 // We successfully found a Conv2D+Squeeze+BiasAdd pattern.
806 *matched = pattern;
807
808 return true;
809 }
810
FindConv2DWithBatchNorm(const RemapperContext & ctx,int node_index,ContractionWithBatchNorm * matched)811 bool FindConv2DWithBatchNorm(const RemapperContext& ctx, int node_index,
812 ContractionWithBatchNorm* matched) {
813 const auto* node_view = ctx.graph_view.GetNode(node_index);
814 const auto* node_def = node_view->node();
815 // Root of the pattern must be a FusedBatchNorm.
816 if (!IsFusedBatchNorm(*node_def)) return false;
817
818 // FusedBatchNormV2 and V3 have an extra type parameter.
819 // Conv2D + FusedBatchNormV2/V3 fusion is currently not supported for bf16.
820 // TODO(intel-tf): enable the fusion for bf16
821 bool dtypeU_is_float = HasDataType(node_def, DT_FLOAT, "U");
822 bool dtypeT_is_bf16 = HasDataType(node_def, DT_BFLOAT16, "T");
823 if (node_view->GetOp() != "FusedBatchNorm" &&
824 (!dtypeU_is_float || dtypeT_is_bf16)) {
825 return false;
826 }
827
828 // Check that batch normalization is in inference mode.
829 const auto* training_attr = node_view->GetAttr(kIsTraining);
830 if (training_attr != nullptr && training_attr->b()) return false;
831
832 // Check that only 0th output is consumed by other nodes.
833 // TODO(lyandy): Forward controls for patterns with control dependencies.
834 if (HasControlFaninOrFanout(*node_view) ||
835 !node_view->GetRegularFanout(1).empty() || // batch_mean
836 !node_view->GetRegularFanout(2).empty() || // batch_variance
837 !node_view->GetRegularFanout(3).empty() || // reserve_space_1
838 !node_view->GetRegularFanout(4).empty()) // reserve_space_2
839 return false;
840
841 // Input to the FusedBatchNorm must be a Conv2D.
842 if (node_view->NumRegularFanins() < 1) return false;
843 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
844 const auto* conv2d_node_view = regular_fanin_0.node_view();
845 const auto* conv2d_node_def = conv2d_node_view->node();
846
847 if (!IsConv2D(*conv2d_node_def) || !NodeIsOnCpu(conv2d_node_def) ||
848 !HaveSameDataType(node_def, conv2d_node_def) ||
849 !IsCpuCompatibleDataType(conv2d_node_def) ||
850 !IsCpuCompatibleDataFormat(ctx, conv2d_node_def) ||
851 HasControlFaninOrFanout(*conv2d_node_view) ||
852 !HasAtMostOneFanoutAtPort0(*conv2d_node_view) ||
853 IsInPreserveSet(ctx, conv2d_node_def))
854 return false;
855
856 // We successfully found a Conv2D+FusedBatchNorm pattern.
857 matched->contraction = conv2d_node_view->node_index();
858 matched->fused_batch_norm = node_index;
859 if (!TryGetNodeAttr(*node_def, "epsilon", &matched->epsilon)) return false;
860
861 return true;
862 }
863
FindConv2DWithBatchNormAndActivation(const RemapperContext & ctx,int node_index,ContractionWithBatchNormAndActivation * matched)864 bool FindConv2DWithBatchNormAndActivation(
865 const RemapperContext& ctx, int node_index,
866 ContractionWithBatchNormAndActivation* matched) {
867 const auto* node_view = ctx.graph_view.GetNode(node_index);
868 // TODO(lyandy): Forward controls for patterns with control dependencies.
869 if (HasControlFaninOrFanout(*node_view)) return false;
870
871 // Root of the pattern must be an activation node.
872 const auto* node_def = node_view->node();
873 if (!IsSupportedActivation(*node_def)) return false;
874
875 // Need to test and enable in Kernel Op before enabling
876 // this activation TODO(intel-tf)
877 if (IsSigmoid(*node_def)) return false;
878
879 // And input to the activation node must match Conv2DWithBatchNorm pattern.
880 if (node_view->NumRegularFanins() < 1) return false;
881 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
882 const auto* batch_norm_node_view = regular_fanin_0.node_view();
883
884 ContractionWithBatchNorm base;
885 if (!FindConv2DWithBatchNorm(ctx, batch_norm_node_view->node_index(), &base))
886 return false;
887
888 const auto* fused_batch_norm_node_view =
889 ctx.graph_view.GetNode(base.fused_batch_norm);
890 const auto* fused_batch_norm_node_def = fused_batch_norm_node_view->node();
891 if (!HasAtMostOneFanoutAtPort0(*fused_batch_norm_node_view) ||
892 !HaveSameDataType(node_def, fused_batch_norm_node_def) ||
893 IsInPreserveSet(ctx, fused_batch_norm_node_def))
894 return false;
895
896 // We successfully found a Conv2D+FusedBatchNorm+Activation pattern.
897 matched->contraction = base.contraction;
898 matched->fused_batch_norm = base.fused_batch_norm;
899 matched->activation = node_index;
900 matched->epsilon = base.epsilon;
901
902 return true;
903 }
904
905 // As AddN has multiple inputs, this function tries to find Conv2D + Bias
906 // pattern in specific input port.
FindContractionWithBiasInPort(const RemapperContext & ctx,const utils::MutableNodeView & add_node_view,const NodeDef & add_node_def,int port_id,ContractionWithBiasAdd * base)907 bool FindContractionWithBiasInPort(const RemapperContext& ctx,
908 const utils::MutableNodeView& add_node_view,
909 const NodeDef& add_node_def, int port_id,
910 ContractionWithBiasAdd* base) {
911 // Input to AddN must match ContractionWithBiasAdd pattern.
912 if (add_node_view.NumRegularFanins() < port_id + 1) return false;
913 const auto& bias_add_node_view =
914 add_node_view.GetRegularFanin(port_id).node_view();
915 if (bias_add_node_view == nullptr) return false;
916 const auto* bias_add_node_def = bias_add_node_view->node();
917
918 if (!FindContractionWithBias(ctx, bias_add_node_view->node_index(), base,
919 /*check_device_compatible=*/false))
920 return false;
921 if (!HasAtMostOneFanoutAtPort0(*bias_add_node_view) ||
922 !HaveSameDataType(&add_node_def, bias_add_node_def) ||
923 IsInPreserveSet(ctx, bias_add_node_def))
924 return false;
925 return true;
926 }
927
IsAddWithNoBroadcast(const RemapperContext & ctx,const NodeDef & node)928 bool IsAddWithNoBroadcast(const RemapperContext& ctx, const NodeDef& node) {
929 if (!IsAdd(node)) return false;
930
931 // Check if this is case of broadcasting - Add node supports broadcasting.
932 const auto& props = ctx.graph_properties.GetInputProperties(node.name());
933 if (props.size() == 2 &&
934 ShapesSymbolicallyEqual(props[0].shape(), props[1].shape())) {
935 return true;
936 }
937 return false;
938 }
939
FindPadWithConv3D(const RemapperContext & ctx,int node_index,PadWithConv3D * matched)940 bool FindPadWithConv3D(const RemapperContext& ctx, int node_index,
941 PadWithConv3D* matched) {
942 if (!IsMKLEnabled()) return false;
943 const auto* node_view = ctx.graph_view.GetNode(node_index);
944 const auto* node_def = node_view->node();
945 // The optimization is only for CPU
946 if (!NodeIsOnCpu(node_def)) return false;
947 // Root of the pattern must be a Conv3D or _FusedConv3D
948 if (!(IsConv3D(*node_def) || node_def->op() == kFusedConv3D)) return false;
949 if (!(HasDataType(node_def, DT_FLOAT) || HasDataType(node_def, DT_BFLOAT16)))
950 return false;
951
952 // Input to Conv3D/_FusedConv3D must be a Pad
953 if (node_view->NumRegularFanins() < 1) return false;
954 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
955 const auto* pad_node_view = regular_fanin_0.node_view();
956 const auto* pad_node_def = pad_node_view->node();
957 const auto& padding_const = pad_node_view->GetRegularFanin(1);
958 const auto* padding_const_node_view = padding_const.node_view();
959
960 if (!(pad_node_def->op() == "Pad") ||
961 !HaveSameDataType(node_def, pad_node_def))
962 return false;
963 const PadWithConv3D pattern{node_view->node_index(),
964 pad_node_view->node_index(),
965 padding_const_node_view->node_index()};
966
967 // Successfully found a Pad+{Conv3D, _FusedConv3D} pattern.
968 *matched = pattern;
969 return true;
970 }
971
FindContractionWithBiasAddAndAdd(const RemapperContext & ctx,const utils::MutableNodeView & node_view,ContractionWithBiasAddAndAdd * matched)972 bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx,
973 const utils::MutableNodeView& node_view,
974 ContractionWithBiasAddAndAdd* matched) {
975 // Fusion with AddN is supported only when it has two inputs.
976 // TODO(lyandy): Forward controls for patterns with control dependencies.
977 if (HasControlFaninOrFanout(node_view) || node_view.NumRegularFanins() != 2)
978 return false;
979
980 // Root of the pattern must be a AddN or Add with same input shapes
981 // (no broadcasting).
982 const auto* node_def = node_view.node();
983 if (!IsAddN(*node_def) && !IsAddWithNoBroadcast(ctx, *node_def)) return false;
984
985 if (!NodeIsOnCpu(node_def)) return false;
986
987 // MKL AddN ops only support float and bfloat16 data types.
988 if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16))
989 return false;
990
991 ContractionWithBiasAdd base;
992 matched->port_id = 0;
993
994 // Find the conv+bias pattern in specific port.
995 if (!FindContractionWithBiasInPort(ctx, node_view, *node_def,
996 matched->port_id, &base)) {
997 matched->port_id = 1;
998 if (!FindContractionWithBiasInPort(ctx, node_view, *node_def,
999 matched->port_id, &base)) {
1000 return false;
1001 }
1002 }
1003
1004 // We successfully found a {Conv2D,Conv3D}+BiasAdd+{AddN,Add} pattern.
1005 matched->contraction = base.contraction;
1006 matched->bias_add = base.bias_add;
1007 matched->add = node_view.node_index();
1008 matched->bias_port = base.bias_port;
1009
1010 return true;
1011 }
1012
FindContractionWithBiasAddAndAdd(const RemapperContext & ctx,int node_index,ContractionWithBiasAddAndAdd * matched)1013 bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx,
1014 int node_index,
1015 ContractionWithBiasAddAndAdd* matched) {
1016 const auto* node_view = ctx.graph_view.GetNode(node_index);
1017 return FindContractionWithBiasAddAndAdd(ctx, *node_view, matched);
1018 }
1019
FindContractionWithBiasAndAddActivation(const RemapperContext & ctx,int node_index,ContractionWithBiasAndAddActivation * matched)1020 bool FindContractionWithBiasAndAddActivation(
1021 const RemapperContext& ctx, int node_index,
1022 ContractionWithBiasAndAddActivation* matched) {
1023 const auto* node_view = ctx.graph_view.GetNode(node_index);
1024 // TODO(lyandy): Forward controls for patterns with control dependencies.
1025 if (HasControlFaninOrFanout(*node_view)) return false;
1026
1027 // Root of the pattern must be an activation node.
1028 const auto* node_def = node_view->node();
1029 if (node_def == nullptr) return false;
1030 if (!IsSupportedActivation(*node_def)) return false;
1031
1032 if (!NodeIsOnCpu(node_def)) return false;
1033
1034 // Currently, Contraction + Bias + Add + Tanh pattern is not supported
1035 if (IsTanh(*node_def)) return false;
1036
1037 // Need to test and enable in Kernel Op before enabling
1038 // this activation. TODO(intel-tf)
1039 if (IsSigmoid(*node_def)) return false;
1040
1041 // MKL activation op only supports float and bfloat16 data types.
1042 if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16))
1043 return false;
1044
1045 // And input to activation must match ContractionWithBiasAddAndAdd pattern.
1046 if (node_view->NumRegularFanins() < 1) return false;
1047 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
1048 const auto* add_node_view = regular_fanin_0.node_view();
1049
1050 ContractionWithBiasAddAndAdd base;
1051
1052 if (!FindContractionWithBiasAddAndAdd(ctx, *add_node_view, &base)) {
1053 return false;
1054 }
1055
1056 // Get the contraction node
1057 const auto* bias_add_node_view =
1058 add_node_view->GetRegularFanin(base.port_id).node_view();
1059 const auto* contraction_node_view =
1060 bias_add_node_view->GetRegularFanin(0).node_view();
1061 const auto* contraction_node_def = contraction_node_view->node();
1062
1063 // Currently, only conv + bias + add + leakyrelu is enabled
1064 if (!(IsConv2D(*contraction_node_def) || IsConv3D(*contraction_node_def)) &&
1065 IsLeakyRelu(*node_def))
1066 return false;
1067 // Conv3D fusion is available with oneDNN enabled
1068 if (IsConv3D(*contraction_node_def) && !IsMKLEnabled()) return false;
1069
1070 // We successfully found a Conv2D+BiasAdd+AddN+activation pattern
1071 // or Conv3D+BiasAdd+AddN+activation pattern
1072 const ContractionWithBiasAndAddActivation pattern{
1073 base.contraction, base.bias_add, base.add,
1074 base.port_id, node_index, base.bias_port};
1075 *matched = pattern;
1076
1077 return true;
1078 }
1079
VerifyConstants(RemapperContext * ctx,std::map<string,int> * nodes_map,std::map<string,float> * values_map)1080 inline bool VerifyConstants(RemapperContext* ctx,
1081 std::map<string, int>* nodes_map,
1082 std::map<string, float>* values_map) {
1083 using utils::MutableNodeView;
1084 for (auto it = values_map->begin(); it != values_map->end(); ++it) {
1085 int node_idx = nodes_map->at(it->first);
1086 MutableNodeView* node_view = ctx->graph_view.GetNode(node_idx);
1087 NodeDef* node_def = node_view->node();
1088 Tensor const_tensor;
1089 if (node_def != nullptr && node_def->op() == "Const" &&
1090 const_tensor.FromProto(node_def->attr().at("value").tensor())) {
1091 if (const_tensor.NumElements() == 1) {
1092 DataType dtype = const_tensor.dtype();
1093 float const_value;
1094 if (dtype == DT_FLOAT) {
1095 const_value = const_tensor.flat<float>()(0);
1096 } else if (dtype == DT_BFLOAT16) {
1097 const_value = const_tensor.flat<bfloat16>()(0);
1098 } else if (dtype == DT_HALF) {
1099 const_value = const_tensor.flat<Eigen::half>()(0);
1100 } else {
1101 return false;
1102 }
1103 if (std::abs(const_value - it->second) > 1e-2) return false;
1104 } else {
1105 return false;
1106 }
1107 } else {
1108 return false;
1109 }
1110 }
1111 return true;
1112 }
1113
1114 // Gelu in python api generates a number of nodes in the graph. Depending on the
1115 // parmeter `approximate={True/False}` different types of ops are generated. We
1116 // distinguish them as `GeluExact` that uses Erf and `GeluApproximate` that
1117 // uses Tanh.
FindMatMulBiasAddAndGelu(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices,bool * is_gelu_approximate)1118 bool FindMatMulBiasAddAndGelu(RemapperContext* ctx, int node_index,
1119 std::map<string, int>* matched_nodes_map,
1120 std::set<int>* remove_node_indices,
1121 bool* is_gelu_approximate) {
1122 // Gelu fusion is enabled with oneDNN or cublasLt library.
1123 if (!IsMKLEnabled() && !BlasLtMatmulEnabled()) return false;
1124
1125 using utils::MatchingDirection;
1126 using utils::NodeStatus;
1127 // clang-format off
1128 utils::OpTypePattern gelu_exact_pattern =
1129 {"Mul", "output", NodeStatus::kReplace,
1130 {
1131 {"Mul", "erf_plus_one_times_one_half", NodeStatus::kRemove,
1132 {
1133 {"AddV2", "erf_plus_one", NodeStatus::kRemove,
1134 {
1135 {"Erf", "erf", NodeStatus::kRemove,
1136 {
1137 {"Mul", "bias_add_times_square_root_one_half", NodeStatus::kRemove,
1138 {
1139 {"BiasAdd", "bias_add", NodeStatus::kRemove},
1140 {"Const", "square_root_one_half", NodeStatus::kRemain}
1141 }
1142 }
1143 }
1144 },
1145 {"Const", "one", NodeStatus::kRemain}
1146 }
1147 },
1148 {"Const", "one_half", NodeStatus::kRemain}
1149 }
1150 },
1151 {"BiasAdd", "bias_add", NodeStatus::kRemove,
1152 {
1153 {"MatMul", "matmul", NodeStatus::kRemove},
1154 {"*", "bias", NodeStatus::kRemain}
1155 }
1156 }
1157 }
1158 };
1159 // clang-format on
1160
1161 // Gelu approximate uses Pow(x, 3). On GPU, it is a single Pow() node, but on
1162 // CPU, it is optimized by arithmetic optimizer as Mul(x, Square(x)) with an
1163 // arifact of control dependency. So we try to match pattern at second pass of
1164 // remapper which reccieves _FusedMatMul (MatMul + BiasAdd) with control
1165 // dependency removed.
1166 // clang-format off
1167 utils::OpTypePattern subgraph_gpu =
1168 {"Mul", "mul", NodeStatus::kRemove,
1169 {
1170 {"Pow", "pow", NodeStatus::kRemove,
1171 {
1172 {"_FusedMatMul", "matmul", NodeStatus::kRemove},
1173 {"Const", "three", NodeStatus::kRemain}
1174 }
1175 },
1176 {"Const", "empirical_const", NodeStatus::kRemain}
1177 }
1178 };
1179 utils::OpTypePattern subgraph_cpu =
1180 {"Mul", "mul", NodeStatus::kRemove,
1181 {
1182 {"Mul", "empirical_const_times_matmul", NodeStatus::kRemove,
1183 {
1184 {"Const", "empirical_const", NodeStatus::kRemain},
1185 {"_FusedMatMul", "matmul", NodeStatus::kRemove}
1186 }
1187 },
1188 {"Square", "square", NodeStatus::kRemove,
1189 {
1190 {"_FusedMatMul", "matmul", NodeStatus::kRemove}
1191 }
1192 }
1193 }
1194 };
1195 // clang-format on
1196
1197 utils::MutableNodeView* node_view = ctx->graph_view.GetNode(node_index);
1198 const NodeDef* node_def = node_view->node();
1199 bool root_on_gpu = NodeIsOnGpu(node_def);
1200 utils::OpTypePattern* subgraph_pattern =
1201 root_on_gpu ? &subgraph_gpu : &subgraph_cpu;
1202
1203 // clang-format off
1204 utils::OpTypePattern gelu_approximate_pattern =
1205 {"Mul", "output", NodeStatus::kReplace,
1206 {
1207 {"Mul", "tanh_plus_one_times_one_half", NodeStatus::kRemove,
1208 {
1209 {"AddV2", "tanh_plus_one", NodeStatus::kRemove,
1210 {
1211 {"Tanh", "tanh", NodeStatus::kRemove,
1212 {
1213 {"Mul", "matmul_plus_mul_times_square_root_two_over_pi", NodeStatus::kRemove,
1214 {
1215 {"AddV2", "matmul_plus_mul", NodeStatus::kRemove,
1216 {
1217 {"_FusedMatMul", "matmul", NodeStatus::kRemove},
1218 *subgraph_pattern
1219 }
1220 },
1221 {"Const", "square_root_two_over_pi", NodeStatus::kRemain}
1222 }
1223 }
1224 }
1225 },
1226 {"Const", "one", NodeStatus::kRemain}
1227 }
1228 },
1229 {"Const", "one_half", NodeStatus::kRemain}
1230 }
1231 },
1232 {"_FusedMatMul", "matmul", NodeStatus::kRemove}
1233 }
1234 };
1235 // clang-format on
1236
1237 bool found_gelu_exact = false;
1238 bool found_gelu_approximate = false;
1239 utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
1240 &(ctx->graph_view));
1241 // Find GeluExact
1242 matched_nodes_map->clear();
1243 remove_node_indices->clear();
1244 found_gelu_exact = graph_matcher.GetMatchedNodes(
1245 gelu_exact_pattern, ctx->nodes_to_preserve, node_view, matched_nodes_map,
1246 remove_node_indices);
1247 // Find GeluApproximate
1248 if (!found_gelu_exact) {
1249 matched_nodes_map->clear();
1250 remove_node_indices->clear();
1251 found_gelu_approximate = graph_matcher.GetMatchedNodes(
1252 gelu_approximate_pattern, ctx->nodes_to_preserve, node_view,
1253 matched_nodes_map, remove_node_indices);
1254 }
1255
1256 // Pattern matcher does subgraph matching based on op types only. The matcher
1257 // also does a sanity check on nodes tagged as `kRemove`, i.e., they do not
1258 // have any consumer outside the matched nodes. In order to replace the
1259 // subgraph, we need additional checks, for example, if the key ops have been
1260 // placed on CPU or GPU, desired data type, const has desired value etc. For
1261 // the following fusion: MatMul + BiasAdd + Gelu (disintegrated into smaller
1262 // ops), we check if (i) MatMul op is CpuCompatible or GpuComptible, (ii)
1263 // const nodes have desired values.
1264 if (found_gelu_exact) {
1265 if (!IsMKLEnabled()) return false;
1266 // Check if the MatMul to be fused is CPU compatible
1267 // TODO(kaixih@nvidia): Add GPU support when cublastLt supports the exact
1268 // form.
1269 NodeDef* matmul_node =
1270 ctx->graph_view.GetNode(matched_nodes_map->at("matmul"))->node();
1271 if (!IsCpuCompatibleMatMul(*ctx, matmul_node)) {
1272 matched_nodes_map->clear();
1273 remove_node_indices->clear();
1274 return false;
1275 }
1276 // Check if the matched constants have desired values.
1277 if (found_gelu_exact) {
1278 std::map<string, float> values_map = {
1279 {"square_root_one_half", 0.707106}, {"one", 1.0}, {"one_half", 0.5}};
1280 if (!VerifyConstants(ctx, matched_nodes_map, &values_map)) return false;
1281 }
1282 } else if (found_gelu_approximate) {
1283 NodeDef* matmul_node =
1284 ctx->graph_view.GetNode(matched_nodes_map->at("matmul"))->node();
1285
1286 // matmul_node is already the _FusedMatMul and we don't need to check its
1287 // data type again.
1288 if (!IsMKLEnabled() && !NodeIsOnGpu(matmul_node)) return false;
1289
1290 // Check if _FusedMatMul contains only BiasAdd
1291 auto fused_ops = matmul_node->attr().at("fused_ops").list().s();
1292 if (fused_ops.size() == 1) {
1293 if (fused_ops.at(0) != "BiasAdd") return false;
1294 } else {
1295 return false;
1296 }
1297 // Check if the matched constants have desired values.
1298 std::map<string, float> values_map = {{"square_root_two_over_pi", 0.797884},
1299 {"one", 1.0},
1300 {"one_half", 0.5},
1301 {"empirical_const", 0.044715}};
1302 if (NodeIsOnGpu(matmul_node)) {
1303 values_map["three"] = 3.0;
1304 }
1305
1306 if (!VerifyConstants(ctx, matched_nodes_map, &values_map)) return false;
1307 } else {
1308 return false;
1309 }
1310 *is_gelu_approximate = found_gelu_approximate ? true : false;
1311 return (found_gelu_exact || found_gelu_approximate);
1312 }
1313
FindMulAndMaximum(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)1314 bool FindMulAndMaximum(RemapperContext* ctx, int node_index,
1315 std::map<string, int>* matched_nodes_map,
1316 std::set<int>* remove_node_indices) {
1317 using utils::MatchingDirection;
1318 using utils::NodeStatus;
1319
1320 // clang-format off
1321 // Convert Mul+Maximum to LeakyRelu
1322 // maximum(x, alpha * x) = LeakyRelu(x)
1323 utils::OpTypePattern mulmax_pattern{
1324 "Maximum", "max_to_leakyrelu", NodeStatus::kReplace,
1325 {
1326 { "Mul", "mul", NodeStatus::kRemove,
1327 {
1328 { "*", "input", NodeStatus::kRemain},
1329 { "Const|Cast", "alpha", NodeStatus::kRemain}
1330 }
1331 },
1332 { "*", "input", NodeStatus::kRemain}
1333 }
1334 };
1335 // clang-format on
1336 // Check for allowed datatypes
1337 auto* max_node_def = ctx->graph_view.GetNode(node_index)->node();
1338 if (!HasDataType(max_node_def, DT_HALF) &&
1339 !HasDataType(max_node_def, DT_BFLOAT16) &&
1340 !HasDataType(max_node_def, DT_FLOAT) &&
1341 !HasDataType(max_node_def, DT_DOUBLE))
1342 return false;
1343
1344 // Current implementation has support only
1345 // for CPU when oneDNN is enabled.
1346 // TODO(intel-tf): This will be removed when fully tested with GPU
1347 if (!NodeIsOnCpu(max_node_def) && !IsMKLEnabled()) return false;
1348
1349 bool found_op_type_match = false;
1350 utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
1351 &(ctx->graph_view));
1352 matched_nodes_map->clear();
1353 remove_node_indices->clear();
1354
1355 found_op_type_match = graph_matcher.GetMatchedNodes(
1356 mulmax_pattern, {}, ctx->graph_view.GetNode(node_index),
1357 matched_nodes_map, remove_node_indices);
1358
1359 // Check if the value of alpha >= 0 as required for LeakyRelu
1360 if (found_op_type_match) {
1361 const auto* alpha_node_view =
1362 ctx->graph_view.GetNode(matched_nodes_map->at("alpha"));
1363 const auto* alpha_node_def = alpha_node_view->node();
1364
1365 float alpha_val;
1366 Tensor alpha_tensor;
1367 if (alpha_node_def->op() == "Cast") {
1368 const auto& regular_fanin_0 = alpha_node_view->GetRegularFanin(0);
1369 const auto* regular_node_view = regular_fanin_0.node_view();
1370 const auto* const_node = regular_node_view->node();
1371 if (const_node != nullptr && const_node->op() == "Const" &&
1372 alpha_tensor.FromProto(const_node->attr().at("value").tensor())) {
1373 // Only fusing if the const is a scalar value
1374 if (alpha_tensor.shape().dims() > 0) {
1375 return false;
1376 }
1377 alpha_val = alpha_tensor.flat<float>()(0);
1378 } else {
1379 return false;
1380 }
1381 } else if (alpha_node_def->op() == "Const" &&
1382 alpha_tensor.FromProto(
1383 alpha_node_def->attr().at("value").tensor())) {
1384 // Only fusing if the const is a scalar value
1385 if (alpha_tensor.shape().dims() > 0) {
1386 return false;
1387 }
1388 alpha_val = alpha_tensor.flat<float>()(0);
1389 } else {
1390 return false;
1391 }
1392
1393 if (alpha_val < 0) {
1394 return false;
1395 }
1396 }
1397 return found_op_type_match;
1398 }
1399
FindSigmoidAndMul(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)1400 bool FindSigmoidAndMul(RemapperContext* ctx, int node_index,
1401 std::map<string, int>* matched_nodes_map,
1402 std::set<int>* remove_node_indices) {
1403 // Gelu fusion is enabled only with oneDNN library.
1404 if (!IsMKLEnabled()) return false;
1405
1406 using utils::MatchingDirection;
1407 using utils::NodeStatus;
1408 // clang-format off
1409 // Convert Sigmoid+Mul to Swish
1410 // Mul(x, Sigmoid(x)) --> _MklSwish(x)
1411
1412 utils::OpTypePattern sigmoidmul_pattern{
1413 "Mul", "mul_to_swish", NodeStatus::kReplace,
1414 {
1415 { "Sigmoid", "sigmoid", NodeStatus::kRemove,
1416 {
1417 { "*", "input", NodeStatus::kRemain}
1418 }
1419 },
1420 { "*", "input", NodeStatus::kRemain}
1421 }
1422 };
1423 // clang-format on
1424 // check for data types
1425 auto* mul_node_def = ctx->graph_view.GetNode(node_index)->node();
1426 if (!HasDataType(mul_node_def, DT_FLOAT) &&
1427 !HasDataType(mul_node_def, DT_BFLOAT16))
1428 return false;
1429
1430 if (!NodeIsOnCpu(mul_node_def)) return false;
1431
1432 bool found_op_type_match = false;
1433 utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
1434 &(ctx->graph_view));
1435 matched_nodes_map->clear();
1436 remove_node_indices->clear();
1437 found_op_type_match = graph_matcher.GetMatchedNodes(
1438 sigmoidmul_pattern, {}, ctx->graph_view.GetNode(node_index),
1439 matched_nodes_map, remove_node_indices);
1440
1441 return found_op_type_match;
1442 }
1443
1444 // Keras LayerNormalization api uses multiple TensorFlow ops. Current fusion
1445 // pattern is only for the case, when LayerNormalization uses FusedBatcNormV3.
1446 // We further restrict it to only 2D or 3D tensor inputs to keras
1447 // LayerNormalization api.
FindMklLayerNorm(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)1448 bool FindMklLayerNorm(RemapperContext* ctx, int node_index,
1449 std::map<string, int>* matched_nodes_map,
1450 std::set<int>* remove_node_indices) {
1451 if (!IsMKLEnabled()) return false;
1452
1453 // The following pattern will be searched in the graph with additional
1454 // contraints. Here * means any type of op.
1455 // clang-format off
1456 // Subgraph for fusion
1457 // -------------------
1458 //
1459 // *(input) * * Const * Const FusedOp
1460 // \ | \ | | / Const -------
1461 // \ | \ | | / Const /
1462 // Reshape Fill Fill / / *(input) *(gamma) *(beta)
1463 // \ / / / / \ | /
1464 // \ / / / / \ | /
1465 // F u s e d B a t c h N o r m V 3 _MklLayerNorm
1466 // \
1467 // \ *
1468 // \ /
1469 // Reshape
1470 // \ *(gamma)
1471 // \ /
1472 // Mul
1473 // *(beta) /
1474 // \ /
1475 // AddV2(output)
1476 // clang-format on
1477 using utils::MatchingDirection;
1478 using utils::NodeStatus;
1479 // clang-format off
1480 utils::OpTypePattern layer_norm_pattern =
1481 {"AddV2", "output", NodeStatus::kReplace,
1482 {
1483 {"*", "beta", NodeStatus::kRemain},
1484 {"Mul", "scale", NodeStatus::kRemove,
1485 {
1486 {"Reshape", "post_reshape", NodeStatus::kRemove,
1487 {
1488 {"FusedBatchNormV3", "fused_batch_norm", NodeStatus::kRemove,
1489 {
1490 {"Reshape", "pre_reshape", NodeStatus::kRemove,
1491 {
1492 {"*", "input", NodeStatus::kRemain},
1493 {"*", "pre_shape", NodeStatus::kRemain}
1494 }
1495 },
1496 {"Fill", "fill_scale", NodeStatus::kRemove,
1497 {
1498 {"*", "dims_fill_scale", NodeStatus::kRemain},
1499 {"Const", "unit_gamma", NodeStatus::kRemain}
1500 }
1501 },
1502 {"Fill", "fill_offset", NodeStatus::kRemove,
1503 {
1504 {"*", "dims_fill_offset", NodeStatus::kRemain},
1505 {"Const", "zero_beta", NodeStatus::kRemain}
1506 }
1507 },
1508 {"Const", "empty", NodeStatus::kRemain},
1509 {"Const", "empty", NodeStatus::kRemain}
1510 }
1511 },
1512 {"*", "post_shape", NodeStatus::kRemain}
1513 }
1514 },
1515 {"*", "gamma", NodeStatus::kRemain}
1516 }
1517 }
1518 }
1519 }; // clang-format on
1520
1521 utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
1522 &(ctx->graph_view));
1523 bool found_op_type_match = false;
1524 matched_nodes_map->clear();
1525 remove_node_indices->clear();
1526 found_op_type_match =
1527 graph_matcher.GetMatchedNodes(layer_norm_pattern, ctx->nodes_to_preserve,
1528 ctx->graph_view.GetNode(node_index),
1529 matched_nodes_map, remove_node_indices);
1530
1531 // Additional check for LayerNorm
1532 if (found_op_type_match) {
1533 // LayerNorm uses FusedBatchNorm in training mode.
1534 NodeDef* fused_batch_norm_node =
1535 ctx->graph_view.GetNode(matched_nodes_map->at("fused_batch_norm"))
1536 ->node();
1537 bool is_training = false;
1538 if (!TryGetNodeAttr(*fused_batch_norm_node, kIsTraining, &is_training) ||
1539 !is_training)
1540 return false;
1541
1542 if (!NodeIsOnCpu(fused_batch_norm_node)) return false;
1543
1544 // FusedBatchNorm node should have mean/variance as empty constant
1545 NodeDef* empty_const_node =
1546 ctx->graph_view.GetNode(matched_nodes_map->at("empty"))->node();
1547 Tensor const_tensor;
1548 if (empty_const_node != nullptr && empty_const_node->op() == "Const" &&
1549 const_tensor.FromProto(empty_const_node->attr().at("value").tensor())) {
1550 if (const_tensor.NumElements() != 0) return false;
1551 } else {
1552 return false;
1553 }
1554
1555 // TODO(intel-tf): Relax the restriction of 2D/3D tensor once kernel
1556 // supports that.
1557 if (!ctx->inferred_graph_properties) {
1558 Status s = ctx->graph_properties.InferStatically(
1559 /*assume_valid_feeds=*/true,
1560 /*aggressive_shape_inference=*/false,
1561 /*include_input_tensor_values=*/true,
1562 /*include_output_tensor_values=*/false);
1563 if (!s.ok()) return false;
1564 ctx->inferred_graph_properties = true;
1565 }
1566 NodeDef* input_node_def =
1567 ctx->graph_view.GetNode(matched_nodes_map->at("input"))->node();
1568 auto input_props =
1569 ctx->graph_properties.GetOutputProperties(input_node_def->name());
1570 NodeDef* output_node_def =
1571 ctx->graph_view.GetNode(matched_nodes_map->at("output"))->node();
1572 auto output_props =
1573 ctx->graph_properties.GetOutputProperties(output_node_def->name());
1574 if (ShapesSymbolicallyEqual(input_props[0].shape(),
1575 output_props[0].shape())) {
1576 int rank = Rank(input_props[0].shape());
1577 if (rank < 2 || rank > 3) return false;
1578 } else {
1579 return false;
1580 }
1581 }
1582 return found_op_type_match;
1583 }
1584
FindFusedBatchNorm(const RemapperContext & ctx,int node_index,FusedBatchNorm * matched)1585 bool FindFusedBatchNorm(const RemapperContext& ctx, int node_index,
1586 FusedBatchNorm* matched) {
1587 const auto* node_view = ctx.graph_view.GetNode(node_index);
1588 const auto* node_def = node_view->node();
1589 if (!IsFusedBatchNorm(*node_def)) return false;
1590 if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
1591
1592 // Check that the node is in inference mode.
1593 bool is_training = true;
1594 if (!TryGetNodeAttr(*node_def, kIsTraining, &is_training)) return false;
1595 if (is_training) return false;
1596
1597 const auto& props = ctx.graph_properties.GetInputProperties(node_def->name());
1598
1599 // a. Scaling factor can be const folded:
1600 // scaling_factor = (variance + epsilon).rsqrt() * scale
1601 bool const_scaling_factor =
1602 props.size() == 5 && // [x, scale, offset, mean, variance]
1603 props[1].has_value() && // scale
1604 props[4].has_value(); // variance aka estimated variance
1605
1606 // b. Or input can be const folded into some other expression.
1607 auto const_inputs = std::count_if(
1608 props.begin(), props.end(),
1609 [](const OpInfo::TensorProperties& props) { return props.has_value(); });
1610
1611 // TODO(bsteiner): use the cost model to compare the cost of fused batch
1612 // norm against that of the optimized form.
1613 bool can_remap = const_scaling_factor || const_inputs >= 4;
1614 if (!can_remap) return false;
1615
1616 // The optimized version only generates the first output.
1617 if (node_view->GetRegularFanouts().size() > 1) {
1618 return false;
1619 }
1620
1621 // We found a fused batch norm node that can be replaced with primitive ops.
1622 matched->fused_batch_norm = node_index;
1623
1624 return true;
1625 }
1626
1627 // NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
1628 // `tensorflow/stream_executor/cuda/cuda_dnn.cc` for details.
BatchnormSpatialPersistentEnabled()1629 bool BatchnormSpatialPersistentEnabled() {
1630 #if CUDNN_VERSION >= 7402
1631 static bool is_enabled = [] {
1632 bool is_enabled = false;
1633 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
1634 "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
1635 /*default_val=*/false, &is_enabled));
1636 return is_enabled;
1637 }();
1638 return is_enabled;
1639 #else
1640 return false;
1641 #endif
1642 }
1643
FindFusedBatchNormEx(const RemapperContext & ctx,int node_index,FusedBatchNormEx * matched)1644 bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
1645 FusedBatchNormEx* matched) {
1646 // Root of the pattern must be a Relu.
1647 // TODO(ezhulenev): Forward control dependencies.
1648 const auto* node_view = ctx.graph_view.GetNode(node_index);
1649 const auto* node_def = node_view->node();
1650 // TODO(lyandy): Forward controls for patterns with control dependencies.
1651 if (!IsRelu(*node_def) || HasControlFaninOrFanout(*node_view)) return false;
1652
1653 // Returns true iff the node is a compatible FusedBatchNorm node.
1654 const auto valid_batch_norm =
1655 [&](const utils::MutableNodeView& fused_batch_norm) -> bool {
1656 const auto* fused_batch_norm_node_def = fused_batch_norm.node();
1657 if (!IsFusedBatchNorm(*fused_batch_norm_node_def)) return false;
1658
1659 // We fuse FusedBatchNorm on GPU or oneDNN CPU.
1660 if (!IsMKLEnabled() && !NodeIsOnGpu(fused_batch_norm_node_def))
1661 return false;
1662
1663 DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T");
1664
1665 if (NodeIsOnGpu(fused_batch_norm_node_def)) {
1666 // GPU supports float and half.
1667 // Put this condition before check `IsMKLEnabled()` because this node
1668 // should be processed when it's on GPU and oneDNN CPU is enabled.
1669 if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
1670 } else {
1671 // Bfloat16 is available only with oneDNN.
1672 // Half is not available with oneDNN.
1673 if (IsMKLEnabled() && t_dtype != DT_FLOAT && t_dtype != DT_BFLOAT16)
1674 return false;
1675 }
1676
1677 // Get the FusedBatchNorm training mode.
1678 bool is_training;
1679 if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
1680 .ok())
1681 return false;
1682 string data_format;
1683 if (!GetNodeAttr(*fused_batch_norm_node_def, kDataFormat, &data_format)
1684 .ok())
1685 return false;
1686 if (data_format != "NHWC" && data_format != "NCHW") return false;
1687
1688 // In training mode we rely on cuDNN for computing FusedBatchNorm with side
1689 // inputs and activation, and it has its own limitations. In inference mode
1690 // we have a custom CUDA kernel that doesn't not have these constraints.
1691 if (is_training && NodeIsOnGpu(fused_batch_norm_node_def)) {
1692 // cuDNN only supports NHWC data layout.
1693 if (data_format != "NHWC") return false;
1694
1695 // Data type must be DT_HALF.
1696 if (t_dtype != DT_HALF) return false;
1697
1698 // Channel dimension must be a multiple of 4.
1699 const auto& props = ctx.graph_properties.GetInputProperties(
1700 fused_batch_norm_node_def->name());
1701
1702 const bool valid_channel_dim = !props.empty() &&
1703 props[0].shape().dim_size() == 4 &&
1704 props[0].shape().dim(3).size() % 4 == 0;
1705 if (!valid_channel_dim) return false;
1706
1707 // cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode.
1708 if (!BatchnormSpatialPersistentEnabled()) return false;
1709 }
1710
1711 // FusedBatchNormV2 and V3 have an extra type parameter.
1712 if ((fused_batch_norm_node_def->op() != "FusedBatchNorm") &&
1713 !HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U"))
1714 return false;
1715
1716 // Check that only one node consumes the 0-th output of a FusedBatchNorm.
1717 if (HasControlFaninOrFanout(fused_batch_norm) ||
1718 !HasAtMostOneDataFanoutAtPort0(fused_batch_norm) ||
1719 IsInPreserveSet(ctx, fused_batch_norm_node_def))
1720 return false;
1721
1722 return true;
1723 };
1724
1725 if (node_view->NumRegularFanins() < 1) return false;
1726 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
1727 const auto* relu_fanin_0_node_view = regular_fanin_0.node_view();
1728 const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
1729
1730 // Input to a Relu can be a FusedBatchNorm.
1731 if (valid_batch_norm(*relu_fanin_0_node_view)) {
1732 matched->activation = node_index;
1733 matched->fused_batch_norm = regular_fanin_0.node_index();
1734 return true;
1735 }
1736
1737 // Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
1738 if (IsAdd(*relu_fanin_0_node_def)) {
1739 // Currently no CPU implementation for "FusedBatchNorm + SideInput +
1740 // <Activation>""
1741 if (IsMKLEnabled() && !NodeIsOnGpu(node_def)) return false;
1742
1743 // Check that only Relu node consumes the output of an Add node.
1744 if (HasControlFaninOrFanout(*relu_fanin_0_node_view) ||
1745 !HasAtMostOneFanoutAtPort0(*relu_fanin_0_node_view) ||
1746 IsInPreserveSet(ctx, relu_fanin_0_node_def))
1747 return false;
1748
1749 // Add node supports broadcasting, FusedBatchNormEx does not.
1750 const auto& props =
1751 ctx.graph_properties.GetInputProperties(relu_fanin_0_node_def->name());
1752 if (props.size() < 2 ||
1753 !ShapesSymbolicallyEqual(props[0].shape(), props[1].shape()))
1754 return false;
1755
1756 if (relu_fanin_0_node_view->NumRegularFanins() < 2) return false;
1757 const auto& add_regular_fanin_0 =
1758 relu_fanin_0_node_view->GetRegularFanin(0);
1759 const auto& add_regular_fanin_1 =
1760 relu_fanin_0_node_view->GetRegularFanin(1);
1761
1762 if (valid_batch_norm(*add_regular_fanin_0.node_view())) {
1763 matched->activation = node_index;
1764 matched->side_input = add_regular_fanin_1.node_index();
1765 matched->fused_batch_norm = add_regular_fanin_0.node_index();
1766 matched->invalidated = regular_fanin_0.node_index();
1767 return true;
1768 }
1769
1770 if (valid_batch_norm(*add_regular_fanin_1.node_view())) {
1771 matched->activation = node_index;
1772 matched->side_input = add_regular_fanin_0.node_index();
1773 matched->fused_batch_norm = add_regular_fanin_1.node_index();
1774 matched->invalidated = regular_fanin_0.node_index();
1775 return true;
1776 }
1777 }
1778
1779 return false;
1780 }
1781
FindFusedBatchNormGradEx(const RemapperContext & ctx,int node_index,FusedBatchNormGradEx * matched)1782 bool FindFusedBatchNormGradEx(const RemapperContext& ctx, int node_index,
1783 FusedBatchNormGradEx* matched) {
1784 // Root of the pattern must be a FusedBatchNormGrad.
1785 const utils::MutableNodeView* node_view = ctx.graph_view.GetNode(node_index);
1786
1787 // Returns true iff the node is a compatible FusedBatchNormGrad node.
1788 const auto valid_batch_norm_grad =
1789 [&](const utils::MutableNodeView& fused_batch_norm_grad) -> bool {
1790 const NodeDef* node_def = fused_batch_norm_grad.node();
1791 if (!IsFusedBatchNormGrad(*node_def) ||
1792 HasControlFaninOrFanout(fused_batch_norm_grad))
1793 return false;
1794
1795 // We fuse FusedBatchNormGrad on GPU.
1796 if (!NodeIsOnGpu(node_def)) return false;
1797
1798 // We fuse FusedBatchNormGrad only for the training mode.
1799 bool is_training;
1800 if (!GetNodeAttr(*node_def, kIsTraining, &is_training).ok() || !is_training)
1801 return false;
1802
1803 // Data type must be DT_HALF.
1804 DataType t_dtype = GetDataTypeFromAttr(*node_def, "T");
1805 if (t_dtype != DT_HALF) return false;
1806
1807 // We rely on cuDNN for computing FusedBatchNormGrad with side
1808 // outputs and activation. cuDNN only supports NHWC data layout.
1809 string data_format;
1810 if (!GetNodeAttr(*node_def, kDataFormat, &data_format).ok()) return false;
1811 if (data_format != "NHWC") return false;
1812
1813 // Channel dimension must be a multiple of 4.
1814 const auto& props =
1815 ctx.graph_properties.GetInputProperties(node_def->name());
1816 const bool valid_channel_dim = !props.empty() &&
1817 props[0].shape().dim_size() == 4 &&
1818 props[0].shape().dim(3).size() % 4 == 0;
1819 if (!valid_channel_dim) return false;
1820
1821 // cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode.
1822 if (!BatchnormSpatialPersistentEnabled()) return false;
1823
1824 // FusedBatchNormV2 and V3 have an extra type parameter.
1825 if (node_def->op() != "FusedBatchNorm" &&
1826 !HasDataType(node_def, DT_FLOAT, "U"))
1827 return false;
1828
1829 return true;
1830 };
1831
1832 if (ctx.xla_auto_clustering_on) return false;
1833
1834 if (!valid_batch_norm_grad(*node_view)) return false;
1835
1836 if (node_view->NumRegularFanins() < 1) return false;
1837
1838 const utils::MutableFanoutView& regular_fanin_0 =
1839 node_view->GetRegularFanin(0);
1840 const utils::MutableNodeView* relugrad_node_view =
1841 regular_fanin_0.node_view();
1842 const NodeDef* relugrad_node_def = relugrad_node_view->node();
1843 bool is_relugrad = IsReluGrad(*relugrad_node_def);
1844
1845 if (!is_relugrad || HasControlFaninOrFanout(*relugrad_node_view) ||
1846 IsInPreserveSet(ctx, relugrad_node_def))
1847 return false;
1848
1849 if (relugrad_node_view->NumRegularFanins() < 1) return false;
1850 // Find its corresponding forward node. We need the node to determine if the
1851 // type is bn+add+act or bn+act. Also, we need to access its "offset" input.
1852 const utils::MutableFanoutView& fanin_1 =
1853 relugrad_node_view->GetRegularFanin(1);
1854 const utils::MutableNodeView* fwd_node_view = fanin_1.node_view();
1855 FusedBatchNormEx fwd_matched;
1856 FindFusedBatchNormEx(ctx, fwd_node_view->node_index(), &fwd_matched);
1857 bool fwd_bn_act_used = fwd_matched.activation != kMissingIndex &&
1858 fwd_matched.side_input == kMissingIndex;
1859 bool fwd_bn_add_act_used = fwd_matched.activation != kMissingIndex &&
1860 fwd_matched.side_input != kMissingIndex;
1861
1862 // Check that only 1 node consumes the output of the ReluGrad node.
1863 if (fwd_bn_act_used && relugrad_node_view->GetRegularFanout(0).size() == 1) {
1864 matched->activation_grad = regular_fanin_0.node_index();
1865 matched->fused_batch_norm_grad = node_index;
1866 matched->fwd_fused_batch_norm = fwd_matched.fused_batch_norm;
1867 return true;
1868 }
1869
1870 // Check that only 2 nodes consume the output of the ReluGrad node.
1871 if (fwd_bn_add_act_used &&
1872 relugrad_node_view->GetRegularFanout(0).size() == 2) {
1873 // In a graph with the Add node having two BatchNorm nodes as the inputs, we
1874 // need to make sure only the one backward BatchNorm that correponds to the
1875 // to-be-fused forward BatchNorm should be fused. We use the edge for the
1876 // reserve space to get the directly corresponded forward BatchNorm node.
1877 const utils::MutableFanoutView& fwd_batch_norm_node =
1878 node_view->GetRegularFanin(5);
1879 if (fwd_matched.fused_batch_norm != fwd_batch_norm_node.node_index()) {
1880 return false;
1881 }
1882
1883 const std::vector<utils::MutableFaninView>& fanouts_at_port_0 =
1884 relugrad_node_view->GetRegularFanouts()[0];
1885 const utils::MutableNodeView* fanout_0_node_view =
1886 ctx.graph_view.GetNode(fanouts_at_port_0[0].node_view()->GetName());
1887 const utils::MutableNodeView* fanout_1_node_view =
1888 ctx.graph_view.GetNode(fanouts_at_port_0[1].node_view()->GetName());
1889 const NodeDef* fanout_0_node_def = fanout_0_node_view->node();
1890 const NodeDef* fanout_1_node_def = fanout_1_node_view->node();
1891 const NodeDef* node_def = node_view->node();
1892
1893 matched->activation_grad = regular_fanin_0.node_index();
1894 matched->fused_batch_norm_grad = node_index;
1895 matched->fwd_fused_batch_norm = fwd_matched.fused_batch_norm;
1896
1897 if (fanout_0_node_def == node_def) {
1898 matched->side_input_grad = fanout_1_node_view->node_index();
1899 return true;
1900 }
1901
1902 if (fanout_1_node_def == node_def) {
1903 matched->side_input_grad = fanout_0_node_view->node_index();
1904 return true;
1905 }
1906 }
1907
1908 return false;
1909 }
1910
FindTensorToHashBucket(const RemapperContext & ctx,int node_index,TensorToHashBucket * matched)1911 bool FindTensorToHashBucket(const RemapperContext& ctx, int node_index,
1912 TensorToHashBucket* matched) {
1913 // Root of the pattern must be a StringToHashBucketFast.
1914 const auto* node_view = ctx.graph_view.GetNode(node_index);
1915 const auto* node_def = node_view->node();
1916
1917 if (!IsStringToHashBucketFast(*node_def) ||
1918 HasControlFaninOrFanout(*node_view)) {
1919 return false;
1920 }
1921
1922 // Input to the StringToHashBucketFast must be AsString.
1923 if (node_view->NumRegularFanins() < 1) return false;
1924
1925 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
1926 const auto* as_string_node_view = regular_fanin_0.node_view();
1927 const auto* as_string_node_def = as_string_node_view->node();
1928 bool is_as_string = IsAsString(*as_string_node_def);
1929
1930 if (!is_as_string || HasControlFaninOrFanout(*as_string_node_view) ||
1931 !HasAtMostOneFanoutAtPort0(*as_string_node_view) ||
1932 IsInPreserveSet(ctx, as_string_node_def))
1933 return false;
1934
1935 // DataType of AsString must be int8/16/32/64 and width/fill attrs must be
1936 // default values.
1937 if (!HasDataType(as_string_node_def, DT_INT8) &&
1938 !HasDataType(as_string_node_def, DT_INT16) &&
1939 !HasDataType(as_string_node_def, DT_INT32) &&
1940 !HasDataType(as_string_node_def, DT_INT64)) {
1941 return false;
1942 }
1943
1944 int width;
1945 if (!GetNodeAttr(*as_string_node_def, kWidth, &width).ok() || width != -1) {
1946 return false;
1947 }
1948
1949 string fill;
1950 if (!GetNodeAttr(*as_string_node_def, kFill, &fill).ok() || !fill.empty()) {
1951 return false;
1952 }
1953
1954 // An input to the AsString must exist to determine the device.
1955 if (as_string_node_view->NumRegularFanins() < 1) return false;
1956
1957 const auto& fanin_0 = as_string_node_view->GetRegularFanin(0);
1958 const auto* pre_node_view = fanin_0.node_view();
1959
1960 // We successfully found a AsString + StringToHashBucketFast pattern.
1961 const TensorToHashBucket pattern{pre_node_view->node_index(),
1962 as_string_node_view->node_index(),
1963 node_index};
1964
1965 *matched = pattern;
1966
1967 return true;
1968 }
1969
FindFusedBatchMatMul(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)1970 bool FindFusedBatchMatMul(RemapperContext* ctx, int node_index,
1971 std::map<string, int>* matched_nodes_map,
1972 std::set<int>* remove_node_indices) {
1973 if (!IsMKLEnabled()) return false;
1974
1975 using utils::MatchingDirection;
1976 using utils::NodeStatus;
1977 // clang-format off
1978 utils::OpTypePattern fusion_pattern1 =
1979 {"AddV2", "output", NodeStatus::kReplace,
1980 {
1981 {"Mul", "mul", NodeStatus::kRemove,
1982 {
1983 {"BatchMatMulV2", "batch_matmul", NodeStatus::kRemove},
1984 {"*", "multiplicand", NodeStatus::kRemain}
1985 }
1986 },
1987 {"*", "addend", NodeStatus::kRemain}
1988 }
1989 };
1990
1991 utils::OpTypePattern fusion_pattern2 =
1992 {"AddV2", "output", NodeStatus::kReplace,
1993 {
1994 {"*", "addend", NodeStatus::kRemain},
1995 {"Mul", "mul", NodeStatus::kRemove,
1996 {
1997 {"BatchMatMulV2", "batch_matmul", NodeStatus::kRemove},
1998 {"*", "multiplicand", NodeStatus::kRemain}
1999 }
2000 }
2001 }
2002 };
2003 // clang-format on
2004
2005 utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
2006 &(ctx->graph_view));
2007 bool found_op_type_match = false;
2008 matched_nodes_map->clear();
2009 remove_node_indices->clear();
2010 found_op_type_match =
2011 graph_matcher.GetMatchedNodes(fusion_pattern1, ctx->nodes_to_preserve,
2012 ctx->graph_view.GetNode(node_index),
2013 matched_nodes_map, remove_node_indices);
2014
2015 if (!found_op_type_match) {
2016 matched_nodes_map->clear();
2017 remove_node_indices->clear();
2018 found_op_type_match =
2019 graph_matcher.GetMatchedNodes(fusion_pattern2, ctx->nodes_to_preserve,
2020 ctx->graph_view.GetNode(node_index),
2021 matched_nodes_map, remove_node_indices);
2022 }
2023
2024 // OneDNN is not optimized for all shapes with regard to binary-post ops
2025 // fusion. Allow limited cases only for now that are optimized, (i)
2026 // multiplicand is scalar, (ii) BatchMatmulV2 output is 4D tensor, and (iii)
2027 // addend is 4D tensor with second dim_size = 1.
2028 if (!found_op_type_match) return false;
2029 if (!ctx->inferred_graph_properties) {
2030 Status s = ctx->graph_properties.InferStatically(
2031 /*assume_valid_feeds=*/true,
2032 /*aggressive_shape_inference=*/false,
2033 /*include_input_tensor_values=*/false,
2034 /*include_output_tensor_values=*/true);
2035 if (!s.ok()) return false;
2036 ctx->inferred_graph_properties = true;
2037 }
2038 NodeDef* multiplicand_node_def =
2039 ctx->graph_view.GetNode(matched_nodes_map->at("multiplicand"))->node();
2040 auto multiplicand_props =
2041 ctx->graph_properties.GetOutputProperties(multiplicand_node_def->name());
2042 if (NumCoefficients(multiplicand_props[0].shape()) != 1) return false;
2043
2044 NodeDef* batch_matmul_node_def =
2045 ctx->graph_view.GetNode(matched_nodes_map->at("batch_matmul"))->node();
2046 if (!IsCpuCompatibleMatMul(*ctx, batch_matmul_node_def)) return false;
2047
2048 auto batch_matmul_props =
2049 ctx->graph_properties.GetOutputProperties(batch_matmul_node_def->name());
2050 if (Rank(batch_matmul_props[0].shape()) != 4) return false;
2051
2052 NodeDef* addend_node_def =
2053 ctx->graph_view.GetNode(matched_nodes_map->at("addend"))->node();
2054 auto addend_props =
2055 ctx->graph_properties.GetOutputProperties(addend_node_def->name());
2056 auto addend_shape = addend_props[0].shape();
2057 if (!(Rank(addend_shape) == 4 && addend_shape.dim(1).size() == 1))
2058 return false;
2059 return found_op_type_match;
2060 }
2061
CopyConv2DAttributes(const NodeDef & conv2d,NodeDef * fused_conv2d,const NodeDef * activation=nullptr)2062 void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d,
2063 const NodeDef* activation = nullptr) {
2064 DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D";
2065
2066 auto* attr = fused_conv2d->mutable_attr();
2067 auto& src_attr = conv2d.attr();
2068
2069 (*attr)["T"] = src_attr.at("T");
2070 (*attr)["strides"] = src_attr.at("strides");
2071 (*attr)["padding"] = src_attr.at("padding");
2072 (*attr)["explicit_paddings"] = src_attr.at("explicit_paddings");
2073 (*attr)["dilations"] = src_attr.at("dilations");
2074 (*attr)["data_format"] = src_attr.at("data_format");
2075 (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
2076 // Copy LeakyRelu's attr alpha to FusedConv2D's attr leakyrelu_alpha
2077 if (activation != nullptr && IsLeakyRelu(*activation)) {
2078 auto& activation_attr = activation->attr();
2079 (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
2080 }
2081 }
2082
CopyConv3DAttributes(const NodeDef & conv3d,NodeDef * fused_conv3d,const NodeDef * activation=nullptr)2083 void CopyConv3DAttributes(const NodeDef& conv3d, NodeDef* fused_conv3d,
2084 const NodeDef* activation = nullptr) {
2085 DCHECK(IsConv3D(conv3d)) << "Input node must be a Conv3D";
2086
2087 auto* attr = fused_conv3d->mutable_attr();
2088 auto& src_attr = conv3d.attr();
2089
2090 (*attr)["T"] = src_attr.at("T");
2091 (*attr)["strides"] = src_attr.at("strides");
2092 (*attr)["padding"] = src_attr.at("padding");
2093 (*attr)["dilations"] = src_attr.at("dilations");
2094 (*attr)["data_format"] = src_attr.at("data_format");
2095 // Copy LeakyRelu's attr alpha to FusedConv3D's attr leakyrelu_alpha
2096 if (activation != nullptr && IsLeakyRelu(*activation)) {
2097 auto& activation_attr = activation->attr();
2098 (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
2099 }
2100 }
2101
CopyDepthwiseConv2dNativeAttributes(const NodeDef & dw_conv2d,NodeDef * fused_dw_conv2d,const NodeDef * activation=nullptr)2102 void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
2103 NodeDef* fused_dw_conv2d,
2104 const NodeDef* activation = nullptr) {
2105 DCHECK(IsDepthwiseConv2dNative(dw_conv2d))
2106 << "Input node must be a DepthwiseConv2dNative";
2107
2108 auto* attr = fused_dw_conv2d->mutable_attr();
2109 auto& src_attr = dw_conv2d.attr();
2110
2111 (*attr)["T"] = src_attr.at("T");
2112 (*attr)["strides"] = src_attr.at("strides");
2113 (*attr)["padding"] = src_attr.at("padding");
2114 (*attr)["dilations"] = src_attr.at("dilations");
2115 (*attr)["data_format"] = src_attr.at("data_format");
2116 // Copy LeakyRelu's attr alpha to FusedDepthwiseConv2d's attr leakyrelu_alpha
2117 if (activation != nullptr && IsLeakyRelu(*activation)) {
2118 auto& activation_attr = activation->attr();
2119 (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
2120 }
2121 }
2122
CopyFusedBatchNormAttributes(const NodeDef & fused_batch_norm,NodeDef * fused_batch_norm_ex)2123 void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
2124 NodeDef* fused_batch_norm_ex) {
2125 DCHECK(IsFusedBatchNorm(fused_batch_norm))
2126 << "Input node must be a FusedBatchNorm";
2127
2128 auto* attr = fused_batch_norm_ex->mutable_attr();
2129 auto src_attr = fused_batch_norm.attr();
2130
2131 (*attr)["T"] = src_attr.at("T");
2132 (*attr)["is_training"] = src_attr.at("is_training");
2133 (*attr)["data_format"] = src_attr.at("data_format");
2134 (*attr)["epsilon"] = src_attr.at("epsilon");
2135 (*attr)["exponential_avg_factor"] = src_attr.at("exponential_avg_factor");
2136
2137 // FusedBatchNormV2 and V3 have an extra type parameter.
2138 if (fused_batch_norm.op() != "FusedBatchNorm") {
2139 SetAttrValue(src_attr.at("U"), &(*attr)["U"]);
2140 } else {
2141 if (!IsMKLEnabled())
2142 SetAttrValue(src_attr.at("T"), &(*attr)["U"]);
2143 else
2144 SetAttrValue(DT_FLOAT, &(*attr)["U"]);
2145 }
2146 }
2147
CopyFusedBatchNormGradAttributes(const NodeDef & fused_batch_norm_grad,NodeDef * fused_batch_norm_grad_ex)2148 void CopyFusedBatchNormGradAttributes(const NodeDef& fused_batch_norm_grad,
2149 NodeDef* fused_batch_norm_grad_ex) {
2150 DCHECK(IsFusedBatchNormGrad(fused_batch_norm_grad))
2151 << "Input node must be a FusedBatchNormGrad";
2152
2153 auto* attr = fused_batch_norm_grad_ex->mutable_attr();
2154 auto src_attr = fused_batch_norm_grad.attr();
2155
2156 (*attr)["T"] = src_attr.at("T");
2157 (*attr)["is_training"] = src_attr.at("is_training");
2158 (*attr)["data_format"] = src_attr.at("data_format");
2159 (*attr)["epsilon"] = src_attr.at("epsilon");
2160
2161 // FusedBatchNormV2 and V3 have an extra type parameter.
2162 if (fused_batch_norm_grad.op() != "FusedBatchNormGrad") {
2163 SetAttrValue(src_attr.at("U"), &(*attr)["U"]);
2164 } else {
2165 SetAttrValue(DT_FLOAT, &(*attr)["U"]);
2166 }
2167 }
2168
CopyMatMulAttributes(const NodeDef & matmul,NodeDef * fused_matmul,const NodeDef * activation=nullptr)2169 void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul,
2170 const NodeDef* activation = nullptr) {
2171 DCHECK(IsMatMul(matmul)) << "Input node must be a MatMul";
2172
2173 auto* attr = fused_matmul->mutable_attr();
2174 auto& src_attr = matmul.attr();
2175
2176 (*attr)["T"] = src_attr.at("T");
2177 (*attr)["transpose_a"] = src_attr.at("transpose_a");
2178 (*attr)["transpose_b"] = src_attr.at("transpose_b");
2179 // Copy LeakyRelu's attr alpha to _FusedMatMul's attr leakyrelu_alpha
2180 if (activation != nullptr && IsLeakyRelu(*activation)) {
2181 auto& activation_attr = activation->attr();
2182 (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
2183 }
2184 }
2185
CopyBatchMatMulAttributes(const NodeDef & batchmatmul,NodeDef * fused_batch_matmul)2186 void CopyBatchMatMulAttributes(const NodeDef& batchmatmul,
2187 NodeDef* fused_batch_matmul) {
2188 DCHECK(IsAnyBatchMatMul(batchmatmul)) << "Input node must be a BatchMatMul";
2189
2190 auto* attr = fused_batch_matmul->mutable_attr();
2191 auto& src_attr = batchmatmul.attr();
2192
2193 (*attr)["T"] = src_attr.at("T");
2194 (*attr)["adj_x"] = src_attr.at("adj_x");
2195 (*attr)["adj_y"] = src_attr.at("adj_y");
2196 }
2197
SetFusedOpAttributes(NodeDef * fused,const absl::Span<const absl::string_view> fused_ops,int num_args=1,float epsilon=0.0)2198 void SetFusedOpAttributes(NodeDef* fused,
2199 const absl::Span<const absl::string_view> fused_ops,
2200 int num_args = 1, float epsilon = 0.0) {
2201 auto* attr = fused->mutable_attr();
2202 SetAttrValue(fused_ops, &(*attr)["fused_ops"]);
2203 SetAttrValue(num_args, &(*attr)["num_args"]);
2204 SetAttrValue(epsilon, &(*attr)["epsilon"]); // required only for BatchNorm
2205 }
2206
AddFusedContractionNode(RemapperContext * ctx,const ContractionWithBiasAdd & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2207 Status AddFusedContractionNode(RemapperContext* ctx,
2208 const ContractionWithBiasAdd& matched,
2209 std::vector<bool>* invalidated_nodes,
2210 std::vector<bool>* nodes_to_delete) {
2211 DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
2212
2213 const GraphDef* graph = ctx->graph_view.graph();
2214 const NodeDef& contraction = graph->node(matched.contraction);
2215 const NodeDef& bias_add = graph->node(matched.bias_add);
2216 VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd: "
2217 << " bias_add=" << bias_add.name()
2218 << " contraction=" << contraction.name();
2219
2220 NodeDef fused_op;
2221 fused_op.set_name(bias_add.name());
2222 fused_op.set_device(contraction.device());
2223 fused_op.add_input(contraction.input(0)); // 0: input
2224 fused_op.add_input(contraction.input(1)); // 1: filter
2225 fused_op.add_input(bias_add.input(matched.bias_port)); // 2: bias
2226 if (IsConv2D(contraction)) {
2227 fused_op.set_op(kFusedConv2D);
2228 CopyConv2DAttributes(contraction, &fused_op);
2229 } else if (IsDepthwiseConv2dNative(contraction)) {
2230 fused_op.set_op(kFusedDepthwiseConv2dNative);
2231 CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
2232 } else if (IsMatMul(contraction)) {
2233 fused_op.set_op(kFusedMatMul);
2234 CopyMatMulAttributes(contraction, &fused_op);
2235 } else if (IsConv3D(contraction)) {
2236 fused_op.set_op(kFusedConv3D);
2237 CopyConv3DAttributes(contraction, &fused_op);
2238 }
2239
2240 SetFusedOpAttributes(&fused_op, {"BiasAdd"});
2241 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2242 Status status;
2243 mutation->AddNode(std::move(fused_op), &status);
2244 TF_RETURN_IF_ERROR(status);
2245 TF_RETURN_IF_ERROR(mutation->Apply());
2246
2247 (*invalidated_nodes)[matched.bias_add] = true;
2248 (*nodes_to_delete)[matched.contraction] = true;
2249
2250 return OkStatus();
2251 }
2252
AddFusedContractionNode(RemapperContext * ctx,const ContractionWithBiasAddAndActivation & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2253 Status AddFusedContractionNode(
2254 RemapperContext* ctx, const ContractionWithBiasAddAndActivation& matched,
2255 std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
2256 DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
2257
2258 const GraphDef* graph = ctx->graph_view.graph();
2259 const NodeDef& contraction = graph->node(matched.contraction);
2260 const NodeDef& bias_add = graph->node(matched.bias_add);
2261 const NodeDef& activation = graph->node(matched.activation);
2262
2263 VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd and "
2264 << activation.op() << ":"
2265 << " activation=" << activation.name()
2266 << " bias_add=" << bias_add.name()
2267 << " contraction=" << contraction.name();
2268
2269 NodeDef fused_op;
2270 fused_op.set_name(activation.name());
2271 fused_op.set_device(contraction.device());
2272 fused_op.add_input(contraction.input(0)); // 0: input
2273 fused_op.add_input(contraction.input(1)); // 1: filter
2274 fused_op.add_input(bias_add.input(matched.bias_port)); // 2: bias
2275
2276 if (IsConv2D(contraction)) {
2277 fused_op.set_op(kFusedConv2D);
2278 // leaky relu has a special attribute alpha
2279 CopyConv2DAttributes(contraction, &fused_op, &activation);
2280 } else if (IsDepthwiseConv2dNative(contraction)) {
2281 fused_op.set_op(kFusedDepthwiseConv2dNative);
2282 CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
2283 } else if (IsMatMul(contraction)) {
2284 fused_op.set_op(kFusedMatMul);
2285 CopyMatMulAttributes(contraction, &fused_op, &activation);
2286 } else if (IsConv3D(contraction)) {
2287 fused_op.set_op(kFusedConv3D);
2288 CopyConv3DAttributes(contraction, &fused_op, &activation);
2289 }
2290
2291 SetFusedOpAttributes(&fused_op, {"BiasAdd", activation.op()});
2292
2293 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2294 Status status;
2295 mutation->AddNode(std::move(fused_op), &status);
2296 TF_RETURN_IF_ERROR(status);
2297 TF_RETURN_IF_ERROR(mutation->Apply());
2298
2299 (*nodes_to_delete)[matched.contraction] = true;
2300 (*nodes_to_delete)[matched.bias_add] = true;
2301 (*invalidated_nodes)[matched.activation] = true;
2302
2303 return OkStatus();
2304 }
2305
AddFusedConvNode(RemapperContext * ctx,const ContractionWithSqueezeAndBiasAdd & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2306 Status AddFusedConvNode(RemapperContext* ctx,
2307 const ContractionWithSqueezeAndBiasAdd& matched,
2308 std::vector<bool>* invalidated_nodes,
2309 std::vector<bool>* nodes_to_delete) {
2310 DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
2311
2312 const GraphDef* graph = ctx->graph_view.graph();
2313 const NodeDef& contraction = graph->node(matched.contraction);
2314
2315 const NodeDef& bias_add = graph->node(matched.bias_add);
2316 const NodeDef& squeeze = graph->node(matched.squeeze);
2317 VLOG(2) << "Fuse Conv2D/3D with Squeeze and BiasAdd: "
2318 << " bias_add=" << bias_add.name() << " squeeze=" << squeeze.name()
2319 << " conv=" << contraction.name();
2320
2321 // Replace Conv2D/3D node with a fused Conv2D/3D. Matched pattern guarantees
2322 // that it has single consumer (only the squeeze node).
2323 NodeDef fused_conv;
2324 fused_conv.set_name(contraction.name());
2325 fused_conv.set_device(contraction.device());
2326 fused_conv.add_input(contraction.input(0)); // 0: input
2327 fused_conv.add_input(contraction.input(1)); // 1: filter
2328 fused_conv.add_input(bias_add.input(1)); // 2: bias
2329
2330 if (IsConv2D(contraction)) {
2331 fused_conv.set_op(kFusedConv2D);
2332 CopyConv2DAttributes(contraction, &fused_conv);
2333 } else if (IsConv3D(contraction)) {
2334 fused_conv.set_op(kFusedConv3D);
2335 CopyConv3DAttributes(contraction, &fused_conv);
2336 }
2337
2338 SetFusedOpAttributes(&fused_conv, {"BiasAdd"});
2339
2340 // Replace BiasAdd node with a Squeeze.
2341 NodeDef remapped_squeeze = squeeze;
2342 remapped_squeeze.set_name(bias_add.name());
2343 remapped_squeeze.set_input(0, contraction.name());
2344
2345 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2346 Status status;
2347 mutation->AddNode(std::move(fused_conv), &status);
2348 TF_RETURN_IF_ERROR(status);
2349 mutation->AddNode(std::move(remapped_squeeze), &status);
2350 TF_RETURN_IF_ERROR(status);
2351 TF_RETURN_IF_ERROR(mutation->Apply());
2352
2353 (*invalidated_nodes)[matched.contraction] = true;
2354 (*invalidated_nodes)[matched.bias_add] = true;
2355 (*nodes_to_delete)[matched.squeeze] = true;
2356
2357 return OkStatus();
2358 }
2359
AddFusedConv2DNode(RemapperContext * ctx,const ContractionWithBatchNorm & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2360 Status AddFusedConv2DNode(RemapperContext* ctx,
2361 const ContractionWithBatchNorm& matched,
2362 std::vector<bool>* invalidated_nodes,
2363 std::vector<bool>* nodes_to_delete) {
2364 const GraphDef* graph = ctx->graph_view.graph();
2365 const NodeDef& contraction = graph->node(matched.contraction);
2366 DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";
2367 const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
2368 VLOG(2) << "Fuse Conv2D with BatchNorm: batch_norm="
2369 << fused_batch_norm.name() << " conv2d=" << contraction.name();
2370
2371 NodeDef fused_conv2d;
2372 fused_conv2d.set_name(fused_batch_norm.name());
2373 fused_conv2d.set_op(kFusedConv2D);
2374 fused_conv2d.set_device(contraction.device());
2375 fused_conv2d.add_input(contraction.input(0)); // 0: input
2376 fused_conv2d.add_input(contraction.input(1)); // 1: filter
2377 fused_conv2d.add_input(fused_batch_norm.input(1)); // 2: scale
2378 fused_conv2d.add_input(fused_batch_norm.input(2)); // 3: offset
2379 fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean
2380 fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance
2381
2382 CopyConv2DAttributes(contraction, &fused_conv2d);
2383 SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm"},
2384 /*num_args=*/4, /*epsilon=*/matched.epsilon);
2385
2386 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2387 Status status;
2388 mutation->AddNode(std::move(fused_conv2d), &status);
2389 TF_RETURN_IF_ERROR(status);
2390 TF_RETURN_IF_ERROR(mutation->Apply());
2391
2392 (*invalidated_nodes)[matched.fused_batch_norm] = true;
2393 (*nodes_to_delete)[matched.contraction] = true;
2394
2395 return OkStatus();
2396 }
2397
AddFusedConv2DNode(RemapperContext * ctx,const ContractionWithBatchNormAndActivation & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2398 Status AddFusedConv2DNode(RemapperContext* ctx,
2399 const ContractionWithBatchNormAndActivation& matched,
2400 std::vector<bool>* invalidated_nodes,
2401 std::vector<bool>* nodes_to_delete) {
2402 const GraphDef* graph = ctx->graph_view.graph();
2403 const NodeDef& contraction = graph->node(matched.contraction);
2404
2405 DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";
2406
2407 const NodeDef& activation = graph->node(matched.activation);
2408 const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
2409 VLOG(2) << "Fuse Conv2D with BatchNorm and " << activation.op()
2410 << ": activation=" << activation.name()
2411 << " batch_norm=" << fused_batch_norm.name()
2412 << " conv2d=" << contraction.name();
2413
2414 NodeDef fused_conv2d;
2415 fused_conv2d.set_name(activation.name());
2416 fused_conv2d.set_op(kFusedConv2D);
2417 fused_conv2d.set_device(contraction.device());
2418 fused_conv2d.add_input(contraction.input(0)); // 0: input
2419 fused_conv2d.add_input(contraction.input(1)); // 1: filter
2420 fused_conv2d.add_input(fused_batch_norm.input(1)); // 2: scale
2421 fused_conv2d.add_input(fused_batch_norm.input(2)); // 3: offset
2422 fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean
2423 fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance
2424
2425 CopyConv2DAttributes(contraction, &fused_conv2d, &activation);
2426 SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm", activation.op()},
2427 /*num_args=*/4, /*epsilon=*/matched.epsilon);
2428
2429 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2430 Status status;
2431 mutation->AddNode(std::move(fused_conv2d), &status);
2432 TF_RETURN_IF_ERROR(status);
2433 TF_RETURN_IF_ERROR(mutation->Apply());
2434
2435 (*invalidated_nodes)[matched.activation] = true;
2436 (*nodes_to_delete)[matched.contraction] = true;
2437 (*nodes_to_delete)[matched.fused_batch_norm] = true;
2438
2439 return OkStatus();
2440 }
2441
AddFusedContractionNode(RemapperContext * ctx,const ContractionWithBiasAddAndAdd & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2442 Status AddFusedContractionNode(RemapperContext* ctx,
2443 const ContractionWithBiasAddAndAdd& matched,
2444 std::vector<bool>* invalidated_nodes,
2445 std::vector<bool>* nodes_to_delete) {
2446 const GraphDef* graph = ctx->graph_view.graph();
2447 const NodeDef& contraction = graph->node(matched.contraction);
2448 const NodeDef& bias_add = graph->node(matched.bias_add);
2449
2450 // oneDNN version only supports fusion for Conv2D/3D and MatMul
2451 DCHECK(IsConv2D(contraction) || IsMatMul(contraction) ||
2452 IsConv3D(contraction));
2453
2454 NodeDef contraction_node;
2455 const NodeDef& add = graph->node(matched.add);
2456 contraction_node.set_name(add.name());
2457 contraction_node.set_device(contraction.device());
2458 contraction_node.add_input(
2459 contraction.input(0)); // 0: input(conv) / a (matmul)
2460 contraction_node.add_input(
2461 contraction.input(1)); // 1: filter(conv) / b (matmul)
2462 contraction_node.add_input(bias_add.input(matched.bias_port)); // 2: bias
2463
2464 // Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
2465 // previously, the other input to add is fused here.
2466 contraction_node.add_input(add.input(1 - matched.port_id));
2467
2468 if (IsConv2D(contraction)) {
2469 contraction_node.set_op(kFusedConv2D);
2470 CopyConv2DAttributes(contraction, &contraction_node);
2471 } else if (IsMatMul(contraction)) {
2472 contraction_node.set_op(kFusedMatMul);
2473 CopyMatMulAttributes(contraction, &contraction_node);
2474 } else if (IsConv3D(contraction)) {
2475 contraction_node.set_op(kFusedConv3D);
2476 CopyConv3DAttributes(contraction, &contraction_node);
2477 }
2478
2479 SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2);
2480
2481 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2482 Status status;
2483 mutation->AddNode(std::move(contraction_node), &status);
2484 TF_RETURN_IF_ERROR(status);
2485 TF_RETURN_IF_ERROR(mutation->Apply());
2486
2487 (*invalidated_nodes)[matched.add] = true;
2488 (*nodes_to_delete)[matched.contraction] = true;
2489 (*nodes_to_delete)[matched.bias_add] = true;
2490
2491 return OkStatus();
2492 }
2493
AddFusedConv3DNode(RemapperContext * ctx,const PadWithConv3D & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2494 Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched,
2495 std::vector<bool>* invalidated_nodes,
2496 std::vector<bool>* nodes_to_delete) {
2497 const GraphDef* graph = ctx->graph_view.graph();
2498 const NodeDef& contraction = graph->node(matched.contraction_idx);
2499 const NodeDef& pad_node_def = graph->node(matched.pad_idx);
2500 const NodeDef& padding_const_node_def =
2501 graph->node(matched.padding_const_idx);
2502 VLOG(2) << "Fuse " << pad_node_def.op() << " with contraction: "
2503 << " contraction=" << contraction.name();
2504
2505 NodeDef fused_node;
2506 fused_node.set_name(contraction.name());
2507 fused_node.set_device(contraction.device());
2508 fused_node.add_input(pad_node_def.input(0)); // 0: input
2509 fused_node.add_input(contraction.input(1)); // 1: filter
2510 fused_node.set_op(kFusedConv3D);
2511
2512 auto* attr = fused_node.mutable_attr();
2513 auto& src_attr = contraction.attr();
2514 (*attr)["T"] = src_attr.at("T");
2515 (*attr)["strides"] = src_attr.at("strides");
2516 (*attr)["data_format"] = src_attr.at("data_format");
2517 (*attr)["padding"] = src_attr.at("padding");
2518 (*attr)["dilations"] = src_attr.at("dilations");
2519
2520 if (contraction.op() == kFusedConv3D) {
2521 fused_node.add_input(contraction.input(2)); // 2: bias
2522 (*attr)["fused_ops"] = src_attr.at("fused_ops");
2523 (*attr)["num_args"] = src_attr.at("num_args");
2524 } else {
2525 SetAttrValue(0, &(*attr)["num_args"]);
2526 }
2527
2528 Tensor const_tensor;
2529 if (padding_const_node_def.op() == "Const" &&
2530 const_tensor.FromProto(
2531 padding_const_node_def.attr().at("value").tensor())) {
2532 auto const_value = const_tensor.flat<int32>();
2533 std::vector<int32> paddings;
2534 for (int i = 0; i < const_value.size(); ++i) {
2535 paddings.push_back(const_value(i));
2536 SetAttrValue(paddings, &(*attr)["padding_list"]);
2537 }
2538 }
2539
2540 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2541 Status status;
2542 mutation->AddNode(std::move(fused_node), &status);
2543 TF_RETURN_IF_ERROR(status);
2544 TF_RETURN_IF_ERROR(mutation->Apply());
2545
2546 (*invalidated_nodes)[matched.contraction_idx] = true;
2547 (*nodes_to_delete)[matched.pad_idx] = true;
2548 return OkStatus();
2549 }
2550
AddFusedContractionNode(RemapperContext * ctx,const ContractionWithBiasAndAddActivation & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2551 Status AddFusedContractionNode(
2552 RemapperContext* ctx, const ContractionWithBiasAndAddActivation& matched,
2553 std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
2554 const GraphDef* graph = ctx->graph_view.graph();
2555 // MKL version only support fusion for Conv2D
2556 const NodeDef& contraction = graph->node(matched.contraction);
2557 DCHECK(IsConv2D(contraction) || IsConv3D(contraction));
2558 const NodeDef& activation = graph->node(matched.activation);
2559
2560 NodeDef fused_conv;
2561 fused_conv.set_name(activation.name());
2562 fused_conv.set_device(contraction.device());
2563 fused_conv.add_input(contraction.input(0)); // 0: input
2564 fused_conv.add_input(contraction.input(1)); // 1: filter
2565 const NodeDef& bias_add = graph->node(matched.bias_add);
2566 fused_conv.add_input(bias_add.input(matched.bias_port)); // 2: bias
2567
2568 if (IsConv2D(contraction)) {
2569 fused_conv.set_op(kFusedConv2D);
2570 CopyConv2DAttributes(contraction, &fused_conv);
2571 } else if (IsConv3D(contraction)) {
2572 fused_conv.set_op(kFusedConv3D);
2573 CopyConv3DAttributes(contraction, &fused_conv);
2574 }
2575
2576 // Add OP has two inputs, one is conv+bias pattern matched previously,
2577 // the other input to add is fused here.
2578 const NodeDef& add = graph->node(matched.add);
2579 fused_conv.add_input(add.input(1 - matched.port_id));
2580
2581 SetFusedOpAttributes(&fused_conv, {"BiasAdd", "Add", activation.op()}, 2);
2582
2583 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2584 Status status;
2585 mutation->AddNode(std::move(fused_conv), &status);
2586 TF_RETURN_IF_ERROR(status);
2587 TF_RETURN_IF_ERROR(mutation->Apply());
2588
2589 (*invalidated_nodes)[matched.activation] = true;
2590 (*nodes_to_delete)[matched.add] = true;
2591 (*nodes_to_delete)[matched.bias_add] = true;
2592 (*nodes_to_delete)[matched.contraction] = true;
2593
2594 return OkStatus();
2595 }
2596
AddFusedMatMulBiasAddAndGelu(RemapperContext * ctx,const std::map<string,int> & matched_nodes_map,const std::set<int> & remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete,bool is_gelu_approximate)2597 Status AddFusedMatMulBiasAddAndGelu(
2598 RemapperContext* ctx, const std::map<string, int>& matched_nodes_map,
2599 const std::set<int>& remove_node_indices,
2600 std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete,
2601 bool is_gelu_approximate) {
2602 auto* output_node =
2603 ctx->graph_view.GetNode(matched_nodes_map.at("output"))->node();
2604 auto* matmul_node =
2605 ctx->graph_view.GetNode(matched_nodes_map.at("matmul"))->node();
2606
2607 NodeDef fused_node;
2608 // Fused node should have the name of terminal node of the fusion.
2609 fused_node.set_name(output_node->name());
2610 fused_node.set_op("_FusedMatMul");
2611 fused_node.set_device(matmul_node->device());
2612 fused_node.add_input(matmul_node->input(0));
2613 fused_node.add_input(matmul_node->input(1));
2614 if (is_gelu_approximate) {
2615 fused_node.add_input(matmul_node->input(2));
2616 } else {
2617 auto* bias_add_node =
2618 ctx->graph_view.GetNode(matched_nodes_map.at("bias_add"))->node();
2619 fused_node.add_input(bias_add_node->input(1));
2620 }
2621 CopyMatMulAttributes(*matmul_node, &fused_node);
2622 if (is_gelu_approximate)
2623 SetFusedOpAttributes(&fused_node, {"BiasAdd", "GeluApproximate"});
2624 else
2625 SetFusedOpAttributes(&fused_node, {"BiasAdd", "GeluExact"});
2626
2627 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2628 Status status;
2629 mutation->AddNode(std::move(fused_node), &status);
2630 TF_RETURN_IF_ERROR(status);
2631 TF_RETURN_IF_ERROR(mutation->Apply());
2632 (*invalidated_nodes)[matched_nodes_map.at("output")] = true;
2633
2634 for (const auto& node_idx : remove_node_indices) {
2635 (*nodes_to_delete)[node_idx] = true;
2636 }
2637 return OkStatus();
2638 }
2639
AddMklLayerNorm(RemapperContext * ctx,const std::map<string,int> & matched_nodes_map,const std::set<int> & remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2640 Status AddMklLayerNorm(RemapperContext* ctx,
2641 const std::map<string, int>& matched_nodes_map,
2642 const std::set<int>& remove_node_indices,
2643 std::vector<bool>* invalidated_nodes,
2644 std::vector<bool>* nodes_to_delete) {
2645 auto* pre_reshape_node =
2646 ctx->graph_view.GetNode(matched_nodes_map.at("pre_reshape"))->node();
2647 auto* scale_node =
2648 ctx->graph_view.GetNode(matched_nodes_map.at("gamma"))->node();
2649 auto* output_node =
2650 ctx->graph_view.GetNode(matched_nodes_map.at("output"))->node();
2651
2652 NodeDef fused_node;
2653 fused_node.set_name(output_node->name());
2654 fused_node.set_op("_MklLayerNorm");
2655 fused_node.set_device(output_node->device());
2656 fused_node.add_input(pre_reshape_node->input(0));
2657 fused_node.add_input(scale_node->name());
2658 fused_node.add_input(output_node->input(0));
2659 auto* attr = fused_node.mutable_attr();
2660 auto& src_attr = output_node->attr();
2661 (*attr)["T"] = src_attr.at("T");
2662
2663 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2664 Status status;
2665 mutation->AddNode(std::move(fused_node), &status);
2666 TF_RETURN_IF_ERROR(status);
2667 TF_RETURN_IF_ERROR(mutation->Apply());
2668 (*invalidated_nodes)[matched_nodes_map.at("output")] = true;
2669
2670 for (const auto& node_idx : remove_node_indices) {
2671 (*nodes_to_delete)[node_idx] = true;
2672 }
2673 return OkStatus();
2674 }
2675
ReplaceMulMaximumWithLeakyRelu(RemapperContext * ctx,const std::map<string,int> & matched_nodes_map,const std::set<int> & remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2676 Status ReplaceMulMaximumWithLeakyRelu(
2677 RemapperContext* ctx, const std::map<string, int>& matched_nodes_map,
2678 const std::set<int>& remove_node_indices,
2679 std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
2680 const NodeDef* maximum =
2681 ctx->graph_view.GetNode(matched_nodes_map.at("max_to_leakyrelu"))->node();
2682 const NodeDef* input =
2683 ctx->graph_view.GetNode(matched_nodes_map.at("input"))->node();
2684 const auto* alpha_node_view =
2685 ctx->graph_view.GetNode(matched_nodes_map.at("alpha"));
2686 const auto* alpha_node_def = alpha_node_view->node();
2687
2688 NodeDef fused_op;
2689 fused_op.set_name(maximum->name());
2690 fused_op.set_op("LeakyRelu");
2691 fused_op.set_device(maximum->device());
2692 fused_op.add_input(input->name());
2693
2694 auto* attr = fused_op.mutable_attr();
2695 (*attr)["T"] = maximum->attr().at("T");
2696
2697 // BF16 adds a cast before the const alpha, so accessing the const node
2698 // using the cast node to retrieve the value of alpha.
2699 float alpha_val;
2700 Tensor alpha_tensor;
2701 if (alpha_node_def->op() == "Cast") {
2702 const auto& regular_fanin_0 = alpha_node_view->GetRegularFanin(0);
2703 const auto* regular_node_view = regular_fanin_0.node_view();
2704 const auto* const_node = regular_node_view->node();
2705 if (const_node != nullptr && const_node->op() == "Const" &&
2706 alpha_tensor.FromProto(const_node->attr().at("value").tensor())) {
2707 alpha_val = alpha_tensor.flat<float>()(0);
2708 SetAttrValue(alpha_val, &(*attr)["alpha"]);
2709 }
2710 } else if (alpha_node_def->op() == "Const" &&
2711 alpha_tensor.FromProto(
2712 alpha_node_def->attr().at("value").tensor())) {
2713 alpha_val = alpha_tensor.flat<float>()(0);
2714 SetAttrValue(alpha_val, &(*attr)["alpha"]);
2715 }
2716
2717 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2718 Status status;
2719 mutation->AddNode(std::move(fused_op), &status);
2720 TF_RETURN_IF_ERROR(status);
2721 TF_RETURN_IF_ERROR(mutation->Apply());
2722
2723 (*invalidated_nodes)[matched_nodes_map.at("max_to_leakyrelu")] = true;
2724
2725 for (const auto& node_index : remove_node_indices) {
2726 (*nodes_to_delete)[node_index] = true;
2727 }
2728
2729 return Status::OK();
2730 }
2731
ReplaceSigmoidMulWithSwish(RemapperContext * ctx,const std::map<string,int> & matched_nodes_map,const std::set<int> & remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2732 Status ReplaceSigmoidMulWithSwish(
2733 RemapperContext* ctx, const std::map<string, int>& matched_nodes_map,
2734 const std::set<int>& remove_node_indices,
2735 std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
2736 const NodeDef* mul =
2737 ctx->graph_view.GetNode(matched_nodes_map.at("mul_to_swish"))->node();
2738 const NodeDef* sigmoid =
2739 ctx->graph_view.GetNode(matched_nodes_map.at("sigmoid"))->node();
2740
2741 NodeDef fused_op;
2742 fused_op.set_name(mul->name());
2743 fused_op.set_op("_MklSwish");
2744 fused_op.set_device(mul->device());
2745 fused_op.add_input(sigmoid->input(0));
2746
2747 auto* attr = fused_op.mutable_attr();
2748 (*attr)["T"] = mul->attr().at("T");
2749
2750 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2751 Status status;
2752 mutation->AddNode(std::move(fused_op), &status);
2753 TF_RETURN_IF_ERROR(status);
2754 TF_RETURN_IF_ERROR(mutation->Apply());
2755
2756 (*invalidated_nodes)[matched_nodes_map.at("mul_to_swish")] = true;
2757
2758 for (const auto& node_index : remove_node_indices) {
2759 (*nodes_to_delete)[node_index] = true;
2760 }
2761 return OkStatus();
2762 }
2763
AddFusedBatchNormExNode(RemapperContext * ctx,const FusedBatchNormEx & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2764 Status AddFusedBatchNormExNode(RemapperContext* ctx,
2765 const FusedBatchNormEx& matched,
2766 std::vector<bool>* invalidated_nodes,
2767 std::vector<bool>* nodes_to_delete) {
2768 const GraphDef* graph = ctx->graph_view.graph();
2769 const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
2770 const NodeDef& activation = graph->node(matched.activation);
2771
2772 VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:"
2773 << " activation=" << activation.name() << " side_input="
2774 << (matched.side_input != kMissingIndex
2775 ? graph->node(matched.side_input).name()
2776 : "<none>")
2777 << " invalidated="
2778 << (matched.invalidated != kMissingIndex
2779 ? graph->node(matched.invalidated).name()
2780 : "<none>")
2781 << " fused_batch_norm=" << fused_batch_norm.name();
2782
2783 // Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
2784 NodeDef fused_op;
2785 fused_op.set_op(kFusedBatchNormEx);
2786 fused_op.set_name(fused_batch_norm.name());
2787 fused_op.set_device(fused_batch_norm.device());
2788
2789 fused_op.add_input(fused_batch_norm.input(0)); // 0: input
2790 fused_op.add_input(fused_batch_norm.input(1)); // 1: scale
2791 fused_op.add_input(fused_batch_norm.input(2)); // 2: offset
2792 fused_op.add_input(fused_batch_norm.input(3)); // 3: estimated_mean
2793 fused_op.add_input(fused_batch_norm.input(4)); // 4: estimated_var
2794
2795 CopyFusedBatchNormAttributes(fused_batch_norm, &fused_op);
2796
2797 auto* attrs = fused_op.mutable_attr();
2798 SetAttrValue(activation.op(), &(*attrs)["activation_mode"]);
2799
2800 if (matched.side_input != kMissingIndex) {
2801 SetAttrValue(1, &(*attrs)["num_side_inputs"]);
2802 const NodeDef& side_input = graph->node(matched.side_input);
2803 fused_op.add_input(side_input.name()); // 5: side_input
2804 } else {
2805 SetAttrValue(0, &(*attrs)["num_side_inputs"]);
2806 }
2807
2808 // Turn activation node into Identity node.
2809 NodeDef identity_op;
2810 identity_op.set_op("Identity");
2811 identity_op.set_name(activation.name());
2812 identity_op.set_device(fused_batch_norm.device());
2813 identity_op.add_input(fused_batch_norm.name());
2814 (*identity_op.mutable_attr())["T"] = attrs->at("T");
2815
2816 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2817 Status status;
2818 mutation->AddNode(std::move(fused_op), &status);
2819 TF_RETURN_IF_ERROR(status);
2820 mutation->AddNode(std::move(identity_op), &status);
2821 TF_RETURN_IF_ERROR(status);
2822 TF_RETURN_IF_ERROR(mutation->Apply());
2823
2824 (*invalidated_nodes)[matched.fused_batch_norm] = true;
2825 (*invalidated_nodes)[matched.activation] = true;
2826 if (matched.side_input != kMissingIndex) {
2827 (*nodes_to_delete)[matched.invalidated] = true;
2828 }
2829
2830 return OkStatus();
2831 }
2832
AddFusedBatchNormGradExNode(RemapperContext * ctx,const FusedBatchNormGradEx & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)2833 Status AddFusedBatchNormGradExNode(RemapperContext* ctx,
2834 const FusedBatchNormGradEx& matched,
2835 std::vector<bool>* invalidated_nodes,
2836 std::vector<bool>* nodes_to_delete) {
2837 const GraphDef* graph = ctx->graph_view.graph();
2838 const NodeDef& fused_batch_norm_grad =
2839 graph->node(matched.fused_batch_norm_grad);
2840 const NodeDef& activation_grad = graph->node(matched.activation_grad);
2841 const NodeDef& fwd_fused_batch_norm =
2842 graph->node(matched.fwd_fused_batch_norm);
2843
2844 VLOG(2) << "Fuse FusedBatchNormGrad with " << activation_grad.op() << ": "
2845 << " fused_batch_norm_grad=" << fused_batch_norm_grad.name()
2846 << " side_input="
2847 << (matched.side_input_grad != kMissingIndex
2848 ? graph->node(matched.side_input_grad).name()
2849 : "<none>")
2850 << " activation=" << activation_grad.name()
2851 << " corresponding FusedBatchNorm=" << fwd_fused_batch_norm.name();
2852
2853 NodeDef fused_op;
2854 fused_op.set_op(kFusedBatchNormGradEx);
2855 fused_op.set_name(fused_batch_norm_grad.name());
2856 fused_op.set_device(fused_batch_norm_grad.device());
2857
2858 fused_op.add_input(activation_grad.input(0)); // 0: y_backprop
2859 fused_op.add_input(fused_batch_norm_grad.input(1)); // 1: x
2860 fused_op.add_input(fused_batch_norm_grad.input(2)); // 2: scale
2861 fused_op.add_input(fused_batch_norm_grad.input(3)); // 3: reserve_space_1
2862 fused_op.add_input(fused_batch_norm_grad.input(4)); // 4: reserve_space_2
2863 fused_op.add_input(fused_batch_norm_grad.input(5)); // 5: reserve_space_3
2864 fused_op.add_input(fwd_fused_batch_norm.input(2)); // 6: offset
2865 fused_op.add_input(activation_grad.input(1)); // 7: y
2866
2867 CopyFusedBatchNormGradAttributes(fused_batch_norm_grad, &fused_op);
2868
2869 auto* attrs = fused_op.mutable_attr();
2870 // Only support Relu mode.
2871 SetAttrValue("Relu", &(*attrs)["activation_mode"]);
2872
2873 if (matched.side_input_grad != kMissingIndex) {
2874 SetAttrValue(1, &(*attrs)["num_side_inputs"]);
2875 } else {
2876 SetAttrValue(0, &(*attrs)["num_side_inputs"]);
2877 }
2878
2879 NodeDef identity_op;
2880 identity_op.set_op("Identity");
2881 identity_op.set_name(activation_grad.name());
2882 identity_op.set_device(fused_batch_norm_grad.device());
2883 identity_op.add_input(absl::StrCat(fused_batch_norm_grad.name(), ":5"));
2884 (*identity_op.mutable_attr())["T"] = attrs->at("T");
2885
2886 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2887 Status status;
2888 mutation->AddNode(std::move(fused_op), &status);
2889 TF_RETURN_IF_ERROR(status);
2890 if (matched.side_input_grad != kMissingIndex) {
2891 mutation->AddNode(std::move(identity_op), &status);
2892 TF_RETURN_IF_ERROR(status);
2893 }
2894 TF_RETURN_IF_ERROR(mutation->Apply());
2895
2896 (*invalidated_nodes)[matched.fused_batch_norm_grad] = true;
2897 if (matched.side_input_grad != kMissingIndex) {
2898 (*invalidated_nodes)[matched.activation_grad] = true;
2899 } else {
2900 (*nodes_to_delete)[matched.activation_grad] = true;
2901 }
2902
2903 return OkStatus();
2904 }
2905
AddBatchNormNodes(RemapperContext * ctx,const FusedBatchNorm & matched)2906 Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
2907 const GraphDef* graph = ctx->graph_view.graph();
2908 const NodeDef& fused_node = graph->node(matched.fused_batch_norm);
2909 VLOG(2) << "Optimizing fused batch norm node "
2910 << SummarizeNodeDef(fused_node);
2911
2912 const string& x = fused_node.input(0);
2913 string scale = fused_node.input(1);
2914 string offset = fused_node.input(2);
2915 string mean = fused_node.input(3);
2916 string variance = fused_node.input(4);
2917
2918 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
2919 Status status;
2920
2921 string x_format = fused_node.attr().at(kDataFormat).s();
2922 if (x_format == "NCHW" || x_format == "NCDHW") {
2923 // Need to reshape the last 4 inputs
2924 NodeDef new_shape;
2925 const string new_shape_name =
2926 AddPrefixToNodeName(x_format + "Shape", fused_node.name());
2927 new_shape.set_name(new_shape_name);
2928 new_shape.set_op("Const");
2929 new_shape.set_device(fused_node.device());
2930 *new_shape.add_input() = AsControlDependency(scale);
2931 (*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
2932 if (x_format == "NCHW") {
2933 Tensor t(DT_INT32, {4});
2934 t.flat<int32>()(0) = 1;
2935 t.flat<int32>()(1) = -1;
2936 t.flat<int32>()(2) = 1;
2937 t.flat<int32>()(3) = 1;
2938 t.AsProtoTensorContent(
2939 (*new_shape.mutable_attr())["value"].mutable_tensor());
2940 } else {
2941 Tensor t(DT_INT32, {5});
2942 t.flat<int32>()(0) = 1;
2943 t.flat<int32>()(1) = -1;
2944 t.flat<int32>()(2) = 1;
2945 t.flat<int32>()(3) = 1;
2946 t.flat<int32>()(4) = 1;
2947 t.AsProtoTensorContent(
2948 (*new_shape.mutable_attr())["value"].mutable_tensor());
2949 }
2950 mutation->AddNode(std::move(new_shape), &status);
2951 TF_RETURN_IF_ERROR(status);
2952
2953 NodeDef reshaped_scale;
2954 reshaped_scale.set_name(
2955 AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name()));
2956 reshaped_scale.set_op("Reshape");
2957 reshaped_scale.set_device(fused_node.device());
2958 *reshaped_scale.add_input() = scale;
2959 *reshaped_scale.add_input() = new_shape_name;
2960 (*reshaped_scale.mutable_attr())["T"] = fused_node.attr().at("T");
2961 (*reshaped_scale.mutable_attr())["Tshape"].set_type(DT_INT32);
2962 scale = reshaped_scale.name();
2963 mutation->AddNode(std::move(reshaped_scale), &status);
2964 TF_RETURN_IF_ERROR(status);
2965
2966 NodeDef reshaped_offset;
2967 reshaped_offset.set_name(
2968 AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name()));
2969 reshaped_offset.set_op("Reshape");
2970 reshaped_offset.set_device(fused_node.device());
2971 *reshaped_offset.add_input() = offset;
2972 *reshaped_offset.add_input() = new_shape_name;
2973 (*reshaped_offset.mutable_attr())["T"] = fused_node.attr().at("T");
2974 (*reshaped_offset.mutable_attr())["Tshape"].set_type(DT_INT32);
2975 offset = reshaped_offset.name();
2976 mutation->AddNode(std::move(reshaped_offset), &status);
2977 TF_RETURN_IF_ERROR(status);
2978
2979 NodeDef reshaped_mean;
2980 reshaped_mean.set_name(
2981 AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name()));
2982 reshaped_mean.set_op("Reshape");
2983 reshaped_mean.set_device(fused_node.device());
2984 *reshaped_mean.add_input() = mean;
2985 *reshaped_mean.add_input() = new_shape_name;
2986 (*reshaped_mean.mutable_attr())["T"] = fused_node.attr().at("T");
2987 (*reshaped_mean.mutable_attr())["Tshape"].set_type(DT_INT32);
2988 mean = reshaped_mean.name();
2989 mutation->AddNode(std::move(reshaped_mean), &status);
2990 TF_RETURN_IF_ERROR(status);
2991
2992 NodeDef reshaped_variance;
2993 reshaped_variance.set_name(
2994 AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name()));
2995 reshaped_variance.set_op("Reshape");
2996 reshaped_variance.set_device(fused_node.device());
2997 *reshaped_variance.add_input() = variance;
2998 *reshaped_variance.add_input() = new_shape_name;
2999 (*reshaped_variance.mutable_attr())["T"] = fused_node.attr().at("T");
3000 (*reshaped_variance.mutable_attr())["Tshape"].set_type(DT_INT32);
3001 variance = reshaped_variance.name();
3002 mutation->AddNode(std::move(reshaped_variance), &status);
3003 TF_RETURN_IF_ERROR(status);
3004 }
3005
3006 float epsilon = 0.0f;
3007 if (fused_node.attr().count("epsilon")) {
3008 epsilon = fused_node.attr().at("epsilon").f();
3009 }
3010 DataType dtype = fused_node.attr().at("T").type();
3011 Tensor value(dtype, TensorShape());
3012 value.scalar<float>()() = epsilon;
3013 NodeDef variance_epsilon;
3014 const string variance_epsilon_name =
3015 AddPrefixToNodeName("Const", fused_node.name());
3016 TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
3017 variance_epsilon_name, TensorValue(&value), &variance_epsilon));
3018 variance_epsilon.set_device(fused_node.device());
3019 mutation->AddNode(std::move(variance_epsilon), &status);
3020 TF_RETURN_IF_ERROR(status);
3021
3022 NodeDef variance_plus_epsilon;
3023 const string variance_plus_epsilon_name =
3024 AddPrefixToNodeName("VarPlusEpsilon", fused_node.name());
3025 variance_plus_epsilon.set_name(variance_plus_epsilon_name);
3026 variance_plus_epsilon.set_op("Add");
3027 (*variance_plus_epsilon.mutable_attr())["T"].set_type(dtype);
3028 variance_plus_epsilon.set_device(fused_node.device());
3029 *variance_plus_epsilon.add_input() = variance;
3030 *variance_plus_epsilon.add_input() = variance_epsilon_name;
3031 mutation->AddNode(std::move(variance_plus_epsilon), &status);
3032 TF_RETURN_IF_ERROR(status);
3033
3034 NodeDef inv;
3035 const string inv_name = AddPrefixToNodeName("Inv", fused_node.name());
3036 inv.set_name(inv_name);
3037 inv.set_op("Rsqrt");
3038 inv.set_device(fused_node.device());
3039 (*inv.mutable_attr())["T"].set_type(dtype);
3040 *inv.add_input() = variance_plus_epsilon_name;
3041 mutation->AddNode(std::move(inv), &status);
3042 TF_RETURN_IF_ERROR(status);
3043
3044 NodeDef scaled;
3045 const string scaled_name = AddPrefixToNodeName("Scaled", fused_node.name());
3046 scaled.set_name(scaled_name);
3047 scaled.set_op("Mul");
3048 scaled.set_device(fused_node.device());
3049 (*scaled.mutable_attr())["T"].set_type(dtype);
3050 *scaled.add_input() = inv_name;
3051 *scaled.add_input() = scale;
3052 mutation->AddNode(std::move(scaled), &status);
3053 TF_RETURN_IF_ERROR(status);
3054
3055 NodeDef a;
3056 const string a_name = AddPrefixToNodeName("Mul", fused_node.name());
3057 a.set_name(a_name);
3058 a.set_op("Mul");
3059 a.set_device(fused_node.device());
3060 (*a.mutable_attr())["T"].set_type(dtype);
3061 *a.add_input() = x;
3062 *a.add_input() = scaled_name;
3063 mutation->AddNode(std::move(a), &status);
3064 TF_RETURN_IF_ERROR(status);
3065
3066 NodeDef b;
3067 const string b_name = AddPrefixToNodeName("Mul2", fused_node.name());
3068 b.set_name(b_name);
3069 b.set_op("Mul");
3070 b.set_device(fused_node.device());
3071 (*b.mutable_attr())["T"].set_type(dtype);
3072 *b.add_input() = mean;
3073 *b.add_input() = scaled_name;
3074 mutation->AddNode(std::move(b), &status);
3075 TF_RETURN_IF_ERROR(status);
3076
3077 NodeDef c;
3078 const string c_name = AddPrefixToNodeName("Offset", fused_node.name());
3079 c.set_name(c_name);
3080 c.set_op("Sub");
3081 c.set_device(fused_node.device());
3082 (*c.mutable_attr())["T"].set_type(dtype);
3083 *c.add_input() = offset;
3084 *c.add_input() = b_name;
3085 mutation->AddNode(std::move(c), &status);
3086 TF_RETURN_IF_ERROR(status);
3087
3088 NodeDef r;
3089 r.set_name(fused_node.name());
3090 r.set_op("Add");
3091 r.set_device(fused_node.device());
3092 (*r.mutable_attr())["T"].set_type(dtype);
3093 *r.add_input() = a_name;
3094 *r.add_input() = c_name;
3095 mutation->AddNode(std::move(r), &status);
3096 TF_RETURN_IF_ERROR(status);
3097
3098 return mutation->Apply();
3099 }
3100
AddTensorToHashBucketNode(RemapperContext * ctx,const TensorToHashBucket & matched,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)3101 Status AddTensorToHashBucketNode(RemapperContext* ctx,
3102 const TensorToHashBucket& matched,
3103 std::vector<bool>* invalidated_nodes,
3104 std::vector<bool>* nodes_to_delete) {
3105 const GraphDef* graph = ctx->graph_view.graph();
3106 const NodeDef& pre_as_string = graph->node(matched.pre_as_string);
3107 const NodeDef& as_string = graph->node(matched.as_string);
3108 const NodeDef& string_to_hash_bucket =
3109 graph->node(matched.string_to_hash_bucket);
3110 VLOG(2) << "Fuse AsString with StringToHashBucketFast:"
3111 << " as_string=" << as_string.name()
3112 << " string_to_hash_bucket=" << string_to_hash_bucket.name()
3113 << " on device=" << pre_as_string.device();
3114
3115 NodeDef fused_op;
3116 fused_op.set_name(string_to_hash_bucket.name());
3117 fused_op.set_device(pre_as_string.device());
3118 fused_op.add_input(as_string.input(0)); // 0: input
3119 fused_op.set_op(kTensorToHashBucket);
3120
3121 auto* attr = fused_op.mutable_attr();
3122 auto& src_attr0 = as_string.attr();
3123 auto& src_attr1 = string_to_hash_bucket.attr();
3124 (*attr)["T"] = src_attr0.at("T");
3125 (*attr)["num_buckets"] = src_attr1.at("num_buckets");
3126
3127 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
3128 Status status;
3129 mutation->AddNode(std::move(fused_op), &status);
3130 TF_RETURN_IF_ERROR(status);
3131 TF_RETURN_IF_ERROR(mutation->Apply());
3132
3133 (*invalidated_nodes)[matched.string_to_hash_bucket] = true;
3134 (*nodes_to_delete)[matched.as_string] = true;
3135
3136 return OkStatus();
3137 }
3138
AddFusedBatchMatMul(RemapperContext * ctx,const std::map<string,int> & matched_nodes_map,const std::set<int> & remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)3139 Status AddFusedBatchMatMul(RemapperContext* ctx,
3140 const std::map<string, int>& matched_nodes_map,
3141 const std::set<int>& remove_node_indices,
3142 std::vector<bool>* invalidated_nodes,
3143 std::vector<bool>* nodes_to_delete) {
3144 auto* output_node =
3145 ctx->graph_view.GetNode(matched_nodes_map.at("output"))->node();
3146 auto* batch_matmul_node =
3147 ctx->graph_view.GetNode(matched_nodes_map.at("batch_matmul"))->node();
3148 auto* multiplicand_node =
3149 ctx->graph_view.GetNode(matched_nodes_map.at("multiplicand"))->node();
3150 auto* addend_node =
3151 ctx->graph_view.GetNode(matched_nodes_map.at("addend"))->node();
3152
3153 NodeDef fused_node;
3154 fused_node.set_name(output_node->name());
3155 fused_node.set_op("_MklFusedBatchMatMulV2");
3156 fused_node.set_device(batch_matmul_node->device());
3157 fused_node.add_input(batch_matmul_node->input(0));
3158 fused_node.add_input(batch_matmul_node->input(1));
3159 fused_node.add_input(multiplicand_node->name());
3160 fused_node.add_input(addend_node->name());
3161
3162 CopyBatchMatMulAttributes(*batch_matmul_node, &fused_node);
3163 SetFusedOpAttributes(&fused_node, {"Mul", "Add"}, /*num_args=*/2);
3164
3165 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
3166 Status status;
3167 mutation->AddNode(std::move(fused_node), &status);
3168 TF_RETURN_IF_ERROR(status);
3169 TF_RETURN_IF_ERROR(mutation->Apply());
3170 (*invalidated_nodes)[matched_nodes_map.at("output")] = true;
3171
3172 for (const auto& node_idx : remove_node_indices) {
3173 (*nodes_to_delete)[node_idx] = true;
3174 }
3175 return OkStatus();
3176 }
3177
3178 // This function supports below patterns that require inferred
3179 // shapes:
3180 // 1. Contraction + Add.
3181 // 2. Contraction + Add + Activation.
3182 // 3. Contraction + BiasAdd/BiasSemanticAdd + Add.
3183 // 4. Contraction + BiasAdd/BiasSemanticAdd + Add + Activation.
3184 // Contraction candidate: MatMul, Conv2D, Conv3D, DepthwiseConv2dNative.
IsContractionWithAdd(const RemapperContext & ctx,int node_index)3185 bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
3186 const auto* node_view = ctx.graph_view.GetNode(node_index);
3187
3188 auto is_supported_add_input = [](const auto* node_view) -> bool {
3189 if (IsConvOrMatMul(*node_view->node())) return true;
3190 // IsAdd will verify BiasSemanticAdd.
3191 if (IsBiasAdd(*node_view->node()) || IsAdd(*node_view->node())) {
3192 if (node_view->NumRegularFanins() < 2) return false;
3193 const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
3194 const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
3195 return IsConvOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
3196 IsConvOrMatMul(*bias_add_fanin_1.node_view()->node());
3197 }
3198 return false;
3199 };
3200
3201 auto is_supported_add = [&](const auto* node_view) -> bool {
3202 const auto* node_def = node_view->node();
3203 if (IsAdd(*node_def)) {
3204 if (node_view->NumRegularFanins() < 2) return false;
3205 const auto& add_fanin_0 = node_view->GetRegularFanin(0);
3206 const auto& add_fanin_1 = node_view->GetRegularFanin(1);
3207 return is_supported_add_input(add_fanin_0.node_view()) ||
3208 is_supported_add_input(add_fanin_1.node_view());
3209 }
3210 return false;
3211 };
3212
3213 // Dealing with the Contraction + Add or Contraction + BiasAdd/BiasSemanticAdd
3214 // + Add patterns.
3215 if (is_supported_add(node_view)) {
3216 return true;
3217 }
3218 // Dealing with the Contraction + Add + Activation or Contraction +
3219 // BiasAdd/BiasSemanticAdd + Add + Activation pattern.
3220 if (IsSupportedActivation(*node_view->node())) {
3221 for (int i = 0; i < node_view->NumRegularFanins(); i++) {
3222 const auto& fanin_i = node_view->GetRegularFanin(i);
3223 if (is_supported_add(fanin_i.node_view())) return true;
3224 }
3225 }
3226
3227 return false;
3228 }
3229
FindSoftplusAndTanhAndMul(RemapperContext * ctx,int node_index,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)3230 bool FindSoftplusAndTanhAndMul(RemapperContext* ctx, int node_index,
3231 std::map<string, int>* matched_nodes_map,
3232 std::set<int>* remove_node_indices) {
3233 // Mish fusion is enabled only with oneDNN library.
3234 if (!IsMKLEnabled()) return false;
3235
3236 using utils::MatchingDirection;
3237 using utils::NodeStatus;
3238 // clang-format off
3239 // Convert Softplus+Tanh+Mul to Mish
3240 // From Graph To Graph
3241 // ----------- ---------
3242 // Conv2D <- Filter(const) Conv2D <- Filter(const)
3243 // ! !
3244 // V V
3245 // BiasAdd <- bias(const) BiasAdd <- bias(const)
3246 // ! !
3247 // V !
3248 // ---- ---- !
3249 // ! ! !
3250 // ! V !
3251 // ! Softplus !
3252 // ! ! !
3253 // ! V !
3254 // ! Tanh !
3255 // ! ! !
3256 // ! V !
3257 // --- --- !
3258 // ! ! !
3259 // ! ! !
3260 // V V V
3261 // Mul _MklFusedMish
3262 // ! !
3263 // V V
3264
3265 utils::OpTypePattern softplustanhmul_pattern {
3266 "Mul", "mul_to_mish", NodeStatus::kReplace,
3267 {
3268 {
3269 "Tanh", "tanh", NodeStatus::kRemove,
3270 {
3271 {
3272 "Softplus", "softplus", NodeStatus::kRemove,
3273 {
3274 {"*", "input", NodeStatus::kRemain}
3275 }
3276 }
3277 }
3278 },
3279 {"*", "input", NodeStatus::kRemain}
3280 }
3281 };
3282 // clang-format on
3283
3284 // check for data types
3285 auto* mul_node_def = ctx->graph_view.GetNode(node_index)->node();
3286 if (!HasDataType(mul_node_def, DT_FLOAT) &&
3287 !HasDataType(mul_node_def, DT_BFLOAT16))
3288 return false;
3289
3290 if (!NodeIsOnCpu(mul_node_def)) return false;
3291
3292 bool found_op_type_match = false;
3293 utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
3294 &(ctx->graph_view));
3295 matched_nodes_map->clear();
3296 remove_node_indices->clear();
3297 found_op_type_match = graph_matcher.GetMatchedNodes(
3298 softplustanhmul_pattern, {}, ctx->graph_view.GetNode(node_index),
3299 matched_nodes_map, remove_node_indices);
3300
3301 return found_op_type_match;
3302 }
3303
ReplaceSoftplusTanhAndMulWithMish(RemapperContext * ctx,const std::map<string,int> * matched_nodes_map,const std::set<int> * remove_node_indices,std::vector<bool> * invalidated_nodes,std::vector<bool> * nodes_to_delete)3304 Status ReplaceSoftplusTanhAndMulWithMish(
3305 RemapperContext* ctx, const std::map<string, int>* matched_nodes_map,
3306 const std::set<int>* remove_node_indices,
3307 std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
3308 // Fuse Softplus + Tanh + Mul to Mish
3309 auto* old_mul_node =
3310 ctx->graph_view.GetNode(matched_nodes_map->at("mul_to_mish"))->node();
3311 auto* softplus_node =
3312 ctx->graph_view.GetNode(matched_nodes_map->at("softplus"))->node();
3313
3314 NodeDef fused_node;
3315 fused_node.set_name(old_mul_node->name());
3316 fused_node.set_op("_MklFusedMish");
3317 fused_node.set_device(old_mul_node->device());
3318 fused_node.add_input(softplus_node->input(0));
3319
3320 auto* fused_node_attr = fused_node.mutable_attr();
3321 (*fused_node_attr)["T"] = old_mul_node->attr().at("T");
3322
3323 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
3324 Status status;
3325 mutation->AddNode(std::move(fused_node), &status);
3326 TF_RETURN_IF_ERROR(status);
3327 TF_RETURN_IF_ERROR(mutation->Apply());
3328 (*invalidated_nodes)[matched_nodes_map->at("mul_to_mish")] = true;
3329
3330 for (const auto& node_index : *remove_node_indices) {
3331 (*nodes_to_delete)[node_index] = true;
3332 }
3333
3334 return OkStatus();
3335 }
3336
3337 // Check if a node is a candidate to one of the patterns that require inferred
3338 // shapes:
3339 // (1) Splitting FusedBatchNorm into primitives.
3340 // (2) Fusing side input and/or activation into FusedBatchNorm.
3341 // (3) Fusing Conv2D biasadd and relu on GPU
3342 // (4) INTEL_MKL specific: Conv2D -> Add or Conv2D -> BiasAdd -> Add.
3343 // (5) Fusing side output and/or activation into FusedBatchNormGrad.
RequiresInferredShapes(const RemapperContext & ctx,int node_index,const Cluster * cluster)3344 bool RequiresInferredShapes(const RemapperContext& ctx, int node_index,
3345 const Cluster* cluster) {
3346 // Candidate for a FusedBatchNorm splitting.
3347 const auto* node_view = ctx.graph_view.GetNode(node_index);
3348 const auto* node_def = node_view->node();
3349 const auto is_batch_norm_candidate = [&]() -> bool {
3350 if (!IsFusedBatchNorm(*node_def)) return false;
3351 if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
3352
3353 bool is_training = true;
3354 if (!TryGetNodeAttr(*node_def, kIsTraining, &is_training)) return false;
3355 if (is_training) return false;
3356
3357 return true;
3358 };
3359
3360 const auto is_act_biasadd_conv_candidate = [&]() -> bool {
3361 if (!IsSupportedActivation(*node_def)) return false;
3362
3363 if (!RuntimeFusionEnabled(cluster) && !IsRelu(*node_def)) return false;
3364
3365 const auto is_compatible_dtype = [&](const NodeDef& node) -> bool {
3366 // The cuDNN fusion with relu6, elu, leakyrelu is realized by the runtime
3367 // compiled kernels which only support fp16.
3368 bool fp16_only =
3369 IsRelu6(*node_def) || IsElu(*node_def) || IsLeakyRelu(*node_def);
3370 DataType dtype = GetDataTypeFromAttr(node, "T");
3371 return dtype == DT_HALF || (!fp16_only && dtype == DT_FLOAT);
3372 };
3373 if (!is_compatible_dtype(*node_def)) return false;
3374
3375 if (node_view->NumRegularFanins() < 1) return false;
3376 const auto& relu_fanin_0 = node_view->GetRegularFanin(0);
3377 const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
3378 const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
3379
3380 if (!IsBiasAdd(*relu_fanin_0_node_def) && !IsAdd(*relu_fanin_0_node_def))
3381 return false;
3382 if (!is_compatible_dtype(*relu_fanin_0_node_def)) return false;
3383
3384 if (relu_fanin_0_node_view->NumRegularFanins() < 1) return false;
3385
3386 const auto& biasadd_fanin_0 = relu_fanin_0_node_view->GetRegularFanin(0);
3387 const auto* biasadd_fanin_0_node_def = biasadd_fanin_0.node_view()->node();
3388
3389 if (!IsConv2D(*biasadd_fanin_0_node_def) &&
3390 !IsConv3D(*biasadd_fanin_0_node_def))
3391 return false;
3392 if (!is_compatible_dtype(*biasadd_fanin_0_node_def)) return false;
3393 return true;
3394 };
3395
3396 // Candidate for a FusedBatchNorm fusion.
3397 const auto is_batch_norm_fusion_candidate = [&]() -> bool {
3398 if (!IsRelu(*node_def)) return false;
3399
3400 if (node_view->NumRegularFanins() < 1) return false;
3401 const auto& relu_fanin_0 = node_view->GetRegularFanin(0);
3402 const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
3403 const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
3404
3405 if (IsFusedBatchNorm(*relu_fanin_0_node_def)) {
3406 // FusedBatchNorm + Relu.
3407 return true;
3408
3409 } else if (IsAdd(*relu_fanin_0_node_def)) {
3410 // FusedBatchNorm + Add + Relu.
3411
3412 if (relu_fanin_0_node_view->NumRegularFanins() < 2) return false;
3413 const auto& add_regular_fanin_0 =
3414 relu_fanin_0_node_view->GetRegularFanin(0);
3415 if (IsFusedBatchNorm(*add_regular_fanin_0.node_view()->node()))
3416 return true;
3417 const auto& add_regular_fanin_1 =
3418 relu_fanin_0_node_view->GetRegularFanin(1);
3419 if (IsFusedBatchNorm(*add_regular_fanin_1.node_view()->node()))
3420 return true;
3421 }
3422
3423 return false;
3424 };
3425
3426 // Candidate for a FusedBatchNormGrad fusion.
3427 const auto is_batch_norm_grad_fusion_candidate = [&]() -> bool {
3428 if (!IsFusedBatchNormGrad(*node_def)) return false;
3429
3430 if (node_view->NumRegularFanins() < 1) return false;
3431 const auto& bn_fanin_0 = node_view->GetRegularFanin(0);
3432 const auto* bn_fanin_0_node_view = bn_fanin_0.node_view();
3433 const auto* bn_fanin_0_node_def = bn_fanin_0_node_view->node();
3434
3435 if (IsReluGrad(*bn_fanin_0_node_def)) {
3436 // ReluGrad + FusedBatchNormGrad.
3437 return true;
3438 }
3439
3440 return false;
3441 };
3442
3443 if (IsMKLEnabled())
3444 return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
3445 IsContractionWithAdd(ctx, node_index) ||
3446 is_act_biasadd_conv_candidate();
3447
3448 return is_act_biasadd_conv_candidate() || is_batch_norm_candidate() ||
3449 is_batch_norm_fusion_candidate() ||
3450 is_batch_norm_grad_fusion_candidate();
3451 }
3452 } // namespace
3453
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)3454 Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
3455 GraphDef* optimized_graph) {
3456 GrapplerItem mutable_item = item;
3457 Status status;
3458 RemapperContext ctx(&mutable_item, &status, cpu_layout_conversion_,
3459 xla_auto_clustering_on_);
3460 TF_RETURN_IF_ERROR(status);
3461 // Processing graph in reverse-topological sorted order allows to remap
3462 // longer chains of dependent ops in one pass.
3463 TF_RETURN_IF_ERROR(
3464 ctx.graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
3465
3466 const int num_nodes = item.graph.node_size();
3467 // Skip nodes that were invalidated by a remapper, e.g. do not process BiasAdd
3468 // and Activation nodes that were fused into a Conv2D node.
3469 std::vector<bool> invalidated_nodes(num_nodes);
3470 std::vector<bool> nodes_to_delete(num_nodes);
3471
3472 // _Fused{...} kernels do not have registered gradient function, so we must
3473 // not perform rewrite if the graph will be differentiated later.
3474 bool allow_non_differentiable_rewrites =
3475 item.optimization_options().allow_non_differentiable_rewrites;
3476
3477 for (int i = num_nodes - 1; i >= 0; --i) {
3478 // Check if node was invalidated by one of the previous remaps.
3479 if (invalidated_nodes[i] || nodes_to_delete[i]) {
3480 continue;
3481 }
3482
3483 // Infer properties lazily in case they are not needed.
3484 if (!ctx.inferred_graph_properties &&
3485 RequiresInferredShapes(ctx, i, cluster)) {
3486 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3487 TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
3488 assume_valid_feeds,
3489 /*aggressive_shape_inference=*/false,
3490 /*include_input_tensor_values=*/true,
3491 /*include_output_tensor_values=*/false));
3492 ctx.inferred_graph_properties = true;
3493 }
3494
3495 ContractionWithBiasAddAndAdd contract_with_bias_and_add;
3496 ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
3497
3498 if (IsMKLEnabled()) {
3499 // Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D.
3500 // or Remap Conv3D+BiasAdd+Add+relu into _FusedConv3D
3501 if (FindContractionWithBiasAndAddActivation(
3502 ctx, i, &contract_with_bias_and_add_activation)) {
3503 TF_RETURN_IF_ERROR(
3504 AddFusedContractionNode(&ctx, contract_with_bias_and_add_activation,
3505 &invalidated_nodes, &nodes_to_delete));
3506 continue;
3507 }
3508
3509 // Remap {Conv2D,Conv3D}+BiasAdd+Add into the _FusedConv2D/3D.
3510 if (FindContractionWithBiasAddAndAdd(ctx, i,
3511 &contract_with_bias_and_add)) {
3512 TF_RETURN_IF_ERROR(
3513 AddFusedContractionNode(&ctx, contract_with_bias_and_add,
3514 &invalidated_nodes, &nodes_to_delete));
3515 continue;
3516 }
3517
3518 PadWithConv3D pad_with_conv3d;
3519 // Remap Pad+{Conv3D,_FusedConv3D} into the _FusedConv3D.
3520 if (FindPadWithConv3D(ctx, i, &pad_with_conv3d)) {
3521 TF_RETURN_IF_ERROR(AddFusedConv3DNode(
3522 &ctx, pad_with_conv3d, &invalidated_nodes, &nodes_to_delete));
3523 continue;
3524 }
3525
3526 std::map<string, int> matched_nodes_map;
3527 std::set<int> remove_node_indices;
3528
3529 // Softplus + Tanh + Mul to Mish conversion
3530 matched_nodes_map.clear();
3531 remove_node_indices.clear();
3532 if (FindSoftplusAndTanhAndMul(&ctx, i, &matched_nodes_map,
3533 &remove_node_indices)) {
3534 TF_RETURN_IF_ERROR(ReplaceSoftplusTanhAndMulWithMish(
3535 &ctx, &matched_nodes_map, &remove_node_indices, &invalidated_nodes,
3536 &nodes_to_delete));
3537 continue;
3538 }
3539
3540 // Remap BatchMatMul+Mul+AddV2 into the _FusedBatchMatMul.
3541 matched_nodes_map.clear();
3542 remove_node_indices.clear();
3543 if (FindFusedBatchMatMul(&ctx, i, &matched_nodes_map,
3544 &remove_node_indices)) {
3545 TF_RETURN_IF_ERROR(
3546 AddFusedBatchMatMul(&ctx, matched_nodes_map, remove_node_indices,
3547 &invalidated_nodes, &nodes_to_delete));
3548 continue;
3549 }
3550
3551 // Remap Maximum(x, alpha * x) pattern, fuse them into the LeakyRelu(x).
3552 std::map<string, int> mulmax_matched_nodes_map;
3553 std::set<int> mulmax_remove_node_indices;
3554 if (FindMulAndMaximum(&ctx, i, &mulmax_matched_nodes_map,
3555 &mulmax_remove_node_indices)) {
3556 TF_RETURN_IF_ERROR(ReplaceMulMaximumWithLeakyRelu(
3557 &ctx, mulmax_matched_nodes_map, mulmax_remove_node_indices,
3558 &invalidated_nodes, &nodes_to_delete));
3559 continue;
3560 }
3561
3562 // Remap Mul(x, Sigmoid(x)) pattern, fuse them into the Swish(x).
3563 std::map<string, int> sigmoidmul_matched_nodes_map;
3564 std::set<int> sigmoidmul_remove_node_indices;
3565 if (FindSigmoidAndMul(&ctx, i, &sigmoidmul_matched_nodes_map,
3566 &sigmoidmul_remove_node_indices)) {
3567 TF_RETURN_IF_ERROR(ReplaceSigmoidMulWithSwish(
3568 &ctx, sigmoidmul_matched_nodes_map, sigmoidmul_remove_node_indices,
3569 &invalidated_nodes, &nodes_to_delete));
3570 continue;
3571 }
3572
3573 // Remap smaller ops from layernorm python api into _MklLayerNorm
3574 matched_nodes_map.clear();
3575 remove_node_indices.clear();
3576 if (FindMklLayerNorm(&ctx, i, &matched_nodes_map, &remove_node_indices)) {
3577 TF_RETURN_IF_ERROR(
3578 AddMklLayerNorm(&ctx, matched_nodes_map, remove_node_indices,
3579 &invalidated_nodes, &nodes_to_delete));
3580 continue;
3581 }
3582 }
3583
3584 // Remap MatMul + BiasAdd + gelu-subgraph
3585 std::map<string, int> matched_nodes_map;
3586 std::set<int> remove_node_indices;
3587 bool is_gelu_approximate = false;
3588 if (FindMatMulBiasAddAndGelu(&ctx, i, &matched_nodes_map,
3589 &remove_node_indices, &is_gelu_approximate)) {
3590 TF_RETURN_IF_ERROR(AddFusedMatMulBiasAddAndGelu(
3591 &ctx, matched_nodes_map, remove_node_indices, &invalidated_nodes,
3592 &nodes_to_delete, is_gelu_approximate));
3593 continue;
3594 }
3595
3596 // Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd into the
3597 // _Fused{Conv2D,DepthwiseConv2dNative,MatMul}
3598 ContractionWithBiasAdd contract_with_bias;
3599 if (allow_non_differentiable_rewrites &&
3600 FindContractionWithBias(ctx, i, &contract_with_bias)) {
3601 TF_RETURN_IF_ERROR(AddFusedContractionNode(
3602 &ctx, contract_with_bias, &invalidated_nodes, &nodes_to_delete));
3603 continue;
3604 }
3605
3606 // Remap {Conv2D,DepthwiseConv2D,MatMul,Conv3D}+BiasAdd+Activation into the
3607 // _Fused{Conv2D,DepthwiseConv2dNative,MatMul,Conv3D}.
3608 ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
3609 if (allow_non_differentiable_rewrites &&
3610 FindContractionWithBiasAndActivation(
3611 ctx, cluster, i, &contract_with_bias_and_activation)) {
3612 TF_RETURN_IF_ERROR(
3613 AddFusedContractionNode(&ctx, contract_with_bias_and_activation,
3614 &invalidated_nodes, &nodes_to_delete));
3615 continue;
3616 }
3617
3618 // NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do
3619 // it for MatMul as well, but in practice this pattern does not appear in
3620 // real Tensorflow graphs.
3621
3622 // Remap {Conv2D, Conv3D}+Squeeze+BiasAdd into the {_FusedConv2D,
3623 // _FusedConv3D}+Squeeze.
3624 ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
3625 if (allow_non_differentiable_rewrites &&
3626 FindConvWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) {
3627 TF_RETURN_IF_ERROR(AddFusedConvNode(&ctx, contract_with_squeeze_and_bias,
3628 &invalidated_nodes,
3629 &nodes_to_delete));
3630 continue;
3631 }
3632
3633 #ifndef DNNL_AARCH64_USE_ACL
3634 // Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
3635 ContractionWithBatchNorm contract_with_batch_norm;
3636 if (allow_non_differentiable_rewrites &&
3637 FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
3638 TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
3639 &invalidated_nodes,
3640 &nodes_to_delete));
3641 continue;
3642 }
3643
3644 // Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
3645 ContractionWithBatchNormAndActivation
3646 contract_with_batch_norm_and_activation;
3647 if (allow_non_differentiable_rewrites &&
3648 FindConv2DWithBatchNormAndActivation(
3649 ctx, i, &contract_with_batch_norm_and_activation)) {
3650 TF_RETURN_IF_ERROR(
3651 AddFusedConv2DNode(&ctx, contract_with_batch_norm_and_activation,
3652 &invalidated_nodes, &nodes_to_delete));
3653 continue;
3654 }
3655 #endif // !DNNL_AARCH64_USE_ACL
3656
3657 // Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
3658 FusedBatchNormEx fused_batch_norm_ex;
3659 if (allow_non_differentiable_rewrites &&
3660 FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) {
3661 TF_RETURN_IF_ERROR(AddFusedBatchNormExNode(
3662 &ctx, fused_batch_norm_ex, &invalidated_nodes, &nodes_to_delete));
3663 continue;
3664 }
3665
3666 FusedBatchNormGradEx fused_batch_norm_grad_ex;
3667 if (allow_non_differentiable_rewrites &&
3668 FindFusedBatchNormGradEx(ctx, i, &fused_batch_norm_grad_ex)) {
3669 TF_RETURN_IF_ERROR(
3670 AddFusedBatchNormGradExNode(&ctx, fused_batch_norm_grad_ex,
3671 &invalidated_nodes, &nodes_to_delete));
3672 continue;
3673 }
3674
3675 TensorToHashBucket tensor_to_hash_bucket;
3676 if (allow_non_differentiable_rewrites &&
3677 FindTensorToHashBucket(ctx, i, &tensor_to_hash_bucket)) {
3678 TF_RETURN_IF_ERROR(AddTensorToHashBucketNode(
3679 &ctx, tensor_to_hash_bucket, &invalidated_nodes, &nodes_to_delete));
3680 continue;
3681 }
3682
3683 // During inference, most of the inputs to FusedBatchNorm are constant, and
3684 // we can therefore replace the op with a much cheaper set of primitives.
3685 FusedBatchNorm fused_batch_norm;
3686 if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) {
3687 TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
3688 continue;
3689 }
3690 }
3691
3692 // Remove invalidated nodes.
3693 utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder();
3694 for (int i = 0; i < num_nodes; ++i) {
3695 if (nodes_to_delete[i]) {
3696 mutation->RemoveNode(ctx.graph_view.GetNode(i));
3697 }
3698 }
3699 TF_RETURN_IF_ERROR(mutation->Apply());
3700
3701 *optimized_graph = std::move(mutable_item.graph);
3702
3703 return OkStatus();
3704 }
3705
3706 } // namespace grappler
3707 } // namespace tensorflow
3708