• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // TODO(intel): Improve error handling in this file; instead of CHECK failing
17 // all over the place, we should log an error and execute the original graph.
18 #ifdef INTEL_MKL
19 
20 #include "tensorflow/core/common_runtime/mkl_layout_pass.h"
21 
22 #include <algorithm>
23 #include <functional>
24 #include <memory>
25 #include <queue>
26 #include <set>
27 #include <stack>
28 #include <tuple>
29 #include <unordered_set>
30 #include <utility>
31 #include <vector>
32 
33 #include "absl/base/call_once.h"
34 #include "tensorflow/core/common_runtime/function.h"
35 #include "tensorflow/core/common_runtime/optimization_registry.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/tensor.pb.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph.h"
40 #include "tensorflow/core/graph/mkl_graph_util.h"
41 #include "tensorflow/core/graph/node_builder.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/gtl/array_slice.h"
44 #include "tensorflow/core/lib/gtl/map_util.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/util/tensor_format.h"
48 #include "tensorflow/core/util/util.h"
49 
50 namespace tensorflow {
51 
52 // This pass implements rewriting of graph to support following scenarios:
53 // (A) Merging nodes in the graph
54 // (B) Rewriting a node in the graph to a new node
55 //     Rewrite happens under following scenario:
56 //     - Propagating Mkl layout as an additional output tensor
57 //        (we will loosely call a tensor that carries Mkl layout as Mkl tensor
58 //         henceforth.) from every Mkl supported NN layer.
59 //
60 // Example of A : Merging nodes in the graph
61 // -----------------------------------------
62 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
63 //
64 //           O = Conv2D(A, B)
65 //           P = BiasAdd(O, C)
66 //
67 // We merge them into Conv2DWithBias as:
68 //           P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
69 //
70 // The meaning of A_m, B_m and C_m is explained in B.1.
71 //
72 // Merge rules:
73 //  - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
74 //    goes to BiasAdd.
75 //  - Also, the intersection of attributes of both the nodes must have same
76 //    values.
77 //  - Both the nodes must have been assigned to same device (if any).
78 //
79 // Example of B.1 : Rewriting nodes to Mkl nodes
80 // ---------------------------------------------
81 // Consider a Relu node. Current definition of Relu node looks like:
82 //
83 //           O = Relu(A)
84 //
85 // Relu has 1 input (A), and 1 output (O).
86 //
87 // This rewrite pass will generate a new graph node for Relu (new node is
88 // called MklRelu) as:
89 //
90 //          O, O_m = MklRelu(A, A_m)
91 //
92 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
93 // same as input A of Relu; output O is same as output O of Relu. O_m is the
94 // additional output tensor that will be set by MklRelu, and it represents
95 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
96 // metadata for O. A_m is additional input of Relu, and it represents metadata
97 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
98 // this metadata from previous node in the graph.
99 //
100 // When a previous node in the graph is an Mkl node, A_m will represent a valid
101 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
102 // a dummy Mkl tensor.
103 //
104 // Rewriting rules:
105 //  - Selection of a node for rewriting happens by registering the op type of
106 //    the node with the rewriting pass. If the op type is not registered, then
107 //    all nodes of this op type will not be rewritten.
108 //  - Number of inputs after rewriting:
109 //      Since for every input Tensorflow tensor, the rewritten node gets Mkl
110 //      tensor(s), rewritten node gets 2*N inputs, where N is the number of
111 //      inputs for the original node.
112 //  - Number of outputs after rewriting:
113 //      Since for every output Tensorflow tensor, the rewritten node generates
114 //      Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
115 //      number of outputs of the original node.
116 //  - Ordering of Tensorflow tensors and Mkl tensors:
117 //      Since every rewritten node generates twice the number of inputs and
118 //      outputs, one could imagine various orderings among Tensorflow tensors
119 //      and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
120 //      inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
121 //      in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
122 //      order. Among N inputs one can get N! permutations.
123 //
124 //      So the question is: which order do we follow? We support 2 types of
125 //      orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
126 //      follows an intuitive order where an Mkl tensor follows the
127 //      corresponding Tensorflow tensor immediately. In the context of the
128 //      above example, it will be: A, A_m, B, B_m. Note that the ordering rule
129 //      applies to both the inputs and outputs. Contiguous ordering means
130 //      all the Tensorflow tensors are contiguous followed by all the Mkl
131 //      tensors. We use contiguous ordering as default.
132 //
133 // Graph rewrite algorithm:
134 //      Algorithm: Graph Rewrite
135 //      Input: Graph G, Names of the nodes to rewrite and their new names
136 //      Output: Modified Graph G' if the nodes are modified, G otherwise.
137 //      Start:
138 //        N = Topological_Sort(G) // N is a set of nodes in toposort order.
139 //        foreach node n in N
140 //        do
141 //          if (Is_MKL_Op(n))  // Can this node accept an Mkl layout as input.
142 //          then
143 //            E = set of <incoming edge and its src_output slot> of n
144 //            E' = {}   // a new set of edges for rewritten node
145 //            foreach <e,s> in E
146 //            do
147 //              E' U {<e,s>}  // First copy edge which generates Tensorflow
148 //                            // tensor as it is
149 //              m = Source node of edge e
150 //              if Is_Rewritten(m)  // Did we rewrite this node in this pass?
151 //              then
152 //                E' U {<m,s+1>}    // If yes, then m will generate an Mkl
153 //                                  // tensor as an additional output.
154 //              else
155 //                d = Generate_Dummy_Mkl_Tensor()  // If not, generate a dummy
156 //                                                 // Mkl tensor.
157 //                E' U {<d,0>}  // The dummy Mkl tensor has only 1 output slot.
158 //              fi
159 //            done
160 //            n' = Build_New_Node(G,new_name,E')
161 //            Mark_Rewritten(n')  // Mark the new node as being rewritten.
162 //          fi
163 //        done
164 //
165 //      Explanation:
166 //        For graph rewrite, we visit nodes of the input graph in the
167 //        topological sort order. With this ordering, we visit nodes in the
168 //        top-to-bottom fashion. We need this order because while visiting a
169 //        node we want that all of its input nodes are visited and rewritten if
170 //        applicable. This is because if we need to rewrite a given node
171 //        then all of its input nodes need to be fixed (in other words they
172 //        cannot be deleted later.)
173 //
174 //        While visiting a node, we first check if the op type of the node is
175 //        an Mkl op. If it is, then we rewrite that node after constructing
176 //        new inputs to the node. If the op type of the node is not Mkl op,
177 //        then we do not rewrite that node.
178 //
179 // Handling workspace propagation for certain ops:
180 //
181 //        Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
182 //        passing of a workspace from their respective forward ops. Workspace
183 //        tensors provide memory for storing results of intermediate operations
184 //        which are helpful in backward propagation. TensorFlow does not have
185 //        a notion of a workspace and as a result does not allow producing
186 //        additional outputs from these forward ops. For these ops, we need
187 //        to add 2 extra edges between forward ops and their corresponding
188 //        backward ops - the first extra edge carries a workspace tensor and
189 //        the second one carries an Mkl tensor for the workspace tensor.
190 //
191 //        Example:
192 //
193 //        Typical graph for MaxPool and its gradient looks like:
194 //
195 //        A = MaxPool(T)
196 //        B = MaxPoolGrad(X, A, Y)
197 //
198 //        We will transform this graph to propagate the workspace as:
199 //        (with the contiguous ordering)
200 //
201 //        A, W, A_m, W_m = MklMaxPool(T, T_m)
202 //        B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
203 //
204 //        Here W is the workspace tensor. Transformed tensor names with the
205 //        suffix _m are Mkl tensors, and this transformation has been done
206 //        using the algorithm discussed earlier. The transformation for
207 //        workspace propagation only adds extra outputs (W, W_m) for a forward
208 //        op and connects them to the corresponding backward ops.
209 //
210 //        Terms:
211 //
212 //        Forward op name = name of the op in the forward pass
213 //          where a workspace tensor originates (MaxPool in this example)
214 //        Backward op name = name of the op in the backward pass that receives
215 //          a workspace tensor from the forward op (MaxPoolGrad in the example)
216 //        Slot = Position of the output or input slot that will be
217 //               used by the workspace tensor (1 for MklMaxPool as W is the 2nd
218 //               output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
219 //
220 //        Question:
221 //
222 //        How do we associate a backward op to a forward op? There can be more
223 //        than one op with the exact same name.
224 //
225 //        In this example, we associate MaxPoolGrad with MaxPool. But there
226 //        could be more than one MaxPool ops. To solve this problem, we look
227 //        for _direct_ edge between a forward op and a backward op (tensor A is
228 //        flowing along this edge in the example).
229 //
230 //        How do we transform forward and backward ops when there is no direct
231 //        edge between them? In such a case, we generate dummy tensors for
232 //        workspace tensors. For the example, transformation of MaxPool will
233 //        be exactly same as it would be when there is a direct edge between
234 //        the forward and the backward op --- it is just that MaxPool won't
235 //        generate any workspace tensor. For MaxPoolGrad, the transformation
236 //        will also be same, but instead of connecting W and W_m with the
237 //        outputs of MaxPool, we will produce dummy tensors for them, and we
238 //        will set workspace_enabled attribute to false.
239 //
240 class MklLayoutRewritePass : public GraphOptimizationPass {
241  public:
MklLayoutRewritePass()242   MklLayoutRewritePass() {
243     // NOTE: names are alphabetically sorted.
244     csinfo_.addn = "AddN";
245     csinfo_.avg_pool = "AvgPool";
246     csinfo_.avg_pool_grad = "AvgPoolGrad";
247     csinfo_.avg_pool3d = "AvgPool3D";
248     csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
249     csinfo_.batch_matmul = "BatchMatMul";
250     csinfo_.batch_matmul_v2 = "BatchMatMulV2";
251     csinfo_.bias_add = "BiasAdd";
252     csinfo_.bias_add_grad = "BiasAddGrad";
253     csinfo_.concat = "Concat";
254     csinfo_.concatv2 = "ConcatV2";
255     csinfo_.conjugate_transpose = "ConjugateTranspose";
256     csinfo_.conv2d = "Conv2D";
257     csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias";
258     csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
259     csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
260     csinfo_.conv2d_grad_filter_with_bias =
261         "__MklDummyConv2DBackpropFilterWithBias";
262     csinfo_.conv3d = "Conv3D";
263     csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
264     csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
265     csinfo_.depthwise_conv2d = "DepthwiseConv2dNative";
266     csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput";
267     csinfo_.depthwise_conv2d_grad_filter =
268         "DepthwiseConv2dNativeBackpropFilter";
269     csinfo_.dequantize = "Dequantize";
270     csinfo_.einsum = "Einsum";
271     csinfo_.fused_batch_norm = "FusedBatchNorm";
272     csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
273     csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx";
274     csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
275     csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
276     csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
277     csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
278     csinfo_.fused_conv2d = "_FusedConv2D";
279     csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
280     csinfo_.fused_matmul = "_FusedMatMul";
281     csinfo_.identity = "Identity";
282     csinfo_.leakyrelu = "LeakyRelu";
283     csinfo_.leakyrelu_grad = "LeakyReluGrad";
284     csinfo_.lrn = "LRN";
285     csinfo_.lrn_grad = "LRNGrad";
286     csinfo_.matmul = "MatMul";
287     csinfo_.max_pool = "MaxPool";
288     csinfo_.max_pool_grad = "MaxPoolGrad";
289     csinfo_.max_pool3d = "MaxPool3D";
290     csinfo_.max_pool3d_grad = "MaxPool3DGrad";
291     csinfo_.mkl_conv2d = "_MklConv2D";
292     csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
293     csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
294     csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
295     csinfo_.mkl_conv2d_grad_filter_with_bias =
296         "_MklConv2DBackpropFilterWithBias";
297     csinfo_.mkl_depthwise_conv2d_grad_input =
298         "_MklDepthwiseConv2dNativeBackpropInput";
299     csinfo_.mkl_depthwise_conv2d_grad_filter =
300         "_MklDepthwiseConv2dNativeBackpropFilter";
301     csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx";
302     csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
303     csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
304     csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
305     csinfo_.mkl_native_conv2d_with_bias = "_MklNativeConv2DWithBias";
306     csinfo_.mkl_native_conv2d_grad_filter_with_bias =
307         "_MklNativeConv2DBackpropFilterWithBias";
308     csinfo_.mkl_native_fused_batch_norm_ex = "_MklNativeFusedBatchNormEx";
309     csinfo_.mkl_native_fused_conv2d = "_MklNativeFusedConv2D";
310     csinfo_.mkl_native_fused_depthwise_conv2d =
311         "_MklNativeFusedDepthwiseConv2dNative";
312     csinfo_.mkl_native_fused_matmul = "_MklNativeFusedMatMul";
313     csinfo_.mkl_native_pad_with_conv2d = "_MklNativePadWithConv2D";
314     csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D";
315     csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
316     csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
317     csinfo_.pad = "Pad";
318     csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D";
319     csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D";
320     csinfo_.quantized_avg_pool = "QuantizedAvgPool";
321     csinfo_.quantized_concatv2 = "QuantizedConcatV2";
322     csinfo_.quantized_conv2d = "QuantizedConv2D";
323     csinfo_.quantized_conv2d_per_channel = "QuantizedConv2DPerChannel";
324     csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize";
325     csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias";
326     csinfo_.quantized_conv2d_with_bias_and_requantize =
327         "QuantizedConv2DWithBiasAndRequantize";
328     csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu";
329     csinfo_.quantized_conv2d_and_relu_and_requantize =
330         "QuantizedConv2DAndReluAndRequantize";
331     csinfo_.quantized_conv2d_with_bias_and_relu =
332         "QuantizedConv2DWithBiasAndRelu";
333     csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize =
334         "QuantizedConv2DWithBiasAndReluAndRequantize";
335     csinfo_.quantized_max_pool = "QuantizedMaxPool";
336     csinfo_.quantized_conv2d_with_bias_sum_and_relu =
337         "QuantizedConv2DWithBiasSumAndRelu";
338     csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize =
339         "QuantizedConv2DWithBiasSumAndReluAndRequantize";
340     csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize =
341         "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize";
342     csinfo_.quantized_matmul_with_bias = "QuantizedMatMulWithBias";
343     csinfo_.quantized_matmul_with_bias_and_relu =
344         "QuantizedMatMulWithBiasAndRelu";
345     csinfo_.quantized_matmul_with_bias_and_relu_and_requantize =
346         "QuantizedMatMulWithBiasAndReluAndRequantize";
347     csinfo_.quantized_matmul_with_bias_and_dequantize =
348         "QuantizedMatMulWithBiasAndDequantize";
349     csinfo_.quantized_matmul_with_bias_and_requantize =
350         "QuantizedMatMulWithBiasAndRequantize";
351     csinfo_.quantized_depthwise_conv2d = "QuantizedDepthwiseConv2D";
352     csinfo_.quantized_depthwise_conv2d_with_bias =
353         "QuantizedDepthwiseConv2DWithBias";
354     csinfo_.quantized_depthwise_conv2d_with_bias_and_relu =
355         "QuantizedDepthwiseConv2DWithBiasAndRelu";
356     csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize =
357         "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize";
358     csinfo_.quantize_v2 = "QuantizeV2";
359     csinfo_.relu = "Relu";
360     csinfo_.relu_grad = "ReluGrad";
361     csinfo_.relu6 = "Relu6";
362     csinfo_.relu6_grad = "Relu6Grad";
363     csinfo_.requantize = "Requantize";
364     csinfo_.tanh = "Tanh";
365     csinfo_.tanh_grad = "TanhGrad";
366     csinfo_.reshape = "Reshape";
367     csinfo_.slice = "Slice";
368     csinfo_.softmax = "Softmax";
369     csinfo_.split = "Split";
370     csinfo_.transpose = "Transpose";
371     // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
372     // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
373     // MklInputConversion op is added before it.
374     csinfo_.add = "Add";
375     csinfo_.add_v2 = "AddV2";
376     csinfo_.maximum = "Maximum";
377     csinfo_.mul = "Mul";
378     csinfo_.squared_difference = "SquaredDifference";
379     csinfo_.sub = "Sub";
380     // End - element-wise ops. See note above.
381 
382     const bool native_fmt = NativeFormatEnabled();
383     // NOTE: names are alphabetically sorted.
384     rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
385                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
386     rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
387                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
388                       GetRewriteCause()});
389     rinfo_.push_back(
390         {csinfo_.add_v2, mkl_op_registry::GetMklOpName(csinfo_.add_v2),
391          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
392     rinfo_.push_back({csinfo_.avg_pool,
393                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
394                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
395     rinfo_.push_back({csinfo_.avg_pool_grad,
396                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
397                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
398     rinfo_.push_back({csinfo_.avg_pool3d,
399                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
400                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
401     rinfo_.push_back({csinfo_.avg_pool3d_grad,
402                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
403                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
404     rinfo_.push_back({csinfo_.batch_matmul,
405                       mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
406                       CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
407     rinfo_.push_back({csinfo_.einsum,
408                       mkl_op_registry::GetMklOpName(csinfo_.einsum),
409                       CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
410     rinfo_.push_back({csinfo_.batch_matmul_v2,
411                       mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
412                       CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
413     rinfo_.push_back({csinfo_.concat,
414                       mkl_op_registry::GetMklOpName(csinfo_.concat),
415                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
416     rinfo_.push_back({csinfo_.concatv2,
417                       mkl_op_registry::GetMklOpName(csinfo_.concatv2),
418                       CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
419     rinfo_.push_back(
420         {csinfo_.conjugate_transpose,
421          mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
422          CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
423     rinfo_.push_back(
424         {csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d),
425          CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
426     rinfo_.push_back({csinfo_.conv2d_with_bias,
427                       native_fmt ? csinfo_.mkl_native_conv2d_with_bias
428                                  : csinfo_.mkl_conv2d_with_bias,
429                       CopyAttrsConvCheckConstFilter, AlwaysRewrite,
430                       GetRewriteCause()});
431     rinfo_.push_back({csinfo_.conv2d_grad_filter,
432                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
433                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
434     rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
435                       native_fmt
436                           ? csinfo_.mkl_native_conv2d_grad_filter_with_bias
437                           : csinfo_.mkl_conv2d_grad_filter_with_bias,
438                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
439     rinfo_.push_back({csinfo_.conv2d_grad_input,
440                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
441                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
442     rinfo_.push_back(
443         {csinfo_.conv3d, mkl_op_registry::GetMklOpName(csinfo_.conv3d),
444          CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
445     rinfo_.push_back({csinfo_.conv3d_grad_filter,
446                       mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
447                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
448     rinfo_.push_back({csinfo_.conv3d_grad_input,
449                       mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
450                       CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
451     rinfo_.push_back({csinfo_.depthwise_conv2d,
452                       mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
453                       CopyAttrsConvCheckConstFilter, AlwaysRewrite,
454                       GetRewriteCause()});
455     rinfo_.push_back(
456         {csinfo_.depthwise_conv2d_grad_input,
457          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
458          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
459     rinfo_.push_back(
460         {csinfo_.depthwise_conv2d_grad_filter,
461          mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
462          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
463     rinfo_.push_back(
464         {csinfo_.dequantize, mkl_op_registry::GetMklOpName(csinfo_.dequantize),
465          CopyAttrsAll, DequantizeRewrite, kRewriteForOpNameChange});
466     rinfo_.push_back({csinfo_.fused_batch_norm,
467                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
468                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
469     rinfo_.push_back(
470         {csinfo_.fused_batch_norm_grad,
471          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
472          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
473     rinfo_.push_back(
474         {csinfo_.fused_batch_norm_v2,
475          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v2),
476          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
477     rinfo_.push_back(
478         {csinfo_.fused_batch_norm_grad_v2,
479          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v2),
480          CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
481 
482     // Using CopyAttrsAll for V3 on CPU, as there are no additional
483     // attributes.
484     rinfo_.push_back(
485         {csinfo_.fused_batch_norm_v3,
486          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
487          CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
488     rinfo_.push_back(
489         {csinfo_.fused_batch_norm_grad_v3,
490          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
491          CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
492     rinfo_.push_back({csinfo_.fused_batch_norm_ex,
493                       native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
494                                  : csinfo_.mkl_fused_batch_norm_ex,
495                       CopyAttrsAll, FusedBatchNormExRewrite,
496                       GetRewriteCause()});
497     rinfo_.push_back({csinfo_.fused_conv2d,
498                       native_fmt ? csinfo_.mkl_native_fused_conv2d
499                                  : csinfo_.mkl_fused_conv2d,
500                       CopyAttrsAllCheckConstFilter, FusedConv2DRewrite,
501                       GetRewriteCause()});
502     rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
503                       native_fmt ? csinfo_.mkl_native_fused_depthwise_conv2d
504                                  : csinfo_.mkl_fused_depthwise_conv2d,
505                       CopyAttrsAllCheckConstFilter, FusedDepthwiseConv2DRewrite,
506                       GetRewriteCause()});
507     rinfo_.push_back({csinfo_.fused_matmul,
508                       native_fmt ? csinfo_.mkl_native_fused_matmul
509                                  : csinfo_.mkl_fused_matmul,
510                       CopyAttrsAllCheckConstFilter, FusedMatMulRewrite,
511                       GetRewriteCause()});
512 
513     rinfo_.push_back(
514         {csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity),
515          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
516     rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
517                       CopyAttrsAll, LrnRewrite, GetRewriteCause()});
518     rinfo_.push_back({csinfo_.lrn_grad,
519                       mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
520                       CopyAttrsAll, LrnGradRewrite, GetRewriteCause()});
521     rinfo_.push_back({csinfo_.matmul,
522                       mkl_op_registry::GetMklOpName(csinfo_.matmul),
523                       CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
524     rinfo_.push_back({csinfo_.leakyrelu,
525                       mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
526                       CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
527     rinfo_.push_back({csinfo_.leakyrelu_grad,
528                       mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
529                       CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
530     rinfo_.push_back(
531         {csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool),
532          CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
533     rinfo_.push_back({csinfo_.max_pool_grad,
534                       mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
535                       CopyAttrsAll, MaxpoolGradRewrite, GetRewriteCause()});
536     rinfo_.push_back(
537         {csinfo_.max_pool3d, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
538          CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
539     rinfo_.push_back({csinfo_.max_pool3d_grad,
540                       mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
541                       CopyAttrsAll, Maxpool3DGradRewrite, GetRewriteCause()});
542     rinfo_.push_back(
543         {csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum),
544          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
545     rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
546                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
547                       GetRewriteCause()});
548     rinfo_.push_back({csinfo_.pad_with_conv2d,
549                       native_fmt ? csinfo_.mkl_native_pad_with_conv2d
550                                  : csinfo_.mkl_pad_with_conv2d,
551                       CopyAttrsAllCheckConstFilter, AlwaysRewrite,
552                       GetRewriteCause()});
553     rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
554                       native_fmt ? csinfo_.mkl_native_pad_with_fused_conv2d
555                                  : csinfo_.mkl_pad_with_fused_conv2d,
556                       CopyAttrsAllCheckConstFilter, AlwaysRewrite,
557                       GetRewriteCause()});
558     rinfo_.push_back({csinfo_.quantized_avg_pool,
559                       mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
560                       CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
561     rinfo_.push_back({csinfo_.quantized_concatv2,
562                       mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
563                       CopyAttrsAll, ConcatV2Rewrite, kRewriteForOpNameChange});
564     rinfo_.push_back({csinfo_.quantized_conv2d,
565                       mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
566                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
567                       kRewriteForOpNameChange});
568     rinfo_.push_back(
569         {csinfo_.quantized_conv2d_per_channel,
570          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_per_channel),
571          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
572     rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
573                       mkl_op_registry::GetMklOpName(
574                           csinfo_.quantized_conv2d_with_requantize),
575                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
576                       kRewriteForOpNameChange});
577     rinfo_.push_back(
578         {csinfo_.quantized_conv2d_with_bias,
579          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
580          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
581     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
582                       mkl_op_registry::GetMklOpName(
583                           csinfo_.quantized_conv2d_with_bias_and_requantize),
584                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
585                       kRewriteForOpNameChange});
586     rinfo_.push_back(
587         {csinfo_.quantized_conv2d_and_relu,
588          mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
589          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
590     rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
591                       mkl_op_registry::GetMklOpName(
592                           csinfo_.quantized_conv2d_and_relu_and_requantize),
593                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
594                       kRewriteForOpNameChange});
595     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
596                       mkl_op_registry::GetMklOpName(
597                           csinfo_.quantized_conv2d_with_bias_and_relu),
598                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
599                       kRewriteForOpNameChange});
600     rinfo_.push_back(
601         {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
602          mkl_op_registry::GetMklOpName(
603              csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
604          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
605     rinfo_.push_back({csinfo_.quantized_max_pool,
606                       mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
607                       CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
608     rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
609                       mkl_op_registry::GetMklOpName(
610                           csinfo_.quantized_conv2d_with_bias_sum_and_relu),
611                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
612                       kRewriteForOpNameChange});
613     rinfo_.push_back(
614         {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
615          mkl_op_registry::GetMklOpName(
616              csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
617          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
618     rinfo_.push_back(
619         {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
620          mkl_op_registry::GetMklOpName(
621              csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
622          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
623     rinfo_.push_back(
624         {csinfo_.quantized_matmul_with_bias,
625          mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias),
626          CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
627          kRewriteForOpNameChange});
628     rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_relu,
629                       mkl_op_registry::GetMklOpName(
630                           csinfo_.quantized_matmul_with_bias_and_relu),
631                       CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
632                       kRewriteForOpNameChange});
633     rinfo_.push_back(
634         {csinfo_.quantized_matmul_with_bias_and_relu_and_requantize,
635          mkl_op_registry::GetMklOpName(
636              csinfo_.quantized_matmul_with_bias_and_relu_and_requantize),
637          CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
638          kRewriteForOpNameChange});
639     rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_requantize,
640                       mkl_op_registry::GetMklOpName(
641                           csinfo_.quantized_matmul_with_bias_and_requantize),
642                       CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
643                       kRewriteForOpNameChange});
644     rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_dequantize,
645                       mkl_op_registry::GetMklOpName(
646                           csinfo_.quantized_matmul_with_bias_and_dequantize),
647                       CopyAttrsQuantizedMatMulWithBiasAndDequantize,
648                       AlwaysRewrite, kRewriteForOpNameChange});
649     rinfo_.push_back(
650         {csinfo_.quantized_depthwise_conv2d,
651          mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d),
652          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
653     rinfo_.push_back({csinfo_.quantized_depthwise_conv2d_with_bias,
654                       mkl_op_registry::GetMklOpName(
655                           csinfo_.quantized_depthwise_conv2d_with_bias),
656                       CopyAttrsQuantizedConv2D, AlwaysRewrite,
657                       kRewriteForOpNameChange});
658     rinfo_.push_back(
659         {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu,
660          mkl_op_registry::GetMklOpName(
661              csinfo_.quantized_depthwise_conv2d_with_bias_and_relu),
662          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
663     rinfo_.push_back(
664         {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize,
665          mkl_op_registry::GetMklOpName(
666              csinfo_
667                  .quantized_depthwise_conv2d_with_bias_and_relu_and_requantize),
668          CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
669     rinfo_.push_back({csinfo_.quantize_v2,
670                       mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
671                       CopyAttrsAll, QuantizeOpRewrite,
672                       kRewriteForOpNameChange});
673     rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
674                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
675     rinfo_.push_back({csinfo_.relu_grad,
676                       mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
677                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
678     rinfo_.push_back({csinfo_.relu6,
679                       mkl_op_registry::GetMklOpName(csinfo_.relu6),
680                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
681     rinfo_.push_back({csinfo_.relu6_grad,
682                       mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
683                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
684     rinfo_.push_back({csinfo_.requantize,
685                       mkl_op_registry::GetMklOpName(csinfo_.requantize),
686                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
687     // Optimized TanhGrad support exists only in DNNL 1.x.
688     rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
689                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
690     rinfo_.push_back({csinfo_.tanh_grad,
691                       mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
692                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
693     rinfo_.push_back({csinfo_.reshape,
694                       mkl_op_registry::GetMklOpName(csinfo_.reshape),
695                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
696     rinfo_.push_back(
697         {csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice),
698          CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
699     rinfo_.push_back({csinfo_.softmax,
700                       mkl_op_registry::GetMklOpName(csinfo_.softmax),
701                       CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
702 
703     rinfo_.push_back({csinfo_.squared_difference,
704                       mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
705                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
706                       GetRewriteCause()});
707     rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
708                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
709                       GetRewriteCause()});
710     rinfo_.push_back({csinfo_.transpose,
711                       mkl_op_registry::GetMklOpName(csinfo_.transpose),
712                       CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
713 
714     // Add info about which ops to add workspace edge to and the slots.
715     wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
716     wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
717     wsinfo_.push_back(
718         {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
719 
720     // Add a rule for merging nodes
721     minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
722                       csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
723 
724     // Merge Pad and Conv2d, only if the pad op is "Pad"
725     // Doesn't merge if pad op is "PadV2" or "MirrorPad"
726     minfo_.push_back(
727         {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D});
728 
729     minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
730                       csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
731 
732     minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
733                       csinfo_.conv2d_grad_filter_with_bias,
734                       GetConv2DBackpropFilterOrBiasAddGrad});
735 
736     // The fusion patterns in "finfo_" that show up first will get applied
737     // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
738     // A->B->C->D to ABCD}, since the first gets applied first, the final
739     // graph will be ABC->D.
740 
741     //
742     // Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D
743     // (NHWC) + Transpose (NHWC->
744     // NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras.
745     // Note: we use the term "merge" to combine (exactly) 2 nodes into one,
746     // while "fusion" is for 3+ nodes situation.
747     //
748 
749     // Transpose + Conv2d + Transpose:
750     std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H,
751                                           NCHW::dim::W, NCHW::dim::C};
752     std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C,
753                                           NHWC::dim::H, NHWC::dim::W};
754     auto CheckForTransposeToNHWC =
755         std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nhwc);
756     auto CheckForConv2dOp =
757         std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d);
758     auto CheckForTransposeToNCHW =
759         std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nchw);
760     auto FuseConv2D =
761         std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
762                   std::placeholders::_2, std::placeholders::_3, "NCHW");
763     finfo_.push_back(
764         {"transpose-elimination for Conv2D",
765          {CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW},
766          FuseConv2D,
767          CopyAttrsConv});
768 
769     // Transpose + Conv3d + Transpose:
770     std::vector<int> transpose_to_ndhwc = {NCDHW::dim::N, NCDHW::dim::D,
771                                            NCDHW::dim::H, NCDHW::dim::W,
772                                            NCDHW::dim::C};
773     std::vector<int> transpose_to_ncdhw = {NDHWC::dim::N, NDHWC::dim::C,
774                                            NDHWC::dim::D, NDHWC::dim::H,
775                                            NDHWC::dim::W};
776 
777     auto CheckForTransposeToNDHWC =
778         std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ndhwc);
779     auto CheckForConv3dOp =
780         std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv3d);
781     auto CheckForTransposeToNCDHW =
782         std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ncdhw);
783     auto FuseConv3D =
784         std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
785                   std::placeholders::_2, std::placeholders::_3, "NCDHW");
786 
787     finfo_.push_back(
788         {"transpose-elimination for Conv3D",
789          {CheckForTransposeToNDHWC, CheckForConv3dOp, CheckForTransposeToNCDHW},
790          FuseConv3D,
791          CopyAttrsConv});
792 
793     auto CheckForMaxPool3DOp =
794         std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.max_pool3d);
795     auto FuseMaxPool3D =
796         std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
797                   std::placeholders::_2, std::placeholders::_3, "NCDHW");
798     finfo_.push_back({"transpose-elimination for MaxPool3D",
799                       {CheckForTransposeToNDHWC, CheckForMaxPool3DOp,
800                        CheckForTransposeToNCDHW},
801                       FuseMaxPool3D,
802                       CopyAttrsPooling});
803   }
804 
805   // Standard interface to run pass
806   Status Run(const GraphOptimizationPassOptions& options);
807 
808   // Helper function which does most of heavy lifting for rewriting
809   // Mkl nodes to propagate Mkl tensor as additional output
810   //
811   // Extracts common functionality between Run public interface and
812   // test interface.
813   //
814   // @return true, if and only if graph is mutated; false otherwise.
815   bool RunPass(std::unique_ptr<Graph>* g);
816 
817   /// Cause for rewrite
818   /// Currently, we only support 2 causes - either for Mkl layout propagation
819   /// which is the most common case, or for just a name change (used in case
820   /// of ops like MatMul, Transpose, which do not support Mkl layout)
821   enum RewriteCause { kRewriteForLayoutPropagation, kRewriteForOpNameChange };
822 
823   // Get the op rewrite cause depending on whether native format mode
824   // is enabled or not.
GetRewriteCause()825   RewriteCause GetRewriteCause() {
826     if (NativeFormatEnabled()) {
827       return kRewriteForOpNameChange;
828     } else {
829       return kRewriteForLayoutPropagation;
830     }
831   }
832 
833   /// Structure to specify the name of an original node, its new name after
834   /// rewrite, the number of inputs to the original node, the function to
835   /// be used to copy attributes for the op, and the rule (if any) which
836   /// must hold for rewriting the node
837   typedef struct {
838     string name;      // Original name of op of the node in the graph
839     string new_name;  // New name of the op of the node in the graph
840     // A function handler to copy attributes from an old node to a new node.
841     std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs;
842     // A rule under which to rewrite this node
843     std::function<bool(const Node*)> rewrite_rule;
844     // Why are we rewriting?
845     RewriteCause rewrite_cause;
846   } RewriteInfo;
847 
848   /// Structure to specify a forward op, a backward op, and the slot numbers
849   /// in the forward and backward ops where we will add a workspace edge.
850   typedef struct {
851     string fwd_op;    // Name of a forward op in the graph
852     string bwd_op;    // Name of a backward op in the graph
853     int fwd_slot;     // Output slot in the forward op node where actual
854                       // output tensor resides
855     int bwd_slot;     // Input slot in the backward op node where actual
856                       // input tensor resides
857     int ws_fwd_slot;  // Output slot in the forward op node where workspace
858                       // edge is added
859     int ws_bwd_slot;  // Input slot in the backward op node where workspace
860                       // edge is added
861   } WorkSpaceInfo;
862 
863   /// Structure to specify information used in node merge of 2 operators
864   typedef struct {
865     string op1;       // Node string for one operator.
866     string op2;       // Node string for second operator.
867     string new_node;  // Name of the node after merge
868     // Function that enables user of the node merger to specify how to find
869     // second operator given the first operator.
870     std::function<Node*(const Node*)> get_node_to_be_merged;
871   } MergeInfo;
872 
873   // Structure to specify information used in node fusion of 3+ operators
874   typedef struct {
875     std::string pattern_name;  // Name to describe this pattern, such as
876                                // "Transpose_Mklop_Transpose".
877     std::vector<std::function<bool(const Node*)> >
878         node_checkers;  // Extra restriction checker for these ops
879     std::function<Status(
880         std::unique_ptr<Graph>*, std::vector<Node*>&,
881         std::function<void(const Node*, NodeBuilder* nb, bool)>)>
882         fuse_func;
883     std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs;
884   } FusionInfo;
885 
886   //
887   // Dimension indices for 2D tensor.
888   //
889   struct NCHW {
890     enum dim { N = 0, C = 1, H = 2, W = 3 };
891   };
892 
893   struct NHWC {
894     enum dim { N = 0, H = 1, W = 2, C = 3 };
895   };
896 
897   //
898   // dimension indices for 3D tensor.
899   //
900   struct NCDHW {
901     enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 };
902   };
903 
904   struct NDHWC {
905     enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 };
906   };
907 
908   /// Structure to store all constant strings
909   /// NOTE: names are alphabetically sorted.
910   typedef struct {
911     string addn;
912     string add;
913     string add_v2;
914     string avg_pool;
915     string avg_pool_grad;
916     string avg_pool3d;
917     string avg_pool3d_grad;
918     string batch_matmul;
919     string batch_matmul_v2;
920     string bias_add;
921     string bias_add_grad;
922     string concat;
923     string concatv2;
924     string conjugate_transpose;
925     string conv2d;
926     string conv2d_with_bias;
927     string conv2d_grad_input;
928     string conv2d_grad_filter;
929     string conv2d_grad_filter_with_bias;
930     string conv3d;
931     string conv3d_grad_input;
932     string conv3d_grad_filter;
933     string depthwise_conv2d;
934     string depthwise_conv2d_grad_input;
935     string depthwise_conv2d_grad_filter;
936     string dequantize;
937     string einsum;
938     string fused_batch_norm;
939     string fused_batch_norm_grad;
940     string fused_batch_norm_ex;
941     string fused_batch_norm_v2;
942     string fused_batch_norm_grad_v2;
943     string fused_batch_norm_v3;
944     string fused_batch_norm_grad_v3;
945     string fused_conv2d;
946     string fused_depthwise_conv2d;
947     string fused_matmul;
948     string identity;
949     string leakyrelu;
950     string leakyrelu_grad;
951     string lrn;
952     string lrn_grad;
953     string matmul;
954     string max_pool;
955     string max_pool_grad;
956     string max_pool3d;
957     string max_pool3d_grad;
958     string maximum;
959     string mkl_conv2d;
960     string mkl_conv2d_grad_input;
961     string mkl_conv2d_grad_filter;
962     string mkl_conv2d_grad_filter_with_bias;
963     string mkl_conv2d_with_bias;
964     string mkl_depthwise_conv2d_grad_input;
965     string mkl_depthwise_conv2d_grad_filter;
966     string mkl_fused_batch_norm_ex;
967     string mkl_fused_conv2d;
968     string mkl_fused_depthwise_conv2d;
969     string mkl_fused_matmul;
970     string mkl_native_conv2d_with_bias;
971     string mkl_native_conv2d_grad_filter_with_bias;
972     string mkl_native_fused_batch_norm_ex;
973     string mkl_native_fused_conv2d;
974     string mkl_native_fused_depthwise_conv2d;
975     string mkl_native_fused_matmul;
976     string mkl_native_pad_with_conv2d;
977     string mkl_native_pad_with_fused_conv2d;
978     string mkl_pad_with_conv2d;
979     string mkl_pad_with_fused_conv2d;
980     string mul;
981     string pad;
982     string pad_with_conv2d;
983     string pad_with_fused_conv2d;
984     string quantized_avg_pool;
985     string quantized_conv2d;
986     string quantized_conv2d_per_channel;
987     string quantized_conv2d_with_requantize;
988     string quantized_conv2d_with_bias;
989     string quantized_conv2d_with_bias_and_requantize;
990     string quantized_conv2d_and_relu;
991     string quantized_conv2d_and_relu_and_requantize;
992     string quantized_conv2d_with_bias_and_relu;
993     string quantized_conv2d_with_bias_and_relu_and_requantize;
994     string quantized_concatv2;
995     string quantized_max_pool;
996     string quantized_conv2d_with_bias_sum_and_relu;
997     string quantized_conv2d_with_bias_sum_and_relu_and_requantize;
998     string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize;
999     string quantized_matmul_with_bias;
1000     string quantized_matmul_with_bias_and_relu;
1001     string quantized_matmul_with_bias_and_relu_and_requantize;
1002     string quantized_matmul_with_bias_and_requantize;
1003     string quantized_matmul_with_bias_and_dequantize;
1004     string quantized_depthwise_conv2d;
1005     string quantized_depthwise_conv2d_with_bias;
1006     string quantized_depthwise_conv2d_with_bias_and_relu;
1007     string quantized_depthwise_conv2d_with_bias_and_relu_and_requantize;
1008     string quantize_v2;
1009     string relu;
1010     string relu_grad;
1011     string relu6;
1012     string relu6_grad;
1013     string requantize;
1014     string tanh;
1015     string tanh_grad;
1016     string transpose;
1017     string reshape;
1018     string slice;
1019     string softmax;
1020     string split;
1021     string squared_difference;
1022     string sub;
1023   } ConstStringsInfo;
1024 
1025  private:
1026   /// Maintain info about nodes to rewrite
1027   std::vector<RewriteInfo> rinfo_;
1028 
1029   /// Maintain info about nodes to add workspace edge
1030   std::vector<WorkSpaceInfo> wsinfo_;
1031 
1032   /// Maintain info about nodes to be merged
1033   std::vector<MergeInfo> minfo_;
1034 
1035   /// Maintain info about nodes to be fused
1036   std::vector<FusionInfo> finfo_;
1037 
1038   /// Maintain structure of constant strings
1039   static ConstStringsInfo csinfo_;
1040 
1041  private:
1042   // Is OpDef::ArgDef a list type? It could be N * T or list(type).
1043   // Refer to opdef.proto for details of list type.
ArgIsList(const OpDef::ArgDef & arg) const1044   inline bool ArgIsList(const OpDef::ArgDef& arg) const {
1045     return !arg.type_list_attr().empty() || !arg.number_attr().empty();
1046   }
1047 
1048   // Get length of a list in 'n' if 'arg' is of list type. Refer to
1049   // description of ArgIsList for definition of list type.
GetTensorListLength(const OpDef::ArgDef & arg,const Node * n)1050   inline int GetTensorListLength(const OpDef::ArgDef& arg, const Node* n) {
1051     CHECK_EQ(ArgIsList(arg), true);
1052     int N = 0;
1053     const string attr_name = !arg.type_list_attr().empty()
1054                                  ? arg.type_list_attr()
1055                                  : arg.number_attr();
1056     if (!arg.type_list_attr().empty()) {
1057       std::vector<DataType> value;
1058       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
1059       N = value.size();
1060     } else {
1061       TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
1062     }
1063     return N;
1064   }
1065 
1066   // Can op represented by node 'n' run on DEVICE_CPU?
1067   // Op can run on CPU with MKL if the runtime assigned device or the
1068   // user requested device contains device CPU, or both are empty.
CanOpRunOnCPUDevice(const Node * n)1069   bool CanOpRunOnCPUDevice(const Node* n) {
1070     bool result = true;
1071     string reason;
1072 
1073     // Substring that should be checked for in device name for CPU device.
1074     const char* const kCPUDeviceSubStr = "CPU";
1075     const char* const kXLACPUDeviceSubStr = "XLA_CPU";
1076 
1077     // If Op has been specifically assigned to a non-CPU or XLA_CPU device, then
1078     // No.
1079     if (!n->assigned_device_name().empty() &&
1080         (!absl::StrContains(n->assigned_device_name(), kCPUDeviceSubStr) ||
1081          absl::StrContains(n->assigned_device_name(), kXLACPUDeviceSubStr))) {
1082       result = false;
1083       reason = "Op has been assigned a runtime device that is not CPU.";
1084     }
1085 
1086     // If user has specifically assigned this op to a non-CPU or XLA_CPU device,
1087     // then No.
1088     if (!n->def().device().empty() &&
1089         (!absl::StrContains(n->def().device(), kCPUDeviceSubStr) ||
1090          absl::StrContains(n->def().device(), kXLACPUDeviceSubStr))) {
1091       result = false;
1092       reason = "User has assigned a device that is not CPU.";
1093     }
1094 
1095     if (result == false) {
1096       VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
1097               << n->type_string() << ", reason: " << reason;
1098     }
1099 
1100     // Otherwise Yes.
1101     return result;
1102   }
1103 
1104   // Return a node that can be merged with input node 'n'
1105   //
1106   // @return pointer to the node if we can find such a
1107   // node. Otherwise, it returns nullptr.
1108   Node* CheckForNodeMerge(const Node* n) const;
1109 
1110   // Merge node 'm' with node 'n'.
1111   // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with
1112   // Conv2DBackpropFilter.
1113   //
1114   // Input nodes m and n may be deleted if the call to
1115   // this function is successful. Attempt to use the pointers
1116   // after the call to function may result in undefined behaviors.
1117   //
1118   // @input g - input graph, m - graph node, n - graph node to be merged with m
1119   // @return Status::OK(), if merging is successful and supported.
1120   //         Returns appropriate Status error code otherwise.
1121   //         Graph is updated in case nodes are merged. Otherwise, it is
1122   //         not updated.
1123   Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n);
1124 
1125   // Helper function to merge different nodes
1126   Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n);
1127   Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n);
1128   Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g,
1129                                                   Node* m, Node* n);
1130 
1131   // Find BiasAdd or Conv2D node that can be merged with input node 'm'.
1132   // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be
1133   // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd
1134   // node that can be merged with 'm'.
GetConv2DOrBiasAdd(const Node * m)1135   static Node* GetConv2DOrBiasAdd(const Node* m) {
1136     DCHECK(m);
1137     Node* n = nullptr;
1138 
1139     DataType T_m;
1140     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1141 
1142     // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1143     if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1144 
1145     if (m->type_string() == csinfo_.bias_add) {
1146       // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
1147       TF_CHECK_OK(m->input_node(0, &n));
1148     } else {
1149       CHECK_EQ(m->type_string(), csinfo_.conv2d);
1150       // Go over all output edges and search for BiasAdd Node.
1151       // 0th input of BiasAdd is Conv2D.
1152       for (const Edge* e : m->out_edges()) {
1153         if (!e->IsControlEdge() &&
1154             e->dst()->type_string() == csinfo_.bias_add &&
1155             e->dst_input() == 0) {
1156           n = e->dst();
1157           break;
1158         }
1159       }
1160     }
1161 
1162     if (n == nullptr) {
1163       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1164               << "Conv2D and BiasAdd node for merging. Input node: "
1165               << m->DebugString();
1166     }
1167 
1168     return n;
1169   }
1170 
1171   // Find Pad or Conv2D node that can be merged with input node 'm'.
1172   // If input 'm' is Pad, then check if there exists Conv2D node that can be
1173   // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad
1174   // node that can be merged with 'm'.
GetPadOrConv2D(const Node * m)1175   static Node* GetPadOrConv2D(const Node* m) {
1176     DCHECK(m);
1177     Node* n = nullptr;
1178 
1179     DataType T_m;
1180     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1181 
1182     // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1183     if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1184 
1185     const Node* conv_node;
1186     if (m->type_string() == csinfo_.pad) {
1187       // If m is Pad, then Conv2D is the output of Pad.
1188       for (const Edge* e : m->out_edges()) {
1189         if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) {
1190           n = e->dst();
1191           conv_node = n;
1192           break;
1193         }
1194       }
1195     } else {
1196       DCHECK_EQ(m->type_string(), csinfo_.conv2d);
1197       // If m is conv2D, Go over all input edges
1198       // and search for Pad  Node.
1199       for (const Edge* e : m->in_edges()) {
1200         if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1201           n = e->src();
1202           conv_node = m;
1203           break;
1204         }
1205       }
1206     }
1207     // Check if only VALID type of padding is used
1208     // or not.
1209     if (n != nullptr) {
1210       string padding;
1211       TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1212       if (padding != "VALID")
1213         // Then do not merge.
1214         // Only VALID type of padding in conv op can be
1215         // merged with Pad op.
1216         n = nullptr;
1217     } else {
1218       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1219               << "Pad and Conv2D node for merging. Input node: "
1220               << m->DebugString();
1221     }
1222 
1223     return n;
1224   }
1225 
1226   // Find Pad or _FusedConv2D node that can be merged with input node 'm'.
1227   // If input 'm' is Pad, then check if there exists _FusedConv2D node that can
1228   // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there
1229   // exists Pad node that can be merged with 'm'.
GetPadOrFusedConv2D(const Node * m)1230   static Node* GetPadOrFusedConv2D(const Node* m) {
1231     DCHECK(m);
1232     Node* n = nullptr;
1233 
1234     const Node* conv_node;
1235     if (m->type_string() == csinfo_.pad) {
1236       // If m is Pad, then _FusedConv2D is the output of Pad.
1237       for (const Edge* e : m->out_edges()) {
1238         if (!e->IsControlEdge() &&
1239             e->dst()->type_string() == csinfo_.fused_conv2d) {
1240           n = e->dst();
1241           conv_node = n;
1242           break;
1243         }
1244       }
1245     } else {
1246       DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d);
1247       // If m is _FusedConv2D, Go over all input edges
1248       // and search for Pad node.
1249       for (const Edge* e : m->in_edges()) {
1250         if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1251           n = e->src();
1252           conv_node = m;
1253           break;
1254         }
1255       }
1256     }
1257     // Check if only VALID type of padding is used or not.
1258     if (n != nullptr) {
1259       string padding;
1260       TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1261       if (padding != "VALID") {
1262         // Then do not merge.
1263         n = nullptr;
1264         VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
1265                 << "nodes but cannot merge them. Only conv ops with padding "
1266                 << "type VALID can be merged with Pad op Input node: "
1267                 << m->DebugString();
1268       }
1269     } else {
1270       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1271               << "Pad and _FusedConv2D node for merging. Input node: "
1272               << m->DebugString();
1273     }
1274 
1275     return n;
1276   }
1277 
1278   // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input
1279   // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists
1280   // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad,
1281   // then check if there exists Conv2DBackpropFilter node that can be merged
1282   // with 'm'.
1283   //
1284   // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad
1285   // would look like:
1286   //
1287   // _ = Conv2DBackpropFilter(F, _, G)
1288   // _ = BiasAddGrad(G)
1289   //
1290   // So 1st input of BiasAddGrad connects with 3rd input of
1291   // Conv2DBackpropFilter and vice versa.
GetConv2DBackpropFilterOrBiasAddGrad(const Node * m)1292   static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) {
1293     DCHECK(m);
1294     Node* n = nullptr;
1295     const Node* conv2d_backprop_filter = nullptr;
1296 
1297     DataType T_m;
1298     TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1299 
1300     // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1301     if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1302 
1303     if (m->type_string() == csinfo_.bias_add_grad) {
1304       // Get 1st input 'g' of BiasAddGrad.
1305       Node* g = nullptr;
1306       TF_CHECK_OK(m->input_node(0, &g));
1307       // Now traverse all outgoing edges from g that have destination node as
1308       // Conv2DBackpropFilter.
1309       for (const Edge* e : g->out_edges()) {
1310         if (!e->IsControlEdge() &&
1311             e->dst()->type_string() == csinfo_.conv2d_grad_filter &&
1312             e->dst_input() == 2 /* 3rd input of BackpropFilter */) {
1313           n = e->dst();
1314           conv2d_backprop_filter = n;
1315           break;
1316         }
1317       }
1318     } else {
1319       conv2d_backprop_filter = m;
1320       CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter);
1321       // Get 3rd input 'g' of Conv2DBackpropFilter.
1322       Node* g = nullptr;
1323       TF_CHECK_OK(m->input_node(2, &g));
1324       // Now traverse all outgoing edges from g that have destination node as
1325       // BiasAddGrad.
1326       for (const Edge* e : g->out_edges()) {
1327         if (!e->IsControlEdge() &&
1328             e->dst()->type_string() == csinfo_.bias_add_grad &&
1329             e->dst_input() == 0 /* 1st input of BiasAddGrad */) {
1330           n = e->dst();
1331           break;
1332         }
1333       }
1334     }
1335 
1336     // Do not merge if padding type is EXPLICIT.
1337     // TODO(intel): Support `EXPLICIT` padding for MklConv2DBackpropFilter.
1338     if (conv2d_backprop_filter != nullptr) {
1339       string padding;
1340       TF_CHECK_OK(
1341           GetNodeAttr(conv2d_backprop_filter->def(), "padding", &padding));
1342       if (padding == "EXPLICIT") {
1343         // Then do not merge.
1344         VLOG(1) << "MklLayoutRewritePass: Could match Conv2DBackpropFilter "
1345                 << "and BiasAddGrad nodes but cannot merge them. "
1346                 << "EXPLICIT padding is not supported now. "
1347                 << conv2d_backprop_filter->DebugString();
1348         return nullptr;
1349       }
1350     }
1351 
1352     if (n == nullptr) {
1353       VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1354               << "Conv2DBackpropFilter and BiasAddGrad node for merging. "
1355               << "Input node: " << m->DebugString();
1356     }
1357     return n;
1358   }
1359 
1360   // Return a node that can be fused with input node 'n'
1361   //
1362   // @return tuple. If we can find such nodes, the first
1363   // element of the tuple is a true. Otherwise, it's false.
1364   std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
1365   CheckForNodeFusion(Node* n) const;
1366 
1367   // Fuse nodes in the vector "nodes"
1368   Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1369                   const MklLayoutRewritePass::FusionInfo fi);
1370 
1371   // Fuse transpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into
1372   // mklop("NCHW").
1373   // Here "mklop" can be any MKL-DNN supported op, such as Conv2D.
1374   static Status FuseTransposeMklOpTranspose(
1375       std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1376       std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
1377       string data_format);
1378 
CheckForTranspose(const Node * node,std::vector<int> perm)1379   static bool CheckForTranspose(const Node* node, std::vector<int> perm) {
1380     // Check if node's type is "Transpose"
1381     if (node->type_string() != "Transpose") return false;
1382 
1383     // If "Transpose" has multiple output data edges, also don't fuse it.
1384     if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false;
1385 
1386     // Check if has out control edge. If true, this is a training graph.
1387     // Currently we focus on inference and do no fusion in training.
1388     // Note: this constraint will eventually be removed, if we enabled this
1389     // fusion for training
1390     // in the future.
1391     for (const Edge* e : node->out_edges()) {
1392       if (e->IsControlEdge()) {
1393         return false;
1394       }
1395     }
1396 
1397     // If "Transpose" has input control edges, don't fuse on it.
1398     for (const Edge* e : node->in_edges()) {
1399       if (e->IsControlEdge()) {
1400         return false;
1401       }
1402     }
1403 
1404     // We compared the tensor containing the permutation order ("perm_node")
1405     // with our desired order ("perm"). If they're exactly match, this check
1406     // succeed and returns true.
1407     for (const Edge* e : node->in_edges()) {
1408       if (!e->IsControlEdge()) {
1409         const Node* perm_node = e->src();
1410 
1411         const int kPermTensorIndex = 1;
1412         if (perm_node->type_string() == "Const" &&
1413             e->dst_input() == kPermTensorIndex) {
1414           // we find the "perm" node, now try to retrieve its value.
1415           const TensorProto* proto = nullptr;
1416           TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value", &proto));
1417 
1418           DataType type;
1419           TF_CHECK_OK(GetNodeAttr(perm_node->def(), "dtype", &type));
1420 
1421           Tensor tensor;
1422           if (!tensor.FromProto(*proto)) {
1423             TF_CHECK_OK(errors::InvalidArgument(
1424                 "Could not construct Tensor from TensorProto in node: ",
1425                 node->name()));
1426             return false;
1427           }
1428           // Current fusion only supports 4D or 5D tensors according to `perm`
1429           // vector, return false otherwise.
1430           if (tensor.dim_size(0) != perm.size()) return false;
1431           DCHECK_EQ(tensor.dims(), 1);
1432           if (type == DT_INT32) {
1433             const auto tensor_content = tensor.flat<int>().data();
1434             for (int i = 0; i < perm.size(); ++i)
1435               if (tensor_content[i] != perm[i]) return false;
1436             return true;
1437           } else if (type == DT_INT64) {
1438             const auto tensor_content = tensor.flat<int64>().data();
1439             for (int i = 0; i < perm.size(); ++i)
1440               if (tensor_content[i] != perm[i]) return false;
1441             return true;
1442           }
1443           return false;
1444         }
1445       }
1446     }
1447     return false;
1448   }
1449 
CheckForMklOp(const Node * node,string name="")1450   static bool CheckForMklOp(const Node* node, string name = "") {
1451     if (node == nullptr) return false;
1452 
1453     if (!name.empty() && node->type_string() != name) {
1454       return false;
1455     }
1456 
1457     // if mklop has multiple outputs, don't fuse it.
1458     if (node->num_outputs() > 1) return false;
1459 
1460     if (node->out_edges().size() > 1) return false;
1461 
1462     DataType T;
1463     TF_CHECK_OK(GetNodeAttr(node->def(), "T", &T));
1464     return mkl_op_registry::IsMklOp(
1465         mkl_op_registry::GetMklOpName(node->type_string()), T);
1466   }
1467 
1468   // Check if the node 'n' has any applicable rewrite rule
1469   // We check for 2 scenarios for rewrite.
1470   //
1471   // @return RewriteInfo* for the applicable rewrite rule
1472   const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
1473   const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const;
1474 
1475   // Default rewrite rule to be used in scenario 1 for rewrite.
1476   // @return - true (since we want to always rewrite)
AlwaysRewrite(const Node * n)1477   static bool AlwaysRewrite(const Node* n) { return true; }
1478 
1479   // Rewrite rule which considers "context" of the current node to decide if we
1480   // should rewrite. By "context" we currently mean all the inputs of current
1481   // node. The idea is if none of the inputs of current node are not MKL nodes,
1482   // then rewriting current node to MKL node _may not_ offer any performance
1483   // improvement.
1484   //
1485   // One such case is element-wise ops. For such ops, we reuse the Eigen
1486   // implementation and pass the MKL metadata tensor through so we can avoid
1487   // conversions. However, if all incoming edges are in TF format, we don't
1488   // need all this overhead, so replace the elementwise node only if at least
1489   // one of its parents is a MKL node.
1490   //
1491   // More generally, all memory- or IO-bound ops (such as Identity) may fall
1492   // under this category.
1493   //
1494   // @input - Input graph node to be rewritten
1495   // @return - true if node is to be rewritten as MKL node; false otherwise.
RewriteIfAtleastOneMklInput(const Node * n)1496   static bool RewriteIfAtleastOneMklInput(const Node* n) {
1497     DataType T;
1498     if (GetNodeAttr(n->def(), "T", &T).ok() &&
1499         mkl_op_registry::IsMklOp(
1500             mkl_op_registry::GetMklOpName(n->type_string()), T)) {
1501       for (auto e : n->in_edges()) {
1502         if (e->IsControlEdge()) continue;
1503         if (mkl_op_registry::IsMklOp(e->src())) {
1504           return true;
1505         }
1506       }
1507     }
1508     return false;
1509   }
1510 
MatMulRewrite(const Node * n)1511   static bool MatMulRewrite(const Node* n) {
1512     DataType T;
1513     TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
1514     if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) {
1515       VLOG(2) << "Rewriting MatMul to _MklMatMul";
1516       return true;
1517     }
1518     return false;
1519   }
1520   // For oneDNN, only int32 is supported for axis data type
ConcatV2Rewrite(const Node * n)1521   static bool ConcatV2Rewrite(const Node* n) {
1522     DataType T;
1523     TF_CHECK_OK(GetNodeAttr(n->def(), "Tidx", &T));
1524     return (T == DT_INT32);
1525   }
1526 
DequantizeRewrite(const Node * n)1527   static bool DequantizeRewrite(const Node* n) {
1528     DCHECK(n);
1529     Node* input = nullptr;
1530     TF_CHECK_OK(n->input_node(0, &input));
1531     string mode_string;
1532     TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1533     if (mode_string != "SCALED") {
1534       VLOG(1) << "DequantizeRewrite: Mode is not SCALED. "
1535               << "This case is not optimized by Intel MKL kernel, thus using "
1536                  "Eigen op for Dequantize op.";
1537       return false;
1538     }
1539     if (input->IsConstant()) {
1540       VLOG(1) << "DequantizeRewrite: Trying to dequantize a Const node which "
1541               << "could possibly be a filter. "
1542               << "This case is not supported by Intel MKL kernel, thus using "
1543                  "Eigen op for Dequantize op.";
1544       return false;
1545     }
1546     return true;
1547   }
1548 
1549   // Rewrite rule for _FusedMatMul.
1550   // @return - true (no transpose attribute for input 1);
1551   //           false otherwise.
FusedMatMulRewrite(const Node * n)1552   static bool FusedMatMulRewrite(const Node* n) {
1553     bool trans_a;
1554 
1555     // Do not rewrite with transpose attribute because reorder has performance
1556     // impact.
1557     TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &trans_a));
1558 
1559     return !trans_a;
1560   }
1561 
1562   // Check if we are performing pooling on depth or batch. If it is, then we
1563   // do not rewrite MaxPool node to Mkl version.
1564   // @return - true (if it is not a depth/batch wise pooling case);
1565   //           false otherwise.
NonDepthBatchWisePoolRewrite(const Node * n)1566   static bool NonDepthBatchWisePoolRewrite(const Node* n) {
1567     DCHECK(n);
1568 
1569     string data_format_str;
1570     TensorFormat data_format;
1571     std::vector<int32> ksize, strides;
1572     TF_CHECK_OK(GetNodeAttr(n->def(), "ksize", &ksize));
1573     TF_CHECK_OK(GetNodeAttr(n->def(), "strides", &strides));
1574     TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format_str));
1575     bool result = FormatFromString(data_format_str, &data_format);
1576     DCHECK(result);
1577 
1578     // Condition that specifies non-batch-wise and non-depth-wise pooling.
1579     if (GetTensorDim(ksize, data_format, 'N') == 1 &&
1580         GetTensorDim(strides, data_format, 'N') == 1 &&
1581         GetTensorDim(ksize, data_format, 'C') == 1 &&
1582         GetTensorDim(strides, data_format, 'C') == 1) {
1583       return true;
1584     }
1585 
1586     return false;
1587   }
1588 
1589   // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
1590   // path. The unoptimized path is slow. Thus we don't rewrite the node
1591   // and use default Eigen. But for depth_radius=2, MKL DNN optimized
1592   // path is taken, i.e., eigen node is rewritten by MKl DNN node.
LrnRewrite(const Node * n)1593   static bool LrnRewrite(const Node* n) {
1594     DCHECK(n);
1595 
1596     int depth_radius;
1597     TF_CHECK_OK(GetNodeAttr(n->def(), "depth_radius", &depth_radius));
1598 
1599     // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
1600     // and use eigen node instead
1601     if (depth_radius == 2) {
1602       return true;
1603     }
1604     VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
1605             << "case is not optimized by Intel MKL, thus using Eigen op"
1606             << "for LRN ";
1607 
1608     return false;
1609   }
1610 
LrnGradRewrite(const Node * n)1611   static bool LrnGradRewrite(const Node* n) {
1612     DCHECK(n);
1613     bool do_rewrite = false;
1614 
1615     for (const Edge* e : n->in_edges()) {
1616       // Rewrite only if there is corresponding LRN, i.e workspace is available
1617       if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
1618           e->src()->type_string() ==
1619               mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
1620           e->src_output() == 0) {
1621         do_rewrite = true;
1622         break;
1623       }
1624     }
1625     return do_rewrite;
1626   }
1627 
1628   // MKL-DNN's LeakyRelu(feature) = feature          (if feature > 0), or
1629   //                                feature * alpha  (otherwise),
1630   // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha).
1631   // These two algorithms are not consistent when alpha > 1,
1632   // so we only rewrite LeakyRelu to MKL OP when alpha <= 1.
LeakyReluRewrite(const Node * n)1633   static bool LeakyReluRewrite(const Node* n) {
1634     DCHECK(n);
1635 
1636     float alpha;
1637     bool has_attr = TryGetNodeAttr(n->def(), "alpha", &alpha);
1638     DCHECK(has_attr);
1639 
1640     // If the alpha of LeakyRelu is less than 1, rewrite the node.
1641     // Otherwise eigen node is used instead.
1642     if (alpha <= 1) {
1643       return true;
1644     }
1645     VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 "
1646             << "which case is not optimized by Intel MKL, thus using Eigen op"
1647             << "for LeakyRelu ";
1648 
1649     return false;
1650   }
1651 
QuantizeOpRewrite(const Node * n)1652   static bool QuantizeOpRewrite(const Node* n) {
1653     DCHECK(n);
1654     Node* filter_node = nullptr;
1655     TF_CHECK_OK(n->input_node(0, &filter_node));
1656     bool narrow_range = false;
1657     int axis = -1;
1658     string mode_string;
1659     string round_mode_string;
1660     DataType type;
1661     TryGetNodeAttr(n->def(), "narrow_range", &narrow_range);
1662     TryGetNodeAttr(n->def(), "axis", &axis);
1663     TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1664     TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode", &round_mode_string));
1665     TF_CHECK_OK(GetNodeAttr(n->def(), "T", &type));
1666 
1667     if (narrow_range) {
1668       VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization."
1669               << "This case is not optimized by Intel MKL, "
1670               << "thus using Eigen op for Quantize op ";
1671       return false;
1672     }
1673     if (axis != -1) {
1674       VLOG(1) << "QuantizeOpRewrite: dimension is specified for "
1675               << "per slice quantization."
1676               << "This case is not optimized by Intel MKL, "
1677               << "thus using Eigen op for Quantize op ";
1678       return false;
1679     }
1680     if (!((mode_string == "SCALED" && round_mode_string == "HALF_TO_EVEN") ||
1681           (mode_string == "MIN_FIRST"))) {
1682       VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED or MIN_FIRST and/or"
1683               << "rounding mode is not HALF_TO_EVEN. "
1684               << "This case is not optimized by Intel MKL, thus using Eigen op"
1685               << "for Quantize op ";
1686       return false;
1687     }
1688     if (filter_node->IsConstant()) {
1689       VLOG(1) << "QuantizeOpRewrite: Trying to quantize a node which "
1690               << "is a constant. "
1691               << "This case is not supported by the kernel, thus using Eigen op"
1692               << "for Quantize op ";
1693 
1694       return false;
1695     }
1696     if (mode_string == "MIN_FIRST") {
1697       if (type != DT_QUINT8) {
1698         VLOG(1) << "QuantizeOpRewrite: For MIN_FIRST mode the data type is "
1699                 << "not DT_UINT8. This case is not optimized by Intel MKL, "
1700                 << "thus using Eigen op for Quantize op ";
1701         return false;
1702       }
1703     }
1704     return true;
1705   }
1706 
MaxpoolGradRewrite(const Node * n)1707   static bool MaxpoolGradRewrite(const Node* n) {
1708     DCHECK(n);
1709     bool do_rewrite = false;
1710     for (const Edge* e : n->in_edges()) {
1711       // Rewrite only if there is corresponding Maxpool, i.e workspace is
1712       // available
1713       if (e->dst()->type_string() == csinfo_.max_pool_grad &&
1714           e->dst_input() == 1 &&
1715           e->src()->type_string() ==
1716               mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
1717           e->src_output() == 0) {
1718         do_rewrite = true;
1719         break;
1720       }
1721     }
1722     return do_rewrite;
1723   }
1724 
Maxpool3DGradRewrite(const Node * n)1725   static bool Maxpool3DGradRewrite(const Node* n) {
1726     DCHECK(n);
1727     for (const Edge* e : n->in_edges()) {
1728       // Rewrite only if there is corresponding Maxpool3D, i.e., workspace is
1729       // available
1730       if (e->dst()->type_string() == csinfo_.max_pool3d_grad &&
1731           e->dst_input() == 1 &&
1732           e->src()->type_string() ==
1733               mkl_op_registry::GetMklOpName(csinfo_.max_pool3d) &&
1734           e->src_output() == 0) {
1735         return true;
1736       }
1737     }
1738     return false;
1739   }
1740 
FusedBatchNormV3Rewrite(const Node * n)1741   static bool FusedBatchNormV3Rewrite(const Node* n) {
1742     DCHECK(n);
1743     if (Check5DFormat(n->def())) {
1744       VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not "
1745               << "support 5D tensors.";
1746       return false;
1747     }
1748     return true;
1749   }
1750 
FusedBatchNormExRewrite(const Node * n)1751   static bool FusedBatchNormExRewrite(const Node* n) {
1752     DCHECK(n);
1753 
1754     int num_side_inputs;
1755     TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs));
1756     string activation_mode;
1757     TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode));
1758 
1759     // if the num_side_inputs is not 0, don't rewrite the node.
1760     if (num_side_inputs != 0) {
1761       VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs"
1762               << "larger than 0 is not optimized by Intel MKL.";
1763       return false;
1764     }
1765 
1766     // if the activation_mode is not 'Relu', don't rewrite the node.
1767     if (activation_mode != "Relu") {
1768       VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is"
1769               << "supported by Intel MKL.";
1770       return false;
1771     }
1772 
1773     return true;
1774   }
1775 
FusedConv2DRewrite(const Node * n)1776   static bool FusedConv2DRewrite(const Node* n) {
1777     // MKL DNN currently doesn't support all fusions that grappler fuses
1778     // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
1779     // it includes those we support.
1780     DataType T;
1781     if (!TryGetNodeAttr(n->def(), "T", &T) ||
1782         !mkl_op_registry::IsMklOp(NativeFormatEnabled()
1783                                       ? csinfo_.mkl_native_fused_conv2d
1784                                       : csinfo_.mkl_fused_conv2d,
1785                                   T)) {
1786       return false;
1787     }
1788 
1789     std::vector<string> fused_ops;
1790     TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1791     return (fused_ops == std::vector<string>{"BiasAdd"} ||
1792             fused_ops == std::vector<string>{"Relu"} ||
1793             fused_ops == std::vector<string>{"Relu6"} ||
1794             fused_ops == std::vector<string>{"Elu"} ||
1795             fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1796             fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1797             fused_ops == std::vector<string>{"BiasAdd", "Elu"} ||
1798             fused_ops == std::vector<string>{"BiasAdd", "Add"} ||
1799             fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"} ||
1800             fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"} ||
1801             fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"} ||
1802             fused_ops == std::vector<string>{"LeakyRelu"} ||
1803             fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"} ||
1804             fused_ops == std::vector<string>{"BiasAdd", "Add", "LeakyRelu"} ||
1805             fused_ops == std::vector<string>{"FusedBatchNorm"} ||
1806             fused_ops == std::vector<string>{"FusedBatchNorm", "Relu"} ||
1807             fused_ops == std::vector<string>{"FusedBatchNorm", "Relu6"} ||
1808             fused_ops == std::vector<string>{"FusedBatchNorm", "Elu"} ||
1809             fused_ops == std::vector<string>{"FusedBatchNorm", "LeakyRelu"});
1810   }
1811 
FusedDepthwiseConv2DRewrite(const Node * n)1812   static bool FusedDepthwiseConv2DRewrite(const Node* n) {
1813     // MKL DNN currently doesn't support all fusions that grappler fuses
1814     // together with DepthwiseConv2D (ex. batchnorm). We rewrite
1815     // _FusedDepthwiseConv2DNative only if it includes those we support.
1816     DataType T;
1817     if (!TryGetNodeAttr(n->def(), "T", &T) ||
1818         !mkl_op_registry::IsMklOp(
1819             NativeFormatEnabled() ? csinfo_.mkl_native_fused_depthwise_conv2d
1820                                   : csinfo_.mkl_fused_depthwise_conv2d,
1821             T)) {
1822       return false;
1823     }
1824 
1825     std::vector<string> fused_ops;
1826     TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1827     return (fused_ops == std::vector<string>{"BiasAdd"} ||
1828             fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1829             fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1830             fused_ops == std::vector<string>{"BiasAdd", "Elu"});
1831   }
1832 
1833   // Rewrites input node to a new node specified by its matching rewrite info.
1834   //
1835   // Method first searches matching rewrite info for input node and then
1836   // uses that info to rewrite.
1837   //
1838   // Input node may be deleted in case of rewrite. Attempt to use the node
1839   // after the call can result in undefined behaviors.
1840   //
1841   // @input  g - input graph, n - Node to be rewritten,
1842   //         ri - matching rewriteinfo
1843   // @return Status::OK(), if the input node is rewritten;
1844   //         Returns appropriate Status error code otherwise.
1845   //         Graph is updated in case the input node is rewritten.
1846   //         Otherwise, it is not updated.
1847   Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
1848 
1849   // Rewrites input node to just change its operator name. The number of
1850   // inputs to the node and the number of outputs remain the same. Attributes
1851   // of the new node could be copied from attributes of the old node or
1852   // modified. copy_attrs field of RewriteInfo controls this.
1853   //
1854   // Conceptually, it allows us to rewrite:
1855   //
1856   //        f[a=v1,b=v2](x,y) -> g[a'=v3,b'=v4](x,y)
1857   //
1858   // Attributes can be altered without any restrictions --- they could be
1859   // copied, modified, or deleted completely.
1860   //
1861   // @input  g - input graph, orig_node - Node to be rewritten,
1862   //         ri - matching rewriteinfo
1863   // @output new_node - points to newly created node
1864   // @return Status::OK(), if the input node is rewritten;
1865   //         Returns appropriate Status error code otherwise.
1866   //         Graph is only updated when the input node is rewritten.
1867   Status RewriteNodeForJustOpNameChange(std::unique_ptr<Graph>* g,
1868                                         const Node* orig_node, Node** new_node,
1869                                         const RewriteInfo* ri);
1870 
1871   // Rewrites input node to enable MKL layout propagation. Please also refer to
1872   // documentation for the function RewriteNodeForJustOpNameChange() to
1873   // understand what it means.
1874   //
1875   // @input  g - input graph, orig_node - Node to be rewritten,
1876   //         ri - matching rewriteinfo
1877   // @output new_node - points to newly created node
1878   // @return Status::OK(), if the input node is rewritten;
1879   //         Returns appropriate Status error code otherwise.
1880   //         Graph is updated in case the input node is rewritten.
1881   //         Otherwise, it is not updated.
1882   Status RewriteNodeForLayoutPropagation(std::unique_ptr<Graph>* g,
1883                                          const Node* orig_node, Node** new_node,
1884                                          const RewriteInfo* ri);
1885 
1886   // Get nodes that will feed a list of TF tensors to the new
1887   // node that we are constructing.
1888   //
1889   // @input g - input graph,
1890   // @input inputs - inputs to old node that we are using for constructing
1891   //                 new inputs,
1892   // @input input_idx - the index in the 'inputs' vector pointing to the
1893   //                    current input that we have processed so far
1894   // @output input_idx - index will be incremented by the number of nodes
1895   //                     from 'inputs' that are processed
1896   // @input list_length - The expected length of list of TF tensors
1897   // @output output_nodes - the list of new nodes creating TF tensors
1898   //
1899   // @return None
1900   void GetNodesProducingTFTensorList(
1901       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1902       int* input_idx, int list_length,
1903       std::vector<NodeBuilder::NodeOut>* output_nodes);
1904 
1905   // Get nodes that will feed a list of Mkl tensors to the new
1906   // node that we are constructing.
1907   //
1908   // @input g - input graph,
1909   // @input orig_node - Original node that we are rewriting
1910   // @input inputs - inputs to old node that we are using for constructing
1911   //                 new inputs,
1912   // @input input_idx - the index in the 'inputs' vector pointing to the
1913   //                    current input that we have processed so far
1914   // @output input_idx - index will be incremented by the number of nodes
1915   //                     from 'inputs' that are processed
1916   // @input list_length - The expected length of list of Mkl tensors
1917   // @output output_nodes - the list of new nodes creating Mkl tensors
1918   //
1919   // @return None
1920   void GetNodesProducingMklTensorList(
1921       std::unique_ptr<Graph>* g, const Node* orig_node,
1922       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1923       int* input_idx, int list_length,
1924       std::vector<NodeBuilder::NodeOut>* output_nodes);
1925 
1926   // Get a node that will feed an Mkl tensor to the new
1927   // node that we are constructing. The output node could be (1) 'n'
1928   // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
1929   // if 'n' is not an Mkl layer.
1930   //
1931   // @input g - input graph,
1932   // @input orig_node - Original node that we are rewriting,
1933   // @input n - Node based on which we are creating Mkl node,
1934   // @input n_output_slot - the output slot of node 'n'
1935   //            which is feeding to the node that we are constructing
1936   // @output mkl_node - the new node that will feed Mkl tensor
1937   // @output mkl_node_output_slot - the slot number of mkl_node that
1938   //                                will feed the tensor
1939   // @return None
1940   void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
1941                                  const Node* orig_node, Node* n,
1942                                  int n_output_slot, Node** mkl_node,
1943                                  int* mkl_node_output_slot);
1944 
1945   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1946   // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
1947   // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
1948   // producing workspace edges if 'are_workspace_tensors_available' is true.
1949   // Otherwise, 'workspace_tensors' is empty vector.
1950   //
1951   // For details, refer to 'Ordering of inputs after rewriting' section in the
1952   // documentation above.
1953   //
1954   // Returns Status::OK() if setting up inputs is successful, otherwise
1955   // returns appropriate status code.
1956   int SetUpContiguousInputs(
1957       std::unique_ptr<Graph>* g,
1958       const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1959       NodeBuilder* nb, const Node* old_node,
1960       std::vector<NodeBuilder::NodeOut>* workspace_tensors,
1961       bool are_workspace_tensors_available);
1962 
1963   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1964   // in graph 'g'. Original node is input in 'orig_node'.
1965   //
1966   // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
1967   // section in the documentation above.
1968   //
1969   // Returns Status::OK() if setting up inputs is successful, otherwise
1970   // returns appropriate status code.
1971   Status SetUpInputs(std::unique_ptr<Graph>* g,
1972                      const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1973                      NodeBuilder* nb, const Node* orig_node);
1974 
1975   // Create new inputs by copying old inputs 'inputs' for the rewritten node
1976   // in 'nb' in graph 'g'. Original node is input in 'orig_node'. This is mostly
1977   // used in the context of rewrite for just operator name change in which
1978   // inputs of old operator and new operator are same.
1979   //
1980   // Returns Status::OK() if setting up inputs is successful, otherwise
1981   // returns appropriate status code.
1982   Status CopyInputs(const Node* orig_node,
1983                     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1984                     NodeBuilder* nb);
1985 
1986   // Add workspace edge on the input or output side of Node 'orig_node' by using
1987   // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
1988   // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
1989   // tensors, if they need to be added, will be set into these tensors.
1990   // If we set workspace tensors, then are_ws_tensors_added should be true.
1991   void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
1992                                 const Node* orig_node, NodeBuilder* nb,
1993                                 std::vector<NodeBuilder::NodeOut>* ws_tensors,
1994                                 bool* are_ws_tensors_added);
1995 
1996   // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
1997   // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
1998   // 'g'. Returns true if fixup was done; otherwise, it returns false.
1999   bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data,
2000                                   const Edge* e_metadata);
2001 
2002   // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
2003   // connected? If not, then fix them. This is needed because a graph may have
2004   // some input Mkl metadata edges incorrectly setup after node merge and
2005   // rewrite passes. This could happen because GetReversePostOrder function may
2006   // not provide topologically sorted order if a graph contains cycles. The
2007   // function returns true if at least one Mkl metadata edge for node 'n' was
2008   // fixed. Otherwise, it returns false.
2009   //
2010   // Example:
2011   //
2012   // X = MklConv2D(_, _, _)
2013   // Y = MklConv2DWithBias(_, _, _, _, _, _)
2014   // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
2015   //
2016   // For a graph such as shown above, note that 3rd argument of MklAdd contains
2017   // DummyMklTensor. Actually, it should be getting the Mkl metadata from
2018   // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
2019   // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
2020   // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
2021   // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
2022   // data edges (1st and 2nd arguments of MklAdd).
2023   bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
2024 
2025   // Functions specific to operators to copy attributes
2026   // We need operator-specific function to copy attributes because the framework
2027   // does not provide any generic function for it.
2028   // NOTE: names are alphabetically sorted.
2029   static void CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2030                            bool change_format = false);
2031   static void CopyAttrsAllCheckConstFilter(const Node* orig_node,
2032                                            NodeBuilder* nb,
2033                                            bool change_format = false);
2034 
2035   static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2036                             bool change_format = false);
2037   static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
2038                                             NodeBuilder* nb,
2039                                             bool change_format = false);
2040   static void CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2041                                         const Node* orig_node2, NodeBuilder* nb,
2042                                         bool change_format = false);
2043   static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1,
2044                                              const Node* orig_node2,
2045                                              NodeBuilder* nb,
2046                                              bool change_format = false);
2047   static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb,
2048                                        bool change_format = false);
2049   static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb,
2050                                   const std::vector<int32>& strides,
2051                                   const std::vector<int32>& dilations,
2052                                   bool change_format = false);
2053 
2054   static void CopyAttrsQuantizedMatMulWithBias(const Node* orig_node,
2055                                                NodeBuilder* nb,
2056                                                bool change_format = false);
2057   static void CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2058       const Node* orig_node, NodeBuilder* nb, bool change_format = false);
2059   static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb,
2060                                bool change_format = false);
2061 
2062   // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
2063   // using node for original node 'orig_node' and return it in '*out'.
2064   // TODO(nhasabni) We should move this to mkl_util.h
2065   void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
2066                              const Node* orig_node);
2067   void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
2068                                    const Node* orig_node);
2069 };
2070 
2071 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
2072 
2073 // We register Mkl rewrite pass for phase 1 in post partitioning group.
2074 // We register it here so that we get a complete picture of all users of Mkl
2075 // nodes. Do not change the ordering of the Mkl passes.
2076 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
2077     OptimizationPassRegistry::POST_PARTITIONING;
2078 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
2079 
2080 //////////////////////////////////////////////////////////////////////////
2081 //           Helper functions for creating new node
2082 //////////////////////////////////////////////////////////////////////////
2083 
FillInputs(const Node * n,gtl::InlinedVector<Node *,4> * control_edges,gtl::InlinedVector<std::pair<Node *,int>,4> * in)2084 static void FillInputs(const Node* n,
2085                        gtl::InlinedVector<Node*, 4>* control_edges,
2086                        gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
2087   control_edges->clear();
2088   for (const Edge* e : n->in_edges()) {
2089     if (e->IsControlEdge()) {
2090       control_edges->push_back(e->src());
2091     } else {
2092       (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
2093     }
2094   }
2095   std::sort(control_edges->begin(), control_edges->end());
2096 }
2097 
GetNodesProducingTFTensorList(const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2098 void MklLayoutRewritePass::GetNodesProducingTFTensorList(
2099     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2100     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2101   CHECK_LT(*input_idx, inputs.size());
2102   CHECK_GT(list_length, 0);
2103   DCHECK(output_nodes);
2104   output_nodes->reserve(list_length);
2105 
2106   while (list_length != 0) {
2107     CHECK_GT(list_length, 0);
2108     CHECK_LT(*input_idx, inputs.size());
2109     Node* n = inputs[*input_idx].first;
2110     int slot = inputs[*input_idx].second;
2111     // If input node 'n' is just producing a single tensor at
2112     // output slot 'slot' then we just add that single node.
2113     output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
2114     (*input_idx)++;
2115     list_length--;
2116   }
2117 }
2118 
2119 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyMklTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2120 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
2121                                                  Node** out,
2122                                                  const Node* orig_node) {
2123   // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
2124   // dummy Mkl tensor. 8 = 2*size_t.
2125   const DataType dt = DataTypeToEnum<uint8>::v();
2126   TensorProto proto;
2127   proto.set_dtype(dt);
2128   uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
2129   proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8));
2130   TensorShape dummy_shape({8});
2131   dummy_shape.AsProto(proto.mutable_tensor_shape());
2132   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
2133                   .Attr("value", proto)
2134                   .Attr("dtype", dt)
2135                   .Device(orig_node->def().device())  // We place this node on
2136                                                       // the same device as the
2137                                                       // device of the original
2138                                                       // node.
2139                   .Finalize(&**g, out));
2140   DCHECK(*out);  // Make sure we got a valid object before using it
2141 
2142   // If number of inputs to the original node is > 0, then we add
2143   // control dependency between 1st input (index 0) of the original node and
2144   // the dummy Mkl node. This is needed because control-flow ops such as Enter,
2145   // Merge, etc, require frame_name of the dummy Mkl node to be same as the
2146   // rewritten node. Adding control edge between 1st input of the original node
2147   // and the dummy Mkl node ensures that the dummy node is in the same frame
2148   // as the original node. Choosing 1st input is not necessary - any input of
2149   // the original node is fine because all the inputs of a node are always in
2150   // the same frame.
2151   if (orig_node->num_inputs() > 0) {
2152     Node* orig_input0 = nullptr;
2153     TF_CHECK_OK(
2154         orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
2155     auto edge = (*g)->AddControlEdge(orig_input0, *out, false);
2156     DCHECK(edge != nullptr || DoesControlEdgeExist(orig_input0, *out));
2157   }
2158 
2159   (*out)->set_assigned_device_name(orig_node->assigned_device_name());
2160 }
2161 
GetNodesProducingMklTensorList(std::unique_ptr<Graph> * g,const Node * orig_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2162 void MklLayoutRewritePass::GetNodesProducingMklTensorList(
2163     std::unique_ptr<Graph>* g, const Node* orig_node,
2164     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2165     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2166   CHECK_LT(*input_idx, inputs.size());
2167   CHECK_GT(list_length, 0);
2168   DCHECK(output_nodes);
2169   output_nodes->reserve(list_length);
2170 
2171   while (list_length != 0) {
2172     CHECK_GT(list_length, 0);
2173     CHECK_LT(*input_idx, inputs.size());
2174     Node* n = inputs[*input_idx].first;
2175     int slot = inputs[*input_idx].second;
2176     // If 'n' is producing a single tensor, then create a single Mkl tensor
2177     // node.
2178     Node* mkl_node = nullptr;
2179     int mkl_node_output_slot = 0;
2180     GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
2181                               &mkl_node_output_slot);
2182     output_nodes->push_back(
2183         NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
2184     (*input_idx)++;
2185     list_length--;
2186   }
2187 }
2188 
2189 // Get an input node that will feed Mkl tensor to the new
2190 // node that we are constructing. An input node could be (1) 'n'
2191 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
2192 // if 'n' is not an Mkl layer.
GetNodeProducingMklTensor(std::unique_ptr<Graph> * g,const Node * orig_node,Node * n,int n_output_slot,Node ** mkl_node,int * mkl_node_output_slot)2193 void MklLayoutRewritePass::GetNodeProducingMklTensor(
2194     std::unique_ptr<Graph>* g, const Node* orig_node, Node* n,
2195     int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
2196   DCHECK(n);
2197   DCHECK(mkl_node);
2198   DCHECK(mkl_node_output_slot);
2199 
2200   // If this is an MKL op, then it will create extra output for MKL layout.
2201   DataType T;
2202   if (TryGetNodeAttr(n->def(), "T", &T) &&
2203       mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
2204     // If this is an MKL op, then it will generate an edge that will receive
2205     // Mkl tensor from a node.
2206     // output slot number for Mkl tensor would be N+slot number of TensorFlow
2207     // tensor, where N is total number of TensorFlow tensors.
2208     *mkl_node = n;
2209     *mkl_node_output_slot =
2210         GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
2211   } else {
2212     // If we have not visited the node and rewritten it, then we need
2213     // to create a dummy node that will feed a dummy Mkl tensor to this node.
2214     // DummyMklTensor node has no input and generates only 1 output
2215     // (dummy Mkl tensor) as output slot number 0.
2216     GetDummyMklTensorNode(g, mkl_node, orig_node);
2217     DCHECK(*mkl_node);
2218     *mkl_node_output_slot = 0;
2219   }
2220 }
2221 
SetUpContiguousInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node,std::vector<NodeBuilder::NodeOut> * workspace_tensors,bool are_workspace_tensors_available)2222 int MklLayoutRewritePass::SetUpContiguousInputs(
2223     std::unique_ptr<Graph>* g,
2224     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2225     NodeBuilder* nb, const Node* old_node,
2226     std::vector<NodeBuilder::NodeOut>* workspace_tensors,
2227     bool are_workspace_tensors_available) {
2228   DCHECK(workspace_tensors);
2229   CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2230 
2231   // TODO(nhasabni): Temporary solution to connect filter input of
2232   // BackpropInput with the converted filter from Conv2D.
2233   bool do_connect_conv2d_backprop_input_filter = false;
2234   Node* conv2d_node = nullptr;
2235   // Filter node is 2nd input (slot index 1) of Conv2D.
2236   int kConv2DFilterInputSlotIdx = 1;
2237   int kConv2DBackpropInputFilterInputSlotIdx = 1;
2238   int kConv2DFilterOutputSlotIdx = 1;
2239   if (old_node->type_string() == csinfo_.conv2d_grad_input) {
2240     // We need to find Conv2D node from Conv2DBackpropInput.
2241     // For that let's first find filter node that is 2nd input (slot 1)
2242     // of BackpropInput.
2243     Node* filter_node = nullptr;
2244     TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx,
2245                                      &filter_node));
2246     DCHECK(filter_node);
2247 
2248     // Now check which nodes receive from filter_node. Filter feeds as
2249     // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
2250     // _MklFusedConv2D.
2251     for (const Edge* e : filter_node->out_edges()) {
2252       if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
2253            e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d ||
2254            e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d ||
2255            e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias ||
2256            e->dst()->type_string() == csinfo_.mkl_fused_conv2d) &&
2257           e->dst_input() == kConv2DFilterInputSlotIdx
2258           /* filter is 2nd input of Conv2D and _MklConv2D. */) {
2259         if (conv2d_node != nullptr) {
2260           VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
2261                   << " feeding multiple Conv2D nodes: "
2262                   << filter_node->DebugString();
2263           // We will not connect filter input of Conv2DBackpropInput
2264           // to be safe here.
2265           do_connect_conv2d_backprop_input_filter = false;
2266           break;
2267         } else {
2268           conv2d_node = e->dst();
2269           do_connect_conv2d_backprop_input_filter = true;
2270         }
2271       }
2272     }
2273   }
2274 
2275   // Number of input slots to original op
2276   // Input slots are represented by .Input() calls in REGISTER_OP.
2277   int old_node_input_slots = old_node->op_def().input_arg_size();
2278   int nn_slot_idx = 0;  // slot index for inputs of new node
2279 
2280   // Let's copy all inputs (TF tensors) of original node to new node.
2281   int iidx = 0;
2282   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2283     // An input slot could be a single tensor or a list. We need
2284     // to handle this case accordingly.
2285     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2286     if (ArgIsList(arg)) {
2287       std::vector<NodeBuilder::NodeOut> new_node_inputs;
2288       int tensor_list_length = GetTensorListLength(arg, old_node);
2289       if (tensor_list_length != 0) {
2290         GetNodesProducingTFTensorList(old_node_inputs, &iidx,
2291                                       tensor_list_length, &new_node_inputs);
2292       }
2293       nb->Input(new_node_inputs);
2294       nn_slot_idx++;
2295     } else {
2296       // Special case for connecting filter input of Conv2DBackpropInput
2297       if (do_connect_conv2d_backprop_input_filter &&
2298           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2299         nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
2300       } else {
2301         nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2302       }
2303       iidx++;
2304       nn_slot_idx++;
2305     }
2306   }
2307 
2308   // If workspace tensors are available for this op and we are using
2309   // contiguous ordering then we need to add Tensorflow tensor for
2310   // workspace here because Tensorflow tensor for workspace is the
2311   // last tensor in the list of Tensorflow tensors.
2312   if (are_workspace_tensors_available) {
2313     CHECK_EQ(workspace_tensors->size(), 2);
2314     // Tensorflow tensor
2315     nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
2316     nn_slot_idx++;
2317   }
2318 
2319   // Let's now setup all Mkl inputs to a new node.
2320   // Number of Mkl inputs must be same as number of TF inputs.
2321   iidx = 0;
2322   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2323     // An input slot could be a single tensor or a list. We need
2324     // to handle this case accordingly.
2325     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2326     if (ArgIsList(arg)) {
2327       std::vector<NodeBuilder::NodeOut> new_node_inputs;
2328       int tensor_list_length = GetTensorListLength(arg, old_node);
2329       if (tensor_list_length != 0) {
2330         GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
2331                                        tensor_list_length, &new_node_inputs);
2332       }
2333       nb->Input(new_node_inputs);
2334       nn_slot_idx++;
2335     } else {
2336       Node* mkl_node = nullptr;
2337       int mkl_node_output_slot = 0;
2338       // Special case for connecting filter input of Conv2DBackpropInput
2339       if (do_connect_conv2d_backprop_input_filter &&
2340           iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2341         GetNodeProducingMklTensor(g, old_node, conv2d_node,
2342                                   kConv2DFilterOutputSlotIdx, &mkl_node,
2343                                   &mkl_node_output_slot);
2344       } else {
2345         GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
2346                                   old_node_inputs[iidx].second, &mkl_node,
2347                                   &mkl_node_output_slot);
2348       }
2349       nb->Input(mkl_node, mkl_node_output_slot);
2350       iidx++;
2351       nn_slot_idx++;
2352     }
2353   }
2354 
2355   // If workspace tensors are available for this op and we are using
2356   // contiguous ordering then we need to add Mkl tensor for
2357   // workspace here because Mkl tensor for workspace is the
2358   // last tensor in the list of Mkl tensors.
2359   if (are_workspace_tensors_available) {
2360     CHECK_EQ(workspace_tensors->size(), 2);
2361     // Mkl tensor
2362     nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
2363     nn_slot_idx++;
2364   }
2365 
2366   return nn_slot_idx;
2367 }
2368 
2369 // This method finds out if checking workspace is needed or not. Workspace is
2370 // not used in quantized ops, so checking that would fail as quantized ops
2371 // don't have attribute: "T".
IsWorkspaceCheckNeeded(const Node * node)2372 bool IsWorkspaceCheckNeeded(const Node* node) {
2373   std::vector<string> quant_ops{
2374       "Dequantize",
2375       "QuantizeV2",
2376       "QuantizedConv2D",
2377       "QuantizedConv2DWithBias",
2378       "QuantizedConv2DAndRelu",
2379       "QuantizedConv2DWithBiasAndRelu",
2380       "QuantizedConv2DWithBiasSumAndRelu",
2381       "QuantizedConv2DPerChannel",
2382       "QuantizedConv2DAndRequantize",
2383       "QuantizedConv2DWithBiasAndRequantize",
2384       "QuantizedConv2DAndReluAndRequantize",
2385       "QuantizedConv2DWithBiasAndReluAndRequantize",
2386       "QuantizedConv2DWithBiasSumAndReluAndRequantize",
2387       "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize",
2388       "QuantizedMatMulWithBias",
2389       "QuantizedMatMulWithBiasAndRequantize",
2390       "QuantizedMatMulWithBiasAndDequantize",
2391       "QuantizedMatMulWithBiasAndRelu",
2392       "QuantizedMatMulWithBiasAndReluAndRequantize",
2393       "QuantizedDepthwiseConv2D",
2394       "QuantizedDepthwiseConv2DWithBias",
2395       "QuantizedDepthwiseConv2DWithBiasAndRelu",
2396       "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"};
2397   return std::find(std::begin(quant_ops), std::end(quant_ops),
2398                    node->type_string()) == std::end(quant_ops);
2399 }
2400 
SetUpInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node)2401 Status MklLayoutRewritePass::SetUpInputs(
2402     std::unique_ptr<Graph>* g,
2403     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2404     NodeBuilder* nb, const Node* old_node) {
2405   // Let's check if we need to add workspace tensors for this node.
2406   // We add workspace edge only for MaxPool, LRN and BatchNorm.
2407   std::vector<NodeBuilder::NodeOut> workspace_tensors;
2408   bool are_workspace_tensors_available = false;
2409 
2410   if (IsWorkspaceCheckNeeded(old_node)) {
2411     AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
2412                              &are_workspace_tensors_available);
2413   }
2414 
2415   int new_node_input_slots = 0;
2416   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
2417     // TODO(nhasabni): implement this function just for same of completion.
2418     // We do not use interleaved ordering right now.
2419     return Status(
2420         error::Code::UNIMPLEMENTED,
2421         "Interleaved ordering of tensors is currently not supported.");
2422   } else {
2423     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2424     new_node_input_slots = SetUpContiguousInputs(
2425         g, old_node_inputs, nb, old_node, &workspace_tensors,
2426         are_workspace_tensors_available);
2427   }
2428 
2429   // Sanity check
2430   int old_node_input_slots = old_node->op_def().input_arg_size();
2431   if (!are_workspace_tensors_available) {
2432     // If we are not adding workspace tensors for this op, then the total
2433     // number of input slots to the new node _must_ be 2 times the number
2434     // of input slots to the original node: N original Tensorflow tensors and
2435     // N for Mkl tensors corresponding to each Tensorflow tensors.
2436     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
2437   } else {
2438     // If we are adding workspace tensors for this op, then the total
2439     // The total number of input slots to new node _must_ be 2 times the number
2440     // of input slots to the original node: N original Tensorflow tensors and
2441     // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
2442     // (for workspace Tensorflow tensor and workspace Mkl tensor).
2443     CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
2444   }
2445 
2446   return Status::OK();
2447 }
2448 
CopyInputs(const Node * old_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb)2449 Status MklLayoutRewritePass::CopyInputs(
2450     const Node* old_node,
2451     const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2452     NodeBuilder* nb) {
2453   // Number of input slots to old node
2454   // Input slots are represented by .Input() calls in REGISTER_OP.
2455   int old_node_input_slots = old_node->op_def().input_arg_size();
2456   // Actual number of inputs can be greater than or equal to number
2457   // of Input slots because inputs of type list could be unfolded.
2458   auto old_node_input_size = old_node_inputs.size();
2459   DCHECK_GE(old_node_input_size, old_node_input_slots);
2460 
2461   // Let's copy all inputs of old node to new node.
2462   int iidx = 0;
2463   for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2464     // An input slot could be a single tensor or a list. We need
2465     // to handle this case accordingly.
2466     DCHECK_LT(iidx, old_node_input_size);
2467     const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2468     if (ArgIsList(arg)) {
2469       std::vector<NodeBuilder::NodeOut> new_node_inputs;
2470       int N = GetTensorListLength(arg, old_node);
2471       if (N != 0) {
2472         GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
2473                                       &new_node_inputs);
2474       }
2475       nb->Input(new_node_inputs);
2476     } else {
2477       nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2478       iidx++;
2479     }
2480   }
2481   return Status::OK();
2482 }
2483 
2484 //////////////////////////////////////////////////////////////////////////
2485 //           Helper functions related to workspace pass
2486 //////////////////////////////////////////////////////////////////////////
2487 
2488 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyWorkspaceTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2489 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
2490     std::unique_ptr<Graph>* g, Node** out, const Node* orig_node) {
2491   // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
2492   // workspace tensor.
2493   GetDummyMklTensorNode(g, out, orig_node);
2494 }
2495 
AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph> * g,const Node * orig_node,NodeBuilder * nb,std::vector<NodeBuilder::NodeOut> * ws_tensors,bool * are_ws_tensors_added)2496 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
2497     std::unique_ptr<Graph>* g, const Node* orig_node, NodeBuilder* nb,
2498     std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
2499   bool workspace_edge_added = false;  // Default initializer
2500   DCHECK(are_ws_tensors_added);
2501   *are_ws_tensors_added = false;  // Default initializer
2502 
2503   DataType T;
2504   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2505   for (auto ws : wsinfo_) {
2506     if (orig_node->type_string() == ws.fwd_op &&
2507         mkl_op_registry::IsMklOp(
2508             mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
2509       // If this op is a fwd op, then we need to check if there is an
2510       // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
2511       // an edge, then we just add an attribute on this node for setting
2512       // workspace_passed to true. We don't add actual workspace edge
2513       // in this node. Actual workspace edge gets added in the backward
2514       // op for this node.
2515       for (const Edge* e : orig_node->out_edges()) {
2516         if (e->src_output() == ws.fwd_slot &&
2517             e->dst()->type_string() == ws.bwd_op &&
2518             e->dst_input() == ws.bwd_slot) {
2519           nb->Attr("workspace_enabled", true);
2520           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2521                   << orig_node->type_string();
2522           workspace_edge_added = true;
2523           // We found the edge that we were looking for, so break.
2524           break;
2525         }
2526       }
2527 
2528       if (!workspace_edge_added) {
2529         // If we are here, then we did not find backward operator for this
2530         // node.
2531         nb->Attr("workspace_enabled", false);
2532       }
2533     } else if (orig_node->type_string() == ws.bwd_op &&
2534                mkl_op_registry::IsMklOp(
2535                    mkl_op_registry::GetMklOpName(orig_node->type_string()),
2536                    T)) {
2537       // If this op is a bwd op, then we need to add workspace edge and
2538       // it's Mkl tensor edge between its corresponding fwd op and this
2539       // op. Corresponding fwd op is specified in 'fwd_op' field of
2540       // workspace info. fwd_slot and bwd_slot in workspace info specify
2541       // an edge between which slots connect forward and backward op.
2542       // Once all these criteria match, we add a workspace edge between
2543       // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
2544       // determined by interleaved/contiguous ordering. Function
2545       // DataIndexToMetaDataIndex tells us the location of Mkl tensor
2546       // from the location of the Tensorflow tensor.
2547       for (const Edge* e : orig_node->in_edges()) {
2548         if (e->src_output() == ws.fwd_slot &&
2549             // We would have rewritten the forward op, so we need to use
2550             // GetMklOpName call to get its Mkl name.
2551             e->src()->type_string() ==
2552                 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
2553             e->dst_input() == ws.bwd_slot) {
2554           nb->Attr("workspace_enabled", true);
2555           DCHECK(ws_tensors);
2556           // Add workspace edge between fwd op and bwd op.
2557           ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
2558           // Check if we are running in native format mode. If so,
2559           // we don't need to have an Mkl metadata tensor for the workspace.
2560           if (!NativeFormatEnabled()) {
2561             // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
2562             ws_tensors->push_back(NodeBuilder::NodeOut(
2563                 e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
2564                                                    e->src()->num_outputs())));
2565           }
2566           *are_ws_tensors_added = true;
2567           // In terms of input ordering, we add these calls to add Input
2568           // here because workspace edge (and its Mkl tensor) is the last
2569           // edge in the fwdop and bwdop. So all inputs before workspace
2570           // tensor have been added by SetUpInputs function.
2571           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2572                   << orig_node->type_string();
2573           workspace_edge_added = true;
2574           // We found the edge that we were looking for, so break.
2575           break;
2576         }
2577       }
2578 
2579       // If we are here means we did not find fwd op that feeds to this
2580       // bwd op. So in this case, we need to generate dummy tensors for
2581       // workspace input and Mkl tensor for workspace, and set
2582       // workspace_enabled to false.
2583       if (!workspace_edge_added) {
2584         nb->Attr("workspace_enabled", false);
2585         Node* dmt_ws = nullptr;      // Dummy tensor for workspace
2586         Node* dmt_mkl_ws = nullptr;  // Dummy Mkl tensor for workspace
2587         GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
2588         GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
2589         DCHECK(dmt_ws);
2590         DCHECK(dmt_mkl_ws);
2591         DCHECK(ws_tensors);
2592         // We add dummy tensor as workspace tensor.
2593         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
2594         // We add dummy tensor as Mkl tensor for workspace tensor.
2595         ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
2596         *are_ws_tensors_added = true;
2597         VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
2598                 << orig_node->type_string();
2599       }
2600     } else {
2601       // If this node does not match any workspace info, then we do not
2602       // do anything special for workspace propagation for it.
2603     }
2604   }
2605 }
2606 
2607 //////////////////////////////////////////////////////////////////////////
2608 // Op-specific functions to copy attributes from old node to new node
2609 //////////////////////////////////////////////////////////////////////////
2610 
2611 // Generic function to copy all attributes from original node to target.
CopyAttrsAll(const Node * orig_node,NodeBuilder * nb,bool change_format)2612 void MklLayoutRewritePass::CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2613                                         bool change_format) {
2614   string name;
2615   AttrSlice attr_list(orig_node->def());
2616 
2617   auto iter = attr_list.begin();
2618   while (iter != attr_list.end()) {
2619     name = iter->first;
2620     auto attr = iter->second;
2621     nb->Attr(name, attr);
2622     ++iter;
2623   }
2624 }
2625 
2626 // Generic function to copy all attributes and check if filter is const.
CopyAttrsAllCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2627 void MklLayoutRewritePass::CopyAttrsAllCheckConstFilter(const Node* orig_node,
2628                                                         NodeBuilder* nb,
2629                                                         bool change_format) {
2630   CopyAttrsAll(orig_node, nb, change_format);
2631 
2632   // Check and set filter attribute.
2633   Node* filter_node = nullptr;
2634   TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2635   nb->Attr("is_filter_const", filter_node->IsConstant());
2636 }
2637 
CopyAttrsConvCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2638 void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node,
2639                                                          NodeBuilder* nb,
2640                                                          bool change_format) {
2641   CopyAttrsConv(orig_node, nb, change_format);
2642 
2643   // Check and set filter attribute.
2644   Node* filter_node = nullptr;
2645   TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2646   nb->Attr("is_filter_const", filter_node->IsConstant());
2647 }
2648 
CopyAttrsConv(const Node * orig_node,NodeBuilder * nb,bool change_format)2649 void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2650                                          bool change_format) {
2651   DataType T;
2652   string padding;
2653   std::vector<int32> strides;
2654   std::vector<int32> dilations;
2655   std::vector<int32> explicit_paddings;
2656 
2657   // Get all attributes from old node.
2658   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2659   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2660   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2661   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2662 
2663   // Check `explicit_paddings` first because some Conv ops don't have
2664   // this attribute.
2665   if (TryGetNodeAttr(orig_node->def(), "explicit_paddings",
2666                      &explicit_paddings) &&
2667       !explicit_paddings.empty()) {
2668     nb->Attr("explicit_paddings", explicit_paddings);
2669   }
2670 
2671   // Add attributes to new node.
2672   nb->Attr("T", T);
2673   nb->Attr("padding", padding);
2674 
2675   // Add attributes related to `data_format`.
2676   CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
2677 }
2678 
2679 // Used with MergePadWithConv2D
CopyAttrsFromPadAndConv2D(const Node * orig_node1,const Node * orig_node2,NodeBuilder * nb,bool change_format)2680 void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2681                                                      const Node* orig_node2,
2682                                                      NodeBuilder* nb,
2683                                                      bool change_format) {
2684   DataType Tpaddings;
2685   DataType T;
2686   string data_format;
2687   string padding;
2688   std::vector<int32> strides;
2689   std::vector<int32> dilations;
2690   bool use_cudnn_on_gpu;
2691 
2692   // Get all attributes from old node 1.
2693   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T", &T));
2694   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides", &strides));
2695   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations", &dilations));
2696   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding", &padding));
2697   TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format", &data_format));
2698   TF_CHECK_OK(
2699       GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2700   // Get all attributes from old node 2.
2701   TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings", &Tpaddings));
2702 
2703   // Add attributes to new node.
2704   nb->Attr("T", T);
2705   nb->Attr("strides", strides);
2706   nb->Attr("dilations", dilations);
2707   nb->Attr("padding", padding);
2708   nb->Attr("data_format", data_format);
2709   nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2710   nb->Attr("Tpaddings", Tpaddings);
2711 }
2712 
CopyAttrsFromPadAndFusedConv2D(const Node * fused_conv2d,const Node * pad,NodeBuilder * nb,bool change_format)2713 void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
2714     const Node* fused_conv2d, const Node* pad, NodeBuilder* nb,
2715     bool change_format) {
2716   DataType T;
2717   int num_args;
2718   string data_format;
2719   string padding;
2720   std::vector<int32> strides;
2721   std::vector<int32> dilations;
2722   float epsilon;
2723   std::vector<string> fused_ops;
2724   DataType Tpaddings;
2725   float leakyrelu_alpha;
2726 
2727   // Get all attributes from old node.
2728   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T));
2729   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args", &num_args));
2730   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides", &strides));
2731   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding", &padding));
2732   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format", &data_format));
2733   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations));
2734   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops));
2735   TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon));
2736   TF_CHECK_OK(
2737       GetNodeAttr(fused_conv2d->def(), "leakyrelu_alpha", &leakyrelu_alpha));
2738   TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings));
2739 
2740   // Add attributes to new node.
2741   nb->Attr("T", T);
2742   nb->Attr("num_args", num_args);
2743   nb->Attr("strides", strides);
2744   nb->Attr("padding", padding);
2745   nb->Attr("data_format", data_format);
2746   nb->Attr("dilations", dilations);
2747   nb->Attr("epsilon", epsilon);
2748   nb->Attr("Tpaddings", Tpaddings);
2749   nb->Attr("fused_ops", fused_ops);
2750   nb->Attr("leakyrelu_alpha", leakyrelu_alpha);
2751 }
2752 
CopyAttrsQuantizedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2753 void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
2754                                                     NodeBuilder* nb,
2755                                                     bool change_format) {
2756   DataType Tinput, Tfilter, out_type;
2757   string padding;
2758   string data_format("NHWC");
2759   std::vector<int32> strides, dilations, padding_list;
2760   bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list");
2761 
2762   // Get all attributes from old node.
2763   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
2764   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter", &Tfilter));
2765   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
2766   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2767   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2768   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2769   if (has_padding_list) {
2770     TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list));
2771   }
2772 
2773   Node* filter_node = nullptr;
2774   TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2775 
2776   // Add attributes to new node.
2777   nb->Attr("Tinput", Tinput);
2778   nb->Attr("Tfilter", Tfilter);
2779   nb->Attr("out_type", out_type);
2780   nb->Attr("padding", padding);
2781   nb->Attr("is_filter_const", filter_node->IsConstant());
2782   nb->Attr("strides", strides);
2783   nb->Attr("dilations", dilations);
2784   nb->Attr("data_format", data_format);
2785   if (has_padding_list) {
2786     nb->Attr("padding_list", padding_list);
2787   }
2788 
2789   // Requantization attr Tbias.
2790   DataType Tbias;
2791   Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2792   if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2793 }
2794 
CopyAttrsQuantizedMatMulWithBiasAndDequantize(const Node * orig_node,NodeBuilder * nb,bool change_format)2795 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2796     const Node* orig_node, NodeBuilder* nb, bool change_format) {
2797   CopyAttrsAll(orig_node, nb, change_format);
2798 
2799   // Check and set filter attribute.
2800   Node* filter_node = nullptr;
2801   TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2802   nb->Attr("is_weight_const", filter_node->IsConstant());
2803 }
2804 
CopyAttrsQuantizedMatMulWithBias(const Node * orig_node,NodeBuilder * nb,bool change_format)2805 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBias(
2806     const Node* orig_node, NodeBuilder* nb, bool change_format) {
2807   DataType T1, T2, Toutput;
2808 
2809   // Get all attributes from old node.
2810   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T1", &T1));
2811   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T2", &T2));
2812   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Toutput", &Toutput));
2813 
2814   Node* weight_node = nullptr;
2815   TF_CHECK_OK(orig_node->input_node(1, &weight_node));
2816 
2817   // Add attributes to new node.
2818   nb->Attr("T1", T1);
2819   nb->Attr("T2", T2);
2820   nb->Attr("Toutput", Toutput);
2821   nb->Attr("is_weight_const", weight_node->IsConstant());
2822 
2823   // Requantization attr Tbias
2824   DataType Tbias;
2825   Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2826   if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2827 }
2828 
CopyFormatAttrsConv(const Node * orig_node,NodeBuilder * nb,const std::vector<int32> & strides,const std::vector<int32> & dilations,bool change_format)2829 void MklLayoutRewritePass::CopyFormatAttrsConv(
2830     const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides,
2831     const std::vector<int32>& dilations, bool change_format) {
2832   string data_format;
2833 
2834   if (!change_format) {
2835     nb->Attr("strides", strides);
2836     nb->Attr("dilations", dilations);
2837 
2838     TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2839     nb->Attr("data_format", data_format);
2840   } else {
2841     std::vector<int32> new_strides;
2842     std::vector<int32> new_dilations;
2843     if (strides.size() == 5) {
2844       // `strides` and `dilations` also need to be changed according to
2845       // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2846       new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2847                      strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2848                      strides[NDHWC::dim::W]};
2849 
2850       new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C],
2851                        dilations[NDHWC::dim::D], dilations[NDHWC::dim::H],
2852                        dilations[NDHWC::dim::W]};
2853     } else {
2854       // `strides` and `dilations` also need to be changed according to
2855       // `data_format`. In this case, from `NHWC` to `NCHW`.
2856 
2857       new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2858                      strides[NHWC::dim::H], strides[NHWC::dim::W]};
2859 
2860       new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C],
2861                        dilations[NHWC::dim::H], dilations[NHWC::dim::W]};
2862     }
2863     nb->Attr("strides", new_strides);
2864     nb->Attr("dilations", new_dilations);
2865   }
2866 }
2867 
CopyAttrsPooling(const Node * orig_node,NodeBuilder * nb,bool change_format)2868 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
2869                                             NodeBuilder* nb,
2870                                             bool change_format) {
2871   DataType T;
2872   string data_format;
2873   string padding;
2874   std::vector<int32> ksize, strides;
2875 
2876   // Get all attributes from old node.
2877   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2878   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
2879   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2880   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2881   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2882 
2883   // Add attributes to new node.
2884   nb->Attr("T", T);
2885   nb->Attr("padding", padding);
2886 
2887   if (!change_format) {
2888     nb->Attr("strides", strides);
2889     nb->Attr("ksize", ksize);
2890 
2891     nb->Attr("data_format", data_format);
2892   } else {
2893     std::vector<int32> new_strides;
2894     std::vector<int32> new_ksize;
2895     if (strides.size() == 5) {
2896       DCHECK(data_format == "NCDHW");
2897       // `strides` and `ksize` also need to be changed according to
2898       // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2899       new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2900                      strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2901                      strides[NDHWC::dim::W]};
2902 
2903       new_ksize = {ksize[NDHWC::dim::N], ksize[NDHWC::dim::C],
2904                    ksize[NDHWC::dim::D], ksize[NDHWC::dim::H],
2905                    ksize[NDHWC::dim::W]};
2906 
2907     } else {
2908       // `strides` and `ksize` also need to be changed according to
2909       // `data_format`. In this case, from `NHWC` to `NCHW`.
2910       DCHECK(data_format == "NCHW");
2911       new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2912                      strides[NHWC::dim::H], strides[NHWC::dim::W]};
2913 
2914       new_ksize = {ksize[NHWC::dim::N], ksize[NHWC::dim::C],
2915                    ksize[NHWC::dim::H], ksize[NHWC::dim::W]};
2916     }
2917     nb->Attr("strides", new_strides);
2918     nb->Attr("ksize", new_ksize);
2919   }
2920 }
2921 
2922 //////////////////////////////////////////////////////////////////////////
2923 //           Helper functions related to node merge pass
2924 //////////////////////////////////////////////////////////////////////////
2925 
CheckForNodeMerge(const Node * a) const2926 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
2927   // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
2928   // once we support BiasAddGrad as Mkl layer.
2929 
2930   // Search for all matching mergeinfo.
2931   // We allow more than one match for extensibility.
2932   std::vector<const MergeInfo*> matching_mi;
2933   for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
2934     if (a->type_string() == mi->op1 || a->type_string() == mi->op2) {
2935       matching_mi.push_back(&*mi);
2936     }
2937   }
2938 
2939   for (const MergeInfo* mi : matching_mi) {
2940     // Get the operand with which 'a' can be merged.
2941     Node* b = nullptr;
2942     if ((b = mi->get_node_to_be_merged(a)) == nullptr) {
2943       continue;
2944     }
2945 
2946     // Get the control edges and input of node
2947     const int N_in = a->num_inputs();
2948     gtl::InlinedVector<Node*, 4> a_control_edges;
2949     gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
2950     FillInputs(a, &a_control_edges, &a_in);
2951 
2952     const int B_in = b->num_inputs();
2953     gtl::InlinedVector<Node*, 4> b_control_edges;
2954     gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
2955     FillInputs(b, &b_control_edges, &b_in);
2956 
2957     // Shouldn't merge if a and b have different control edges.
2958     if (a_control_edges != b_control_edges) {
2959       continue;
2960     } else {
2961       // We found a match.
2962       return b;
2963     }
2964   }
2965 
2966   return nullptr;
2967 }
2968 
MergeConv2DWithBiasAdd(std::unique_ptr<Graph> * g,Node * m,Node * n)2969 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
2970                                                     Node* m, Node* n) {
2971   CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
2972              n->type_string() == csinfo_.conv2d)) ||
2973                ((n->type_string() == csinfo_.bias_add &&
2974                  m->type_string() == csinfo_.conv2d)),
2975            true);
2976 
2977   // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
2978   // BiasAdd is successor node, and Conv2D predecessor node.
2979   Node* pred = m->type_string() == csinfo_.bias_add ? n : m;
2980   Node* succ = m->type_string() == csinfo_.bias_add ? m : n;
2981 
2982   // 1. Get all attributes from input nodes.
2983   DataType T_pred, T_succ;
2984   string padding;
2985   std::vector<int32> strides;
2986   std::vector<int32> dilations;
2987   string data_format_pred, data_format_succ;
2988   bool use_cudnn_on_gpu;
2989   TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
2990   TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
2991   TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
2992   TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
2993   TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations));
2994   TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
2995   TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
2996   TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2997   // We check to ensure that data formats of both succ and pred are same.
2998   // We expect them to be same, so we can enforce this as assert.
2999   // But assert can be too strict, so we enforce this as a check.
3000   // If the check fails, then we do not merge two nodes.
3001   // We also do same check for devices.
3002   if (data_format_pred != data_format_succ || T_pred != T_succ ||
3003       pred->assigned_device_name() != succ->assigned_device_name() ||
3004       pred->def().device() != succ->def().device()) {
3005     return Status(error::Code::INVALID_ARGUMENT,
3006                   "data_format or T attribute or devices of Conv2D and "
3007                   "BiasAdd do not match. Will skip node merge optimization");
3008   }
3009 
3010   const int succ_num = succ->num_inputs();
3011   gtl::InlinedVector<Node*, 4> succ_control_edges;
3012   gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3013   FillInputs(succ, &succ_control_edges, &succ_in);
3014 
3015   const int pred_num = pred->num_inputs();
3016   gtl::InlinedVector<Node*, 4> pred_control_edges;
3017   gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3018   FillInputs(pred, &pred_control_edges, &pred_in);
3019 
3020   // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is
3021   // not expecting output of Conv2D). If this is not the case, then we cannot
3022   // merge Conv2D with BiasAdd.
3023   const int kFirstOutputSlot = 0;
3024   for (const Edge* e : pred->out_edges()) {
3025     if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3026       return Status(error::Code::INVALID_ARGUMENT,
3027                     "Conv2D does not feed to BiasAdd, or "
3028                     "it feeds BiasAdd but has multiple outputs. "
3029                     "Will skip node merge optimization");
3030     }
3031   }
3032 
3033   // 2. Get inputs from both the nodes.
3034   // Find the 2 inputs from the conv and the bias from the add Bias.
3035   // Get operand 0, 1 of conv2D.
3036   CHECK_EQ(pred->in_edges().size(), 2);  // Conv2D must have 2 inputs.
3037   // Get operand 1 of add_bias
3038   // BiasAdd must have 2 inputs: Conv, bias
3039   CHECK_EQ(succ->in_edges().size(), 2);
3040 
3041   // We will use the node name of BiasAdd as the name of new node
3042   // Build new node. We use same name as original node, but change the op
3043   // name.
3044   NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias);
3045   nb.Input(pred_in[0].first, pred_in[0].second);  // In1 of Conv2D
3046   // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3047   nb.Input(pred_in[1].first, pred_in[1].second);  // In2 of Conv2D
3048   // In1 of BiasAdd is same as output of Conv2D.
3049   nb.Input(succ_in[1].first, succ_in[1].second);  // In2 of BiasAdd
3050 
3051   // Copy attributes from Conv2D to Conv2DWithBias.
3052   CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb);
3053 
3054   // Copy the device assigned to old node to new node.
3055   nb.Device(succ->def().device());
3056 
3057   // Create node.
3058   Node* new_node;
3059   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3060 
3061   // In the following code of this function, an unsorted set is used to make
3062   // sure no duplicated edges be added into the new node. Therefore, we can
3063   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3064   // check in the routine.
3065 
3066   // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3067   // node are already copied in BuildNode. We handle control edges now.
3068   std::unordered_set<Node*> unique_node;
3069   for (const Edge* e : pred->in_edges()) {
3070     if (e->IsControlEdge()) {
3071       auto result = unique_node.insert(e->src());
3072       if (result.second) {
3073         (*g)->AddControlEdge(e->src(), new_node, true);
3074       }
3075     }
3076   }
3077   unique_node.clear();
3078 
3079   for (const Edge* e : succ->in_edges()) {
3080     if (e->IsControlEdge()) {
3081       auto result = unique_node.insert(e->src());
3082       if (result.second) {
3083         (*g)->AddControlEdge(e->src(), new_node, true);
3084       }
3085     }
3086   }
3087   unique_node.clear();
3088 
3089   // Incoming edges are fixed, we will fix the outgoing edges now.
3090   // First, we will fix outgoing control edges from 'pred' node.
3091   for (const Edge* e : pred->out_edges()) {
3092     if (e->IsControlEdge()) {
3093       auto result = unique_node.insert(e->dst());
3094       if (result.second) {
3095         (*g)->AddControlEdge(new_node, e->dst(), true);
3096       }
3097     }
3098   }
3099   unique_node.clear();
3100 
3101   // Second, we will fix outgoing control and data edges from 'succ' node.
3102   for (const Edge* e : succ->out_edges()) {
3103     if (e->IsControlEdge()) {
3104       auto result = unique_node.insert(e->dst());
3105       if (result.second) {
3106         (*g)->AddControlEdge(new_node, e->dst(), true);
3107       }
3108     } else {
3109       // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
3110       // output (at slot 0).
3111       const int kConv2DWithBiasOutputSlot = 0;
3112       auto new_edge = (*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot,
3113                                     e->dst(), e->dst_input());
3114       DCHECK(new_edge);
3115     }
3116   }
3117 
3118   // Copy device assigned to old node to new node.
3119   // It's ok to use pred or succ as we have enforced a check that
3120   // both have same device assigned.
3121   new_node->set_assigned_device_name(pred->assigned_device_name());
3122 
3123   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3124           << ", and node: " << succ->DebugString()
3125           << ", into node:" << new_node->DebugString();
3126 
3127   (*g)->RemoveNode(succ);
3128   (*g)->RemoveNode(pred);
3129 
3130   return Status::OK();
3131 }
3132 
MergePadWithConv2D(std::unique_ptr<Graph> * g,Node * m,Node * n)3133 Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g,
3134                                                 Node* m, Node* n) {
3135   DCHECK((m->type_string() == csinfo_.pad &&
3136           (n->type_string() == csinfo_.conv2d ||
3137            n->type_string() == csinfo_.fused_conv2d)) ||
3138          (n->type_string() == csinfo_.pad &&
3139           (m->type_string() == csinfo_.conv2d ||
3140            m->type_string() == csinfo_.fused_conv2d)));
3141 
3142   bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d ||
3143                          m->type_string() == csinfo_.fused_conv2d;
3144   // Conv2D is successor node, and Pad predecessor node.
3145   Node* pred = m->type_string() == csinfo_.pad ? m : n;
3146   Node* succ = m->type_string() == csinfo_.pad ? n : m;
3147 
3148   // 1. Get all attributes from input nodes.
3149   DataType T_pred, T_succ;
3150   string padding;
3151   std::vector<int32> strides;
3152   std::vector<int32> dilations;
3153   string data_format_pred, data_format_succ;
3154 
3155   TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
3156   TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
3157   TF_CHECK_OK(GetNodeAttr(succ->def(), "padding", &padding));
3158   TF_CHECK_OK(GetNodeAttr(succ->def(), "strides", &strides));
3159   TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations", &dilations));
3160   // Check if the devices of both succ and pred are the same.
3161   // Assert is not used because it can be too strict.
3162   // Don't need to check for data formats because it is not available in Pad.
3163   if (T_pred != T_succ ||
3164       pred->assigned_device_name() != succ->assigned_device_name() ||
3165       pred->def().device() != succ->def().device()) {
3166     return Status(error::Code::INVALID_ARGUMENT,
3167                   "T attribute or devices of Conv2D and "
3168                   "Pad do not match. Will skip node merge optimization");
3169   }
3170 
3171   const int succ_num = succ->num_inputs();
3172   gtl::InlinedVector<Node*, 4> succ_control_edges;
3173   gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3174   FillInputs(succ, &succ_control_edges, &succ_in);
3175 
3176   const int pred_num = pred->num_inputs();
3177   gtl::InlinedVector<Node*, 4> pred_control_edges;
3178   gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3179   FillInputs(pred, &pred_control_edges, &pred_in);
3180 
3181   // We need to ensure that Pad only feeds to Conv2D (some other operator is
3182   // not expecting output of Pad). If this is not the case, then we cannot
3183   // merge Conv2D with Pad.
3184   const int kFirstOutputSlot = 0;
3185   for (const Edge* e : pred->out_edges()) {
3186     if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3187       return Status(error::Code::INVALID_ARGUMENT,
3188                     "Pad does not feed to Conv2D, or "
3189                     "it feeds Conv2D but has multiple outputs. "
3190                     "Will skip node merge optimization");
3191     }
3192   }
3193 
3194   // 2. Get inputs from both the nodes.
3195 
3196   // Pad must have 2 data inputs: "input" and paddings.
3197   int PadDataInputEdges = 0;
3198   for (const Edge* e : pred->in_edges()) {
3199     if (!e->IsControlEdge()) {
3200       PadDataInputEdges++;
3201     }
3202   }
3203   DCHECK_EQ(PadDataInputEdges, 2);
3204 
3205   // Conv2D must have 2 data inputs: Pad output and Filter
3206   // FusedConv2D have 3 data inputs: Pad output, Filter and Args;
3207   int ConvDataInputEdges = 0;
3208   for (const Edge* e : succ->in_edges()) {
3209     if (!e->IsControlEdge()) {
3210       ConvDataInputEdges++;
3211     }
3212   }
3213 
3214   DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2);
3215 
3216   // We will use the node name of Conv2D as the name of new node
3217   // Build new node. We use same name as original node, but change the op
3218   // name.
3219 
3220   NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d
3221                                                : csinfo_.pad_with_conv2d);
3222   nb.Input(pred_in[0].first, pred_in[0].second);  // In1 (input data)  of Pad
3223   // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3224   nb.Input(succ_in[1].first, succ_in[1].second);  // In2 (filter) of conv2d
3225   // In1 of Conv2D is same as output of Pad.
3226   // Thus, only need to add In2 of Conv2D
3227 
3228   if (is_fused_conv2d) {
3229     // FusedConv2D has one additional input, args
3230     std::vector<NodeBuilder::NodeOut> args;
3231     int num_args = 1;
3232     GetNodeAttr(succ->def(), "num_args", &num_args);
3233     for (int i = 0; i < num_args; i++) {
3234       args.emplace_back(succ_in[2 + i].first, succ_in[2 + i].second);
3235     }
3236     nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{
3237         args});                                     // In3 (args) of FusedConv2D
3238     nb.Input(pred_in[1].first, pred_in[1].second);  // In2 (paddings) of Pad
3239     // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D.
3240     CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ),
3241                                    const_cast<const Node*>(pred), &nb);
3242   } else {
3243     nb.Input(pred_in[1].first, pred_in[1].second);  // In2 (paddings) of Pad
3244     // Copy attributes from Pad and conv2D to PadWithConv2D.
3245     CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ),
3246                               const_cast<const Node*>(pred), &nb);
3247   }
3248 
3249   // Copy the device assigned to old node to new node.
3250   nb.Device(succ->def().device());
3251 
3252   // Create node.
3253   Node* new_node;
3254   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3255   // No need to check if new_node is null because it will be null only when
3256   // Finalize fails.
3257 
3258   // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3259   // node are already copied in BuildNode.
3260   // We handle control edges now.
3261   for (const Edge* e : pred->in_edges()) {
3262     if (e->IsControlEdge()) {
3263       // Don't allow duplicate edge
3264       (*g)->AddControlEdge(e->src(), new_node, false);
3265     }
3266   }
3267   for (const Edge* e : succ->in_edges()) {
3268     if (e->IsControlEdge()) {
3269       // Don't allow duplicate edge
3270       (*g)->AddControlEdge(e->src(), new_node, false);
3271     }
3272   }
3273 
3274   // Incoming edges are fixed, we will fix the outgoing edges now.
3275   // First, we will fix outgoing control edges from 'pred' node.
3276   for (const Edge* e : pred->out_edges()) {
3277     if (e->IsControlEdge()) {
3278       // Don't allow duplicate edge
3279       (*g)->AddControlEdge(new_node, e->dst(), false);
3280     }
3281   }
3282 
3283   // Second, we will fix outgoing control and data edges from 'succ' node.
3284   for (const Edge* e : succ->out_edges()) {
3285     if (e->IsControlEdge()) {
3286       // Allow duplicate while adding control edge as it would fail (return
3287       // NULL) if we try to add duplicate edge.
3288       (*g)->AddControlEdge(new_node, e->dst(), false);
3289     } else {
3290       // Conv2D has only 1 output (at slot 0) and merged node also has only 1
3291       // output (at slot 0).
3292       const int kPadWithConv2DOutputSlot = 0;
3293       (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(),
3294                     e->dst_input());
3295     }
3296   }
3297 
3298   // Copy device assigned to old node to new node.
3299   // It's ok to use pred or succ as we have enforced a check that
3300   // both have same device assigned.
3301   new_node->set_assigned_device_name(pred->assigned_device_name());
3302 
3303   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3304           << ", and node: " << succ->DebugString()
3305           << ", into node:" << new_node->DebugString();
3306 
3307   (*g)->RemoveNode(succ);
3308   (*g)->RemoveNode(pred);
3309 
3310   return Status::OK();
3311 }
3312 
MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph> * g,Node * m,Node * n)3313 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
3314     std::unique_ptr<Graph>* g, Node* m, Node* n) {
3315   CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
3316              n->type_string() == csinfo_.conv2d_grad_filter)) ||
3317                ((n->type_string() == csinfo_.bias_add_grad &&
3318                  m->type_string() == csinfo_.conv2d_grad_filter)),
3319            true);
3320 
3321   // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
3322   Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
3323   Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m;
3324 
3325   // Sanity check for attributes from input nodes.
3326   DataType T_b, T_f;
3327   string data_format_b, data_format_f;
3328   TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b));
3329   TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f));
3330   TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b));
3331   TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f));
3332   if (data_format_b != data_format_f || T_b != T_f ||
3333       badd->assigned_device_name() != fltr->assigned_device_name() ||
3334       badd->def().device() != fltr->def().device()) {
3335     return Status(error::Code::INVALID_ARGUMENT,
3336                   "data_format or T attribute or devices of "
3337                   "Conv2DBackpropFilter and BiasAddGrad do not match. "
3338                   "Will skip node merge optimization");
3339   }
3340 
3341   // We will use the node name of Conv2DBackpropFilter as the name of new node.
3342   // This is because BackpropFilterWithBias is going to emit bias output also.
3343   NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias);
3344   // Since Conv2DBackpropFilterWithBias has same number of inputs as
3345   // Conv2DBackpropFilter, we can just copy input edges directly. We don't need
3346   // to copy any data input of BiasAddGrad because that input also goes to
3347   // Conv2DBackpropFilter.
3348   const int fltr_ins = fltr->num_inputs();
3349   gtl::InlinedVector<Node*, 4> fltr_control_edges;
3350   gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins);
3351   FillInputs(fltr, &fltr_control_edges, &fltr_in_edges);
3352   for (int idx = 0; idx < fltr_ins; idx++) {
3353     nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second);
3354   }
3355 
3356   // Copy attributes from Conv2DBackpropFilter.
3357   CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
3358 
3359   // Copy the device assigned to old node to new node.
3360   nb.Device(fltr->def().device());
3361 
3362   // Create node.
3363   Node* new_node;
3364   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3365 
3366   // In the following code of this function, an unsorted set is used to make
3367   // sure no duplicated edges be added into the new node. Therefore, we can
3368   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3369   // check in the routine.
3370 
3371   // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to
3372   // new 'new_node' node are already copied in BuildNode. We handle control
3373   // edges now.
3374   std::unordered_set<Node*> unique_node;
3375   for (const Edge* e : badd->in_edges()) {
3376     if (e->IsControlEdge()) {
3377       auto result = unique_node.insert(e->src());
3378       if (result.second) {
3379         (*g)->AddControlEdge(e->src(), new_node, true);
3380       }
3381     }
3382   }
3383   unique_node.clear();
3384   for (const Edge* e : fltr->in_edges()) {
3385     if (e->IsControlEdge()) {
3386       auto result = unique_node.insert(e->src());
3387       if (result.second) {
3388         (*g)->AddControlEdge(e->src(), new_node, true);
3389       }
3390     }
3391   }
3392   unique_node.clear();
3393 
3394   // Incoming edges are fixed, we will fix the outgoing edges now.
3395   // First, we will fix outgoing control edges from 'badd' node.
3396   // Conv2DBackpropFilter has 1 output -- filter_grad.
3397   // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and
3398   // bias_grad. But filter_grad is at same slot number (0) in both the
3399   // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while
3400   // it is at slot number 0 in BiasAddGrad.
3401   const int kMergedNodeFilterGradOutputIdx = 0;
3402   const int kMergedNodeBiasGradOutputIdx = 1;
3403 
3404   for (const Edge* e : badd->out_edges()) {
3405     if (e->IsControlEdge()) {
3406       auto result = unique_node.insert(e->dst());
3407       if (result.second) {
3408         (*g)->AddControlEdge(new_node, e->dst(), true);
3409       }
3410     } else {
3411       auto new_edge = (*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
3412                                     e->dst(), e->dst_input());
3413       DCHECK(new_edge);
3414     }
3415   }
3416   unique_node.clear();
3417 
3418   // Second, we will fix outgoing control and data edges from 'fltr' node.
3419   for (const Edge* e : fltr->out_edges()) {
3420     if (e->IsControlEdge()) {
3421       auto result = unique_node.insert(e->dst());
3422       if (result.second) {
3423         (*g)->AddControlEdge(new_node, e->dst(), true);
3424       }
3425     } else {
3426       auto new_edge = (*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
3427                                     e->dst(), e->dst_input());
3428       DCHECK(new_edge);
3429     }
3430   }
3431 
3432   // Copy device assigned to old node to new node.
3433   // It's ok to use badd or fltr as we have enforced a check that
3434   // both have same device assigned.
3435   new_node->set_assigned_device_name(badd->assigned_device_name());
3436 
3437   VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString()
3438           << ", and node: " << fltr->DebugString()
3439           << ", into node:" << new_node->DebugString();
3440 
3441   (*g)->RemoveNode(badd);
3442   (*g)->RemoveNode(fltr);
3443 
3444   return Status::OK();
3445 }
3446 
MergeNode(std::unique_ptr<Graph> * g,Node * m,Node * n)3447 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m,
3448                                        Node* n) {
3449   DCHECK(m);
3450   DCHECK(n);
3451 
3452   if (((m->type_string() == csinfo_.bias_add &&
3453         n->type_string() == csinfo_.conv2d)) ||
3454       ((n->type_string() == csinfo_.bias_add &&
3455         m->type_string() == csinfo_.conv2d))) {
3456     return this->MergeConv2DWithBiasAdd(g, m, n);
3457   }
3458   if ((m->type_string() == csinfo_.pad &&
3459        (n->type_string() == csinfo_.conv2d ||
3460         (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) ||
3461       (n->type_string() == csinfo_.pad &&
3462        (m->type_string() == csinfo_.conv2d ||
3463         (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) {
3464     return this->MergePadWithConv2D(g, m, n);
3465   }
3466 
3467   if (((m->type_string() == csinfo_.bias_add_grad &&
3468         n->type_string() == csinfo_.conv2d_grad_filter)) ||
3469       ((n->type_string() == csinfo_.bias_add_grad &&
3470         m->type_string() == csinfo_.conv2d_grad_filter))) {
3471     return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n);
3472   }
3473 
3474   return Status(error::Code::UNIMPLEMENTED,
3475                 "Unimplemented case for node merge optimization.");
3476 }
3477 
3478 //////////////////////////////////////////////////////////////////////////
3479 //           Helper functions for node rewrite
3480 //////////////////////////////////////////////////////////////////////////
3481 
RewriteNodeForLayoutPropagation(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3482 Status MklLayoutRewritePass::RewriteNodeForLayoutPropagation(
3483     std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3484     const RewriteInfo* ri) {
3485   // Get all data inputs.
3486   int num_data_inputs = orig_node->in_edges().size();
3487   // Drop count for control edges from inputs
3488   for (const Edge* e : orig_node->in_edges()) {
3489     if (e->IsControlEdge()) {
3490       num_data_inputs--;
3491     }
3492   }
3493 
3494   gtl::InlinedVector<Node*, 4> control_edges;
3495   gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3496   FillInputs(orig_node, &control_edges, &inputs);
3497 
3498   // Build new node. We use same name as original node, but change the op name.
3499   NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3500   // Copy user-specified device assigned to original node to new node.
3501   nb.Device(orig_node->def().device());
3502   // Set up new inputs to the rewritten node.
3503   Status s = SetUpInputs(g, inputs, &nb, orig_node);
3504   if (s != Status::OK()) {
3505     return s;
3506   }
3507 
3508   const bool kPartialCopyAttrs = false;
3509   ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs);
3510 
3511   // Set the Mkl layer label for this op.
3512   if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3513       DataTypeIsQuantized(orig_node->output_type(0))) {
3514     nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3515   } else {
3516     nb.Attr("_kernel", mkl_op_registry::kMklLayoutDependentOpLabel);
3517   }
3518   // Finalize graph and get new node.
3519   s = nb.Finalize(&**g, new_node);
3520   if (s != Status::OK()) {
3521     return s;
3522   }
3523 
3524   // In the following code of this function, an unsorted set is used to make
3525   // sure no duplicated edges be added into the new node. Therefore, we can
3526   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3527   // check in the routine.
3528 
3529   // Incoming data edges from 'orig_node' node to new 'new_node' node are
3530   // already copied in BuildNode. We need to handle control edges now.
3531   std::unordered_set<Node*> unique_node;
3532   for (const Edge* e : orig_node->in_edges()) {
3533     if (e->IsControlEdge()) {
3534       auto result = unique_node.insert(e->src());
3535       if (result.second) {
3536         (*g)->AddControlEdge(e->src(), *new_node, true);
3537       }
3538     }
3539   }
3540   unique_node.clear();
3541 
3542   // Copy outgoing edges from 'orig_node' node to new
3543   // 'new_node' node, since the output also follows same ordering among
3544   // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
3545   // tensors appropriately. Specifically, nth output of the original node
3546   // will become 2*nth output of the Mkl node for the interleaved ordering
3547   // of the tensors. For the contiguous ordering of the tensors, it will be n.
3548   // GetTensorDataIndex provides this mapping function.
3549   for (const Edge* e : orig_node->out_edges()) {
3550     if (e->IsControlEdge()) {
3551       auto result = unique_node.insert(e->dst());
3552       if (result.second) {
3553         (*g)->AddControlEdge(*new_node, e->dst(), true);
3554       }
3555     } else {
3556       auto new_edge = (*g)->AddEdge(
3557           *new_node,
3558           GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
3559           e->dst(), e->dst_input());
3560       DCHECK(new_edge);
3561     }
3562   }
3563   return Status::OK();
3564 }
3565 
RewriteNodeForJustOpNameChange(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3566 Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(
3567     std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3568     const RewriteInfo* ri) {
3569   // Get all data inputs.
3570   int num_data_inputs = orig_node->in_edges().size();
3571   // Drop count for control edges from inputs
3572   for (const Edge* e : orig_node->in_edges()) {
3573     if (e->IsControlEdge()) {
3574       num_data_inputs--;
3575     }
3576   }
3577   gtl::InlinedVector<Node*, 4> control_edges;
3578   gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3579   FillInputs(orig_node, &control_edges, &inputs);
3580 
3581   // Build new node. We use same name as original node, but change the op name.
3582   NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3583   // Copy user-specified device assigned to original node to new node.
3584   nb.Device(orig_node->def().device());
3585 
3586   Status s = CopyInputs(orig_node, inputs, &nb);
3587   if (s != Status::OK()) {
3588     return s;
3589   }
3590 
3591   std::vector<NodeBuilder::NodeOut> workspace_tensors;
3592   bool are_workspace_tensors_available = false;
3593   if (IsWorkspaceCheckNeeded(orig_node)) {
3594     AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors,
3595                              &are_workspace_tensors_available);
3596     if (are_workspace_tensors_available) {
3597       DCHECK_EQ(workspace_tensors.size(), 1);
3598       nb.Input(workspace_tensors[0].node, workspace_tensors[0].index);
3599     }
3600   }
3601 
3602   if (!NativeFormatEnabled()) {
3603     ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
3604   } else {
3605     ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, false);
3606   }
3607 
3608   if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3609       DataTypeIsQuantized(orig_node->output_type(0))) {
3610     nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3611   } else {
3612     nb.Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
3613   }
3614 
3615   // Finalize graph and get new node.
3616   s = nb.Finalize(&**g, new_node);
3617   if (s != Status::OK()) {
3618     return s;
3619   }
3620 
3621   // In the following code of this function, an unsorted set is used to make
3622   // sure no duplicated edges be added into the new node. Therefore, we can
3623   // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3624   // check in the routine.
3625 
3626   // Incoming data edges from 'orig_node' node to new 'new_node' node are
3627   // already copied in BuildNode. We need to handle control edges now.
3628   std::unordered_set<Node*> unique_node;
3629   for (const Edge* e : orig_node->in_edges()) {
3630     if (e->IsControlEdge()) {
3631       auto result = unique_node.insert(e->src());
3632       if (result.second) {
3633         (*g)->AddControlEdge(e->src(), *new_node, true);
3634       }
3635     }
3636   }
3637   unique_node.clear();
3638 
3639   // Transfer outgoing edges from 'orig_node' node to new 'new_node' node.
3640   for (const Edge* e : orig_node->out_edges()) {
3641     if (e->IsControlEdge()) {
3642       auto result = unique_node.insert(e->dst());
3643       if (result.second) {
3644         (*g)->AddControlEdge(*new_node, e->dst(), true);
3645       }
3646     } else {
3647       auto result =
3648           (*g)->AddEdge(*new_node, e->src_output(), e->dst(), e->dst_input());
3649       DCHECK(result != nullptr);
3650     }
3651   }
3652 
3653   return Status::OK();
3654 }
3655 
RewriteNode(std::unique_ptr<Graph> * g,Node * orig_node,const RewriteInfo * ri)3656 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
3657                                          Node* orig_node,
3658                                          const RewriteInfo* ri) {
3659   DCHECK(ri != nullptr);
3660   DCHECK(orig_node != nullptr);
3661 
3662   VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
3663 
3664   Status ret_status = Status::OK();
3665   Node* new_node = nullptr;
3666   if (ri->rewrite_cause == kRewriteForLayoutPropagation) {
3667     ret_status = RewriteNodeForLayoutPropagation(g, orig_node, &new_node, ri);
3668   } else if (ri->rewrite_cause == kRewriteForOpNameChange) {
3669     ret_status = RewriteNodeForJustOpNameChange(g, orig_node, &new_node, ri);
3670   } else {
3671     ret_status = Status(error::Code::INVALID_ARGUMENT,
3672                         "Unsupported rewrite cause found."
3673                         "RewriteNode will fail.");
3674   }
3675   TF_CHECK_OK(ret_status);
3676 
3677   // Copy the runtime device assigned from original code to new node.
3678   new_node->set_assigned_device_name(orig_node->assigned_device_name());
3679 
3680   // Delete original node and mark new node as rewritten.
3681   (*g)->RemoveNode(orig_node);
3682 
3683   VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
3684   return ret_status;
3685 }
3686 
3687 // TODO(mdfaijul): Is there any other elegant way to check for quantized ops
3688 // having attributes other than "T"?
3689 // Current implementation reflects only QuantizedConv2D and its fused Ops.
3690 const MklLayoutRewritePass::RewriteInfo*
CheckForQuantizedNodeRewrite(const Node * n) const3691 MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {
3692   DataType T1, T2;
3693   DataType Tinput, Tfilter;
3694   bool type_attrs_present = false;
3695 
3696   if (TryGetNodeAttr(n->def(), "Tinput", &Tinput) &&
3697       TryGetNodeAttr(n->def(), "Tfilter", &Tfilter) &&
3698       mkl_op_registry::IsMklQuantizedOp(
3699           mkl_op_registry::GetMklOpName(n->type_string()), Tinput, Tfilter)) {
3700     type_attrs_present = true;
3701   } else if (TryGetNodeAttr(n->def(), "T1", &T1) &&
3702              TryGetNodeAttr(n->def(), "T2", &T2) &&
3703              mkl_op_registry::IsMklQuantizedOp(
3704                  mkl_op_registry::GetMklOpName(n->type_string()), T1, T2)) {
3705     type_attrs_present = true;
3706   }
3707 
3708   if (type_attrs_present) {
3709     for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3710       if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3711         return &*ri;
3712       }
3713     }
3714   }
3715 
3716   return nullptr;
3717 }
3718 
3719 const MklLayoutRewritePass::RewriteInfo*
CheckForNodeRewrite(const Node * n) const3720 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
3721   DCHECK(n);
3722 
3723   // QuantizedOps may have attributes other than "T", so decoupled the check
3724   // with a function, CheckForQuantizedNodeRewrite(const Node*).
3725   const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n);
3726   if (ri != nullptr) return ri;
3727 
3728   // First check if node along with its type is supported by MKL layer.
3729   // We do not want to rewrite an op into Mkl op if types are not supported.
3730   // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
3731   // MklRelu if type is INT32.
3732   DataType T;
3733   if (!TryGetNodeAttr(n->def(), "T", &T)) {
3734     return nullptr;
3735   }
3736 
3737   // We make an exception for Conv2DGrad and MaxPool related ops as
3738   // the corresponding MKL ops currently do not support the case
3739   // of padding == EXPLICIT yet.
3740   // TODO(intel): support `EXPLICIT` padding for ConvGrad
3741   if (n->type_string() == csinfo_.conv2d_grad_input ||
3742       n->type_string() == csinfo_.conv2d_grad_filter ||
3743       n->type_string() == csinfo_.max_pool ||
3744       n->type_string() == csinfo_.max_pool_grad ||
3745       n->type_string() == csinfo_.max_pool3d ||
3746       n->type_string() == csinfo_.max_pool3d_grad) {
3747     string padding;
3748     TF_CHECK_OK(GetNodeAttr(n->def(), "padding", &padding));
3749     if (padding == "EXPLICIT") return nullptr;
3750   }
3751 
3752   // We make an exception for __MklDummyConv2DWithBias,
3753   // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their
3754   // names do not match Mkl node names.
3755   if (n->type_string() != csinfo_.conv2d_with_bias &&
3756       n->type_string() != csinfo_.pad_with_conv2d &&
3757       n->type_string() != csinfo_.pad_with_fused_conv2d &&
3758       n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
3759       n->type_string() != csinfo_.fused_batch_norm_ex &&
3760       n->type_string() != csinfo_.fused_conv2d &&
3761       n->type_string() != csinfo_.fused_depthwise_conv2d &&
3762       n->type_string() != csinfo_.fused_matmul &&
3763       !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
3764                                 T)) {
3765     return nullptr;
3766   }
3767 
3768   // We now check if rewrite rule applies for this op. If rewrite rule passes
3769   // for this op, then we rewrite it to Mkl op.
3770   // Find matching RewriteInfo and then check that rewrite rule applies.
3771   for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3772     if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3773       return &*ri;
3774     }
3775   }
3776 
3777   // Else return not found.
3778   return nullptr;
3779 }
3780 
3781 //////////////////////////////////////////////////////////////////////////
3782 //           Helper functions for node fusion
3783 //////////////////////////////////////////////////////////////////////////
FuseTransposeMklOpTranspose(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,std::function<void (const Node *,NodeBuilder * nb,bool)> copy_attrs,string data_format)3784 Status MklLayoutRewritePass::FuseTransposeMklOpTranspose(
3785     std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3786     std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
3787     string data_format) {
3788   Node* transpose_to_nhwc = nodes[0];
3789   Node* mklop = nodes[1];
3790   Node* transpose_to_nchw = nodes[2];
3791 
3792   const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs();
3793   gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges;
3794   gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in(
3795       transpose_nhwc_num_inputs);
3796   FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges,
3797              &transpose_nhwc_in);
3798 
3799   const int mklop_num_inputs = mklop->num_inputs();
3800   gtl::InlinedVector<Node*, 4> mklop_control_edges;
3801   gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs);
3802   FillInputs(mklop, &mklop_control_edges, &mklop_in);
3803 
3804   const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs();
3805   gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges;
3806   gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in(
3807       transpose_nchw_num_inputs);
3808   FillInputs(transpose_to_nchw, &transpose_nchw_control_edges,
3809              &transpose_nchw_in);
3810 
3811   // We use same name as original node, but change the op
3812   // type.
3813   NodeBuilder nb(mklop->name(), mklop->type_string());
3814 
3815   // Storing the output slots of the input nodes.
3816   for (int i = 0; i < mklop_num_inputs; i++) {
3817     if (mklop_in[i].first == transpose_to_nhwc) {
3818       // Fill "x":
3819       nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second);
3820     } else {
3821       // Fill inputs other than "x":
3822       nb.Input(mklop_in[i].first, mklop_in[i].second);
3823     }
3824   }
3825 
3826   copy_attrs(const_cast<const Node*>(mklop), &nb, true);
3827   nb.Attr("data_format", data_format);
3828 
3829   // Copy the device assigned to old node to new node.
3830   nb.Device(mklop->def().device());
3831 
3832   // Create node.
3833   Node* new_node;
3834   TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3835   // No need to check if new_node is null because it will be null only when
3836   // Finalize fails.
3837 
3838   // Fill outputs.
3839   for (const Edge* e : transpose_to_nchw->out_edges()) {
3840     if (!e->IsControlEdge()) {
3841       const int kTransposeWithMklOpOutputSlot = 0;
3842       auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot,
3843                                     e->dst(), e->dst_input());
3844       DCHECK(new_edge);
3845     }
3846   }
3847 
3848   // Copy device assigned to old node to new node.
3849   new_node->set_assigned_device_name(mklop->assigned_device_name());
3850 
3851   // Copy requested_device and assigned_device_name_index
3852   new_node->set_requested_device(mklop->requested_device());
3853   new_node->set_assigned_device_name_index(mklop->assigned_device_name_index());
3854 
3855   (*g)->RemoveNode(transpose_to_nhwc);
3856   (*g)->RemoveNode(mklop);
3857   (*g)->RemoveNode(transpose_to_nchw);
3858 
3859   return Status::OK();
3860 }
3861 
FuseNode(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,const MklLayoutRewritePass::FusionInfo fi)3862 Status MklLayoutRewritePass::FuseNode(
3863     std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3864     const MklLayoutRewritePass::FusionInfo fi) {
3865   return fi.fuse_func(g, nodes, fi.copy_attrs);
3866 }
3867 
3868 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
CheckForNodeFusion(Node * a) const3869 MklLayoutRewritePass::CheckForNodeFusion(Node* a) const {
3870   // Stores matched nodes, in the same order as node_checkers.
3871   std::vector<Node*> nodes;
3872 
3873   for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) {
3874     //
3875     // Make sure node "a" and its succeeding nodes (b, c ...), match the pattern
3876     // defined in fusion info (ops[0], ops[1], ...),
3877     // a.k.a. "a->b->c" matches "op1->op2->op3"
3878     //
3879 
3880     // Stores the first unvisited outgoing edge of each matched node in "nodes".
3881     std::stack<EdgeSet::const_iterator> current_neighbor_stack;
3882     nodes.clear();
3883 
3884     auto node_checker = fi->node_checkers.begin();
3885     if (a != nullptr && (*node_checker)(a)) {
3886       nodes.push_back(a);
3887       current_neighbor_stack.push(a->out_edges().begin());
3888       ++node_checker;
3889     }
3890 
3891     while (!nodes.empty()) {
3892       auto& current_neighbor_iter = current_neighbor_stack.top();
3893 
3894       if (current_neighbor_iter != nodes.back()->out_edges().end()) {
3895         // Found an unvisited edge. Goes through the edge to get the neighbor.
3896         Node* neighbor_node = (*current_neighbor_iter)->dst();
3897         ++current_neighbor_stack.top();  // Retrieves the next unvisited edge.
3898 
3899         if ((*node_checker)(neighbor_node)) {
3900           // Found a match. Stores the node and moves to the next checker.
3901           nodes.push_back(neighbor_node);
3902           current_neighbor_stack.push(neighbor_node->out_edges().begin());
3903           if (++node_checker == fi->node_checkers.end()) {
3904             return make_tuple(true, nodes, *fi);
3905           }
3906         }
3907       } else {
3908         // Removes the current node since none of its neighbor leads to a
3909         // further match.
3910         nodes.pop_back();
3911         current_neighbor_stack.pop();
3912         --node_checker;
3913       }
3914     }
3915   }
3916 
3917   return make_tuple(false, std::vector<Node*>(), FusionInfo());
3918 }
3919 
3920 ///////////////////////////////////////////////////////////////////////////////
3921 //              Post-rewrite Mkl metadata fixup pass
3922 ///////////////////////////////////////////////////////////////////////////////
FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph> * g,const Edge * e_data,const Edge * e_metadata)3923 bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
3924                                                       const Edge* e_data,
3925                                                       const Edge* e_metadata) {
3926   if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
3927     return false;
3928   }
3929 
3930   Node* n_data = e_data->src();
3931   int n_data_op_slot = e_data->src_output();
3932   int n_metadata_op_slot =
3933       GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs());
3934 
3935   // If the source of meta edge is a constant node (producing dummy Mkl metadata
3936   // tensor), then we will need to fix.
3937   if (IsConstant(e_metadata->src())) {
3938     Node* e_metadata_dst = e_metadata->dst();
3939     int e_metadata_in_slot = e_metadata->dst_input();
3940     auto new_edge = (*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst,
3941                                   e_metadata_in_slot);
3942     DCHECK(new_edge);
3943 
3944     (*g)->RemoveEdge(e_metadata);
3945     return true;
3946   }
3947 
3948   return false;
3949 }
3950 
FixMklMetaDataEdges(std::unique_ptr<Graph> * g,Node * n)3951 bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
3952                                                Node* n) {
3953   bool result = false;
3954 
3955   // If graph node is not Mkl node, then return.
3956   DataType T = DT_INVALID;
3957   if (!TryGetNodeAttr(n->def(), "T", &T) ||
3958       !mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
3959     return result;
3960   }
3961 
3962   // If it is Mkl node, then check if the input edges to this node that carry
3963   // Mkl metadata are linked up correctly with the source node.
3964 
3965   // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
3966   // data tensors + n for Mkl metadata tensors). We need to check for correct
3967   // connection of n metadata tensors only.
3968   int num_data_inputs = n->num_inputs() / 2;
3969   for (int idx = 0; idx < num_data_inputs; idx++) {
3970     // Get the edge connecting input slot with index (idx).
3971     const Edge* e = nullptr;
3972     TF_CHECK_OK(n->input_edge(idx, &e));
3973 
3974     // If e is control edge, then skip.
3975     if (e->IsControlEdge()) {
3976       continue;
3977     }
3978 
3979     // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
3980     // node, then we don't need to do anything.
3981     Node* e_src = e->src();
3982     if (TryGetNodeAttr(e_src->def(), "T", &T) &&
3983         mkl_op_registry::IsMklLayoutDependentOp(e_src->type_string(), T)) {
3984       // Source node for edge 'e' is Mkl node.
3985       // Destination node and destination input slot of e is node 'n' and 'idx'
3986       // resp.
3987       CHECK_EQ(e->dst(), n);
3988       CHECK_EQ(e->dst_input(), idx);
3989 
3990       // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
3991       // 'e'. For that, let's first get the input slot of 'n' where the meta
3992       // edge will feed the value.
3993       int e_meta_in_slot =
3994           GetTensorMetaDataIndex(e->dst_input(), n->num_inputs());
3995       const Edge* e_meta = nullptr;
3996       TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
3997 
3998       // Let's check if we need to fix this meta edge.
3999       if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
4000         result = true;
4001       }
4002     }
4003   }
4004 
4005   return result;
4006 }
4007 
4008 ///////////////////////////////////////////////////////////////////////////////
4009 //              Run function for the pass
4010 ///////////////////////////////////////////////////////////////////////////////
4011 
RunPass(std::unique_ptr<Graph> * g)4012 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
4013   bool result = false;
4014   DCHECK(g);
4015 
4016   DumpGraph("Before running MklLayoutRewritePass", &**g);
4017 
4018   std::vector<Node*> order;
4019   GetReversePostOrder(**g, &order);  // This will give us topological sort.
4020   for (Node* n : order) {
4021     // If node is not an op or it cannot run on CPU device, then skip.
4022     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4023       continue;
4024     }
4025 
4026     Node* m = nullptr;
4027     if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) {
4028       // Check if the node 'n' can be merged with any other node. If it can
4029       // be 'm' contains the node with which it can be merged.
4030       string n1_name = n->name();
4031       string n2_name = m->name();
4032 
4033       VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
4034               << n2_name << " for merging";
4035 
4036       if (MergeNode(g, n, m) == Status::OK()) {
4037         VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
4038                 << n2_name;
4039         result = true;
4040       }
4041     }
4042   }
4043 
4044   DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g);
4045 
4046   order.clear();
4047   GetReversePostOrder(**g, &order);  // This will give us topological sort.
4048   for (Node* n : order) {
4049     // If node is not an op or it cannot run on CPU device, then skip.
4050     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4051       continue;
4052     }
4053 
4054     auto check_result = CheckForNodeFusion(n);
4055     bool found_pattern = std::get<0>(check_result);
4056     std::vector<Node*> nodes = std::get<1>(check_result);
4057     const FusionInfo fi = std::get<2>(check_result);
4058 
4059     // if "found_pattern" is true, we can do the fusion.
4060     if (found_pattern) {
4061       if (FuseNode(g, nodes, fi) == Status::OK()) {
4062         result = true;
4063       }
4064     }
4065   }
4066   DumpGraph("After running MklLayoutRewritePass(NodeFusion)", &**g);
4067 
4068   order.clear();
4069   GetReversePostOrder(**g, &order);  // This will give us topological sort.
4070   for (Node* n : order) {
4071     // If node is not an op or it cannot run on CPU device, then skip.
4072     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4073       continue;
4074     }
4075 
4076     const RewriteInfo* ri = nullptr;
4077     // We will first search if node is to be rewritten.
4078     if ((ri = CheckForNodeRewrite(n)) != nullptr) {
4079       string node_name = n->name();
4080       string op_name = n->type_string();
4081 
4082       VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
4083               << " with op " << op_name << " for rewrite using"
4084               << " layout optimization.";
4085 
4086       if (RewriteNode(g, n, ri) == Status::OK()) {
4087         VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
4088                 << " with op " << op_name << " for Mkl layout optimization.";
4089         result = true;
4090       }
4091     }
4092   }
4093 
4094   DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
4095 
4096   order.clear();
4097   GetReversePostOrder(**g, &order);  // This will give us topological sort.
4098   for (Node* n : order) {
4099     // If node is not an op or it cannot run on CPU device, then skip.
4100     if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4101       continue;
4102     }
4103     if (FixMklMetaDataEdges(g, n)) {
4104       string node_name = n->name();
4105       string op_name = n->type_string();
4106 
4107       VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
4108               << node_name << " with op " << op_name;
4109       result = true;
4110     }
4111   }
4112   DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
4113             &**g);
4114 
4115   return result;
4116 }
4117 
RunMklLayoutRewritePass(std::unique_ptr<Graph> * g)4118 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
4119   return MklLayoutRewritePass().RunPass(g);
4120 }
4121 
Run(const GraphOptimizationPassOptions & options)4122 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
4123   if (options.graph == nullptr && options.partition_graphs == nullptr) {
4124     return Status::OK();
4125   }
4126   if (!IsMKLEnabled()) {
4127     VLOG(2) << "TF-MKL: MKL is not enabled";
4128     return Status::OK();
4129   }
4130 
4131   auto process_graph = [&](std::unique_ptr<Graph>* g) {
4132     // Get the ownership of a graph
4133     std::unique_ptr<Graph>* ng = std::move(g);
4134     RunPass(ng);
4135     // Return the ownership of a graph back
4136     g->reset(ng->release());
4137   };
4138 
4139   if (kMklLayoutRewritePassGroup !=
4140       OptimizationPassRegistry::POST_PARTITIONING) {
4141     // For any pre-partitioning phase, a graph is stored in options.graph.
4142     process_graph(options.graph);
4143   } else {
4144     // For post partitioning phase, graphs are stored in
4145     // options.partition_graphs.
4146     for (auto& pg : *options.partition_graphs) {
4147       process_graph(&pg.second);
4148     }
4149   }
4150 
4151   return Status::OK();
4152 }
4153 
4154 }  // namespace tensorflow
4155 
4156 #endif  // INTEL_MKL
4157