1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
20
21 #define EIGEN_USE_THREADS
22
23 #include <type_traits>
24 #include <utility>
25 #include <vector>
26
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/type_traits.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/kernels/fill_functor.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/inlined_vector.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/core/util/matmul_autotune.h"
40 #include "tensorflow/core/util/matmul_bcast.h"
41 #include "tensorflow/core/util/work_sharder.h"
42
43 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
44 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
45 #endif
46
47 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48 #include "tensorflow/core/kernels/gpu_utils.h"
49 #include "tensorflow/core/platform/stream_executor.h"
50 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 #if GOOGLE_CUDA
52 #include "third_party/gpus/cuda/include/cuda.h"
53 #include "tensorflow/core/kernels/matmul_util.h"
54 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h"
55 #include "tensorflow/stream_executor/host_or_device_scalar.h"
56 #endif // GOOGLE_CUDA
57
58 namespace tensorflow {
59
60 typedef Eigen::ThreadPoolDevice CPUDevice;
61 typedef Eigen::GpuDevice GPUDevice;
62
63 namespace {
64
65 // Returns the pair of dimensions along which to perform Tensor contraction to
66 // emulate matrix multiplication.
67 // For matrix multiplication of 2D Tensors X and Y, X is contracted along
68 // second dimension and Y is contracted along the first dimension (if neither X
69 // nor Y is adjointed). The dimension to contract along is switched when any
70 // operand is adjointed.
71 // See http://en.wikipedia.org/wiki/Tensor_contraction
ContractionDims(bool adj_x,bool adj_y)72 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
73 return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
74 }
75
76 // Parallel batch matmul kernel based on the multi-threaded tensor contraction
77 // in Eigen.
78 template <typename Scalar, bool IsComplex = true>
79 struct ParallelMatMulKernel {
ConjugateParallelMatMulKernel80 static void Conjugate(const OpKernelContext* context, Tensor* out) {
81 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
82 auto z = out->tensor<Scalar, 3>();
83 z.device(d) = z.conjugate();
84 }
85
RunParallelMatMulKernel86 static void Run(const OpKernelContext* context, const Tensor& in_x,
87 const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
88 bool trans_y, const MatMulBCast& bcast, Tensor* out,
89 int batch_size) {
90 static_assert(IsComplex, "Complex type expected.");
91 auto Tx = in_x.tensor<Scalar, 3>();
92 auto Ty = in_y.tensor<Scalar, 3>();
93 auto Tz = out->tensor<Scalar, 3>();
94 // We use the identities
95 // conj(a) * conj(b) = conj(a * b)
96 // conj(a) * b = conj(a * conj(b))
97 // to halve the number of cases. The final conjugation of the result is
98 // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
99 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
100 contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
101 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
102
103 const bool should_bcast = bcast.IsBroadcastingRequired();
104 const auto& x_batch_indices = bcast.x_batch_indices();
105 const auto& y_batch_indices = bcast.y_batch_indices();
106 // TODO(rmlarsen): Consider launching these contractions asynchronously.
107 for (int64_t i = 0; i < batch_size; ++i) {
108 const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
109 const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
110
111 auto x = Tx.template chip<0>(x_batch_index);
112 auto z = Tz.template chip<0>(i);
113 if (adj_x != adj_y) {
114 auto y = Ty.template chip<0>(y_batch_index).conjugate();
115 z.device(d) = x.contract(y, contract_pairs);
116 } else {
117 auto y = Ty.template chip<0>(y_batch_index);
118 z.device(d) = x.contract(y, contract_pairs);
119 }
120 }
121 }
122 };
123
124 // The Eigen contraction kernel used here is very large and slow to compile,
125 // so we partially specialize ParallelMatMulKernel for real types to avoid all
126 // but one of the instantiations.
127 template <typename Scalar>
128 struct ParallelMatMulKernel<Scalar, false> {
129 static void Conjugate(const OpKernelContext* context, Tensor* out) {}
130
131 static void Run(const OpKernelContext* context, const Tensor& in_x,
132 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
133 bool trans_y, const MatMulBCast& bcast, Tensor* out,
134 int batch_size) {
135 const bool should_bcast = bcast.IsBroadcastingRequired();
136 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
137 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
138 contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
139 if (batch_size == 1 && !should_bcast) {
140 auto Tx = in_x.flat_inner_dims<Scalar, 2>();
141 auto Ty = in_y.flat_inner_dims<Scalar, 2>();
142 auto Tz = out->flat_inner_dims<Scalar, 2>();
143 Tz.device(d) = Tx.contract(Ty, contract_pairs);
144 } else {
145 auto Tx = in_x.tensor<Scalar, 3>();
146 auto Ty = in_y.tensor<Scalar, 3>();
147 auto Tz = out->tensor<Scalar, 3>();
148 const auto& x_batch_indices = bcast.x_batch_indices();
149 const auto& y_batch_indices = bcast.y_batch_indices();
150 // TODO(rmlarsen): Consider launching these contractions asynchronously.
151 for (int64_t i = 0; i < batch_size; ++i) {
152 const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
153 const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
154 auto x = Tx.template chip<0>(x_batch_index);
155 auto y = Ty.template chip<0>(y_batch_index);
156 auto z = Tz.template chip<0>(i);
157
158 z.device(d) = x.contract(y, contract_pairs);
159 }
160 }
161 }
162 };
163
164 // Sequential batch matmul kernel that calls the regular Eigen matmul.
165 // We prefer this over the tensor contraction because it performs
166 // better on vector-matrix and matrix-vector products.
167 template <typename Scalar>
168 struct SequentialMatMulKernel {
169 using Matrix =
170 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
171 using ConstMatrixMap = Eigen::Map<const Matrix>;
172 using MatrixMap = Eigen::Map<Matrix>;
173
174 static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
175 int slice) {
176 return ConstMatrixMap(
177 t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
178 t.dim_size(1), t.dim_size(2));
179 }
180
181 static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
182 return MatrixMap(
183 t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
184 t->dim_size(1), t->dim_size(2));
185 }
186
187 static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
188 bool adj_y, bool trans_x, bool trans_y,
189 const MatMulBCast& bcast, Tensor* out, int start, int limit) {
190 const bool should_bcast = bcast.IsBroadcastingRequired();
191 const auto& x_batch_indices = bcast.x_batch_indices();
192 const auto& y_batch_indices = bcast.y_batch_indices();
193 for (int64_t i = start; i < limit; ++i) {
194 const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
195 const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
196 auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
197 auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
198 auto z = TensorSliceToEigenMatrix(out, i);
199 // Assume at most one of adj_x or trans_x is true. Similarly, for adj_y
200 // and trans_y.
201 if (!adj_x && !trans_x) {
202 if (!adj_y && !trans_y) {
203 z.noalias() = x * y;
204 } else if (adj_y) {
205 z.noalias() = x * y.adjoint();
206 } else { // trans_y == true
207 z.noalias() = x * y.transpose();
208 }
209 } else if (adj_x) {
210 if (!adj_y && !trans_y) {
211 z.noalias() = x.adjoint() * y;
212 } else if (adj_y) {
213 z.noalias() = x.adjoint() * y.adjoint();
214 } else { // trans_y == true
215 z.noalias() = x.adjoint() * y.transpose();
216 }
217 } else { // trans_x == true
218 if (!adj_y && !trans_y) {
219 z.noalias() = x.transpose() * y;
220 } else if (adj_y) {
221 z.noalias() = x.transpose() * y.adjoint();
222 } else { // trans_y == true
223 z.noalias() = x.transpose() * y.transpose();
224 }
225 }
226 }
227 }
228 };
229 } // namespace
230
231 template <typename Device, typename Scalar>
232 struct LaunchBatchMatMul;
233
234 template <typename Scalar>
235 struct LaunchBatchMatMul<CPUDevice, Scalar> {
236 static void Launch(OpKernelContext* context, const Tensor& in_x,
237 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
238 bool trans_y, const MatMulBCast& bcast, Tensor* out) {
239 typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
240 ParallelMatMulKernel;
241 bool conjugate_result = false;
242
243 // Number of matrix multiplies i.e. size of the batch.
244 const int64_t batch_size = bcast.output_batch_size();
245 const int64_t cost_per_unit =
246 in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
247 const int64_t small_dim = std::min(
248 std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
249 // NOTE(nikhilsarda): This heuristic is optimal in benchmarks as of
250 // Jan 21, 2020.
251 const int64_t kMaxCostOuterParallelism = 128 * 128; // heuristic.
252 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
253 // TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous
254 // evaluation in Eigen Tensor.
255 if (small_dim > 1 &&
256 (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
257 // Parallelize over inner dims.
258 // For large matrix products it is counter-productive to parallelize
259 // over the batch dimension.
260 ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
261 trans_y, bcast, out, batch_size);
262 conjugate_result = adj_x;
263 } else {
264 // Parallelize over outer dims. For small matrices and large batches, it
265 // is counter-productive to parallelize the inner matrix multiplies.
266 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
267 cost_per_unit,
268 [&in_x, &in_y, adj_x, adj_y, trans_x, trans_y, &bcast, out](
269 int start, int limit) {
270 SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
271 trans_x, trans_y, bcast, out,
272 start, limit);
273 });
274 }
275 if (conjugate_result) {
276 // We used one of the identities
277 // conj(a) * conj(b) = conj(a * b)
278 // conj(a) * b = conj(a * conj(b))
279 // above, we need to conjugate the final output. This is a
280 // no-op for non-complex types.
281 ParallelMatMulKernel::Conjugate(context, out);
282 }
283 }
284 };
285
286 #if GOOGLE_CUDA
287
288 namespace {
289 // A dummy type to group matmul autotune results together.
290 struct BlasLtMatmulAutoTuneGroup {
291 static string name() { return "MatmulLt"; }
292 };
293
294 typedef AutotuneSingleton<BlasLtMatmulAutoTuneGroup, BlasLtMatmulPlanParams,
295 se::blas::AlgorithmConfig,
296 absl::Hash<BlasLtMatmulPlanParams>>
297 AutoTuneBatchMatmul;
298
299 } // namespace
300
301 #endif // GOOGLE_CUDA
302 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
303
304 class BlasScratchAllocator : public se::ScratchAllocator {
305 public:
306 using Stream = se::Stream;
307 using DeviceMemoryBytes = se::DeviceMemory<uint8>;
308
309 BlasScratchAllocator(OpKernelContext* context)
310 : memory_limit_(0), total_byte_size_(0), context_(context) {}
311
312 BlasScratchAllocator(OpKernelContext* context, int64_t memory_limit)
313 : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
314
315 int64_t GetMemoryLimitInBytes() override { return memory_limit_; }
316
317 se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
318 int64_t byte_size) override {
319 Tensor temporary_memory;
320
321 if (memory_limit_ > 0 && byte_size > memory_limit_) {
322 return se::port::Status{
323 se::port::error::UNAVAILABLE,
324 absl::StrCat("Requested memory size (", byte_size,
325 ") exceeds the memory limit (", memory_limit_, ").")};
326 }
327 AllocationAttributes allocation_attr;
328 allocation_attr.retry_on_failure = false;
329 Status allocation_status(context_->allocate_temp(
330 DT_UINT8, TensorShape({byte_size}), &temporary_memory));
331 if (!allocation_status.ok()) {
332 return se::port::Status{
333 se::port::error::UNAVAILABLE,
334 absl::StrCat("Failed to allocate requested memory of (", byte_size,
335 ").")};
336 }
337 // Hold the reference of the allocated tensors until the end of the
338 // allocator.
339 allocated_tensors_.push_back(temporary_memory);
340 return se::port::StatusOr<DeviceMemoryBytes>(
341 DeviceMemoryBytes::MakeFromByteSize(
342 temporary_memory.flat<uint8>().data(),
343 temporary_memory.flat<uint8>().size()));
344 }
345
346 private:
347 int64_t memory_limit_;
348 int64_t total_byte_size_;
349 OpKernelContext* context_;
350 std::vector<Tensor> allocated_tensors_;
351 };
352
353 template <typename Scalar>
354 struct LaunchBatchMatMul<GPUDevice, Scalar> {
355 static void Launch(OpKernelContext* context, const Tensor& in_x,
356 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
357 bool trans_y, const MatMulBCast& bcast, Tensor* out) {
358 static const bool use_autotune = MatmulAutotuneEnable();
359 se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
360 se::blas::Transpose::kTranspose,
361 se::blas::Transpose::kConjugateTranspose};
362 const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
363 const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
364 const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
365 const int64_t batch_size = bcast.output_batch_size();
366 auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
367 auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
368
369 auto* stream = context->op_device_context()->stream();
370 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
371
372 typedef se::DeviceMemory<Scalar> DeviceMemoryType;
373 std::vector<DeviceMemoryType> a_device_memory;
374 std::vector<DeviceMemoryType> b_device_memory;
375 std::vector<DeviceMemoryType> c_device_memory;
376 std::vector<DeviceMemoryType*> a_ptrs;
377 std::vector<DeviceMemoryType*> b_ptrs;
378 std::vector<DeviceMemoryType*> c_ptrs;
379 a_device_memory.reserve(bcast.x_batch_size());
380 b_device_memory.reserve(bcast.y_batch_size());
381 c_device_memory.reserve(batch_size);
382 a_ptrs.reserve(batch_size);
383 b_ptrs.reserve(batch_size);
384 c_ptrs.reserve(batch_size);
385 auto* a_base_ptr = in_x.template flat<Scalar>().data();
386 auto* b_base_ptr = in_y.template flat<Scalar>().data();
387 auto* c_base_ptr = out->template flat<Scalar>().data();
388 uint64 a_stride;
389 uint64 b_stride;
390 uint64 c_stride;
391
392 bool is_full_broadcast =
393 std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
394
395 // Use float as coefficient type for half precision inputs, otherwise use
396 // the input type.
397 typedef std::conditional_t<std::is_same_v<Scalar, Eigen::half>, float,
398 Scalar>
399 Coefficient;
400
401 #if GOOGLE_CUDA
402 if (EnableCublasLtGemm()) {
403 static const int64_t max_scratch_size =
404 GetWorkspaceLimit(1LL << 32); // 4GB by default
405
406 bool requires_mixed_broadcasting =
407 bcast.IsBroadcastingRequired() && !is_full_broadcast;
408
409 if (!requires_mixed_broadcasting) {
410 a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
411 b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
412 c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
413 a_ptrs.push_back(&a_device_memory.back());
414 b_ptrs.push_back(&b_device_memory.back());
415 c_ptrs.push_back(&c_device_memory.back());
416
417 BlasLtMatmulPlanParams matmul_params{
418 se::blas::ToDataType<Scalar>::value,
419 static_cast<size_t>(m),
420 static_cast<size_t>(n),
421 static_cast<size_t>(k),
422 blas_transpose_a,
423 blas_transpose_b,
424 static_cast<size_t>(batch_size),
425 /*broadcast_a=*/bcast.x_batch_size() == 1,
426 /*broadcast_b=*/bcast.y_batch_size() == 1};
427
428 std::optional<int> max_algorithm_count;
429 if (!use_autotune) max_algorithm_count = 1;
430
431 auto plan_and_algorithms_or =
432 GetPlanAndAlgorithms(stream, matmul_params, max_algorithm_count);
433 OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
434 const auto* plan_and_algorithms =
435 std::move(plan_and_algorithms_or).value();
436 const auto& plan = plan_and_algorithms->plan;
437 const auto& algorithms = plan_and_algorithms->algorithms;
438
439 se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream);
440 OP_REQUIRES(context, blas_lt != nullptr,
441 errors::Internal("blaslt not supported"));
442
443 se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
444 if (!use_autotune) {
445 algorithm_config.set_algorithm(0);
446 } else if (!AutoTuneBatchMatmul::GetInstance()->Find(
447 matmul_params, &algorithm_config)) {
448 VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size()
449 << " algorithms.";
450 se::blas::ProfileResult best_result;
451 se::blas::ProfileResult profile_result;
452
453 for (size_t i = 0; i != algorithms.size(); ++i) {
454 const auto& profile_algorithm = algorithms[i];
455 // Create a new scratch allocator with every autotuning run so that
456 // scratch space is deallocated between runs.
457 BlasScratchAllocator scratch_allocator(context, max_scratch_size);
458 Status cublas_launch_status =
459 DoBlasLtMatmul(stream, plan, *a_ptrs[0], *b_ptrs[0], *c_ptrs[0],
460 profile_algorithm, scratch_allocator,
461 /*bias = */ {}, &profile_result);
462
463 VLOG(4) << " Autotune algorithm " << i
464 << " result: " << profile_result.elapsed_time_in_ms()
465 << " ms, valid=" << profile_result.is_valid()
466 << ", workspace_size=" << profile_algorithm.workspace_size;
467
468 if (cublas_launch_status.ok() && profile_result.is_valid() &&
469 profile_result.elapsed_time_in_ms() <
470 best_result.elapsed_time_in_ms()) {
471 best_result = profile_result;
472 // Use index into algorithms array, instead of cublas internal ID.
473 best_result.set_algorithm(i);
474 }
475 }
476
477 if (best_result.is_valid()) {
478 algorithm_config.set_algorithm(best_result.algorithm());
479 }
480 // Each matmul parameter set gets one pass of
481 // autotune. If no algorithms works, kNoAlgorithm is added to the
482 // autotune map.
483 AutoTuneBatchMatmul::GetInstance()->Insert(matmul_params,
484 algorithm_config);
485 }
486 se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm();
487 OP_REQUIRES(context,
488 0 <= algorithm_idx && algorithm_idx < algorithms.size(),
489 errors::Internal("Missing/invalid BatchMatmul algorithm"));
490 const auto& algorithm = algorithms[algorithm_idx];
491 BlasScratchAllocator scratch_allocator(context, max_scratch_size);
492 VLOG(4) << "Calling BlasLtMatMul: a.shape=(" << bcast.x_batch_size()
493 << ", " << in_x.dim_size(1) << ", " << in_x.dim_size(2)
494 << "), b.shape=(" << bcast.y_batch_size() << ", "
495 << in_y.dim_size(1) << ", " << in_y.dim_size(2) << "), m=" << m
496 << ", n=" << n << ", k=" << k << ", batch_size=" << batch_size
497 << "trans_x = " << trans_x << "trans_y = " << trans_y
498 << "adj_x = " << adj_x << "adj_y = " << adj_y;
499
500 OP_REQUIRES_OK(
501 context, DoBlasLtMatmul(stream, plan, *a_ptrs[0], *b_ptrs[0],
502 *c_ptrs[0], algorithm, scratch_allocator));
503 } else { // requires mixed broadcasting
504 const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices();
505 const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices();
506 for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
507 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
508 }
509 for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
510 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
511 }
512 for (int64_t i = 0; i < batch_size; ++i) {
513 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
514 a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
515 b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
516 c_ptrs.push_back(&c_device_memory.back());
517 }
518
519 BlasScratchAllocator scratch_allocator(context, max_scratch_size);
520 bool blas_launch_status =
521 stream
522 ->ThenBlasGemmBatchedWithScratch(
523 blas_transpose_b, blas_transpose_a, n, m, k,
524 static_cast<Coefficient>(1.0), b_ptrs,
525 adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
526 static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
527 &scratch_allocator)
528 .ok();
529 if (!blas_launch_status) {
530 context->SetStatus(errors::Internal(
531 "Blas xGEMMBatched launch failed: a.shape=",
532 in_x.shape().DebugString(),
533 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
534 ", k=", k, ", batch_size=", batch_size));
535 }
536 }
537 } else {
538 #endif // GOOGLE_CUDA
539 bool use_strided_batched =
540 (!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
541 batch_size > 1;
542 if (use_strided_batched) {
543 a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
544 b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
545 c_stride = m * n;
546 a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
547 b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
548 c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
549 a_ptrs.push_back(&a_device_memory.back());
550 b_ptrs.push_back(&b_device_memory.back());
551 c_ptrs.push_back(&c_device_memory.back());
552 } else if (!bcast.IsBroadcastingRequired()) {
553 for (int64_t i = 0; i < batch_size; ++i) {
554 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
555 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
556 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
557 a_ptrs.push_back(&a_device_memory.back());
558 b_ptrs.push_back(&b_device_memory.back());
559 c_ptrs.push_back(&c_device_memory.back());
560 }
561 } else {
562 const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices();
563 const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices();
564 for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
565 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
566 }
567 for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
568 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
569 }
570 for (int64_t i = 0; i < batch_size; ++i) {
571 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
572 a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
573 b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
574 c_ptrs.push_back(&c_device_memory.back());
575 }
576 }
577
578 // Blas does
579 // C = A x B
580 // where A, B and C are assumed to be in column major.
581 // We want the output to be in row-major, so we can compute
582 // C' = B' x A', where ' stands for transpose (not adjoint).
583 // TODO(yangzihao): Choose the best of the three strategies using
584 // autotune.
585 if (batch_size == 1) {
586 // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
587 // overhead of the scratch allocator and the batch interface.
588 // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
589 if constexpr (!std::is_same_v<Scalar, Eigen::half>) {
590 if (n == 1 &&
591 blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
592 blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
593 // This is a matrix*vector multiply so use GEMV to compute A * b.
594 // Here we are multiplying in the natural order, so we have to flip
595 // the transposition flag to compensate for the tensor being stored
596 // row-major. Since GEMV doesn't provide a way to just conjugate an
597 // argument, we have to defer those cases to GEMM below.
598 auto gemv_trans_a =
599 blas_transpose_a == se::blas::Transpose::kTranspose
600 ? se::blas::Transpose::kNoTranspose
601 : se::blas::Transpose::kTranspose;
602 bool blas_launch_status =
603 stream
604 ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
605 adj_x || trans_x ? k : m,
606 static_cast<Coefficient>(1.0), *(a_ptrs[0]),
607 adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
608 static_cast<Coefficient>(0.0), c_ptrs[0], 1)
609 .ok();
610 if (!blas_launch_status) {
611 context->SetStatus(errors::Internal(
612 "Blas xGEMV launch failed : a.shape=",
613 in_x.shape().DebugString(), ", b.shape=",
614 in_y.shape().DebugString(), ", m=", m, ", n=", n, ", k=", k));
615 }
616 return;
617 }
618 }
619
620 OP_REQUIRES_OK(context,
621 stream->ThenBlasGemm(
622 blas_transpose_b, blas_transpose_a, n, m, k,
623 *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]),
624 adj_x || trans_x ? m : k, c_ptrs[0], n,
625 se::blas::kDefaultComputePrecision));
626 } else if (use_strided_batched) {
627 OP_REQUIRES_OK(context, stream->ThenBlasGemmStridedBatched(
628 blas_transpose_b, blas_transpose_a, n, m, k,
629 static_cast<Coefficient>(1.0), *b_ptrs[0],
630 adj_y || trans_y ? k : n, b_stride,
631 *a_ptrs[0], adj_x || trans_x ? m : k,
632 a_stride, static_cast<Coefficient>(0.0),
633 c_ptrs[0], n, c_stride, batch_size));
634 } else {
635 BlasScratchAllocator scratch_allocator(context);
636 bool blas_launch_status =
637 stream
638 ->ThenBlasGemmBatchedWithScratch(
639 blas_transpose_b, blas_transpose_a, n, m, k,
640 static_cast<Coefficient>(1.0), b_ptrs,
641 adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
642 static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
643 &scratch_allocator)
644 .ok();
645 if (!blas_launch_status) {
646 context->SetStatus(errors::Internal(
647 "Blas xGEMMBatched launch failed : a.shape=",
648 in_x.shape().DebugString(),
649 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
650 ", k=", k, ", batch_size=", batch_size));
651 }
652 }
653 #if GOOGLE_CUDA
654 }
655 #endif // GOOGLE_CUDA
656 }
657 };
658
659 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
660
661 template <typename Device, typename Ta, typename Tb, typename Tout>
662 class BaseBatchMatMulOp : public OpKernel {
663 public:
664 explicit BaseBatchMatMulOp(OpKernelConstruction* context,
665 bool is_legacy_matmul)
666 : OpKernel(context) {
667 if (is_legacy_matmul) {
668 // The old MatMul kernel has "transpose_a/transpose_b" attributes.
669 OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
670 OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
671 adj_x_ = false;
672 adj_y_ = false;
673 } else {
674 OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
675 OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
676 trans_x_ = false;
677 trans_y_ = false;
678 }
679 }
680
681 ~BaseBatchMatMulOp() override {}
682
683 void Compute(OpKernelContext* ctx) override {
684 const Tensor& in0 = ctx->input(0);
685 const Tensor& in1 = ctx->input(1);
686
687 const Status s = ValidateInputTensors(ctx, in0, in1);
688 if (!s.ok()) {
689 ctx->SetStatus(s);
690 return;
691 }
692
693 MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
694 OP_REQUIRES(
695 ctx, bcast.IsValid(),
696 errors::InvalidArgument(
697 "In[0] and In[1] must have compatible batch dimensions: ",
698 in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
699
700 TensorShape out_shape = bcast.output_batch_shape();
701 auto batch_size = bcast.output_batch_size();
702 auto d0 = in0.dim_size(in0.dims() - 2);
703 auto d1 = in0.dim_size(in0.dims() - 1);
704 Tensor in0_reshaped;
705 OP_REQUIRES(
706 ctx,
707 in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
708 errors::Internal("Failed to reshape In[0] from ",
709 in0.shape().DebugString()));
710 auto d2 = in1.dim_size(in1.dims() - 2);
711 auto d3 = in1.dim_size(in1.dims() - 1);
712 Tensor in1_reshaped;
713 OP_REQUIRES(
714 ctx,
715 in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
716 errors::Internal("Failed to reshape In[1] from ",
717 in1.shape().DebugString()));
718 if (adj_x_ || trans_x_) std::swap(d0, d1);
719 if (adj_y_ || trans_y_) std::swap(d2, d3);
720 OP_REQUIRES(
721 ctx, d1 == d2,
722 errors::InvalidArgument(
723 "Matrix size-incompatible: In[0]: ", in0.shape().DebugString(),
724 ", In[1]: ", in1.shape().DebugString()));
725 out_shape.AddDim(d0);
726 out_shape.AddDim(d3);
727 Tensor* out = nullptr;
728 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
729 if (out->NumElements() == 0) {
730 return;
731 }
732 if (in0.NumElements() == 0 || in1.NumElements() == 0) {
733 functor::SetZeroFunctor<Device, Tout> f;
734 f(ctx->eigen_device<Device>(), out->flat<Tout>());
735 return;
736 }
737 Tensor out_reshaped;
738 OP_REQUIRES(ctx,
739 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
740 errors::Internal("Failed to reshape output from ",
741 out->shape().DebugString()));
742 if (std::is_same<Ta, bfloat16>::value &&
743 std::is_same<Tb, bfloat16>::value) {
744 bool is_cpu = std::is_same<Device, CPUDevice>::value;
745 OP_REQUIRES(ctx, is_cpu,
746 errors::Internal("bfloat16 matmul is not supported by GPU"));
747 Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
748 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
749 &in0_reshaped_float));
750 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
751 &in1_reshaped_float));
752 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
753 &out_reshaped_float));
754
755 // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
756 BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
757 in0_reshaped_float.flat<float>().data(),
758 in0_reshaped.NumElements());
759 BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
760 in1_reshaped_float.flat<float>().data(),
761 in1_reshaped.NumElements());
762
763 LaunchBatchMatMul<Device, float>::Launch(
764 ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
765 trans_y_, bcast, &out_reshaped_float);
766 FloatToBFloat16(out_reshaped_float.flat<float>().data(),
767 out_reshaped.flat<bfloat16>().data(), out->NumElements());
768 } else {
769 // Cast tensor to desired type to reuse Eigen.
770 // TODO(b/178749687): remove this cast if Eigen supports this natively.
771 if (!std::is_same<Ta, Tout>::value) {
772 in0_reshaped = CastTensor<Ta, Tout>(in0_reshaped);
773 }
774 if (!std::is_same<Tb, Tout>::value) {
775 in1_reshaped = CastTensor<Tb, Tout>(in1_reshaped);
776 }
777 LaunchBatchMatMul<Device, Tout>::Launch(ctx, in0_reshaped, in1_reshaped,
778 adj_x_, adj_y_, trans_x_,
779 trans_y_, bcast, &out_reshaped);
780 }
781 }
782
783 protected:
784 virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
785 const Tensor& in1) = 0;
786
787 private:
788 // TODO(171979567) Make the ops take both adj and transpose attributes.
789 bool adj_x_ = false;
790 bool adj_y_ = false;
791 bool trans_x_ = false;
792 bool trans_y_ = false;
793
794 // Cast `t` from `SrcT` to `DstT`.
795 template <typename SrcT, typename DstT>
796 Tensor CastTensor(const Tensor& t) {
797 Tensor res = Tensor(DataTypeToEnum<DstT>::v(), t.shape());
798 res.flat<DstT>() = t.flat<SrcT>().template cast<DstT>();
799 return res;
800 }
801 };
802
803 // BatchMatMul Op implementation which disallows broadcasting.
804 template <typename Device, typename Ta, typename Tb, typename Tout,
805 bool is_legacy_matmul = false>
806 class BatchMatMulOp : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> {
807 public:
808 explicit BatchMatMulOp(OpKernelConstruction* context)
809 : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context, is_legacy_matmul) {}
810
811 ~BatchMatMulOp() override {}
812
813 private:
814 Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
815 const Tensor& in1) override {
816 // Disallow broadcasting support. Ensure that all batch dimensions of the
817 // input tensors match.
818 if (in0.dims() != in1.dims()) {
819 return errors::InvalidArgument(
820 "In[0] and In[1] has different ndims: ", in0.shape().DebugString(),
821 " vs. ", in1.shape().DebugString());
822 }
823 const int ndims = in0.dims();
824 if (is_legacy_matmul) {
825 if (ndims != 2) {
826 return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: ",
827 ndims);
828 }
829 } else {
830 if (ndims < 2) {
831 return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ",
832 ndims);
833 }
834 for (int i = 0; i < ndims - 2; ++i) {
835 if (in0.dim_size(i) != in1.dim_size(i)) {
836 return errors::InvalidArgument(
837 "In[0].dim(", i, ") and In[1].dim(", i,
838 ") must be the same: ", in0.shape().DebugString(), " vs ",
839 in1.shape().DebugString());
840 }
841 }
842 }
843 return Status::OK();
844 }
845 };
846
847 // BatchMatMul Op implementation with broadcasting support.
848 template <typename Device, typename Ta, typename Tb, typename Tout>
849 class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> {
850 public:
851 explicit BatchMatMulV2Op(OpKernelConstruction* context)
852 : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context,
853 /* is_legacy_matmul= */ false) {
854 }
855
856 ~BatchMatMulV2Op() override {}
857
858 private:
859 Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
860 const Tensor& in1) override {
861 // Enable broadcasting support. Validity of broadcasting is checked in
862 // BaseBatchMatMulOp.
863 if (in0.dims() < 2) {
864 return errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims());
865 }
866 if (in1.dims() < 2) {
867 return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims());
868 }
869 return Status::OK();
870 }
871 };
872
873 // Register for MatMul, BatchMatMul, BatchMatMulv2 where Tin = Tout.
874 #define REGISTER_BATCH_MATMUL_CPU(TYPE) \
875 REGISTER_KERNEL_BUILDER( \
876 Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
877 BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE>); \
878 REGISTER_KERNEL_BUILDER( \
879 Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
880 BatchMatMulV2Op<CPUDevice, TYPE, TYPE, TYPE>); \
881 REGISTER_KERNEL_BUILDER( \
882 Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
883 BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>)
884
885 #define REGISTER_BATCH_MATMUL_GPU(TYPE) \
886 REGISTER_KERNEL_BUILDER( \
887 Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
888 BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE>); \
889 REGISTER_KERNEL_BUILDER( \
890 Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
891 BatchMatMulV2Op<GPUDevice, TYPE, TYPE, TYPE>); \
892 REGISTER_KERNEL_BUILDER( \
893 Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
894 BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>)
895
896 // Register for BatchMatMulv3 where Ta, Tb and Tout are not the same.
897 #define REGISTER_BATCH_MATMUL_TOUT_CPU(Ta, Tb, Tout) \
898 REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3") \
899 .Device(DEVICE_CPU) \
900 .TypeConstraint<Ta>("Ta") \
901 .TypeConstraint<Tb>("Tb") \
902 .TypeConstraint<Tout>("Tout"), \
903 BatchMatMulV2Op<CPUDevice, Ta, Tb, Tout>)
904
905 #define REGISTER_BATCH_MATMUL_TOUT_GPU(Ta, Tb, Tout) \
906 REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3") \
907 .Device(DEVICE_GPU) \
908 .TypeConstraint<Ta>("Ta") \
909 .TypeConstraint<Tb>("Tb") \
910 .TypeConstraint<Tout>("Tout"), \
911 BatchMatMulV2Op<GPUDevice, Ta, Tb, Tout>)
912
913 } // namespace tensorflow
914
915 #endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
916