1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 // This file uses oneDNN library for acceleration of Batch Matrix-Matrix 19 // Multiplication (MatMul) operations. We currently register this kernel only 20 // for oneDNN supported data types (float, bfloat16). The maximum number of 21 // dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank 22 // exceeds 12, we fall back to Eigen library based kernel. 23 24 #define EIGEN_USE_THREADS 25 26 #if defined(INTEL_MKL) 27 28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 29 #include "tensorflow/core/framework/op.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/framework/register_types.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/framework/tensor_shape.h" 34 #include "tensorflow/core/framework/type_traits.h" 35 #include "tensorflow/core/framework/types.h" 36 #include "tensorflow/core/kernels/fill_functor.h" 37 #include "tensorflow/core/kernels/matmul_op_impl.h" 38 #include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h" 39 #include "tensorflow/core/platform/logging.h" 40 #include "tensorflow/core/platform/types.h" 41 #include "tensorflow/core/util/matmul_bcast.h" 42 #include "tensorflow/core/util/mkl_util.h" 43 44 namespace tensorflow { 45 46 typedef Eigen::ThreadPoolDevice CPUDevice; 47 48 // The third parameter v2_bcast is set to true if we are using V2 otherwise 49 // we set it to false. 50 template <typename Device, typename Scalar, bool v2_bcast> 51 class BatchMatMulMkl : public OpKernel { 52 public: BatchMatMulMkl(OpKernelConstruction * context)53 explicit BatchMatMulMkl(OpKernelConstruction* context) 54 : OpKernel(context), eigen_batch_mm_v2_(context) { 55 OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); 56 OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); 57 } 58 ~BatchMatMulMkl()59 virtual ~BatchMatMulMkl() {} 60 Compute(OpKernelContext * ctx)61 void Compute(OpKernelContext* ctx) override { 62 const Tensor& lhs = ctx->input(0); 63 const Tensor& rhs = ctx->input(1); 64 65 if (!v2_bcast) { 66 // Using V1, so check to make sure lhs and rhs dimensions are correct and 67 // no broadcasting is needed. 68 OP_REQUIRES(ctx, lhs.dims() == rhs.dims(), 69 errors::InvalidArgument("lhs and rhs has different ndims: ", 70 lhs.shape().DebugString(), " vs. ", 71 rhs.shape().DebugString())); 72 const int ndims = lhs.dims(); 73 OP_REQUIRES( 74 ctx, ndims >= 2, 75 errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims)); 76 for (int i = 0; i < ndims - 2; ++i) { 77 OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i), 78 errors::InvalidArgument( 79 "lhs.dim(", i, ") and rhs.dim(", i, 80 ") must be the same: ", lhs.shape().DebugString(), 81 " vs ", rhs.shape().DebugString())); 82 } 83 } else { 84 OP_REQUIRES( 85 ctx, lhs.dims() >= 2, 86 errors::InvalidArgument("In[0] ndims must be >= 2: ", lhs.dims())); 87 OP_REQUIRES( 88 ctx, rhs.dims() >= 2, 89 errors::InvalidArgument("In[1] ndims must be >= 2: ", rhs.dims())); 90 } 91 92 // lhs and rhs can have different dimensions 93 const auto ndims_lhs = lhs.dims(); 94 const auto ndims_rhs = rhs.dims(); 95 96 // Get broadcast info 97 MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes()); 98 OP_REQUIRES( 99 ctx, bcast.IsValid(), 100 errors::InvalidArgument( 101 "In[0] and In[1] must have compatible batch dimensions: ", 102 lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString())); 103 104 TensorShape out_shape = bcast.output_batch_shape(); 105 106 auto lhs_rows = lhs.dim_size(ndims_lhs - 2); 107 auto lhs_cols = lhs.dim_size(ndims_lhs - 1); 108 auto rhs_rows = rhs.dim_size(ndims_rhs - 2); 109 auto rhs_cols = rhs.dim_size(ndims_rhs - 1); 110 111 if (adj_x_) std::swap(lhs_rows, lhs_cols); 112 if (adj_y_) std::swap(rhs_rows, rhs_cols); 113 OP_REQUIRES(ctx, lhs_cols == rhs_rows, 114 errors::InvalidArgument( 115 "lhs mismatch rhs shape: ", lhs_cols, " vs. ", rhs_rows, 116 ": ", lhs.shape().DebugString(), " ", 117 rhs.shape().DebugString(), " ", adj_x_, " ", adj_y_)); 118 119 out_shape.AddDim(lhs_rows); 120 out_shape.AddDim(rhs_cols); 121 // The maximum number of dimensions for a tensor in DNNL is 12. 122 OP_REQUIRES( 123 ctx, out_shape.dims() <= 12, 124 errors::InvalidArgument( 125 "Rank of output tensor must be <= 12, but is ", out_shape.dims(), 126 ". Current implementation supports upto rank 12 tensors.")); 127 128 Tensor* out = nullptr; 129 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); 130 if (out->NumElements() == 0) { 131 return; 132 } 133 if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { 134 functor::SetZeroFunctor<Device, Scalar> f; 135 f(ctx->eigen_device<Device>(), out->flat<Scalar>()); 136 return; 137 } 138 139 // Compute parameters for DNNL matmul primitive. 140 auto params = CreateMatMulParams(lhs.shape(), rhs.shape(), out_shape); 141 // Create or retrieve matmul primitive from cache. 142 MklMatMulPrimitive<Scalar>* matmul_prim = 143 MklMatMulPrimitiveFactory<Scalar>::Get( 144 *params, false /* value for do_not_cache */); 145 // Execute matmul primitive. 146 std::shared_ptr<stream> cpu_stream; 147 cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); 148 matmul_prim->Execute(lhs.flat<Scalar>().data(), rhs.flat<Scalar>().data(), 149 out->flat<Scalar>().data(), cpu_stream); 150 } 151 152 private: 153 bool adj_x_; 154 bool adj_y_; 155 BatchMatMulV2Op<CPUDevice, Scalar> eigen_batch_mm_v2_; 156 157 using dims = dnnl::memory::dims; 158 159 // This method makes the rank (ndims) of input same as the output by adding 160 // new axes to the input. For example, if input shape is [a, b, c, d] and 161 // output shape is [e, f, g, h, i, j], then the reshaped input would have a 162 // shape of [1, 1, a, b, c, d]. ExpandInputDimsToOutputShape(const TensorShape & input_shape,const TensorShape & output_shape,dims * reshaped_dims)163 void ExpandInputDimsToOutputShape(const TensorShape& input_shape, 164 const TensorShape& output_shape, 165 dims* reshaped_dims) { 166 auto ndims_input = input_shape.dims(); 167 auto ndims_output = output_shape.dims(); 168 auto dim_offset = ndims_output - ndims_input; 169 DCHECK(dim_offset > 0); 170 reshaped_dims->clear(); 171 reshaped_dims->resize(ndims_output, 1); 172 auto input_dims = input_shape.dim_sizes(); 173 for (int dim_idx = 0; dim_idx < ndims_input; ++dim_idx) 174 reshaped_dims->at(dim_idx + dim_offset) = input_dims[dim_idx]; 175 } 176 CreateMatMulParams(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const TensorShape & out_shape)177 std::unique_ptr<MklMatMulParams> CreateMatMulParams( 178 const TensorShape& lhs_shape, const TensorShape& rhs_shape, 179 const TensorShape& out_shape) { 180 const auto ndims_lhs = lhs_shape.dims(); 181 const auto ndims_rhs = rhs_shape.dims(); 182 const auto ndims_out = out_shape.dims(); 183 auto lhs_dims = TFShapeToMklDnnDims(lhs_shape); 184 auto rhs_dims = TFShapeToMklDnnDims(rhs_shape); 185 auto out_dims = TFShapeToMklDnnDims(out_shape); 186 187 // DNNL matmul_primitive requires ranks of inputs and output to be same. 188 // Create dnnl::memory::dims for inputs and output of same rank. 189 // It is assumed here that MatMulBCast object creates output_batch_shape as 190 // a conforming superset of input batch shapes, i.e., ndims_out >= 191 // ndims_lhs and ndims_out >= ndims_rhs. 192 if (ndims_lhs < ndims_out) { 193 ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims); 194 } 195 if (ndims_rhs < ndims_out) { 196 ExpandInputDimsToOutputShape(rhs_shape, out_shape, &rhs_dims); 197 } 198 199 using dim = dnnl::memory::dim; 200 dim m; // number of rows in x 201 dim k; // number of columns in x 202 dim n; // number of columns in y 203 auto lhs_strides = CalculateTFStrides(lhs_dims); 204 auto rhs_strides = CalculateTFStrides(rhs_dims); 205 auto out_strides = CalculateTFStrides(out_dims); 206 207 if (adj_x_) { 208 int m_idx = ndims_out - 1; 209 int k_idx = ndims_out - 2; 210 m = lhs_dims[m_idx]; 211 k = lhs_dims[k_idx]; 212 std::swap(lhs_dims[m_idx], lhs_dims[k_idx]); 213 lhs_strides[m_idx] = m; 214 lhs_strides[k_idx] = 1; 215 } 216 217 if (adj_y_) { 218 int k_idx = ndims_out - 1; 219 int n_idx = ndims_out - 2; 220 k = rhs_dims[k_idx]; 221 n = rhs_dims[n_idx]; 222 std::swap(rhs_dims[k_idx], rhs_dims[n_idx]); 223 rhs_strides[k_idx] = k; 224 rhs_strides[n_idx] = 1; 225 } 226 return std::make_unique<MklMatMulParams>( 227 lhs_dims, rhs_dims, out_dims, lhs_strides, rhs_strides, out_strides); 228 } 229 }; 230 231 #define REGISTER_BATCH_MATMUL_MKL(TYPE) \ 232 REGISTER_KERNEL_BUILDER(Name("_MklBatchMatMul") \ 233 .Device(DEVICE_CPU) \ 234 .TypeConstraint<TYPE>("T") \ 235 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 236 BatchMatMulMkl<CPUDevice, TYPE, false>) 237 238 #define REGISTER_BATCH_MATMUL_MKL_V2(TYPE) \ 239 REGISTER_KERNEL_BUILDER(Name("_MklBatchMatMulV2") \ 240 .Device(DEVICE_CPU) \ 241 .TypeConstraint<TYPE>("T") \ 242 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 243 BatchMatMulMkl<CPUDevice, TYPE, true>) 244 #ifdef ENABLE_MKL 245 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL); 246 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2); 247 TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL); 248 TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2); 249 #endif // ENABLE_MKL 250 251 } // end namespace tensorflow 252 #endif 253