• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
20 
21 #define EIGEN_USE_THREADS
22 
23 #include <vector>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/type_traits.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/kernels/fill_functor.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/inlined_vector.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/matmul_bcast.h"
39 #include "tensorflow/core/util/work_sharder.h"
40 
41 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
42 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
43 #endif
44 
45 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
46 #include "tensorflow/core/platform/stream_executor.h"
47 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48 
49 namespace tensorflow {
50 
51 typedef Eigen::ThreadPoolDevice CPUDevice;
52 typedef Eigen::GpuDevice GPUDevice;
53 
54 namespace {
55 
56 // Returns the pair of dimensions along which to perform Tensor contraction to
57 // emulate matrix multiplication.
58 // For matrix multiplication of 2D Tensors X and Y, X is contracted along
59 // second dimension and Y is contracted along the first dimension (if neither X
60 // nor Y is adjointed). The dimension to contract along is switched when any
61 // operand is adjointed.
62 // See http://en.wikipedia.org/wiki/Tensor_contraction
ContractionDims(bool adj_x,bool adj_y)63 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
64   return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
65 }
66 
67 // Parallel batch matmul kernel based on the multi-threaded tensor contraction
68 // in Eigen.
69 template <typename Scalar, bool IsComplex = true>
70 struct ParallelMatMulKernel {
ConjugateParallelMatMulKernel71   static void Conjugate(const OpKernelContext* context, Tensor* out) {
72     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
73     auto z = out->tensor<Scalar, 3>();
74     z.device(d) = z.conjugate();
75   }
76 
RunParallelMatMulKernel77   static void Run(const OpKernelContext* context, const Tensor& in_x,
78                   const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
79                   bool trans_y, const MatMulBCast& bcast, Tensor* out,
80                   int batch_size) {
81     static_assert(IsComplex, "Complex type expected.");
82     auto Tx = in_x.tensor<Scalar, 3>();
83     auto Ty = in_y.tensor<Scalar, 3>();
84     auto Tz = out->tensor<Scalar, 3>();
85     // We use the identities
86     //   conj(a) * conj(b) = conj(a * b)
87     //   conj(a) * b = conj(a * conj(b))
88     // to halve the number of cases. The final conjugation of the result is
89     // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
90     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
91     contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
92     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
93 
94     const bool should_bcast = bcast.IsBroadcastingRequired();
95     const auto& x_batch_indices = bcast.x_batch_indices();
96     const auto& y_batch_indices = bcast.y_batch_indices();
97     // TODO(rmlarsen): Consider launching these contractions asynchronously.
98     for (int64_t i = 0; i < batch_size; ++i) {
99       const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
100       const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
101 
102       auto x = Tx.template chip<0>(x_batch_index);
103       auto z = Tz.template chip<0>(i);
104       if (adj_x != adj_y) {
105         auto y = Ty.template chip<0>(y_batch_index).conjugate();
106         z.device(d) = x.contract(y, contract_pairs);
107       } else {
108         auto y = Ty.template chip<0>(y_batch_index);
109         z.device(d) = x.contract(y, contract_pairs);
110       }
111     }
112   }
113 };
114 
115 // The Eigen contraction kernel used here is very large and slow to compile,
116 // so we partially specialize ParallelMatMulKernel for real types to avoid all
117 // but one of the instantiations.
118 template <typename Scalar>
119 struct ParallelMatMulKernel<Scalar, false> {
120   static void Conjugate(const OpKernelContext* context, Tensor* out) {}
121 
122   static void Run(const OpKernelContext* context, const Tensor& in_x,
123                   const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
124                   bool trans_y, const MatMulBCast& bcast, Tensor* out,
125                   int batch_size) {
126     const bool should_bcast = bcast.IsBroadcastingRequired();
127     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
128     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
129     contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
130     if (batch_size == 1 && !should_bcast) {
131       auto Tx = in_x.flat_inner_dims<Scalar, 2>();
132       auto Ty = in_y.flat_inner_dims<Scalar, 2>();
133       auto Tz = out->flat_inner_dims<Scalar, 2>();
134       Tz.device(d) = Tx.contract(Ty, contract_pairs);
135     } else {
136       auto Tx = in_x.tensor<Scalar, 3>();
137       auto Ty = in_y.tensor<Scalar, 3>();
138       auto Tz = out->tensor<Scalar, 3>();
139       const auto& x_batch_indices = bcast.x_batch_indices();
140       const auto& y_batch_indices = bcast.y_batch_indices();
141       // TODO(rmlarsen): Consider launching these contractions asynchronously.
142       for (int64_t i = 0; i < batch_size; ++i) {
143         const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
144         const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
145         auto x = Tx.template chip<0>(x_batch_index);
146         auto y = Ty.template chip<0>(y_batch_index);
147         auto z = Tz.template chip<0>(i);
148 
149         z.device(d) = x.contract(y, contract_pairs);
150       }
151     }
152   }
153 };
154 
155 // Sequential batch matmul kernel that calls the regular Eigen matmul.
156 // We prefer this over the tensor contraction because it performs
157 // better on vector-matrix and matrix-vector products.
158 template <typename Scalar>
159 struct SequentialMatMulKernel {
160   using Matrix =
161       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
162   using ConstMatrixMap = Eigen::Map<const Matrix>;
163   using MatrixMap = Eigen::Map<Matrix>;
164 
165   static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
166                                                       int slice) {
167     return ConstMatrixMap(
168         t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
169         t.dim_size(1), t.dim_size(2));
170   }
171 
172   static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
173     return MatrixMap(
174         t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
175         t->dim_size(1), t->dim_size(2));
176   }
177 
178   static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
179                   bool adj_y, bool trans_x, bool trans_y,
180                   const MatMulBCast& bcast, Tensor* out, int start, int limit) {
181     const bool should_bcast = bcast.IsBroadcastingRequired();
182     const auto& x_batch_indices = bcast.x_batch_indices();
183     const auto& y_batch_indices = bcast.y_batch_indices();
184     for (int64_t i = start; i < limit; ++i) {
185       const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
186       const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
187       auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
188       auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
189       auto z = TensorSliceToEigenMatrix(out, i);
190       // Assume at most one of adj_x or trans_x is true. Similarly, for adj_y
191       // and trans_y.
192       if (!adj_x && !trans_x) {
193         if (!adj_y && !trans_y) {
194           z.noalias() = x * y;
195         } else if (adj_y) {
196           z.noalias() = x * y.adjoint();
197         } else {  // trans_y == true
198           z.noalias() = x * y.transpose();
199         }
200       } else if (adj_x) {
201         if (!adj_y && !trans_y) {
202           z.noalias() = x.adjoint() * y;
203         } else if (adj_y) {
204           z.noalias() = x.adjoint() * y.adjoint();
205         } else {  // trans_y == true
206           z.noalias() = x.adjoint() * y.transpose();
207         }
208       } else {  // trans_x == true
209         if (!adj_y && !trans_y) {
210           z.noalias() = x.transpose() * y;
211         } else if (adj_y) {
212           z.noalias() = x.transpose() * y.adjoint();
213         } else {  // trans_y == true
214           z.noalias() = x.transpose() * y.transpose();
215         }
216       }
217     }
218   }
219 };
220 
221 }  // namespace
222 
223 template <typename Device, typename Scalar>
224 struct LaunchBatchMatMul;
225 
226 template <typename Scalar>
227 struct LaunchBatchMatMul<CPUDevice, Scalar> {
228   static void Launch(OpKernelContext* context, const Tensor& in_x,
229                      const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
230                      bool trans_y, const MatMulBCast& bcast, Tensor* out) {
231     typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
232         ParallelMatMulKernel;
233     bool conjugate_result = false;
234 
235     // Number of matrix multiplies i.e. size of the batch.
236     const int64_t batch_size = bcast.output_batch_size();
237     const int64_t cost_per_unit =
238         in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
239     const int64_t small_dim = std::min(
240         std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
241     // NOTE(nikhilsarda): This heuristic is optimal in benchmarks as of
242     // Jan 21, 2020.
243     const int64_t kMaxCostOuterParallelism = 128 * 128;  // heuristic.
244     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
245     // TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous
246     // evaluation in Eigen Tensor.
247     if (small_dim > 1 &&
248         (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
249       // Parallelize over inner dims.
250       // For large matrix products it is counter-productive to parallelize
251       // over the batch dimension.
252       ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
253                                 trans_y, bcast, out, batch_size);
254       conjugate_result = adj_x;
255     } else {
256       // Parallelize over outer dims. For small matrices and large batches, it
257       // is counter-productive to parallelize the inner matrix multiplies.
258       Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
259             cost_per_unit,
260             [&in_x, &in_y, adj_x, adj_y, trans_x, trans_y, &bcast, out](
261                 int start, int limit) {
262               SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
263                                                   trans_x, trans_y, bcast, out,
264                                                   start, limit);
265             });
266     }
267     if (conjugate_result) {
268       // We used one of the identities
269       //   conj(a) * conj(b) = conj(a * b)
270       //   conj(a) * b = conj(a * conj(b))
271       // above, we need to conjugate the final output. This is a
272       // no-op for non-complex types.
273       ParallelMatMulKernel::Conjugate(context, out);
274     }
275   }
276 };
277 
278 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
279 
280 namespace {
281 template <typename T>
282 se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
283   se::DeviceMemoryBase wrapped(const_cast<T*>(gpu_memory));
284   se::DeviceMemory<T> typed(wrapped);
285   return typed;
286 }
287 
288 class BlasScratchAllocator : public se::ScratchAllocator {
289  public:
290   using Stream = se::Stream;
291   using DeviceMemoryBytes = se::DeviceMemory<uint8>;
292 
293   BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
294 
295   int64 GetMemoryLimitInBytes() override { return -1; }
296 
297   se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
298       int64_t byte_size) override {
299     Tensor temporary_memory;
300 
301     Status allocation_status(context_->allocate_temp(
302         DT_UINT8, TensorShape({byte_size}), &temporary_memory));
303     if (!allocation_status.ok()) {
304       return se::port::StatusOr<DeviceMemoryBytes>(
305           DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
306     }
307     // Hold the reference of the allocated tensors until the end of the
308     // allocator.
309     allocated_tensors_.push_back(temporary_memory);
310     return se::port::StatusOr<DeviceMemoryBytes>(
311         DeviceMemoryBytes::MakeFromByteSize(
312             temporary_memory.flat<uint8>().data(),
313             temporary_memory.flat<uint8>().size()));
314   }
315 
316  private:
317   OpKernelContext* context_;
318   std::vector<Tensor> allocated_tensors_;
319 };
320 }  // namespace
321 
322 template <typename Scalar>
323 struct LaunchBatchMatMul<GPUDevice, Scalar> {
324   static void Launch(OpKernelContext* context, const Tensor& in_x,
325                      const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
326                      bool trans_y, const MatMulBCast& bcast, Tensor* out) {
327     se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
328                                    se::blas::Transpose::kTranspose,
329                                    se::blas::Transpose::kConjugateTranspose};
330     const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
331     const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
332     const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
333     const int64_t batch_size = bcast.output_batch_size();
334     auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
335     auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
336 
337     auto* stream = context->op_device_context()->stream();
338     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
339 
340     typedef se::DeviceMemory<Scalar> DeviceMemoryType;
341     std::vector<DeviceMemoryType> a_device_memory;
342     std::vector<DeviceMemoryType> b_device_memory;
343     std::vector<DeviceMemoryType> c_device_memory;
344     std::vector<DeviceMemoryType*> a_ptrs;
345     std::vector<DeviceMemoryType*> b_ptrs;
346     std::vector<DeviceMemoryType*> c_ptrs;
347     a_device_memory.reserve(bcast.x_batch_size());
348     b_device_memory.reserve(bcast.y_batch_size());
349     c_device_memory.reserve(batch_size);
350     a_ptrs.reserve(batch_size);
351     b_ptrs.reserve(batch_size);
352     c_ptrs.reserve(batch_size);
353     auto* a_base_ptr = in_x.template flat<Scalar>().data();
354     auto* b_base_ptr = in_y.template flat<Scalar>().data();
355     auto* c_base_ptr = out->template flat<Scalar>().data();
356     uint64 a_stride;
357     uint64 b_stride;
358     uint64 c_stride;
359 
360     bool is_full_broadcast =
361         std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
362     bool use_strided_batched =
363         (!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
364         batch_size > 1;
365     if (use_strided_batched) {
366       a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
367       b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
368       c_stride = m * n;
369       a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
370       b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
371       c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
372       a_ptrs.push_back(&a_device_memory.back());
373       b_ptrs.push_back(&b_device_memory.back());
374       c_ptrs.push_back(&c_device_memory.back());
375     } else if (!bcast.IsBroadcastingRequired()) {
376       for (int64_t i = 0; i < batch_size; ++i) {
377         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
378         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
379         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
380         a_ptrs.push_back(&a_device_memory.back());
381         b_ptrs.push_back(&b_device_memory.back());
382         c_ptrs.push_back(&c_device_memory.back());
383       }
384     } else {
385       const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
386       const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
387       for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
388         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
389       }
390       for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
391         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
392       }
393       for (int64_t i = 0; i < batch_size; ++i) {
394         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
395         a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
396         b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
397         c_ptrs.push_back(&c_device_memory.back());
398       }
399     }
400 
401     typedef Scalar Coefficient;
402 
403     // Blas does
404     // C = A x B
405     // where A, B and C are assumed to be in column major.
406     // We want the output to be in row-major, so we can compute
407     // C' = B' x A', where ' stands for transpose (not adjoint).
408     // TODO(yangzihao): Choose the best of the three strategies using autotune.
409     if (batch_size == 1) {
410       // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
411       // overhead of the scratch allocator and the batch interface.
412       if (n == 1 &&
413           blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
414           blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
415         // This is a matrix*vector multiply so use GEMV to compute A * b.
416         // Here we are multiplying in the natural order, so we have to flip
417         // the transposition flag to compensate for the tensor being stored
418         // row-major. Since GEMV doesn't provide a way to just conjugate an
419         // argument, we have to defer those cases to GEMM below.
420         auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
421                                 ? se::blas::Transpose::kNoTranspose
422                                 : se::blas::Transpose::kTranspose;
423         bool blas_launch_status =
424             stream
425                 ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
426                                adj_x || trans_x ? k : m,
427                                static_cast<Coefficient>(1.0), *(a_ptrs[0]),
428                                adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
429                                static_cast<Coefficient>(0.0), c_ptrs[0], 1)
430                 .ok();
431         if (!blas_launch_status) {
432           context->SetStatus(errors::Internal(
433               "Blas xGEMV launch failed : a.shape=", in_x.shape().DebugString(),
434               ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
435               ", k=", k));
436         }
437       } else {
438         OP_REQUIRES_OK(context,
439                        stream->ThenBlasGemm(
440                            blas_transpose_b, blas_transpose_a, n, m, k,
441                            *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]),
442                            adj_x || trans_x ? m : k, c_ptrs[0], n));
443       }
444     } else if (use_strided_batched) {
445       OP_REQUIRES_OK(context, stream->ThenBlasGemmStridedBatched(
446                                   blas_transpose_b, blas_transpose_a, n, m, k,
447                                   static_cast<Coefficient>(1.0), *b_ptrs[0],
448                                   adj_y || trans_y ? k : n, b_stride,
449                                   *a_ptrs[0], adj_x || trans_x ? m : k,
450                                   a_stride, static_cast<Coefficient>(0.0),
451                                   c_ptrs[0], n, c_stride, batch_size));
452     } else {
453       BlasScratchAllocator scratch_allocator(context);
454       bool blas_launch_status =
455           stream
456               ->ThenBlasGemmBatchedWithScratch(
457                   blas_transpose_b, blas_transpose_a, n, m, k,
458                   static_cast<Coefficient>(1.0), b_ptrs,
459                   adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
460                   static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
461                   &scratch_allocator)
462               .ok();
463       if (!blas_launch_status) {
464         context->SetStatus(errors::Internal(
465             "Blas xGEMMBatched launch failed : a.shape=",
466             in_x.shape().DebugString(),
467             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
468             ", k=", k, ", batch_size=", batch_size));
469       }
470     }
471   }
472 };
473 
474 template <>
475 struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
476   static void Launch(OpKernelContext* context, const Tensor& in_x,
477                      const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
478                      bool trans_y, const MatMulBCast& bcast, Tensor* out) {
479     typedef Eigen::half Scalar;
480     se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
481                                    se::blas::Transpose::kTranspose,
482                                    se::blas::Transpose::kConjugateTranspose};
483     const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
484     const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
485     const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
486     const uint64 batch_size = bcast.output_batch_size();
487     auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
488     auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
489 
490     auto* stream = context->op_device_context()->stream();
491     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
492 
493     typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
494     std::vector<DeviceMemoryType> a_device_memory;
495     std::vector<DeviceMemoryType> b_device_memory;
496     std::vector<DeviceMemoryType> c_device_memory;
497     std::vector<DeviceMemoryType*> a_ptrs;
498     std::vector<DeviceMemoryType*> b_ptrs;
499     std::vector<DeviceMemoryType*> c_ptrs;
500     a_device_memory.reserve(bcast.x_batch_size());
501     b_device_memory.reserve(bcast.y_batch_size());
502     c_device_memory.reserve(batch_size);
503     a_ptrs.reserve(batch_size);
504     b_ptrs.reserve(batch_size);
505     c_ptrs.reserve(batch_size);
506     auto* a_base_ptr = in_x.template flat<Scalar>().data();
507     auto* b_base_ptr = in_y.template flat<Scalar>().data();
508     auto* c_base_ptr = out->template flat<Scalar>().data();
509 
510     uint64 a_stride;
511     uint64 b_stride;
512     uint64 c_stride;
513 
514     bool is_full_broadcast =
515         std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
516     bool use_strided_batched =
517         (!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
518         batch_size > 1;
519     if (use_strided_batched) {
520       a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
521       b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
522       c_stride = m * n;
523       a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
524       b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
525       c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
526       a_ptrs.push_back(&a_device_memory.back());
527       b_ptrs.push_back(&b_device_memory.back());
528       c_ptrs.push_back(&c_device_memory.back());
529     } else if (!bcast.IsBroadcastingRequired()) {
530       for (int64_t i = 0; i < batch_size; ++i) {
531         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
532         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
533         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
534         a_ptrs.push_back(&a_device_memory.back());
535         b_ptrs.push_back(&b_device_memory.back());
536         c_ptrs.push_back(&c_device_memory.back());
537       }
538     } else {
539       const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
540       const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
541       for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
542         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
543       }
544       for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
545         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
546       }
547       for (int64_t i = 0; i < batch_size; ++i) {
548         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
549         a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
550         b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
551         c_ptrs.push_back(&c_device_memory.back());
552       }
553     }
554 
555     typedef float Coefficient;
556 
557     // Blas does
558     // C = A x B
559     // where A, B and C are assumed to be in column major.
560     // We want the output to be in row-major, so we can compute
561     // C' = B' x A', where ' stands for transpose (not adjoint).
562     // TODO(yangzihao): Choose the best of the three strategies using autotune.
563     if (batch_size == 1) {
564       // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
565       // overhead of the scratch allocator and the batch interface.
566       // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
567       OP_REQUIRES_OK(context,
568                      stream->ThenBlasGemm(
569                          blas_transpose_b, blas_transpose_a, n, m, k,
570                          *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]),
571                          adj_x || trans_x ? m : k, c_ptrs[0], n));
572     } else if (use_strided_batched) {
573       bool blas_launch_status =
574           stream
575               ->ThenBlasGemmStridedBatched(
576                   blas_transpose_b, blas_transpose_a, n, m, k,
577                   static_cast<Coefficient>(1.0), *b_ptrs[0],
578                   adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
579                   adj_x || trans_x ? m : k, a_stride,
580                   static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
581                   batch_size)
582               .ok();
583       if (!blas_launch_status) {
584         context->SetStatus(errors::Internal(
585             "Blas xGEMMStridedBatched launch failed : a.shape=",
586             in_x.shape().DebugString(),
587             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
588             ", k=", k, ", batch_size=", batch_size));
589       }
590     } else {
591       BlasScratchAllocator scratch_allocator(context);
592       bool blas_launch_status =
593           stream
594               ->ThenBlasGemmBatchedWithScratch(
595                   blas_transpose_b, blas_transpose_a, n, m, k,
596                   static_cast<Coefficient>(1.0), b_ptrs,
597                   adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
598                   static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
599                   &scratch_allocator)
600               .ok();
601       if (!blas_launch_status) {
602         context->SetStatus(errors::Internal(
603             "Blas xGEMMBatched launch failed : a.shape=",
604             in_x.shape().DebugString(),
605             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
606             ", k=", k, ", batch_size=", batch_size));
607       }
608     }
609   }
610 };
611 
612 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
613 
614 template <typename Device, typename Ta, typename Tb, typename Tout>
615 class BaseBatchMatMulOp : public OpKernel {
616  public:
617   explicit BaseBatchMatMulOp(OpKernelConstruction* context,
618                              bool is_legacy_matmul)
619       : OpKernel(context) {
620     if (is_legacy_matmul) {
621       // The old MatMul kernel has "transpose_a/transpose_b" attributes.
622       OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
623       OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
624       adj_x_ = false;
625       adj_y_ = false;
626     } else {
627       OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
628       OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
629       trans_x_ = false;
630       trans_y_ = false;
631     }
632   }
633 
634   ~BaseBatchMatMulOp() override {}
635 
636   void Compute(OpKernelContext* ctx) override {
637     const Tensor& in0 = ctx->input(0);
638     const Tensor& in1 = ctx->input(1);
639 
640     const Status s = ValidateInputTensors(ctx, in0, in1);
641     if (!s.ok()) {
642       ctx->SetStatus(s);
643       return;
644     }
645 
646     MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
647     OP_REQUIRES(
648         ctx, bcast.IsValid(),
649         errors::InvalidArgument(
650             "In[0] and In[1] must have compatible batch dimensions: ",
651             in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
652 
653     TensorShape out_shape = bcast.output_batch_shape();
654     auto batch_size = bcast.output_batch_size();
655     auto d0 = in0.dim_size(in0.dims() - 2);
656     auto d1 = in0.dim_size(in0.dims() - 1);
657     Tensor in0_reshaped;
658     OP_REQUIRES(
659         ctx,
660         in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
661         errors::Internal("Failed to reshape In[0] from ",
662                          in0.shape().DebugString()));
663     auto d2 = in1.dim_size(in1.dims() - 2);
664     auto d3 = in1.dim_size(in1.dims() - 1);
665     Tensor in1_reshaped;
666     OP_REQUIRES(
667         ctx,
668         in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
669         errors::Internal("Failed to reshape In[1] from ",
670                          in1.shape().DebugString()));
671     if (adj_x_ || trans_x_) std::swap(d0, d1);
672     if (adj_y_ || trans_y_) std::swap(d2, d3);
673     OP_REQUIRES(ctx, d1 == d2,
674                 errors::InvalidArgument(
675                     "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
676                     in0.shape().DebugString(), " ", in1.shape().DebugString(),
677                     " ", adj_x_, " ", adj_y_));
678     out_shape.AddDim(d0);
679     out_shape.AddDim(d3);
680     Tensor* out = nullptr;
681     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
682     if (out->NumElements() == 0) {
683       return;
684     }
685     if (in0.NumElements() == 0 || in1.NumElements() == 0) {
686       functor::SetZeroFunctor<Device, Tout> f;
687       f(ctx->eigen_device<Device>(), out->flat<Tout>());
688       return;
689     }
690     Tensor out_reshaped;
691     OP_REQUIRES(ctx,
692                 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
693                 errors::Internal("Failed to reshape output from ",
694                                  out->shape().DebugString()));
695     if (std::is_same<Ta, bfloat16>::value &&
696         std::is_same<Tb, bfloat16>::value) {
697       bool is_cpu = std::is_same<Device, CPUDevice>::value;
698       OP_REQUIRES(ctx, is_cpu,
699                   errors::Internal("bfloat16 matmul is not supported by GPU"));
700       Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
701       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
702                                              &in0_reshaped_float));
703       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
704                                              &in1_reshaped_float));
705       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
706                                              &out_reshaped_float));
707 
708       // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
709       BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
710                       in0_reshaped_float.flat<float>().data(),
711                       in0_reshaped.NumElements());
712       BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
713                       in1_reshaped_float.flat<float>().data(),
714                       in1_reshaped.NumElements());
715 
716       LaunchBatchMatMul<Device, float>::Launch(
717           ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
718           trans_y_, bcast, &out_reshaped_float);
719       FloatToBFloat16(out_reshaped_float.flat<float>().data(),
720                       out_reshaped.flat<bfloat16>().data(), out->NumElements());
721     } else {
722       // Cast tensor to desired type to reuse Eigen.
723       // TODO(b/178749687): remove this cast if Eigen supports this natively.
724       if (!std::is_same<Ta, Tout>::value) {
725         in0_reshaped = CastTensor<Ta, Tout>(in0_reshaped);
726       }
727       if (!std::is_same<Tb, Tout>::value) {
728         in1_reshaped = CastTensor<Tb, Tout>(in1_reshaped);
729       }
730       LaunchBatchMatMul<Device, Tout>::Launch(ctx, in0_reshaped, in1_reshaped,
731                                               adj_x_, adj_y_, trans_x_,
732                                               trans_y_, bcast, &out_reshaped);
733     }
734   }
735 
736  protected:
737   virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
738                                       const Tensor& in1) = 0;
739 
740  private:
741   // TODO(171979567) Make the ops take both adj and transpose attributes.
742   bool adj_x_ = false;
743   bool adj_y_ = false;
744   bool trans_x_ = false;
745   bool trans_y_ = false;
746 
747   // Cast `t` from `SrcT` to `DstT`.
748   template <typename SrcT, typename DstT>
749   Tensor CastTensor(const Tensor& t) {
750     Tensor res = Tensor(DataTypeToEnum<DstT>::v(), t.shape());
751     res.flat<DstT>() = t.flat<SrcT>().template cast<DstT>();
752     return res;
753   }
754 };
755 
756 // BatchMatMul Op implementation which disallows broadcasting.
757 template <typename Device, typename Ta, typename Tb, typename Tout,
758           bool is_legacy_matmul = false>
759 class BatchMatMulOp : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> {
760  public:
761   explicit BatchMatMulOp(OpKernelConstruction* context)
762       : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context, is_legacy_matmul) {}
763 
764   ~BatchMatMulOp() override {}
765 
766  private:
767   Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
768                               const Tensor& in1) override {
769     // Disallow broadcasting support. Ensure that all batch dimensions of the
770     // input tensors match.
771     if (in0.dims() != in1.dims()) {
772       return errors::InvalidArgument(
773           "In[0] and In[1] has different ndims: ", in0.shape().DebugString(),
774           " vs. ", in1.shape().DebugString());
775     }
776     const int ndims = in0.dims();
777     if (is_legacy_matmul) {
778       if (ndims != 2) {
779         return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: ",
780                                        ndims);
781       }
782     } else {
783       if (ndims < 2) {
784         return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ",
785                                        ndims);
786       }
787       for (int i = 0; i < ndims - 2; ++i) {
788         if (in0.dim_size(i) != in1.dim_size(i)) {
789           return errors::InvalidArgument(
790               "In[0].dim(", i, ") and In[1].dim(", i,
791               ") must be the same: ", in0.shape().DebugString(), " vs ",
792               in1.shape().DebugString());
793         }
794       }
795     }
796     return Status::OK();
797   }
798 };
799 
800 // BatchMatMul Op implementation with broadcasting support.
801 template <typename Device, typename Ta, typename Tb, typename Tout>
802 class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> {
803  public:
804   explicit BatchMatMulV2Op(OpKernelConstruction* context)
805       : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context,
806                                                 /* is_legacy_matmul= */ false) {
807   }
808 
809   ~BatchMatMulV2Op() override {}
810 
811  private:
812   Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
813                               const Tensor& in1) override {
814     // Enable broadcasting support. Validity of broadcasting is checked in
815     // BaseBatchMatMulOp.
816     if (in0.dims() < 2) {
817       return errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims());
818     }
819     if (in1.dims() < 2) {
820       return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims());
821     }
822     return Status::OK();
823   }
824 };
825 
826 // Register for MatMul, BatchMatMul, BatchMatMulv2 where Tin = Tout.
827 #define REGISTER_BATCH_MATMUL_CPU(TYPE)                                   \
828   REGISTER_KERNEL_BUILDER(                                                \
829       Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"),   \
830       BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE>);                        \
831   REGISTER_KERNEL_BUILDER(                                                \
832       Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
833       BatchMatMulV2Op<CPUDevice, TYPE, TYPE, TYPE>);                      \
834   REGISTER_KERNEL_BUILDER(                                                \
835       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"),        \
836       BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>)
837 
838 #define REGISTER_BATCH_MATMUL_GPU(TYPE)                                   \
839   REGISTER_KERNEL_BUILDER(                                                \
840       Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"),   \
841       BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE>);                        \
842   REGISTER_KERNEL_BUILDER(                                                \
843       Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
844       BatchMatMulV2Op<GPUDevice, TYPE, TYPE, TYPE>);                      \
845   REGISTER_KERNEL_BUILDER(                                                \
846       Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"),        \
847       BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>)
848 
849 // Register for BatchMatMulv3 where Ta, Tb and Tout are not the same.
850 #define REGISTER_BATCH_MATMUL_TOUT_CPU(Ta, Tb, Tout)         \
851   REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3")              \
852                               .Device(DEVICE_CPU)            \
853                               .TypeConstraint<Ta>("Ta")      \
854                               .TypeConstraint<Tb>("Tb")      \
855                               .TypeConstraint<Tout>("Tout"), \
856                           BatchMatMulV2Op<CPUDevice, Ta, Tb, Tout>)
857 
858 #define REGISTER_BATCH_MATMUL_TOUT_GPU(Ta, Tb, Tout)         \
859   REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3")              \
860                               .Device(DEVICE_GPU)            \
861                               .TypeConstraint<Ta>("Ta")      \
862                               .TypeConstraint<Tb>("Tb")      \
863                               .TypeConstraint<Tout>("Tout"), \
864                           BatchMatMulV2Op<GPUDevice, Ta, Tb, Tout>)
865 
866 }  // namespace tensorflow
867 
868 #endif  // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
869