1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/remapper.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/core/framework/versions.pb.h"
20 #include "tensorflow/core/grappler/costs/graph_properties.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/grappler/utils/graph_view.h"
27 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
28 #include "tensorflow/core/grappler/utils/topological_sort.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/core/util/util.h"
33
34 #if GOOGLE_CUDA
35 #include "third_party/gpus/cudnn/cudnn.h"
36 #endif // GOOGLE_CUDA
37
38 #ifdef INTEL_MKL
39 #include "tensorflow/core/graph/mkl_graph_util.h"
40 #endif // INTEL_MKL
41
42 namespace tensorflow {
43 namespace grappler {
44
45 // Supported patterns:
46 //
47 // Conv2D + ... -> _FusedConv2D
48 // (1) Conv2D + BiasAdd + <Activation>
49 // (2) Conv2D + FusedBatchNorm + <Activation>
50 // (3) Conv2D + Squeeze + BiasAdd
51 //
52 // MatMul + ... -> _FusedMatMul:
53 // (1) MatMul + BiasAdd + <Activation>
54 //
55 // DepthwiseConv2dNative + ... -> _FusedDepthwiseConv2dNative:
56 // (1) DepthwiseConv2dNative + BiasAdd + <Activation>
57 //
58 // FusedBatchNorm[$is_training] + ... -> _FusedBatchNormEx[$is_training]
59 // (1) FusedBatchNorm + <Activation>
60 // (2) FusedBatchNorm + SideInput + <Activation>
61 //
62 // In all cases, the supported activation functions are Relu, Relu6, and Elu.
63 //
64 // Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the
65 // patterns are "ContractionWith...".
66 namespace {
67
68 constexpr char kFusedConv2D[] = "_FusedConv2D";
69 constexpr char kFusedMatMul[] = "_FusedMatMul";
70 constexpr char kFusedDepthwiseConv2dNative[] = "_FusedDepthwiseConv2dNative";
71 constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
72
73 constexpr char kDataFormat[] = "data_format";
74 constexpr char kIsTraining[] = "is_training";
75
76 constexpr int kMissingIndex = -1;
77
78 struct RemapperContext {
RemapperContexttensorflow::grappler::__anonbf673cd50111::RemapperContext79 explicit RemapperContext(GrapplerItem* item, Status* status)
80 : nodes_to_preserve(item->NodesToPreserve()),
81 graph_view(&item->graph, status),
82 graph_properties(*item),
83 inferred_graph_properties(false) {}
84
85 std::unordered_set<string> nodes_to_preserve;
86 utils::MutableGraphView graph_view;
87 GraphProperties graph_properties;
88 bool inferred_graph_properties;
89 };
90
91 // FusedBatchNorm that can be replaced with a cheaper set of primitives.
92 struct FusedBatchNorm {
93 FusedBatchNorm() = default;
FusedBatchNormtensorflow::grappler::__anonbf673cd50111::FusedBatchNorm94 explicit FusedBatchNorm(int fused_batch_norm)
95 : fused_batch_norm(fused_batch_norm) {}
96
97 int fused_batch_norm = kMissingIndex;
98 };
99
100 // FusedBatchNorm[$is_training] with fused side input and/or activation.
101 struct FusedBatchNormEx {
102 FusedBatchNormEx() = default;
103
104 int fused_batch_norm = kMissingIndex;
105 int side_input = kMissingIndex;
106 int activation = kMissingIndex;
107 // Add node that will be invalidated by fusing side input and fused batch norm
108 int invalidated = kMissingIndex;
109 };
110
111 // Contraction node followed by a BiasAdd.
112 struct ContractionWithBiasAdd {
113 ContractionWithBiasAdd() = default;
ContractionWithBiasAddtensorflow::grappler::__anonbf673cd50111::ContractionWithBiasAdd114 ContractionWithBiasAdd(int contraction, int bias_add)
115 : contraction(contraction), bias_add(bias_add) {}
116
117 int contraction = kMissingIndex;
118 int bias_add = kMissingIndex;
119 };
120
121 // Contraction node followed by a BiasAdd and Activation.
122 struct ContractionWithBiasAddAndActivation {
123 ContractionWithBiasAddAndActivation() = default;
ContractionWithBiasAddAndActivationtensorflow::grappler::__anonbf673cd50111::ContractionWithBiasAddAndActivation124 ContractionWithBiasAddAndActivation(int contraction, int bias_add,
125 int activation)
126 : contraction(contraction), bias_add(bias_add), activation(activation) {}
127
128 int contraction = kMissingIndex;
129 int bias_add = kMissingIndex;
130 int activation = kMissingIndex;
131 };
132
133 // Contraction node followed by a Squeeze and BiasAdd.
134 struct ContractionWithSqueezeAndBiasAdd {
135 ContractionWithSqueezeAndBiasAdd() = default;
ContractionWithSqueezeAndBiasAddtensorflow::grappler::__anonbf673cd50111::ContractionWithSqueezeAndBiasAdd136 ContractionWithSqueezeAndBiasAdd(int contraction, int squeeze, int bias_add)
137 : contraction(contraction), squeeze(squeeze), bias_add(bias_add) {}
138
139 int contraction = kMissingIndex;
140 int squeeze = kMissingIndex;
141 int bias_add = kMissingIndex;
142 };
143
144 // Contraction node followed by a FusedBatchNorm.
145 struct ContractionWithBatchNorm {
146 ContractionWithBatchNorm() = default;
ContractionWithBatchNormtensorflow::grappler::__anonbf673cd50111::ContractionWithBatchNorm147 ContractionWithBatchNorm(int contraction, int fused_batch_norm,
148 float epsilon = 0.0)
149 : contraction(contraction),
150 fused_batch_norm(fused_batch_norm),
151 epsilon(epsilon) {}
152
153 int contraction = kMissingIndex;
154 int fused_batch_norm = kMissingIndex;
155 float epsilon = 0.0;
156 };
157
158 // Contraction node followed by a FusedBatchNorm and Activation.
159 struct ContractionWithBatchNormAndActivation {
160 ContractionWithBatchNormAndActivation() = default;
ContractionWithBatchNormAndActivationtensorflow::grappler::__anonbf673cd50111::ContractionWithBatchNormAndActivation161 ContractionWithBatchNormAndActivation(int contraction, int fused_batch_norm,
162 int activation, float epsilon = 0.0)
163 : contraction(contraction),
164 fused_batch_norm(fused_batch_norm),
165 activation(activation),
166 epsilon(epsilon) {}
167
168 int contraction = kMissingIndex;
169 int fused_batch_norm = kMissingIndex;
170 int activation = kMissingIndex;
171 float epsilon = 0.0;
172 };
173
174 #ifdef INTEL_MKL
175 // Contraction node followed by a BiasAdd and Add.
176 struct ContractionWithBiasAddAndAdd {
177 ContractionWithBiasAddAndAdd() = default;
ContractionWithBiasAddAndAddtensorflow::grappler::__anonbf673cd50111::ContractionWithBiasAddAndAdd178 ContractionWithBiasAddAndAdd(int contraction, int bias_add, int add,
179 int port_id)
180 : contraction(contraction),
181 bias_add(bias_add),
182 add(add),
183 port_id(port_id) {}
184
185 int contraction = kMissingIndex;
186 int bias_add = kMissingIndex;
187 int add = kMissingIndex;
188 int port_id = 0;
189 };
190
191 // Contraction node followed by a BiasAdd, Add and Relu.
192 struct ContractionWithBiasAndAddActivation {
193 ContractionWithBiasAndAddActivation() = default;
ContractionWithBiasAndAddActivationtensorflow::grappler::__anonbf673cd50111::ContractionWithBiasAndAddActivation194 ContractionWithBiasAndAddActivation(int contraction, int bias_add, int add,
195 int port_id, int activation)
196 : contraction(contraction),
197 bias_add(bias_add),
198 add(add),
199 port_id(port_id),
200 activation(activation) {}
201
202 int contraction = kMissingIndex;
203 int bias_add = kMissingIndex;
204 int add = kMissingIndex;
205 int port_id = 0;
206 int activation = kMissingIndex;
207 };
208 #endif // INTEL_MKL
209
IsInPreserveSet(const RemapperContext & ctx,const NodeDef * node)210 bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) {
211 return ctx.nodes_to_preserve.count(node->name()) > 0;
212 }
213
HaveSameDataType(const NodeDef * lhs,const NodeDef * rhs,const string & type_attr="T")214 bool HaveSameDataType(const NodeDef* lhs, const NodeDef* rhs,
215 const string& type_attr = "T") {
216 DataType lhs_attr = GetDataTypeFromAttr(*lhs, type_attr);
217 DataType rhs_attr = GetDataTypeFromAttr(*rhs, type_attr);
218
219 return lhs_attr != DT_INVALID && rhs_attr != DT_INVALID &&
220 lhs_attr == rhs_attr;
221 }
222
HasDataType(const NodeDef * node,const DataType & expected,const string & type_attr="T")223 bool HasDataType(const NodeDef* node, const DataType& expected,
224 const string& type_attr = "T") {
225 DataType dtype = GetDataTypeFromAttr(*node, type_attr);
226 return dtype == expected;
227 }
228
IsCpuCompatibleDataType(const NodeDef * contraction,const string & type_attr="T")229 bool IsCpuCompatibleDataType(const NodeDef* contraction,
230 const string& type_attr = "T") {
231 DataType dtype = GetDataTypeFromAttr(*contraction, type_attr);
232 #if defined(INTEL_MKL)
233 #if defined(ENABLE_INTEL_MKL_BFLOAT16)
234 if (IsConv2D(*contraction) || IsDepthwiseConv2dNative(*contraction) ||
235 IsMatMul(*contraction)) {
236 return dtype == DT_FLOAT || dtype == DT_BFLOAT16;
237 #else
238 if (IsConv2D(*contraction) || IsDepthwiseConv2dNative(*contraction) ||
239 IsMatMul(*contraction)) {
240 return dtype == DT_FLOAT;
241 #endif // ENABLE_INTEL_MKL_BFLOAT16
242 #else
243 if (IsConv2D(*contraction)) {
244 return dtype == DT_FLOAT || dtype == DT_DOUBLE;
245 } else if (IsMatMul(*contraction)) {
246 return dtype == DT_FLOAT;
247 #endif // INTEL_MKL
248 } else {
249 return false;
250 }
251 }
252
253 bool IsGpuCompatibleDataType(const NodeDef* contraction,
254 const string& type_attr = "T") {
255 DataType dtype = GetDataTypeFromAttr(*contraction, type_attr);
256 if (IsConv2D(*contraction)) {
257 return dtype == DT_FLOAT;
258 } else {
259 return false;
260 }
261 }
262
263 bool IsCpuCompatibleDataFormat(const NodeDef* conv2d) {
264 DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
265 const string& data_format = conv2d->attr().at(kDataFormat).s();
266 #ifndef INTEL_MKL
267 return data_format == "NHWC";
268 #else
269 return data_format == "NHWC" || data_format == "NCHW";
270 #endif // !INTEL_MKL
271 }
272
273 bool IsGpuCompatibleDataFormat(const NodeDef* conv2d) {
274 DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
275 const string& data_format = conv2d->attr().at(kDataFormat).s();
276 return data_format == "NHWC" || data_format == "NCHW";
277 }
278
279 bool IsCpuCompatibleConv2D(const NodeDef* conv2d) {
280 DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
281 return NodeIsOnCpu(conv2d) && IsCpuCompatibleDataType(conv2d) &&
282 IsCpuCompatibleDataFormat(conv2d);
283 }
284
285 bool IsGpuCompatibleConv2D(const NodeDef* conv2d) {
286 DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
287 return NodeIsOnGpu(conv2d) && IsGpuCompatibleDataType(conv2d) &&
288 IsGpuCompatibleDataFormat(conv2d);
289 }
290
291 bool IsCpuCompatibleMatMul(const NodeDef* matmul) {
292 DCHECK(IsMatMul(*matmul)) << "Expected MatMul op";
293 return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul);
294 }
295
296 bool IsCpuCompatibleDepthwiseConv2dNative(const NodeDef* dw_conv2d) {
297 DCHECK(IsDepthwiseConv2dNative(*dw_conv2d))
298 << "Expected DepthwiseConv2dNative op";
299 return NodeIsOnCpu(dw_conv2d) && IsCpuCompatibleDataType(dw_conv2d);
300 }
301
302 // Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU.
303 template <typename Pattern>
304 bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) {
305 const NodeDef& node = ctx.graph_view.graph()->node(matched.contraction);
306 if (IsConv2D(node)) {
307 return IsCpuCompatibleConv2D(&node);
308 } else if (IsDepthwiseConv2dNative(node)) {
309 #ifdef INTEL_MKL
310 if (DisableMKL()) {
311 return false;
312 }
313 return IsCpuCompatibleDepthwiseConv2dNative(&node);
314 #else
315 return false;
316 #endif // INTEL_MKL
317 } else if (IsMatMul(node)) {
318 return IsCpuCompatibleMatMul(&node);
319 } else {
320 return false;
321 }
322 }
323
324 // Checks if we can rewrite a pattern to the `_FusedConv2D` on GPU device.
325 bool IsGpuCompatible(const RemapperContext& ctx,
326 const ContractionWithBiasAddAndActivation& matched) {
327 #if TENSORFLOW_USE_ROCM
328 // ROCm does not support _FusedConv2D
329 return false;
330 #endif
331 const GraphDef* graph = ctx.graph_view.graph();
332 const NodeDef& contraction_node = graph->node(matched.contraction);
333 if (!IsConv2D(contraction_node)) return false;
334
335 const std::vector<OpInfo::TensorProperties>& input_props =
336 ctx.graph_properties.GetInputProperties(contraction_node.name());
337 const TensorShapeProto& filter_shape =
338 input_props.size() >= 2 ? input_props[1].shape() : TensorShapeProto();
339
340 // FusedConv2D on GPU with 1x1 convolution is marginally faster than
341 // in-graph computation in micro benchmarks (see kernels/conv_ops_test.cc),
342 // and significantly slower in large scale benchmarks.
343 bool is_spatial_conv = Rank(filter_shape) == 4 && //
344 IsKnown(filter_shape.dim(1)) && //
345 IsKnown(filter_shape.dim(2)) && //
346 filter_shape.dim(1).size() != 1 && //
347 filter_shape.dim(2).size() != 1;
348
349 // We rely on cuDNN for fused convolution, and it currently supports only Relu
350 // activation.
351 const NodeDef& activation_node = graph->node(matched.activation);
352 bool is_relu = IsRelu(activation_node);
353
354 return is_relu && is_spatial_conv && IsGpuCompatibleConv2D(&contraction_node);
355 }
356 bool IsGpuCompatible(const RemapperContext& ctx,
357 const ContractionWithBiasAdd& matched) {
358 return false;
359 }
360 bool IsGpuCompatible(const RemapperContext& ctx,
361 const ContractionWithSqueezeAndBiasAdd& matched) {
362 return false;
363 }
364
365 // Returns true if the given pattern is supported on the assigned device.
366 template <typename Pattern>
367 bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) {
368 return IsCpuCompatible(ctx, matched) || IsGpuCompatible(ctx, matched);
369 }
370
371 bool IsSupportedActivation(const NodeDef& node) {
372 #ifdef INTEL_MKL
373 return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node) ||
374 IsTanh(node);
375 #else
376 return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
377 #endif
378 }
379
380 inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) {
381 return node_view.NumControllingFanins() > 0 ||
382 node_view.NumControlledFanouts() > 0;
383 }
384
385 // Returns true if at most one fanout reads output at port 0 (output used once).
386 inline bool HasAtMostOneFanoutAtPort0(const utils::MutableNodeView& node_view) {
387 return node_view.GetRegularFanout(0).size() <= 1;
388 }
389
390 // Returns true if at most one fanout reads actual tensor data at output port 0
391 // (output used once for any data computation).
392 inline bool HasAtMostOneDataFanoutAtPort0(
393 const utils::MutableNodeView& node_view) {
394 const auto predicate = [](const auto& fanout) -> bool {
395 const NodeDef* node = fanout.node_view()->node();
396 return !IsShape(*node) && !IsRank(*node);
397 };
398 return absl::c_count_if(node_view.GetRegularFanout(0), predicate) <= 1;
399 }
400
401 bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
402 ContractionWithBiasAdd* matched,
403 bool check_device_compatible = true) {
404 const auto* node_view = ctx.graph_view.GetNode(node_index);
405 // Root of the pattern must be a BiasAdd.
406 // TODO(lyandy): Forward controls for patterns with control dependencies.
407 if (HasControlFaninOrFanout(*node_view)) return false;
408
409 const auto* node_def = node_view->node();
410 if (!IsBiasAdd(*node_def)) return false;
411
412 // Input to the BiasAdd must be a Conv2D or a MatMul.
413 if (node_view->NumRegularFanins() < 1) return false;
414 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
415 const auto* contraction_node_view = regular_fanin_0.node_view();
416 const auto* contraction_node_def = contraction_node_view->node();
417
418 // Conv2D, MatMul or DepthwiseConv2D
419 bool is_contraction = IsConv2D(*contraction_node_def) ||
420 IsMatMul(*contraction_node_def) ||
421 IsDepthwiseConv2dNative(*contraction_node_def);
422
423 if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) ||
424 HasControlFaninOrFanout(*contraction_node_view) ||
425 !HasAtMostOneFanoutAtPort0(*contraction_node_view) ||
426 IsInPreserveSet(ctx, contraction_node_def))
427 return false;
428
429 // Check that data type and data format are supported on assigned device.
430 const ContractionWithBiasAdd pattern{contraction_node_view->node_index(),
431 node_index};
432 if (check_device_compatible && !IsDeviceCompatible(ctx, pattern))
433 return false;
434
435 // We successfully found a {Conv2D, MatMul}+BiasAdd pattern.
436 *matched = pattern;
437
438 return true;
439 }
440
441 bool FindContractionWithBiasAndActivation(
442 const RemapperContext& ctx, int node_index,
443 ContractionWithBiasAddAndActivation* matched) {
444 const auto* node_view = ctx.graph_view.GetNode(node_index);
445 // Root of the pattern must be an activation node.
446 // TODO(lyandy): Forward controls for patterns with control dependencies.
447 if (HasControlFaninOrFanout(*node_view)) return false;
448
449 const auto* node_def = node_view->node();
450 if (!IsSupportedActivation(*node_def)) return false;
451
452 // And input to the activation node must match ContractionWithBiasAdd pattern.
453 if (node_view->NumRegularFanins() < 1) return false;
454 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
455 const auto* bias_add_node_view = regular_fanin_0.node_view();
456 const auto* bias_add_node_def = bias_add_node_view->node();
457
458 ContractionWithBiasAdd base;
459 if (!FindContractionWithBias(ctx, bias_add_node_view->node_index(), &base,
460 /*check_device_compatible=*/false) ||
461 !HasAtMostOneFanoutAtPort0(*bias_add_node_view) ||
462 !HaveSameDataType(node_def, bias_add_node_def) ||
463 IsInPreserveSet(ctx, bias_add_node_def))
464 return false;
465
466 // Get the contraction node
467 const auto* contraction_node_view =
468 bias_add_node_view->GetRegularFanin(0).node_view();
469 const auto* contraction_node_def = contraction_node_view->node();
470
471 // Currently, only matmul + bias + tanh is enable
472 if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false;
473
474 // Currently, only (conv | matmul) + bias + leakyrelu is enabled
475 if (!(IsConv2D(*contraction_node_def) || IsMatMul(*contraction_node_def)) &&
476 IsLeakyRelu(*node_def))
477 return false;
478
479 // Check that data type and data format are supported on assigned device.
480 const ContractionWithBiasAddAndActivation pattern{base.contraction,
481 base.bias_add, node_index};
482 if (!IsDeviceCompatible(ctx, pattern)) return false;
483
484 // We successfully found a {Conv2D, MatMul}+BiasAdd+Activation pattern.
485 *matched = pattern;
486
487 return true;
488 }
489
490 bool FindConv2DWithSqueezeAndBias(const RemapperContext& ctx, int node_index,
491 ContractionWithSqueezeAndBiasAdd* matched) {
492 const auto* node_view = ctx.graph_view.GetNode(node_index);
493 // TODO(lyandy): Forward controls for patterns with control dependencies.
494 if (HasControlFaninOrFanout(*node_view)) return false;
495
496 // Root of the pattern must be a BiasAdd.
497 const auto* node_def = node_view->node();
498 if (!IsBiasAdd(*node_def)) return false;
499
500 // Input to the BiasAdd must be a Squeeze.
501 if (node_view->NumRegularFanins() < 1) return false;
502 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
503 const auto* squeeze_node_view = regular_fanin_0.node_view();
504 const auto* squeeze_node_def = squeeze_node_view->node();
505
506 if (!IsSqueeze(*squeeze_node_def) ||
507 !HaveSameDataType(node_def, squeeze_node_def, "T") ||
508 HasControlFaninOrFanout(*squeeze_node_view) ||
509 !HasAtMostOneFanoutAtPort0(*squeeze_node_view) ||
510 IsInPreserveSet(ctx, squeeze_node_def))
511 return false;
512
513 // Squeeze must not squeeze output channel dimension.
514 std::vector<int32> dims;
515 if (!TryGetNodeAttr(*squeeze_node_def, "squeeze_dims", &dims)) return false;
516 for (auto dim : dims) {
517 if (dim == 3) return false;
518 }
519
520 // Input to the Squeeze must be a Conv2D.
521 if (squeeze_node_view->NumRegularFanins() < 1) return false;
522 const auto& squeeze_regular_fanin_0 = squeeze_node_view->GetRegularFanin(0);
523 const auto* conv2d_node_view = squeeze_regular_fanin_0.node_view();
524 const auto* conv2d_node_def = conv2d_node_view->node();
525
526 if (!IsConv2D(*conv2d_node_def) ||
527 !HaveSameDataType(node_def, conv2d_node_def, "T") ||
528 HasControlFaninOrFanout(*conv2d_node_view) ||
529 !HasAtMostOneFanoutAtPort0(*conv2d_node_view) ||
530 IsInPreserveSet(ctx, conv2d_node_def))
531 return false;
532
533 // Check that data type and data format are supported on assigned device.
534 const ContractionWithSqueezeAndBiasAdd pattern{
535 conv2d_node_view->node_index(), squeeze_node_view->node_index(),
536 node_index};
537 if (!IsDeviceCompatible(ctx, pattern)) return false;
538
539 // We successfully found a Conv2D+Squeeze+BiasAdd pattern.
540 *matched = pattern;
541
542 return true;
543 }
544
545 bool FindConv2DWithBatchNorm(const RemapperContext& ctx, int node_index,
546 ContractionWithBatchNorm* matched) {
547 const auto* node_view = ctx.graph_view.GetNode(node_index);
548 const auto* node_def = node_view->node();
549 // Root of the pattern must be a FusedBatchNorm.
550 if (!IsFusedBatchNorm(*node_def)) return false;
551
552 // FusedBatchNormV2 and V3 have an extra type parameter.
553 if (node_view->GetOp() != "FusedBatchNorm" &&
554 !HasDataType(node_def, DT_FLOAT, "U"))
555 return false;
556
557 // Check that batch normalization is in inference mode.
558 const auto* training_attr = node_view->GetAttr(kIsTraining);
559 if (training_attr != nullptr && training_attr->b()) return false;
560
561 // Check that only 0th output is consumed by other nodes.
562 // TODO(lyandy): Forward controls for patterns with control dependencies.
563 if (HasControlFaninOrFanout(*node_view) ||
564 !node_view->GetRegularFanout(1).empty() || // batch_mean
565 !node_view->GetRegularFanout(2).empty() || // batch_variance
566 !node_view->GetRegularFanout(3).empty() || // reserve_space_1
567 !node_view->GetRegularFanout(4).empty()) // reserve_space_2
568 return false;
569
570 // Input to the FusedBatchNorm must be a Conv2D.
571 if (node_view->NumRegularFanins() < 1) return false;
572 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
573 const auto* conv2d_node_view = regular_fanin_0.node_view();
574 const auto* conv2d_node_def = conv2d_node_view->node();
575
576 if (!IsConv2D(*conv2d_node_def) || !NodeIsOnCpu(conv2d_node_def) ||
577 !HaveSameDataType(node_def, conv2d_node_def) ||
578 !IsCpuCompatibleDataType(conv2d_node_def) ||
579 !IsCpuCompatibleDataFormat(conv2d_node_def) ||
580 HasControlFaninOrFanout(*conv2d_node_view) ||
581 !HasAtMostOneFanoutAtPort0(*conv2d_node_view) ||
582 IsInPreserveSet(ctx, conv2d_node_def))
583 return false;
584
585 // We successfully found a Conv2D+FusedBatchNorm pattern.
586 matched->contraction = conv2d_node_view->node_index();
587 matched->fused_batch_norm = node_index;
588 if (!TryGetNodeAttr(*node_def, "epsilon", &matched->epsilon)) return false;
589
590 return true;
591 }
592
593 bool FindConv2DWithBatchNormAndActivation(
594 const RemapperContext& ctx, int node_index,
595 ContractionWithBatchNormAndActivation* matched) {
596 const auto* node_view = ctx.graph_view.GetNode(node_index);
597 // TODO(lyandy): Forward controls for patterns with control dependencies.
598 if (HasControlFaninOrFanout(*node_view)) return false;
599
600 // Root of the pattern must be an activation node.
601 const auto* node_def = node_view->node();
602 if (!IsSupportedActivation(*node_def)) return false;
603
604 // And input to the activation node must match Conv2DWithBatchNorm pattern.
605 if (node_view->NumRegularFanins() < 1) return false;
606 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
607 const auto* batch_norm_node_view = regular_fanin_0.node_view();
608
609 ContractionWithBatchNorm base;
610 if (!FindConv2DWithBatchNorm(ctx, batch_norm_node_view->node_index(), &base))
611 return false;
612
613 const auto* fused_batch_norm_node_view =
614 ctx.graph_view.GetNode(base.fused_batch_norm);
615 const auto* fused_batch_norm_node_def = fused_batch_norm_node_view->node();
616 if (!HasAtMostOneFanoutAtPort0(*fused_batch_norm_node_view) ||
617 !HaveSameDataType(node_def, fused_batch_norm_node_def) ||
618 IsInPreserveSet(ctx, fused_batch_norm_node_def))
619 return false;
620
621 // We successfully found a Conv2D+FusedBatchNorm+Activation pattern.
622 matched->contraction = base.contraction;
623 matched->fused_batch_norm = base.fused_batch_norm;
624 matched->activation = node_index;
625 matched->epsilon = base.epsilon;
626
627 return true;
628 }
629
630 #ifdef INTEL_MKL
631 // As AddN has multiple inputs, this function tries to find Conv2D + Bias
632 // pattern in specific input port.
633 bool FindContractionWithBiasInPort(const RemapperContext& ctx,
634 const utils::MutableNodeView& add_node_view,
635 const NodeDef& add_node_def, int port_id,
636 ContractionWithBiasAdd* base) {
637 // Input to AddN must match ContractionWithBiasAdd pattern.
638 if (add_node_view.NumRegularFanins() < port_id + 1) return false;
639 const auto& bias_add_node_view =
640 add_node_view.GetRegularFanin(port_id).node_view();
641 if (bias_add_node_view == nullptr) return false;
642 const auto* bias_add_node_def = bias_add_node_view->node();
643
644 if (!FindContractionWithBias(ctx, bias_add_node_view->node_index(), base,
645 /*check_device_compatible=*/false))
646 return false;
647 if (!HasAtMostOneFanoutAtPort0(*bias_add_node_view) ||
648 !HaveSameDataType(&add_node_def, bias_add_node_def) ||
649 IsInPreserveSet(ctx, bias_add_node_def))
650 return false;
651 return true;
652 }
653
654 bool IsAddWithNoBroadcast(const RemapperContext& ctx, const NodeDef& node) {
655 if (!IsAdd(node)) return false;
656
657 // Check if this is case of broadcasting - Add node supports broadcasting.
658 const auto& props = ctx.graph_properties.GetInputProperties(node.name());
659 if (props.size() == 2 &&
660 ShapesSymbolicallyEqual(props[0].shape(), props[1].shape())) {
661 return true;
662 }
663 return false;
664 }
665
666 bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx,
667 const utils::MutableNodeView& node_view,
668 ContractionWithBiasAddAndAdd* matched) {
669 if (DisableMKL()) return false;
670 // Fusion with AddN is supported only when it has two inputs.
671 // TODO(lyandy): Forward controls for patterns with control dependencies.
672 if (HasControlFaninOrFanout(node_view) || node_view.NumRegularFanins() != 2)
673 return false;
674
675 // Root of the pattern must be a AddN or Add with same input shapes
676 // (no broadcasting).
677 const auto* node_def = node_view.node();
678 if (!IsAddN(*node_def) && !IsAddWithNoBroadcast(ctx, *node_def)) return false;
679
680 #ifdef ENABLE_INTEL_MKL_BFLOAT16
681 // MKL AddN ops only support float and bfloat16 data types.
682 if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16))
683 return false;
684 #else
685 // MKL AddN ops only support float data type.
686 if (!HasDataType(node_def, DT_FLOAT)) return false;
687 #endif // ENABLE_INTEL_MKL_BFLOAT16
688
689 ContractionWithBiasAdd base;
690 matched->port_id = 0;
691
692 // Find the conv+bias pattern in specific port.
693 if (!FindContractionWithBiasInPort(ctx, node_view, *node_def,
694 matched->port_id, &base)) {
695 matched->port_id = 1;
696 if (!FindContractionWithBiasInPort(ctx, node_view, *node_def,
697 matched->port_id, &base)) {
698 return false;
699 }
700 }
701
702 // We successfully found a Conv2D+BiasAdd+{AddN,Add} pattern.
703 matched->contraction = base.contraction;
704 matched->bias_add = base.bias_add;
705 matched->add = node_view.node_index();
706
707 return true;
708 }
709
710 bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx,
711 int node_index,
712 ContractionWithBiasAddAndAdd* matched) {
713 const auto* node_view = ctx.graph_view.GetNode(node_index);
714 return FindContractionWithBiasAddAndAdd(ctx, *node_view, matched);
715 }
716
717 bool FindContractionWithBiasAndAddActivation(
718 const RemapperContext& ctx, int node_index,
719 ContractionWithBiasAndAddActivation* matched) {
720 if (DisableMKL()) return false;
721 const auto* node_view = ctx.graph_view.GetNode(node_index);
722 // TODO(lyandy): Forward controls for patterns with control dependencies.
723 if (HasControlFaninOrFanout(*node_view)) return false;
724
725 // Root of the pattern must be an activation node.
726 const auto* node_def = node_view->node();
727 if (node_def == nullptr) return false;
728 if (!IsSupportedActivation(*node_def)) return false;
729
730 // Currently, Contraction + Bias + Add + Tanh pattern is not supported
731 if (IsTanh(*node_def)) return false;
732
733 #ifdef ENABLE_INTEL_MKL_BFLOAT16
734 // MKL activation op only supports float and bfloat16 data types.
735 if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16))
736 return false;
737 #else
738 // MKL activation op only supports float data type.
739 if (!HasDataType(node_def, DT_FLOAT)) return false;
740 #endif // ENABLE_INTEL_MKL_BFLOAT16
741
742 // And input to activation must match ContractionWithBiasAddAndAdd pattern.
743 if (node_view->NumRegularFanins() < 1) return false;
744 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
745 const auto* add_node_view = regular_fanin_0.node_view();
746
747 ContractionWithBiasAddAndAdd base;
748
749 if (!FindContractionWithBiasAddAndAdd(ctx, *add_node_view, &base)) {
750 return false;
751 }
752
753 // Get the contraction node
754 const auto* bias_add_node_view =
755 add_node_view->GetRegularFanin(base.port_id).node_view();
756 const auto* contraction_node_view =
757 bias_add_node_view->GetRegularFanin(0).node_view();
758 const auto* contraction_node_def = contraction_node_view->node();
759
760 // Currently, only conv + bias + add + leakyrelu is enabled
761 if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
762
763 // We successfully found a Conv2D+BiasAdd+AddN+activation pattern.
764 const ContractionWithBiasAndAddActivation pattern{
765 base.contraction, base.bias_add, base.add, base.port_id, node_index};
766 *matched = pattern;
767
768 return true;
769 }
770 #endif
771
772 bool FindFusedBatchNorm(const RemapperContext& ctx, int node_index,
773 FusedBatchNorm* matched) {
774 const auto* node_view = ctx.graph_view.GetNode(node_index);
775 const auto* node_def = node_view->node();
776 if (!IsFusedBatchNorm(*node_def)) return false;
777 if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
778
779 // Check that the node is in inference mode.
780 bool is_training = true;
781 if (!TryGetNodeAttr(*node_def, kIsTraining, &is_training)) return false;
782 if (is_training) return false;
783
784 const auto& props = ctx.graph_properties.GetInputProperties(node_def->name());
785
786 // a. Scaling factor can be const folded:
787 // scaling_factor = (variance + epsilon).rsqrt() * scale
788 bool const_scaling_factor =
789 props.size() == 5 && // [x, scale, offset, mean, variance]
790 props[1].has_value() && // scale
791 props[4].has_value(); // variance aka estimated variance
792
793 // b. Or input can be const folded into some other expression.
794 auto const_inputs = std::count_if(
795 props.begin(), props.end(),
796 [](const OpInfo::TensorProperties& props) { return props.has_value(); });
797
798 // TODO(bsteiner): use the cost model to compare the cost of fused batch
799 // norm against that of the optimized form.
800 bool can_remap = const_scaling_factor || const_inputs >= 4;
801 if (!can_remap) return false;
802
803 // The optimized version only generates the first output.
804 if (node_view->GetRegularFanouts().size() > 1) {
805 return false;
806 }
807
808 // We found a fused batch norm node that can be replaced with primitive ops.
809 matched->fused_batch_norm = node_index;
810
811 return true;
812 }
813
814 // NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
815 // `tensorflow/stream_executor/cuda/cuda_dnn.cc` for details.
816 bool BatchnormSpatialPersistentEnabled() {
817 #if CUDNN_VERSION >= 7402
818 static bool is_enabled = [] {
819 bool is_enabled = false;
820 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
821 "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
822 /*default_val=*/false, &is_enabled));
823 return is_enabled;
824 }();
825 return is_enabled;
826 #else
827 return false;
828 #endif
829 }
830
831 bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
832 FusedBatchNormEx* matched) {
833 // Root of the pattern must be a Relu.
834 // TODO(ezhulenev): Forward control dependencies.
835 const auto* node_view = ctx.graph_view.GetNode(node_index);
836 const auto* node_def = node_view->node();
837 // TODO(lyandy): Forward controls for patterns with control dependencies.
838 if (!IsRelu(*node_def) || HasControlFaninOrFanout(*node_view)) return false;
839
840 // Returns true iff the node is a compatible FusedBatchNorm node.
841 const auto valid_batch_norm =
842 [&](const utils::MutableNodeView& fused_batch_norm) -> bool {
843 const auto* fused_batch_norm_node_def = fused_batch_norm.node();
844 if (!IsFusedBatchNorm(*fused_batch_norm_node_def)) return false;
845
846 #ifndef ENABLE_MKLDNN_V1
847 // We fuse FusedBatchNorm on GPU or MKL CPU.
848 if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false;
849 #else
850 if (DisableMKL()) return false;
851 #endif
852
853 DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T");
854 #ifndef ENABLE_MKLDNN_V1
855 if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
856 #else
857 if (t_dtype != DT_FLOAT && t_dtype != DT_BFLOAT16) return false;
858 #endif
859
860 // Get the FusedBatchNorm training mode.
861 bool is_training;
862 if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
863 .ok())
864 return false;
865 // In training mode we rely on cuDNN for computing FusedBatchNorm with side
866 // inputs and activation, and it has its own limitations. In inference mode
867 // we have a custom CUDA kernel that doesn't not have these constraints.
868 if (is_training && NodeIsOnGpu(fused_batch_norm_node_def)) {
869 // cuDNN only supports NHWC data layout.
870 string data_format;
871 if (!GetNodeAttr(*fused_batch_norm_node_def, kDataFormat, &data_format)
872 .ok())
873 return false;
874 if (data_format != "NHWC") return false;
875
876 // Data type must be DT_HALF.
877 if (t_dtype != DT_HALF) return false;
878
879 // Channel dimension must be a multiple of 4.
880 const auto& props = ctx.graph_properties.GetInputProperties(
881 fused_batch_norm_node_def->name());
882
883 const bool valid_channel_dim = !props.empty() &&
884 props[0].shape().dim_size() == 4 &&
885 props[0].shape().dim(3).size() % 4 == 0;
886 if (!valid_channel_dim) return false;
887
888 // cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode.
889 if (!BatchnormSpatialPersistentEnabled()) return false;
890 }
891
892 // FusedBatchNormV2 and V3 have an extra type parameter.
893 if ((fused_batch_norm_node_def->op() != "FusedBatchNorm") &&
894 !HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U"))
895 return false;
896
897 // Check that only one node consumes the 0-th output of a FusedBatchNorm.
898 if (HasControlFaninOrFanout(fused_batch_norm) ||
899 !HasAtMostOneDataFanoutAtPort0(fused_batch_norm) ||
900 IsInPreserveSet(ctx, fused_batch_norm_node_def))
901 return false;
902
903 return true;
904 };
905
906 if (node_view->NumRegularFanins() < 1) return false;
907 const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
908 const auto* relu_fanin_0_node_view = regular_fanin_0.node_view();
909 const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
910
911 // Input to a Relu can be a FusedBatchNorm.
912 if (valid_batch_norm(*relu_fanin_0_node_view)) {
913 matched->activation = node_index;
914 matched->fused_batch_norm = regular_fanin_0.node_index();
915 return true;
916 }
917
918 // Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
919 if (IsAdd(*relu_fanin_0_node_def)) {
920 // Currently no CPU implementation for "FusedBatchNorm + SideInput +
921 // <Activation>""
922 #ifdef ENABLE_MKLDNN_V1
923 return false;
924 #endif
925
926 // Check that only Relu node consumes the output of an Add node.
927 if (HasControlFaninOrFanout(*relu_fanin_0_node_view) ||
928 !HasAtMostOneFanoutAtPort0(*relu_fanin_0_node_view) ||
929 IsInPreserveSet(ctx, relu_fanin_0_node_def))
930 return false;
931
932 // Add node supports broadcasting, FusedBatchNormEx does not.
933 const auto& props =
934 ctx.graph_properties.GetInputProperties(relu_fanin_0_node_def->name());
935 if (props.size() < 2 ||
936 !ShapesSymbolicallyEqual(props[0].shape(), props[1].shape()))
937 return false;
938
939 if (relu_fanin_0_node_view->NumRegularFanins() < 2) return false;
940 const auto& add_regular_fanin_0 =
941 relu_fanin_0_node_view->GetRegularFanin(0);
942 const auto& add_regular_fanin_1 =
943 relu_fanin_0_node_view->GetRegularFanin(1);
944
945 if (valid_batch_norm(*add_regular_fanin_0.node_view())) {
946 matched->activation = node_index;
947 matched->side_input = add_regular_fanin_1.node_index();
948 matched->fused_batch_norm = add_regular_fanin_0.node_index();
949 matched->invalidated = regular_fanin_0.node_index();
950 return true;
951 }
952
953 if (valid_batch_norm(*add_regular_fanin_1.node_view())) {
954 matched->activation = node_index;
955 matched->side_input = add_regular_fanin_0.node_index();
956 matched->fused_batch_norm = add_regular_fanin_1.node_index();
957 matched->invalidated = regular_fanin_0.node_index();
958 return true;
959 }
960 }
961
962 return false;
963 }
964
965 void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d,
966 const NodeDef* activation = nullptr) {
967 DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D";
968
969 auto* attr = fused_conv2d->mutable_attr();
970 auto& src_attr = conv2d.attr();
971
972 (*attr)["T"] = src_attr.at("T");
973 (*attr)["strides"] = src_attr.at("strides");
974 (*attr)["padding"] = src_attr.at("padding");
975 (*attr)["explicit_paddings"] = src_attr.at("explicit_paddings");
976 (*attr)["dilations"] = src_attr.at("dilations");
977 (*attr)["data_format"] = src_attr.at("data_format");
978 (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
979 // Copy LeakyRelu's attr alpha to FusedConv2D's attr leakyrelu_alpha
980 if (activation != nullptr && IsLeakyRelu(*activation)) {
981 auto& activation_attr = activation->attr();
982 (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
983 }
984 }
985
986 void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
987 NodeDef* fused_dw_conv2d,
988 const NodeDef* activation = nullptr) {
989 DCHECK(IsDepthwiseConv2dNative(dw_conv2d))
990 << "Input node must be a DepthwiseConv2dNative";
991
992 auto* attr = fused_dw_conv2d->mutable_attr();
993 auto& src_attr = dw_conv2d.attr();
994
995 (*attr)["T"] = src_attr.at("T");
996 (*attr)["strides"] = src_attr.at("strides");
997 (*attr)["padding"] = src_attr.at("padding");
998 (*attr)["dilations"] = src_attr.at("dilations");
999 (*attr)["data_format"] = src_attr.at("data_format");
1000 // Copy LeakyRelu's attr alpha to FusedDepthwiseConv2d's attr leakyrelu_alpha
1001 if (activation != nullptr && IsLeakyRelu(*activation)) {
1002 auto& activation_attr = activation->attr();
1003 (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
1004 }
1005 }
1006
1007 void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
1008 NodeDef* fused_batch_norm_ex) {
1009 DCHECK(IsFusedBatchNorm(fused_batch_norm))
1010 << "Input node must be a FusedBatchNorm";
1011
1012 auto* attr = fused_batch_norm_ex->mutable_attr();
1013 auto src_attr = fused_batch_norm.attr();
1014
1015 (*attr)["T"] = src_attr.at("T");
1016 (*attr)["is_training"] = src_attr.at("is_training");
1017 (*attr)["data_format"] = src_attr.at("data_format");
1018 (*attr)["epsilon"] = src_attr.at("epsilon");
1019 (*attr)["exponential_avg_factor"] = src_attr.at("exponential_avg_factor");
1020
1021 // FusedBatchNormV2 and V3 have an extra type parameter.
1022 if (fused_batch_norm.op() != "FusedBatchNorm") {
1023 SetAttrValue(src_attr.at("U"), &(*attr)["U"]);
1024 } else {
1025 #ifndef ENABLE_MKLDNN_V1
1026 SetAttrValue(src_attr.at("T"), &(*attr)["U"]);
1027 #else
1028 SetAttrValue(DT_FLOAT, &(*attr)["U"]);
1029 #endif
1030 }
1031 }
1032
1033 void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul,
1034 const NodeDef* activation = nullptr) {
1035 DCHECK(IsMatMul(matmul)) << "Input node must be a MatMul";
1036
1037 auto* attr = fused_matmul->mutable_attr();
1038 auto& src_attr = matmul.attr();
1039
1040 (*attr)["T"] = src_attr.at("T");
1041 (*attr)["transpose_a"] = src_attr.at("transpose_a");
1042 (*attr)["transpose_b"] = src_attr.at("transpose_b");
1043 // Copy LeakyRelu's attr alpha to _FusedMatMul's attr leakyrelu_alpha
1044 if (activation != nullptr && IsLeakyRelu(*activation)) {
1045 auto& activation_attr = activation->attr();
1046 (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
1047 }
1048 }
1049
1050 void SetFusedOpAttributes(NodeDef* fused,
1051 const absl::Span<const absl::string_view> fused_ops,
1052 int num_args = 1, float epsilon = 0.0) {
1053 auto* attr = fused->mutable_attr();
1054 SetAttrValue(fused_ops, &(*attr)["fused_ops"]);
1055 SetAttrValue(num_args, &(*attr)["num_args"]);
1056 SetAttrValue(epsilon, &(*attr)["epsilon"]); // required only for BatchNorm
1057 }
1058
1059 Status AddFusedContractionNode(RemapperContext* ctx,
1060 const ContractionWithBiasAdd& matched,
1061 std::vector<bool>* invalidated_nodes,
1062 std::vector<bool>* nodes_to_delete) {
1063 DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
1064
1065 const GraphDef* graph = ctx->graph_view.graph();
1066 const NodeDef& contraction = graph->node(matched.contraction);
1067 const NodeDef& bias_add = graph->node(matched.bias_add);
1068 VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd: "
1069 << " bias_add=" << bias_add.name()
1070 << " contraction=" << contraction.name();
1071
1072 NodeDef fused_op;
1073 fused_op.set_name(bias_add.name());
1074 fused_op.set_device(contraction.device());
1075 fused_op.add_input(contraction.input(0)); // 0: input
1076 fused_op.add_input(contraction.input(1)); // 1: filter
1077 fused_op.add_input(bias_add.input(1)); // 2: bias
1078
1079 if (IsConv2D(contraction)) {
1080 fused_op.set_op(kFusedConv2D);
1081 CopyConv2DAttributes(contraction, &fused_op);
1082 } else if (IsDepthwiseConv2dNative(contraction)) {
1083 fused_op.set_op(kFusedDepthwiseConv2dNative);
1084 CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
1085 } else if (IsMatMul(contraction)) {
1086 fused_op.set_op(kFusedMatMul);
1087 CopyMatMulAttributes(contraction, &fused_op);
1088 }
1089
1090 SetFusedOpAttributes(&fused_op, {"BiasAdd"});
1091
1092 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1093 Status status;
1094 mutation->AddNode(std::move(fused_op), &status);
1095 TF_RETURN_IF_ERROR(status);
1096 TF_RETURN_IF_ERROR(mutation->Apply());
1097
1098 (*invalidated_nodes)[matched.bias_add] = true;
1099 (*nodes_to_delete)[matched.contraction] = true;
1100
1101 return Status::OK();
1102 }
1103
1104 Status AddFusedContractionNode(
1105 RemapperContext* ctx, const ContractionWithBiasAddAndActivation& matched,
1106 std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
1107 DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
1108
1109 const GraphDef* graph = ctx->graph_view.graph();
1110 const NodeDef& contraction = graph->node(matched.contraction);
1111 const NodeDef& bias_add = graph->node(matched.bias_add);
1112 const NodeDef& activation = graph->node(matched.activation);
1113
1114 VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd and "
1115 << activation.op() << ":"
1116 << " activation=" << activation.name()
1117 << " bias_add=" << bias_add.name()
1118 << " contraction=" << contraction.name();
1119
1120 NodeDef fused_op;
1121 fused_op.set_name(activation.name());
1122 fused_op.set_device(contraction.device());
1123 fused_op.add_input(contraction.input(0)); // 0: input
1124 fused_op.add_input(contraction.input(1)); // 1: filter
1125 fused_op.add_input(bias_add.input(1)); // 2: bias
1126
1127 if (IsConv2D(contraction)) {
1128 fused_op.set_op(kFusedConv2D);
1129 // leaky relu has a special attribute alpha
1130 CopyConv2DAttributes(contraction, &fused_op, &activation);
1131 } else if (IsDepthwiseConv2dNative(contraction)) {
1132 fused_op.set_op(kFusedDepthwiseConv2dNative);
1133 CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
1134 } else if (IsMatMul(contraction)) {
1135 fused_op.set_op(kFusedMatMul);
1136 CopyMatMulAttributes(contraction, &fused_op, &activation);
1137 }
1138
1139 SetFusedOpAttributes(&fused_op, {"BiasAdd", activation.op()});
1140
1141 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1142 Status status;
1143 mutation->AddNode(std::move(fused_op), &status);
1144 TF_RETURN_IF_ERROR(status);
1145 TF_RETURN_IF_ERROR(mutation->Apply());
1146
1147 (*nodes_to_delete)[matched.contraction] = true;
1148 (*nodes_to_delete)[matched.bias_add] = true;
1149 (*invalidated_nodes)[matched.activation] = true;
1150
1151 return Status::OK();
1152 }
1153
1154 Status AddFusedConv2DNode(RemapperContext* ctx,
1155 const ContractionWithSqueezeAndBiasAdd& matched,
1156 std::vector<bool>* invalidated_nodes,
1157 std::vector<bool>* nodes_to_delete) {
1158 DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern";
1159
1160 const GraphDef* graph = ctx->graph_view.graph();
1161 const NodeDef& contraction = graph->node(matched.contraction);
1162 DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";
1163
1164 const NodeDef& bias_add = graph->node(matched.bias_add);
1165 const NodeDef& squeeze = graph->node(matched.squeeze);
1166 VLOG(2) << "Fuse Conv2D with Squeeze and BiasAdd: "
1167 << " bias_add=" << bias_add.name() << " squeeze=" << squeeze.name()
1168 << " conv2d=" << contraction.name();
1169
1170 // Replace Conv2D node with a fused Conv2D. Matched pattern guarantees that it
1171 // has single consumer (only the squeeze node).
1172 NodeDef fused_conv2d;
1173 fused_conv2d.set_name(contraction.name());
1174 fused_conv2d.set_op(kFusedConv2D);
1175 fused_conv2d.set_device(contraction.device());
1176 fused_conv2d.add_input(contraction.input(0)); // 0: input
1177 fused_conv2d.add_input(contraction.input(1)); // 1: filter
1178 fused_conv2d.add_input(bias_add.input(1)); // 2: bias
1179
1180 CopyConv2DAttributes(contraction, &fused_conv2d);
1181 SetFusedOpAttributes(&fused_conv2d, {"BiasAdd"});
1182
1183 // Replace BiasAdd node with a Squeeze.
1184 NodeDef remapped_squeeze = squeeze;
1185 remapped_squeeze.set_name(bias_add.name());
1186 remapped_squeeze.set_input(0, contraction.name());
1187
1188 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1189 Status status;
1190 mutation->AddNode(std::move(fused_conv2d), &status);
1191 TF_RETURN_IF_ERROR(status);
1192 mutation->AddNode(std::move(remapped_squeeze), &status);
1193 TF_RETURN_IF_ERROR(status);
1194 TF_RETURN_IF_ERROR(mutation->Apply());
1195
1196 (*invalidated_nodes)[matched.contraction] = true;
1197 (*invalidated_nodes)[matched.bias_add] = true;
1198 (*nodes_to_delete)[matched.squeeze] = true;
1199
1200 return Status::OK();
1201 }
1202
1203 Status AddFusedConv2DNode(RemapperContext* ctx,
1204 const ContractionWithBatchNorm& matched,
1205 std::vector<bool>* invalidated_nodes,
1206 std::vector<bool>* nodes_to_delete) {
1207 const GraphDef* graph = ctx->graph_view.graph();
1208 const NodeDef& contraction = graph->node(matched.contraction);
1209 DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";
1210 const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
1211 VLOG(2) << "Fuse Conv2D with BatchNorm: batch_norm="
1212 << fused_batch_norm.name() << " conv2d=" << contraction.name();
1213
1214 NodeDef fused_conv2d;
1215 fused_conv2d.set_name(fused_batch_norm.name());
1216 fused_conv2d.set_op(kFusedConv2D);
1217 fused_conv2d.set_device(contraction.device());
1218 fused_conv2d.add_input(contraction.input(0)); // 0: input
1219 fused_conv2d.add_input(contraction.input(1)); // 1: filter
1220 fused_conv2d.add_input(fused_batch_norm.input(1)); // 2: scale
1221 fused_conv2d.add_input(fused_batch_norm.input(2)); // 3: offset
1222 fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean
1223 fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance
1224
1225 CopyConv2DAttributes(contraction, &fused_conv2d);
1226 SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm"},
1227 /*num_args=*/4, /*epsilon=*/matched.epsilon);
1228
1229 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1230 Status status;
1231 mutation->AddNode(std::move(fused_conv2d), &status);
1232 TF_RETURN_IF_ERROR(status);
1233 TF_RETURN_IF_ERROR(mutation->Apply());
1234
1235 (*invalidated_nodes)[matched.fused_batch_norm] = true;
1236 (*nodes_to_delete)[matched.contraction] = true;
1237
1238 return Status::OK();
1239 }
1240
1241 Status AddFusedConv2DNode(RemapperContext* ctx,
1242 const ContractionWithBatchNormAndActivation& matched,
1243 std::vector<bool>* invalidated_nodes,
1244 std::vector<bool>* nodes_to_delete) {
1245 const GraphDef* graph = ctx->graph_view.graph();
1246 const NodeDef& contraction = graph->node(matched.contraction);
1247
1248 DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";
1249
1250 const NodeDef& activation = graph->node(matched.activation);
1251 const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
1252 VLOG(2) << "Fuse Conv2D with BatchNorm and " << activation.op()
1253 << ": activation=" << activation.name()
1254 << " batch_norm=" << fused_batch_norm.name()
1255 << " conv2d=" << contraction.name();
1256
1257 NodeDef fused_conv2d;
1258 fused_conv2d.set_name(activation.name());
1259 fused_conv2d.set_op(kFusedConv2D);
1260 fused_conv2d.set_device(contraction.device());
1261 fused_conv2d.add_input(contraction.input(0)); // 0: input
1262 fused_conv2d.add_input(contraction.input(1)); // 1: filter
1263 fused_conv2d.add_input(fused_batch_norm.input(1)); // 2: scale
1264 fused_conv2d.add_input(fused_batch_norm.input(2)); // 3: offset
1265 fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean
1266 fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance
1267
1268 CopyConv2DAttributes(contraction, &fused_conv2d, &activation);
1269 SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm", activation.op()},
1270 /*num_args=*/4, /*epsilon=*/matched.epsilon);
1271
1272 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1273 Status status;
1274 mutation->AddNode(std::move(fused_conv2d), &status);
1275 TF_RETURN_IF_ERROR(status);
1276 TF_RETURN_IF_ERROR(mutation->Apply());
1277
1278 (*invalidated_nodes)[matched.activation] = true;
1279 (*nodes_to_delete)[matched.contraction] = true;
1280 (*nodes_to_delete)[matched.fused_batch_norm] = true;
1281
1282 return Status::OK();
1283 }
1284
1285 #ifdef INTEL_MKL
1286 Status AddFusedContractionNode(RemapperContext* ctx,
1287 const ContractionWithBiasAddAndAdd& matched,
1288 std::vector<bool>* invalidated_nodes,
1289 std::vector<bool>* nodes_to_delete) {
1290 const GraphDef* graph = ctx->graph_view.graph();
1291 const NodeDef& contraction = graph->node(matched.contraction);
1292 const NodeDef& bias_add = graph->node(matched.bias_add);
1293
1294 // MKL version only support fusion for Conv2D and MatMul
1295 DCHECK(IsConv2D(contraction) || IsMatMul(contraction));
1296
1297 NodeDef contraction_node;
1298 const NodeDef& add = graph->node(matched.add);
1299 contraction_node.set_name(add.name());
1300 contraction_node.set_device(contraction.device());
1301 contraction_node.add_input(
1302 contraction.input(0)); // 0: input(conv) / a (matmul)
1303 contraction_node.add_input(
1304 contraction.input(1)); // 1: filter(conv) / b (matmul)
1305 contraction_node.add_input(bias_add.input(1)); // 2: bias
1306
1307 // Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
1308 // previously, the other input to add is fused here.
1309 contraction_node.add_input(add.input(1 - matched.port_id));
1310
1311 if (IsConv2D(contraction)) {
1312 contraction_node.set_op(kFusedConv2D);
1313 CopyConv2DAttributes(contraction, &contraction_node);
1314 } else if (IsMatMul(contraction)) {
1315 contraction_node.set_op(kFusedMatMul);
1316 CopyMatMulAttributes(contraction, &contraction_node);
1317 }
1318
1319 SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2);
1320
1321 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1322 Status status;
1323 mutation->AddNode(std::move(contraction_node), &status);
1324 TF_RETURN_IF_ERROR(status);
1325 TF_RETURN_IF_ERROR(mutation->Apply());
1326
1327 (*invalidated_nodes)[matched.add] = true;
1328 (*nodes_to_delete)[matched.contraction] = true;
1329 (*nodes_to_delete)[matched.bias_add] = true;
1330
1331 return Status::OK();
1332 }
1333
1334 Status AddFusedContractionNode(
1335 RemapperContext* ctx, const ContractionWithBiasAndAddActivation& matched,
1336 std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
1337 const GraphDef* graph = ctx->graph_view.graph();
1338 // MKL version only support fusion for Conv2D
1339 const NodeDef& contraction = graph->node(matched.contraction);
1340 DCHECK(IsConv2D(contraction));
1341 const NodeDef& activation = graph->node(matched.activation);
1342
1343 NodeDef fused_conv2d;
1344 fused_conv2d.set_name(activation.name());
1345 fused_conv2d.set_op(kFusedConv2D);
1346 fused_conv2d.set_device(contraction.device());
1347 fused_conv2d.add_input(contraction.input(0)); // 0: input
1348 fused_conv2d.add_input(contraction.input(1)); // 1: filter
1349 const NodeDef& bias_add = graph->node(matched.bias_add);
1350 fused_conv2d.add_input(bias_add.input(1)); // 2: bias
1351
1352 // Add OP has two inputs, one is conv+bias pattern matched previously,
1353 // the other input to add is fused here.
1354 const NodeDef& add = graph->node(matched.add);
1355 fused_conv2d.add_input(add.input(1 - matched.port_id));
1356
1357 CopyConv2DAttributes(contraction, &fused_conv2d);
1358 SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", activation.op()}, 2);
1359
1360 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1361 Status status;
1362 mutation->AddNode(std::move(fused_conv2d), &status);
1363 TF_RETURN_IF_ERROR(status);
1364 TF_RETURN_IF_ERROR(mutation->Apply());
1365
1366 (*invalidated_nodes)[matched.activation] = true;
1367 (*nodes_to_delete)[matched.add] = true;
1368 (*nodes_to_delete)[matched.bias_add] = true;
1369 (*nodes_to_delete)[matched.contraction] = true;
1370
1371 return Status::OK();
1372 }
1373 #endif
1374
1375 Status AddFusedBatchNormExNode(RemapperContext* ctx,
1376 const FusedBatchNormEx& matched,
1377 std::vector<bool>* invalidated_nodes,
1378 std::vector<bool>* nodes_to_delete) {
1379 const GraphDef* graph = ctx->graph_view.graph();
1380 const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
1381 const NodeDef& activation = graph->node(matched.activation);
1382
1383 VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:"
1384 << " activation=" << activation.name() << " side_input="
1385 << (matched.side_input != kMissingIndex
1386 ? graph->node(matched.side_input).name()
1387 : "<none>")
1388 << " invalidated="
1389 << (matched.invalidated != kMissingIndex
1390 ? graph->node(matched.invalidated).name()
1391 : "<none>")
1392 << " fused_batch_norm=" << fused_batch_norm.name();
1393
1394 // Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
1395 NodeDef fused_op;
1396 fused_op.set_op(kFusedBatchNormEx);
1397 fused_op.set_name(fused_batch_norm.name());
1398 fused_op.set_device(fused_batch_norm.device());
1399
1400 fused_op.add_input(fused_batch_norm.input(0)); // 0: input
1401 fused_op.add_input(fused_batch_norm.input(1)); // 1: scale
1402 fused_op.add_input(fused_batch_norm.input(2)); // 2: offset
1403 fused_op.add_input(fused_batch_norm.input(3)); // 3: estimated_mean
1404 fused_op.add_input(fused_batch_norm.input(4)); // 4: estimated_var
1405
1406 CopyFusedBatchNormAttributes(fused_batch_norm, &fused_op);
1407
1408 auto* attrs = fused_op.mutable_attr();
1409 SetAttrValue(activation.op(), &(*attrs)["activation_mode"]);
1410
1411 if (matched.side_input != kMissingIndex) {
1412 SetAttrValue(1, &(*attrs)["num_side_inputs"]);
1413 const NodeDef& side_input = graph->node(matched.side_input);
1414 fused_op.add_input(side_input.name()); // 5: side_input
1415 } else {
1416 SetAttrValue(0, &(*attrs)["num_side_inputs"]);
1417 }
1418
1419 // Turn activation node into Identity node.
1420 NodeDef identity_op;
1421 identity_op.set_op("Identity");
1422 identity_op.set_name(activation.name());
1423 identity_op.set_device(fused_batch_norm.device());
1424 identity_op.add_input(fused_batch_norm.name());
1425 (*identity_op.mutable_attr())["T"] = attrs->at("T");
1426
1427 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1428 Status status;
1429 mutation->AddNode(std::move(fused_op), &status);
1430 TF_RETURN_IF_ERROR(status);
1431 mutation->AddNode(std::move(identity_op), &status);
1432 TF_RETURN_IF_ERROR(status);
1433 TF_RETURN_IF_ERROR(mutation->Apply());
1434
1435 (*invalidated_nodes)[matched.fused_batch_norm] = true;
1436 (*invalidated_nodes)[matched.activation] = true;
1437 if (matched.side_input != kMissingIndex) {
1438 (*nodes_to_delete)[matched.invalidated] = true;
1439 }
1440
1441 return Status::OK();
1442 }
1443
1444 Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
1445 const GraphDef* graph = ctx->graph_view.graph();
1446 const NodeDef& fused_node = graph->node(matched.fused_batch_norm);
1447 VLOG(2) << "Optimizing fused batch norm node "
1448 << SummarizeNodeDef(fused_node);
1449
1450 const string& x = fused_node.input(0);
1451 string scale = fused_node.input(1);
1452 string offset = fused_node.input(2);
1453 string mean = fused_node.input(3);
1454 string variance = fused_node.input(4);
1455
1456 utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
1457 Status status;
1458
1459 string x_format = fused_node.attr().at(kDataFormat).s();
1460 if (x_format == "NCHW" || x_format == "NCDHW") {
1461 // Need to reshape the last 4 inputs
1462 NodeDef new_shape;
1463 const string new_shape_name =
1464 AddPrefixToNodeName(x_format + "Shape", fused_node.name());
1465 new_shape.set_name(new_shape_name);
1466 new_shape.set_op("Const");
1467 new_shape.set_device(fused_node.device());
1468 *new_shape.add_input() = AsControlDependency(scale);
1469 (*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
1470 if (x_format == "NCHW") {
1471 Tensor t(DT_INT32, {4});
1472 t.flat<int32>()(0) = 1;
1473 t.flat<int32>()(1) = -1;
1474 t.flat<int32>()(2) = 1;
1475 t.flat<int32>()(3) = 1;
1476 t.AsProtoTensorContent(
1477 (*new_shape.mutable_attr())["value"].mutable_tensor());
1478 } else {
1479 Tensor t(DT_INT32, {5});
1480 t.flat<int32>()(0) = 1;
1481 t.flat<int32>()(1) = -1;
1482 t.flat<int32>()(2) = 1;
1483 t.flat<int32>()(3) = 1;
1484 t.flat<int32>()(4) = 1;
1485 t.AsProtoTensorContent(
1486 (*new_shape.mutable_attr())["value"].mutable_tensor());
1487 }
1488 mutation->AddNode(std::move(new_shape), &status);
1489 TF_RETURN_IF_ERROR(status);
1490
1491 NodeDef reshaped_scale;
1492 reshaped_scale.set_name(
1493 AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name()));
1494 reshaped_scale.set_op("Reshape");
1495 reshaped_scale.set_device(fused_node.device());
1496 *reshaped_scale.add_input() = scale;
1497 *reshaped_scale.add_input() = new_shape_name;
1498 (*reshaped_scale.mutable_attr())["T"] = fused_node.attr().at("T");
1499 (*reshaped_scale.mutable_attr())["Tshape"].set_type(DT_INT32);
1500 scale = reshaped_scale.name();
1501 mutation->AddNode(std::move(reshaped_scale), &status);
1502 TF_RETURN_IF_ERROR(status);
1503
1504 NodeDef reshaped_offset;
1505 reshaped_offset.set_name(
1506 AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name()));
1507 reshaped_offset.set_op("Reshape");
1508 reshaped_offset.set_device(fused_node.device());
1509 *reshaped_offset.add_input() = offset;
1510 *reshaped_offset.add_input() = new_shape_name;
1511 (*reshaped_offset.mutable_attr())["T"] = fused_node.attr().at("T");
1512 (*reshaped_offset.mutable_attr())["Tshape"].set_type(DT_INT32);
1513 offset = reshaped_offset.name();
1514 mutation->AddNode(std::move(reshaped_offset), &status);
1515 TF_RETURN_IF_ERROR(status);
1516
1517 NodeDef reshaped_mean;
1518 reshaped_mean.set_name(
1519 AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name()));
1520 reshaped_mean.set_op("Reshape");
1521 reshaped_mean.set_device(fused_node.device());
1522 *reshaped_mean.add_input() = mean;
1523 *reshaped_mean.add_input() = new_shape_name;
1524 (*reshaped_mean.mutable_attr())["T"] = fused_node.attr().at("T");
1525 (*reshaped_mean.mutable_attr())["Tshape"].set_type(DT_INT32);
1526 mean = reshaped_mean.name();
1527 mutation->AddNode(std::move(reshaped_mean), &status);
1528 TF_RETURN_IF_ERROR(status);
1529
1530 NodeDef reshaped_variance;
1531 reshaped_variance.set_name(
1532 AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name()));
1533 reshaped_variance.set_op("Reshape");
1534 reshaped_variance.set_device(fused_node.device());
1535 *reshaped_variance.add_input() = variance;
1536 *reshaped_variance.add_input() = new_shape_name;
1537 (*reshaped_variance.mutable_attr())["T"] = fused_node.attr().at("T");
1538 (*reshaped_variance.mutable_attr())["Tshape"].set_type(DT_INT32);
1539 variance = reshaped_variance.name();
1540 mutation->AddNode(std::move(reshaped_variance), &status);
1541 TF_RETURN_IF_ERROR(status);
1542 }
1543
1544 float epsilon = 0.0f;
1545 if (fused_node.attr().count("epsilon")) {
1546 epsilon = fused_node.attr().at("epsilon").f();
1547 }
1548 DataType dtype = fused_node.attr().at("T").type();
1549 Tensor value(dtype, TensorShape());
1550 value.scalar<float>()() = epsilon;
1551 NodeDef variance_epsilon;
1552 const string variance_epsilon_name =
1553 AddPrefixToNodeName("Const", fused_node.name());
1554 TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
1555 variance_epsilon_name, TensorValue(&value), &variance_epsilon));
1556 variance_epsilon.set_device(fused_node.device());
1557 mutation->AddNode(std::move(variance_epsilon), &status);
1558 TF_RETURN_IF_ERROR(status);
1559
1560 NodeDef variance_plus_epsilon;
1561 const string variance_plus_epsilon_name =
1562 AddPrefixToNodeName("VarPlusEpsilon", fused_node.name());
1563 variance_plus_epsilon.set_name(variance_plus_epsilon_name);
1564 variance_plus_epsilon.set_op("Add");
1565 (*variance_plus_epsilon.mutable_attr())["T"].set_type(dtype);
1566 variance_plus_epsilon.set_device(fused_node.device());
1567 *variance_plus_epsilon.add_input() = variance;
1568 *variance_plus_epsilon.add_input() = variance_epsilon_name;
1569 mutation->AddNode(std::move(variance_plus_epsilon), &status);
1570 TF_RETURN_IF_ERROR(status);
1571
1572 NodeDef inv;
1573 const string inv_name = AddPrefixToNodeName("Inv", fused_node.name());
1574 inv.set_name(inv_name);
1575 inv.set_op("Rsqrt");
1576 inv.set_device(fused_node.device());
1577 (*inv.mutable_attr())["T"].set_type(dtype);
1578 *inv.add_input() = variance_plus_epsilon_name;
1579 mutation->AddNode(std::move(inv), &status);
1580 TF_RETURN_IF_ERROR(status);
1581
1582 NodeDef scaled;
1583 const string scaled_name = AddPrefixToNodeName("Scaled", fused_node.name());
1584 scaled.set_name(scaled_name);
1585 scaled.set_op("Mul");
1586 scaled.set_device(fused_node.device());
1587 (*scaled.mutable_attr())["T"].set_type(dtype);
1588 *scaled.add_input() = inv_name;
1589 *scaled.add_input() = scale;
1590 mutation->AddNode(std::move(scaled), &status);
1591 TF_RETURN_IF_ERROR(status);
1592
1593 NodeDef a;
1594 const string a_name = AddPrefixToNodeName("Mul", fused_node.name());
1595 a.set_name(a_name);
1596 a.set_op("Mul");
1597 a.set_device(fused_node.device());
1598 (*a.mutable_attr())["T"].set_type(dtype);
1599 *a.add_input() = x;
1600 *a.add_input() = scaled_name;
1601 mutation->AddNode(std::move(a), &status);
1602 TF_RETURN_IF_ERROR(status);
1603
1604 NodeDef b;
1605 const string b_name = AddPrefixToNodeName("Mul2", fused_node.name());
1606 b.set_name(b_name);
1607 b.set_op("Mul");
1608 b.set_device(fused_node.device());
1609 (*b.mutable_attr())["T"].set_type(dtype);
1610 *b.add_input() = mean;
1611 *b.add_input() = scaled_name;
1612 mutation->AddNode(std::move(b), &status);
1613 TF_RETURN_IF_ERROR(status);
1614
1615 NodeDef c;
1616 const string c_name = AddPrefixToNodeName("Offset", fused_node.name());
1617 c.set_name(c_name);
1618 c.set_op("Sub");
1619 c.set_device(fused_node.device());
1620 (*c.mutable_attr())["T"].set_type(dtype);
1621 *c.add_input() = offset;
1622 *c.add_input() = b_name;
1623 mutation->AddNode(std::move(c), &status);
1624 TF_RETURN_IF_ERROR(status);
1625
1626 NodeDef r;
1627 r.set_name(fused_node.name());
1628 r.set_op("Add");
1629 r.set_device(fused_node.device());
1630 (*r.mutable_attr())["T"].set_type(dtype);
1631 *r.add_input() = a_name;
1632 *r.add_input() = c_name;
1633 mutation->AddNode(std::move(r), &status);
1634 TF_RETURN_IF_ERROR(status);
1635
1636 return mutation->Apply();
1637 }
1638
1639 #ifdef INTEL_MKL
1640 bool IsConv2DOrMatMul(const NodeDef& node) {
1641 return IsConv2D(node) || IsMatMul(node);
1642 }
1643
1644 bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
1645 const auto* node_view = ctx.graph_view.GetNode(node_index);
1646
1647 // Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion.
1648 // MatMul + Add or MatMul + BiasAdd + Add fusion.
1649 auto is_supported_add_input = [](const auto* node_view) -> bool {
1650 // Currently only support Conv2D and MatMul
1651 if (IsConv2DOrMatMul(*node_view->node())) return true;
1652 if (IsBiasAdd(*node_view->node())) {
1653 if (node_view->NumRegularFanins() < 2) return false;
1654 const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
1655 const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
1656 return IsConv2DOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
1657 IsConv2DOrMatMul(*bias_add_fanin_1.node_view()->node());
1658 }
1659 return false;
1660 };
1661
1662 auto is_supported_add = [&](const auto* node_view) -> bool {
1663 const auto* node_def = node_view->node();
1664 if (IsAdd(*node_def)) {
1665 if (node_view->NumRegularFanins() < 2) return false;
1666 const auto& add_fanin_0 = node_view->GetRegularFanin(0);
1667 const auto& add_fanin_1 = node_view->GetRegularFanin(1);
1668 return is_supported_add_input(add_fanin_0.node_view()) ||
1669 is_supported_add_input(add_fanin_1.node_view());
1670 }
1671 return false;
1672 };
1673
1674 bool ret = false;
1675 for (int i = 0; i < node_view->NumRegularFanins(); i++) {
1676 const auto& fanin_i = node_view->GetRegularFanin(i);
1677 ret = is_supported_add(fanin_i.node_view());
1678 if (ret) break;
1679 }
1680
1681 return ret;
1682 }
1683 #endif
1684
1685 // Check if a node is a candidate to one of the patterns that require inferred
1686 // shapes:
1687 // (1) Splitting FusedBatchNorm into primitives.
1688 // (2) Fusing side input and/or activation into FusedBatchNorm.
1689 // (3) Fusing Conv2D biasadd and relu on GPU
1690 // (4) INTEL_MKL specific: Conv2D -> Add or Conv2D -> BiasAdd -> Add.
1691 bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
1692 // Candidate for a FusedBatchNorm splitting.
1693 const auto* node_view = ctx.graph_view.GetNode(node_index);
1694 const auto* node_def = node_view->node();
1695 const auto is_batch_norm_candidate = [&]() -> bool {
1696 if (!IsFusedBatchNorm(*node_def)) return false;
1697 if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
1698
1699 bool is_training = true;
1700 if (!TryGetNodeAttr(*node_def, kIsTraining, &is_training)) return false;
1701 if (is_training) return false;
1702
1703 return true;
1704 };
1705
1706 const auto is_relu_biasadd_conv2d_candidate = [&]() -> bool {
1707 if (!IsRelu(*node_def)) return false;
1708 if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false;
1709
1710 if (node_view->NumRegularFanins() < 1) return false;
1711 const auto& relu_fanin_0 = node_view->GetRegularFanin(0);
1712 const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
1713 const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
1714
1715 if (!IsBiasAdd(*relu_fanin_0_node_def)) return false;
1716 if (GetDataTypeFromAttr(*relu_fanin_0_node_def, "T") != DT_FLOAT)
1717 return false;
1718
1719 if (relu_fanin_0_node_view->NumRegularFanins() < 1) return false;
1720
1721 const auto& biasadd_fanin_0 = relu_fanin_0_node_view->GetRegularFanin(0);
1722 const auto* biasadd_fanin_0_node_def = biasadd_fanin_0.node_view()->node();
1723
1724 if (!IsConv2D(*biasadd_fanin_0_node_def)) return false;
1725 if (GetDataTypeFromAttr(*biasadd_fanin_0_node_def, "T") != DT_FLOAT)
1726 return false;
1727
1728 return true;
1729 };
1730
1731 // Candidate for a FusedBatchNorm fusion.
1732 const auto is_batch_norm_fusion_candidate = [&]() -> bool {
1733 if (!IsRelu(*node_def)) return false;
1734
1735 if (node_view->NumRegularFanins() < 1) return false;
1736 const auto& relu_fanin_0 = node_view->GetRegularFanin(0);
1737 const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
1738 const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
1739
1740 if (IsFusedBatchNorm(*relu_fanin_0_node_def)) {
1741 // FusedBatchNorm + Relu.
1742 return true;
1743
1744 } else if (IsAdd(*relu_fanin_0_node_def)) {
1745 // FusedBatchNorm + Add + Relu.
1746
1747 if (relu_fanin_0_node_view->NumRegularFanins() < 2) return false;
1748 const auto& add_regular_fanin_0 =
1749 relu_fanin_0_node_view->GetRegularFanin(0);
1750 if (IsFusedBatchNorm(*add_regular_fanin_0.node_view()->node()))
1751 return true;
1752 const auto& add_regular_fanin_1 =
1753 relu_fanin_0_node_view->GetRegularFanin(1);
1754 if (IsFusedBatchNorm(*add_regular_fanin_1.node_view()->node()))
1755 return true;
1756 }
1757
1758 return false;
1759 };
1760
1761 #ifdef INTEL_MKL
1762 (void)is_relu_biasadd_conv2d_candidate; // To fix unused variable error.
1763 return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
1764 IsContractionWithAdd(ctx, node_index);
1765 #else
1766 return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
1767 is_batch_norm_fusion_candidate();
1768 #endif // INTEL_MKL
1769 }
1770
1771 } // namespace
1772
1773 Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
1774 GraphDef* optimized_graph) {
1775 GrapplerItem mutable_item = item;
1776 Status status;
1777 RemapperContext ctx(&mutable_item, &status);
1778 TF_RETURN_IF_ERROR(status);
1779 // Processing graph in reverse-topological sorted order allows to remap
1780 // longer chains of dependent ops in one pass.
1781 TF_RETURN_IF_ERROR(
1782 ctx.graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
1783
1784 const int num_nodes = item.graph.node_size();
1785 // Skip nodes that were invalidated by a remapper, e.g. do not process BiasAdd
1786 // and Activation nodes that were fused into a Conv2D node.
1787 std::vector<bool> invalidated_nodes(num_nodes);
1788 std::vector<bool> nodes_to_delete(num_nodes);
1789
1790 // _Fused{...} kernels do not have registered gradient function, so we must
1791 // not perform rewrite if the graph will be differentiated later.
1792 bool allow_non_differentiable_rewrites =
1793 item.optimization_options().allow_non_differentiable_rewrites;
1794
1795 for (int i = num_nodes - 1; i >= 0; --i) {
1796 // Check if node was invalidated by one of the previous remaps.
1797 if (invalidated_nodes[i] || nodes_to_delete[i]) {
1798 continue;
1799 }
1800
1801 // Infer properties lazily in case they are not needed.
1802 if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, i)) {
1803 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
1804 TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
1805 assume_valid_feeds,
1806 /*aggressive_shape_inference=*/false,
1807 /*include_input_tensor_values=*/true,
1808 /*include_output_tensor_values=*/false));
1809 ctx.inferred_graph_properties = true;
1810 }
1811
1812 #ifdef INTEL_MKL
1813 ContractionWithBiasAddAndAdd contract_with_bias_and_add;
1814 ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
1815
1816 if (!item.optimization_options().is_eager_mode) {
1817 // Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D.
1818 if (FindContractionWithBiasAndAddActivation(
1819 ctx, i, &contract_with_bias_and_add_activation)) {
1820 TF_RETURN_IF_ERROR(
1821 AddFusedContractionNode(&ctx, contract_with_bias_and_add_activation,
1822 &invalidated_nodes, &nodes_to_delete));
1823 continue;
1824 }
1825
1826 // Remap Conv2D+BiasAdd+Add into the _FusedConv2D.
1827 if (FindContractionWithBiasAddAndAdd(ctx, i,
1828 &contract_with_bias_and_add)) {
1829 TF_RETURN_IF_ERROR(
1830 AddFusedContractionNode(&ctx, contract_with_bias_and_add,
1831 &invalidated_nodes, &nodes_to_delete));
1832 continue;
1833 }
1834 }
1835 #endif //! INTEL_MKL
1836
1837 // Infer properties lazily in case they are not needed.
1838 if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, i)) {
1839 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
1840 TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
1841 assume_valid_feeds,
1842 /*aggressive_shape_inference=*/false,
1843 /*include_input_tensor_values=*/true,
1844 /*include_output_tensor_values=*/false));
1845 ctx.inferred_graph_properties = true;
1846 }
1847
1848 // Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd into the
1849 // _Fused{Conv2D,DepthwiseConv2dNative,MatMul}
1850 ContractionWithBiasAdd contract_with_bias;
1851 if (allow_non_differentiable_rewrites &&
1852 FindContractionWithBias(ctx, i, &contract_with_bias)) {
1853 TF_RETURN_IF_ERROR(AddFusedContractionNode(
1854 &ctx, contract_with_bias, &invalidated_nodes, &nodes_to_delete));
1855 continue;
1856 }
1857
1858 // Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd+Activation into the
1859 // _Fused{Conv2D,DepthwiseConv2dNative,MatMul}.
1860 ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
1861 if (allow_non_differentiable_rewrites &&
1862 FindContractionWithBiasAndActivation(
1863 ctx, i, &contract_with_bias_and_activation)) {
1864 TF_RETURN_IF_ERROR(
1865 AddFusedContractionNode(&ctx, contract_with_bias_and_activation,
1866 &invalidated_nodes, &nodes_to_delete));
1867 continue;
1868 }
1869
1870 // NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do
1871 // it for MatMul as well, but in practice this pattern does not appear in
1872 // real Tensorflow graphs.
1873
1874 // Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
1875 ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
1876 if (allow_non_differentiable_rewrites &&
1877 FindConv2DWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) {
1878 TF_RETURN_IF_ERROR(
1879 AddFusedConv2DNode(&ctx, contract_with_squeeze_and_bias,
1880 &invalidated_nodes, &nodes_to_delete));
1881 continue;
1882 }
1883
1884 // TODO(intel-tf):
1885 // Remove this once TF-MKL supports _FusedConv2D with these operations.
1886 #ifndef INTEL_MKL
1887 // Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
1888 ContractionWithBatchNorm contract_with_batch_norm;
1889 if (allow_non_differentiable_rewrites &&
1890 FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
1891 TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
1892 &invalidated_nodes,
1893 &nodes_to_delete));
1894 continue;
1895 }
1896
1897 // Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
1898 ContractionWithBatchNormAndActivation
1899 contract_with_batch_norm_and_activation;
1900 if (allow_non_differentiable_rewrites &&
1901 FindConv2DWithBatchNormAndActivation(
1902 ctx, i, &contract_with_batch_norm_and_activation)) {
1903 TF_RETURN_IF_ERROR(
1904 AddFusedConv2DNode(&ctx, contract_with_batch_norm_and_activation,
1905 &invalidated_nodes, &nodes_to_delete));
1906 continue;
1907 }
1908 #endif // !INTEL_MKL
1909
1910 // Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
1911 FusedBatchNormEx fused_batch_norm_ex;
1912 if (allow_non_differentiable_rewrites &&
1913 FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) {
1914 TF_RETURN_IF_ERROR(AddFusedBatchNormExNode(
1915 &ctx, fused_batch_norm_ex, &invalidated_nodes, &nodes_to_delete));
1916 continue;
1917 }
1918
1919 // During inference, most of the inputs to FusedBatchNorm are constant, and
1920 // we can therefore replace the op with a much cheaper set of primitives.
1921 FusedBatchNorm fused_batch_norm;
1922 if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) {
1923 TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
1924 continue;
1925 }
1926 }
1927
1928 // Remove invalidated nodes.
1929 utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder();
1930 for (int i = 0; i < num_nodes; ++i) {
1931 if (nodes_to_delete[i]) {
1932 mutation->RemoveNode(ctx.graph_view.GetNode(i));
1933 }
1934 }
1935 TF_RETURN_IF_ERROR(mutation->Apply());
1936
1937 *optimized_graph = std::move(mutable_item.graph);
1938
1939 return Status::OK();
1940 }
1941
1942 void Remapper::Feedback(Cluster* cluster, const GrapplerItem& item,
1943 const GraphDef& optimized_graph, double result) {
1944 // Nothing to do for RemapperOptimizer.
1945 }
1946
1947 } // namespace grappler
1948 } // namespace tensorflow
1949