• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
20 
21 #define EIGEN_USE_THREADS
22 
23 #include <type_traits>
24 #include <utility>
25 #include <vector>
26 
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/type_traits.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/kernels/fill_functor.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/inlined_vector.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/core/util/matmul_autotune.h"
40 #include "tensorflow/core/util/matmul_bcast.h"
41 #include "tensorflow/core/util/work_sharder.h"
42 
43 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
44 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
45 #endif
46 
47 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48 #include "tensorflow/core/kernels/gpu_utils.h"
49 #include "tensorflow/core/platform/stream_executor.h"
50 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 #if GOOGLE_CUDA
52 #include "third_party/gpus/cuda/include/cuda.h"
53 #include "tensorflow/core/kernels/matmul_util.h"
54 #include "tensorflow/stream_executor/cuda/cuda_blas_lt.h"
55 #include "tensorflow/stream_executor/host_or_device_scalar.h"
56 #endif  // GOOGLE_CUDA
57 
58 namespace tensorflow {
59 
60 typedef Eigen::ThreadPoolDevice CPUDevice;
61 typedef Eigen::GpuDevice GPUDevice;
62 
63 namespace {
64 
65 // Returns the pair of dimensions along which to perform Tensor contraction to
66 // emulate matrix multiplication.
67 // For matrix multiplication of 2D Tensors X and Y, X is contracted along
68 // second dimension and Y is contracted along the first dimension (if neither X
69 // nor Y is adjointed). The dimension to contract along is switched when any
70 // operand is adjointed.
71 // See http://en.wikipedia.org/wiki/Tensor_contraction
ContractionDims(bool adj_x,bool adj_y)72 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
73   return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
74 }
75 
76 // Parallel batch matmul kernel based on the multi-threaded tensor contraction
77 // in Eigen.
78 template <typename Scalar, bool IsComplex = true>
79 struct ParallelMatMulKernel {
ConjugateParallelMatMulKernel80   static void Conjugate(const OpKernelContext* context, Tensor* out) {
81     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
82     auto z = out->tensor<Scalar, 3>();
83     z.device(d) = z.conjugate();
84   }
85 
RunParallelMatMulKernel86   static void Run(const OpKernelContext* context, const Tensor& in_x,
87                   const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
88                   bool trans_y, const MatMulBCast& bcast, Tensor* out,
89                   int batch_size) {
90     static_assert(IsComplex, "Complex type expected.");
91     auto Tx = in_x.tensor<Scalar, 3>();
92     auto Ty = in_y.tensor<Scalar, 3>();
93     auto Tz = out->tensor<Scalar, 3>();
94     // We use the identities
95     //   conj(a) * conj(b) = conj(a * b)
96     //   conj(a) * b = conj(a * conj(b))
97     // to halve the number of cases. The final conjugation of the result is
98     // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
99     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
100     contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
101     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
102 
103     const bool should_bcast = bcast.IsBroadcastingRequired();
104     const auto& x_batch_indices = bcast.x_batch_indices();
105     const auto& y_batch_indices = bcast.y_batch_indices();
106     // TODO(rmlarsen): Consider launching these contractions asynchronously.
107     for (int64_t i = 0; i < batch_size; ++i) {
108       const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
109       const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
110 
111       auto x = Tx.template chip<0>(x_batch_index);
112       auto z = Tz.template chip<0>(i);
113       if (adj_x != adj_y) {
114         auto y = Ty.template chip<0>(y_batch_index).conjugate();
115         z.device(d) = x.contract(y, contract_pairs);
116       } else {
117         auto y = Ty.template chip<0>(y_batch_index);
118         z.device(d) = x.contract(y, contract_pairs);
119       }
120     }
121   }
122 };
123 
124 // The Eigen contraction kernel used here is very large and slow to compile,
125 // so we partially specialize ParallelMatMulKernel for real types to avoid all
126 // but one of the instantiations.
127 template <typename Scalar>
128 struct ParallelMatMulKernel<Scalar, false> {
129   static void Conjugate(const OpKernelContext* context, Tensor* out) {}
130 
131   static void Run(const OpKernelContext* context, const Tensor& in_x,
132                   const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
133                   bool trans_y, const MatMulBCast& bcast, Tensor* out,
134                   int batch_size) {
135     const bool should_bcast = bcast.IsBroadcastingRequired();
136     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
137     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
138     contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
139     if (batch_size == 1 && !should_bcast) {
140       auto Tx = in_x.flat_inner_dims<Scalar, 2>();
141       auto Ty = in_y.flat_inner_dims<Scalar, 2>();
142       auto Tz = out->flat_inner_dims<Scalar, 2>();
143       Tz.device(d) = Tx.contract(Ty, contract_pairs);
144     } else {
145       auto Tx = in_x.tensor<Scalar, 3>();
146       auto Ty = in_y.tensor<Scalar, 3>();
147       auto Tz = out->tensor<Scalar, 3>();
148       const auto& x_batch_indices = bcast.x_batch_indices();
149       const auto& y_batch_indices = bcast.y_batch_indices();
150       // TODO(rmlarsen): Consider launching these contractions asynchronously.
151       for (int64_t i = 0; i < batch_size; ++i) {
152         const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
153         const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
154         auto x = Tx.template chip<0>(x_batch_index);
155         auto y = Ty.template chip<0>(y_batch_index);
156         auto z = Tz.template chip<0>(i);
157 
158         z.device(d) = x.contract(y, contract_pairs);
159       }
160     }
161   }
162 };
163 
164 // Sequential batch matmul kernel that calls the regular Eigen matmul.
165 // We prefer this over the tensor contraction because it performs
166 // better on vector-matrix and matrix-vector products.
167 template <typename Scalar>
168 struct SequentialMatMulKernel {
169   using Matrix =
170       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
171   using ConstMatrixMap = Eigen::Map<const Matrix>;
172   using MatrixMap = Eigen::Map<Matrix>;
173 
174   static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
175                                                       int slice) {
176     return ConstMatrixMap(
177         t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
178         t.dim_size(1), t.dim_size(2));
179   }
180 
181   static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
182     return MatrixMap(
183         t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
184         t->dim_size(1), t->dim_size(2));
185   }
186 
187   static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
188                   bool adj_y, bool trans_x, bool trans_y,
189                   const MatMulBCast& bcast, Tensor* out, int start, int limit) {
190     const bool should_bcast = bcast.IsBroadcastingRequired();
191     const auto& x_batch_indices = bcast.x_batch_indices();
192     const auto& y_batch_indices = bcast.y_batch_indices();
193     for (int64_t i = start; i < limit; ++i) {
194       const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
195       const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
196       auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
197       auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
198       auto z = TensorSliceToEigenMatrix(out, i);
199       // Assume at most one of adj_x or trans_x is true. Similarly, for adj_y
200       // and trans_y.
201       if (!adj_x && !trans_x) {
202         if (!adj_y && !trans_y) {
203           z.noalias() = x * y;
204         } else if (adj_y) {
205           z.noalias() = x * y.adjoint();
206         } else {  // trans_y == true
207           z.noalias() = x * y.transpose();
208         }
209       } else if (adj_x) {
210         if (!adj_y && !trans_y) {
211           z.noalias() = x.adjoint() * y;
212         } else if (adj_y) {
213           z.noalias() = x.adjoint() * y.adjoint();
214         } else {  // trans_y == true
215           z.noalias() = x.adjoint() * y.transpose();
216         }
217       } else {  // trans_x == true
218         if (!adj_y && !trans_y) {
219           z.noalias() = x.transpose() * y;
220         } else if (adj_y) {
221           z.noalias() = x.transpose() * y.adjoint();
222         } else {  // trans_y == true
223           z.noalias() = x.transpose() * y.transpose();
224         }
225       }
226     }
227   }
228 };
229 }  // namespace
230 
231 template <typename Device, typename Scalar>
232 struct LaunchBatchMatMul;
233 
234 template <typename Scalar>
235 struct LaunchBatchMatMul<CPUDevice, Scalar> {
236   static void Launch(OpKernelContext* context, const Tensor& in_x,
237                      const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
238                      bool trans_y, const MatMulBCast& bcast, Tensor* out) {
239     typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
240         ParallelMatMulKernel;
241     bool conjugate_result = false;
242 
243     // Number of matrix multiplies i.e. size of the batch.
244     const int64_t batch_size = bcast.output_batch_size();
245     const int64_t cost_per_unit =
246         in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
247     const int64_t small_dim = std::min(
248         std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
249     // NOTE(nikhilsarda): This heuristic is optimal in benchmarks as of
250     // Jan 21, 2020.
251     const int64_t kMaxCostOuterParallelism = 128 * 128;  // heuristic.
252     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
253     // TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous
254     // evaluation in Eigen Tensor.
255     if (small_dim > 1 &&
256         (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
257       // Parallelize over inner dims.
258       // For large matrix products it is counter-productive to parallelize
259       // over the batch dimension.
260       ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
261                                 trans_y, bcast, out, batch_size);
262       conjugate_result = adj_x;
263     } else {
264       // Parallelize over outer dims. For small matrices and large batches, it
265       // is counter-productive to parallelize the inner matrix multiplies.
266       Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
267             cost_per_unit,
268             [&in_x, &in_y, adj_x, adj_y, trans_x, trans_y, &bcast, out](
269                 int start, int limit) {
270               SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
271                                                   trans_x, trans_y, bcast, out,
272                                                   start, limit);
273             });
274     }
275     if (conjugate_result) {
276       // We used one of the identities
277       //   conj(a) * conj(b) = conj(a * b)
278       //   conj(a) * b = conj(a * conj(b))
279       // above, we need to conjugate the final output. This is a
280       // no-op for non-complex types.
281       ParallelMatMulKernel::Conjugate(context, out);
282     }
283   }
284 };
285 
286 #if GOOGLE_CUDA
287 
288 namespace {
289 // A dummy type to group matmul autotune results together.
290 struct BlasLtMatmulAutoTuneGroup {
291   static string name() { return "MatmulLt"; }
292 };
293 
294 typedef AutotuneSingleton<BlasLtMatmulAutoTuneGroup, BlasLtMatmulPlanParams,
295                           se::blas::AlgorithmConfig,
296                           absl::Hash<BlasLtMatmulPlanParams>>
297     AutoTuneBatchMatmul;
298 
299 }  // namespace
300 
301 #endif  // GOOGLE_CUDA
302 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
303 
304 class BlasScratchAllocator : public se::ScratchAllocator {
305  public:
306   using Stream = se::Stream;
307   using DeviceMemoryBytes = se::DeviceMemory<uint8>;
308 
309   BlasScratchAllocator(OpKernelContext* context)
310       : memory_limit_(0), total_byte_size_(0), context_(context) {}
311 
312   BlasScratchAllocator(OpKernelContext* context, int64_t memory_limit)
313       : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
314 
315   int64_t GetMemoryLimitInBytes() override { return memory_limit_; }
316 
317   se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
318       int64_t byte_size) override {
319     Tensor temporary_memory;
320 
321     if (memory_limit_ > 0 && byte_size > memory_limit_) {
322       return se::port::Status{
323           se::port::error::UNAVAILABLE,
324           absl::StrCat("Requested memory size (", byte_size,
325                        ") exceeds the memory limit (", memory_limit_, ").")};
326     }
327     AllocationAttributes allocation_attr;
328     allocation_attr.retry_on_failure = false;
329     Status allocation_status(context_->allocate_temp(
330         DT_UINT8, TensorShape({byte_size}), &temporary_memory));
331     if (!allocation_status.ok()) {
332       return se::port::Status{
333           se::port::error::UNAVAILABLE,
334           absl::StrCat("Failed to allocate requested memory of (", byte_size,
335                        ").")};
336     }
337     // Hold the reference of the allocated tensors until the end of the
338     // allocator.
339     allocated_tensors_.push_back(temporary_memory);
340     return se::port::StatusOr<DeviceMemoryBytes>(
341         DeviceMemoryBytes::MakeFromByteSize(
342             temporary_memory.flat<uint8>().data(),
343             temporary_memory.flat<uint8>().size()));
344   }
345 
346  private:
347   int64_t memory_limit_;
348   int64_t total_byte_size_;
349   OpKernelContext* context_;
350   std::vector<Tensor> allocated_tensors_;
351 };
352 
353 template <typename Scalar>
354 struct LaunchBatchMatMul<GPUDevice, Scalar> {
355   static void Launch(OpKernelContext* context, const Tensor& in_x,
356                      const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
357                      bool trans_y, const MatMulBCast& bcast, Tensor* out) {
358     static const bool use_autotune = MatmulAutotuneEnable();
359     se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
360                                    se::blas::Transpose::kTranspose,
361                                    se::blas::Transpose::kConjugateTranspose};
362     const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
363     const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
364     const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
365     const int64_t batch_size = bcast.output_batch_size();
366     auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
367     auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
368 
369     auto* stream = context->op_device_context()->stream();
370     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
371 
372     typedef se::DeviceMemory<Scalar> DeviceMemoryType;
373     std::vector<DeviceMemoryType> a_device_memory;
374     std::vector<DeviceMemoryType> b_device_memory;
375     std::vector<DeviceMemoryType> c_device_memory;
376     std::vector<DeviceMemoryType*> a_ptrs;
377     std::vector<DeviceMemoryType*> b_ptrs;
378     std::vector<DeviceMemoryType*> c_ptrs;
379     a_device_memory.reserve(bcast.x_batch_size());
380     b_device_memory.reserve(bcast.y_batch_size());
381     c_device_memory.reserve(batch_size);
382     a_ptrs.reserve(batch_size);
383     b_ptrs.reserve(batch_size);
384     c_ptrs.reserve(batch_size);
385     auto* a_base_ptr = in_x.template flat<Scalar>().data();
386     auto* b_base_ptr = in_y.template flat<Scalar>().data();
387     auto* c_base_ptr = out->template flat<Scalar>().data();
388     uint64 a_stride;
389     uint64 b_stride;
390     uint64 c_stride;
391 
392     bool is_full_broadcast =
393         std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
394 
395     // Use float as coefficient type for half precision inputs, otherwise use
396     // the input type.
397     typedef std::conditional_t<std::is_same_v<Scalar, Eigen::half>, float,
398                                Scalar>
399         Coefficient;
400 
401 #if GOOGLE_CUDA
402     if (EnableCublasLtGemm()) {
403       static const int64_t max_scratch_size =
404           GetWorkspaceLimit(1LL << 32);  // 4GB by default
405 
406       bool requires_mixed_broadcasting =
407           bcast.IsBroadcastingRequired() && !is_full_broadcast;
408 
409       if (!requires_mixed_broadcasting) {
410         a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
411         b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
412         c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
413         a_ptrs.push_back(&a_device_memory.back());
414         b_ptrs.push_back(&b_device_memory.back());
415         c_ptrs.push_back(&c_device_memory.back());
416 
417         BlasLtMatmulPlanParams matmul_params{
418             se::blas::ToDataType<Scalar>::value,
419             static_cast<size_t>(m),
420             static_cast<size_t>(n),
421             static_cast<size_t>(k),
422             blas_transpose_a,
423             blas_transpose_b,
424             static_cast<size_t>(batch_size),
425             /*broadcast_a=*/bcast.x_batch_size() == 1,
426             /*broadcast_b=*/bcast.y_batch_size() == 1};
427 
428         std::optional<int> max_algorithm_count;
429         if (!use_autotune) max_algorithm_count = 1;
430 
431         auto plan_and_algorithms_or =
432             GetPlanAndAlgorithms(stream, matmul_params, max_algorithm_count);
433         OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
434         const auto* plan_and_algorithms =
435             std::move(plan_and_algorithms_or).value();
436         const auto& plan = plan_and_algorithms->plan;
437         const auto& algorithms = plan_and_algorithms->algorithms;
438 
439         se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream);
440         OP_REQUIRES(context, blas_lt != nullptr,
441                     errors::Internal("blaslt not supported"));
442 
443         se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
444         if (!use_autotune) {
445           algorithm_config.set_algorithm(0);
446         } else if (!AutoTuneBatchMatmul::GetInstance()->Find(
447                        matmul_params, &algorithm_config)) {
448           VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size()
449                   << " algorithms.";
450           se::blas::ProfileResult best_result;
451           se::blas::ProfileResult profile_result;
452 
453           for (size_t i = 0; i != algorithms.size(); ++i) {
454             const auto& profile_algorithm = algorithms[i];
455             // Create a new scratch allocator with every autotuning run so that
456             // scratch space is deallocated between runs.
457             BlasScratchAllocator scratch_allocator(context, max_scratch_size);
458             Status cublas_launch_status =
459                 DoBlasLtMatmul(stream, plan, *a_ptrs[0], *b_ptrs[0], *c_ptrs[0],
460                                profile_algorithm, scratch_allocator,
461                                /*bias = */ {}, &profile_result);
462 
463             VLOG(4) << "  Autotune algorithm " << i
464                     << " result: " << profile_result.elapsed_time_in_ms()
465                     << " ms, valid=" << profile_result.is_valid()
466                     << ", workspace_size=" << profile_algorithm.workspace_size;
467 
468             if (cublas_launch_status.ok() && profile_result.is_valid() &&
469                 profile_result.elapsed_time_in_ms() <
470                     best_result.elapsed_time_in_ms()) {
471               best_result = profile_result;
472               // Use index into algorithms array, instead of cublas internal ID.
473               best_result.set_algorithm(i);
474             }
475           }
476 
477           if (best_result.is_valid()) {
478             algorithm_config.set_algorithm(best_result.algorithm());
479           }
480           // Each matmul parameter set gets one pass of
481           // autotune. If no algorithms works, kNoAlgorithm is added to the
482           // autotune map.
483           AutoTuneBatchMatmul::GetInstance()->Insert(matmul_params,
484                                                      algorithm_config);
485         }
486         se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm();
487         OP_REQUIRES(context,
488                     0 <= algorithm_idx && algorithm_idx < algorithms.size(),
489                     errors::Internal("Missing/invalid BatchMatmul algorithm"));
490         const auto& algorithm = algorithms[algorithm_idx];
491         BlasScratchAllocator scratch_allocator(context, max_scratch_size);
492         VLOG(4) << "Calling BlasLtMatMul: a.shape=(" << bcast.x_batch_size()
493                 << ", " << in_x.dim_size(1) << ", " << in_x.dim_size(2)
494                 << "), b.shape=(" << bcast.y_batch_size() << ", "
495                 << in_y.dim_size(1) << ", " << in_y.dim_size(2) << "), m=" << m
496                 << ", n=" << n << ", k=" << k << ", batch_size=" << batch_size
497                 << "trans_x = " << trans_x << "trans_y = " << trans_y
498                 << "adj_x = " << adj_x << "adj_y = " << adj_y;
499 
500         OP_REQUIRES_OK(
501             context, DoBlasLtMatmul(stream, plan, *a_ptrs[0], *b_ptrs[0],
502                                     *c_ptrs[0], algorithm, scratch_allocator));
503       } else {  // requires mixed broadcasting
504         const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices();
505         const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices();
506         for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
507           a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
508         }
509         for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
510           b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
511         }
512         for (int64_t i = 0; i < batch_size; ++i) {
513           c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
514           a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
515           b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
516           c_ptrs.push_back(&c_device_memory.back());
517         }
518 
519         BlasScratchAllocator scratch_allocator(context, max_scratch_size);
520         bool blas_launch_status =
521             stream
522                 ->ThenBlasGemmBatchedWithScratch(
523                     blas_transpose_b, blas_transpose_a, n, m, k,
524                     static_cast<Coefficient>(1.0), b_ptrs,
525                     adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
526                     static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
527                     &scratch_allocator)
528                 .ok();
529         if (!blas_launch_status) {
530           context->SetStatus(errors::Internal(
531               "Blas xGEMMBatched launch failed: a.shape=",
532               in_x.shape().DebugString(),
533               ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
534               ", k=", k, ", batch_size=", batch_size));
535         }
536       }
537     } else {
538 #endif  // GOOGLE_CUDA
539       bool use_strided_batched =
540           (!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
541           batch_size > 1;
542       if (use_strided_batched) {
543         a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
544         b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
545         c_stride = m * n;
546         a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
547         b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
548         c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
549         a_ptrs.push_back(&a_device_memory.back());
550         b_ptrs.push_back(&b_device_memory.back());
551         c_ptrs.push_back(&c_device_memory.back());
552       } else if (!bcast.IsBroadcastingRequired()) {
553         for (int64_t i = 0; i < batch_size; ++i) {
554           a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
555           b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
556           c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
557           a_ptrs.push_back(&a_device_memory.back());
558           b_ptrs.push_back(&b_device_memory.back());
559           c_ptrs.push_back(&c_device_memory.back());
560         }
561       } else {
562         const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices();
563         const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices();
564         for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
565           a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
566         }
567         for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
568           b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
569         }
570         for (int64_t i = 0; i < batch_size; ++i) {
571           c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
572           a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
573           b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
574           c_ptrs.push_back(&c_device_memory.back());
575         }
576       }
577 
578       // Blas does
579       // C = A x B
580       // where A, B and C are assumed to be in column major.
581       // We want the output to be in row-major, so we can compute
582       // C' = B' x A', where ' stands for transpose (not adjoint).
583       // TODO(yangzihao): Choose the best of the three strategies using
584       // autotune.
585       if (batch_size == 1) {
586         // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
587         // overhead of the scratch allocator and the batch interface.
588         // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
589         if constexpr (!std::is_same_v<Scalar, Eigen::half>) {
590           if (n == 1 &&
591               blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
592               blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
593             // This is a matrix*vector multiply so use GEMV to compute A * b.
594             // Here we are multiplying in the natural order, so we have to flip
595             // the transposition flag to compensate for the tensor being stored
596             // row-major. Since GEMV doesn't provide a way to just conjugate an
597             // argument, we have to defer those cases to GEMM below.
598             auto gemv_trans_a =
599                 blas_transpose_a == se::blas::Transpose::kTranspose
600                     ? se::blas::Transpose::kNoTranspose
601                     : se::blas::Transpose::kTranspose;
602             bool blas_launch_status =
603                 stream
604                     ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
605                                    adj_x || trans_x ? k : m,
606                                    static_cast<Coefficient>(1.0), *(a_ptrs[0]),
607                                    adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
608                                    static_cast<Coefficient>(0.0), c_ptrs[0], 1)
609                     .ok();
610             if (!blas_launch_status) {
611               context->SetStatus(errors::Internal(
612                   "Blas xGEMV launch failed : a.shape=",
613                   in_x.shape().DebugString(), ", b.shape=",
614                   in_y.shape().DebugString(), ", m=", m, ", n=", n, ", k=", k));
615             }
616             return;
617           }
618         }
619 
620         OP_REQUIRES_OK(context,
621                        stream->ThenBlasGemm(
622                            blas_transpose_b, blas_transpose_a, n, m, k,
623                            *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]),
624                            adj_x || trans_x ? m : k, c_ptrs[0], n,
625                            se::blas::kDefaultComputePrecision));
626       } else if (use_strided_batched) {
627         OP_REQUIRES_OK(context, stream->ThenBlasGemmStridedBatched(
628                                     blas_transpose_b, blas_transpose_a, n, m, k,
629                                     static_cast<Coefficient>(1.0), *b_ptrs[0],
630                                     adj_y || trans_y ? k : n, b_stride,
631                                     *a_ptrs[0], adj_x || trans_x ? m : k,
632                                     a_stride, static_cast<Coefficient>(0.0),
633                                     c_ptrs[0], n, c_stride, batch_size));
634       } else {
635         BlasScratchAllocator scratch_allocator(context);
636         bool blas_launch_status =
637             stream
638                 ->ThenBlasGemmBatchedWithScratch(
639                     blas_transpose_b, blas_transpose_a, n, m, k,
640                     static_cast<Coefficient>(1.0), b_ptrs,
641                     adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
642                     static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
643                     &scratch_allocator)
644                 .ok();
645         if (!blas_launch_status) {
646           context->SetStatus(errors::Internal(
647               "Blas xGEMMBatched launch failed : a.shape=",
648               in_x.shape().DebugString(),
649               ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
650               ", k=", k, ", batch_size=", batch_size));
651         }
652       }
653 #if GOOGLE_CUDA
654     }
655 #endif  // GOOGLE_CUDA
656   }
657 };
658 
659 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
660 
661 template <typename Device, typename Ta, typename Tb, typename Tout>
662 class BaseBatchMatMulOp : public OpKernel {
663  public:
664   explicit BaseBatchMatMulOp(OpKernelConstruction* context,
665                              bool is_legacy_matmul)
666       : OpKernel(context) {
667     if (is_legacy_matmul) {
668       // The old MatMul kernel has "transpose_a/transpose_b" attributes.
669       OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
670       OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
671       adj_x_ = false;
672       adj_y_ = false;
673     } else {
674       OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
675       OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
676       trans_x_ = false;
677       trans_y_ = false;
678     }
679   }
680 
681   ~BaseBatchMatMulOp() override {}
682 
683   void Compute(OpKernelContext* ctx) override {
684     const Tensor& in0 = ctx->input(0);
685     const Tensor& in1 = ctx->input(1);
686 
687     const Status s = ValidateInputTensors(ctx, in0, in1);
688     if (!s.ok()) {
689       ctx->SetStatus(s);
690       return;
691     }
692 
693     MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
694     OP_REQUIRES(
695         ctx, bcast.IsValid(),
696         errors::InvalidArgument(
697             "In[0] and In[1] must have compatible batch dimensions: ",
698             in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
699 
700     TensorShape out_shape = bcast.output_batch_shape();
701     auto batch_size = bcast.output_batch_size();
702     auto d0 = in0.dim_size(in0.dims() - 2);
703     auto d1 = in0.dim_size(in0.dims() - 1);
704     Tensor in0_reshaped;
705     OP_REQUIRES(
706         ctx,
707         in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
708         errors::Internal("Failed to reshape In[0] from ",
709                          in0.shape().DebugString()));
710     auto d2 = in1.dim_size(in1.dims() - 2);
711     auto d3 = in1.dim_size(in1.dims() - 1);
712     Tensor in1_reshaped;
713     OP_REQUIRES(
714         ctx,
715         in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
716         errors::Internal("Failed to reshape In[1] from ",
717                          in1.shape().DebugString()));
718     if (adj_x_ || trans_x_) std::swap(d0, d1);
719     if (adj_y_ || trans_y_) std::swap(d2, d3);
720     OP_REQUIRES(
721         ctx, d1 == d2,
722         errors::InvalidArgument(
723             "Matrix size-incompatible: In[0]: ", in0.shape().DebugString(),
724             ", In[1]: ", in1.shape().DebugString()));
725     out_shape.AddDim(d0);
726     out_shape.AddDim(d3);
727     Tensor* out = nullptr;
728     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
729     if (out->NumElements() == 0) {
730       return;
731     }
732     if (in0.NumElements() == 0 || in1.NumElements() == 0) {
733       functor::SetZeroFunctor<Device, Tout> f;
734       f(ctx->eigen_device<Device>(), out->flat<Tout>());
735       return;
736     }
737     Tensor out_reshaped;
738     OP_REQUIRES(ctx,
739                 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
740                 errors::Internal("Failed to reshape output from ",
741                                  out->shape().DebugString()));
742     if (std::is_same<Ta, bfloat16>::value &&
743         std::is_same<Tb, bfloat16>::value) {
744       bool is_cpu = std::is_same<Device, CPUDevice>::value;
745       OP_REQUIRES(ctx, is_cpu,
746                   errors::Internal("bfloat16 matmul is not supported by GPU"));
747       Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
748       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
749                                              &in0_reshaped_float));
750       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
751                                              &in1_reshaped_float));
752       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
753                                              &out_reshaped_float));
754 
755       // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
756       BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
757                       in0_reshaped_float.flat<float>().data(),
758                       in0_reshaped.NumElements());
759       BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
760                       in1_reshaped_float.flat<float>().data(),
761                       in1_reshaped.NumElements());
762 
763       LaunchBatchMatMul<Device, float>::Launch(
764           ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
765           trans_y_, bcast, &out_reshaped_float);
766       FloatToBFloat16(out_reshaped_float.flat<float>().data(),
767                       out_reshaped.flat<bfloat16>().data(), out->NumElements());
768     } else {
769       // Cast tensor to desired type to reuse Eigen.
770       // TODO(b/178749687): remove this cast if Eigen supports this natively.
771       if (!std::is_same<Ta, Tout>::value) {
772         in0_reshaped = CastTensor<Ta, Tout>(in0_reshaped);
773       }
774       if (!std::is_same<Tb, Tout>::value) {
775         in1_reshaped = CastTensor<Tb, Tout>(in1_reshaped);
776       }
777       LaunchBatchMatMul<Device, Tout>::Launch(ctx, in0_reshaped, in1_reshaped,
778                                               adj_x_, adj_y_, trans_x_,
779                                               trans_y_, bcast, &out_reshaped);
780     }
781   }
782 
783  protected:
784   virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
785                                       const Tensor& in1) = 0;
786 
787  private:
788   // TODO(171979567) Make the ops take both adj and transpose attributes.
789   bool adj_x_ = false;
790   bool adj_y_ = false;
791   bool trans_x_ = false;
792   bool trans_y_ = false;
793 
794   // Cast `t` from `SrcT` to `DstT`.
795   template <typename SrcT, typename DstT>
796   Tensor CastTensor(const Tensor& t) {
797     Tensor res = Tensor(DataTypeToEnum<DstT>::v(), t.shape());
798     res.flat<DstT>() = t.flat<SrcT>().template cast<DstT>();
799     return res;
800   }
801 };
802 
803 // BatchMatMul Op implementation which disallows broadcasting.
804 template <typename Device, typename Ta, typename Tb, typename Tout,
805           bool is_legacy_matmul = false>
806 class BatchMatMulOp : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> {
807  public:
808   explicit BatchMatMulOp(OpKernelConstruction* context)
809       : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context, is_legacy_matmul) {}
810 
811   ~BatchMatMulOp() override {}
812 
813  private:
814   Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
815                               const Tensor& in1) override {
816     // Disallow broadcasting support. Ensure that all batch dimensions of the
817     // input tensors match.
818     if (in0.dims() != in1.dims()) {
819       return errors::InvalidArgument(
820           "In[0] and In[1] has different ndims: ", in0.shape().DebugString(),
821           " vs. ", in1.shape().DebugString());
822     }
823     const int ndims = in0.dims();
824     if (is_legacy_matmul) {
825       if (ndims != 2) {
826         return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: ",
827                                        ndims);
828       }
829     } else {
830       if (ndims < 2) {
831         return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ",
832                                        ndims);
833       }
834       for (int i = 0; i < ndims - 2; ++i) {
835         if (in0.dim_size(i) != in1.dim_size(i)) {
836           return errors::InvalidArgument(
837               "In[0].dim(", i, ") and In[1].dim(", i,
838               ") must be the same: ", in0.shape().DebugString(), " vs ",
839               in1.shape().DebugString());
840         }
841       }
842     }
843     return Status::OK();
844   }
845 };
846 
847 // BatchMatMul Op implementation with broadcasting support.
848 template <typename Device, typename Ta, typename Tb, typename Tout>
849 class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> {
850  public:
851   explicit BatchMatMulV2Op(OpKernelConstruction* context)
852       : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context,
853                                                 /* is_legacy_matmul= */ false) {
854   }
855 
856   ~BatchMatMulV2Op() override {}
857 
858  private:
859   Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
860                               const Tensor& in1) override {
861     // Enable broadcasting support. Validity of broadcasting is checked in
862     // BaseBatchMatMulOp.
863     if (in0.dims() < 2) {
864       return errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims());
865     }
866     if (in1.dims() < 2) {
867       return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims());
868     }
869     return Status::OK();
870   }
871 };
872 
873 // Register for MatMul, BatchMatMul, BatchMatMulv2 where Tin = Tout.
874 #define REGISTER_BATCH_MATMUL_CPU(TYPE)                                   \
875   REGISTER_KERNEL_BUILDER(                                                \
876       Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"),   \
877       BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE>);                        \
878   REGISTER_KERNEL_BUILDER(                                                \
879       Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
880       BatchMatMulV2Op<CPUDevice, TYPE, TYPE, TYPE>);                      \
881   REGISTER_KERNEL_BUILDER(                                                \
882       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"),        \
883       BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>)
884 
885 #define REGISTER_BATCH_MATMUL_GPU(TYPE)                                   \
886   REGISTER_KERNEL_BUILDER(                                                \
887       Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"),   \
888       BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE>);                        \
889   REGISTER_KERNEL_BUILDER(                                                \
890       Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
891       BatchMatMulV2Op<GPUDevice, TYPE, TYPE, TYPE>);                      \
892   REGISTER_KERNEL_BUILDER(                                                \
893       Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"),        \
894       BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>)
895 
896 // Register for BatchMatMulv3 where Ta, Tb and Tout are not the same.
897 #define REGISTER_BATCH_MATMUL_TOUT_CPU(Ta, Tb, Tout)         \
898   REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3")              \
899                               .Device(DEVICE_CPU)            \
900                               .TypeConstraint<Ta>("Ta")      \
901                               .TypeConstraint<Tb>("Tb")      \
902                               .TypeConstraint<Tout>("Tout"), \
903                           BatchMatMulV2Op<CPUDevice, Ta, Tb, Tout>)
904 
905 #define REGISTER_BATCH_MATMUL_TOUT_GPU(Ta, Tb, Tout)         \
906   REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3")              \
907                               .Device(DEVICE_GPU)            \
908                               .TypeConstraint<Ta>("Ta")      \
909                               .TypeConstraint<Tb>("Tb")      \
910                               .TypeConstraint<Tout>("Tout"), \
911                           BatchMatMulV2Op<GPUDevice, Ta, Tb, Tout>)
912 
913 }  // namespace tensorflow
914 
915 #endif  // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
916