• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // TODO(intel): Improve error handling in this file; instead of CHECK failing
17 // all over the place, we should log an error and execute the original graph.
18 #ifdef INTEL_MKL
19 
20 #include "tensorflow/core/common_runtime/mkl_layout_pass.h"
21 
22 #include <algorithm>
23 #include <functional>
24 #include <memory>
25 #include <queue>
26 #include <set>
27 #include <stack>
28 #include <tuple>
29 #include <unordered_set>
30 #include <utility>
31 #include <vector>
32 
33 #include "absl/base/call_once.h"
34 #include "tensorflow/core/common_runtime/function.h"
35 #include "tensorflow/core/common_runtime/optimization_registry.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/tensor.pb.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph.h"
40 #include "tensorflow/core/graph/mkl_graph_util.h"
41 #include "tensorflow/core/graph/node_builder.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/gtl/array_slice.h"
44 #include "tensorflow/core/lib/gtl/map_util.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/util/tensor_format.h"
48 #include "tensorflow/core/util/util.h"
49 
50 namespace tensorflow {
51 
52 // This pass implements rewriting of graph to support following scenarios:
53 // (A) Merging nodes in the graph
54 // (B) Rewriting a node in the graph to a new node
55 //     Rewrite happens under following scenario:
56 //     - Propagating Mkl layout as an additional output tensor
57 //        (we will loosely call a tensor that carries Mkl layout as Mkl tensor
58 //         henceforth.) from every Mkl supported NN layer.
59 //
60 // Example of A : Merging nodes in the graph
61 // -----------------------------------------
62 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
63 //
64 //           O = Conv2D(A, B)
65 //           P = BiasAdd(O, C)
66 //
67 // We merge them into Conv2DWithBias as:
68 //           P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
69 //
70 // The meaning of A_m, B_m and C_m is explained in B.1.
71 //
72 // Merge rules:
73 //  - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
74 //    goes to BiasAdd.
75 //  - Also, the intersection of attributes of both the nodes must have same
76 //    values.
77 //  - Both the nodes must have been assigned to same device (if any).
78 //
79 // Example of B.1 : Rewriting nodes to Mkl nodes
80 // ---------------------------------------------
81 // Consider a Relu node. Current definition of Relu node looks like:
82 //
83 //           O = Relu(A)
84 //
85 // Relu has 1 input (A), and 1 output (O).
86 //
87 // This rewrite pass will generate a new graph node for Relu (new node is
88 // called MklRelu) as:
89 //
90 //          O, O_m = MklRelu(A, A_m)
91 //
92 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
93 // same as input A of Relu; output O is same as output O of Relu. O_m is the
94 // additional output tensor that will be set by MklRelu, and it represents
95 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
96 // metadata for O. A_m is additional input of Relu, and it represents metadata
97 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
98 // this metadata from previous node in the graph.
99 //
100 // When a previous node in the graph is an Mkl node, A_m will represent a valid
101 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
102 // a dummy Mkl tensor.
103 //
104 // Rewriting rules:
105 //  - Selection of a node for rewriting happens by registering the op type of
106 //    the node with the rewriting pass. If the op type is not registered, then
107 //    all nodes of this op type will not be rewritten.
108 //  - Number of inputs after rewriting:
109 //      Since for every input Tensorflow tensor, the rewritten node gets Mkl
110 //      tensor(s), rewritten node gets 2*N inputs, where N is the number of
111 //      inputs for the original node.
112 //  - Number of outputs after rewriting:
113 //      Since for every output Tensorflow tensor, the rewritten node generates
114 //      Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
115 //      number of outputs of the original node.
116 //  - Ordering of Tensorflow tensors and Mkl tensors:
117 //      Since every rewritten node generates twice the number of inputs and
118 //      outputs, one could imagine various orderings among Tensorflow tensors
119 //      and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
120 //      inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
121 //      in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
122 //      order. Among N inputs one can get N! permutations.
123 //
124 //      So the question is: which order do we follow? We support 2 types of
125 //      orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
126 //      follows an intuitive order where an Mkl tensor follows the
127 //      corresponding Tensorflow tensor immediately. In the context of the
128 //      above example, it will be: A, A_m, B, B_m. Note that the ordering rule
129 //      applies to both the inputs and outputs. Contiguous ordering means
130 //      all the Tensorflow tensors are contiguous followed by all the Mkl
131 //      tensors. We use contiguous ordering as default.
132 //
133 // Graph rewrite algorithm:
134 //      Algorithm: Graph Rewrite
135 //      Input: Graph G, Names of the nodes to rewrite and their new names
136 //      Output: Modified Graph G' if the nodes are modified, G otherwise.
137 //      Start:
138 //        N = Topological_Sort(G) // N is a set of nodes in toposort order.
139 //        foreach node n in N
140 //        do
141 //          if (Is_MKL_Op(n))  // Can this node accept an Mkl layout as input.
142 //          then
143 //            E = set of <incoming edge and its src_output slot> of n
144 //            E' = {}   // a new set of edges for rewritten node
145 //            foreach <e,s> in E
146 //            do
147 //              E' U {<e,s>}  // First copy edge which generates Tensorflow
148 //                            // tensor as it is
149 //              m = Source node of edge e
150 //              if Is_Rewritten(m)  // Did we rewrite this node in this pass?
151 //              then
152 //                E' U {<m,s+1>}    // If yes, then m will generate an Mkl
153 //                                  // tensor as an additional output.
154 //              else
155 //                d = Generate_Dummy_Mkl_Tensor()  // If not, generate a dummy
156 //                                                 // Mkl tensor.
157 //                E' U {<d,0>}  // The dummy Mkl tensor has only 1 output slot.
158 //              fi
159 //            done
160 //            n' = Build_New_Node(G,new_name,E')
161 //            Mark_Rewritten(n')  // Mark the new node as being rewritten.
162 //          fi
163 //        done
164 //
165 //      Explanation:
166 //        For graph rewrite, we visit nodes of the input graph in the
167 //        topological sort order. With this ordering, we visit nodes in the
168 //        top-to-bottom fashion. We need this order because while visiting a
169 //        node we want that all of its input nodes are visited and rewritten if
170 //        applicable. This is because if we need to rewrite a given node
171 //        then all of its input nodes need to be fixed (in other words they
172 //        cannot be deleted later.)
173 //
174 //        While visiting a node, we first check if the op type of the node is
175 //        an Mkl op. If it is, then we rewrite that node after constructing
176 //        new inputs to the node. If the op type of the node is not Mkl op,
177 //        then we do not rewrite that node.
178 //
179 // Handling workspace propagation for certain ops:
180 //
181 //        Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
182 //        passing of a workspace from their respective forward ops. Workspace
183 //        tensors provide memory for storing results of intermediate operations
184 //        which are helpful in backward propagation. TensorFlow does not have
185 //        a notion of a workspace and as a result does not allow producing
186 //        additional outputs from these forward ops. For these ops, we need
187 //        to add 2 extra edges between forward ops and their corresponding
188 //        backward ops - the first extra edge carries a workspace tensor and
189 //        the second one carries an Mkl tensor for the workspace tensor.
190 //
191 //        Example:
192 //
193 //        Typical graph for MaxPool and its gradient looks like:
194 //
195 //        A = MaxPool(T)
196 //        B = MaxPoolGrad(X, A, Y)
197 //
198 //        We will transform this graph to propagate the workspace as:
199 //        (with the contiguous ordering)
200 //
201 //        A, W, A_m, W_m = MklMaxPool(T, T_m)
202 //        B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
203 //
204 //        Here W is the workspace tensor. Transformed tensor names with the
205 //        suffix _m are Mkl tensors, and this transformation has been done
206 //        using the algorithm discussed earlier. The transformation for
207 //        workspace propagation only adds extra outputs (W, W_m) for a forward
208 //        op and connects them to the corresponding backward ops.
209 //
210 //        Terms:
211 //
212 //        Forward op name = name of the op in the forward pass
213 //          where a workspace tensor originates (MaxPool in this example)
214 //        Backward op name = name of the op in the backward pass that receives
215 //          a workspace tensor from the forward op (MaxPoolGrad in the example)
216 //        Slot = Position of the output or input slot that will be
217 //               used by the workspace tensor (1 for MklMaxPool as W is the 2nd
218 //               output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
219 //
220 //        Question:
221 //
222 //        How do we associate a backward op to a forward op? There can be more
223 //        than one op with the exact same name.
224 //
225 //        In this example, we associate MaxPoolGrad with MaxPool. But there
226 //        could be more than one MaxPool ops. To solve this problem, we look
227 //        for _direct_ edge between a forward op and a backward op (tensor A is
228 //        flowing along this edge in the example).
229 //
230 //        How do we transform forward and backward ops when there is no direct
231 //        edge between them? In such a case, we generate dummy tensors for
232 //        workspace tensors. For the example, transformation of MaxPool will
233 //        be exactly same as it would be when there is a direct edge between
234 //        the forward and the backward op --- it is just that MaxPool won't
235 //        generate any workspace tensor. For MaxPoolGrad, the transformation
236 //        will also be same, but instead of connecting W and W_m with the
237 //        outputs of MaxPool, we will produce dummy tensors for them, and we
238 //        will set workspace_enabled attribute to false.
239 //
240 class MklLayoutRewritePass : public GraphOptimizationPass {
241  public:
MklLayoutRewritePass()242   MklLayoutRewritePass() {
243     // NOTE: names are alphabetically sorted.
244     csinfo_.addn = "AddN";
245     csinfo_.avg_pool = "AvgPool";
246     csinfo_.avg_pool_grad = "AvgPoolGrad";
247     csinfo_.avg_pool3d = "AvgPool3D";
248     csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
249     csinfo_.batch_matmul = "BatchMatMul";
250     csinfo_.batch_matmul_v2 = "BatchMatMulV2";
251     csinfo_.bias_add = "BiasAdd";
252     csinfo_.bias_add_grad = "BiasAddGrad";
253     csinfo_.concat = "Concat";
254     csinfo_.concatv2 = "ConcatV2";
255     csinfo_.conjugate_transpose = "ConjugateTranspose";
256     csinfo_.conv2d = "Conv2D";
257     csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias";
258     csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
259     csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
260     csinfo_.conv2d_grad_filter_with_bias =
261         "__MklDummyConv2DBackpropFilterWithBias";
262     csinfo_.conv3d = "Conv3D";
263     csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
264     csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
265     csinfo_.depthwise_conv2d = "DepthwiseConv2dNative";
266     csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput";
267     csinfo_.depthwise_conv2d_grad_filter =
268         "DepthwiseConv2dNativeBackpropFilter";
269     csinfo_.dequantize = "Dequantize";
270     csinfo_.einsum = "Einsum";
271     csinfo_.fused_batch_norm = "FusedBatchNorm";
272     csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
273     csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx";
274     csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
275     csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
276     csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
277     csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
278     csinfo_.fused_conv2d = "_FusedConv2D";
279     csinfo_.fused_conv3d = "_FusedConv3D";
280     csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
281     csinfo_.fused_matmul = "_FusedMatMul";
282     csinfo_.identity = "Identity";
283     csinfo_.leakyrelu = "LeakyRelu";
284     csinfo_.leakyrelu_grad = "LeakyReluGrad";
285     csinfo_.lrn = "LRN";
286     csinfo_.lrn_grad = "LRNGrad";
287     csinfo_.matmul = "MatMul";
288     csinfo_.max_pool = "MaxPool";
289     csinfo_.max_pool_grad = "MaxPoolGrad";
290     csinfo_.max_pool3d = "MaxPool3D";
291     csinfo_.max_pool3d_grad = "MaxPool3DGrad";
292     csinfo_.mkl_conv2d = "_MklConv2D";
293     csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
294     csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
295     csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
296     csinfo_.mkl_conv2d_grad_filter_with_bias =
297         "_MklConv2DBackpropFilterWithBias";
298     csinfo_.mkl_depthwise_conv2d_grad_input =
299         "_MklDepthwiseConv2dNativeBackpropInput";
300     csinfo_.mkl_depthwise_conv2d_grad_filter =
301         "_MklDepthwiseConv2dNativeBackpropFilter";
302     csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx";
303     csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
304     csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
305     csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
306     csinfo_.mkl_native_conv2d_with_bias = "_MklNativeConv2DWithBias";
307     csinfo_.mkl_native_conv2d_grad_filter_with_bias =
308         "_MklNativeConv2DBackpropFilterWithBias";
309     csinfo_.mkl_native_fused_batch_norm_ex = "_MklNativeFusedBatchNormEx";
310     csinfo_.mkl_native_fused_conv2d = "_MklNativeFusedConv2D";
311     csinfo_.mkl_native_fused_conv3d = "_MklNativeFusedConv3D";
312     csinfo_.mkl_native_fused_depthwise_conv2d =
313         "_MklNativeFusedDepthwiseConv2dNative";
314     csinfo_.mkl_native_fused_matmul = "_MklNativeFusedMatMul";
315     csinfo_.mkl_native_pad_with_conv2d = "_MklNativePadWithConv2D";
316     csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D";
317     csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
318     csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
319     csinfo_.pad = "Pad";
320     csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D";
321     csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D";
322     csinfo_.quantized_avg_pool = "QuantizedAvgPool";
323     csinfo_.quantized_concatv2 = "QuantizedConcatV2";
324     csinfo_.quantized_conv2d = "QuantizedConv2D";
325     csinfo_.quantized_conv2d_per_channel = "QuantizedConv2DPerChannel";
326     csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize";
327     csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias";
328     csinfo_.quantized_conv2d_with_bias_and_requantize =
329         "QuantizedConv2DWithBiasAndRequantize";
330     csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu";
331     csinfo_.quantized_conv2d_and_relu_and_requantize =
332         "QuantizedConv2DAndReluAndRequantize";
333     csinfo_.quantized_conv2d_with_bias_and_relu =
334         "QuantizedConv2DWithBiasAndRelu";
335     csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize =
336         "QuantizedConv2DWithBiasAndReluAndRequantize";
337     csinfo_.quantized_max_pool = "QuantizedMaxPool";
338     csinfo_.quantized_conv2d_with_bias_sum_and_relu =
339         "QuantizedConv2DWithBiasSumAndRelu";
340     csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize =
341         "QuantizedConv2DWithBiasSumAndReluAndRequantize";
342     csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize =
343         "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize";
344     csinfo_.quantized_matmul_with_bias = "QuantizedMatMulWithBias";
345     csinfo_.quantized_matmul_with_bias_and_relu =
346         "QuantizedMatMulWithBiasAndRelu";
347     csinfo_.quantized_matmul_with_bias_and_relu_and_requantize =
348         "QuantizedMatMulWithBiasAndReluAndRequantize";
349     csinfo_.quantized_matmul_with_bias_and_dequantize =
350         "QuantizedMatMulWithBiasAndDequantize";
351     csinfo_.quantized_matmul_with_bias_and_requantize =
352         "QuantizedMatMulWithBiasAndRequantize";
353     csinfo_.quantized_depthwise_conv2d = "QuantizedDepthwiseConv2D";
354     csinfo_.quantized_depthwise_conv2d_with_bias =
355         "QuantizedDepthwiseConv2DWithBias";
356     csinfo_.quantized_depthwise_conv2d_with_bias_and_relu =
357         "QuantizedDepthwiseConv2DWithBiasAndRelu";
358     csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize =
359         "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize";
360     csinfo_.quantize_v2 = "QuantizeV2";
361     csinfo_.relu = "Relu";
362     csinfo_.relu_grad = "ReluGrad";
363     csinfo_.relu6 = "Relu6";
364     csinfo_.relu6_grad = "Relu6Grad";
365     csinfo_.requantize = "Requantize";
366     csinfo_.tanh = "Tanh";
367     csinfo_.tanh_grad = "TanhGrad";
368     csinfo_.reshape = "Reshape";
369     csinfo_.slice = "Slice";
370     csinfo_.softmax = "Softmax";
371     csinfo_.split = "Split";
372     csinfo_.transpose = "Transpose";
373     // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
374     // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
375     // MklInputConversion op is added before it.
376     csinfo_.add = "Add";
377     csinfo_.add_v2 = "AddV2";
378     csinfo_.maximum = "Maximum";
379     csinfo_.mul = "Mul";
380     csinfo_.squared_difference = "SquaredDifference";
381     csinfo_.sub = "Sub";
382     // End - element-wise ops. See note above.
383 
384     const bool native_fmt = NativeFormatEnabled();
385     // NOTE: names are alphabetically sorted.
386     rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
387                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
388     rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
389                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
390                       GetRewriteCause()});
391     rinfo_.push_back(
392         {csinfo_.add_v2, mkl_op_registry::GetMklOpName(csinfo_.add_v2),
393          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
394     rinfo_.push_back({csinfo_.avg_pool,
395                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
396                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
397     rinfo_.push_back({csinfo_.avg_pool_grad,
398                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
399                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
400     rinfo_.push_back({csinfo_.avg_pool3d,
401                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
402                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
403     rinfo_.push_back({csinfo_.avg_pool3d_grad,
404                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
405                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
406     rinfo_.push_back({csinfo_.batch_matmul,
407                       mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
408                       CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
409     rinfo_.push_back({csinfo_.einsum,
410                       mkl_op_registry::GetMklOpName(csinfo_.einsum),
411                       CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
412     rinfo_.push_back({csinfo_.batch_matmul_v2,
413                       mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
414                       CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
415     rinfo_.push_back({csinfo_.concat,
416                       mkl_op_registry::GetMklOpName(csinfo_.concat),
417                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
418     rinfo_.push_back({csinfo_.concatv2,
419                       mkl_op_registry::GetMklOpName(csinfo_.concatv2),
420                       CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
421     rinfo_.push_back(
422         {csinfo_.conjugate_transpose,
423          mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
424          CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
425     rinfo_.push_back(
426         {csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d),
427          CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
428     rinfo_.push_back({csinfo_.conv2d_with_bias,
429                       native_fmt ? csinfo_.mkl_native_conv2d_with_bias
430                                  : csinfo_.mkl_conv2d_with_bias,
431                       CopyAttrsConvCheckConstFilter, AlwaysRewrite,
432                       GetRewriteCause()});
433     rinfo_.push_back({csinfo_.conv2d_grad_filter,
434                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
435                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
436     rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
437                       native_fmt
438                           ? csinfo_.mkl_native_conv2d_grad_filter_with_bias
439                           : csinfo_.mkl_conv2d_grad_filter_with_bias,
440                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
441     rinfo_.push_back({csinfo_.conv2d_grad_input,
442                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
443                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
444     rinfo_.push_back(
445         {csinfo_.conv3d, mkl_op_registry::GetMklOpName(csinfo_.conv3d),
446          CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
447     rinfo_.push_back({csinfo_.conv3d_grad_filter,
448                       mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
449                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
450     rinfo_.push_back({csinfo_.conv3d_grad_input,
451                       mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
452                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
453     rinfo_.push_back({csinfo_.depthwise_conv2d,
454                       mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
455                       CopyAttrsConvCheckConstFilter, AlwaysRewrite,
456                       GetRewriteCause()});
457     rinfo_.push_back(
458         {csinfo_.depthwise_conv2d_grad_input,
459          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
460          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
461     rinfo_.push_back(
462         {csinfo_.depthwise_conv2d_grad_filter,
463          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
464          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
465     rinfo_.push_back(
466         {csinfo_.dequantize, mkl_op_registry::GetMklOpName(csinfo_.dequantize),
467          CopyAttrsAll, DequantizeRewrite, kRewriteForOpNameChange});
468     rinfo_.push_back({csinfo_.fused_batch_norm,
469                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
470                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
471     rinfo_.push_back(
472         {csinfo_.fused_batch_norm_grad,
473          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
474          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
475     rinfo_.push_back(
476         {csinfo_.fused_batch_norm_v2,
477          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v2),
478          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
479     rinfo_.push_back(
480         {csinfo_.fused_batch_norm_grad_v2,
481          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v2),
482          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
483 
484     // Using CopyAttrsAll for V3 on CPU, as there are no additional
485     // attributes.
486     rinfo_.push_back(
487         {csinfo_.fused_batch_norm_v3,
488          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
489          CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
490     rinfo_.push_back(
491         {csinfo_.fused_batch_norm_grad_v3,
492          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
493          CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
494     rinfo_.push_back({csinfo_.fused_batch_norm_ex,
495                       native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
496                                  : csinfo_.mkl_fused_batch_norm_ex,
497                       CopyAttrsAll, FusedBatchNormExRewrite,
498                       GetRewriteCause()});
499     rinfo_.push_back({csinfo_.fused_conv2d,
500                       native_fmt ? csinfo_.mkl_native_fused_conv2d
501                                  : csinfo_.mkl_fused_conv2d,
502                       CopyAttrsAllCheckConstFilter, FusedConv2DRewrite,
503                       GetRewriteCause()});
504     rinfo_.push_back({csinfo_.fused_conv3d, csinfo_.mkl_native_fused_conv3d,
505                       CopyAttrsAllCheckConstFilter, AlwaysRewrite,
506                       kRewriteForOpNameChange});
507     rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
508                       native_fmt ? csinfo_.mkl_native_fused_depthwise_conv2d
509                                  : csinfo_.mkl_fused_depthwise_conv2d,
510                       CopyAttrsAllCheckConstFilter, FusedDepthwiseConv2DRewrite,
511                       GetRewriteCause()});
512     rinfo_.push_back({csinfo_.fused_matmul,
513                       native_fmt ? csinfo_.mkl_native_fused_matmul
514                                  : csinfo_.mkl_fused_matmul,
515                       CopyAttrsAllCheckConstFilter, FusedMatMulRewrite,
516                       GetRewriteCause()});
517     rinfo_.push_back(
518         {csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity),
519          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
520     rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
521                       CopyAttrsAll, LrnRewrite, GetRewriteCause()});
522     rinfo_.push_back({csinfo_.lrn_grad,
523                       mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
524                       CopyAttrsAll, LrnGradRewrite, GetRewriteCause()});
525     rinfo_.push_back({csinfo_.matmul,
526                       mkl_op_registry::GetMklOpName(csinfo_.matmul),
527                       CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
528     rinfo_.push_back({csinfo_.leakyrelu,
529                       mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
530                       CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
531     rinfo_.push_back({csinfo_.leakyrelu_grad,
532                       mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
533                       CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
534     rinfo_.push_back(
535         {csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool),
536          CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
537     rinfo_.push_back({csinfo_.max_pool_grad,
538                       mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
539                       CopyAttrsAll, MaxpoolGradRewrite, GetRewriteCause()});
540     rinfo_.push_back(
541         {csinfo_.max_pool3d, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
542          CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
543     rinfo_.push_back({csinfo_.max_pool3d_grad,
544                       mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
545                       CopyAttrsAll, Maxpool3DGradRewrite, GetRewriteCause()});
546     rinfo_.push_back(
547         {csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum),
548          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
549     rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
550                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
551                       GetRewriteCause()});
552     rinfo_.push_back({csinfo_.pad_with_conv2d,
553                       native_fmt ? csinfo_.mkl_native_pad_with_conv2d
554                                  : csinfo_.mkl_pad_with_conv2d,
555                       CopyAttrsAllCheckConstFilter, AlwaysRewrite,
556                       GetRewriteCause()});
557     rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
558                       native_fmt ? csinfo_.mkl_native_pad_with_fused_conv2d
559                                  : csinfo_.mkl_pad_with_fused_conv2d,
560                       CopyAttrsAllCheckConstFilter, AlwaysRewrite,
561                       GetRewriteCause()});
562     rinfo_.push_back({csinfo_.quantized_avg_pool,
563                       mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
564                       CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
565     rinfo_.push_back({csinfo_.quantized_concatv2,
566                       mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
567                       CopyAttrsAll, ConcatV2Rewrite, kRewriteForOpNameChange});
568     rinfo_.push_back({csinfo_.quantized_conv2d,
569                       mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
570                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
571                       kRewriteForOpNameChange});
572     rinfo_.push_back(
573         {csinfo_.quantized_conv2d_per_channel,
574          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_per_channel),
575          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
576     rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
577                       mkl_op_registry::GetMklOpName(
578                           csinfo_.quantized_conv2d_with_requantize),
579                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
580                       kRewriteForOpNameChange});
581     rinfo_.push_back(
582         {csinfo_.quantized_conv2d_with_bias,
583          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
584          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
585     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
586                       mkl_op_registry::GetMklOpName(
587                           csinfo_.quantized_conv2d_with_bias_and_requantize),
588                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
589                       kRewriteForOpNameChange});
590     rinfo_.push_back(
591         {csinfo_.quantized_conv2d_and_relu,
592          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
593          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
594     rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
595                       mkl_op_registry::GetMklOpName(
596                           csinfo_.quantized_conv2d_and_relu_and_requantize),
597                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
598                       kRewriteForOpNameChange});
599     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
600                       mkl_op_registry::GetMklOpName(
601                           csinfo_.quantized_conv2d_with_bias_and_relu),
602                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
603                       kRewriteForOpNameChange});
604     rinfo_.push_back(
605         {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
606          mkl_op_registry::GetMklOpName(
607              csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
608          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
609     rinfo_.push_back({csinfo_.quantized_max_pool,
610                       mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
611                       CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
612     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
613                       mkl_op_registry::GetMklOpName(
614                           csinfo_.quantized_conv2d_with_bias_sum_and_relu),
615                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
616                       kRewriteForOpNameChange});
617     rinfo_.push_back(
618         {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
619          mkl_op_registry::GetMklOpName(
620              csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
621          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
622     rinfo_.push_back(
623         {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
624          mkl_op_registry::GetMklOpName(
625              csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
626          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
627     rinfo_.push_back(
628         {csinfo_.quantized_matmul_with_bias,
629          mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias),
630          CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
631          kRewriteForOpNameChange});
632     rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_relu,
633                       mkl_op_registry::GetMklOpName(
634                           csinfo_.quantized_matmul_with_bias_and_relu),
635                       CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
636                       kRewriteForOpNameChange});
637     rinfo_.push_back(
638         {csinfo_.quantized_matmul_with_bias_and_relu_and_requantize,
639          mkl_op_registry::GetMklOpName(
640              csinfo_.quantized_matmul_with_bias_and_relu_and_requantize),
641          CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
642          kRewriteForOpNameChange});
643     rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_requantize,
644                       mkl_op_registry::GetMklOpName(
645                           csinfo_.quantized_matmul_with_bias_and_requantize),
646                       CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
647                       kRewriteForOpNameChange});
648     rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_dequantize,
649                       mkl_op_registry::GetMklOpName(
650                           csinfo_.quantized_matmul_with_bias_and_dequantize),
651                       CopyAttrsQuantizedMatMulWithBiasAndDequantize,
652                       AlwaysRewrite, kRewriteForOpNameChange});
653     rinfo_.push_back(
654         {csinfo_.quantized_depthwise_conv2d,
655          mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d),
656          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
657     rinfo_.push_back({csinfo_.quantized_depthwise_conv2d_with_bias,
658                       mkl_op_registry::GetMklOpName(
659                           csinfo_.quantized_depthwise_conv2d_with_bias),
660                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
661                       kRewriteForOpNameChange});
662     rinfo_.push_back(
663         {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu,
664          mkl_op_registry::GetMklOpName(
665              csinfo_.quantized_depthwise_conv2d_with_bias_and_relu),
666          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
667     rinfo_.push_back(
668         {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize,
669          mkl_op_registry::GetMklOpName(
670              csinfo_
671                  .quantized_depthwise_conv2d_with_bias_and_relu_and_requantize),
672          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
673     rinfo_.push_back({csinfo_.quantize_v2,
674                       mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
675                       CopyAttrsAll, QuantizeOpRewrite,
676                       kRewriteForOpNameChange});
677     rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
678                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
679     rinfo_.push_back({csinfo_.relu_grad,
680                       mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
681                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
682     rinfo_.push_back({csinfo_.relu6,
683                       mkl_op_registry::GetMklOpName(csinfo_.relu6),
684                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
685     rinfo_.push_back({csinfo_.relu6_grad,
686                       mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
687                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
688     rinfo_.push_back({csinfo_.requantize,
689                       mkl_op_registry::GetMklOpName(csinfo_.requantize),
690                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
691     // Optimized TanhGrad support exists only in DNNL 1.x.
692     rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
693                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
694     rinfo_.push_back({csinfo_.tanh_grad,
695                       mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
696                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
697     rinfo_.push_back({csinfo_.reshape,
698                       mkl_op_registry::GetMklOpName(csinfo_.reshape),
699                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
700     rinfo_.push_back(
701         {csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice),
702          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
703     rinfo_.push_back({csinfo_.softmax,
704                       mkl_op_registry::GetMklOpName(csinfo_.softmax),
705                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
706 
707     rinfo_.push_back({csinfo_.squared_difference,
708                       mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
709                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
710                       GetRewriteCause()});
711     rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
712                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
713                       GetRewriteCause()});
714     rinfo_.push_back({csinfo_.transpose,
715                       mkl_op_registry::GetMklOpName(csinfo_.transpose),
716                       CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
717 
718     // Add info about which ops to add workspace edge to and the slots.
719     wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
720     wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
721     wsinfo_.push_back(
722         {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
723 
724     // Add a rule for merging nodes
725     minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
726                       csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
727 
728     // Merge Pad and Conv2d, only if the pad op is "Pad"
729     // Doesn't merge if pad op is "PadV2" or "MirrorPad"
730     minfo_.push_back(
731         {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D});
732 
733     minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
734                       csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
735 
736     minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
737                       csinfo_.conv2d_grad_filter_with_bias,
738                       GetConv2DBackpropFilterOrBiasAddGrad});
739 
740     // The fusion patterns in "finfo_" that show up first will get applied
741     // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
742     // A->B->C->D to ABCD}, since the first gets applied first, the final
743     // graph will be ABC->D.
744   }
745 
746   // Standard interface to run pass
747   Status Run(const GraphOptimizationPassOptions& options);
748 
749   // Helper function which does most of heavy lifting for rewriting
750   // Mkl nodes to propagate Mkl tensor as additional output
751   //
752   // Extracts common functionality between Run public interface and
753   // test interface.
754   //
755   // @return true, if and only if graph is mutated; false otherwise.
756   bool RunPass(std::unique_ptr<Graph>* g);
757 
758   /// Cause for rewrite
759   /// Currently, we only support 2 causes - either for Mkl layout propagation
760   /// which is the most common case, or for just a name change (used in case
761   /// of ops like MatMul, Transpose, which do not support Mkl layout)
762   enum RewriteCause { kRewriteForLayoutPropagation, kRewriteForOpNameChange };
763 
764   // Get the op rewrite cause depending on whether native format mode
765   // is enabled or not.
GetRewriteCause()766   RewriteCause GetRewriteCause() {
767     if (NativeFormatEnabled()) {
768       return kRewriteForOpNameChange;
769     } else {
770       return kRewriteForLayoutPropagation;
771     }
772   }
773 
774   /// Structure to specify the name of an original node, its new name after
775   /// rewrite, the number of inputs to the original node, the function to
776   /// be used to copy attributes for the op, and the rule (if any) which
777   /// must hold for rewriting the node
778   typedef struct {
779     string name;      // Original name of op of the node in the graph
780     string new_name;  // New name of the op of the node in the graph
781     // A function handler to copy attributes from an old node to a new node.
782     std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs;
783     // A rule under which to rewrite this node
784     std::function<bool(const Node*)> rewrite_rule;
785     // Why are we rewriting?
786     RewriteCause rewrite_cause;
787   } RewriteInfo;
788 
789   /// Structure to specify a forward op, a backward op, and the slot numbers
790   /// in the forward and backward ops where we will add a workspace edge.
791   typedef struct {
792     string fwd_op;    // Name of a forward op in the graph
793     string bwd_op;    // Name of a backward op in the graph
794     int fwd_slot;     // Output slot in the forward op node where actual
795                       // output tensor resides
796     int bwd_slot;     // Input slot in the backward op node where actual
797                       // input tensor resides
798     int ws_fwd_slot;  // Output slot in the forward op node where workspace
799                       // edge is added
800     int ws_bwd_slot;  // Input slot in the backward op node where workspace
801                       // edge is added
802   } WorkSpaceInfo;
803 
804   /// Structure to specify information used in node merge of 2 operators
805   typedef struct {
806     string op1;       // Node string for one operator.
807     string op2;       // Node string for second operator.
808     string new_node;  // Name of the node after merge
809     // Function that enables user of the node merger to specify how to find
810     // second operator given the first operator.
811     std::function<Node*(const Node*)> get_node_to_be_merged;
812   } MergeInfo;
813 
814   // Structure to specify information used in node fusion of 3+ operators
815   typedef struct {
816     std::string pattern_name;  // Name to describe this pattern, such as
817                                // "Transpose_Mklop_Transpose".
818     std::vector<std::function<bool(const Node*)> >
819         node_checkers;  // Extra restriction checker for these ops
820     std::function<Status(
821         std::unique_ptr<Graph>*, std::vector<Node*>&,
822         std::function<void(const Node*, NodeBuilder* nb, bool)>)>
823         fuse_func;
824     std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs;
825   } FusionInfo;
826 
827   //
828   // Dimension indices for 2D tensor.
829   //
830   struct NCHW {
831     enum dim { N = 0, C = 1, H = 2, W = 3 };
832   };
833 
834   struct NHWC {
835     enum dim { N = 0, H = 1, W = 2, C = 3 };
836   };
837 
838   //
839   // dimension indices for 3D tensor.
840   //
841   struct NCDHW {
842     enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 };
843   };
844 
845   struct NDHWC {
846     enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 };
847   };
848 
849   /// Structure to store all constant strings
850   /// NOTE: names are alphabetically sorted.
851   typedef struct {
852     string addn;
853     string add;
854     string add_v2;
855     string avg_pool;
856     string avg_pool_grad;
857     string avg_pool3d;
858     string avg_pool3d_grad;
859     string batch_matmul;
860     string batch_matmul_v2;
861     string bias_add;
862     string bias_add_grad;
863     string concat;
864     string concatv2;
865     string conjugate_transpose;
866     string conv2d;
867     string conv2d_with_bias;
868     string conv2d_grad_input;
869     string conv2d_grad_filter;
870     string conv2d_grad_filter_with_bias;
871     string conv3d;
872     string conv3d_grad_input;
873     string conv3d_grad_filter;
874     string depthwise_conv2d;
875     string depthwise_conv2d_grad_input;
876     string depthwise_conv2d_grad_filter;
877     string dequantize;
878     string einsum;
879     string fused_batch_norm;
880     string fused_batch_norm_grad;
881     string fused_batch_norm_ex;
882     string fused_batch_norm_v2;
883     string fused_batch_norm_grad_v2;
884     string fused_batch_norm_v3;
885     string fused_batch_norm_grad_v3;
886     string fused_conv2d;
887     string fused_conv3d;
888     string fused_depthwise_conv2d;
889     string fused_matmul;
890     string identity;
891     string leakyrelu;
892     string leakyrelu_grad;
893     string lrn;
894     string lrn_grad;
895     string matmul;
896     string max_pool;
897     string max_pool_grad;
898     string max_pool3d;
899     string max_pool3d_grad;
900     string maximum;
901     string mkl_conv2d;
902     string mkl_conv2d_grad_input;
903     string mkl_conv2d_grad_filter;
904     string mkl_conv2d_grad_filter_with_bias;
905     string mkl_conv2d_with_bias;
906     string mkl_depthwise_conv2d_grad_input;
907     string mkl_depthwise_conv2d_grad_filter;
908     string mkl_fused_batch_norm_ex;
909     string mkl_fused_conv2d;
910     string mkl_fused_depthwise_conv2d;
911     string mkl_fused_matmul;
912     string mkl_native_conv2d_with_bias;
913     string mkl_native_conv2d_grad_filter_with_bias;
914     string mkl_native_fused_batch_norm_ex;
915     string mkl_native_fused_conv2d;
916     string mkl_native_fused_conv3d;
917     string mkl_native_fused_depthwise_conv2d;
918     string mkl_native_fused_matmul;
919     string mkl_native_pad_with_conv2d;
920     string mkl_native_pad_with_fused_conv2d;
921     string mkl_pad_with_conv2d;
922     string mkl_pad_with_fused_conv2d;
923     string mul;
924     string pad;
925     string pad_with_conv2d;
926     string pad_with_fused_conv2d;
927     string quantized_avg_pool;
928     string quantized_conv2d;
929     string quantized_conv2d_per_channel;
930     string quantized_conv2d_with_requantize;
931     string quantized_conv2d_with_bias;
932     string quantized_conv2d_with_bias_and_requantize;
933     string quantized_conv2d_and_relu;
934     string quantized_conv2d_and_relu_and_requantize;
935     string quantized_conv2d_with_bias_and_relu;
936     string quantized_conv2d_with_bias_and_relu_and_requantize;
937     string quantized_concatv2;
938     string quantized_max_pool;
939     string quantized_conv2d_with_bias_sum_and_relu;
940     string quantized_conv2d_with_bias_sum_and_relu_and_requantize;
941     string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize;
942     string quantized_matmul_with_bias;
943     string quantized_matmul_with_bias_and_relu;
944     string quantized_matmul_with_bias_and_relu_and_requantize;
945     string quantized_matmul_with_bias_and_requantize;
946     string quantized_matmul_with_bias_and_dequantize;
947     string quantized_depthwise_conv2d;
948     string quantized_depthwise_conv2d_with_bias;
949     string quantized_depthwise_conv2d_with_bias_and_relu;
950     string quantized_depthwise_conv2d_with_bias_and_relu_and_requantize;
951     string quantize_v2;
952     string relu;
953     string relu_grad;
954     string relu6;
955     string relu6_grad;
956     string requantize;
957     string tanh;
958     string tanh_grad;
959     string transpose;
960     string reshape;
961     string slice;
962     string softmax;
963     string split;
964     string squared_difference;
965     string sub;
966   } ConstStringsInfo;
967 
968  private:
969   /// Maintain info about nodes to rewrite
970   std::vector<RewriteInfo> rinfo_;
971 
972   /// Maintain info about nodes to add workspace edge
973   std::vector<WorkSpaceInfo> wsinfo_;
974 
975   /// Maintain info about nodes to be merged
976   std::vector<MergeInfo> minfo_;
977 
978   /// Maintain info about nodes to be fused
979   std::vector<FusionInfo> finfo_;
980 
981   /// Maintain structure of constant strings
982   static ConstStringsInfo csinfo_;
983 
984  private:
985   // Is OpDef::ArgDef a list type? It could be N * T or list(type).
986   // Refer to opdef.proto for details of list type.
ArgIsList(const OpDef::ArgDef & arg) const987   inline bool ArgIsList(const OpDef::ArgDef& arg) const {
988     return !arg.type_list_attr().empty() || !arg.number_attr().empty();
989   }
990 
991   // Get length of a list in 'n' if 'arg' is of list type. Refer to
992   // description of ArgIsList for definition of list type.
GetTensorListLength(const OpDef::ArgDef & arg,const Node * n)993   inline int GetTensorListLength(const OpDef::ArgDef& arg, const Node* n) {
994     CHECK_EQ(ArgIsList(arg), true);
995     int N = 0;
996     const string attr_name = !arg.type_list_attr().empty()
997                                  ? arg.type_list_attr()
998                                  : arg.number_attr();
999     if (!arg.type_list_attr().empty()) {
1000       std::vector<DataType> value;
1001       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
1002       N = value.size();
1003     } else {
1004       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
1005     }
1006     return N;
1007   }
1008 
1009   // Can op represented by node 'n' run on DEVICE_CPU?
1010   // Op can run on CPU with MKL if the runtime assigned device or the
1011   // user requested device contains device CPU, or both are empty.
CanOpRunOnCPUDevice(const Node * n)1012   bool CanOpRunOnCPUDevice(const Node* n) {
1013     bool result = true;
1014     string reason;
1015 
1016     // Substring that should be checked for in device name for CPU device.
1017     const char* const kCPUDeviceSubStr = "CPU";
1018     const char* const kXLACPUDeviceSubStr = "XLA_CPU";
1019 
1020     // If Op has been specifically assigned to a non-CPU or XLA_CPU device, then
1021     // No.
1022     if (!n->assigned_device_name().empty() &&
1023         (!absl::StrContains(n->assigned_device_name(), kCPUDeviceSubStr) ||
1024          absl::StrContains(n->assigned_device_name(), kXLACPUDeviceSubStr))) {
1025       result = false;
1026       reason = "Op has been assigned a runtime device that is not CPU.";
1027     }
1028 
1029     // If user has specifically assigned this op to a non-CPU or XLA_CPU device,
1030     // then No.
1031     if (!n->def().device().empty() &&
1032         (!absl::StrContains(n->def().device(), kCPUDeviceSubStr) ||
1033          absl::StrContains(n->def().device(), kXLACPUDeviceSubStr))) {
1034       result = false;
1035       reason = "User has assigned a device that is not CPU.";
1036     }
1037 
1038     if (result == false) {
1039       VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
1040               << n->type_string() << ", reason: " << reason;
1041     }
1042 
1043     // Otherwise Yes.
1044     return result;
1045   }
1046 
1047   // Return a node that can be merged with input node 'n'
1048   //
1049   // @return pointer to the node if we can find such a
1050   // node. Otherwise, it returns nullptr.
1051   Node* CheckForNodeMerge(const Node* n) const;
1052 
1053   // Merge node 'm' with node 'n'.
1054   // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with
1055   // Conv2DBackpropFilter.
1056   //
1057   // Input nodes m and n may be deleted if the call to
1058   // this function is successful. Attempt to use the pointers
1059   // after the call to function may result in undefined behaviors.
1060   //
1061   // @input g - input graph, m - graph node, n - graph node to be merged with m
1062   // @return Status::OK(), if merging is successful and supported.
1063   //         Returns appropriate Status error code otherwise.
1064   //         Graph is updated in case nodes are merged. Otherwise, it is
1065   //         not updated.
1066   Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n);
1067 
1068   // Helper function to merge different nodes
1069   Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n);
1070   Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n);
1071   Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g,
1072                                                   Node* m, Node* n);
1073 
1074   // Find BiasAdd or Conv2D node that can be merged with input node 'm'.
1075   // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be
1076   // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd
1077   // node that can be merged with 'm'.
GetConv2DOrBiasAdd(const Node * m)1078   static Node* GetConv2DOrBiasAdd(const Node* m) {
1079     DCHECK(m);
1080     Node* n = nullptr;
1081 
1082     DataType T_m;
1083     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1084 
1085     // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1086     if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1087 
1088     if (m->type_string() == csinfo_.bias_add) {
1089       // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
1090       TF_CHECK_OK(m->input_node(0, &n));
1091     } else {
1092       CHECK_EQ(m->type_string(), csinfo_.conv2d);
1093       // Go over all output edges and search for BiasAdd Node.
1094       // 0th input of BiasAdd is Conv2D.
1095       for (const Edge* e : m->out_edges()) {
1096         if (!e->IsControlEdge() &&
1097             e->dst()->type_string() == csinfo_.bias_add &&
1098             e->dst_input() == 0) {
1099           n = e->dst();
1100           break;
1101         }
1102       }
1103     }
1104 
1105     if (n == nullptr) {
1106       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1107               << "Conv2D and BiasAdd node for merging. Input node: "
1108               << m->DebugString();
1109     }
1110 
1111     return n;
1112   }
1113 
1114   // Find Pad or Conv2D node that can be merged with input node 'm'.
1115   // If input 'm' is Pad, then check if there exists Conv2D node that can be
1116   // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad
1117   // node that can be merged with 'm'.
GetPadOrConv2D(const Node * m)1118   static Node* GetPadOrConv2D(const Node* m) {
1119     DCHECK(m);
1120     Node* n = nullptr;
1121 
1122     DataType T_m;
1123     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1124 
1125     // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1126     if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1127 
1128     const Node* conv_node;
1129     if (m->type_string() == csinfo_.pad) {
1130       // If m is Pad, then Conv2D is the output of Pad.
1131       for (const Edge* e : m->out_edges()) {
1132         if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) {
1133           n = e->dst();
1134           conv_node = n;
1135           break;
1136         }
1137       }
1138     } else {
1139       DCHECK_EQ(m->type_string(), csinfo_.conv2d);
1140       // If m is conv2D, Go over all input edges
1141       // and search for Pad  Node.
1142       for (const Edge* e : m->in_edges()) {
1143         if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1144           n = e->src();
1145           conv_node = m;
1146           break;
1147         }
1148       }
1149     }
1150     // Check if only VALID type of padding is used
1151     // or not.
1152     if (n != nullptr) {
1153       string padding;
1154       TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1155       if (padding != "VALID")
1156         // Then do not merge.
1157         // Only VALID type of padding in conv op can be
1158         // merged with Pad op.
1159         n = nullptr;
1160     } else {
1161       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1162               << "Pad and Conv2D node for merging. Input node: "
1163               << m->DebugString();
1164     }
1165 
1166     return n;
1167   }
1168 
1169   // Find Pad or _FusedConv2D node that can be merged with input node 'm'.
1170   // If input 'm' is Pad, then check if there exists _FusedConv2D node that can
1171   // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there
1172   // exists Pad node that can be merged with 'm'.
GetPadOrFusedConv2D(const Node * m)1173   static Node* GetPadOrFusedConv2D(const Node* m) {
1174     DCHECK(m);
1175     Node* n = nullptr;
1176 
1177     const Node* conv_node;
1178     if (m->type_string() == csinfo_.pad) {
1179       // If m is Pad, then _FusedConv2D is the output of Pad.
1180       for (const Edge* e : m->out_edges()) {
1181         if (!e->IsControlEdge() &&
1182             e->dst()->type_string() == csinfo_.fused_conv2d) {
1183           n = e->dst();
1184           conv_node = n;
1185           break;
1186         }
1187       }
1188     } else {
1189       DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d);
1190       // If m is _FusedConv2D, Go over all input edges
1191       // and search for Pad node.
1192       for (const Edge* e : m->in_edges()) {
1193         if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1194           n = e->src();
1195           conv_node = m;
1196           break;
1197         }
1198       }
1199     }
1200     // Check if only VALID type of padding is used or not.
1201     if (n != nullptr) {
1202       string padding;
1203       TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1204       if (padding != "VALID") {
1205         // Then do not merge.
1206         n = nullptr;
1207         VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
1208                 << "nodes but cannot merge them. Only conv ops with padding "
1209                 << "type VALID can be merged with Pad op Input node: "
1210                 << m->DebugString();
1211       }
1212     } else {
1213       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1214               << "Pad and _FusedConv2D node for merging. Input node: "
1215               << m->DebugString();
1216     }
1217 
1218     return n;
1219   }
1220 
1221   // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input
1222   // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists
1223   // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad,
1224   // then check if there exists Conv2DBackpropFilter node that can be merged
1225   // with 'm'.
1226   //
1227   // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad
1228   // would look like:
1229   //
1230   // _ = Conv2DBackpropFilter(F, _, G)
1231   // _ = BiasAddGrad(G)
1232   //
1233   // So 1st input of BiasAddGrad connects with 3rd input of
1234   // Conv2DBackpropFilter and vice versa.
GetConv2DBackpropFilterOrBiasAddGrad(const Node * m)1235   static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) {
1236     DCHECK(m);
1237     Node* n = nullptr;
1238     const Node* conv2d_backprop_filter = nullptr;
1239 
1240     DataType T_m;
1241     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1242 
1243     // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1244     if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1245 
1246     if (m->type_string() == csinfo_.bias_add_grad) {
1247       // Get 1st input 'g' of BiasAddGrad.
1248       Node* g = nullptr;
1249       TF_CHECK_OK(m->input_node(0, &g));
1250       // Now traverse all outgoing edges from g that have destination node as
1251       // Conv2DBackpropFilter.
1252       for (const Edge* e : g->out_edges()) {
1253         if (!e->IsControlEdge() &&
1254             e->dst()->type_string() == csinfo_.conv2d_grad_filter &&
1255             e->dst_input() == 2 /* 3rd input of BackpropFilter */) {
1256           n = e->dst();
1257           conv2d_backprop_filter = n;
1258           break;
1259         }
1260       }
1261     } else {
1262       conv2d_backprop_filter = m;
1263       CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter);
1264       // Get 3rd input 'g' of Conv2DBackpropFilter.
1265       Node* g = nullptr;
1266       TF_CHECK_OK(m->input_node(2, &g));
1267       // Now traverse all outgoing edges from g that have destination node as
1268       // BiasAddGrad.
1269       for (const Edge* e : g->out_edges()) {
1270         if (!e->IsControlEdge() &&
1271             e->dst()->type_string() == csinfo_.bias_add_grad &&
1272             e->dst_input() == 0 /* 1st input of BiasAddGrad */) {
1273           n = e->dst();
1274           break;
1275         }
1276       }
1277     }
1278 
1279     // Do not merge if padding type is EXPLICIT.
1280     // TODO(intel): Support `EXPLICIT` padding for MklConv2DBackpropFilter.
1281     if (conv2d_backprop_filter != nullptr) {
1282       string padding;
1283       TF_CHECK_OK(
1284           GetNodeAttr(conv2d_backprop_filter->def(), "padding", &padding));
1285       if (padding == "EXPLICIT") {
1286         // Then do not merge.
1287         VLOG(1) << "MklLayoutRewritePass: Could match Conv2DBackpropFilter "
1288                 << "and BiasAddGrad nodes but cannot merge them. "
1289                 << "EXPLICIT padding is not supported now. "
1290                 << conv2d_backprop_filter->DebugString();
1291         return nullptr;
1292       }
1293     }
1294 
1295     if (n == nullptr) {
1296       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1297               << "Conv2DBackpropFilter and BiasAddGrad node for merging. "
1298               << "Input node: " << m->DebugString();
1299     }
1300     return n;
1301   }
1302 
1303   // Return a node that can be fused with input node 'n'
1304   //
1305   // @return tuple. If we can find such nodes, the first
1306   // element of the tuple is a true. Otherwise, it's false.
1307   std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
1308   CheckForNodeFusion(Node* n) const;
1309 
1310   // Fuse nodes in the vector "nodes"
1311   Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1312                   const MklLayoutRewritePass::FusionInfo fi);
1313 
1314   // Fuse transpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into
1315   // mklop("NCHW").
1316   // Here "mklop" can be any MKL-DNN supported op, such as Conv2D.
1317   static Status FuseTransposeMklOpTranspose(
1318       std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1319       std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
1320       string data_format);
1321 
CheckForTranspose(const Node * node,std::vector<int> perm)1322   static bool CheckForTranspose(const Node* node, std::vector<int> perm) {
1323     // Check if node's type is "Transpose"
1324     if (node->type_string() != "Transpose") return false;
1325 
1326     // If "Transpose" has multiple output data edges, also don't fuse it.
1327     if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false;
1328 
1329     // Check if has out control edge. If true, this is a training graph.
1330     // Currently we focus on inference and do no fusion in training.
1331     // Note: this constraint will eventually be removed, if we enabled this
1332     // fusion for training
1333     // in the future.
1334     for (const Edge* e : node->out_edges()) {
1335       if (e->IsControlEdge()) {
1336         return false;
1337       }
1338     }
1339 
1340     // If "Transpose" has input control edges, don't fuse on it.
1341     for (const Edge* e : node->in_edges()) {
1342       if (e->IsControlEdge()) {
1343         return false;
1344       }
1345     }
1346 
1347     // We compared the tensor containing the permutation order ("perm_node")
1348     // with our desired order ("perm"). If they're exactly match, this check
1349     // succeed and returns true.
1350     for (const Edge* e : node->in_edges()) {
1351       if (!e->IsControlEdge()) {
1352         const Node* perm_node = e->src();
1353 
1354         const int kPermTensorIndex = 1;
1355         if (perm_node->type_string() == "Const" &&
1356             e->dst_input() == kPermTensorIndex) {
1357           // we find the "perm" node, now try to retrieve its value.
1358           const TensorProto* proto = nullptr;
1359           TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value", &proto));
1360 
1361           DataType type;
1362           TF_CHECK_OK(GetNodeAttr(perm_node->def(), "dtype", &type));
1363 
1364           Tensor tensor;
1365           if (!tensor.FromProto(*proto)) {
1366             TF_CHECK_OK(errors::InvalidArgument(
1367                 "Could not construct Tensor from TensorProto in node: ",
1368                 node->name()));
1369             return false;
1370           }
1371           // Current fusion only supports 4D or 5D tensors according to `perm`
1372           // vector, return false otherwise.
1373           if (tensor.dim_size(0) != perm.size()) return false;
1374           DCHECK_EQ(tensor.dims(), 1);
1375           if (type == DT_INT32) {
1376             const auto tensor_content = tensor.flat<int>().data();
1377             for (int i = 0; i < perm.size(); ++i)
1378               if (tensor_content[i] != perm[i]) return false;
1379             return true;
1380           } else if (type == DT_INT64) {
1381             const auto tensor_content = tensor.flat<int64_t>().data();
1382             for (int i = 0; i < perm.size(); ++i)
1383               if (tensor_content[i] != perm[i]) return false;
1384             return true;
1385           }
1386           return false;
1387         }
1388       }
1389     }
1390     return false;
1391   }
1392 
CheckForMklOp(const Node * node,string name="")1393   static bool CheckForMklOp(const Node* node, string name = "") {
1394     if (node == nullptr) return false;
1395 
1396     if (!name.empty() && node->type_string() != name) {
1397       return false;
1398     }
1399 
1400     // if mklop has multiple outputs, don't fuse it.
1401     if (node->num_outputs() > 1) return false;
1402 
1403     if (node->out_edges().size() > 1) return false;
1404 
1405     DataType T;
1406     TF_CHECK_OK(GetNodeAttr(node->def(), "T", &T));
1407     return mkl_op_registry::IsMklOp(
1408         mkl_op_registry::GetMklOpName(node->type_string()), T);
1409   }
1410 
1411   // Check if the node 'n' has any applicable rewrite rule
1412   // We check for 2 scenarios for rewrite.
1413   //
1414   // @return RewriteInfo* for the applicable rewrite rule
1415   const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
1416   const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const;
1417 
1418   // Default rewrite rule to be used in scenario 1 for rewrite.
1419   // @return - true (since we want to always rewrite)
AlwaysRewrite(const Node * n)1420   static bool AlwaysRewrite(const Node* n) { return true; }
1421 
1422   // Rewrite rule which considers "context" of the current node to decide if we
1423   // should rewrite. By "context" we currently mean all the inputs of current
1424   // node. The idea is if none of the inputs of current node are not MKL nodes,
1425   // then rewriting current node to MKL node _may not_ offer any performance
1426   // improvement.
1427   //
1428   // One such case is element-wise ops. For such ops, we reuse the Eigen
1429   // implementation and pass the MKL metadata tensor through so we can avoid
1430   // conversions. However, if all incoming edges are in TF format, we don't
1431   // need all this overhead, so replace the elementwise node only if at least
1432   // one of its parents is a MKL node.
1433   //
1434   // More generally, all memory- or IO-bound ops (such as Identity) may fall
1435   // under this category.
1436   //
1437   // @input - Input graph node to be rewritten
1438   // @return - true if node is to be rewritten as MKL node; false otherwise.
RewriteIfAtleastOneMklInput(const Node * n)1439   static bool RewriteIfAtleastOneMklInput(const Node* n) {
1440     DataType T;
1441     if (GetNodeAttr(n->def(), "T", &T).ok() &&
1442         mkl_op_registry::IsMklOp(
1443             mkl_op_registry::GetMklOpName(n->type_string()), T)) {
1444       for (auto e : n->in_edges()) {
1445         if (e->IsControlEdge()) continue;
1446         if (mkl_op_registry::IsMklOp(e->src())) {
1447           return true;
1448         }
1449       }
1450     }
1451     return false;
1452   }
1453 
MatMulRewrite(const Node * n)1454   static bool MatMulRewrite(const Node* n) {
1455     DataType T;
1456     TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
1457     if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) {
1458       VLOG(2) << "Rewriting MatMul to _MklMatMul";
1459       return true;
1460     }
1461     return false;
1462   }
1463   // For oneDNN, only int32 is supported for axis data type
ConcatV2Rewrite(const Node * n)1464   static bool ConcatV2Rewrite(const Node* n) {
1465     DataType T;
1466     TF_CHECK_OK(GetNodeAttr(n->def(), "Tidx", &T));
1467     return (T == DT_INT32);
1468   }
1469 
DequantizeRewrite(const Node * n)1470   static bool DequantizeRewrite(const Node* n) {
1471     DCHECK(n);
1472     Node* input = nullptr;
1473     TF_CHECK_OK(n->input_node(0, &input));
1474     string mode_string;
1475     int axis = -1;
1476     TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1477     TF_CHECK_OK(GetNodeAttr(n->def(), "axis", &axis));
1478     if (mode_string != "SCALED") {
1479       VLOG(1) << "DequantizeRewrite: Mode is not SCALED. "
1480               << "This case is not optimized by Intel MKL kernel, thus using "
1481                  "Eigen op for Dequantize op.";
1482       return false;
1483     }
1484     if (input->IsConstant()) {
1485       VLOG(1) << "DequantizeRewrite: Trying to dequantize a Const node which "
1486               << "could possibly be a filter. "
1487               << "This case is not supported by Intel MKL kernel, thus using "
1488                  "Eigen op for Dequantize op.";
1489       return false;
1490     }
1491 
1492     if (axis != -1) {
1493       VLOG(1) << "DequantizeRewrite: Using Eigen op for Dequantize op because "
1494               << "dimension is specified for per slice dequantization. ";
1495       return false;
1496     }
1497     return true;
1498   }
1499 
1500   // Rewrite rule for _FusedMatMul.
1501   // @return - true (no transpose attribute for input 1);
1502   //           false otherwise.
FusedMatMulRewrite(const Node * n)1503   static bool FusedMatMulRewrite(const Node* n) {
1504     bool trans_a;
1505 
1506     // Do not rewrite with transpose attribute because reorder has performance
1507     // impact.
1508     TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &trans_a));
1509 
1510     return !trans_a;
1511   }
1512 
1513   // Check if we are performing pooling on depth or batch. If it is, then we
1514   // do not rewrite MaxPool node to Mkl version.
1515   // @return - true (if it is not a depth/batch wise pooling case);
1516   //           false otherwise.
NonDepthBatchWisePoolRewrite(const Node * n)1517   static bool NonDepthBatchWisePoolRewrite(const Node* n) {
1518     DCHECK(n);
1519 
1520     string data_format_str;
1521     TensorFormat data_format;
1522     std::vector<int32> ksize, strides;
1523     TF_CHECK_OK(GetNodeAttr(n->def(), "ksize", &ksize));
1524     TF_CHECK_OK(GetNodeAttr(n->def(), "strides", &strides));
1525     TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format_str));
1526     bool result = FormatFromString(data_format_str, &data_format);
1527     DCHECK(result);
1528 
1529     // Condition that specifies non-batch-wise and non-depth-wise pooling.
1530     if (GetTensorDim(ksize, data_format, 'N') == 1 &&
1531         GetTensorDim(strides, data_format, 'N') == 1 &&
1532         GetTensorDim(ksize, data_format, 'C') == 1 &&
1533         GetTensorDim(strides, data_format, 'C') == 1) {
1534       return true;
1535     }
1536 
1537     return false;
1538   }
1539 
1540   // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
1541   // path. The unoptimized path is slow. Thus we don't rewrite the node
1542   // and use default Eigen. But for depth_radius=2, MKL DNN optimized
1543   // path is taken, i.e., eigen node is rewritten by MKl DNN node.
LrnRewrite(const Node * n)1544   static bool LrnRewrite(const Node* n) {
1545     DCHECK(n);
1546 
1547     int depth_radius;
1548     TF_CHECK_OK(GetNodeAttr(n->def(), "depth_radius", &depth_radius));
1549 
1550     // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
1551     // and use eigen node instead
1552     if (depth_radius == 2) {
1553       return true;
1554     }
1555     VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
1556             << "case is not optimized by Intel MKL, thus using Eigen op"
1557             << "for LRN ";
1558 
1559     return false;
1560   }
1561 
LrnGradRewrite(const Node * n)1562   static bool LrnGradRewrite(const Node* n) {
1563     DCHECK(n);
1564     bool do_rewrite = false;
1565 
1566     for (const Edge* e : n->in_edges()) {
1567       // Rewrite only if there is corresponding LRN, i.e workspace is available
1568       if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
1569           e->src()->type_string() ==
1570               mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
1571           e->src_output() == 0) {
1572         do_rewrite = true;
1573         break;
1574       }
1575     }
1576     return do_rewrite;
1577   }
1578 
1579   // MKL-DNN's LeakyRelu(feature) = feature          (if feature > 0), or
1580   //                                feature * alpha  (otherwise),
1581   // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha).
1582   // These two algorithms are not consistent when alpha > 1,
1583   // so we only rewrite LeakyRelu to MKL OP when alpha <= 1.
LeakyReluRewrite(const Node * n)1584   static bool LeakyReluRewrite(const Node* n) {
1585     DCHECK(n);
1586 
1587     float alpha;
1588     bool has_attr = TryGetNodeAttr(n->def(), "alpha", &alpha);
1589     DCHECK(has_attr);
1590 
1591     // If the alpha of LeakyRelu is less than 1, rewrite the node.
1592     // Otherwise eigen node is used instead.
1593     if (alpha <= 1) {
1594       return true;
1595     }
1596     VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 "
1597             << "which case is not optimized by Intel MKL, thus using Eigen op"
1598             << "for LeakyRelu ";
1599 
1600     return false;
1601   }
1602 
QuantizeOpRewrite(const Node * n)1603   static bool QuantizeOpRewrite(const Node* n) {
1604     DCHECK(n);
1605     Node* filter_node = nullptr;
1606     TF_CHECK_OK(n->input_node(0, &filter_node));
1607     bool narrow_range = false;
1608     int axis = -1;
1609     string mode_string;
1610     string round_mode_string;
1611     DataType type;
1612     TryGetNodeAttr(n->def(), "narrow_range", &narrow_range);
1613     TryGetNodeAttr(n->def(), "axis", &axis);
1614     TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1615     TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode", &round_mode_string));
1616     TF_CHECK_OK(GetNodeAttr(n->def(), "T", &type));
1617 
1618     if (narrow_range) {
1619       VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization."
1620               << "This case is not optimized by Intel MKL, "
1621               << "thus using Eigen op for Quantize op ";
1622       return false;
1623     }
1624     if (axis != -1) {
1625       VLOG(1) << "QuantizeOpRewrite: dimension is specified for "
1626               << "per slice quantization."
1627               << "This case is not optimized by Intel MKL, "
1628               << "thus using Eigen op for Quantize op ";
1629       return false;
1630     }
1631     if (!((mode_string == "SCALED" && round_mode_string == "HALF_TO_EVEN") ||
1632           (mode_string == "MIN_FIRST"))) {
1633       VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED or MIN_FIRST and/or"
1634               << "rounding mode is not HALF_TO_EVEN. "
1635               << "This case is not optimized by Intel MKL, thus using Eigen op"
1636               << "for Quantize op ";
1637       return false;
1638     }
1639     if (filter_node->IsConstant()) {
1640       VLOG(1) << "QuantizeOpRewrite: Trying to quantize a node which "
1641               << "is a constant. "
1642               << "This case is not supported by the kernel, thus using Eigen op"
1643               << "for Quantize op ";
1644 
1645       return false;
1646     }
1647     if (mode_string == "MIN_FIRST") {
1648       if (type != DT_QUINT8) {
1649         VLOG(1) << "QuantizeOpRewrite: For MIN_FIRST mode the data type is "
1650                 << "not DT_UINT8. This case is not optimized by Intel MKL, "
1651                 << "thus using Eigen op for Quantize op ";
1652         return false;
1653       }
1654     }
1655     return true;
1656   }
1657 
MaxpoolGradRewrite(const Node * n)1658   static bool MaxpoolGradRewrite(const Node* n) {
1659     DCHECK(n);
1660     bool do_rewrite = false;
1661     for (const Edge* e : n->in_edges()) {
1662       // Rewrite only if there is corresponding Maxpool, i.e workspace is
1663       // available
1664       if (e->dst()->type_string() == csinfo_.max_pool_grad &&
1665           e->dst_input() == 1 &&
1666           e->src()->type_string() ==
1667               mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
1668           e->src_output() == 0) {
1669         do_rewrite = true;
1670         break;
1671       }
1672     }
1673     return do_rewrite;
1674   }
1675 
Maxpool3DGradRewrite(const Node * n)1676   static bool Maxpool3DGradRewrite(const Node* n) {
1677     DCHECK(n);
1678     for (const Edge* e : n->in_edges()) {
1679       // Rewrite only if there is corresponding Maxpool3D, i.e., workspace is
1680       // available
1681       if (e->dst()->type_string() == csinfo_.max_pool3d_grad &&
1682           e->dst_input() == 1 &&
1683           e->src()->type_string() ==
1684               mkl_op_registry::GetMklOpName(csinfo_.max_pool3d) &&
1685           e->src_output() == 0) {
1686         return true;
1687       }
1688     }
1689     return false;
1690   }
1691 
FusedBatchNormV3Rewrite(const Node * n)1692   static bool FusedBatchNormV3Rewrite(const Node* n) {
1693     DCHECK(n);
1694     if (Check5DFormat(n->def())) {
1695       VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not "
1696               << "support 5D tensors.";
1697       return false;
1698     }
1699     return true;
1700   }
1701 
FusedBatchNormExRewrite(const Node * n)1702   static bool FusedBatchNormExRewrite(const Node* n) {
1703     DCHECK(n);
1704 
1705     int num_side_inputs;
1706     TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs));
1707     string activation_mode;
1708     TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode));
1709 
1710     // if the num_side_inputs is not 0, don't rewrite the node.
1711     if (num_side_inputs != 0) {
1712       VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs"
1713               << "larger than 0 is not optimized by Intel MKL.";
1714       return false;
1715     }
1716 
1717     // if the activation_mode is not 'Relu', don't rewrite the node.
1718     if (activation_mode != "Relu") {
1719       VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is"
1720               << "supported by Intel MKL.";
1721       return false;
1722     }
1723 
1724     return true;
1725   }
1726 
FusedConv2DRewrite(const Node * n)1727   static bool FusedConv2DRewrite(const Node* n) {
1728     // MKL DNN currently doesn't support all fusions that grappler fuses
1729     // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
1730     // it includes those we support.
1731     DataType T;
1732     if (!TryGetNodeAttr(n->def(), "T", &T) ||
1733         !mkl_op_registry::IsMklOp(NativeFormatEnabled()
1734                                       ? csinfo_.mkl_native_fused_conv2d
1735                                       : csinfo_.mkl_fused_conv2d,
1736                                   T)) {
1737       return false;
1738     }
1739 
1740     std::vector<string> fused_ops;
1741     TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1742     return (fused_ops == std::vector<string>{"BiasAdd"} ||
1743             fused_ops == std::vector<string>{"Relu"} ||
1744             fused_ops == std::vector<string>{"Relu6"} ||
1745             fused_ops == std::vector<string>{"Elu"} ||
1746             fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1747             fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1748             fused_ops == std::vector<string>{"BiasAdd", "Elu"} ||
1749             fused_ops == std::vector<string>{"BiasAdd", "Add"} ||
1750             fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"} ||
1751             fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"} ||
1752             fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"} ||
1753             fused_ops == std::vector<string>{"LeakyRelu"} ||
1754             fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"} ||
1755             fused_ops == std::vector<string>{"BiasAdd", "Add", "LeakyRelu"} ||
1756             fused_ops == std::vector<string>{"FusedBatchNorm"} ||
1757             fused_ops == std::vector<string>{"FusedBatchNorm", "Relu"} ||
1758             fused_ops == std::vector<string>{"FusedBatchNorm", "Relu6"} ||
1759             fused_ops == std::vector<string>{"FusedBatchNorm", "Elu"} ||
1760             fused_ops == std::vector<string>{"FusedBatchNorm", "LeakyRelu"});
1761   }
1762 
FusedDepthwiseConv2DRewrite(const Node * n)1763   static bool FusedDepthwiseConv2DRewrite(const Node* n) {
1764     // MKL DNN currently doesn't support all fusions that grappler fuses
1765     // together with DepthwiseConv2D (ex. batchnorm). We rewrite
1766     // _FusedDepthwiseConv2DNative only if it includes those we support.
1767     DataType T;
1768     if (!TryGetNodeAttr(n->def(), "T", &T) ||
1769         !mkl_op_registry::IsMklOp(
1770             NativeFormatEnabled() ? csinfo_.mkl_native_fused_depthwise_conv2d
1771                                   : csinfo_.mkl_fused_depthwise_conv2d,
1772             T)) {
1773       return false;
1774     }
1775 
1776     std::vector<string> fused_ops;
1777     TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1778     return (fused_ops == std::vector<string>{"BiasAdd"} ||
1779             fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1780             fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1781             fused_ops == std::vector<string>{"BiasAdd", "Elu"});
1782   }
1783 
1784   // Rewrites input node to a new node specified by its matching rewrite info.
1785   //
1786   // Method first searches matching rewrite info for input node and then
1787   // uses that info to rewrite.
1788   //
1789   // Input node may be deleted in case of rewrite. Attempt to use the node
1790   // after the call can result in undefined behaviors.
1791   //
1792   // @input  g - input graph, n - Node to be rewritten,
1793   //         ri - matching rewriteinfo
1794   // @return Status::OK(), if the input node is rewritten;
1795   //         Returns appropriate Status error code otherwise.
1796   //         Graph is updated in case the input node is rewritten.
1797   //         Otherwise, it is not updated.
1798   Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
1799 
1800   // Rewrites input node to just change its operator name. The number of
1801   // inputs to the node and the number of outputs remain the same. Attributes
1802   // of the new node could be copied from attributes of the old node or
1803   // modified. copy_attrs field of RewriteInfo controls this.
1804   //
1805   // Conceptually, it allows us to rewrite:
1806   //
1807   //        f[a=v1,b=v2](x,y) -> g[a'=v3,b'=v4](x,y)
1808   //
1809   // Attributes can be altered without any restrictions --- they could be
1810   // copied, modified, or deleted completely.
1811   //
1812   // @input  g - input graph, orig_node - Node to be rewritten,
1813   //         ri - matching rewriteinfo
1814   // @output new_node - points to newly created node
1815   // @return Status::OK(), if the input node is rewritten;
1816   //         Returns appropriate Status error code otherwise.
1817   //         Graph is only updated when the input node is rewritten.
1818   Status RewriteNodeForJustOpNameChange(std::unique_ptr<Graph>* g,
1819                                         const Node* orig_node, Node** new_node,
1820                                         const RewriteInfo* ri);
1821 
1822   // Rewrites input node to enable MKL layout propagation. Please also refer to
1823   // documentation for the function RewriteNodeForJustOpNameChange() to
1824   // understand what it means.
1825   //
1826   // @input  g - input graph, orig_node - Node to be rewritten,
1827   //         ri - matching rewriteinfo
1828   // @output new_node - points to newly created node
1829   // @return Status::OK(), if the input node is rewritten;
1830   //         Returns appropriate Status error code otherwise.
1831   //         Graph is updated in case the input node is rewritten.
1832   //         Otherwise, it is not updated.
1833   Status RewriteNodeForLayoutPropagation(std::unique_ptr<Graph>* g,
1834                                          const Node* orig_node, Node** new_node,
1835                                          const RewriteInfo* ri);
1836 
1837   // Get nodes that will feed a list of TF tensors to the new
1838   // node that we are constructing.
1839   //
1840   // @input g - input graph,
1841   // @input inputs - inputs to old node that we are using for constructing
1842   //                 new inputs,
1843   // @input input_idx - the index in the 'inputs' vector pointing to the
1844   //                    current input that we have processed so far
1845   // @output input_idx - index will be incremented by the number of nodes
1846   //                     from 'inputs' that are processed
1847   // @input list_length - The expected length of list of TF tensors
1848   // @output output_nodes - the list of new nodes creating TF tensors
1849   //
1850   // @return None
1851   void GetNodesProducingTFTensorList(
1852       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1853       int* input_idx, int list_length,
1854       std::vector<NodeBuilder::NodeOut>* output_nodes);
1855 
1856   // Get nodes that will feed a list of Mkl tensors to the new
1857   // node that we are constructing.
1858   //
1859   // @input g - input graph,
1860   // @input orig_node - Original node that we are rewriting
1861   // @input inputs - inputs to old node that we are using for constructing
1862   //                 new inputs,
1863   // @input input_idx - the index in the 'inputs' vector pointing to the
1864   //                    current input that we have processed so far
1865   // @output input_idx - index will be incremented by the number of nodes
1866   //                     from 'inputs' that are processed
1867   // @input list_length - The expected length of list of Mkl tensors
1868   // @output output_nodes - the list of new nodes creating Mkl tensors
1869   //
1870   // @return None
1871   void GetNodesProducingMklTensorList(
1872       std::unique_ptr<Graph>* g, const Node* orig_node,
1873       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1874       int* input_idx, int list_length,
1875       std::vector<NodeBuilder::NodeOut>* output_nodes);
1876 
1877   // Get a node that will feed an Mkl tensor to the new
1878   // node that we are constructing. The output node could be (1) 'n'
1879   // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
1880   // if 'n' is not an Mkl layer.
1881   //
1882   // @input g - input graph,
1883   // @input orig_node - Original node that we are rewriting,
1884   // @input n - Node based on which we are creating Mkl node,
1885   // @input n_output_slot - the output slot of node 'n'
1886   //            which is feeding to the node that we are constructing
1887   // @output mkl_node - the new node that will feed Mkl tensor
1888   // @output mkl_node_output_slot - the slot number of mkl_node that
1889   //                                will feed the tensor
1890   // @return None
1891   void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
1892                                  const Node* orig_node, Node* n,
1893                                  int n_output_slot, Node** mkl_node,
1894                                  int* mkl_node_output_slot);
1895 
1896   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1897   // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
1898   // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
1899   // producing workspace edges if 'are_workspace_tensors_available' is true.
1900   // Otherwise, 'workspace_tensors' is empty vector.
1901   //
1902   // For details, refer to 'Ordering of inputs after rewriting' section in the
1903   // documentation above.
1904   //
1905   // Returns Status::OK() if setting up inputs is successful, otherwise
1906   // returns appropriate status code.
1907   int SetUpContiguousInputs(
1908       std::unique_ptr<Graph>* g,
1909       const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1910       NodeBuilder* nb, const Node* old_node,
1911       std::vector<NodeBuilder::NodeOut>* workspace_tensors,
1912       bool are_workspace_tensors_available);
1913 
1914   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1915   // in graph 'g'. Original node is input in 'orig_node'.
1916   //
1917   // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
1918   // section in the documentation above.
1919   //
1920   // Returns Status::OK() if setting up inputs is successful, otherwise
1921   // returns appropriate status code.
1922   Status SetUpInputs(std::unique_ptr<Graph>* g,
1923                      const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1924                      NodeBuilder* nb, const Node* orig_node);
1925 
1926   // Create new inputs by copying old inputs 'inputs' for the rewritten node
1927   // in 'nb' in graph 'g'. Original node is input in 'orig_node'. This is mostly
1928   // used in the context of rewrite for just operator name change in which
1929   // inputs of old operator and new operator are same.
1930   //
1931   // Returns Status::OK() if setting up inputs is successful, otherwise
1932   // returns appropriate status code.
1933   Status CopyInputs(const Node* orig_node,
1934                     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1935                     NodeBuilder* nb);
1936 
1937   // Add workspace edge on the input or output side of Node 'orig_node' by using
1938   // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
1939   // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
1940   // tensors, if they need to be added, will be set into these tensors.
1941   // If we set workspace tensors, then are_ws_tensors_added should be true.
1942   void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
1943                                 const Node* orig_node, NodeBuilder* nb,
1944                                 std::vector<NodeBuilder::NodeOut>* ws_tensors,
1945                                 bool* are_ws_tensors_added);
1946 
1947   // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
1948   // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
1949   // 'g'. Returns true if fixup was done; otherwise, it returns false.
1950   bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data,
1951                                   const Edge* e_metadata);
1952 
1953   // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
1954   // connected? If not, then fix them. This is needed because a graph may have
1955   // some input Mkl metadata edges incorrectly setup after node merge and
1956   // rewrite passes. This could happen because GetReversePostOrder function may
1957   // not provide topologically sorted order if a graph contains cycles. The
1958   // function returns true if at least one Mkl metadata edge for node 'n' was
1959   // fixed. Otherwise, it returns false.
1960   //
1961   // Example:
1962   //
1963   // X = MklConv2D(_, _, _)
1964   // Y = MklConv2DWithBias(_, _, _, _, _, _)
1965   // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
1966   //
1967   // For a graph such as shown above, note that 3rd argument of MklAdd contains
1968   // DummyMklTensor. Actually, it should be getting the Mkl metadata from
1969   // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
1970   // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
1971   // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
1972   // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
1973   // data edges (1st and 2nd arguments of MklAdd).
1974   bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
1975 
1976   // Functions specific to operators to copy attributes
1977   // We need operator-specific function to copy attributes because the framework
1978   // does not provide any generic function for it.
1979   // NOTE: names are alphabetically sorted.
1980   static void CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
1981                            bool change_format = false);
1982   static void CopyAttrsAllCheckConstFilter(const Node* orig_node,
1983                                            NodeBuilder* nb,
1984                                            bool change_format = false);
1985 
1986   static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
1987                             bool change_format = false);
1988   static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
1989                                             NodeBuilder* nb,
1990                                             bool change_format = false);
1991   static void CopyAttrsFromPadAndConv2D(const Node* orig_node1,
1992                                         const Node* orig_node2, NodeBuilder* nb,
1993                                         bool change_format = false);
1994   static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1,
1995                                              const Node* orig_node2,
1996                                              NodeBuilder* nb,
1997                                              bool change_format = false);
1998   static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb,
1999                                        bool change_format = false);
2000   static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb,
2001                                   const std::vector<int32>& strides,
2002                                   const std::vector<int32>& dilations,
2003                                   bool change_format = false);
2004 
2005   static void CopyAttrsQuantizedMatMulWithBias(const Node* orig_node,
2006                                                NodeBuilder* nb,
2007                                                bool change_format = false);
2008   static void CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2009       const Node* orig_node, NodeBuilder* nb, bool change_format = false);
2010   static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb,
2011                                bool change_format = false);
2012 
2013   // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
2014   // using node for original node 'orig_node' and return it in '*out'.
2015   // TODO(nhasabni) We should move this to mkl_util.h
2016   void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
2017                              const Node* orig_node);
2018   void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
2019                                    const Node* orig_node);
2020 };
2021 
2022 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
2023 
2024 // We register Mkl rewrite pass for phase 1 in post partitioning group.
2025 // We register it here so that we get a complete picture of all users of Mkl
2026 // nodes. Do not change the ordering of the Mkl passes.
2027 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
2028     OptimizationPassRegistry::POST_PARTITIONING;
2029 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
2030 
2031 //////////////////////////////////////////////////////////////////////////
2032 //           Helper functions for creating new node
2033 //////////////////////////////////////////////////////////////////////////
2034 
FillInputs(const Node * n,gtl::InlinedVector<Node *,4> * control_edges,gtl::InlinedVector<std::pair<Node *,int>,4> * in)2035 static void FillInputs(const Node* n,
2036                        gtl::InlinedVector<Node*, 4>* control_edges,
2037                        gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
2038   control_edges->clear();
2039   for (const Edge* e : n->in_edges()) {
2040     if (e->IsControlEdge()) {
2041       control_edges->push_back(e->src());
2042     } else {
2043       (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
2044     }
2045   }
2046   std::sort(control_edges->begin(), control_edges->end());
2047 }
2048 
GetNodesProducingTFTensorList(const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2049 void MklLayoutRewritePass::GetNodesProducingTFTensorList(
2050     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2051     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2052   CHECK_LT(*input_idx, inputs.size());
2053   CHECK_GT(list_length, 0);
2054   DCHECK(output_nodes);
2055   output_nodes->reserve(list_length);
2056 
2057   while (list_length != 0) {
2058     CHECK_GT(list_length, 0);
2059     CHECK_LT(*input_idx, inputs.size());
2060     Node* n = inputs[*input_idx].first;
2061     int slot = inputs[*input_idx].second;
2062     // If input node 'n' is just producing a single tensor at
2063     // output slot 'slot' then we just add that single node.
2064     output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
2065     (*input_idx)++;
2066     list_length--;
2067   }
2068 }
2069 
2070 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyMklTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2071 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
2072                                                  Node** out,
2073                                                  const Node* orig_node) {
2074   // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
2075   // dummy Mkl tensor. 8 = 2*size_t.
2076   const DataType dt = DataTypeToEnum<uint8>::v();
2077   TensorProto proto;
2078   proto.set_dtype(dt);
2079   uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
2080   proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8));
2081   TensorShape dummy_shape({8});
2082   dummy_shape.AsProto(proto.mutable_tensor_shape());
2083   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
2084                   .Attr("value", proto)
2085                   .Attr("dtype", dt)
2086                   .Device(orig_node->def().device())  // We place this node on
2087                                                       // the same device as the
2088                                                       // device of the original
2089                                                       // node.
2090                   .Finalize(&**g, out));
2091   DCHECK(*out);  // Make sure we got a valid object before using it
2092 
2093   // If number of inputs to the original node is > 0, then we add
2094   // control dependency between 1st input (index 0) of the original node and
2095   // the dummy Mkl node. This is needed because control-flow ops such as Enter,
2096   // Merge, etc, require frame_name of the dummy Mkl node to be same as the
2097   // rewritten node. Adding control edge between 1st input of the original node
2098   // and the dummy Mkl node ensures that the dummy node is in the same frame
2099   // as the original node. Choosing 1st input is not necessary - any input of
2100   // the original node is fine because all the inputs of a node are always in
2101   // the same frame.
2102   if (orig_node->num_inputs() > 0) {
2103     Node* orig_input0 = nullptr;
2104     TF_CHECK_OK(
2105         orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
2106     auto edge = (*g)->AddControlEdge(orig_input0, *out, false);
2107     DCHECK(edge != nullptr || DoesControlEdgeExist(orig_input0, *out));
2108   }
2109 
2110   (*out)->set_assigned_device_name(orig_node->assigned_device_name());
2111 }
2112 
GetNodesProducingMklTensorList(std::unique_ptr<Graph> * g,const Node * orig_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2113 void MklLayoutRewritePass::GetNodesProducingMklTensorList(
2114     std::unique_ptr<Graph>* g, const Node* orig_node,
2115     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2116     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2117   CHECK_LT(*input_idx, inputs.size());
2118   CHECK_GT(list_length, 0);
2119   DCHECK(output_nodes);
2120   output_nodes->reserve(list_length);
2121 
2122   while (list_length != 0) {
2123     CHECK_GT(list_length, 0);
2124     CHECK_LT(*input_idx, inputs.size());
2125     Node* n = inputs[*input_idx].first;
2126     int slot = inputs[*input_idx].second;
2127     // If 'n' is producing a single tensor, then create a single Mkl tensor
2128     // node.
2129     Node* mkl_node = nullptr;
2130     int mkl_node_output_slot = 0;
2131     GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
2132                               &mkl_node_output_slot);
2133     output_nodes->push_back(
2134         NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
2135     (*input_idx)++;
2136     list_length--;
2137   }
2138 }
2139 
2140 // Get an input node that will feed Mkl tensor to the new
2141 // node that we are constructing. An input node could be (1) 'n'
2142 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
2143 // if 'n' is not an Mkl layer.
GetNodeProducingMklTensor(std::unique_ptr<Graph> * g,const Node * orig_node,Node * n,int n_output_slot,Node ** mkl_node,int * mkl_node_output_slot)2144 void MklLayoutRewritePass::GetNodeProducingMklTensor(
2145     std::unique_ptr<Graph>* g, const Node* orig_node, Node* n,
2146     int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
2147   DCHECK(n);
2148   DCHECK(mkl_node);
2149   DCHECK(mkl_node_output_slot);
2150 
2151   // If this is an MKL op, then it will create extra output for MKL layout.
2152   DataType T;
2153   if (TryGetNodeAttr(n->def(), "T", &T) &&
2154       mkl_op_registry::IsMklOp(n->type_string(), T, false)) {
2155     // If this is an MKL op, then it will generate an edge that will receive
2156     // Mkl tensor from a node.
2157     // output slot number for Mkl tensor would be N+slot number of TensorFlow
2158     // tensor, where N is total number of TensorFlow tensors.
2159     *mkl_node = n;
2160     *mkl_node_output_slot =
2161         GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
2162   } else {
2163     // If we have not visited the node and rewritten it, then we need
2164     // to create a dummy node that will feed a dummy Mkl tensor to this node.
2165     // DummyMklTensor node has no input and generates only 1 output
2166     // (dummy Mkl tensor) as output slot number 0.
2167     GetDummyMklTensorNode(g, mkl_node, orig_node);
2168     DCHECK(*mkl_node);
2169     *mkl_node_output_slot = 0;
2170   }
2171 }
2172 
SetUpContiguousInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node,std::vector<NodeBuilder::NodeOut> * workspace_tensors,bool are_workspace_tensors_available)2173 int MklLayoutRewritePass::SetUpContiguousInputs(
2174     std::unique_ptr<Graph>* g,
2175     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2176     NodeBuilder* nb, const Node* old_node,
2177     std::vector<NodeBuilder::NodeOut>* workspace_tensors,
2178     bool are_workspace_tensors_available) {
2179   DCHECK(workspace_tensors);
2180   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2181 
2182   // TODO(nhasabni): Temporary solution to connect filter input of
2183   // BackpropInput with the converted filter from Conv2D.
2184   bool do_connect_conv2d_backprop_input_filter = false;
2185   Node* conv2d_node = nullptr;
2186   // Filter node is 2nd input (slot index 1) of Conv2D.
2187   int kConv2DFilterInputSlotIdx = 1;
2188   int kConv2DBackpropInputFilterInputSlotIdx = 1;
2189   int kConv2DFilterOutputSlotIdx = 1;
2190   if (old_node->type_string() == csinfo_.conv2d_grad_input) {
2191     // We need to find Conv2D node from Conv2DBackpropInput.
2192     // For that let's first find filter node that is 2nd input (slot 1)
2193     // of BackpropInput.
2194     Node* filter_node = nullptr;
2195     TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx,
2196                                      &filter_node));
2197     DCHECK(filter_node);
2198 
2199     // Now check which nodes receive from filter_node. Filter feeds as
2200     // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
2201     // _MklFusedConv2D.
2202     for (const Edge* e : filter_node->out_edges()) {
2203       if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
2204            e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d ||
2205            e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d ||
2206            e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias ||
2207            e->dst()->type_string() == csinfo_.mkl_fused_conv2d) &&
2208           e->dst_input() == kConv2DFilterInputSlotIdx
2209           /* filter is 2nd input of Conv2D and _MklConv2D. */) {
2210         if (conv2d_node != nullptr) {
2211           VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
2212                   << " feeding multiple Conv2D nodes: "
2213                   << filter_node->DebugString();
2214           // We will not connect filter input of Conv2DBackpropInput
2215           // to be safe here.
2216           do_connect_conv2d_backprop_input_filter = false;
2217           break;
2218         } else {
2219           conv2d_node = e->dst();
2220           do_connect_conv2d_backprop_input_filter = true;
2221         }
2222       }
2223     }
2224   }
2225 
2226   // Number of input slots to original op
2227   // Input slots are represented by .Input() calls in REGISTER_OP.
2228   int old_node_input_slots = old_node->op_def().input_arg_size();
2229   int nn_slot_idx = 0;  // slot index for inputs of new node
2230 
2231   // Let's copy all inputs (TF tensors) of original node to new node.
2232   int iidx = 0;
2233   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2234     // An input slot could be a single tensor or a list. We need
2235     // to handle this case accordingly.
2236     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2237     if (ArgIsList(arg)) {
2238       std::vector<NodeBuilder::NodeOut> new_node_inputs;
2239       int tensor_list_length = GetTensorListLength(arg, old_node);
2240       if (tensor_list_length != 0) {
2241         GetNodesProducingTFTensorList(old_node_inputs, &iidx,
2242                                       tensor_list_length, &new_node_inputs);
2243       }
2244       nb->Input(new_node_inputs);
2245       nn_slot_idx++;
2246     } else {
2247       // Special case for connecting filter input of Conv2DBackpropInput
2248       if (do_connect_conv2d_backprop_input_filter &&
2249           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2250         nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
2251       } else {
2252         nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2253       }
2254       iidx++;
2255       nn_slot_idx++;
2256     }
2257   }
2258 
2259   // If workspace tensors are available for this op and we are using
2260   // contiguous ordering then we need to add Tensorflow tensor for
2261   // workspace here because Tensorflow tensor for workspace is the
2262   // last tensor in the list of Tensorflow tensors.
2263   if (are_workspace_tensors_available) {
2264     CHECK_EQ(workspace_tensors->size(), 2);
2265     // Tensorflow tensor
2266     nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
2267     nn_slot_idx++;
2268   }
2269 
2270   // Let's now setup all Mkl inputs to a new node.
2271   // Number of Mkl inputs must be same as number of TF inputs.
2272   iidx = 0;
2273   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2274     // An input slot could be a single tensor or a list. We need
2275     // to handle this case accordingly.
2276     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2277     if (ArgIsList(arg)) {
2278       std::vector<NodeBuilder::NodeOut> new_node_inputs;
2279       int tensor_list_length = GetTensorListLength(arg, old_node);
2280       if (tensor_list_length != 0) {
2281         GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
2282                                        tensor_list_length, &new_node_inputs);
2283       }
2284       nb->Input(new_node_inputs);
2285       nn_slot_idx++;
2286     } else {
2287       Node* mkl_node = nullptr;
2288       int mkl_node_output_slot = 0;
2289       // Special case for connecting filter input of Conv2DBackpropInput
2290       if (do_connect_conv2d_backprop_input_filter &&
2291           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2292         GetNodeProducingMklTensor(g, old_node, conv2d_node,
2293                                   kConv2DFilterOutputSlotIdx, &mkl_node,
2294                                   &mkl_node_output_slot);
2295       } else {
2296         GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
2297                                   old_node_inputs[iidx].second, &mkl_node,
2298                                   &mkl_node_output_slot);
2299       }
2300       nb->Input(mkl_node, mkl_node_output_slot);
2301       iidx++;
2302       nn_slot_idx++;
2303     }
2304   }
2305 
2306   // If workspace tensors are available for this op and we are using
2307   // contiguous ordering then we need to add Mkl tensor for
2308   // workspace here because Mkl tensor for workspace is the
2309   // last tensor in the list of Mkl tensors.
2310   if (are_workspace_tensors_available) {
2311     CHECK_EQ(workspace_tensors->size(), 2);
2312     // Mkl tensor
2313     nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
2314     nn_slot_idx++;
2315   }
2316 
2317   return nn_slot_idx;
2318 }
2319 
2320 // This method finds out if checking workspace is needed or not. Workspace is
2321 // not used in quantized ops, so checking that would fail as quantized ops
2322 // don't have attribute: "T".
IsWorkspaceCheckNeeded(const Node * node)2323 bool IsWorkspaceCheckNeeded(const Node* node) {
2324   std::vector<string> quant_ops{
2325       "Dequantize",
2326       "QuantizeV2",
2327       "QuantizedConv2D",
2328       "QuantizedConv2DWithBias",
2329       "QuantizedConv2DAndRelu",
2330       "QuantizedConv2DWithBiasAndRelu",
2331       "QuantizedConv2DWithBiasSumAndRelu",
2332       "QuantizedConv2DPerChannel",
2333       "QuantizedConv2DAndRequantize",
2334       "QuantizedConv2DWithBiasAndRequantize",
2335       "QuantizedConv2DAndReluAndRequantize",
2336       "QuantizedConv2DWithBiasAndReluAndRequantize",
2337       "QuantizedConv2DWithBiasSumAndReluAndRequantize",
2338       "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize",
2339       "QuantizedMatMulWithBias",
2340       "QuantizedMatMulWithBiasAndRequantize",
2341       "QuantizedMatMulWithBiasAndDequantize",
2342       "QuantizedMatMulWithBiasAndRelu",
2343       "QuantizedMatMulWithBiasAndReluAndRequantize",
2344       "QuantizedDepthwiseConv2D",
2345       "QuantizedDepthwiseConv2DWithBias",
2346       "QuantizedDepthwiseConv2DWithBiasAndRelu",
2347       "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"};
2348   return std::find(std::begin(quant_ops), std::end(quant_ops),
2349                    node->type_string()) == std::end(quant_ops);
2350 }
2351 
SetUpInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node)2352 Status MklLayoutRewritePass::SetUpInputs(
2353     std::unique_ptr<Graph>* g,
2354     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2355     NodeBuilder* nb, const Node* old_node) {
2356   // Let's check if we need to add workspace tensors for this node.
2357   // We add workspace edge only for MaxPool, LRN and BatchNorm.
2358   std::vector<NodeBuilder::NodeOut> workspace_tensors;
2359   bool are_workspace_tensors_available = false;
2360 
2361   if (IsWorkspaceCheckNeeded(old_node)) {
2362     AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
2363                              &are_workspace_tensors_available);
2364   }
2365 
2366   int new_node_input_slots = 0;
2367   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
2368     // TODO(nhasabni): implement this function just for same of completion.
2369     // We do not use interleaved ordering right now.
2370     return Status(
2371         error::Code::UNIMPLEMENTED,
2372         "Interleaved ordering of tensors is currently not supported.");
2373   } else {
2374     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2375     new_node_input_slots = SetUpContiguousInputs(
2376         g, old_node_inputs, nb, old_node, &workspace_tensors,
2377         are_workspace_tensors_available);
2378   }
2379 
2380   // Sanity check
2381   int old_node_input_slots = old_node->op_def().input_arg_size();
2382   if (!are_workspace_tensors_available) {
2383     // If we are not adding workspace tensors for this op, then the total
2384     // number of input slots to the new node _must_ be 2 times the number
2385     // of input slots to the original node: N original Tensorflow tensors and
2386     // N for Mkl tensors corresponding to each Tensorflow tensors.
2387     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
2388   } else {
2389     // If we are adding workspace tensors for this op, then the total
2390     // The total number of input slots to new node _must_ be 2 times the number
2391     // of input slots to the original node: N original Tensorflow tensors and
2392     // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
2393     // (for workspace Tensorflow tensor and workspace Mkl tensor).
2394     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
2395   }
2396 
2397   return Status::OK();
2398 }
2399 
CopyInputs(const Node * old_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb)2400 Status MklLayoutRewritePass::CopyInputs(
2401     const Node* old_node,
2402     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2403     NodeBuilder* nb) {
2404   // Number of input slots to old node
2405   // Input slots are represented by .Input() calls in REGISTER_OP.
2406   int old_node_input_slots = old_node->op_def().input_arg_size();
2407   // Actual number of inputs can be greater than or equal to number
2408   // of Input slots because inputs of type list could be unfolded.
2409   auto old_node_input_size = old_node_inputs.size();
2410   DCHECK_GE(old_node_input_size, old_node_input_slots);
2411 
2412   // Let's copy all inputs of old node to new node.
2413   int iidx = 0;
2414   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2415     // An input slot could be a single tensor or a list. We need
2416     // to handle this case accordingly.
2417     DCHECK_LT(iidx, old_node_input_size);
2418     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2419     if (ArgIsList(arg)) {
2420       std::vector<NodeBuilder::NodeOut> new_node_inputs;
2421       int N = GetTensorListLength(arg, old_node);
2422       if (N != 0) {
2423         GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
2424                                       &new_node_inputs);
2425       }
2426       nb->Input(new_node_inputs);
2427     } else {
2428       nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2429       iidx++;
2430     }
2431   }
2432   return Status::OK();
2433 }
2434 
2435 //////////////////////////////////////////////////////////////////////////
2436 //           Helper functions related to workspace pass
2437 //////////////////////////////////////////////////////////////////////////
2438 
2439 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyWorkspaceTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2440 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
2441     std::unique_ptr<Graph>* g, Node** out, const Node* orig_node) {
2442   // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
2443   // workspace tensor.
2444   GetDummyMklTensorNode(g, out, orig_node);
2445 }
2446 
AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph> * g,const Node * orig_node,NodeBuilder * nb,std::vector<NodeBuilder::NodeOut> * ws_tensors,bool * are_ws_tensors_added)2447 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
2448     std::unique_ptr<Graph>* g, const Node* orig_node, NodeBuilder* nb,
2449     std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
2450   bool workspace_edge_added = false;  // Default initializer
2451   DCHECK(are_ws_tensors_added);
2452   *are_ws_tensors_added = false;  // Default initializer
2453 
2454   DataType T;
2455   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2456   for (auto ws : wsinfo_) {
2457     if (orig_node->type_string() == ws.fwd_op &&
2458         mkl_op_registry::IsMklOp(
2459             mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
2460       // If this op is a fwd op, then we need to check if there is an
2461       // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
2462       // an edge, then we just add an attribute on this node for setting
2463       // workspace_passed to true. We don't add actual workspace edge
2464       // in this node. Actual workspace edge gets added in the backward
2465       // op for this node.
2466       for (const Edge* e : orig_node->out_edges()) {
2467         if (e->src_output() == ws.fwd_slot &&
2468             e->dst()->type_string() == ws.bwd_op &&
2469             e->dst_input() == ws.bwd_slot) {
2470           nb->Attr("workspace_enabled", true);
2471           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2472                   << orig_node->type_string();
2473           workspace_edge_added = true;
2474           // We found the edge that we were looking for, so break.
2475           break;
2476         }
2477       }
2478 
2479       if (!workspace_edge_added) {
2480         // If we are here, then we did not find backward operator for this
2481         // node.
2482         nb->Attr("workspace_enabled", false);
2483       }
2484     } else if (orig_node->type_string() == ws.bwd_op &&
2485                mkl_op_registry::IsMklOp(
2486                    mkl_op_registry::GetMklOpName(orig_node->type_string()),
2487                    T)) {
2488       // If this op is a bwd op, then we need to add workspace edge and
2489       // it's Mkl tensor edge between its corresponding fwd op and this
2490       // op. Corresponding fwd op is specified in 'fwd_op' field of
2491       // workspace info. fwd_slot and bwd_slot in workspace info specify
2492       // an edge between which slots connect forward and backward op.
2493       // Once all these criteria match, we add a workspace edge between
2494       // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
2495       // determined by interleaved/contiguous ordering. Function
2496       // DataIndexToMetaDataIndex tells us the location of Mkl tensor
2497       // from the location of the Tensorflow tensor.
2498       for (const Edge* e : orig_node->in_edges()) {
2499         if (e->src_output() == ws.fwd_slot &&
2500             // We would have rewritten the forward op, so we need to use
2501             // GetMklOpName call to get its Mkl name.
2502             e->src()->type_string() ==
2503                 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
2504             e->dst_input() == ws.bwd_slot) {
2505           nb->Attr("workspace_enabled", true);
2506           DCHECK(ws_tensors);
2507           // Add workspace edge between fwd op and bwd op.
2508           ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
2509           // Check if we are running in native format mode. If so,
2510           // we don't need to have an Mkl metadata tensor for the workspace.
2511           if (!NativeFormatEnabled()) {
2512             // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
2513             ws_tensors->push_back(NodeBuilder::NodeOut(
2514                 e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
2515                                                    e->src()->num_outputs())));
2516           }
2517           *are_ws_tensors_added = true;
2518           // In terms of input ordering, we add these calls to add Input
2519           // here because workspace edge (and its Mkl tensor) is the last
2520           // edge in the fwdop and bwdop. So all inputs before workspace
2521           // tensor have been added by SetUpInputs function.
2522           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2523                   << orig_node->type_string();
2524           workspace_edge_added = true;
2525           // We found the edge that we were looking for, so break.
2526           break;
2527         }
2528       }
2529 
2530       // If we are here means we did not find fwd op that feeds to this
2531       // bwd op. So in this case, we need to generate dummy tensors for
2532       // workspace input and Mkl tensor for workspace, and set
2533       // workspace_enabled to false.
2534       if (!workspace_edge_added) {
2535         nb->Attr("workspace_enabled", false);
2536         Node* dmt_ws = nullptr;      // Dummy tensor for workspace
2537         Node* dmt_mkl_ws = nullptr;  // Dummy Mkl tensor for workspace
2538         GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
2539         GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
2540         DCHECK(dmt_ws);
2541         DCHECK(dmt_mkl_ws);
2542         DCHECK(ws_tensors);
2543         // We add dummy tensor as workspace tensor.
2544         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
2545         // We add dummy tensor as Mkl tensor for workspace tensor.
2546         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
2547         *are_ws_tensors_added = true;
2548         VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
2549                 << orig_node->type_string();
2550       }
2551     } else {
2552       // If this node does not match any workspace info, then we do not
2553       // do anything special for workspace propagation for it.
2554     }
2555   }
2556 }
2557 
2558 //////////////////////////////////////////////////////////////////////////
2559 // Op-specific functions to copy attributes from old node to new node
2560 //////////////////////////////////////////////////////////////////////////
2561 
2562 // Generic function to copy all attributes from original node to target.
CopyAttrsAll(const Node * orig_node,NodeBuilder * nb,bool change_format)2563 void MklLayoutRewritePass::CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2564                                         bool change_format) {
2565   string name;
2566   AttrSlice attr_list(orig_node->def());
2567 
2568   auto iter = attr_list.begin();
2569   while (iter != attr_list.end()) {
2570     name = iter->first;
2571     auto attr = iter->second;
2572     nb->Attr(name, attr);
2573     ++iter;
2574   }
2575 }
2576 
2577 // Generic function to copy all attributes and check if filter is const.
CopyAttrsAllCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2578 void MklLayoutRewritePass::CopyAttrsAllCheckConstFilter(const Node* orig_node,
2579                                                         NodeBuilder* nb,
2580                                                         bool change_format) {
2581   CopyAttrsAll(orig_node, nb, change_format);
2582 
2583   // Check and set filter attribute.
2584   Node* filter_node = nullptr;
2585   TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2586 
2587   bool is_filter_const = false;
2588   if (HasNodeAttr(orig_node->def(), "is_filter_const")) {
2589     GetNodeAttr(orig_node->def(), "is_filter_const", &is_filter_const);
2590   }
2591 
2592   // In case that (1) orig_node does not have attribute 'is_filter_const',
2593   // or (2) it has the attribute but with the false value, then we set the
2594   // attribute for 'nb' with a value based on filter_node being const or not.
2595   // If is_filter_const == true, then there is no need to call nb->Attr() as
2596   // CopyAttrsAll() has already copied the attribute from orig_node to nb.
2597   if (!is_filter_const) {
2598     nb->Attr("is_filter_const", filter_node->IsConstant());
2599   }
2600 }
2601 
CopyAttrsConvCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2602 void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node,
2603                                                          NodeBuilder* nb,
2604                                                          bool change_format) {
2605   CopyAttrsConv(orig_node, nb, change_format);
2606 
2607   // Check and set filter attribute.
2608   Node* filter_node = nullptr;
2609   TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2610   nb->Attr("is_filter_const", filter_node->IsConstant());
2611 }
2612 
CopyAttrsConv(const Node * orig_node,NodeBuilder * nb,bool change_format)2613 void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2614                                          bool change_format) {
2615   DataType T;
2616   string padding;
2617   std::vector<int32> strides;
2618   std::vector<int32> dilations;
2619   std::vector<int32> explicit_paddings;
2620 
2621   // Get all attributes from old node.
2622   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2623   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2624   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2625   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2626 
2627   // Check `explicit_paddings` first because some Conv ops don't have
2628   // this attribute.
2629   if (TryGetNodeAttr(orig_node->def(), "explicit_paddings",
2630                      &explicit_paddings) &&
2631       !explicit_paddings.empty()) {
2632     nb->Attr("explicit_paddings", explicit_paddings);
2633   }
2634 
2635   // Add attributes to new node.
2636   nb->Attr("T", T);
2637   nb->Attr("padding", padding);
2638 
2639   // Add attributes related to `data_format`.
2640   CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
2641 }
2642 
2643 // Used with MergePadWithConv2D
CopyAttrsFromPadAndConv2D(const Node * orig_node1,const Node * orig_node2,NodeBuilder * nb,bool change_format)2644 void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2645                                                      const Node* orig_node2,
2646                                                      NodeBuilder* nb,
2647                                                      bool change_format) {
2648   DataType Tpaddings;
2649   DataType T;
2650   string data_format;
2651   string padding;
2652   std::vector<int32> strides;
2653   std::vector<int32> dilations;
2654   bool use_cudnn_on_gpu;
2655 
2656   // Get all attributes from old node 1.
2657   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T", &T));
2658   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides", &strides));
2659   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations", &dilations));
2660   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding", &padding));
2661   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format", &data_format));
2662   TF_CHECK_OK(
2663       GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2664   // Get all attributes from old node 2.
2665   TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings", &Tpaddings));
2666 
2667   // Add attributes to new node.
2668   nb->Attr("T", T);
2669   nb->Attr("strides", strides);
2670   nb->Attr("dilations", dilations);
2671   nb->Attr("padding", padding);
2672   nb->Attr("data_format", data_format);
2673   nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2674   nb->Attr("Tpaddings", Tpaddings);
2675 }
2676 
CopyAttrsFromPadAndFusedConv2D(const Node * fused_conv2d,const Node * pad,NodeBuilder * nb,bool change_format)2677 void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
2678     const Node* fused_conv2d, const Node* pad, NodeBuilder* nb,
2679     bool change_format) {
2680   DataType T;
2681   int num_args;
2682   string data_format;
2683   string padding;
2684   std::vector<int32> strides;
2685   std::vector<int32> dilations;
2686   float epsilon;
2687   std::vector<string> fused_ops;
2688   DataType Tpaddings;
2689   float leakyrelu_alpha;
2690 
2691   // Get all attributes from old node.
2692   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T));
2693   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args", &num_args));
2694   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides", &strides));
2695   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding", &padding));
2696   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format", &data_format));
2697   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations));
2698   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops));
2699   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon));
2700   TF_CHECK_OK(
2701       GetNodeAttr(fused_conv2d->def(), "leakyrelu_alpha", &leakyrelu_alpha));
2702   TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings));
2703 
2704   // Add attributes to new node.
2705   nb->Attr("T", T);
2706   nb->Attr("num_args", num_args);
2707   nb->Attr("strides", strides);
2708   nb->Attr("padding", padding);
2709   nb->Attr("data_format", data_format);
2710   nb->Attr("dilations", dilations);
2711   nb->Attr("epsilon", epsilon);
2712   nb->Attr("Tpaddings", Tpaddings);
2713   nb->Attr("fused_ops", fused_ops);
2714   nb->Attr("leakyrelu_alpha", leakyrelu_alpha);
2715 }
2716 
CopyAttrsQuantizedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2717 void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
2718                                                     NodeBuilder* nb,
2719                                                     bool change_format) {
2720   DataType Tinput, Tfilter, out_type;
2721   string padding;
2722   string data_format("NHWC");
2723   std::vector<int32> strides, dilations, padding_list;
2724   bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list");
2725 
2726   // Get all attributes from old node.
2727   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
2728   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter", &Tfilter));
2729   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
2730   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2731   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2732   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2733   if (has_padding_list) {
2734     TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list));
2735   }
2736 
2737   Node* filter_node = nullptr;
2738   TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2739 
2740   // Add attributes to new node.
2741   nb->Attr("Tinput", Tinput);
2742   nb->Attr("Tfilter", Tfilter);
2743   nb->Attr("out_type", out_type);
2744   nb->Attr("padding", padding);
2745   nb->Attr("is_filter_const", filter_node->IsConstant());
2746   nb->Attr("strides", strides);
2747   nb->Attr("dilations", dilations);
2748   nb->Attr("data_format", data_format);
2749   if (has_padding_list) {
2750     nb->Attr("padding_list", padding_list);
2751   }
2752 
2753   // Requantization attr Tbias.
2754   DataType Tbias;
2755   Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2756   if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2757 }
2758 
CopyAttrsQuantizedMatMulWithBiasAndDequantize(const Node * orig_node,NodeBuilder * nb,bool change_format)2759 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2760     const Node* orig_node, NodeBuilder* nb, bool change_format) {
2761   CopyAttrsAll(orig_node, nb, change_format);
2762 
2763   // Check and set filter attribute.
2764   Node* filter_node = nullptr;
2765   TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2766   nb->Attr("is_weight_const", filter_node->IsConstant());
2767 }
2768 
CopyAttrsQuantizedMatMulWithBias(const Node * orig_node,NodeBuilder * nb,bool change_format)2769 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBias(
2770     const Node* orig_node, NodeBuilder* nb, bool change_format) {
2771   DataType T1, T2, Toutput;
2772 
2773   // Get all attributes from old node.
2774   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T1", &T1));
2775   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T2", &T2));
2776   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Toutput", &Toutput));
2777 
2778   Node* weight_node = nullptr;
2779   TF_CHECK_OK(orig_node->input_node(1, &weight_node));
2780 
2781   // Add attributes to new node.
2782   nb->Attr("T1", T1);
2783   nb->Attr("T2", T2);
2784   nb->Attr("Toutput", Toutput);
2785   nb->Attr("is_weight_const", weight_node->IsConstant());
2786 
2787   // Requantization attr Tbias
2788   DataType Tbias;
2789   Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2790   if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2791 }
2792 
CopyFormatAttrsConv(const Node * orig_node,NodeBuilder * nb,const std::vector<int32> & strides,const std::vector<int32> & dilations,bool change_format)2793 void MklLayoutRewritePass::CopyFormatAttrsConv(
2794     const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides,
2795     const std::vector<int32>& dilations, bool change_format) {
2796   string data_format;
2797 
2798   if (!change_format) {
2799     nb->Attr("strides", strides);
2800     nb->Attr("dilations", dilations);
2801 
2802     TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2803     nb->Attr("data_format", data_format);
2804   } else {
2805     std::vector<int32> new_strides;
2806     std::vector<int32> new_dilations;
2807     if (strides.size() == 5) {
2808       // `strides` and `dilations` also need to be changed according to
2809       // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2810       new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2811                      strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2812                      strides[NDHWC::dim::W]};
2813 
2814       new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C],
2815                        dilations[NDHWC::dim::D], dilations[NDHWC::dim::H],
2816                        dilations[NDHWC::dim::W]};
2817     } else {
2818       // `strides` and `dilations` also need to be changed according to
2819       // `data_format`. In this case, from `NHWC` to `NCHW`.
2820 
2821       new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2822                      strides[NHWC::dim::H], strides[NHWC::dim::W]};
2823 
2824       new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C],
2825                        dilations[NHWC::dim::H], dilations[NHWC::dim::W]};
2826     }
2827     nb->Attr("strides", new_strides);
2828     nb->Attr("dilations", new_dilations);
2829   }
2830 }
2831 
CopyAttrsPooling(const Node * orig_node,NodeBuilder * nb,bool change_format)2832 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
2833                                             NodeBuilder* nb,
2834                                             bool change_format) {
2835   DataType T;
2836   string data_format;
2837   string padding;
2838   std::vector<int32> ksize, strides;
2839 
2840   // Get all attributes from old node.
2841   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2842   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
2843   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2844   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2845   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2846 
2847   // Add attributes to new node.
2848   nb->Attr("T", T);
2849   nb->Attr("padding", padding);
2850 
2851   if (!change_format) {
2852     nb->Attr("strides", strides);
2853     nb->Attr("ksize", ksize);
2854 
2855     nb->Attr("data_format", data_format);
2856   } else {
2857     std::vector<int32> new_strides;
2858     std::vector<int32> new_ksize;
2859     if (strides.size() == 5) {
2860       DCHECK(data_format == "NCDHW");
2861       // `strides` and `ksize` also need to be changed according to
2862       // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2863       new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2864                      strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2865                      strides[NDHWC::dim::W]};
2866 
2867       new_ksize = {ksize[NDHWC::dim::N], ksize[NDHWC::dim::C],
2868                    ksize[NDHWC::dim::D], ksize[NDHWC::dim::H],
2869                    ksize[NDHWC::dim::W]};
2870 
2871     } else {
2872       // `strides` and `ksize` also need to be changed according to
2873       // `data_format`. In this case, from `NHWC` to `NCHW`.
2874       DCHECK(data_format == "NCHW");
2875       new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2876                      strides[NHWC::dim::H], strides[NHWC::dim::W]};
2877 
2878       new_ksize = {ksize[NHWC::dim::N], ksize[NHWC::dim::C],
2879                    ksize[NHWC::dim::H], ksize[NHWC::dim::W]};
2880     }
2881     nb->Attr("strides", new_strides);
2882     nb->Attr("ksize", new_ksize);
2883   }
2884 }
2885 
2886 //////////////////////////////////////////////////////////////////////////
2887 //           Helper functions related to node merge pass
2888 //////////////////////////////////////////////////////////////////////////
2889 
CheckForNodeMerge(const Node * a) const2890 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
2891   // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
2892   // once we support BiasAddGrad as Mkl layer.
2893 
2894   // Search for all matching mergeinfo.
2895   // We allow more than one match for extensibility.
2896   std::vector<const MergeInfo*> matching_mi;
2897   for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
2898     if (a->type_string() == mi->op1 || a->type_string() == mi->op2) {
2899       matching_mi.push_back(&*mi);
2900     }
2901   }
2902 
2903   for (const MergeInfo* mi : matching_mi) {
2904     // Get the operand with which 'a' can be merged.
2905     Node* b = nullptr;
2906     if ((b = mi->get_node_to_be_merged(a)) == nullptr) {
2907       continue;
2908     }
2909 
2910     // Get the control edges and input of node
2911     const int N_in = a->num_inputs();
2912     gtl::InlinedVector<Node*, 4> a_control_edges;
2913     gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
2914     FillInputs(a, &a_control_edges, &a_in);
2915 
2916     const int B_in = b->num_inputs();
2917     gtl::InlinedVector<Node*, 4> b_control_edges;
2918     gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
2919     FillInputs(b, &b_control_edges, &b_in);
2920 
2921     // Shouldn't merge if a and b have different control edges.
2922     if (a_control_edges != b_control_edges) {
2923       continue;
2924     } else {
2925       // We found a match.
2926       return b;
2927     }
2928   }
2929 
2930   return nullptr;
2931 }
2932 
MergeConv2DWithBiasAdd(std::unique_ptr<Graph> * g,Node * m,Node * n)2933 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
2934                                                     Node* m, Node* n) {
2935   CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
2936              n->type_string() == csinfo_.conv2d)) ||
2937                ((n->type_string() == csinfo_.bias_add &&
2938                  m->type_string() == csinfo_.conv2d)),
2939            true);
2940 
2941   // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
2942   // BiasAdd is successor node, and Conv2D predecessor node.
2943   Node* pred = m->type_string() == csinfo_.bias_add ? n : m;
2944   Node* succ = m->type_string() == csinfo_.bias_add ? m : n;
2945 
2946   // 1. Get all attributes from input nodes.
2947   DataType T_pred, T_succ;
2948   string padding;
2949   std::vector<int32> strides;
2950   std::vector<int32> dilations;
2951   string data_format_pred, data_format_succ;
2952   bool use_cudnn_on_gpu;
2953   TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
2954   TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
2955   TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
2956   TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
2957   TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations));
2958   TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
2959   TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
2960   TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2961   // We check to ensure that data formats of both succ and pred are same.
2962   // We expect them to be same, so we can enforce this as assert.
2963   // But assert can be too strict, so we enforce this as a check.
2964   // If the check fails, then we do not merge two nodes.
2965   // We also do same check for devices.
2966   if (data_format_pred != data_format_succ || T_pred != T_succ ||
2967       pred->assigned_device_name() != succ->assigned_device_name() ||
2968       pred->def().device() != succ->def().device()) {
2969     return Status(error::Code::INVALID_ARGUMENT,
2970                   "data_format or T attribute or devices of Conv2D and "
2971                   "BiasAdd do not match. Will skip node merge optimization");
2972   }
2973 
2974   const int succ_num = succ->num_inputs();
2975   gtl::InlinedVector<Node*, 4> succ_control_edges;
2976   gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
2977   FillInputs(succ, &succ_control_edges, &succ_in);
2978 
2979   const int pred_num = pred->num_inputs();
2980   gtl::InlinedVector<Node*, 4> pred_control_edges;
2981   gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
2982   FillInputs(pred, &pred_control_edges, &pred_in);
2983 
2984   // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is
2985   // not expecting output of Conv2D). If this is not the case, then we cannot
2986   // merge Conv2D with BiasAdd.
2987   const int kFirstOutputSlot = 0;
2988   for (const Edge* e : pred->out_edges()) {
2989     if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
2990       return Status(error::Code::INVALID_ARGUMENT,
2991                     "Conv2D does not feed to BiasAdd, or "
2992                     "it feeds BiasAdd but has multiple outputs. "
2993                     "Will skip node merge optimization");
2994     }
2995   }
2996 
2997   // 2. Get inputs from both the nodes.
2998   // Find the 2 inputs from the conv and the bias from the add Bias.
2999   // Get operand 0, 1 of conv2D.
3000   CHECK_EQ(pred->in_edges().size(), 2);  // Conv2D must have 2 inputs.
3001   // Get operand 1 of add_bias
3002   // BiasAdd must have 2 inputs: Conv, bias
3003   CHECK_EQ(succ->in_edges().size(), 2);
3004 
3005   // We will use the node name of BiasAdd as the name of new node
3006   // Build new node. We use same name as original node, but change the op
3007   // name.
3008   NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias);
3009   nb.Input(pred_in[0].first, pred_in[0].second);  // In1 of Conv2D
3010   // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3011   nb.Input(pred_in[1].first, pred_in[1].second);  // In2 of Conv2D
3012   // In1 of BiasAdd is same as output of Conv2D.
3013   nb.Input(succ_in[1].first, succ_in[1].second);  // In2 of BiasAdd
3014 
3015   // Copy attributes from Conv2D to Conv2DWithBias.
3016   CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb);
3017 
3018   // Copy the device assigned to old node to new node.
3019   nb.Device(succ->def().device());
3020 
3021   // Create node.
3022   Node* new_node;
3023   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3024 
3025   // In the following code of this function, an unsorted set is used to make
3026   // sure no duplicated edges be added into the new node. Therefore, we can
3027   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3028   // check in the routine.
3029 
3030   // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3031   // node are already copied in BuildNode. We handle control edges now.
3032   std::unordered_set<Node*> unique_node;
3033   for (const Edge* e : pred->in_edges()) {
3034     if (e->IsControlEdge()) {
3035       auto result = unique_node.insert(e->src());
3036       if (result.second) {
3037         (*g)->AddControlEdge(e->src(), new_node, true);
3038       }
3039     }
3040   }
3041   unique_node.clear();
3042 
3043   for (const Edge* e : succ->in_edges()) {
3044     if (e->IsControlEdge()) {
3045       auto result = unique_node.insert(e->src());
3046       if (result.second) {
3047         (*g)->AddControlEdge(e->src(), new_node, true);
3048       }
3049     }
3050   }
3051   unique_node.clear();
3052 
3053   // Incoming edges are fixed, we will fix the outgoing edges now.
3054   // First, we will fix outgoing control edges from 'pred' node.
3055   for (const Edge* e : pred->out_edges()) {
3056     if (e->IsControlEdge()) {
3057       auto result = unique_node.insert(e->dst());
3058       if (result.second) {
3059         (*g)->AddControlEdge(new_node, e->dst(), true);
3060       }
3061     }
3062   }
3063   unique_node.clear();
3064 
3065   // Second, we will fix outgoing control and data edges from 'succ' node.
3066   for (const Edge* e : succ->out_edges()) {
3067     if (e->IsControlEdge()) {
3068       auto result = unique_node.insert(e->dst());
3069       if (result.second) {
3070         (*g)->AddControlEdge(new_node, e->dst(), true);
3071       }
3072     } else {
3073       // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
3074       // output (at slot 0).
3075       const int kConv2DWithBiasOutputSlot = 0;
3076       auto new_edge = (*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot,
3077                                     e->dst(), e->dst_input());
3078       DCHECK(new_edge);
3079     }
3080   }
3081 
3082   // Copy device assigned to old node to new node.
3083   // It's ok to use pred or succ as we have enforced a check that
3084   // both have same device assigned.
3085   new_node->set_assigned_device_name(pred->assigned_device_name());
3086 
3087   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3088           << ", and node: " << succ->DebugString()
3089           << ", into node:" << new_node->DebugString();
3090 
3091   (*g)->RemoveNode(succ);
3092   (*g)->RemoveNode(pred);
3093 
3094   return Status::OK();
3095 }
3096 
MergePadWithConv2D(std::unique_ptr<Graph> * g,Node * m,Node * n)3097 Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g,
3098                                                 Node* m, Node* n) {
3099   DCHECK((m->type_string() == csinfo_.pad &&
3100           (n->type_string() == csinfo_.conv2d ||
3101            n->type_string() == csinfo_.fused_conv2d)) ||
3102          (n->type_string() == csinfo_.pad &&
3103           (m->type_string() == csinfo_.conv2d ||
3104            m->type_string() == csinfo_.fused_conv2d)));
3105 
3106   bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d ||
3107                          m->type_string() == csinfo_.fused_conv2d;
3108   // Conv2D is successor node, and Pad predecessor node.
3109   Node* pred = m->type_string() == csinfo_.pad ? m : n;
3110   Node* succ = m->type_string() == csinfo_.pad ? n : m;
3111 
3112   // 1. Get all attributes from input nodes.
3113   DataType T_pred, T_succ;
3114   string padding;
3115   std::vector<int32> strides;
3116   std::vector<int32> dilations;
3117   string data_format_pred, data_format_succ;
3118 
3119   TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
3120   TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
3121   TF_CHECK_OK(GetNodeAttr(succ->def(), "padding", &padding));
3122   TF_CHECK_OK(GetNodeAttr(succ->def(), "strides", &strides));
3123   TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations", &dilations));
3124   // Check if the devices of both succ and pred are the same.
3125   // Assert is not used because it can be too strict.
3126   // Don't need to check for data formats because it is not available in Pad.
3127   if (T_pred != T_succ ||
3128       pred->assigned_device_name() != succ->assigned_device_name() ||
3129       pred->def().device() != succ->def().device()) {
3130     return Status(error::Code::INVALID_ARGUMENT,
3131                   "T attribute or devices of Conv2D and "
3132                   "Pad do not match. Will skip node merge optimization");
3133   }
3134 
3135   const int succ_num = succ->num_inputs();
3136   gtl::InlinedVector<Node*, 4> succ_control_edges;
3137   gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3138   FillInputs(succ, &succ_control_edges, &succ_in);
3139 
3140   const int pred_num = pred->num_inputs();
3141   gtl::InlinedVector<Node*, 4> pred_control_edges;
3142   gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3143   FillInputs(pred, &pred_control_edges, &pred_in);
3144 
3145   // We need to ensure that Pad only feeds to Conv2D (some other operator is
3146   // not expecting output of Pad). If this is not the case, then we cannot
3147   // merge Conv2D with Pad.
3148   const int kFirstOutputSlot = 0;
3149   for (const Edge* e : pred->out_edges()) {
3150     if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3151       return Status(error::Code::INVALID_ARGUMENT,
3152                     "Pad does not feed to Conv2D, or "
3153                     "it feeds Conv2D but has multiple outputs. "
3154                     "Will skip node merge optimization");
3155     }
3156   }
3157 
3158   // 2. Get inputs from both the nodes.
3159 
3160   // Pad must have 2 data inputs: "input" and paddings.
3161   int PadDataInputEdges = 0;
3162   for (const Edge* e : pred->in_edges()) {
3163     if (!e->IsControlEdge()) {
3164       PadDataInputEdges++;
3165     }
3166   }
3167   DCHECK_EQ(PadDataInputEdges, 2);
3168 
3169   // Conv2D must have 2 data inputs: Pad output and Filter
3170   // FusedConv2D have 3 data inputs: Pad output, Filter and Args;
3171   int ConvDataInputEdges = 0;
3172   for (const Edge* e : succ->in_edges()) {
3173     if (!e->IsControlEdge()) {
3174       ConvDataInputEdges++;
3175     }
3176   }
3177 
3178   DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2);
3179 
3180   // We will use the node name of Conv2D as the name of new node
3181   // Build new node. We use same name as original node, but change the op
3182   // name.
3183 
3184   NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d
3185                                                : csinfo_.pad_with_conv2d);
3186   nb.Input(pred_in[0].first, pred_in[0].second);  // In1 (input data)  of Pad
3187   // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3188   nb.Input(succ_in[1].first, succ_in[1].second);  // In2 (filter) of conv2d
3189   // In1 of Conv2D is same as output of Pad.
3190   // Thus, only need to add In2 of Conv2D
3191 
3192   if (is_fused_conv2d) {
3193     // FusedConv2D has one additional input, args
3194     std::vector<NodeBuilder::NodeOut> args;
3195     int num_args = 1;
3196     TF_CHECK_OK(GetNodeAttr(succ->def(), "num_args", &num_args));
3197     for (int i = 0; i < num_args; i++) {
3198       args.emplace_back(succ_in[2 + i].first, succ_in[2 + i].second);
3199     }
3200     nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{
3201         args});                                     // In3 (args) of FusedConv2D
3202     nb.Input(pred_in[1].first, pred_in[1].second);  // In2 (paddings) of Pad
3203     // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D.
3204     CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ),
3205                                    const_cast<const Node*>(pred), &nb);
3206   } else {
3207     nb.Input(pred_in[1].first, pred_in[1].second);  // In2 (paddings) of Pad
3208     // Copy attributes from Pad and conv2D to PadWithConv2D.
3209     CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ),
3210                               const_cast<const Node*>(pred), &nb);
3211   }
3212 
3213   // Copy the device assigned to old node to new node.
3214   nb.Device(succ->def().device());
3215 
3216   // Create node.
3217   Node* new_node;
3218   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3219   // No need to check if new_node is null because it will be null only when
3220   // Finalize fails.
3221 
3222   // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3223   // node are already copied in BuildNode.
3224   // We handle control edges now.
3225   for (const Edge* e : pred->in_edges()) {
3226     if (e->IsControlEdge()) {
3227       // Don't allow duplicate edge
3228       (*g)->AddControlEdge(e->src(), new_node, false);
3229     }
3230   }
3231   for (const Edge* e : succ->in_edges()) {
3232     if (e->IsControlEdge()) {
3233       // Don't allow duplicate edge
3234       (*g)->AddControlEdge(e->src(), new_node, false);
3235     }
3236   }
3237 
3238   // Incoming edges are fixed, we will fix the outgoing edges now.
3239   // First, we will fix outgoing control edges from 'pred' node.
3240   for (const Edge* e : pred->out_edges()) {
3241     if (e->IsControlEdge()) {
3242       // Don't allow duplicate edge
3243       (*g)->AddControlEdge(new_node, e->dst(), false);
3244     }
3245   }
3246 
3247   // Second, we will fix outgoing control and data edges from 'succ' node.
3248   for (const Edge* e : succ->out_edges()) {
3249     if (e->IsControlEdge()) {
3250       // Allow duplicate while adding control edge as it would fail (return
3251       // NULL) if we try to add duplicate edge.
3252       (*g)->AddControlEdge(new_node, e->dst(), false);
3253     } else {
3254       // Conv2D has only 1 output (at slot 0) and merged node also has only 1
3255       // output (at slot 0).
3256       const int kPadWithConv2DOutputSlot = 0;
3257       (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(),
3258                     e->dst_input());
3259     }
3260   }
3261 
3262   // Copy device assigned to old node to new node.
3263   // It's ok to use pred or succ as we have enforced a check that
3264   // both have same device assigned.
3265   new_node->set_assigned_device_name(pred->assigned_device_name());
3266 
3267   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3268           << ", and node: " << succ->DebugString()
3269           << ", into node:" << new_node->DebugString();
3270 
3271   (*g)->RemoveNode(succ);
3272   (*g)->RemoveNode(pred);
3273 
3274   return Status::OK();
3275 }
3276 
MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph> * g,Node * m,Node * n)3277 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
3278     std::unique_ptr<Graph>* g, Node* m, Node* n) {
3279   CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
3280              n->type_string() == csinfo_.conv2d_grad_filter)) ||
3281                ((n->type_string() == csinfo_.bias_add_grad &&
3282                  m->type_string() == csinfo_.conv2d_grad_filter)),
3283            true);
3284 
3285   // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
3286   Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
3287   Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m;
3288 
3289   // Sanity check for attributes from input nodes.
3290   DataType T_b, T_f;
3291   string data_format_b, data_format_f;
3292   TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b));
3293   TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f));
3294   TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b));
3295   TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f));
3296   if (data_format_b != data_format_f || T_b != T_f ||
3297       badd->assigned_device_name() != fltr->assigned_device_name() ||
3298       badd->def().device() != fltr->def().device()) {
3299     return Status(error::Code::INVALID_ARGUMENT,
3300                   "data_format or T attribute or devices of "
3301                   "Conv2DBackpropFilter and BiasAddGrad do not match. "
3302                   "Will skip node merge optimization");
3303   }
3304 
3305   // We will use the node name of Conv2DBackpropFilter as the name of new node.
3306   // This is because BackpropFilterWithBias is going to emit bias output also.
3307   NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias);
3308   // Since Conv2DBackpropFilterWithBias has same number of inputs as
3309   // Conv2DBackpropFilter, we can just copy input edges directly. We don't need
3310   // to copy any data input of BiasAddGrad because that input also goes to
3311   // Conv2DBackpropFilter.
3312   const int fltr_ins = fltr->num_inputs();
3313   gtl::InlinedVector<Node*, 4> fltr_control_edges;
3314   gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins);
3315   FillInputs(fltr, &fltr_control_edges, &fltr_in_edges);
3316   for (int idx = 0; idx < fltr_ins; idx++) {
3317     nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second);
3318   }
3319 
3320   // Copy attributes from Conv2DBackpropFilter.
3321   CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
3322 
3323   // Copy the device assigned to old node to new node.
3324   nb.Device(fltr->def().device());
3325 
3326   // Create node.
3327   Node* new_node;
3328   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3329 
3330   // In the following code of this function, an unsorted set is used to make
3331   // sure no duplicated edges be added into the new node. Therefore, we can
3332   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3333   // check in the routine.
3334 
3335   // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to
3336   // new 'new_node' node are already copied in BuildNode. We handle control
3337   // edges now.
3338   std::unordered_set<Node*> unique_node;
3339   for (const Edge* e : badd->in_edges()) {
3340     if (e->IsControlEdge()) {
3341       auto result = unique_node.insert(e->src());
3342       if (result.second) {
3343         (*g)->AddControlEdge(e->src(), new_node, true);
3344       }
3345     }
3346   }
3347   unique_node.clear();
3348   for (const Edge* e : fltr->in_edges()) {
3349     if (e->IsControlEdge()) {
3350       auto result = unique_node.insert(e->src());
3351       if (result.second) {
3352         (*g)->AddControlEdge(e->src(), new_node, true);
3353       }
3354     }
3355   }
3356   unique_node.clear();
3357 
3358   // Incoming edges are fixed, we will fix the outgoing edges now.
3359   // First, we will fix outgoing control edges from 'badd' node.
3360   // Conv2DBackpropFilter has 1 output -- filter_grad.
3361   // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and
3362   // bias_grad. But filter_grad is at same slot number (0) in both the
3363   // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while
3364   // it is at slot number 0 in BiasAddGrad.
3365   const int kMergedNodeFilterGradOutputIdx = 0;
3366   const int kMergedNodeBiasGradOutputIdx = 1;
3367 
3368   for (const Edge* e : badd->out_edges()) {
3369     if (e->IsControlEdge()) {
3370       auto result = unique_node.insert(e->dst());
3371       if (result.second) {
3372         (*g)->AddControlEdge(new_node, e->dst(), true);
3373       }
3374     } else {
3375       auto new_edge = (*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
3376                                     e->dst(), e->dst_input());
3377       DCHECK(new_edge);
3378     }
3379   }
3380   unique_node.clear();
3381 
3382   // Second, we will fix outgoing control and data edges from 'fltr' node.
3383   for (const Edge* e : fltr->out_edges()) {
3384     if (e->IsControlEdge()) {
3385       auto result = unique_node.insert(e->dst());
3386       if (result.second) {
3387         (*g)->AddControlEdge(new_node, e->dst(), true);
3388       }
3389     } else {
3390       auto new_edge = (*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
3391                                     e->dst(), e->dst_input());
3392       DCHECK(new_edge);
3393     }
3394   }
3395 
3396   // Copy device assigned to old node to new node.
3397   // It's ok to use badd or fltr as we have enforced a check that
3398   // both have same device assigned.
3399   new_node->set_assigned_device_name(badd->assigned_device_name());
3400 
3401   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString()
3402           << ", and node: " << fltr->DebugString()
3403           << ", into node:" << new_node->DebugString();
3404 
3405   (*g)->RemoveNode(badd);
3406   (*g)->RemoveNode(fltr);
3407 
3408   return Status::OK();
3409 }
3410 
MergeNode(std::unique_ptr<Graph> * g,Node * m,Node * n)3411 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m,
3412                                        Node* n) {
3413   DCHECK(m);
3414   DCHECK(n);
3415 
3416   if (((m->type_string() == csinfo_.bias_add &&
3417         n->type_string() == csinfo_.conv2d)) ||
3418       ((n->type_string() == csinfo_.bias_add &&
3419         m->type_string() == csinfo_.conv2d))) {
3420     return this->MergeConv2DWithBiasAdd(g, m, n);
3421   }
3422   if ((m->type_string() == csinfo_.pad &&
3423        (n->type_string() == csinfo_.conv2d ||
3424         (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) ||
3425       (n->type_string() == csinfo_.pad &&
3426        (m->type_string() == csinfo_.conv2d ||
3427         (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) {
3428     return this->MergePadWithConv2D(g, m, n);
3429   }
3430 
3431   if (((m->type_string() == csinfo_.bias_add_grad &&
3432         n->type_string() == csinfo_.conv2d_grad_filter)) ||
3433       ((n->type_string() == csinfo_.bias_add_grad &&
3434         m->type_string() == csinfo_.conv2d_grad_filter))) {
3435     return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n);
3436   }
3437 
3438   return Status(error::Code::UNIMPLEMENTED,
3439                 "Unimplemented case for node merge optimization.");
3440 }
3441 
3442 //////////////////////////////////////////////////////////////////////////
3443 //           Helper functions for node rewrite
3444 //////////////////////////////////////////////////////////////////////////
3445 
RewriteNodeForLayoutPropagation(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3446 Status MklLayoutRewritePass::RewriteNodeForLayoutPropagation(
3447     std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3448     const RewriteInfo* ri) {
3449   // Get all data inputs.
3450   int num_data_inputs = orig_node->in_edges().size();
3451   // Drop count for control edges from inputs
3452   for (const Edge* e : orig_node->in_edges()) {
3453     if (e->IsControlEdge()) {
3454       num_data_inputs--;
3455     }
3456   }
3457 
3458   gtl::InlinedVector<Node*, 4> control_edges;
3459   gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3460   FillInputs(orig_node, &control_edges, &inputs);
3461 
3462   // Build new node. We use same name as original node, but change the op name.
3463   NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3464   // Copy user-specified device assigned to original node to new node.
3465   nb.Device(orig_node->def().device());
3466   // Set up new inputs to the rewritten node.
3467   Status s = SetUpInputs(g, inputs, &nb, orig_node);
3468   if (s != Status::OK()) {
3469     return s;
3470   }
3471 
3472   const bool kPartialCopyAttrs = false;
3473   ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs);
3474 
3475   // Set the Mkl layer label for this op.
3476   if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3477       DataTypeIsQuantized(orig_node->output_type(0))) {
3478     nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3479   } else {
3480     nb.Attr("_kernel", mkl_op_registry::kMklLayoutDependentOpLabel);
3481   }
3482   // Finalize graph and get new node.
3483   s = nb.Finalize(&**g, new_node);
3484   if (s != Status::OK()) {
3485     return s;
3486   }
3487 
3488   // In the following code of this function, an unsorted set is used to make
3489   // sure no duplicated edges be added into the new node. Therefore, we can
3490   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3491   // check in the routine.
3492 
3493   // Incoming data edges from 'orig_node' node to new 'new_node' node are
3494   // already copied in BuildNode. We need to handle control edges now.
3495   std::unordered_set<Node*> unique_node;
3496   for (const Edge* e : orig_node->in_edges()) {
3497     if (e->IsControlEdge()) {
3498       auto result = unique_node.insert(e->src());
3499       if (result.second) {
3500         (*g)->AddControlEdge(e->src(), *new_node, true);
3501       }
3502     }
3503   }
3504   unique_node.clear();
3505 
3506   // Copy outgoing edges from 'orig_node' node to new
3507   // 'new_node' node, since the output also follows same ordering among
3508   // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
3509   // tensors appropriately. Specifically, nth output of the original node
3510   // will become 2*nth output of the Mkl node for the interleaved ordering
3511   // of the tensors. For the contiguous ordering of the tensors, it will be n.
3512   // GetTensorDataIndex provides this mapping function.
3513   for (const Edge* e : orig_node->out_edges()) {
3514     if (e->IsControlEdge()) {
3515       auto result = unique_node.insert(e->dst());
3516       if (result.second) {
3517         (*g)->AddControlEdge(*new_node, e->dst(), true);
3518       }
3519     } else {
3520       auto new_edge = (*g)->AddEdge(
3521           *new_node,
3522           GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
3523           e->dst(), e->dst_input());
3524       DCHECK(new_edge);
3525     }
3526   }
3527   return Status::OK();
3528 }
3529 
RewriteNodeForJustOpNameChange(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3530 Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(
3531     std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3532     const RewriteInfo* ri) {
3533   // Get all data inputs.
3534   int num_data_inputs = orig_node->in_edges().size();
3535   // Drop count for control edges from inputs
3536   for (const Edge* e : orig_node->in_edges()) {
3537     if (e->IsControlEdge()) {
3538       num_data_inputs--;
3539     }
3540   }
3541   gtl::InlinedVector<Node*, 4> control_edges;
3542   gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3543   FillInputs(orig_node, &control_edges, &inputs);
3544 
3545   // Build new node. We use same name as original node, but change the op name.
3546   NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3547   // Copy user-specified device assigned to original node to new node.
3548   nb.Device(orig_node->def().device());
3549 
3550   Status s = CopyInputs(orig_node, inputs, &nb);
3551   if (s != Status::OK()) {
3552     return s;
3553   }
3554 
3555   std::vector<NodeBuilder::NodeOut> workspace_tensors;
3556   bool are_workspace_tensors_available = false;
3557   if (IsWorkspaceCheckNeeded(orig_node)) {
3558     AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors,
3559                              &are_workspace_tensors_available);
3560     if (are_workspace_tensors_available) {
3561       DCHECK_EQ(workspace_tensors.size(), 1);
3562       nb.Input(workspace_tensors[0].node, workspace_tensors[0].index);
3563     }
3564   }
3565 
3566   if (!NativeFormatEnabled()) {
3567     ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
3568   } else {
3569     ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, false);
3570   }
3571 
3572   if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3573       DataTypeIsQuantized(orig_node->output_type(0))) {
3574     nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3575   } else {
3576     nb.Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
3577   }
3578 
3579   // Finalize graph and get new node.
3580   s = nb.Finalize(&**g, new_node);
3581   if (s != Status::OK()) {
3582     return s;
3583   }
3584 
3585   // In the following code of this function, an unsorted set is used to make
3586   // sure no duplicated edges be added into the new node. Therefore, we can
3587   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3588   // check in the routine.
3589 
3590   // Incoming data edges from 'orig_node' node to new 'new_node' node are
3591   // already copied in BuildNode. We need to handle control edges now.
3592   std::unordered_set<Node*> unique_node;
3593   for (const Edge* e : orig_node->in_edges()) {
3594     if (e->IsControlEdge()) {
3595       auto result = unique_node.insert(e->src());
3596       if (result.second) {
3597         (*g)->AddControlEdge(e->src(), *new_node, true);
3598       }
3599     }
3600   }
3601   unique_node.clear();
3602 
3603   // Transfer outgoing edges from 'orig_node' node to new 'new_node' node.
3604   for (const Edge* e : orig_node->out_edges()) {
3605     if (e->IsControlEdge()) {
3606       auto result = unique_node.insert(e->dst());
3607       if (result.second) {
3608         (*g)->AddControlEdge(*new_node, e->dst(), true);
3609       }
3610     } else {
3611       auto result =
3612           (*g)->AddEdge(*new_node, e->src_output(), e->dst(), e->dst_input());
3613       DCHECK(result != nullptr);
3614     }
3615   }
3616 
3617   return Status::OK();
3618 }
3619 
RewriteNode(std::unique_ptr<Graph> * g,Node * orig_node,const RewriteInfo * ri)3620 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
3621                                          Node* orig_node,
3622                                          const RewriteInfo* ri) {
3623   DCHECK(ri != nullptr);
3624   DCHECK(orig_node != nullptr);
3625 
3626   VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
3627 
3628   Status ret_status = Status::OK();
3629   Node* new_node = nullptr;
3630   if (ri->rewrite_cause == kRewriteForLayoutPropagation) {
3631     ret_status = RewriteNodeForLayoutPropagation(g, orig_node, &new_node, ri);
3632   } else if (ri->rewrite_cause == kRewriteForOpNameChange) {
3633     ret_status = RewriteNodeForJustOpNameChange(g, orig_node, &new_node, ri);
3634   } else {
3635     ret_status = Status(error::Code::INVALID_ARGUMENT,
3636                         "Unsupported rewrite cause found."
3637                         "RewriteNode will fail.");
3638   }
3639   TF_CHECK_OK(ret_status);
3640 
3641   // Copy the runtime device assigned from original code to new node.
3642   new_node->set_assigned_device_name(orig_node->assigned_device_name());
3643 
3644   // Delete original node and mark new node as rewritten.
3645   (*g)->RemoveNode(orig_node);
3646 
3647   VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
3648   return ret_status;
3649 }
3650 
3651 // TODO(mdfaijul): Is there any other elegant way to check for quantized ops
3652 // having attributes other than "T"?
3653 // Current implementation reflects only QuantizedConv2D and its fused Ops.
3654 const MklLayoutRewritePass::RewriteInfo*
CheckForQuantizedNodeRewrite(const Node * n) const3655 MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {
3656   DataType T1, T2;
3657   DataType Tinput, Tfilter;
3658   bool type_attrs_present = false;
3659 
3660   if (TryGetNodeAttr(n->def(), "Tinput", &Tinput) &&
3661       TryGetNodeAttr(n->def(), "Tfilter", &Tfilter) &&
3662       mkl_op_registry::IsMklQuantizedOp(
3663           mkl_op_registry::GetMklOpName(n->type_string()), Tinput, Tfilter)) {
3664     type_attrs_present = true;
3665   } else if (TryGetNodeAttr(n->def(), "T1", &T1) &&
3666              TryGetNodeAttr(n->def(), "T2", &T2) &&
3667              mkl_op_registry::IsMklQuantizedOp(
3668                  mkl_op_registry::GetMklOpName(n->type_string()), T1, T2)) {
3669     type_attrs_present = true;
3670   }
3671 
3672   if (type_attrs_present) {
3673     for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3674       if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3675         return &*ri;
3676       }
3677     }
3678   }
3679 
3680   return nullptr;
3681 }
3682 
3683 const MklLayoutRewritePass::RewriteInfo*
CheckForNodeRewrite(const Node * n) const3684 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
3685   DCHECK(n);
3686 
3687   // QuantizedOps may have attributes other than "T", so decoupled the check
3688   // with a function, CheckForQuantizedNodeRewrite(const Node*).
3689   const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n);
3690   if (ri != nullptr) return ri;
3691 
3692   // First check if node along with its type is supported by MKL layer.
3693   // We do not want to rewrite an op into Mkl op if types are not supported.
3694   // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
3695   // MklRelu if type is INT32.
3696   DataType T;
3697   if (!TryGetNodeAttr(n->def(), "T", &T)) {
3698     return nullptr;
3699   }
3700 
3701   // We make an exception for Conv2DGrad and MaxPool related ops as
3702   // the corresponding MKL ops currently do not support the case
3703   // of padding == EXPLICIT yet.
3704   // TODO(intel): support `EXPLICIT` padding for ConvGrad
3705   if (n->type_string() == csinfo_.conv2d_grad_input ||
3706       n->type_string() == csinfo_.conv2d_grad_filter ||
3707       n->type_string() == csinfo_.depthwise_conv2d_grad_filter ||
3708       n->type_string() == csinfo_.depthwise_conv2d_grad_input ||
3709       n->type_string() == csinfo_.conv3d_grad_filter ||
3710       n->type_string() == csinfo_.conv3d_grad_filter ||
3711       n->type_string() == csinfo_.max_pool ||
3712       n->type_string() == csinfo_.max_pool_grad ||
3713       n->type_string() == csinfo_.max_pool3d ||
3714       n->type_string() == csinfo_.max_pool3d_grad) {
3715     string padding;
3716     TF_CHECK_OK(GetNodeAttr(n->def(), "padding", &padding));
3717     if (padding == "EXPLICIT") return nullptr;
3718   }
3719 
3720   // We make an exception for __MklDummyConv2DWithBias,
3721   // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their
3722   // names do not match Mkl node names.
3723   if (n->type_string() != csinfo_.conv2d_with_bias &&
3724       n->type_string() != csinfo_.pad_with_conv2d &&
3725       n->type_string() != csinfo_.pad_with_fused_conv2d &&
3726       n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
3727       n->type_string() != csinfo_.fused_batch_norm_ex &&
3728       n->type_string() != csinfo_.fused_conv2d &&
3729       n->type_string() != csinfo_.fused_depthwise_conv2d &&
3730       n->type_string() != csinfo_.fused_matmul &&
3731       n->type_string() != csinfo_.fused_conv3d &&
3732       !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
3733                                 T)) {
3734     return nullptr;
3735   }
3736 
3737   // We now check if rewrite rule applies for this op. If rewrite rule passes
3738   // for this op, then we rewrite it to Mkl op.
3739   // Find matching RewriteInfo and then check that rewrite rule applies.
3740   for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3741     if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3742       return &*ri;
3743     }
3744   }
3745 
3746   // Else return not found.
3747   return nullptr;
3748 }
3749 
3750 //////////////////////////////////////////////////////////////////////////
3751 //           Helper functions for node fusion
3752 //////////////////////////////////////////////////////////////////////////
FuseTransposeMklOpTranspose(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,std::function<void (const Node *,NodeBuilder * nb,bool)> copy_attrs,string data_format)3753 Status MklLayoutRewritePass::FuseTransposeMklOpTranspose(
3754     std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3755     std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
3756     string data_format) {
3757   Node* transpose_to_nhwc = nodes[0];
3758   Node* mklop = nodes[1];
3759   Node* transpose_to_nchw = nodes[2];
3760 
3761   const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs();
3762   gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges;
3763   gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in(
3764       transpose_nhwc_num_inputs);
3765   FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges,
3766              &transpose_nhwc_in);
3767 
3768   const int mklop_num_inputs = mklop->num_inputs();
3769   gtl::InlinedVector<Node*, 4> mklop_control_edges;
3770   gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs);
3771   FillInputs(mklop, &mklop_control_edges, &mklop_in);
3772 
3773   const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs();
3774   gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges;
3775   gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in(
3776       transpose_nchw_num_inputs);
3777   FillInputs(transpose_to_nchw, &transpose_nchw_control_edges,
3778              &transpose_nchw_in);
3779 
3780   // We use same name as original node, but change the op
3781   // type.
3782   NodeBuilder nb(mklop->name(), mklop->type_string());
3783 
3784   // Storing the output slots of the input nodes.
3785   for (int i = 0; i < mklop_num_inputs; i++) {
3786     if (mklop_in[i].first == transpose_to_nhwc) {
3787       // Fill "x":
3788       nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second);
3789     } else {
3790       // Fill inputs other than "x":
3791       nb.Input(mklop_in[i].first, mklop_in[i].second);
3792     }
3793   }
3794 
3795   copy_attrs(const_cast<const Node*>(mklop), &nb, true);
3796   nb.Attr("data_format", data_format);
3797 
3798   // Copy the device assigned to old node to new node.
3799   nb.Device(mklop->def().device());
3800 
3801   // Create node.
3802   Node* new_node;
3803   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3804   // No need to check if new_node is null because it will be null only when
3805   // Finalize fails.
3806 
3807   // Fill outputs.
3808   for (const Edge* e : transpose_to_nchw->out_edges()) {
3809     if (!e->IsControlEdge()) {
3810       const int kTransposeWithMklOpOutputSlot = 0;
3811       auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot,
3812                                     e->dst(), e->dst_input());
3813       DCHECK(new_edge);
3814     }
3815   }
3816 
3817   // Copy device assigned to old node to new node.
3818   new_node->set_assigned_device_name(mklop->assigned_device_name());
3819 
3820   // Copy requested_device and assigned_device_name_index
3821   new_node->set_requested_device(mklop->requested_device());
3822   new_node->set_assigned_device_name_index(mklop->assigned_device_name_index());
3823 
3824   (*g)->RemoveNode(transpose_to_nhwc);
3825   (*g)->RemoveNode(mklop);
3826   (*g)->RemoveNode(transpose_to_nchw);
3827 
3828   return Status::OK();
3829 }
3830 
FuseNode(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,const MklLayoutRewritePass::FusionInfo fi)3831 Status MklLayoutRewritePass::FuseNode(
3832     std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3833     const MklLayoutRewritePass::FusionInfo fi) {
3834   return fi.fuse_func(g, nodes, fi.copy_attrs);
3835 }
3836 
3837 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
CheckForNodeFusion(Node * a) const3838 MklLayoutRewritePass::CheckForNodeFusion(Node* a) const {
3839   // Stores matched nodes, in the same order as node_checkers.
3840   std::vector<Node*> nodes;
3841 
3842   for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) {
3843     //
3844     // Make sure node "a" and its succeeding nodes (b, c ...), match the pattern
3845     // defined in fusion info (ops[0], ops[1], ...),
3846     // a.k.a. "a->b->c" matches "op1->op2->op3"
3847     //
3848 
3849     // Stores the first unvisited outgoing edge of each matched node in "nodes".
3850     std::stack<EdgeSet::const_iterator> current_neighbor_stack;
3851     nodes.clear();
3852 
3853     auto node_checker = fi->node_checkers.begin();
3854     if (a != nullptr && (*node_checker)(a)) {
3855       nodes.push_back(a);
3856       current_neighbor_stack.push(a->out_edges().begin());
3857       ++node_checker;
3858     }
3859 
3860     while (!nodes.empty()) {
3861       auto& current_neighbor_iter = current_neighbor_stack.top();
3862 
3863       if (current_neighbor_iter != nodes.back()->out_edges().end()) {
3864         // Found an unvisited edge. Goes through the edge to get the neighbor.
3865         Node* neighbor_node = (*current_neighbor_iter)->dst();
3866         ++current_neighbor_stack.top();  // Retrieves the next unvisited edge.
3867 
3868         if ((*node_checker)(neighbor_node)) {
3869           // Found a match. Stores the node and moves to the next checker.
3870           nodes.push_back(neighbor_node);
3871           current_neighbor_stack.push(neighbor_node->out_edges().begin());
3872           if (++node_checker == fi->node_checkers.end()) {
3873             return make_tuple(true, nodes, *fi);
3874           }
3875         }
3876       } else {
3877         // Removes the current node since none of its neighbor leads to a
3878         // further match.
3879         nodes.pop_back();
3880         current_neighbor_stack.pop();
3881         --node_checker;
3882       }
3883     }
3884   }
3885 
3886   return make_tuple(false, std::vector<Node*>(), FusionInfo());
3887 }
3888 
3889 ///////////////////////////////////////////////////////////////////////////////
3890 //              Post-rewrite Mkl metadata fixup pass
3891 ///////////////////////////////////////////////////////////////////////////////
FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph> * g,const Edge * e_data,const Edge * e_metadata)3892 bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
3893                                                       const Edge* e_data,
3894                                                       const Edge* e_metadata) {
3895   if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
3896     return false;
3897   }
3898 
3899   Node* n_data = e_data->src();
3900   int n_data_op_slot = e_data->src_output();
3901   int n_metadata_op_slot =
3902       GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs());
3903 
3904   // If the source of meta edge is a constant node (producing dummy Mkl metadata
3905   // tensor), then we will need to fix.
3906   if (IsConstant(e_metadata->src())) {
3907     Node* e_metadata_dst = e_metadata->dst();
3908     int e_metadata_in_slot = e_metadata->dst_input();
3909     auto new_edge = (*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst,
3910                                   e_metadata_in_slot);
3911     DCHECK(new_edge);
3912 
3913     (*g)->RemoveEdge(e_metadata);
3914     return true;
3915   }
3916 
3917   return false;
3918 }
3919 
FixMklMetaDataEdges(std::unique_ptr<Graph> * g,Node * n)3920 bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
3921                                                Node* n) {
3922   bool result = false;
3923 
3924   // If graph node is not Mkl node, then return.
3925   DataType T = DT_INVALID;
3926   if (!TryGetNodeAttr(n->def(), "T", &T) ||
3927       !mkl_op_registry::IsMklOp(n->type_string(), T, false)) {
3928     return result;
3929   }
3930 
3931   // If it is Mkl node, then check if the input edges to this node that carry
3932   // Mkl metadata are linked up correctly with the source node.
3933 
3934   // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
3935   // data tensors + n for Mkl metadata tensors). We need to check for correct
3936   // connection of n metadata tensors only.
3937   int num_data_inputs = n->num_inputs() / 2;
3938   for (int idx = 0; idx < num_data_inputs; idx++) {
3939     // Get the edge connecting input slot with index (idx).
3940     const Edge* e = nullptr;
3941     TF_CHECK_OK(n->input_edge(idx, &e));
3942 
3943     // If e is control edge, then skip.
3944     if (e->IsControlEdge()) {
3945       continue;
3946     }
3947 
3948     // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
3949     // node, then we don't need to do anything.
3950     Node* e_src = e->src();
3951     if (TryGetNodeAttr(e_src->def(), "T", &T) &&
3952         mkl_op_registry::IsMklOp(e_src->type_string(), T, false)) {
3953       // Source node for edge 'e' is Mkl node.
3954       // Destination node and destination input slot of e is node 'n' and 'idx'
3955       // resp.
3956       CHECK_EQ(e->dst(), n);
3957       CHECK_EQ(e->dst_input(), idx);
3958 
3959       // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
3960       // 'e'. For that, let's first get the input slot of 'n' where the meta
3961       // edge will feed the value.
3962       int e_meta_in_slot =
3963           GetTensorMetaDataIndex(e->dst_input(), n->num_inputs());
3964       const Edge* e_meta = nullptr;
3965       TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
3966 
3967       // Let's check if we need to fix this meta edge.
3968       if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
3969         result = true;
3970       }
3971     }
3972   }
3973 
3974   return result;
3975 }
3976 
3977 ///////////////////////////////////////////////////////////////////////////////
3978 //              Run function for the pass
3979 ///////////////////////////////////////////////////////////////////////////////
3980 
RunPass(std::unique_ptr<Graph> * g)3981 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
3982   bool result = false;
3983   DCHECK(g);
3984 
3985   DumpGraph("Before running MklLayoutRewritePass", &**g);
3986 
3987   std::vector<Node*> order;
3988   GetReversePostOrder(**g, &order);  // This will give us topological sort.
3989   for (Node* n : order) {
3990     // If node is not an op or it cannot run on CPU device, then skip.
3991     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
3992       continue;
3993     }
3994 
3995     Node* m = nullptr;
3996     if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) {
3997       // Check if the node 'n' can be merged with any other node. If it can
3998       // be 'm' contains the node with which it can be merged.
3999       string n1_name = n->name();
4000       string n2_name = m->name();
4001 
4002       VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
4003               << n2_name << " for merging";
4004 
4005       if (MergeNode(g, n, m) == Status::OK()) {
4006         VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
4007                 << n2_name;
4008         result = true;
4009       }
4010     }
4011   }
4012 
4013   DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g);
4014 
4015   order.clear();
4016   GetReversePostOrder(**g, &order);  // This will give us topological sort.
4017   for (Node* n : order) {
4018     // If node is not an op or it cannot run on CPU device, then skip.
4019     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4020       continue;
4021     }
4022 
4023     auto check_result = CheckForNodeFusion(n);
4024     bool found_pattern = std::get<0>(check_result);
4025     std::vector<Node*> nodes = std::get<1>(check_result);
4026     const FusionInfo fi = std::get<2>(check_result);
4027 
4028     // if "found_pattern" is true, we can do the fusion.
4029     if (found_pattern) {
4030       if (FuseNode(g, nodes, fi) == Status::OK()) {
4031         result = true;
4032       }
4033     }
4034   }
4035   DumpGraph("After running MklLayoutRewritePass(NodeFusion)", &**g);
4036 
4037   order.clear();
4038   GetReversePostOrder(**g, &order);  // This will give us topological sort.
4039   for (Node* n : order) {
4040     // If node is not an op or it cannot run on CPU device, then skip.
4041     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4042       continue;
4043     }
4044 
4045     const RewriteInfo* ri = nullptr;
4046     // We will first search if node is to be rewritten.
4047     if ((ri = CheckForNodeRewrite(n)) != nullptr) {
4048       string node_name = n->name();
4049       string op_name = n->type_string();
4050 
4051       VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
4052               << " with op " << op_name << " for rewrite using"
4053               << " layout optimization.";
4054 
4055       if (RewriteNode(g, n, ri) == Status::OK()) {
4056         VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
4057                 << " with op " << op_name << " for Mkl layout optimization.";
4058         result = true;
4059       }
4060     }
4061   }
4062 
4063   DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
4064 
4065   if (!NativeFormatEnabled()) {
4066     order.clear();
4067     GetReversePostOrder(**g, &order);  // This will give us topological sort.
4068     for (Node* n : order) {
4069       // If node is not an op or it cannot run on CPU device, then skip.
4070       if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4071         continue;
4072       }
4073       if (FixMklMetaDataEdges(g, n)) {
4074         string node_name = n->name();
4075         string op_name = n->type_string();
4076 
4077         VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
4078                 << node_name << " with op " << op_name;
4079         result = true;
4080       }
4081     }
4082     DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
4083               &**g);
4084   }
4085 
4086   return result;
4087 }
4088 
RunMklLayoutRewritePass(std::unique_ptr<Graph> * g)4089 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
4090   return MklLayoutRewritePass().RunPass(g);
4091 }
4092 
Run(const GraphOptimizationPassOptions & options)4093 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
4094   if (options.graph == nullptr && options.partition_graphs == nullptr) {
4095     return Status::OK();
4096   }
4097   if (!IsMKLEnabled()) {
4098     VLOG(2) << "TF-MKL: MKL is not enabled";
4099     return Status::OK();
4100   }
4101 
4102   auto process_graph = [&](std::unique_ptr<Graph>* g) {
4103     // Get the ownership of a graph
4104     std::unique_ptr<Graph>* ng = std::move(g);
4105     RunPass(ng);
4106     // Return the ownership of a graph back
4107     g->reset(ng->release());
4108   };
4109 
4110   if (kMklLayoutRewritePassGroup !=
4111       OptimizationPassRegistry::POST_PARTITIONING) {
4112     // For any pre-partitioning phase, a graph is stored in options.graph.
4113     process_graph(options.graph);
4114   } else {
4115     // For post partitioning phase, graphs are stored in
4116     // options.partition_graphs.
4117     for (auto& pg : *options.partition_graphs) {
4118       process_graph(&pg.second);
4119     }
4120   }
4121 
4122   return Status::OK();
4123 }
4124 
4125 }  // namespace tensorflow
4126 
4127 #endif  // INTEL_MKL
4128