• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_
18 
19 #ifdef INTEL_MKL
20 #include <limits>
21 #include <memory>
22 #include <vector>
23 
24 #include "mkldnn.hpp"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/kernel_shape_util.h"
27 #include "tensorflow/core/framework/numeric_op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_slice.h"
33 #include "tensorflow/core/kernels/conv_grad_ops.h"
34 #include "tensorflow/core/kernels/ops_util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/array_slice.h"
37 #include "tensorflow/core/lib/strings/numbers.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/macros.h"
41 #include "tensorflow/core/util/mkl_util.h"
42 #include "tensorflow/core/util/padding.h"
43 #include "tensorflow/core/util/tensor_format.h"
44 
45 using mkldnn::convolution_forward;
46 using mkldnn::prop_kind;
47 using mkldnn::stream;
48 
49 namespace tensorflow {
50 
51 #define MKLDNN_SIZE_DTYPE memory::dim
52 
53 using ConvFwdDesc = mkldnn::convolution_forward::desc;
54 using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
55 
56 class MklDnnConvUtil {
57  protected:
58   OpKernelContext* context_;  // We don't own this.
59   std::vector<int32> strides_;
60   std::vector<int32> dilations_;
61   Padding padding_;
62   TensorFormat data_format_;
63 
64  public:
65   MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides,
66                  Padding pad, TensorFormat fm,
67                  const std::vector<int32>& dilations, bool is_depthwise = false)
context_(context)68       : context_(context),
69         strides_(strides),
70         dilations_(dilations),
71         padding_(pad),
72         data_format_(fm) {}
73 
~MklDnnConvUtil()74   virtual ~MklDnnConvUtil() { context_ = nullptr; }
75 
76   // Calculate Convolution strides
GetStridesInMklOrder(memory::dims * strides)77   virtual inline void GetStridesInMklOrder(memory::dims* strides) {
78     // For now we take the stride from the second and third dimensions only
79     // (we do not support striding on the batch or depth dimension).
80     DCHECK(strides);
81     if (strides_.size() == 4) {
82       int stride_rows = GetTensorDim(strides_, data_format_, 'H');
83       int stride_cols = GetTensorDim(strides_, data_format_, 'W');
84       *strides = {stride_rows, stride_cols};
85     } else if (strides_.size() == 5) {
86       int stride_planes = GetTensorDim(strides_, data_format_, '0');
87       int stride_rows = GetTensorDim(strides_, data_format_, '1');
88       int stride_cols = GetTensorDim(strides_, data_format_, '2');
89       *strides = {stride_planes, stride_rows, stride_cols};
90     }
91   }
92 
93   // Calculate Convolution dilations
GetDilationsInMklOrder(memory::dims * dilations)94   virtual inline void GetDilationsInMklOrder(memory::dims* dilations) {
95     // For now we take the dilation from the second and third dimensions only
96     // (we do not support dilation on the batch or depth dimension).
97     DCHECK(dilations);
98     if (dilations_.size() == 4) {
99       int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
100       int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
101       *dilations = {dilations_rows, dilations_cols};
102     } else if (dilations_.size() == 5) {
103       int dilations_planes = GetTensorDim(dilations_, data_format_, '0');
104       int dilations_rows = GetTensorDim(dilations_, data_format_, '1');
105       int dilations_cols = GetTensorDim(dilations_, data_format_, '2');
106       *dilations = {dilations_planes, dilations_rows, dilations_cols};
107     }
108   }
109 
110   // Calculate Convolution input size in MKL-DNN order. MKL-DNN
111   // requires input in NCHW/NCDHW format. Function does not return anything.
112   // But errors arising from sanity checks are returned in context's
113   // status.
GetInputSizeInMklOrder(const TensorShape & input_shape,memory::dims * input_dims)114   virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape,
115                                              memory::dims* input_dims) {
116 #define CHECK_BOUNDS(val, err_msg)                                     \
117   do {                                                                 \
118     OP_REQUIRES(context_,                                              \
119                 FastBoundsCheck(val, std::numeric_limits<int>::max()), \
120                 errors::InvalidArgument(err_msg));                     \
121   } while (0)
122 
123     DCHECK(input_dims);
124 
125     // Input channel
126     int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
127     int input_depth = static_cast<int>(input_depth_raw);
128 
129     // Input batch
130     int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
131     CHECK_BOUNDS(input_batch_raw, "Input batch too large");
132     int input_batch = static_cast<int>(input_batch_raw);
133 
134     if (strides_.size() == 4) {  // NCHW format for Conv2D
135       // Input rows/height
136       int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
137       CHECK_BOUNDS(input_rows_raw, "Input rows too large");
138       int input_rows = static_cast<int>(input_rows_raw);
139 
140       // Input columns/width
141       int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
142       CHECK_BOUNDS(input_cols_raw, "Input cols too large");
143       int input_cols = static_cast<int>(input_cols_raw);
144 
145       // MKL-DNN always requires input in NCHW format Conv2D.
146       std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
147       mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
148       mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
149       mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
150       mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
151 
152       *input_dims = mkldnn_sizes;
153     } else if (strides_.size() == 5) {  // NCDHW format for Conv3D
154       // Input planes/third-dimension
155       int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0');
156       CHECK_BOUNDS(input_planes_raw, "Input depth too large");
157       int input_planes = static_cast<int>(input_planes_raw);
158 
159       // Input rows/height
160       int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1');
161       CHECK_BOUNDS(input_rows_raw, "Input rows too large");
162       int input_rows = static_cast<int>(input_rows_raw);
163 
164       // Input columns/width
165       int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2');
166       CHECK_BOUNDS(input_cols_raw, "Input cols too large");
167       int input_cols = static_cast<int>(input_cols_raw);
168 
169       // MKL-DNN always requires input in NCDHW format for Conv3D.
170       std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
171       mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch;
172       mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth;
173       mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes;
174       mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_rows;
175       mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_cols;
176 
177       *input_dims = mkldnn_sizes;
178     }
179 #undef CHECK_BOUNDS
180   }
181 
182   // Calculate Convolution filter size in MKL-DNN order.
183   // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format.
184   // Function does not return anything.
185   // But errors arising from sanity checks are returned in context's
186   // status. This function differs from GetConvFilterSizeInMklOrder in
187   // parameter for input - it accepts src_shape since Convolution Backward
188   // Input gets shape of input tensor rather than actual tensor (Convolution
189   // forward gets actual tensor as input).
190   //
191   // TODO(nhasabni): Add similar function for input and filter in MklShape.
GetFilterSizeInMklOrder(const TensorShape & input_shape,const TensorShape & filter_shape,memory::dims * filter_dims,bool is_depthwise)192   virtual inline void GetFilterSizeInMklOrder(const TensorShape& input_shape,
193                                               const TensorShape& filter_shape,
194                                               memory::dims* filter_dims,
195                                               bool is_depthwise) {
196     DCHECK(filter_dims);
197 
198     OP_REQUIRES(context_, filter_shape.dims() == strides_.size(),
199                 errors::InvalidArgument((strides_.size() == 4)
200                                             ? "filter must be 4-dimensional: "
201                                             : "filter must be 5-dimensional: ",
202                                         filter_shape.DebugString()));
203 
204     for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) {
205       OP_REQUIRES(context_,
206                   FastBoundsCheck(filter_shape.dim_size(i),
207                                   std::numeric_limits<int>::max()),
208                   errors::InvalidArgument("filter too large"));
209     }
210 
211     int input_depth = GetTensorDim(input_shape, data_format_, 'C');
212 
213     if (strides_.size() == 4) {  // Conv2D
214       OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
215                   errors::InvalidArgument(
216                       "input and filter must have the same depth: ",
217                       input_depth, " vs ", filter_shape.dim_size(2)));
218 
219       // TF filter is always in (rows, cols, in_depth, out_depth) order.
220       int filter_rows =
221           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_H));
222       int filter_cols =
223           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_W));
224       int filter_in_depth =
225           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_I));
226       int filter_out_depth =
227           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_O));
228       // MKL-DNN always needs filter in OIHW format for regular convolutions
229       // and GOIHW for grouped/depthwise convolutions,
230       // OIHW = (out_depth, in_depth, rows, cols)
231       // GOIHW = (group, out_depth, in_depth, rows, cols)
232       // Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1
233       if (is_depthwise) {
234         std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
235         mkldnn_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth;
236         mkldnn_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth;
237         mkldnn_sizes[MKL_GROUP_FILTER_DIM_I] = 1;
238         mkldnn_sizes[MKL_GROUP_FILTER_DIM_H] = filter_rows;
239         mkldnn_sizes[MKL_GROUP_FILTER_DIM_W] = filter_cols;
240 
241         *filter_dims = mkldnn_sizes;
242       } else {
243         std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
244         mkldnn_sizes[MklDnnDims::Dim_O] = filter_out_depth;
245         mkldnn_sizes[MklDnnDims::Dim_I] = filter_in_depth;
246         mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
247         mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
248 
249         *filter_dims = mkldnn_sizes;
250       }
251     } else {  // Conv3D
252       OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3),
253                   errors::InvalidArgument(
254                       "input and filter must have the same depth: ",
255                       input_depth, " vs ", filter_shape.dim_size(3)));
256 
257       // TF filter is always in (planes, rows, cols, in_depth, out_depth) order.
258       int filter_planes =
259           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_P));
260       int filter_rows =
261           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_H));
262       int filter_cols =
263           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_W));
264       int filter_in_depth =
265           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_I));
266       int filter_out_depth =
267           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_O));
268 
269       // MKL-DNN always needs filter in OIDHW format.
270       // OIDHW = (out_depth, in_depth, planes, rows, cols)
271       std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
272       mkldnn_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth;
273       mkldnn_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth;
274       mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes;
275       mkldnn_sizes[MklDnnDims3D::Dim3d_H] = filter_rows;
276       mkldnn_sizes[MklDnnDims3D::Dim3d_W] = filter_cols;
277 
278       *filter_dims = mkldnn_sizes;
279     }
280   }
281 
282   // Calculate Convolution filter size in MKL-DNN order.
283   // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format.
284   // Function does not return anything. But errors arising from sanity
285   // checks are returned in context's status.
GetFilterSizeInMklOrder(size_t src_index,size_t filter_index,memory::dims * filter_dims,bool is_depthwise)286   virtual inline void GetFilterSizeInMklOrder(size_t src_index,
287                                               size_t filter_index,
288                                               memory::dims* filter_dims,
289                                               bool is_depthwise) {
290     DCHECK(filter_dims);
291     GetFilterSizeInMklOrder(GetTfShape(context_, src_index),
292                             GetTfShape(context_, filter_index), filter_dims,
293                             is_depthwise);
294   }
295 
296   // Calculate Bias size for 2D or 3D Convolution. Function does not
297   // return anything, but may set an error in context status.
GetBiasSizeInMklOrder(size_t bias_index,memory::dims * bias_dims)298   virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
299                                             memory::dims* bias_dims) {
300     const Tensor& bias = MklGetInput(context_, bias_index);
301     OP_REQUIRES(context_, bias.dims() == 1,
302                 errors::InvalidArgument("bias must be 1-dimensional: ",
303                                         bias.shape().DebugString()));
304 
305     *bias_dims = {static_cast<int>(bias.dim_size(0))};
306   }
307 
308   // Function to calculate output and padding size for 2D/3D convolution.
309   //
310   // Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
311   // MKL-DNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order.
312   // But TensorFlow output will be in NHWC||NCHW(Conv2D) or
313   // NDHWC||NCDHW(Conv3D) format depending on data format.
314   // Function also calculates left, right, top and bottom pads.
315   // Function does not return any status which is set with context status.
316   //
317   // TODO(nhasabni): Add similar function for input and filter in MklShape.
318   virtual inline void GetOutputAndPadSizeInMklOrder(
319       const TensorShape& input_shape, const TensorShape& filter_shape,
320       const memory::dims& strides, const memory::dims& dilations,
321       memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
322       memory::dims* pad_l, memory::dims* pad_r, bool pad_enabled = false,
323       bool is_depthwise = false) {
324     DCHECK(output_dims_tf_order);
325     DCHECK(output_dims_mkl_order);
326     DCHECK(pad_l);
327     DCHECK(pad_r);
328 
329     bool is_conv2d = (strides_.size() == 4);
330     int input_planes, input_rows, input_cols;
331     if (is_conv2d) {
332       input_rows = GetTensorDim(input_shape, data_format_, 'H');
333       input_cols = GetTensorDim(input_shape, data_format_, 'W');
334     } else {
335       input_planes = GetTensorDim(input_shape, data_format_, '0');
336       input_rows = GetTensorDim(input_shape, data_format_, '1');
337       input_cols = GetTensorDim(input_shape, data_format_, '2');
338     }
339 
340     // Filter dimension
341     // Conv2D:
342     //    First dimension: rows/height.
343     //    Second dimension: cols/width.
344     // Conv3D:
345     //    First dimension: planes/depth.
346     //    Second dimension: rows/height.
347     //    Third dimension: cols/width.
348 
349     int filter_planes, filter_rows, filter_cols;
350     if (is_conv2d) {
351       filter_rows = filter_shape.dim_size(TF_2DFILTER_DIM_H);
352       filter_cols = filter_shape.dim_size(TF_2DFILTER_DIM_W);
353     } else {
354       filter_planes = filter_shape.dim_size(TF_3DFILTER_DIM_P);
355       filter_rows = filter_shape.dim_size(TF_3DFILTER_DIM_H);
356       filter_cols = filter_shape.dim_size(TF_3DFILTER_DIM_W);
357     }
358 
359     int stride_planes, stride_rows, stride_cols;
360     int dilation_planes, dilation_rows, dilation_cols;
361     if (is_conv2d) {
362       // Conv2D stride is a vector of 2 elements: {s_r, s_c}
363       stride_rows = strides[0];
364       stride_cols = strides[1];
365       dilation_rows = dilations[0];
366       dilation_cols = dilations[1];
367     } else {
368       // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c}
369       stride_planes = strides[0];
370       stride_rows = strides[1];
371       stride_cols = strides[2];
372       dilation_planes = dilations[0];
373       dilation_rows = dilations[1];
374       dilation_cols = dilations[2];
375     }
376 
377     // Output batch is same as input batch.
378     int out_batch = GetTensorDim(input_shape, data_format_, 'N');
379     int out_depth;
380 
381     // TODO add support for 3-D Depthwise
382 
383     // Output depth is same as last dimension for filters for regular
384     // convolutions. For depthwise it is in_depth * channel_multiplier.
385     // The channel_multiplier is the last dimension of TF filter for
386     // depthwise convolutions.
387     if (is_depthwise) {
388       out_depth = (filter_shape.dim_size(TF_2DFILTER_DIM_I) *
389                    filter_shape.dim_size(TF_2DFILTER_DIM_O));
390     } else {
391       out_depth = filter_shape.dim_size(
392           is_conv2d ? static_cast<int>(TF_2DFILTER_DIM_O)
393                     : static_cast<int>(TF_3DFILTER_DIM_O));
394     }
395 
396     int64 out_rows = 0, out_cols = 0, out_planes = 0;
397     int64 pad_top = 0, pad_bottom = 0, pad_left = 0, pad_right = 0;
398     int64 pad_D1, pad_D2;
399 
400     if (is_conv2d) {
401       Padding padding_type;
402       if (pad_enabled) {
403         padding_type = Padding::EXPLICIT;
404         pad_top = static_cast<int64>((*pad_l)[0]);
405         pad_left = static_cast<int64>((*pad_l)[1]);
406         pad_bottom = static_cast<int64>((*pad_r)[0]);
407         pad_right = static_cast<int64>((*pad_r)[1]);
408       } else {
409         padding_type = padding_;
410       }
411       OP_REQUIRES_OK(context_,
412                      GetWindowedOutputSizeVerboseV2(
413                          input_rows, filter_rows, dilation_rows, stride_rows,
414                          padding_type, &out_rows, &pad_top, &pad_bottom));
415       OP_REQUIRES_OK(context_,
416                      GetWindowedOutputSizeVerboseV2(
417                          input_cols, filter_cols, dilation_cols, stride_cols,
418                          padding_type, &out_cols, &pad_left, &pad_right));
419     } else {
420       OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerboseV2(
421                                    input_planes, filter_planes, dilation_planes,
422                                    stride_planes, padding_, &out_planes,
423                                    &pad_D1, &pad_D2));
424       OP_REQUIRES_OK(context_,
425                      GetWindowedOutputSizeVerboseV2(
426                          input_rows, filter_rows, dilation_rows, stride_rows,
427                          padding_, &out_rows, &pad_top, &pad_bottom));
428       OP_REQUIRES_OK(context_,
429                      GetWindowedOutputSizeVerboseV2(
430                          input_cols, filter_cols, dilation_cols, stride_cols,
431                          padding_, &out_cols, &pad_left, &pad_right));
432     }
433 
434     if (is_conv2d) {
435       // Conv + pad fusion is enabled only for 2D.
436       // If pad_enabled, i.e., pad and conv op are fused, then
437       // all pads are already passed from pad op through
438       // *pad_l and *pad_r and they don't need to be set here.
439       if (!pad_enabled) {
440         *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
441         *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
442       }
443     } else {
444       // Set padding for Conv3D here
445       *pad_l = {static_cast<int>(pad_D1), static_cast<int>(pad_top),
446                 static_cast<int>(pad_left)};
447       *pad_r = {static_cast<int>(pad_D2), static_cast<int>(pad_bottom),
448                 static_cast<int>(pad_right)};
449     }
450     // Tensorflow output is in data_format order.
451     //     Conv2D: NHWC or NCHW
452     //     Conv3D: NDHWC or NCDHW
453     // MKL-DNN uses asymmetric padding.
454     TensorShape out_shape =
455         is_conv2d
456             ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols,
457                               out_depth)
458             : ShapeFromFormat(data_format_, out_batch,
459                               {{out_planes, out_rows, out_cols}}, out_depth);
460     *output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
461 
462     if (is_conv2d) {
463       // For Conv2D, MKL-DNN always needs output in NCHW format.
464       std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
465       mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
466       mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
467       mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
468       mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
469       *output_dims_mkl_order = mkldnn_sizes;
470     } else {
471       std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
472       mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch;
473       mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth;
474       mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes);
475       mkldnn_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows);
476       mkldnn_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols);
477       *output_dims_mkl_order = mkldnn_sizes;
478     }
479   }
480 
481   // Calculate output and pad size of forward Convolution operator.
482   // See comment on GetConvOutputAndPadSizeInMklOrder for parameters.
483   //
484   // Function does not return anything, but sets error in context status.
GetOutputAndPadSizeInMklOrder(size_t src_index,size_t filter_index,const memory::dims & strides,const memory::dims & dilations,memory::dims * output_dims_tf_order,memory::dims * output_dims_mkl_order,memory::dims * pad_l,memory::dims * pad_r,bool is_depthwise)485   inline void GetOutputAndPadSizeInMklOrder(
486       size_t src_index, size_t filter_index, const memory::dims& strides,
487       const memory::dims& dilations, memory::dims* output_dims_tf_order,
488       memory::dims* output_dims_mkl_order, memory::dims* pad_l,
489       memory::dims* pad_r, bool is_depthwise) {
490     DCHECK(output_dims_tf_order);
491     DCHECK(output_dims_mkl_order);
492     DCHECK(pad_l);
493     DCHECK(pad_r);
494 
495     auto input_tf_shape = GetTfShape(context_, src_index);
496     auto filter_tf_shape = GetTfShape(context_, filter_index);
497 
498     if (strides_.size() == 4) {
499       // Conv2D
500       OP_REQUIRES(context_, input_tf_shape.dims() == 4,
501                   errors::InvalidArgument("input must be 4-dimensional",
502                                           input_tf_shape.DebugString()));
503     } else {
504       // Conv3D
505       OP_REQUIRES(context_, input_tf_shape.dims() == 5,
506                   errors::InvalidArgument("input must be 5-dimensional",
507                                           input_tf_shape.DebugString()));
508     }
509 
510     GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides,
511                                   dilations, output_dims_tf_order,
512                                   output_dims_mkl_order, pad_l, pad_r,
513                                   is_depthwise);
514   }
515 
516   // Wrapper function to calculate input, filter, and output sizes of
517   // Conv2D/Conv3D in MKL order:
518   //     Conv2D: NCHW for input and output; OIHW for filter.
519   //     Conv3D: NCDHW for input and output; OIDHW for filter.
520   // Function also calculates output shape in Tensorflow order.
521   // Additionally, it also calculates strides and paddings.
522   //
523   // Function does not return anything, but sets error in context status.
524   inline void GetConvFwdSizesInMklOrder(
525       const TensorShape& input_shape, const TensorShape& filter_shape,
526       memory::dims* input_dims, memory::dims* filter_dims,
527       memory::dims* strides, memory::dims* dilations,
528       memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
529       memory::dims* pad_l, memory::dims* pad_r, bool pad_enabled = false,
530       bool is_depthwise = false) {
531     DCHECK(input_dims);
532     DCHECK(filter_dims);
533     DCHECK(strides);
534     DCHECK(dilations);
535     DCHECK(output_dims_tf_order);
536     DCHECK(output_dims_mkl_order);
537     DCHECK(pad_l);
538     DCHECK(pad_r);
539 
540     GetInputSizeInMklOrder(input_shape, input_dims);
541     if (!context_->status().ok()) return;
542     GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims,
543                             is_depthwise);
544     if (!context_->status().ok()) return;
545     GetStridesInMklOrder(strides);
546     GetDilationsInMklOrder(dilations);
547     GetOutputAndPadSizeInMklOrder(
548         input_shape, filter_shape, *strides, *dilations, output_dims_tf_order,
549         output_dims_mkl_order, pad_l, pad_r, pad_enabled, is_depthwise);
550     if (!context_->status().ok()) return;
551   }
552 };
553 
554 /////////////////////////////////////////////////////////////////////
555 ///  Common class that implements ConvBackpropFilter and Input
556 /////////////////////////////////////////////////////////////////////
557 
558 template <typename Device, class T, bool is_depthwise>
559 class MklConvBackpropCommonOp : public OpKernel {
560  public:
~MklConvBackpropCommonOp()561   ~MklConvBackpropCommonOp() {}
MklConvBackpropCommonOp(OpKernelConstruction * context)562   explicit MklConvBackpropCommonOp(OpKernelConstruction* context)
563       : OpKernel(context) {
564     string data_format_str;
565     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
566     OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
567                 errors::InvalidArgument("Invalid data format"));
568     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
569     int stride_n = GetTensorDim(strides_, data_format_, 'N');
570     int stride_c = GetTensorDim(strides_, data_format_, 'C');
571     OP_REQUIRES(
572         context, (stride_n == 1 && stride_c == 1),
573         errors::InvalidArgument("Current implementation does not yet support "
574                                 "strides in the batch and depth dimensions."));
575 
576     // Depthwise Convolution doesn't have dilation parameter
577     if (!is_depthwise) {
578       OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
579       if (strides_.size() == 4) {
580         // Check Conv2D dilations
581         OP_REQUIRES(
582             context, dilations_.size() == 4,
583             errors::InvalidArgument("Sliding window dilations field must "
584                                     "specify 4 dimensions"));
585         int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
586         int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
587         int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
588         int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
589         OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
590                     errors::InvalidArgument(
591                         "Current implementation does not yet support "
592                         "dilations in the batch and depth dimensions."));
593         OP_REQUIRES(
594             context, dilation_h > 0 && dilation_w > 0,
595             errors::InvalidArgument("Dilated rates should be larger than 0."));
596       }
597     } else {
598       // Set dilations as 1 for depthwise conv
599       // for future support to align with Tensorflow
600       dilations_ = {1, 1, 1, 1};
601     }
602 
603     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
604   }
605 
606  protected:
607   // data members accessible to derived classes.
608   std::vector<int32> dilations_;
609   std::vector<int32> strides_;
610   Padding padding_;
611   TensorFormat data_format_;  // NCHW or NHWC
612 };
613 
614 /////////////////////////////////////////////////////////////////////
615 ///  Dummy Mkl op that is just used for operators that are intermediate
616 ///  output of node fusion in the graph
617 /////////////////////////////////////////////////////////////////////
618 
619 template <typename Device, typename T>
620 class MklDummyOp : public OpKernel {
621  public:
~MklDummyOp()622   ~MklDummyOp() {}
623 
MklDummyOp(OpKernelConstruction * context)624   explicit MklDummyOp(OpKernelConstruction* context) : OpKernel(context) {}
625 
Compute(OpKernelContext * context)626   void Compute(OpKernelContext* context) override {
627     TF_CHECK_OK(
628         errors::Unimplemented("This is a dummy op."
629                               "It should not have been invoked."));
630   }
631 };
632 
633 #undef MKLDNN_SIZE_DTYPE
634 
635 }  // namespace tensorflow
636 
637 #endif  // INTEL_MKL
638 #endif  // TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_
639