1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // TODO(intel): Improve error handling in this file; instead of CHECK failing
17 // all over the place, we should log an error and execute the original graph.
18 #ifdef INTEL_MKL
19
20 #include <algorithm>
21 #include <functional>
22 #include <memory>
23 #include <queue>
24 #include <set>
25 #include <stack>
26 #include <tuple>
27 #include <unordered_set>
28 #include <utility>
29 #include <vector>
30
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/optimization_registry.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/lib/gtl/array_slice.h"
40 #include "tensorflow/core/lib/gtl/map_util.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/util/tensor_format.h"
44 #include "tensorflow/core/util/util.h"
45
46 #include "tensorflow/core/graph/mkl_graph_util.h"
47 #include "tensorflow/core/graph/mkl_layout_pass.h"
48
49 namespace tensorflow {
50
51 // This pass implements rewriting of graph to support following scenarios:
52 // (A) Merging nodes in the graph
53 // (B) Rewriting a node in the graph to a new node
54 // Rewrite happens under following scenario:
55 // - Propagating Mkl layout as an additional output tensor
56 // (we will loosely call a tensor that carries Mkl layout as Mkl tensor
57 // henceforth.) from every Mkl supported NN layer.
58 //
59 // Example of A : Merging nodes in the graph
60 // -----------------------------------------
61 // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
62 //
63 // O = Conv2D(A, B)
64 // P = BiasAdd(O, C)
65 //
66 // We merge them into Conv2DWithBias as:
67 // P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
68 //
69 // The meaning of A_m, B_m and C_m is explained in B.1.
70 //
71 // Merge rules:
72 // - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
73 // goes to BiasAdd.
74 // - Also, the intersection of attributes of both the nodes must have same
75 // values.
76 // - Both the nodes must have been assigned to same device (if any).
77 //
78 // Example of B.1 : Rewriting nodes to Mkl nodes
79 // ---------------------------------------------
80 // Consider a Relu node. Current definition of Relu node looks like:
81 //
82 // O = Relu(A)
83 //
84 // Relu has 1 input (A), and 1 output (O).
85 //
86 // This rewrite pass will generate a new graph node for Relu (new node is
87 // called MklRelu) as:
88 //
89 // O, O_m = MklRelu(A, A_m)
90 //
91 // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
92 // same as input A of Relu; output O is same as output O of Relu. O_m is the
93 // additional output tensor that will be set by MklRelu, and it represents
94 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
95 // metadata for O. A_m is additional input of Relu, and it represents metadata
96 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
97 // this metadata from previous node in the graph.
98 //
99 // When a previous node in the graph is an Mkl node, A_m will represent a valid
100 // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
101 // a dummy Mkl tensor.
102 //
103 // Rewriting rules:
104 // - Selection of a node for rewriting happens by registering the op type of
105 // the node with the rewriting pass. If the op type is not registered, then
106 // all nodes of this op type will not be rewritten.
107 // - Number of inputs after rewriting:
108 // Since for every input Tensorflow tensor, the rewritten node gets Mkl
109 // tensor(s), rewritten node gets 2*N inputs, where N is the number of
110 // inputs for the original node.
111 // - Number of outputs after rewriting:
112 // Since for every output Tensorflow tensor, the rewritten node generates
113 // Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
114 // number of outputs of the original node.
115 // - Ordering of Tensorflow tensors and Mkl tensors:
116 // Since every rewritten node generates twice the number of inputs and
117 // outputs, one could imagine various orderings among Tensorflow tensors
118 // and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
119 // inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
120 // in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
121 // order. Among N inputs one can get N! permutations.
122 //
123 // So the question is: which order do we follow? We support 2 types of
124 // orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
125 // follows an intuitive order where an Mkl tensor follows the
126 // corresponding Tensorflow tensor immediately. In the context of the
127 // above example, it will be: A, A_m, B, B_m. Note that the ordering rule
128 // applies to both the inputs and outputs. Contiguous ordering means
129 // all the Tensorflow tensors are contiguous followed by all the Mkl
130 // tensors. We use contiguous ordering as default.
131 //
132 // Graph rewrite algorithm:
133 // Algorithm: Graph Rewrite
134 // Input: Graph G, Names of the nodes to rewrite and their new names
135 // Output: Modified Graph G' if the nodes are modified, G otherwise.
136 // Start:
137 // N = Topological_Sort(G) // N is a set of nodes in toposort order.
138 // foreach node n in N
139 // do
140 // if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input.
141 // then
142 // E = set of <incoming edge and its src_output slot> of n
143 // E' = {} // a new set of edges for rewritten node
144 // foreach <e,s> in E
145 // do
146 // E' U {<e,s>} // First copy edge which generates Tensorflow
147 // // tensor as it is
148 // m = Source node of edge e
149 // if Is_Rewritten(m) // Did we rewrite this node in this pass?
150 // then
151 // E' U {<m,s+1>} // If yes, then m will generate an Mkl
152 // // tensor as an additional output.
153 // else
154 // d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy
155 // // Mkl tensor.
156 // E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot.
157 // fi
158 // done
159 // n' = Build_New_Node(G,new_name,E')
160 // Mark_Rewritten(n') // Mark the new node as being rewritten.
161 // fi
162 // done
163 //
164 // Explanation:
165 // For graph rewrite, we visit nodes of the input graph in the
166 // topological sort order. With this ordering, we visit nodes in the
167 // top-to-bottom fashion. We need this order because while visiting a
168 // node we want that all of its input nodes are visited and rewritten if
169 // applicable. This is because if we need to rewrite a given node
170 // then all of its input nodes need to be fixed (in other words they
171 // cannot be deleted later.)
172 //
173 // While visiting a node, we first check if the op type of the node is
174 // an Mkl op. If it is, then we rewrite that node after constructing
175 // new inputs to the node. If the op type of the node is not Mkl op,
176 // then we do not rewrite that node.
177 //
178 // Handling workspace propagation for certain ops:
179 //
180 // Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
181 // passing of a workspace from their respective forward ops. Workspace
182 // tensors provide memory for storing results of intermediate operations
183 // which are helpful in backward propagation. TensorFlow does not have
184 // a notion of a workspace and as a result does not allow producing
185 // additional outputs from these forward ops. For these ops, we need
186 // to add 2 extra edges between forward ops and their corresponding
187 // backward ops - the first extra edge carries a workspace tensor and
188 // the second one carries an Mkl tensor for the workspace tensor.
189 //
190 // Example:
191 //
192 // Typical graph for MaxPool and its gradient looks like:
193 //
194 // A = MaxPool(T)
195 // B = MaxPoolGrad(X, A, Y)
196 //
197 // We will transform this graph to propagate the workspace as:
198 // (with the contiguous ordering)
199 //
200 // A, W, A_m, W_m = MklMaxPool(T, T_m)
201 // B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
202 //
203 // Here W is the workspace tensor. Transformed tensor names with the
204 // suffix _m are Mkl tensors, and this transformation has been done
205 // using the algorithm discussed earlier. The transformation for
206 // workspace propagation only adds extra outputs (W, W_m) for a forward
207 // op and connects them to the corresponding backward ops.
208 //
209 // Terms:
210 //
211 // Forward op name = name of the op in the forward pass
212 // where a workspace tensor originates (MaxPool in this example)
213 // Backward op name = name of the op in the backward pass that receives
214 // a workspace tensor from the forward op (MaxPoolGrad in the example)
215 // Slot = Position of the output or input slot that will be
216 // used by the workspace tensor (1 for MklMaxPool as W is the 2nd
217 // output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
218 //
219 // Question:
220 //
221 // How do we associate a backward op to a forward op? There can be more
222 // than one op with the exact same name.
223 //
224 // In this example, we associate MaxPoolGrad with MaxPool. But there
225 // could be more than one MaxPool ops. To solve this problem, we look
226 // for _direct_ edge between a forward op and a backward op (tensor A is
227 // flowing along this edge in the example).
228 //
229 // How do we transform forward and backward ops when there is no direct
230 // edge between them? In such a case, we generate dummy tensors for
231 // workspace tensors. For the example, transformation of MaxPool will
232 // be exactly same as it would be when there is a direct edge between
233 // the forward and the backward op --- it is just that MaxPool won't
234 // generate any workspace tensor. For MaxPoolGrad, the transformation
235 // will also be same, but instead of connecting W and W_m with the
236 // outputs of MaxPool, we will produce dummy tensors for them, and we
237 // will set workspace_enabled attribute to false.
238 //
239 class MklLayoutRewritePass : public GraphOptimizationPass {
240 public:
MklLayoutRewritePass()241 MklLayoutRewritePass() {
242 // NOTE: names are alphabetically sorted.
243 csinfo_.addn = "AddN";
244 csinfo_.avg_pool = "AvgPool";
245 csinfo_.avg_pool_grad = "AvgPoolGrad";
246 csinfo_.avg_pool3d = "AvgPool3D";
247 csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
248 csinfo_.bias_add = "BiasAdd";
249 csinfo_.bias_add_grad = "BiasAddGrad";
250 csinfo_.concat = "Concat";
251 csinfo_.concatv2 = "ConcatV2";
252 csinfo_.conv2d = "Conv2D";
253 csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias";
254 csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
255 csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
256 csinfo_.conv2d_grad_filter_with_bias =
257 "__MklDummyConv2DBackpropFilterWithBias";
258 csinfo_.conv3d = "Conv3D";
259 csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
260 csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
261 csinfo_.depthwise_conv2d = "DepthwiseConv2dNative";
262 csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput";
263 csinfo_.depthwise_conv2d_grad_filter =
264 "DepthwiseConv2dNativeBackpropFilter";
265 csinfo_.fused_batch_norm = "FusedBatchNorm";
266 csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
267 csinfo_.fused_conv2d = "_FusedConv2D";
268 csinfo_.identity = "Identity";
269 csinfo_.leakyrelu = "LeakyRelu";
270 csinfo_.leakyrelu_grad = "LeakyReluGrad";
271 csinfo_.lrn = "LRN";
272 csinfo_.lrn_grad = "LRNGrad";
273 csinfo_.matmul = "MatMul";
274 csinfo_.max_pool = "MaxPool";
275 csinfo_.max_pool_grad = "MaxPoolGrad";
276 csinfo_.max_pool3d = "MaxPool3D";
277 csinfo_.max_pool3d_grad = "MaxPool3DGrad";
278 csinfo_.mkl_conv2d = "_MklConv2D";
279 csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
280 csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
281 csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
282 csinfo_.mkl_conv2d_grad_filter_with_bias =
283 "_MklConv2DBackpropFilterWithBias";
284 csinfo_.mkl_depthwise_conv2d_grad_input =
285 "_MklDepthwiseConv2dNativeBackpropInput";
286 csinfo_.mkl_depthwise_conv2d_grad_filter =
287 "_MklDepthwiseConv2dNativeBackpropFilter";
288 csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
289 csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
290 csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
291 csinfo_.pad = "Pad";
292 csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D";
293 csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D";
294 csinfo_.quantized_avg_pool = "QuantizedAvgPool";
295 csinfo_.quantized_concatv2 = "QuantizedConcatV2";
296 csinfo_.quantized_conv2d = "QuantizedConv2D";
297 csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize";
298 csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias";
299 csinfo_.quantized_conv2d_with_bias_and_requantize =
300 "QuantizedConv2DWithBiasAndRequantize";
301 csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu";
302 csinfo_.quantized_conv2d_and_relu_and_requantize =
303 "QuantizedConv2DAndReluAndRequantize";
304 csinfo_.quantized_conv2d_with_bias_and_relu =
305 "QuantizedConv2DWithBiasAndRelu";
306 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize =
307 "QuantizedConv2DWithBiasAndReluAndRequantize";
308 csinfo_.quantized_max_pool = "QuantizedMaxPool";
309 csinfo_.quantized_conv2d_with_bias_sum_and_relu =
310 "QuantizedConv2DWithBiasSumAndRelu";
311 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize =
312 "QuantizedConv2DWithBiasSumAndReluAndRequantize";
313 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize =
314 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize";
315 csinfo_.relu = "Relu";
316 csinfo_.relu_grad = "ReluGrad";
317 csinfo_.relu6 = "Relu6";
318 csinfo_.relu6_grad = "Relu6Grad";
319 csinfo_.requantize = "Requantize";
320 csinfo_.tanh = "Tanh";
321 csinfo_.tanh_grad = "TanhGrad";
322 csinfo_.reshape = "Reshape";
323 csinfo_.slice = "Slice";
324 csinfo_.softmax = "Softmax";
325 csinfo_.split = "Split";
326 csinfo_.transpose = "Transpose";
327 // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
328 // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
329 // MklInputConversion op is added before it.
330 csinfo_.add = "Add";
331 csinfo_.maximum = "Maximum";
332 csinfo_.mul = "Mul";
333 csinfo_.squared_difference = "SquaredDifference";
334 csinfo_.sub = "Sub";
335 // End - element-wise ops. See note above.
336
337 // NOTE: names are alphabetically sorted.
338 rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
339 CopyAttrsAddN, AddNRewrite});
340 rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
341 CopyAttrsDataType, AlwaysRewrite});
342 rinfo_.push_back({csinfo_.avg_pool,
343 mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
344 CopyAttrsPooling, AlwaysRewrite});
345 rinfo_.push_back({csinfo_.avg_pool_grad,
346 mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
347 CopyAttrsPooling, AlwaysRewrite});
348 rinfo_.push_back({csinfo_.avg_pool3d,
349 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
350 CopyAttrsPooling, AlwaysRewrite});
351 rinfo_.push_back({csinfo_.avg_pool3d_grad,
352 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
353 CopyAttrsPooling, AlwaysRewrite});
354 rinfo_.push_back({csinfo_.concat,
355 mkl_op_registry::GetMklOpName(csinfo_.concat),
356 CopyAttrsConcat, AlwaysRewrite});
357 rinfo_.push_back({csinfo_.concatv2,
358 mkl_op_registry::GetMklOpName(csinfo_.concatv2),
359 CopyAttrsConcatV2, AlwaysRewrite});
360 rinfo_.push_back({csinfo_.conv2d,
361 mkl_op_registry::GetMklOpName(csinfo_.conv2d),
362 CopyAttrsConvCheckConstFilter, AlwaysRewrite});
363 rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
364 CopyAttrsConvCheckConstFilter, AlwaysRewrite});
365 rinfo_.push_back({csinfo_.conv2d_grad_filter,
366 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
367 CopyAttrsConv, AlwaysRewrite});
368 rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
369 csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv,
370 AlwaysRewrite});
371 rinfo_.push_back({csinfo_.conv2d_grad_input,
372 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
373 CopyAttrsConv, AlwaysRewrite});
374 rinfo_.push_back({csinfo_.conv3d,
375 mkl_op_registry::GetMklOpName(csinfo_.conv3d),
376 CopyAttrsConvCheckConstFilter, AlwaysRewrite});
377 rinfo_.push_back({csinfo_.conv3d_grad_filter,
378 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
379 CopyAttrsConv, AlwaysRewrite});
380 rinfo_.push_back({csinfo_.conv3d_grad_input,
381 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
382 CopyAttrsConv, AlwaysRewrite});
383 rinfo_.push_back({csinfo_.depthwise_conv2d,
384 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
385 CopyAttrsConv2DDepthwiseCheckConstFilter, AlwaysRewrite});
386 rinfo_.push_back(
387 {csinfo_.depthwise_conv2d_grad_input,
388 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
389 CopyAttrsConv2DDepthwise, AlwaysRewrite});
390 rinfo_.push_back(
391 {csinfo_.depthwise_conv2d_grad_filter,
392 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
393 CopyAttrsConv2DDepthwise, AlwaysRewrite});
394 rinfo_.push_back({csinfo_.fused_batch_norm,
395 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
396 CopyAttrsFusedBatchNorm, AlwaysRewrite});
397 rinfo_.push_back(
398 {csinfo_.fused_batch_norm_grad,
399 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
400 CopyAttrsFusedBatchNorm, AlwaysRewrite});
401 rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
402 CopyAttrsFusedConv2D, FusedConv2DRewrite});
403 rinfo_.push_back({csinfo_.identity,
404 mkl_op_registry::GetMklOpName(csinfo_.identity),
405 CopyAttrsDataType, AlwaysRewrite});
406 rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
407 CopyAttrsLRN, LrnRewrite});
408 rinfo_.push_back({csinfo_.lrn_grad,
409 mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
410 CopyAttrsLRN, LrnGradRewrite});
411 rinfo_.push_back({csinfo_.leakyrelu,
412 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
413 CopyAttrsLeakyRelu, LeakyReluRewrite});
414 rinfo_.push_back({csinfo_.leakyrelu_grad,
415 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
416 CopyAttrsLeakyRelu, LeakyReluRewrite});
417 rinfo_.push_back({csinfo_.max_pool,
418 mkl_op_registry::GetMklOpName(csinfo_.max_pool),
419 CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
420 rinfo_.push_back({csinfo_.max_pool_grad,
421 mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
422 CopyAttrsPooling, MaxpoolGradRewrite});
423 rinfo_.push_back({csinfo_.max_pool3d,
424 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
425 CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
426 rinfo_.push_back({csinfo_.max_pool3d_grad,
427 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
428 CopyAttrsPooling, AlwaysRewrite});
429 rinfo_.push_back({csinfo_.maximum,
430 mkl_op_registry::GetMklOpName(csinfo_.maximum),
431 CopyAttrsDataType, AlwaysRewrite});
432 rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
433 CopyAttrsDataType, AlwaysRewrite});
434 rinfo_.push_back({csinfo_.pad_with_conv2d, csinfo_.mkl_pad_with_conv2d,
435 CopyAttrsPadWithConv2D, AlwaysRewrite});
436 rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
437 csinfo_.mkl_pad_with_fused_conv2d,
438 CopyAttrsPadWithFusedConv2D, AlwaysRewrite});
439 rinfo_.push_back({csinfo_.quantized_avg_pool,
440 mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
441 CopyAttrsQuantizedPooling, AlwaysRewrite});
442 rinfo_.push_back({csinfo_.quantized_concatv2,
443 mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
444 CopyAttrsConcatV2, AlwaysRewrite});
445 rinfo_.push_back({csinfo_.quantized_conv2d,
446 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
447 CopyAttrsQuantizedConv2D, AlwaysRewrite});
448 rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
449 mkl_op_registry::GetMklOpName(
450 csinfo_.quantized_conv2d_with_requantize),
451 CopyAttrsQuantizedConv2D, AlwaysRewrite});
452 rinfo_.push_back(
453 {csinfo_.quantized_conv2d_with_bias,
454 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
455 CopyAttrsQuantizedConv2D, AlwaysRewrite});
456 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
457 mkl_op_registry::GetMklOpName(
458 csinfo_.quantized_conv2d_with_bias_and_requantize),
459 CopyAttrsQuantizedConv2D, AlwaysRewrite});
460 rinfo_.push_back(
461 {csinfo_.quantized_conv2d_and_relu,
462 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
463 CopyAttrsQuantizedConv2D, AlwaysRewrite});
464 rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
465 mkl_op_registry::GetMklOpName(
466 csinfo_.quantized_conv2d_and_relu_and_requantize),
467 CopyAttrsQuantizedConv2D, AlwaysRewrite});
468 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
469 mkl_op_registry::GetMklOpName(
470 csinfo_.quantized_conv2d_with_bias_and_relu),
471 CopyAttrsQuantizedConv2D, AlwaysRewrite});
472 rinfo_.push_back(
473 {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
474 mkl_op_registry::GetMklOpName(
475 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
476 CopyAttrsQuantizedConv2D, AlwaysRewrite});
477 rinfo_.push_back({csinfo_.quantized_max_pool,
478 mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
479 CopyAttrsQuantizedPooling, AlwaysRewrite});
480 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
481 mkl_op_registry::GetMklOpName(
482 csinfo_.quantized_conv2d_with_bias_sum_and_relu),
483 CopyAttrsQuantizedConv2D, AlwaysRewrite});
484 rinfo_.push_back(
485 {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
486 mkl_op_registry::GetMklOpName(
487 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
488 CopyAttrsQuantizedConv2D, AlwaysRewrite});
489 rinfo_.push_back(
490 {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
491 mkl_op_registry::GetMklOpName(
492 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
493 CopyAttrsQuantizedConv2D, AlwaysRewrite});
494 rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
495 CopyAttrsDataType, AlwaysRewrite});
496 rinfo_.push_back({csinfo_.relu_grad,
497 mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
498 CopyAttrsDataType, AlwaysRewrite});
499 rinfo_.push_back({csinfo_.relu6,
500 mkl_op_registry::GetMklOpName(csinfo_.relu6),
501 CopyAttrsDataType, AlwaysRewrite});
502 rinfo_.push_back({csinfo_.relu6_grad,
503 mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
504 CopyAttrsDataType, AlwaysRewrite});
505 rinfo_.push_back({csinfo_.requantize,
506 mkl_op_registry::GetMklOpName(csinfo_.requantize),
507 CopyAttrsRequantize, AlwaysRewrite});
508 /*
509 rinfo_.push_back({csinfo_.tanh,
510 mkl_op_registry::GetMklOpName(csinfo_.tanh),
511 CopyAttrsDataType, AlwaysRewrite});
512 rinfo_.push_back({csinfo_.tanh_grad,
513 mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
514 CopyAttrsDataType, AlwaysRewrite});
515 */
516 rinfo_.push_back({csinfo_.reshape,
517 mkl_op_registry::GetMklOpName(csinfo_.reshape),
518 CopyAttrsReshape, AlwaysRewrite});
519 rinfo_.push_back({csinfo_.slice,
520 mkl_op_registry::GetMklOpName(csinfo_.slice),
521 CopyAttrsSlice, AlwaysRewrite});
522 rinfo_.push_back({csinfo_.softmax,
523 mkl_op_registry::GetMklOpName(csinfo_.softmax),
524 CopyAttrsDataType, AlwaysRewrite});
525
526 rinfo_.push_back({csinfo_.squared_difference,
527 mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
528 CopyAttrsDataType, AlwaysRewrite});
529 rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
530 CopyAttrsDataType, AlwaysRewrite});
531
532 // Add info about which ops to add workspace edge to and the slots.
533 wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
534 wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
535 wsinfo_.push_back(
536 {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
537
538 // Add a rule for merging nodes
539 minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
540 csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
541
542 minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
543 csinfo_.conv2d_grad_filter_with_bias,
544 GetConv2DBackpropFilterOrBiasAddGrad});
545 // Merge Pad and Conv2d, only if the pad op is "Pad"
546 // Doesn't merge if pad op is "PadV2" or "MirrorPad"
547 minfo_.push_back(
548 {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D});
549
550 minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
551 csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
552
553 // The fusion patterns in "finfo_" that show up first will get applied
554 // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
555 // A->B->C->D to ABCD}, since the first gets applied first, the final
556 // graph will be ABC->D.
557
558 //
559 // Add rules to fuse sequences such as "Transpose (NCHW -> NHWC) + Conv2D
560 // (NHWC) + Transpose (NHWC->
561 // NCHW)" into "Conv2D (NCHW)". Such patterns occur frequently in Keras.
562 // Note: we use the term "merge" to combine (exactly) 2 nodes into one,
563 // while "fusion" is for 3+ nodes situation.
564 //
565
566 // Transpose + Conv2d + Transpose:
567 std::vector<int> transpose_to_nhwc = {NCHW::dim::N, NCHW::dim::H,
568 NCHW::dim::W, NCHW::dim::C};
569 std::vector<int> transpose_to_nchw = {NHWC::dim::N, NHWC::dim::C,
570 NHWC::dim::H, NHWC::dim::W};
571 auto CheckForTransposeToNHWC =
572 std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nhwc);
573 auto CheckForConv2dOp =
574 std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv2d);
575 auto CheckForTransposeToNCHW =
576 std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_nchw);
577 auto FuseConv2D =
578 std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
579 std::placeholders::_2, std::placeholders::_3, "NCHW");
580 finfo_.push_back(
581 {"transpose-elimination for Conv2D",
582 {CheckForTransposeToNHWC, CheckForConv2dOp, CheckForTransposeToNCHW},
583 // CheckForMklOp
584 FuseConv2D,
585 CopyAttrsConv});
586 }
587
588 // Standard interface to run pass
589 Status Run(const GraphOptimizationPassOptions& options);
590
591 // Helper function which does most of heavy lifting for rewriting
592 // Mkl nodes to propagate Mkl tensor as additional output
593 //
594 // Extracts common functionality between Run public interface and
595 // test interface.
596 //
597 // @return true, if and only if graph is mutated; false otherwise.
598 bool RunPass(std::unique_ptr<Graph>* g);
599
600 /// Structure to specify the name of an original node, its new name after
601 /// rewrite, the number of inputs to the original node, the function to
602 /// be used to copy attributes for the op, and the rule (if any) which
603 /// must hold for rewriting the node
604 typedef struct {
605 string name; // Original name of op of the node in the graph
606 string new_name; // New name of the op of the node in the graph
607 // A function handler to copy attributes from an old node to a new node.
608 std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs;
609 // A rule under which to rewrite this node
610 std::function<bool(const Node*)> rewrite_rule;
611 } RewriteInfo;
612
613 /// Structure to specify a forward op, a backward op, and the slot numbers
614 /// in the forward and backward ops where we will add a workspace edge.
615 typedef struct {
616 string fwd_op; // Name of a forward op in the graph
617 string bwd_op; // Name of a backward op in the graph
618 int fwd_slot; // Output slot in the forward op node where actual
619 // output tensor resides
620 int bwd_slot; // Input slot in the backward op node where actual
621 // input tensor resides
622 int ws_fwd_slot; // Output slot in the forward op node where workspace
623 // edge is added
624 int ws_bwd_slot; // Input slot in the backward op node where workspace
625 // edge is added
626 } WorkSpaceInfo;
627
628 /// Structure to specify information used in node merge of 2 operators
629 typedef struct {
630 string op1; // Node string for one operator.
631 string op2; // Node string for second operator.
632 string new_node; // Name of the node after merge
633 // Function that enables user of the node merger to specify how to find
634 // second operator given the first operator.
635 std::function<Node*(const Node*)> get_node_to_be_merged;
636 } MergeInfo;
637
638 // Structure to specify information used in node fusion of 3+ operators
639 typedef struct {
640 std::string pattern_name; // Name to describe this pattern, such as
641 // "Transpose_Mklop_Transpose".
642 std::vector<std::function<bool(const Node*)> >
643 node_checkers; // Extra restriction checker for these ops
644 std::function<Status(
645 std::unique_ptr<Graph>*, std::vector<Node*>&,
646 std::function<void(const Node*, NodeBuilder* nb, bool)>)>
647 fuse_func;
648 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs;
649 } FusionInfo;
650
651 //
652 // Dimension indices for 2D tensor.
653 //
654 struct NCHW {
655 enum dim { N = 0, C = 1, H = 2, W = 3 };
656 };
657
658 struct NHWC {
659 enum dim { N = 0, H = 1, W = 2, C = 3 };
660 };
661
662 //
663 // dimension indices for 3D tensor.
664 //
665 struct NCDHW {
666 enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 };
667 };
668
669 struct NDHWC {
670 enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 };
671 };
672
673 /// Structure to store all constant strings
674 /// NOTE: names are alphabetically sorted.
675 typedef struct {
676 string addn;
677 string add;
678 string avg_pool;
679 string avg_pool_grad;
680 string avg_pool3d;
681 string avg_pool3d_grad;
682 string bias_add;
683 string bias_add_grad;
684 string concat;
685 string concatv2;
686 string conv2d;
687 string conv2d_with_bias;
688 string conv2d_grad_input;
689 string conv2d_grad_filter;
690 string conv2d_grad_filter_with_bias;
691 string conv3d;
692 string conv3d_grad_input;
693 string conv3d_grad_filter;
694 string depthwise_conv2d;
695 string depthwise_conv2d_grad_input;
696 string depthwise_conv2d_grad_filter;
697 string fused_batch_norm;
698 string fused_batch_norm_grad;
699 string fused_conv2d;
700 string identity;
701 string leakyrelu;
702 string leakyrelu_grad;
703 string lrn;
704 string lrn_grad;
705 string matmul;
706 string max_pool;
707 string max_pool_grad;
708 string max_pool3d;
709 string max_pool3d_grad;
710 string maximum;
711 string mkl_conv2d;
712 string mkl_conv2d_grad_input;
713 string mkl_conv2d_grad_filter;
714 string mkl_conv2d_grad_filter_with_bias;
715 string mkl_conv2d_with_bias;
716 string mkl_depthwise_conv2d_grad_input;
717 string mkl_depthwise_conv2d_grad_filter;
718 string mkl_fused_conv2d;
719 string mkl_pad_with_conv2d;
720 string mkl_pad_with_fused_conv2d;
721 string mul;
722 string pad;
723 string pad_with_conv2d;
724 string pad_with_fused_conv2d;
725 string quantized_avg_pool;
726 string quantized_conv2d;
727 string quantized_conv2d_with_requantize;
728 string quantized_conv2d_with_bias;
729 string quantized_conv2d_with_bias_and_requantize;
730 string quantized_conv2d_and_relu;
731 string quantized_conv2d_and_relu_and_requantize;
732 string quantized_conv2d_with_bias_and_relu;
733 string quantized_conv2d_with_bias_and_relu_and_requantize;
734 string quantized_concatv2;
735 string quantized_max_pool;
736 string quantized_conv2d_with_bias_sum_and_relu;
737 string quantized_conv2d_with_bias_sum_and_relu_and_requantize;
738 string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize;
739 string relu;
740 string relu_grad;
741 string relu6;
742 string relu6_grad;
743 string requantize;
744 string tanh;
745 string tanh_grad;
746 string transpose;
747 string reshape;
748 string slice;
749 string softmax;
750 string split;
751 string squared_difference;
752 string sub;
753 } ConstStringsInfo;
754
755 private:
756 /// Maintain info about nodes to rewrite
757 std::vector<RewriteInfo> rinfo_;
758
759 /// Maintain info about nodes to add workspace edge
760 std::vector<WorkSpaceInfo> wsinfo_;
761
762 /// Maintain info about nodes to be merged
763 std::vector<MergeInfo> minfo_;
764
765 /// Maintain info about nodes to be fused
766 std::vector<FusionInfo> finfo_;
767
768 /// Maintain structure of constant strings
769 static ConstStringsInfo csinfo_;
770
771 private:
772 // Is OpDef::ArgDef a list type? It could be N * T or list(type).
773 // Refer to opdef.proto for details of list type.
ArgIsList(const OpDef::ArgDef & arg) const774 inline bool ArgIsList(const OpDef::ArgDef& arg) const {
775 return !arg.type_list_attr().empty() || !arg.number_attr().empty();
776 }
777
778 // Get length of a list in 'n' if 'arg' is of list type. Refer to
779 // description of ArgIsList for definition of list type.
GetTensorListLength(const OpDef::ArgDef & arg,Node * n)780 inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) {
781 CHECK_EQ(ArgIsList(arg), true);
782 int N = 0;
783 const string attr_name = !arg.type_list_attr().empty()
784 ? arg.type_list_attr()
785 : arg.number_attr();
786 if (!arg.type_list_attr().empty()) {
787 std::vector<DataType> value;
788 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
789 N = value.size();
790 } else {
791 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
792 }
793 return N;
794 }
795
796 // Can op represented by node 'n' run on DEVICE_CPU?
797 // Op can run on CPU with MKL if the runtime assigned device or the
798 // user requested device contains device CPU, or both are empty.
CanOpRunOnCPUDevice(const Node * n)799 bool CanOpRunOnCPUDevice(const Node* n) {
800 bool result = true;
801 string reason;
802
803 // Substring that should be checked for in device name for CPU device.
804 const char* const kCPUDeviceSubStr = "CPU";
805
806 // If Op has been specifically assigned to a non-CPU device, then No.
807 if (!n->assigned_device_name().empty() &&
808 !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
809 result = false;
810 reason = "Op has been assigned a runtime device that is not CPU.";
811 }
812
813 // If user has specifically assigned this op to a non-CPU device, then No.
814 if (!n->def().device().empty() &&
815 !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
816 result = false;
817 reason = "User has assigned a device that is not CPU.";
818 }
819
820 if (result == false) {
821 VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
822 << n->type_string() << ", reason: " << reason;
823 }
824
825 // Otherwise Yes.
826 return result;
827 }
828
829 // Return a node that can be merged with input node 'n'
830 //
831 // @return pointer to the node if we can find such a
832 // node. Otherwise, it returns nullptr.
833 Node* CheckForNodeMerge(const Node* n) const;
834
835 // Merge node 'm' with node 'n'.
836 // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with
837 // Conv2DBackpropFilter.
838 //
839 // Input nodes m and n may be deleted if the call to
840 // this function is successful. Attempt to use the pointers
841 // after the call to function may result in undefined behaviors.
842 //
843 // @input g - input graph, m - graph node, n - graph node to be merged with m
844 // @return Status::OK(), if merging is successful and supported.
845 // Returns appropriate Status error code otherwise.
846 // Graph is updated in case nodes are merged. Otherwise, it is
847 // not updated.
848 Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n);
849
850 // Helper function to merge different nodes
851 Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n);
852 Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n);
853 Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g,
854 Node* m, Node* n);
855
856 // Find BiasAdd or Conv2D node that can be merged with input node 'm'.
857 // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be
858 // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd
859 // node that can be merged with 'm'.
GetConv2DOrBiasAdd(const Node * m)860 static Node* GetConv2DOrBiasAdd(const Node* m) {
861 CHECK_NOTNULL(m);
862 Node* n = nullptr;
863
864 DataType T_m;
865 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
866
867 // Don't try to merge if datatype is not DT_FLOAT
868 if (T_m != DT_FLOAT) return n;
869
870 if (m->type_string() == csinfo_.bias_add) {
871 // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
872 TF_CHECK_OK(m->input_node(0, &n));
873 } else {
874 CHECK_EQ(m->type_string(), csinfo_.conv2d);
875 // Go over all output edges and search for BiasAdd Node.
876 // 0th input of BiasAdd is Conv2D.
877 for (const Edge* e : m->out_edges()) {
878 if (!e->IsControlEdge() &&
879 e->dst()->type_string() == csinfo_.bias_add &&
880 e->dst_input() == 0) {
881 n = e->dst();
882 break;
883 }
884 }
885 }
886
887 if (n == nullptr) {
888 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
889 << "Conv2D and BiasAdd node for merging. Input node: "
890 << m->DebugString();
891 }
892
893 return n;
894 }
895
896 // Find Pad or Conv2D node that can be merged with input node 'm'.
897 // If input 'm' is Pad, then check if there exists Conv2D node that can be
898 // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad
899 // node that can be merged with 'm'.
GetPadOrConv2D(const Node * m)900 static Node* GetPadOrConv2D(const Node* m) {
901 DCHECK(m);
902 Node* n = nullptr;
903
904 DataType T_m;
905 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
906
907 // Don't try to merge if datatype is not DT_FLOAT
908 if (T_m != DT_FLOAT) return n;
909
910 const Node* conv_node;
911 if (m->type_string() == csinfo_.pad) {
912 // If m is Pad, then Conv2D is the output of Pad.
913 for (const Edge* e : m->out_edges()) {
914 if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) {
915 n = e->dst();
916 conv_node = n;
917 break;
918 }
919 }
920 } else {
921 DCHECK_EQ(m->type_string(), csinfo_.conv2d);
922 // If m is conv2D, Go over all input edges
923 // and search for Pad Node.
924 for (const Edge* e : m->in_edges()) {
925 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
926 n = e->src();
927 conv_node = m;
928 break;
929 }
930 }
931 }
932 // Check if only VALID type of padding is used
933 // or not.
934 if (n != nullptr) {
935 string padding;
936 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
937 if (padding != "VALID")
938 // Then do not merge.
939 // Only VALID type of padding in conv op can be
940 // merged with Pad op.
941 n = nullptr;
942 } else {
943 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
944 << "Pad and Conv2D node for merging. Input node: "
945 << m->DebugString();
946 }
947
948 return n;
949 }
950
951 // Find Pad or _FusedConv2D node that can be merged with input node 'm'.
952 // If input 'm' is Pad, then check if there exists _FusedConv2D node that can
953 // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there
954 // exists Pad node that can be merged with 'm'.
GetPadOrFusedConv2D(const Node * m)955 static Node* GetPadOrFusedConv2D(const Node* m) {
956 DCHECK(m);
957 Node* n = nullptr;
958
959 const Node* conv_node;
960 if (m->type_string() == csinfo_.pad) {
961 // If m is Pad, then _FusedConv2D is the output of Pad.
962 for (const Edge* e : m->out_edges()) {
963 if (!e->IsControlEdge() &&
964 e->dst()->type_string() == csinfo_.fused_conv2d) {
965 n = e->dst();
966 conv_node = n;
967 break;
968 }
969 }
970 } else {
971 DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d);
972 // If m is _FusedConv2D, Go over all input edges
973 // and search for Pad node.
974 for (const Edge* e : m->in_edges()) {
975 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
976 n = e->src();
977 conv_node = m;
978 break;
979 }
980 }
981 }
982 // Check if only VALID type of padding is used or not.
983 if (n != nullptr) {
984 string padding;
985 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
986 if (padding != "VALID") {
987 // Then do not merge.
988 n = nullptr;
989 VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
990 << "nodes but cannot merge them. Only conv ops with padding "
991 << "type VALID can be merged with Pad op Input node: "
992 << m->DebugString();
993 }
994 } else {
995 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
996 << "Pad and _FusedConv2D node for merging. Input node: "
997 << m->DebugString();
998 }
999
1000 return n;
1001 }
1002
1003 // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input
1004 // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists
1005 // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad,
1006 // then check if there exists Conv2DBackpropFilter node that can be merged
1007 // with 'm'.
1008 //
1009 // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad
1010 // would look like:
1011 //
1012 // _ = Conv2DBackpropFilter(F, _, G)
1013 // _ = BiasAddGrad(G)
1014 //
1015 // So 1st input of BiasAddGrad connects with 3rd input of
1016 // Conv2DBackpropFilter and vice versa.
GetConv2DBackpropFilterOrBiasAddGrad(const Node * m)1017 static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) {
1018 CHECK_NOTNULL(m);
1019 Node* n = nullptr;
1020
1021 DataType T_m;
1022 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1023
1024 // Don't try to merge if datatype is not DT_FLOAT
1025 if (T_m != DT_FLOAT) return n;
1026
1027 if (m->type_string() == csinfo_.bias_add_grad) {
1028 // Get 1st input 'g' of BiasAddGrad.
1029 Node* g = nullptr;
1030 TF_CHECK_OK(m->input_node(0, &g));
1031 // Now traverse all outgoing edges from g that have destination node as
1032 // Conv2DBackpropFilter.
1033 for (const Edge* e : g->out_edges()) {
1034 if (!e->IsControlEdge() &&
1035 e->dst()->type_string() == csinfo_.conv2d_grad_filter &&
1036 e->dst_input() == 2 /* 3rd input of BackpropFilter */) {
1037 n = e->dst();
1038 break;
1039 }
1040 }
1041 } else {
1042 CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter);
1043 // Get 3rd input 'g' of Conv2DBackpropFilter.
1044 Node* g = nullptr;
1045 TF_CHECK_OK(m->input_node(2, &g));
1046 // Now traverse all outgoing edges from g that have destination node as
1047 // BiasAddGrad.
1048 for (const Edge* e : g->out_edges()) {
1049 if (!e->IsControlEdge() &&
1050 e->dst()->type_string() == csinfo_.bias_add_grad &&
1051 e->dst_input() == 0 /* 1st input of BiasAddGrad */) {
1052 n = e->dst();
1053 break;
1054 }
1055 }
1056 }
1057
1058 if (n == nullptr) {
1059 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1060 << "Conv2DBackpropFilter and BiasAddGrad node for merging. "
1061 << "Input node: " << m->DebugString();
1062 }
1063 return n;
1064 }
1065
1066 // Return a node that can be fused with input node 'n'
1067 //
1068 // @return tuple. If we can find such nodes, the first
1069 // element of the tuple is a true. Otherwise, it's false.
1070 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
1071 CheckForNodeFusion(Node* n) const;
1072
1073 // Fuse nodes in the vector "nodes"
1074 Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1075 const MklLayoutRewritePass::FusionInfo fi);
1076
1077 // Fuse tranpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into
1078 // mklop("NCHW").
1079 // Here "mklop" can be any MKL-DNN supported op, such as Conv2D.
1080 static Status FuseTransposeMklOpTranspose(
1081 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1082 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
1083 string data_format);
1084
CheckForTranspose(const Node * node,std::vector<int> perm)1085 static bool CheckForTranspose(const Node* node, std::vector<int> perm) {
1086 // Check if node's type is "Transpose"
1087 if (node->type_string() != "Transpose") return false;
1088
1089 // If "Transpose" has multiple output data edges, also don't fuse it.
1090 if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false;
1091
1092 // Check if has out control edge. If true, this is a training graph.
1093 // Currently we focus on inference and do no fusion in training.
1094 // Note: this constraint will eventually be removed, if we enabled this
1095 // fusion for training
1096 // in the future.
1097 for (const Edge* e : node->out_edges()) {
1098 if (e->IsControlEdge()) {
1099 return false;
1100 }
1101 }
1102
1103 // If "Transpose" has input control edges, don't fuse on it.
1104 for (const Edge* e : node->in_edges()) {
1105 if (e->IsControlEdge()) {
1106 return false;
1107 }
1108 }
1109
1110 // We compared the tensor containing the permutation order ("perm_node")
1111 // with our desired order ("perm"). If they're exactly match, this check
1112 // succeed and returns true.
1113 for (const Edge* e : node->in_edges()) {
1114 if (!e->IsControlEdge()) {
1115 const Node* perm_node = e->src();
1116
1117 const int kPermTensorIndex = 1;
1118 if (perm_node->type_string() == "Const" &&
1119 e->dst_input() == kPermTensorIndex) {
1120 // we find the "perm" node, now try to retrieve its value.
1121 const TensorProto* proto = nullptr;
1122 TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value", &proto));
1123
1124 DataType type;
1125 GetNodeAttr(perm_node->def(), "dtype", &type);
1126
1127 // Here we directly access to the "tensor_content", rather than
1128 // "int_val". This is because we find "int_val" is
1129 // not set properly under some circumstances.
1130 if (type == DT_INT32) {
1131 const int type_size = 4;
1132 const int* tensor_content =
1133 reinterpret_cast<const int*>(proto->tensor_content().c_str());
1134 const int tensor_content_size =
1135 proto->tensor_content().size() / type_size;
1136
1137 std::vector<int> perm_value(tensor_content,
1138 tensor_content + tensor_content_size);
1139
1140 return perm_value == perm;
1141 } else if (type == DT_INT64) {
1142 const int type_size = 8;
1143 const long* tensor_content =
1144 reinterpret_cast<const long*>(proto->tensor_content().c_str());
1145 const int tensor_content_size =
1146 proto->tensor_content().size() / type_size;
1147
1148 std::vector<long> perm_value(tensor_content,
1149 tensor_content + tensor_content_size);
1150 std::vector<long> long_perm(perm.cbegin(), perm.cend());
1151
1152 return perm_value == long_perm;
1153 }
1154 return false;
1155 }
1156 }
1157 }
1158 return false;
1159 }
1160
CheckForMklOp(const Node * node,string name="")1161 static bool CheckForMklOp(const Node* node, string name = "") {
1162 if (node == nullptr) return false;
1163
1164 if (!name.empty() && node->type_string() != name) {
1165 return false;
1166 }
1167
1168 // if mklop has multiple outputs, don't fuse it.
1169 if (node->num_outputs() > 1) return false;
1170
1171 if (node->out_edges().size() > 1) return false;
1172
1173 DataType T;
1174 TF_CHECK_OK(GetNodeAttr(node->def(), "T", &T));
1175 return mkl_op_registry::IsMklOp(
1176 mkl_op_registry::GetMklOpName(node->type_string()), T);
1177 }
1178
1179 // Check if the node 'n' has any applicable rewrite rule
1180 // We check for 2 scenarios for rewrite.
1181 //
1182 // @return RewriteInfo* for the applicable rewrite rule
1183 const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
1184 const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const;
1185
1186 // Default rewrite rule to be used in scenario 1 for rewrite.
1187 // @return - true (since we want to always rewrite)
AlwaysRewrite(const Node * n)1188 static bool AlwaysRewrite(const Node* n) { return true; }
1189
1190 // Check if we are performing pooling on depth or batch. If it is, then we
1191 // do not rewrite MaxPool node to Mkl version.
1192 // @return - true (if it is not a depth/batch wise pooling case);
1193 // false otherwise.
NonDepthBatchWisePoolRewrite(const Node * n)1194 static bool NonDepthBatchWisePoolRewrite(const Node* n) {
1195 CHECK_NOTNULL(n);
1196
1197 string data_format_str;
1198 TensorFormat data_format;
1199 std::vector<int32> ksize, strides;
1200 CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
1201 CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
1202 CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
1203 CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
1204
1205 // Condition that specifies non-batch-wise and non-depth-wise pooling.
1206 if (GetTensorDim(ksize, data_format, 'N') == 1 &&
1207 GetTensorDim(strides, data_format, 'N') == 1 &&
1208 GetTensorDim(ksize, data_format, 'C') == 1 &&
1209 GetTensorDim(strides, data_format, 'C') == 1) {
1210 return true;
1211 }
1212
1213 return false;
1214 }
1215
1216 // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
1217 // path. The unoptimized path is slow. Thus we dont rewrite the node
1218 // and use default Eigen. But for depth_radius=2, MKL DNN optimized
1219 // path is taken, i.e., eigen node is rewritten by MKl DNN node.
LrnRewrite(const Node * n)1220 static bool LrnRewrite(const Node* n) {
1221 CHECK_NOTNULL(n);
1222
1223 int depth_radius;
1224 CHECK_EQ(GetNodeAttr(n->def(), "depth_radius", &depth_radius).ok(), true);
1225
1226 // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
1227 // and use eigen node instead
1228 if (depth_radius == 2) {
1229 return true;
1230 }
1231 VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
1232 << "case is not optimized by Intel MKL, thus using Eigen op"
1233 << "for LRN ";
1234
1235 return false;
1236 }
1237
LrnGradRewrite(const Node * n)1238 static bool LrnGradRewrite(const Node* n) {
1239 CHECK_NOTNULL(n);
1240 bool do_rewrite = false;
1241
1242 for (const Edge* e : n->in_edges()) {
1243 // Rewrite only if there is corresponding LRN, i.e workspace is available
1244 if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
1245 e->src()->type_string() ==
1246 mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
1247 e->src_output() == 0) {
1248 do_rewrite = true;
1249 break;
1250 }
1251 }
1252 return do_rewrite;
1253 }
1254
1255 // MKL-DNN's LeakyRelu(feature) = feature (if feature > 0), or
1256 // feature * alpha (otherwise),
1257 // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha).
1258 // These two algorithms are not consistent when alpha > 1,
1259 // so we only rewrite LeakyRelu to MKL OP when alpha <= 1.
LeakyReluRewrite(const Node * n)1260 static bool LeakyReluRewrite(const Node* n) {
1261 DCHECK(n);
1262
1263 float alpha;
1264 bool has_attr = GetNodeAttr(n->def(), "alpha", &alpha).ok();
1265 DCHECK(has_attr);
1266
1267 // If the alpha of LeakyRelu is less than 1, rewrite the node.
1268 // Otherwise eigen node is used instead.
1269 if (alpha <= 1) {
1270 return true;
1271 }
1272 VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 "
1273 << "which case is not optimized by Intel MKL, thus using Eigen op"
1274 << "for LeakyRelu ";
1275
1276 return false;
1277 }
1278
MaxpoolGradRewrite(const Node * n)1279 static bool MaxpoolGradRewrite(const Node* n) {
1280 CHECK_NOTNULL(n);
1281 bool do_rewrite = false;
1282 for (const Edge* e : n->in_edges()) {
1283 // Rewrite only if there is corresponding Maxpool, i.e workspace is
1284 // available
1285 if (e->dst()->type_string() == csinfo_.max_pool_grad &&
1286 e->dst_input() == 1 &&
1287 e->src()->type_string() ==
1288 mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
1289 e->src_output() == 0) {
1290 do_rewrite = true;
1291 break;
1292 }
1293 }
1294 return do_rewrite;
1295 }
1296
AddNRewrite(const Node * n)1297 static bool AddNRewrite(const Node* n) {
1298 CHECK_NOTNULL(n);
1299
1300 int num;
1301 CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true);
1302
1303 // Condition that specifies non-batch-wise and non-depth-wise pooling.
1304 if (num == 2) {
1305 return true;
1306 }
1307
1308 return false;
1309 }
1310
FusedConv2DRewrite(const Node * n)1311 static bool FusedConv2DRewrite(const Node* n) {
1312 // MKL DNN currently doesn't support all fusions that grappler fuses
1313 // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
1314 // it includes those we support.
1315 DataType T;
1316 if (!GetNodeAttr(n->def(), "T", &T).ok() ||
1317 !mkl_op_registry::IsMklOp(csinfo_.mkl_fused_conv2d, T)) {
1318 return false;
1319 }
1320
1321 std::vector<string> fused_ops;
1322 TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1323 return (fused_ops == std::vector<string>{"BiasAdd"} ||
1324 fused_ops == std::vector<string>{"Relu"} ||
1325 fused_ops == std::vector<string>{"BiasAdd", "Relu"});
1326 }
1327
1328 // Rewrites input node to a new node specified by its matching rewrite info.
1329 //
1330 // Method first searches matching rewrite info for input node and then
1331 // uses that info to rewrite.
1332 //
1333 // Input node may be deleted in case of rewrite. Attempt to use the node
1334 // after the call can result in undefined behaviors.
1335 //
1336 // @input g - input graph, n - Node to be rewritten,
1337 // ri - matching rewriteinfo
1338 // @return Status::OK(), if the input node is rewritten;
1339 // Returns appropriate Status error code otherwise.
1340 // Graph is updated in case the input node is rewritten.
1341 // Otherwise, it is not updated.
1342 Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
1343
1344 // Get nodes that will feed a list of TF tensors to the new
1345 // node that we are constructing.
1346 //
1347 // @input g - input graph,
1348 // @input inputs - inputs to old node that we are using for constructing
1349 // new inputs,
1350 // @input input_idx - the index in the 'inputs' vector pointing to the
1351 // current input that we have processed so far
1352 // @output input_idx - index will be incremented by the number of nodes
1353 // from 'inputs' that are processed
1354 // @input list_length - The expected length of list of TF tensors
1355 // @output output_nodes - the list of new nodes creating TF tensors
1356 //
1357 // @return None
1358 void GetNodesProducingTFTensorList(
1359 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1360 int* input_idx, int list_length,
1361 std::vector<NodeBuilder::NodeOut>* output_nodes);
1362
1363 // Get nodes that will feed a list of Mkl tensors to the new
1364 // node that we are constructing.
1365 //
1366 // @input g - input graph,
1367 // @input orig_node - Original node that we are rewriting
1368 // @input inputs - inputs to old node that we are using for constructing
1369 // new inputs,
1370 // @input input_idx - the index in the 'inputs' vector pointing to the
1371 // current input that we have processed so far
1372 // @output input_idx - index will be incremented by the number of nodes
1373 // from 'inputs' that are processed
1374 // @input list_length - The expected length of list of Mkl tensors
1375 // @output output_nodes - the list of new nodes creating Mkl tensors
1376 //
1377 // @return None
1378 void GetNodesProducingMklTensorList(
1379 std::unique_ptr<Graph>* g, Node* orig_node,
1380 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1381 int* input_idx, int list_length,
1382 std::vector<NodeBuilder::NodeOut>* output_nodes);
1383
1384 // Get a node that will feed an Mkl tensor to the new
1385 // node that we are constructing. The output node could be (1) 'n'
1386 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
1387 // if 'n' is not an Mkl layer.
1388 //
1389 // @input g - input graph,
1390 // @input orig_node - Original node that we are rewriting,
1391 // @input n - Node based on which we are creating Mkl node,
1392 // @input n_output_slot - the output slot of node 'n'
1393 // which is feeding to the node that we are constructing
1394 // @output mkl_node - the new node that will feed Mkl tensor
1395 // @output mkl_node_output_slot - the slot number of mkl_node that
1396 // will feed the tensor
1397 // @return None
1398 void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
1399 Node* n, int n_output_slot, Node** mkl_node,
1400 int* mkl_node_output_slot);
1401
1402 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1403 // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
1404 // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
1405 // producing workspace edges if 'are_workspace_tensors_available' is true.
1406 // Otherwise, 'workspace_tensors' is empty vector.
1407 //
1408 // For details, refer to 'Ordering of inputs after rewriting' section in the
1409 // documentation above.
1410 //
1411 // Returns Status::OK() if setting up inputs is successful, otherwise
1412 // returns appropriate status code.
1413 int SetUpContiguousInputs(
1414 std::unique_ptr<Graph>* g,
1415 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1416 NodeBuilder* nb, Node* old_node,
1417 std::vector<NodeBuilder::NodeOut>* workspace_tensors,
1418 bool are_workspace_tensors_available);
1419
1420 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1421 // in graph 'g'. Original node is input in 'orig_node'.
1422 //
1423 // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
1424 // section in the documentation above.
1425 //
1426 // Returns Status::OK() if setting up inputs is successful, otherwise
1427 // returns appropriate status code.
1428 Status SetUpInputs(std::unique_ptr<Graph>* g,
1429 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1430 NodeBuilder* nb, Node* orig_node);
1431
1432 // Add workspace edge on the input or output side of Node 'orig_node' by using
1433 // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
1434 // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
1435 // tensors, if they need to be added, will be set into these tensors.
1436 // If we set workspace tensors, then are_ws_tensors_added should be true.
1437 void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node,
1438 NodeBuilder* nb,
1439 std::vector<NodeBuilder::NodeOut>* ws_tensors,
1440 bool* are_ws_tensors_added);
1441
1442 // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
1443 // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
1444 // 'g'. Returns true is fixup was done; otherwise, it returns false.
1445 bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data,
1446 const Edge* e_metadata);
1447
1448 // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
1449 // connected? If not, then fix them. This is needed because a graph may have
1450 // some input Mkl metadata edges incorrectly setup after node merge and
1451 // rewrite passes. This could happen because GetReversePostOrder function may
1452 // not provide topologically sorted order if a graph contains cycles. The
1453 // function returns true if at least one Mkl metadata edge for node 'n' was
1454 // fixed. Otherwise, it returns false.
1455 //
1456 // Example:
1457 //
1458 // X = MklConv2D(_, _, _)
1459 // Y = MklConv2DWithBias(_, _, _, _, _, _)
1460 // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
1461 //
1462 // For a graph such as shown above, note that 3rd argument of MklAdd contains
1463 // DummyMklTensor. Actually, it should be getting the Mkl metadata from
1464 // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
1465 // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
1466 // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
1467 // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
1468 // data edges (1st and 2nd arguments of MklAdd).
1469 bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
1470
1471 // Functions specific to operators to copy attributes
1472 // We need operator-specific function to copy attributes because the framework
1473 // does not provide any generic function for it.
1474 // NOTE: names are alphabetically sorted.
1475 static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb,
1476 bool change_format = false);
1477 static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb,
1478 bool change_format = false);
1479 static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb,
1480 bool change_format = false);
1481 static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb,
1482 bool change_format = false);
1483 static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
1484 bool change_format = false);
1485 static void CopyAttrsConv2DDepthwise(const Node* orig_node, NodeBuilder* nb,
1486 bool change_format = false);
1487 static void CopyAttrsConv2DDepthwiseCheckConstFilter(
1488 const Node* orig_node, NodeBuilder* nb, bool change_format = false);
1489 static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
1490 NodeBuilder* nb,
1491 bool change_format = false);
1492 static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb,
1493 bool change_format = false);
1494 static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb,
1495 bool change_format = false);
1496 static void CopyAttrsLeakyRelu(const Node* orig_node, NodeBuilder* nb,
1497 bool change_format = false);
1498 static void CopyAttrsFusedConv2D(const Node* orig_node, NodeBuilder* nb,
1499 bool change_format = false);
1500 static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb,
1501 bool change_format = false);
1502 static void CopyAttrsPadWithConv2D(const Node* orig_node, NodeBuilder* nb,
1503 bool change_format = false);
1504 static void CopyAttrsPadWithFusedConv2D(const Node* orig_node,
1505 NodeBuilder* nb,
1506 bool change_format = false);
1507 static void CopyAttrsFromPadAndConv2D(const Node* orig_node1,
1508 const Node* orig_node2, NodeBuilder* nb,
1509 bool change_format = false);
1510 static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1,
1511 const Node* orig_node2,
1512 NodeBuilder* nb,
1513 bool change_format = false);
1514 static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb,
1515 bool change_format = false);
1516 static void CopyAttrsQuantizedPooling(const Node* orig_node, NodeBuilder* nb,
1517 bool change_format = false);
1518 static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb,
1519 bool change_format = false);
1520 static void CopyAttrsQuantizedConcat(const Node* orig_node, NodeBuilder* nb,
1521 bool change_format = false);
1522 static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb,
1523 bool change_format = false);
1524 static void CopyAttrsRequantize(const Node* orig_node, NodeBuilder* nb,
1525 bool change_format = false);
1526 static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb,
1527 bool change_format = false);
1528 static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb,
1529 bool change_format = false);
1530 static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb,
1531 const std::vector<int32>& strides,
1532 const std::vector<int32>& dilations,
1533 bool change_format = false);
1534
1535 // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
1536 // using node for original node 'orig_node' and return it in '*out'.
1537 // TODO(nhasabni) We should move this to mkl_util.h
1538 void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
1539 Node* orig_node);
1540 void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
1541 Node* orig_node);
1542 };
1543
1544 MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
1545
1546 // We register Mkl rewrite pass for phase 1 in post partitioning group.
1547 // We register it here so that we get a complete picture of all users of Mkl
1548 // nodes. Do not change the ordering of the Mkl passes.
1549 const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
1550 OptimizationPassRegistry::POST_PARTITIONING;
1551 #ifdef ENABLE_MKL
1552 REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
1553 #endif // ENABLE_MKL
1554
1555 //////////////////////////////////////////////////////////////////////////
1556 // Helper functions for creating new node
1557 //////////////////////////////////////////////////////////////////////////
1558
FillInputs(const Node * n,gtl::InlinedVector<Node *,4> * control_edges,gtl::InlinedVector<std::pair<Node *,int>,4> * in)1559 static void FillInputs(const Node* n,
1560 gtl::InlinedVector<Node*, 4>* control_edges,
1561 gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
1562 control_edges->clear();
1563 for (const Edge* e : n->in_edges()) {
1564 if (e->IsControlEdge()) {
1565 control_edges->push_back(e->src());
1566 } else {
1567 (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
1568 }
1569 }
1570 std::sort(control_edges->begin(), control_edges->end());
1571 if (n->op_def().is_commutative()) {
1572 // For commutative inputs, we sort the input by the input Node*
1573 // to get a canonical ordering (so that add(a,b) and add(b, a) will
1574 // hash to the same value if is_commutative is true for 'add').
1575 std::sort(in->begin(), in->end());
1576 }
1577 }
1578
GetNodesProducingTFTensorList(const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)1579 void MklLayoutRewritePass::GetNodesProducingTFTensorList(
1580 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
1581 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
1582 CHECK_LT(*input_idx, inputs.size());
1583 CHECK_GT(list_length, 0);
1584 CHECK_NOTNULL(output_nodes);
1585 output_nodes->reserve(list_length);
1586
1587 while (list_length != 0) {
1588 CHECK_GT(list_length, 0);
1589 CHECK_LT(*input_idx, inputs.size());
1590 Node* n = inputs[*input_idx].first;
1591 int slot = inputs[*input_idx].second;
1592 // If input node 'n' is just producing a single tensor at
1593 // output slot 'slot' then we just add that single node.
1594 output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
1595 (*input_idx)++;
1596 list_length--;
1597 }
1598 }
1599
1600 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyMklTensorNode(std::unique_ptr<Graph> * g,Node ** out,Node * orig_node)1601 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
1602 Node** out, Node* orig_node) {
1603 // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
1604 // dummy Mkl tensor. 8 = 2*size_t.
1605 const DataType dt = DataTypeToEnum<uint8>::v();
1606 TensorProto proto;
1607 proto.set_dtype(dt);
1608 uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
1609 proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8));
1610 TensorShape dummy_shape({8});
1611 dummy_shape.AsProto(proto.mutable_tensor_shape());
1612 TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
1613 .Attr("value", proto)
1614 .Attr("dtype", dt)
1615 .Device(orig_node->def().device()) // We place this node on
1616 // the same device as the
1617 // device of the original
1618 // node.
1619 .Finalize(&**g, out));
1620 CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
1621
1622 // If number of inputs to the original node is > 0, then we add
1623 // control dependency between 1st input (index 0) of the original node and
1624 // the dummy Mkl node. This is needed because control-flow ops such as Enter,
1625 // Merge, etc, require frame_name of the dummy Mkl node to be same as the
1626 // rewritten node. Adding control edge between 1st input of the original node
1627 // and the dummy Mkl node ensures that the dummy node is in the same frame
1628 // as the original node. Choosing 1st input is not necessary - any input of
1629 // the original node is fine because all the inputs of a node are always in
1630 // the same frame.
1631 if (orig_node->num_inputs() > 0) {
1632 Node* orig_input0 = nullptr;
1633 TF_CHECK_OK(
1634 orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
1635 // Allow duplicate while adding control edge as it would fail (return
1636 // NULL) if we try to add duplicate edge.
1637 CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
1638 }
1639
1640 (*out)->set_assigned_device_name(orig_node->assigned_device_name());
1641 }
1642
GetNodesProducingMklTensorList(std::unique_ptr<Graph> * g,Node * orig_node,const gtl::InlinedVector<std::pair<Node *,int>,4> & inputs,int * input_idx,int list_length,std::vector<NodeBuilder::NodeOut> * output_nodes)1643 void MklLayoutRewritePass::GetNodesProducingMklTensorList(
1644 std::unique_ptr<Graph>* g, Node* orig_node,
1645 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
1646 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
1647 CHECK_LT(*input_idx, inputs.size());
1648 CHECK_GT(list_length, 0);
1649 CHECK_NOTNULL(output_nodes);
1650 output_nodes->reserve(list_length);
1651
1652 while (list_length != 0) {
1653 CHECK_GT(list_length, 0);
1654 CHECK_LT(*input_idx, inputs.size());
1655 Node* n = inputs[*input_idx].first;
1656 int slot = inputs[*input_idx].second;
1657 // If 'n' is producing a single tensor, then create a single Mkl tensor
1658 // node.
1659 Node* mkl_node = nullptr;
1660 int mkl_node_output_slot = 0;
1661 GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
1662 &mkl_node_output_slot);
1663 output_nodes->push_back(
1664 NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
1665 (*input_idx)++;
1666 list_length--;
1667 }
1668 }
1669
1670 // Get an input node that will feed Mkl tensor to the new
1671 // node that we are constructing. An input node could be (1) 'n'
1672 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
1673 // if 'n' is not an Mkl layer.
GetNodeProducingMklTensor(std::unique_ptr<Graph> * g,Node * orig_node,Node * n,int n_output_slot,Node ** mkl_node,int * mkl_node_output_slot)1674 void MklLayoutRewritePass::GetNodeProducingMklTensor(
1675 std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
1676 Node** mkl_node, int* mkl_node_output_slot) {
1677 CHECK_NOTNULL(n);
1678 CHECK_NOTNULL(mkl_node);
1679 CHECK_NOTNULL(mkl_node_output_slot);
1680
1681 // If this is an MKL op, then it will create extra output for MKL layout.
1682 DataType T;
1683 if (GetNodeAttr(n->def(), "T", &T).ok() &&
1684 mkl_op_registry::IsMklOp(n->type_string(), T)) {
1685 // If this is an MKL op, then it will generate an edge that will receive
1686 // Mkl tensor from a node.
1687 // output slot number for Mkl tensor would be N+slot number of TensorFlow
1688 // tensor, where N is total number of TensorFlow tensors.
1689 *mkl_node = n;
1690 *mkl_node_output_slot =
1691 GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
1692 } else {
1693 // If we have not visited the node and rewritten it, then we need
1694 // to create a dummy node that will feed a dummy Mkl tensor to this node.
1695 // DummyMklTensor node has no input and generates only 1 output
1696 // (dummy Mkl tensor) as output slot number 0.
1697 GetDummyMklTensorNode(g, mkl_node, orig_node);
1698 CHECK_NOTNULL(*mkl_node);
1699 *mkl_node_output_slot = 0;
1700 }
1701 }
1702
SetUpContiguousInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,Node * old_node,std::vector<NodeBuilder::NodeOut> * workspace_tensors,bool are_workspace_tensors_available)1703 int MklLayoutRewritePass::SetUpContiguousInputs(
1704 std::unique_ptr<Graph>* g,
1705 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1706 NodeBuilder* nb, Node* old_node,
1707 std::vector<NodeBuilder::NodeOut>* workspace_tensors,
1708 bool are_workspace_tensors_available) {
1709 CHECK_NOTNULL(workspace_tensors);
1710 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
1711
1712 // TODO(nhasabni): Temporary solution to connect filter input of
1713 // BackpropInput with the converted filter from Conv2D.
1714 bool do_connect_conv2d_backprop_input_filter = false;
1715 Node* conv2d_node = nullptr;
1716 // Filter node is 2nd input (slot index 1) of Conv2D.
1717 int kConv2DFilterInputSlotIdx = 1;
1718 int kConv2DBackpropInputFilterInputSlotIdx = 1;
1719 int kConv2DFilterOutputSlotIdx = 1;
1720 if (old_node->type_string() == csinfo_.conv2d_grad_input) {
1721 // We need to find Conv2D node from Conv2DBackpropInput.
1722 // For that let's first find filter node that is 2nd input (slot 1)
1723 // of BackpropInput.
1724 Node* filter_node = nullptr;
1725 TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx,
1726 &filter_node));
1727 CHECK_NOTNULL(filter_node);
1728
1729 // Now check which nodes receive from filter_node. Filter feeds as
1730 // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
1731 // _MklFusedConv2D.
1732 for (const Edge* e : filter_node->out_edges()) {
1733 if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
1734 e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d ||
1735 e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d ||
1736 e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias ||
1737 e->dst()->type_string() == csinfo_.mkl_fused_conv2d) &&
1738 e->dst_input() == kConv2DFilterInputSlotIdx
1739 /* filter is 2nd input of Conv2D and _MklConv2D. */) {
1740 if (conv2d_node != nullptr) {
1741 VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
1742 << " feeding multiple Conv2D nodes: "
1743 << filter_node->DebugString();
1744 // We will not connect filter input of Conv2DBackpropInput
1745 // to be safe here.
1746 do_connect_conv2d_backprop_input_filter = false;
1747 break;
1748 } else {
1749 conv2d_node = e->dst();
1750 do_connect_conv2d_backprop_input_filter = true;
1751 }
1752 }
1753 }
1754 }
1755
1756 // Number of input slots to original op
1757 // Input slots are represented by .Input() calls in REGISTER_OP.
1758 int old_node_input_slots = old_node->op_def().input_arg_size();
1759 // Actual number of inputs can be greater than or equal to number
1760 // of Input slots because inputs of type list could be unfolded.
1761 CHECK_GE(old_node_inputs.size(), old_node_input_slots);
1762 int nn_slot_idx = 0; // slot index for inputs of new node
1763
1764 // Let's copy all inputs (TF tensors) of original node to new node.
1765 int iidx = 0;
1766 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
1767 // An input slot could be a single tensor or a list. We need
1768 // to handle this case accordingly.
1769 CHECK_LT(iidx, old_node_inputs.size());
1770 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
1771 if (ArgIsList(arg)) {
1772 std::vector<NodeBuilder::NodeOut> new_node_inputs;
1773 int N = GetTensorListLength(arg, old_node);
1774 GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
1775 &new_node_inputs);
1776 nb->Input(new_node_inputs);
1777 nn_slot_idx++;
1778 } else {
1779 // Special case for connecting filter input of Conv2DBackpropInput
1780 if (do_connect_conv2d_backprop_input_filter &&
1781 iidx == kConv2DBackpropInputFilterInputSlotIdx) {
1782 nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
1783 } else {
1784 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
1785 }
1786 iidx++;
1787 nn_slot_idx++;
1788 }
1789 }
1790
1791 // If workspace tensors are available for this op and we are using
1792 // contiguous ordering then we need to add Tensorflow tensor for
1793 // workspace here because Tensorflow tensor for workspace is the
1794 // last tensor in the list of Tensorflow tensors.
1795 if (are_workspace_tensors_available) {
1796 CHECK_EQ(workspace_tensors->size(), 2);
1797 // Tensorflow tensor
1798 nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
1799 nn_slot_idx++;
1800 }
1801
1802 // Let's now setup all Mkl inputs to a new node.
1803 // Number of Mkl inputs must be same as number of TF inputs.
1804 iidx = 0;
1805 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
1806 // An input slot could be a single tensor or a list. We need
1807 // to handle this case accordingly.
1808 CHECK_LT(iidx, old_node_inputs.size());
1809 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
1810 if (ArgIsList(arg)) {
1811 std::vector<NodeBuilder::NodeOut> new_node_inputs;
1812 int N = GetTensorListLength(arg, old_node);
1813 GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
1814 &new_node_inputs);
1815 nb->Input(new_node_inputs);
1816 nn_slot_idx++;
1817 } else {
1818 Node* mkl_node = nullptr;
1819 int mkl_node_output_slot = 0;
1820 // Special case for connecting filter input of Conv2DBackpropInput
1821 if (do_connect_conv2d_backprop_input_filter &&
1822 iidx == kConv2DBackpropInputFilterInputSlotIdx) {
1823 GetNodeProducingMklTensor(g, old_node, conv2d_node,
1824 kConv2DFilterOutputSlotIdx, &mkl_node,
1825 &mkl_node_output_slot);
1826 } else {
1827 GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
1828 old_node_inputs[iidx].second, &mkl_node,
1829 &mkl_node_output_slot);
1830 }
1831 nb->Input(mkl_node, mkl_node_output_slot);
1832 iidx++;
1833 nn_slot_idx++;
1834 }
1835 }
1836
1837 // If workspace tensors are available for this op and we are using
1838 // contiguous ordering then we need to add Mkl tensor for
1839 // workspace here because Mkl tensor for workspace is the
1840 // last tensor in the list of Mkl tensors.
1841 if (are_workspace_tensors_available) {
1842 CHECK_EQ(workspace_tensors->size(), 2);
1843 // Mkl tensor
1844 nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
1845 nn_slot_idx++;
1846 }
1847
1848 return nn_slot_idx;
1849 }
1850
SetUpInputs(std::unique_ptr<Graph> * g,const gtl::InlinedVector<std::pair<Node *,int>,4> & old_node_inputs,NodeBuilder * nb,Node * old_node)1851 Status MklLayoutRewritePass::SetUpInputs(
1852 std::unique_ptr<Graph>* g,
1853 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1854 NodeBuilder* nb, Node* old_node) {
1855 // Let's check if we need to add workspace tensors for this node.
1856 // We add workspace edge only for MaxPool, LRN and BatchNorm.
1857 std::vector<NodeBuilder::NodeOut> workspace_tensors;
1858 bool are_workspace_tensors_available = false;
1859
1860 // Avoid workspace check for QuantizedConv2D and the fused
1861 // Ops as they don't have attribute: "T".
1862 std::vector<string> quant_ops{
1863 "QuantizedConv2D",
1864 "QuantizedConv2DWithBias",
1865 "QuantizedConv2DAndRelu",
1866 "QuantizedConv2DWithBiasAndRelu",
1867 "QuantizedConv2DWithBiasSumAndRelu",
1868 "QuantizedConv2DAndRequantize",
1869 "QuantizedConv2DWithBiasAndRequantize",
1870 "QuantizedConv2DAndReluAndRequantize",
1871 "QuantizedConv2DWithBiasAndReluAndRequantize",
1872 "QuantizedConv2DWithBiasSumAndReluAndRequantize",
1873 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"};
1874 bool should_check_workspace =
1875 std::find(std::begin(quant_ops), std::end(quant_ops),
1876 old_node->type_string()) == std::end(quant_ops);
1877 if (should_check_workspace)
1878 AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
1879 &are_workspace_tensors_available);
1880
1881 int new_node_input_slots = 0;
1882 if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
1883 // TODO(nhasabni): implement this function just for same of completion.
1884 // We do not use interleaved ordering right now.
1885 return Status(
1886 error::Code::UNIMPLEMENTED,
1887 "Interleaved ordering of tensors is currently not supported.");
1888 } else {
1889 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
1890 new_node_input_slots = SetUpContiguousInputs(
1891 g, old_node_inputs, nb, old_node, &workspace_tensors,
1892 are_workspace_tensors_available);
1893 }
1894
1895 // Sanity check
1896 int old_node_input_slots = old_node->op_def().input_arg_size();
1897 if (!are_workspace_tensors_available) {
1898 // If we are not adding workspace tensors for this op, then the total
1899 // number of input slots to the new node _must_ be 2 times the number
1900 // of input slots to the original node: N original Tensorflow tensors and
1901 // N for Mkl tensors corresponding to each Tensorflow tensors.
1902 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
1903 } else {
1904 // If we are adding workspace tensors for this op, then the total
1905 // The total number of input slots to new node _must_ be 2 times the number
1906 // of input slots to the original node: N original Tensorflow tensors and
1907 // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
1908 // (for workspace Tensorflow tensor and workspace Mkl tensor).
1909 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
1910 }
1911
1912 return Status::OK();
1913 }
1914
1915 //////////////////////////////////////////////////////////////////////////
1916 // Helper functions related to workspace pass
1917 //////////////////////////////////////////////////////////////////////////
1918
1919 // TODO(nhasabni) We should move this to mkl_util.h.
GetDummyWorkspaceTensorNode(std::unique_ptr<Graph> * g,Node ** out,Node * orig_node)1920 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
1921 std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
1922 // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
1923 // workspace tensor.
1924 GetDummyMklTensorNode(g, out, orig_node);
1925 }
1926
AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph> * g,Node * orig_node,NodeBuilder * nb,std::vector<NodeBuilder::NodeOut> * ws_tensors,bool * are_ws_tensors_added)1927 void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
1928 std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb,
1929 std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
1930 bool workspace_edge_added = false; // Default initializer
1931 CHECK_NOTNULL(are_ws_tensors_added);
1932 *are_ws_tensors_added = false; // Default initializer
1933
1934 DataType T;
1935 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
1936 for (auto ws : wsinfo_) {
1937 if (orig_node->type_string() == ws.fwd_op &&
1938 mkl_op_registry::IsMklOp(
1939 mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
1940 // If this op is a fwd op, then we need to check if there is an
1941 // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
1942 // an edge, then we just add an attribute on this node for setting
1943 // workspace_passed to true. We don't add actual workspace edge
1944 // in this node. Actual workspace edge gets added in the backward
1945 // op for this node.
1946 for (const Edge* e : orig_node->out_edges()) {
1947 if (e->src_output() == ws.fwd_slot &&
1948 e->dst()->type_string() == ws.bwd_op &&
1949 e->dst_input() == ws.bwd_slot) {
1950 nb->Attr("workspace_enabled", true);
1951 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
1952 << orig_node->type_string();
1953 workspace_edge_added = true;
1954 // We found the edge that we were looking for, so break.
1955 break;
1956 }
1957 }
1958
1959 if (!workspace_edge_added) {
1960 // If we are here, then we did not find backward operator for this
1961 // node.
1962 nb->Attr("workspace_enabled", false);
1963 }
1964 } else if (orig_node->type_string() == ws.bwd_op &&
1965 mkl_op_registry::IsMklOp(
1966 mkl_op_registry::GetMklOpName(orig_node->type_string()),
1967 T)) {
1968 // If this op is a bwd op, then we need to add workspace edge and
1969 // it's Mkl tensor edge between its corresponding fwd op and this
1970 // op. Corresponding fwd op is specified in 'fwd_op' field of
1971 // workspace info. fwd_slot and bwd_slot in workspace info specify
1972 // an edge between which slots connect forward and backward op.
1973 // Once all these criteria match, we add a workspace edge between
1974 // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
1975 // determined by interleaved/contiguous ordering. Function
1976 // DataIndexToMetaDataIndex tells us the location of Mkl tensor
1977 // from the location of the Tensorflow tensor.
1978 for (const Edge* e : orig_node->in_edges()) {
1979 if (e->src_output() == ws.fwd_slot &&
1980 // We would have rewritten the forward op, so we need to use
1981 // GetMklOpName call to get its Mkl name.
1982 e->src()->type_string() ==
1983 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
1984 e->dst_input() == ws.bwd_slot) {
1985 nb->Attr("workspace_enabled", true);
1986 CHECK_NOTNULL(ws_tensors);
1987 // Add workspace edge between fwd op and bwd op.
1988 ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
1989 // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
1990 ws_tensors->push_back(NodeBuilder::NodeOut(
1991 e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
1992 e->src()->num_outputs())));
1993 *are_ws_tensors_added = true;
1994 // In terms of input ordering, we add these calls to add Input
1995 // here because workspace edge (and its Mkl tensor) is the last
1996 // edge in the fwdop and bwdop. So all inputs before workspace
1997 // tensor have been added by SetUpInputs function.
1998 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
1999 << orig_node->type_string();
2000 workspace_edge_added = true;
2001 // We found the edge that we were looking for, so break.
2002 break;
2003 }
2004 }
2005
2006 // If we are here means we did not find fwd op that feeds to this
2007 // bwd op. So in this case, we need to generate dummy tensors for
2008 // workspace input and Mkl tensor for workspace, and set
2009 // workspace_enabled to false.
2010 if (!workspace_edge_added) {
2011 nb->Attr("workspace_enabled", false);
2012 Node* dmt_ws = nullptr; // Dummy tensor for workspace
2013 Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace
2014 GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
2015 GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
2016 CHECK_NOTNULL(dmt_ws);
2017 CHECK_NOTNULL(dmt_mkl_ws);
2018 CHECK_NOTNULL(ws_tensors);
2019 // We add dummy tensor as workspace tensor.
2020 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
2021 // We add dummy tensor as Mkl tensor for workspace tensor.
2022 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
2023 *are_ws_tensors_added = true;
2024 VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
2025 << orig_node->type_string();
2026 }
2027 } else {
2028 // If this node does not match any workspace info, then we do not
2029 // do anything special for workspace propagation for it.
2030 }
2031 }
2032 }
2033
2034 //////////////////////////////////////////////////////////////////////////
2035 // Op-specific functions to copy attributes from old node to new node
2036 //////////////////////////////////////////////////////////////////////////
2037
CopyAttrsConvCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2038 void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node,
2039 NodeBuilder* nb,
2040 bool change_format) {
2041 DataType T;
2042 string padding;
2043 std::vector<int32> strides;
2044 std::vector<int32> dilations;
2045
2046 // Get all attributes from old node.
2047 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2048 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2049 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2050 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2051
2052 Node* filter_node = nullptr;
2053 orig_node->input_node(1, &filter_node);
2054
2055 // Add attributes to new node.
2056 nb->Attr("T", T);
2057 nb->Attr("padding", padding);
2058 nb->Attr("is_filter_const", filter_node->IsConstant());
2059
2060 // Add attributes related to `data_format`.
2061 CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
2062 }
2063
CopyAttrsConv(const Node * orig_node,NodeBuilder * nb,bool change_format)2064 void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2065 bool change_format) {
2066 DataType T;
2067 string padding;
2068 std::vector<int32> strides;
2069 std::vector<int32> dilations;
2070
2071 // Get all attributes from old node.
2072 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2073 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2074 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2075 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2076
2077 // Add attributes to new node.
2078 nb->Attr("T", T);
2079 nb->Attr("padding", padding);
2080
2081 // Add attributes related to `data_format`.
2082 CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
2083 }
2084
2085 // Used in rinfo when replacing __MklDummyPadWithConv2D by _MklPadWithConv2D
CopyAttrsPadWithConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2086 void MklLayoutRewritePass::CopyAttrsPadWithConv2D(const Node* orig_node,
2087 NodeBuilder* nb,
2088 bool change_format) {
2089 DataType Tpaddings;
2090 DataType T;
2091 string data_format;
2092 string padding;
2093 std::vector<int32> strides;
2094 std::vector<int32> dilations;
2095 bool use_cudnn_on_gpu;
2096
2097 // Get all attributes from old node.
2098 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2099 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2100 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2101 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2102 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2103 TF_CHECK_OK(
2104 GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2105 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tpaddings", &Tpaddings));
2106
2107 Node* filter_node = nullptr;
2108 orig_node->input_node(1, &filter_node);
2109
2110 // Add attributes to new node.
2111 nb->Attr("T", T);
2112 nb->Attr("strides", strides);
2113 nb->Attr("dilations", dilations);
2114 nb->Attr("padding", padding);
2115 nb->Attr("is_filter_const", filter_node->IsConstant());
2116 nb->Attr("data_format", data_format);
2117 nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2118 nb->Attr("Tpaddings", Tpaddings);
2119 }
2120
CopyAttrsPadWithFusedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2121 void MklLayoutRewritePass::CopyAttrsPadWithFusedConv2D(const Node* orig_node,
2122 NodeBuilder* nb,
2123 bool change_format) {
2124 DataType Tpaddings;
2125
2126 CopyAttrsFusedConv2D(orig_node, nb, change_format);
2127
2128 // Get attributes from old node.
2129 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tpaddings", &Tpaddings));
2130 // Check if filter is a constant.
2131 Node* filter_node = nullptr;
2132 orig_node->input_node(1, &filter_node);
2133
2134 // Add attributes to new node.
2135 nb->Attr("Tpaddings", Tpaddings);
2136 nb->Attr("is_filter_const", filter_node->IsConstant());
2137 }
2138
2139 // Used with MergePadWithConv2D
CopyAttrsFromPadAndConv2D(const Node * orig_node1,const Node * orig_node2,NodeBuilder * nb,bool change_format)2140 void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2141 const Node* orig_node2,
2142 NodeBuilder* nb,
2143 bool change_format) {
2144 DataType Tpaddings;
2145 DataType T;
2146 string data_format;
2147 string padding;
2148 std::vector<int32> strides;
2149 std::vector<int32> dilations;
2150 bool use_cudnn_on_gpu;
2151
2152 // Get all attributes from old node 1.
2153 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T", &T));
2154 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides", &strides));
2155 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations", &dilations));
2156 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding", &padding));
2157 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format", &data_format));
2158 TF_CHECK_OK(
2159 GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2160 // Get all attributes from old node 2.
2161 TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings", &Tpaddings));
2162
2163 // Add attributes to new node.
2164 nb->Attr("T", T);
2165 nb->Attr("strides", strides);
2166 nb->Attr("dilations", dilations);
2167 nb->Attr("padding", padding);
2168 nb->Attr("data_format", data_format);
2169 nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2170 nb->Attr("Tpaddings", Tpaddings);
2171 }
2172
CopyAttrsFromPadAndFusedConv2D(const Node * fused_conv2d,const Node * pad,NodeBuilder * nb,bool change_format)2173 void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
2174 const Node* fused_conv2d, const Node* pad, NodeBuilder* nb,
2175 bool change_format) {
2176 DataType T;
2177 int num_args;
2178 string data_format;
2179 string padding;
2180 std::vector<int32> strides;
2181 std::vector<int32> dilations;
2182 float epsilon;
2183 std::vector<string> fused_ops;
2184 DataType Tpaddings;
2185
2186 // Get all attributes from old node.
2187 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T));
2188 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args", &num_args));
2189 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides", &strides));
2190 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding", &padding));
2191 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format", &data_format));
2192 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations));
2193 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops));
2194 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon));
2195 TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings));
2196
2197 // Add attributes to new node.
2198 nb->Attr("T", T);
2199 nb->Attr("num_args", num_args);
2200 nb->Attr("strides", strides);
2201 nb->Attr("padding", padding);
2202 nb->Attr("data_format", data_format);
2203 nb->Attr("dilations", dilations);
2204 nb->Attr("epsilon", epsilon);
2205 nb->Attr("Tpaddings", Tpaddings);
2206 nb->Attr("fused_ops", fused_ops);
2207 }
2208
CopyAttrsConv2DDepthwise(const Node * orig_node,NodeBuilder * nb,bool change_format)2209 void MklLayoutRewritePass::CopyAttrsConv2DDepthwise(const Node* orig_node,
2210 NodeBuilder* nb,
2211 bool change_format) {
2212 DataType T;
2213 string data_format;
2214 string padding;
2215 std::vector<int32> strides;
2216 std::vector<int32> dilations;
2217
2218 // Get all attributes from old node.
2219 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2220 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2221 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2222 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2223 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2224
2225 // Add attributes to new node.
2226 nb->Attr("T", T);
2227 nb->Attr("strides", strides);
2228 nb->Attr("dilations", dilations);
2229 nb->Attr("padding", padding);
2230 nb->Attr("data_format", data_format);
2231 }
2232
CopyAttrsConv2DDepthwiseCheckConstFilter(const Node * orig_node,NodeBuilder * nb,bool change_format)2233 void MklLayoutRewritePass::CopyAttrsConv2DDepthwiseCheckConstFilter(
2234 const Node* orig_node, NodeBuilder* nb, bool change_format) {
2235 DataType T;
2236 string data_format;
2237 string padding;
2238 std::vector<int32> strides;
2239 std::vector<int32> dilations;
2240
2241 // Get all attributes from old node.
2242 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2243 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2244 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2245 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2246 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2247
2248 Node* filter_node = nullptr;
2249 orig_node->input_node(1, &filter_node);
2250
2251 // Add attributes to new node.
2252 nb->Attr("T", T);
2253 nb->Attr("strides", strides);
2254 nb->Attr("dilations", dilations);
2255 nb->Attr("padding", padding);
2256 nb->Attr("is_filter_const", filter_node->IsConstant());
2257 nb->Attr("data_format", data_format);
2258 }
2259
CopyAttrsAddN(const Node * orig_node,NodeBuilder * nb,bool change_format)2260 void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb,
2261 bool change_format) {
2262 DataType T;
2263 int N;
2264
2265 // Get all attributes from old node.
2266 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2267 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
2268
2269 // Add attributes to new node.
2270 nb->Attr("T", T);
2271 nb->Attr("N", N);
2272 }
2273
CopyAttrsBiasAddGrad(const Node * orig_node,NodeBuilder * nb,bool change_format)2274 void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
2275 NodeBuilder* nb,
2276 bool change_format) {
2277 DataType T;
2278 string data_format;
2279 std::vector<int32> strides;
2280
2281 // Get all attributes from old node.
2282 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2283 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2284 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2285
2286 // Add attributes to new node.
2287 nb->Attr("T", T);
2288 nb->Attr("strides", strides);
2289 nb->Attr("data_format", data_format);
2290 }
2291
CopyAttrsLRN(const Node * orig_node,NodeBuilder * nb,bool change_format)2292 void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb,
2293 bool change_format) {
2294 DataType T;
2295 int depth_radius;
2296 float bias;
2297 float alpha;
2298 float beta;
2299
2300 // Get all attributes from old node.
2301 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2302 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius));
2303 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias));
2304 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
2305 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta));
2306
2307 // Add attributes to new node.
2308 nb->Attr("T", T);
2309 nb->Attr("depth_radius", depth_radius);
2310 nb->Attr("bias", bias);
2311 nb->Attr("alpha", alpha);
2312 nb->Attr("beta", beta);
2313 }
2314
CopyAttrsLeakyRelu(const Node * orig_node,NodeBuilder * nb,bool change_format)2315 void MklLayoutRewritePass::CopyAttrsLeakyRelu(const Node* orig_node,
2316 NodeBuilder* nb,
2317 bool change_format) {
2318 DataType T;
2319 float alpha;
2320
2321 // Get all attributes from old node.
2322 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2323 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
2324
2325 // Add attributes to new node.
2326 nb->Attr("T", T);
2327 nb->Attr("alpha", alpha);
2328 }
2329
CopyAttrsPooling(const Node * orig_node,NodeBuilder * nb,bool change_format)2330 void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
2331 NodeBuilder* nb,
2332 bool change_format) {
2333 DataType T;
2334 string data_format;
2335 string padding;
2336 std::vector<int32> ksize, strides;
2337
2338 // Get all attributes from old node.
2339 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2340 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
2341 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2342 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2343 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2344
2345 // Add attributes to new node.
2346 nb->Attr("T", T);
2347 nb->Attr("ksize", ksize);
2348 nb->Attr("strides", strides);
2349 nb->Attr("padding", padding);
2350 nb->Attr("data_format", data_format);
2351 }
2352
CopyAttrsDataType(const Node * orig_node,NodeBuilder * nb,bool change_format)2353 void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
2354 NodeBuilder* nb,
2355 bool change_format) {
2356 DataType T;
2357
2358 // Get all attributes from old node.
2359 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2360
2361 // Add attributes to new node.
2362 nb->Attr("T", T);
2363 }
2364
CopyAttrsQuantizedPooling(const Node * orig_node,NodeBuilder * nb,bool change_format)2365 void MklLayoutRewritePass::CopyAttrsQuantizedPooling(const Node* orig_node,
2366 NodeBuilder* nb,
2367 bool change_format) {
2368 DataType T;
2369 string padding;
2370 std::vector<int32> ksize, strides;
2371
2372 // Get all attributes from old node.
2373 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2374 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
2375 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2376 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2377
2378 // Add attributes to new node.
2379 nb->Attr("T", T);
2380 nb->Attr("ksize", ksize);
2381 nb->Attr("strides", strides);
2382 nb->Attr("padding", padding);
2383 }
2384
CopyAttrsQuantizedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2385 void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
2386 NodeBuilder* nb,
2387 bool change_format) {
2388 DataType Tinput, Tfilter, out_type;
2389 string padding;
2390 string data_format("NHWC");
2391 std::vector<int32> strides, dilations, padding_list;
2392 bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list");
2393
2394 // Get all attributes from old node.
2395 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
2396 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter", &Tfilter));
2397 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
2398 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2399 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2400 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2401 if (has_padding_list) {
2402 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list));
2403 }
2404
2405 Node* filter_node = nullptr;
2406 orig_node->input_node(1, &filter_node);
2407
2408 // Add attributes to new node.
2409 nb->Attr("Tinput", Tinput);
2410 nb->Attr("Tfilter", Tfilter);
2411 nb->Attr("out_type", out_type);
2412 nb->Attr("padding", padding);
2413 nb->Attr("is_filter_const", filter_node->IsConstant());
2414 nb->Attr("strides", strides);
2415 nb->Attr("dilations", dilations);
2416 nb->Attr("T", out_type); // added "T" for facilitating MklToTf conversion.
2417 nb->Attr("data_format", data_format);
2418 if (has_padding_list) {
2419 nb->Attr("padding_list", padding_list);
2420 }
2421
2422 // Requantization attr Tbias.
2423 DataType Tbias;
2424 Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2425 if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2426 }
2427
CopyAttrsRequantize(const Node * orig_node,NodeBuilder * nb,bool change_format)2428 void MklLayoutRewritePass::CopyAttrsRequantize(const Node* orig_node,
2429 NodeBuilder* nb,
2430 bool change_format) {
2431 DataType Tinput, out_type;
2432
2433 // Get all attributes from old node.
2434 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
2435 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
2436
2437 // Add attributes to new node.
2438 nb->Attr("Tinput", Tinput);
2439 nb->Attr("out_type", out_type);
2440 }
2441
CopyAttrsReshape(const Node * orig_node,NodeBuilder * nb,bool change_format)2442 void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
2443 NodeBuilder* nb,
2444 bool change_format) {
2445 DataType T;
2446 DataType Tshape;
2447
2448 // Get all attributes from old node.
2449 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2450 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
2451
2452 // Add attributes to new node.
2453 nb->Attr("T", T);
2454 nb->Attr("Tshape", Tshape);
2455 }
2456
CopyAttrsSlice(const Node * orig_node,NodeBuilder * nb,bool change_format)2457 void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node,
2458 NodeBuilder* nb, bool change_format) {
2459 DataType T;
2460 DataType Index;
2461
2462 // Get all attributes from old node.
2463 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2464 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index));
2465
2466 // Add attributes to new node.
2467 nb->Attr("T", T);
2468 nb->Attr("Index", Index);
2469 }
2470
CopyAttrsSplit(const Node * orig_node,NodeBuilder * nb,bool change_format)2471 void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
2472 NodeBuilder* nb, bool change_format) {
2473 DataType T;
2474 string data_format;
2475 int num_split;
2476
2477 // Get all attributes from old node.
2478 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2479 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split));
2480 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2481
2482 // Add attributes to new node.
2483 nb->Attr("T", T);
2484 nb->Attr("num_split", num_split);
2485 nb->Attr("data_format", data_format);
2486 }
2487
CopyFormatAttrsConv(const Node * orig_node,NodeBuilder * nb,const std::vector<int32> & strides,const std::vector<int32> & dilations,bool change_format)2488 void MklLayoutRewritePass::CopyFormatAttrsConv(
2489 const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides,
2490 const std::vector<int32>& dilations, bool change_format) {
2491 string data_format;
2492
2493 if (!change_format) {
2494 nb->Attr("strides", strides);
2495 nb->Attr("dilations", dilations);
2496
2497 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2498 nb->Attr("data_format", data_format);
2499 } else {
2500 std::vector<int32> new_strides;
2501 std::vector<int32> new_dilations;
2502 if (strides.size() == 5) {
2503 // `strides` and `dilations` also need to be changed according to
2504 // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2505 new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2506 strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2507 strides[NDHWC::dim::W]};
2508
2509 new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C],
2510 dilations[NDHWC::dim::D], dilations[NDHWC::dim::H],
2511 dilations[NDHWC::dim::W]};
2512 } else {
2513 // `strides` and `dilations` also need to be changed according to
2514 // `data_format`. In this case, from `NHWC` to `NCHW`.
2515
2516 new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2517 strides[NHWC::dim::H], strides[NHWC::dim::W]};
2518
2519 new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C],
2520 dilations[NHWC::dim::H], dilations[NHWC::dim::W]};
2521 }
2522 nb->Attr("strides", new_strides);
2523 nb->Attr("dilations", new_dilations);
2524 }
2525 }
2526
CopyAttrsConcat(const Node * orig_node,NodeBuilder * nb,bool change_format)2527 void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node,
2528 NodeBuilder* nb,
2529 bool change_format) {
2530 DataType T;
2531 int N;
2532
2533 // Get all attributes from old node.
2534 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2535 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
2536
2537 // Add attributes to new node.
2538 nb->Attr("T", T);
2539 nb->Attr("N", N);
2540 }
2541
CopyAttrsConcatV2(const Node * orig_node,NodeBuilder * nb,bool change_format)2542 void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node,
2543 NodeBuilder* nb,
2544 bool change_format) {
2545 DataType T;
2546 int N;
2547 DataType tidx;
2548
2549 // Get all attributes from old node.
2550 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2551 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
2552 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx));
2553
2554 // Add attributes to new node.
2555 nb->Attr("T", T);
2556 nb->Attr("N", N);
2557 nb->Attr("Tidx", tidx);
2558 }
2559
CopyAttrsFusedBatchNorm(const Node * orig_node,NodeBuilder * nb,bool change_format)2560 void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
2561 NodeBuilder* nb,
2562 bool change_format) {
2563 DataType T;
2564 float epsilon;
2565 string data_format;
2566 bool is_training;
2567
2568 // Get all attributes from old node.
2569 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2570 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
2571 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2572 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training));
2573
2574 // Add attributes to new node.
2575 nb->Attr("T", T);
2576 nb->Attr("epsilon", epsilon);
2577 nb->Attr("data_format", data_format);
2578 nb->Attr("is_training", is_training);
2579 }
2580
CopyAttrsFusedConv2D(const Node * orig_node,NodeBuilder * nb,bool change_format)2581 void MklLayoutRewritePass::CopyAttrsFusedConv2D(const Node* orig_node,
2582 NodeBuilder* nb,
2583 bool change_format) {
2584 DataType T;
2585 int num_args;
2586 float epsilon;
2587 string data_format;
2588 string padding;
2589 std::vector<int32> strides;
2590 std::vector<int32> dilations;
2591 std::vector<string> fused_ops;
2592
2593 // Get all attributes from old node.
2594 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2595 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_args", &num_args));
2596 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2597 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2598 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2599 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2600 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "fused_ops", &fused_ops));
2601 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
2602
2603 Node* filter_node = nullptr;
2604 orig_node->input_node(1, &filter_node);
2605
2606 // Add attributes to new node.
2607 nb->Attr("T", T);
2608 nb->Attr("num_args", num_args);
2609 nb->Attr("strides", strides);
2610 nb->Attr("padding", padding);
2611 nb->Attr("is_filter_const", filter_node->IsConstant());
2612 nb->Attr("data_format", data_format);
2613 nb->Attr("dilations", dilations);
2614 nb->Attr("fused_ops", fused_ops);
2615 nb->Attr("epsilon", epsilon);
2616 }
2617
2618 //////////////////////////////////////////////////////////////////////////
2619 // Helper functions related to node merge pass
2620 //////////////////////////////////////////////////////////////////////////
2621
CheckForNodeMerge(const Node * a) const2622 Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
2623 // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
2624 // once we support BiasAddGrad as Mkl layer.
2625
2626 // Search for all matching mergeinfo.
2627 // We allow more than one match for extensibility.
2628 std::vector<const MergeInfo*> matching_mi;
2629 for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
2630 if (a->type_string() == mi->op1 || a->type_string() == mi->op2) {
2631 matching_mi.push_back(&*mi);
2632 }
2633 }
2634
2635 for (const MergeInfo* mi : matching_mi) {
2636 // Get the operand with which 'a' can be merged.
2637 Node* b = nullptr;
2638 if ((b = mi->get_node_to_be_merged(a)) == nullptr) {
2639 continue;
2640 }
2641
2642 // Get the control edges and input of node
2643 const int N_in = a->num_inputs();
2644 gtl::InlinedVector<Node*, 4> a_control_edges;
2645 gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
2646 FillInputs(a, &a_control_edges, &a_in);
2647
2648 const int B_in = b->num_inputs();
2649 gtl::InlinedVector<Node*, 4> b_control_edges;
2650 gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
2651 FillInputs(b, &b_control_edges, &b_in);
2652
2653 // Shouldn't merge if a and b have different control edges.
2654 if (a_control_edges != b_control_edges) {
2655 continue;
2656 } else {
2657 // We found a match.
2658 return b;
2659 }
2660 }
2661
2662 return nullptr;
2663 }
2664
MergeConv2DWithBiasAdd(std::unique_ptr<Graph> * g,Node * m,Node * n)2665 Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
2666 Node* m, Node* n) {
2667 CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
2668 n->type_string() == csinfo_.conv2d)) ||
2669 ((n->type_string() == csinfo_.bias_add &&
2670 m->type_string() == csinfo_.conv2d)),
2671 true);
2672
2673 // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
2674 // BiasAdd is successor node, and Conv2D predecessor node.
2675 Node* pred = m->type_string() == csinfo_.bias_add ? n : m;
2676 Node* succ = m->type_string() == csinfo_.bias_add ? m : n;
2677
2678 // 1. Get all attributes from input nodes.
2679 DataType T_pred, T_succ;
2680 string padding;
2681 std::vector<int32> strides;
2682 std::vector<int32> dilations;
2683 string data_format_pred, data_format_succ;
2684 bool use_cudnn_on_gpu;
2685 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
2686 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
2687 TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
2688 TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
2689 TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations));
2690 TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
2691 TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
2692 TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2693 // We check to ensure that data formats of both succ and pred are same.
2694 // We expect them to be same, so we can enforce this as assert.
2695 // But assert can be too strict, so we enforce this as a check.
2696 // If the check fails, then we do not merge two nodes.
2697 // We also do same check for devices.
2698 if (data_format_pred != data_format_succ || T_pred != T_succ ||
2699 pred->assigned_device_name() != succ->assigned_device_name() ||
2700 pred->def().device() != succ->def().device()) {
2701 return Status(error::Code::INVALID_ARGUMENT,
2702 "data_format or T attribute or devices of Conv2D and "
2703 "BiasAdd do not match. Will skip node merge optimization");
2704 }
2705
2706 const int succ_num = succ->num_inputs();
2707 gtl::InlinedVector<Node*, 4> succ_control_edges;
2708 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
2709 FillInputs(succ, &succ_control_edges, &succ_in);
2710
2711 const int pred_num = pred->num_inputs();
2712 gtl::InlinedVector<Node*, 4> pred_control_edges;
2713 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
2714 FillInputs(pred, &pred_control_edges, &pred_in);
2715
2716 // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is
2717 // not expecting output of Conv2D). If this is not the case, then we cannot
2718 // merge Conv2D with BiasAdd.
2719 const int kFirstOutputSlot = 0;
2720 for (const Edge* e : pred->out_edges()) {
2721 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
2722 return Status(error::Code::INVALID_ARGUMENT,
2723 "Conv2D does not feed to BiasAdd, or "
2724 "it feeds BiasAdd but has multiple outputs. "
2725 "Will skip node merge optimization");
2726 }
2727 }
2728
2729 // 2. Get inputs from both the nodes.
2730 // Find the 2 inputs from the conv and the bias from the add Bias.
2731 // Get operand 0, 1 of conv2D.
2732 CHECK_EQ(pred->in_edges().size(), 2); // Conv2D must have 2 inputs.
2733 // Get operand 1 of add_bias
2734 // BiasAdd must have 2 inputs: Conv, bias
2735 CHECK_EQ(succ->in_edges().size(), 2);
2736
2737 // We will use the node name of BiasAdd as the name of new node
2738 // Build new node. We use same name as original node, but change the op
2739 // name.
2740 NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias);
2741 nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D
2742 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
2743 nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D
2744 // In1 of BiasAdd is same as output of Conv2D.
2745 nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
2746
2747 // Copy attributes from Conv2D to Conv2DWithBias.
2748 CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb);
2749
2750 // Copy the device assigned to old node to new node.
2751 nb.Device(succ->def().device());
2752
2753 // Create node.
2754 Node* new_node;
2755 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
2756 CHECK_NOTNULL(new_node);
2757
2758 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
2759 // node are already copied in BuildNode. We handle control edges now.
2760 for (const Edge* e : pred->in_edges()) {
2761 if (e->IsControlEdge()) {
2762 // Allow duplicate while adding control edge as it would fail (return
2763 // NULL) if we try to add duplicate edge.
2764 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
2765 }
2766 }
2767 for (const Edge* e : succ->in_edges()) {
2768 if (e->IsControlEdge()) {
2769 // Allow duplicate while adding control edge as it would fail (return
2770 // NULL) if we try to add duplicate edge.
2771 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
2772 }
2773 }
2774
2775 // Incoming edges are fixed, we will fix the outgoing edges now.
2776 // First, we will fix outgoing control edges from 'pred' node.
2777 for (const Edge* e : pred->out_edges()) {
2778 if (e->IsControlEdge()) {
2779 // Allow duplicate while adding control edge as it would fail (return
2780 // NULL) if we try to add duplicate edge.
2781 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
2782 }
2783 }
2784
2785 // Second, we will fix outgoing control and data edges from 'succ' node.
2786 for (const Edge* e : succ->out_edges()) {
2787 if (e->IsControlEdge()) {
2788 // Allow duplicate while adding control edge as it would fail (return
2789 // NULL) if we try to add duplicate edge.
2790 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
2791 } else {
2792 // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
2793 // output (at slot 0).
2794 const int kConv2DWithBiasOutputSlot = 0;
2795 CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(),
2796 e->dst_input()));
2797 }
2798 }
2799
2800 // Copy device assigned to old node to new node.
2801 // It's ok to use pred or succ as we have enforced a check that
2802 // both have same device assigned.
2803 new_node->set_assigned_device_name(pred->assigned_device_name());
2804
2805 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
2806 << ", and node: " << succ->DebugString()
2807 << ", into node:" << new_node->DebugString();
2808
2809 (*g)->RemoveNode(succ);
2810 (*g)->RemoveNode(pred);
2811
2812 return Status::OK();
2813 }
2814
MergePadWithConv2D(std::unique_ptr<Graph> * g,Node * m,Node * n)2815 Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g,
2816 Node* m, Node* n) {
2817 DCHECK((m->type_string() == csinfo_.pad &&
2818 (n->type_string() == csinfo_.conv2d ||
2819 n->type_string() == csinfo_.fused_conv2d)) ||
2820 (n->type_string() == csinfo_.pad &&
2821 (m->type_string() == csinfo_.conv2d ||
2822 m->type_string() == csinfo_.fused_conv2d)));
2823
2824 bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d ||
2825 m->type_string() == csinfo_.fused_conv2d;
2826 // Conv2D is successor node, and Pad predecessor node.
2827 Node* pred = m->type_string() == csinfo_.pad ? m : n;
2828 Node* succ = m->type_string() == csinfo_.pad ? n : m;
2829
2830 // 1. Get all attributes from input nodes.
2831 DataType T_pred, T_succ;
2832 string padding;
2833 std::vector<int32> strides;
2834 std::vector<int32> dilations;
2835 string data_format_pred, data_format_succ;
2836
2837 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
2838 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
2839 TF_CHECK_OK(GetNodeAttr(succ->def(), "padding", &padding));
2840 TF_CHECK_OK(GetNodeAttr(succ->def(), "strides", &strides));
2841 TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations", &dilations));
2842 // Check if the devices of both succ and pred are the same.
2843 // Assert is not used because it can be too strict.
2844 // Don't need to check for data formats because it is not available in Pad.
2845 if (T_pred != T_succ ||
2846 pred->assigned_device_name() != succ->assigned_device_name() ||
2847 pred->def().device() != succ->def().device()) {
2848 return Status(error::Code::INVALID_ARGUMENT,
2849 "T attribute or devices of Conv2D and "
2850 "Pad do not match. Will skip node merge optimization");
2851 }
2852
2853 const int succ_num = succ->num_inputs();
2854 gtl::InlinedVector<Node*, 4> succ_control_edges;
2855 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
2856 FillInputs(succ, &succ_control_edges, &succ_in);
2857
2858 const int pred_num = pred->num_inputs();
2859 gtl::InlinedVector<Node*, 4> pred_control_edges;
2860 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
2861 FillInputs(pred, &pred_control_edges, &pred_in);
2862
2863 // We need to ensure that Pad only feeds to Conv2D (some other operator is
2864 // not expecting output of Pad). If this is not the case, then we cannot
2865 // merge Conv2D with Pad.
2866 const int kFirstOutputSlot = 0;
2867 for (const Edge* e : pred->out_edges()) {
2868 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
2869 return Status(error::Code::INVALID_ARGUMENT,
2870 "Pad does not feed to Conv2D, or "
2871 "it feeds Conv2D but has multiple outputs. "
2872 "Will skip node merge optimization");
2873 }
2874 }
2875
2876 // 2. Get inputs from both the nodes.
2877
2878 // Pad must have 2 data inputs: "input" and paddings.
2879 int PadDataInputEdges = 0;
2880 for (const Edge* e : pred->in_edges()) {
2881 if (!e->IsControlEdge()) {
2882 PadDataInputEdges++;
2883 }
2884 }
2885 DCHECK_EQ(PadDataInputEdges, 2);
2886
2887 // Conv2D must have 2 data inputs: Pad output and Filter
2888 // FusedConv2D have 3 data inputs: Pad output, Filter and Args;
2889 int ConvDataInputEdges = 0;
2890 for (const Edge* e : succ->in_edges()) {
2891 if (!e->IsControlEdge()) {
2892 ConvDataInputEdges++;
2893 }
2894 }
2895
2896 DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2);
2897
2898 // We will use the node name of Conv2D as the name of new node
2899 // Build new node. We use same name as original node, but change the op
2900 // name.
2901
2902 NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d
2903 : csinfo_.pad_with_conv2d);
2904 nb.Input(pred_in[0].first, pred_in[0].second); // In1 (input data) of Pad
2905 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
2906 nb.Input(succ_in[1].first, succ_in[1].second); // In2 (filter) of conv2d
2907 // In1 of Conv2D is same as output of Pad.
2908 // Thus, only need to add In2 of Conv2D
2909
2910 if (is_fused_conv2d) {
2911 // FusedConv2D has one additional input, args
2912 std::vector<NodeBuilder::NodeOut> args;
2913 args.emplace_back(succ_in[2].first, succ_in[2].second);
2914 nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{
2915 args}); // In3 (args) of FusedConv2D
2916 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad
2917 // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D.
2918 CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ),
2919 const_cast<const Node*>(pred), &nb);
2920 } else {
2921 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad
2922 // Copy attributes from Pad and conv2D to PadWithConv2D.
2923 CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ),
2924 const_cast<const Node*>(pred), &nb);
2925 }
2926
2927 // Copy the device assigned to old node to new node.
2928 nb.Device(succ->def().device());
2929
2930 // Create node.
2931 Node* new_node;
2932 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
2933 DCHECK(new_node);
2934
2935 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
2936 // node are already copied in BuildNode.
2937 // We handle control edges now.
2938 for (const Edge* e : pred->in_edges()) {
2939 if (e->IsControlEdge()) {
2940 // Don't allow duplicate edge
2941 (*g)->AddControlEdge(e->src(), new_node, false);
2942 }
2943 }
2944 for (const Edge* e : succ->in_edges()) {
2945 if (e->IsControlEdge()) {
2946 // Don't allow duplicate edge
2947 (*g)->AddControlEdge(e->src(), new_node, false);
2948 }
2949 }
2950
2951 // Incoming edges are fixed, we will fix the outgoing edges now.
2952 // First, we will fix outgoing control edges from 'pred' node.
2953 for (const Edge* e : pred->out_edges()) {
2954 if (e->IsControlEdge()) {
2955 // Don't allow duplicate edge
2956 (*g)->AddControlEdge(new_node, e->dst(), false);
2957 }
2958 }
2959
2960 // Second, we will fix outgoing control and data edges from 'succ' node.
2961 for (const Edge* e : succ->out_edges()) {
2962 if (e->IsControlEdge()) {
2963 // Allow duplicate while adding control edge as it would fail (return
2964 // NULL) if we try to add duplicate edge.
2965 (*g)->AddControlEdge(new_node, e->dst(), false);
2966 } else {
2967 // Conv2D has only 1 output (at slot 0) and merged node also has only 1
2968 // output (at slot 0).
2969 const int kPadWithConv2DOutputSlot = 0;
2970 (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(),
2971 e->dst_input());
2972 }
2973 }
2974
2975 // Copy device assigned to old node to new node.
2976 // It's ok to use pred or succ as we have enforced a check that
2977 // both have same device assigned.
2978 new_node->set_assigned_device_name(pred->assigned_device_name());
2979
2980 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
2981 << ", and node: " << succ->DebugString()
2982 << ", into node:" << new_node->DebugString();
2983
2984 (*g)->RemoveNode(succ);
2985 (*g)->RemoveNode(pred);
2986
2987 return Status::OK();
2988 }
2989
MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph> * g,Node * m,Node * n)2990 Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
2991 std::unique_ptr<Graph>* g, Node* m, Node* n) {
2992 CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
2993 n->type_string() == csinfo_.conv2d_grad_filter)) ||
2994 ((n->type_string() == csinfo_.bias_add_grad &&
2995 m->type_string() == csinfo_.conv2d_grad_filter)),
2996 true);
2997
2998 // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
2999 Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
3000 Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m;
3001
3002 // Sanity check for attributes from input nodes.
3003 DataType T_b, T_f;
3004 string data_format_b, data_format_f;
3005 TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b));
3006 TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f));
3007 TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b));
3008 TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f));
3009 if (data_format_b != data_format_f || T_b != T_f ||
3010 badd->assigned_device_name() != fltr->assigned_device_name() ||
3011 badd->def().device() != fltr->def().device()) {
3012 return Status(error::Code::INVALID_ARGUMENT,
3013 "data_format or T attribute or devices of "
3014 "Conv2DBackpropFilter and BiasAddGrad do not match. "
3015 "Will skip node merge optimization");
3016 }
3017
3018 // We will use the node name of Conv2DBackpropFilter as the name of new node.
3019 // This is because BackpropFilterWithBias is going to emit bias output also.
3020 NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias);
3021 // Since Conv2DBackpropFilterWithBias has same number of inputs as
3022 // Conv2DBackpropFilter, we can just copy input edges directly. We dont need
3023 // to copy any data input of BiasAddGrad because that input also goes to
3024 // Conv2DBackpropFilter.
3025 const int fltr_ins = fltr->num_inputs();
3026 gtl::InlinedVector<Node*, 4> fltr_control_edges;
3027 gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins);
3028 FillInputs(fltr, &fltr_control_edges, &fltr_in_edges);
3029 for (int idx = 0; idx < fltr_ins; idx++) {
3030 nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second);
3031 }
3032
3033 // Copy attributes from Conv2DBackpropFilter.
3034 CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
3035
3036 // Copy the device assigned to old node to new node.
3037 nb.Device(fltr->def().device());
3038
3039 // Create node.
3040 Node* new_node;
3041 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3042 CHECK_NOTNULL(new_node);
3043
3044 // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to
3045 // new 'new_node' node are already copied in BuildNode. We handle control
3046 // edges now.
3047 for (const Edge* e : badd->in_edges()) {
3048 if (e->IsControlEdge()) {
3049 // Allow duplicate while adding control edge as it would fail (return
3050 // NULL) if we try to add duplicate edge.
3051 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
3052 }
3053 }
3054 for (const Edge* e : fltr->in_edges()) {
3055 if (e->IsControlEdge()) {
3056 // Allow duplicate while adding control edge as it would fail (return
3057 // NULL) if we try to add duplicate edge.
3058 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
3059 }
3060 }
3061
3062 // Incoming edges are fixed, we will fix the outgoing edges now.
3063 // First, we will fix outgoing control edges from 'badd' node.
3064 // Conv2DBackpropFilter has 1 output -- filter_grad.
3065 // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and
3066 // bias_grad. But filter_grad is at same slot number (0) in both the
3067 // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while
3068 // it is at slot number 0 in BiasAddGrad.
3069 const int kMergedNodeFilterGradOutputIdx = 0;
3070 const int kMergedNodeBiasGradOutputIdx = 1;
3071
3072 for (const Edge* e : badd->out_edges()) {
3073 if (e->IsControlEdge()) {
3074 // Allow duplicate while adding control edge as it would fail (return
3075 // NULL) if we try to add duplicate edge.
3076 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
3077 } else {
3078 CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
3079 e->dst(), e->dst_input()));
3080 }
3081 }
3082
3083 // Second, we will fix outgoing control and data edges from 'fltr' node.
3084 for (const Edge* e : fltr->out_edges()) {
3085 if (e->IsControlEdge()) {
3086 // We allow duplicate edge for this case since we already add control
3087 // edge from new_node in line 3990. Line below could be adding same
3088 // edge to same destination again. In such case, if we do not allow
3089 // duplicate edge, then this call will fail.
3090 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
3091 } else {
3092 CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
3093 e->dst(), e->dst_input()));
3094 }
3095 }
3096
3097 // Copy device assigned to old node to new node.
3098 // It's ok to use badd or fltr as we have enforced a check that
3099 // both have same device assigned.
3100 new_node->set_assigned_device_name(badd->assigned_device_name());
3101
3102 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString()
3103 << ", and node: " << fltr->DebugString()
3104 << ", into node:" << new_node->DebugString();
3105
3106 (*g)->RemoveNode(badd);
3107 (*g)->RemoveNode(fltr);
3108
3109 return Status::OK();
3110 }
3111
MergeNode(std::unique_ptr<Graph> * g,Node * m,Node * n)3112 Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m,
3113 Node* n) {
3114 CHECK_NOTNULL(m);
3115 CHECK_NOTNULL(n);
3116
3117 if (((m->type_string() == csinfo_.bias_add &&
3118 n->type_string() == csinfo_.conv2d)) ||
3119 ((n->type_string() == csinfo_.bias_add &&
3120 m->type_string() == csinfo_.conv2d))) {
3121 return this->MergeConv2DWithBiasAdd(g, m, n);
3122 }
3123 if ((m->type_string() == csinfo_.pad &&
3124 (n->type_string() == csinfo_.conv2d ||
3125 (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) ||
3126 (n->type_string() == csinfo_.pad &&
3127 (m->type_string() == csinfo_.conv2d ||
3128 (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) {
3129 return this->MergePadWithConv2D(g, m, n);
3130 }
3131
3132 if (((m->type_string() == csinfo_.bias_add_grad &&
3133 n->type_string() == csinfo_.conv2d_grad_filter)) ||
3134 ((n->type_string() == csinfo_.bias_add_grad &&
3135 m->type_string() == csinfo_.conv2d_grad_filter))) {
3136 return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n);
3137 }
3138
3139 return Status(error::Code::UNIMPLEMENTED,
3140 "Unimplemented case for node merge optimization.");
3141 }
3142
3143 //////////////////////////////////////////////////////////////////////////
3144 // Helper functions for node rewrite
3145 //////////////////////////////////////////////////////////////////////////
3146
RewriteNode(std::unique_ptr<Graph> * g,Node * orig_node,const RewriteInfo * ri)3147 Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
3148 Node* orig_node,
3149 const RewriteInfo* ri) {
3150 CHECK_NOTNULL(ri);
3151 CHECK_NOTNULL(orig_node);
3152
3153 VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
3154
3155 // Get all inputs.
3156 int num_inputs = orig_node->in_edges().size();
3157
3158 // Drop count for control edges from inputs
3159 for (const Edge* e : orig_node->in_edges()) {
3160 if (e->IsControlEdge()) {
3161 num_inputs--;
3162 }
3163 }
3164
3165 gtl::InlinedVector<Node*, 4> control_edges;
3166 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
3167 FillInputs(orig_node, &control_edges, &inputs);
3168
3169 // Build new node. We use same name as original node, but change the op name.
3170 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3171 // Copy user-specified device assigned to original node to new node.
3172 nb.Device(orig_node->def().device());
3173 // Set up new inputs to the rewritten node.
3174 Status s = SetUpInputs(g, inputs, &nb, orig_node);
3175 if (s != Status::OK()) {
3176 return s;
3177 }
3178
3179 const bool kPartialCopyAttrs = false;
3180 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs);
3181
3182 // Set the Mkl layer label for this op.
3183 if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3184 DataTypeIsQuantized(orig_node->output_type(0))) {
3185 nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3186 } else {
3187 nb.Attr("_kernel", mkl_op_registry::kMklOpLabel);
3188 }
3189 // Finalize graph and get new node.
3190 Node* new_node = nullptr;
3191 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3192 CHECK_NOTNULL(new_node);
3193
3194 // Incoming data edges from 'orig_node' node to new 'new_node' node are
3195 // already copied in BuildNode. We need to handle control edges now.
3196 for (const Edge* e : orig_node->in_edges()) {
3197 if (e->IsControlEdge()) {
3198 // Allow duplicate while adding control edge as it would fail (return
3199 // NULL) if we try to add duplicate edge.
3200 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
3201 }
3202 }
3203
3204 // Copy outgoing edges from 'orig_node' node to new
3205 // 'new_node' node, since the output also follows same ordering among
3206 // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
3207 // tensors appropriately. Specifically, nth output of the original node
3208 // will become 2*nth output of the Mkl node for the interleaved ordering
3209 // of the tensors. For the contiguous ordering of the tensors, it will be n.
3210 // GetTensorDataIndex provides this mapping function.
3211 for (const Edge* e : orig_node->out_edges()) {
3212 if (e->IsControlEdge()) {
3213 // Allow duplicate while adding control edge as it would fail (return
3214 // NULL) if we try to add duplicate edge.
3215 CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
3216 } else {
3217 CHECK_NOTNULL((*g)->AddEdge(
3218 new_node,
3219 GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
3220 e->dst(), e->dst_input()));
3221 }
3222 }
3223
3224 // Copy the runtime device assigned from original code to new node.
3225 new_node->set_assigned_device_name(orig_node->assigned_device_name());
3226
3227 // Delete original node and mark new node as rewritten.
3228 (*g)->RemoveNode(orig_node);
3229
3230 VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
3231 return Status::OK();
3232 }
3233
3234 // TODO(mdfaijul): Is there any other elegent way to check for quantized ops
3235 // having attributes other than "T"?
3236 // Current implementation reflects only QuantizedConv2D and its fused Ops.
3237 const MklLayoutRewritePass::RewriteInfo*
CheckForQuantizedNodeRewrite(const Node * n) const3238 MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {
3239 DataType Tinput, Tfilter;
3240 if (!(GetNodeAttr(n->def(), "Tinput", &Tinput).ok() &&
3241 GetNodeAttr(n->def(), "Tfilter", &Tfilter).ok())) {
3242 return nullptr;
3243 }
3244 if (mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
3245 Tinput, Tfilter)) {
3246 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3247 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3248 return &*ri;
3249 }
3250 }
3251 }
3252 return nullptr;
3253 }
3254
3255 const MklLayoutRewritePass::RewriteInfo*
CheckForNodeRewrite(const Node * n) const3256 MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
3257 CHECK_NOTNULL(n);
3258
3259 // QuntizedOps may have attributes other than "T", so decoupled the check
3260 // with a function, CheckForQuantizedNodeRewrite(const Node*).
3261 const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n);
3262 if (ri != nullptr) return ri;
3263
3264 // First check if node along with its type is supported by MKL layer.
3265 // We do not want to rewrite an op into Mkl op if types are not supported.
3266 // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
3267 // MklRelu if type is INT32.
3268 DataType T;
3269 if (!GetNodeAttr(n->def(), "T", &T).ok()) {
3270 return nullptr;
3271 }
3272
3273 // We make an exception for __MklDummyConv2DWithBias,
3274 // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their
3275 // names do not match Mkl node names.
3276 if (n->type_string() != csinfo_.conv2d_with_bias &&
3277 n->type_string() != csinfo_.pad_with_conv2d &&
3278 n->type_string() != csinfo_.pad_with_fused_conv2d &&
3279 n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
3280 n->type_string() != csinfo_.fused_conv2d &&
3281 !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
3282 T)) {
3283 return nullptr;
3284 }
3285
3286 // For elementwise node, we reuse the Eigen implementation and pass the MKL
3287 // metadata tensor through so we can avoid conversions. However, if all
3288 // incoming edges are in TF format, we don't need all this overhead, so
3289 // replace the elementwise node only if at least one of its parents is a MKL
3290 // node.
3291 //
3292 // Identity nodes can also skip replacement if they are not being served by
3293 // any MKL nodes.
3294 //
3295 // TODO(vrane): Add implementation for element-wise ops that doesn't reuse
3296 // eigen code to reduce cross-library dependency.
3297 VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string();
3298 if (mkl_op_registry::IsMklElementWiseOp(
3299 mkl_op_registry::GetMklOpName(n->type_string()), T) ||
3300 n->type_string().find("Identity") != string::npos) {
3301 VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string();
3302 bool incoming_mkl_edge = false;
3303 int num_parent = 0;
3304 for (auto parent : n->in_edges()) {
3305 if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) {
3306 VLOG(1) << "ELEMENTWISE: parent " << num_parent++
3307 << " is MKL op: " << parent->src()->type_string();
3308 incoming_mkl_edge = true;
3309 break;
3310 } else {
3311 VLOG(1) << "ELEMENTWISE: parent " << num_parent++
3312 << " is NON-MKL op: " << parent->src()->type_string();
3313 }
3314 }
3315 if (incoming_mkl_edge == false) {
3316 VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which "
3317 "has no MKL "
3318 "parents.";
3319 return nullptr;
3320 } else {
3321 VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string()
3322 << " which has MKL parents";
3323 }
3324 }
3325
3326 // We now check if rewrite rule applies for this op. If rewrite rule passes
3327 // for this op, then we rewrite it to Mkl op.
3328 // Find matching RewriteInfo and then check that rewrite rule applies.
3329 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3330 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3331 return &*ri;
3332 }
3333 }
3334
3335 // Else return not found.
3336 return nullptr;
3337 }
3338
3339 //////////////////////////////////////////////////////////////////////////
3340 // Helper functions for node fusion
3341 //////////////////////////////////////////////////////////////////////////
FuseTransposeMklOpTranspose(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,std::function<void (const Node *,NodeBuilder * nb,bool)> copy_attrs,string data_format)3342 Status MklLayoutRewritePass::FuseTransposeMklOpTranspose(
3343 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3344 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
3345 string data_format) {
3346 Node* transpose_to_nhwc = nodes[0];
3347 Node* mklop = nodes[1];
3348 Node* transpose_to_nchw = nodes[2];
3349
3350 const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs();
3351 gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges;
3352 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in(
3353 transpose_nhwc_num_inputs);
3354 FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges,
3355 &transpose_nhwc_in);
3356
3357 const int mklop_num_inputs = mklop->num_inputs();
3358 gtl::InlinedVector<Node*, 4> mklop_control_edges;
3359 gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs);
3360 FillInputs(mklop, &mklop_control_edges, &mklop_in);
3361
3362 const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs();
3363 gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges;
3364 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in(
3365 transpose_nchw_num_inputs);
3366 FillInputs(transpose_to_nchw, &transpose_nchw_control_edges,
3367 &transpose_nchw_in);
3368
3369 // We use same name as original node, but change the op
3370 // type.
3371 NodeBuilder nb(mklop->name(), mklop->type_string());
3372
3373 // Storing the output slots of the input nodes.
3374 for (int i = 0; i < mklop_num_inputs; i++) {
3375 if (mklop_in[i].first == transpose_to_nhwc) {
3376 // Fill "x":
3377 nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second);
3378 } else {
3379 // Fill inputs other than "x":
3380 nb.Input(mklop_in[i].first, mklop_in[i].second);
3381 }
3382 }
3383
3384 copy_attrs(const_cast<const Node*>(mklop), &nb, true);
3385 nb.Attr("data_format", data_format);
3386
3387 // Copy the device assigned to old node to new node.
3388 nb.Device(mklop->def().device());
3389
3390 // Create node.
3391 Node* new_node;
3392 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3393 DCHECK(new_node);
3394
3395 // Fill outputs.
3396 for (const Edge* e : transpose_to_nchw->out_edges()) {
3397 if (!e->IsControlEdge()) {
3398 const int kTransposeWithMklOpOutputSlot = 0;
3399 auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot,
3400 e->dst(), e->dst_input());
3401 DCHECK(new_edge);
3402 }
3403 }
3404
3405 // Copy device assigned to old node to new node.
3406 new_node->set_assigned_device_name(mklop->assigned_device_name());
3407
3408 // Copy requested_device and assigned_device_name_index
3409 new_node->set_requested_device(mklop->requested_device());
3410 new_node->set_assigned_device_name_index(mklop->assigned_device_name_index());
3411
3412 (*g)->RemoveNode(transpose_to_nhwc);
3413 (*g)->RemoveNode(mklop);
3414 (*g)->RemoveNode(transpose_to_nchw);
3415
3416 return Status::OK();
3417 }
3418
FuseNode(std::unique_ptr<Graph> * g,std::vector<Node * > & nodes,const MklLayoutRewritePass::FusionInfo fi)3419 Status MklLayoutRewritePass::FuseNode(
3420 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3421 const MklLayoutRewritePass::FusionInfo fi) {
3422 return fi.fuse_func(g, nodes, fi.copy_attrs);
3423 }
3424
3425 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
CheckForNodeFusion(Node * a) const3426 MklLayoutRewritePass::CheckForNodeFusion(Node* a) const {
3427 // Stores matched nodes, in the same order as node_checkers.
3428 std::vector<Node*> nodes;
3429
3430 for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) {
3431 //
3432 // Make sure node "a" and its succeding nodes (b, c ...), match the pattern
3433 // defined in fusion info (ops[0], ops[1], ...),
3434 // a.k.a. "a->b->c" matches "op1->op2->op3"
3435 //
3436
3437 // Stores the first unvisted outgoing edge of each matched node in "nodes".
3438 std::stack<EdgeSet::const_iterator> current_neighbor_stack;
3439 nodes.clear();
3440
3441 auto node_checker = fi->node_checkers.begin();
3442 if (a != nullptr && (*node_checker)(a)) {
3443 nodes.push_back(a);
3444 current_neighbor_stack.push(a->out_edges().begin());
3445 ++node_checker;
3446 }
3447
3448 while (!nodes.empty()) {
3449 auto& current_neighbor_iter = current_neighbor_stack.top();
3450
3451 if (current_neighbor_iter != nodes.back()->out_edges().end()) {
3452 // Found an unvisited edge. Goes through the edge to get the neighbor.
3453 Node* neighbor_node = (*current_neighbor_iter)->dst();
3454 ++current_neighbor_stack.top(); // Retrieves the next unvisited edge.
3455
3456 if ((*node_checker)(neighbor_node)) {
3457 // Found a match. Stores the node and moves to the next checker.
3458 nodes.push_back(neighbor_node);
3459 current_neighbor_stack.push(neighbor_node->out_edges().begin());
3460 if (++node_checker == fi->node_checkers.end()) {
3461 return make_tuple(true, nodes, *fi);
3462 }
3463 }
3464 } else {
3465 // Removes the current node since none of its neighbor leads to a
3466 // further match.
3467 nodes.pop_back();
3468 current_neighbor_stack.pop();
3469 --node_checker;
3470 }
3471 }
3472 }
3473
3474 return make_tuple(false, std::vector<Node*>(), FusionInfo());
3475 }
3476
3477 ///////////////////////////////////////////////////////////////////////////////
3478 // Post-rewrite Mkl metadata fixup pass
3479 ///////////////////////////////////////////////////////////////////////////////
FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph> * g,const Edge * e_data,const Edge * e_metadata)3480 bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
3481 const Edge* e_data,
3482 const Edge* e_metadata) {
3483 if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
3484 return false;
3485 }
3486
3487 Node* n_data = e_data->src();
3488 int n_data_op_slot = e_data->src_output();
3489 int n_metadata_op_slot =
3490 GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs());
3491
3492 // If the source of meta edge is a constant node (producing dummy Mkl metadata
3493 // tensor), then we will need to fix.
3494 if (IsConstant(e_metadata->src())) {
3495 Node* e_metadata_dst = e_metadata->dst();
3496 int e_metadata_in_slot = e_metadata->dst_input();
3497 CHECK_NOTNULL((*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst,
3498 e_metadata_in_slot));
3499
3500 (*g)->RemoveEdge(e_metadata);
3501 return true;
3502 }
3503
3504 return false;
3505 }
3506
FixMklMetaDataEdges(std::unique_ptr<Graph> * g,Node * n)3507 bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
3508 Node* n) {
3509 bool result = false;
3510
3511 // If graph node is not Mkl node, then return.
3512 DataType T = DT_INVALID;
3513 if (!GetNodeAttr(n->def(), "T", &T).ok() ||
3514 !mkl_op_registry::IsMklOp(n->type_string(), T)) {
3515 return result;
3516 }
3517
3518 // If it is Mkl node, then check if the input edges to this node that carry
3519 // Mkl metadata are linked up correctly with the source node.
3520
3521 // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
3522 // data tensors + n for Mkl metadata tensors). We need to check for correct
3523 // connection of n metadata tensors only.
3524 int num_data_inputs = n->num_inputs() / 2;
3525 for (int idx = 0; idx < num_data_inputs; idx++) {
3526 // Get the edge connecting input slot with index (idx).
3527 const Edge* e = nullptr;
3528 TF_CHECK_OK(n->input_edge(idx, &e));
3529
3530 // If e is control edge, then skip.
3531 if (e->IsControlEdge()) {
3532 continue;
3533 }
3534
3535 // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
3536 // node, then we don't need to do anything.
3537 Node* e_src = e->src();
3538 if (GetNodeAttr(e_src->def(), "T", &T).ok() &&
3539 mkl_op_registry::IsMklOp(e_src->type_string(), T)) {
3540 // Source node for edge 'e' is Mkl node.
3541 // Destination node and destination input slot of e is node 'n' and 'idx'
3542 // resp.
3543 CHECK_EQ(e->dst(), n);
3544 CHECK_EQ(e->dst_input(), idx);
3545
3546 // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
3547 // 'e'. For that, let's first get the input slot of 'n' where the meta
3548 // edge will feed the value.
3549 int e_meta_in_slot =
3550 GetTensorMetaDataIndex(e->dst_input(), n->num_inputs());
3551 const Edge* e_meta = nullptr;
3552 TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
3553
3554 // Let's check if we need to fix this meta edge.
3555 if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
3556 result = true;
3557 }
3558 }
3559 }
3560
3561 return result;
3562 }
3563
3564 ///////////////////////////////////////////////////////////////////////////////
3565 // Run function for the pass
3566 ///////////////////////////////////////////////////////////////////////////////
3567
RunPass(std::unique_ptr<Graph> * g)3568 bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
3569 bool result = false;
3570 CHECK_NOTNULL(g);
3571
3572 DumpGraph("Before running MklLayoutRewritePass", &**g);
3573
3574 std::vector<Node*> order;
3575 GetReversePostOrder(**g, &order); // This will give us topological sort.
3576 for (Node* n : order) {
3577 // If node is not an op or it cannot run on CPU device, then skip.
3578 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
3579 continue;
3580 }
3581
3582 Node* m = nullptr;
3583 if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) {
3584 // Check if the node 'n' can be merged with any other node. If it can
3585 // be 'm' contains the node with which it can be merged.
3586 string n1_name = n->name();
3587 string n2_name = m->name();
3588
3589 VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
3590 << n2_name << " for merging";
3591
3592 if (MergeNode(g, n, m) == Status::OK()) {
3593 VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
3594 << n2_name;
3595 result = true;
3596 }
3597 }
3598 }
3599
3600 DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g);
3601
3602 order.clear();
3603 GetReversePostOrder(**g, &order); // This will give us topological sort.
3604 for (Node* n : order) {
3605 // If node is not an op or it cannot run on CPU device, then skip.
3606 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
3607 continue;
3608 }
3609
3610 auto check_result = CheckForNodeFusion(n);
3611 bool found_pattern = std::get<0>(check_result);
3612 std::vector<Node*> nodes = std::get<1>(check_result);
3613 const FusionInfo fi = std::get<2>(check_result);
3614
3615 // if "found_pattern" is true, we can do the fusion.
3616 if (found_pattern) {
3617 if (FuseNode(g, nodes, fi) == Status::OK()) {
3618 result = true;
3619 }
3620 }
3621 }
3622 DumpGraph("After running MklLayoutRewritePass(NodeFusion)", &**g);
3623
3624 order.clear();
3625 GetReversePostOrder(**g, &order); // This will give us topological sort.
3626 for (Node* n : order) {
3627 // If node is not an op or it cannot run on CPU device, then skip.
3628 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
3629 continue;
3630 }
3631
3632 const RewriteInfo* ri = nullptr;
3633 // We will first search if node is to be rewritten.
3634 if ((ri = CheckForNodeRewrite(n)) != nullptr) {
3635 string node_name = n->name();
3636 string op_name = n->type_string();
3637
3638 VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
3639 << " with op " << op_name << " for rewrite using"
3640 << " layout optimization.";
3641
3642 if (RewriteNode(g, n, ri) == Status::OK()) {
3643 VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
3644 << " with op " << op_name << " for Mkl layout optimization.";
3645 result = true;
3646 }
3647 }
3648 }
3649
3650 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
3651
3652 order.clear();
3653 GetReversePostOrder(**g, &order); // This will give us topological sort.
3654 for (Node* n : order) {
3655 // If node is not an op or it cannot run on CPU device, then skip.
3656 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
3657 continue;
3658 }
3659 if (FixMklMetaDataEdges(g, n)) {
3660 string node_name = n->name();
3661 string op_name = n->type_string();
3662
3663 VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
3664 << node_name << " with op " << op_name;
3665 result = true;
3666 }
3667 }
3668 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
3669 &**g);
3670
3671 return result;
3672 }
3673
RunMklLayoutRewritePass(std::unique_ptr<Graph> * g)3674 bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
3675 return MklLayoutRewritePass().RunPass(g);
3676 }
3677
Run(const GraphOptimizationPassOptions & options)3678 Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
3679 if (options.graph == nullptr && options.partition_graphs == nullptr) {
3680 return Status::OK();
3681 }
3682 if (DisableMKL()) {
3683 VLOG(2) << "TF-MKL: Disabling MKL";
3684 return Status::OK();
3685 }
3686
3687 auto process_graph = [&](std::unique_ptr<Graph>* g) {
3688 // Get the ownership of a graph
3689 std::unique_ptr<Graph>* ng = std::move(g);
3690 RunPass(ng);
3691 // Return the ownership of a graph back
3692 g->reset(ng->release());
3693 };
3694
3695 if (kMklLayoutRewritePassGroup !=
3696 OptimizationPassRegistry::POST_PARTITIONING) {
3697 // For any pre-partitioning phase, a graph is stored in options.graph.
3698 process_graph(options.graph);
3699 } else {
3700 // For post partitioning phase, graphs are stored in
3701 // options.partition_graphs.
3702 for (auto& pg : *options.partition_graphs) {
3703 process_graph(&pg.second);
3704 }
3705 }
3706
3707 return Status::OK();
3708 }
3709
3710 } // namespace tensorflow
3711
3712 #endif
3713