1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_ 18 19 #ifdef INTEL_MKL 20 #include <limits> 21 #include <memory> 22 #include <vector> 23 24 #include "mkldnn.hpp" 25 #include "tensorflow/core/framework/bounds_check.h" 26 #include "tensorflow/core/framework/kernel_shape_util.h" 27 #include "tensorflow/core/framework/numeric_op.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/tensor_slice.h" 33 #include "tensorflow/core/kernels/conv_grad_ops.h" 34 #include "tensorflow/core/kernels/ops_util.h" 35 #include "tensorflow/core/lib/core/errors.h" 36 #include "tensorflow/core/lib/gtl/array_slice.h" 37 #include "tensorflow/core/lib/strings/numbers.h" 38 #include "tensorflow/core/lib/strings/str_util.h" 39 #include "tensorflow/core/platform/logging.h" 40 #include "tensorflow/core/platform/macros.h" 41 #include "tensorflow/core/util/mkl_util.h" 42 #include "tensorflow/core/util/padding.h" 43 #include "tensorflow/core/util/tensor_format.h" 44 45 using mkldnn::convolution_forward; 46 using mkldnn::prop_kind; 47 using mkldnn::stream; 48 49 namespace tensorflow { 50 51 #define MKLDNN_SIZE_DTYPE memory::dim 52 53 using ConvFwdDesc = mkldnn::convolution_forward::desc; 54 using ConvFwdPd = mkldnn::convolution_forward::primitive_desc; 55 56 class MklDnnConvUtil { 57 protected: 58 OpKernelContext* context_; // We don't own this. 59 std::vector<int32> strides_; 60 std::vector<int32> dilations_; 61 Padding padding_; 62 TensorFormat data_format_; 63 64 public: 65 MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides, 66 Padding pad, TensorFormat fm, 67 const std::vector<int32>& dilations, bool is_depthwise = false) context_(context)68 : context_(context), 69 strides_(strides), 70 dilations_(dilations), 71 padding_(pad), 72 data_format_(fm) {} 73 ~MklDnnConvUtil()74 virtual ~MklDnnConvUtil() { context_ = nullptr; } 75 76 // Calculate Convolution strides GetStridesInMklOrder(memory::dims * strides)77 virtual inline void GetStridesInMklOrder(memory::dims* strides) { 78 // For now we take the stride from the second and third dimensions only 79 // (we do not support striding on the batch or depth dimension). 80 DCHECK(strides); 81 if (strides_.size() == 4) { 82 int stride_rows = GetTensorDim(strides_, data_format_, 'H'); 83 int stride_cols = GetTensorDim(strides_, data_format_, 'W'); 84 *strides = {stride_rows, stride_cols}; 85 } else if (strides_.size() == 5) { 86 int stride_planes = GetTensorDim(strides_, data_format_, '0'); 87 int stride_rows = GetTensorDim(strides_, data_format_, '1'); 88 int stride_cols = GetTensorDim(strides_, data_format_, '2'); 89 *strides = {stride_planes, stride_rows, stride_cols}; 90 } 91 } 92 93 // Calculate Convolution dilations GetDilationsInMklOrder(memory::dims * dilations)94 virtual inline void GetDilationsInMklOrder(memory::dims* dilations) { 95 // For now we take the dilation from the second and third dimensions only 96 // (we do not support dilation on the batch or depth dimension). 97 DCHECK(dilations); 98 if (dilations_.size() == 4) { 99 int dilations_rows = GetTensorDim(dilations_, data_format_, 'H'); 100 int dilations_cols = GetTensorDim(dilations_, data_format_, 'W'); 101 *dilations = {dilations_rows, dilations_cols}; 102 } else if (dilations_.size() == 5) { 103 int dilations_planes = GetTensorDim(dilations_, data_format_, '0'); 104 int dilations_rows = GetTensorDim(dilations_, data_format_, '1'); 105 int dilations_cols = GetTensorDim(dilations_, data_format_, '2'); 106 *dilations = {dilations_planes, dilations_rows, dilations_cols}; 107 } 108 } 109 110 // Calculate Convolution input size in MKL-DNN order. MKL-DNN 111 // requires input in NCHW/NCDHW format. Function does not return anything. 112 // But errors arising from sanity checks are returned in context's 113 // status. GetInputSizeInMklOrder(const TensorShape & input_shape,memory::dims * input_dims)114 virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape, 115 memory::dims* input_dims) { 116 #define CHECK_BOUNDS(val, err_msg) \ 117 do { \ 118 OP_REQUIRES(context_, \ 119 FastBoundsCheck(val, std::numeric_limits<int>::max()), \ 120 errors::InvalidArgument(err_msg)); \ 121 } while (0) 122 123 DCHECK(input_dims); 124 125 // Input channel 126 int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C'); 127 int input_depth = static_cast<int>(input_depth_raw); 128 129 // Input batch 130 int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N'); 131 CHECK_BOUNDS(input_batch_raw, "Input batch too large"); 132 int input_batch = static_cast<int>(input_batch_raw); 133 134 if (strides_.size() == 4) { // NCHW format for Conv2D 135 // Input rows/height 136 int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H'); 137 CHECK_BOUNDS(input_rows_raw, "Input rows too large"); 138 int input_rows = static_cast<int>(input_rows_raw); 139 140 // Input columns/width 141 int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W'); 142 CHECK_BOUNDS(input_cols_raw, "Input cols too large"); 143 int input_cols = static_cast<int>(input_cols_raw); 144 145 // MKL-DNN always requires input in NCHW format Conv2D. 146 std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1); 147 mkldnn_sizes[MklDnnDims::Dim_N] = input_batch; 148 mkldnn_sizes[MklDnnDims::Dim_C] = input_depth; 149 mkldnn_sizes[MklDnnDims::Dim_H] = input_rows; 150 mkldnn_sizes[MklDnnDims::Dim_W] = input_cols; 151 152 *input_dims = mkldnn_sizes; 153 } else if (strides_.size() == 5) { // NCDHW format for Conv3D 154 // Input planes/third-dimension 155 int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0'); 156 CHECK_BOUNDS(input_planes_raw, "Input depth too large"); 157 int input_planes = static_cast<int>(input_planes_raw); 158 159 // Input rows/height 160 int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1'); 161 CHECK_BOUNDS(input_rows_raw, "Input rows too large"); 162 int input_rows = static_cast<int>(input_rows_raw); 163 164 // Input columns/width 165 int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2'); 166 CHECK_BOUNDS(input_cols_raw, "Input cols too large"); 167 int input_cols = static_cast<int>(input_cols_raw); 168 169 // MKL-DNN always requires input in NCDHW format for Conv3D. 170 std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1); 171 mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch; 172 mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth; 173 mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes; 174 mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_rows; 175 mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_cols; 176 177 *input_dims = mkldnn_sizes; 178 } 179 #undef CHECK_BOUNDS 180 } 181 182 // Calculate Convolution filter size in MKL-DNN order. 183 // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format. 184 // Function does not return anything. 185 // But errors arising from sanity checks are returned in context's 186 // status. This function differs from GetConvFilterSizeInMklOrder in 187 // parameter for input - it accepts src_shape since Convolution Backward 188 // Input gets shape of input tensor rather than actual tensor (Convolution 189 // forward gets actual tensor as input). 190 // 191 // TODO(nhasabni): Add similar function for input and filter in MklShape. GetFilterSizeInMklOrder(const TensorShape & input_shape,const TensorShape & filter_shape,memory::dims * filter_dims,bool is_depthwise)192 virtual inline void GetFilterSizeInMklOrder(const TensorShape& input_shape, 193 const TensorShape& filter_shape, 194 memory::dims* filter_dims, 195 bool is_depthwise) { 196 DCHECK(filter_dims); 197 198 OP_REQUIRES(context_, filter_shape.dims() == strides_.size(), 199 errors::InvalidArgument((strides_.size() == 4) 200 ? "filter must be 4-dimensional: " 201 : "filter must be 5-dimensional: ", 202 filter_shape.DebugString())); 203 204 for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) { 205 OP_REQUIRES(context_, 206 FastBoundsCheck(filter_shape.dim_size(i), 207 std::numeric_limits<int>::max()), 208 errors::InvalidArgument("filter too large")); 209 } 210 211 int input_depth = GetTensorDim(input_shape, data_format_, 'C'); 212 213 if (strides_.size() == 4) { // Conv2D 214 OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2), 215 errors::InvalidArgument( 216 "input and filter must have the same depth: ", 217 input_depth, " vs ", filter_shape.dim_size(2))); 218 219 // TF filter is always in (rows, cols, in_depth, out_depth) order. 220 int filter_rows = 221 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_H)); 222 int filter_cols = 223 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_W)); 224 int filter_in_depth = 225 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_I)); 226 int filter_out_depth = 227 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_O)); 228 // MKL-DNN always needs filter in OIHW format for regular convolutions 229 // and GOIHW for grouped/depthwise convolutions, 230 // OIHW = (out_depth, in_depth, rows, cols) 231 // GOIHW = (group, out_depth, in_depth, rows, cols) 232 // Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1 233 if (is_depthwise) { 234 std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1); 235 mkldnn_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth; 236 mkldnn_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth; 237 mkldnn_sizes[MKL_GROUP_FILTER_DIM_I] = 1; 238 mkldnn_sizes[MKL_GROUP_FILTER_DIM_H] = filter_rows; 239 mkldnn_sizes[MKL_GROUP_FILTER_DIM_W] = filter_cols; 240 241 *filter_dims = mkldnn_sizes; 242 } else { 243 std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1); 244 mkldnn_sizes[MklDnnDims::Dim_O] = filter_out_depth; 245 mkldnn_sizes[MklDnnDims::Dim_I] = filter_in_depth; 246 mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows; 247 mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols; 248 249 *filter_dims = mkldnn_sizes; 250 } 251 } else { // Conv3D 252 OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3), 253 errors::InvalidArgument( 254 "input and filter must have the same depth: ", 255 input_depth, " vs ", filter_shape.dim_size(3))); 256 257 // TF filter is always in (planes, rows, cols, in_depth, out_depth) order. 258 int filter_planes = 259 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_P)); 260 int filter_rows = 261 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_H)); 262 int filter_cols = 263 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_W)); 264 int filter_in_depth = 265 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_I)); 266 int filter_out_depth = 267 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_O)); 268 269 // MKL-DNN always needs filter in OIDHW format. 270 // OIDHW = (out_depth, in_depth, planes, rows, cols) 271 std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1); 272 mkldnn_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth; 273 mkldnn_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth; 274 mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes; 275 mkldnn_sizes[MklDnnDims3D::Dim3d_H] = filter_rows; 276 mkldnn_sizes[MklDnnDims3D::Dim3d_W] = filter_cols; 277 278 *filter_dims = mkldnn_sizes; 279 } 280 } 281 282 // Calculate Convolution filter size in MKL-DNN order. 283 // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format. 284 // Function does not return anything. But errors arising from sanity 285 // checks are returned in context's status. GetFilterSizeInMklOrder(size_t src_index,size_t filter_index,memory::dims * filter_dims,bool is_depthwise)286 virtual inline void GetFilterSizeInMklOrder(size_t src_index, 287 size_t filter_index, 288 memory::dims* filter_dims, 289 bool is_depthwise) { 290 DCHECK(filter_dims); 291 GetFilterSizeInMklOrder(GetTfShape(context_, src_index), 292 GetTfShape(context_, filter_index), filter_dims, 293 is_depthwise); 294 } 295 296 // Calculate Bias size for 2D or 3D Convolution. Function does not 297 // return anything, but may set an error in context status. GetBiasSizeInMklOrder(size_t bias_index,memory::dims * bias_dims)298 virtual inline void GetBiasSizeInMklOrder(size_t bias_index, 299 memory::dims* bias_dims) { 300 const Tensor& bias = MklGetInput(context_, bias_index); 301 OP_REQUIRES(context_, bias.dims() == 1, 302 errors::InvalidArgument("bias must be 1-dimensional: ", 303 bias.shape().DebugString())); 304 305 *bias_dims = {static_cast<int>(bias.dim_size(0))}; 306 } 307 308 // Function to calculate output and padding size for 2D/3D convolution. 309 // 310 // Calculate output shape of Convolution in MKL-DNN and TensorFlow order. 311 // MKL-DNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order. 312 // But TensorFlow output will be in NHWC||NCHW(Conv2D) or 313 // NDHWC||NCDHW(Conv3D) format depending on data format. 314 // Function also calculates left, right, top and bottom pads. 315 // Function does not return any status which is set with context status. 316 // 317 // TODO(nhasabni): Add similar function for input and filter in MklShape. 318 virtual inline void GetOutputAndPadSizeInMklOrder( 319 const TensorShape& input_shape, const TensorShape& filter_shape, 320 const memory::dims& strides, const memory::dims& dilations, 321 memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, 322 memory::dims* pad_l, memory::dims* pad_r, bool pad_enabled = false, 323 bool is_depthwise = false) { 324 DCHECK(output_dims_tf_order); 325 DCHECK(output_dims_mkl_order); 326 DCHECK(pad_l); 327 DCHECK(pad_r); 328 329 bool is_conv2d = (strides_.size() == 4); 330 int input_planes, input_rows, input_cols; 331 if (is_conv2d) { 332 input_rows = GetTensorDim(input_shape, data_format_, 'H'); 333 input_cols = GetTensorDim(input_shape, data_format_, 'W'); 334 } else { 335 input_planes = GetTensorDim(input_shape, data_format_, '0'); 336 input_rows = GetTensorDim(input_shape, data_format_, '1'); 337 input_cols = GetTensorDim(input_shape, data_format_, '2'); 338 } 339 340 // Filter dimension 341 // Conv2D: 342 // First dimension: rows/height. 343 // Second dimension: cols/width. 344 // Conv3D: 345 // First dimension: planes/depth. 346 // Second dimension: rows/height. 347 // Third dimension: cols/width. 348 349 int filter_planes, filter_rows, filter_cols; 350 if (is_conv2d) { 351 filter_rows = filter_shape.dim_size(TF_2DFILTER_DIM_H); 352 filter_cols = filter_shape.dim_size(TF_2DFILTER_DIM_W); 353 } else { 354 filter_planes = filter_shape.dim_size(TF_3DFILTER_DIM_P); 355 filter_rows = filter_shape.dim_size(TF_3DFILTER_DIM_H); 356 filter_cols = filter_shape.dim_size(TF_3DFILTER_DIM_W); 357 } 358 359 int stride_planes, stride_rows, stride_cols; 360 int dilation_planes, dilation_rows, dilation_cols; 361 if (is_conv2d) { 362 // Conv2D stride is a vector of 2 elements: {s_r, s_c} 363 stride_rows = strides[0]; 364 stride_cols = strides[1]; 365 dilation_rows = dilations[0]; 366 dilation_cols = dilations[1]; 367 } else { 368 // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c} 369 stride_planes = strides[0]; 370 stride_rows = strides[1]; 371 stride_cols = strides[2]; 372 dilation_planes = dilations[0]; 373 dilation_rows = dilations[1]; 374 dilation_cols = dilations[2]; 375 } 376 377 // Output batch is same as input batch. 378 int out_batch = GetTensorDim(input_shape, data_format_, 'N'); 379 int out_depth; 380 381 // TODO add support for 3-D Depthwise 382 383 // Output depth is same as last dimension for filters for regular 384 // convolutions. For depthwise it is in_depth * channel_multiplier. 385 // The channel_multiplier is the last dimension of TF filter for 386 // depthwise convolutions. 387 if (is_depthwise) { 388 out_depth = (filter_shape.dim_size(TF_2DFILTER_DIM_I) * 389 filter_shape.dim_size(TF_2DFILTER_DIM_O)); 390 } else { 391 out_depth = filter_shape.dim_size( 392 is_conv2d ? static_cast<int>(TF_2DFILTER_DIM_O) 393 : static_cast<int>(TF_3DFILTER_DIM_O)); 394 } 395 396 int64 out_rows = 0, out_cols = 0, out_planes = 0; 397 int64 pad_top = 0, pad_bottom = 0, pad_left = 0, pad_right = 0; 398 int64 pad_D1, pad_D2; 399 400 if (is_conv2d) { 401 Padding padding_type; 402 if (pad_enabled) { 403 padding_type = Padding::EXPLICIT; 404 pad_top = static_cast<int64>((*pad_l)[0]); 405 pad_left = static_cast<int64>((*pad_l)[1]); 406 pad_bottom = static_cast<int64>((*pad_r)[0]); 407 pad_right = static_cast<int64>((*pad_r)[1]); 408 } else { 409 padding_type = padding_; 410 } 411 OP_REQUIRES_OK(context_, 412 GetWindowedOutputSizeVerboseV2( 413 input_rows, filter_rows, dilation_rows, stride_rows, 414 padding_type, &out_rows, &pad_top, &pad_bottom)); 415 OP_REQUIRES_OK(context_, 416 GetWindowedOutputSizeVerboseV2( 417 input_cols, filter_cols, dilation_cols, stride_cols, 418 padding_type, &out_cols, &pad_left, &pad_right)); 419 } else { 420 OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerboseV2( 421 input_planes, filter_planes, dilation_planes, 422 stride_planes, padding_, &out_planes, 423 &pad_D1, &pad_D2)); 424 OP_REQUIRES_OK(context_, 425 GetWindowedOutputSizeVerboseV2( 426 input_rows, filter_rows, dilation_rows, stride_rows, 427 padding_, &out_rows, &pad_top, &pad_bottom)); 428 OP_REQUIRES_OK(context_, 429 GetWindowedOutputSizeVerboseV2( 430 input_cols, filter_cols, dilation_cols, stride_cols, 431 padding_, &out_cols, &pad_left, &pad_right)); 432 } 433 434 if (is_conv2d) { 435 // Conv + pad fusion is enabled only for 2D. 436 // If pad_enabled, i.e., pad and conv op are fused, then 437 // all pads are already passed from pad op through 438 // *pad_l and *pad_r and they don't need to be set here. 439 if (!pad_enabled) { 440 *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; 441 *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)}; 442 } 443 } else { 444 // Set padding for Conv3D here 445 *pad_l = {static_cast<int>(pad_D1), static_cast<int>(pad_top), 446 static_cast<int>(pad_left)}; 447 *pad_r = {static_cast<int>(pad_D2), static_cast<int>(pad_bottom), 448 static_cast<int>(pad_right)}; 449 } 450 // Tensorflow output is in data_format order. 451 // Conv2D: NHWC or NCHW 452 // Conv3D: NDHWC or NCDHW 453 // MKL-DNN uses asymmetric padding. 454 TensorShape out_shape = 455 is_conv2d 456 ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, 457 out_depth) 458 : ShapeFromFormat(data_format_, out_batch, 459 {{out_planes, out_rows, out_cols}}, out_depth); 460 *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); 461 462 if (is_conv2d) { 463 // For Conv2D, MKL-DNN always needs output in NCHW format. 464 std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1); 465 mkldnn_sizes[MklDnnDims::Dim_N] = out_batch; 466 mkldnn_sizes[MklDnnDims::Dim_C] = out_depth; 467 mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows); 468 mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols); 469 *output_dims_mkl_order = mkldnn_sizes; 470 } else { 471 std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1); 472 mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch; 473 mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth; 474 mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes); 475 mkldnn_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows); 476 mkldnn_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols); 477 *output_dims_mkl_order = mkldnn_sizes; 478 } 479 } 480 481 // Calculate output and pad size of forward Convolution operator. 482 // See comment on GetConvOutputAndPadSizeInMklOrder for parameters. 483 // 484 // Function does not return anything, but sets error in context status. GetOutputAndPadSizeInMklOrder(size_t src_index,size_t filter_index,const memory::dims & strides,const memory::dims & dilations,memory::dims * output_dims_tf_order,memory::dims * output_dims_mkl_order,memory::dims * pad_l,memory::dims * pad_r,bool is_depthwise)485 inline void GetOutputAndPadSizeInMklOrder( 486 size_t src_index, size_t filter_index, const memory::dims& strides, 487 const memory::dims& dilations, memory::dims* output_dims_tf_order, 488 memory::dims* output_dims_mkl_order, memory::dims* pad_l, 489 memory::dims* pad_r, bool is_depthwise) { 490 DCHECK(output_dims_tf_order); 491 DCHECK(output_dims_mkl_order); 492 DCHECK(pad_l); 493 DCHECK(pad_r); 494 495 auto input_tf_shape = GetTfShape(context_, src_index); 496 auto filter_tf_shape = GetTfShape(context_, filter_index); 497 498 if (strides_.size() == 4) { 499 // Conv2D 500 OP_REQUIRES(context_, input_tf_shape.dims() == 4, 501 errors::InvalidArgument("input must be 4-dimensional", 502 input_tf_shape.DebugString())); 503 } else { 504 // Conv3D 505 OP_REQUIRES(context_, input_tf_shape.dims() == 5, 506 errors::InvalidArgument("input must be 5-dimensional", 507 input_tf_shape.DebugString())); 508 } 509 510 GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides, 511 dilations, output_dims_tf_order, 512 output_dims_mkl_order, pad_l, pad_r, 513 is_depthwise); 514 } 515 516 // Wrapper function to calculate input, filter, and output sizes of 517 // Conv2D/Conv3D in MKL order: 518 // Conv2D: NCHW for input and output; OIHW for filter. 519 // Conv3D: NCDHW for input and output; OIDHW for filter. 520 // Function also calculates output shape in Tensorflow order. 521 // Additionally, it also calculates strides and paddings. 522 // 523 // Function does not return anything, but sets error in context status. 524 inline void GetConvFwdSizesInMklOrder( 525 const TensorShape& input_shape, const TensorShape& filter_shape, 526 memory::dims* input_dims, memory::dims* filter_dims, 527 memory::dims* strides, memory::dims* dilations, 528 memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, 529 memory::dims* pad_l, memory::dims* pad_r, bool pad_enabled = false, 530 bool is_depthwise = false) { 531 DCHECK(input_dims); 532 DCHECK(filter_dims); 533 DCHECK(strides); 534 DCHECK(dilations); 535 DCHECK(output_dims_tf_order); 536 DCHECK(output_dims_mkl_order); 537 DCHECK(pad_l); 538 DCHECK(pad_r); 539 540 GetInputSizeInMklOrder(input_shape, input_dims); 541 if (!context_->status().ok()) return; 542 GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims, 543 is_depthwise); 544 if (!context_->status().ok()) return; 545 GetStridesInMklOrder(strides); 546 GetDilationsInMklOrder(dilations); 547 GetOutputAndPadSizeInMklOrder( 548 input_shape, filter_shape, *strides, *dilations, output_dims_tf_order, 549 output_dims_mkl_order, pad_l, pad_r, pad_enabled, is_depthwise); 550 if (!context_->status().ok()) return; 551 } 552 }; 553 554 ///////////////////////////////////////////////////////////////////// 555 /// Common class that implements ConvBackpropFilter and Input 556 ///////////////////////////////////////////////////////////////////// 557 558 template <typename Device, class T, bool is_depthwise> 559 class MklConvBackpropCommonOp : public OpKernel { 560 public: ~MklConvBackpropCommonOp()561 ~MklConvBackpropCommonOp() {} MklConvBackpropCommonOp(OpKernelConstruction * context)562 explicit MklConvBackpropCommonOp(OpKernelConstruction* context) 563 : OpKernel(context) { 564 string data_format_str; 565 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); 566 OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), 567 errors::InvalidArgument("Invalid data format")); 568 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 569 int stride_n = GetTensorDim(strides_, data_format_, 'N'); 570 int stride_c = GetTensorDim(strides_, data_format_, 'C'); 571 OP_REQUIRES( 572 context, (stride_n == 1 && stride_c == 1), 573 errors::InvalidArgument("Current implementation does not yet support " 574 "strides in the batch and depth dimensions.")); 575 576 // Depthwise Convolution doesn't have dilation parameter 577 if (!is_depthwise) { 578 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 579 if (strides_.size() == 4) { 580 // Check Conv2D dilations 581 OP_REQUIRES( 582 context, dilations_.size() == 4, 583 errors::InvalidArgument("Sliding window dilations field must " 584 "specify 4 dimensions")); 585 int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); 586 int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); 587 int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); 588 int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); 589 OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1), 590 errors::InvalidArgument( 591 "Current implementation does not yet support " 592 "dilations in the batch and depth dimensions.")); 593 OP_REQUIRES( 594 context, dilation_h > 0 && dilation_w > 0, 595 errors::InvalidArgument("Dilated rates should be larger than 0.")); 596 } 597 } else { 598 // Set dilations as 1 for depthwise conv 599 // for future support to align with Tensorflow 600 dilations_ = {1, 1, 1, 1}; 601 } 602 603 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 604 } 605 606 protected: 607 // data members accessible to derived classes. 608 std::vector<int32> dilations_; 609 std::vector<int32> strides_; 610 Padding padding_; 611 TensorFormat data_format_; // NCHW or NHWC 612 }; 613 614 ///////////////////////////////////////////////////////////////////// 615 /// Dummy Mkl op that is just used for operators that are intermediate 616 /// output of node fusion in the graph 617 ///////////////////////////////////////////////////////////////////// 618 619 template <typename Device, typename T> 620 class MklDummyOp : public OpKernel { 621 public: ~MklDummyOp()622 ~MklDummyOp() {} 623 MklDummyOp(OpKernelConstruction * context)624 explicit MklDummyOp(OpKernelConstruction* context) : OpKernel(context) {} 625 Compute(OpKernelContext * context)626 void Compute(OpKernelContext* context) override { 627 TF_CHECK_OK( 628 errors::Unimplemented("This is a dummy op." 629 "It should not have been invoked.")); 630 } 631 }; 632 633 #undef MKLDNN_SIZE_DTYPE 634 635 } // namespace tensorflow 636 637 #endif // INTEL_MKL 638 #endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_ 639