• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
20 
21 #define EIGEN_USE_THREADS
22 
23 #include <vector>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/type_traits.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/kernels/fill_functor.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/inlined_vector.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/matmul_bcast.h"
39 #include "tensorflow/core/util/work_sharder.h"
40 
41 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
42 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
43 #endif
44 
45 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
46 #include "tensorflow/core/platform/stream_executor.h"
47 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48 
49 namespace tensorflow {
50 
51 typedef Eigen::ThreadPoolDevice CPUDevice;
52 typedef Eigen::GpuDevice GPUDevice;
53 #ifdef TENSORFLOW_USE_SYCL
54 typedef Eigen::SyclDevice SYCLDevice;
55 #endif  // TENSORFLOW_USE_SYCL
56 
57 namespace {
58 
59 // Returns the pair of dimensions along which to perform Tensor contraction to
60 // emulate matrix multiplication.
61 // For matrix multiplication of 2D Tensors X and Y, X is contracted along
62 // second dimension and Y is contracted along the first dimension (if neither X
63 // nor Y is adjointed). The dimension to contract along is switched when any
64 // operand is adjointed.
65 // See http://en.wikipedia.org/wiki/Tensor_contraction
ContractionDims(bool adj_x,bool adj_y)66 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
67   return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
68 }
69 
70 // Parallel batch matmul kernel based on the multi-threaded tensor contraction
71 // in Eigen.
72 template <typename Scalar, bool IsComplex = true>
73 struct ParallelMatMulKernel {
ConjugateParallelMatMulKernel74   static void Conjugate(const OpKernelContext* context, Tensor* out) {
75     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
76     auto z = out->tensor<Scalar, 3>();
77     z.device(d) = z.conjugate();
78   }
79 
RunParallelMatMulKernel80   static void Run(const OpKernelContext* context, const Tensor& in_x,
81                   const Tensor in_y, bool adj_x, bool adj_y,
82                   const MatMulBCast& bcast, Tensor* out, int start, int limit) {
83     static_assert(IsComplex, "Complex type expected.");
84     auto Tx = in_x.tensor<Scalar, 3>();
85     auto Ty = in_y.tensor<Scalar, 3>();
86     auto Tz = out->tensor<Scalar, 3>();
87     // We use the identities
88     //   conj(a) * conj(b) = conj(a * b)
89     //   conj(a) * b = conj(a * conj(b))
90     // to halve the number of cases. The final conjugation of the result is
91     // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
92     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
93     contract_pairs[0] = ContractionDims(adj_x, adj_y);
94     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
95 
96     const bool should_bcast = bcast.IsBroadcastingRequired();
97     const auto& x_batch_indices = bcast.x_batch_indices();
98     const auto& y_batch_indices = bcast.y_batch_indices();
99     for (int64 i = start; i < limit; ++i) {
100       const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
101       const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
102 
103       auto x = Tx.template chip<0>(x_batch_index);
104       auto z = Tz.template chip<0>(i);
105       if (adj_x != adj_y) {
106         auto y = Ty.template chip<0>(y_batch_index).conjugate();
107         z.device(d) = x.contract(y, contract_pairs);
108       } else {
109         auto y = Ty.template chip<0>(y_batch_index);
110         z.device(d) = x.contract(y, contract_pairs);
111       }
112     }
113   }
114 };
115 
116 // The Eigen contraction kernel used here is very large and slow to compile,
117 // so we partially specialize ParallelMatMulKernel for real types to avoid all
118 // but one of the instantiations.
119 template <typename Scalar>
120 struct ParallelMatMulKernel<Scalar, false> {
121   static void Conjugate(const OpKernelContext* context, Tensor* out) {}
122 
123   static void Run(const OpKernelContext* context, const Tensor& in_x,
124                   const Tensor& in_y, bool adj_x, bool adj_y,
125                   const MatMulBCast& bcast, Tensor* out, int start, int limit) {
126     auto Tx = in_x.tensor<Scalar, 3>();
127     auto Ty = in_y.tensor<Scalar, 3>();
128     auto Tz = out->tensor<Scalar, 3>();
129     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
130     contract_pairs[0] = ContractionDims(adj_x, adj_y);
131     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
132 
133     const bool should_bcast = bcast.IsBroadcastingRequired();
134     const auto& x_batch_indices = bcast.x_batch_indices();
135     const auto& y_batch_indices = bcast.y_batch_indices();
136     for (int64 i = start; i < limit; ++i) {
137       const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
138       const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
139       auto x = Tx.template chip<0>(x_batch_index);
140       auto y = Ty.template chip<0>(y_batch_index);
141       auto z = Tz.template chip<0>(i);
142 
143       z.device(d) = x.contract(y, contract_pairs);
144     }
145   }
146 };
147 
148 // Sequential batch matmul kernel that calls the regular Eigen matmul.
149 // We prefer this over the tensor contraction because it performs
150 // better on vector-matrix and matrix-vector products.
151 template <typename Scalar>
152 struct SequentialMatMulKernel {
153   using Matrix =
154       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
155   using ConstMatrixMap = Eigen::Map<const Matrix>;
156   using MatrixMap = Eigen::Map<Matrix>;
157 
158   static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
159                                                       int slice) {
160     return ConstMatrixMap(
161         t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
162         t.dim_size(1), t.dim_size(2));
163   }
164 
165   static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
166     return MatrixMap(
167         t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
168         t->dim_size(1), t->dim_size(2));
169   }
170 
171   static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
172                   bool adj_y, const MatMulBCast& bcast, Tensor* out, int start,
173                   int limit) {
174     const bool should_bcast = bcast.IsBroadcastingRequired();
175     const auto& x_batch_indices = bcast.x_batch_indices();
176     const auto& y_batch_indices = bcast.y_batch_indices();
177     for (int64 i = start; i < limit; ++i) {
178       const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
179       const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
180       auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
181       auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
182       auto z = TensorSliceToEigenMatrix(out, i);
183       if (!adj_x) {
184         if (!adj_y) {
185           z.noalias() = x * y;
186         } else {
187           z.noalias() = x * y.adjoint();
188         }
189       } else {
190         if (!adj_y) {
191           z.noalias() = x.adjoint() * y;
192         } else {
193           z.noalias() = x.adjoint() * y.adjoint();
194         }
195       }
196     }
197   }
198 };
199 
200 }  // namespace
201 
202 template <typename Device, typename Scalar>
203 struct LaunchBatchMatMul;
204 
205 template <typename Scalar>
206 struct LaunchBatchMatMul<CPUDevice, Scalar> {
207   static void Launch(OpKernelContext* context, const Tensor& in_x,
208                      const Tensor& in_y, bool adj_x, bool adj_y,
209                      const MatMulBCast& bcast, Tensor* out) {
210     typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
211         ParallelMatMulKernel;
212     bool conjugate_result = false;
213 
214     // Number of matrix multiplies i.e. size of the batch.
215     const int64 batch_size = bcast.output_batch_size();
216     const int64 cost_per_unit =
217         in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
218     const int64 small_dim = std::min(
219         std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
220     // NOTE(nikhilsarda): This heuristic is optimal in benchmarks as of
221     // Jan 21, 2020.
222     const int64 kMaxCostOuterParallelism = 128 * 128;  // heuristic.
223     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
224     if (small_dim > 1 &&
225         (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
226       // Parallelize over inner dims.
227       // For large matrix products it is counter-productive to parallelize
228       // over the batch dimension.
229       ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, bcast, out,
230                                 0, batch_size);
231       conjugate_result = adj_x;
232     } else {
233       // Parallelize over outer dims. For small matrices and large batches, it
234       // is counter-productive to parallelize the inner matrix multiplies.
235       Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
236             cost_per_unit,
237             [&in_x, &in_y, adj_x, adj_y, &bcast, out](int start, int limit) {
238               SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
239                                                   bcast, out, start, limit);
240             });
241     }
242     if (conjugate_result) {
243       // We used one of the identities
244       //   conj(a) * conj(b) = conj(a * b)
245       //   conj(a) * b = conj(a * conj(b))
246       // above, we need to conjugate the final output. This is a
247       // no-op for non-complex types.
248       ParallelMatMulKernel::Conjugate(context, out);
249     }
250   }
251 };
252 
253 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
254 
255 namespace {
256 template <typename T>
257 se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
258   se::DeviceMemoryBase wrapped(const_cast<T*>(gpu_memory));
259   se::DeviceMemory<T> typed(wrapped);
260   return typed;
261 }
262 
263 class BlasScratchAllocator : public se::ScratchAllocator {
264  public:
265   using Stream = se::Stream;
266   using DeviceMemoryBytes = se::DeviceMemory<uint8>;
267 
268   BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
269 
270   int64 GetMemoryLimitInBytes() override { return -1; }
271 
272   se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
273       int64 byte_size) override {
274     Tensor temporary_memory;
275 
276     Status allocation_status(context_->allocate_temp(
277         DT_UINT8, TensorShape({byte_size}), &temporary_memory));
278     if (!allocation_status.ok()) {
279       return se::port::StatusOr<DeviceMemoryBytes>(
280           DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
281     }
282     // Hold the reference of the allocated tensors until the end of the
283     // allocator.
284     allocated_tensors_.push_back(temporary_memory);
285     return se::port::StatusOr<DeviceMemoryBytes>(
286         DeviceMemoryBytes::MakeFromByteSize(
287             temporary_memory.flat<uint8>().data(),
288             temporary_memory.flat<uint8>().size()));
289   }
290 
291  private:
292   OpKernelContext* context_;
293   std::vector<Tensor> allocated_tensors_;
294 };
295 }  // namespace
296 
297 template <typename Scalar>
298 struct LaunchBatchMatMul<GPUDevice, Scalar> {
299   static void Launch(OpKernelContext* context, const Tensor& in_x,
300                      const Tensor& in_y, bool adj_x, bool adj_y,
301                      const MatMulBCast& bcast, Tensor* out) {
302     constexpr se::blas::Transpose kTranspose =
303         is_complex<Scalar>::value ? se::blas::Transpose::kConjugateTranspose
304                                   : se::blas::Transpose::kTranspose;
305     se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
306                                    kTranspose};
307     const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
308     const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
309     const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
310     const int64 batch_size = bcast.output_batch_size();
311     auto blas_transpose_a = trans[adj_x];
312     auto blas_transpose_b = trans[adj_y];
313 
314     auto* stream = context->op_device_context()->stream();
315     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
316 
317     typedef se::DeviceMemory<Scalar> DeviceMemoryType;
318     std::vector<DeviceMemoryType> a_device_memory;
319     std::vector<DeviceMemoryType> b_device_memory;
320     std::vector<DeviceMemoryType> c_device_memory;
321     std::vector<DeviceMemoryType*> a_ptrs;
322     std::vector<DeviceMemoryType*> b_ptrs;
323     std::vector<DeviceMemoryType*> c_ptrs;
324     a_device_memory.reserve(bcast.x_batch_size());
325     b_device_memory.reserve(bcast.y_batch_size());
326     c_device_memory.reserve(batch_size);
327     a_ptrs.reserve(batch_size);
328     b_ptrs.reserve(batch_size);
329     c_ptrs.reserve(batch_size);
330     auto* a_base_ptr = in_x.template flat<Scalar>().data();
331     auto* b_base_ptr = in_y.template flat<Scalar>().data();
332     auto* c_base_ptr = out->template flat<Scalar>().data();
333 
334     if (!bcast.IsBroadcastingRequired()) {
335       for (int64 i = 0; i < batch_size; ++i) {
336         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
337         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
338         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
339         a_ptrs.push_back(&a_device_memory.back());
340         b_ptrs.push_back(&b_device_memory.back());
341         c_ptrs.push_back(&c_device_memory.back());
342       }
343     } else {
344       const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
345       const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
346       for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
347         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
348       }
349       for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
350         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
351       }
352       for (int64 i = 0; i < batch_size; ++i) {
353         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
354         a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
355         b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
356         c_ptrs.push_back(&c_device_memory.back());
357       }
358     }
359 
360     typedef Scalar Coefficient;
361 
362     // Blas does
363     // C = A x B
364     // where A, B and C are assumed to be in column major.
365     // We want the output to be in row-major, so we can compute
366     // C' = B' x A', where ' stands for transpose (not adjoint).
367     // TODO(yangzihao): Choose the best of the three strategies using autotune.
368     if (batch_size == 1) {
369       // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
370       // overhead of the scratch allocator and the batch interface.
371       if (n == 1 &&
372           blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
373           blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
374         // This is a matrix*vector multiply so use GEMV to compute A * b.
375         // Here we are multiplying in the natural order, so we have to flip
376         // the transposition flag to compensate for the tensor being stored
377         // row-major. Since GEMV doesn't provide a way to just conjugate an
378         // argument, we have to defer those cases to GEMM below.
379         auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
380                                 ? se::blas::Transpose::kNoTranspose
381                                 : se::blas::Transpose::kTranspose;
382         bool blas_launch_status =
383             stream
384                 ->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
385                                static_cast<Coefficient>(1.0), *(a_ptrs[0]),
386                                adj_x ? m : k, *(b_ptrs[0]), 1,
387                                static_cast<Coefficient>(0.0), c_ptrs[0], 1)
388                 .ok();
389         if (!blas_launch_status) {
390           context->SetStatus(errors::Internal(
391               "Blas xGEMV launch failed : a.shape=", in_x.shape().DebugString(),
392               ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
393               ", k=", k));
394         }
395       } else {
396         bool blas_launch_status =
397             stream
398                 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
399                                static_cast<Coefficient>(1.0), *(b_ptrs[0]),
400                                adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
401                                static_cast<Coefficient>(0.0), c_ptrs[0], n)
402                 .ok();
403         if (!blas_launch_status) {
404           context->SetStatus(errors::Internal(
405               "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
406               ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
407               ", k=", k));
408         }
409       }
410     } else {
411       BlasScratchAllocator scratch_allocator(context);
412       bool blas_launch_status =
413           stream
414               ->ThenBlasGemmBatchedWithScratch(
415                   blas_transpose_b, blas_transpose_a, n, m, k,
416                   static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
417                   adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
418                   batch_size, &scratch_allocator)
419               .ok();
420       if (!blas_launch_status) {
421         context->SetStatus(errors::Internal(
422             "Blas xGEMMBatched launch failed : a.shape=",
423             in_x.shape().DebugString(),
424             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
425             ", k=", k, ", batch_size=", batch_size));
426       }
427     }
428   }
429 };
430 
431 template <>
432 struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
433   static void Launch(OpKernelContext* context, const Tensor& in_x,
434                      const Tensor& in_y, bool adj_x, bool adj_y,
435                      const MatMulBCast& bcast, Tensor* out) {
436     typedef Eigen::half Scalar;
437     constexpr perftools::gputools::blas::Transpose kTranspose =
438         is_complex<Scalar>::value
439             ? perftools::gputools::blas::Transpose::kConjugateTranspose
440             : perftools::gputools::blas::Transpose::kTranspose;
441     perftools::gputools::blas::Transpose trans[] = {
442         perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
443     const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
444     const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
445     const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
446     const uint64 batch_size = bcast.output_batch_size();
447     auto blas_transpose_a = trans[adj_x];
448     auto blas_transpose_b = trans[adj_y];
449 
450     auto* stream = context->op_device_context()->stream();
451     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
452 
453     typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
454     std::vector<DeviceMemoryType> a_device_memory;
455     std::vector<DeviceMemoryType> b_device_memory;
456     std::vector<DeviceMemoryType> c_device_memory;
457     std::vector<DeviceMemoryType*> a_ptrs;
458     std::vector<DeviceMemoryType*> b_ptrs;
459     std::vector<DeviceMemoryType*> c_ptrs;
460     a_device_memory.reserve(bcast.x_batch_size());
461     b_device_memory.reserve(bcast.y_batch_size());
462     c_device_memory.reserve(batch_size);
463     a_ptrs.reserve(batch_size);
464     b_ptrs.reserve(batch_size);
465     c_ptrs.reserve(batch_size);
466     auto* a_base_ptr = in_x.template flat<Scalar>().data();
467     auto* b_base_ptr = in_y.template flat<Scalar>().data();
468     auto* c_base_ptr = out->template flat<Scalar>().data();
469 
470     if (!bcast.IsBroadcastingRequired()) {
471       for (int64 i = 0; i < batch_size; ++i) {
472         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
473         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
474         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
475         a_ptrs.push_back(&a_device_memory.back());
476         b_ptrs.push_back(&b_device_memory.back());
477         c_ptrs.push_back(&c_device_memory.back());
478       }
479     } else {
480       const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
481       const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
482       for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
483         a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
484       }
485       for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
486         b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
487       }
488       for (int64 i = 0; i < batch_size; ++i) {
489         c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
490         a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
491         b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
492         c_ptrs.push_back(&c_device_memory.back());
493       }
494     }
495 
496     typedef float Coefficient;
497 
498     // Blas does
499     // C = A x B
500     // where A, B and C are assumed to be in column major.
501     // We want the output to be in row-major, so we can compute
502     // C' = B' x A', where ' stands for transpose (not adjoint).
503     // TODO(yangzihao): Choose the best of the three strategies using autotune.
504     if (batch_size == 1) {
505       // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
506       // overhead of the scratch allocator and the batch interface.
507       // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
508       bool blas_launch_status =
509           stream
510               ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
511                              static_cast<Coefficient>(1.0), *(b_ptrs[0]),
512                              adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
513                              static_cast<Coefficient>(0.0), c_ptrs[0], n)
514               .ok();
515       if (!blas_launch_status) {
516         context->SetStatus(errors::Internal(
517             "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
518             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
519             ", k=", k));
520       }
521     } else {
522       BlasScratchAllocator scratch_allocator(context);
523       bool blas_launch_status =
524           stream
525               ->ThenBlasGemmBatchedWithScratch(
526                   blas_transpose_b, blas_transpose_a, n, m, k,
527                   static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
528                   adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
529                   batch_size, &scratch_allocator)
530               .ok();
531       if (!blas_launch_status) {
532         context->SetStatus(errors::Internal(
533             "Blas xGEMMBatched launch failed : a.shape=",
534             in_x.shape().DebugString(),
535             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
536             ", k=", k, ", batch_size=", batch_size));
537       }
538     }
539   }
540 };
541 
542 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
543 
544 #ifdef TENSORFLOW_USE_SYCL
545 template <typename Scalar>
546 struct ParallelMatMulKernelSYCL {
547   static void Run(const OpKernelContext* context, const Tensor& in_x,
548                   const Tensor& in_y, bool adj_x, bool adj_y,
549                   const MatMulBCast& bcast, Tensor* out, int start, int limit) {
550     auto Tx = in_x.tensor<Scalar, 3>();
551     auto Ty = in_y.tensor<Scalar, 3>();
552     auto Tz = out->tensor<Scalar, 3>();
553     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
554     contract_pairs[0] = ContractionDims(adj_x, adj_y);
555     auto d = context->eigen_sycl_device();
556 
557     const bool should_bcast = bcast.IsBroadcastingRequired();
558     const auto& x_batch_indices = bcast.x_batch_indices();
559     const auto& y_batch_indices = bcast.y_batch_indices();
560     for (int64 i = start; i < limit; ++i) {
561       const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
562       const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
563 
564       auto x = Tx.template chip<0>(x_batch_index);
565       auto y = Ty.template chip<0>(y_batch_index);
566       auto z = Tz.template chip<0>(i);
567       z.device(d) = x.contract(y, contract_pairs);
568     }
569   }
570 };
571 
572 template <typename Scalar>
573 struct LaunchBatchMatMul<SYCLDevice, Scalar> {
574   static void Launch(OpKernelContext* context, const Tensor& in_x,
575                      const Tensor& in_y, bool adj_x, bool adj_y,
576                      const MatMulBCast& bcast, Tensor* out) {
577     // Number of matrix multiplies i.e. size of the batch.
578     const int64 batch_size = bcast.output_batch_size();
579     ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y,
580                                           bcast, out, 0, batch_size);
581   }
582 };
583 #endif  // TENSORFLOW_USE_SYCL
584 
585 template <typename Device, typename Scalar>
586 class BaseBatchMatMulOp : public OpKernel {
587  public:
588   explicit BaseBatchMatMulOp(OpKernelConstruction* context)
589       : OpKernel(context) {
590     OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
591     OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
592   }
593 
594   ~BaseBatchMatMulOp() override {}
595 
596   void Compute(OpKernelContext* ctx) override {
597     const Tensor& in0 = ctx->input(0);
598     const Tensor& in1 = ctx->input(1);
599 
600     ValidateInputTensors(ctx, in0, in1);
601 
602     MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
603     OP_REQUIRES(
604         ctx, bcast.IsValid(),
605         errors::InvalidArgument(
606             "In[0] and In[1] must have compatible batch dimensions: ",
607             in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
608 
609     TensorShape out_shape = bcast.output_batch_shape();
610     auto batch_size = bcast.output_batch_size();
611     auto d0 = in0.dim_size(in0.dims() - 2);
612     auto d1 = in0.dim_size(in0.dims() - 1);
613     Tensor in0_reshaped;
614     OP_REQUIRES(
615         ctx,
616         in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
617         errors::Internal("Failed to reshape In[0] from ",
618                          in0.shape().DebugString()));
619     auto d2 = in1.dim_size(in1.dims() - 2);
620     auto d3 = in1.dim_size(in1.dims() - 1);
621     Tensor in1_reshaped;
622     OP_REQUIRES(
623         ctx,
624         in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
625         errors::Internal("Failed to reshape In[1] from ",
626                          in1.shape().DebugString()));
627     if (adj_x_) std::swap(d0, d1);
628     if (adj_y_) std::swap(d2, d3);
629     OP_REQUIRES(ctx, d1 == d2,
630                 errors::InvalidArgument(
631                     "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
632                     in0.shape().DebugString(), " ", in1.shape().DebugString(),
633                     " ", adj_x_, " ", adj_y_));
634     out_shape.AddDim(d0);
635     out_shape.AddDim(d3);
636     Tensor* out = nullptr;
637     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
638     if (out->NumElements() == 0) {
639       return;
640     }
641     if (in0.NumElements() == 0 || in1.NumElements() == 0) {
642       functor::SetZeroFunctor<Device, Scalar> f;
643       f(ctx->eigen_device<Device>(), out->flat<Scalar>());
644       return;
645     }
646     Tensor out_reshaped;
647     OP_REQUIRES(ctx,
648                 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
649                 errors::Internal("Failed to reshape output from ",
650                                  out->shape().DebugString()));
651     LaunchBatchMatMul<Device, Scalar>::Launch(
652         ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, bcast, &out_reshaped);
653   }
654 
655  protected:
656   virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
657                                     const Tensor& in1) = 0;
658 
659  private:
660   bool adj_x_;
661   bool adj_y_;
662 };
663 
664 // BatchMatMul Op implementation which disallows broadcasting.
665 template <typename Device, typename Scalar>
666 class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
667  public:
668   explicit BatchMatMulOp(OpKernelConstruction* context)
669       : BaseBatchMatMulOp<Device, Scalar>(context) {}
670 
671   ~BatchMatMulOp() override {}
672 
673  private:
674   void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
675                             const Tensor& in1) override {
676     // Disallow broadcasting support. Ensure that all batch dimensions of the
677     // input tensors match.
678     OP_REQUIRES(ctx, in0.dims() == in1.dims(),
679                 errors::InvalidArgument("In[0] and In[1] has different ndims: ",
680                                         in0.shape().DebugString(), " vs. ",
681                                         in1.shape().DebugString()));
682     const int ndims = in0.dims();
683     OP_REQUIRES(
684         ctx, ndims >= 2,
685         errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
686     for (int i = 0; i < ndims - 2; ++i) {
687       OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
688                   errors::InvalidArgument(
689                       "In[0].dim(", i, ") and In[1].dim(", i,
690                       ") must be the same: ", in0.shape().DebugString(), " vs ",
691                       in1.shape().DebugString()));
692     }
693   }
694 };
695 
696 // BatchMatMul Op implementation with broadcasting support.
697 template <typename Device, typename Scalar>
698 class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
699  public:
700   explicit BatchMatMulV2Op(OpKernelConstruction* context)
701       : BaseBatchMatMulOp<Device, Scalar>(context) {}
702 
703   ~BatchMatMulV2Op() override {}
704 
705  private:
706   void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
707                             const Tensor& in1) override {
708     // Enable broadcasting support. Validity of broadcasting is checked in
709     // BaseBatchMatMulOp.
710     OP_REQUIRES(
711         ctx, in0.dims() >= 2,
712         errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
713     OP_REQUIRES(
714         ctx, in1.dims() >= 2,
715         errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
716   }
717 };
718 
719 #define REGISTER_BATCH_MATMUL_CPU(TYPE)                                   \
720   REGISTER_KERNEL_BUILDER(                                                \
721       Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"),   \
722       BatchMatMulOp<CPUDevice, TYPE>);                                    \
723   REGISTER_KERNEL_BUILDER(                                                \
724       Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
725       BatchMatMulV2Op<CPUDevice, TYPE>)
726 
727 #define REGISTER_BATCH_MATMUL_GPU(TYPE)                                   \
728   REGISTER_KERNEL_BUILDER(                                                \
729       Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"),   \
730       BatchMatMulOp<GPUDevice, TYPE>);                                    \
731   REGISTER_KERNEL_BUILDER(                                                \
732       Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
733       BatchMatMulV2Op<GPUDevice, TYPE>)
734 
735 #ifdef TENSORFLOW_USE_SYCL
736 #define REGISTER_BATCH_MATMUL_SYCL(TYPE)                                   \
737   REGISTER_KERNEL_BUILDER(                                                 \
738       Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"),   \
739       BatchMatMulOp<SYCLDevice, TYPE>);                                    \
740   REGISTER_KERNEL_BUILDER(                                                 \
741       Name("BatchMatMulV2").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
742       BatchMatMulV2Op<SYCLDevice, TYPE>)
743 #endif  // TENSORFLOW_USE_SYCL
744 }  // namespace tensorflow
745 
746 #endif  // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
747