1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 Licensed under the Apache License, Version 2.0 (the "License"); 3 you may not use this file except in compliance with the License. 4 You may obtain a copy of the License at 5 http://www.apache.org/licenses/LICENSE-2.0 6 Unless required by applicable law or agreed to in writing, software 7 distributed under the License is distributed on an "AS IS" BASIS, 8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 See the License for the specific language governing permissions and 10 limitations under the License. 11 ==============================================================================*/ 12 13 #ifdef INTEL_MKL 14 15 #include <limits> 16 #include <unordered_map> 17 #include <vector> 18 19 #include "mkldnn.hpp" 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/framework/bounds_check.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_types.h" 26 #include "tensorflow/core/framework/types.h" 27 #include "tensorflow/core/kernels/concat_lib.h" 28 #include "tensorflow/core/kernels/concat_lib_cpu.h" 29 #include "tensorflow/core/kernels/quantization_utils.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/util/mkl_util.h" 33 34 using mkldnn::concat; 35 using mkldnn::stream; 36 37 namespace tensorflow { 38 typedef Eigen::ThreadPoolDevice CPUDevice; 39 40 // List of TensorShape objects. Used in Concat/Split layers. 41 typedef std::vector<TensorShape> TensorShapeList; 42 43 enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM }; 44 45 // TODO(intelft) Check if we can reuse existing EigenConcatOp using Mutable 46 // reference inputs. 47 // -------------------------------------------------------------------------- 48 // Eigen Concat Op 49 // -------------------------------------------------------------------------- 50 template <typename Device, typename T, AxisArgumentName AxisArgName> 51 class EigenConcatBaseOp : public OpKernel { 52 public: 53 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> 54 ConstMatrixVector; 55 EigenConcatBaseOp(OpKernelConstruction * c)56 explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {} 57 58 // Although, we modify Compute for this call to accept one extra param, 59 // we need to have empty Compute because Compute is pure virtual function. Compute(OpKernelContext * c)60 void Compute(OpKernelContext* c) {} 61 Compute(OpKernelContext * c,const std::vector<Tensor> & values,const TensorShapeList & input_shapes)62 void Compute(OpKernelContext* c, const std::vector<Tensor>& values, 63 const TensorShapeList& input_shapes) { 64 const Tensor* concat_dim_tensor; 65 const char* axis_attribute_name = 66 AxisArgName == NAME_IS_AXIS 67 ? "axis" 68 : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>"; 69 OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); 70 OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()), 71 errors::InvalidArgument( 72 axis_attribute_name, 73 " tensor should be a scalar integer, but got shape ", 74 concat_dim_tensor->shape().DebugString())); 75 const int32 concat_dim = 76 internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()()); 77 // Instead of accessing values from context, we use input to Compute. 78 const int N = values.size(); 79 const int input_dims = input_shapes[0].dims(); 80 const TensorShape& input_shape = input_shapes[0]; 81 82 int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; 83 OP_REQUIRES(c, 84 (0 <= axis && axis < input_dims) || 85 (allow_legacy_scalars() && concat_dim == 0), 86 errors::InvalidArgument( 87 "ConcatOp : Expected concatenating dimensions in the range " 88 "[", 89 -input_dims, ", ", input_dims, "), but got ", concat_dim)); 90 // Note that we reduce the concat of n-dimensional tensors into a two 91 // dimensional concat. Assuming the dimensions of any input/output 92 // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along 93 // the dimension indicated with size y0, we flatten it to {x, y}, where y = 94 // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1). 95 ConstMatrixVector inputs_flat; 96 inputs_flat.reserve(N); 97 int64 inputs_flat_dim0 = 1; 98 for (int d = 0; d < axis; ++d) { 99 inputs_flat_dim0 *= input_shape.dim_size(d); 100 } 101 int64 output_concat_dim = 0; 102 const bool input_is_scalar = IsLegacyScalar(input_shape); 103 for (int i = 0; i < N; ++i) { 104 const auto in = values[i]; 105 const bool in_is_scalar = IsLegacyScalar(input_shapes[i]); 106 OP_REQUIRES( 107 c, 108 (input_shapes[i].dims() == input_dims) || 109 (input_is_scalar && in_is_scalar), 110 errors::InvalidArgument( 111 "ConcatOp : Ranks of all input tensors should match: shape[0] = ", 112 input_shape.DebugString(), " vs. shape[", i, 113 "] = ", input_shapes[i].DebugString())); 114 if (in.NumElements() > 0) { 115 int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; 116 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( 117 in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1}))); 118 } 119 output_concat_dim += 120 input_shapes[i].dims() > 0 ? input_shapes[i].dim_size(axis) : 1; 121 } 122 123 TensorShape output_shape(input_shape); 124 if (output_shape.dims() == 0) { 125 output_shape.AddDim(output_concat_dim); 126 } else { 127 output_shape.set_dim(axis, output_concat_dim); 128 } 129 Tensor* output = nullptr; 130 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); 131 if (output->NumElements() > 0) { 132 int64 output_dim1 = output->NumElements() / inputs_flat_dim0; 133 auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1}); 134 ConcatCPU<T>(c->device(), inputs_flat, &output_flat); 135 } 136 } 137 }; 138 // -------------------------------------------------------------------------- 139 // Mkl Concat Op 140 // -------------------------------------------------------------------------- 141 142 template <typename Device, typename T, AxisArgumentName AxisArgName> 143 class MklConcatOp : public OpKernel { 144 private: 145 TensorFormat data_format_; 146 EigenConcatBaseOp<Device, T, AxisArgName> eigen_concat_op_; 147 148 public: 149 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> 150 ConstMatrixVector; 151 MklConcatOp(OpKernelConstruction * c)152 explicit MklConcatOp(OpKernelConstruction* c) 153 : OpKernel(c), eigen_concat_op_(c) {} 154 Compute(OpKernelContext * context)155 void Compute(OpKernelContext* context) override { 156 try { 157 auto cpu_engine = engine(engine::cpu, 0); 158 OpInputList input_tensors; 159 GetMklInputList(context, "values", &input_tensors); 160 const int N = input_tensors.size(); 161 162 // Get Tensor shapes. 163 std::vector<MklDnnShape> mkl_input_shapes(N); 164 GetMklShapeList(context, "values", &mkl_input_shapes); 165 166 const Tensor& concat_dim_tensor = (AxisArgName == NAME_IS_CONCAT_DIM) 167 ? MklGetInput(context, 0) 168 : MklGetInput(context, N); 169 // Sanity checks 170 OP_REQUIRES( 171 context, IsLegacyScalar(concat_dim_tensor.shape()), 172 errors::InvalidArgument( 173 "Concat dim tensor should be a scalar integer, but got shape ", 174 concat_dim_tensor.shape().DebugString())); 175 int32 concat_dim = 176 internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()()); 177 178 // check that ranks of all tensors match 179 // and that their shapes match except for concat_dim. 180 int i = 0; 181 bool invoke_eigen = false; 182 bool are_all_mkl_inputs = true, are_all_tf_inputs = true; 183 const TensorShape expected_shape = mkl_input_shapes[0].IsMklTensor() 184 ? mkl_input_shapes[0].GetTfShape() 185 : input_tensors[0].shape(); 186 size_t expected_dims = expected_shape.dims(); 187 188 if (concat_dim < 0) concat_dim = expected_dims + concat_dim; 189 190 for (auto& s : mkl_input_shapes) { 191 TensorShape s_shape = 192 s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape(); 193 size_t s_dims = s_shape.dims(); 194 195 OP_REQUIRES( 196 context, s_dims == expected_dims, 197 errors::InvalidArgument( 198 "_MklConcatOp : Ranks of all input tensors should match:" 199 " input dimensions = ", 200 s_dims, " vs. expected rank = ", expected_dims)); 201 202 for (int d = 0; d < expected_dims; ++d) { 203 if (d == concat_dim) continue; 204 205 size_t expected_size = expected_shape.dim_size(d); 206 size_t s_size = s_shape.dim_size(d); 207 OP_REQUIRES( 208 context, expected_size == s_size, 209 errors::InvalidArgument("_MklConcatOp : Dimensions of inputs " 210 "should match: shape[0][", 211 d, "]= ", expected_size, " vs. shape[", i, 212 "][", d, "] = ", s_size)); 213 } 214 215 if (s.IsMklTensor()) 216 are_all_tf_inputs = false; 217 else 218 are_all_mkl_inputs = false; 219 220 if (s_dims != 4) invoke_eigen = true; 221 ++i; 222 } 223 224 // All inputs are not in one format (TF or MKL). This is mixed input case. 225 // We can potentially optimize this case by converting all TF inputs 226 // to Mkl format. But currently, we fall to Eigen for this case. 227 // It may be possible to convert inputs that in TF format to Mkl 228 // format and avoid calling eigen version. 229 if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true; 230 231 OpInputList input_mins, input_maxes; 232 if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) { 233 // MKL-DNN concat does not support input tensors that have different 234 // ranges. Check if the ranges of the all input tensors are the same. 235 // If not, forward it to Eigen implementation. 236 237 OP_REQUIRES_OK(context, context->input_list("input_mins", &input_mins)); 238 OP_REQUIRES(context, (input_mins.size() == N), 239 errors::InvalidArgument( 240 "QuantizedConcatOp : Expected mins input list length ", 241 input_mins.size(), " to equal values length ", N)); 242 243 OP_REQUIRES_OK(context, 244 context->input_list("input_maxes", &input_maxes)); 245 OP_REQUIRES(context, (input_maxes.size() == N), 246 errors::InvalidArgument( 247 "QuantizedConcatOp : Expected maxes input list length ", 248 input_maxes.size(), " to equal values length ", N)); 249 float input_min = input_mins[0].flat<float>()(0); 250 float input_max = input_maxes[0].flat<float>()(0); 251 const float eps = 1.0e-6; 252 for (int i = 1; i < N; ++i) { 253 float min = input_mins[i].flat<float>()(0); 254 float max = input_maxes[i].flat<float>()(0); 255 256 if (fabs(input_min - min) > eps || fabs(input_max - max) > eps) { 257 invoke_eigen = true; 258 break; 259 } 260 } 261 } 262 263 // Call Eigen library 264 if (invoke_eigen) { 265 // MKL-DNN quantized concat does not support input tensors with 266 // different ranges. 267 // TODO (mabuzain): Add quantized version of CallEigen() to support 268 // this case. 269 OP_REQUIRES( 270 context, 271 (!std::is_same<T, qint8>::value && !std::is_same<T, quint8>::value), 272 errors::Unimplemented("MKL DNN quantized concat does not " 273 "support input tensors that have " 274 "different ranges")); 275 CallEigenVersion(context, input_tensors, mkl_input_shapes); 276 return; 277 } 278 279 memory::dims dst_dims; 280 281 if (are_all_mkl_inputs) 282 dst_dims = TFShapeToMklDnnDims(mkl_input_shapes[0].GetTfShape()); 283 else 284 // When all the inputs are in Tensorflow format, we don't know 285 // what is the input data format. In that case, we just use 286 // output format that is same as input formats. 287 dst_dims = TFShapeToMklDnnDims(input_tensors[0].shape()); 288 289 std::vector<memory::primitive_desc> srcs_pd; 290 std::vector<MklDnnData<T>> srcs(N, MklDnnData<T>(&cpu_engine)); 291 int64 dst_concat_dim_size = 0; 292 293 bool isMklReorderNeeded = false; 294 memory::format mkl_common_format = memory::format::any; 295 if (are_all_mkl_inputs) { 296 mkl_common_format = 297 FindMklCommonFormat(mkl_input_shapes, concat_dim, 298 &isMklReorderNeeded, &dst_concat_dim_size); 299 300 if (!isMklReorderNeeded) { 301 // All MKL tensors have a same format. Reorder is not needed. 302 for (int k = 0; k < N; k++) { 303 if (input_tensors[k].NumElements() == 0) continue; 304 305 auto src_md = mkl_input_shapes[k].GetMklLayout(); 306 srcs[k].SetUsrMem(src_md, &input_tensors[k]); 307 auto src_mpd = srcs[k].GetUsrMemPrimDesc(); 308 srcs_pd.push_back(src_mpd); 309 } 310 } else { 311 // MKL tensors have different formats. 312 // Reorder them to most common format. 313 for (int k = 0; k < N; k++) { 314 if (input_tensors[k].NumElements() == 0) continue; 315 316 auto src_md = mkl_input_shapes[k].GetMklLayout(); 317 srcs[k].SetUsrMem(src_md, &input_tensors[k]); 318 319 if (src_md.data.format != mkl_common_format) { 320 memory::dims src_dims(src_md.data.dims, 321 &src_md.data.dims[src_md.data.ndims]); 322 src_md = 323 memory::desc(src_dims, MklDnnType<T>(), mkl_common_format); 324 } 325 326 srcs_pd.push_back(memory::primitive_desc(src_md, cpu_engine)); 327 } 328 } 329 } else { // All TF inputs 330 for (int k = 0; k < N; k++) { 331 if (input_tensors[k].NumElements() == 0) continue; 332 333 memory::dims src_dims = TFShapeToMklDnnDims(input_tensors[k].shape()); 334 dst_concat_dim_size += src_dims[concat_dim]; 335 336 // It does not matter what data format to be used (NHWC versus NCHW). 337 // We just need to ensure that output uses same data format as inputs. 338 auto src_md = 339 memory::desc(src_dims, MklDnnType<T>(), memory::format::nchw); 340 341 srcs[k].SetUsrMem(src_md, &input_tensors[k]); 342 auto src_mpd = srcs[k].GetUsrMemPrimDesc(); 343 srcs_pd.push_back(src_mpd); 344 } 345 } 346 dst_dims[concat_dim] = dst_concat_dim_size; 347 348 MklDnnData<T> dst(&cpu_engine); 349 memory::desc dst_md({}, memory::data_undef, memory::format_undef); 350 memory::dims dst_dims_in_nchw; 351 if (are_all_mkl_inputs) { 352 // Since we are passing a specific format for destination, 353 // we need to have dst_dims in MklDnn order (NCHW). 354 auto orig_tf_format = mkl_input_shapes[0].GetTfDataFormat(); 355 dst_dims_in_nchw = MklDnnDimsInNCHW( 356 dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format)); 357 // Set the output format same as the most common format of inputs 358 // to avoid layout conversions. 359 dst_md = 360 memory::desc(dst_dims_in_nchw, MklDnnType<T>(), mkl_common_format); 361 } else { 362 // All inputs are TF tensors. 363 // Set the output format same as input format (nchw). 364 dst_md = memory::desc(dst_dims, MklDnnType<T>(), memory::format::nchw); 365 } 366 367 std::vector<primitive::at> inputs; 368 if (isMklReorderNeeded) { 369 for (int k = 0; k < input_tensors.size(); k++) { 370 if (input_tensors[k].NumElements() > 0) { 371 srcs[k].CheckReorderToOpMem(srcs_pd[k]); 372 } 373 } 374 } 375 for (int k = 0; k < input_tensors.size(); k++) { 376 if (input_tensors[k].NumElements() > 0) { 377 inputs.push_back(srcs[k].GetOpMem()); 378 } 379 } 380 381 // If all inputs are in MKL format, then meaning of concat_dim needs to 382 // change. Value of concat_dim is tied to input Tensorflow data format 383 // (NHWC or NCHW). MklDnn dimensions are in NCHW order. So if Tensorflow 384 // tensors are in NCHW order, then concat_dim semantics is preserved. 385 // But ifinput tensors are in NHWC order, then semantics need to change. 386 // E.g., if we are concatinating over Channel (dimension 3 for NHWC), 387 // then since MklDnn order is NCHW, concat_dim needs to be 1. 388 if (are_all_mkl_inputs) 389 concat_dim = mkl_input_shapes[0].TfDimIdx(concat_dim); 390 391 auto concat_pd = concat::primitive_desc(concat_dim, srcs_pd); 392 auto dst_pd = concat_pd.dst_primitive_desc(); 393 394 MklDnnShape dnn_shape_dst; 395 TensorShape tf_shape_dst; 396 Tensor* dst_tensor = nullptr; 397 if (are_all_mkl_inputs) { 398 dnn_shape_dst.SetMklTensor(true); 399 auto dst_pd = concat_pd.dst_primitive_desc(); 400 dnn_shape_dst.SetMklLayout(&dst_pd); 401 dnn_shape_dst.SetElemType(MklDnnType<T>()); 402 dnn_shape_dst.SetTfLayout(dst_dims.size(), dst_dims_in_nchw, 403 mkl_input_shapes[0].GetTfDataFormat()); 404 tf_shape_dst.AddDim((dst_pd.get_size() / sizeof(T))); 405 } else { 406 dnn_shape_dst.SetMklTensor(false); 407 tf_shape_dst = MklDnnDimsToTFShape(dst_dims); 408 } 409 AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst, 410 dnn_shape_dst); 411 CHECK_NOTNULL(dst_tensor); 412 413 dst_md = 414 dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout() : dst_md; 415 dst.SetUsrMem(dst_md, dst_tensor); 416 417 auto concat_op = concat(concat_pd, inputs, dst.GetOpMem()); 418 std::vector<primitive> net; 419 net.push_back(concat_op); 420 stream(stream::kind::eager).submit(net).wait(); 421 422 // For quantized concat, min and max outputs are also computed. 423 if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) { 424 Tensor* output_min = nullptr; 425 Tensor* output_max = nullptr; 426 MklDnnShape output_min_mkl_shape, output_max_mkl_shape; 427 output_min_mkl_shape.SetMklTensor(false); 428 output_max_mkl_shape.SetMklTensor(false); 429 AllocateOutputSetMklShape(context, 1, &output_min, {}, 430 output_min_mkl_shape); 431 AllocateOutputSetMklShape(context, 2, &output_max, {}, 432 output_max_mkl_shape); 433 // All input tensors should have the same range, just use the 434 // first one 435 output_min->flat<float>()(0) = input_mins[0].flat<float>()(0); 436 output_max->flat<float>()(0) = input_maxes[0].flat<float>()(0); 437 } 438 } catch (mkldnn::error& e) { 439 string error_msg = "Status: " + std::to_string(e.status) + 440 ", message: " + string(e.message) + ", in file " + 441 string(__FILE__) + ":" + std::to_string(__LINE__); 442 OP_REQUIRES_OK( 443 context, 444 errors::Aborted("Operation received an exception:", error_msg)); 445 } 446 } 447 CallEigenVersion(OpKernelContext * context,const OpInputList & values,const MklDnnShapeList & mkl_input_shapes)448 void CallEigenVersion(OpKernelContext* context, const OpInputList& values, 449 const MklDnnShapeList& mkl_input_shapes) { 450 CHECK_EQ(values.size(), mkl_input_shapes.size()); 451 452 std::vector<Tensor> converted_values; 453 TensorShapeList tf_input_shapes; 454 for (int i = 0; i < mkl_input_shapes.size(); i++) { 455 if (mkl_input_shapes[i].IsMklTensor()) { 456 // do conversion from MKL to TF 457 Tensor tmp_tensor = 458 ConvertMklToTF<T>(context, values[i], mkl_input_shapes[i]); 459 converted_values.push_back(tmp_tensor); 460 tf_input_shapes.push_back(mkl_input_shapes[i].GetTfShape()); 461 } else { 462 // no conversion since it is TF tensor already 463 converted_values.push_back(values[i]); 464 tf_input_shapes.push_back(values[i].shape()); 465 } 466 } 467 468 // Call Eigen concat. 469 eigen_concat_op_.Compute(context, converted_values, tf_input_shapes); 470 471 // Set output Mkl tensor for this op. 472 MklDnnShape dnn_shape_output; 473 dnn_shape_output.SetMklTensor(false); 474 dnn_shape_output.SetDimensions(4); 475 Tensor* output_tensor = nullptr; 476 TensorShape tf_shape_output; 477 tf_shape_output.AddDim(dnn_shape_output.GetSerializeBufferSize()); 478 OP_REQUIRES_OK(context, 479 context->allocate_output( 480 GetTensorMetaDataIndex(0, context->num_outputs()), 481 tf_shape_output, &output_tensor)); 482 dnn_shape_output.SerializeMklDnnShape( 483 output_tensor->flat<uint8>().data(), 484 output_tensor->flat<uint8>().size() * sizeof(uint8)); 485 } 486 487 // This method finds the most common format across all MKL inputs 488 // Inputs: 489 // 1. input_shapes: shapes of input (MKL) tensors. 490 // 2. concat_dim: concat dimension. 491 // Outputs: 492 // 1. is_reorder_needed is set to true if inputs have difference formats 493 // It is set to false otherwise. 494 // 2. concat_dim_size is the size of concat_dim. 495 // Return: 496 // return the common MKL format. FindMklCommonFormat(const MklDnnShapeList & input_shapes,int concat_dim,bool * is_reorder_needed,int64 * concat_dim_size)497 memory::format FindMklCommonFormat(const MklDnnShapeList& input_shapes, 498 int concat_dim, bool* is_reorder_needed, 499 int64* concat_dim_size) { 500 *is_reorder_needed = false; 501 *concat_dim_size = 0; 502 std::unordered_map<int, int> occurrence_map; 503 if (input_shapes.size() == 0) return memory::format::any; 504 505 // Compute ocurrences of each format of all inputs. 506 for (int k = 0; k < input_shapes.size(); k++) { 507 auto src_dims = TFShapeToMklDnnDims(input_shapes[k].GetTfShape()); 508 *concat_dim_size += src_dims[concat_dim]; 509 int fmt = static_cast<int>(input_shapes[k].GetMklLayout().data.format); 510 occurrence_map[fmt] += 1; 511 } 512 513 if (occurrence_map.size() == 1) { 514 // this means that all inputs have a same format 515 // return it with is_reorder_needed set false. 516 return static_cast<memory::format>( 517 input_shapes[0].GetMklLayout().data.format); 518 } 519 520 // Input tensors have different formats. Thus, reorder is needed. 521 // We pick up the most common format to minimize the total 522 // number of input reorder. 523 memory::format commonest_format = memory::format::any; 524 int max_occurrence = 0; 525 *is_reorder_needed = true; 526 for (auto item : occurrence_map) { 527 if (item.second > max_occurrence) { 528 commonest_format = static_cast<memory::format>(item.first); 529 max_occurrence = item.second; 530 } 531 } 532 return commonest_format; 533 } 534 }; 535 536 /* Use optimized concat for float type only */ 537 #define REGISTER_MKL_CPU(type) \ 538 REGISTER_KERNEL_BUILDER(Name("_MklConcat") \ 539 .Device(DEVICE_CPU) \ 540 .TypeConstraint<type>("T") \ 541 .HostMemory("concat_dim") \ 542 .Label(mkl_op_registry::kMklOpLabel), \ 543 MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>) \ 544 REGISTER_KERNEL_BUILDER(Name("_MklConcatV2") \ 545 .Device(DEVICE_CPU) \ 546 .TypeConstraint<type>("T") \ 547 .TypeConstraint<int32>("Tidx") \ 548 .HostMemory("axis") \ 549 .Label(mkl_op_registry::kMklOpLabel), \ 550 MklConcatOp<CPUDevice, type, NAME_IS_AXIS>) 551 552 TF_CALL_float(REGISTER_MKL_CPU); 553 554 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2") 555 .Device(DEVICE_CPU) 556 .TypeConstraint<quint8>("T") 557 .HostMemory("axis") 558 .Label(mkl_op_registry::kMklQuantizedOpLabel), 559 MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>) 560 561 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2") 562 .Device(DEVICE_CPU) 563 .TypeConstraint<qint8>("T") 564 .HostMemory("axis") 565 .Label(mkl_op_registry::kMklQuantizedOpLabel), 566 MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>) 567 568 #undef REGISTER_CONCAT_MKL 569 } // namespace tensorflow 570 571 #endif // INTEL_MKL 572