• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/matmul_op.h"
21 
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/kernels/fill_functor.h"
26 #include "tensorflow/core/util/matmul_autotune.h"
27 #if GOOGLE_CUDA
28 #include "third_party/gpus/cuda/include/cuda.h"
29 #endif
30 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
31 #include "tensorflow/core/kernels/gpu_utils.h"
32 #include "tensorflow/core/platform/stream_executor.h"
33 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
34 
35 namespace tensorflow {
36 
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 typedef Eigen::GpuDevice GPUDevice;
39 #ifdef TENSORFLOW_USE_SYCL
40 typedef Eigen::SyclDevice SYCLDevice;
41 #endif  // TENSORFLOW_USE_SYCL
42 
43 template <typename Device, typename T, bool USE_CUBLAS>
44 struct LaunchMatMul;
45 
46 namespace {
47 // Converts a TensorFlow Tensor to an Eigen Matrix.
48 template <typename T>
49 Eigen::Map<
50     const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
ToEigenMatrix(const Tensor & tensor)51 ToEigenMatrix(const Tensor& tensor) {
52   auto matrix = tensor.matrix<T>();
53   return Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Map(
54       matrix.data(), matrix.dimension(0), matrix.dimension(1));
55 }
56 
57 // Converts a TensorFlow Tensor to an Eigen Vector.
58 template <typename T>
ToEigenVector(Tensor * tensor)59 Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(Tensor* tensor) {
60   auto v = tensor->flat<T>();
61   return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
62 }
63 template <typename T>
ToEigenVector(const Tensor & tensor)64 Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(
65     const Tensor& tensor) {
66   auto v = tensor.flat<T>();
67   return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
68 }
69 }  // namespace
70 
71 // If either side can be represented as a vector, do an explicit vector
72 // matrix multiply and return true; else return false.
73 //
74 // Note: this uses plain Eigen and not Eigen Tensor because it is more
75 // efficient.
76 template <typename T>
ExplicitVectorMatrixOptimization(const Tensor & a,const Tensor & b,const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>,1> & dim_pair,Tensor * out)77 bool ExplicitVectorMatrixOptimization(
78     const Tensor& a, const Tensor& b,
79     const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
80     Tensor* out) {
81   if (out->dim_size(0) == 1) {
82     if (dim_pair[0].second == 0) {
83       // Note: this case is optimized in Eigen Tensors.
84       return false;
85     } else {
86       auto out_v = ToEigenVector<T>(out);
87       auto a_v = ToEigenVector<T>(a);
88       auto b_m = ToEigenMatrix<T>(b);
89       out_v.noalias() = b_m * a_v;
90     }
91     return true;
92   } else if (out->dim_size(1) == 1) {
93     auto out_v = ToEigenVector<T>(out);
94     auto a_m = ToEigenMatrix<T>(a);
95     auto b_v = ToEigenVector<T>(b);
96     if (dim_pair[0].first == 0) {
97       out_v.noalias() = a_m.transpose() * b_v;
98     } else {
99       out_v.noalias() = a_m * b_v;
100     }
101     return true;
102   }
103   return false;
104 }
105 // Half is not supported.
106 template <>
ExplicitVectorMatrixOptimization(const Tensor & a,const Tensor & b,const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>,1> & dim_pair,Tensor * out)107 bool ExplicitVectorMatrixOptimization<Eigen::half>(
108     const Tensor& a, const Tensor& b,
109     const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
110     Tensor* out) {
111   return false;
112 }
113 
114 template <typename Device, typename T>
115 struct LaunchMatMulBase {
116 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
117   typedef se::blas::AlgorithmType AlgorithmType;
118 #else
119   typedef int64 AlgorithmType;
120 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
121 
launchtensorflow::LaunchMatMulBase122   static void launch(
123       OpKernelContext* ctx, const Tensor& a, const Tensor& b,
124       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
125       std::vector<AlgorithmType>* algorithms, bool use_aututone, Tensor* out) {
126 #ifndef TENSORFLOW_USE_SYCL
127     // An explicit vector-matrix multiply is much better optimized than an
128     // implicit one and this is a bottleneck during non-batched inference.
129     bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
130     if (!was_vector) {
131 #endif  // TENSORFLOW_USE_SYCL
132       functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(),
133                                           out->matrix<T>(), a.matrix<T>(),
134                                           b.matrix<T>(), dim_pair);
135 #ifndef TENSORFLOW_USE_SYCL
136     }
137 #endif  // TENSORFLOW_USE_SYCL
138   }
139 
GetBlasGemmAlgorithmtensorflow::LaunchMatMulBase140   static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
141                                    std::vector<int64>* algorithms,
142                                    bool* algorithm_set_flag) {}
143 };
144 // On CPUs, we ignore USE_CUBLAS
145 template <typename T>
146 struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {};
147 
148 template <typename T, bool USE_CUBLAS>
149 struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
150 
151 #ifdef TENSORFLOW_USE_SYCL
152 template <typename T>
153 struct LaunchMatMulSYCL : LaunchMatMulBase<SYCLDevice, T> {};
154 
155 template <typename T, bool USE_CUBLAS>
156 struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
157 #endif  // TENSORFLOW_USE_SYCL
158 
159 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
160 
161 namespace {
162 
163 template <typename T>
164 struct LaunchBlasGemv {
Computetensorflow::__anonacf0c25d0211::LaunchBlasGemv165   static void Compute(OpKernelContext* ctx, se::Stream* stream, bool trans,
166                       uint64 m, uint64 n, const se::DeviceMemory<T>& a,
167                       const se::DeviceMemory<T>& b, se::DeviceMemory<T>* c,
168                       se::blas::ProfileResult* output_profile) {
169     const auto blas_trans = trans ? se::blas::Transpose::kTranspose
170                                   : se::blas::Transpose::kNoTranspose;
171     if (output_profile == nullptr) {
172       bool blas_launch_status =
173           stream
174               ->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
175                              static_cast<T>(0.0), c, 1)
176               .ok();
177       if (!blas_launch_status) {
178         ctx->SetStatus(
179             errors::Internal("Blas GEMV launch failed:  m=", m, ", n=", n));
180       }
181     } else {
182       bool blas_launch_status =
183           stream
184               ->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
185                                           a, m, b, 1, static_cast<T>(0.0), c, 1,
186                                           output_profile)
187               .ok();
188       if (!blas_launch_status) {
189         ctx->SetStatus(errors::Internal(
190             "Blas GEMV with profiling launch failed:  m=", m, ", n=", n));
191       }
192     }
193   }
194 
IsSupportedtensorflow::__anonacf0c25d0211::LaunchBlasGemv195   static bool IsSupported() { return true; }
196 };
197 
198 template <>
Compute(OpKernelContext * ctx,se::Stream * stream,bool trans,uint64 m,uint64 n,const se::DeviceMemory<Eigen::half> & a,const se::DeviceMemory<Eigen::half> & b,se::DeviceMemory<Eigen::half> * c,se::blas::ProfileResult * output_profile)199 void LaunchBlasGemv<Eigen::half>::Compute(
200     OpKernelContext* ctx, se::Stream* stream, bool trans, uint64 m, uint64 n,
201     const se::DeviceMemory<Eigen::half>& a,
202     const se::DeviceMemory<Eigen::half>& b, se::DeviceMemory<Eigen::half>* c,
203     se::blas::ProfileResult* output_profile) {
204   ctx->SetStatus(errors::Internal(
205       "Blas GEMV launch failed: GEMV is not implemented for float16."));
206 }
207 
208 template <>
IsSupported()209 bool LaunchBlasGemv<Eigen::half>::IsSupported() {
210   return false;
211 }
212 
213 template <typename T>
ShouldUseGemv(uint64 n)214 bool ShouldUseGemv(uint64 n) {
215   return (LaunchBlasGemv<T>::IsSupported() && n == 1);
216 }
217 
218 }  // namespace
219 
GetCublasAutotuneComputationType(const DataType & dtype,se::blas::ComputationType * compute_type)220 bool GetCublasAutotuneComputationType(const DataType& dtype,
221                                       se::blas::ComputationType* compute_type) {
222   using se::blas::ComputationType;
223   switch (dtype) {
224     case DT_HALF:
225     case DT_BFLOAT16:
226       static bool use_f32_for_f16_computation =
227           MatmulDoFP32ComputationFP16Input();
228       if (use_f32_for_f16_computation) {
229         *compute_type = ComputationType::kF32;
230       } else {
231         *compute_type = ComputationType::kF16;
232       }
233       return false;
234     case DT_FLOAT:
235       *compute_type = ComputationType::kF32;
236       return true;
237     case DT_DOUBLE:
238       *compute_type = ComputationType::kF64;
239       return true;
240     default:
241       // Unsupported compute_type, return false.
242       return false;
243   }
244 }
245 
246 // A dummy type to group matmul autotune results together.
247 struct MatmulAutoTuneGroup {
nametensorflow::MatmulAutoTuneGroup248   static string name() { return "Matmul"; }
249 };
250 typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
251                           se::blas::AlgorithmConfig>
252     AutoTuneMatmul;
253 
254 template <typename T>
255 struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
launchtensorflow::LaunchMatMul256   static void launch(
257       OpKernelContext* ctx, const Tensor& a, const Tensor& b,
258       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
259       std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
260     using se::blas::AlgorithmConfig;
261     using se::blas::ComputationType;
262     using se::blas::kDefaultAlgorithm;
263     using se::blas::kDefaultBlasGemm;
264     using se::blas::kDefaultBlasGemv;
265     using se::blas::kNoAlgorithm;
266     using se::blas::ProfileResult;
267     using se::blas::Transpose;
268     Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
269     const uint64 m = a.dim_size(1 - dim_pair[0].first);
270     const uint64 k = a.dim_size(dim_pair[0].first);
271     const uint64 n = b.dim_size(1 - dim_pair[0].second);
272     bool transpose_a = dim_pair[0].first == 0;
273     bool transpose_b = dim_pair[0].second == 1;
274     auto blas_transpose_a = trans[transpose_a];
275     auto blas_transpose_b = trans[transpose_b];
276 
277     auto* stream = ctx->op_device_context()->stream();
278     OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
279 
280     auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
281                                 a.template flat<T>().size());
282     auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
283                                 b.template flat<T>().size());
284     auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
285                                 out->template flat<T>().size());
286     auto alpha = static_cast<T>(1.0);
287     auto beta = static_cast<T>(0.0);
288 
289     int device_id = stream->parent()->device_ordinal();
290     DataType dtype = a.dtype();
291     MatmulParameters matmul_parameters = {
292         transpose_a, transpose_b, m, n, k, dtype, device_id,
293     };
294     AlgorithmConfig algorithm_config(kNoAlgorithm);
295 
296     ComputationType computation_type;
297     bool compute_type_supported =
298         GetCublasAutotuneComputationType(dtype, &computation_type);
299     if (use_autotune && compute_type_supported && !algorithms->empty()) {
300       ProfileResult best_result;
301       // TODO(yangzihao): Unify this code with conv autotuning.
302       if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
303                                                &algorithm_config)) {
304         ProfileResult profile_result;
305         for (auto profile_algorithm : (*algorithms)) {
306           // Cublas does
307           // C = A x B
308           // where A, B and C are assumed to be in column major.
309           // We want the output to be in row-major, so we can compute
310           // C' = B' x A' (' stands for transpose)
311           bool cublas_launch_status =
312               stream
313                   ->ThenBlasGemmWithAlgorithm(
314                       blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
315                       transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
316                       &c_ptr, n, computation_type, profile_algorithm,
317                       &profile_result)
318                   .ok();
319           if (cublas_launch_status) {
320             if (profile_result.is_valid()) {
321               if (profile_result.elapsed_time_in_ms() <
322                   best_result.elapsed_time_in_ms()) {
323                 best_result = profile_result;
324               }
325             }
326           }
327         }
328         // Try BlasGemmWithProfiling
329         bool cublas_launch_status =
330             stream
331                 ->ThenBlasGemmWithProfiling(
332                     blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
333                     transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
334                     &c_ptr, n, &profile_result)
335                 .ok();
336         if (cublas_launch_status) {
337           if (profile_result.is_valid()) {
338             if (profile_result.elapsed_time_in_ms() <
339                 best_result.elapsed_time_in_ms()) {
340               best_result = profile_result;
341             }
342           }
343         }
344         // Try BlasGemvWithProfiling
345         if (ShouldUseGemv<T>(n)) {
346           LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
347                                      transpose_a ? m : k, transpose_a ? k : m,
348                                      a_ptr, b_ptr, &c_ptr, &profile_result);
349           if (profile_result.is_valid()) {
350             if (profile_result.elapsed_time_in_ms() <
351                 best_result.elapsed_time_in_ms()) {
352               best_result = profile_result;
353             }
354           }
355         }
356       }
357       // We make sure that each matmul parameter set only gets one pass of
358       // autotune. If the best result is found, assign it to algorithm_type
359       // and insert it to autotune map. If all internal kernels of
360       // cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
361       // autotune map.
362       if (best_result.is_valid()) {
363         algorithm_config.set_algorithm(best_result.algorithm());
364       }
365       AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
366                                             algorithm_config);
367       if (algorithm_config.algorithm() != kNoAlgorithm &&
368           algorithm_config.algorithm() != kDefaultBlasGemm &&
369           algorithm_config.algorithm() != kDefaultBlasGemv) {
370         bool cublas_launch_status =
371             stream
372                 ->ThenBlasGemmWithAlgorithm(
373                     blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
374                     transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
375                     &c_ptr, n, computation_type, algorithm_config.algorithm(),
376                     nullptr)
377                 .ok();
378         if (!cublas_launch_status) {
379           ctx->SetStatus(errors::Internal(
380               "Blas GEMM with algorithm launch failed : a.shape=(",
381               a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
382               ", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
383         }
384       }
385     }
386     // For the following case, we use normal BlasGemm():
387     //  1) We didn't set the use_autotune flag;
388     //  2) compute type does not support autotune;
389     //  3) no algorithm is found;
390     //  4) all internal kernels in autotune return invalid results.
391     //  For the following case, we use normal BlasGemv():
392     //  1) We didn't set the use_autotune flag but LaunchBlasGemv is supported
393     //     and n == 1.
394     //  2) We set the use_autotune flag and it picked up BlasGemv() and set the
395     //     algorithm_config.algorithm() to be kDefaultBlasGemv.
396     if (!use_autotune || !compute_type_supported || algorithms->empty() ||
397         algorithm_config.algorithm() == kNoAlgorithm ||
398         algorithm_config.algorithm() == kDefaultBlasGemm ||
399         algorithm_config.algorithm() == kDefaultBlasGemv) {
400       if (algorithm_config.algorithm() == kDefaultBlasGemv ||
401           ShouldUseGemv<T>(n)) {
402         // This is a matrix*vector multiply so use GEMV to compute A * b.
403         // Here we are multiplying in the natural order, so we have to flip
404         // the transposition flag to compensate for the tensor being stored
405         // row-major.
406         // TODO(yangzihao): Add Gemv as an autotuning option too.
407         LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
408                                    transpose_a ? m : k, transpose_a ? k : m,
409                                    a_ptr, b_ptr, &c_ptr, nullptr);
410       } else {
411         // Use C' = B' x A' (' stands for transpose)
412         bool blas_launch_status =
413             stream
414                 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
415                                1.0f, b_ptr, transpose_b ? k : n, a_ptr,
416                                transpose_a ? m : k, 0.0f, &c_ptr, n)
417                 .ok();
418         if (!blas_launch_status) {
419           ctx->SetStatus(errors::Internal(
420               "Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
421               a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
422               "), m=", m, ", n=", n, ", k=", k));
423         }
424       }
425     }
426   }
427 
GetBlasGemmAlgorithmtensorflow::LaunchMatMul428   static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
429                                    std::vector<int64>* algorithms,
430                                    bool* algorithm_set_flag) {
431     if (*algorithm_set_flag == false) {
432       auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
433       stream->parent()->GetBlasGemmAlgorithms(algorithms);
434       *algorithm_set_flag = true;
435     }
436   }
437 };
438 
439 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
440 
441 template <typename Device, typename T, bool USE_CUBLAS>
442 class MatMulOp : public OpKernel {
443  public:
MatMulOp(OpKernelConstruction * ctx)444   explicit MatMulOp(OpKernelConstruction* ctx)
445       : OpKernel(ctx), algorithms_set_already_(false) {
446     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
447     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
448 
449     LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
450         ctx, &algorithms_, &algorithms_set_already_);
451     use_autotune_ = MatmulAutotuneEnable();
452   }
453 
Compute(OpKernelContext * ctx)454   void Compute(OpKernelContext* ctx) override {
455     const Tensor& a = ctx->input(0);
456     const Tensor& b = ctx->input(1);
457 
458     // Check that the dimensions of the two matrices are valid.
459     OP_REQUIRES(
460         ctx, TensorShapeUtils::IsMatrix(a.shape()),
461         errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
462                                 a.shape().DebugString()));
463     OP_REQUIRES(
464         ctx, TensorShapeUtils::IsMatrix(b.shape()),
465         errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
466                                 b.shape().DebugString()));
467     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
468     dim_pair[0].first = transpose_a_ ? 0 : 1;
469     dim_pair[0].second = transpose_b_ ? 1 : 0;
470 
471     OP_REQUIRES(
472         ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
473         errors::InvalidArgument(
474             "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
475             ", In[1]: ", b.shape().DebugString()));
476     int a_dim_remaining = 1 - dim_pair[0].first;
477     int b_dim_remaining = 1 - dim_pair[0].second;
478     TensorShape out_shape(
479         {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
480     Tensor* out = nullptr;
481     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
482 
483     if (out->NumElements() == 0) {
484       // If a has shape [0, x] or b has shape [x, 0], the output shape
485       // is a 0-element matrix, so there is nothing to do.
486       return;
487     }
488 
489     if (a.NumElements() == 0 && b.NumElements() == 0) {
490       // If a has shape [x, 0] and b has shape [0, y], the
491       // output shape is [x, y] where x and y are non-zero, so we fill
492       // the output with zeros.
493       functor::SetZeroFunctor<Device, T> f;
494       f(ctx->eigen_device<Device>(), out->flat<T>());
495       return;
496     }
497 
498     if (std::is_same<T, bfloat16>::value) {
499       bool is_cpu = std::is_same<Device, CPUDevice>::value;
500       OP_REQUIRES(ctx, is_cpu,
501                   errors::Internal("bfloat16 matmul is not supported by GPU"));
502       Tensor a_float, b_float, out_float;
503       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, a.shape(), &a_float));
504       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, b.shape(), &b_float));
505       OP_REQUIRES_OK(ctx,
506                      ctx->allocate_temp(DT_FLOAT, out->shape(), &out_float));
507 
508       // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
509       BFloat16ToFloat(a.flat<bfloat16>().data(), a_float.flat<float>().data(),
510                       a.NumElements());
511       BFloat16ToFloat(b.flat<bfloat16>().data(), b_float.flat<float>().data(),
512                       b.NumElements());
513 
514       LaunchMatMul<Device, float, USE_CUBLAS>::launch(
515           ctx, a_float, b_float, dim_pair, &algorithms_, use_autotune_,
516           &out_float);
517       FloatToBFloat16(out_float.flat<float>().data(),
518                       out->flat<bfloat16>().data(), out->NumElements());
519     } else {
520       LaunchMatMul<Device, T, USE_CUBLAS>::launch(
521           ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
522     }
523   }
524 
525  private:
526   std::vector<int64> algorithms_;
527   bool algorithms_set_already_;
528   bool use_autotune_;
529   bool transpose_a_;
530   bool transpose_b_;
531 };
532 
533 namespace functor {
534 
535 // Partial specialization MatMulFunctor<Device=CPUDevice, T>.
536 template <typename T>
537 struct MatMulFunctor<CPUDevice, T> {
operator ()tensorflow::functor::MatMulFunctor538   void operator()(
539       const CPUDevice& d, typename MatMulTypes<T>::out_type out,
540       typename MatMulTypes<T>::in_type in0,
541       typename MatMulTypes<T>::in_type in1,
542       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
543     MatMul<CPUDevice>(d, out, in0, in1, dim_pair);
544   }
545 };
546 
547 #ifdef TENSORFLOW_USE_SYCL
548 // Partial specialization MatMulFunctor<Device=SYCLDevice, T>.
549 template <typename T>
550 struct MatMulFunctor<SYCLDevice, T> {
operator ()tensorflow::functor::MatMulFunctor551   void operator()(
552       const SYCLDevice& d, typename MatMulTypes<T>::out_type out,
553       typename MatMulTypes<T>::in_type in0,
554       typename MatMulTypes<T>::in_type in1,
555       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
556     MatMul<SYCLDevice>(d, out, in0, in1, dim_pair);
557   }
558 };
559 #endif  // TENSORFLOW_USE_SYCL
560 
561 }  // end namespace functor
562 
563 #define REGISTER_CPU_EIGEN(T)                                                  \
564   REGISTER_KERNEL_BUILDER(                                                     \
565       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \
566       MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
567 
568 #define REGISTER_CPU(T)                                             \
569   REGISTER_KERNEL_BUILDER(                                          \
570       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"),     \
571       MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
572   REGISTER_CPU_EIGEN(T);
573 
574 #define REGISTER_GPU(T)                                            \
575   REGISTER_KERNEL_BUILDER(                                         \
576       Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"),    \
577       MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \
578   REGISTER_KERNEL_BUILDER(Name("MatMul")                           \
579                               .Device(DEVICE_GPU)                  \
580                               .TypeConstraint<T>("T")              \
581                               .Label("cublas"),                    \
582                           MatMulOp<GPUDevice, T, true /* cublas */>)
583 
584 TF_CALL_float(REGISTER_CPU);
585 TF_CALL_double(REGISTER_CPU);
586 TF_CALL_half(REGISTER_CPU);
587 TF_CALL_bfloat16(REGISTER_CPU);
588 TF_CALL_int32(REGISTER_CPU);
589 TF_CALL_int64(REGISTER_CPU);
590 TF_CALL_complex64(REGISTER_CPU);
591 TF_CALL_complex128(REGISTER_CPU);
592 
593 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
594 TF_CALL_float(REGISTER_GPU);
595 TF_CALL_double(REGISTER_GPU);
596 TF_CALL_complex64(REGISTER_GPU);
597 TF_CALL_complex128(REGISTER_GPU);
598 TF_CALL_half(REGISTER_GPU);
599 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
600 
601 #ifdef TENSORFLOW_USE_SYCL
602 #define REGISTER_SYCL(T)                                         \
603   REGISTER_KERNEL_BUILDER(                                       \
604       Name("MatMul").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \
605       MatMulOp<SYCLDevice, T, false /* xxblas */>);              \
606   REGISTER_KERNEL_BUILDER(Name("MatMul")                         \
607                               .Device(DEVICE_SYCL)               \
608                               .TypeConstraint<T>("T")            \
609                               .Label("eigen"),                   \
610                           MatMulOp<SYCLDevice, T, false /* xxblas */>)
611 TF_CALL_float(REGISTER_SYCL);
612 TF_CALL_double(REGISTER_SYCL);
613 
614 #endif  // TENSORFLOW_USE_SYCL
615 }  // namespace tensorflow
616