1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // TODO(intel): Improve error handling in this file; instead of CHECK failing
17 // all over the place, we should log an error and execute the original graph.
18 #ifdef INTEL_MKL
19
20 #include "tensorflow/core/common_runtime/mkl_layout_pass.h"
21
22 #include <algorithm>
23 #include <functional>
24 #include <memory>
25 #include <queue>
26 #include <set>
27 #include <stack>
28 #include <tuple>
29 #include <unordered_set>
30 #include <utility>
31 #include <vector>
32
33 #include "tensorflow/core/common_runtime/function.h"
34 #include "tensorflow/core/common_runtime/optimization_registry.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/graph/graph.h"
39 #include "tensorflow/core/graph/mkl_graph_util.h"
40 #include "tensorflow/core/graph/node_builder.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/lib/gtl/array_slice.h"
43 #include "tensorflow/core/lib/gtl/map_util.h"
44 #include "tensorflow/core/lib/hash/hash.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/util/tensor_format.h"
47 #include "tensorflow/core/util/util.h"
48
49 namespace tensorflow {
50
51 // This pass implements rewriting of graph to support following scenarios:
52 // (A) Merging nodes in the graph
53 // (B) Rewriting a node in the graph to a new node
54 // Rewrite happens under following scenario:
55 // - Propagating Mkl layout as an additional output tensor
56 // (we will loosely call a tensor that carries Mkl layout as Mkl tensor
57 // henceforth.) from every Mkl supported NN layer.
58 //
59 // Example of A : Merging nodes in the graph
60 // -----------------------------------------
61 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
62 //
63 // O = Conv2D(A, B)
64 // P = BiasAdd(O, C)
65 //
66 // We merge them into Conv2DWithBias as:
67 // P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
68 //
69 // The meaning of A_m, B_m and C_m is explained in B.1.
70 //
71 // Merge rules:
72 // - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
73 // goes to BiasAdd.
74 // - Also, the intersection of attributes of both the nodes must have same
75 // values.
76 // - Both the nodes must have been assigned to same device (if any).
77 //
78 // Example of B.1 : Rewriting nodes to Mkl nodes
79 // ---------------------------------------------
80 // Consider a Relu node. Current definition of Relu node looks like:
81 //
82 // O = Relu(A)
83 //
84 // Relu has 1 input (A), and 1 output (O).
85 //
86 // This rewrite pass will generate a new graph node for Relu (new node is
87 // called MklRelu) as:
88 //
89 // O, O_m = MklRelu(A, A_m)
90 //
91 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
92 // same as input A of Relu; output O is same as output O of Relu. O_m is the
93 // additional output tensor that will be set by MklRelu, and it represents
94 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
95 // metadata for O. A_m is additional input of Relu, and it represents metadata
96 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
97 // this metadata from previous node in the graph.
98 //
99 // When a previous node in the graph is an Mkl node, A_m will represent a valid
100 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
101 // a dummy Mkl tensor.
102 //
103 // Rewriting rules:
104 // - Selection of a node for rewriting happens by registering the op type of
105 // the node with the rewriting pass. If the op type is not registered, then
106 // all nodes of this op type will not be rewritten.
107 // - Number of inputs after rewriting:
108 // Since for every input Tensorflow tensor, the rewritten node gets Mkl
109 // tensor(s), rewritten node gets 2*N inputs, where N is the number of
110 // inputs for the original node.
111 // - Number of outputs after rewriting:
112 // Since for every output Tensorflow tensor, the rewritten node generates
113 // Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
114 // number of outputs of the original node.
115 // - Ordering of Tensorflow tensors and Mkl tensors:
116 // Since every rewritten node generates twice the number of inputs and
117 // outputs, one could imagine various orderings among Tensorflow tensors
118 // and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
119 // inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
120 // in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
121 // order. Among N inputs one can get N! permutations.
122 //
123 // So the question is: which order do we follow? We support 2 types of
124 // orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
125 // follows an intuitive order where an Mkl tensor follows the
126 // corresponding Tensorflow tensor immediately. In the context of the
127 // above example, it will be: A, A_m, B, B_m. Note that the ordering rule
128 // applies to both the inputs and outputs. Contiguous ordering means
129 // all the Tensorflow tensors are contiguous followed by all the Mkl
130 // tensors. We use contiguous ordering as default.
131 //
132 // Graph rewrite algorithm:
133 // Algorithm: Graph Rewrite
134 // Input: Graph G, Names of the nodes to rewrite and their new names
135 // Output: Modified Graph G' if the nodes are modified, G otherwise.
136 // Start:
137 // N = Topological_Sort(G) // N is a set of nodes in toposort order.
138 // foreach node n in N
139 // do
140 // if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input.
141 // then
142 // E = set of <incoming edge and its src_output slot> of n
143 // E' = {} // a new set of edges for rewritten node
144 // foreach <e,s> in E
145 // do
146 // E' U {<e,s>} // First copy edge which generates Tensorflow
147 // // tensor as it is
148 // m = Source node of edge e
149 // if Is_Rewritten(m) // Did we rewrite this node in this pass?
150 // then
151 // E' U {<m,s+1>} // If yes, then m will generate an Mkl
152 // // tensor as an additional output.
153 // else
154 // d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy
155 // // Mkl tensor.
156 // E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot.
157 // fi
158 // done
159 // n' = Build_New_Node(G,new_name,E')
160 // Mark_Rewritten(n') // Mark the new node as being rewritten.
161 // fi
162 // done
163 //
164 // Explanation:
165 // For graph rewrite, we visit nodes of the input graph in the
166 // topological sort order. With this ordering, we visit nodes in the
167 // top-to-bottom fashion. We need this order because while visiting a
168 // node we want that all of its input nodes are visited and rewritten if
169 // applicable. This is because if we need to rewrite a given node
170 // then all of its input nodes need to be fixed (in other words they
171 // cannot be deleted later.)
172 //
173 // While visiting a node, we first check if the op type of the node is
174 // an Mkl op. If it is, then we rewrite that node after constructing
175 // new inputs to the node. If the op type of the node is not Mkl op,
176 // then we do not rewrite that node.
177 //
178 // Handling workspace propagation for certain ops:
179 //
180 // Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
181 // passing of a workspace from their respective forward ops. Workspace
182 // tensors provide memory for storing results of intermediate operations
183 // which are helpful in backward propagation. TensorFlow does not have
184 // a notion of a workspace and as a result does not allow producing
185 // additional outputs from these forward ops. For these ops, we need
186 // to add 2 extra edges between forward ops and their corresponding
187 // backward ops - the first extra edge carries a workspace tensor and
188 // the second one carries an Mkl tensor for the workspace tensor.
189 //
190 // Example:
191 //
192 // Typical graph for MaxPool and its gradient looks like:
193 //
194 // A = MaxPool(T)
195 // B = MaxPoolGrad(X, A, Y)
196 //
197 // We will transform this graph to propagate the workspace as:
198 // (with the contiguous ordering)
199 //
200 // A, W, A_m, W_m = MklMaxPool(T, T_m)
201 // B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
202 //
203 // Here W is the workspace tensor. Transformed tensor names with the
204 // suffix _m are Mkl tensors, and this transformation has been done
205 // using the algorithm discussed earlier. The transformation for
206 // workspace propagation only adds extra outputs (W, W_m) for a forward
207 // op and connects them to the corresponding backward ops.
208 //
209 // Terms:
210 //
211 // Forward op name = name of the op in the forward pass
212 // where a workspace tensor originates (MaxPool in this example)
213 // Backward op name = name of the op in the backward pass that receives
214 // a workspace tensor from the forward op (MaxPoolGrad in the example)
215 // Slot = Position of the output or input slot that will be
216 // used by the workspace tensor (1 for MklMaxPool as W is the 2nd
217 // output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
218 //
219 // Question:
220 //
221 // How do we associate a backward op to a forward op? There can be more
222 // than one op with the exact same name.
223 //
224 // In this example, we associate MaxPoolGrad with MaxPool. But there
225 // could be more than one MaxPool ops. To solve this problem, we look
226 // for _direct_ edge between a forward op and a backward op (tensor A is
227 // flowing along this edge in the example).
228 //
229 // How do we transform forward and backward ops when there is no direct
230 // edge between them? In such a case, we generate dummy tensors for
231 // workspace tensors. For the example, transformation of MaxPool will
232 // be exactly same as it would be when there is a direct edge between
233 // the forward and the backward op --- it is just that MaxPool won't
234 // generate any workspace tensor. For MaxPoolGrad, the transformation
235 // will also be same, but instead of connecting W and W_m with the
236 // outputs of MaxPool, we will produce dummy tensors for them, and we
237 // will set workspace_enabled attribute to false.
238 //
239 class MklLayoutRewritePass : public GraphOptimizationPass {
240 public:
MklLayoutRewritePass()241 MklLayoutRewritePass() {
242 // NOTE: names are alphabetically sorted.
243 csinfo_.addn = "AddN";
244 csinfo_.avg_pool = "AvgPool";
245 csinfo_.avg_pool_grad = "AvgPoolGrad";
246 csinfo_.avg_pool3d = "AvgPool3D";
247 csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
248 csinfo_.batch_matmul = "BatchMatMul";
249 csinfo_.batch_matmul_v2 = "BatchMatMulV2";
250 csinfo_.bias_add = "BiasAdd";
251 csinfo_.bias_add_grad = "BiasAddGrad";
252 csinfo_.concat = "Concat";
253 csinfo_.concatv2 = "ConcatV2";
254 csinfo_.conjugate_transpose = "ConjugateTranspose";
255 csinfo_.conv2d = "Conv2D";
256 csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias";
257 csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
258 csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
259 csinfo_.conv2d_grad_filter_with_bias =
260 "__MklDummyConv2DBackpropFilterWithBias";
261 csinfo_.conv3d = "Conv3D";
262 csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
263 csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
264 csinfo_.depthwise_conv2d = "DepthwiseConv2dNative";
265 csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput";
266 csinfo_.depthwise_conv2d_grad_filter =
267 "DepthwiseConv2dNativeBackpropFilter";
268 csinfo_.dequantize = "Dequantize";
269 csinfo_.fused_batch_norm = "FusedBatchNorm";
270 csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
271 csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx";
272 csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
273 csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
274 csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
275 csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
276 csinfo_.fused_conv2d = "_FusedConv2D";
277 csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
278 csinfo_.fused_matmul = "_FusedMatMul";
279 csinfo_.identity = "Identity";
280 csinfo_.leakyrelu = "LeakyRelu";
281 csinfo_.leakyrelu_grad = "LeakyReluGrad";
282 csinfo_.lrn = "LRN";
283 csinfo_.lrn_grad = "LRNGrad";
284 csinfo_.matmul = "MatMul";
285 csinfo_.max_pool = "MaxPool";
286 csinfo_.max_pool_grad = "MaxPoolGrad";
287 csinfo_.max_pool3d = "MaxPool3D";
288 csinfo_.max_pool3d_grad = "MaxPool3DGrad";
289 csinfo_.mkl_conv2d = "_MklConv2D";
290 csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
291 csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
292 csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
293 csinfo_.mkl_conv2d_grad_filter_with_bias =
294 "_MklConv2DBackpropFilterWithBias";
295 csinfo_.mkl_depthwise_conv2d_grad_input =
296 "_MklDepthwiseConv2dNativeBackpropInput";
297 csinfo_.mkl_depthwise_conv2d_grad_filter =
298 "_MklDepthwiseConv2dNativeBackpropFilter";
299 csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx";
300 csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
301 csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
302 csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
303 csinfo_.mkl_native_conv2d_with_bias = "_MklNativeConv2DWithBias";
304 csinfo_.mkl_native_fused_batch_norm_ex = "_MklNativeFusedBatchNormEx";
305 csinfo_.mkl_native_fused_conv2d = "_MklNativeFusedConv2D";
306 csinfo_.mkl_native_fused_depthwise_conv2d =
307 "_MklNativeFusedDepthwiseConv2dNative";
308 csinfo_.mkl_native_fused_matmul = "_MklNativeFusedMatMul";
309 csinfo_.mkl_native_pad_with_conv2d = "_MklNativePadWithConv2D";
310 csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D";
311 csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
312 csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
313 csinfo_.pad = "Pad";
314 csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D";
315 csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D";
316 csinfo_.quantized_avg_pool = "QuantizedAvgPool";
317 csinfo_.quantized_concatv2 = "QuantizedConcatV2";
318 csinfo_.quantized_conv2d = "QuantizedConv2D";
319 csinfo_.quantized_conv2d_per_channel = "QuantizedConv2DPerChannel";
320 csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize";
321 csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias";
322 csinfo_.quantized_conv2d_with_bias_and_requantize =
323 "QuantizedConv2DWithBiasAndRequantize";
324 csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu";
325 csinfo_.quantized_conv2d_and_relu_and_requantize =
326 "QuantizedConv2DAndReluAndRequantize";
327 csinfo_.quantized_conv2d_with_bias_and_relu =
328 "QuantizedConv2DWithBiasAndRelu";
329 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize =
330 "QuantizedConv2DWithBiasAndReluAndRequantize";
331 csinfo_.quantized_max_pool = "QuantizedMaxPool";
332 csinfo_.quantized_conv2d_with_bias_sum_and_relu =
333 "QuantizedConv2DWithBiasSumAndRelu";
334 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize =
335 "QuantizedConv2DWithBiasSumAndReluAndRequantize";
336 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize =
337 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize";
338 csinfo_.quantized_matmul_with_bias = "QuantizedMatMulWithBias";
339 csinfo_.quantized_matmul_with_bias_and_relu =
340 "QuantizedMatMulWithBiasAndRelu";
341 csinfo_.quantized_matmul_with_bias_and_relu_and_requantize =
342 "QuantizedMatMulWithBiasAndReluAndRequantize";
343 csinfo_.quantized_matmul_with_bias_and_dequantize =
344 "QuantizedMatMulWithBiasAndDequantize";
345 csinfo_.quantized_matmul_with_bias_and_requantize =
346 "QuantizedMatMulWithBiasAndRequantize";
347 csinfo_.quantized_depthwise_conv2d = "QuantizedDepthwiseConv2D";
348 csinfo_.quantized_depthwise_conv2d_with_bias =
349 "QuantizedDepthwiseConv2DWithBias";
350 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu =
351 "QuantizedDepthwiseConv2DWithBiasAndRelu";
352 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize =
353 "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize";
354 csinfo_.quantize_v2 = "QuantizeV2";
355 csinfo_.relu = "Relu";
356 csinfo_.relu_grad = "ReluGrad";
357 csinfo_.relu6 = "Relu6";
358 csinfo_.relu6_grad = "Relu6Grad";
359 csinfo_.requantize = "Requantize";
360 csinfo_.tanh = "Tanh";
361 csinfo_.tanh_grad = "TanhGrad";
362 csinfo_.reshape = "Reshape";
363 csinfo_.slice = "Slice";
364 csinfo_.softmax = "Softmax";
365 csinfo_.split = "Split";
366 csinfo_.transpose = "Transpose";
367 // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
368 // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
369 // MklInputConversion op is added before it.
370 csinfo_.add = "Add";
371 csinfo_.add_v2 = "AddV2";
372 csinfo_.maximum = "Maximum";
373 csinfo_.mul = "Mul";
374 csinfo_.squared_difference = "SquaredDifference";
375 csinfo_.sub = "Sub";
376 // End - element-wise ops. See note above.
377
378 const bool native_fmt = NativeFormatEnabled();
379 // NOTE: names are alphabetically sorted.
380 rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
381 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
382 rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
383 CopyAttrsAll, RewriteIfAtleastOneMklInput,
384 GetRewriteCause()});
385 rinfo_.push_back(
386 {csinfo_.add_v2, mkl_op_registry::GetMklOpName(csinfo_.add_v2),
387 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
388 rinfo_.push_back({csinfo_.avg_pool,
389 mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
390 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
391 rinfo_.push_back({csinfo_.avg_pool_grad,
392 mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
393 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
394 rinfo_.push_back({csinfo_.avg_pool3d,
395 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
396 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
397 rinfo_.push_back({csinfo_.avg_pool3d_grad,
398 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
399 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
400 rinfo_.push_back({csinfo_.batch_matmul,
401 mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
402 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
403 rinfo_.push_back({csinfo_.batch_matmul_v2,
404 mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
405 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
406 rinfo_.push_back({csinfo_.concat,
407 mkl_op_registry::GetMklOpName(csinfo_.concat),
408 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
409 rinfo_.push_back({csinfo_.concatv2,
410 mkl_op_registry::GetMklOpName(csinfo_.concatv2),
411 CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
412 rinfo_.push_back(
413 {csinfo_.conjugate_transpose,
414 mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
415 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
416 rinfo_.push_back(
417 {csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d),
418 CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
419 rinfo_.push_back({csinfo_.conv2d_with_bias,
420 native_fmt ? csinfo_.mkl_native_conv2d_with_bias
421 : csinfo_.mkl_conv2d_with_bias,
422 CopyAttrsConvCheckConstFilter, AlwaysRewrite,
423 GetRewriteCause()});
424 rinfo_.push_back({csinfo_.conv2d_grad_filter,
425 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
426 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
427 rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
428 csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv,
429 AlwaysRewrite, GetRewriteCause()});
430 rinfo_.push_back({csinfo_.conv2d_grad_input,
431 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
432 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
433 rinfo_.push_back(
434 {csinfo_.conv3d, mkl_op_registry::GetMklOpName(csinfo_.conv3d),
435 CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
436 rinfo_.push_back({csinfo_.conv3d_grad_filter,
437 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
438 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
439 rinfo_.push_back({csinfo_.conv3d_grad_input,
440 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
441 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
442 rinfo_.push_back({csinfo_.depthwise_conv2d,
443 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
444 CopyAttrsConvCheckConstFilter, AlwaysRewrite,
445 GetRewriteCause()});
446 rinfo_.push_back(
447 {csinfo_.depthwise_conv2d_grad_input,
448 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
449 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
450 rinfo_.push_back(
451 {csinfo_.depthwise_conv2d_grad_filter,
452 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
453 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
454 rinfo_.push_back({csinfo_.dequantize,
455 mkl_op_registry::GetMklOpName(csinfo_.dequantize),
456 CopyAttrsAll, DequantizeRewrite, GetRewriteCause()});
457 rinfo_.push_back({csinfo_.fused_batch_norm,
458 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
459 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
460 rinfo_.push_back(
461 {csinfo_.fused_batch_norm_grad,
462 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
463 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
464 rinfo_.push_back(
465 {csinfo_.fused_batch_norm_v2,
466 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v2),
467 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
468 rinfo_.push_back(
469 {csinfo_.fused_batch_norm_grad_v2,
470 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v2),
471 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
472
473 // Using CopyAttrsAll for V3 on CPU, as there are no additional
474 // attributes.
475 rinfo_.push_back(
476 {csinfo_.fused_batch_norm_v3,
477 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
478 CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
479 rinfo_.push_back(
480 {csinfo_.fused_batch_norm_grad_v3,
481 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
482 CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
483 rinfo_.push_back({csinfo_.fused_batch_norm_ex,
484 native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
485 : csinfo_.mkl_fused_batch_norm_ex,
486 CopyAttrsAll, FusedBatchNormExRewrite,
487 GetRewriteCause()});
488 rinfo_.push_back({csinfo_.fused_conv2d,
489 native_fmt ? csinfo_.mkl_native_fused_conv2d
490 : csinfo_.mkl_fused_conv2d,
491 CopyAttrsAllCheckConstFilter, FusedConv2DRewrite,
492 GetRewriteCause()});
493 rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
494 native_fmt ? csinfo_.mkl_native_fused_depthwise_conv2d
495 : csinfo_.mkl_fused_depthwise_conv2d,
496 CopyAttrsAllCheckConstFilter, FusedDepthwiseConv2DRewrite,
497 GetRewriteCause()});
498 rinfo_.push_back({csinfo_.fused_matmul,
499 native_fmt ? csinfo_.mkl_native_fused_matmul
500 : csinfo_.mkl_fused_matmul,
501 CopyAttrsAllCheckConstFilter, FusedMatMulRewrite,
502 GetRewriteCause()});
503
504 rinfo_.push_back(
505 {csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity),
506 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
507 rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
508 CopyAttrsAll, LrnRewrite, GetRewriteCause()});
509 rinfo_.push_back({csinfo_.lrn_grad,
510 mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
511 CopyAttrsAll, LrnGradRewrite, GetRewriteCause()});
512 rinfo_.push_back({csinfo_.matmul,
513 mkl_op_registry::GetMklOpName(csinfo_.matmul),
514 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
515 rinfo_.push_back({csinfo_.leakyrelu,
516 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
517 CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
518 rinfo_.push_back({csinfo_.leakyrelu_grad,
519 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
520 CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
521 rinfo_.push_back(
522 {csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool),
523 CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
524 rinfo_.push_back({csinfo_.max_pool_grad,
525 mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
526 CopyAttrsAll, MaxpoolGradRewrite, GetRewriteCause()});
527 rinfo_.push_back(
528 {csinfo_.max_pool3d, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
529 CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
530 rinfo_.push_back({csinfo_.max_pool3d_grad,
531 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
532 CopyAttrsAll, Maxpool3DGradRewrite, GetRewriteCause()});
533 rinfo_.push_back(
534 {csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum),
535 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
536 rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
537 CopyAttrsAll, RewriteIfAtleastOneMklInput,
538 GetRewriteCause()});
539 rinfo_.push_back({csinfo_.pad_with_conv2d,
540 native_fmt ? csinfo_.mkl_native_pad_with_conv2d
541 : csinfo_.mkl_pad_with_conv2d,
542 CopyAttrsAllCheckConstFilter, AlwaysRewrite,
543 GetRewriteCause()});
544 rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
545 native_fmt ? csinfo_.mkl_native_pad_with_fused_conv2d
546 : csinfo_.mkl_pad_with_fused_conv2d,
547 CopyAttrsAllCheckConstFilter, AlwaysRewrite,
548 GetRewriteCause()});
549 rinfo_.push_back({csinfo_.quantized_avg_pool,
550 mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
551 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
552 rinfo_.push_back({csinfo_.quantized_concatv2,
553 mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
554 CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
555 rinfo_.push_back({csinfo_.quantized_conv2d,
556 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
557 CopyAttrsQuantizedConv2D, AlwaysRewrite,
558 GetRewriteCause()});
559 rinfo_.push_back(
560 {csinfo_.quantized_conv2d_per_channel,
561 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_per_channel),
562 CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
563 rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
564 mkl_op_registry::GetMklOpName(
565 csinfo_.quantized_conv2d_with_requantize),
566 CopyAttrsQuantizedConv2D, AlwaysRewrite,
567 GetRewriteCause()});
568 rinfo_.push_back(
569 {csinfo_.quantized_conv2d_with_bias,
570 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
571 CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
572 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
573 mkl_op_registry::GetMklOpName(
574 csinfo_.quantized_conv2d_with_bias_and_requantize),
575 CopyAttrsQuantizedConv2D, AlwaysRewrite,
576 GetRewriteCause()});
577 rinfo_.push_back(
578 {csinfo_.quantized_conv2d_and_relu,
579 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
580 CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
581 rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
582 mkl_op_registry::GetMklOpName(
583 csinfo_.quantized_conv2d_and_relu_and_requantize),
584 CopyAttrsQuantizedConv2D, AlwaysRewrite,
585 GetRewriteCause()});
586 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
587 mkl_op_registry::GetMklOpName(
588 csinfo_.quantized_conv2d_with_bias_and_relu),
589 CopyAttrsQuantizedConv2D, AlwaysRewrite,
590 GetRewriteCause()});
591 rinfo_.push_back(
592 {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
593 mkl_op_registry::GetMklOpName(
594 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
595 CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
596 rinfo_.push_back({csinfo_.quantized_max_pool,
597 mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
598 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
599 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
600 mkl_op_registry::GetMklOpName(
601 csinfo_.quantized_conv2d_with_bias_sum_and_relu),
602 CopyAttrsQuantizedConv2D, AlwaysRewrite,
603 GetRewriteCause()});
604 rinfo_.push_back(
605 {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
606 mkl_op_registry::GetMklOpName(
607 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
608 CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
609 rinfo_.push_back(
610 {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
611 mkl_op_registry::GetMklOpName(
612 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
613 CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
614 rinfo_.push_back(
615 {csinfo_.quantized_matmul_with_bias,
616 mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias),
617 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
618 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_relu,
619 mkl_op_registry::GetMklOpName(
620 csinfo_.quantized_matmul_with_bias_and_relu),
621 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
622 rinfo_.push_back(
623 {csinfo_.quantized_matmul_with_bias_and_relu_and_requantize,
624 mkl_op_registry::GetMklOpName(
625 csinfo_.quantized_matmul_with_bias_and_relu_and_requantize),
626 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
627 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_requantize,
628 mkl_op_registry::GetMklOpName(
629 csinfo_.quantized_matmul_with_bias_and_requantize),
630 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite});
631 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_dequantize,
632 mkl_op_registry::GetMklOpName(
633 csinfo_.quantized_matmul_with_bias_and_dequantize),
634 CopyAttrsQuantizedMatMulWithBiasAndDequantize,
635 AlwaysRewrite});
636 rinfo_.push_back(
637 {csinfo_.quantized_depthwise_conv2d,
638 mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d),
639 CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
640 rinfo_.push_back({csinfo_.quantized_depthwise_conv2d_with_bias,
641 mkl_op_registry::GetMklOpName(
642 csinfo_.quantized_depthwise_conv2d_with_bias),
643 CopyAttrsQuantizedConv2D, AlwaysRewrite,
644 GetRewriteCause()});
645 rinfo_.push_back(
646 {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu,
647 mkl_op_registry::GetMklOpName(
648 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu),
649 CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
650 rinfo_.push_back(
651 {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize,
652 mkl_op_registry::GetMklOpName(
653 csinfo_
654 .quantized_depthwise_conv2d_with_bias_and_relu_and_requantize),
655 CopyAttrsQuantizedConv2D, AlwaysRewrite, GetRewriteCause()});
656 rinfo_.push_back({csinfo_.quantize_v2,
657 mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
658 CopyAttrsAll, QuantizeOpRewrite, GetRewriteCause()});
659 rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
660 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
661 rinfo_.push_back({csinfo_.relu_grad,
662 mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
663 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
664 rinfo_.push_back({csinfo_.relu6,
665 mkl_op_registry::GetMklOpName(csinfo_.relu6),
666 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
667 rinfo_.push_back({csinfo_.relu6_grad,
668 mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
669 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
670 rinfo_.push_back({csinfo_.requantize,
671 mkl_op_registry::GetMklOpName(csinfo_.requantize),
672 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
673 // Optimized TanhGrad support exists only in DNNL 1.x.
674 rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
675 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
676 rinfo_.push_back({csinfo_.tanh_grad,
677 mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
678 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
679 rinfo_.push_back({csinfo_.reshape,
680 mkl_op_registry::GetMklOpName(csinfo_.reshape),
681 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
682 rinfo_.push_back(
683 {csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice),
684 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
685 rinfo_.push_back({csinfo_.softmax,
686 mkl_op_registry::GetMklOpName(csinfo_.softmax),
687 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
688
689 rinfo_.push_back({csinfo_.squared_difference,
690 mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
691 CopyAttrsAll, RewriteIfAtleastOneMklInput,
692 GetRewriteCause()});
693 rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
694 CopyAttrsAll, RewriteIfAtleastOneMklInput,
695 GetRewriteCause()});
696 rinfo_.push_back({csinfo_.transpose,
697 mkl_op_registry::GetMklOpName(csinfo_.transpose),
698 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
699
700 // Add info about which ops to add workspace edge to and the slots.
701 wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
702 wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
703 wsinfo_.push_back(
704 {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
705
706 // Add a rule for merging nodes
707 minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
708 csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
709
710 // Merge Pad and Conv2d, only if the pad op is "Pad"
711 // Doesn't merge if pad op is "PadV2" or "MirrorPad"
712 minfo_.push_back(
713 {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D});
714
715 minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
716 csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
717
718 if (!native_fmt) {
719 minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
720 csinfo_.conv2d_grad_filter_with_bias,
721 GetConv2DBackpropFilterOrBiasAddGrad});
722
723 // The fusion patterns in "finfo_" that show up first will get applied
724 // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
725 // A->B->C->D to ABCD}, since the first gets applied first, the final
726 // graph will be ABC->D.
727
728 //
729 // Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D
730 // (NHWC) + Transpose (NHWC->
731 // NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras.
732 // Note: we use the term "merge" to combine (exactly) 2 nodes into one,
733 // while "fusion" is for 3+ nodes situation.
734 //
735
736 // Transpose + Conv2d + Transpose:
737 std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H,
738 NCHW::dim::W, NCHW::dim::C};
739 std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C,
740 NHWC::dim::H, NHWC::dim::W};
741 auto CheckForTransposeToNHWC = std::bind(
742 CheckForTranspose, std::placeholders::_1, transpose_to_nhwc);
743 auto CheckForConv2dOp =
744 std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d);
745 auto CheckForTransposeToNCHW = std::bind(
746 CheckForTranspose, std::placeholders::_1, transpose_to_nchw);
747 auto FuseConv2D =
748 std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
749 std::placeholders::_2, std::placeholders::_3, "NCHW");
750 finfo_.push_back(
751 {"transpose-elimination for Conv2D",
752 {CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW},
753 // CheckForMklOp
754 FuseConv2D,
755 CopyAttrsConv});
756
757 // Transpose + Conv3d + Transpose:
758 std::vector<int> transpose_to_ndhwc = {NCDHW::dim::N, NCDHW::dim::D,
759 NCDHW::dim::H, NCDHW::dim::W,
760 NCDHW::dim::C};
761 std::vector<int> transpose_to_ncdhw = {NDHWC::dim::N, NDHWC::dim::C,
762 NDHWC::dim::D, NDHWC::dim::H,
763 NDHWC::dim::W};
764
765 auto CheckForTransposeToNDHWC = std::bind(
766 CheckForTranspose, std::placeholders::_1, transpose_to_ndhwc);
767 auto CheckForConv3dOp =
768 std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv3d);
769 auto CheckForTransposeToNCDHW = std::bind(
770 CheckForTranspose, std::placeholders::_1, transpose_to_ncdhw);
771 auto FuseConv3D =
772 std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
773 std::placeholders::_2, std::placeholders::_3, "NCDHW");
774
775 finfo_.push_back({"transpose-elimination for Conv3D",
776 {CheckForTransposeToNDHWC, CheckForConv3dOp,
777 CheckForTransposeToNCDHW},
778 // CheckForMklOp
779 FuseConv3D,
780 CopyAttrsConv});
781
782 auto CheckForMaxPool3DOp =
783 std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.max_pool3d);
784 auto FuseMaxPool3D =
785 std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
786 std::placeholders::_2, std::placeholders::_3, "NCDHW");
787 finfo_.push_back({"transpose-elimination for MaxPool3D",
788 {CheckForTransposeToNDHWC, CheckForMaxPool3DOp,
789 CheckForTransposeToNCDHW},
790 // CheckForMklOp
791 FuseMaxPool3D,
792 CopyAttrsPooling});
793 }
794 }
795
796 // Standard interface to run pass
797 Status Run(const GraphOptimizationPassOptions& options);
798
799 // Helper function which does most of heavy lifting for rewriting
800 // Mkl nodes to propagate Mkl tensor as additional output
801 //
802 // Extracts common functionality between Run public interface and
803 // test interface.
804 //
805 // @return true, if and only if graph is mutated; false otherwise.
806 bool RunPass(std::unique_ptr<Graph>* g);
807
808 /// Cause for rewrite
809 /// Currently, we only support 2 causes - either for Mkl layout propagation
810 /// which is the most common case, or for just a name change (used in case
811 /// of ops like MatMul, Transpose, which do not support Mkl layout)
812 enum RewriteCause { kRewriteForLayoutPropagation, kRewriteForOpNameChange };
813
814 // Get the op rewrite cause depending on whether native format mode
815 // is enabled or not.
GetRewriteCause()816 RewriteCause GetRewriteCause() {
817 if (NativeFormatEnabled()) {
818 return kRewriteForOpNameChange;
819 } else {
820 return kRewriteForLayoutPropagation;
821 }
822 }
823
824 /// Structure to specify the name of an original node, its new name after
825 /// rewrite, the number of inputs to the original node, the function to
826 /// be used to copy attributes for the op, and the rule (if any) which
827 /// must hold for rewriting the node
828 typedef struct {
829 string name; // Original name of op of the node in the graph
830 string new_name; // New name of the op of the node in the graph
831 // A function handler to copy attributes from an old node to a new node.
832 std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs;
833 // A rule under which to rewrite this node
834 std::function<bool(const Node*)> rewrite_rule;
835 // Why are we rewriting?
836 RewriteCause rewrite_cause;
837 } RewriteInfo;
838
839 /// Structure to specify a forward op, a backward op, and the slot numbers
840 /// in the forward and backward ops where we will add a workspace edge.
841 typedef struct {
842 string fwd_op; // Name of a forward op in the graph
843 string bwd_op; // Name of a backward op in the graph
844 int fwd_slot; // Output slot in the forward op node where actual
845 // output tensor resides
846 int bwd_slot; // Input slot in the backward op node where actual
847 // input tensor resides
848 int ws_fwd_slot; // Output slot in the forward op node where workspace
849 // edge is added
850 int ws_bwd_slot; // Input slot in the backward op node where workspace
851 // edge is added
852 } WorkSpaceInfo;
853
854 /// Structure to specify information used in node merge of 2 operators
855 typedef struct {
856 string op1; // Node string for one operator.
857 string op2; // Node string for second operator.
858 string new_node; // Name of the node after merge
859 // Function that enables user of the node merger to specify how to find
860 // second operator given the first operator.
861 std::function<Node*(const Node*)> get_node_to_be_merged;
862 } MergeInfo;
863
864 // Structure to specify information used in node fusion of 3+ operators
865 typedef struct {
866 std::string pattern_name; // Name to describe this pattern, such as
867 // "Transpose_Mklop_Transpose".
868 std::vector<std::function<bool(const Node*)> >
869 node_checkers; // Extra restriction checker for these ops
870 std::function<Status(
871 std::unique_ptr<Graph>*, std::vector<Node*>&,
872 std::function<void(const Node*, NodeBuilder* nb, bool)>)>
873 fuse_func;
874 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs;
875 } FusionInfo;
876
877 //
878 // Dimension indices for 2D tensor.
879 //
880 struct NCHW {
881 enum dim { N = 0, C = 1, H = 2, W = 3 };
882 };
883
884 struct NHWC {
885 enum dim { N = 0, H = 1, W = 2, C = 3 };
886 };
887
888 //
889 // dimension indices for 3D tensor.
890 //
891 struct NCDHW {
892 enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 };
893 };
894
895 struct NDHWC {
896 enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 };
897 };
898
899 /// Structure to store all constant strings
900 /// NOTE: names are alphabetically sorted.
901 typedef struct {
902 string addn;
903 string add;
904 string add_v2;
905 string avg_pool;
906 string avg_pool_grad;
907 string avg_pool3d;
908 string avg_pool3d_grad;
909 string batch_matmul;
910 string batch_matmul_v2;
911 string bias_add;
912 string bias_add_grad;
913 string concat;
914 string concatv2;
915 string conjugate_transpose;
916 string conv2d;
917 string conv2d_with_bias;
918 string conv2d_grad_input;
919 string conv2d_grad_filter;
920 string conv2d_grad_filter_with_bias;
921 string conv3d;
922 string conv3d_grad_input;
923 string conv3d_grad_filter;
924 string depthwise_conv2d;
925 string depthwise_conv2d_grad_input;
926 string depthwise_conv2d_grad_filter;
927 string dequantize;
928 string fused_batch_norm;
929 string fused_batch_norm_grad;
930 string fused_batch_norm_ex;
931 string fused_batch_norm_v2;
932 string fused_batch_norm_grad_v2;
933 string fused_batch_norm_v3;
934 string fused_batch_norm_grad_v3;
935 string fused_conv2d;
936 string fused_depthwise_conv2d;
937 string fused_matmul;
938 string identity;
939 string leakyrelu;
940 string leakyrelu_grad;
941 string lrn;
942 string lrn_grad;
943 string matmul;
944 string max_pool;
945 string max_pool_grad;
946 string max_pool3d;
947 string max_pool3d_grad;
948 string maximum;
949 string mkl_conv2d;
950 string mkl_conv2d_grad_input;
951 string mkl_conv2d_grad_filter;
952 string mkl_conv2d_grad_filter_with_bias;
953 string mkl_conv2d_with_bias;
954 string mkl_depthwise_conv2d_grad_input;
955 string mkl_depthwise_conv2d_grad_filter;
956 string mkl_fused_batch_norm_ex;
957 string mkl_fused_conv2d;
958 string mkl_fused_depthwise_conv2d;
959 string mkl_fused_matmul;
960 string mkl_native_conv2d_with_bias;
961 string mkl_native_fused_batch_norm_ex;
962 string mkl_native_fused_conv2d;
963 string mkl_native_fused_depthwise_conv2d;
964 string mkl_native_fused_matmul;
965 string mkl_native_pad_with_conv2d;
966 string mkl_native_pad_with_fused_conv2d;
967 string mkl_pad_with_conv2d;
968 string mkl_pad_with_fused_conv2d;
969 string mul;
970 string pad;
971 string pad_with_conv2d;
972 string pad_with_fused_conv2d;
973 string quantized_avg_pool;
974 string quantized_conv2d;
975 string quantized_conv2d_per_channel;
976 string quantized_conv2d_with_requantize;
977 string quantized_conv2d_with_bias;
978 string quantized_conv2d_with_bias_and_requantize;
979 string quantized_conv2d_and_relu;
980 string quantized_conv2d_and_relu_and_requantize;
981 string quantized_conv2d_with_bias_and_relu;
982 string quantized_conv2d_with_bias_and_relu_and_requantize;
983 string quantized_concatv2;
984 string quantized_max_pool;
985 string quantized_conv2d_with_bias_sum_and_relu;
986 string quantized_conv2d_with_bias_sum_and_relu_and_requantize;
987 string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize;
988 string quantized_matmul_with_bias;
989 string quantized_matmul_with_bias_and_relu;
990 string quantized_matmul_with_bias_and_relu_and_requantize;
991 string quantized_matmul_with_bias_and_requantize;
992 string quantized_matmul_with_bias_and_dequantize;
993 string quantized_depthwise_conv2d;
994 string quantized_depthwise_conv2d_with_bias;
995 string quantized_depthwise_conv2d_with_bias_and_relu;
996 string quantized_depthwise_conv2d_with_bias_and_relu_and_requantize;
997 string quantize_v2;
998 string relu;
999 string relu_grad;
1000 string relu6;
1001 string relu6_grad;
1002 string requantize;
1003 string tanh;
1004 string tanh_grad;
1005 string transpose;
1006 string reshape;
1007 string slice;
1008 string softmax;
1009 string split;
1010 string squared_difference;
1011 string sub;
1012 } ConstStringsInfo;
1013
1014 private:
1015 /// Maintain info about nodes to rewrite
1016 std::vector<RewriteInfo> rinfo_;
1017
1018 /// Maintain info about nodes to add workspace edge
1019 std::vector<WorkSpaceInfo> wsinfo_;
1020
1021 /// Maintain info about nodes to be merged
1022 std::vector<MergeInfo> minfo_;
1023
1024 /// Maintain info about nodes to be fused
1025 std::vector<FusionInfo> finfo_;
1026
1027 /// Maintain structure of constant strings
1028 static ConstStringsInfo csinfo_;
1029
1030 private:
1031 // Is OpDef::ArgDef a list type? It could be N * T or list(type).
1032 // Refer to opdef.proto for details of list type.
ArgIsList(const OpDef::ArgDef & arg) const1033 inline bool ArgIsList(const OpDef::ArgDef& arg) const {
1034 return !arg.type_list_attr().empty() || !arg.number_attr().empty();
1035 }
1036
1037 // Get length of a list in 'n' if 'arg' is of list type. Refer to
1038 // description of ArgIsList for definition of list type.
GetTensorListLength(const OpDef::ArgDef & arg,const Node * n)1039 inline int GetTensorListLength(const OpDef::ArgDef& arg, const Node* n) {
1040 CHECK_EQ(ArgIsList(arg), true);
1041 int N = 0;
1042 const string attr_name = !arg.type_list_attr().empty()
1043 ? arg.type_list_attr()
1044 : arg.number_attr();
1045 if (!arg.type_list_attr().empty()) {
1046 std::vector<DataType> value;
1047 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
1048 N = value.size();
1049 } else {
1050 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
1051 }
1052 return N;
1053 }
1054
1055 // Can op represented by node 'n' run on DEVICE_CPU?
1056 // Op can run on CPU with MKL if the runtime assigned device or the
1057 // user requested device contains device CPU, or both are empty.
CanOpRunOnCPUDevice(const Node * n)1058 bool CanOpRunOnCPUDevice(const Node* n) {
1059 bool result = true;
1060 string reason;
1061
1062 // Substring that should be checked for in device name for CPU device.
1063 const char* const kCPUDeviceSubStr = "CPU";
1064
1065 // If Op has been specifically assigned to a non-CPU device, then No.
1066 if (!n->assigned_device_name().empty() &&
1067 !absl::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
1068 result = false;
1069 reason = "Op has been assigned a runtime device that is not CPU.";
1070 }
1071
1072 // If user has specifically assigned this op to a non-CPU device, then No.
1073 if (!n->def().device().empty() &&
1074 !absl::StrContains(n->def().device(), kCPUDeviceSubStr)) {
1075 result = false;
1076 reason = "User has assigned a device that is not CPU.";
1077 }
1078
1079 if (result == false) {
1080 VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
1081 << n->type_string() << ", reason: " << reason;
1082 }
1083
1084 // Otherwise Yes.
1085 return result;
1086 }
1087
1088 // Return a node that can be merged with input node 'n'
1089 //
1090 // @return pointer to the node if we can find such a
1091 // node. Otherwise, it returns nullptr.
1092 Node* CheckForNodeMerge(const Node* n) const;
1093
1094 // Merge node 'm' with node 'n'.
1095 // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with
1096 // Conv2DBackpropFilter.
1097 //
1098 // Input nodes m and n may be deleted if the call to
1099 // this function is successful. Attempt to use the pointers
1100 // after the call to function may result in undefined behaviors.
1101 //
1102 // @input g - input graph, m - graph node, n - graph node to be merged with m
1103 // @return Status::OK(), if merging is successful and supported.
1104 // Returns appropriate Status error code otherwise.
1105 // Graph is updated in case nodes are merged. Otherwise, it is
1106 // not updated.
1107 Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n);
1108
1109 // Helper function to merge different nodes
1110 Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n);
1111 Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n);
1112 Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g,
1113 Node* m, Node* n);
1114
1115 // Find BiasAdd or Conv2D node that can be merged with input node 'm'.
1116 // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be
1117 // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd
1118 // node that can be merged with 'm'.
GetConv2DOrBiasAdd(const Node * m)1119 static Node* GetConv2DOrBiasAdd(const Node* m) {
1120 DCHECK(m);
1121 Node* n = nullptr;
1122
1123 DataType T_m;
1124 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1125
1126 #ifndef ENABLE_INTEL_MKL_BFLOAT16
1127 // Don't try to merge if datatype is not DT_FLOAT
1128 if (T_m != DT_FLOAT) return n;
1129 #else
1130 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1131 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1132 #endif
1133
1134 if (m->type_string() == csinfo_.bias_add) {
1135 // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
1136 TF_CHECK_OK(m->input_node(0, &n));
1137 } else {
1138 CHECK_EQ(m->type_string(), csinfo_.conv2d);
1139 // Go over all output edges and search for BiasAdd Node.
1140 // 0th input of BiasAdd is Conv2D.
1141 for (const Edge* e : m->out_edges()) {
1142 if (!e->IsControlEdge() &&
1143 e->dst()->type_string() == csinfo_.bias_add &&
1144 e->dst_input() == 0) {
1145 n = e->dst();
1146 break;
1147 }
1148 }
1149 }
1150
1151 if (n == nullptr) {
1152 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1153 << "Conv2D and BiasAdd node for merging. Input node: "
1154 << m->DebugString();
1155 }
1156
1157 return n;
1158 }
1159
1160 // Find Pad or Conv2D node that can be merged with input node 'm'.
1161 // If input 'm' is Pad, then check if there exists Conv2D node that can be
1162 // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad
1163 // node that can be merged with 'm'.
GetPadOrConv2D(const Node * m)1164 static Node* GetPadOrConv2D(const Node* m) {
1165 DCHECK(m);
1166 Node* n = nullptr;
1167
1168 DataType T_m;
1169 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1170
1171 #ifndef ENABLE_INTEL_MKL_BFLOAT16
1172 // Don't try to merge if datatype is not DT_FLOAT
1173 if (T_m != DT_FLOAT) return n;
1174 #else
1175 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1176 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1177 #endif
1178
1179 const Node* conv_node;
1180 if (m->type_string() == csinfo_.pad) {
1181 // If m is Pad, then Conv2D is the output of Pad.
1182 for (const Edge* e : m->out_edges()) {
1183 if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) {
1184 n = e->dst();
1185 conv_node = n;
1186 break;
1187 }
1188 }
1189 } else {
1190 DCHECK_EQ(m->type_string(), csinfo_.conv2d);
1191 // If m is conv2D, Go over all input edges
1192 // and search for Pad Node.
1193 for (const Edge* e : m->in_edges()) {
1194 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1195 n = e->src();
1196 conv_node = m;
1197 break;
1198 }
1199 }
1200 }
1201 // Check if only VALID type of padding is used
1202 // or not.
1203 if (n != nullptr) {
1204 string padding;
1205 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1206 if (padding != "VALID")
1207 // Then do not merge.
1208 // Only VALID type of padding in conv op can be
1209 // merged with Pad op.
1210 n = nullptr;
1211 } else {
1212 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1213 << "Pad and Conv2D node for merging. Input node: "
1214 << m->DebugString();
1215 }
1216
1217 return n;
1218 }
1219
1220 // Find Pad or _FusedConv2D node that can be merged with input node 'm'.
1221 // If input 'm' is Pad, then check if there exists _FusedConv2D node that can
1222 // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there
1223 // exists Pad node that can be merged with 'm'.
GetPadOrFusedConv2D(const Node * m)1224 static Node* GetPadOrFusedConv2D(const Node* m) {
1225 DCHECK(m);
1226 Node* n = nullptr;
1227
1228 const Node* conv_node;
1229 if (m->type_string() == csinfo_.pad) {
1230 // If m is Pad, then _FusedConv2D is the output of Pad.
1231 for (const Edge* e : m->out_edges()) {
1232 if (!e->IsControlEdge() &&
1233 e->dst()->type_string() == csinfo_.fused_conv2d) {
1234 n = e->dst();
1235 conv_node = n;
1236 break;
1237 }
1238 }
1239 } else {
1240 DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d);
1241 // If m is _FusedConv2D, Go over all input edges
1242 // and search for Pad node.
1243 for (const Edge* e : m->in_edges()) {
1244 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1245 n = e->src();
1246 conv_node = m;
1247 break;
1248 }
1249 }
1250 }
1251 // Check if only VALID type of padding is used or not.
1252 if (n != nullptr) {
1253 string padding;
1254 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1255 if (padding != "VALID") {
1256 // Then do not merge.
1257 n = nullptr;
1258 VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
1259 << "nodes but cannot merge them. Only conv ops with padding "
1260 << "type VALID can be merged with Pad op Input node: "
1261 << m->DebugString();
1262 }
1263 } else {
1264 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1265 << "Pad and _FusedConv2D node for merging. Input node: "
1266 << m->DebugString();
1267 }
1268
1269 return n;
1270 }
1271
1272 // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input
1273 // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists
1274 // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad,
1275 // then check if there exists Conv2DBackpropFilter node that can be merged
1276 // with 'm'.
1277 //
1278 // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad
1279 // would look like:
1280 //
1281 // _ = Conv2DBackpropFilter(F, _, G)
1282 // _ = BiasAddGrad(G)
1283 //
1284 // So 1st input of BiasAddGrad connects with 3rd input of
1285 // Conv2DBackpropFilter and vice versa.
GetConv2DBackpropFilterOrBiasAddGrad(const Node * m)1286 static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) {
1287 DCHECK(m);
1288 Node* n = nullptr;
1289 const Node* conv2d_backprop_filter = nullptr;
1290
1291 DataType T_m;
1292 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1293
1294 #ifndef ENABLE_INTEL_MKL_BFLOAT16
1295 // Don't try to merge if datatype is not DT_FLOAT
1296 if (T_m != DT_FLOAT) return n;
1297 #else
1298 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1299 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1300 #endif
1301
1302 if (m->type_string() == csinfo_.bias_add_grad) {
1303 // Get 1st input 'g' of BiasAddGrad.
1304 Node* g = nullptr;
1305 TF_CHECK_OK(m->input_node(0, &g));
1306 // Now traverse all outgoing edges from g that have destination node as
1307 // Conv2DBackpropFilter.
1308 for (const Edge* e : g->out_edges()) {
1309 if (!e->IsControlEdge() &&
1310 e->dst()->type_string() == csinfo_.conv2d_grad_filter &&
1311 e->dst_input() == 2 /* 3rd input of BackpropFilter */) {
1312 n = e->dst();
1313 conv2d_backprop_filter = n;
1314 break;
1315 }
1316 }
1317 } else {
1318 conv2d_backprop_filter = m;
1319 CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter);
1320 // Get 3rd input 'g' of Conv2DBackpropFilter.
1321 Node* g = nullptr;
1322 TF_CHECK_OK(m->input_node(2, &g));
1323 // Now traverse all outgoing edges from g that have destination node as
1324 // BiasAddGrad.
1325 for (const Edge* e : g->out_edges()) {
1326 if (!e->IsControlEdge() &&
1327 e->dst()->type_string() == csinfo_.bias_add_grad &&
1328 e->dst_input() == 0 /* 1st input of BiasAddGrad */) {
1329 n = e->dst();
1330 break;
1331 }
1332 }
1333 }
1334
1335 // Do not merge if padding type is EXPLICIT.
1336 // TODO(intel): Support `EXPLICIT` padding for MklConv2DBackpropFilter.
1337 if (conv2d_backprop_filter != nullptr) {
1338 string padding;
1339 TF_CHECK_OK(
1340 GetNodeAttr(conv2d_backprop_filter->def(), "padding", &padding));
1341 if (padding == "EXPLICIT") {
1342 // Then do not merge.
1343 VLOG(1) << "MklLayoutRewritePass: Could match Conv2DBackpropFilter "
1344 << "and BiasAddGrad nodes but cannot merge them. "
1345 << "EXPLICIT padding is not supported now. "
1346 << conv2d_backprop_filter->DebugString();
1347 return nullptr;
1348 }
1349 }
1350
1351 if (n == nullptr) {
1352 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1353 << "Conv2DBackpropFilter and BiasAddGrad node for merging. "
1354 << "Input node: " << m->DebugString();
1355 }
1356 return n;
1357 }
1358
1359 // Return a node that can be fused with input node 'n'
1360 //
1361 // @return tuple. If we can find such nodes, the first
1362 // element of the tuple is a true. Otherwise, it's false.
1363 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
1364 CheckForNodeFusion(Node* n) const;
1365
1366 // Fuse nodes in the vector "nodes"
1367 Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1368 const MklLayoutRewritePass::FusionInfo fi);
1369
1370 // Fuse transpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into
1371 // mklop("NCHW").
1372 // Here "mklop" can be any MKL-DNN supported op, such as Conv2D.
1373 static Status FuseTransposeMklOpTranspose(
1374 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1375 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
1376 string data_format);
1377
CheckForTranspose(const Node * node,std::vector<int> perm)1378 static bool CheckForTranspose(const Node* node, std::vector<int> perm) {
1379 // Check if node's type is "Transpose"
1380 if (node->type_string() != "Transpose") return false;
1381
1382 // If "Transpose" has multiple output data edges, also don't fuse it.
1383 if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false;
1384
1385 // Check if has out control edge. If true, this is a training graph.
1386 // Currently we focus on inference and do no fusion in training.
1387 // Note: this constraint will eventually be removed, if we enabled this
1388 // fusion for training
1389 // in the future.
1390 for (const Edge* e : node->out_edges()) {
1391 if (e->IsControlEdge()) {
1392 return false;
1393 }
1394 }
1395
1396 // If "Transpose" has input control edges, don't fuse on it.
1397 for (const Edge* e : node->in_edges()) {
1398 if (e->IsControlEdge()) {
1399 return false;
1400 }
1401 }
1402
1403 // We compared the tensor containing the permutation order ("perm_node")
1404 // with our desired order ("perm"). If they're exactly match, this check
1405 // succeed and returns true.
1406 for (const Edge* e : node->in_edges()) {
1407 if (!e->IsControlEdge()) {
1408 const Node* perm_node = e->src();
1409
1410 const int kPermTensorIndex = 1;
1411 if (perm_node->type_string() == "Const" &&
1412 e->dst_input() == kPermTensorIndex) {
1413 // we find the "perm" node, now try to retrieve its value.
1414 const TensorProto* proto = nullptr;
1415 TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value", &proto));
1416
1417 DataType type;
1418 TF_CHECK_OK(GetNodeAttr(perm_node->def(), "dtype", &type));
1419
1420 Tensor tensor;
1421 if (!tensor.FromProto(*proto)) {
1422 TF_CHECK_OK(errors::InvalidArgument(
1423 "Could not construct Tensor from TensorProto in node: ",
1424 node->name()));
1425 return false;
1426 }
1427 // Current fusion only supports 4D or 5D tensors according to `perm`
1428 // vector, return false otherwise.
1429 if (tensor.dim_size(0) != perm.size()) return false;
1430 DCHECK_EQ(tensor.dims(), 1);
1431 if (type == DT_INT32) {
1432 const auto tensor_content = tensor.flat<int>().data();
1433 for (int i = 0; i < perm.size(); ++i)
1434 if (tensor_content[i] != perm[i]) return false;
1435 return true;
1436 } else if (type == DT_INT64) {
1437 const auto tensor_content = tensor.flat<int64>().data();
1438 for (int i = 0; i < perm.size(); ++i)
1439 if (tensor_content[i] != perm[i]) return false;
1440 return true;
1441 }
1442 return false;
1443 }
1444 }
1445 }
1446 return false;
1447 }
1448
CheckForMklOp(const Node * node,string name="")1449 static bool CheckForMklOp(const Node* node, string name = "") {
1450 if (node == nullptr) return false;
1451
1452 if (!name.empty() && node->type_string() != name) {
1453 return false;
1454 }
1455
1456 // if mklop has multiple outputs, don't fuse it.
1457 if (node->num_outputs() > 1) return false;
1458
1459 if (node->out_edges().size() > 1) return false;
1460
1461 DataType T;
1462 TF_CHECK_OK(GetNodeAttr(node->def(), "T", &T));
1463 return mkl_op_registry::IsMklLayoutDependentOp(
1464 mkl_op_registry::GetMklOpName(node->type_string()), T);
1465 }
1466
1467 // Check if the node 'n' has any applicable rewrite rule
1468 // We check for 2 scenarios for rewrite.
1469 //
1470 // @return RewriteInfo* for the applicable rewrite rule
1471 const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
1472 const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const;
1473
1474 // Default rewrite rule to be used in scenario 1 for rewrite.
1475 // @return - true (since we want to always rewrite)
AlwaysRewrite(const Node * n)1476 static bool AlwaysRewrite(const Node* n) { return true; }
1477
1478 // Rewrite rule which considers "context" of the current node to decide if we
1479 // should rewrite. By "context" we currently mean all the inputs of current
1480 // node. The idea is if none of the inputs of current node are not MKL nodes,
1481 // then rewriting current node to MKL node _may not_ offer any performance
1482 // improvement.
1483 //
1484 // One such case is element-wise ops. For such ops, we reuse the Eigen
1485 // implementation and pass the MKL metadata tensor through so we can avoid
1486 // conversions. However, if all incoming edges are in TF format, we don't
1487 // need all this overhead, so replace the elementwise node only if at least
1488 // one of its parents is a MKL node.
1489 //
1490 // More generally, all memory- or IO-bound ops (such as Identity) may fall
1491 // under this category.
1492 //
1493 // @input - Input graph node to be rewritten
1494 // @return - true if node is to be rewritten as MKL node; false otherwise.
RewriteIfAtleastOneMklInput(const Node * n)1495 static bool RewriteIfAtleastOneMklInput(const Node* n) {
1496 DataType T;
1497 if (GetNodeAttr(n->def(), "T", &T).ok() &&
1498 mkl_op_registry::IsMklOp(
1499 mkl_op_registry::GetMklOpName(n->type_string()), T)) {
1500 for (auto e : n->in_edges()) {
1501 if (e->IsControlEdge()) continue;
1502 if (mkl_op_registry::IsMklOp(e->src())) {
1503 return true;
1504 }
1505 }
1506 }
1507 return false;
1508 }
1509
MatMulRewrite(const Node * n)1510 static bool MatMulRewrite(const Node* n) {
1511 DataType T;
1512 TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
1513 if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) {
1514 VLOG(2) << "Rewriting MatMul to _MklMatMul";
1515 return true;
1516 }
1517 return false;
1518 }
1519 // For oneDNN, only int32 is supported for axis data type
ConcatV2Rewrite(const Node * n)1520 static bool ConcatV2Rewrite(const Node* n) {
1521 DataType T;
1522 TF_CHECK_OK(GetNodeAttr(n->def(), "Tidx", &T));
1523 return (T == DT_INT32);
1524 }
1525
DequantizeRewrite(const Node * n)1526 static bool DequantizeRewrite(const Node* n) {
1527 DCHECK(n);
1528 Node* input = nullptr;
1529 TF_CHECK_OK(n->input_node(0, &input));
1530 string mode_string;
1531 TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1532 if (mode_string != "SCALED") {
1533 VLOG(1) << "DequantizeRewrite: Mode is not SCALED. "
1534 << "This case is not optimized by Intel MKL kernel, thus using "
1535 "Eigen op for Dequantize op.";
1536 return false;
1537 }
1538 if (input->IsConstant()) {
1539 VLOG(1) << "DequantizeRewrite: Trying to dequantize a Const node which "
1540 << "could possibly be a filter. "
1541 << "This case is not supported by Intel MKL kernel, thus using "
1542 "Eigen op for Dequantize op.";
1543 return false;
1544 }
1545 return true;
1546 }
1547
1548 // Rewrite rule for _FusedMatMul.
1549 // @return - true (no transpose attribute for input 1);
1550 // false otherwise.
FusedMatMulRewrite(const Node * n)1551 static bool FusedMatMulRewrite(const Node* n) {
1552 bool trans_a;
1553
1554 // Do not rewrite with transpose attribute because reorder has performance
1555 // impact.
1556 TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &trans_a));
1557
1558 return !trans_a;
1559 }
1560
1561 // Check if we are performing pooling on depth or batch. If it is, then we
1562 // do not rewrite MaxPool node to Mkl version.
1563 // @return - true (if it is not a depth/batch wise pooling case);
1564 // false otherwise.
NonDepthBatchWisePoolRewrite(const Node * n)1565 static bool NonDepthBatchWisePoolRewrite(const Node* n) {
1566 DCHECK(n);
1567
1568 string data_format_str;
1569 TensorFormat data_format;
1570 std::vector<int32> ksize, strides;
1571 TF_CHECK_OK(GetNodeAttr(n->def(), "ksize", &ksize));
1572 TF_CHECK_OK(GetNodeAttr(n->def(), "strides", &strides));
1573 TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format_str));
1574 bool result = FormatFromString(data_format_str, &data_format);
1575 DCHECK(result);
1576
1577 // Condition that specifies non-batch-wise and non-depth-wise pooling.
1578 if (GetTensorDim(ksize, data_format, 'N') == 1 &&
1579 GetTensorDim(strides, data_format, 'N') == 1 &&
1580 GetTensorDim(ksize, data_format, 'C') == 1 &&
1581 GetTensorDim(strides, data_format, 'C') == 1) {
1582 return true;
1583 }
1584
1585 return false;
1586 }
1587
1588 // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
1589 // path. The unoptimized path is slow. Thus we don't rewrite the node
1590 // and use default Eigen. But for depth_radius=2, MKL DNN optimized
1591 // path is taken, i.e., eigen node is rewritten by MKl DNN node.
LrnRewrite(const Node * n)1592 static bool LrnRewrite(const Node* n) {
1593 DCHECK(n);
1594
1595 int depth_radius;
1596 TF_CHECK_OK(GetNodeAttr(n->def(), "depth_radius", &depth_radius));
1597
1598 // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
1599 // and use eigen node instead
1600 if (depth_radius == 2) {
1601 return true;
1602 }
1603 VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
1604 << "case is not optimized by Intel MKL, thus using Eigen op"
1605 << "for LRN ";
1606
1607 return false;
1608 }
1609
LrnGradRewrite(const Node * n)1610 static bool LrnGradRewrite(const Node* n) {
1611 DCHECK(n);
1612 bool do_rewrite = false;
1613
1614 for (const Edge* e : n->in_edges()) {
1615 // Rewrite only if there is corresponding LRN, i.e workspace is available
1616 if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
1617 e->src()->type_string() ==
1618 mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
1619 e->src_output() == 0) {
1620 do_rewrite = true;
1621 break;
1622 }
1623 }
1624 return do_rewrite;
1625 }
1626
1627 // MKL-DNN's LeakyRelu(feature) = feature (if feature > 0), or
1628 // feature * alpha (otherwise),
1629 // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha).
1630 // These two algorithms are not consistent when alpha > 1,
1631 // so we only rewrite LeakyRelu to MKL OP when alpha <= 1.
LeakyReluRewrite(const Node * n)1632 static bool LeakyReluRewrite(const Node* n) {
1633 DCHECK(n);
1634
1635 float alpha;
1636 bool has_attr = TryGetNodeAttr(n->def(), "alpha", &alpha);
1637 DCHECK(has_attr);
1638
1639 // If the alpha of LeakyRelu is less than 1, rewrite the node.
1640 // Otherwise eigen node is used instead.
1641 if (alpha <= 1) {
1642 return true;
1643 }
1644 VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 "
1645 << "which case is not optimized by Intel MKL, thus using Eigen op"
1646 << "for LeakyRelu ";
1647
1648 return false;
1649 }
1650
QuantizeOpRewrite(const Node * n)1651 static bool QuantizeOpRewrite(const Node* n) {
1652 DCHECK(n);
1653 Node* filter_node = nullptr;
1654 TF_CHECK_OK(n->input_node(0, &filter_node));
1655 bool narrow_range = false;
1656 int axis = -1;
1657 string mode_string;
1658 string round_mode_string;
1659 DataType type;
1660 TryGetNodeAttr(n->def(), "narrow_range", &narrow_range);
1661 TryGetNodeAttr(n->def(), "axis", &axis);
1662 TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1663 TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode", &round_mode_string));
1664 TF_CHECK_OK(GetNodeAttr(n->def(), "T", &type));
1665
1666 if (narrow_range) {
1667 VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization."
1668 << "This case is not optimized by Intel MKL, "
1669 << "thus using Eigen op for Quantize op ";
1670 return false;
1671 }
1672 if (axis != -1) {
1673 VLOG(1) << "QuantizeOpRewrite: dimension is specified for "
1674 << "per slice quantization."
1675 << "This case is not optimized by Intel MKL, "
1676 << "thus using Eigen op for Quantize op ";
1677 return false;
1678 }
1679 if (!((mode_string == "SCALED" && round_mode_string == "HALF_TO_EVEN") ||
1680 (mode_string == "MIN_FIRST"))) {
1681 VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED or MIN_FIRST and/or"
1682 << "rounding mode is not HALF_TO_EVEN. "
1683 << "This case is not optimized by Intel MKL, thus using Eigen op"
1684 << "for Quantize op ";
1685 return false;
1686 }
1687 if (filter_node->IsConstant()) {
1688 VLOG(1) << "QuantizeOpRewrite: Trying to quantize a node which "
1689 << "is a constant. "
1690 << "This case is not supported by the kernel, thus using Eigen op"
1691 << "for Quantize op ";
1692
1693 return false;
1694 }
1695 if (mode_string == "MIN_FIRST") {
1696 if (type != DT_QUINT8) {
1697 VLOG(1) << "QuantizeOpRewrite: For MIN_FIRST mode the data type is "
1698 << "not DT_UINT8. This case is not optimized by Intel MKL, "
1699 << "thus using Eigen op for Quantize op ";
1700 return false;
1701 }
1702 }
1703 return true;
1704 }
1705
MaxpoolGradRewrite(const Node * n)1706 static bool MaxpoolGradRewrite(const Node* n) {
1707 DCHECK(n);
1708 bool do_rewrite = false;
1709 for (const Edge* e : n->in_edges()) {
1710 // Rewrite only if there is corresponding Maxpool, i.e workspace is
1711 // available
1712 if (e->dst()->type_string() == csinfo_.max_pool_grad &&
1713 e->dst_input() == 1 &&
1714 e->src()->type_string() ==
1715 mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
1716 e->src_output() == 0) {
1717 do_rewrite = true;
1718 break;
1719 }
1720 }
1721 return do_rewrite;
1722 }
1723
Maxpool3DGradRewrite(const Node * n)1724 static bool Maxpool3DGradRewrite(const Node* n) {
1725 DCHECK(n);
1726 for (const Edge* e : n->in_edges()) {
1727 // Rewrite only if there is corresponding Maxpool3D, i.e., workspace is
1728 // available
1729 if (e->dst()->type_string() == csinfo_.max_pool3d_grad &&
1730 e->dst_input() == 1 &&
1731 e->src()->type_string() ==
1732 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d) &&
1733 e->src_output() == 0) {
1734 return true;
1735 }
1736 }
1737 return false;
1738 }
1739
FusedBatchNormV3Rewrite(const Node * n)1740 static bool FusedBatchNormV3Rewrite(const Node* n) {
1741 DCHECK(n);
1742 if (Check5DFormat(n->def())) {
1743 VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not "
1744 << "support 5D tensors.";
1745 return false;
1746 }
1747 return true;
1748 }
1749
FusedBatchNormExRewrite(const Node * n)1750 static bool FusedBatchNormExRewrite(const Node* n) {
1751 DCHECK(n);
1752
1753 int num_side_inputs;
1754 TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs));
1755 string activation_mode;
1756 TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode));
1757
1758 // if the num_side_inputs is not 0, don't rewrite the node.
1759 if (num_side_inputs != 0) {
1760 VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs"
1761 << "larger than 0 is not optimized by Intel MKL.";
1762 return false;
1763 }
1764
1765 // if the activation_mode is not 'Relu', don't rewrite the node.
1766 if (activation_mode != "Relu") {
1767 VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is"
1768 << "supported by Intel MKL.";
1769 return false;
1770 }
1771
1772 return true;
1773 }
1774
FusedConv2DRewrite(const Node * n)1775 static bool FusedConv2DRewrite(const Node* n) {
1776 // MKL DNN currently doesn't support all fusions that grappler fuses
1777 // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
1778 // it includes those we support.
1779 DataType T;
1780 if (!TryGetNodeAttr(n->def(), "T", &T) ||
1781 !mkl_op_registry::IsMklLayoutDependentOp(csinfo_.mkl_fused_conv2d, T)) {
1782 return false;
1783 }
1784
1785 std::vector<string> fused_ops;
1786 TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1787 return (fused_ops == std::vector<string>{"BiasAdd"} ||
1788 fused_ops == std::vector<string>{"Relu"} ||
1789 fused_ops == std::vector<string>{"Relu6"} ||
1790 fused_ops == std::vector<string>{"Elu"} ||
1791 fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1792 fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1793 fused_ops == std::vector<string>{"BiasAdd", "Elu"} ||
1794 fused_ops == std::vector<string>{"BiasAdd", "Add"} ||
1795 fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"} ||
1796 fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"} ||
1797 fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"} ||
1798 fused_ops == std::vector<string>{"LeakyRelu"} ||
1799 fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"} ||
1800 fused_ops == std::vector<string>{"BiasAdd", "Add", "LeakyRelu"});
1801 }
1802
FusedDepthwiseConv2DRewrite(const Node * n)1803 static bool FusedDepthwiseConv2DRewrite(const Node* n) {
1804 // MKL DNN currently doesn't support all fusions that grappler fuses
1805 // together with DepthwiseConv2D (ex. batchnorm). We rewrite
1806 // _FusedDepthwiseConv2DNative only if it includes those we support.
1807 DataType T;
1808 if (!TryGetNodeAttr(n->def(), "T", &T) ||
1809 !mkl_op_registry::IsMklLayoutDependentOp(
1810 csinfo_.mkl_fused_depthwise_conv2d, T)) {
1811 return false;
1812 }
1813
1814 std::vector<string> fused_ops;
1815 TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1816 return (fused_ops == std::vector<string>{"BiasAdd"} ||
1817 fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1818 fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1819 fused_ops == std::vector<string>{"BiasAdd", "Elu"});
1820 }
1821
1822 // Rewrites input node to a new node specified by its matching rewrite info.
1823 //
1824 // Method first searches matching rewrite info for input node and then
1825 // uses that info to rewrite.
1826 //
1827 // Input node may be deleted in case of rewrite. Attempt to use the node
1828 // after the call can result in undefined behaviors.
1829 //
1830 // @input g - input graph, n - Node to be rewritten,
1831 // ri - matching rewriteinfo
1832 // @return Status::OK(), if the input node is rewritten;
1833 // Returns appropriate Status error code otherwise.
1834 // Graph is updated in case the input node is rewritten.
1835 // Otherwise, it is not updated.
1836 Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
1837
1838 // Rewrites input node to just change its operator name. The number of
1839 // inputs to the node and the number of outputs remain the same. Attributes
1840 // of the new node could be copied from attributes of the old node or
1841 // modified. copy_attrs field of RewriteInfo controls this.
1842 //
1843 // Conceptually, it allows us to rewrite:
1844 //
1845 // f[a=v1,b=v2](x,y) -> g[a'=v3,b'=v4](x,y)
1846 //
1847 // Attributes can be altered without any restrictions --- they could be
1848 // copied, modified, or deleted completely.
1849 //
1850 // @input g - input graph, orig_node - Node to be rewritten,
1851 // ri - matching rewriteinfo
1852 // @output new_node - points to newly created node
1853 // @return Status::OK(), if the input node is rewritten;
1854 // Returns appropriate Status error code otherwise.
1855 // Graph is only updated when the input node is rewritten.
1856 Status RewriteNodeForJustOpNameChange(std::unique_ptr<Graph>* g,
1857 const Node* orig_node, Node** new_node,
1858 const RewriteInfo* ri);
1859
1860 // Rewrites input node to enable MKL layout propagation. Please also refer to
1861 // documentation for the function RewriteNodeForJustOpNameChange() to
1862 // understand what it means.
1863 //
1864 // @input g - input graph, orig_node - Node to be rewritten,
1865 // ri - matching rewriteinfo
1866 // @output new_node - points to newly created node
1867 // @return Status::OK(), if the input node is rewritten;
1868 // Returns appropriate Status error code otherwise.
1869 // Graph is updated in case the input node is rewritten.
1870 // Otherwise, it is not updated.
1871 Status RewriteNodeForLayoutPropagation(std::unique_ptr<Graph>* g,
1872 const Node* orig_node, Node** new_node,
1873 const RewriteInfo* ri);
1874
1875 // Get nodes that will feed a list of TF tensors to the new
1876 // node that we are constructing.
1877 //
1878 // @input g - input graph,
1879 // @input inputs - inputs to old node that we are using for constructing
1880 // new inputs,
1881 // @input input_idx - the index in the 'inputs' vector pointing to the
1882 // current input that we have processed so far
1883 // @output input_idx - index will be incremented by the number of nodes
1884 // from 'inputs' that are processed
1885 // @input list_length - The expected length of list of TF tensors
1886 // @output output_nodes - the list of new nodes creating TF tensors
1887 //
1888 // @return None
1889 void GetNodesProducingTFTensorList(
1890 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1891 int* input_idx, int list_length,
1892 std::vector<NodeBuilder::NodeOut>* output_nodes);
1893
1894 // Get nodes that will feed a list of Mkl tensors to the new
1895 // node that we are constructing.
1896 //
1897 // @input g - input graph,
1898 // @input orig_node - Original node that we are rewriting
1899 // @input inputs - inputs to old node that we are using for constructing
1900 // new inputs,
1901 // @input input_idx - the index in the 'inputs' vector pointing to the
1902 // current input that we have processed so far
1903 // @output input_idx - index will be incremented by the number of nodes
1904 // from 'inputs' that are processed
1905 // @input list_length - The expected length of list of Mkl tensors
1906 // @output output_nodes - the list of new nodes creating Mkl tensors
1907 //
1908 // @return None
1909 void GetNodesProducingMklTensorList(
1910 std::unique_ptr<Graph>* g, const Node* orig_node,
1911 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1912 int* input_idx, int list_length,
1913 std::vector<NodeBuilder::NodeOut>* output_nodes);
1914
1915 // Get a node that will feed an Mkl tensor to the new
1916 // node that we are constructing. The output node could be (1) 'n'
1917 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
1918 // if 'n' is not an Mkl layer.
1919 //
1920 // @input g - input graph,
1921 // @input orig_node - Original node that we are rewriting,
1922 // @input n - Node based on which we are creating Mkl node,
1923 // @input n_output_slot - the output slot of node 'n'
1924 // which is feeding to the node that we are constructing
1925 // @output mkl_node - the new node that will feed Mkl tensor
1926 // @output mkl_node_output_slot - the slot number of mkl_node that
1927 // will feed the tensor
1928 // @return None
1929 void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
1930 const Node* orig_node, Node* n,
1931 int n_output_slot, Node** mkl_node,
1932 int* mkl_node_output_slot);
1933
1934 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1935 // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
1936 // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
1937 // producing workspace edges if 'are_workspace_tensors_available' is true.
1938 // Otherwise, 'workspace_tensors' is empty vector.
1939 //
1940 // For details, refer to 'Ordering of inputs after rewriting' section in the
1941 // documentation above.
1942 //
1943 // Returns Status::OK() if setting up inputs is successful, otherwise
1944 // returns appropriate status code.
1945 int SetUpContiguousInputs(
1946 std::unique_ptr<Graph>* g,
1947 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1948 NodeBuilder* nb, const Node* old_node,
1949 std::vector<NodeBuilder::NodeOut>* workspace_tensors,
1950 bool are_workspace_tensors_available);
1951
1952 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1953 // in graph 'g'. Original node is input in 'orig_node'.
1954 //
1955 // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
1956 // section in the documentation above.
1957 //
1958 // Returns Status::OK() if setting up inputs is successful, otherwise
1959 // returns appropriate status code.
1960 Status SetUpInputs(std::unique_ptr<Graph>* g,
1961 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1962 NodeBuilder* nb, const Node* orig_node);
1963
1964 // Create new inputs by copying old inputs 'inputs' for the rewritten node
1965 // in 'nb' in graph 'g'. Original node is input in 'orig_node'. This is mostly
1966 // used in the context of rewrite for just operator name change in which
1967 // inputs of old operator and new operator are same.
1968 //
1969 // Returns Status::OK() if setting up inputs is successful, otherwise
1970 // returns appropriate status code.
1971 Status CopyInputs(const Node* orig_node,
1972 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1973 NodeBuilder* nb);
1974
1975 // Add workspace edge on the input or output side of Node 'orig_node' by using
1976 // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
1977 // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
1978 // tensors, if they need to be added, will be set into these tensors.
1979 // If we set workspace tensors, then are_ws_tensors_added should be true.
1980 void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
1981 const Node* orig_node, NodeBuilder* nb,
1982 std::vector<NodeBuilder::NodeOut>* ws_tensors,
1983 bool* are_ws_tensors_added);
1984
1985 // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
1986 // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
1987 // 'g'. Returns true if fixup was done; otherwise, it returns false.
1988 bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data,
1989 const Edge* e_metadata);
1990
1991 // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
1992 // connected? If not, then fix them. This is needed because a graph may have
1993 // some input Mkl metadata edges incorrectly setup after node merge and
1994 // rewrite passes. This could happen because GetReversePostOrder function may
1995 // not provide topologically sorted order if a graph contains cycles. The
1996 // function returns true if at least one Mkl metadata edge for node 'n' was
1997 // fixed. Otherwise, it returns false.
1998 //
1999 // Example:
2000 //
2001 // X = MklConv2D(_, _, _)
2002 // Y = MklConv2DWithBias(_, _, _, _, _, _)
2003 // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
2004 //
2005 // For a graph such as shown above, note that 3rd argument of MklAdd contains
2006 // DummyMklTensor. Actually, it should be getting the Mkl metadata from
2007 // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
2008 // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
2009 // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
2010 // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
2011 // data edges (1st and 2nd arguments of MklAdd).
2012 bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
2013
2014 // Functions specific to operators to copy attributes
2015 // We need operator-specific function to copy attributes because the framework
2016 // does not provide any generic function for it.
2017 // NOTE: names are alphabetically sorted.
2018 static void CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2019 bool change_format = false);
2020 static void CopyAttrsAllCheckConstFilter(const Node* orig_node,
2021 NodeBuilder* nb,
2022 bool change_format = false);
2023
2024 static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2025 bool change_format = false);
2026 static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
2027 NodeBuilder* nb,
2028 bool change_format = false);
2029 static void CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2030 const Node* orig_node2, NodeBuilder* nb,
2031 bool change_format = false);
2032 static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1,
2033 const Node* orig_node2,
2034 NodeBuilder* nb,
2035 bool change_format = false);
2036 static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb,
2037 bool change_format = false);
2038 static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb,
2039 const std::vector<int32>& strides,
2040 const std::vector<int32>& dilations,
2041 bool change_format = false);
2042
2043 static void CopyAttrsQuantizedMatMulWithBias(const Node* orig_node,
2044 NodeBuilder* nb,
2045 bool change_format = false);
2046 static void CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2047 const Node* orig_node, NodeBuilder* nb, bool change_format = false);
2048 static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb,
2049 bool change_format = false);
2050
2051 // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
2052 // using node for original node 'orig_node' and return it in '*out'.
2053 // TODO(nhasabni) We should move this to mkl_util.h
2054 void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
2055 const Node* orig_node);
2056 void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
2057 const Node* orig_node);
2058 };
2059
2060 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
2061
2062 // We register Mkl rewrite pass for phase 1 in post partitioning group.
2063 // We register it here so that we get a complete picture of all users of Mkl
2064 // nodes. Do not change the ordering of the Mkl passes.
2065 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
2066 OptimizationPassRegistry::POST_PARTITIONING;
2067 #ifdef ENABLE_MKL
2068 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
2069 #endif // ENABLE_MKL
2070
2071 //////////////////////////////////////////////////////////////////////////
2072 // Helper functions for creating new node
2073 //////////////////////////////////////////////////////////////////////////
2074
FillInputs(const Node * n,gtl::InlinedVector<Node *,4> * control_edges,gtl::InlinedVector<std::pair<Node *,int>,4> * in)2075 static void FillInputs(const Node* n,
2076 gtl::InlinedVector<Node*, 4>* control_edges,
2077 gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
2078 control_edges->clear();
2079 for (const Edge* e : n->in_edges()) {
2080 if (e->IsControlEdge()) {
2081 control_edges->push_back(e->src());
2082 } else {
2083 (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
2084 }
2085 }
2086 std::sort(control_edges->begin(), control_edges->end());
2087 }
2088
GetNodesProducingTFTensorList(const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2089 void MklLayoutRewritePass::GetNodesProducingTFTensorList(
2090 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2091 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2092 CHECK_LT(*input_idx, inputs.size());
2093 CHECK_GT(list_length, 0);
2094 DCHECK(output_nodes);
2095 output_nodes->reserve(list_length);
2096
2097 while (list_length != 0) {
2098 CHECK_GT(list_length, 0);
2099 CHECK_LT(*input_idx, inputs.size());
2100 Node* n = inputs[*input_idx].first;
2101 int slot = inputs[*input_idx].second;
2102 // If input node 'n' is just producing a single tensor at
2103 // output slot 'slot' then we just add that single node.
2104 output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
2105 (*input_idx)++;
2106 list_length--;
2107 }
2108 }
2109
2110 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyMklTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2111 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
2112 Node** out,
2113 const Node* orig_node) {
2114 // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
2115 // dummy Mkl tensor. 8 = 2*size_t.
2116 const DataType dt = DataTypeToEnum<uint8>::v();
2117 TensorProto proto;
2118 proto.set_dtype(dt);
2119 uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
2120 proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8));
2121 TensorShape dummy_shape({8});
2122 dummy_shape.AsProto(proto.mutable_tensor_shape());
2123 TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
2124 .Attr("value", proto)
2125 .Attr("dtype", dt)
2126 .Device(orig_node->def().device()) // We place this node on
2127 // the same device as the
2128 // device of the original
2129 // node.
2130 .Finalize(&**g, out));
2131 DCHECK(*out); // Make sure we got a valid object before using it
2132
2133 // If number of inputs to the original node is > 0, then we add
2134 // control dependency between 1st input (index 0) of the original node and
2135 // the dummy Mkl node. This is needed because control-flow ops such as Enter,
2136 // Merge, etc, require frame_name of the dummy Mkl node to be same as the
2137 // rewritten node. Adding control edge between 1st input of the original node
2138 // and the dummy Mkl node ensures that the dummy node is in the same frame
2139 // as the original node. Choosing 1st input is not necessary - any input of
2140 // the original node is fine because all the inputs of a node are always in
2141 // the same frame.
2142 if (orig_node->num_inputs() > 0) {
2143 Node* orig_input0 = nullptr;
2144 TF_CHECK_OK(
2145 orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
2146 auto edge = (*g)->AddControlEdge(orig_input0, *out, false);
2147 DCHECK(edge != nullptr || DoesControlEdgeExist(orig_input0, *out));
2148 }
2149
2150 (*out)->set_assigned_device_name(orig_node->assigned_device_name());
2151 }
2152
GetNodesProducingMklTensorList(std::unique_ptr<Graph> * g,const Node * orig_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)2153 void MklLayoutRewritePass::GetNodesProducingMklTensorList(
2154 std::unique_ptr<Graph>* g, const Node* orig_node,
2155 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2156 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2157 CHECK_LT(*input_idx, inputs.size());
2158 CHECK_GT(list_length, 0);
2159 DCHECK(output_nodes);
2160 output_nodes->reserve(list_length);
2161
2162 while (list_length != 0) {
2163 CHECK_GT(list_length, 0);
2164 CHECK_LT(*input_idx, inputs.size());
2165 Node* n = inputs[*input_idx].first;
2166 int slot = inputs[*input_idx].second;
2167 // If 'n' is producing a single tensor, then create a single Mkl tensor
2168 // node.
2169 Node* mkl_node = nullptr;
2170 int mkl_node_output_slot = 0;
2171 GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
2172 &mkl_node_output_slot);
2173 output_nodes->push_back(
2174 NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
2175 (*input_idx)++;
2176 list_length--;
2177 }
2178 }
2179
2180 // Get an input node that will feed Mkl tensor to the new
2181 // node that we are constructing. An input node could be (1) 'n'
2182 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
2183 // if 'n' is not an Mkl layer.
GetNodeProducingMklTensor(std::unique_ptr<Graph> * g,const Node * orig_node,Node * n,int n_output_slot,Node ** mkl_node,int * mkl_node_output_slot)2184 void MklLayoutRewritePass::GetNodeProducingMklTensor(
2185 std::unique_ptr<Graph>* g, const Node* orig_node, Node* n,
2186 int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
2187 DCHECK(n);
2188 DCHECK(mkl_node);
2189 DCHECK(mkl_node_output_slot);
2190
2191 // If this is an MKL op, then it will create extra output for MKL layout.
2192 DataType T;
2193 if (TryGetNodeAttr(n->def(), "T", &T) &&
2194 mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
2195 // If this is an MKL op, then it will generate an edge that will receive
2196 // Mkl tensor from a node.
2197 // output slot number for Mkl tensor would be N+slot number of TensorFlow
2198 // tensor, where N is total number of TensorFlow tensors.
2199 *mkl_node = n;
2200 *mkl_node_output_slot =
2201 GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
2202 } else {
2203 // If we have not visited the node and rewritten it, then we need
2204 // to create a dummy node that will feed a dummy Mkl tensor to this node.
2205 // DummyMklTensor node has no input and generates only 1 output
2206 // (dummy Mkl tensor) as output slot number 0.
2207 GetDummyMklTensorNode(g, mkl_node, orig_node);
2208 DCHECK(*mkl_node);
2209 *mkl_node_output_slot = 0;
2210 }
2211 }
2212
SetUpContiguousInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node,std::vector<NodeBuilder::NodeOut> * workspace_tensors,bool are_workspace_tensors_available)2213 int MklLayoutRewritePass::SetUpContiguousInputs(
2214 std::unique_ptr<Graph>* g,
2215 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2216 NodeBuilder* nb, const Node* old_node,
2217 std::vector<NodeBuilder::NodeOut>* workspace_tensors,
2218 bool are_workspace_tensors_available) {
2219 DCHECK(workspace_tensors);
2220 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2221
2222 // TODO(nhasabni): Temporary solution to connect filter input of
2223 // BackpropInput with the converted filter from Conv2D.
2224 bool do_connect_conv2d_backprop_input_filter = false;
2225 Node* conv2d_node = nullptr;
2226 // Filter node is 2nd input (slot index 1) of Conv2D.
2227 int kConv2DFilterInputSlotIdx = 1;
2228 int kConv2DBackpropInputFilterInputSlotIdx = 1;
2229 int kConv2DFilterOutputSlotIdx = 1;
2230 if (old_node->type_string() == csinfo_.conv2d_grad_input) {
2231 // We need to find Conv2D node from Conv2DBackpropInput.
2232 // For that let's first find filter node that is 2nd input (slot 1)
2233 // of BackpropInput.
2234 Node* filter_node = nullptr;
2235 TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx,
2236 &filter_node));
2237 DCHECK(filter_node);
2238
2239 // Now check which nodes receive from filter_node. Filter feeds as
2240 // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
2241 // _MklFusedConv2D.
2242 for (const Edge* e : filter_node->out_edges()) {
2243 if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
2244 e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d ||
2245 e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d ||
2246 e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias ||
2247 e->dst()->type_string() == csinfo_.mkl_fused_conv2d) &&
2248 e->dst_input() == kConv2DFilterInputSlotIdx
2249 /* filter is 2nd input of Conv2D and _MklConv2D. */) {
2250 if (conv2d_node != nullptr) {
2251 VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
2252 << " feeding multiple Conv2D nodes: "
2253 << filter_node->DebugString();
2254 // We will not connect filter input of Conv2DBackpropInput
2255 // to be safe here.
2256 do_connect_conv2d_backprop_input_filter = false;
2257 break;
2258 } else {
2259 conv2d_node = e->dst();
2260 do_connect_conv2d_backprop_input_filter = true;
2261 }
2262 }
2263 }
2264 }
2265
2266 // Number of input slots to original op
2267 // Input slots are represented by .Input() calls in REGISTER_OP.
2268 int old_node_input_slots = old_node->op_def().input_arg_size();
2269 int nn_slot_idx = 0; // slot index for inputs of new node
2270
2271 // Let's copy all inputs (TF tensors) of original node to new node.
2272 int iidx = 0;
2273 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2274 // An input slot could be a single tensor or a list. We need
2275 // to handle this case accordingly.
2276 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2277 if (ArgIsList(arg)) {
2278 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2279 int tensor_list_length = GetTensorListLength(arg, old_node);
2280 if (tensor_list_length != 0) {
2281 GetNodesProducingTFTensorList(old_node_inputs, &iidx,
2282 tensor_list_length, &new_node_inputs);
2283 }
2284 nb->Input(new_node_inputs);
2285 nn_slot_idx++;
2286 } else {
2287 // Special case for connecting filter input of Conv2DBackpropInput
2288 if (do_connect_conv2d_backprop_input_filter &&
2289 iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2290 nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
2291 } else {
2292 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2293 }
2294 iidx++;
2295 nn_slot_idx++;
2296 }
2297 }
2298
2299 // If workspace tensors are available for this op and we are using
2300 // contiguous ordering then we need to add Tensorflow tensor for
2301 // workspace here because Tensorflow tensor for workspace is the
2302 // last tensor in the list of Tensorflow tensors.
2303 if (are_workspace_tensors_available) {
2304 CHECK_EQ(workspace_tensors->size(), 2);
2305 // Tensorflow tensor
2306 nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
2307 nn_slot_idx++;
2308 }
2309
2310 // Let's now setup all Mkl inputs to a new node.
2311 // Number of Mkl inputs must be same as number of TF inputs.
2312 iidx = 0;
2313 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2314 // An input slot could be a single tensor or a list. We need
2315 // to handle this case accordingly.
2316 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2317 if (ArgIsList(arg)) {
2318 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2319 int tensor_list_length = GetTensorListLength(arg, old_node);
2320 if (tensor_list_length != 0) {
2321 GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
2322 tensor_list_length, &new_node_inputs);
2323 }
2324 nb->Input(new_node_inputs);
2325 nn_slot_idx++;
2326 } else {
2327 Node* mkl_node = nullptr;
2328 int mkl_node_output_slot = 0;
2329 // Special case for connecting filter input of Conv2DBackpropInput
2330 if (do_connect_conv2d_backprop_input_filter &&
2331 iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2332 GetNodeProducingMklTensor(g, old_node, conv2d_node,
2333 kConv2DFilterOutputSlotIdx, &mkl_node,
2334 &mkl_node_output_slot);
2335 } else {
2336 GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
2337 old_node_inputs[iidx].second, &mkl_node,
2338 &mkl_node_output_slot);
2339 }
2340 nb->Input(mkl_node, mkl_node_output_slot);
2341 iidx++;
2342 nn_slot_idx++;
2343 }
2344 }
2345
2346 // If workspace tensors are available for this op and we are using
2347 // contiguous ordering then we need to add Mkl tensor for
2348 // workspace here because Mkl tensor for workspace is the
2349 // last tensor in the list of Mkl tensors.
2350 if (are_workspace_tensors_available) {
2351 CHECK_EQ(workspace_tensors->size(), 2);
2352 // Mkl tensor
2353 nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
2354 nn_slot_idx++;
2355 }
2356
2357 return nn_slot_idx;
2358 }
2359
SetUpInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,const Node * old_node)2360 Status MklLayoutRewritePass::SetUpInputs(
2361 std::unique_ptr<Graph>* g,
2362 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2363 NodeBuilder* nb, const Node* old_node) {
2364 // Let's check if we need to add workspace tensors for this node.
2365 // We add workspace edge only for MaxPool, LRN and BatchNorm.
2366 std::vector<NodeBuilder::NodeOut> workspace_tensors;
2367 bool are_workspace_tensors_available = false;
2368
2369 // Avoid workspace check for QuantizedConv2D and the fused
2370 // Ops as they don't have attribute: "T".
2371 std::vector<string> quant_ops{
2372 "Dequantize",
2373 "QuantizeV2",
2374 "QuantizedConv2D",
2375 "QuantizedConv2DWithBias",
2376 "QuantizedConv2DAndRelu",
2377 "QuantizedConv2DWithBiasAndRelu",
2378 "QuantizedConv2DWithBiasSumAndRelu",
2379 "QuantizedConv2DPerChannel",
2380 "QuantizedConv2DAndRequantize",
2381 "QuantizedConv2DWithBiasAndRequantize",
2382 "QuantizedConv2DAndReluAndRequantize",
2383 "QuantizedConv2DWithBiasAndReluAndRequantize",
2384 "QuantizedConv2DWithBiasSumAndReluAndRequantize",
2385 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize",
2386 "QuantizedMatMulWithBias",
2387 "QuantizedMatMulWithBiasAndRequantize",
2388 "QuantizedMatMulWithBiasAndDequantize",
2389 "QuantizedMatMulWithBiasAndRelu",
2390 "QuantizedMatMulWithBiasAndReluAndRequantize",
2391 "QuantizedDepthwiseConv2D",
2392 "QuantizedDepthwiseConv2DWithBias",
2393 "QuantizedDepthwiseConv2DWithBiasAndRelu",
2394 "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"};
2395 bool should_check_workspace =
2396 std::find(std::begin(quant_ops), std::end(quant_ops),
2397 old_node->type_string()) == std::end(quant_ops);
2398 if (should_check_workspace)
2399 AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
2400 &are_workspace_tensors_available);
2401
2402 int new_node_input_slots = 0;
2403 if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
2404 // TODO(nhasabni): implement this function just for same of completion.
2405 // We do not use interleaved ordering right now.
2406 return Status(
2407 error::Code::UNIMPLEMENTED,
2408 "Interleaved ordering of tensors is currently not supported.");
2409 } else {
2410 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2411 new_node_input_slots = SetUpContiguousInputs(
2412 g, old_node_inputs, nb, old_node, &workspace_tensors,
2413 are_workspace_tensors_available);
2414 }
2415
2416 // Sanity check
2417 int old_node_input_slots = old_node->op_def().input_arg_size();
2418 if (!are_workspace_tensors_available) {
2419 // If we are not adding workspace tensors for this op, then the total
2420 // number of input slots to the new node _must_ be 2 times the number
2421 // of input slots to the original node: N original Tensorflow tensors and
2422 // N for Mkl tensors corresponding to each Tensorflow tensors.
2423 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
2424 } else {
2425 // If we are adding workspace tensors for this op, then the total
2426 // The total number of input slots to new node _must_ be 2 times the number
2427 // of input slots to the original node: N original Tensorflow tensors and
2428 // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
2429 // (for workspace Tensorflow tensor and workspace Mkl tensor).
2430 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
2431 }
2432
2433 return Status::OK();
2434 }
2435
CopyInputs(const Node * old_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb)2436 Status MklLayoutRewritePass::CopyInputs(
2437 const Node* old_node,
2438 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2439 NodeBuilder* nb) {
2440 // Number of input slots to old node
2441 // Input slots are represented by .Input() calls in REGISTER_OP.
2442 int old_node_input_slots = old_node->op_def().input_arg_size();
2443 // Actual number of inputs can be greater than or equal to number
2444 // of Input slots because inputs of type list could be unfolded.
2445 auto old_node_input_size = old_node_inputs.size();
2446 DCHECK_GE(old_node_input_size, old_node_input_slots);
2447
2448 // Let's copy all inputs of old node to new node.
2449 int iidx = 0;
2450 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2451 // An input slot could be a single tensor or a list. We need
2452 // to handle this case accordingly.
2453 DCHECK_LT(iidx, old_node_input_size);
2454 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2455 if (ArgIsList(arg)) {
2456 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2457 int N = GetTensorListLength(arg, old_node);
2458 if (N != 0) {
2459 GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
2460 &new_node_inputs);
2461 }
2462 nb->Input(new_node_inputs);
2463 } else {
2464 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2465 iidx++;
2466 }
2467 }
2468 return Status::OK();
2469 }
2470
2471 //////////////////////////////////////////////////////////////////////////
2472 // Helper functions related to workspace pass
2473 //////////////////////////////////////////////////////////////////////////
2474
2475 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyWorkspaceTensorNode(std::unique_ptr<Graph> * g,Node ** out,const Node * orig_node)2476 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
2477 std::unique_ptr<Graph>* g, Node** out, const Node* orig_node) {
2478 // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
2479 // workspace tensor.
2480 GetDummyMklTensorNode(g, out, orig_node);
2481 }
2482
AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph> * g,const Node * orig_node,NodeBuilder * nb,std::vector<NodeBuilder::NodeOut> * ws_tensors,bool * are_ws_tensors_added)2483 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
2484 std::unique_ptr<Graph>* g, const Node* orig_node, NodeBuilder* nb,
2485 std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
2486 bool workspace_edge_added = false; // Default initializer
2487 DCHECK(are_ws_tensors_added);
2488 *are_ws_tensors_added = false; // Default initializer
2489
2490 DataType T;
2491 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2492 for (auto ws : wsinfo_) {
2493 if (orig_node->type_string() == ws.fwd_op &&
2494 mkl_op_registry::IsMklOp(
2495 mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
2496 // If this op is a fwd op, then we need to check if there is an
2497 // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
2498 // an edge, then we just add an attribute on this node for setting
2499 // workspace_passed to true. We don't add actual workspace edge
2500 // in this node. Actual workspace edge gets added in the backward
2501 // op for this node.
2502 for (const Edge* e : orig_node->out_edges()) {
2503 if (e->src_output() == ws.fwd_slot &&
2504 e->dst()->type_string() == ws.bwd_op &&
2505 e->dst_input() == ws.bwd_slot) {
2506 nb->Attr("workspace_enabled", true);
2507 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2508 << orig_node->type_string();
2509 workspace_edge_added = true;
2510 // We found the edge that we were looking for, so break.
2511 break;
2512 }
2513 }
2514
2515 if (!workspace_edge_added) {
2516 // If we are here, then we did not find backward operator for this
2517 // node.
2518 nb->Attr("workspace_enabled", false);
2519 }
2520 } else if (orig_node->type_string() == ws.bwd_op &&
2521 mkl_op_registry::IsMklOp(
2522 mkl_op_registry::GetMklOpName(orig_node->type_string()),
2523 T)) {
2524 // If this op is a bwd op, then we need to add workspace edge and
2525 // it's Mkl tensor edge between its corresponding fwd op and this
2526 // op. Corresponding fwd op is specified in 'fwd_op' field of
2527 // workspace info. fwd_slot and bwd_slot in workspace info specify
2528 // an edge between which slots connect forward and backward op.
2529 // Once all these criteria match, we add a workspace edge between
2530 // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
2531 // determined by interleaved/contiguous ordering. Function
2532 // DataIndexToMetaDataIndex tells us the location of Mkl tensor
2533 // from the location of the Tensorflow tensor.
2534 for (const Edge* e : orig_node->in_edges()) {
2535 if (e->src_output() == ws.fwd_slot &&
2536 // We would have rewritten the forward op, so we need to use
2537 // GetMklOpName call to get its Mkl name.
2538 e->src()->type_string() ==
2539 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
2540 e->dst_input() == ws.bwd_slot) {
2541 nb->Attr("workspace_enabled", true);
2542 DCHECK(ws_tensors);
2543 // Add workspace edge between fwd op and bwd op.
2544 ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
2545 // Check if we are running in native format mode. If so,
2546 // we don't need to have an Mkl metadata tensor for the workspace.
2547 if (!NativeFormatEnabled()) {
2548 // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
2549 ws_tensors->push_back(NodeBuilder::NodeOut(
2550 e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
2551 e->src()->num_outputs())));
2552 }
2553 *are_ws_tensors_added = true;
2554 // In terms of input ordering, we add these calls to add Input
2555 // here because workspace edge (and its Mkl tensor) is the last
2556 // edge in the fwdop and bwdop. So all inputs before workspace
2557 // tensor have been added by SetUpInputs function.
2558 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2559 << orig_node->type_string();
2560 workspace_edge_added = true;
2561 // We found the edge that we were looking for, so break.
2562 break;
2563 }
2564 }
2565
2566 // If we are here means we did not find fwd op that feeds to this
2567 // bwd op. So in this case, we need to generate dummy tensors for
2568 // workspace input and Mkl tensor for workspace, and set
2569 // workspace_enabled to false.
2570 if (!workspace_edge_added) {
2571 nb->Attr("workspace_enabled", false);
2572 Node* dmt_ws = nullptr; // Dummy tensor for workspace
2573 Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace
2574 GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
2575 GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
2576 DCHECK(dmt_ws);
2577 DCHECK(dmt_mkl_ws);
2578 DCHECK(ws_tensors);
2579 // We add dummy tensor as workspace tensor.
2580 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
2581 // We add dummy tensor as Mkl tensor for workspace tensor.
2582 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
2583 *are_ws_tensors_added = true;
2584 VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
2585 << orig_node->type_string();
2586 }
2587 } else {
2588 // If this node does not match any workspace info, then we do not
2589 // do anything special for workspace propagation for it.
2590 }
2591 }
2592 }
2593
2594 //////////////////////////////////////////////////////////////////////////
2595 // Op-specific functions to copy attributes from old node to new node
2596 //////////////////////////////////////////////////////////////////////////
2597
2598 // Generic function to copy all attributes from original node to target.
CopyAttrsAll(const Node * orig_node,NodeBuilder * nb,bool change_format)2599 void MklLayoutRewritePass::CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2600 bool change_format) {
2601 string name;
2602 AttrSlice attr_list(orig_node->def());
2603
2604 auto iter = attr_list.begin();
2605 while (iter != attr_list.end()) {
2606 name = iter->first;
2607 auto attr = iter->second;
2608 nb->Attr(name, attr);
2609 ++iter;
2610 }
2611 }
2612
2613 // Generic function to copy all attributes and check if filter is const.
CopyAttrsAllCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2614 void MklLayoutRewritePass::CopyAttrsAllCheckConstFilter(const Node* orig_node,
2615 NodeBuilder* nb,
2616 bool change_format) {
2617 CopyAttrsAll(orig_node, nb, change_format);
2618
2619 // Check and set filter attribute.
2620 Node* filter_node = nullptr;
2621 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2622 nb->Attr("is_filter_const", filter_node->IsConstant());
2623 }
2624
CopyAttrsConvCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2625 void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node,
2626 NodeBuilder* nb,
2627 bool change_format) {
2628 CopyAttrsConv(orig_node, nb, change_format);
2629
2630 // Check and set filter attribute.
2631 Node* filter_node = nullptr;
2632 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2633 nb->Attr("is_filter_const", filter_node->IsConstant());
2634 }
2635
CopyAttrsConv(const Node * orig_node,NodeBuilder * nb,bool change_format)2636 void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2637 bool change_format) {
2638 DataType T;
2639 string padding;
2640 std::vector<int32> strides;
2641 std::vector<int32> dilations;
2642 std::vector<int32> explicit_paddings;
2643
2644 // Get all attributes from old node.
2645 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2646 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2647 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2648 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2649
2650 // Check `explicit_paddings` first because some Conv ops don't have
2651 // this attribute.
2652 if (TryGetNodeAttr(orig_node->def(), "explicit_paddings",
2653 &explicit_paddings) &&
2654 !explicit_paddings.empty()) {
2655 nb->Attr("explicit_paddings", explicit_paddings);
2656 }
2657
2658 // Add attributes to new node.
2659 nb->Attr("T", T);
2660 nb->Attr("padding", padding);
2661
2662 // Add attributes related to `data_format`.
2663 CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
2664 }
2665
2666 // Used with MergePadWithConv2D
CopyAttrsFromPadAndConv2D(const Node * orig_node1,const Node * orig_node2,NodeBuilder * nb,bool change_format)2667 void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2668 const Node* orig_node2,
2669 NodeBuilder* nb,
2670 bool change_format) {
2671 DataType Tpaddings;
2672 DataType T;
2673 string data_format;
2674 string padding;
2675 std::vector<int32> strides;
2676 std::vector<int32> dilations;
2677 bool use_cudnn_on_gpu;
2678
2679 // Get all attributes from old node 1.
2680 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T", &T));
2681 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides", &strides));
2682 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations", &dilations));
2683 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding", &padding));
2684 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format", &data_format));
2685 TF_CHECK_OK(
2686 GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2687 // Get all attributes from old node 2.
2688 TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings", &Tpaddings));
2689
2690 // Add attributes to new node.
2691 nb->Attr("T", T);
2692 nb->Attr("strides", strides);
2693 nb->Attr("dilations", dilations);
2694 nb->Attr("padding", padding);
2695 nb->Attr("data_format", data_format);
2696 nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2697 nb->Attr("Tpaddings", Tpaddings);
2698 }
2699
CopyAttrsFromPadAndFusedConv2D(const Node * fused_conv2d,const Node * pad,NodeBuilder * nb,bool change_format)2700 void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
2701 const Node* fused_conv2d, const Node* pad, NodeBuilder* nb,
2702 bool change_format) {
2703 DataType T;
2704 int num_args;
2705 string data_format;
2706 string padding;
2707 std::vector<int32> strides;
2708 std::vector<int32> dilations;
2709 float epsilon;
2710 std::vector<string> fused_ops;
2711 DataType Tpaddings;
2712 float leakyrelu_alpha;
2713
2714 // Get all attributes from old node.
2715 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T));
2716 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args", &num_args));
2717 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides", &strides));
2718 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding", &padding));
2719 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format", &data_format));
2720 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations));
2721 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops));
2722 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon));
2723 TF_CHECK_OK(
2724 GetNodeAttr(fused_conv2d->def(), "leakyrelu_alpha", &leakyrelu_alpha));
2725 TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings));
2726
2727 // Add attributes to new node.
2728 nb->Attr("T", T);
2729 nb->Attr("num_args", num_args);
2730 nb->Attr("strides", strides);
2731 nb->Attr("padding", padding);
2732 nb->Attr("data_format", data_format);
2733 nb->Attr("dilations", dilations);
2734 nb->Attr("epsilon", epsilon);
2735 nb->Attr("Tpaddings", Tpaddings);
2736 nb->Attr("fused_ops", fused_ops);
2737 nb->Attr("leakyrelu_alpha", leakyrelu_alpha);
2738 }
2739
CopyAttrsQuantizedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2740 void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
2741 NodeBuilder* nb,
2742 bool change_format) {
2743 DataType Tinput, Tfilter, out_type;
2744 string padding;
2745 string data_format("NHWC");
2746 std::vector<int32> strides, dilations, padding_list;
2747 bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list");
2748
2749 // Get all attributes from old node.
2750 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
2751 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter", &Tfilter));
2752 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
2753 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2754 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2755 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2756 if (has_padding_list) {
2757 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list));
2758 }
2759
2760 Node* filter_node = nullptr;
2761 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2762
2763 // Add attributes to new node.
2764 nb->Attr("Tinput", Tinput);
2765 nb->Attr("Tfilter", Tfilter);
2766 nb->Attr("out_type", out_type);
2767 nb->Attr("padding", padding);
2768 nb->Attr("is_filter_const", filter_node->IsConstant());
2769 nb->Attr("strides", strides);
2770 nb->Attr("dilations", dilations);
2771 nb->Attr("T", out_type); // added "T" for facilitating MklToTf conversion.
2772 nb->Attr("data_format", data_format);
2773 if (has_padding_list) {
2774 nb->Attr("padding_list", padding_list);
2775 }
2776
2777 // Requantization attr Tbias.
2778 DataType Tbias;
2779 Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2780 if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2781 }
2782
CopyAttrsQuantizedMatMulWithBiasAndDequantize(const Node * orig_node,NodeBuilder * nb,bool change_format)2783 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2784 const Node* orig_node, NodeBuilder* nb, bool change_format) {
2785 DataType T1, T2, Toutput;
2786
2787 // Get all attributes from old node.
2788 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T1", &T1));
2789 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T2", &T2));
2790 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Toutput", &Toutput));
2791
2792 // Add attributes to new node.
2793 nb->Attr("T1", T1);
2794 nb->Attr("T2", T2);
2795 nb->Attr("Toutput", Toutput);
2796 nb->Attr("T", T1); // added "T" for facilitating MklToTf conversion.
2797
2798 // Requantization attr Tbias
2799 DataType Tbias;
2800 Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2801 if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2802 }
2803
CopyAttrsQuantizedMatMulWithBias(const Node * orig_node,NodeBuilder * nb,bool change_format)2804 void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBias(
2805 const Node* orig_node, NodeBuilder* nb, bool change_format) {
2806 DataType T1, T2, Toutput;
2807
2808 // Get all attributes from old node.
2809 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T1", &T1));
2810 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T2", &T2));
2811 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Toutput", &Toutput));
2812
2813 Node* weight_node = nullptr;
2814 TF_CHECK_OK(orig_node->input_node(1, &weight_node));
2815
2816 // Add attributes to new node.
2817 nb->Attr("T1", T1);
2818 nb->Attr("T2", T2);
2819 nb->Attr("Toutput", Toutput);
2820 nb->Attr("is_weight_const", weight_node->IsConstant());
2821 nb->Attr("T", Toutput); // added "T" for facilitating MklToTf conversion.
2822
2823 // Requantization attr Tbias
2824 DataType Tbias;
2825 Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2826 if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2827 }
2828
CopyFormatAttrsConv(const Node * orig_node,NodeBuilder * nb,const std::vector<int32> & strides,const std::vector<int32> & dilations,bool change_format)2829 void MklLayoutRewritePass::CopyFormatAttrsConv(
2830 const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides,
2831 const std::vector<int32>& dilations, bool change_format) {
2832 string data_format;
2833
2834 if (!change_format) {
2835 nb->Attr("strides", strides);
2836 nb->Attr("dilations", dilations);
2837
2838 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2839 nb->Attr("data_format", data_format);
2840 } else {
2841 std::vector<int32> new_strides;
2842 std::vector<int32> new_dilations;
2843 if (strides.size() == 5) {
2844 // `strides` and `dilations` also need to be changed according to
2845 // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2846 new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2847 strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2848 strides[NDHWC::dim::W]};
2849
2850 new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C],
2851 dilations[NDHWC::dim::D], dilations[NDHWC::dim::H],
2852 dilations[NDHWC::dim::W]};
2853 } else {
2854 // `strides` and `dilations` also need to be changed according to
2855 // `data_format`. In this case, from `NHWC` to `NCHW`.
2856
2857 new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2858 strides[NHWC::dim::H], strides[NHWC::dim::W]};
2859
2860 new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C],
2861 dilations[NHWC::dim::H], dilations[NHWC::dim::W]};
2862 }
2863 nb->Attr("strides", new_strides);
2864 nb->Attr("dilations", new_dilations);
2865 }
2866 }
2867
CopyAttrsPooling(const Node * orig_node,NodeBuilder * nb,bool change_format)2868 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
2869 NodeBuilder* nb,
2870 bool change_format) {
2871 DataType T;
2872 string data_format;
2873 string padding;
2874 std::vector<int32> ksize, strides;
2875
2876 // Get all attributes from old node.
2877 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2878 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
2879 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2880 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2881 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2882
2883 // Add attributes to new node.
2884 nb->Attr("T", T);
2885 nb->Attr("padding", padding);
2886
2887 if (!change_format) {
2888 nb->Attr("strides", strides);
2889 nb->Attr("ksize", ksize);
2890
2891 nb->Attr("data_format", data_format);
2892 } else {
2893 std::vector<int32> new_strides;
2894 std::vector<int32> new_ksize;
2895 if (strides.size() == 5) {
2896 DCHECK(data_format == "NCDHW");
2897 // `strides` and `ksize` also need to be changed according to
2898 // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2899 new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2900 strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2901 strides[NDHWC::dim::W]};
2902
2903 new_ksize = {ksize[NDHWC::dim::N], ksize[NDHWC::dim::C],
2904 ksize[NDHWC::dim::D], ksize[NDHWC::dim::H],
2905 ksize[NDHWC::dim::W]};
2906
2907 } else {
2908 // `strides` and `ksize` also need to be changed according to
2909 // `data_format`. In this case, from `NHWC` to `NCHW`.
2910 DCHECK(data_format == "NCHW");
2911 new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2912 strides[NHWC::dim::H], strides[NHWC::dim::W]};
2913
2914 new_ksize = {ksize[NHWC::dim::N], ksize[NHWC::dim::C],
2915 ksize[NHWC::dim::H], ksize[NHWC::dim::W]};
2916 }
2917 nb->Attr("strides", new_strides);
2918 nb->Attr("ksize", new_ksize);
2919 }
2920 }
2921
2922 //////////////////////////////////////////////////////////////////////////
2923 // Helper functions related to node merge pass
2924 //////////////////////////////////////////////////////////////////////////
2925
CheckForNodeMerge(const Node * a) const2926 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
2927 // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
2928 // once we support BiasAddGrad as Mkl layer.
2929
2930 // Search for all matching mergeinfo.
2931 // We allow more than one match for extensibility.
2932 std::vector<const MergeInfo*> matching_mi;
2933 for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
2934 if (a->type_string() == mi->op1 || a->type_string() == mi->op2) {
2935 matching_mi.push_back(&*mi);
2936 }
2937 }
2938
2939 for (const MergeInfo* mi : matching_mi) {
2940 // Get the operand with which 'a' can be merged.
2941 Node* b = nullptr;
2942 if ((b = mi->get_node_to_be_merged(a)) == nullptr) {
2943 continue;
2944 }
2945
2946 // Get the control edges and input of node
2947 const int N_in = a->num_inputs();
2948 gtl::InlinedVector<Node*, 4> a_control_edges;
2949 gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
2950 FillInputs(a, &a_control_edges, &a_in);
2951
2952 const int B_in = b->num_inputs();
2953 gtl::InlinedVector<Node*, 4> b_control_edges;
2954 gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
2955 FillInputs(b, &b_control_edges, &b_in);
2956
2957 // Shouldn't merge if a and b have different control edges.
2958 if (a_control_edges != b_control_edges) {
2959 continue;
2960 } else {
2961 // We found a match.
2962 return b;
2963 }
2964 }
2965
2966 return nullptr;
2967 }
2968
MergeConv2DWithBiasAdd(std::unique_ptr<Graph> * g,Node * m,Node * n)2969 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
2970 Node* m, Node* n) {
2971 CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
2972 n->type_string() == csinfo_.conv2d)) ||
2973 ((n->type_string() == csinfo_.bias_add &&
2974 m->type_string() == csinfo_.conv2d)),
2975 true);
2976
2977 // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
2978 // BiasAdd is successor node, and Conv2D predecessor node.
2979 Node* pred = m->type_string() == csinfo_.bias_add ? n : m;
2980 Node* succ = m->type_string() == csinfo_.bias_add ? m : n;
2981
2982 // 1. Get all attributes from input nodes.
2983 DataType T_pred, T_succ;
2984 string padding;
2985 std::vector<int32> strides;
2986 std::vector<int32> dilations;
2987 string data_format_pred, data_format_succ;
2988 bool use_cudnn_on_gpu;
2989 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
2990 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
2991 TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
2992 TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
2993 TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations));
2994 TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
2995 TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
2996 TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2997 // We check to ensure that data formats of both succ and pred are same.
2998 // We expect them to be same, so we can enforce this as assert.
2999 // But assert can be too strict, so we enforce this as a check.
3000 // If the check fails, then we do not merge two nodes.
3001 // We also do same check for devices.
3002 if (data_format_pred != data_format_succ || T_pred != T_succ ||
3003 pred->assigned_device_name() != succ->assigned_device_name() ||
3004 pred->def().device() != succ->def().device()) {
3005 return Status(error::Code::INVALID_ARGUMENT,
3006 "data_format or T attribute or devices of Conv2D and "
3007 "BiasAdd do not match. Will skip node merge optimization");
3008 }
3009
3010 const int succ_num = succ->num_inputs();
3011 gtl::InlinedVector<Node*, 4> succ_control_edges;
3012 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3013 FillInputs(succ, &succ_control_edges, &succ_in);
3014
3015 const int pred_num = pred->num_inputs();
3016 gtl::InlinedVector<Node*, 4> pred_control_edges;
3017 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3018 FillInputs(pred, &pred_control_edges, &pred_in);
3019
3020 // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is
3021 // not expecting output of Conv2D). If this is not the case, then we cannot
3022 // merge Conv2D with BiasAdd.
3023 const int kFirstOutputSlot = 0;
3024 for (const Edge* e : pred->out_edges()) {
3025 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3026 return Status(error::Code::INVALID_ARGUMENT,
3027 "Conv2D does not feed to BiasAdd, or "
3028 "it feeds BiasAdd but has multiple outputs. "
3029 "Will skip node merge optimization");
3030 }
3031 }
3032
3033 // 2. Get inputs from both the nodes.
3034 // Find the 2 inputs from the conv and the bias from the add Bias.
3035 // Get operand 0, 1 of conv2D.
3036 CHECK_EQ(pred->in_edges().size(), 2); // Conv2D must have 2 inputs.
3037 // Get operand 1 of add_bias
3038 // BiasAdd must have 2 inputs: Conv, bias
3039 CHECK_EQ(succ->in_edges().size(), 2);
3040
3041 // We will use the node name of BiasAdd as the name of new node
3042 // Build new node. We use same name as original node, but change the op
3043 // name.
3044 NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias);
3045 nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D
3046 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3047 nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D
3048 // In1 of BiasAdd is same as output of Conv2D.
3049 nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
3050
3051 // Copy attributes from Conv2D to Conv2DWithBias.
3052 CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb);
3053
3054 // Copy the device assigned to old node to new node.
3055 nb.Device(succ->def().device());
3056
3057 // Create node.
3058 Node* new_node;
3059 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3060
3061 // In the following code of this function, an unsorted set is used to make
3062 // sure no duplicated edges be added into the new node. Therefore, we can
3063 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3064 // check in the routine.
3065
3066 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3067 // node are already copied in BuildNode. We handle control edges now.
3068 std::unordered_set<Node*> unique_node;
3069 for (const Edge* e : pred->in_edges()) {
3070 if (e->IsControlEdge()) {
3071 auto result = unique_node.insert(e->src());
3072 if (result.second) {
3073 (*g)->AddControlEdge(e->src(), new_node, true);
3074 }
3075 }
3076 }
3077 unique_node.clear();
3078
3079 for (const Edge* e : succ->in_edges()) {
3080 if (e->IsControlEdge()) {
3081 auto result = unique_node.insert(e->src());
3082 if (result.second) {
3083 (*g)->AddControlEdge(e->src(), new_node, true);
3084 }
3085 }
3086 }
3087 unique_node.clear();
3088
3089 // Incoming edges are fixed, we will fix the outgoing edges now.
3090 // First, we will fix outgoing control edges from 'pred' node.
3091 for (const Edge* e : pred->out_edges()) {
3092 if (e->IsControlEdge()) {
3093 auto result = unique_node.insert(e->dst());
3094 if (result.second) {
3095 (*g)->AddControlEdge(new_node, e->dst(), true);
3096 }
3097 }
3098 }
3099 unique_node.clear();
3100
3101 // Second, we will fix outgoing control and data edges from 'succ' node.
3102 for (const Edge* e : succ->out_edges()) {
3103 if (e->IsControlEdge()) {
3104 auto result = unique_node.insert(e->dst());
3105 if (result.second) {
3106 (*g)->AddControlEdge(new_node, e->dst(), true);
3107 }
3108 } else {
3109 // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
3110 // output (at slot 0).
3111 const int kConv2DWithBiasOutputSlot = 0;
3112 auto new_edge = (*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot,
3113 e->dst(), e->dst_input());
3114 DCHECK(new_edge);
3115 }
3116 }
3117
3118 // Copy device assigned to old node to new node.
3119 // It's ok to use pred or succ as we have enforced a check that
3120 // both have same device assigned.
3121 new_node->set_assigned_device_name(pred->assigned_device_name());
3122
3123 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3124 << ", and node: " << succ->DebugString()
3125 << ", into node:" << new_node->DebugString();
3126
3127 (*g)->RemoveNode(succ);
3128 (*g)->RemoveNode(pred);
3129
3130 return Status::OK();
3131 }
3132
MergePadWithConv2D(std::unique_ptr<Graph> * g,Node * m,Node * n)3133 Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g,
3134 Node* m, Node* n) {
3135 DCHECK((m->type_string() == csinfo_.pad &&
3136 (n->type_string() == csinfo_.conv2d ||
3137 n->type_string() == csinfo_.fused_conv2d)) ||
3138 (n->type_string() == csinfo_.pad &&
3139 (m->type_string() == csinfo_.conv2d ||
3140 m->type_string() == csinfo_.fused_conv2d)));
3141
3142 bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d ||
3143 m->type_string() == csinfo_.fused_conv2d;
3144 // Conv2D is successor node, and Pad predecessor node.
3145 Node* pred = m->type_string() == csinfo_.pad ? m : n;
3146 Node* succ = m->type_string() == csinfo_.pad ? n : m;
3147
3148 // 1. Get all attributes from input nodes.
3149 DataType T_pred, T_succ;
3150 string padding;
3151 std::vector<int32> strides;
3152 std::vector<int32> dilations;
3153 string data_format_pred, data_format_succ;
3154
3155 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
3156 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
3157 TF_CHECK_OK(GetNodeAttr(succ->def(), "padding", &padding));
3158 TF_CHECK_OK(GetNodeAttr(succ->def(), "strides", &strides));
3159 TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations", &dilations));
3160 // Check if the devices of both succ and pred are the same.
3161 // Assert is not used because it can be too strict.
3162 // Don't need to check for data formats because it is not available in Pad.
3163 if (T_pred != T_succ ||
3164 pred->assigned_device_name() != succ->assigned_device_name() ||
3165 pred->def().device() != succ->def().device()) {
3166 return Status(error::Code::INVALID_ARGUMENT,
3167 "T attribute or devices of Conv2D and "
3168 "Pad do not match. Will skip node merge optimization");
3169 }
3170
3171 const int succ_num = succ->num_inputs();
3172 gtl::InlinedVector<Node*, 4> succ_control_edges;
3173 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3174 FillInputs(succ, &succ_control_edges, &succ_in);
3175
3176 const int pred_num = pred->num_inputs();
3177 gtl::InlinedVector<Node*, 4> pred_control_edges;
3178 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3179 FillInputs(pred, &pred_control_edges, &pred_in);
3180
3181 // We need to ensure that Pad only feeds to Conv2D (some other operator is
3182 // not expecting output of Pad). If this is not the case, then we cannot
3183 // merge Conv2D with Pad.
3184 const int kFirstOutputSlot = 0;
3185 for (const Edge* e : pred->out_edges()) {
3186 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3187 return Status(error::Code::INVALID_ARGUMENT,
3188 "Pad does not feed to Conv2D, or "
3189 "it feeds Conv2D but has multiple outputs. "
3190 "Will skip node merge optimization");
3191 }
3192 }
3193
3194 // 2. Get inputs from both the nodes.
3195
3196 // Pad must have 2 data inputs: "input" and paddings.
3197 int PadDataInputEdges = 0;
3198 for (const Edge* e : pred->in_edges()) {
3199 if (!e->IsControlEdge()) {
3200 PadDataInputEdges++;
3201 }
3202 }
3203 DCHECK_EQ(PadDataInputEdges, 2);
3204
3205 // Conv2D must have 2 data inputs: Pad output and Filter
3206 // FusedConv2D have 3 data inputs: Pad output, Filter and Args;
3207 int ConvDataInputEdges = 0;
3208 for (const Edge* e : succ->in_edges()) {
3209 if (!e->IsControlEdge()) {
3210 ConvDataInputEdges++;
3211 }
3212 }
3213
3214 DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2);
3215
3216 // We will use the node name of Conv2D as the name of new node
3217 // Build new node. We use same name as original node, but change the op
3218 // name.
3219
3220 NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d
3221 : csinfo_.pad_with_conv2d);
3222 nb.Input(pred_in[0].first, pred_in[0].second); // In1 (input data) of Pad
3223 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3224 nb.Input(succ_in[1].first, succ_in[1].second); // In2 (filter) of conv2d
3225 // In1 of Conv2D is same as output of Pad.
3226 // Thus, only need to add In2 of Conv2D
3227
3228 if (is_fused_conv2d) {
3229 // FusedConv2D has one additional input, args
3230 std::vector<NodeBuilder::NodeOut> args;
3231 args.emplace_back(succ_in[2].first, succ_in[2].second);
3232 nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{
3233 args}); // In3 (args) of FusedConv2D
3234 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad
3235 // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D.
3236 CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ),
3237 const_cast<const Node*>(pred), &nb);
3238 } else {
3239 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad
3240 // Copy attributes from Pad and conv2D to PadWithConv2D.
3241 CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ),
3242 const_cast<const Node*>(pred), &nb);
3243 }
3244
3245 // Copy the device assigned to old node to new node.
3246 nb.Device(succ->def().device());
3247
3248 // Create node.
3249 Node* new_node;
3250 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3251 // No need to check if new_node is null because it will be null only when
3252 // Finalize fails.
3253
3254 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3255 // node are already copied in BuildNode.
3256 // We handle control edges now.
3257 for (const Edge* e : pred->in_edges()) {
3258 if (e->IsControlEdge()) {
3259 // Don't allow duplicate edge
3260 (*g)->AddControlEdge(e->src(), new_node, false);
3261 }
3262 }
3263 for (const Edge* e : succ->in_edges()) {
3264 if (e->IsControlEdge()) {
3265 // Don't allow duplicate edge
3266 (*g)->AddControlEdge(e->src(), new_node, false);
3267 }
3268 }
3269
3270 // Incoming edges are fixed, we will fix the outgoing edges now.
3271 // First, we will fix outgoing control edges from 'pred' node.
3272 for (const Edge* e : pred->out_edges()) {
3273 if (e->IsControlEdge()) {
3274 // Don't allow duplicate edge
3275 (*g)->AddControlEdge(new_node, e->dst(), false);
3276 }
3277 }
3278
3279 // Second, we will fix outgoing control and data edges from 'succ' node.
3280 for (const Edge* e : succ->out_edges()) {
3281 if (e->IsControlEdge()) {
3282 // Allow duplicate while adding control edge as it would fail (return
3283 // NULL) if we try to add duplicate edge.
3284 (*g)->AddControlEdge(new_node, e->dst(), false);
3285 } else {
3286 // Conv2D has only 1 output (at slot 0) and merged node also has only 1
3287 // output (at slot 0).
3288 const int kPadWithConv2DOutputSlot = 0;
3289 (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(),
3290 e->dst_input());
3291 }
3292 }
3293
3294 // Copy device assigned to old node to new node.
3295 // It's ok to use pred or succ as we have enforced a check that
3296 // both have same device assigned.
3297 new_node->set_assigned_device_name(pred->assigned_device_name());
3298
3299 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3300 << ", and node: " << succ->DebugString()
3301 << ", into node:" << new_node->DebugString();
3302
3303 (*g)->RemoveNode(succ);
3304 (*g)->RemoveNode(pred);
3305
3306 return Status::OK();
3307 }
3308
MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph> * g,Node * m,Node * n)3309 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
3310 std::unique_ptr<Graph>* g, Node* m, Node* n) {
3311 CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
3312 n->type_string() == csinfo_.conv2d_grad_filter)) ||
3313 ((n->type_string() == csinfo_.bias_add_grad &&
3314 m->type_string() == csinfo_.conv2d_grad_filter)),
3315 true);
3316
3317 // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
3318 Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
3319 Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m;
3320
3321 // Sanity check for attributes from input nodes.
3322 DataType T_b, T_f;
3323 string data_format_b, data_format_f;
3324 TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b));
3325 TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f));
3326 TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b));
3327 TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f));
3328 if (data_format_b != data_format_f || T_b != T_f ||
3329 badd->assigned_device_name() != fltr->assigned_device_name() ||
3330 badd->def().device() != fltr->def().device()) {
3331 return Status(error::Code::INVALID_ARGUMENT,
3332 "data_format or T attribute or devices of "
3333 "Conv2DBackpropFilter and BiasAddGrad do not match. "
3334 "Will skip node merge optimization");
3335 }
3336
3337 // We will use the node name of Conv2DBackpropFilter as the name of new node.
3338 // This is because BackpropFilterWithBias is going to emit bias output also.
3339 NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias);
3340 // Since Conv2DBackpropFilterWithBias has same number of inputs as
3341 // Conv2DBackpropFilter, we can just copy input edges directly. We don't need
3342 // to copy any data input of BiasAddGrad because that input also goes to
3343 // Conv2DBackpropFilter.
3344 const int fltr_ins = fltr->num_inputs();
3345 gtl::InlinedVector<Node*, 4> fltr_control_edges;
3346 gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins);
3347 FillInputs(fltr, &fltr_control_edges, &fltr_in_edges);
3348 for (int idx = 0; idx < fltr_ins; idx++) {
3349 nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second);
3350 }
3351
3352 // Copy attributes from Conv2DBackpropFilter.
3353 CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
3354
3355 // Copy the device assigned to old node to new node.
3356 nb.Device(fltr->def().device());
3357
3358 // Create node.
3359 Node* new_node;
3360 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3361
3362 // In the following code of this function, an unsorted set is used to make
3363 // sure no duplicated edges be added into the new node. Therefore, we can
3364 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3365 // check in the routine.
3366
3367 // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to
3368 // new 'new_node' node are already copied in BuildNode. We handle control
3369 // edges now.
3370 std::unordered_set<Node*> unique_node;
3371 for (const Edge* e : badd->in_edges()) {
3372 if (e->IsControlEdge()) {
3373 auto result = unique_node.insert(e->src());
3374 if (result.second) {
3375 (*g)->AddControlEdge(e->src(), new_node, true);
3376 }
3377 }
3378 }
3379 unique_node.clear();
3380 for (const Edge* e : fltr->in_edges()) {
3381 if (e->IsControlEdge()) {
3382 auto result = unique_node.insert(e->src());
3383 if (result.second) {
3384 (*g)->AddControlEdge(e->src(), new_node, true);
3385 }
3386 }
3387 }
3388 unique_node.clear();
3389
3390 // Incoming edges are fixed, we will fix the outgoing edges now.
3391 // First, we will fix outgoing control edges from 'badd' node.
3392 // Conv2DBackpropFilter has 1 output -- filter_grad.
3393 // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and
3394 // bias_grad. But filter_grad is at same slot number (0) in both the
3395 // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while
3396 // it is at slot number 0 in BiasAddGrad.
3397 const int kMergedNodeFilterGradOutputIdx = 0;
3398 const int kMergedNodeBiasGradOutputIdx = 1;
3399
3400 for (const Edge* e : badd->out_edges()) {
3401 if (e->IsControlEdge()) {
3402 auto result = unique_node.insert(e->dst());
3403 if (result.second) {
3404 (*g)->AddControlEdge(new_node, e->dst(), true);
3405 }
3406 } else {
3407 auto new_edge = (*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
3408 e->dst(), e->dst_input());
3409 DCHECK(new_edge);
3410 }
3411 }
3412 unique_node.clear();
3413
3414 // Second, we will fix outgoing control and data edges from 'fltr' node.
3415 for (const Edge* e : fltr->out_edges()) {
3416 if (e->IsControlEdge()) {
3417 auto result = unique_node.insert(e->dst());
3418 if (result.second) {
3419 (*g)->AddControlEdge(new_node, e->dst(), true);
3420 }
3421 } else {
3422 auto new_edge = (*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
3423 e->dst(), e->dst_input());
3424 DCHECK(new_edge);
3425 }
3426 }
3427
3428 // Copy device assigned to old node to new node.
3429 // It's ok to use badd or fltr as we have enforced a check that
3430 // both have same device assigned.
3431 new_node->set_assigned_device_name(badd->assigned_device_name());
3432
3433 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString()
3434 << ", and node: " << fltr->DebugString()
3435 << ", into node:" << new_node->DebugString();
3436
3437 (*g)->RemoveNode(badd);
3438 (*g)->RemoveNode(fltr);
3439
3440 return Status::OK();
3441 }
3442
MergeNode(std::unique_ptr<Graph> * g,Node * m,Node * n)3443 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m,
3444 Node* n) {
3445 DCHECK(m);
3446 DCHECK(n);
3447
3448 if (((m->type_string() == csinfo_.bias_add &&
3449 n->type_string() == csinfo_.conv2d)) ||
3450 ((n->type_string() == csinfo_.bias_add &&
3451 m->type_string() == csinfo_.conv2d))) {
3452 return this->MergeConv2DWithBiasAdd(g, m, n);
3453 }
3454 if ((m->type_string() == csinfo_.pad &&
3455 (n->type_string() == csinfo_.conv2d ||
3456 (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) ||
3457 (n->type_string() == csinfo_.pad &&
3458 (m->type_string() == csinfo_.conv2d ||
3459 (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) {
3460 return this->MergePadWithConv2D(g, m, n);
3461 }
3462
3463 if (((m->type_string() == csinfo_.bias_add_grad &&
3464 n->type_string() == csinfo_.conv2d_grad_filter)) ||
3465 ((n->type_string() == csinfo_.bias_add_grad &&
3466 m->type_string() == csinfo_.conv2d_grad_filter))) {
3467 return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n);
3468 }
3469
3470 return Status(error::Code::UNIMPLEMENTED,
3471 "Unimplemented case for node merge optimization.");
3472 }
3473
3474 //////////////////////////////////////////////////////////////////////////
3475 // Helper functions for node rewrite
3476 //////////////////////////////////////////////////////////////////////////
3477
RewriteNodeForLayoutPropagation(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3478 Status MklLayoutRewritePass::RewriteNodeForLayoutPropagation(
3479 std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3480 const RewriteInfo* ri) {
3481 // Get all data inputs.
3482 int num_data_inputs = orig_node->in_edges().size();
3483 // Drop count for control edges from inputs
3484 for (const Edge* e : orig_node->in_edges()) {
3485 if (e->IsControlEdge()) {
3486 num_data_inputs--;
3487 }
3488 }
3489
3490 gtl::InlinedVector<Node*, 4> control_edges;
3491 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3492 FillInputs(orig_node, &control_edges, &inputs);
3493
3494 // Build new node. We use same name as original node, but change the op name.
3495 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3496 // Copy user-specified device assigned to original node to new node.
3497 nb.Device(orig_node->def().device());
3498 // Set up new inputs to the rewritten node.
3499 Status s = SetUpInputs(g, inputs, &nb, orig_node);
3500 if (s != Status::OK()) {
3501 return s;
3502 }
3503
3504 const bool kPartialCopyAttrs = false;
3505 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs);
3506
3507 // Set the Mkl layer label for this op.
3508 if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3509 DataTypeIsQuantized(orig_node->output_type(0))) {
3510 nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3511 } else {
3512 nb.Attr("_kernel", mkl_op_registry::kMklLayoutDependentOpLabel);
3513 }
3514 // Finalize graph and get new node.
3515 s = nb.Finalize(&**g, new_node);
3516 if (s != Status::OK()) {
3517 return s;
3518 }
3519
3520 // In the following code of this function, an unsorted set is used to make
3521 // sure no duplicated edges be added into the new node. Therefore, we can
3522 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3523 // check in the routine.
3524
3525 // Incoming data edges from 'orig_node' node to new 'new_node' node are
3526 // already copied in BuildNode. We need to handle control edges now.
3527 std::unordered_set<Node*> unique_node;
3528 for (const Edge* e : orig_node->in_edges()) {
3529 if (e->IsControlEdge()) {
3530 auto result = unique_node.insert(e->src());
3531 if (result.second) {
3532 (*g)->AddControlEdge(e->src(), *new_node, true);
3533 }
3534 }
3535 }
3536 unique_node.clear();
3537
3538 // Copy outgoing edges from 'orig_node' node to new
3539 // 'new_node' node, since the output also follows same ordering among
3540 // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
3541 // tensors appropriately. Specifically, nth output of the original node
3542 // will become 2*nth output of the Mkl node for the interleaved ordering
3543 // of the tensors. For the contiguous ordering of the tensors, it will be n.
3544 // GetTensorDataIndex provides this mapping function.
3545 for (const Edge* e : orig_node->out_edges()) {
3546 if (e->IsControlEdge()) {
3547 auto result = unique_node.insert(e->dst());
3548 if (result.second) {
3549 (*g)->AddControlEdge(*new_node, e->dst(), true);
3550 }
3551 } else {
3552 auto new_edge = (*g)->AddEdge(
3553 *new_node,
3554 GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
3555 e->dst(), e->dst_input());
3556 DCHECK(new_edge);
3557 }
3558 }
3559 return Status::OK();
3560 }
3561
RewriteNodeForJustOpNameChange(std::unique_ptr<Graph> * g,const Node * orig_node,Node ** new_node,const RewriteInfo * ri)3562 Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(
3563 std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3564 const RewriteInfo* ri) {
3565 // Get all data inputs.
3566 int num_data_inputs = orig_node->in_edges().size();
3567 // Drop count for control edges from inputs
3568 for (const Edge* e : orig_node->in_edges()) {
3569 if (e->IsControlEdge()) {
3570 num_data_inputs--;
3571 }
3572 }
3573 gtl::InlinedVector<Node*, 4> control_edges;
3574 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3575 FillInputs(orig_node, &control_edges, &inputs);
3576
3577 // Build new node. We use same name as original node, but change the op name.
3578 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3579 // Copy user-specified device assigned to original node to new node.
3580 nb.Device(orig_node->def().device());
3581
3582 Status s = CopyInputs(orig_node, inputs, &nb);
3583 if (s != Status::OK()) {
3584 return s;
3585 }
3586
3587 std::vector<NodeBuilder::NodeOut> workspace_tensors;
3588 bool are_workspace_tensors_available = false;
3589 AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors,
3590 &are_workspace_tensors_available);
3591 if (are_workspace_tensors_available) {
3592 DCHECK_EQ(workspace_tensors.size(), 1);
3593 nb.Input(workspace_tensors[0].node, workspace_tensors[0].index);
3594 }
3595
3596 if (!NativeFormatEnabled()) {
3597 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
3598 } else {
3599 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, false);
3600 }
3601
3602 nb.Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
3603
3604 // Finalize graph and get new node.
3605 s = nb.Finalize(&**g, new_node);
3606 if (s != Status::OK()) {
3607 return s;
3608 }
3609
3610 // In the following code of this function, an unsorted set is used to make
3611 // sure no duplicated edges be added into the new node. Therefore, we can
3612 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3613 // check in the routine.
3614
3615 // Incoming data edges from 'orig_node' node to new 'new_node' node are
3616 // already copied in BuildNode. We need to handle control edges now.
3617 std::unordered_set<Node*> unique_node;
3618 for (const Edge* e : orig_node->in_edges()) {
3619 if (e->IsControlEdge()) {
3620 auto result = unique_node.insert(e->src());
3621 if (result.second) {
3622 (*g)->AddControlEdge(e->src(), *new_node, true);
3623 }
3624 }
3625 }
3626 unique_node.clear();
3627
3628 // Transfer outgoing edges from 'orig_node' node to new 'new_node' node.
3629 for (const Edge* e : orig_node->out_edges()) {
3630 if (e->IsControlEdge()) {
3631 auto result = unique_node.insert(e->dst());
3632 if (result.second) {
3633 (*g)->AddControlEdge(*new_node, e->dst(), true);
3634 }
3635 } else {
3636 auto result =
3637 (*g)->AddEdge(*new_node, e->src_output(), e->dst(), e->dst_input());
3638 DCHECK(result != nullptr);
3639 }
3640 }
3641
3642 return Status::OK();
3643 }
3644
RewriteNode(std::unique_ptr<Graph> * g,Node * orig_node,const RewriteInfo * ri)3645 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
3646 Node* orig_node,
3647 const RewriteInfo* ri) {
3648 DCHECK(ri != nullptr);
3649 DCHECK(orig_node != nullptr);
3650
3651 VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
3652
3653 Status ret_status = Status::OK();
3654 Node* new_node = nullptr;
3655 if (ri->rewrite_cause == kRewriteForLayoutPropagation) {
3656 ret_status = RewriteNodeForLayoutPropagation(g, orig_node, &new_node, ri);
3657 } else if (ri->rewrite_cause == kRewriteForOpNameChange) {
3658 ret_status = RewriteNodeForJustOpNameChange(g, orig_node, &new_node, ri);
3659 } else {
3660 ret_status = Status(error::Code::INVALID_ARGUMENT,
3661 "Unsupported rewrite cause found."
3662 "RewriteNode will fail.");
3663 }
3664 TF_CHECK_OK(ret_status);
3665
3666 // Copy the runtime device assigned from original code to new node.
3667 new_node->set_assigned_device_name(orig_node->assigned_device_name());
3668
3669 // Delete original node and mark new node as rewritten.
3670 (*g)->RemoveNode(orig_node);
3671
3672 VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
3673 return ret_status;
3674 }
3675
3676 // TODO(mdfaijul): Is there any other elegant way to check for quantized ops
3677 // having attributes other than "T"?
3678 // Current implementation reflects only QuantizedConv2D and its fused Ops.
3679 const MklLayoutRewritePass::RewriteInfo*
CheckForQuantizedNodeRewrite(const Node * n) const3680 MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {
3681 DataType T1, T2;
3682 DataType Tinput, Tfilter;
3683 bool type_attrs_present = false;
3684
3685 if (TryGetNodeAttr(n->def(), "Tinput", &Tinput) &&
3686 TryGetNodeAttr(n->def(), "Tfilter", &Tfilter) &&
3687 mkl_op_registry::IsMklLayoutDependentOp(
3688 mkl_op_registry::GetMklOpName(n->type_string()), Tinput, Tfilter)) {
3689 type_attrs_present = true;
3690 } else if (TryGetNodeAttr(n->def(), "T1", &T1) &&
3691 TryGetNodeAttr(n->def(), "T2", &T2) &&
3692 mkl_op_registry::IsMklLayoutDependentOp(
3693 mkl_op_registry::GetMklOpName(n->type_string()), T1, T2)) {
3694 type_attrs_present = true;
3695 }
3696
3697 if (type_attrs_present) {
3698 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3699 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3700 return &*ri;
3701 }
3702 }
3703 }
3704
3705 return nullptr;
3706 }
3707
3708 const MklLayoutRewritePass::RewriteInfo*
CheckForNodeRewrite(const Node * n) const3709 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
3710 DCHECK(n);
3711
3712 // QuantizedOps may have attributes other than "T", so decoupled the check
3713 // with a function, CheckForQuantizedNodeRewrite(const Node*).
3714 const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n);
3715 if (ri != nullptr) return ri;
3716
3717 // First check if node along with its type is supported by MKL layer.
3718 // We do not want to rewrite an op into Mkl op if types are not supported.
3719 // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
3720 // MklRelu if type is INT32.
3721 DataType T;
3722 if (!TryGetNodeAttr(n->def(), "T", &T)) {
3723 return nullptr;
3724 }
3725
3726 // We make an exception for Conv2DGrad and MaxPool related ops as
3727 // the corresponding MKL ops currently do not support the case
3728 // of padding == EXPLICIT yet.
3729 // TODO(intel): support `EXPLICIT` padding for ConvGrad
3730 if (n->type_string() == csinfo_.conv2d_grad_input ||
3731 n->type_string() == csinfo_.conv2d_grad_filter ||
3732 n->type_string() == csinfo_.max_pool ||
3733 n->type_string() == csinfo_.max_pool_grad ||
3734 n->type_string() == csinfo_.max_pool3d ||
3735 n->type_string() == csinfo_.max_pool3d_grad) {
3736 string padding;
3737 TF_CHECK_OK(GetNodeAttr(n->def(), "padding", &padding));
3738 if (padding == "EXPLICIT") return nullptr;
3739 }
3740
3741 // We make an exception for __MklDummyConv2DWithBias,
3742 // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their
3743 // names do not match Mkl node names.
3744 if (n->type_string() != csinfo_.conv2d_with_bias &&
3745 n->type_string() != csinfo_.pad_with_conv2d &&
3746 n->type_string() != csinfo_.pad_with_fused_conv2d &&
3747 n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
3748 n->type_string() != csinfo_.fused_batch_norm_ex &&
3749 n->type_string() != csinfo_.fused_conv2d &&
3750 n->type_string() != csinfo_.fused_depthwise_conv2d &&
3751 n->type_string() != csinfo_.fused_matmul &&
3752 !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
3753 T)) {
3754 return nullptr;
3755 }
3756
3757 // We now check if rewrite rule applies for this op. If rewrite rule passes
3758 // for this op, then we rewrite it to Mkl op.
3759 // Find matching RewriteInfo and then check that rewrite rule applies.
3760 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3761 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3762 return &*ri;
3763 }
3764 }
3765
3766 // Else return not found.
3767 return nullptr;
3768 }
3769
3770 //////////////////////////////////////////////////////////////////////////
3771 // Helper functions for node fusion
3772 //////////////////////////////////////////////////////////////////////////
FuseTransposeMklOpTranspose(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,std::function<void (const Node *,NodeBuilder * nb,bool)> copy_attrs,string data_format)3773 Status MklLayoutRewritePass::FuseTransposeMklOpTranspose(
3774 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3775 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
3776 string data_format) {
3777 Node* transpose_to_nhwc = nodes[0];
3778 Node* mklop = nodes[1];
3779 Node* transpose_to_nchw = nodes[2];
3780
3781 const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs();
3782 gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges;
3783 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in(
3784 transpose_nhwc_num_inputs);
3785 FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges,
3786 &transpose_nhwc_in);
3787
3788 const int mklop_num_inputs = mklop->num_inputs();
3789 gtl::InlinedVector<Node*, 4> mklop_control_edges;
3790 gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs);
3791 FillInputs(mklop, &mklop_control_edges, &mklop_in);
3792
3793 const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs();
3794 gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges;
3795 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in(
3796 transpose_nchw_num_inputs);
3797 FillInputs(transpose_to_nchw, &transpose_nchw_control_edges,
3798 &transpose_nchw_in);
3799
3800 // We use same name as original node, but change the op
3801 // type.
3802 NodeBuilder nb(mklop->name(), mklop->type_string());
3803
3804 // Storing the output slots of the input nodes.
3805 for (int i = 0; i < mklop_num_inputs; i++) {
3806 if (mklop_in[i].first == transpose_to_nhwc) {
3807 // Fill "x":
3808 nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second);
3809 } else {
3810 // Fill inputs other than "x":
3811 nb.Input(mklop_in[i].first, mklop_in[i].second);
3812 }
3813 }
3814
3815 copy_attrs(const_cast<const Node*>(mklop), &nb, true);
3816 nb.Attr("data_format", data_format);
3817
3818 // Copy the device assigned to old node to new node.
3819 nb.Device(mklop->def().device());
3820
3821 // Create node.
3822 Node* new_node;
3823 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3824 // No need to check if new_node is null because it will be null only when
3825 // Finalize fails.
3826
3827 // Fill outputs.
3828 for (const Edge* e : transpose_to_nchw->out_edges()) {
3829 if (!e->IsControlEdge()) {
3830 const int kTransposeWithMklOpOutputSlot = 0;
3831 auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot,
3832 e->dst(), e->dst_input());
3833 DCHECK(new_edge);
3834 }
3835 }
3836
3837 // Copy device assigned to old node to new node.
3838 new_node->set_assigned_device_name(mklop->assigned_device_name());
3839
3840 // Copy requested_device and assigned_device_name_index
3841 new_node->set_requested_device(mklop->requested_device());
3842 new_node->set_assigned_device_name_index(mklop->assigned_device_name_index());
3843
3844 (*g)->RemoveNode(transpose_to_nhwc);
3845 (*g)->RemoveNode(mklop);
3846 (*g)->RemoveNode(transpose_to_nchw);
3847
3848 return Status::OK();
3849 }
3850
FuseNode(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,const MklLayoutRewritePass::FusionInfo fi)3851 Status MklLayoutRewritePass::FuseNode(
3852 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3853 const MklLayoutRewritePass::FusionInfo fi) {
3854 return fi.fuse_func(g, nodes, fi.copy_attrs);
3855 }
3856
3857 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
CheckForNodeFusion(Node * a) const3858 MklLayoutRewritePass::CheckForNodeFusion(Node* a) const {
3859 // Stores matched nodes, in the same order as node_checkers.
3860 std::vector<Node*> nodes;
3861
3862 for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) {
3863 //
3864 // Make sure node "a" and its succeeding nodes (b, c ...), match the pattern
3865 // defined in fusion info (ops[0], ops[1], ...),
3866 // a.k.a. "a->b->c" matches "op1->op2->op3"
3867 //
3868
3869 // Stores the first unvisited outgoing edge of each matched node in "nodes".
3870 std::stack<EdgeSet::const_iterator> current_neighbor_stack;
3871 nodes.clear();
3872
3873 auto node_checker = fi->node_checkers.begin();
3874 if (a != nullptr && (*node_checker)(a)) {
3875 nodes.push_back(a);
3876 current_neighbor_stack.push(a->out_edges().begin());
3877 ++node_checker;
3878 }
3879
3880 while (!nodes.empty()) {
3881 auto& current_neighbor_iter = current_neighbor_stack.top();
3882
3883 if (current_neighbor_iter != nodes.back()->out_edges().end()) {
3884 // Found an unvisited edge. Goes through the edge to get the neighbor.
3885 Node* neighbor_node = (*current_neighbor_iter)->dst();
3886 ++current_neighbor_stack.top(); // Retrieves the next unvisited edge.
3887
3888 if ((*node_checker)(neighbor_node)) {
3889 // Found a match. Stores the node and moves to the next checker.
3890 nodes.push_back(neighbor_node);
3891 current_neighbor_stack.push(neighbor_node->out_edges().begin());
3892 if (++node_checker == fi->node_checkers.end()) {
3893 return make_tuple(true, nodes, *fi);
3894 }
3895 }
3896 } else {
3897 // Removes the current node since none of its neighbor leads to a
3898 // further match.
3899 nodes.pop_back();
3900 current_neighbor_stack.pop();
3901 --node_checker;
3902 }
3903 }
3904 }
3905
3906 return make_tuple(false, std::vector<Node*>(), FusionInfo());
3907 }
3908
3909 ///////////////////////////////////////////////////////////////////////////////
3910 // Post-rewrite Mkl metadata fixup pass
3911 ///////////////////////////////////////////////////////////////////////////////
FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph> * g,const Edge * e_data,const Edge * e_metadata)3912 bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
3913 const Edge* e_data,
3914 const Edge* e_metadata) {
3915 if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
3916 return false;
3917 }
3918
3919 Node* n_data = e_data->src();
3920 int n_data_op_slot = e_data->src_output();
3921 int n_metadata_op_slot =
3922 GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs());
3923
3924 // If the source of meta edge is a constant node (producing dummy Mkl metadata
3925 // tensor), then we will need to fix.
3926 if (IsConstant(e_metadata->src())) {
3927 Node* e_metadata_dst = e_metadata->dst();
3928 int e_metadata_in_slot = e_metadata->dst_input();
3929 auto new_edge = (*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst,
3930 e_metadata_in_slot);
3931 DCHECK(new_edge);
3932
3933 (*g)->RemoveEdge(e_metadata);
3934 return true;
3935 }
3936
3937 return false;
3938 }
3939
FixMklMetaDataEdges(std::unique_ptr<Graph> * g,Node * n)3940 bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
3941 Node* n) {
3942 bool result = false;
3943
3944 // If graph node is not Mkl node, then return.
3945 DataType T = DT_INVALID;
3946 if (!TryGetNodeAttr(n->def(), "T", &T) ||
3947 !mkl_op_registry::IsMklLayoutDependentOp(n->type_string(), T)) {
3948 return result;
3949 }
3950
3951 // If it is Mkl node, then check if the input edges to this node that carry
3952 // Mkl metadata are linked up correctly with the source node.
3953
3954 // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
3955 // data tensors + n for Mkl metadata tensors). We need to check for correct
3956 // connection of n metadata tensors only.
3957 int num_data_inputs = n->num_inputs() / 2;
3958 for (int idx = 0; idx < num_data_inputs; idx++) {
3959 // Get the edge connecting input slot with index (idx).
3960 const Edge* e = nullptr;
3961 TF_CHECK_OK(n->input_edge(idx, &e));
3962
3963 // If e is control edge, then skip.
3964 if (e->IsControlEdge()) {
3965 continue;
3966 }
3967
3968 // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
3969 // node, then we don't need to do anything.
3970 Node* e_src = e->src();
3971 if (TryGetNodeAttr(e_src->def(), "T", &T) &&
3972 mkl_op_registry::IsMklLayoutDependentOp(e_src->type_string(), T)) {
3973 // Source node for edge 'e' is Mkl node.
3974 // Destination node and destination input slot of e is node 'n' and 'idx'
3975 // resp.
3976 CHECK_EQ(e->dst(), n);
3977 CHECK_EQ(e->dst_input(), idx);
3978
3979 // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
3980 // 'e'. For that, let's first get the input slot of 'n' where the meta
3981 // edge will feed the value.
3982 int e_meta_in_slot =
3983 GetTensorMetaDataIndex(e->dst_input(), n->num_inputs());
3984 const Edge* e_meta = nullptr;
3985 TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
3986
3987 // Let's check if we need to fix this meta edge.
3988 if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
3989 result = true;
3990 }
3991 }
3992 }
3993
3994 return result;
3995 }
3996
3997 ///////////////////////////////////////////////////////////////////////////////
3998 // Run function for the pass
3999 ///////////////////////////////////////////////////////////////////////////////
4000
RunPass(std::unique_ptr<Graph> * g)4001 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
4002 bool result = false;
4003 DCHECK(g);
4004
4005 DumpGraph("Before running MklLayoutRewritePass", &**g);
4006
4007 std::vector<Node*> order;
4008 GetReversePostOrder(**g, &order); // This will give us topological sort.
4009 for (Node* n : order) {
4010 // If node is not an op or it cannot run on CPU device, then skip.
4011 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4012 continue;
4013 }
4014
4015 Node* m = nullptr;
4016 if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) {
4017 // Check if the node 'n' can be merged with any other node. If it can
4018 // be 'm' contains the node with which it can be merged.
4019 string n1_name = n->name();
4020 string n2_name = m->name();
4021
4022 VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
4023 << n2_name << " for merging";
4024
4025 if (MergeNode(g, n, m) == Status::OK()) {
4026 VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
4027 << n2_name;
4028 result = true;
4029 }
4030 }
4031 }
4032
4033 DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g);
4034
4035 order.clear();
4036 GetReversePostOrder(**g, &order); // This will give us topological sort.
4037 for (Node* n : order) {
4038 // If node is not an op or it cannot run on CPU device, then skip.
4039 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4040 continue;
4041 }
4042
4043 auto check_result = CheckForNodeFusion(n);
4044 bool found_pattern = std::get<0>(check_result);
4045 std::vector<Node*> nodes = std::get<1>(check_result);
4046 const FusionInfo fi = std::get<2>(check_result);
4047
4048 // if "found_pattern" is true, we can do the fusion.
4049 if (found_pattern) {
4050 if (FuseNode(g, nodes, fi) == Status::OK()) {
4051 result = true;
4052 }
4053 }
4054 }
4055 DumpGraph("After running MklLayoutRewritePass(NodeFusion)", &**g);
4056
4057 order.clear();
4058 GetReversePostOrder(**g, &order); // This will give us topological sort.
4059 for (Node* n : order) {
4060 // If node is not an op or it cannot run on CPU device, then skip.
4061 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4062 continue;
4063 }
4064
4065 const RewriteInfo* ri = nullptr;
4066 // We will first search if node is to be rewritten.
4067 if ((ri = CheckForNodeRewrite(n)) != nullptr) {
4068 string node_name = n->name();
4069 string op_name = n->type_string();
4070
4071 VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
4072 << " with op " << op_name << " for rewrite using"
4073 << " layout optimization.";
4074
4075 if (RewriteNode(g, n, ri) == Status::OK()) {
4076 VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
4077 << " with op " << op_name << " for Mkl layout optimization.";
4078 result = true;
4079 }
4080 }
4081 }
4082
4083 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
4084
4085 order.clear();
4086 GetReversePostOrder(**g, &order); // This will give us topological sort.
4087 for (Node* n : order) {
4088 // If node is not an op or it cannot run on CPU device, then skip.
4089 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4090 continue;
4091 }
4092 if (FixMklMetaDataEdges(g, n)) {
4093 string node_name = n->name();
4094 string op_name = n->type_string();
4095
4096 VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
4097 << node_name << " with op " << op_name;
4098 result = true;
4099 }
4100 }
4101 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
4102 &**g);
4103
4104 return result;
4105 }
4106
RunMklLayoutRewritePass(std::unique_ptr<Graph> * g)4107 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
4108 return MklLayoutRewritePass().RunPass(g);
4109 }
4110
Run(const GraphOptimizationPassOptions & options)4111 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
4112 if (options.graph == nullptr && options.partition_graphs == nullptr) {
4113 return Status::OK();
4114 }
4115 if (DisableMKL()) {
4116 VLOG(2) << "TF-MKL: Disabling MKL";
4117 return Status::OK();
4118 }
4119
4120 auto process_graph = [&](std::unique_ptr<Graph>* g) {
4121 // Get the ownership of a graph
4122 std::unique_ptr<Graph>* ng = std::move(g);
4123 RunPass(ng);
4124 // Return the ownership of a graph back
4125 g->reset(ng->release());
4126 };
4127
4128 if (kMklLayoutRewritePassGroup !=
4129 OptimizationPassRegistry::POST_PARTITIONING) {
4130 // For any pre-partitioning phase, a graph is stored in options.graph.
4131 process_graph(options.graph);
4132 } else {
4133 // For post partitioning phase, graphs are stored in
4134 // options.partition_graphs.
4135 for (auto& pg : *options.partition_graphs) {
4136 process_graph(&pg.second);
4137 }
4138 }
4139
4140 return Status::OK();
4141 }
4142
4143 } // namespace tensorflow
4144
4145 #endif
4146