1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // TODO(intel): Improve error handling in this file; instead of CHECK failing
17 // all over the place, we should log an error and execute the original graph.
18 #ifdef INTEL_MKL
19
20 #include "tensorflow/core/common_runtime/mkl_layout_pass.h"
21
22 #include <algorithm>
23 #include <functional>
24 #include <memory>
25 #include <queue>
26 #include <set>
27 #include <stack>
28 #include <tuple>
29 #include <unordered_set>
30 #include <utility>
31 #include <vector>
32
33 #include "absl/base/call_once.h"
34 #include "tensorflow/core/common_runtime/function.h"
35 #include "tensorflow/core/common_runtime/optimization_registry.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/tensor.pb.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph.h"
40 #include "tensorflow/core/graph/mkl_graph_util.h"
41 #include "tensorflow/core/graph/node_builder.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/gtl/array_slice.h"
44 #include "tensorflow/core/lib/gtl/map_util.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/util/tensor_format.h"
48 #include "tensorflow/core/util/util.h"
49
50 namespace tensorflow {
51
52 // This pass implements rewriting of graph to support following scenarios:
53 // (A) Merging nodes in the graph
54 // (B) Rewriting a node in the graph to a new node
55 // Rewrite happens under following scenario:
56 // - Propagating Mkl layout as an additional output tensor
57 // (we will loosely call a tensor that carries Mkl layout as Mkl tensor
58 // henceforth.) from every Mkl supported NN layer.
59 //
60 // Example of A : Merging nodes in the graph
61 // -----------------------------------------
62 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
63 //
64 // O = Conv2D(A, B)
65 // P = BiasAdd(O, C)
66 //
67 // We merge them into Conv2DWithBias as:
68 // P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
69 //
70 // The meaning of A_m, B_m and C_m is explained in B.1.
71 //
72 // Merge rules:
73 // - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
74 // goes to BiasAdd.
75 // - Also, the intersection of attributes of both the nodes must have same
76 // values.
77 // - Both the nodes must have been assigned to same device (if any).
78 //
79 // Example of B.1 : Rewriting nodes to Mkl nodes
80 // ---------------------------------------------
81 // Consider a Relu node. Current definition of Relu node looks like:
82 //
83 // O = Relu(A)
84 //
85 // Relu has 1 input (A), and 1 output (O).
86 //
87 // This rewrite pass will generate a new graph node for Relu (new node is
88 // called MklRelu) as:
89 //
90 // O, O_m = MklRelu(A, A_m)
91 //
92 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
93 // same as input A of Relu; output O is same as output O of Relu. O_m is the
94 // additional output tensor that will be set by MklRelu, and it represents
95 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
96 // metadata for O. A_m is additional input of Relu, and it represents metadata
97 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
98 // this metadata from previous node in the graph.
99 //
100 // When a previous node in the graph is an Mkl node, A_m will represent a valid
101 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
102 // a dummy Mkl tensor.
103 //
104 // Rewriting rules:
105 // - Selection of a node for rewriting happens by registering the op type of
106 // the node with the rewriting pass. If the op type is not registered, then
107 // all nodes of this op type will not be rewritten.
108 // - Number of inputs after rewriting:
109 // Since for every input Tensorflow tensor, the rewritten node gets Mkl
110 // tensor(s), rewritten node gets 2*N inputs, where N is the number of
111 // inputs for the original node.
112 // - Number of outputs after rewriting:
113 // Since for every output Tensorflow tensor, the rewritten node generates
114 // Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
115 // number of outputs of the original node.
116 // - Ordering of Tensorflow tensors and Mkl tensors:
117 // Since every rewritten node generates twice the number of inputs and
118 // outputs, one could imagine various orderings among Tensorflow tensors
119 // and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
120 // inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
121 // in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
122 // order. Among N inputs one can get N! permutations.
123 //
124 // So the question is: which order do we follow? We support 2 types of
125 // orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
126 // follows an intuitive order where an Mkl tensor follows the
127 // corresponding Tensorflow tensor immediately. In the context of the
128 // above example, it will be: A, A_m, B, B_m. Note that the ordering rule
129 // applies to both the inputs and outputs. Contiguous ordering means
130 // all the Tensorflow tensors are contiguous followed by all the Mkl
131 // tensors. We use contiguous ordering as default.
132 //
133 // Graph rewrite algorithm:
134 // Algorithm: Graph Rewrite
135 // Input: Graph G, Names of the nodes to rewrite and their new names
136 // Output: Modified Graph G' if the nodes are modified, G otherwise.
137 // Start:
138 // N = Topological_Sort(G) // N is a set of nodes in toposort order.
139 // foreach node n in N
140 // do
141 // if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input.
142 // then
143 // E = set of <incoming edge and its src_output slot> of n
144 // E' = {} // a new set of edges for rewritten node
145 // foreach <e,s> in E
146 // do
147 // E' U {<e,s>} // First copy edge which generates Tensorflow
148 // // tensor as it is
149 // m = Source node of edge e
150 // if Is_Rewritten(m) // Did we rewrite this node in this pass?
151 // then
152 // E' U {<m,s+1>} // If yes, then m will generate an Mkl
153 // // tensor as an additional output.
154 // else
155 // d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy
156 // // Mkl tensor.
157 // E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot.
158 // fi
159 // done
160 // n' = Build_New_Node(G,new_name,E')
161 // Mark_Rewritten(n') // Mark the new node as being rewritten.
162 // fi
163 // done
164 //
165 // Explanation:
166 // For graph rewrite, we visit nodes of the input graph in the
167 // topological sort order. With this ordering, we visit nodes in the
168 // top-to-bottom fashion. We need this order because while visiting a
169 // node we want that all of its input nodes are visited and rewritten if
170 // applicable. This is because if we need to rewrite a given node
171 // then all of its input nodes need to be fixed (in other words they
172 // cannot be deleted later.)
173 //
174 // While visiting a node, we first check if the op type of the node is
175 // an Mkl op. If it is, then we rewrite that node after constructing
176 // new inputs to the node. If the op type of the node is not Mkl op,
177 // then we do not rewrite that node.
178 //
179 // Handling workspace propagation for certain ops:
180 //
181 // Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
182 // passing of a workspace from their respective forward ops. Workspace
183 // tensors provide memory for storing results of intermediate operations
184 // which are helpful in backward propagation. TensorFlow does not have
185 // a notion of a workspace and as a result does not allow producing
186 // additional outputs from these forward ops. For these ops, we need
187 // to add 2 extra edges between forward ops and their corresponding
188 // backward ops - the first extra edge carries a workspace tensor and
189 // the second one carries an Mkl tensor for the workspace tensor.
190 //
191 // Example:
192 //
193 // Typical graph for MaxPool and its gradient looks like:
194 //
195 // A = MaxPool(T)
196 // B = MaxPoolGrad(X, A, Y)
197 //
198 // We will transform this graph to propagate the workspace as:
199 // (with the contiguous ordering)
200 //
201 // A, W, A_m, W_m = MklMaxPool(T, T_m)
202 // B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
203 //
204 // Here W is the workspace tensor. Transformed tensor names with the
205 // suffix _m are Mkl tensors, and this transformation has been done
206 // using the algorithm discussed earlier. The transformation for
207 // workspace propagation only adds extra outputs (W, W_m) for a forward
208 // op and connects them to the corresponding backward ops.
209 //
210 // Terms:
211 //
212 // Forward op name = name of the op in the forward pass
213 // where a workspace tensor originates (MaxPool in this example)
214 // Backward op name = name of the op in the backward pass that receives
215 // a workspace tensor from the forward op (MaxPoolGrad in the example)
216 // Slot = Position of the output or input slot that will be
217 // used by the workspace tensor (1 for MklMaxPool as W is the 2nd
218 // output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
219 //
220 // Question:
221 //
222 // How do we associate a backward op to a forward op? There can be more
223 // than one op with the exact same name.
224 //
225 // In this example, we associate MaxPoolGrad with MaxPool. But there
226 // could be more than one MaxPool ops. To solve this problem, we look
227 // for _direct_ edge between a forward op and a backward op (tensor A is
228 // flowing along this edge in the example).
229 //
230 // How do we transform forward and backward ops when there is no direct
231 // edge between them? In such a case, we generate dummy tensors for
232 // workspace tensors. For the example, transformation of MaxPool will
233 // be exactly same as it would be when there is a direct edge between
234 // the forward and the backward op --- it is just that MaxPool won't
235 // generate any workspace tensor. For MaxPoolGrad, the transformation
236 // will also be same, but instead of connecting W and W_m with the
237 // outputs of MaxPool, we will produce dummy tensors for them, and we
238 // will set workspace_enabled attribute to false.
239 //
240 class MklLayoutRewritePass : public GraphOptimizationPass {
241 public:
MklLayoutRewritePass()242 MklLayoutRewritePass() {
243 // NOTE: names are alphabetically sorted.
244 csinfo_.addn = "AddN";
245 csinfo_.avg_pool = "AvgPool";
246 csinfo_.avg_pool_grad = "AvgPoolGrad";
247 csinfo_.avg_pool3d = "AvgPool3D";
248 csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
249 csinfo_.batch_matmul = "BatchMatMul";
250 csinfo_.batch_matmul_v2 = "BatchMatMulV2";
251 csinfo_.bias_add = "BiasAdd";
252 csinfo_.bias_add_grad = "BiasAddGrad";
253 csinfo_.concat = "Concat";
254 csinfo_.concatv2 = "ConcatV2";
255 csinfo_.conjugate_transpose = "ConjugateTranspose";
256 csinfo_.conv2d = "Conv2D";
257 csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias";
258 csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
259 csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
260 csinfo_.conv2d_grad_filter_with_bias =
261 "__MklDummyConv2DBackpropFilterWithBias";
262 csinfo_.conv3d = "Conv3D";
263 csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
264 csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
265 csinfo_.depthwise_conv2d = "DepthwiseConv2dNative";
266 csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput";
267 csinfo_.depthwise_conv2d_grad_filter =
268 "DepthwiseConv2dNativeBackpropFilter";
269 csinfo_.dequantize = "Dequantize";
270 csinfo_.einsum = "Einsum";
271 csinfo_.fused_batch_norm = "FusedBatchNorm";
272 csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
273 csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx";
274 csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
275 csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
276 csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
277 csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
278 csinfo_.fused_conv2d = "_FusedConv2D";
279 csinfo_.fused_conv3d = "_FusedConv3D";
280 csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
281 csinfo_.fused_matmul = "_FusedMatMul";
282 csinfo_.identity = "Identity";
283 csinfo_.leakyrelu = "LeakyRelu";
284 csinfo_.leakyrelu_grad = "LeakyReluGrad";
285 csinfo_.lrn = "LRN";
286 csinfo_.lrn_grad = "LRNGrad";
287 csinfo_.matmul = "MatMul";
288 csinfo_.max_pool = "MaxPool";
289 csinfo_.max_pool_grad = "MaxPoolGrad";
290 csinfo_.max_pool3d = "MaxPool3D";
291 csinfo_.max_pool3d_grad = "MaxPool3DGrad";
292 csinfo_.mkl_conv2d = "_MklConv2D";
293 csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
294 csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
295 csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
296 csinfo_.mkl_conv2d_grad_filter_with_bias =
297 "_MklConv2DBackpropFilterWithBias";
298 csinfo_.mkl_depthwise_conv2d_grad_input =
299 "_MklDepthwiseConv2dNativeBackpropInput";
300 csinfo_.mkl_depthwise_conv2d_grad_filter =
301 "_MklDepthwiseConv2dNativeBackpropFilter";
302 csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx";
303 csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
304 csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
305 csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
306 csinfo_.mkl_native_conv2d_with_bias = "_MklNativeConv2DWithBias";
307 csinfo_.mkl_native_conv2d_grad_filter_with_bias =
308 "_MklNativeConv2DBackpropFilterWithBias";
309 csinfo_.mkl_native_fused_batch_norm_ex = "_MklNativeFusedBatchNormEx";
310 csinfo_.mkl_native_fused_conv2d = "_MklNativeFusedConv2D";
311 csinfo_.mkl_native_fused_conv3d = "_MklNativeFusedConv3D";
312 csinfo_.mkl_native_fused_depthwise_conv2d =
313 "_MklNativeFusedDepthwiseConv2dNative";
314 csinfo_.mkl_native_fused_matmul = "_MklNativeFusedMatMul";
315 csinfo_.mkl_native_pad_with_conv2d = "_MklNativePadWithConv2D";
316 csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D";
317 csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
318 csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
319 csinfo_.pad = "Pad";
320 csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D";
321 csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D";
322 csinfo_.quantized_avg_pool = "QuantizedAvgPool";
323 csinfo_.quantized_concatv2 = "QuantizedConcatV2";
324 csinfo_.quantized_conv2d = "QuantizedConv2D";
325 csinfo_.quantized_conv2d_per_channel = "QuantizedConv2DPerChannel";
326 csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize";
327 csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias";
328 csinfo_.quantized_conv2d_with_bias_and_requantize =
329 "QuantizedConv2DWithBiasAndRequantize";
330 csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu";
331 csinfo_.quantized_conv2d_and_relu_and_requantize =
332 "QuantizedConv2DAndReluAndRequantize";
333 csinfo_.quantized_conv2d_with_bias_and_relu =
334 "QuantizedConv2DWithBiasAndRelu";
335 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize =
336 "QuantizedConv2DWithBiasAndReluAndRequantize";
337 csinfo_.quantized_max_pool = "QuantizedMaxPool";
338 csinfo_.quantized_conv2d_with_bias_sum_and_relu =
339 "QuantizedConv2DWithBiasSumAndRelu";
340 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize =
341 "QuantizedConv2DWithBiasSumAndReluAndRequantize";
342 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize =
343 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize";
344 csinfo_.quantized_matmul_with_bias = "QuantizedMatMulWithBias";
345 csinfo_.quantized_matmul_with_bias_and_relu =
346 "QuantizedMatMulWithBiasAndRelu";
347 csinfo_.quantized_matmul_with_bias_and_relu_and_requantize =
348 "QuantizedMatMulWithBiasAndReluAndRequantize";
349 csinfo_.quantized_matmul_with_bias_and_dequantize =
350 "QuantizedMatMulWithBiasAndDequantize";
351 csinfo_.quantized_matmul_with_bias_and_requantize =
352 "QuantizedMatMulWithBiasAndRequantize";
353 csinfo_.quantized_depthwise_conv2d = "QuantizedDepthwiseConv2D";
354 csinfo_.quantized_depthwise_conv2d_with_bias =
355 "QuantizedDepthwiseConv2DWithBias";
356 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu =
357 "QuantizedDepthwiseConv2DWithBiasAndRelu";
358 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize =
359 "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize";
360 csinfo_.quantize_v2 = "QuantizeV2";
361 csinfo_.relu = "Relu";
362 csinfo_.relu_grad = "ReluGrad";
363 csinfo_.relu6 = "Relu6";
364 csinfo_.relu6_grad = "Relu6Grad";
365 csinfo_.requantize = "Requantize";
366 csinfo_.tanh = "Tanh";
367 csinfo_.tanh_grad = "TanhGrad";
368 csinfo_.reshape = "Reshape";
369 csinfo_.slice = "Slice";
370 csinfo_.softmax = "Softmax";
371 csinfo_.split = "Split";
372 csinfo_.transpose = "Transpose";
373 // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
374 // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
375 // MklInputConversion op is added before it.
376 csinfo_.add = "Add";
377 csinfo_.add_v2 = "AddV2";
378 csinfo_.maximum = "Maximum";
379 csinfo_.mul = "Mul";
380 csinfo_.squared_difference = "SquaredDifference";
381 csinfo_.sub = "Sub";
382 // End - element-wise ops. See note above.
383
384 const bool native_fmt = NativeFormatEnabled();
385 // NOTE: names are alphabetically sorted.
386 rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
387 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
388 rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
389 CopyAttrsAll, RewriteIfAtleastOneMklInput,
390 GetRewriteCause()});
391 rinfo_.push_back(
392 {csinfo_.add_v2, mkl_op_registry::GetMklOpName(csinfo_.add_v2),
393 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
394 rinfo_.push_back({csinfo_.avg_pool,
395 mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
396 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
397 rinfo_.push_back({csinfo_.avg_pool_grad,
398 mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
399 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
400 rinfo_.push_back({csinfo_.avg_pool3d,
401 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
402 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
403 rinfo_.push_back({csinfo_.avg_pool3d_grad,
404 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
405 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
406 rinfo_.push_back({csinfo_.batch_matmul,
407 mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
408 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
409 rinfo_.push_back({csinfo_.einsum,
410 mkl_op_registry::GetMklOpName(csinfo_.einsum),
411 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
412 rinfo_.push_back({csinfo_.batch_matmul_v2,
413 mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
414 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
415 rinfo_.push_back({csinfo_.concat,
416 mkl_op_registry::GetMklOpName(csinfo_.concat),
417 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
418 rinfo_.push_back({csinfo_.concatv2,
419 mkl_op_registry::GetMklOpName(csinfo_.concatv2),
420 CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
421 rinfo_.push_back(
422 {csinfo_.conjugate_transpose,
423 mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
424 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
425 rinfo_.push_back(
426 {csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d),
427 CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
428 rinfo_.push_back({csinfo_.conv2d_with_bias,
429 native_fmt ? csinfo_.mkl_native_conv2d_with_bias
430 : csinfo_.mkl_conv2d_with_bias,
431 CopyAttrsConvCheckConstFilter, AlwaysRewrite,
432 GetRewriteCause()});
433 rinfo_.push_back({csinfo_.conv2d_grad_filter,
434 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
435 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
436 rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
437 native_fmt
438 ? csinfo_.mkl_native_conv2d_grad_filter_with_bias
439 : csinfo_.mkl_conv2d_grad_filter_with_bias,
440 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
441 rinfo_.push_back({csinfo_.conv2d_grad_input,
442 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
443 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
444 rinfo_.push_back(
445 {csinfo_.conv3d, mkl_op_registry::GetMklOpName(csinfo_.conv3d),
446 CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
447 rinfo_.push_back({csinfo_.conv3d_grad_filter,
448 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
449 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
450 rinfo_.push_back({csinfo_.conv3d_grad_input,
451 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
452 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
453 rinfo_.push_back({csinfo_.depthwise_conv2d,
454 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
455 CopyAttrsConvCheckConstFilter, AlwaysRewrite,
456 GetRewriteCause()});
457 rinfo_.push_back(
458 {csinfo_.depthwise_conv2d_grad_input,
459 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
460 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
461 rinfo_.push_back(
462 {csinfo_.depthwise_conv2d_grad_filter,
463 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
464 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
465 rinfo_.push_back(
466 {csinfo_.dequantize, mkl_op_registry::GetMklOpName(csinfo_.dequantize),
467 CopyAttrsAll, DequantizeRewrite, kRewriteForOpNameChange});
468 rinfo_.push_back({csinfo_.fused_batch_norm,
469 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
470 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
471 rinfo_.push_back(
472 {csinfo_.fused_batch_norm_grad,
473 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
474 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
475 rinfo_.push_back(
476 {csinfo_.fused_batch_norm_v2,
477 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v2),
478 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
479 rinfo_.push_back(
480 {csinfo_.fused_batch_norm_grad_v2,
481 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v2),
482 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
483
484 // Using CopyAttrsAll for V3 on CPU, as there are no additional
485 // attributes.
486 rinfo_.push_back(
487 {csinfo_.fused_batch_norm_v3,
488 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
489 CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
490 rinfo_.push_back(
491 {csinfo_.fused_batch_norm_grad_v3,
492 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
493 CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
494 rinfo_.push_back({csinfo_.fused_batch_norm_ex,
495 native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
496 : csinfo_.mkl_fused_batch_norm_ex,
497 CopyAttrsAll, FusedBatchNormExRewrite,
498 GetRewriteCause()});
499 rinfo_.push_back({csinfo_.fused_conv2d,
500 native_fmt ? csinfo_.mkl_native_fused_conv2d
501 : csinfo_.mkl_fused_conv2d,
502 CopyAttrsAllCheckConstFilter, FusedConv2DRewrite,
503 GetRewriteCause()});
504 rinfo_.push_back({csinfo_.fused_conv3d, csinfo_.mkl_native_fused_conv3d,
505 CopyAttrsAllCheckConstFilter, AlwaysRewrite,
506 kRewriteForOpNameChange});
507 rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
508 native_fmt ? csinfo_.mkl_native_fused_depthwise_conv2d
509 : csinfo_.mkl_fused_depthwise_conv2d,
510 CopyAttrsAllCheckConstFilter, FusedDepthwiseConv2DRewrite,
511 GetRewriteCause()});
512 rinfo_.push_back({csinfo_.fused_matmul,
513 native_fmt ? csinfo_.mkl_native_fused_matmul
514 : csinfo_.mkl_fused_matmul,
515 CopyAttrsAllCheckConstFilter, FusedMatMulRewrite,
516 GetRewriteCause()});
517 rinfo_.push_back(
518 {csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity),
519 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
520 rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
521 CopyAttrsAll, LrnRewrite, GetRewriteCause()});
522 rinfo_.push_back({csinfo_.lrn_grad,
523 mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
524 CopyAttrsAll, LrnGradRewrite, GetRewriteCause()});
525 rinfo_.push_back({csinfo_.matmul,
526 mkl_op_registry::GetMklOpName(csinfo_.matmul),
527 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
528 rinfo_.push_back({csinfo_.leakyrelu,
529 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
530 CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
531 rinfo_.push_back({csinfo_.leakyrelu_grad,
532 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
533 CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
534 rinfo_.push_back(
535 {csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool),
536 CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
537 rinfo_.push_back({csinfo_.max_pool_grad,
538 mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
539 CopyAttrsAll, MaxpoolGradRewrite, GetRewriteCause()});
540 rinfo_.push_back(
541 {csinfo_.max_pool3d, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
542 CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
543 rinfo_.push_back({csinfo_.max_pool3d_grad,
544 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
545 CopyAttrsAll, Maxpool3DGradRewrite, GetRewriteCause()});
546 rinfo_.push_back(
547 {csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum),
548 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
549 rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
550 CopyAttrsAll, RewriteIfAtleastOneMklInput,
551 GetRewriteCause()});
552 rinfo_.push_back({csinfo_.pad_with_conv2d,
553 native_fmt ? csinfo_.mkl_native_pad_with_conv2d
554 : csinfo_.mkl_pad_with_conv2d,
555 CopyAttrsAllCheckConstFilter, AlwaysRewrite,
556 GetRewriteCause()});
557 rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
558 native_fmt ? csinfo_.mkl_native_pad_with_fused_conv2d
559 : csinfo_.mkl_pad_with_fused_conv2d,
560 CopyAttrsAllCheckConstFilter, AlwaysRewrite,
561 GetRewriteCause()});
562 rinfo_.push_back({csinfo_.quantized_avg_pool,
563 mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
564 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
565 rinfo_.push_back({csinfo_.quantized_concatv2,
566 mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
567 CopyAttrsAll, ConcatV2Rewrite, kRewriteForOpNameChange});
568 rinfo_.push_back({csinfo_.quantized_conv2d,
569 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
570 CopyAttrsQuantizedConv2D, AlwaysRewrite,
571 kRewriteForOpNameChange});
572 rinfo_.push_back(
573 {csinfo_.quantized_conv2d_per_channel,
574 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_per_channel),
575 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
576 rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
577 mkl_op_registry::GetMklOpName(
578 csinfo_.quantized_conv2d_with_requantize),
579 CopyAttrsQuantizedConv2D, AlwaysRewrite,
580 kRewriteForOpNameChange});
581 rinfo_.push_back(
582 {csinfo_.quantized_conv2d_with_bias,
583 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
584 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
585 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
586 mkl_op_registry::GetMklOpName(
587 csinfo_.quantized_conv2d_with_bias_and_requantize),
588 CopyAttrsQuantizedConv2D, AlwaysRewrite,
589 kRewriteForOpNameChange});
590 rinfo_.push_back(
591 {csinfo_.quantized_conv2d_and_relu,
592 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
593 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
594 rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
595 mkl_op_registry::GetMklOpName(
596 csinfo_.quantized_conv2d_and_relu_and_requantize),
597 CopyAttrsQuantizedConv2D, AlwaysRewrite,
598 kRewriteForOpNameChange});
599 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
600 mkl_op_registry::GetMklOpName(
601 csinfo_.quantized_conv2d_with_bias_and_relu),
602 CopyAttrsQuantizedConv2D, AlwaysRewrite,
603 kRewriteForOpNameChange});
604 rinfo_.push_back(
605 {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
606 mkl_op_registry::GetMklOpName(
607 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
608 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
609 rinfo_.push_back({csinfo_.quantized_max_pool,
610 mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
611 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
612 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
613 mkl_op_registry::GetMklOpName(
614 csinfo_.quantized_conv2d_with_bias_sum_and_relu),
615 CopyAttrsQuantizedConv2D, AlwaysRewrite,
616 kRewriteForOpNameChange});
617 rinfo_.push_back(
618 {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
619 mkl_op_registry::GetMklOpName(
620 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
621 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
622 rinfo_.push_back(
623 {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
624 mkl_op_registry::GetMklOpName(
625 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
626 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
627 rinfo_.push_back(
628 {csinfo_.quantized_matmul_with_bias,
629 mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias),
630 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
631 kRewriteForOpNameChange});
632 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_relu,
633 mkl_op_registry::GetMklOpName(
634 csinfo_.quantized_matmul_with_bias_and_relu),
635 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
636 kRewriteForOpNameChange});
637 rinfo_.push_back(
638 {csinfo_.quantized_matmul_with_bias_and_relu_and_requantize,
639 mkl_op_registry::GetMklOpName(
640 csinfo_.quantized_matmul_with_bias_and_relu_and_requantize),
641 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
642 kRewriteForOpNameChange});
643 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_requantize,
644 mkl_op_registry::GetMklOpName(
645 csinfo_.quantized_matmul_with_bias_and_requantize),
646 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
647 kRewriteForOpNameChange});
648 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_dequantize,
649 mkl_op_registry::GetMklOpName(
650 csinfo_.quantized_matmul_with_bias_and_dequantize),
651 CopyAttrsQuantizedMatMulWithBiasAndDequantize,
652 AlwaysRewrite, kRewriteForOpNameChange});
653 rinfo_.push_back(
654 {csinfo_.quantized_depthwise_conv2d,
655 mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d),
656 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
657 rinfo_.push_back({csinfo_.quantized_depthwise_conv2d_with_bias,
658 mkl_op_registry::GetMklOpName(
659 csinfo_.quantized_depthwise_conv2d_with_bias),
660 CopyAttrsQuantizedConv2D, AlwaysRewrite,
661 kRewriteForOpNameChange});
662 rinfo_.push_back(
663 {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu,
664 mkl_op_registry::GetMklOpName(
665 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu),
666 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
667 rinfo_.push_back(
668 {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize,
669 mkl_op_registry::GetMklOpName(
670 csinfo_
671 .quantized_depthwise_conv2d_with_bias_and_relu_and_requantize),
672 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
673 rinfo_.push_back({csinfo_.quantize_v2,
674 mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
675 CopyAttrsAll, QuantizeOpRewrite,
676 kRewriteForOpNameChange});
677 rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
678 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
679 rinfo_.push_back({csinfo_.relu_grad,
680 mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
681 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
682 rinfo_.push_back({csinfo_.relu6,
683 mkl_op_registry::GetMklOpName(csinfo_.relu6),
684 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
685 rinfo_.push_back({csinfo_.relu6_grad,
686 mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
687 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
688 rinfo_.push_back({csinfo_.requantize,
689 mkl_op_registry::GetMklOpName(csinfo_.requantize),
690 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
691 // Optimized TanhGrad support exists only in DNNL 1.x.
692 rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
693 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
694 rinfo_.push_back({csinfo_.tanh_grad,
695 mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
696 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
697 rinfo_.push_back({csinfo_.reshape,
698 mkl_op_registry::GetMklOpName(csinfo_.reshape),
699 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
700 rinfo_.push_back(
701 {csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice),
702 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
703 rinfo_.push_back({csinfo_.softmax,
704 mkl_op_registry::GetMklOpName(csinfo_.softmax),
705 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
706
707 rinfo_.push_back({csinfo_.squared_difference,
708 mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
709 CopyAttrsAll, RewriteIfAtleastOneMklInput,
710 GetRewriteCause()});
711 rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
712 CopyAttrsAll, RewriteIfAtleastOneMklInput,
713 GetRewriteCause()});
714 rinfo_.push_back({csinfo_.transpose,
715 mkl_op_registry::GetMklOpName(csinfo_.transpose),
716 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
717
718 // Add info about which ops to add workspace edge to and the slots.
719 wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
720 wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
721 wsinfo_.push_back(
722 {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
723
724 // Add a rule for merging nodes
725 minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
726 csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
727
728 // Merge Pad and Conv2d, only if the pad op is "Pad"
729 // Doesn't merge if pad op is "PadV2" or "MirrorPad"
730 minfo_.push_back(
731 {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D});
732
733 minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
734 csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
735
736 minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
737 csinfo_.conv2d_grad_filter_with_bias,
738 GetConv2DBackpropFilterOrBiasAddGrad});
739
740 // The fusion patterns in "finfo_" that show up first will get applied
741 // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
742 // A->B->C->D to ABCD}, since the first gets applied first, the final
743 // graph will be ABC->D.
744 }
745
746 // Standard interface to run pass
747 Status Run(const GraphOptimizationPassOptions& options);
748
749 // Helper function which does most of heavy lifting for rewriting
750 // Mkl nodes to propagate Mkl tensor as additional output
751 //
752 // Extracts common functionality between Run public interface and
753 // test interface.
754 //
755 // @return true, if and only if graph is mutated; false otherwise.
756 bool RunPass(std::unique_ptr<Graph>* g);
757
758 /// Cause for rewrite
759 /// Currently, we only support 2 causes - either for Mkl layout propagation
760 /// which is the most common case, or for just a name change (used in case
761 /// of ops like MatMul, Transpose, which do not support Mkl layout)
762 enum RewriteCause { kRewriteForLayoutPropagation, kRewriteForOpNameChange };
763
764 // Get the op rewrite cause depending on whether native format mode
765 // is enabled or not.
GetRewriteCause()766 RewriteCause GetRewriteCause() {
767 if (NativeFormatEnabled()) {
768 return kRewriteForOpNameChange;
769 } else {
770 return kRewriteForLayoutPropagation;
771 }
772 }
773
774 /// Structure to specify the name of an original node, its new name after
775 /// rewrite, the number of inputs to the original node, the function to
776 /// be used to copy attributes for the op, and the rule (if any) which
777 /// must hold for rewriting the node
778 typedef struct {
779 string name; // Original name of op of the node in the graph
780 string new_name; // New name of the op of the node in the graph
781 // A function handler to copy attributes from an old node to a new node.
782 std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs;
783 // A rule under which to rewrite this node
784 std::function<bool(const Node*)> rewrite_rule;
785 // Why are we rewriting?
786 RewriteCause rewrite_cause;
787 } RewriteInfo;
788
789 /// Structure to specify a forward op, a backward op, and the slot numbers
790 /// in the forward and backward ops where we will add a workspace edge.
791 typedef struct {
792 string fwd_op; // Name of a forward op in the graph
793 string bwd_op; // Name of a backward op in the graph
794 int fwd_slot; // Output slot in the forward op node where actual
795 // output tensor resides
796 int bwd_slot; // Input slot in the backward op node where actual
797 // input tensor resides
798 int ws_fwd_slot; // Output slot in the forward op node where workspace
799 // edge is added
800 int ws_bwd_slot; // Input slot in the backward op node where workspace
801 // edge is added
802 } WorkSpaceInfo;
803
804 /// Structure to specify information used in node merge of 2 operators
805 typedef struct {
806 string op1; // Node string for one operator.
807 string op2; // Node string for second operator.
808 string new_node; // Name of the node after merge
809 // Function that enables user of the node merger to specify how to find
810 // second operator given the first operator.
811 std::function<Node*(const Node*)> get_node_to_be_merged;
812 } MergeInfo;
813
814 // Structure to specify information used in node fusion of 3+ operators
815 typedef struct {
816 std::string pattern_name; // Name to describe this pattern, such as
817 // "Transpose_Mklop_Transpose".
818 std::vector<std::function<bool(const Node*)> >
819 node_checkers; // Extra restriction checker for these ops
820 std::function<Status(
821 std::unique_ptr<Graph>*, std::vector<Node*>&,
822 std::function<void(const Node*, NodeBuilder* nb, bool)>)>
823 fuse_func;
824 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs;
825 } FusionInfo;
826
827 //
828 // Dimension indices for 2D tensor.
829 //
830 struct NCHW {
831 enum dim { N = 0, C = 1, H = 2, W = 3 };
832 };
833
834 struct NHWC {
835 enum dim { N = 0, H = 1, W = 2, C = 3 };
836 };
837
838 //
839 // dimension indices for 3D tensor.
840 //
841 struct NCDHW {
842 enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 };
843 };
844
845 struct NDHWC {
846 enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 };
847 };
848
849 /// Structure to store all constant strings
850 /// NOTE: names are alphabetically sorted.
851 typedef struct {
852 string addn;
853 string add;
854 string add_v2;
855 string avg_pool;
856 string avg_pool_grad;
857 string avg_pool3d;
858 string avg_pool3d_grad;
859 string batch_matmul;
860 string batch_matmul_v2;
861 string bias_add;
862 string bias_add_grad;
863 string concat;
864 string concatv2;
865 string conjugate_transpose;
866 string conv2d;
867 string conv2d_with_bias;
868 string conv2d_grad_input;
869 string conv2d_grad_filter;
870 string conv2d_grad_filter_with_bias;
871 string conv3d;
872 string conv3d_grad_input;
873 string conv3d_grad_filter;
874 string depthwise_conv2d;
875 string depthwise_conv2d_grad_input;
876 string depthwise_conv2d_grad_filter;
877 string dequantize;
878 string einsum;
879 string fused_batch_norm;
880 string fused_batch_norm_grad;
881 string fused_batch_norm_ex;
882 string fused_batch_norm_v2;
883 string fused_batch_norm_grad_v2;
884 string fused_batch_norm_v3;
885 string fused_batch_norm_grad_v3;
886 string fused_conv2d;
887 string fused_conv3d;
888 string fused_depthwise_conv2d;
889 string fused_matmul;
890 string identity;
891 string leakyrelu;
892 string leakyrelu_grad;
893 string lrn;
894 string lrn_grad;
895 string matmul;
896 string max_pool;
897 string max_pool_grad;
898 string max_pool3d;
899 string max_pool3d_grad;
900 string maximum;
901 string mkl_conv2d;
902 string mkl_conv2d_grad_input;
903 string mkl_conv2d_grad_filter;
904 string mkl_conv2d_grad_filter_with_bias;
905 string mkl_conv2d_with_bias;
906 string mkl_depthwise_conv2d_grad_input;
907 string mkl_depthwise_conv2d_grad_filter;
908 string mkl_fused_batch_norm_ex;
909 string mkl_fused_conv2d;
910 string mkl_fused_depthwise_conv2d;
911 string mkl_fused_matmul;
912 string mkl_native_conv2d_with_bias;
913 string mkl_native_conv2d_grad_filter_with_bias;
914 string mkl_native_fused_batch_norm_ex;
915 string mkl_native_fused_conv2d;
916 string mkl_native_fused_conv3d;
917 string mkl_native_fused_depthwise_conv2d;
918 string mkl_native_fused_matmul;
919 string mkl_native_pad_with_conv2d;
920 string mkl_native_pad_with_fused_conv2d;
921 string mkl_pad_with_conv2d;
922 string mkl_pad_with_fused_conv2d;
923 string mul;
924 string pad;
925 string pad_with_conv2d;
926 string pad_with_fused_conv2d;
927 string quantized_avg_pool;
928 string quantized_conv2d;
929 string quantized_conv2d_per_channel;
930 string quantized_conv2d_with_requantize;
931 string quantized_conv2d_with_bias;
932 string quantized_conv2d_with_bias_and_requantize;
933 string quantized_conv2d_and_relu;
934 string quantized_conv2d_and_relu_and_requantize;
935 string quantized_conv2d_with_bias_and_relu;
936 string quantized_conv2d_with_bias_and_relu_and_requantize;
937 string quantized_concatv2;
938 string quantized_max_pool;
939 string quantized_conv2d_with_bias_sum_and_relu;
940 string quantized_conv2d_with_bias_sum_and_relu_and_requantize;
941 string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize;
942 string quantized_matmul_with_bias;
943 string quantized_matmul_with_bias_and_relu;
944 string quantized_matmul_with_bias_and_relu_and_requantize;
945 string quantized_matmul_with_bias_and_requantize;
946 string quantized_matmul_with_bias_and_dequantize;
947 string quantized_depthwise_conv2d;
948 string quantized_depthwise_conv2d_with_bias;
949 string quantized_depthwise_conv2d_with_bias_and_relu;
950 string quantized_depthwise_conv2d_with_bias_and_relu_and_requantize;
951 string quantize_v2;
952 string relu;
953 string relu_grad;
954 string relu6;
955 string relu6_grad;
956 string requantize;
957 string tanh;
958 string tanh_grad;
959 string transpose;
960 string reshape;
961 string slice;
962 string softmax;
963 string split;
964 string squared_difference;
965 string sub;
966 } ConstStringsInfo;
967
968 private:
969 /// Maintain info about nodes to rewrite
970 std::vector<RewriteInfo> rinfo_;
971
972 /// Maintain info about nodes to add workspace edge
973 std::vector<WorkSpaceInfo> wsinfo_;
974
975 /// Maintain info about nodes to be merged
976 std::vector<MergeInfo> minfo_;
977
978 /// Maintain info about nodes to be fused
979 std::vector<FusionInfo> finfo_;
980
981 /// Maintain structure of constant strings
982 static ConstStringsInfo csinfo_;
983
984 private:
985 // Is OpDef::ArgDef a list type? It could be N * T or list(type).
986 // Refer to opdef.proto for details of list type.
ArgIsList(const OpDef::ArgDef & arg) const987 inline bool ArgIsList(const OpDef::ArgDef& arg) const {
988 return !arg.type_list_attr().empty() || !arg.number_attr().empty();
989 }
990
991 // Get length of a list in 'n' if 'arg' is of list type. Refer to
992 // description of ArgIsList for definition of list type.
GetTensorListLength(const OpDef::ArgDef & arg,const Node * n)993 inline int GetTensorListLength(const OpDef::ArgDef& arg, const Node* n) {
994 CHECK_EQ(ArgIsList(arg), true);
995 int N = 0;
996 const string attr_name = !arg.type_list_attr().empty()
997 ? arg.type_list_attr()
998 : arg.number_attr();
999 if (!arg.type_list_attr().empty()) {
1000 std::vector<DataType> value;
1001 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
1002 N = value.size();
1003 } else {
1004 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
1005 }
1006 return N;
1007 }
1008
1009 // Can op represented by node 'n' run on DEVICE_CPU?
1010 // Op can run on CPU with MKL if the runtime assigned device or the
1011 // user requested device contains device CPU, or both are empty.
CanOpRunOnCPUDevice(const Node * n)1012 bool CanOpRunOnCPUDevice(const Node* n) {
1013 bool result = true;
1014 string reason;
1015
1016 // Substring that should be checked for in device name for CPU device.
1017 const char* const kCPUDeviceSubStr = "CPU";
1018 const char* const kXLACPUDeviceSubStr = "XLA_CPU";
1019
1020 // If Op has been specifically assigned to a non-CPU or XLA_CPU device, then
1021 // No.
1022 if (!n->assigned_device_name().empty() &&
1023 (!absl::StrContains(n->assigned_device_name(), kCPUDeviceSubStr) ||
1024 absl::StrContains(n->assigned_device_name(), kXLACPUDeviceSubStr))) {
1025 result = false;
1026 reason = "Op has been assigned a runtime device that is not CPU.";
1027 }
1028
1029 // If user has specifically assigned this op to a non-CPU or XLA_CPU device,
1030 // then No.
1031 if (!n->def().device().empty() &&
1032 (!absl::StrContains(n->def().device(), kCPUDeviceSubStr) ||
1033 absl::StrContains(n->def().device(), kXLACPUDeviceSubStr))) {
1034 result = false;
1035 reason = "User has assigned a device that is not CPU.";
1036 }
1037
1038 if (result == false) {
1039 VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
1040 << n->type_string() << ", reason: " << reason;
1041 }
1042
1043 // Otherwise Yes.
1044 return result;
1045 }
1046
1047 // Return a node that can be merged with input node 'n'
1048 //
1049 // @return pointer to the node if we can find such a
1050 // node. Otherwise, it returns nullptr.
1051 Node* CheckForNodeMerge(const Node* n) const;
1052
1053 // Merge node 'm' with node 'n'.
1054 // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with
1055 // Conv2DBackpropFilter.
1056 //
1057 // Input nodes m and n may be deleted if the call to
1058 // this function is successful. Attempt to use the pointers
1059 // after the call to function may result in undefined behaviors.
1060 //
1061 // @input g - input graph, m - graph node, n - graph node to be merged with m
1062 // @return Status::OK(), if merging is successful and supported.
1063 // Returns appropriate Status error code otherwise.
1064 // Graph is updated in case nodes are merged. Otherwise, it is
1065 // not updated.
1066 Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n);
1067
1068 // Helper function to merge different nodes
1069 Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n);
1070 Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n);
1071 Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g,
1072 Node* m, Node* n);
1073
1074 // Find BiasAdd or Conv2D node that can be merged with input node 'm'.
1075 // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be
1076 // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd
1077 // node that can be merged with 'm'.
GetConv2DOrBiasAdd(const Node * m)1078 static Node* GetConv2DOrBiasAdd(const Node* m) {
1079 DCHECK(m);
1080 Node* n = nullptr;
1081
1082 DataType T_m;
1083 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1084
1085 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1086 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1087
1088 if (m->type_string() == csinfo_.bias_add) {
1089 // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
1090 TF_CHECK_OK(m->input_node(0, &n));
1091 } else {
1092 CHECK_EQ(m->type_string(), csinfo_.conv2d);
1093 // Go over all output edges and search for BiasAdd Node.
1094 // 0th input of BiasAdd is Conv2D.
1095 for (const Edge* e : m->out_edges()) {
1096 if (!e->IsControlEdge() &&
1097 e->dst()->type_string() == csinfo_.bias_add &&
1098 e->dst_input() == 0) {
1099 n = e->dst();
1100 break;
1101 }
1102 }
1103 }
1104
1105 if (n == nullptr) {
1106 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1107 << "Conv2D and BiasAdd node for merging. Input node: "
1108 << m->DebugString();
1109 }
1110
1111 return n;
1112 }
1113
1114 // Find Pad or Conv2D node that can be merged with input node 'm'.
1115 // If input 'm' is Pad, then check if there exists Conv2D node that can be
1116 // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad
1117 // node that can be merged with 'm'.
GetPadOrConv2D(const Node * m)1118 static Node* GetPadOrConv2D(const Node* m) {
1119 DCHECK(m);
1120 Node* n = nullptr;
1121
1122 DataType T_m;
1123 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1124
1125 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1126 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1127
1128 const Node* conv_node;
1129 if (m->type_string() == csinfo_.pad) {
1130 // If m is Pad, then Conv2D is the output of Pad.
1131 for (const Edge* e : m->out_edges()) {
1132 if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) {
1133 n = e->dst();
1134 conv_node = n;
1135 break;
1136 }
1137 }
1138 } else {
1139 DCHECK_EQ(m->type_string(), csinfo_.conv2d);
1140 // If m is conv2D, Go over all input edges
1141 // and search for Pad Node.
1142 for (const Edge* e : m->in_edges()) {
1143 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1144 n = e->src();
1145 conv_node = m;
1146 break;
1147 }
1148 }
1149 }
1150 // Check if only VALID type of padding is used
1151 // or not.
1152 if (n != nullptr) {
1153 string padding;
1154 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1155 if (padding != "VALID")
1156 // Then do not merge.
1157 // Only VALID type of padding in conv op can be
1158 // merged with Pad op.
1159 n = nullptr;
1160 } else {
1161 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1162 << "Pad and Conv2D node for merging. Input node: "
1163 << m->DebugString();
1164 }
1165
1166 return n;
1167 }
1168
1169 // Find Pad or _FusedConv2D node that can be merged with input node 'm'.
1170 // If input 'm' is Pad, then check if there exists _FusedConv2D node that can
1171 // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there
1172 // exists Pad node that can be merged with 'm'.
GetPadOrFusedConv2D(const Node * m)1173 static Node* GetPadOrFusedConv2D(const Node* m) {
1174 DCHECK(m);
1175 Node* n = nullptr;
1176
1177 const Node* conv_node;
1178 if (m->type_string() == csinfo_.pad) {
1179 // If m is Pad, then _FusedConv2D is the output of Pad.
1180 for (const Edge* e : m->out_edges()) {
1181 if (!e->IsControlEdge() &&
1182 e->dst()->type_string() == csinfo_.fused_conv2d) {
1183 n = e->dst();
1184 conv_node = n;
1185 break;
1186 }
1187 }
1188 } else {
1189 DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d);
1190 // If m is _FusedConv2D, Go over all input edges
1191 // and search for Pad node.
1192 for (const Edge* e : m->in_edges()) {
1193 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1194 n = e->src();
1195 conv_node = m;
1196 break;
1197 }
1198 }
1199 }
1200 // Check if only VALID type of padding is used or not.
1201 if (n != nullptr) {
1202 string padding;
1203 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1204 if (padding != "VALID") {
1205 // Then do not merge.
1206 n = nullptr;
1207 VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
1208 << "nodes but cannot merge them. Only conv ops with padding "
1209 << "type VALID can be merged with Pad op Input node: "
1210 << m->DebugString();
1211 }
1212 } else {
1213 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1214 << "Pad and _FusedConv2D node for merging. Input node: "
1215 << m->DebugString();
1216 }
1217
1218 return n;
1219 }
1220
1221 // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input
1222 // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists
1223 // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad,
1224 // then check if there exists Conv2DBackpropFilter node that can be merged
1225 // with 'm'.
1226 //
1227 // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad
1228 // would look like:
1229 //
1230 // _ = Conv2DBackpropFilter(F, _, G)
1231 // _ = BiasAddGrad(G)
1232 //
1233 // So 1st input of BiasAddGrad connects with 3rd input of
1234 // Conv2DBackpropFilter and vice versa.
GetConv2DBackpropFilterOrBiasAddGrad(const Node * m)1235 static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) {
1236 DCHECK(m);
1237 Node* n = nullptr;
1238 const Node* conv2d_backprop_filter = nullptr;
1239
1240 DataType T_m;
1241 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1242
1243 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1244 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1245
1246 if (m->type_string() == csinfo_.bias_add_grad) {
1247 // Get 1st input 'g' of BiasAddGrad.
1248 Node* g = nullptr;
1249 TF_CHECK_OK(m->input_node(0, &g));
1250 // Now traverse all outgoing edges from g that have destination node as
1251 // Conv2DBackpropFilter.
1252 for (const Edge* e : g->out_edges()) {
1253 if (!e->IsControlEdge() &&
1254 e->dst()->type_string() == csinfo_.conv2d_grad_filter &&
1255 e->dst_input() == 2 /* 3rd input of BackpropFilter */) {
1256 n = e->dst();
1257 conv2d_backprop_filter = n;
1258 break;
1259 }
1260 }
1261 } else {
1262 conv2d_backprop_filter = m;
1263 CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter);
1264 // Get 3rd input 'g' of Conv2DBackpropFilter.
1265 Node* g = nullptr;
1266 TF_CHECK_OK(m->input_node(2, &g));
1267 // Now traverse all outgoing edges from g that have destination node as
1268 // BiasAddGrad.
1269 for (const Edge* e : g->out_edges()) {
1270 if (!e->IsControlEdge() &&
1271 e->dst()->type_string() == csinfo_.bias_add_grad &&
1272 e->dst_input() == 0 /* 1st input of BiasAddGrad */) {
1273 n = e->dst();
1274 break;
1275 }
1276 }
1277 }
1278
1279 // Do not merge if padding type is EXPLICIT.
1280 // TODO(intel): Support `EXPLICIT` padding for MklConv2DBackpropFilter.
1281 if (conv2d_backprop_filter != nullptr) {
1282 string padding;
1283 TF_CHECK_OK(
1284 GetNodeAttr(conv2d_backprop_filter->def(), "padding", &padding));
1285 if (padding == "EXPLICIT") {
1286 // Then do not merge.
1287 VLOG(1) << "MklLayoutRewritePass: Could match Conv2DBackpropFilter "
1288 << "and BiasAddGrad nodes but cannot merge them. "
1289 << "EXPLICIT padding is not supported now. "
1290 << conv2d_backprop_filter->DebugString();
1291 return nullptr;
1292 }
1293 }
1294
1295 if (n == nullptr) {
1296 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1297 << "Conv2DBackpropFilter and BiasAddGrad node for merging. "
1298 << "Input node: " << m->DebugString();
1299 }
1300 return n;
1301 }
1302
1303 // Return a node that can be fused with input node 'n'
1304 //
1305 // @return tuple. If we can find such nodes, the first
1306 // element of the tuple is a true. Otherwise, it's false.
1307 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
1308 CheckForNodeFusion(Node* n) const;
1309
1310 // Fuse nodes in the vector "nodes"
1311 Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1312 const MklLayoutRewritePass::FusionInfo fi);
1313
1314 // Fuse transpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into
1315 // mklop("NCHW").
1316 // Here "mklop" can be any MKL-DNN supported op, such as Conv2D.
1317 static Status FuseTransposeMklOpTranspose(
1318 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1319 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
1320 string data_format);
1321
CheckForTranspose(const Node * node,std::vector<int> perm)1322 static bool CheckForTranspose(const Node* node, std::vector<int> perm) {
1323 // Check if node's type is "Transpose"
1324 if (node->type_string() != "Transpose") return false;
1325
1326 // If "Transpose" has multiple output data edges, also don't fuse it.
1327 if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false;
1328
1329 // Check if has out control edge. If true, this is a training graph.
1330 // Currently we focus on inference and do no fusion in training.
1331 // Note: this constraint will eventually be removed, if we enabled this
1332 // fusion for training
1333 // in the future.
1334 for (const Edge* e : node->out_edges()) {
1335 if (e->IsControlEdge()) {
1336 return false;
1337 }
1338 }
1339
1340 // If "Transpose" has input control edges, don't fuse on it.
1341 for (const Edge* e : node->in_edges()) {
1342 if (e->IsControlEdge()) {
1343 return false;
1344 }
1345 }
1346
1347 // We compared the tensor containing the permutation order ("perm_node")
1348 // with our desired order ("perm"). If they're exactly match, this check
1349 // succeed and returns true.
1350 for (const Edge* e : node->in_edges()) {
1351 if (!e->IsControlEdge()) {
1352 const Node* perm_node = e->src();
1353
1354 const int kPermTensorIndex = 1;
1355 if (perm_node->type_string() == "Const" &&
1356 e->dst_input() == kPermTensorIndex) {
1357 // we find the "perm" node, now try to retrieve its value.
1358 const TensorProto* proto = nullptr;
1359 TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value", &proto));
1360
1361 DataType type;
1362 TF_CHECK_OK(GetNodeAttr(perm_node->def(), "dtype", &type));
1363
1364 Tensor tensor;
1365 if (!tensor.FromProto(*proto)) {
1366 TF_CHECK_OK(errors::InvalidArgument(
1367 "Could not construct Tensor from TensorProto in node: ",
1368 node->name()));
1369 return false;
1370 }
1371 // Current fusion only supports 4D or 5D tensors according to `perm`
1372 // vector, return false otherwise.
1373 if (tensor.dim_size(0) != perm.size()) return false;
1374 DCHECK_EQ(tensor.dims(), 1);
1375 if (type == DT_INT32) {
1376 const auto tensor_content = tensor.flat<int>().data();
1377 for (int i = 0; i < perm.size(); ++i)
1378 if (tensor_content[i] != perm[i]) return false;
1379 return true;
1380 } else if (type == DT_INT64) {
1381 const auto tensor_content = tensor.flat<int64_t>().data();
1382 for (int i = 0; i < perm.size(); ++i)
1383 if (tensor_content[i] != perm[i]) return false;
1384 return true;
1385 }
1386 return false;
1387 }
1388 }
1389 }
1390 return false;
1391 }
1392
CheckForMklOp(const Node * node,string name="")1393 static bool CheckForMklOp(const Node* node, string name = "") {
1394 if (node == nullptr) return false;
1395
1396 if (!name.empty() && node->type_string() != name) {
1397 return false;
1398 }
1399
1400 // if mklop has multiple outputs, don't fuse it.
1401 if (node->num_outputs() > 1) return false;
1402
1403 if (node->out_edges().size() > 1) return false;
1404
1405 DataType T;
1406 TF_CHECK_OK(GetNodeAttr(node->def(), "T", &T));
1407 return mkl_op_registry::IsMklOp(
1408 mkl_op_registry::GetMklOpName(node->type_string()), T);
1409 }
1410
1411 // Check if the node 'n' has any applicable rewrite rule
1412 // We check for 2 scenarios for rewrite.
1413 //
1414 // @return RewriteInfo* for the applicable rewrite rule
1415 const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
1416 const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const;
1417
1418 // Default rewrite rule to be used in scenario 1 for rewrite.
1419 // @return - true (since we want to always rewrite)
AlwaysRewrite(const Node * n)1420 static bool AlwaysRewrite(const Node* n) { return true; }
1421
1422 // Rewrite rule which considers "context" of the current node to decide if we
1423 // should rewrite. By "context" we currently mean all the inputs of current
1424 // node. The idea is if none of the inputs of current node are not MKL nodes,
1425 // then rewriting current node to MKL node _may not_ offer any performance
1426 // improvement.
1427 //
1428 // One such case is element-wise ops. For such ops, we reuse the Eigen
1429 // implementation and pass the MKL metadata tensor through so we can avoid
1430 // conversions. However, if all incoming edges are in TF format, we don't
1431 // need all this overhead, so replace the elementwise node only if at least
1432 // one of its parents is a MKL node.
1433 //
1434 // More generally, all memory- or IO-bound ops (such as Identity) may fall
1435 // under this category.
1436 //
1437 // @input - Input graph node to be rewritten
1438 // @return - true if node is to be rewritten as MKL node; false otherwise.
RewriteIfAtleastOneMklInput(const Node * n)1439 static bool RewriteIfAtleastOneMklInput(const Node* n) {
1440 DataType T;
1441 if (GetNodeAttr(n->def(), "T", &T).ok() &&
1442 mkl_op_registry::IsMklOp(
1443 mkl_op_registry::GetMklOpName(n->type_string()), T)) {
1444 for (auto e : n->in_edges()) {
1445 if (e->IsControlEdge()) continue;
1446 if (mkl_op_registry::IsMklOp(e->src())) {
1447 return true;
1448 }
1449 }
1450 }
1451 return false;
1452 }
1453
MatMulRewrite(const Node * n)1454 static bool MatMulRewrite(const Node* n) {
1455 DataType T;
1456 TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
1457 if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) {
1458 VLOG(2) << "Rewriting MatMul to _MklMatMul";
1459 return true;
1460 }
1461 return false;
1462 }
1463 // For oneDNN, only int32 is supported for axis data type
ConcatV2Rewrite(const Node * n)1464 static bool ConcatV2Rewrite(const Node* n) {
1465 DataType T;
1466 TF_CHECK_OK(GetNodeAttr(n->def(), "Tidx", &T));
1467 return (T == DT_INT32);
1468 }
1469
DequantizeRewrite(const Node * n)1470 static bool DequantizeRewrite(const Node* n) {
1471 DCHECK(n);
1472 Node* input = nullptr;
1473 TF_CHECK_OK(n->input_node(0, &input));
1474 string mode_string;
1475 int axis = -1;
1476 TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1477 TF_CHECK_OK(GetNodeAttr(n->def(), "axis", &axis));
1478 if (mode_string != "SCALED") {
1479 VLOG(1) << "DequantizeRewrite: Mode is not SCALED. "
1480 << "This case is not optimized by Intel MKL kernel, thus using "
1481 "Eigen op for Dequantize op.";
1482 return false;
1483 }
1484 if (input->IsConstant()) {
1485 VLOG(1) << "DequantizeRewrite: Trying to dequantize a Const node which "
1486 << "could possibly be a filter. "
1487 << "This case is not supported by Intel MKL kernel, thus using "
1488 "Eigen op for Dequantize op.";
1489 return false;
1490 }
1491
1492 if (axis != -1) {
1493 VLOG(1) << "DequantizeRewrite: Using Eigen op for Dequantize op because "
1494 << "dimension is specified for per slice dequantization. ";
1495 return false;
1496 }
1497 return true;
1498 }
1499
1500 // Rewrite rule for _FusedMatMul.
1501 // @return - true (no transpose attribute for input 1);
1502 // false otherwise.
FusedMatMulRewrite(const Node * n)1503 static bool FusedMatMulRewrite(const Node* n) {
1504 bool trans_a;
1505
1506 // Do not rewrite with transpose attribute because reorder has performance
1507 // impact.
1508 TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &trans_a));
1509
1510 return !trans_a;
1511 }
1512
1513 // Check if we are performing pooling on depth or batch. If it is, then we
1514 // do not rewrite MaxPool node to Mkl version.
1515 // @return - true (if it is not a depth/batch wise pooling case);
1516 // false otherwise.
NonDepthBatchWisePoolRewrite(const Node * n)1517 static bool NonDepthBatchWisePoolRewrite(const Node* n) {
1518 DCHECK(n);
1519
1520 string data_format_str;
1521 TensorFormat data_format;
1522 std::vector<int32> ksize, strides;
1523 TF_CHECK_OK(GetNodeAttr(n->def(), "ksize", &ksize));
1524 TF_CHECK_OK(GetNodeAttr(n->def(), "strides", &strides));
1525 TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format_str));
1526 bool result = FormatFromString(data_format_str, &data_format);
1527 DCHECK(result);
1528
1529 // Condition that specifies non-batch-wise and non-depth-wise pooling.
1530 if (GetTensorDim(ksize, data_format, 'N') == 1 &&
1531 GetTensorDim(strides, data_format, 'N') == 1 &&
1532 GetTensorDim(ksize, data_format, 'C') == 1 &&
1533 GetTensorDim(strides, data_format, 'C') == 1) {
1534 return true;
1535 }
1536
1537 return false;
1538 }
1539
1540 // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
1541 // path. The unoptimized path is slow. Thus we don't rewrite the node
1542 // and use default Eigen. But for depth_radius=2, MKL DNN optimized
1543 // path is taken, i.e., eigen node is rewritten by MKl DNN node.
LrnRewrite(const Node * n)1544 static bool LrnRewrite(const Node* n) {
1545 DCHECK(n);
1546
1547 int depth_radius;
1548 TF_CHECK_OK(GetNodeAttr(n->def(), "depth_radius", &depth_radius));
1549
1550 // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
1551 // and use eigen node instead
1552 if (depth_radius == 2) {
1553 return true;
1554 }
1555 VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
1556 << "case is not optimized by Intel MKL, thus using Eigen op"
1557 << "for LRN ";
1558
1559 return false;
1560 }
1561
LrnGradRewrite(const Node * n)1562 static bool LrnGradRewrite(const Node* n) {
1563 DCHECK(n);
1564 bool do_rewrite = false;
1565
1566 for (const Edge* e : n->in_edges()) {
1567 // Rewrite only if there is corresponding LRN, i.e workspace is available
1568 if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
1569 e->src()->type_string() ==
1570 mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
1571 e->src_output() == 0) {
1572 do_rewrite = true;
1573 break;
1574 }
1575 }
1576 return do_rewrite;
1577 }
1578
1579 // MKL-DNN's LeakyRelu(feature) = feature (if feature > 0), or
1580 // feature * alpha (otherwise),
1581 // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha).
1582 // These two algorithms are not consistent when alpha > 1,
1583 // so we only rewrite LeakyRelu to MKL OP when alpha <= 1.
LeakyReluRewrite(const Node * n)1584 static bool LeakyReluRewrite(const Node* n) {
1585 DCHECK(n);
1586
1587 float alpha;
1588 bool has_attr = TryGetNodeAttr(n->def(), "alpha", &alpha);
1589 DCHECK(has_attr);
1590
1591 // If the alpha of LeakyRelu is less than 1, rewrite the node.
1592 // Otherwise eigen node is used instead.
1593 if (alpha <= 1) {
1594 return true;
1595 }
1596 VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 "
1597 << "which case is not optimized by Intel MKL, thus using Eigen op"
1598 << "for LeakyRelu ";
1599
1600 return false;
1601 }
1602
QuantizeOpRewrite(const Node * n)1603 static bool QuantizeOpRewrite(const Node* n) {
1604 DCHECK(n);
1605 Node* filter_node = nullptr;
1606 TF_CHECK_OK(n->input_node(0, &filter_node));
1607 bool narrow_range = false;
1608 int axis = -1;
1609 string mode_string;
1610 string round_mode_string;
1611 DataType type;
1612 TryGetNodeAttr(n->def(), "narrow_range", &narrow_range);
1613 TryGetNodeAttr(n->def(), "axis", &axis);
1614 TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1615 TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode", &round_mode_string));
1616 TF_CHECK_OK(GetNodeAttr(n->def(), "T", &type));
1617
1618 if (narrow_range) {
1619 VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization."
1620 << "This case is not optimized by Intel MKL, "
1621 << "thus using Eigen op for Quantize op ";
1622 return false;
1623 }
1624 if (axis != -1) {
1625 VLOG(1) << "QuantizeOpRewrite: dimension is specified for "
1626 << "per slice quantization."
1627 << "This case is not optimized by Intel MKL, "
1628 << "thus using Eigen op for Quantize op ";
1629 return false;
1630 }
1631 if (!((mode_string == "SCALED" && round_mode_string == "HALF_TO_EVEN") ||
1632 (mode_string == "MIN_FIRST"))) {
1633 VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED or MIN_FIRST and/or"
1634 << "rounding mode is not HALF_TO_EVEN. "
1635 << "This case is not optimized by Intel MKL, thus using Eigen op"
1636 << "for Quantize op ";
1637 return false;
1638 }
1639 if (filter_node->IsConstant()) {
1640 VLOG(1) << "QuantizeOpRewrite: Trying to quantize a node which "
1641 << "is a constant. "
1642 << "This case is not supported by the kernel, thus using Eigen op"
1643 << "for Quantize op ";
1644
1645 return false;
1646 }
1647 if (mode_string == "MIN_FIRST") {
1648 if (type != DT_QUINT8) {
1649 VLOG(1) << "QuantizeOpRewrite: For MIN_FIRST mode the data type is "
1650 << "not DT_UINT8. This case is not optimized by Intel MKL, "
1651 << "thus using Eigen op for Quantize op ";
1652 return false;
1653 }
1654 }
1655 return true;
1656 }
1657
MaxpoolGradRewrite(const Node * n)1658 static bool MaxpoolGradRewrite(const Node* n) {
1659 DCHECK(n);
1660 bool do_rewrite = false;
1661 for (const Edge* e : n->in_edges()) {
1662 // Rewrite only if there is corresponding Maxpool, i.e workspace is
1663 // available
1664 if (e->dst()->type_string() == csinfo_.max_pool_grad &&
1665 e->dst_input() == 1 &&
1666 e->src()->type_string() ==
1667 mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
1668 e->src_output() == 0) {
1669 do_rewrite = true;
1670 break;
1671 }
1672 }
1673 return do_rewrite;
1674 }
1675
Maxpool3DGradRewrite(const Node * n)1676 static bool Maxpool3DGradRewrite(const Node* n) {
1677 DCHECK(n);
1678 for (const Edge* e : n->in_edges()) {
1679 // Rewrite only if there is corresponding Maxpool3D, i.e., workspace is
1680 // available
1681 if (e->dst()->type_string() == csinfo_.max_pool3d_grad &&
1682 e->dst_input() == 1 &&
1683 e->src()->type_string() ==
1684 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d) &&
1685 e->src_output() == 0) {
1686 return true;
1687 }
1688 }
1689 return false;
1690 }
1691
FusedBatchNormV3Rewrite(const Node * n)1692 static bool FusedBatchNormV3Rewrite(const Node* n) {
1693 DCHECK(n);
1694 if (Check5DFormat(n->def())) {
1695 VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not "
1696 << "support 5D tensors.";
1697 return false;
1698 }
1699 return true;
1700 }
1701
FusedBatchNormExRewrite(const Node * n)1702 static bool FusedBatchNormExRewrite(const Node* n) {
1703 DCHECK(n);
1704
1705 int num_side_inputs;
1706 TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs));
1707 string activation_mode;
1708 TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode));
1709
1710 // if the num_side_inputs is not 0, don't rewrite the node.
1711 if (num_side_inputs != 0) {
1712 VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs"
1713 << "larger than 0 is not optimized by Intel MKL.";
1714 return false;
1715 }
1716
1717 // if the activation_mode is not 'Relu', don't rewrite the node.
1718 if (activation_mode != "Relu") {
1719 VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is"
1720 << "supported by Intel MKL.";
1721 return false;
1722 }
1723
1724 return true;
1725 }
1726
FusedConv2DRewrite(const Node * n)1727 static bool FusedConv2DRewrite(const Node* n) {
1728 // MKL DNN currently doesn't support all fusions that grappler fuses
1729 // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
1730 // it includes those we support.
1731 DataType T;
1732 if (!TryGetNodeAttr(n->def(), "T", &T) ||
1733 !mkl_op_registry::IsMklOp(NativeFormatEnabled()
1734 ? csinfo_.mkl_native_fused_conv2d
1735 : csinfo_.mkl_fused_conv2d,
1736 T)) {
1737 return false;
1738 }
1739
1740 std::vector<string> fused_ops;
1741 TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1742 return (fused_ops == std::vector<string>{"BiasAdd"} ||
1743 fused_ops == std::vector<string>{"Relu"} ||
1744 fused_ops == std::vector<string>{"Relu6"} ||
1745 fused_ops == std::vector<string>{"Elu"} ||
1746 fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1747 fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1748 fused_ops == std::vector<string>{"BiasAdd", "Elu"} ||
1749 fused_ops == std::vector<string>{"BiasAdd", "Add"} ||
1750 fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"} ||
1751 fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"} ||
1752 fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"} ||
1753 fused_ops == std::vector<string>{"LeakyRelu"} ||
1754 fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"} ||
1755 fused_ops == std::vector<string>{"BiasAdd", "Add", "LeakyRelu"} ||
1756 fused_ops == std::vector<string>{"FusedBatchNorm"} ||
1757 fused_ops == std::vector<string>{"FusedBatchNorm", "Relu"} ||
1758 fused_ops == std::vector<string>{"FusedBatchNorm", "Relu6"} ||
1759 fused_ops == std::vector<string>{"FusedBatchNorm", "Elu"} ||
1760 fused_ops == std::vector<string>{"FusedBatchNorm", "LeakyRelu"});
1761 }
1762
FusedDepthwiseConv2DRewrite(const Node * n)1763 static bool FusedDepthwiseConv2DRewrite(const Node* n) {
1764 // MKL DNN currently doesn't support all fusions that grappler fuses
1765 // together with DepthwiseConv2D (ex. batchnorm). We rewrite
1766 // _FusedDepthwiseConv2DNative only if it includes those we support.
1767 DataType T;
1768 if (!TryGetNodeAttr(n->def(), "T", &T) ||
1769 !mkl_op_registry::IsMklOp(
1770 NativeFormatEnabled() ? csinfo_.mkl_native_fused_depthwise_conv2d
1771 : csinfo_.mkl_fused_depthwise_conv2d,
1772 T)) {
1773 return false;
1774 }
1775
1776 std::vector<string> fused_ops;
1777 TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1778 return (fused_ops == std::vector<string>{"BiasAdd"} ||
1779 fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1780 fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1781 fused_ops == std::vector<string>{"BiasAdd", "Elu"});
1782 }
1783
1784 // Rewrites input node to a new node specified by its matching rewrite info.
1785 //
1786 // Method first searches matching rewrite info for input node and then
1787 // uses that info to rewrite.
1788 //
1789 // Input node may be deleted in case of rewrite. Attempt to use the node
1790 // after the call can result in undefined behaviors.
1791 //
1792 // @input g - input graph, n - Node to be rewritten,
1793 // ri - matching rewriteinfo
1794 // @return Status::OK(), if the input node is rewritten;
1795 // Returns appropriate Status error code otherwise.
1796 // Graph is updated in case the input node is rewritten.
1797 // Otherwise, it is not updated.
1798 Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
1799
1800 // Rewrites input node to just change its operator name. The number of
1801 // inputs to the node and the number of outputs remain the same. Attributes
1802 // of the new node could be copied from attributes of the old node or
1803 // modified. copy_attrs field of RewriteInfo controls this.
1804 //
1805 // Conceptually, it allows us to rewrite:
1806 //
1807 // f[a=v1,b=v2](x,y) -> g[a'=v3,b'=v4](x,y)
1808 //
1809 // Attributes can be altered without any restrictions --- they could be
1810 // copied, modified, or deleted completely.
1811 //
1812 // @input g - input graph, orig_node - Node to be rewritten,
1813 // ri - matching rewriteinfo
1814 // @output new_node - points to newly created node
1815 // @return Status::OK(), if the input node is rewritten;
1816 // Returns appropriate Status error code otherwise.
1817 // Graph is only updated when the input node is rewritten.
1818 Status RewriteNodeForJustOpNameChange(std::unique_ptr<Graph>* g,
1819 const Node* orig_node, Node** new_node,
1820 const RewriteInfo* ri);
1821
1822 // Rewrites input node to enable MKL layout propagation. Please also refer to
1823 // documentation for the function RewriteNodeForJustOpNameChange() to
1824 // understand what it means.
1825 //
1826 // @input g - input graph, orig_node - Node to be rewritten,
1827 // ri - matching rewriteinfo
1828 // @output new_node - points to newly created node
1829 // @return Status::OK(), if the input node is rewritten;
1830 // Returns appropriate Status error code otherwise.
1831 // Graph is updated in case the input node is rewritten.
1832 // Otherwise, it is not updated.
1833 Status RewriteNodeForLayoutPropagation(std::unique_ptr<Graph>* g,
1834 const Node* orig_node, Node** new_node,
1835 const RewriteInfo* ri);
1836
1837 // Get nodes that will feed a list of TF tensors to the new
1838 // node that we are constructing.
1839 //
1840 // @input g - input graph,
1841 // @input inputs - inputs to old node that we are using for constructing
1842 // new inputs,
1843 // @input input_idx - the index in the 'inputs' vector pointing to the
1844 // current input that we have processed so far
1845 // @output input_idx - index will be incremented by the number of nodes
1846 // from 'inputs' that are processed
1847 // @input list_length - The expected length of list of TF tensors
1848 // @output output_nodes - the list of new nodes creating TF tensors
1849 //
1850 // @return None
1851 void GetNodesProducingTFTensorList(
1852 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1853 int* input_idx, int list_length,
1854 std::vector<NodeBuilder::NodeOut>* output_nodes);
1855
1856 // Get nodes that will feed a list of Mkl tensors to the new
1857 // node that we are constructing.
1858 //
1859 // @input g - input graph,
1860 // @input orig_node - Original node that we are rewriting
1861 // @input inputs - inputs to old node that we are using for constructing
1862 // new inputs,
1863 // @input input_idx - the index in the 'inputs' vector pointing to the
1864 // current input that we have processed so far
1865 // @output input_idx - index will be incremented by the number of nodes
1866 // from 'inputs' that are processed
1867 // @input list_length - The expected length of list of Mkl tensors
1868 // @output output_nodes - the list of new nodes creating Mkl tensors
1869 //
1870 // @return None
1871 void GetNodesProducingMklTensorList(
1872 std::unique_ptr<Graph>* g, const Node* orig_node,
1873 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1874 int* input_idx, int list_length,
1875 std::vector<NodeBuilder::NodeOut>* output_nodes);
1876
1877 // Get a node that will feed an Mkl tensor to the new
1878 // node that we are constructing. The output node could be (1) 'n'
1879 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
1880 // if 'n' is not an Mkl layer.
1881 //
1882 // @input g - input graph,
1883 // @input orig_node - Original node that we are rewriting,
1884 // @input n - Node based on which we are creating Mkl node,
1885 // @input n_output_slot - the output slot of node 'n'
1886 // which is feeding to the node that we are constructing
1887 // @output mkl_node - the new node that will feed Mkl tensor
1888 // @output mkl_node_output_slot - the slot number of mkl_node that
1889 // will feed the tensor
1890 // @return None
1891 void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
1892 const Node* orig_node, Node* n,
1893 int n_output_slot, Node** mkl_node,
1894 int* mkl_node_output_slot);
1895
1896 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1897 // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
1898 // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
1899 // producing workspace edges if 'are_workspace_tensors_available' is true.
1900 // Otherwise, 'workspace_tensors' is empty vector.
1901 //
1902 // For details, refer to 'Ordering of inputs after rewriting' section in the
1903 // documentation above.
1904 //
1905 // Returns Status::OK() if setting up inputs is successful, otherwise
1906 // returns appropriate status code.
1907 int SetUpContiguousInputs(
1908 std::unique_ptr<Graph>* g,
1909 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1910 NodeBuilder* nb, const Node* old_node,
1911 std::vector<NodeBuilder::NodeOut>* workspace_tensors,
1912 bool are_workspace_tensors_available);
1913
1914 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1915 // in graph 'g'. Original node is input in 'orig_node'.
1916 //
1917 // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
1918 // section in the documentation above.
1919 //
1920 // Returns Status::OK() if setting up inputs is successful, otherwise
1921 // returns appropriate status code.
1922 Status SetUpInputs(std::unique_ptr<Graph>* g,
1923 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1924 NodeBuilder* nb, const Node* orig_node);
1925
1926 // Create new inputs by copying old inputs 'inputs' for the rewritten node
1927 // in 'nb' in graph 'g'. Original node is input in 'orig_node'. This is mostly
1928 // used in the context of rewrite for just operator name change in which
1929 // inputs of old operator and new operator are same.
1930 //
1931 // Returns Status::OK() if setting up inputs is successful, otherwise
1932 // returns appropriate status code.
1933 Status CopyInputs(const Node* orig_node,
1934 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1935 NodeBuilder* nb);
1936
1937 // Add workspace edge on the input or output side of Node 'orig_node' by using
1938 // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
1939 // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
1940 // tensors, if they need to be added, will be set into these tensors.
1941 // If we set workspace tensors, then are_ws_tensors_added should be true.
1942 void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
1943 const Node* orig_node, NodeBuilder* nb,
1944 std::vector<NodeBuilder::NodeOut>* ws_tensors,
1945 bool* are_ws_tensors_added);
1946
1947 // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
1948 // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
1949 // 'g'. Returns true if fixup was done; otherwise, it returns false.
1950 bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data,
1951 const Edge* e_metadata);
1952
1953 // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
1954 // connected? If not, then fix them. This is needed because a graph may have
1955 // some input Mkl metadata edges incorrectly setup after node merge and
1956 // rewrite passes. This could happen because GetReversePostOrder function may
1957 // not provide topologically sorted order if a graph contains cycles. The
1958 // function returns true if at least one Mkl metadata edge for node 'n' was
1959 // fixed. Otherwise, it returns false.
1960 //
1961 // Example:
1962 //
1963 // X = MklConv2D(_, _, _)
1964 // Y = MklConv2DWithBias(_, _, _, _, _, _)
1965 // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
1966 //
1967 // For a graph such as shown above, note that 3rd argument of MklAdd contains
1968 // DummyMklTensor. Actually, it should be getting the Mkl metadata from
1969 // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
1970 // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
1971 // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
1972 // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
1973 // data edges (1st and 2nd arguments of MklAdd).
1974 bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
1975
1976 // Functions specific to operators to copy attributes
1977 // We need operator-specific function to copy attributes because the framework
1978 // does not provide any generic function for it.
1979 // NOTE: names are alphabetically sorted.
1980 static void CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
1981 bool change_format = false);
1982 static void CopyAttrsAllCheckConstFilter(const Node* orig_node,
1983 NodeBuilder* nb,
1984 bool change_format = false);
1985
1986 static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
1987 bool change_format = false);
1988 static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
1989 NodeBuilder* nb,
1990 bool change_format = false);
1991 static void CopyAttrsFromPadAndConv2D(const Node* orig_node1,
1992 const Node* orig_node2, NodeBuilder* nb,
1993 bool change_format = false);
1994 static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1,
1995 const Node* orig_node2,
1996 NodeBuilder* nb,
1997 bool change_format = false);
1998 static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb,
1999 bool change_format = false);
2000 static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb,
2001 const std::vector<int32>& strides,
2002 const std::vector<int32>& dilations,
2003 bool change_format = false);
2004
2005 static void CopyAttrsQuantizedMatMulWithBias(const Node* orig_node,
2006 NodeBuilder* nb,
2007 bool change_format = false);
2008 static void CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2009 const Node* orig_node, NodeBuilder* nb, bool change_format = false);
2010 static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb,
2011 bool change_format = false);
2012
2013 // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
2014 // using node for original node 'orig_node' and return it in '*out'.
2015 // TODO(nhasabni) We should move this to mkl_util.h
2016 void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
2017 const Node* orig_node);
2018 void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
2019 const Node* orig_node);
2020 };
2021
2022 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
2023
2024 // We register Mkl rewrite pass for phase 1 in post partitioning group.
2025 // We register it here so that we get a complete picture of all users of Mkl
2026 // nodes. Do not change the ordering of the Mkl passes.
2027 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
2028 OptimizationPassRegistry::POST_PARTITIONING;
2029 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
2030
2031 //////////////////////////////////////////////////////////////////////////
2032 // Helper functions for creating new node
2033 //////////////////////////////////////////////////////////////////////////
2034
FillInputs(const Node * n,gtl::InlinedVector<Node *,4> * control_edges,gtl::InlinedVector<std::pair<Node *,int>,4> * in)2035 static void FillInputs(const Node* n,
2036 gtl::InlinedVector<Node*, 4>* control_edges,
2037 gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
2038 control_edges->clear();
2039 for (const Edge* e : n->in_edges()) {
2040 if (e->IsControlEdge()) {
2041 control_edges->push_back(e->src());
2042 } else {
2043 (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
2044 }
2045 }
2046 std::sort(control_edges->begin(), control_edges->end());
2047 }
2048
GetNodesProducingTFTensorList(const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2049 void MklLayoutRewritePass::GetNodesProducingTFTensorList(
2050 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2051 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2052 CHECK_LT(*input_idx, inputs.size());
2053 CHECK_GT(list_length, 0);
2054 DCHECK(output_nodes);
2055 output_nodes->reserve(list_length);
2056
2057 while (list_length != 0) {
2058 CHECK_GT(list_length, 0);
2059 CHECK_LT(*input_idx, inputs.size());
2060 Node* n = inputs[*input_idx].first;
2061 int slot = inputs[*input_idx].second;
2062 // If input node 'n' is just producing a single tensor at
2063 // output slot 'slot' then we just add that single node.
2064 output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
2065 (*input_idx)++;
2066 list_length--;
2067 }
2068 }
2069
2070 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyMklTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2071 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
2072 Node** out,
2073 const Node* orig_node) {
2074 // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
2075 // dummy Mkl tensor. 8 = 2*size_t.
2076 const DataType dt = DataTypeToEnum<uint8>::v();
2077 TensorProto proto;
2078 proto.set_dtype(dt);
2079 uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
2080 proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8));
2081 TensorShape dummy_shape({8});
2082 dummy_shape.AsProto(proto.mutable_tensor_shape());
2083 TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
2084 .Attr("value", proto)
2085 .Attr("dtype", dt)
2086 .Device(orig_node->def().device()) // We place this node on
2087 // the same device as the
2088 // device of the original
2089 // node.
2090 .Finalize(&**g, out));
2091 DCHECK(*out); // Make sure we got a valid object before using it
2092
2093 // If number of inputs to the original node is > 0, then we add
2094 // control dependency between 1st input (index 0) of the original node and
2095 // the dummy Mkl node. This is needed because control-flow ops such as Enter,
2096 // Merge, etc, require frame_name of the dummy Mkl node to be same as the
2097 // rewritten node. Adding control edge between 1st input of the original node
2098 // and the dummy Mkl node ensures that the dummy node is in the same frame
2099 // as the original node. Choosing 1st input is not necessary - any input of
2100 // the original node is fine because all the inputs of a node are always in
2101 // the same frame.
2102 if (orig_node->num_inputs() > 0) {
2103 Node* orig_input0 = nullptr;
2104 TF_CHECK_OK(
2105 orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
2106 auto edge = (*g)->AddControlEdge(orig_input0, *out, false);
2107 DCHECK(edge != nullptr || DoesControlEdgeExist(orig_input0, *out));
2108 }
2109
2110 (*out)->set_assigned_device_name(orig_node->assigned_device_name());
2111 }
2112
GetNodesProducingMklTensorList(std::unique_ptr<Graph> * g,const Node * orig_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2113 void MklLayoutRewritePass::GetNodesProducingMklTensorList(
2114 std::unique_ptr<Graph>* g, const Node* orig_node,
2115 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2116 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2117 CHECK_LT(*input_idx, inputs.size());
2118 CHECK_GT(list_length, 0);
2119 DCHECK(output_nodes);
2120 output_nodes->reserve(list_length);
2121
2122 while (list_length != 0) {
2123 CHECK_GT(list_length, 0);
2124 CHECK_LT(*input_idx, inputs.size());
2125 Node* n = inputs[*input_idx].first;
2126 int slot = inputs[*input_idx].second;
2127 // If 'n' is producing a single tensor, then create a single Mkl tensor
2128 // node.
2129 Node* mkl_node = nullptr;
2130 int mkl_node_output_slot = 0;
2131 GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
2132 &mkl_node_output_slot);
2133 output_nodes->push_back(
2134 NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
2135 (*input_idx)++;
2136 list_length--;
2137 }
2138 }
2139
2140 // Get an input node that will feed Mkl tensor to the new
2141 // node that we are constructing. An input node could be (1) 'n'
2142 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
2143 // if 'n' is not an Mkl layer.
GetNodeProducingMklTensor(std::unique_ptr<Graph> * g,const Node * orig_node,Node * n,int n_output_slot,Node ** mkl_node,int * mkl_node_output_slot)2144 void MklLayoutRewritePass::GetNodeProducingMklTensor(
2145 std::unique_ptr<Graph>* g, const Node* orig_node, Node* n,
2146 int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
2147 DCHECK(n);
2148 DCHECK(mkl_node);
2149 DCHECK(mkl_node_output_slot);
2150
2151 // If this is an MKL op, then it will create extra output for MKL layout.
2152 DataType T;
2153 if (TryGetNodeAttr(n->def(), "T", &T) &&
2154 mkl_op_registry::IsMklOp(n->type_string(), T, false)) {
2155 // If this is an MKL op, then it will generate an edge that will receive
2156 // Mkl tensor from a node.
2157 // output slot number for Mkl tensor would be N+slot number of TensorFlow
2158 // tensor, where N is total number of TensorFlow tensors.
2159 *mkl_node = n;
2160 *mkl_node_output_slot =
2161 GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
2162 } else {
2163 // If we have not visited the node and rewritten it, then we need
2164 // to create a dummy node that will feed a dummy Mkl tensor to this node.
2165 // DummyMklTensor node has no input and generates only 1 output
2166 // (dummy Mkl tensor) as output slot number 0.
2167 GetDummyMklTensorNode(g, mkl_node, orig_node);
2168 DCHECK(*mkl_node);
2169 *mkl_node_output_slot = 0;
2170 }
2171 }
2172
SetUpContiguousInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node,std::vector<NodeBuilder::NodeOut> * workspace_tensors,bool are_workspace_tensors_available)2173 int MklLayoutRewritePass::SetUpContiguousInputs(
2174 std::unique_ptr<Graph>* g,
2175 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2176 NodeBuilder* nb, const Node* old_node,
2177 std::vector<NodeBuilder::NodeOut>* workspace_tensors,
2178 bool are_workspace_tensors_available) {
2179 DCHECK(workspace_tensors);
2180 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2181
2182 // TODO(nhasabni): Temporary solution to connect filter input of
2183 // BackpropInput with the converted filter from Conv2D.
2184 bool do_connect_conv2d_backprop_input_filter = false;
2185 Node* conv2d_node = nullptr;
2186 // Filter node is 2nd input (slot index 1) of Conv2D.
2187 int kConv2DFilterInputSlotIdx = 1;
2188 int kConv2DBackpropInputFilterInputSlotIdx = 1;
2189 int kConv2DFilterOutputSlotIdx = 1;
2190 if (old_node->type_string() == csinfo_.conv2d_grad_input) {
2191 // We need to find Conv2D node from Conv2DBackpropInput.
2192 // For that let's first find filter node that is 2nd input (slot 1)
2193 // of BackpropInput.
2194 Node* filter_node = nullptr;
2195 TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx,
2196 &filter_node));
2197 DCHECK(filter_node);
2198
2199 // Now check which nodes receive from filter_node. Filter feeds as
2200 // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
2201 // _MklFusedConv2D.
2202 for (const Edge* e : filter_node->out_edges()) {
2203 if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
2204 e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d ||
2205 e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d ||
2206 e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias ||
2207 e->dst()->type_string() == csinfo_.mkl_fused_conv2d) &&
2208 e->dst_input() == kConv2DFilterInputSlotIdx
2209 /* filter is 2nd input of Conv2D and _MklConv2D. */) {
2210 if (conv2d_node != nullptr) {
2211 VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
2212 << " feeding multiple Conv2D nodes: "
2213 << filter_node->DebugString();
2214 // We will not connect filter input of Conv2DBackpropInput
2215 // to be safe here.
2216 do_connect_conv2d_backprop_input_filter = false;
2217 break;
2218 } else {
2219 conv2d_node = e->dst();
2220 do_connect_conv2d_backprop_input_filter = true;
2221 }
2222 }
2223 }
2224 }
2225
2226 // Number of input slots to original op
2227 // Input slots are represented by .Input() calls in REGISTER_OP.
2228 int old_node_input_slots = old_node->op_def().input_arg_size();
2229 int nn_slot_idx = 0; // slot index for inputs of new node
2230
2231 // Let's copy all inputs (TF tensors) of original node to new node.
2232 int iidx = 0;
2233 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2234 // An input slot could be a single tensor or a list. We need
2235 // to handle this case accordingly.
2236 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2237 if (ArgIsList(arg)) {
2238 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2239 int tensor_list_length = GetTensorListLength(arg, old_node);
2240 if (tensor_list_length != 0) {
2241 GetNodesProducingTFTensorList(old_node_inputs, &iidx,
2242 tensor_list_length, &new_node_inputs);
2243 }
2244 nb->Input(new_node_inputs);
2245 nn_slot_idx++;
2246 } else {
2247 // Special case for connecting filter input of Conv2DBackpropInput
2248 if (do_connect_conv2d_backprop_input_filter &&
2249 iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2250 nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
2251 } else {
2252 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2253 }
2254 iidx++;
2255 nn_slot_idx++;
2256 }
2257 }
2258
2259 // If workspace tensors are available for this op and we are using
2260 // contiguous ordering then we need to add Tensorflow tensor for
2261 // workspace here because Tensorflow tensor for workspace is the
2262 // last tensor in the list of Tensorflow tensors.
2263 if (are_workspace_tensors_available) {
2264 CHECK_EQ(workspace_tensors->size(), 2);
2265 // Tensorflow tensor
2266 nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
2267 nn_slot_idx++;
2268 }
2269
2270 // Let's now setup all Mkl inputs to a new node.
2271 // Number of Mkl inputs must be same as number of TF inputs.
2272 iidx = 0;
2273 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2274 // An input slot could be a single tensor or a list. We need
2275 // to handle this case accordingly.
2276 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2277 if (ArgIsList(arg)) {
2278 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2279 int tensor_list_length = GetTensorListLength(arg, old_node);
2280 if (tensor_list_length != 0) {
2281 GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
2282 tensor_list_length, &new_node_inputs);
2283 }
2284 nb->Input(new_node_inputs);
2285 nn_slot_idx++;
2286 } else {
2287 Node* mkl_node = nullptr;
2288 int mkl_node_output_slot = 0;
2289 // Special case for connecting filter input of Conv2DBackpropInput
2290 if (do_connect_conv2d_backprop_input_filter &&
2291 iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2292 GetNodeProducingMklTensor(g, old_node, conv2d_node,
2293 kConv2DFilterOutputSlotIdx, &mkl_node,
2294 &mkl_node_output_slot);
2295 } else {
2296 GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
2297 old_node_inputs[iidx].second, &mkl_node,
2298 &mkl_node_output_slot);
2299 }
2300 nb->Input(mkl_node, mkl_node_output_slot);
2301 iidx++;
2302 nn_slot_idx++;
2303 }
2304 }
2305
2306 // If workspace tensors are available for this op and we are using
2307 // contiguous ordering then we need to add Mkl tensor for
2308 // workspace here because Mkl tensor for workspace is the
2309 // last tensor in the list of Mkl tensors.
2310 if (are_workspace_tensors_available) {
2311 CHECK_EQ(workspace_tensors->size(), 2);
2312 // Mkl tensor
2313 nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
2314 nn_slot_idx++;
2315 }
2316
2317 return nn_slot_idx;
2318 }
2319
2320 // This method finds out if checking workspace is needed or not. Workspace is
2321 // not used in quantized ops, so checking that would fail as quantized ops
2322 // don't have attribute: "T".
IsWorkspaceCheckNeeded(const Node * node)2323 bool IsWorkspaceCheckNeeded(const Node* node) {
2324 std::vector<string> quant_ops{
2325 "Dequantize",
2326 "QuantizeV2",
2327 "QuantizedConv2D",
2328 "QuantizedConv2DWithBias",
2329 "QuantizedConv2DAndRelu",
2330 "QuantizedConv2DWithBiasAndRelu",
2331 "QuantizedConv2DWithBiasSumAndRelu",
2332 "QuantizedConv2DPerChannel",
2333 "QuantizedConv2DAndRequantize",
2334 "QuantizedConv2DWithBiasAndRequantize",
2335 "QuantizedConv2DAndReluAndRequantize",
2336 "QuantizedConv2DWithBiasAndReluAndRequantize",
2337 "QuantizedConv2DWithBiasSumAndReluAndRequantize",
2338 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize",
2339 "QuantizedMatMulWithBias",
2340 "QuantizedMatMulWithBiasAndRequantize",
2341 "QuantizedMatMulWithBiasAndDequantize",
2342 "QuantizedMatMulWithBiasAndRelu",
2343 "QuantizedMatMulWithBiasAndReluAndRequantize",
2344 "QuantizedDepthwiseConv2D",
2345 "QuantizedDepthwiseConv2DWithBias",
2346 "QuantizedDepthwiseConv2DWithBiasAndRelu",
2347 "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"};
2348 return std::find(std::begin(quant_ops), std::end(quant_ops),
2349 node->type_string()) == std::end(quant_ops);
2350 }
2351
SetUpInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node)2352 Status MklLayoutRewritePass::SetUpInputs(
2353 std::unique_ptr<Graph>* g,
2354 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2355 NodeBuilder* nb, const Node* old_node) {
2356 // Let's check if we need to add workspace tensors for this node.
2357 // We add workspace edge only for MaxPool, LRN and BatchNorm.
2358 std::vector<NodeBuilder::NodeOut> workspace_tensors;
2359 bool are_workspace_tensors_available = false;
2360
2361 if (IsWorkspaceCheckNeeded(old_node)) {
2362 AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
2363 &are_workspace_tensors_available);
2364 }
2365
2366 int new_node_input_slots = 0;
2367 if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
2368 // TODO(nhasabni): implement this function just for same of completion.
2369 // We do not use interleaved ordering right now.
2370 return Status(
2371 error::Code::UNIMPLEMENTED,
2372 "Interleaved ordering of tensors is currently not supported.");
2373 } else {
2374 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2375 new_node_input_slots = SetUpContiguousInputs(
2376 g, old_node_inputs, nb, old_node, &workspace_tensors,
2377 are_workspace_tensors_available);
2378 }
2379
2380 // Sanity check
2381 int old_node_input_slots = old_node->op_def().input_arg_size();
2382 if (!are_workspace_tensors_available) {
2383 // If we are not adding workspace tensors for this op, then the total
2384 // number of input slots to the new node _must_ be 2 times the number
2385 // of input slots to the original node: N original Tensorflow tensors and
2386 // N for Mkl tensors corresponding to each Tensorflow tensors.
2387 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
2388 } else {
2389 // If we are adding workspace tensors for this op, then the total
2390 // The total number of input slots to new node _must_ be 2 times the number
2391 // of input slots to the original node: N original Tensorflow tensors and
2392 // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
2393 // (for workspace Tensorflow tensor and workspace Mkl tensor).
2394 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
2395 }
2396
2397 return Status::OK();
2398 }
2399
CopyInputs(const Node * old_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb)2400 Status MklLayoutRewritePass::CopyInputs(
2401 const Node* old_node,
2402 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2403 NodeBuilder* nb) {
2404 // Number of input slots to old node
2405 // Input slots are represented by .Input() calls in REGISTER_OP.
2406 int old_node_input_slots = old_node->op_def().input_arg_size();
2407 // Actual number of inputs can be greater than or equal to number
2408 // of Input slots because inputs of type list could be unfolded.
2409 auto old_node_input_size = old_node_inputs.size();
2410 DCHECK_GE(old_node_input_size, old_node_input_slots);
2411
2412 // Let's copy all inputs of old node to new node.
2413 int iidx = 0;
2414 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2415 // An input slot could be a single tensor or a list. We need
2416 // to handle this case accordingly.
2417 DCHECK_LT(iidx, old_node_input_size);
2418 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2419 if (ArgIsList(arg)) {
2420 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2421 int N = GetTensorListLength(arg, old_node);
2422 if (N != 0) {
2423 GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
2424 &new_node_inputs);
2425 }
2426 nb->Input(new_node_inputs);
2427 } else {
2428 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2429 iidx++;
2430 }
2431 }
2432 return Status::OK();
2433 }
2434
2435 //////////////////////////////////////////////////////////////////////////
2436 // Helper functions related to workspace pass
2437 //////////////////////////////////////////////////////////////////////////
2438
2439 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyWorkspaceTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2440 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
2441 std::unique_ptr<Graph>* g, Node** out, const Node* orig_node) {
2442 // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
2443 // workspace tensor.
2444 GetDummyMklTensorNode(g, out, orig_node);
2445 }
2446
AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph> * g,const Node * orig_node,NodeBuilder * nb,std::vector<NodeBuilder::NodeOut> * ws_tensors,bool * are_ws_tensors_added)2447 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
2448 std::unique_ptr<Graph>* g, const Node* orig_node, NodeBuilder* nb,
2449 std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
2450 bool workspace_edge_added = false; // Default initializer
2451 DCHECK(are_ws_tensors_added);
2452 *are_ws_tensors_added = false; // Default initializer
2453
2454 DataType T;
2455 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2456 for (auto ws : wsinfo_) {
2457 if (orig_node->type_string() == ws.fwd_op &&
2458 mkl_op_registry::IsMklOp(
2459 mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
2460 // If this op is a fwd op, then we need to check if there is an
2461 // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
2462 // an edge, then we just add an attribute on this node for setting
2463 // workspace_passed to true. We don't add actual workspace edge
2464 // in this node. Actual workspace edge gets added in the backward
2465 // op for this node.
2466 for (const Edge* e : orig_node->out_edges()) {
2467 if (e->src_output() == ws.fwd_slot &&
2468 e->dst()->type_string() == ws.bwd_op &&
2469 e->dst_input() == ws.bwd_slot) {
2470 nb->Attr("workspace_enabled", true);
2471 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2472 << orig_node->type_string();
2473 workspace_edge_added = true;
2474 // We found the edge that we were looking for, so break.
2475 break;
2476 }
2477 }
2478
2479 if (!workspace_edge_added) {
2480 // If we are here, then we did not find backward operator for this
2481 // node.
2482 nb->Attr("workspace_enabled", false);
2483 }
2484 } else if (orig_node->type_string() == ws.bwd_op &&
2485 mkl_op_registry::IsMklOp(
2486 mkl_op_registry::GetMklOpName(orig_node->type_string()),
2487 T)) {
2488 // If this op is a bwd op, then we need to add workspace edge and
2489 // it's Mkl tensor edge between its corresponding fwd op and this
2490 // op. Corresponding fwd op is specified in 'fwd_op' field of
2491 // workspace info. fwd_slot and bwd_slot in workspace info specify
2492 // an edge between which slots connect forward and backward op.
2493 // Once all these criteria match, we add a workspace edge between
2494 // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
2495 // determined by interleaved/contiguous ordering. Function
2496 // DataIndexToMetaDataIndex tells us the location of Mkl tensor
2497 // from the location of the Tensorflow tensor.
2498 for (const Edge* e : orig_node->in_edges()) {
2499 if (e->src_output() == ws.fwd_slot &&
2500 // We would have rewritten the forward op, so we need to use
2501 // GetMklOpName call to get its Mkl name.
2502 e->src()->type_string() ==
2503 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
2504 e->dst_input() == ws.bwd_slot) {
2505 nb->Attr("workspace_enabled", true);
2506 DCHECK(ws_tensors);
2507 // Add workspace edge between fwd op and bwd op.
2508 ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
2509 // Check if we are running in native format mode. If so,
2510 // we don't need to have an Mkl metadata tensor for the workspace.
2511 if (!NativeFormatEnabled()) {
2512 // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
2513 ws_tensors->push_back(NodeBuilder::NodeOut(
2514 e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
2515 e->src()->num_outputs())));
2516 }
2517 *are_ws_tensors_added = true;
2518 // In terms of input ordering, we add these calls to add Input
2519 // here because workspace edge (and its Mkl tensor) is the last
2520 // edge in the fwdop and bwdop. So all inputs before workspace
2521 // tensor have been added by SetUpInputs function.
2522 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2523 << orig_node->type_string();
2524 workspace_edge_added = true;
2525 // We found the edge that we were looking for, so break.
2526 break;
2527 }
2528 }
2529
2530 // If we are here means we did not find fwd op that feeds to this
2531 // bwd op. So in this case, we need to generate dummy tensors for
2532 // workspace input and Mkl tensor for workspace, and set
2533 // workspace_enabled to false.
2534 if (!workspace_edge_added) {
2535 nb->Attr("workspace_enabled", false);
2536 Node* dmt_ws = nullptr; // Dummy tensor for workspace
2537 Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace
2538 GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
2539 GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
2540 DCHECK(dmt_ws);
2541 DCHECK(dmt_mkl_ws);
2542 DCHECK(ws_tensors);
2543 // We add dummy tensor as workspace tensor.
2544 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
2545 // We add dummy tensor as Mkl tensor for workspace tensor.
2546 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
2547 *are_ws_tensors_added = true;
2548 VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
2549 << orig_node->type_string();
2550 }
2551 } else {
2552 // If this node does not match any workspace info, then we do not
2553 // do anything special for workspace propagation for it.
2554 }
2555 }
2556 }
2557
2558 //////////////////////////////////////////////////////////////////////////
2559 // Op-specific functions to copy attributes from old node to new node
2560 //////////////////////////////////////////////////////////////////////////
2561
2562 // Generic function to copy all attributes from original node to target.
CopyAttrsAll(const Node * orig_node,NodeBuilder * nb,bool change_format)2563 void MklLayoutRewritePass::CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2564 bool change_format) {
2565 string name;
2566 AttrSlice attr_list(orig_node->def());
2567
2568 auto iter = attr_list.begin();
2569 while (iter != attr_list.end()) {
2570 name = iter->first;
2571 auto attr = iter->second;
2572 nb->Attr(name, attr);
2573 ++iter;
2574 }
2575 }
2576
2577 // Generic function to copy all attributes and check if filter is const.
CopyAttrsAllCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2578 void MklLayoutRewritePass::CopyAttrsAllCheckConstFilter(const Node* orig_node,
2579 NodeBuilder* nb,
2580 bool change_format) {
2581 CopyAttrsAll(orig_node, nb, change_format);
2582
2583 // Check and set filter attribute.
2584 Node* filter_node = nullptr;
2585 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2586
2587 bool is_filter_const = false;
2588 if (HasNodeAttr(orig_node->def(), "is_filter_const")) {
2589 GetNodeAttr(orig_node->def(), "is_filter_const", &is_filter_const);
2590 }
2591
2592 // In case that (1) orig_node does not have attribute 'is_filter_const',
2593 // or (2) it has the attribute but with the false value, then we set the
2594 // attribute for 'nb' with a value based on filter_node being const or not.
2595 // If is_filter_const == true, then there is no need to call nb->Attr() as
2596 // CopyAttrsAll() has already copied the attribute from orig_node to nb.
2597 if (!is_filter_const) {
2598 nb->Attr("is_filter_const", filter_node->IsConstant());
2599 }
2600 }
2601
CopyAttrsConvCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2602 void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node,
2603 NodeBuilder* nb,
2604 bool change_format) {
2605 CopyAttrsConv(orig_node, nb, change_format);
2606
2607 // Check and set filter attribute.
2608 Node* filter_node = nullptr;
2609 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2610 nb->Attr("is_filter_const", filter_node->IsConstant());
2611 }
2612
CopyAttrsConv(const Node * orig_node,NodeBuilder * nb,bool change_format)2613 void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2614 bool change_format) {
2615 DataType T;
2616 string padding;
2617 std::vector<int32> strides;
2618 std::vector<int32> dilations;
2619 std::vector<int32> explicit_paddings;
2620
2621 // Get all attributes from old node.
2622 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2623 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2624 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2625 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2626
2627 // Check `explicit_paddings` first because some Conv ops don't have
2628 // this attribute.
2629 if (TryGetNodeAttr(orig_node->def(), "explicit_paddings",
2630 &explicit_paddings) &&
2631 !explicit_paddings.empty()) {
2632 nb->Attr("explicit_paddings", explicit_paddings);
2633 }
2634
2635 // Add attributes to new node.
2636 nb->Attr("T", T);
2637 nb->Attr("padding", padding);
2638
2639 // Add attributes related to `data_format`.
2640 CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
2641 }
2642
2643 // Used with MergePadWithConv2D
CopyAttrsFromPadAndConv2D(const Node * orig_node1,const Node * orig_node2,NodeBuilder * nb,bool change_format)2644 void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2645 const Node* orig_node2,
2646 NodeBuilder* nb,
2647 bool change_format) {
2648 DataType Tpaddings;
2649 DataType T;
2650 string data_format;
2651 string padding;
2652 std::vector<int32> strides;
2653 std::vector<int32> dilations;
2654 bool use_cudnn_on_gpu;
2655
2656 // Get all attributes from old node 1.
2657 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T", &T));
2658 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides", &strides));
2659 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations", &dilations));
2660 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding", &padding));
2661 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format", &data_format));
2662 TF_CHECK_OK(
2663 GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2664 // Get all attributes from old node 2.
2665 TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings", &Tpaddings));
2666
2667 // Add attributes to new node.
2668 nb->Attr("T", T);
2669 nb->Attr("strides", strides);
2670 nb->Attr("dilations", dilations);
2671 nb->Attr("padding", padding);
2672 nb->Attr("data_format", data_format);
2673 nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2674 nb->Attr("Tpaddings", Tpaddings);
2675 }
2676
CopyAttrsFromPadAndFusedConv2D(const Node * fused_conv2d,const Node * pad,NodeBuilder * nb,bool change_format)2677 void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
2678 const Node* fused_conv2d, const Node* pad, NodeBuilder* nb,
2679 bool change_format) {
2680 DataType T;
2681 int num_args;
2682 string data_format;
2683 string padding;
2684 std::vector<int32> strides;
2685 std::vector<int32> dilations;
2686 float epsilon;
2687 std::vector<string> fused_ops;
2688 DataType Tpaddings;
2689 float leakyrelu_alpha;
2690
2691 // Get all attributes from old node.
2692 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T));
2693 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args", &num_args));
2694 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides", &strides));
2695 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding", &padding));
2696 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format", &data_format));
2697 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations));
2698 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops));
2699 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon));
2700 TF_CHECK_OK(
2701 GetNodeAttr(fused_conv2d->def(), "leakyrelu_alpha", &leakyrelu_alpha));
2702 TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings));
2703
2704 // Add attributes to new node.
2705 nb->Attr("T", T);
2706 nb->Attr("num_args", num_args);
2707 nb->Attr("strides", strides);
2708 nb->Attr("padding", padding);
2709 nb->Attr("data_format", data_format);
2710 nb->Attr("dilations", dilations);
2711 nb->Attr("epsilon", epsilon);
2712 nb->Attr("Tpaddings", Tpaddings);
2713 nb->Attr("fused_ops", fused_ops);
2714 nb->Attr("leakyrelu_alpha", leakyrelu_alpha);
2715 }
2716
CopyAttrsQuantizedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2717 void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
2718 NodeBuilder* nb,
2719 bool change_format) {
2720 DataType Tinput, Tfilter, out_type;
2721 string padding;
2722 string data_format("NHWC");
2723 std::vector<int32> strides, dilations, padding_list;
2724 bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list");
2725
2726 // Get all attributes from old node.
2727 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
2728 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter", &Tfilter));
2729 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
2730 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2731 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2732 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2733 if (has_padding_list) {
2734 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list));
2735 }
2736
2737 Node* filter_node = nullptr;
2738 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2739
2740 // Add attributes to new node.
2741 nb->Attr("Tinput", Tinput);
2742 nb->Attr("Tfilter", Tfilter);
2743 nb->Attr("out_type", out_type);
2744 nb->Attr("padding", padding);
2745 nb->Attr("is_filter_const", filter_node->IsConstant());
2746 nb->Attr("strides", strides);
2747 nb->Attr("dilations", dilations);
2748 nb->Attr("data_format", data_format);
2749 if (has_padding_list) {
2750 nb->Attr("padding_list", padding_list);
2751 }
2752
2753 // Requantization attr Tbias.
2754 DataType Tbias;
2755 Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2756 if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2757 }
2758
CopyAttrsQuantizedMatMulWithBiasAndDequantize(const Node * orig_node,NodeBuilder * nb,bool change_format)2759 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2760 const Node* orig_node, NodeBuilder* nb, bool change_format) {
2761 CopyAttrsAll(orig_node, nb, change_format);
2762
2763 // Check and set filter attribute.
2764 Node* filter_node = nullptr;
2765 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2766 nb->Attr("is_weight_const", filter_node->IsConstant());
2767 }
2768
CopyAttrsQuantizedMatMulWithBias(const Node * orig_node,NodeBuilder * nb,bool change_format)2769 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBias(
2770 const Node* orig_node, NodeBuilder* nb, bool change_format) {
2771 DataType T1, T2, Toutput;
2772
2773 // Get all attributes from old node.
2774 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T1", &T1));
2775 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T2", &T2));
2776 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Toutput", &Toutput));
2777
2778 Node* weight_node = nullptr;
2779 TF_CHECK_OK(orig_node->input_node(1, &weight_node));
2780
2781 // Add attributes to new node.
2782 nb->Attr("T1", T1);
2783 nb->Attr("T2", T2);
2784 nb->Attr("Toutput", Toutput);
2785 nb->Attr("is_weight_const", weight_node->IsConstant());
2786
2787 // Requantization attr Tbias
2788 DataType Tbias;
2789 Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2790 if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2791 }
2792
CopyFormatAttrsConv(const Node * orig_node,NodeBuilder * nb,const std::vector<int32> & strides,const std::vector<int32> & dilations,bool change_format)2793 void MklLayoutRewritePass::CopyFormatAttrsConv(
2794 const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides,
2795 const std::vector<int32>& dilations, bool change_format) {
2796 string data_format;
2797
2798 if (!change_format) {
2799 nb->Attr("strides", strides);
2800 nb->Attr("dilations", dilations);
2801
2802 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2803 nb->Attr("data_format", data_format);
2804 } else {
2805 std::vector<int32> new_strides;
2806 std::vector<int32> new_dilations;
2807 if (strides.size() == 5) {
2808 // `strides` and `dilations` also need to be changed according to
2809 // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2810 new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2811 strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2812 strides[NDHWC::dim::W]};
2813
2814 new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C],
2815 dilations[NDHWC::dim::D], dilations[NDHWC::dim::H],
2816 dilations[NDHWC::dim::W]};
2817 } else {
2818 // `strides` and `dilations` also need to be changed according to
2819 // `data_format`. In this case, from `NHWC` to `NCHW`.
2820
2821 new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2822 strides[NHWC::dim::H], strides[NHWC::dim::W]};
2823
2824 new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C],
2825 dilations[NHWC::dim::H], dilations[NHWC::dim::W]};
2826 }
2827 nb->Attr("strides", new_strides);
2828 nb->Attr("dilations", new_dilations);
2829 }
2830 }
2831
CopyAttrsPooling(const Node * orig_node,NodeBuilder * nb,bool change_format)2832 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
2833 NodeBuilder* nb,
2834 bool change_format) {
2835 DataType T;
2836 string data_format;
2837 string padding;
2838 std::vector<int32> ksize, strides;
2839
2840 // Get all attributes from old node.
2841 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2842 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
2843 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2844 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2845 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2846
2847 // Add attributes to new node.
2848 nb->Attr("T", T);
2849 nb->Attr("padding", padding);
2850
2851 if (!change_format) {
2852 nb->Attr("strides", strides);
2853 nb->Attr("ksize", ksize);
2854
2855 nb->Attr("data_format", data_format);
2856 } else {
2857 std::vector<int32> new_strides;
2858 std::vector<int32> new_ksize;
2859 if (strides.size() == 5) {
2860 DCHECK(data_format == "NCDHW");
2861 // `strides` and `ksize` also need to be changed according to
2862 // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2863 new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2864 strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2865 strides[NDHWC::dim::W]};
2866
2867 new_ksize = {ksize[NDHWC::dim::N], ksize[NDHWC::dim::C],
2868 ksize[NDHWC::dim::D], ksize[NDHWC::dim::H],
2869 ksize[NDHWC::dim::W]};
2870
2871 } else {
2872 // `strides` and `ksize` also need to be changed according to
2873 // `data_format`. In this case, from `NHWC` to `NCHW`.
2874 DCHECK(data_format == "NCHW");
2875 new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2876 strides[NHWC::dim::H], strides[NHWC::dim::W]};
2877
2878 new_ksize = {ksize[NHWC::dim::N], ksize[NHWC::dim::C],
2879 ksize[NHWC::dim::H], ksize[NHWC::dim::W]};
2880 }
2881 nb->Attr("strides", new_strides);
2882 nb->Attr("ksize", new_ksize);
2883 }
2884 }
2885
2886 //////////////////////////////////////////////////////////////////////////
2887 // Helper functions related to node merge pass
2888 //////////////////////////////////////////////////////////////////////////
2889
CheckForNodeMerge(const Node * a) const2890 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
2891 // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
2892 // once we support BiasAddGrad as Mkl layer.
2893
2894 // Search for all matching mergeinfo.
2895 // We allow more than one match for extensibility.
2896 std::vector<const MergeInfo*> matching_mi;
2897 for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
2898 if (a->type_string() == mi->op1 || a->type_string() == mi->op2) {
2899 matching_mi.push_back(&*mi);
2900 }
2901 }
2902
2903 for (const MergeInfo* mi : matching_mi) {
2904 // Get the operand with which 'a' can be merged.
2905 Node* b = nullptr;
2906 if ((b = mi->get_node_to_be_merged(a)) == nullptr) {
2907 continue;
2908 }
2909
2910 // Get the control edges and input of node
2911 const int N_in = a->num_inputs();
2912 gtl::InlinedVector<Node*, 4> a_control_edges;
2913 gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
2914 FillInputs(a, &a_control_edges, &a_in);
2915
2916 const int B_in = b->num_inputs();
2917 gtl::InlinedVector<Node*, 4> b_control_edges;
2918 gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
2919 FillInputs(b, &b_control_edges, &b_in);
2920
2921 // Shouldn't merge if a and b have different control edges.
2922 if (a_control_edges != b_control_edges) {
2923 continue;
2924 } else {
2925 // We found a match.
2926 return b;
2927 }
2928 }
2929
2930 return nullptr;
2931 }
2932
MergeConv2DWithBiasAdd(std::unique_ptr<Graph> * g,Node * m,Node * n)2933 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
2934 Node* m, Node* n) {
2935 CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
2936 n->type_string() == csinfo_.conv2d)) ||
2937 ((n->type_string() == csinfo_.bias_add &&
2938 m->type_string() == csinfo_.conv2d)),
2939 true);
2940
2941 // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
2942 // BiasAdd is successor node, and Conv2D predecessor node.
2943 Node* pred = m->type_string() == csinfo_.bias_add ? n : m;
2944 Node* succ = m->type_string() == csinfo_.bias_add ? m : n;
2945
2946 // 1. Get all attributes from input nodes.
2947 DataType T_pred, T_succ;
2948 string padding;
2949 std::vector<int32> strides;
2950 std::vector<int32> dilations;
2951 string data_format_pred, data_format_succ;
2952 bool use_cudnn_on_gpu;
2953 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
2954 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
2955 TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
2956 TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
2957 TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations));
2958 TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
2959 TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
2960 TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2961 // We check to ensure that data formats of both succ and pred are same.
2962 // We expect them to be same, so we can enforce this as assert.
2963 // But assert can be too strict, so we enforce this as a check.
2964 // If the check fails, then we do not merge two nodes.
2965 // We also do same check for devices.
2966 if (data_format_pred != data_format_succ || T_pred != T_succ ||
2967 pred->assigned_device_name() != succ->assigned_device_name() ||
2968 pred->def().device() != succ->def().device()) {
2969 return Status(error::Code::INVALID_ARGUMENT,
2970 "data_format or T attribute or devices of Conv2D and "
2971 "BiasAdd do not match. Will skip node merge optimization");
2972 }
2973
2974 const int succ_num = succ->num_inputs();
2975 gtl::InlinedVector<Node*, 4> succ_control_edges;
2976 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
2977 FillInputs(succ, &succ_control_edges, &succ_in);
2978
2979 const int pred_num = pred->num_inputs();
2980 gtl::InlinedVector<Node*, 4> pred_control_edges;
2981 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
2982 FillInputs(pred, &pred_control_edges, &pred_in);
2983
2984 // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is
2985 // not expecting output of Conv2D). If this is not the case, then we cannot
2986 // merge Conv2D with BiasAdd.
2987 const int kFirstOutputSlot = 0;
2988 for (const Edge* e : pred->out_edges()) {
2989 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
2990 return Status(error::Code::INVALID_ARGUMENT,
2991 "Conv2D does not feed to BiasAdd, or "
2992 "it feeds BiasAdd but has multiple outputs. "
2993 "Will skip node merge optimization");
2994 }
2995 }
2996
2997 // 2. Get inputs from both the nodes.
2998 // Find the 2 inputs from the conv and the bias from the add Bias.
2999 // Get operand 0, 1 of conv2D.
3000 CHECK_EQ(pred->in_edges().size(), 2); // Conv2D must have 2 inputs.
3001 // Get operand 1 of add_bias
3002 // BiasAdd must have 2 inputs: Conv, bias
3003 CHECK_EQ(succ->in_edges().size(), 2);
3004
3005 // We will use the node name of BiasAdd as the name of new node
3006 // Build new node. We use same name as original node, but change the op
3007 // name.
3008 NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias);
3009 nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D
3010 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3011 nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D
3012 // In1 of BiasAdd is same as output of Conv2D.
3013 nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
3014
3015 // Copy attributes from Conv2D to Conv2DWithBias.
3016 CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb);
3017
3018 // Copy the device assigned to old node to new node.
3019 nb.Device(succ->def().device());
3020
3021 // Create node.
3022 Node* new_node;
3023 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3024
3025 // In the following code of this function, an unsorted set is used to make
3026 // sure no duplicated edges be added into the new node. Therefore, we can
3027 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3028 // check in the routine.
3029
3030 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3031 // node are already copied in BuildNode. We handle control edges now.
3032 std::unordered_set<Node*> unique_node;
3033 for (const Edge* e : pred->in_edges()) {
3034 if (e->IsControlEdge()) {
3035 auto result = unique_node.insert(e->src());
3036 if (result.second) {
3037 (*g)->AddControlEdge(e->src(), new_node, true);
3038 }
3039 }
3040 }
3041 unique_node.clear();
3042
3043 for (const Edge* e : succ->in_edges()) {
3044 if (e->IsControlEdge()) {
3045 auto result = unique_node.insert(e->src());
3046 if (result.second) {
3047 (*g)->AddControlEdge(e->src(), new_node, true);
3048 }
3049 }
3050 }
3051 unique_node.clear();
3052
3053 // Incoming edges are fixed, we will fix the outgoing edges now.
3054 // First, we will fix outgoing control edges from 'pred' node.
3055 for (const Edge* e : pred->out_edges()) {
3056 if (e->IsControlEdge()) {
3057 auto result = unique_node.insert(e->dst());
3058 if (result.second) {
3059 (*g)->AddControlEdge(new_node, e->dst(), true);
3060 }
3061 }
3062 }
3063 unique_node.clear();
3064
3065 // Second, we will fix outgoing control and data edges from 'succ' node.
3066 for (const Edge* e : succ->out_edges()) {
3067 if (e->IsControlEdge()) {
3068 auto result = unique_node.insert(e->dst());
3069 if (result.second) {
3070 (*g)->AddControlEdge(new_node, e->dst(), true);
3071 }
3072 } else {
3073 // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
3074 // output (at slot 0).
3075 const int kConv2DWithBiasOutputSlot = 0;
3076 auto new_edge = (*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot,
3077 e->dst(), e->dst_input());
3078 DCHECK(new_edge);
3079 }
3080 }
3081
3082 // Copy device assigned to old node to new node.
3083 // It's ok to use pred or succ as we have enforced a check that
3084 // both have same device assigned.
3085 new_node->set_assigned_device_name(pred->assigned_device_name());
3086
3087 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3088 << ", and node: " << succ->DebugString()
3089 << ", into node:" << new_node->DebugString();
3090
3091 (*g)->RemoveNode(succ);
3092 (*g)->RemoveNode(pred);
3093
3094 return Status::OK();
3095 }
3096
MergePadWithConv2D(std::unique_ptr<Graph> * g,Node * m,Node * n)3097 Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g,
3098 Node* m, Node* n) {
3099 DCHECK((m->type_string() == csinfo_.pad &&
3100 (n->type_string() == csinfo_.conv2d ||
3101 n->type_string() == csinfo_.fused_conv2d)) ||
3102 (n->type_string() == csinfo_.pad &&
3103 (m->type_string() == csinfo_.conv2d ||
3104 m->type_string() == csinfo_.fused_conv2d)));
3105
3106 bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d ||
3107 m->type_string() == csinfo_.fused_conv2d;
3108 // Conv2D is successor node, and Pad predecessor node.
3109 Node* pred = m->type_string() == csinfo_.pad ? m : n;
3110 Node* succ = m->type_string() == csinfo_.pad ? n : m;
3111
3112 // 1. Get all attributes from input nodes.
3113 DataType T_pred, T_succ;
3114 string padding;
3115 std::vector<int32> strides;
3116 std::vector<int32> dilations;
3117 string data_format_pred, data_format_succ;
3118
3119 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
3120 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
3121 TF_CHECK_OK(GetNodeAttr(succ->def(), "padding", &padding));
3122 TF_CHECK_OK(GetNodeAttr(succ->def(), "strides", &strides));
3123 TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations", &dilations));
3124 // Check if the devices of both succ and pred are the same.
3125 // Assert is not used because it can be too strict.
3126 // Don't need to check for data formats because it is not available in Pad.
3127 if (T_pred != T_succ ||
3128 pred->assigned_device_name() != succ->assigned_device_name() ||
3129 pred->def().device() != succ->def().device()) {
3130 return Status(error::Code::INVALID_ARGUMENT,
3131 "T attribute or devices of Conv2D and "
3132 "Pad do not match. Will skip node merge optimization");
3133 }
3134
3135 const int succ_num = succ->num_inputs();
3136 gtl::InlinedVector<Node*, 4> succ_control_edges;
3137 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3138 FillInputs(succ, &succ_control_edges, &succ_in);
3139
3140 const int pred_num = pred->num_inputs();
3141 gtl::InlinedVector<Node*, 4> pred_control_edges;
3142 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3143 FillInputs(pred, &pred_control_edges, &pred_in);
3144
3145 // We need to ensure that Pad only feeds to Conv2D (some other operator is
3146 // not expecting output of Pad). If this is not the case, then we cannot
3147 // merge Conv2D with Pad.
3148 const int kFirstOutputSlot = 0;
3149 for (const Edge* e : pred->out_edges()) {
3150 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3151 return Status(error::Code::INVALID_ARGUMENT,
3152 "Pad does not feed to Conv2D, or "
3153 "it feeds Conv2D but has multiple outputs. "
3154 "Will skip node merge optimization");
3155 }
3156 }
3157
3158 // 2. Get inputs from both the nodes.
3159
3160 // Pad must have 2 data inputs: "input" and paddings.
3161 int PadDataInputEdges = 0;
3162 for (const Edge* e : pred->in_edges()) {
3163 if (!e->IsControlEdge()) {
3164 PadDataInputEdges++;
3165 }
3166 }
3167 DCHECK_EQ(PadDataInputEdges, 2);
3168
3169 // Conv2D must have 2 data inputs: Pad output and Filter
3170 // FusedConv2D have 3 data inputs: Pad output, Filter and Args;
3171 int ConvDataInputEdges = 0;
3172 for (const Edge* e : succ->in_edges()) {
3173 if (!e->IsControlEdge()) {
3174 ConvDataInputEdges++;
3175 }
3176 }
3177
3178 DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2);
3179
3180 // We will use the node name of Conv2D as the name of new node
3181 // Build new node. We use same name as original node, but change the op
3182 // name.
3183
3184 NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d
3185 : csinfo_.pad_with_conv2d);
3186 nb.Input(pred_in[0].first, pred_in[0].second); // In1 (input data) of Pad
3187 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3188 nb.Input(succ_in[1].first, succ_in[1].second); // In2 (filter) of conv2d
3189 // In1 of Conv2D is same as output of Pad.
3190 // Thus, only need to add In2 of Conv2D
3191
3192 if (is_fused_conv2d) {
3193 // FusedConv2D has one additional input, args
3194 std::vector<NodeBuilder::NodeOut> args;
3195 int num_args = 1;
3196 TF_CHECK_OK(GetNodeAttr(succ->def(), "num_args", &num_args));
3197 for (int i = 0; i < num_args; i++) {
3198 args.emplace_back(succ_in[2 + i].first, succ_in[2 + i].second);
3199 }
3200 nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{
3201 args}); // In3 (args) of FusedConv2D
3202 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad
3203 // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D.
3204 CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ),
3205 const_cast<const Node*>(pred), &nb);
3206 } else {
3207 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad
3208 // Copy attributes from Pad and conv2D to PadWithConv2D.
3209 CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ),
3210 const_cast<const Node*>(pred), &nb);
3211 }
3212
3213 // Copy the device assigned to old node to new node.
3214 nb.Device(succ->def().device());
3215
3216 // Create node.
3217 Node* new_node;
3218 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3219 // No need to check if new_node is null because it will be null only when
3220 // Finalize fails.
3221
3222 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3223 // node are already copied in BuildNode.
3224 // We handle control edges now.
3225 for (const Edge* e : pred->in_edges()) {
3226 if (e->IsControlEdge()) {
3227 // Don't allow duplicate edge
3228 (*g)->AddControlEdge(e->src(), new_node, false);
3229 }
3230 }
3231 for (const Edge* e : succ->in_edges()) {
3232 if (e->IsControlEdge()) {
3233 // Don't allow duplicate edge
3234 (*g)->AddControlEdge(e->src(), new_node, false);
3235 }
3236 }
3237
3238 // Incoming edges are fixed, we will fix the outgoing edges now.
3239 // First, we will fix outgoing control edges from 'pred' node.
3240 for (const Edge* e : pred->out_edges()) {
3241 if (e->IsControlEdge()) {
3242 // Don't allow duplicate edge
3243 (*g)->AddControlEdge(new_node, e->dst(), false);
3244 }
3245 }
3246
3247 // Second, we will fix outgoing control and data edges from 'succ' node.
3248 for (const Edge* e : succ->out_edges()) {
3249 if (e->IsControlEdge()) {
3250 // Allow duplicate while adding control edge as it would fail (return
3251 // NULL) if we try to add duplicate edge.
3252 (*g)->AddControlEdge(new_node, e->dst(), false);
3253 } else {
3254 // Conv2D has only 1 output (at slot 0) and merged node also has only 1
3255 // output (at slot 0).
3256 const int kPadWithConv2DOutputSlot = 0;
3257 (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(),
3258 e->dst_input());
3259 }
3260 }
3261
3262 // Copy device assigned to old node to new node.
3263 // It's ok to use pred or succ as we have enforced a check that
3264 // both have same device assigned.
3265 new_node->set_assigned_device_name(pred->assigned_device_name());
3266
3267 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3268 << ", and node: " << succ->DebugString()
3269 << ", into node:" << new_node->DebugString();
3270
3271 (*g)->RemoveNode(succ);
3272 (*g)->RemoveNode(pred);
3273
3274 return Status::OK();
3275 }
3276
MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph> * g,Node * m,Node * n)3277 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
3278 std::unique_ptr<Graph>* g, Node* m, Node* n) {
3279 CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
3280 n->type_string() == csinfo_.conv2d_grad_filter)) ||
3281 ((n->type_string() == csinfo_.bias_add_grad &&
3282 m->type_string() == csinfo_.conv2d_grad_filter)),
3283 true);
3284
3285 // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
3286 Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
3287 Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m;
3288
3289 // Sanity check for attributes from input nodes.
3290 DataType T_b, T_f;
3291 string data_format_b, data_format_f;
3292 TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b));
3293 TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f));
3294 TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b));
3295 TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f));
3296 if (data_format_b != data_format_f || T_b != T_f ||
3297 badd->assigned_device_name() != fltr->assigned_device_name() ||
3298 badd->def().device() != fltr->def().device()) {
3299 return Status(error::Code::INVALID_ARGUMENT,
3300 "data_format or T attribute or devices of "
3301 "Conv2DBackpropFilter and BiasAddGrad do not match. "
3302 "Will skip node merge optimization");
3303 }
3304
3305 // We will use the node name of Conv2DBackpropFilter as the name of new node.
3306 // This is because BackpropFilterWithBias is going to emit bias output also.
3307 NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias);
3308 // Since Conv2DBackpropFilterWithBias has same number of inputs as
3309 // Conv2DBackpropFilter, we can just copy input edges directly. We don't need
3310 // to copy any data input of BiasAddGrad because that input also goes to
3311 // Conv2DBackpropFilter.
3312 const int fltr_ins = fltr->num_inputs();
3313 gtl::InlinedVector<Node*, 4> fltr_control_edges;
3314 gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins);
3315 FillInputs(fltr, &fltr_control_edges, &fltr_in_edges);
3316 for (int idx = 0; idx < fltr_ins; idx++) {
3317 nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second);
3318 }
3319
3320 // Copy attributes from Conv2DBackpropFilter.
3321 CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
3322
3323 // Copy the device assigned to old node to new node.
3324 nb.Device(fltr->def().device());
3325
3326 // Create node.
3327 Node* new_node;
3328 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3329
3330 // In the following code of this function, an unsorted set is used to make
3331 // sure no duplicated edges be added into the new node. Therefore, we can
3332 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3333 // check in the routine.
3334
3335 // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to
3336 // new 'new_node' node are already copied in BuildNode. We handle control
3337 // edges now.
3338 std::unordered_set<Node*> unique_node;
3339 for (const Edge* e : badd->in_edges()) {
3340 if (e->IsControlEdge()) {
3341 auto result = unique_node.insert(e->src());
3342 if (result.second) {
3343 (*g)->AddControlEdge(e->src(), new_node, true);
3344 }
3345 }
3346 }
3347 unique_node.clear();
3348 for (const Edge* e : fltr->in_edges()) {
3349 if (e->IsControlEdge()) {
3350 auto result = unique_node.insert(e->src());
3351 if (result.second) {
3352 (*g)->AddControlEdge(e->src(), new_node, true);
3353 }
3354 }
3355 }
3356 unique_node.clear();
3357
3358 // Incoming edges are fixed, we will fix the outgoing edges now.
3359 // First, we will fix outgoing control edges from 'badd' node.
3360 // Conv2DBackpropFilter has 1 output -- filter_grad.
3361 // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and
3362 // bias_grad. But filter_grad is at same slot number (0) in both the
3363 // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while
3364 // it is at slot number 0 in BiasAddGrad.
3365 const int kMergedNodeFilterGradOutputIdx = 0;
3366 const int kMergedNodeBiasGradOutputIdx = 1;
3367
3368 for (const Edge* e : badd->out_edges()) {
3369 if (e->IsControlEdge()) {
3370 auto result = unique_node.insert(e->dst());
3371 if (result.second) {
3372 (*g)->AddControlEdge(new_node, e->dst(), true);
3373 }
3374 } else {
3375 auto new_edge = (*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
3376 e->dst(), e->dst_input());
3377 DCHECK(new_edge);
3378 }
3379 }
3380 unique_node.clear();
3381
3382 // Second, we will fix outgoing control and data edges from 'fltr' node.
3383 for (const Edge* e : fltr->out_edges()) {
3384 if (e->IsControlEdge()) {
3385 auto result = unique_node.insert(e->dst());
3386 if (result.second) {
3387 (*g)->AddControlEdge(new_node, e->dst(), true);
3388 }
3389 } else {
3390 auto new_edge = (*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
3391 e->dst(), e->dst_input());
3392 DCHECK(new_edge);
3393 }
3394 }
3395
3396 // Copy device assigned to old node to new node.
3397 // It's ok to use badd or fltr as we have enforced a check that
3398 // both have same device assigned.
3399 new_node->set_assigned_device_name(badd->assigned_device_name());
3400
3401 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString()
3402 << ", and node: " << fltr->DebugString()
3403 << ", into node:" << new_node->DebugString();
3404
3405 (*g)->RemoveNode(badd);
3406 (*g)->RemoveNode(fltr);
3407
3408 return Status::OK();
3409 }
3410
MergeNode(std::unique_ptr<Graph> * g,Node * m,Node * n)3411 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m,
3412 Node* n) {
3413 DCHECK(m);
3414 DCHECK(n);
3415
3416 if (((m->type_string() == csinfo_.bias_add &&
3417 n->type_string() == csinfo_.conv2d)) ||
3418 ((n->type_string() == csinfo_.bias_add &&
3419 m->type_string() == csinfo_.conv2d))) {
3420 return this->MergeConv2DWithBiasAdd(g, m, n);
3421 }
3422 if ((m->type_string() == csinfo_.pad &&
3423 (n->type_string() == csinfo_.conv2d ||
3424 (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) ||
3425 (n->type_string() == csinfo_.pad &&
3426 (m->type_string() == csinfo_.conv2d ||
3427 (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) {
3428 return this->MergePadWithConv2D(g, m, n);
3429 }
3430
3431 if (((m->type_string() == csinfo_.bias_add_grad &&
3432 n->type_string() == csinfo_.conv2d_grad_filter)) ||
3433 ((n->type_string() == csinfo_.bias_add_grad &&
3434 m->type_string() == csinfo_.conv2d_grad_filter))) {
3435 return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n);
3436 }
3437
3438 return Status(error::Code::UNIMPLEMENTED,
3439 "Unimplemented case for node merge optimization.");
3440 }
3441
3442 //////////////////////////////////////////////////////////////////////////
3443 // Helper functions for node rewrite
3444 //////////////////////////////////////////////////////////////////////////
3445
RewriteNodeForLayoutPropagation(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3446 Status MklLayoutRewritePass::RewriteNodeForLayoutPropagation(
3447 std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3448 const RewriteInfo* ri) {
3449 // Get all data inputs.
3450 int num_data_inputs = orig_node->in_edges().size();
3451 // Drop count for control edges from inputs
3452 for (const Edge* e : orig_node->in_edges()) {
3453 if (e->IsControlEdge()) {
3454 num_data_inputs--;
3455 }
3456 }
3457
3458 gtl::InlinedVector<Node*, 4> control_edges;
3459 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3460 FillInputs(orig_node, &control_edges, &inputs);
3461
3462 // Build new node. We use same name as original node, but change the op name.
3463 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3464 // Copy user-specified device assigned to original node to new node.
3465 nb.Device(orig_node->def().device());
3466 // Set up new inputs to the rewritten node.
3467 Status s = SetUpInputs(g, inputs, &nb, orig_node);
3468 if (s != Status::OK()) {
3469 return s;
3470 }
3471
3472 const bool kPartialCopyAttrs = false;
3473 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs);
3474
3475 // Set the Mkl layer label for this op.
3476 if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3477 DataTypeIsQuantized(orig_node->output_type(0))) {
3478 nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3479 } else {
3480 nb.Attr("_kernel", mkl_op_registry::kMklLayoutDependentOpLabel);
3481 }
3482 // Finalize graph and get new node.
3483 s = nb.Finalize(&**g, new_node);
3484 if (s != Status::OK()) {
3485 return s;
3486 }
3487
3488 // In the following code of this function, an unsorted set is used to make
3489 // sure no duplicated edges be added into the new node. Therefore, we can
3490 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3491 // check in the routine.
3492
3493 // Incoming data edges from 'orig_node' node to new 'new_node' node are
3494 // already copied in BuildNode. We need to handle control edges now.
3495 std::unordered_set<Node*> unique_node;
3496 for (const Edge* e : orig_node->in_edges()) {
3497 if (e->IsControlEdge()) {
3498 auto result = unique_node.insert(e->src());
3499 if (result.second) {
3500 (*g)->AddControlEdge(e->src(), *new_node, true);
3501 }
3502 }
3503 }
3504 unique_node.clear();
3505
3506 // Copy outgoing edges from 'orig_node' node to new
3507 // 'new_node' node, since the output also follows same ordering among
3508 // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
3509 // tensors appropriately. Specifically, nth output of the original node
3510 // will become 2*nth output of the Mkl node for the interleaved ordering
3511 // of the tensors. For the contiguous ordering of the tensors, it will be n.
3512 // GetTensorDataIndex provides this mapping function.
3513 for (const Edge* e : orig_node->out_edges()) {
3514 if (e->IsControlEdge()) {
3515 auto result = unique_node.insert(e->dst());
3516 if (result.second) {
3517 (*g)->AddControlEdge(*new_node, e->dst(), true);
3518 }
3519 } else {
3520 auto new_edge = (*g)->AddEdge(
3521 *new_node,
3522 GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
3523 e->dst(), e->dst_input());
3524 DCHECK(new_edge);
3525 }
3526 }
3527 return Status::OK();
3528 }
3529
RewriteNodeForJustOpNameChange(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3530 Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(
3531 std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3532 const RewriteInfo* ri) {
3533 // Get all data inputs.
3534 int num_data_inputs = orig_node->in_edges().size();
3535 // Drop count for control edges from inputs
3536 for (const Edge* e : orig_node->in_edges()) {
3537 if (e->IsControlEdge()) {
3538 num_data_inputs--;
3539 }
3540 }
3541 gtl::InlinedVector<Node*, 4> control_edges;
3542 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3543 FillInputs(orig_node, &control_edges, &inputs);
3544
3545 // Build new node. We use same name as original node, but change the op name.
3546 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3547 // Copy user-specified device assigned to original node to new node.
3548 nb.Device(orig_node->def().device());
3549
3550 Status s = CopyInputs(orig_node, inputs, &nb);
3551 if (s != Status::OK()) {
3552 return s;
3553 }
3554
3555 std::vector<NodeBuilder::NodeOut> workspace_tensors;
3556 bool are_workspace_tensors_available = false;
3557 if (IsWorkspaceCheckNeeded(orig_node)) {
3558 AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors,
3559 &are_workspace_tensors_available);
3560 if (are_workspace_tensors_available) {
3561 DCHECK_EQ(workspace_tensors.size(), 1);
3562 nb.Input(workspace_tensors[0].node, workspace_tensors[0].index);
3563 }
3564 }
3565
3566 if (!NativeFormatEnabled()) {
3567 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
3568 } else {
3569 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, false);
3570 }
3571
3572 if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3573 DataTypeIsQuantized(orig_node->output_type(0))) {
3574 nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3575 } else {
3576 nb.Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
3577 }
3578
3579 // Finalize graph and get new node.
3580 s = nb.Finalize(&**g, new_node);
3581 if (s != Status::OK()) {
3582 return s;
3583 }
3584
3585 // In the following code of this function, an unsorted set is used to make
3586 // sure no duplicated edges be added into the new node. Therefore, we can
3587 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3588 // check in the routine.
3589
3590 // Incoming data edges from 'orig_node' node to new 'new_node' node are
3591 // already copied in BuildNode. We need to handle control edges now.
3592 std::unordered_set<Node*> unique_node;
3593 for (const Edge* e : orig_node->in_edges()) {
3594 if (e->IsControlEdge()) {
3595 auto result = unique_node.insert(e->src());
3596 if (result.second) {
3597 (*g)->AddControlEdge(e->src(), *new_node, true);
3598 }
3599 }
3600 }
3601 unique_node.clear();
3602
3603 // Transfer outgoing edges from 'orig_node' node to new 'new_node' node.
3604 for (const Edge* e : orig_node->out_edges()) {
3605 if (e->IsControlEdge()) {
3606 auto result = unique_node.insert(e->dst());
3607 if (result.second) {
3608 (*g)->AddControlEdge(*new_node, e->dst(), true);
3609 }
3610 } else {
3611 auto result =
3612 (*g)->AddEdge(*new_node, e->src_output(), e->dst(), e->dst_input());
3613 DCHECK(result != nullptr);
3614 }
3615 }
3616
3617 return Status::OK();
3618 }
3619
RewriteNode(std::unique_ptr<Graph> * g,Node * orig_node,const RewriteInfo * ri)3620 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
3621 Node* orig_node,
3622 const RewriteInfo* ri) {
3623 DCHECK(ri != nullptr);
3624 DCHECK(orig_node != nullptr);
3625
3626 VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
3627
3628 Status ret_status = Status::OK();
3629 Node* new_node = nullptr;
3630 if (ri->rewrite_cause == kRewriteForLayoutPropagation) {
3631 ret_status = RewriteNodeForLayoutPropagation(g, orig_node, &new_node, ri);
3632 } else if (ri->rewrite_cause == kRewriteForOpNameChange) {
3633 ret_status = RewriteNodeForJustOpNameChange(g, orig_node, &new_node, ri);
3634 } else {
3635 ret_status = Status(error::Code::INVALID_ARGUMENT,
3636 "Unsupported rewrite cause found."
3637 "RewriteNode will fail.");
3638 }
3639 TF_CHECK_OK(ret_status);
3640
3641 // Copy the runtime device assigned from original code to new node.
3642 new_node->set_assigned_device_name(orig_node->assigned_device_name());
3643
3644 // Delete original node and mark new node as rewritten.
3645 (*g)->RemoveNode(orig_node);
3646
3647 VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
3648 return ret_status;
3649 }
3650
3651 // TODO(mdfaijul): Is there any other elegant way to check for quantized ops
3652 // having attributes other than "T"?
3653 // Current implementation reflects only QuantizedConv2D and its fused Ops.
3654 const MklLayoutRewritePass::RewriteInfo*
CheckForQuantizedNodeRewrite(const Node * n) const3655 MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {
3656 DataType T1, T2;
3657 DataType Tinput, Tfilter;
3658 bool type_attrs_present = false;
3659
3660 if (TryGetNodeAttr(n->def(), "Tinput", &Tinput) &&
3661 TryGetNodeAttr(n->def(), "Tfilter", &Tfilter) &&
3662 mkl_op_registry::IsMklQuantizedOp(
3663 mkl_op_registry::GetMklOpName(n->type_string()), Tinput, Tfilter)) {
3664 type_attrs_present = true;
3665 } else if (TryGetNodeAttr(n->def(), "T1", &T1) &&
3666 TryGetNodeAttr(n->def(), "T2", &T2) &&
3667 mkl_op_registry::IsMklQuantizedOp(
3668 mkl_op_registry::GetMklOpName(n->type_string()), T1, T2)) {
3669 type_attrs_present = true;
3670 }
3671
3672 if (type_attrs_present) {
3673 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3674 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3675 return &*ri;
3676 }
3677 }
3678 }
3679
3680 return nullptr;
3681 }
3682
3683 const MklLayoutRewritePass::RewriteInfo*
CheckForNodeRewrite(const Node * n) const3684 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
3685 DCHECK(n);
3686
3687 // QuantizedOps may have attributes other than "T", so decoupled the check
3688 // with a function, CheckForQuantizedNodeRewrite(const Node*).
3689 const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n);
3690 if (ri != nullptr) return ri;
3691
3692 // First check if node along with its type is supported by MKL layer.
3693 // We do not want to rewrite an op into Mkl op if types are not supported.
3694 // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
3695 // MklRelu if type is INT32.
3696 DataType T;
3697 if (!TryGetNodeAttr(n->def(), "T", &T)) {
3698 return nullptr;
3699 }
3700
3701 // We make an exception for Conv2DGrad and MaxPool related ops as
3702 // the corresponding MKL ops currently do not support the case
3703 // of padding == EXPLICIT yet.
3704 // TODO(intel): support `EXPLICIT` padding for ConvGrad
3705 if (n->type_string() == csinfo_.conv2d_grad_input ||
3706 n->type_string() == csinfo_.conv2d_grad_filter ||
3707 n->type_string() == csinfo_.depthwise_conv2d_grad_filter ||
3708 n->type_string() == csinfo_.depthwise_conv2d_grad_input ||
3709 n->type_string() == csinfo_.conv3d_grad_filter ||
3710 n->type_string() == csinfo_.conv3d_grad_filter ||
3711 n->type_string() == csinfo_.max_pool ||
3712 n->type_string() == csinfo_.max_pool_grad ||
3713 n->type_string() == csinfo_.max_pool3d ||
3714 n->type_string() == csinfo_.max_pool3d_grad) {
3715 string padding;
3716 TF_CHECK_OK(GetNodeAttr(n->def(), "padding", &padding));
3717 if (padding == "EXPLICIT") return nullptr;
3718 }
3719
3720 // We make an exception for __MklDummyConv2DWithBias,
3721 // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their
3722 // names do not match Mkl node names.
3723 if (n->type_string() != csinfo_.conv2d_with_bias &&
3724 n->type_string() != csinfo_.pad_with_conv2d &&
3725 n->type_string() != csinfo_.pad_with_fused_conv2d &&
3726 n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
3727 n->type_string() != csinfo_.fused_batch_norm_ex &&
3728 n->type_string() != csinfo_.fused_conv2d &&
3729 n->type_string() != csinfo_.fused_depthwise_conv2d &&
3730 n->type_string() != csinfo_.fused_matmul &&
3731 n->type_string() != csinfo_.fused_conv3d &&
3732 !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
3733 T)) {
3734 return nullptr;
3735 }
3736
3737 // We now check if rewrite rule applies for this op. If rewrite rule passes
3738 // for this op, then we rewrite it to Mkl op.
3739 // Find matching RewriteInfo and then check that rewrite rule applies.
3740 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3741 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3742 return &*ri;
3743 }
3744 }
3745
3746 // Else return not found.
3747 return nullptr;
3748 }
3749
3750 //////////////////////////////////////////////////////////////////////////
3751 // Helper functions for node fusion
3752 //////////////////////////////////////////////////////////////////////////
FuseTransposeMklOpTranspose(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,std::function<void (const Node *,NodeBuilder * nb,bool)> copy_attrs,string data_format)3753 Status MklLayoutRewritePass::FuseTransposeMklOpTranspose(
3754 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3755 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
3756 string data_format) {
3757 Node* transpose_to_nhwc = nodes[0];
3758 Node* mklop = nodes[1];
3759 Node* transpose_to_nchw = nodes[2];
3760
3761 const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs();
3762 gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges;
3763 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in(
3764 transpose_nhwc_num_inputs);
3765 FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges,
3766 &transpose_nhwc_in);
3767
3768 const int mklop_num_inputs = mklop->num_inputs();
3769 gtl::InlinedVector<Node*, 4> mklop_control_edges;
3770 gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs);
3771 FillInputs(mklop, &mklop_control_edges, &mklop_in);
3772
3773 const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs();
3774 gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges;
3775 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in(
3776 transpose_nchw_num_inputs);
3777 FillInputs(transpose_to_nchw, &transpose_nchw_control_edges,
3778 &transpose_nchw_in);
3779
3780 // We use same name as original node, but change the op
3781 // type.
3782 NodeBuilder nb(mklop->name(), mklop->type_string());
3783
3784 // Storing the output slots of the input nodes.
3785 for (int i = 0; i < mklop_num_inputs; i++) {
3786 if (mklop_in[i].first == transpose_to_nhwc) {
3787 // Fill "x":
3788 nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second);
3789 } else {
3790 // Fill inputs other than "x":
3791 nb.Input(mklop_in[i].first, mklop_in[i].second);
3792 }
3793 }
3794
3795 copy_attrs(const_cast<const Node*>(mklop), &nb, true);
3796 nb.Attr("data_format", data_format);
3797
3798 // Copy the device assigned to old node to new node.
3799 nb.Device(mklop->def().device());
3800
3801 // Create node.
3802 Node* new_node;
3803 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3804 // No need to check if new_node is null because it will be null only when
3805 // Finalize fails.
3806
3807 // Fill outputs.
3808 for (const Edge* e : transpose_to_nchw->out_edges()) {
3809 if (!e->IsControlEdge()) {
3810 const int kTransposeWithMklOpOutputSlot = 0;
3811 auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot,
3812 e->dst(), e->dst_input());
3813 DCHECK(new_edge);
3814 }
3815 }
3816
3817 // Copy device assigned to old node to new node.
3818 new_node->set_assigned_device_name(mklop->assigned_device_name());
3819
3820 // Copy requested_device and assigned_device_name_index
3821 new_node->set_requested_device(mklop->requested_device());
3822 new_node->set_assigned_device_name_index(mklop->assigned_device_name_index());
3823
3824 (*g)->RemoveNode(transpose_to_nhwc);
3825 (*g)->RemoveNode(mklop);
3826 (*g)->RemoveNode(transpose_to_nchw);
3827
3828 return Status::OK();
3829 }
3830
FuseNode(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,const MklLayoutRewritePass::FusionInfo fi)3831 Status MklLayoutRewritePass::FuseNode(
3832 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3833 const MklLayoutRewritePass::FusionInfo fi) {
3834 return fi.fuse_func(g, nodes, fi.copy_attrs);
3835 }
3836
3837 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
CheckForNodeFusion(Node * a) const3838 MklLayoutRewritePass::CheckForNodeFusion(Node* a) const {
3839 // Stores matched nodes, in the same order as node_checkers.
3840 std::vector<Node*> nodes;
3841
3842 for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) {
3843 //
3844 // Make sure node "a" and its succeeding nodes (b, c ...), match the pattern
3845 // defined in fusion info (ops[0], ops[1], ...),
3846 // a.k.a. "a->b->c" matches "op1->op2->op3"
3847 //
3848
3849 // Stores the first unvisited outgoing edge of each matched node in "nodes".
3850 std::stack<EdgeSet::const_iterator> current_neighbor_stack;
3851 nodes.clear();
3852
3853 auto node_checker = fi->node_checkers.begin();
3854 if (a != nullptr && (*node_checker)(a)) {
3855 nodes.push_back(a);
3856 current_neighbor_stack.push(a->out_edges().begin());
3857 ++node_checker;
3858 }
3859
3860 while (!nodes.empty()) {
3861 auto& current_neighbor_iter = current_neighbor_stack.top();
3862
3863 if (current_neighbor_iter != nodes.back()->out_edges().end()) {
3864 // Found an unvisited edge. Goes through the edge to get the neighbor.
3865 Node* neighbor_node = (*current_neighbor_iter)->dst();
3866 ++current_neighbor_stack.top(); // Retrieves the next unvisited edge.
3867
3868 if ((*node_checker)(neighbor_node)) {
3869 // Found a match. Stores the node and moves to the next checker.
3870 nodes.push_back(neighbor_node);
3871 current_neighbor_stack.push(neighbor_node->out_edges().begin());
3872 if (++node_checker == fi->node_checkers.end()) {
3873 return make_tuple(true, nodes, *fi);
3874 }
3875 }
3876 } else {
3877 // Removes the current node since none of its neighbor leads to a
3878 // further match.
3879 nodes.pop_back();
3880 current_neighbor_stack.pop();
3881 --node_checker;
3882 }
3883 }
3884 }
3885
3886 return make_tuple(false, std::vector<Node*>(), FusionInfo());
3887 }
3888
3889 ///////////////////////////////////////////////////////////////////////////////
3890 // Post-rewrite Mkl metadata fixup pass
3891 ///////////////////////////////////////////////////////////////////////////////
FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph> * g,const Edge * e_data,const Edge * e_metadata)3892 bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
3893 const Edge* e_data,
3894 const Edge* e_metadata) {
3895 if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
3896 return false;
3897 }
3898
3899 Node* n_data = e_data->src();
3900 int n_data_op_slot = e_data->src_output();
3901 int n_metadata_op_slot =
3902 GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs());
3903
3904 // If the source of meta edge is a constant node (producing dummy Mkl metadata
3905 // tensor), then we will need to fix.
3906 if (IsConstant(e_metadata->src())) {
3907 Node* e_metadata_dst = e_metadata->dst();
3908 int e_metadata_in_slot = e_metadata->dst_input();
3909 auto new_edge = (*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst,
3910 e_metadata_in_slot);
3911 DCHECK(new_edge);
3912
3913 (*g)->RemoveEdge(e_metadata);
3914 return true;
3915 }
3916
3917 return false;
3918 }
3919
FixMklMetaDataEdges(std::unique_ptr<Graph> * g,Node * n)3920 bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
3921 Node* n) {
3922 bool result = false;
3923
3924 // If graph node is not Mkl node, then return.
3925 DataType T = DT_INVALID;
3926 if (!TryGetNodeAttr(n->def(), "T", &T) ||
3927 !mkl_op_registry::IsMklOp(n->type_string(), T, false)) {
3928 return result;
3929 }
3930
3931 // If it is Mkl node, then check if the input edges to this node that carry
3932 // Mkl metadata are linked up correctly with the source node.
3933
3934 // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
3935 // data tensors + n for Mkl metadata tensors). We need to check for correct
3936 // connection of n metadata tensors only.
3937 int num_data_inputs = n->num_inputs() / 2;
3938 for (int idx = 0; idx < num_data_inputs; idx++) {
3939 // Get the edge connecting input slot with index (idx).
3940 const Edge* e = nullptr;
3941 TF_CHECK_OK(n->input_edge(idx, &e));
3942
3943 // If e is control edge, then skip.
3944 if (e->IsControlEdge()) {
3945 continue;
3946 }
3947
3948 // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
3949 // node, then we don't need to do anything.
3950 Node* e_src = e->src();
3951 if (TryGetNodeAttr(e_src->def(), "T", &T) &&
3952 mkl_op_registry::IsMklOp(e_src->type_string(), T, false)) {
3953 // Source node for edge 'e' is Mkl node.
3954 // Destination node and destination input slot of e is node 'n' and 'idx'
3955 // resp.
3956 CHECK_EQ(e->dst(), n);
3957 CHECK_EQ(e->dst_input(), idx);
3958
3959 // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
3960 // 'e'. For that, let's first get the input slot of 'n' where the meta
3961 // edge will feed the value.
3962 int e_meta_in_slot =
3963 GetTensorMetaDataIndex(e->dst_input(), n->num_inputs());
3964 const Edge* e_meta = nullptr;
3965 TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
3966
3967 // Let's check if we need to fix this meta edge.
3968 if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
3969 result = true;
3970 }
3971 }
3972 }
3973
3974 return result;
3975 }
3976
3977 ///////////////////////////////////////////////////////////////////////////////
3978 // Run function for the pass
3979 ///////////////////////////////////////////////////////////////////////////////
3980
RunPass(std::unique_ptr<Graph> * g)3981 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
3982 bool result = false;
3983 DCHECK(g);
3984
3985 DumpGraph("Before running MklLayoutRewritePass", &**g);
3986
3987 std::vector<Node*> order;
3988 GetReversePostOrder(**g, &order); // This will give us topological sort.
3989 for (Node* n : order) {
3990 // If node is not an op or it cannot run on CPU device, then skip.
3991 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
3992 continue;
3993 }
3994
3995 Node* m = nullptr;
3996 if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) {
3997 // Check if the node 'n' can be merged with any other node. If it can
3998 // be 'm' contains the node with which it can be merged.
3999 string n1_name = n->name();
4000 string n2_name = m->name();
4001
4002 VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
4003 << n2_name << " for merging";
4004
4005 if (MergeNode(g, n, m) == Status::OK()) {
4006 VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
4007 << n2_name;
4008 result = true;
4009 }
4010 }
4011 }
4012
4013 DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g);
4014
4015 order.clear();
4016 GetReversePostOrder(**g, &order); // This will give us topological sort.
4017 for (Node* n : order) {
4018 // If node is not an op or it cannot run on CPU device, then skip.
4019 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4020 continue;
4021 }
4022
4023 auto check_result = CheckForNodeFusion(n);
4024 bool found_pattern = std::get<0>(check_result);
4025 std::vector<Node*> nodes = std::get<1>(check_result);
4026 const FusionInfo fi = std::get<2>(check_result);
4027
4028 // if "found_pattern" is true, we can do the fusion.
4029 if (found_pattern) {
4030 if (FuseNode(g, nodes, fi) == Status::OK()) {
4031 result = true;
4032 }
4033 }
4034 }
4035 DumpGraph("After running MklLayoutRewritePass(NodeFusion)", &**g);
4036
4037 order.clear();
4038 GetReversePostOrder(**g, &order); // This will give us topological sort.
4039 for (Node* n : order) {
4040 // If node is not an op or it cannot run on CPU device, then skip.
4041 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4042 continue;
4043 }
4044
4045 const RewriteInfo* ri = nullptr;
4046 // We will first search if node is to be rewritten.
4047 if ((ri = CheckForNodeRewrite(n)) != nullptr) {
4048 string node_name = n->name();
4049 string op_name = n->type_string();
4050
4051 VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
4052 << " with op " << op_name << " for rewrite using"
4053 << " layout optimization.";
4054
4055 if (RewriteNode(g, n, ri) == Status::OK()) {
4056 VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
4057 << " with op " << op_name << " for Mkl layout optimization.";
4058 result = true;
4059 }
4060 }
4061 }
4062
4063 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
4064
4065 if (!NativeFormatEnabled()) {
4066 order.clear();
4067 GetReversePostOrder(**g, &order); // This will give us topological sort.
4068 for (Node* n : order) {
4069 // If node is not an op or it cannot run on CPU device, then skip.
4070 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4071 continue;
4072 }
4073 if (FixMklMetaDataEdges(g, n)) {
4074 string node_name = n->name();
4075 string op_name = n->type_string();
4076
4077 VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
4078 << node_name << " with op " << op_name;
4079 result = true;
4080 }
4081 }
4082 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
4083 &**g);
4084 }
4085
4086 return result;
4087 }
4088
RunMklLayoutRewritePass(std::unique_ptr<Graph> * g)4089 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
4090 return MklLayoutRewritePass().RunPass(g);
4091 }
4092
Run(const GraphOptimizationPassOptions & options)4093 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
4094 if (options.graph == nullptr && options.partition_graphs == nullptr) {
4095 return Status::OK();
4096 }
4097 if (!IsMKLEnabled()) {
4098 VLOG(2) << "TF-MKL: MKL is not enabled";
4099 return Status::OK();
4100 }
4101
4102 auto process_graph = [&](std::unique_ptr<Graph>* g) {
4103 // Get the ownership of a graph
4104 std::unique_ptr<Graph>* ng = std::move(g);
4105 RunPass(ng);
4106 // Return the ownership of a graph back
4107 g->reset(ng->release());
4108 };
4109
4110 if (kMklLayoutRewritePassGroup !=
4111 OptimizationPassRegistry::POST_PARTITIONING) {
4112 // For any pre-partitioning phase, a graph is stored in options.graph.
4113 process_graph(options.graph);
4114 } else {
4115 // For post partitioning phase, graphs are stored in
4116 // options.partition_graphs.
4117 for (auto& pg : *options.partition_graphs) {
4118 process_graph(&pg.second);
4119 }
4120 }
4121
4122 return Status::OK();
4123 }
4124
4125 } // namespace tensorflow
4126
4127 #endif // INTEL_MKL
4128