1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/matmul_op.h"
21
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/kernels/fill_functor.h"
26 #include "tensorflow/core/util/matmul_autotune.h"
27 #if GOOGLE_CUDA
28 #include "third_party/gpus/cuda/include/cuda.h"
29 #endif
30 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
31 #include "tensorflow/core/kernels/gpu_utils.h"
32 #include "tensorflow/core/platform/stream_executor.h"
33 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
34
35 namespace tensorflow {
36
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 typedef Eigen::GpuDevice GPUDevice;
39 #ifdef TENSORFLOW_USE_SYCL
40 typedef Eigen::SyclDevice SYCLDevice;
41 #endif // TENSORFLOW_USE_SYCL
42
43 template <typename Device, typename T, bool USE_CUBLAS>
44 struct LaunchMatMul;
45
46 namespace {
47 // Converts a TensorFlow Tensor to an Eigen Matrix.
48 template <typename T>
49 Eigen::Map<
50 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
ToEigenMatrix(const Tensor & tensor)51 ToEigenMatrix(const Tensor& tensor) {
52 auto matrix = tensor.matrix<T>();
53 return Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Map(
54 matrix.data(), matrix.dimension(0), matrix.dimension(1));
55 }
56
57 // Converts a TensorFlow Tensor to an Eigen Vector.
58 template <typename T>
ToEigenVector(Tensor * tensor)59 Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(Tensor* tensor) {
60 auto v = tensor->flat<T>();
61 return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
62 }
63 template <typename T>
ToEigenVector(const Tensor & tensor)64 Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(
65 const Tensor& tensor) {
66 auto v = tensor.flat<T>();
67 return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
68 }
69 } // namespace
70
71 // If either side can be represented as a vector, do an explicit vector
72 // matrix multiply and return true; else return false.
73 //
74 // Note: this uses plain Eigen and not Eigen Tensor because it is more
75 // efficient.
76 template <typename T>
ExplicitVectorMatrixOptimization(const Tensor & a,const Tensor & b,const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>,1> & dim_pair,Tensor * out)77 bool ExplicitVectorMatrixOptimization(
78 const Tensor& a, const Tensor& b,
79 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
80 Tensor* out) {
81 if (out->dim_size(0) == 1) {
82 if (dim_pair[0].second == 0) {
83 // Note: this case is optimized in Eigen Tensors.
84 return false;
85 } else {
86 auto out_v = ToEigenVector<T>(out);
87 auto a_v = ToEigenVector<T>(a);
88 auto b_m = ToEigenMatrix<T>(b);
89 out_v.noalias() = b_m * a_v;
90 }
91 return true;
92 } else if (out->dim_size(1) == 1) {
93 auto out_v = ToEigenVector<T>(out);
94 auto a_m = ToEigenMatrix<T>(a);
95 auto b_v = ToEigenVector<T>(b);
96 if (dim_pair[0].first == 0) {
97 out_v.noalias() = a_m.transpose() * b_v;
98 } else {
99 out_v.noalias() = a_m * b_v;
100 }
101 return true;
102 }
103 return false;
104 }
105 // Half is not supported.
106 template <>
ExplicitVectorMatrixOptimization(const Tensor & a,const Tensor & b,const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>,1> & dim_pair,Tensor * out)107 bool ExplicitVectorMatrixOptimization<Eigen::half>(
108 const Tensor& a, const Tensor& b,
109 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
110 Tensor* out) {
111 return false;
112 }
113
114 template <typename Device, typename T>
115 struct LaunchMatMulBase {
116 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
117 typedef se::blas::AlgorithmType AlgorithmType;
118 #else
119 typedef int64 AlgorithmType;
120 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
121
launchtensorflow::LaunchMatMulBase122 static void launch(
123 OpKernelContext* ctx, const Tensor& a, const Tensor& b,
124 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
125 std::vector<AlgorithmType>* algorithms, bool use_aututone, Tensor* out) {
126 #ifndef TENSORFLOW_USE_SYCL
127 // An explicit vector-matrix multiply is much better optimized than an
128 // implicit one and this is a bottleneck during non-batched inference.
129 bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
130 if (!was_vector) {
131 #endif // TENSORFLOW_USE_SYCL
132 functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(),
133 out->matrix<T>(), a.matrix<T>(),
134 b.matrix<T>(), dim_pair);
135 #ifndef TENSORFLOW_USE_SYCL
136 }
137 #endif // TENSORFLOW_USE_SYCL
138 }
139
GetBlasGemmAlgorithmtensorflow::LaunchMatMulBase140 static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
141 std::vector<int64>* algorithms,
142 bool* algorithm_set_flag) {}
143 };
144 // On CPUs, we ignore USE_CUBLAS
145 template <typename T>
146 struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {};
147
148 template <typename T, bool USE_CUBLAS>
149 struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
150
151 #ifdef TENSORFLOW_USE_SYCL
152 template <typename T>
153 struct LaunchMatMulSYCL : LaunchMatMulBase<SYCLDevice, T> {};
154
155 template <typename T, bool USE_CUBLAS>
156 struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
157 #endif // TENSORFLOW_USE_SYCL
158
159 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
160
161 namespace {
162
163 template <typename T>
164 struct LaunchBlasGemv {
Computetensorflow::__anonacf0c25d0211::LaunchBlasGemv165 static void Compute(OpKernelContext* ctx, se::Stream* stream, bool trans,
166 uint64 m, uint64 n, const se::DeviceMemory<T>& a,
167 const se::DeviceMemory<T>& b, se::DeviceMemory<T>* c,
168 se::blas::ProfileResult* output_profile) {
169 const auto blas_trans = trans ? se::blas::Transpose::kTranspose
170 : se::blas::Transpose::kNoTranspose;
171 if (output_profile == nullptr) {
172 bool blas_launch_status =
173 stream
174 ->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
175 static_cast<T>(0.0), c, 1)
176 .ok();
177 if (!blas_launch_status) {
178 ctx->SetStatus(
179 errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
180 }
181 } else {
182 bool blas_launch_status =
183 stream
184 ->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
185 a, m, b, 1, static_cast<T>(0.0), c, 1,
186 output_profile)
187 .ok();
188 if (!blas_launch_status) {
189 ctx->SetStatus(errors::Internal(
190 "Blas GEMV with profiling launch failed: m=", m, ", n=", n));
191 }
192 }
193 }
194
IsSupportedtensorflow::__anonacf0c25d0211::LaunchBlasGemv195 static bool IsSupported() { return true; }
196 };
197
198 template <>
Compute(OpKernelContext * ctx,se::Stream * stream,bool trans,uint64 m,uint64 n,const se::DeviceMemory<Eigen::half> & a,const se::DeviceMemory<Eigen::half> & b,se::DeviceMemory<Eigen::half> * c,se::blas::ProfileResult * output_profile)199 void LaunchBlasGemv<Eigen::half>::Compute(
200 OpKernelContext* ctx, se::Stream* stream, bool trans, uint64 m, uint64 n,
201 const se::DeviceMemory<Eigen::half>& a,
202 const se::DeviceMemory<Eigen::half>& b, se::DeviceMemory<Eigen::half>* c,
203 se::blas::ProfileResult* output_profile) {
204 ctx->SetStatus(errors::Internal(
205 "Blas GEMV launch failed: GEMV is not implemented for float16."));
206 }
207
208 template <>
IsSupported()209 bool LaunchBlasGemv<Eigen::half>::IsSupported() {
210 return false;
211 }
212
213 template <typename T>
ShouldUseGemv(uint64 n)214 bool ShouldUseGemv(uint64 n) {
215 return (LaunchBlasGemv<T>::IsSupported() && n == 1);
216 }
217
218 } // namespace
219
GetCublasAutotuneComputationType(const DataType & dtype,se::blas::ComputationType * compute_type)220 bool GetCublasAutotuneComputationType(const DataType& dtype,
221 se::blas::ComputationType* compute_type) {
222 using se::blas::ComputationType;
223 switch (dtype) {
224 case DT_HALF:
225 case DT_BFLOAT16:
226 static bool use_f32_for_f16_computation =
227 MatmulDoFP32ComputationFP16Input();
228 if (use_f32_for_f16_computation) {
229 *compute_type = ComputationType::kF32;
230 } else {
231 *compute_type = ComputationType::kF16;
232 }
233 return false;
234 case DT_FLOAT:
235 *compute_type = ComputationType::kF32;
236 return true;
237 case DT_DOUBLE:
238 *compute_type = ComputationType::kF64;
239 return true;
240 default:
241 // Unsupported compute_type, return false.
242 return false;
243 }
244 }
245
246 // A dummy type to group matmul autotune results together.
247 struct MatmulAutoTuneGroup {
nametensorflow::MatmulAutoTuneGroup248 static string name() { return "Matmul"; }
249 };
250 typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
251 se::blas::AlgorithmConfig>
252 AutoTuneMatmul;
253
254 template <typename T>
255 struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
launchtensorflow::LaunchMatMul256 static void launch(
257 OpKernelContext* ctx, const Tensor& a, const Tensor& b,
258 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
259 std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
260 using se::blas::AlgorithmConfig;
261 using se::blas::ComputationType;
262 using se::blas::kDefaultAlgorithm;
263 using se::blas::kDefaultBlasGemm;
264 using se::blas::kDefaultBlasGemv;
265 using se::blas::kNoAlgorithm;
266 using se::blas::ProfileResult;
267 using se::blas::Transpose;
268 Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
269 const uint64 m = a.dim_size(1 - dim_pair[0].first);
270 const uint64 k = a.dim_size(dim_pair[0].first);
271 const uint64 n = b.dim_size(1 - dim_pair[0].second);
272 bool transpose_a = dim_pair[0].first == 0;
273 bool transpose_b = dim_pair[0].second == 1;
274 auto blas_transpose_a = trans[transpose_a];
275 auto blas_transpose_b = trans[transpose_b];
276
277 auto* stream = ctx->op_device_context()->stream();
278 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
279
280 auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
281 a.template flat<T>().size());
282 auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
283 b.template flat<T>().size());
284 auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
285 out->template flat<T>().size());
286 auto alpha = static_cast<T>(1.0);
287 auto beta = static_cast<T>(0.0);
288
289 int device_id = stream->parent()->device_ordinal();
290 DataType dtype = a.dtype();
291 MatmulParameters matmul_parameters = {
292 transpose_a, transpose_b, m, n, k, dtype, device_id,
293 };
294 AlgorithmConfig algorithm_config(kNoAlgorithm);
295
296 ComputationType computation_type;
297 bool compute_type_supported =
298 GetCublasAutotuneComputationType(dtype, &computation_type);
299 if (use_autotune && compute_type_supported && !algorithms->empty()) {
300 ProfileResult best_result;
301 // TODO(yangzihao): Unify this code with conv autotuning.
302 if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
303 &algorithm_config)) {
304 ProfileResult profile_result;
305 for (auto profile_algorithm : (*algorithms)) {
306 // Cublas does
307 // C = A x B
308 // where A, B and C are assumed to be in column major.
309 // We want the output to be in row-major, so we can compute
310 // C' = B' x A' (' stands for transpose)
311 bool cublas_launch_status =
312 stream
313 ->ThenBlasGemmWithAlgorithm(
314 blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
315 transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
316 &c_ptr, n, computation_type, profile_algorithm,
317 &profile_result)
318 .ok();
319 if (cublas_launch_status) {
320 if (profile_result.is_valid()) {
321 if (profile_result.elapsed_time_in_ms() <
322 best_result.elapsed_time_in_ms()) {
323 best_result = profile_result;
324 }
325 }
326 }
327 }
328 // Try BlasGemmWithProfiling
329 bool cublas_launch_status =
330 stream
331 ->ThenBlasGemmWithProfiling(
332 blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
333 transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
334 &c_ptr, n, &profile_result)
335 .ok();
336 if (cublas_launch_status) {
337 if (profile_result.is_valid()) {
338 if (profile_result.elapsed_time_in_ms() <
339 best_result.elapsed_time_in_ms()) {
340 best_result = profile_result;
341 }
342 }
343 }
344 // Try BlasGemvWithProfiling
345 if (ShouldUseGemv<T>(n)) {
346 LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
347 transpose_a ? m : k, transpose_a ? k : m,
348 a_ptr, b_ptr, &c_ptr, &profile_result);
349 if (profile_result.is_valid()) {
350 if (profile_result.elapsed_time_in_ms() <
351 best_result.elapsed_time_in_ms()) {
352 best_result = profile_result;
353 }
354 }
355 }
356 }
357 // We make sure that each matmul parameter set only gets one pass of
358 // autotune. If the best result is found, assign it to algorithm_type
359 // and insert it to autotune map. If all internal kernels of
360 // cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
361 // autotune map.
362 if (best_result.is_valid()) {
363 algorithm_config.set_algorithm(best_result.algorithm());
364 }
365 AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
366 algorithm_config);
367 if (algorithm_config.algorithm() != kNoAlgorithm &&
368 algorithm_config.algorithm() != kDefaultBlasGemm &&
369 algorithm_config.algorithm() != kDefaultBlasGemv) {
370 bool cublas_launch_status =
371 stream
372 ->ThenBlasGemmWithAlgorithm(
373 blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
374 transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
375 &c_ptr, n, computation_type, algorithm_config.algorithm(),
376 nullptr)
377 .ok();
378 if (!cublas_launch_status) {
379 ctx->SetStatus(errors::Internal(
380 "Blas GEMM with algorithm launch failed : a.shape=(",
381 a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
382 ", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
383 }
384 }
385 }
386 // For the following case, we use normal BlasGemm():
387 // 1) We didn't set the use_autotune flag;
388 // 2) compute type does not support autotune;
389 // 3) no algorithm is found;
390 // 4) all internal kernels in autotune return invalid results.
391 // For the following case, we use normal BlasGemv():
392 // 1) We didn't set the use_autotune flag but LaunchBlasGemv is supported
393 // and n == 1.
394 // 2) We set the use_autotune flag and it picked up BlasGemv() and set the
395 // algorithm_config.algorithm() to be kDefaultBlasGemv.
396 if (!use_autotune || !compute_type_supported || algorithms->empty() ||
397 algorithm_config.algorithm() == kNoAlgorithm ||
398 algorithm_config.algorithm() == kDefaultBlasGemm ||
399 algorithm_config.algorithm() == kDefaultBlasGemv) {
400 if (algorithm_config.algorithm() == kDefaultBlasGemv ||
401 ShouldUseGemv<T>(n)) {
402 // This is a matrix*vector multiply so use GEMV to compute A * b.
403 // Here we are multiplying in the natural order, so we have to flip
404 // the transposition flag to compensate for the tensor being stored
405 // row-major.
406 // TODO(yangzihao): Add Gemv as an autotuning option too.
407 LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
408 transpose_a ? m : k, transpose_a ? k : m,
409 a_ptr, b_ptr, &c_ptr, nullptr);
410 } else {
411 // Use C' = B' x A' (' stands for transpose)
412 bool blas_launch_status =
413 stream
414 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
415 1.0f, b_ptr, transpose_b ? k : n, a_ptr,
416 transpose_a ? m : k, 0.0f, &c_ptr, n)
417 .ok();
418 if (!blas_launch_status) {
419 ctx->SetStatus(errors::Internal(
420 "Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
421 a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
422 "), m=", m, ", n=", n, ", k=", k));
423 }
424 }
425 }
426 }
427
GetBlasGemmAlgorithmtensorflow::LaunchMatMul428 static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
429 std::vector<int64>* algorithms,
430 bool* algorithm_set_flag) {
431 if (*algorithm_set_flag == false) {
432 auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
433 stream->parent()->GetBlasGemmAlgorithms(algorithms);
434 *algorithm_set_flag = true;
435 }
436 }
437 };
438
439 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
440
441 template <typename Device, typename T, bool USE_CUBLAS>
442 class MatMulOp : public OpKernel {
443 public:
MatMulOp(OpKernelConstruction * ctx)444 explicit MatMulOp(OpKernelConstruction* ctx)
445 : OpKernel(ctx), algorithms_set_already_(false) {
446 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
447 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
448
449 LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
450 ctx, &algorithms_, &algorithms_set_already_);
451 use_autotune_ = MatmulAutotuneEnable();
452 }
453
Compute(OpKernelContext * ctx)454 void Compute(OpKernelContext* ctx) override {
455 const Tensor& a = ctx->input(0);
456 const Tensor& b = ctx->input(1);
457
458 // Check that the dimensions of the two matrices are valid.
459 OP_REQUIRES(
460 ctx, TensorShapeUtils::IsMatrix(a.shape()),
461 errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
462 a.shape().DebugString()));
463 OP_REQUIRES(
464 ctx, TensorShapeUtils::IsMatrix(b.shape()),
465 errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
466 b.shape().DebugString()));
467 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
468 dim_pair[0].first = transpose_a_ ? 0 : 1;
469 dim_pair[0].second = transpose_b_ ? 1 : 0;
470
471 OP_REQUIRES(
472 ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
473 errors::InvalidArgument(
474 "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
475 ", In[1]: ", b.shape().DebugString()));
476 int a_dim_remaining = 1 - dim_pair[0].first;
477 int b_dim_remaining = 1 - dim_pair[0].second;
478 TensorShape out_shape(
479 {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
480 Tensor* out = nullptr;
481 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
482
483 if (out->NumElements() == 0) {
484 // If a has shape [0, x] or b has shape [x, 0], the output shape
485 // is a 0-element matrix, so there is nothing to do.
486 return;
487 }
488
489 if (a.NumElements() == 0 && b.NumElements() == 0) {
490 // If a has shape [x, 0] and b has shape [0, y], the
491 // output shape is [x, y] where x and y are non-zero, so we fill
492 // the output with zeros.
493 functor::SetZeroFunctor<Device, T> f;
494 f(ctx->eigen_device<Device>(), out->flat<T>());
495 return;
496 }
497
498 if (std::is_same<T, bfloat16>::value) {
499 bool is_cpu = std::is_same<Device, CPUDevice>::value;
500 OP_REQUIRES(ctx, is_cpu,
501 errors::Internal("bfloat16 matmul is not supported by GPU"));
502 Tensor a_float, b_float, out_float;
503 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, a.shape(), &a_float));
504 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, b.shape(), &b_float));
505 OP_REQUIRES_OK(ctx,
506 ctx->allocate_temp(DT_FLOAT, out->shape(), &out_float));
507
508 // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
509 BFloat16ToFloat(a.flat<bfloat16>().data(), a_float.flat<float>().data(),
510 a.NumElements());
511 BFloat16ToFloat(b.flat<bfloat16>().data(), b_float.flat<float>().data(),
512 b.NumElements());
513
514 LaunchMatMul<Device, float, USE_CUBLAS>::launch(
515 ctx, a_float, b_float, dim_pair, &algorithms_, use_autotune_,
516 &out_float);
517 FloatToBFloat16(out_float.flat<float>().data(),
518 out->flat<bfloat16>().data(), out->NumElements());
519 } else {
520 LaunchMatMul<Device, T, USE_CUBLAS>::launch(
521 ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
522 }
523 }
524
525 private:
526 std::vector<int64> algorithms_;
527 bool algorithms_set_already_;
528 bool use_autotune_;
529 bool transpose_a_;
530 bool transpose_b_;
531 };
532
533 namespace functor {
534
535 // Partial specialization MatMulFunctor<Device=CPUDevice, T>.
536 template <typename T>
537 struct MatMulFunctor<CPUDevice, T> {
operator ()tensorflow::functor::MatMulFunctor538 void operator()(
539 const CPUDevice& d, typename MatMulTypes<T>::out_type out,
540 typename MatMulTypes<T>::in_type in0,
541 typename MatMulTypes<T>::in_type in1,
542 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
543 MatMul<CPUDevice>(d, out, in0, in1, dim_pair);
544 }
545 };
546
547 #ifdef TENSORFLOW_USE_SYCL
548 // Partial specialization MatMulFunctor<Device=SYCLDevice, T>.
549 template <typename T>
550 struct MatMulFunctor<SYCLDevice, T> {
operator ()tensorflow::functor::MatMulFunctor551 void operator()(
552 const SYCLDevice& d, typename MatMulTypes<T>::out_type out,
553 typename MatMulTypes<T>::in_type in0,
554 typename MatMulTypes<T>::in_type in1,
555 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
556 MatMul<SYCLDevice>(d, out, in0, in1, dim_pair);
557 }
558 };
559 #endif // TENSORFLOW_USE_SYCL
560
561 } // end namespace functor
562
563 #define REGISTER_CPU_EIGEN(T) \
564 REGISTER_KERNEL_BUILDER( \
565 Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \
566 MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
567
568 #define REGISTER_CPU(T) \
569 REGISTER_KERNEL_BUILDER( \
570 Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
571 MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
572 REGISTER_CPU_EIGEN(T);
573
574 #define REGISTER_GPU(T) \
575 REGISTER_KERNEL_BUILDER( \
576 Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
577 MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \
578 REGISTER_KERNEL_BUILDER(Name("MatMul") \
579 .Device(DEVICE_GPU) \
580 .TypeConstraint<T>("T") \
581 .Label("cublas"), \
582 MatMulOp<GPUDevice, T, true /* cublas */>)
583
584 TF_CALL_float(REGISTER_CPU);
585 TF_CALL_double(REGISTER_CPU);
586 TF_CALL_half(REGISTER_CPU);
587 TF_CALL_bfloat16(REGISTER_CPU);
588 TF_CALL_int32(REGISTER_CPU);
589 TF_CALL_int64(REGISTER_CPU);
590 TF_CALL_complex64(REGISTER_CPU);
591 TF_CALL_complex128(REGISTER_CPU);
592
593 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
594 TF_CALL_float(REGISTER_GPU);
595 TF_CALL_double(REGISTER_GPU);
596 TF_CALL_complex64(REGISTER_GPU);
597 TF_CALL_complex128(REGISTER_GPU);
598 TF_CALL_half(REGISTER_GPU);
599 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
600
601 #ifdef TENSORFLOW_USE_SYCL
602 #define REGISTER_SYCL(T) \
603 REGISTER_KERNEL_BUILDER( \
604 Name("MatMul").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \
605 MatMulOp<SYCLDevice, T, false /* xxblas */>); \
606 REGISTER_KERNEL_BUILDER(Name("MatMul") \
607 .Device(DEVICE_SYCL) \
608 .TypeConstraint<T>("T") \
609 .Label("eigen"), \
610 MatMulOp<SYCLDevice, T, false /* xxblas */>)
611 TF_CALL_float(REGISTER_SYCL);
612 TF_CALL_double(REGISTER_SYCL);
613
614 #endif // TENSORFLOW_USE_SYCL
615 } // namespace tensorflow
616