• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 
13 #ifdef INTEL_MKL
14 
15 #include <limits>
16 #include <unordered_map>
17 #include <vector>
18 
19 #include "mkldnn.hpp"
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/kernels/concat_lib.h"
28 #include "tensorflow/core/kernels/concat_lib_cpu.h"
29 #include "tensorflow/core/kernels/quantization_utils.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/mkl_util.h"
33 
34 using mkldnn::concat;
35 using mkldnn::stream;
36 
37 namespace tensorflow {
38 typedef Eigen::ThreadPoolDevice CPUDevice;
39 
40 // List of TensorShape objects. Used in Concat/Split layers.
41 typedef std::vector<TensorShape> TensorShapeList;
42 
43 enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
44 
45 // TODO(intelft) Check if we can reuse existing EigenConcatOp using Mutable
46 // reference inputs.
47 // --------------------------------------------------------------------------
48 //                      Eigen Concat Op
49 // --------------------------------------------------------------------------
50 template <typename Device, typename T, AxisArgumentName AxisArgName>
51 class EigenConcatBaseOp : public OpKernel {
52  public:
53   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
54       ConstMatrixVector;
55 
EigenConcatBaseOp(OpKernelConstruction * c)56   explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
57 
58   // Although, we modify Compute for this call to accept one extra param,
59   // we need to have empty Compute because Compute is pure virtual function.
Compute(OpKernelContext * c)60   void Compute(OpKernelContext* c) {}
61 
Compute(OpKernelContext * c,const std::vector<Tensor> & values,const TensorShapeList & input_shapes)62   void Compute(OpKernelContext* c, const std::vector<Tensor>& values,
63                const TensorShapeList& input_shapes) {
64     const Tensor* concat_dim_tensor;
65     const char* axis_attribute_name =
66         AxisArgName == NAME_IS_AXIS
67             ? "axis"
68             : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
69     OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
70     OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()),
71                 errors::InvalidArgument(
72                     axis_attribute_name,
73                     " tensor should be a scalar integer, but got shape ",
74                     concat_dim_tensor->shape().DebugString()));
75     const int32 concat_dim =
76         internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
77     // Instead of accessing values from context, we use input to Compute.
78     const int N = values.size();
79     const int input_dims = input_shapes[0].dims();
80     const TensorShape& input_shape = input_shapes[0];
81 
82     int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
83     OP_REQUIRES(c,
84                 (0 <= axis && axis < input_dims) ||
85                     (allow_legacy_scalars() && concat_dim == 0),
86                 errors::InvalidArgument(
87                     "ConcatOp : Expected concatenating dimensions in the range "
88                     "[",
89                     -input_dims, ", ", input_dims, "), but got ", concat_dim));
90     // Note that we reduce the concat of n-dimensional tensors into a two
91     // dimensional concat. Assuming the dimensions of any input/output
92     // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
93     // the dimension indicated with size y0, we flatten it to {x, y}, where y =
94     // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1).
95     ConstMatrixVector inputs_flat;
96     inputs_flat.reserve(N);
97     int64 inputs_flat_dim0 = 1;
98     for (int d = 0; d < axis; ++d) {
99       inputs_flat_dim0 *= input_shape.dim_size(d);
100     }
101     int64 output_concat_dim = 0;
102     const bool input_is_scalar = IsLegacyScalar(input_shape);
103     for (int i = 0; i < N; ++i) {
104       const auto in = values[i];
105       const bool in_is_scalar = IsLegacyScalar(input_shapes[i]);
106       OP_REQUIRES(
107           c,
108           (input_shapes[i].dims() == input_dims) ||
109               (input_is_scalar && in_is_scalar),
110           errors::InvalidArgument(
111               "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
112               input_shape.DebugString(), " vs. shape[", i,
113               "] = ", input_shapes[i].DebugString()));
114       if (in.NumElements() > 0) {
115         int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
116         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
117             in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
118       }
119       output_concat_dim +=
120           input_shapes[i].dims() > 0 ? input_shapes[i].dim_size(axis) : 1;
121     }
122 
123     TensorShape output_shape(input_shape);
124     if (output_shape.dims() == 0) {
125       output_shape.AddDim(output_concat_dim);
126     } else {
127       output_shape.set_dim(axis, output_concat_dim);
128     }
129     Tensor* output = nullptr;
130     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
131     if (output->NumElements() > 0) {
132       int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
133       auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
134       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
135     }
136   }
137 };
138 // --------------------------------------------------------------------------
139 //                      Mkl Concat Op
140 // --------------------------------------------------------------------------
141 
142 template <typename Device, typename T, AxisArgumentName AxisArgName>
143 class MklConcatOp : public OpKernel {
144  private:
145   TensorFormat data_format_;
146   EigenConcatBaseOp<Device, T, AxisArgName> eigen_concat_op_;
147 
148  public:
149   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
150       ConstMatrixVector;
151 
MklConcatOp(OpKernelConstruction * c)152   explicit MklConcatOp(OpKernelConstruction* c)
153       : OpKernel(c), eigen_concat_op_(c) {}
154 
Compute(OpKernelContext * context)155   void Compute(OpKernelContext* context) override {
156     try {
157       auto cpu_engine = engine(engine::cpu, 0);
158       OpInputList input_tensors;
159       GetMklInputList(context, "values", &input_tensors);
160       const int N = input_tensors.size();
161 
162       // Get Tensor shapes.
163       std::vector<MklDnnShape> mkl_input_shapes(N);
164       GetMklShapeList(context, "values", &mkl_input_shapes);
165 
166       const Tensor& concat_dim_tensor = (AxisArgName == NAME_IS_CONCAT_DIM)
167                                             ? MklGetInput(context, 0)
168                                             : MklGetInput(context, N);
169       // Sanity checks
170       OP_REQUIRES(
171           context, IsLegacyScalar(concat_dim_tensor.shape()),
172           errors::InvalidArgument(
173               "Concat dim tensor should be a scalar integer, but got shape ",
174               concat_dim_tensor.shape().DebugString()));
175       int32 concat_dim =
176           internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
177 
178       // check that ranks of all tensors match
179       // and that their shapes match except for concat_dim.
180       int i = 0;
181       bool invoke_eigen = false;
182       bool are_all_mkl_inputs = true, are_all_tf_inputs = true;
183       const TensorShape expected_shape = mkl_input_shapes[0].IsMklTensor()
184                                              ? mkl_input_shapes[0].GetTfShape()
185                                              : input_tensors[0].shape();
186       size_t expected_dims = expected_shape.dims();
187 
188       if (concat_dim < 0) concat_dim = expected_dims + concat_dim;
189 
190       for (auto& s : mkl_input_shapes) {
191         TensorShape s_shape =
192             s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape();
193         size_t s_dims = s_shape.dims();
194 
195         OP_REQUIRES(
196             context, s_dims == expected_dims,
197             errors::InvalidArgument(
198                 "_MklConcatOp : Ranks of all input tensors should match:"
199                 " input dimensions = ",
200                 s_dims, " vs. expected rank = ", expected_dims));
201 
202         for (int d = 0; d < expected_dims; ++d) {
203           if (d == concat_dim) continue;
204 
205           size_t expected_size = expected_shape.dim_size(d);
206           size_t s_size = s_shape.dim_size(d);
207           OP_REQUIRES(
208               context, expected_size == s_size,
209               errors::InvalidArgument("_MklConcatOp : Dimensions of inputs "
210                                       "should match: shape[0][",
211                                       d, "]= ", expected_size, " vs. shape[", i,
212                                       "][", d, "] = ", s_size));
213         }
214 
215         if (s.IsMklTensor())
216           are_all_tf_inputs = false;
217         else
218           are_all_mkl_inputs = false;
219 
220         if (s_dims != 4) invoke_eigen = true;
221         ++i;
222       }
223 
224       // All inputs are not in one format (TF or MKL). This is mixed input case.
225       // We can potentially optimize this case by converting all TF inputs
226       // to Mkl format. But currently, we fall to Eigen for this case.
227       // It may be possible to convert inputs that in TF format to Mkl
228       // format and avoid calling eigen version.
229       if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true;
230 
231       OpInputList input_mins, input_maxes;
232       if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
233         // MKL-DNN concat does not support input tensors that have different
234         // ranges. Check if the ranges of the all input tensors are the same.
235         // If not, forward it to Eigen implementation.
236 
237         OP_REQUIRES_OK(context, context->input_list("input_mins", &input_mins));
238         OP_REQUIRES(context, (input_mins.size() == N),
239                     errors::InvalidArgument(
240                         "QuantizedConcatOp : Expected mins input list length ",
241                         input_mins.size(), " to equal values length ", N));
242 
243         OP_REQUIRES_OK(context,
244                        context->input_list("input_maxes", &input_maxes));
245         OP_REQUIRES(context, (input_maxes.size() == N),
246                     errors::InvalidArgument(
247                         "QuantizedConcatOp : Expected maxes input list length ",
248                         input_maxes.size(), " to equal values length ", N));
249         float input_min = input_mins[0].flat<float>()(0);
250         float input_max = input_maxes[0].flat<float>()(0);
251         const float eps = 1.0e-6;
252         for (int i = 1; i < N; ++i) {
253           float min = input_mins[i].flat<float>()(0);
254           float max = input_maxes[i].flat<float>()(0);
255 
256           if (fabs(input_min - min) > eps || fabs(input_max - max) > eps) {
257             invoke_eigen = true;
258             break;
259           }
260         }
261       }
262 
263       // Call Eigen library
264       if (invoke_eigen) {
265         // MKL-DNN quantized concat does not support input tensors with
266         // different ranges.
267         // TODO (mabuzain): Add quantized version of CallEigen() to support
268         // this case.
269         OP_REQUIRES(
270             context,
271             (!std::is_same<T, qint8>::value && !std::is_same<T, quint8>::value),
272             errors::Unimplemented("MKL DNN quantized concat does not "
273                                   "support input tensors that have "
274                                   "different ranges"));
275         CallEigenVersion(context, input_tensors, mkl_input_shapes);
276         return;
277       }
278 
279       memory::dims dst_dims;
280 
281       if (are_all_mkl_inputs)
282         dst_dims = TFShapeToMklDnnDims(mkl_input_shapes[0].GetTfShape());
283       else
284         // When all the inputs are in Tensorflow format, we don't know
285         // what is the input data format. In that case, we just use
286         // output format that is same as input formats.
287         dst_dims = TFShapeToMklDnnDims(input_tensors[0].shape());
288 
289       std::vector<memory::primitive_desc> srcs_pd;
290       std::vector<MklDnnData<T>> srcs(N, MklDnnData<T>(&cpu_engine));
291       int64 dst_concat_dim_size = 0;
292 
293       bool isMklReorderNeeded = false;
294       memory::format mkl_common_format = memory::format::any;
295       if (are_all_mkl_inputs) {
296         mkl_common_format =
297             FindMklCommonFormat(mkl_input_shapes, concat_dim,
298                                 &isMklReorderNeeded, &dst_concat_dim_size);
299 
300         if (!isMklReorderNeeded) {
301           // All MKL tensors have a same format. Reorder is not needed.
302           for (int k = 0; k < N; k++) {
303             if (input_tensors[k].NumElements() == 0) continue;
304 
305             auto src_md = mkl_input_shapes[k].GetMklLayout();
306             srcs[k].SetUsrMem(src_md, &input_tensors[k]);
307             auto src_mpd = srcs[k].GetUsrMemPrimDesc();
308             srcs_pd.push_back(src_mpd);
309           }
310         } else {
311           // MKL tensors have different formats.
312           // Reorder them to most common format.
313           for (int k = 0; k < N; k++) {
314             if (input_tensors[k].NumElements() == 0) continue;
315 
316             auto src_md = mkl_input_shapes[k].GetMklLayout();
317             srcs[k].SetUsrMem(src_md, &input_tensors[k]);
318 
319             if (src_md.data.format != mkl_common_format) {
320               memory::dims src_dims(src_md.data.dims,
321                                     &src_md.data.dims[src_md.data.ndims]);
322               src_md =
323                   memory::desc(src_dims, MklDnnType<T>(), mkl_common_format);
324             }
325 
326             srcs_pd.push_back(memory::primitive_desc(src_md, cpu_engine));
327           }
328         }
329       } else {  // All TF inputs
330         for (int k = 0; k < N; k++) {
331           if (input_tensors[k].NumElements() == 0) continue;
332 
333           memory::dims src_dims = TFShapeToMklDnnDims(input_tensors[k].shape());
334           dst_concat_dim_size += src_dims[concat_dim];
335 
336           // It does not matter what data format to be used (NHWC versus NCHW).
337           // We just need to ensure that output uses same data format as inputs.
338           auto src_md =
339               memory::desc(src_dims, MklDnnType<T>(), memory::format::nchw);
340 
341           srcs[k].SetUsrMem(src_md, &input_tensors[k]);
342           auto src_mpd = srcs[k].GetUsrMemPrimDesc();
343           srcs_pd.push_back(src_mpd);
344         }
345       }
346       dst_dims[concat_dim] = dst_concat_dim_size;
347 
348       MklDnnData<T> dst(&cpu_engine);
349       memory::desc dst_md({}, memory::data_undef, memory::format_undef);
350       memory::dims dst_dims_in_nchw;
351       if (are_all_mkl_inputs) {
352         // Since we are passing a specific format for destination,
353         // we need to have dst_dims in MklDnn order (NCHW).
354         auto orig_tf_format = mkl_input_shapes[0].GetTfDataFormat();
355         dst_dims_in_nchw = MklDnnDimsInNCHW(
356             dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format));
357         // Set the output format same as the most common format of inputs
358         // to avoid layout conversions.
359         dst_md =
360             memory::desc(dst_dims_in_nchw, MklDnnType<T>(), mkl_common_format);
361       } else {
362         // All inputs are TF tensors.
363         // Set the output format same as input format (nchw).
364         dst_md = memory::desc(dst_dims, MklDnnType<T>(), memory::format::nchw);
365       }
366 
367       std::vector<primitive::at> inputs;
368       if (isMklReorderNeeded) {
369         for (int k = 0; k < input_tensors.size(); k++) {
370           if (input_tensors[k].NumElements() > 0) {
371             srcs[k].CheckReorderToOpMem(srcs_pd[k]);
372           }
373         }
374       }
375       for (int k = 0; k < input_tensors.size(); k++) {
376         if (input_tensors[k].NumElements() > 0) {
377           inputs.push_back(srcs[k].GetOpMem());
378         }
379       }
380 
381       // If all inputs are in MKL format, then meaning of concat_dim needs to
382       // change. Value of concat_dim is tied to input Tensorflow data format
383       // (NHWC or NCHW). MklDnn dimensions are in NCHW order. So if Tensorflow
384       // tensors are in NCHW order, then concat_dim semantics is preserved.
385       // But ifinput tensors are in NHWC order, then semantics need to change.
386       // E.g., if we are concatinating over Channel (dimension 3 for NHWC),
387       // then since MklDnn order is NCHW, concat_dim needs to be 1.
388       if (are_all_mkl_inputs)
389         concat_dim = mkl_input_shapes[0].TfDimIdx(concat_dim);
390 
391       auto concat_pd = concat::primitive_desc(concat_dim, srcs_pd);
392       auto dst_pd = concat_pd.dst_primitive_desc();
393 
394       MklDnnShape dnn_shape_dst;
395       TensorShape tf_shape_dst;
396       Tensor* dst_tensor = nullptr;
397       if (are_all_mkl_inputs) {
398         dnn_shape_dst.SetMklTensor(true);
399         auto dst_pd = concat_pd.dst_primitive_desc();
400         dnn_shape_dst.SetMklLayout(&dst_pd);
401         dnn_shape_dst.SetElemType(MklDnnType<T>());
402         dnn_shape_dst.SetTfLayout(dst_dims.size(), dst_dims_in_nchw,
403                                   mkl_input_shapes[0].GetTfDataFormat());
404         tf_shape_dst.AddDim((dst_pd.get_size() / sizeof(T)));
405       } else {
406         dnn_shape_dst.SetMklTensor(false);
407         tf_shape_dst = MklDnnDimsToTFShape(dst_dims);
408       }
409       AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst,
410                                 dnn_shape_dst);
411       CHECK_NOTNULL(dst_tensor);
412 
413       dst_md =
414           dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout() : dst_md;
415       dst.SetUsrMem(dst_md, dst_tensor);
416 
417       auto concat_op = concat(concat_pd, inputs, dst.GetOpMem());
418       std::vector<primitive> net;
419       net.push_back(concat_op);
420       stream(stream::kind::eager).submit(net).wait();
421 
422       // For quantized concat, min and max outputs are also computed.
423       if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
424         Tensor* output_min = nullptr;
425         Tensor* output_max = nullptr;
426         MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
427         output_min_mkl_shape.SetMklTensor(false);
428         output_max_mkl_shape.SetMklTensor(false);
429         AllocateOutputSetMklShape(context, 1, &output_min, {},
430                                   output_min_mkl_shape);
431         AllocateOutputSetMklShape(context, 2, &output_max, {},
432                                   output_max_mkl_shape);
433         // All input tensors should have the same range, just use the
434         // first one
435         output_min->flat<float>()(0) = input_mins[0].flat<float>()(0);
436         output_max->flat<float>()(0) = input_maxes[0].flat<float>()(0);
437       }
438     } catch (mkldnn::error& e) {
439       string error_msg = "Status: " + std::to_string(e.status) +
440                          ", message: " + string(e.message) + ", in file " +
441                          string(__FILE__) + ":" + std::to_string(__LINE__);
442       OP_REQUIRES_OK(
443           context,
444           errors::Aborted("Operation received an exception:", error_msg));
445     }
446   }
447 
CallEigenVersion(OpKernelContext * context,const OpInputList & values,const MklDnnShapeList & mkl_input_shapes)448   void CallEigenVersion(OpKernelContext* context, const OpInputList& values,
449                         const MklDnnShapeList& mkl_input_shapes) {
450     CHECK_EQ(values.size(), mkl_input_shapes.size());
451 
452     std::vector<Tensor> converted_values;
453     TensorShapeList tf_input_shapes;
454     for (int i = 0; i < mkl_input_shapes.size(); i++) {
455       if (mkl_input_shapes[i].IsMklTensor()) {
456         // do conversion from MKL to TF
457         Tensor tmp_tensor =
458             ConvertMklToTF<T>(context, values[i], mkl_input_shapes[i]);
459         converted_values.push_back(tmp_tensor);
460         tf_input_shapes.push_back(mkl_input_shapes[i].GetTfShape());
461       } else {
462         // no conversion since it is TF tensor already
463         converted_values.push_back(values[i]);
464         tf_input_shapes.push_back(values[i].shape());
465       }
466     }
467 
468     // Call Eigen concat.
469     eigen_concat_op_.Compute(context, converted_values, tf_input_shapes);
470 
471     // Set output Mkl tensor for this op.
472     MklDnnShape dnn_shape_output;
473     dnn_shape_output.SetMklTensor(false);
474     dnn_shape_output.SetDimensions(4);
475     Tensor* output_tensor = nullptr;
476     TensorShape tf_shape_output;
477     tf_shape_output.AddDim(dnn_shape_output.GetSerializeBufferSize());
478     OP_REQUIRES_OK(context,
479                    context->allocate_output(
480                        GetTensorMetaDataIndex(0, context->num_outputs()),
481                        tf_shape_output, &output_tensor));
482     dnn_shape_output.SerializeMklDnnShape(
483         output_tensor->flat<uint8>().data(),
484         output_tensor->flat<uint8>().size() * sizeof(uint8));
485   }
486 
487   // This method finds the most common format across all MKL inputs
488   // Inputs:
489   //   1. input_shapes: shapes of input (MKL) tensors.
490   //   2. concat_dim: concat dimension.
491   // Outputs:
492   //   1. is_reorder_needed is set to true if inputs have difference formats
493   //      It is set to false otherwise.
494   //   2. concat_dim_size is the size of concat_dim.
495   // Return:
496   //   return the common MKL format.
FindMklCommonFormat(const MklDnnShapeList & input_shapes,int concat_dim,bool * is_reorder_needed,int64 * concat_dim_size)497   memory::format FindMklCommonFormat(const MklDnnShapeList& input_shapes,
498                                      int concat_dim, bool* is_reorder_needed,
499                                      int64* concat_dim_size) {
500     *is_reorder_needed = false;
501     *concat_dim_size = 0;
502     std::unordered_map<int, int> occurrence_map;
503     if (input_shapes.size() == 0) return memory::format::any;
504 
505     // Compute ocurrences of each format of all inputs.
506     for (int k = 0; k < input_shapes.size(); k++) {
507       auto src_dims = TFShapeToMklDnnDims(input_shapes[k].GetTfShape());
508       *concat_dim_size += src_dims[concat_dim];
509       int fmt = static_cast<int>(input_shapes[k].GetMklLayout().data.format);
510       occurrence_map[fmt] += 1;
511     }
512 
513     if (occurrence_map.size() == 1) {
514       // this means that all inputs have a same format
515       // return it with is_reorder_needed set false.
516       return static_cast<memory::format>(
517           input_shapes[0].GetMklLayout().data.format);
518     }
519 
520     // Input tensors have different formats. Thus, reorder is needed.
521     // We pick up the most common format to minimize the total
522     // number of input reorder.
523     memory::format commonest_format = memory::format::any;
524     int max_occurrence = 0;
525     *is_reorder_needed = true;
526     for (auto item : occurrence_map) {
527       if (item.second > max_occurrence) {
528         commonest_format = static_cast<memory::format>(item.first);
529         max_occurrence = item.second;
530       }
531     }
532     return commonest_format;
533   }
534 };
535 
536 /* Use optimized concat for float type only */
537 #define REGISTER_MKL_CPU(type)                                              \
538   REGISTER_KERNEL_BUILDER(Name("_MklConcat")                                \
539                               .Device(DEVICE_CPU)                           \
540                               .TypeConstraint<type>("T")                    \
541                               .HostMemory("concat_dim")                     \
542                               .Label(mkl_op_registry::kMklOpLabel),         \
543                           MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>) \
544   REGISTER_KERNEL_BUILDER(Name("_MklConcatV2")                              \
545                               .Device(DEVICE_CPU)                           \
546                               .TypeConstraint<type>("T")                    \
547                               .TypeConstraint<int32>("Tidx")                \
548                               .HostMemory("axis")                           \
549                               .Label(mkl_op_registry::kMklOpLabel),         \
550                           MklConcatOp<CPUDevice, type, NAME_IS_AXIS>)
551 
552 TF_CALL_float(REGISTER_MKL_CPU);
553 
554 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
555                             .Device(DEVICE_CPU)
556                             .TypeConstraint<quint8>("T")
557                             .HostMemory("axis")
558                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
559                         MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>)
560 
561 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
562                             .Device(DEVICE_CPU)
563                             .TypeConstraint<qint8>("T")
564                             .HostMemory("axis")
565                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
566                         MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>)
567 
568 #undef REGISTER_CONCAT_MKL
569 }  // namespace tensorflow
570 
571 #endif  // INTEL_MKL
572