1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // TODO(intel): Improve error handling in this file; instead of CHECK failing
17 // all over the place, we should log an error and execute the original graph.
18 #ifdef INTEL_MKL
19
20 #include "tensorflow/core/common_runtime/mkl_layout_pass.h"
21
22 #include <algorithm>
23 #include <functional>
24 #include <memory>
25 #include <queue>
26 #include <set>
27 #include <stack>
28 #include <tuple>
29 #include <unordered_set>
30 #include <utility>
31 #include <vector>
32
33 #include "absl/base/call_once.h"
34 #include "tensorflow/core/common_runtime/function.h"
35 #include "tensorflow/core/common_runtime/optimization_registry.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/tensor.pb.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph.h"
40 #include "tensorflow/core/graph/mkl_graph_util.h"
41 #include "tensorflow/core/graph/node_builder.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/gtl/array_slice.h"
44 #include "tensorflow/core/lib/gtl/map_util.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/util/tensor_format.h"
48 #include "tensorflow/core/util/util.h"
49
50 namespace tensorflow {
51
52 // This pass implements rewriting of graph to support following scenarios:
53 // (A) Merging nodes in the graph
54 // (B) Rewriting a node in the graph to a new node
55 // Rewrite happens under following scenario:
56 // - Propagating Mkl layout as an additional output tensor
57 // (we will loosely call a tensor that carries Mkl layout as Mkl tensor
58 // henceforth.) from every Mkl supported NN layer.
59 //
60 // Example of A : Merging nodes in the graph
61 // -----------------------------------------
62 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
63 //
64 // O = Conv2D(A, B)
65 // P = BiasAdd(O, C)
66 //
67 // We merge them into Conv2DWithBias as:
68 // P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
69 //
70 // The meaning of A_m, B_m and C_m is explained in B.1.
71 //
72 // Merge rules:
73 // - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
74 // goes to BiasAdd.
75 // - Also, the intersection of attributes of both the nodes must have same
76 // values.
77 // - Both the nodes must have been assigned to same device (if any).
78 //
79 // Example of B.1 : Rewriting nodes to Mkl nodes
80 // ---------------------------------------------
81 // Consider a Relu node. Current definition of Relu node looks like:
82 //
83 // O = Relu(A)
84 //
85 // Relu has 1 input (A), and 1 output (O).
86 //
87 // This rewrite pass will generate a new graph node for Relu (new node is
88 // called MklRelu) as:
89 //
90 // O, O_m = MklRelu(A, A_m)
91 //
92 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
93 // same as input A of Relu; output O is same as output O of Relu. O_m is the
94 // additional output tensor that will be set by MklRelu, and it represents
95 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
96 // metadata for O. A_m is additional input of Relu, and it represents metadata
97 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
98 // this metadata from previous node in the graph.
99 //
100 // When a previous node in the graph is an Mkl node, A_m will represent a valid
101 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
102 // a dummy Mkl tensor.
103 //
104 // Rewriting rules:
105 // - Selection of a node for rewriting happens by registering the op type of
106 // the node with the rewriting pass. If the op type is not registered, then
107 // all nodes of this op type will not be rewritten.
108 // - Number of inputs after rewriting:
109 // Since for every input Tensorflow tensor, the rewritten node gets Mkl
110 // tensor(s), rewritten node gets 2*N inputs, where N is the number of
111 // inputs for the original node.
112 // - Number of outputs after rewriting:
113 // Since for every output Tensorflow tensor, the rewritten node generates
114 // Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
115 // number of outputs of the original node.
116 // - Ordering of Tensorflow tensors and Mkl tensors:
117 // Since every rewritten node generates twice the number of inputs and
118 // outputs, one could imagine various orderings among Tensorflow tensors
119 // and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
120 // inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
121 // in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
122 // order. Among N inputs one can get N! permutations.
123 //
124 // So the question is: which order do we follow? We support 2 types of
125 // orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
126 // follows an intuitive order where an Mkl tensor follows the
127 // corresponding Tensorflow tensor immediately. In the context of the
128 // above example, it will be: A, A_m, B, B_m. Note that the ordering rule
129 // applies to both the inputs and outputs. Contiguous ordering means
130 // all the Tensorflow tensors are contiguous followed by all the Mkl
131 // tensors. We use contiguous ordering as default.
132 //
133 // Graph rewrite algorithm:
134 // Algorithm: Graph Rewrite
135 // Input: Graph G, Names of the nodes to rewrite and their new names
136 // Output: Modified Graph G' if the nodes are modified, G otherwise.
137 // Start:
138 // N = Topological_Sort(G) // N is a set of nodes in toposort order.
139 // foreach node n in N
140 // do
141 // if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input.
142 // then
143 // E = set of <incoming edge and its src_output slot> of n
144 // E' = {} // a new set of edges for rewritten node
145 // foreach <e,s> in E
146 // do
147 // E' U {<e,s>} // First copy edge which generates Tensorflow
148 // // tensor as it is
149 // m = Source node of edge e
150 // if Is_Rewritten(m) // Did we rewrite this node in this pass?
151 // then
152 // E' U {<m,s+1>} // If yes, then m will generate an Mkl
153 // // tensor as an additional output.
154 // else
155 // d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy
156 // // Mkl tensor.
157 // E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot.
158 // fi
159 // done
160 // n' = Build_New_Node(G,new_name,E')
161 // Mark_Rewritten(n') // Mark the new node as being rewritten.
162 // fi
163 // done
164 //
165 // Explanation:
166 // For graph rewrite, we visit nodes of the input graph in the
167 // topological sort order. With this ordering, we visit nodes in the
168 // top-to-bottom fashion. We need this order because while visiting a
169 // node we want that all of its input nodes are visited and rewritten if
170 // applicable. This is because if we need to rewrite a given node
171 // then all of its input nodes need to be fixed (in other words they
172 // cannot be deleted later.)
173 //
174 // While visiting a node, we first check if the op type of the node is
175 // an Mkl op. If it is, then we rewrite that node after constructing
176 // new inputs to the node. If the op type of the node is not Mkl op,
177 // then we do not rewrite that node.
178 //
179 // Handling workspace propagation for certain ops:
180 //
181 // Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
182 // passing of a workspace from their respective forward ops. Workspace
183 // tensors provide memory for storing results of intermediate operations
184 // which are helpful in backward propagation. TensorFlow does not have
185 // a notion of a workspace and as a result does not allow producing
186 // additional outputs from these forward ops. For these ops, we need
187 // to add 2 extra edges between forward ops and their corresponding
188 // backward ops - the first extra edge carries a workspace tensor and
189 // the second one carries an Mkl tensor for the workspace tensor.
190 //
191 // Example:
192 //
193 // Typical graph for MaxPool and its gradient looks like:
194 //
195 // A = MaxPool(T)
196 // B = MaxPoolGrad(X, A, Y)
197 //
198 // We will transform this graph to propagate the workspace as:
199 // (with the contiguous ordering)
200 //
201 // A, W, A_m, W_m = MklMaxPool(T, T_m)
202 // B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
203 //
204 // Here W is the workspace tensor. Transformed tensor names with the
205 // suffix _m are Mkl tensors, and this transformation has been done
206 // using the algorithm discussed earlier. The transformation for
207 // workspace propagation only adds extra outputs (W, W_m) for a forward
208 // op and connects them to the corresponding backward ops.
209 //
210 // Terms:
211 //
212 // Forward op name = name of the op in the forward pass
213 // where a workspace tensor originates (MaxPool in this example)
214 // Backward op name = name of the op in the backward pass that receives
215 // a workspace tensor from the forward op (MaxPoolGrad in the example)
216 // Slot = Position of the output or input slot that will be
217 // used by the workspace tensor (1 for MklMaxPool as W is the 2nd
218 // output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
219 //
220 // Question:
221 //
222 // How do we associate a backward op to a forward op? There can be more
223 // than one op with the exact same name.
224 //
225 // In this example, we associate MaxPoolGrad with MaxPool. But there
226 // could be more than one MaxPool ops. To solve this problem, we look
227 // for _direct_ edge between a forward op and a backward op (tensor A is
228 // flowing along this edge in the example).
229 //
230 // How do we transform forward and backward ops when there is no direct
231 // edge between them? In such a case, we generate dummy tensors for
232 // workspace tensors. For the example, transformation of MaxPool will
233 // be exactly same as it would be when there is a direct edge between
234 // the forward and the backward op --- it is just that MaxPool won't
235 // generate any workspace tensor. For MaxPoolGrad, the transformation
236 // will also be same, but instead of connecting W and W_m with the
237 // outputs of MaxPool, we will produce dummy tensors for them, and we
238 // will set workspace_enabled attribute to false.
239 //
240 class MklLayoutRewritePass : public GraphOptimizationPass {
241 public:
MklLayoutRewritePass()242 MklLayoutRewritePass() {
243 // NOTE: names are alphabetically sorted.
244 csinfo_.addn = "AddN";
245 csinfo_.avg_pool = "AvgPool";
246 csinfo_.avg_pool_grad = "AvgPoolGrad";
247 csinfo_.avg_pool3d = "AvgPool3D";
248 csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
249 csinfo_.batch_matmul = "BatchMatMul";
250 csinfo_.batch_matmul_v2 = "BatchMatMulV2";
251 csinfo_.bias_add = "BiasAdd";
252 csinfo_.bias_add_grad = "BiasAddGrad";
253 csinfo_.concat = "Concat";
254 csinfo_.concatv2 = "ConcatV2";
255 csinfo_.conjugate_transpose = "ConjugateTranspose";
256 csinfo_.conv2d = "Conv2D";
257 csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias";
258 csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
259 csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
260 csinfo_.conv2d_grad_filter_with_bias =
261 "__MklDummyConv2DBackpropFilterWithBias";
262 csinfo_.conv3d = "Conv3D";
263 csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
264 csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
265 csinfo_.depthwise_conv2d = "DepthwiseConv2dNative";
266 csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput";
267 csinfo_.depthwise_conv2d_grad_filter =
268 "DepthwiseConv2dNativeBackpropFilter";
269 csinfo_.dequantize = "Dequantize";
270 csinfo_.einsum = "Einsum";
271 csinfo_.fused_batch_norm = "FusedBatchNorm";
272 csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
273 csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx";
274 csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
275 csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
276 csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
277 csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
278 csinfo_.fused_conv2d = "_FusedConv2D";
279 csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
280 csinfo_.fused_matmul = "_FusedMatMul";
281 csinfo_.identity = "Identity";
282 csinfo_.leakyrelu = "LeakyRelu";
283 csinfo_.leakyrelu_grad = "LeakyReluGrad";
284 csinfo_.lrn = "LRN";
285 csinfo_.lrn_grad = "LRNGrad";
286 csinfo_.matmul = "MatMul";
287 csinfo_.max_pool = "MaxPool";
288 csinfo_.max_pool_grad = "MaxPoolGrad";
289 csinfo_.max_pool3d = "MaxPool3D";
290 csinfo_.max_pool3d_grad = "MaxPool3DGrad";
291 csinfo_.mkl_conv2d = "_MklConv2D";
292 csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
293 csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
294 csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
295 csinfo_.mkl_conv2d_grad_filter_with_bias =
296 "_MklConv2DBackpropFilterWithBias";
297 csinfo_.mkl_depthwise_conv2d_grad_input =
298 "_MklDepthwiseConv2dNativeBackpropInput";
299 csinfo_.mkl_depthwise_conv2d_grad_filter =
300 "_MklDepthwiseConv2dNativeBackpropFilter";
301 csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx";
302 csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
303 csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
304 csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
305 csinfo_.mkl_native_conv2d_with_bias = "_MklNativeConv2DWithBias";
306 csinfo_.mkl_native_conv2d_grad_filter_with_bias =
307 "_MklNativeConv2DBackpropFilterWithBias";
308 csinfo_.mkl_native_fused_batch_norm_ex = "_MklNativeFusedBatchNormEx";
309 csinfo_.mkl_native_fused_conv2d = "_MklNativeFusedConv2D";
310 csinfo_.mkl_native_fused_depthwise_conv2d =
311 "_MklNativeFusedDepthwiseConv2dNative";
312 csinfo_.mkl_native_fused_matmul = "_MklNativeFusedMatMul";
313 csinfo_.mkl_native_pad_with_conv2d = "_MklNativePadWithConv2D";
314 csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D";
315 csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
316 csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
317 csinfo_.pad = "Pad";
318 csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D";
319 csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D";
320 csinfo_.quantized_avg_pool = "QuantizedAvgPool";
321 csinfo_.quantized_concatv2 = "QuantizedConcatV2";
322 csinfo_.quantized_conv2d = "QuantizedConv2D";
323 csinfo_.quantized_conv2d_per_channel = "QuantizedConv2DPerChannel";
324 csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize";
325 csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias";
326 csinfo_.quantized_conv2d_with_bias_and_requantize =
327 "QuantizedConv2DWithBiasAndRequantize";
328 csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu";
329 csinfo_.quantized_conv2d_and_relu_and_requantize =
330 "QuantizedConv2DAndReluAndRequantize";
331 csinfo_.quantized_conv2d_with_bias_and_relu =
332 "QuantizedConv2DWithBiasAndRelu";
333 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize =
334 "QuantizedConv2DWithBiasAndReluAndRequantize";
335 csinfo_.quantized_max_pool = "QuantizedMaxPool";
336 csinfo_.quantized_conv2d_with_bias_sum_and_relu =
337 "QuantizedConv2DWithBiasSumAndRelu";
338 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize =
339 "QuantizedConv2DWithBiasSumAndReluAndRequantize";
340 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize =
341 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize";
342 csinfo_.quantized_matmul_with_bias = "QuantizedMatMulWithBias";
343 csinfo_.quantized_matmul_with_bias_and_relu =
344 "QuantizedMatMulWithBiasAndRelu";
345 csinfo_.quantized_matmul_with_bias_and_relu_and_requantize =
346 "QuantizedMatMulWithBiasAndReluAndRequantize";
347 csinfo_.quantized_matmul_with_bias_and_dequantize =
348 "QuantizedMatMulWithBiasAndDequantize";
349 csinfo_.quantized_matmul_with_bias_and_requantize =
350 "QuantizedMatMulWithBiasAndRequantize";
351 csinfo_.quantized_depthwise_conv2d = "QuantizedDepthwiseConv2D";
352 csinfo_.quantized_depthwise_conv2d_with_bias =
353 "QuantizedDepthwiseConv2DWithBias";
354 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu =
355 "QuantizedDepthwiseConv2DWithBiasAndRelu";
356 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize =
357 "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize";
358 csinfo_.quantize_v2 = "QuantizeV2";
359 csinfo_.relu = "Relu";
360 csinfo_.relu_grad = "ReluGrad";
361 csinfo_.relu6 = "Relu6";
362 csinfo_.relu6_grad = "Relu6Grad";
363 csinfo_.requantize = "Requantize";
364 csinfo_.tanh = "Tanh";
365 csinfo_.tanh_grad = "TanhGrad";
366 csinfo_.reshape = "Reshape";
367 csinfo_.slice = "Slice";
368 csinfo_.softmax = "Softmax";
369 csinfo_.split = "Split";
370 csinfo_.transpose = "Transpose";
371 // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
372 // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
373 // MklInputConversion op is added before it.
374 csinfo_.add = "Add";
375 csinfo_.add_v2 = "AddV2";
376 csinfo_.maximum = "Maximum";
377 csinfo_.mul = "Mul";
378 csinfo_.squared_difference = "SquaredDifference";
379 csinfo_.sub = "Sub";
380 // End - element-wise ops. See note above.
381
382 const bool native_fmt = NativeFormatEnabled();
383 // NOTE: names are alphabetically sorted.
384 rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
385 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
386 rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
387 CopyAttrsAll, RewriteIfAtleastOneMklInput,
388 GetRewriteCause()});
389 rinfo_.push_back(
390 {csinfo_.add_v2, mkl_op_registry::GetMklOpName(csinfo_.add_v2),
391 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
392 rinfo_.push_back({csinfo_.avg_pool,
393 mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
394 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
395 rinfo_.push_back({csinfo_.avg_pool_grad,
396 mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
397 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
398 rinfo_.push_back({csinfo_.avg_pool3d,
399 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
400 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
401 rinfo_.push_back({csinfo_.avg_pool3d_grad,
402 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
403 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
404 rinfo_.push_back({csinfo_.batch_matmul,
405 mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
406 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
407 rinfo_.push_back({csinfo_.einsum,
408 mkl_op_registry::GetMklOpName(csinfo_.einsum),
409 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
410 rinfo_.push_back({csinfo_.batch_matmul_v2,
411 mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
412 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
413 rinfo_.push_back({csinfo_.concat,
414 mkl_op_registry::GetMklOpName(csinfo_.concat),
415 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
416 rinfo_.push_back({csinfo_.concatv2,
417 mkl_op_registry::GetMklOpName(csinfo_.concatv2),
418 CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
419 rinfo_.push_back(
420 {csinfo_.conjugate_transpose,
421 mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
422 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
423 rinfo_.push_back(
424 {csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d),
425 CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
426 rinfo_.push_back({csinfo_.conv2d_with_bias,
427 native_fmt ? csinfo_.mkl_native_conv2d_with_bias
428 : csinfo_.mkl_conv2d_with_bias,
429 CopyAttrsConvCheckConstFilter, AlwaysRewrite,
430 GetRewriteCause()});
431 rinfo_.push_back({csinfo_.conv2d_grad_filter,
432 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
433 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
434 rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
435 native_fmt
436 ? csinfo_.mkl_native_conv2d_grad_filter_with_bias
437 : csinfo_.mkl_conv2d_grad_filter_with_bias,
438 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
439 rinfo_.push_back({csinfo_.conv2d_grad_input,
440 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
441 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
442 rinfo_.push_back(
443 {csinfo_.conv3d, mkl_op_registry::GetMklOpName(csinfo_.conv3d),
444 CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
445 rinfo_.push_back({csinfo_.conv3d_grad_filter,
446 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
447 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
448 rinfo_.push_back({csinfo_.conv3d_grad_input,
449 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
450 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
451 rinfo_.push_back({csinfo_.depthwise_conv2d,
452 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
453 CopyAttrsConvCheckConstFilter, AlwaysRewrite,
454 GetRewriteCause()});
455 rinfo_.push_back(
456 {csinfo_.depthwise_conv2d_grad_input,
457 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
458 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
459 rinfo_.push_back(
460 {csinfo_.depthwise_conv2d_grad_filter,
461 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
462 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
463 rinfo_.push_back(
464 {csinfo_.dequantize, mkl_op_registry::GetMklOpName(csinfo_.dequantize),
465 CopyAttrsAll, DequantizeRewrite, kRewriteForOpNameChange});
466 rinfo_.push_back({csinfo_.fused_batch_norm,
467 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
468 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
469 rinfo_.push_back(
470 {csinfo_.fused_batch_norm_grad,
471 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
472 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
473 rinfo_.push_back(
474 {csinfo_.fused_batch_norm_v2,
475 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v2),
476 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
477 rinfo_.push_back(
478 {csinfo_.fused_batch_norm_grad_v2,
479 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v2),
480 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
481
482 // Using CopyAttrsAll for V3 on CPU, as there are no additional
483 // attributes.
484 rinfo_.push_back(
485 {csinfo_.fused_batch_norm_v3,
486 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
487 CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
488 rinfo_.push_back(
489 {csinfo_.fused_batch_norm_grad_v3,
490 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
491 CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
492 rinfo_.push_back({csinfo_.fused_batch_norm_ex,
493 native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
494 : csinfo_.mkl_fused_batch_norm_ex,
495 CopyAttrsAll, FusedBatchNormExRewrite,
496 GetRewriteCause()});
497 rinfo_.push_back({csinfo_.fused_conv2d,
498 native_fmt ? csinfo_.mkl_native_fused_conv2d
499 : csinfo_.mkl_fused_conv2d,
500 CopyAttrsAllCheckConstFilter, FusedConv2DRewrite,
501 GetRewriteCause()});
502 rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
503 native_fmt ? csinfo_.mkl_native_fused_depthwise_conv2d
504 : csinfo_.mkl_fused_depthwise_conv2d,
505 CopyAttrsAllCheckConstFilter, FusedDepthwiseConv2DRewrite,
506 GetRewriteCause()});
507 rinfo_.push_back({csinfo_.fused_matmul,
508 native_fmt ? csinfo_.mkl_native_fused_matmul
509 : csinfo_.mkl_fused_matmul,
510 CopyAttrsAllCheckConstFilter, FusedMatMulRewrite,
511 GetRewriteCause()});
512
513 rinfo_.push_back(
514 {csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity),
515 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
516 rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
517 CopyAttrsAll, LrnRewrite, GetRewriteCause()});
518 rinfo_.push_back({csinfo_.lrn_grad,
519 mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
520 CopyAttrsAll, LrnGradRewrite, GetRewriteCause()});
521 rinfo_.push_back({csinfo_.matmul,
522 mkl_op_registry::GetMklOpName(csinfo_.matmul),
523 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
524 rinfo_.push_back({csinfo_.leakyrelu,
525 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
526 CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
527 rinfo_.push_back({csinfo_.leakyrelu_grad,
528 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
529 CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
530 rinfo_.push_back(
531 {csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool),
532 CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
533 rinfo_.push_back({csinfo_.max_pool_grad,
534 mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
535 CopyAttrsAll, MaxpoolGradRewrite, GetRewriteCause()});
536 rinfo_.push_back(
537 {csinfo_.max_pool3d, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
538 CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
539 rinfo_.push_back({csinfo_.max_pool3d_grad,
540 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
541 CopyAttrsAll, Maxpool3DGradRewrite, GetRewriteCause()});
542 rinfo_.push_back(
543 {csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum),
544 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
545 rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
546 CopyAttrsAll, RewriteIfAtleastOneMklInput,
547 GetRewriteCause()});
548 rinfo_.push_back({csinfo_.pad_with_conv2d,
549 native_fmt ? csinfo_.mkl_native_pad_with_conv2d
550 : csinfo_.mkl_pad_with_conv2d,
551 CopyAttrsAllCheckConstFilter, AlwaysRewrite,
552 GetRewriteCause()});
553 rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
554 native_fmt ? csinfo_.mkl_native_pad_with_fused_conv2d
555 : csinfo_.mkl_pad_with_fused_conv2d,
556 CopyAttrsAllCheckConstFilter, AlwaysRewrite,
557 GetRewriteCause()});
558 rinfo_.push_back({csinfo_.quantized_avg_pool,
559 mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
560 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
561 rinfo_.push_back({csinfo_.quantized_concatv2,
562 mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
563 CopyAttrsAll, ConcatV2Rewrite, kRewriteForOpNameChange});
564 rinfo_.push_back({csinfo_.quantized_conv2d,
565 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
566 CopyAttrsQuantizedConv2D, AlwaysRewrite,
567 kRewriteForOpNameChange});
568 rinfo_.push_back(
569 {csinfo_.quantized_conv2d_per_channel,
570 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_per_channel),
571 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
572 rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
573 mkl_op_registry::GetMklOpName(
574 csinfo_.quantized_conv2d_with_requantize),
575 CopyAttrsQuantizedConv2D, AlwaysRewrite,
576 kRewriteForOpNameChange});
577 rinfo_.push_back(
578 {csinfo_.quantized_conv2d_with_bias,
579 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
580 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
581 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
582 mkl_op_registry::GetMklOpName(
583 csinfo_.quantized_conv2d_with_bias_and_requantize),
584 CopyAttrsQuantizedConv2D, AlwaysRewrite,
585 kRewriteForOpNameChange});
586 rinfo_.push_back(
587 {csinfo_.quantized_conv2d_and_relu,
588 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
589 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
590 rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
591 mkl_op_registry::GetMklOpName(
592 csinfo_.quantized_conv2d_and_relu_and_requantize),
593 CopyAttrsQuantizedConv2D, AlwaysRewrite,
594 kRewriteForOpNameChange});
595 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
596 mkl_op_registry::GetMklOpName(
597 csinfo_.quantized_conv2d_with_bias_and_relu),
598 CopyAttrsQuantizedConv2D, AlwaysRewrite,
599 kRewriteForOpNameChange});
600 rinfo_.push_back(
601 {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
602 mkl_op_registry::GetMklOpName(
603 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
604 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
605 rinfo_.push_back({csinfo_.quantized_max_pool,
606 mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
607 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
608 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
609 mkl_op_registry::GetMklOpName(
610 csinfo_.quantized_conv2d_with_bias_sum_and_relu),
611 CopyAttrsQuantizedConv2D, AlwaysRewrite,
612 kRewriteForOpNameChange});
613 rinfo_.push_back(
614 {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
615 mkl_op_registry::GetMklOpName(
616 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
617 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
618 rinfo_.push_back(
619 {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
620 mkl_op_registry::GetMklOpName(
621 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
622 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
623 rinfo_.push_back(
624 {csinfo_.quantized_matmul_with_bias,
625 mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias),
626 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
627 kRewriteForOpNameChange});
628 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_relu,
629 mkl_op_registry::GetMklOpName(
630 csinfo_.quantized_matmul_with_bias_and_relu),
631 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
632 kRewriteForOpNameChange});
633 rinfo_.push_back(
634 {csinfo_.quantized_matmul_with_bias_and_relu_and_requantize,
635 mkl_op_registry::GetMklOpName(
636 csinfo_.quantized_matmul_with_bias_and_relu_and_requantize),
637 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
638 kRewriteForOpNameChange});
639 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_requantize,
640 mkl_op_registry::GetMklOpName(
641 csinfo_.quantized_matmul_with_bias_and_requantize),
642 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
643 kRewriteForOpNameChange});
644 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_dequantize,
645 mkl_op_registry::GetMklOpName(
646 csinfo_.quantized_matmul_with_bias_and_dequantize),
647 CopyAttrsQuantizedMatMulWithBiasAndDequantize,
648 AlwaysRewrite, kRewriteForOpNameChange});
649 rinfo_.push_back(
650 {csinfo_.quantized_depthwise_conv2d,
651 mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d),
652 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
653 rinfo_.push_back({csinfo_.quantized_depthwise_conv2d_with_bias,
654 mkl_op_registry::GetMklOpName(
655 csinfo_.quantized_depthwise_conv2d_with_bias),
656 CopyAttrsQuantizedConv2D, AlwaysRewrite,
657 kRewriteForOpNameChange});
658 rinfo_.push_back(
659 {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu,
660 mkl_op_registry::GetMklOpName(
661 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu),
662 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
663 rinfo_.push_back(
664 {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize,
665 mkl_op_registry::GetMklOpName(
666 csinfo_
667 .quantized_depthwise_conv2d_with_bias_and_relu_and_requantize),
668 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
669 rinfo_.push_back({csinfo_.quantize_v2,
670 mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
671 CopyAttrsAll, QuantizeOpRewrite,
672 kRewriteForOpNameChange});
673 rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
674 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
675 rinfo_.push_back({csinfo_.relu_grad,
676 mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
677 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
678 rinfo_.push_back({csinfo_.relu6,
679 mkl_op_registry::GetMklOpName(csinfo_.relu6),
680 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
681 rinfo_.push_back({csinfo_.relu6_grad,
682 mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
683 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
684 rinfo_.push_back({csinfo_.requantize,
685 mkl_op_registry::GetMklOpName(csinfo_.requantize),
686 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
687 // Optimized TanhGrad support exists only in DNNL 1.x.
688 rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
689 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
690 rinfo_.push_back({csinfo_.tanh_grad,
691 mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
692 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
693 rinfo_.push_back({csinfo_.reshape,
694 mkl_op_registry::GetMklOpName(csinfo_.reshape),
695 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
696 rinfo_.push_back(
697 {csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice),
698 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
699 rinfo_.push_back({csinfo_.softmax,
700 mkl_op_registry::GetMklOpName(csinfo_.softmax),
701 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
702
703 rinfo_.push_back({csinfo_.squared_difference,
704 mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
705 CopyAttrsAll, RewriteIfAtleastOneMklInput,
706 GetRewriteCause()});
707 rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
708 CopyAttrsAll, RewriteIfAtleastOneMklInput,
709 GetRewriteCause()});
710 rinfo_.push_back({csinfo_.transpose,
711 mkl_op_registry::GetMklOpName(csinfo_.transpose),
712 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
713
714 // Add info about which ops to add workspace edge to and the slots.
715 wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
716 wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
717 wsinfo_.push_back(
718 {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
719
720 // Add a rule for merging nodes
721 minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
722 csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
723
724 // Merge Pad and Conv2d, only if the pad op is "Pad"
725 // Doesn't merge if pad op is "PadV2" or "MirrorPad"
726 minfo_.push_back(
727 {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D});
728
729 minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
730 csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
731
732 minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
733 csinfo_.conv2d_grad_filter_with_bias,
734 GetConv2DBackpropFilterOrBiasAddGrad});
735
736 // The fusion patterns in "finfo_" that show up first will get applied
737 // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
738 // A->B->C->D to ABCD}, since the first gets applied first, the final
739 // graph will be ABC->D.
740
741 //
742 // Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D
743 // (NHWC) + Transpose (NHWC->
744 // NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras.
745 // Note: we use the term "merge" to combine (exactly) 2 nodes into one,
746 // while "fusion" is for 3+ nodes situation.
747 //
748
749 // Transpose + Conv2d + Transpose:
750 std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H,
751 NCHW::dim::W, NCHW::dim::C};
752 std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C,
753 NHWC::dim::H, NHWC::dim::W};
754 auto CheckForTransposeToNHWC =
755 std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nhwc);
756 auto CheckForConv2dOp =
757 std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d);
758 auto CheckForTransposeToNCHW =
759 std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nchw);
760 auto FuseConv2D =
761 std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
762 std::placeholders::_2, std::placeholders::_3, "NCHW");
763 finfo_.push_back(
764 {"transpose-elimination for Conv2D",
765 {CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW},
766 FuseConv2D,
767 CopyAttrsConv});
768
769 // Transpose + Conv3d + Transpose:
770 std::vector<int> transpose_to_ndhwc = {NCDHW::dim::N, NCDHW::dim::D,
771 NCDHW::dim::H, NCDHW::dim::W,
772 NCDHW::dim::C};
773 std::vector<int> transpose_to_ncdhw = {NDHWC::dim::N, NDHWC::dim::C,
774 NDHWC::dim::D, NDHWC::dim::H,
775 NDHWC::dim::W};
776
777 auto CheckForTransposeToNDHWC =
778 std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ndhwc);
779 auto CheckForConv3dOp =
780 std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv3d);
781 auto CheckForTransposeToNCDHW =
782 std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ncdhw);
783 auto FuseConv3D =
784 std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
785 std::placeholders::_2, std::placeholders::_3, "NCDHW");
786
787 finfo_.push_back(
788 {"transpose-elimination for Conv3D",
789 {CheckForTransposeToNDHWC, CheckForConv3dOp, CheckForTransposeToNCDHW},
790 FuseConv3D,
791 CopyAttrsConv});
792
793 auto CheckForMaxPool3DOp =
794 std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.max_pool3d);
795 auto FuseMaxPool3D =
796 std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
797 std::placeholders::_2, std::placeholders::_3, "NCDHW");
798 finfo_.push_back({"transpose-elimination for MaxPool3D",
799 {CheckForTransposeToNDHWC, CheckForMaxPool3DOp,
800 CheckForTransposeToNCDHW},
801 FuseMaxPool3D,
802 CopyAttrsPooling});
803 }
804
805 // Standard interface to run pass
806 Status Run(const GraphOptimizationPassOptions& options);
807
808 // Helper function which does most of heavy lifting for rewriting
809 // Mkl nodes to propagate Mkl tensor as additional output
810 //
811 // Extracts common functionality between Run public interface and
812 // test interface.
813 //
814 // @return true, if and only if graph is mutated; false otherwise.
815 bool RunPass(std::unique_ptr<Graph>* g);
816
817 /// Cause for rewrite
818 /// Currently, we only support 2 causes - either for Mkl layout propagation
819 /// which is the most common case, or for just a name change (used in case
820 /// of ops like MatMul, Transpose, which do not support Mkl layout)
821 enum RewriteCause { kRewriteForLayoutPropagation, kRewriteForOpNameChange };
822
823 // Get the op rewrite cause depending on whether native format mode
824 // is enabled or not.
GetRewriteCause()825 RewriteCause GetRewriteCause() {
826 if (NativeFormatEnabled()) {
827 return kRewriteForOpNameChange;
828 } else {
829 return kRewriteForLayoutPropagation;
830 }
831 }
832
833 /// Structure to specify the name of an original node, its new name after
834 /// rewrite, the number of inputs to the original node, the function to
835 /// be used to copy attributes for the op, and the rule (if any) which
836 /// must hold for rewriting the node
837 typedef struct {
838 string name; // Original name of op of the node in the graph
839 string new_name; // New name of the op of the node in the graph
840 // A function handler to copy attributes from an old node to a new node.
841 std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs;
842 // A rule under which to rewrite this node
843 std::function<bool(const Node*)> rewrite_rule;
844 // Why are we rewriting?
845 RewriteCause rewrite_cause;
846 } RewriteInfo;
847
848 /// Structure to specify a forward op, a backward op, and the slot numbers
849 /// in the forward and backward ops where we will add a workspace edge.
850 typedef struct {
851 string fwd_op; // Name of a forward op in the graph
852 string bwd_op; // Name of a backward op in the graph
853 int fwd_slot; // Output slot in the forward op node where actual
854 // output tensor resides
855 int bwd_slot; // Input slot in the backward op node where actual
856 // input tensor resides
857 int ws_fwd_slot; // Output slot in the forward op node where workspace
858 // edge is added
859 int ws_bwd_slot; // Input slot in the backward op node where workspace
860 // edge is added
861 } WorkSpaceInfo;
862
863 /// Structure to specify information used in node merge of 2 operators
864 typedef struct {
865 string op1; // Node string for one operator.
866 string op2; // Node string for second operator.
867 string new_node; // Name of the node after merge
868 // Function that enables user of the node merger to specify how to find
869 // second operator given the first operator.
870 std::function<Node*(const Node*)> get_node_to_be_merged;
871 } MergeInfo;
872
873 // Structure to specify information used in node fusion of 3+ operators
874 typedef struct {
875 std::string pattern_name; // Name to describe this pattern, such as
876 // "Transpose_Mklop_Transpose".
877 std::vector<std::function<bool(const Node*)> >
878 node_checkers; // Extra restriction checker for these ops
879 std::function<Status(
880 std::unique_ptr<Graph>*, std::vector<Node*>&,
881 std::function<void(const Node*, NodeBuilder* nb, bool)>)>
882 fuse_func;
883 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs;
884 } FusionInfo;
885
886 //
887 // Dimension indices for 2D tensor.
888 //
889 struct NCHW {
890 enum dim { N = 0, C = 1, H = 2, W = 3 };
891 };
892
893 struct NHWC {
894 enum dim { N = 0, H = 1, W = 2, C = 3 };
895 };
896
897 //
898 // dimension indices for 3D tensor.
899 //
900 struct NCDHW {
901 enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 };
902 };
903
904 struct NDHWC {
905 enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 };
906 };
907
908 /// Structure to store all constant strings
909 /// NOTE: names are alphabetically sorted.
910 typedef struct {
911 string addn;
912 string add;
913 string add_v2;
914 string avg_pool;
915 string avg_pool_grad;
916 string avg_pool3d;
917 string avg_pool3d_grad;
918 string batch_matmul;
919 string batch_matmul_v2;
920 string bias_add;
921 string bias_add_grad;
922 string concat;
923 string concatv2;
924 string conjugate_transpose;
925 string conv2d;
926 string conv2d_with_bias;
927 string conv2d_grad_input;
928 string conv2d_grad_filter;
929 string conv2d_grad_filter_with_bias;
930 string conv3d;
931 string conv3d_grad_input;
932 string conv3d_grad_filter;
933 string depthwise_conv2d;
934 string depthwise_conv2d_grad_input;
935 string depthwise_conv2d_grad_filter;
936 string dequantize;
937 string einsum;
938 string fused_batch_norm;
939 string fused_batch_norm_grad;
940 string fused_batch_norm_ex;
941 string fused_batch_norm_v2;
942 string fused_batch_norm_grad_v2;
943 string fused_batch_norm_v3;
944 string fused_batch_norm_grad_v3;
945 string fused_conv2d;
946 string fused_depthwise_conv2d;
947 string fused_matmul;
948 string identity;
949 string leakyrelu;
950 string leakyrelu_grad;
951 string lrn;
952 string lrn_grad;
953 string matmul;
954 string max_pool;
955 string max_pool_grad;
956 string max_pool3d;
957 string max_pool3d_grad;
958 string maximum;
959 string mkl_conv2d;
960 string mkl_conv2d_grad_input;
961 string mkl_conv2d_grad_filter;
962 string mkl_conv2d_grad_filter_with_bias;
963 string mkl_conv2d_with_bias;
964 string mkl_depthwise_conv2d_grad_input;
965 string mkl_depthwise_conv2d_grad_filter;
966 string mkl_fused_batch_norm_ex;
967 string mkl_fused_conv2d;
968 string mkl_fused_depthwise_conv2d;
969 string mkl_fused_matmul;
970 string mkl_native_conv2d_with_bias;
971 string mkl_native_conv2d_grad_filter_with_bias;
972 string mkl_native_fused_batch_norm_ex;
973 string mkl_native_fused_conv2d;
974 string mkl_native_fused_depthwise_conv2d;
975 string mkl_native_fused_matmul;
976 string mkl_native_pad_with_conv2d;
977 string mkl_native_pad_with_fused_conv2d;
978 string mkl_pad_with_conv2d;
979 string mkl_pad_with_fused_conv2d;
980 string mul;
981 string pad;
982 string pad_with_conv2d;
983 string pad_with_fused_conv2d;
984 string quantized_avg_pool;
985 string quantized_conv2d;
986 string quantized_conv2d_per_channel;
987 string quantized_conv2d_with_requantize;
988 string quantized_conv2d_with_bias;
989 string quantized_conv2d_with_bias_and_requantize;
990 string quantized_conv2d_and_relu;
991 string quantized_conv2d_and_relu_and_requantize;
992 string quantized_conv2d_with_bias_and_relu;
993 string quantized_conv2d_with_bias_and_relu_and_requantize;
994 string quantized_concatv2;
995 string quantized_max_pool;
996 string quantized_conv2d_with_bias_sum_and_relu;
997 string quantized_conv2d_with_bias_sum_and_relu_and_requantize;
998 string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize;
999 string quantized_matmul_with_bias;
1000 string quantized_matmul_with_bias_and_relu;
1001 string quantized_matmul_with_bias_and_relu_and_requantize;
1002 string quantized_matmul_with_bias_and_requantize;
1003 string quantized_matmul_with_bias_and_dequantize;
1004 string quantized_depthwise_conv2d;
1005 string quantized_depthwise_conv2d_with_bias;
1006 string quantized_depthwise_conv2d_with_bias_and_relu;
1007 string quantized_depthwise_conv2d_with_bias_and_relu_and_requantize;
1008 string quantize_v2;
1009 string relu;
1010 string relu_grad;
1011 string relu6;
1012 string relu6_grad;
1013 string requantize;
1014 string tanh;
1015 string tanh_grad;
1016 string transpose;
1017 string reshape;
1018 string slice;
1019 string softmax;
1020 string split;
1021 string squared_difference;
1022 string sub;
1023 } ConstStringsInfo;
1024
1025 private:
1026 /// Maintain info about nodes to rewrite
1027 std::vector<RewriteInfo> rinfo_;
1028
1029 /// Maintain info about nodes to add workspace edge
1030 std::vector<WorkSpaceInfo> wsinfo_;
1031
1032 /// Maintain info about nodes to be merged
1033 std::vector<MergeInfo> minfo_;
1034
1035 /// Maintain info about nodes to be fused
1036 std::vector<FusionInfo> finfo_;
1037
1038 /// Maintain structure of constant strings
1039 static ConstStringsInfo csinfo_;
1040
1041 private:
1042 // Is OpDef::ArgDef a list type? It could be N * T or list(type).
1043 // Refer to opdef.proto for details of list type.
ArgIsList(const OpDef::ArgDef & arg) const1044 inline bool ArgIsList(const OpDef::ArgDef& arg) const {
1045 return !arg.type_list_attr().empty() || !arg.number_attr().empty();
1046 }
1047
1048 // Get length of a list in 'n' if 'arg' is of list type. Refer to
1049 // description of ArgIsList for definition of list type.
GetTensorListLength(const OpDef::ArgDef & arg,const Node * n)1050 inline int GetTensorListLength(const OpDef::ArgDef& arg, const Node* n) {
1051 CHECK_EQ(ArgIsList(arg), true);
1052 int N = 0;
1053 const string attr_name = !arg.type_list_attr().empty()
1054 ? arg.type_list_attr()
1055 : arg.number_attr();
1056 if (!arg.type_list_attr().empty()) {
1057 std::vector<DataType> value;
1058 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
1059 N = value.size();
1060 } else {
1061 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
1062 }
1063 return N;
1064 }
1065
1066 // Can op represented by node 'n' run on DEVICE_CPU?
1067 // Op can run on CPU with MKL if the runtime assigned device or the
1068 // user requested device contains device CPU, or both are empty.
CanOpRunOnCPUDevice(const Node * n)1069 bool CanOpRunOnCPUDevice(const Node* n) {
1070 bool result = true;
1071 string reason;
1072
1073 // Substring that should be checked for in device name for CPU device.
1074 const char* const kCPUDeviceSubStr = "CPU";
1075 const char* const kXLACPUDeviceSubStr = "XLA_CPU";
1076
1077 // If Op has been specifically assigned to a non-CPU or XLA_CPU device, then
1078 // No.
1079 if (!n->assigned_device_name().empty() &&
1080 (!absl::StrContains(n->assigned_device_name(), kCPUDeviceSubStr) ||
1081 absl::StrContains(n->assigned_device_name(), kXLACPUDeviceSubStr))) {
1082 result = false;
1083 reason = "Op has been assigned a runtime device that is not CPU.";
1084 }
1085
1086 // If user has specifically assigned this op to a non-CPU or XLA_CPU device,
1087 // then No.
1088 if (!n->def().device().empty() &&
1089 (!absl::StrContains(n->def().device(), kCPUDeviceSubStr) ||
1090 absl::StrContains(n->def().device(), kXLACPUDeviceSubStr))) {
1091 result = false;
1092 reason = "User has assigned a device that is not CPU.";
1093 }
1094
1095 if (result == false) {
1096 VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
1097 << n->type_string() << ", reason: " << reason;
1098 }
1099
1100 // Otherwise Yes.
1101 return result;
1102 }
1103
1104 // Return a node that can be merged with input node 'n'
1105 //
1106 // @return pointer to the node if we can find such a
1107 // node. Otherwise, it returns nullptr.
1108 Node* CheckForNodeMerge(const Node* n) const;
1109
1110 // Merge node 'm' with node 'n'.
1111 // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with
1112 // Conv2DBackpropFilter.
1113 //
1114 // Input nodes m and n may be deleted if the call to
1115 // this function is successful. Attempt to use the pointers
1116 // after the call to function may result in undefined behaviors.
1117 //
1118 // @input g - input graph, m - graph node, n - graph node to be merged with m
1119 // @return Status::OK(), if merging is successful and supported.
1120 // Returns appropriate Status error code otherwise.
1121 // Graph is updated in case nodes are merged. Otherwise, it is
1122 // not updated.
1123 Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n);
1124
1125 // Helper function to merge different nodes
1126 Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n);
1127 Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n);
1128 Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g,
1129 Node* m, Node* n);
1130
1131 // Find BiasAdd or Conv2D node that can be merged with input node 'm'.
1132 // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be
1133 // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd
1134 // node that can be merged with 'm'.
GetConv2DOrBiasAdd(const Node * m)1135 static Node* GetConv2DOrBiasAdd(const Node* m) {
1136 DCHECK(m);
1137 Node* n = nullptr;
1138
1139 DataType T_m;
1140 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1141
1142 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1143 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1144
1145 if (m->type_string() == csinfo_.bias_add) {
1146 // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
1147 TF_CHECK_OK(m->input_node(0, &n));
1148 } else {
1149 CHECK_EQ(m->type_string(), csinfo_.conv2d);
1150 // Go over all output edges and search for BiasAdd Node.
1151 // 0th input of BiasAdd is Conv2D.
1152 for (const Edge* e : m->out_edges()) {
1153 if (!e->IsControlEdge() &&
1154 e->dst()->type_string() == csinfo_.bias_add &&
1155 e->dst_input() == 0) {
1156 n = e->dst();
1157 break;
1158 }
1159 }
1160 }
1161
1162 if (n == nullptr) {
1163 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1164 << "Conv2D and BiasAdd node for merging. Input node: "
1165 << m->DebugString();
1166 }
1167
1168 return n;
1169 }
1170
1171 // Find Pad or Conv2D node that can be merged with input node 'm'.
1172 // If input 'm' is Pad, then check if there exists Conv2D node that can be
1173 // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad
1174 // node that can be merged with 'm'.
GetPadOrConv2D(const Node * m)1175 static Node* GetPadOrConv2D(const Node* m) {
1176 DCHECK(m);
1177 Node* n = nullptr;
1178
1179 DataType T_m;
1180 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1181
1182 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1183 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1184
1185 const Node* conv_node;
1186 if (m->type_string() == csinfo_.pad) {
1187 // If m is Pad, then Conv2D is the output of Pad.
1188 for (const Edge* e : m->out_edges()) {
1189 if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) {
1190 n = e->dst();
1191 conv_node = n;
1192 break;
1193 }
1194 }
1195 } else {
1196 DCHECK_EQ(m->type_string(), csinfo_.conv2d);
1197 // If m is conv2D, Go over all input edges
1198 // and search for Pad Node.
1199 for (const Edge* e : m->in_edges()) {
1200 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1201 n = e->src();
1202 conv_node = m;
1203 break;
1204 }
1205 }
1206 }
1207 // Check if only VALID type of padding is used
1208 // or not.
1209 if (n != nullptr) {
1210 string padding;
1211 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1212 if (padding != "VALID")
1213 // Then do not merge.
1214 // Only VALID type of padding in conv op can be
1215 // merged with Pad op.
1216 n = nullptr;
1217 } else {
1218 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1219 << "Pad and Conv2D node for merging. Input node: "
1220 << m->DebugString();
1221 }
1222
1223 return n;
1224 }
1225
1226 // Find Pad or _FusedConv2D node that can be merged with input node 'm'.
1227 // If input 'm' is Pad, then check if there exists _FusedConv2D node that can
1228 // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there
1229 // exists Pad node that can be merged with 'm'.
GetPadOrFusedConv2D(const Node * m)1230 static Node* GetPadOrFusedConv2D(const Node* m) {
1231 DCHECK(m);
1232 Node* n = nullptr;
1233
1234 const Node* conv_node;
1235 if (m->type_string() == csinfo_.pad) {
1236 // If m is Pad, then _FusedConv2D is the output of Pad.
1237 for (const Edge* e : m->out_edges()) {
1238 if (!e->IsControlEdge() &&
1239 e->dst()->type_string() == csinfo_.fused_conv2d) {
1240 n = e->dst();
1241 conv_node = n;
1242 break;
1243 }
1244 }
1245 } else {
1246 DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d);
1247 // If m is _FusedConv2D, Go over all input edges
1248 // and search for Pad node.
1249 for (const Edge* e : m->in_edges()) {
1250 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1251 n = e->src();
1252 conv_node = m;
1253 break;
1254 }
1255 }
1256 }
1257 // Check if only VALID type of padding is used or not.
1258 if (n != nullptr) {
1259 string padding;
1260 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1261 if (padding != "VALID") {
1262 // Then do not merge.
1263 n = nullptr;
1264 VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
1265 << "nodes but cannot merge them. Only conv ops with padding "
1266 << "type VALID can be merged with Pad op Input node: "
1267 << m->DebugString();
1268 }
1269 } else {
1270 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1271 << "Pad and _FusedConv2D node for merging. Input node: "
1272 << m->DebugString();
1273 }
1274
1275 return n;
1276 }
1277
1278 // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input
1279 // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists
1280 // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad,
1281 // then check if there exists Conv2DBackpropFilter node that can be merged
1282 // with 'm'.
1283 //
1284 // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad
1285 // would look like:
1286 //
1287 // _ = Conv2DBackpropFilter(F, _, G)
1288 // _ = BiasAddGrad(G)
1289 //
1290 // So 1st input of BiasAddGrad connects with 3rd input of
1291 // Conv2DBackpropFilter and vice versa.
GetConv2DBackpropFilterOrBiasAddGrad(const Node * m)1292 static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) {
1293 DCHECK(m);
1294 Node* n = nullptr;
1295 const Node* conv2d_backprop_filter = nullptr;
1296
1297 DataType T_m;
1298 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1299
1300 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1301 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1302
1303 if (m->type_string() == csinfo_.bias_add_grad) {
1304 // Get 1st input 'g' of BiasAddGrad.
1305 Node* g = nullptr;
1306 TF_CHECK_OK(m->input_node(0, &g));
1307 // Now traverse all outgoing edges from g that have destination node as
1308 // Conv2DBackpropFilter.
1309 for (const Edge* e : g->out_edges()) {
1310 if (!e->IsControlEdge() &&
1311 e->dst()->type_string() == csinfo_.conv2d_grad_filter &&
1312 e->dst_input() == 2 /* 3rd input of BackpropFilter */) {
1313 n = e->dst();
1314 conv2d_backprop_filter = n;
1315 break;
1316 }
1317 }
1318 } else {
1319 conv2d_backprop_filter = m;
1320 CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter);
1321 // Get 3rd input 'g' of Conv2DBackpropFilter.
1322 Node* g = nullptr;
1323 TF_CHECK_OK(m->input_node(2, &g));
1324 // Now traverse all outgoing edges from g that have destination node as
1325 // BiasAddGrad.
1326 for (const Edge* e : g->out_edges()) {
1327 if (!e->IsControlEdge() &&
1328 e->dst()->type_string() == csinfo_.bias_add_grad &&
1329 e->dst_input() == 0 /* 1st input of BiasAddGrad */) {
1330 n = e->dst();
1331 break;
1332 }
1333 }
1334 }
1335
1336 // Do not merge if padding type is EXPLICIT.
1337 // TODO(intel): Support `EXPLICIT` padding for MklConv2DBackpropFilter.
1338 if (conv2d_backprop_filter != nullptr) {
1339 string padding;
1340 TF_CHECK_OK(
1341 GetNodeAttr(conv2d_backprop_filter->def(), "padding", &padding));
1342 if (padding == "EXPLICIT") {
1343 // Then do not merge.
1344 VLOG(1) << "MklLayoutRewritePass: Could match Conv2DBackpropFilter "
1345 << "and BiasAddGrad nodes but cannot merge them. "
1346 << "EXPLICIT padding is not supported now. "
1347 << conv2d_backprop_filter->DebugString();
1348 return nullptr;
1349 }
1350 }
1351
1352 if (n == nullptr) {
1353 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1354 << "Conv2DBackpropFilter and BiasAddGrad node for merging. "
1355 << "Input node: " << m->DebugString();
1356 }
1357 return n;
1358 }
1359
1360 // Return a node that can be fused with input node 'n'
1361 //
1362 // @return tuple. If we can find such nodes, the first
1363 // element of the tuple is a true. Otherwise, it's false.
1364 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
1365 CheckForNodeFusion(Node* n) const;
1366
1367 // Fuse nodes in the vector "nodes"
1368 Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1369 const MklLayoutRewritePass::FusionInfo fi);
1370
1371 // Fuse transpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into
1372 // mklop("NCHW").
1373 // Here "mklop" can be any MKL-DNN supported op, such as Conv2D.
1374 static Status FuseTransposeMklOpTranspose(
1375 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1376 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
1377 string data_format);
1378
CheckForTranspose(const Node * node,std::vector<int> perm)1379 static bool CheckForTranspose(const Node* node, std::vector<int> perm) {
1380 // Check if node's type is "Transpose"
1381 if (node->type_string() != "Transpose") return false;
1382
1383 // If "Transpose" has multiple output data edges, also don't fuse it.
1384 if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false;
1385
1386 // Check if has out control edge. If true, this is a training graph.
1387 // Currently we focus on inference and do no fusion in training.
1388 // Note: this constraint will eventually be removed, if we enabled this
1389 // fusion for training
1390 // in the future.
1391 for (const Edge* e : node->out_edges()) {
1392 if (e->IsControlEdge()) {
1393 return false;
1394 }
1395 }
1396
1397 // If "Transpose" has input control edges, don't fuse on it.
1398 for (const Edge* e : node->in_edges()) {
1399 if (e->IsControlEdge()) {
1400 return false;
1401 }
1402 }
1403
1404 // We compared the tensor containing the permutation order ("perm_node")
1405 // with our desired order ("perm"). If they're exactly match, this check
1406 // succeed and returns true.
1407 for (const Edge* e : node->in_edges()) {
1408 if (!e->IsControlEdge()) {
1409 const Node* perm_node = e->src();
1410
1411 const int kPermTensorIndex = 1;
1412 if (perm_node->type_string() == "Const" &&
1413 e->dst_input() == kPermTensorIndex) {
1414 // we find the "perm" node, now try to retrieve its value.
1415 const TensorProto* proto = nullptr;
1416 TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value", &proto));
1417
1418 DataType type;
1419 TF_CHECK_OK(GetNodeAttr(perm_node->def(), "dtype", &type));
1420
1421 Tensor tensor;
1422 if (!tensor.FromProto(*proto)) {
1423 TF_CHECK_OK(errors::InvalidArgument(
1424 "Could not construct Tensor from TensorProto in node: ",
1425 node->name()));
1426 return false;
1427 }
1428 // Current fusion only supports 4D or 5D tensors according to `perm`
1429 // vector, return false otherwise.
1430 if (tensor.dim_size(0) != perm.size()) return false;
1431 DCHECK_EQ(tensor.dims(), 1);
1432 if (type == DT_INT32) {
1433 const auto tensor_content = tensor.flat<int>().data();
1434 for (int i = 0; i < perm.size(); ++i)
1435 if (tensor_content[i] != perm[i]) return false;
1436 return true;
1437 } else if (type == DT_INT64) {
1438 const auto tensor_content = tensor.flat<int64>().data();
1439 for (int i = 0; i < perm.size(); ++i)
1440 if (tensor_content[i] != perm[i]) return false;
1441 return true;
1442 }
1443 return false;
1444 }
1445 }
1446 }
1447 return false;
1448 }
1449
CheckForMklOp(const Node * node,string name="")1450 static bool CheckForMklOp(const Node* node, string name = "") {
1451 if (node == nullptr) return false;
1452
1453 if (!name.empty() && node->type_string() != name) {
1454 return false;
1455 }
1456
1457 // if mklop has multiple outputs, don't fuse it.
1458 if (node->num_outputs() > 1) return false;
1459
1460 if (node->out_edges().size() > 1) return false;
1461
1462 DataType T;
1463 TF_CHECK_OK(GetNodeAttr(node->def(), "T", &T));
1464 return mkl_op_registry::IsMklOp(
1465 mkl_op_registry::GetMklOpName(node->type_string()), T);
1466 }
1467
1468 // Check if the node 'n' has any applicable rewrite rule
1469 // We check for 2 scenarios for rewrite.
1470 //
1471 // @return RewriteInfo* for the applicable rewrite rule
1472 const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
1473 const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const;
1474
1475 // Default rewrite rule to be used in scenario 1 for rewrite.
1476 // @return - true (since we want to always rewrite)
AlwaysRewrite(const Node * n)1477 static bool AlwaysRewrite(const Node* n) { return true; }
1478
1479 // Rewrite rule which considers "context" of the current node to decide if we
1480 // should rewrite. By "context" we currently mean all the inputs of current
1481 // node. The idea is if none of the inputs of current node are not MKL nodes,
1482 // then rewriting current node to MKL node _may not_ offer any performance
1483 // improvement.
1484 //
1485 // One such case is element-wise ops. For such ops, we reuse the Eigen
1486 // implementation and pass the MKL metadata tensor through so we can avoid
1487 // conversions. However, if all incoming edges are in TF format, we don't
1488 // need all this overhead, so replace the elementwise node only if at least
1489 // one of its parents is a MKL node.
1490 //
1491 // More generally, all memory- or IO-bound ops (such as Identity) may fall
1492 // under this category.
1493 //
1494 // @input - Input graph node to be rewritten
1495 // @return - true if node is to be rewritten as MKL node; false otherwise.
RewriteIfAtleastOneMklInput(const Node * n)1496 static bool RewriteIfAtleastOneMklInput(const Node* n) {
1497 DataType T;
1498 if (GetNodeAttr(n->def(), "T", &T).ok() &&
1499 mkl_op_registry::IsMklOp(
1500 mkl_op_registry::GetMklOpName(n->type_string()), T)) {
1501 for (auto e : n->in_edges()) {
1502 if (e->IsControlEdge()) continue;
1503 if (mkl_op_registry::IsMklOp(e->src())) {
1504 return true;
1505 }
1506 }
1507 }
1508 return false;
1509 }
1510
MatMulRewrite(const Node * n)1511 static bool MatMulRewrite(const Node* n) {
1512 DataType T;
1513 TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
1514 if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) {
1515 VLOG(2) << "Rewriting MatMul to _MklMatMul";
1516 return true;
1517 }
1518 return false;
1519 }
1520 // For oneDNN, only int32 is supported for axis data type
ConcatV2Rewrite(const Node * n)1521 static bool ConcatV2Rewrite(const Node* n) {
1522 DataType T;
1523 TF_CHECK_OK(GetNodeAttr(n->def(), "Tidx", &T));
1524 return (T == DT_INT32);
1525 }
1526
DequantizeRewrite(const Node * n)1527 static bool DequantizeRewrite(const Node* n) {
1528 DCHECK(n);
1529 Node* input = nullptr;
1530 TF_CHECK_OK(n->input_node(0, &input));
1531 string mode_string;
1532 TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1533 if (mode_string != "SCALED") {
1534 VLOG(1) << "DequantizeRewrite: Mode is not SCALED. "
1535 << "This case is not optimized by Intel MKL kernel, thus using "
1536 "Eigen op for Dequantize op.";
1537 return false;
1538 }
1539 if (input->IsConstant()) {
1540 VLOG(1) << "DequantizeRewrite: Trying to dequantize a Const node which "
1541 << "could possibly be a filter. "
1542 << "This case is not supported by Intel MKL kernel, thus using "
1543 "Eigen op for Dequantize op.";
1544 return false;
1545 }
1546 return true;
1547 }
1548
1549 // Rewrite rule for _FusedMatMul.
1550 // @return - true (no transpose attribute for input 1);
1551 // false otherwise.
FusedMatMulRewrite(const Node * n)1552 static bool FusedMatMulRewrite(const Node* n) {
1553 bool trans_a;
1554
1555 // Do not rewrite with transpose attribute because reorder has performance
1556 // impact.
1557 TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &trans_a));
1558
1559 return !trans_a;
1560 }
1561
1562 // Check if we are performing pooling on depth or batch. If it is, then we
1563 // do not rewrite MaxPool node to Mkl version.
1564 // @return - true (if it is not a depth/batch wise pooling case);
1565 // false otherwise.
NonDepthBatchWisePoolRewrite(const Node * n)1566 static bool NonDepthBatchWisePoolRewrite(const Node* n) {
1567 DCHECK(n);
1568
1569 string data_format_str;
1570 TensorFormat data_format;
1571 std::vector<int32> ksize, strides;
1572 TF_CHECK_OK(GetNodeAttr(n->def(), "ksize", &ksize));
1573 TF_CHECK_OK(GetNodeAttr(n->def(), "strides", &strides));
1574 TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format_str));
1575 bool result = FormatFromString(data_format_str, &data_format);
1576 DCHECK(result);
1577
1578 // Condition that specifies non-batch-wise and non-depth-wise pooling.
1579 if (GetTensorDim(ksize, data_format, 'N') == 1 &&
1580 GetTensorDim(strides, data_format, 'N') == 1 &&
1581 GetTensorDim(ksize, data_format, 'C') == 1 &&
1582 GetTensorDim(strides, data_format, 'C') == 1) {
1583 return true;
1584 }
1585
1586 return false;
1587 }
1588
1589 // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
1590 // path. The unoptimized path is slow. Thus we don't rewrite the node
1591 // and use default Eigen. But for depth_radius=2, MKL DNN optimized
1592 // path is taken, i.e., eigen node is rewritten by MKl DNN node.
LrnRewrite(const Node * n)1593 static bool LrnRewrite(const Node* n) {
1594 DCHECK(n);
1595
1596 int depth_radius;
1597 TF_CHECK_OK(GetNodeAttr(n->def(), "depth_radius", &depth_radius));
1598
1599 // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
1600 // and use eigen node instead
1601 if (depth_radius == 2) {
1602 return true;
1603 }
1604 VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
1605 << "case is not optimized by Intel MKL, thus using Eigen op"
1606 << "for LRN ";
1607
1608 return false;
1609 }
1610
LrnGradRewrite(const Node * n)1611 static bool LrnGradRewrite(const Node* n) {
1612 DCHECK(n);
1613 bool do_rewrite = false;
1614
1615 for (const Edge* e : n->in_edges()) {
1616 // Rewrite only if there is corresponding LRN, i.e workspace is available
1617 if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
1618 e->src()->type_string() ==
1619 mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
1620 e->src_output() == 0) {
1621 do_rewrite = true;
1622 break;
1623 }
1624 }
1625 return do_rewrite;
1626 }
1627
1628 // MKL-DNN's LeakyRelu(feature) = feature (if feature > 0), or
1629 // feature * alpha (otherwise),
1630 // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha).
1631 // These two algorithms are not consistent when alpha > 1,
1632 // so we only rewrite LeakyRelu to MKL OP when alpha <= 1.
LeakyReluRewrite(const Node * n)1633 static bool LeakyReluRewrite(const Node* n) {
1634 DCHECK(n);
1635
1636 float alpha;
1637 bool has_attr = TryGetNodeAttr(n->def(), "alpha", &alpha);
1638 DCHECK(has_attr);
1639
1640 // If the alpha of LeakyRelu is less than 1, rewrite the node.
1641 // Otherwise eigen node is used instead.
1642 if (alpha <= 1) {
1643 return true;
1644 }
1645 VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 "
1646 << "which case is not optimized by Intel MKL, thus using Eigen op"
1647 << "for LeakyRelu ";
1648
1649 return false;
1650 }
1651
QuantizeOpRewrite(const Node * n)1652 static bool QuantizeOpRewrite(const Node* n) {
1653 DCHECK(n);
1654 Node* filter_node = nullptr;
1655 TF_CHECK_OK(n->input_node(0, &filter_node));
1656 bool narrow_range = false;
1657 int axis = -1;
1658 string mode_string;
1659 string round_mode_string;
1660 DataType type;
1661 TryGetNodeAttr(n->def(), "narrow_range", &narrow_range);
1662 TryGetNodeAttr(n->def(), "axis", &axis);
1663 TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1664 TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode", &round_mode_string));
1665 TF_CHECK_OK(GetNodeAttr(n->def(), "T", &type));
1666
1667 if (narrow_range) {
1668 VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization."
1669 << "This case is not optimized by Intel MKL, "
1670 << "thus using Eigen op for Quantize op ";
1671 return false;
1672 }
1673 if (axis != -1) {
1674 VLOG(1) << "QuantizeOpRewrite: dimension is specified for "
1675 << "per slice quantization."
1676 << "This case is not optimized by Intel MKL, "
1677 << "thus using Eigen op for Quantize op ";
1678 return false;
1679 }
1680 if (!((mode_string == "SCALED" && round_mode_string == "HALF_TO_EVEN") ||
1681 (mode_string == "MIN_FIRST"))) {
1682 VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED or MIN_FIRST and/or"
1683 << "rounding mode is not HALF_TO_EVEN. "
1684 << "This case is not optimized by Intel MKL, thus using Eigen op"
1685 << "for Quantize op ";
1686 return false;
1687 }
1688 if (filter_node->IsConstant()) {
1689 VLOG(1) << "QuantizeOpRewrite: Trying to quantize a node which "
1690 << "is a constant. "
1691 << "This case is not supported by the kernel, thus using Eigen op"
1692 << "for Quantize op ";
1693
1694 return false;
1695 }
1696 if (mode_string == "MIN_FIRST") {
1697 if (type != DT_QUINT8) {
1698 VLOG(1) << "QuantizeOpRewrite: For MIN_FIRST mode the data type is "
1699 << "not DT_UINT8. This case is not optimized by Intel MKL, "
1700 << "thus using Eigen op for Quantize op ";
1701 return false;
1702 }
1703 }
1704 return true;
1705 }
1706
MaxpoolGradRewrite(const Node * n)1707 static bool MaxpoolGradRewrite(const Node* n) {
1708 DCHECK(n);
1709 bool do_rewrite = false;
1710 for (const Edge* e : n->in_edges()) {
1711 // Rewrite only if there is corresponding Maxpool, i.e workspace is
1712 // available
1713 if (e->dst()->type_string() == csinfo_.max_pool_grad &&
1714 e->dst_input() == 1 &&
1715 e->src()->type_string() ==
1716 mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
1717 e->src_output() == 0) {
1718 do_rewrite = true;
1719 break;
1720 }
1721 }
1722 return do_rewrite;
1723 }
1724
Maxpool3DGradRewrite(const Node * n)1725 static bool Maxpool3DGradRewrite(const Node* n) {
1726 DCHECK(n);
1727 for (const Edge* e : n->in_edges()) {
1728 // Rewrite only if there is corresponding Maxpool3D, i.e., workspace is
1729 // available
1730 if (e->dst()->type_string() == csinfo_.max_pool3d_grad &&
1731 e->dst_input() == 1 &&
1732 e->src()->type_string() ==
1733 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d) &&
1734 e->src_output() == 0) {
1735 return true;
1736 }
1737 }
1738 return false;
1739 }
1740
FusedBatchNormV3Rewrite(const Node * n)1741 static bool FusedBatchNormV3Rewrite(const Node* n) {
1742 DCHECK(n);
1743 if (Check5DFormat(n->def())) {
1744 VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not "
1745 << "support 5D tensors.";
1746 return false;
1747 }
1748 return true;
1749 }
1750
FusedBatchNormExRewrite(const Node * n)1751 static bool FusedBatchNormExRewrite(const Node* n) {
1752 DCHECK(n);
1753
1754 int num_side_inputs;
1755 TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs));
1756 string activation_mode;
1757 TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode));
1758
1759 // if the num_side_inputs is not 0, don't rewrite the node.
1760 if (num_side_inputs != 0) {
1761 VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs"
1762 << "larger than 0 is not optimized by Intel MKL.";
1763 return false;
1764 }
1765
1766 // if the activation_mode is not 'Relu', don't rewrite the node.
1767 if (activation_mode != "Relu") {
1768 VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is"
1769 << "supported by Intel MKL.";
1770 return false;
1771 }
1772
1773 return true;
1774 }
1775
FusedConv2DRewrite(const Node * n)1776 static bool FusedConv2DRewrite(const Node* n) {
1777 // MKL DNN currently doesn't support all fusions that grappler fuses
1778 // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
1779 // it includes those we support.
1780 DataType T;
1781 if (!TryGetNodeAttr(n->def(), "T", &T) ||
1782 !mkl_op_registry::IsMklOp(NativeFormatEnabled()
1783 ? csinfo_.mkl_native_fused_conv2d
1784 : csinfo_.mkl_fused_conv2d,
1785 T)) {
1786 return false;
1787 }
1788
1789 std::vector<string> fused_ops;
1790 TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1791 return (fused_ops == std::vector<string>{"BiasAdd"} ||
1792 fused_ops == std::vector<string>{"Relu"} ||
1793 fused_ops == std::vector<string>{"Relu6"} ||
1794 fused_ops == std::vector<string>{"Elu"} ||
1795 fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1796 fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1797 fused_ops == std::vector<string>{"BiasAdd", "Elu"} ||
1798 fused_ops == std::vector<string>{"BiasAdd", "Add"} ||
1799 fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"} ||
1800 fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"} ||
1801 fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"} ||
1802 fused_ops == std::vector<string>{"LeakyRelu"} ||
1803 fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"} ||
1804 fused_ops == std::vector<string>{"BiasAdd", "Add", "LeakyRelu"} ||
1805 fused_ops == std::vector<string>{"FusedBatchNorm"} ||
1806 fused_ops == std::vector<string>{"FusedBatchNorm", "Relu"} ||
1807 fused_ops == std::vector<string>{"FusedBatchNorm", "Relu6"} ||
1808 fused_ops == std::vector<string>{"FusedBatchNorm", "Elu"} ||
1809 fused_ops == std::vector<string>{"FusedBatchNorm", "LeakyRelu"});
1810 }
1811
FusedDepthwiseConv2DRewrite(const Node * n)1812 static bool FusedDepthwiseConv2DRewrite(const Node* n) {
1813 // MKL DNN currently doesn't support all fusions that grappler fuses
1814 // together with DepthwiseConv2D (ex. batchnorm). We rewrite
1815 // _FusedDepthwiseConv2DNative only if it includes those we support.
1816 DataType T;
1817 if (!TryGetNodeAttr(n->def(), "T", &T) ||
1818 !mkl_op_registry::IsMklOp(
1819 NativeFormatEnabled() ? csinfo_.mkl_native_fused_depthwise_conv2d
1820 : csinfo_.mkl_fused_depthwise_conv2d,
1821 T)) {
1822 return false;
1823 }
1824
1825 std::vector<string> fused_ops;
1826 TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1827 return (fused_ops == std::vector<string>{"BiasAdd"} ||
1828 fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1829 fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1830 fused_ops == std::vector<string>{"BiasAdd", "Elu"});
1831 }
1832
1833 // Rewrites input node to a new node specified by its matching rewrite info.
1834 //
1835 // Method first searches matching rewrite info for input node and then
1836 // uses that info to rewrite.
1837 //
1838 // Input node may be deleted in case of rewrite. Attempt to use the node
1839 // after the call can result in undefined behaviors.
1840 //
1841 // @input g - input graph, n - Node to be rewritten,
1842 // ri - matching rewriteinfo
1843 // @return Status::OK(), if the input node is rewritten;
1844 // Returns appropriate Status error code otherwise.
1845 // Graph is updated in case the input node is rewritten.
1846 // Otherwise, it is not updated.
1847 Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
1848
1849 // Rewrites input node to just change its operator name. The number of
1850 // inputs to the node and the number of outputs remain the same. Attributes
1851 // of the new node could be copied from attributes of the old node or
1852 // modified. copy_attrs field of RewriteInfo controls this.
1853 //
1854 // Conceptually, it allows us to rewrite:
1855 //
1856 // f[a=v1,b=v2](x,y) -> g[a'=v3,b'=v4](x,y)
1857 //
1858 // Attributes can be altered without any restrictions --- they could be
1859 // copied, modified, or deleted completely.
1860 //
1861 // @input g - input graph, orig_node - Node to be rewritten,
1862 // ri - matching rewriteinfo
1863 // @output new_node - points to newly created node
1864 // @return Status::OK(), if the input node is rewritten;
1865 // Returns appropriate Status error code otherwise.
1866 // Graph is only updated when the input node is rewritten.
1867 Status RewriteNodeForJustOpNameChange(std::unique_ptr<Graph>* g,
1868 const Node* orig_node, Node** new_node,
1869 const RewriteInfo* ri);
1870
1871 // Rewrites input node to enable MKL layout propagation. Please also refer to
1872 // documentation for the function RewriteNodeForJustOpNameChange() to
1873 // understand what it means.
1874 //
1875 // @input g - input graph, orig_node - Node to be rewritten,
1876 // ri - matching rewriteinfo
1877 // @output new_node - points to newly created node
1878 // @return Status::OK(), if the input node is rewritten;
1879 // Returns appropriate Status error code otherwise.
1880 // Graph is updated in case the input node is rewritten.
1881 // Otherwise, it is not updated.
1882 Status RewriteNodeForLayoutPropagation(std::unique_ptr<Graph>* g,
1883 const Node* orig_node, Node** new_node,
1884 const RewriteInfo* ri);
1885
1886 // Get nodes that will feed a list of TF tensors to the new
1887 // node that we are constructing.
1888 //
1889 // @input g - input graph,
1890 // @input inputs - inputs to old node that we are using for constructing
1891 // new inputs,
1892 // @input input_idx - the index in the 'inputs' vector pointing to the
1893 // current input that we have processed so far
1894 // @output input_idx - index will be incremented by the number of nodes
1895 // from 'inputs' that are processed
1896 // @input list_length - The expected length of list of TF tensors
1897 // @output output_nodes - the list of new nodes creating TF tensors
1898 //
1899 // @return None
1900 void GetNodesProducingTFTensorList(
1901 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1902 int* input_idx, int list_length,
1903 std::vector<NodeBuilder::NodeOut>* output_nodes);
1904
1905 // Get nodes that will feed a list of Mkl tensors to the new
1906 // node that we are constructing.
1907 //
1908 // @input g - input graph,
1909 // @input orig_node - Original node that we are rewriting
1910 // @input inputs - inputs to old node that we are using for constructing
1911 // new inputs,
1912 // @input input_idx - the index in the 'inputs' vector pointing to the
1913 // current input that we have processed so far
1914 // @output input_idx - index will be incremented by the number of nodes
1915 // from 'inputs' that are processed
1916 // @input list_length - The expected length of list of Mkl tensors
1917 // @output output_nodes - the list of new nodes creating Mkl tensors
1918 //
1919 // @return None
1920 void GetNodesProducingMklTensorList(
1921 std::unique_ptr<Graph>* g, const Node* orig_node,
1922 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1923 int* input_idx, int list_length,
1924 std::vector<NodeBuilder::NodeOut>* output_nodes);
1925
1926 // Get a node that will feed an Mkl tensor to the new
1927 // node that we are constructing. The output node could be (1) 'n'
1928 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
1929 // if 'n' is not an Mkl layer.
1930 //
1931 // @input g - input graph,
1932 // @input orig_node - Original node that we are rewriting,
1933 // @input n - Node based on which we are creating Mkl node,
1934 // @input n_output_slot - the output slot of node 'n'
1935 // which is feeding to the node that we are constructing
1936 // @output mkl_node - the new node that will feed Mkl tensor
1937 // @output mkl_node_output_slot - the slot number of mkl_node that
1938 // will feed the tensor
1939 // @return None
1940 void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
1941 const Node* orig_node, Node* n,
1942 int n_output_slot, Node** mkl_node,
1943 int* mkl_node_output_slot);
1944
1945 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1946 // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
1947 // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
1948 // producing workspace edges if 'are_workspace_tensors_available' is true.
1949 // Otherwise, 'workspace_tensors' is empty vector.
1950 //
1951 // For details, refer to 'Ordering of inputs after rewriting' section in the
1952 // documentation above.
1953 //
1954 // Returns Status::OK() if setting up inputs is successful, otherwise
1955 // returns appropriate status code.
1956 int SetUpContiguousInputs(
1957 std::unique_ptr<Graph>* g,
1958 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1959 NodeBuilder* nb, const Node* old_node,
1960 std::vector<NodeBuilder::NodeOut>* workspace_tensors,
1961 bool are_workspace_tensors_available);
1962
1963 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1964 // in graph 'g'. Original node is input in 'orig_node'.
1965 //
1966 // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
1967 // section in the documentation above.
1968 //
1969 // Returns Status::OK() if setting up inputs is successful, otherwise
1970 // returns appropriate status code.
1971 Status SetUpInputs(std::unique_ptr<Graph>* g,
1972 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1973 NodeBuilder* nb, const Node* orig_node);
1974
1975 // Create new inputs by copying old inputs 'inputs' for the rewritten node
1976 // in 'nb' in graph 'g'. Original node is input in 'orig_node'. This is mostly
1977 // used in the context of rewrite for just operator name change in which
1978 // inputs of old operator and new operator are same.
1979 //
1980 // Returns Status::OK() if setting up inputs is successful, otherwise
1981 // returns appropriate status code.
1982 Status CopyInputs(const Node* orig_node,
1983 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1984 NodeBuilder* nb);
1985
1986 // Add workspace edge on the input or output side of Node 'orig_node' by using
1987 // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
1988 // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
1989 // tensors, if they need to be added, will be set into these tensors.
1990 // If we set workspace tensors, then are_ws_tensors_added should be true.
1991 void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
1992 const Node* orig_node, NodeBuilder* nb,
1993 std::vector<NodeBuilder::NodeOut>* ws_tensors,
1994 bool* are_ws_tensors_added);
1995
1996 // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
1997 // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
1998 // 'g'. Returns true if fixup was done; otherwise, it returns false.
1999 bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data,
2000 const Edge* e_metadata);
2001
2002 // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
2003 // connected? If not, then fix them. This is needed because a graph may have
2004 // some input Mkl metadata edges incorrectly setup after node merge and
2005 // rewrite passes. This could happen because GetReversePostOrder function may
2006 // not provide topologically sorted order if a graph contains cycles. The
2007 // function returns true if at least one Mkl metadata edge for node 'n' was
2008 // fixed. Otherwise, it returns false.
2009 //
2010 // Example:
2011 //
2012 // X = MklConv2D(_, _, _)
2013 // Y = MklConv2DWithBias(_, _, _, _, _, _)
2014 // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
2015 //
2016 // For a graph such as shown above, note that 3rd argument of MklAdd contains
2017 // DummyMklTensor. Actually, it should be getting the Mkl metadata from
2018 // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
2019 // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
2020 // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
2021 // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
2022 // data edges (1st and 2nd arguments of MklAdd).
2023 bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
2024
2025 // Functions specific to operators to copy attributes
2026 // We need operator-specific function to copy attributes because the framework
2027 // does not provide any generic function for it.
2028 // NOTE: names are alphabetically sorted.
2029 static void CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2030 bool change_format = false);
2031 static void CopyAttrsAllCheckConstFilter(const Node* orig_node,
2032 NodeBuilder* nb,
2033 bool change_format = false);
2034
2035 static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2036 bool change_format = false);
2037 static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
2038 NodeBuilder* nb,
2039 bool change_format = false);
2040 static void CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2041 const Node* orig_node2, NodeBuilder* nb,
2042 bool change_format = false);
2043 static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1,
2044 const Node* orig_node2,
2045 NodeBuilder* nb,
2046 bool change_format = false);
2047 static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb,
2048 bool change_format = false);
2049 static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb,
2050 const std::vector<int32>& strides,
2051 const std::vector<int32>& dilations,
2052 bool change_format = false);
2053
2054 static void CopyAttrsQuantizedMatMulWithBias(const Node* orig_node,
2055 NodeBuilder* nb,
2056 bool change_format = false);
2057 static void CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2058 const Node* orig_node, NodeBuilder* nb, bool change_format = false);
2059 static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb,
2060 bool change_format = false);
2061
2062 // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
2063 // using node for original node 'orig_node' and return it in '*out'.
2064 // TODO(nhasabni) We should move this to mkl_util.h
2065 void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
2066 const Node* orig_node);
2067 void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
2068 const Node* orig_node);
2069 };
2070
2071 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
2072
2073 // We register Mkl rewrite pass for phase 1 in post partitioning group.
2074 // We register it here so that we get a complete picture of all users of Mkl
2075 // nodes. Do not change the ordering of the Mkl passes.
2076 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
2077 OptimizationPassRegistry::POST_PARTITIONING;
2078 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
2079
2080 //////////////////////////////////////////////////////////////////////////
2081 // Helper functions for creating new node
2082 //////////////////////////////////////////////////////////////////////////
2083
FillInputs(const Node * n,gtl::InlinedVector<Node *,4> * control_edges,gtl::InlinedVector<std::pair<Node *,int>,4> * in)2084 static void FillInputs(const Node* n,
2085 gtl::InlinedVector<Node*, 4>* control_edges,
2086 gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
2087 control_edges->clear();
2088 for (const Edge* e : n->in_edges()) {
2089 if (e->IsControlEdge()) {
2090 control_edges->push_back(e->src());
2091 } else {
2092 (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
2093 }
2094 }
2095 std::sort(control_edges->begin(), control_edges->end());
2096 }
2097
GetNodesProducingTFTensorList(const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2098 void MklLayoutRewritePass::GetNodesProducingTFTensorList(
2099 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2100 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2101 CHECK_LT(*input_idx, inputs.size());
2102 CHECK_GT(list_length, 0);
2103 DCHECK(output_nodes);
2104 output_nodes->reserve(list_length);
2105
2106 while (list_length != 0) {
2107 CHECK_GT(list_length, 0);
2108 CHECK_LT(*input_idx, inputs.size());
2109 Node* n = inputs[*input_idx].first;
2110 int slot = inputs[*input_idx].second;
2111 // If input node 'n' is just producing a single tensor at
2112 // output slot 'slot' then we just add that single node.
2113 output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
2114 (*input_idx)++;
2115 list_length--;
2116 }
2117 }
2118
2119 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyMklTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2120 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
2121 Node** out,
2122 const Node* orig_node) {
2123 // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
2124 // dummy Mkl tensor. 8 = 2*size_t.
2125 const DataType dt = DataTypeToEnum<uint8>::v();
2126 TensorProto proto;
2127 proto.set_dtype(dt);
2128 uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
2129 proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8));
2130 TensorShape dummy_shape({8});
2131 dummy_shape.AsProto(proto.mutable_tensor_shape());
2132 TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
2133 .Attr("value", proto)
2134 .Attr("dtype", dt)
2135 .Device(orig_node->def().device()) // We place this node on
2136 // the same device as the
2137 // device of the original
2138 // node.
2139 .Finalize(&**g, out));
2140 DCHECK(*out); // Make sure we got a valid object before using it
2141
2142 // If number of inputs to the original node is > 0, then we add
2143 // control dependency between 1st input (index 0) of the original node and
2144 // the dummy Mkl node. This is needed because control-flow ops such as Enter,
2145 // Merge, etc, require frame_name of the dummy Mkl node to be same as the
2146 // rewritten node. Adding control edge between 1st input of the original node
2147 // and the dummy Mkl node ensures that the dummy node is in the same frame
2148 // as the original node. Choosing 1st input is not necessary - any input of
2149 // the original node is fine because all the inputs of a node are always in
2150 // the same frame.
2151 if (orig_node->num_inputs() > 0) {
2152 Node* orig_input0 = nullptr;
2153 TF_CHECK_OK(
2154 orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
2155 auto edge = (*g)->AddControlEdge(orig_input0, *out, false);
2156 DCHECK(edge != nullptr || DoesControlEdgeExist(orig_input0, *out));
2157 }
2158
2159 (*out)->set_assigned_device_name(orig_node->assigned_device_name());
2160 }
2161
GetNodesProducingMklTensorList(std::unique_ptr<Graph> * g,const Node * orig_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2162 void MklLayoutRewritePass::GetNodesProducingMklTensorList(
2163 std::unique_ptr<Graph>* g, const Node* orig_node,
2164 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2165 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2166 CHECK_LT(*input_idx, inputs.size());
2167 CHECK_GT(list_length, 0);
2168 DCHECK(output_nodes);
2169 output_nodes->reserve(list_length);
2170
2171 while (list_length != 0) {
2172 CHECK_GT(list_length, 0);
2173 CHECK_LT(*input_idx, inputs.size());
2174 Node* n = inputs[*input_idx].first;
2175 int slot = inputs[*input_idx].second;
2176 // If 'n' is producing a single tensor, then create a single Mkl tensor
2177 // node.
2178 Node* mkl_node = nullptr;
2179 int mkl_node_output_slot = 0;
2180 GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
2181 &mkl_node_output_slot);
2182 output_nodes->push_back(
2183 NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
2184 (*input_idx)++;
2185 list_length--;
2186 }
2187 }
2188
2189 // Get an input node that will feed Mkl tensor to the new
2190 // node that we are constructing. An input node could be (1) 'n'
2191 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
2192 // if 'n' is not an Mkl layer.
GetNodeProducingMklTensor(std::unique_ptr<Graph> * g,const Node * orig_node,Node * n,int n_output_slot,Node ** mkl_node,int * mkl_node_output_slot)2193 void MklLayoutRewritePass::GetNodeProducingMklTensor(
2194 std::unique_ptr<Graph>* g, const Node* orig_node, Node* n,
2195 int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
2196 DCHECK(n);
2197 DCHECK(mkl_node);
2198 DCHECK(mkl_node_output_slot);
2199
2200 // If this is an MKL op, then it will create extra output for MKL layout.
2201 DataType T;
2202 if (TryGetNodeAttr(n->def(), "T", &T) &&
2203 mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
2204 // If this is an MKL op, then it will generate an edge that will receive
2205 // Mkl tensor from a node.
2206 // output slot number for Mkl tensor would be N+slot number of TensorFlow
2207 // tensor, where N is total number of TensorFlow tensors.
2208 *mkl_node = n;
2209 *mkl_node_output_slot =
2210 GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
2211 } else {
2212 // If we have not visited the node and rewritten it, then we need
2213 // to create a dummy node that will feed a dummy Mkl tensor to this node.
2214 // DummyMklTensor node has no input and generates only 1 output
2215 // (dummy Mkl tensor) as output slot number 0.
2216 GetDummyMklTensorNode(g, mkl_node, orig_node);
2217 DCHECK(*mkl_node);
2218 *mkl_node_output_slot = 0;
2219 }
2220 }
2221
SetUpContiguousInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node,std::vector<NodeBuilder::NodeOut> * workspace_tensors,bool are_workspace_tensors_available)2222 int MklLayoutRewritePass::SetUpContiguousInputs(
2223 std::unique_ptr<Graph>* g,
2224 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2225 NodeBuilder* nb, const Node* old_node,
2226 std::vector<NodeBuilder::NodeOut>* workspace_tensors,
2227 bool are_workspace_tensors_available) {
2228 DCHECK(workspace_tensors);
2229 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2230
2231 // TODO(nhasabni): Temporary solution to connect filter input of
2232 // BackpropInput with the converted filter from Conv2D.
2233 bool do_connect_conv2d_backprop_input_filter = false;
2234 Node* conv2d_node = nullptr;
2235 // Filter node is 2nd input (slot index 1) of Conv2D.
2236 int kConv2DFilterInputSlotIdx = 1;
2237 int kConv2DBackpropInputFilterInputSlotIdx = 1;
2238 int kConv2DFilterOutputSlotIdx = 1;
2239 if (old_node->type_string() == csinfo_.conv2d_grad_input) {
2240 // We need to find Conv2D node from Conv2DBackpropInput.
2241 // For that let's first find filter node that is 2nd input (slot 1)
2242 // of BackpropInput.
2243 Node* filter_node = nullptr;
2244 TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx,
2245 &filter_node));
2246 DCHECK(filter_node);
2247
2248 // Now check which nodes receive from filter_node. Filter feeds as
2249 // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
2250 // _MklFusedConv2D.
2251 for (const Edge* e : filter_node->out_edges()) {
2252 if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
2253 e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d ||
2254 e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d ||
2255 e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias ||
2256 e->dst()->type_string() == csinfo_.mkl_fused_conv2d) &&
2257 e->dst_input() == kConv2DFilterInputSlotIdx
2258 /* filter is 2nd input of Conv2D and _MklConv2D. */) {
2259 if (conv2d_node != nullptr) {
2260 VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
2261 << " feeding multiple Conv2D nodes: "
2262 << filter_node->DebugString();
2263 // We will not connect filter input of Conv2DBackpropInput
2264 // to be safe here.
2265 do_connect_conv2d_backprop_input_filter = false;
2266 break;
2267 } else {
2268 conv2d_node = e->dst();
2269 do_connect_conv2d_backprop_input_filter = true;
2270 }
2271 }
2272 }
2273 }
2274
2275 // Number of input slots to original op
2276 // Input slots are represented by .Input() calls in REGISTER_OP.
2277 int old_node_input_slots = old_node->op_def().input_arg_size();
2278 int nn_slot_idx = 0; // slot index for inputs of new node
2279
2280 // Let's copy all inputs (TF tensors) of original node to new node.
2281 int iidx = 0;
2282 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2283 // An input slot could be a single tensor or a list. We need
2284 // to handle this case accordingly.
2285 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2286 if (ArgIsList(arg)) {
2287 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2288 int tensor_list_length = GetTensorListLength(arg, old_node);
2289 if (tensor_list_length != 0) {
2290 GetNodesProducingTFTensorList(old_node_inputs, &iidx,
2291 tensor_list_length, &new_node_inputs);
2292 }
2293 nb->Input(new_node_inputs);
2294 nn_slot_idx++;
2295 } else {
2296 // Special case for connecting filter input of Conv2DBackpropInput
2297 if (do_connect_conv2d_backprop_input_filter &&
2298 iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2299 nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
2300 } else {
2301 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2302 }
2303 iidx++;
2304 nn_slot_idx++;
2305 }
2306 }
2307
2308 // If workspace tensors are available for this op and we are using
2309 // contiguous ordering then we need to add Tensorflow tensor for
2310 // workspace here because Tensorflow tensor for workspace is the
2311 // last tensor in the list of Tensorflow tensors.
2312 if (are_workspace_tensors_available) {
2313 CHECK_EQ(workspace_tensors->size(), 2);
2314 // Tensorflow tensor
2315 nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
2316 nn_slot_idx++;
2317 }
2318
2319 // Let's now setup all Mkl inputs to a new node.
2320 // Number of Mkl inputs must be same as number of TF inputs.
2321 iidx = 0;
2322 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2323 // An input slot could be a single tensor or a list. We need
2324 // to handle this case accordingly.
2325 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2326 if (ArgIsList(arg)) {
2327 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2328 int tensor_list_length = GetTensorListLength(arg, old_node);
2329 if (tensor_list_length != 0) {
2330 GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
2331 tensor_list_length, &new_node_inputs);
2332 }
2333 nb->Input(new_node_inputs);
2334 nn_slot_idx++;
2335 } else {
2336 Node* mkl_node = nullptr;
2337 int mkl_node_output_slot = 0;
2338 // Special case for connecting filter input of Conv2DBackpropInput
2339 if (do_connect_conv2d_backprop_input_filter &&
2340 iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2341 GetNodeProducingMklTensor(g, old_node, conv2d_node,
2342 kConv2DFilterOutputSlotIdx, &mkl_node,
2343 &mkl_node_output_slot);
2344 } else {
2345 GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
2346 old_node_inputs[iidx].second, &mkl_node,
2347 &mkl_node_output_slot);
2348 }
2349 nb->Input(mkl_node, mkl_node_output_slot);
2350 iidx++;
2351 nn_slot_idx++;
2352 }
2353 }
2354
2355 // If workspace tensors are available for this op and we are using
2356 // contiguous ordering then we need to add Mkl tensor for
2357 // workspace here because Mkl tensor for workspace is the
2358 // last tensor in the list of Mkl tensors.
2359 if (are_workspace_tensors_available) {
2360 CHECK_EQ(workspace_tensors->size(), 2);
2361 // Mkl tensor
2362 nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
2363 nn_slot_idx++;
2364 }
2365
2366 return nn_slot_idx;
2367 }
2368
2369 // This method finds out if checking workspace is needed or not. Workspace is
2370 // not used in quantized ops, so checking that would fail as quantized ops
2371 // don't have attribute: "T".
IsWorkspaceCheckNeeded(const Node * node)2372 bool IsWorkspaceCheckNeeded(const Node* node) {
2373 std::vector<string> quant_ops{
2374 "Dequantize",
2375 "QuantizeV2",
2376 "QuantizedConv2D",
2377 "QuantizedConv2DWithBias",
2378 "QuantizedConv2DAndRelu",
2379 "QuantizedConv2DWithBiasAndRelu",
2380 "QuantizedConv2DWithBiasSumAndRelu",
2381 "QuantizedConv2DPerChannel",
2382 "QuantizedConv2DAndRequantize",
2383 "QuantizedConv2DWithBiasAndRequantize",
2384 "QuantizedConv2DAndReluAndRequantize",
2385 "QuantizedConv2DWithBiasAndReluAndRequantize",
2386 "QuantizedConv2DWithBiasSumAndReluAndRequantize",
2387 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize",
2388 "QuantizedMatMulWithBias",
2389 "QuantizedMatMulWithBiasAndRequantize",
2390 "QuantizedMatMulWithBiasAndDequantize",
2391 "QuantizedMatMulWithBiasAndRelu",
2392 "QuantizedMatMulWithBiasAndReluAndRequantize",
2393 "QuantizedDepthwiseConv2D",
2394 "QuantizedDepthwiseConv2DWithBias",
2395 "QuantizedDepthwiseConv2DWithBiasAndRelu",
2396 "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"};
2397 return std::find(std::begin(quant_ops), std::end(quant_ops),
2398 node->type_string()) == std::end(quant_ops);
2399 }
2400
SetUpInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node)2401 Status MklLayoutRewritePass::SetUpInputs(
2402 std::unique_ptr<Graph>* g,
2403 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2404 NodeBuilder* nb, const Node* old_node) {
2405 // Let's check if we need to add workspace tensors for this node.
2406 // We add workspace edge only for MaxPool, LRN and BatchNorm.
2407 std::vector<NodeBuilder::NodeOut> workspace_tensors;
2408 bool are_workspace_tensors_available = false;
2409
2410 if (IsWorkspaceCheckNeeded(old_node)) {
2411 AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
2412 &are_workspace_tensors_available);
2413 }
2414
2415 int new_node_input_slots = 0;
2416 if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
2417 // TODO(nhasabni): implement this function just for same of completion.
2418 // We do not use interleaved ordering right now.
2419 return Status(
2420 error::Code::UNIMPLEMENTED,
2421 "Interleaved ordering of tensors is currently not supported.");
2422 } else {
2423 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2424 new_node_input_slots = SetUpContiguousInputs(
2425 g, old_node_inputs, nb, old_node, &workspace_tensors,
2426 are_workspace_tensors_available);
2427 }
2428
2429 // Sanity check
2430 int old_node_input_slots = old_node->op_def().input_arg_size();
2431 if (!are_workspace_tensors_available) {
2432 // If we are not adding workspace tensors for this op, then the total
2433 // number of input slots to the new node _must_ be 2 times the number
2434 // of input slots to the original node: N original Tensorflow tensors and
2435 // N for Mkl tensors corresponding to each Tensorflow tensors.
2436 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
2437 } else {
2438 // If we are adding workspace tensors for this op, then the total
2439 // The total number of input slots to new node _must_ be 2 times the number
2440 // of input slots to the original node: N original Tensorflow tensors and
2441 // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
2442 // (for workspace Tensorflow tensor and workspace Mkl tensor).
2443 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
2444 }
2445
2446 return Status::OK();
2447 }
2448
CopyInputs(const Node * old_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb)2449 Status MklLayoutRewritePass::CopyInputs(
2450 const Node* old_node,
2451 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2452 NodeBuilder* nb) {
2453 // Number of input slots to old node
2454 // Input slots are represented by .Input() calls in REGISTER_OP.
2455 int old_node_input_slots = old_node->op_def().input_arg_size();
2456 // Actual number of inputs can be greater than or equal to number
2457 // of Input slots because inputs of type list could be unfolded.
2458 auto old_node_input_size = old_node_inputs.size();
2459 DCHECK_GE(old_node_input_size, old_node_input_slots);
2460
2461 // Let's copy all inputs of old node to new node.
2462 int iidx = 0;
2463 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2464 // An input slot could be a single tensor or a list. We need
2465 // to handle this case accordingly.
2466 DCHECK_LT(iidx, old_node_input_size);
2467 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2468 if (ArgIsList(arg)) {
2469 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2470 int N = GetTensorListLength(arg, old_node);
2471 if (N != 0) {
2472 GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
2473 &new_node_inputs);
2474 }
2475 nb->Input(new_node_inputs);
2476 } else {
2477 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2478 iidx++;
2479 }
2480 }
2481 return Status::OK();
2482 }
2483
2484 //////////////////////////////////////////////////////////////////////////
2485 // Helper functions related to workspace pass
2486 //////////////////////////////////////////////////////////////////////////
2487
2488 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyWorkspaceTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2489 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
2490 std::unique_ptr<Graph>* g, Node** out, const Node* orig_node) {
2491 // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
2492 // workspace tensor.
2493 GetDummyMklTensorNode(g, out, orig_node);
2494 }
2495
AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph> * g,const Node * orig_node,NodeBuilder * nb,std::vector<NodeBuilder::NodeOut> * ws_tensors,bool * are_ws_tensors_added)2496 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
2497 std::unique_ptr<Graph>* g, const Node* orig_node, NodeBuilder* nb,
2498 std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
2499 bool workspace_edge_added = false; // Default initializer
2500 DCHECK(are_ws_tensors_added);
2501 *are_ws_tensors_added = false; // Default initializer
2502
2503 DataType T;
2504 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2505 for (auto ws : wsinfo_) {
2506 if (orig_node->type_string() == ws.fwd_op &&
2507 mkl_op_registry::IsMklOp(
2508 mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
2509 // If this op is a fwd op, then we need to check if there is an
2510 // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
2511 // an edge, then we just add an attribute on this node for setting
2512 // workspace_passed to true. We don't add actual workspace edge
2513 // in this node. Actual workspace edge gets added in the backward
2514 // op for this node.
2515 for (const Edge* e : orig_node->out_edges()) {
2516 if (e->src_output() == ws.fwd_slot &&
2517 e->dst()->type_string() == ws.bwd_op &&
2518 e->dst_input() == ws.bwd_slot) {
2519 nb->Attr("workspace_enabled", true);
2520 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2521 << orig_node->type_string();
2522 workspace_edge_added = true;
2523 // We found the edge that we were looking for, so break.
2524 break;
2525 }
2526 }
2527
2528 if (!workspace_edge_added) {
2529 // If we are here, then we did not find backward operator for this
2530 // node.
2531 nb->Attr("workspace_enabled", false);
2532 }
2533 } else if (orig_node->type_string() == ws.bwd_op &&
2534 mkl_op_registry::IsMklOp(
2535 mkl_op_registry::GetMklOpName(orig_node->type_string()),
2536 T)) {
2537 // If this op is a bwd op, then we need to add workspace edge and
2538 // it's Mkl tensor edge between its corresponding fwd op and this
2539 // op. Corresponding fwd op is specified in 'fwd_op' field of
2540 // workspace info. fwd_slot and bwd_slot in workspace info specify
2541 // an edge between which slots connect forward and backward op.
2542 // Once all these criteria match, we add a workspace edge between
2543 // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
2544 // determined by interleaved/contiguous ordering. Function
2545 // DataIndexToMetaDataIndex tells us the location of Mkl tensor
2546 // from the location of the Tensorflow tensor.
2547 for (const Edge* e : orig_node->in_edges()) {
2548 if (e->src_output() == ws.fwd_slot &&
2549 // We would have rewritten the forward op, so we need to use
2550 // GetMklOpName call to get its Mkl name.
2551 e->src()->type_string() ==
2552 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
2553 e->dst_input() == ws.bwd_slot) {
2554 nb->Attr("workspace_enabled", true);
2555 DCHECK(ws_tensors);
2556 // Add workspace edge between fwd op and bwd op.
2557 ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
2558 // Check if we are running in native format mode. If so,
2559 // we don't need to have an Mkl metadata tensor for the workspace.
2560 if (!NativeFormatEnabled()) {
2561 // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
2562 ws_tensors->push_back(NodeBuilder::NodeOut(
2563 e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
2564 e->src()->num_outputs())));
2565 }
2566 *are_ws_tensors_added = true;
2567 // In terms of input ordering, we add these calls to add Input
2568 // here because workspace edge (and its Mkl tensor) is the last
2569 // edge in the fwdop and bwdop. So all inputs before workspace
2570 // tensor have been added by SetUpInputs function.
2571 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2572 << orig_node->type_string();
2573 workspace_edge_added = true;
2574 // We found the edge that we were looking for, so break.
2575 break;
2576 }
2577 }
2578
2579 // If we are here means we did not find fwd op that feeds to this
2580 // bwd op. So in this case, we need to generate dummy tensors for
2581 // workspace input and Mkl tensor for workspace, and set
2582 // workspace_enabled to false.
2583 if (!workspace_edge_added) {
2584 nb->Attr("workspace_enabled", false);
2585 Node* dmt_ws = nullptr; // Dummy tensor for workspace
2586 Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace
2587 GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
2588 GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
2589 DCHECK(dmt_ws);
2590 DCHECK(dmt_mkl_ws);
2591 DCHECK(ws_tensors);
2592 // We add dummy tensor as workspace tensor.
2593 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
2594 // We add dummy tensor as Mkl tensor for workspace tensor.
2595 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
2596 *are_ws_tensors_added = true;
2597 VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
2598 << orig_node->type_string();
2599 }
2600 } else {
2601 // If this node does not match any workspace info, then we do not
2602 // do anything special for workspace propagation for it.
2603 }
2604 }
2605 }
2606
2607 //////////////////////////////////////////////////////////////////////////
2608 // Op-specific functions to copy attributes from old node to new node
2609 //////////////////////////////////////////////////////////////////////////
2610
2611 // Generic function to copy all attributes from original node to target.
CopyAttrsAll(const Node * orig_node,NodeBuilder * nb,bool change_format)2612 void MklLayoutRewritePass::CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2613 bool change_format) {
2614 string name;
2615 AttrSlice attr_list(orig_node->def());
2616
2617 auto iter = attr_list.begin();
2618 while (iter != attr_list.end()) {
2619 name = iter->first;
2620 auto attr = iter->second;
2621 nb->Attr(name, attr);
2622 ++iter;
2623 }
2624 }
2625
2626 // Generic function to copy all attributes and check if filter is const.
CopyAttrsAllCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2627 void MklLayoutRewritePass::CopyAttrsAllCheckConstFilter(const Node* orig_node,
2628 NodeBuilder* nb,
2629 bool change_format) {
2630 CopyAttrsAll(orig_node, nb, change_format);
2631
2632 // Check and set filter attribute.
2633 Node* filter_node = nullptr;
2634 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2635 nb->Attr("is_filter_const", filter_node->IsConstant());
2636 }
2637
CopyAttrsConvCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2638 void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node,
2639 NodeBuilder* nb,
2640 bool change_format) {
2641 CopyAttrsConv(orig_node, nb, change_format);
2642
2643 // Check and set filter attribute.
2644 Node* filter_node = nullptr;
2645 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2646 nb->Attr("is_filter_const", filter_node->IsConstant());
2647 }
2648
CopyAttrsConv(const Node * orig_node,NodeBuilder * nb,bool change_format)2649 void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2650 bool change_format) {
2651 DataType T;
2652 string padding;
2653 std::vector<int32> strides;
2654 std::vector<int32> dilations;
2655 std::vector<int32> explicit_paddings;
2656
2657 // Get all attributes from old node.
2658 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2659 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2660 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2661 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2662
2663 // Check `explicit_paddings` first because some Conv ops don't have
2664 // this attribute.
2665 if (TryGetNodeAttr(orig_node->def(), "explicit_paddings",
2666 &explicit_paddings) &&
2667 !explicit_paddings.empty()) {
2668 nb->Attr("explicit_paddings", explicit_paddings);
2669 }
2670
2671 // Add attributes to new node.
2672 nb->Attr("T", T);
2673 nb->Attr("padding", padding);
2674
2675 // Add attributes related to `data_format`.
2676 CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
2677 }
2678
2679 // Used with MergePadWithConv2D
CopyAttrsFromPadAndConv2D(const Node * orig_node1,const Node * orig_node2,NodeBuilder * nb,bool change_format)2680 void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2681 const Node* orig_node2,
2682 NodeBuilder* nb,
2683 bool change_format) {
2684 DataType Tpaddings;
2685 DataType T;
2686 string data_format;
2687 string padding;
2688 std::vector<int32> strides;
2689 std::vector<int32> dilations;
2690 bool use_cudnn_on_gpu;
2691
2692 // Get all attributes from old node 1.
2693 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T", &T));
2694 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides", &strides));
2695 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations", &dilations));
2696 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding", &padding));
2697 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format", &data_format));
2698 TF_CHECK_OK(
2699 GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2700 // Get all attributes from old node 2.
2701 TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings", &Tpaddings));
2702
2703 // Add attributes to new node.
2704 nb->Attr("T", T);
2705 nb->Attr("strides", strides);
2706 nb->Attr("dilations", dilations);
2707 nb->Attr("padding", padding);
2708 nb->Attr("data_format", data_format);
2709 nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2710 nb->Attr("Tpaddings", Tpaddings);
2711 }
2712
CopyAttrsFromPadAndFusedConv2D(const Node * fused_conv2d,const Node * pad,NodeBuilder * nb,bool change_format)2713 void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
2714 const Node* fused_conv2d, const Node* pad, NodeBuilder* nb,
2715 bool change_format) {
2716 DataType T;
2717 int num_args;
2718 string data_format;
2719 string padding;
2720 std::vector<int32> strides;
2721 std::vector<int32> dilations;
2722 float epsilon;
2723 std::vector<string> fused_ops;
2724 DataType Tpaddings;
2725 float leakyrelu_alpha;
2726
2727 // Get all attributes from old node.
2728 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T));
2729 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args", &num_args));
2730 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides", &strides));
2731 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding", &padding));
2732 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format", &data_format));
2733 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations));
2734 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops));
2735 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon));
2736 TF_CHECK_OK(
2737 GetNodeAttr(fused_conv2d->def(), "leakyrelu_alpha", &leakyrelu_alpha));
2738 TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings));
2739
2740 // Add attributes to new node.
2741 nb->Attr("T", T);
2742 nb->Attr("num_args", num_args);
2743 nb->Attr("strides", strides);
2744 nb->Attr("padding", padding);
2745 nb->Attr("data_format", data_format);
2746 nb->Attr("dilations", dilations);
2747 nb->Attr("epsilon", epsilon);
2748 nb->Attr("Tpaddings", Tpaddings);
2749 nb->Attr("fused_ops", fused_ops);
2750 nb->Attr("leakyrelu_alpha", leakyrelu_alpha);
2751 }
2752
CopyAttrsQuantizedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2753 void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
2754 NodeBuilder* nb,
2755 bool change_format) {
2756 DataType Tinput, Tfilter, out_type;
2757 string padding;
2758 string data_format("NHWC");
2759 std::vector<int32> strides, dilations, padding_list;
2760 bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list");
2761
2762 // Get all attributes from old node.
2763 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
2764 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter", &Tfilter));
2765 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
2766 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2767 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2768 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2769 if (has_padding_list) {
2770 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list));
2771 }
2772
2773 Node* filter_node = nullptr;
2774 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2775
2776 // Add attributes to new node.
2777 nb->Attr("Tinput", Tinput);
2778 nb->Attr("Tfilter", Tfilter);
2779 nb->Attr("out_type", out_type);
2780 nb->Attr("padding", padding);
2781 nb->Attr("is_filter_const", filter_node->IsConstant());
2782 nb->Attr("strides", strides);
2783 nb->Attr("dilations", dilations);
2784 nb->Attr("data_format", data_format);
2785 if (has_padding_list) {
2786 nb->Attr("padding_list", padding_list);
2787 }
2788
2789 // Requantization attr Tbias.
2790 DataType Tbias;
2791 Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2792 if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2793 }
2794
CopyAttrsQuantizedMatMulWithBiasAndDequantize(const Node * orig_node,NodeBuilder * nb,bool change_format)2795 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2796 const Node* orig_node, NodeBuilder* nb, bool change_format) {
2797 CopyAttrsAll(orig_node, nb, change_format);
2798
2799 // Check and set filter attribute.
2800 Node* filter_node = nullptr;
2801 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2802 nb->Attr("is_weight_const", filter_node->IsConstant());
2803 }
2804
CopyAttrsQuantizedMatMulWithBias(const Node * orig_node,NodeBuilder * nb,bool change_format)2805 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBias(
2806 const Node* orig_node, NodeBuilder* nb, bool change_format) {
2807 DataType T1, T2, Toutput;
2808
2809 // Get all attributes from old node.
2810 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T1", &T1));
2811 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T2", &T2));
2812 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Toutput", &Toutput));
2813
2814 Node* weight_node = nullptr;
2815 TF_CHECK_OK(orig_node->input_node(1, &weight_node));
2816
2817 // Add attributes to new node.
2818 nb->Attr("T1", T1);
2819 nb->Attr("T2", T2);
2820 nb->Attr("Toutput", Toutput);
2821 nb->Attr("is_weight_const", weight_node->IsConstant());
2822
2823 // Requantization attr Tbias
2824 DataType Tbias;
2825 Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2826 if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2827 }
2828
CopyFormatAttrsConv(const Node * orig_node,NodeBuilder * nb,const std::vector<int32> & strides,const std::vector<int32> & dilations,bool change_format)2829 void MklLayoutRewritePass::CopyFormatAttrsConv(
2830 const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides,
2831 const std::vector<int32>& dilations, bool change_format) {
2832 string data_format;
2833
2834 if (!change_format) {
2835 nb->Attr("strides", strides);
2836 nb->Attr("dilations", dilations);
2837
2838 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2839 nb->Attr("data_format", data_format);
2840 } else {
2841 std::vector<int32> new_strides;
2842 std::vector<int32> new_dilations;
2843 if (strides.size() == 5) {
2844 // `strides` and `dilations` also need to be changed according to
2845 // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2846 new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2847 strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2848 strides[NDHWC::dim::W]};
2849
2850 new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C],
2851 dilations[NDHWC::dim::D], dilations[NDHWC::dim::H],
2852 dilations[NDHWC::dim::W]};
2853 } else {
2854 // `strides` and `dilations` also need to be changed according to
2855 // `data_format`. In this case, from `NHWC` to `NCHW`.
2856
2857 new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2858 strides[NHWC::dim::H], strides[NHWC::dim::W]};
2859
2860 new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C],
2861 dilations[NHWC::dim::H], dilations[NHWC::dim::W]};
2862 }
2863 nb->Attr("strides", new_strides);
2864 nb->Attr("dilations", new_dilations);
2865 }
2866 }
2867
CopyAttrsPooling(const Node * orig_node,NodeBuilder * nb,bool change_format)2868 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
2869 NodeBuilder* nb,
2870 bool change_format) {
2871 DataType T;
2872 string data_format;
2873 string padding;
2874 std::vector<int32> ksize, strides;
2875
2876 // Get all attributes from old node.
2877 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2878 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
2879 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2880 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2881 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2882
2883 // Add attributes to new node.
2884 nb->Attr("T", T);
2885 nb->Attr("padding", padding);
2886
2887 if (!change_format) {
2888 nb->Attr("strides", strides);
2889 nb->Attr("ksize", ksize);
2890
2891 nb->Attr("data_format", data_format);
2892 } else {
2893 std::vector<int32> new_strides;
2894 std::vector<int32> new_ksize;
2895 if (strides.size() == 5) {
2896 DCHECK(data_format == "NCDHW");
2897 // `strides` and `ksize` also need to be changed according to
2898 // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2899 new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2900 strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2901 strides[NDHWC::dim::W]};
2902
2903 new_ksize = {ksize[NDHWC::dim::N], ksize[NDHWC::dim::C],
2904 ksize[NDHWC::dim::D], ksize[NDHWC::dim::H],
2905 ksize[NDHWC::dim::W]};
2906
2907 } else {
2908 // `strides` and `ksize` also need to be changed according to
2909 // `data_format`. In this case, from `NHWC` to `NCHW`.
2910 DCHECK(data_format == "NCHW");
2911 new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2912 strides[NHWC::dim::H], strides[NHWC::dim::W]};
2913
2914 new_ksize = {ksize[NHWC::dim::N], ksize[NHWC::dim::C],
2915 ksize[NHWC::dim::H], ksize[NHWC::dim::W]};
2916 }
2917 nb->Attr("strides", new_strides);
2918 nb->Attr("ksize", new_ksize);
2919 }
2920 }
2921
2922 //////////////////////////////////////////////////////////////////////////
2923 // Helper functions related to node merge pass
2924 //////////////////////////////////////////////////////////////////////////
2925
CheckForNodeMerge(const Node * a) const2926 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
2927 // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
2928 // once we support BiasAddGrad as Mkl layer.
2929
2930 // Search for all matching mergeinfo.
2931 // We allow more than one match for extensibility.
2932 std::vector<const MergeInfo*> matching_mi;
2933 for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
2934 if (a->type_string() == mi->op1 || a->type_string() == mi->op2) {
2935 matching_mi.push_back(&*mi);
2936 }
2937 }
2938
2939 for (const MergeInfo* mi : matching_mi) {
2940 // Get the operand with which 'a' can be merged.
2941 Node* b = nullptr;
2942 if ((b = mi->get_node_to_be_merged(a)) == nullptr) {
2943 continue;
2944 }
2945
2946 // Get the control edges and input of node
2947 const int N_in = a->num_inputs();
2948 gtl::InlinedVector<Node*, 4> a_control_edges;
2949 gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
2950 FillInputs(a, &a_control_edges, &a_in);
2951
2952 const int B_in = b->num_inputs();
2953 gtl::InlinedVector<Node*, 4> b_control_edges;
2954 gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
2955 FillInputs(b, &b_control_edges, &b_in);
2956
2957 // Shouldn't merge if a and b have different control edges.
2958 if (a_control_edges != b_control_edges) {
2959 continue;
2960 } else {
2961 // We found a match.
2962 return b;
2963 }
2964 }
2965
2966 return nullptr;
2967 }
2968
MergeConv2DWithBiasAdd(std::unique_ptr<Graph> * g,Node * m,Node * n)2969 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
2970 Node* m, Node* n) {
2971 CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
2972 n->type_string() == csinfo_.conv2d)) ||
2973 ((n->type_string() == csinfo_.bias_add &&
2974 m->type_string() == csinfo_.conv2d)),
2975 true);
2976
2977 // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
2978 // BiasAdd is successor node, and Conv2D predecessor node.
2979 Node* pred = m->type_string() == csinfo_.bias_add ? n : m;
2980 Node* succ = m->type_string() == csinfo_.bias_add ? m : n;
2981
2982 // 1. Get all attributes from input nodes.
2983 DataType T_pred, T_succ;
2984 string padding;
2985 std::vector<int32> strides;
2986 std::vector<int32> dilations;
2987 string data_format_pred, data_format_succ;
2988 bool use_cudnn_on_gpu;
2989 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
2990 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
2991 TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
2992 TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
2993 TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations));
2994 TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
2995 TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
2996 TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2997 // We check to ensure that data formats of both succ and pred are same.
2998 // We expect them to be same, so we can enforce this as assert.
2999 // But assert can be too strict, so we enforce this as a check.
3000 // If the check fails, then we do not merge two nodes.
3001 // We also do same check for devices.
3002 if (data_format_pred != data_format_succ || T_pred != T_succ ||
3003 pred->assigned_device_name() != succ->assigned_device_name() ||
3004 pred->def().device() != succ->def().device()) {
3005 return Status(error::Code::INVALID_ARGUMENT,
3006 "data_format or T attribute or devices of Conv2D and "
3007 "BiasAdd do not match. Will skip node merge optimization");
3008 }
3009
3010 const int succ_num = succ->num_inputs();
3011 gtl::InlinedVector<Node*, 4> succ_control_edges;
3012 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3013 FillInputs(succ, &succ_control_edges, &succ_in);
3014
3015 const int pred_num = pred->num_inputs();
3016 gtl::InlinedVector<Node*, 4> pred_control_edges;
3017 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3018 FillInputs(pred, &pred_control_edges, &pred_in);
3019
3020 // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is
3021 // not expecting output of Conv2D). If this is not the case, then we cannot
3022 // merge Conv2D with BiasAdd.
3023 const int kFirstOutputSlot = 0;
3024 for (const Edge* e : pred->out_edges()) {
3025 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3026 return Status(error::Code::INVALID_ARGUMENT,
3027 "Conv2D does not feed to BiasAdd, or "
3028 "it feeds BiasAdd but has multiple outputs. "
3029 "Will skip node merge optimization");
3030 }
3031 }
3032
3033 // 2. Get inputs from both the nodes.
3034 // Find the 2 inputs from the conv and the bias from the add Bias.
3035 // Get operand 0, 1 of conv2D.
3036 CHECK_EQ(pred->in_edges().size(), 2); // Conv2D must have 2 inputs.
3037 // Get operand 1 of add_bias
3038 // BiasAdd must have 2 inputs: Conv, bias
3039 CHECK_EQ(succ->in_edges().size(), 2);
3040
3041 // We will use the node name of BiasAdd as the name of new node
3042 // Build new node. We use same name as original node, but change the op
3043 // name.
3044 NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias);
3045 nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D
3046 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3047 nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D
3048 // In1 of BiasAdd is same as output of Conv2D.
3049 nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
3050
3051 // Copy attributes from Conv2D to Conv2DWithBias.
3052 CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb);
3053
3054 // Copy the device assigned to old node to new node.
3055 nb.Device(succ->def().device());
3056
3057 // Create node.
3058 Node* new_node;
3059 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3060
3061 // In the following code of this function, an unsorted set is used to make
3062 // sure no duplicated edges be added into the new node. Therefore, we can
3063 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3064 // check in the routine.
3065
3066 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3067 // node are already copied in BuildNode. We handle control edges now.
3068 std::unordered_set<Node*> unique_node;
3069 for (const Edge* e : pred->in_edges()) {
3070 if (e->IsControlEdge()) {
3071 auto result = unique_node.insert(e->src());
3072 if (result.second) {
3073 (*g)->AddControlEdge(e->src(), new_node, true);
3074 }
3075 }
3076 }
3077 unique_node.clear();
3078
3079 for (const Edge* e : succ->in_edges()) {
3080 if (e->IsControlEdge()) {
3081 auto result = unique_node.insert(e->src());
3082 if (result.second) {
3083 (*g)->AddControlEdge(e->src(), new_node, true);
3084 }
3085 }
3086 }
3087 unique_node.clear();
3088
3089 // Incoming edges are fixed, we will fix the outgoing edges now.
3090 // First, we will fix outgoing control edges from 'pred' node.
3091 for (const Edge* e : pred->out_edges()) {
3092 if (e->IsControlEdge()) {
3093 auto result = unique_node.insert(e->dst());
3094 if (result.second) {
3095 (*g)->AddControlEdge(new_node, e->dst(), true);
3096 }
3097 }
3098 }
3099 unique_node.clear();
3100
3101 // Second, we will fix outgoing control and data edges from 'succ' node.
3102 for (const Edge* e : succ->out_edges()) {
3103 if (e->IsControlEdge()) {
3104 auto result = unique_node.insert(e->dst());
3105 if (result.second) {
3106 (*g)->AddControlEdge(new_node, e->dst(), true);
3107 }
3108 } else {
3109 // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
3110 // output (at slot 0).
3111 const int kConv2DWithBiasOutputSlot = 0;
3112 auto new_edge = (*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot,
3113 e->dst(), e->dst_input());
3114 DCHECK(new_edge);
3115 }
3116 }
3117
3118 // Copy device assigned to old node to new node.
3119 // It's ok to use pred or succ as we have enforced a check that
3120 // both have same device assigned.
3121 new_node->set_assigned_device_name(pred->assigned_device_name());
3122
3123 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3124 << ", and node: " << succ->DebugString()
3125 << ", into node:" << new_node->DebugString();
3126
3127 (*g)->RemoveNode(succ);
3128 (*g)->RemoveNode(pred);
3129
3130 return Status::OK();
3131 }
3132
MergePadWithConv2D(std::unique_ptr<Graph> * g,Node * m,Node * n)3133 Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g,
3134 Node* m, Node* n) {
3135 DCHECK((m->type_string() == csinfo_.pad &&
3136 (n->type_string() == csinfo_.conv2d ||
3137 n->type_string() == csinfo_.fused_conv2d)) ||
3138 (n->type_string() == csinfo_.pad &&
3139 (m->type_string() == csinfo_.conv2d ||
3140 m->type_string() == csinfo_.fused_conv2d)));
3141
3142 bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d ||
3143 m->type_string() == csinfo_.fused_conv2d;
3144 // Conv2D is successor node, and Pad predecessor node.
3145 Node* pred = m->type_string() == csinfo_.pad ? m : n;
3146 Node* succ = m->type_string() == csinfo_.pad ? n : m;
3147
3148 // 1. Get all attributes from input nodes.
3149 DataType T_pred, T_succ;
3150 string padding;
3151 std::vector<int32> strides;
3152 std::vector<int32> dilations;
3153 string data_format_pred, data_format_succ;
3154
3155 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
3156 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
3157 TF_CHECK_OK(GetNodeAttr(succ->def(), "padding", &padding));
3158 TF_CHECK_OK(GetNodeAttr(succ->def(), "strides", &strides));
3159 TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations", &dilations));
3160 // Check if the devices of both succ and pred are the same.
3161 // Assert is not used because it can be too strict.
3162 // Don't need to check for data formats because it is not available in Pad.
3163 if (T_pred != T_succ ||
3164 pred->assigned_device_name() != succ->assigned_device_name() ||
3165 pred->def().device() != succ->def().device()) {
3166 return Status(error::Code::INVALID_ARGUMENT,
3167 "T attribute or devices of Conv2D and "
3168 "Pad do not match. Will skip node merge optimization");
3169 }
3170
3171 const int succ_num = succ->num_inputs();
3172 gtl::InlinedVector<Node*, 4> succ_control_edges;
3173 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3174 FillInputs(succ, &succ_control_edges, &succ_in);
3175
3176 const int pred_num = pred->num_inputs();
3177 gtl::InlinedVector<Node*, 4> pred_control_edges;
3178 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3179 FillInputs(pred, &pred_control_edges, &pred_in);
3180
3181 // We need to ensure that Pad only feeds to Conv2D (some other operator is
3182 // not expecting output of Pad). If this is not the case, then we cannot
3183 // merge Conv2D with Pad.
3184 const int kFirstOutputSlot = 0;
3185 for (const Edge* e : pred->out_edges()) {
3186 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3187 return Status(error::Code::INVALID_ARGUMENT,
3188 "Pad does not feed to Conv2D, or "
3189 "it feeds Conv2D but has multiple outputs. "
3190 "Will skip node merge optimization");
3191 }
3192 }
3193
3194 // 2. Get inputs from both the nodes.
3195
3196 // Pad must have 2 data inputs: "input" and paddings.
3197 int PadDataInputEdges = 0;
3198 for (const Edge* e : pred->in_edges()) {
3199 if (!e->IsControlEdge()) {
3200 PadDataInputEdges++;
3201 }
3202 }
3203 DCHECK_EQ(PadDataInputEdges, 2);
3204
3205 // Conv2D must have 2 data inputs: Pad output and Filter
3206 // FusedConv2D have 3 data inputs: Pad output, Filter and Args;
3207 int ConvDataInputEdges = 0;
3208 for (const Edge* e : succ->in_edges()) {
3209 if (!e->IsControlEdge()) {
3210 ConvDataInputEdges++;
3211 }
3212 }
3213
3214 DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2);
3215
3216 // We will use the node name of Conv2D as the name of new node
3217 // Build new node. We use same name as original node, but change the op
3218 // name.
3219
3220 NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d
3221 : csinfo_.pad_with_conv2d);
3222 nb.Input(pred_in[0].first, pred_in[0].second); // In1 (input data) of Pad
3223 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3224 nb.Input(succ_in[1].first, succ_in[1].second); // In2 (filter) of conv2d
3225 // In1 of Conv2D is same as output of Pad.
3226 // Thus, only need to add In2 of Conv2D
3227
3228 if (is_fused_conv2d) {
3229 // FusedConv2D has one additional input, args
3230 std::vector<NodeBuilder::NodeOut> args;
3231 int num_args = 1;
3232 GetNodeAttr(succ->def(), "num_args", &num_args);
3233 for (int i = 0; i < num_args; i++) {
3234 args.emplace_back(succ_in[2 + i].first, succ_in[2 + i].second);
3235 }
3236 nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{
3237 args}); // In3 (args) of FusedConv2D
3238 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad
3239 // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D.
3240 CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ),
3241 const_cast<const Node*>(pred), &nb);
3242 } else {
3243 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad
3244 // Copy attributes from Pad and conv2D to PadWithConv2D.
3245 CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ),
3246 const_cast<const Node*>(pred), &nb);
3247 }
3248
3249 // Copy the device assigned to old node to new node.
3250 nb.Device(succ->def().device());
3251
3252 // Create node.
3253 Node* new_node;
3254 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3255 // No need to check if new_node is null because it will be null only when
3256 // Finalize fails.
3257
3258 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3259 // node are already copied in BuildNode.
3260 // We handle control edges now.
3261 for (const Edge* e : pred->in_edges()) {
3262 if (e->IsControlEdge()) {
3263 // Don't allow duplicate edge
3264 (*g)->AddControlEdge(e->src(), new_node, false);
3265 }
3266 }
3267 for (const Edge* e : succ->in_edges()) {
3268 if (e->IsControlEdge()) {
3269 // Don't allow duplicate edge
3270 (*g)->AddControlEdge(e->src(), new_node, false);
3271 }
3272 }
3273
3274 // Incoming edges are fixed, we will fix the outgoing edges now.
3275 // First, we will fix outgoing control edges from 'pred' node.
3276 for (const Edge* e : pred->out_edges()) {
3277 if (e->IsControlEdge()) {
3278 // Don't allow duplicate edge
3279 (*g)->AddControlEdge(new_node, e->dst(), false);
3280 }
3281 }
3282
3283 // Second, we will fix outgoing control and data edges from 'succ' node.
3284 for (const Edge* e : succ->out_edges()) {
3285 if (e->IsControlEdge()) {
3286 // Allow duplicate while adding control edge as it would fail (return
3287 // NULL) if we try to add duplicate edge.
3288 (*g)->AddControlEdge(new_node, e->dst(), false);
3289 } else {
3290 // Conv2D has only 1 output (at slot 0) and merged node also has only 1
3291 // output (at slot 0).
3292 const int kPadWithConv2DOutputSlot = 0;
3293 (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(),
3294 e->dst_input());
3295 }
3296 }
3297
3298 // Copy device assigned to old node to new node.
3299 // It's ok to use pred or succ as we have enforced a check that
3300 // both have same device assigned.
3301 new_node->set_assigned_device_name(pred->assigned_device_name());
3302
3303 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3304 << ", and node: " << succ->DebugString()
3305 << ", into node:" << new_node->DebugString();
3306
3307 (*g)->RemoveNode(succ);
3308 (*g)->RemoveNode(pred);
3309
3310 return Status::OK();
3311 }
3312
MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph> * g,Node * m,Node * n)3313 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
3314 std::unique_ptr<Graph>* g, Node* m, Node* n) {
3315 CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
3316 n->type_string() == csinfo_.conv2d_grad_filter)) ||
3317 ((n->type_string() == csinfo_.bias_add_grad &&
3318 m->type_string() == csinfo_.conv2d_grad_filter)),
3319 true);
3320
3321 // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
3322 Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
3323 Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m;
3324
3325 // Sanity check for attributes from input nodes.
3326 DataType T_b, T_f;
3327 string data_format_b, data_format_f;
3328 TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b));
3329 TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f));
3330 TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b));
3331 TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f));
3332 if (data_format_b != data_format_f || T_b != T_f ||
3333 badd->assigned_device_name() != fltr->assigned_device_name() ||
3334 badd->def().device() != fltr->def().device()) {
3335 return Status(error::Code::INVALID_ARGUMENT,
3336 "data_format or T attribute or devices of "
3337 "Conv2DBackpropFilter and BiasAddGrad do not match. "
3338 "Will skip node merge optimization");
3339 }
3340
3341 // We will use the node name of Conv2DBackpropFilter as the name of new node.
3342 // This is because BackpropFilterWithBias is going to emit bias output also.
3343 NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias);
3344 // Since Conv2DBackpropFilterWithBias has same number of inputs as
3345 // Conv2DBackpropFilter, we can just copy input edges directly. We don't need
3346 // to copy any data input of BiasAddGrad because that input also goes to
3347 // Conv2DBackpropFilter.
3348 const int fltr_ins = fltr->num_inputs();
3349 gtl::InlinedVector<Node*, 4> fltr_control_edges;
3350 gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins);
3351 FillInputs(fltr, &fltr_control_edges, &fltr_in_edges);
3352 for (int idx = 0; idx < fltr_ins; idx++) {
3353 nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second);
3354 }
3355
3356 // Copy attributes from Conv2DBackpropFilter.
3357 CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
3358
3359 // Copy the device assigned to old node to new node.
3360 nb.Device(fltr->def().device());
3361
3362 // Create node.
3363 Node* new_node;
3364 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3365
3366 // In the following code of this function, an unsorted set is used to make
3367 // sure no duplicated edges be added into the new node. Therefore, we can
3368 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3369 // check in the routine.
3370
3371 // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to
3372 // new 'new_node' node are already copied in BuildNode. We handle control
3373 // edges now.
3374 std::unordered_set<Node*> unique_node;
3375 for (const Edge* e : badd->in_edges()) {
3376 if (e->IsControlEdge()) {
3377 auto result = unique_node.insert(e->src());
3378 if (result.second) {
3379 (*g)->AddControlEdge(e->src(), new_node, true);
3380 }
3381 }
3382 }
3383 unique_node.clear();
3384 for (const Edge* e : fltr->in_edges()) {
3385 if (e->IsControlEdge()) {
3386 auto result = unique_node.insert(e->src());
3387 if (result.second) {
3388 (*g)->AddControlEdge(e->src(), new_node, true);
3389 }
3390 }
3391 }
3392 unique_node.clear();
3393
3394 // Incoming edges are fixed, we will fix the outgoing edges now.
3395 // First, we will fix outgoing control edges from 'badd' node.
3396 // Conv2DBackpropFilter has 1 output -- filter_grad.
3397 // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and
3398 // bias_grad. But filter_grad is at same slot number (0) in both the
3399 // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while
3400 // it is at slot number 0 in BiasAddGrad.
3401 const int kMergedNodeFilterGradOutputIdx = 0;
3402 const int kMergedNodeBiasGradOutputIdx = 1;
3403
3404 for (const Edge* e : badd->out_edges()) {
3405 if (e->IsControlEdge()) {
3406 auto result = unique_node.insert(e->dst());
3407 if (result.second) {
3408 (*g)->AddControlEdge(new_node, e->dst(), true);
3409 }
3410 } else {
3411 auto new_edge = (*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
3412 e->dst(), e->dst_input());
3413 DCHECK(new_edge);
3414 }
3415 }
3416 unique_node.clear();
3417
3418 // Second, we will fix outgoing control and data edges from 'fltr' node.
3419 for (const Edge* e : fltr->out_edges()) {
3420 if (e->IsControlEdge()) {
3421 auto result = unique_node.insert(e->dst());
3422 if (result.second) {
3423 (*g)->AddControlEdge(new_node, e->dst(), true);
3424 }
3425 } else {
3426 auto new_edge = (*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
3427 e->dst(), e->dst_input());
3428 DCHECK(new_edge);
3429 }
3430 }
3431
3432 // Copy device assigned to old node to new node.
3433 // It's ok to use badd or fltr as we have enforced a check that
3434 // both have same device assigned.
3435 new_node->set_assigned_device_name(badd->assigned_device_name());
3436
3437 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString()
3438 << ", and node: " << fltr->DebugString()
3439 << ", into node:" << new_node->DebugString();
3440
3441 (*g)->RemoveNode(badd);
3442 (*g)->RemoveNode(fltr);
3443
3444 return Status::OK();
3445 }
3446
MergeNode(std::unique_ptr<Graph> * g,Node * m,Node * n)3447 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m,
3448 Node* n) {
3449 DCHECK(m);
3450 DCHECK(n);
3451
3452 if (((m->type_string() == csinfo_.bias_add &&
3453 n->type_string() == csinfo_.conv2d)) ||
3454 ((n->type_string() == csinfo_.bias_add &&
3455 m->type_string() == csinfo_.conv2d))) {
3456 return this->MergeConv2DWithBiasAdd(g, m, n);
3457 }
3458 if ((m->type_string() == csinfo_.pad &&
3459 (n->type_string() == csinfo_.conv2d ||
3460 (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) ||
3461 (n->type_string() == csinfo_.pad &&
3462 (m->type_string() == csinfo_.conv2d ||
3463 (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) {
3464 return this->MergePadWithConv2D(g, m, n);
3465 }
3466
3467 if (((m->type_string() == csinfo_.bias_add_grad &&
3468 n->type_string() == csinfo_.conv2d_grad_filter)) ||
3469 ((n->type_string() == csinfo_.bias_add_grad &&
3470 m->type_string() == csinfo_.conv2d_grad_filter))) {
3471 return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n);
3472 }
3473
3474 return Status(error::Code::UNIMPLEMENTED,
3475 "Unimplemented case for node merge optimization.");
3476 }
3477
3478 //////////////////////////////////////////////////////////////////////////
3479 // Helper functions for node rewrite
3480 //////////////////////////////////////////////////////////////////////////
3481
RewriteNodeForLayoutPropagation(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3482 Status MklLayoutRewritePass::RewriteNodeForLayoutPropagation(
3483 std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3484 const RewriteInfo* ri) {
3485 // Get all data inputs.
3486 int num_data_inputs = orig_node->in_edges().size();
3487 // Drop count for control edges from inputs
3488 for (const Edge* e : orig_node->in_edges()) {
3489 if (e->IsControlEdge()) {
3490 num_data_inputs--;
3491 }
3492 }
3493
3494 gtl::InlinedVector<Node*, 4> control_edges;
3495 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3496 FillInputs(orig_node, &control_edges, &inputs);
3497
3498 // Build new node. We use same name as original node, but change the op name.
3499 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3500 // Copy user-specified device assigned to original node to new node.
3501 nb.Device(orig_node->def().device());
3502 // Set up new inputs to the rewritten node.
3503 Status s = SetUpInputs(g, inputs, &nb, orig_node);
3504 if (s != Status::OK()) {
3505 return s;
3506 }
3507
3508 const bool kPartialCopyAttrs = false;
3509 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs);
3510
3511 // Set the Mkl layer label for this op.
3512 if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3513 DataTypeIsQuantized(orig_node->output_type(0))) {
3514 nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3515 } else {
3516 nb.Attr("_kernel", mkl_op_registry::kMklLayoutDependentOpLabel);
3517 }
3518 // Finalize graph and get new node.
3519 s = nb.Finalize(&**g, new_node);
3520 if (s != Status::OK()) {
3521 return s;
3522 }
3523
3524 // In the following code of this function, an unsorted set is used to make
3525 // sure no duplicated edges be added into the new node. Therefore, we can
3526 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3527 // check in the routine.
3528
3529 // Incoming data edges from 'orig_node' node to new 'new_node' node are
3530 // already copied in BuildNode. We need to handle control edges now.
3531 std::unordered_set<Node*> unique_node;
3532 for (const Edge* e : orig_node->in_edges()) {
3533 if (e->IsControlEdge()) {
3534 auto result = unique_node.insert(e->src());
3535 if (result.second) {
3536 (*g)->AddControlEdge(e->src(), *new_node, true);
3537 }
3538 }
3539 }
3540 unique_node.clear();
3541
3542 // Copy outgoing edges from 'orig_node' node to new
3543 // 'new_node' node, since the output also follows same ordering among
3544 // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
3545 // tensors appropriately. Specifically, nth output of the original node
3546 // will become 2*nth output of the Mkl node for the interleaved ordering
3547 // of the tensors. For the contiguous ordering of the tensors, it will be n.
3548 // GetTensorDataIndex provides this mapping function.
3549 for (const Edge* e : orig_node->out_edges()) {
3550 if (e->IsControlEdge()) {
3551 auto result = unique_node.insert(e->dst());
3552 if (result.second) {
3553 (*g)->AddControlEdge(*new_node, e->dst(), true);
3554 }
3555 } else {
3556 auto new_edge = (*g)->AddEdge(
3557 *new_node,
3558 GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
3559 e->dst(), e->dst_input());
3560 DCHECK(new_edge);
3561 }
3562 }
3563 return Status::OK();
3564 }
3565
RewriteNodeForJustOpNameChange(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3566 Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(
3567 std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3568 const RewriteInfo* ri) {
3569 // Get all data inputs.
3570 int num_data_inputs = orig_node->in_edges().size();
3571 // Drop count for control edges from inputs
3572 for (const Edge* e : orig_node->in_edges()) {
3573 if (e->IsControlEdge()) {
3574 num_data_inputs--;
3575 }
3576 }
3577 gtl::InlinedVector<Node*, 4> control_edges;
3578 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3579 FillInputs(orig_node, &control_edges, &inputs);
3580
3581 // Build new node. We use same name as original node, but change the op name.
3582 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3583 // Copy user-specified device assigned to original node to new node.
3584 nb.Device(orig_node->def().device());
3585
3586 Status s = CopyInputs(orig_node, inputs, &nb);
3587 if (s != Status::OK()) {
3588 return s;
3589 }
3590
3591 std::vector<NodeBuilder::NodeOut> workspace_tensors;
3592 bool are_workspace_tensors_available = false;
3593 if (IsWorkspaceCheckNeeded(orig_node)) {
3594 AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors,
3595 &are_workspace_tensors_available);
3596 if (are_workspace_tensors_available) {
3597 DCHECK_EQ(workspace_tensors.size(), 1);
3598 nb.Input(workspace_tensors[0].node, workspace_tensors[0].index);
3599 }
3600 }
3601
3602 if (!NativeFormatEnabled()) {
3603 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
3604 } else {
3605 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, false);
3606 }
3607
3608 if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3609 DataTypeIsQuantized(orig_node->output_type(0))) {
3610 nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3611 } else {
3612 nb.Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
3613 }
3614
3615 // Finalize graph and get new node.
3616 s = nb.Finalize(&**g, new_node);
3617 if (s != Status::OK()) {
3618 return s;
3619 }
3620
3621 // In the following code of this function, an unsorted set is used to make
3622 // sure no duplicated edges be added into the new node. Therefore, we can
3623 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3624 // check in the routine.
3625
3626 // Incoming data edges from 'orig_node' node to new 'new_node' node are
3627 // already copied in BuildNode. We need to handle control edges now.
3628 std::unordered_set<Node*> unique_node;
3629 for (const Edge* e : orig_node->in_edges()) {
3630 if (e->IsControlEdge()) {
3631 auto result = unique_node.insert(e->src());
3632 if (result.second) {
3633 (*g)->AddControlEdge(e->src(), *new_node, true);
3634 }
3635 }
3636 }
3637 unique_node.clear();
3638
3639 // Transfer outgoing edges from 'orig_node' node to new 'new_node' node.
3640 for (const Edge* e : orig_node->out_edges()) {
3641 if (e->IsControlEdge()) {
3642 auto result = unique_node.insert(e->dst());
3643 if (result.second) {
3644 (*g)->AddControlEdge(*new_node, e->dst(), true);
3645 }
3646 } else {
3647 auto result =
3648 (*g)->AddEdge(*new_node, e->src_output(), e->dst(), e->dst_input());
3649 DCHECK(result != nullptr);
3650 }
3651 }
3652
3653 return Status::OK();
3654 }
3655
RewriteNode(std::unique_ptr<Graph> * g,Node * orig_node,const RewriteInfo * ri)3656 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
3657 Node* orig_node,
3658 const RewriteInfo* ri) {
3659 DCHECK(ri != nullptr);
3660 DCHECK(orig_node != nullptr);
3661
3662 VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
3663
3664 Status ret_status = Status::OK();
3665 Node* new_node = nullptr;
3666 if (ri->rewrite_cause == kRewriteForLayoutPropagation) {
3667 ret_status = RewriteNodeForLayoutPropagation(g, orig_node, &new_node, ri);
3668 } else if (ri->rewrite_cause == kRewriteForOpNameChange) {
3669 ret_status = RewriteNodeForJustOpNameChange(g, orig_node, &new_node, ri);
3670 } else {
3671 ret_status = Status(error::Code::INVALID_ARGUMENT,
3672 "Unsupported rewrite cause found."
3673 "RewriteNode will fail.");
3674 }
3675 TF_CHECK_OK(ret_status);
3676
3677 // Copy the runtime device assigned from original code to new node.
3678 new_node->set_assigned_device_name(orig_node->assigned_device_name());
3679
3680 // Delete original node and mark new node as rewritten.
3681 (*g)->RemoveNode(orig_node);
3682
3683 VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
3684 return ret_status;
3685 }
3686
3687 // TODO(mdfaijul): Is there any other elegant way to check for quantized ops
3688 // having attributes other than "T"?
3689 // Current implementation reflects only QuantizedConv2D and its fused Ops.
3690 const MklLayoutRewritePass::RewriteInfo*
CheckForQuantizedNodeRewrite(const Node * n) const3691 MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {
3692 DataType T1, T2;
3693 DataType Tinput, Tfilter;
3694 bool type_attrs_present = false;
3695
3696 if (TryGetNodeAttr(n->def(), "Tinput", &Tinput) &&
3697 TryGetNodeAttr(n->def(), "Tfilter", &Tfilter) &&
3698 mkl_op_registry::IsMklQuantizedOp(
3699 mkl_op_registry::GetMklOpName(n->type_string()), Tinput, Tfilter)) {
3700 type_attrs_present = true;
3701 } else if (TryGetNodeAttr(n->def(), "T1", &T1) &&
3702 TryGetNodeAttr(n->def(), "T2", &T2) &&
3703 mkl_op_registry::IsMklQuantizedOp(
3704 mkl_op_registry::GetMklOpName(n->type_string()), T1, T2)) {
3705 type_attrs_present = true;
3706 }
3707
3708 if (type_attrs_present) {
3709 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3710 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3711 return &*ri;
3712 }
3713 }
3714 }
3715
3716 return nullptr;
3717 }
3718
3719 const MklLayoutRewritePass::RewriteInfo*
CheckForNodeRewrite(const Node * n) const3720 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
3721 DCHECK(n);
3722
3723 // QuantizedOps may have attributes other than "T", so decoupled the check
3724 // with a function, CheckForQuantizedNodeRewrite(const Node*).
3725 const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n);
3726 if (ri != nullptr) return ri;
3727
3728 // First check if node along with its type is supported by MKL layer.
3729 // We do not want to rewrite an op into Mkl op if types are not supported.
3730 // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
3731 // MklRelu if type is INT32.
3732 DataType T;
3733 if (!TryGetNodeAttr(n->def(), "T", &T)) {
3734 return nullptr;
3735 }
3736
3737 // We make an exception for Conv2DGrad and MaxPool related ops as
3738 // the corresponding MKL ops currently do not support the case
3739 // of padding == EXPLICIT yet.
3740 // TODO(intel): support `EXPLICIT` padding for ConvGrad
3741 if (n->type_string() == csinfo_.conv2d_grad_input ||
3742 n->type_string() == csinfo_.conv2d_grad_filter ||
3743 n->type_string() == csinfo_.max_pool ||
3744 n->type_string() == csinfo_.max_pool_grad ||
3745 n->type_string() == csinfo_.max_pool3d ||
3746 n->type_string() == csinfo_.max_pool3d_grad) {
3747 string padding;
3748 TF_CHECK_OK(GetNodeAttr(n->def(), "padding", &padding));
3749 if (padding == "EXPLICIT") return nullptr;
3750 }
3751
3752 // We make an exception for __MklDummyConv2DWithBias,
3753 // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their
3754 // names do not match Mkl node names.
3755 if (n->type_string() != csinfo_.conv2d_with_bias &&
3756 n->type_string() != csinfo_.pad_with_conv2d &&
3757 n->type_string() != csinfo_.pad_with_fused_conv2d &&
3758 n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
3759 n->type_string() != csinfo_.fused_batch_norm_ex &&
3760 n->type_string() != csinfo_.fused_conv2d &&
3761 n->type_string() != csinfo_.fused_depthwise_conv2d &&
3762 n->type_string() != csinfo_.fused_matmul &&
3763 !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
3764 T)) {
3765 return nullptr;
3766 }
3767
3768 // We now check if rewrite rule applies for this op. If rewrite rule passes
3769 // for this op, then we rewrite it to Mkl op.
3770 // Find matching RewriteInfo and then check that rewrite rule applies.
3771 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3772 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3773 return &*ri;
3774 }
3775 }
3776
3777 // Else return not found.
3778 return nullptr;
3779 }
3780
3781 //////////////////////////////////////////////////////////////////////////
3782 // Helper functions for node fusion
3783 //////////////////////////////////////////////////////////////////////////
FuseTransposeMklOpTranspose(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,std::function<void (const Node *,NodeBuilder * nb,bool)> copy_attrs,string data_format)3784 Status MklLayoutRewritePass::FuseTransposeMklOpTranspose(
3785 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3786 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
3787 string data_format) {
3788 Node* transpose_to_nhwc = nodes[0];
3789 Node* mklop = nodes[1];
3790 Node* transpose_to_nchw = nodes[2];
3791
3792 const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs();
3793 gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges;
3794 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in(
3795 transpose_nhwc_num_inputs);
3796 FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges,
3797 &transpose_nhwc_in);
3798
3799 const int mklop_num_inputs = mklop->num_inputs();
3800 gtl::InlinedVector<Node*, 4> mklop_control_edges;
3801 gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs);
3802 FillInputs(mklop, &mklop_control_edges, &mklop_in);
3803
3804 const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs();
3805 gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges;
3806 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in(
3807 transpose_nchw_num_inputs);
3808 FillInputs(transpose_to_nchw, &transpose_nchw_control_edges,
3809 &transpose_nchw_in);
3810
3811 // We use same name as original node, but change the op
3812 // type.
3813 NodeBuilder nb(mklop->name(), mklop->type_string());
3814
3815 // Storing the output slots of the input nodes.
3816 for (int i = 0; i < mklop_num_inputs; i++) {
3817 if (mklop_in[i].first == transpose_to_nhwc) {
3818 // Fill "x":
3819 nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second);
3820 } else {
3821 // Fill inputs other than "x":
3822 nb.Input(mklop_in[i].first, mklop_in[i].second);
3823 }
3824 }
3825
3826 copy_attrs(const_cast<const Node*>(mklop), &nb, true);
3827 nb.Attr("data_format", data_format);
3828
3829 // Copy the device assigned to old node to new node.
3830 nb.Device(mklop->def().device());
3831
3832 // Create node.
3833 Node* new_node;
3834 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3835 // No need to check if new_node is null because it will be null only when
3836 // Finalize fails.
3837
3838 // Fill outputs.
3839 for (const Edge* e : transpose_to_nchw->out_edges()) {
3840 if (!e->IsControlEdge()) {
3841 const int kTransposeWithMklOpOutputSlot = 0;
3842 auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot,
3843 e->dst(), e->dst_input());
3844 DCHECK(new_edge);
3845 }
3846 }
3847
3848 // Copy device assigned to old node to new node.
3849 new_node->set_assigned_device_name(mklop->assigned_device_name());
3850
3851 // Copy requested_device and assigned_device_name_index
3852 new_node->set_requested_device(mklop->requested_device());
3853 new_node->set_assigned_device_name_index(mklop->assigned_device_name_index());
3854
3855 (*g)->RemoveNode(transpose_to_nhwc);
3856 (*g)->RemoveNode(mklop);
3857 (*g)->RemoveNode(transpose_to_nchw);
3858
3859 return Status::OK();
3860 }
3861
FuseNode(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,const MklLayoutRewritePass::FusionInfo fi)3862 Status MklLayoutRewritePass::FuseNode(
3863 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3864 const MklLayoutRewritePass::FusionInfo fi) {
3865 return fi.fuse_func(g, nodes, fi.copy_attrs);
3866 }
3867
3868 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
CheckForNodeFusion(Node * a) const3869 MklLayoutRewritePass::CheckForNodeFusion(Node* a) const {
3870 // Stores matched nodes, in the same order as node_checkers.
3871 std::vector<Node*> nodes;
3872
3873 for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) {
3874 //
3875 // Make sure node "a" and its succeeding nodes (b, c ...), match the pattern
3876 // defined in fusion info (ops[0], ops[1], ...),
3877 // a.k.a. "a->b->c" matches "op1->op2->op3"
3878 //
3879
3880 // Stores the first unvisited outgoing edge of each matched node in "nodes".
3881 std::stack<EdgeSet::const_iterator> current_neighbor_stack;
3882 nodes.clear();
3883
3884 auto node_checker = fi->node_checkers.begin();
3885 if (a != nullptr && (*node_checker)(a)) {
3886 nodes.push_back(a);
3887 current_neighbor_stack.push(a->out_edges().begin());
3888 ++node_checker;
3889 }
3890
3891 while (!nodes.empty()) {
3892 auto& current_neighbor_iter = current_neighbor_stack.top();
3893
3894 if (current_neighbor_iter != nodes.back()->out_edges().end()) {
3895 // Found an unvisited edge. Goes through the edge to get the neighbor.
3896 Node* neighbor_node = (*current_neighbor_iter)->dst();
3897 ++current_neighbor_stack.top(); // Retrieves the next unvisited edge.
3898
3899 if ((*node_checker)(neighbor_node)) {
3900 // Found a match. Stores the node and moves to the next checker.
3901 nodes.push_back(neighbor_node);
3902 current_neighbor_stack.push(neighbor_node->out_edges().begin());
3903 if (++node_checker == fi->node_checkers.end()) {
3904 return make_tuple(true, nodes, *fi);
3905 }
3906 }
3907 } else {
3908 // Removes the current node since none of its neighbor leads to a
3909 // further match.
3910 nodes.pop_back();
3911 current_neighbor_stack.pop();
3912 --node_checker;
3913 }
3914 }
3915 }
3916
3917 return make_tuple(false, std::vector<Node*>(), FusionInfo());
3918 }
3919
3920 ///////////////////////////////////////////////////////////////////////////////
3921 // Post-rewrite Mkl metadata fixup pass
3922 ///////////////////////////////////////////////////////////////////////////////
FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph> * g,const Edge * e_data,const Edge * e_metadata)3923 bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
3924 const Edge* e_data,
3925 const Edge* e_metadata) {
3926 if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
3927 return false;
3928 }
3929
3930 Node* n_data = e_data->src();
3931 int n_data_op_slot = e_data->src_output();
3932 int n_metadata_op_slot =
3933 GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs());
3934
3935 // If the source of meta edge is a constant node (producing dummy Mkl metadata
3936 // tensor), then we will need to fix.
3937 if (IsConstant(e_metadata->src())) {
3938 Node* e_metadata_dst = e_metadata->dst();
3939 int e_metadata_in_slot = e_metadata->dst_input();
3940 auto new_edge = (*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst,
3941 e_metadata_in_slot);
3942 DCHECK(new_edge);
3943
3944 (*g)->RemoveEdge(e_metadata);
3945 return true;
3946 }
3947
3948 return false;
3949 }
3950
FixMklMetaDataEdges(std::unique_ptr<Graph> * g,Node * n)3951 bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
3952 Node* n) {
3953 bool result = false;
3954
3955 // If graph node is not Mkl node, then return.
3956 DataType T = DT_INVALID;
3957 if (!TryGetNodeAttr(n->def(), "T", &T) ||
3958 !mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
3959 return result;
3960 }
3961
3962 // If it is Mkl node, then check if the input edges to this node that carry
3963 // Mkl metadata are linked up correctly with the source node.
3964
3965 // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
3966 // data tensors + n for Mkl metadata tensors). We need to check for correct
3967 // connection of n metadata tensors only.
3968 int num_data_inputs = n->num_inputs() / 2;
3969 for (int idx = 0; idx < num_data_inputs; idx++) {
3970 // Get the edge connecting input slot with index (idx).
3971 const Edge* e = nullptr;
3972 TF_CHECK_OK(n->input_edge(idx, &e));
3973
3974 // If e is control edge, then skip.
3975 if (e->IsControlEdge()) {
3976 continue;
3977 }
3978
3979 // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
3980 // node, then we don't need to do anything.
3981 Node* e_src = e->src();
3982 if (TryGetNodeAttr(e_src->def(), "T", &T) &&
3983 mkl_op_registry::IsMklLayoutDependentOp(e_src->type_string(), T)) {
3984 // Source node for edge 'e' is Mkl node.
3985 // Destination node and destination input slot of e is node 'n' and 'idx'
3986 // resp.
3987 CHECK_EQ(e->dst(), n);
3988 CHECK_EQ(e->dst_input(), idx);
3989
3990 // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
3991 // 'e'. For that, let's first get the input slot of 'n' where the meta
3992 // edge will feed the value.
3993 int e_meta_in_slot =
3994 GetTensorMetaDataIndex(e->dst_input(), n->num_inputs());
3995 const Edge* e_meta = nullptr;
3996 TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
3997
3998 // Let's check if we need to fix this meta edge.
3999 if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
4000 result = true;
4001 }
4002 }
4003 }
4004
4005 return result;
4006 }
4007
4008 ///////////////////////////////////////////////////////////////////////////////
4009 // Run function for the pass
4010 ///////////////////////////////////////////////////////////////////////////////
4011
RunPass(std::unique_ptr<Graph> * g)4012 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
4013 bool result = false;
4014 DCHECK(g);
4015
4016 DumpGraph("Before running MklLayoutRewritePass", &**g);
4017
4018 std::vector<Node*> order;
4019 GetReversePostOrder(**g, &order); // This will give us topological sort.
4020 for (Node* n : order) {
4021 // If node is not an op or it cannot run on CPU device, then skip.
4022 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4023 continue;
4024 }
4025
4026 Node* m = nullptr;
4027 if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) {
4028 // Check if the node 'n' can be merged with any other node. If it can
4029 // be 'm' contains the node with which it can be merged.
4030 string n1_name = n->name();
4031 string n2_name = m->name();
4032
4033 VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
4034 << n2_name << " for merging";
4035
4036 if (MergeNode(g, n, m) == Status::OK()) {
4037 VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
4038 << n2_name;
4039 result = true;
4040 }
4041 }
4042 }
4043
4044 DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g);
4045
4046 order.clear();
4047 GetReversePostOrder(**g, &order); // This will give us topological sort.
4048 for (Node* n : order) {
4049 // If node is not an op or it cannot run on CPU device, then skip.
4050 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4051 continue;
4052 }
4053
4054 auto check_result = CheckForNodeFusion(n);
4055 bool found_pattern = std::get<0>(check_result);
4056 std::vector<Node*> nodes = std::get<1>(check_result);
4057 const FusionInfo fi = std::get<2>(check_result);
4058
4059 // if "found_pattern" is true, we can do the fusion.
4060 if (found_pattern) {
4061 if (FuseNode(g, nodes, fi) == Status::OK()) {
4062 result = true;
4063 }
4064 }
4065 }
4066 DumpGraph("After running MklLayoutRewritePass(NodeFusion)", &**g);
4067
4068 order.clear();
4069 GetReversePostOrder(**g, &order); // This will give us topological sort.
4070 for (Node* n : order) {
4071 // If node is not an op or it cannot run on CPU device, then skip.
4072 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4073 continue;
4074 }
4075
4076 const RewriteInfo* ri = nullptr;
4077 // We will first search if node is to be rewritten.
4078 if ((ri = CheckForNodeRewrite(n)) != nullptr) {
4079 string node_name = n->name();
4080 string op_name = n->type_string();
4081
4082 VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
4083 << " with op " << op_name << " for rewrite using"
4084 << " layout optimization.";
4085
4086 if (RewriteNode(g, n, ri) == Status::OK()) {
4087 VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
4088 << " with op " << op_name << " for Mkl layout optimization.";
4089 result = true;
4090 }
4091 }
4092 }
4093
4094 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
4095
4096 order.clear();
4097 GetReversePostOrder(**g, &order); // This will give us topological sort.
4098 for (Node* n : order) {
4099 // If node is not an op or it cannot run on CPU device, then skip.
4100 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4101 continue;
4102 }
4103 if (FixMklMetaDataEdges(g, n)) {
4104 string node_name = n->name();
4105 string op_name = n->type_string();
4106
4107 VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
4108 << node_name << " with op " << op_name;
4109 result = true;
4110 }
4111 }
4112 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
4113 &**g);
4114
4115 return result;
4116 }
4117
RunMklLayoutRewritePass(std::unique_ptr<Graph> * g)4118 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
4119 return MklLayoutRewritePass().RunPass(g);
4120 }
4121
Run(const GraphOptimizationPassOptions & options)4122 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
4123 if (options.graph == nullptr && options.partition_graphs == nullptr) {
4124 return Status::OK();
4125 }
4126 if (!IsMKLEnabled()) {
4127 VLOG(2) << "TF-MKL: MKL is not enabled";
4128 return Status::OK();
4129 }
4130
4131 auto process_graph = [&](std::unique_ptr<Graph>* g) {
4132 // Get the ownership of a graph
4133 std::unique_ptr<Graph>* ng = std::move(g);
4134 RunPass(ng);
4135 // Return the ownership of a graph back
4136 g->reset(ng->release());
4137 };
4138
4139 if (kMklLayoutRewritePassGroup !=
4140 OptimizationPassRegistry::POST_PARTITIONING) {
4141 // For any pre-partitioning phase, a graph is stored in options.graph.
4142 process_graph(options.graph);
4143 } else {
4144 // For post partitioning phase, graphs are stored in
4145 // options.partition_graphs.
4146 for (auto& pg : *options.partition_graphs) {
4147 process_graph(&pg.second);
4148 }
4149 }
4150
4151 return Status::OK();
4152 }
4153
4154 } // namespace tensorflow
4155
4156 #endif // INTEL_MKL
4157