• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
20 
21 #define EIGEN_USE_THREADS
22 
23 #include <vector>
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/type_traits.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/work_sharder.h"
36 
37 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
38 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
39 #endif
40 
41 #if GOOGLE_CUDA
42 #include "tensorflow/core/platform/stream_executor.h"
43 #endif  // GOOGLE_CUDA
44 
45 namespace tensorflow {
46 
47 typedef Eigen::ThreadPoolDevice CPUDevice;
48 typedef Eigen::GpuDevice GPUDevice;
49 #ifdef TENSORFLOW_USE_SYCL
50 typedef Eigen::SyclDevice SYCLDevice;
51 #endif  // TENSORFLOW_USE_SYCL
52 
53 namespace {
54 
55 // Returns the pair of dimensions along which to perform Tensor contraction to
56 // emulate matrix multiplication.
57 // For matrix multiplication of 2D Tensors X and Y, X is contracted along
58 // second dimension and Y is contracted along the first dimension (if neither X
59 // nor Y is adjointed). The dimension to contract along is switched when any
60 // operand is adjointed.
61 // See http://en.wikipedia.org/wiki/Tensor_contraction
ContractionDims(bool adj_x,bool adj_y)62 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
63   return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
64 }
65 
66 // Parallel batch matmul kernel based on the multi-threaded tensor contraction
67 // in Eigen.
68 template <typename Scalar, bool IsComplex = true>
69 struct ParallelMatMulKernel {
ConjugateParallelMatMulKernel70   static void Conjugate(const OpKernelContext* context, Tensor* out) {
71     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
72     auto z = out->tensor<Scalar, 3>();
73     z.device(d) = z.conjugate();
74   }
75 
RunParallelMatMulKernel76   static void Run(const OpKernelContext* context, const Tensor& in_x,
77                   const Tensor in_y, bool adj_x, bool adj_y, Tensor* out,
78                   int start, int limit) {
79     static_assert(IsComplex, "Complex type expected.");
80     auto Tx = in_x.tensor<Scalar, 3>();
81     auto Ty = in_y.tensor<Scalar, 3>();
82     auto Tz = out->tensor<Scalar, 3>();
83     // We use the identities
84     //   conj(a) * conj(b) = conj(a * b)
85     //   conj(a) * b = conj(a * conj(b))
86     // to halve the number of cases. The final conjugation of the result is
87     // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
88     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
89     contract_pairs[0] = ContractionDims(adj_x, adj_y);
90     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
91     for (int i = start; i < limit; ++i) {
92       auto x = Tx.template chip<0>(i);
93       auto z = Tz.template chip<0>(i);
94       if (adj_x != adj_y) {
95         auto y = Ty.template chip<0>(i).conjugate();
96         z.device(d) = x.contract(y, contract_pairs);
97       } else {
98         auto y = Ty.template chip<0>(i);
99         z.device(d) = x.contract(y, contract_pairs);
100       }
101     }
102   }
103 };
104 
105 // The Eigen contraction kernel used here is very large and slow to compile,
106 // so we partially specialize ParallelMatMulKernel for real types to avoid all
107 // but one of the instantiations.
108 template <typename Scalar>
109 struct ParallelMatMulKernel<Scalar, false> {
110   static void Conjugate(const OpKernelContext* context, Tensor* out) {}
111 
112   static void Run(const OpKernelContext* context, const Tensor& in_x,
113                   const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out,
114                   int start, int limit) {
115     auto Tx = in_x.tensor<Scalar, 3>();
116     auto Ty = in_y.tensor<Scalar, 3>();
117     auto Tz = out->tensor<Scalar, 3>();
118     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
119     contract_pairs[0] = ContractionDims(adj_x, adj_y);
120     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
121     for (int i = start; i < limit; ++i) {
122       auto x = Tx.template chip<0>(i);
123       auto y = Ty.template chip<0>(i);
124       auto z = Tz.template chip<0>(i);
125       z.device(d) = x.contract(y, contract_pairs);
126     }
127   }
128 };
129 
130 // TODO(rmlarsen): Get rid of this when we have upstreamed improvements
131 // for matrix*vector and vector*matrix to Eigen's general matrix product.
132 template <typename Tx, typename Ty, typename Tz>
133 static void Multiply(bool adj_x, bool adj_y, Tx x, Ty y, Tz z) {
134   if (!adj_x) {
135     if (!adj_y) {
136       z.noalias() = x * y;
137     } else {
138       z.noalias() = x * y.adjoint();
139     }
140   } else {
141     if (!adj_y) {
142       z.noalias() = x.adjoint() * y;
143     } else {
144       z.noalias() = x.adjoint() * y.adjoint();
145     }
146   }
147 }
148 
149 // Sequential batch matmul kernel that calls the regular Eigen matmul.
150 // We prefer this over the tensor contraction because it performs
151 // better on vector-matrix and matrix-vector products.
152 template <typename Scalar>
153 struct SequentialMatMulKernel {
154   using Matrix =
155       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
156   using ConstMatrixMap = Eigen::Map<const Matrix>;
157   using MatrixMap = Eigen::Map<Matrix>;
158 
159   static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
160                                                       int slice) {
161     return ConstMatrixMap(
162         t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
163         t.dim_size(1), t.dim_size(2));
164   }
165 
166   static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
167     return MatrixMap(
168         t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
169         t->dim_size(1), t->dim_size(2));
170   }
171 
172   static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
173                   bool adj_y, Tensor* out, int start, int limit) {
174     for (int i = start; i < limit; ++i) {
175       auto x = ConstTensorSliceToEigenMatrix(in_x, i);
176       auto y = ConstTensorSliceToEigenMatrix(in_y, i);
177       auto z = TensorSliceToEigenMatrix(out, i);
178       // TODO(rmlarsen): Get rid of the special casing here when we have
179       // upstreamed improvements for matrix*vector and vector*matrix to
180       // Eigen's general matrix product.
181       if (!adj_x && x.rows() == 1) {
182         Multiply(adj_x, adj_y, x.row(0), y, z);
183       } else if (adj_x && x.cols() == 1) {
184         Multiply(adj_x, adj_y, x.col(0), y, z);
185       } else if (!adj_y && y.cols() == 1) {
186         Multiply(adj_x, adj_y, x, y.col(0), z);
187       } else if (adj_y && y.rows() == 1) {
188         Multiply(adj_x, adj_y, x, y.row(0), z);
189       } else {
190         Multiply(adj_x, adj_y, x, y, z);
191       }
192     }
193   }
194 };
195 
196 }  // namespace
197 
198 template <typename Device, typename Scalar>
199 struct LaunchBatchMatMul;
200 
201 template <typename Scalar>
202 struct LaunchBatchMatMul<CPUDevice, Scalar> {
203   static void Launch(OpKernelContext* context, const Tensor& in_x,
204                      const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
205     typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
206         ParallelMatMulKernel;
207     bool conjugate_result = false;
208 
209     // Number of matrix multiplies i.e. size of the batch.
210     const int64 batch_size = in_x.dim_size(0);
211     const int64 cost_per_unit =
212         in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
213     const int64 small_dim = std::min(
214         std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
215     const int64 kMaxCostOuterParallelism = 128 * 128 * 256;  // heuristic.
216     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
217     if (small_dim > 1 &&
218         (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
219       // Parallelize over inner dims.
220       // For large matrix products it is counter-productive to parallelize
221       // over the batch dimension.
222       ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, out, 0,
223                                 batch_size);
224       conjugate_result = adj_x;
225     } else {
226       // Parallelize over outer dims. For small matrices and large batches, it
227       // is counter-productive to parallelize the inner matrix multiplies.
228       Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
229             cost_per_unit,
230             [&in_x, &in_y, adj_x, adj_y, out](int start, int limit) {
231               SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y, out,
232                                                   start, limit);
233             });
234     }
235     if (conjugate_result) {
236       // We used one of the identities
237       //   conj(a) * conj(b) = conj(a * b)
238       //   conj(a) * b = conj(a * conj(b))
239       // above, we need to conjugate the final output. This is a
240       // no-op for non-complex types.
241       ParallelMatMulKernel::Conjugate(context, out);
242     }
243   }
244 };
245 
246 #if GOOGLE_CUDA
247 
248 namespace {
249 template <typename T>
250 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
251   se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
252   se::DeviceMemory<T> typed(wrapped);
253   return typed;
254 }
255 
256 class CublasScratchAllocator : public se::ScratchAllocator {
257  public:
258   using Stream = se::Stream;
259   using DeviceMemoryBytes = se::DeviceMemory<uint8>;
260 
261   CublasScratchAllocator(OpKernelContext* context) : context_(context) {}
262 
263   int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; }
264 
265   se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
266       Stream* stream, int64 byte_size) override {
267     Tensor temporary_memory;
268 
269     Status allocation_status(context_->allocate_temp(
270         DT_UINT8, TensorShape({byte_size}), &temporary_memory));
271     if (!allocation_status.ok()) {
272       return se::port::StatusOr<DeviceMemoryBytes>(
273           DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
274     }
275     // Hold the reference of the allocated tensors until the end of the
276     // allocator.
277     allocated_tensors_.push_back(temporary_memory);
278     return se::port::StatusOr<DeviceMemoryBytes>(
279         DeviceMemoryBytes::MakeFromByteSize(
280             temporary_memory.flat<uint8>().data(),
281             temporary_memory.flat<uint8>().size()));
282   }
283 
284  private:
285   OpKernelContext* context_;
286   std::vector<Tensor> allocated_tensors_;
287 };
288 }  // namespace
289 
290 template <typename Scalar>
291 struct LaunchBatchMatMul<GPUDevice, Scalar> {
292   static void Launch(OpKernelContext* context, const Tensor& in_x,
293                      const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
294     constexpr se::blas::Transpose kTranspose =
295         is_complex<Scalar>::value ? se::blas::Transpose::kConjugateTranspose
296                                   : se::blas::Transpose::kTranspose;
297     se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
298                                    kTranspose};
299     const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
300     const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
301     const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
302     const uint64 batch_size = in_x.dim_size(0);
303     auto blas_transpose_a = trans[adj_x];
304     auto blas_transpose_b = trans[adj_y];
305 
306     auto* stream = context->op_device_context()->stream();
307     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
308 
309     typedef se::DeviceMemory<Scalar> DeviceMemoryType;
310     std::vector<DeviceMemoryType> a_device_memory;
311     std::vector<DeviceMemoryType> b_device_memory;
312     std::vector<DeviceMemoryType> c_device_memory;
313     std::vector<DeviceMemoryType*> a_ptrs;
314     std::vector<DeviceMemoryType*> b_ptrs;
315     std::vector<DeviceMemoryType*> c_ptrs;
316     a_device_memory.reserve(batch_size);
317     b_device_memory.reserve(batch_size);
318     c_device_memory.reserve(batch_size);
319     a_ptrs.reserve(batch_size);
320     b_ptrs.reserve(batch_size);
321     c_ptrs.reserve(batch_size);
322     auto* a_base_ptr = in_x.template flat<Scalar>().data();
323     auto* b_base_ptr = in_y.template flat<Scalar>().data();
324     auto* c_base_ptr = out->template flat<Scalar>().data();
325     for (int64 i = 0; i < batch_size; ++i) {
326       a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
327       b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
328       c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
329       a_ptrs.push_back(&a_device_memory.back());
330       b_ptrs.push_back(&b_device_memory.back());
331       c_ptrs.push_back(&c_device_memory.back());
332     }
333 
334     typedef Scalar Coefficient;
335 
336     // Cublas does
337     // C = A x B
338     // where A, B and C are assumed to be in column major.
339     // We want the output to be in row-major, so we can compute
340     // C' = B' x A', where ' stands for transpose (not adjoint).
341     // TODO(yangzihao): Choose the best of the three strategies using autotune.
342     if (batch_size == 1) {
343       // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
344       // overhead of the scratch allocator and the batch interface.
345       if (n == 1 &&
346           blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
347           blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
348         // This is a matrix*vector multiply so use GEMV to compute A * b.
349         // Here we are multiplying in the natural order, so we have to flip
350         // the transposition flag to compensate for the tensor being stored
351         // row-major. Since GEMV doesn't provide a way to just conjugate an
352         // argument, we have to defer those cases to GEMM below.
353         auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
354                                 ? se::blas::Transpose::kNoTranspose
355                                 : se::blas::Transpose::kTranspose;
356         bool blas_launch_status =
357             stream
358                 ->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
359                                static_cast<Coefficient>(1.0), *(a_ptrs[0]),
360                                adj_x ? m : k, *(b_ptrs[0]), 1,
361                                static_cast<Coefficient>(0.0), c_ptrs[0], 1)
362                 .ok();
363         if (!blas_launch_status) {
364           context->SetStatus(errors::Internal(
365               "Blas xGEMV launch failed : a.shape=", in_x.shape().DebugString(),
366               ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
367               ", k=", k));
368         }
369       } else {
370         bool blas_launch_status =
371             stream
372                 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
373                                static_cast<Coefficient>(1.0), *(b_ptrs[0]),
374                                adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
375                                static_cast<Coefficient>(0.0), c_ptrs[0], n)
376                 .ok();
377         if (!blas_launch_status) {
378           context->SetStatus(errors::Internal(
379               "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
380               ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
381               ", k=", k));
382         }
383       }
384     } else {
385       CublasScratchAllocator scratch_allocator(context);
386       bool blas_launch_status =
387           stream
388               ->ThenBlasGemmBatchedWithScratch(
389                   blas_transpose_b, blas_transpose_a, n, m, k,
390                   static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
391                   adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
392                   batch_size, &scratch_allocator)
393               .ok();
394       if (!blas_launch_status) {
395         context->SetStatus(errors::Internal(
396             "Blas xGEMMBatched launch failed : a.shape=",
397             in_x.shape().DebugString(),
398             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
399             ", k=", k, ", batch_size=", batch_size));
400       }
401     }
402   }
403 };
404 
405 template <>
406 struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
407   static void Launch(OpKernelContext* context, const Tensor& in_x,
408                      const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
409     typedef Eigen::half Scalar;
410     constexpr perftools::gputools::blas::Transpose kTranspose =
411         is_complex<Scalar>::value
412             ? perftools::gputools::blas::Transpose::kConjugateTranspose
413             : perftools::gputools::blas::Transpose::kTranspose;
414     perftools::gputools::blas::Transpose trans[] = {
415         perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
416     const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
417     const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
418     const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
419     const uint64 batch_size = in_x.dim_size(0);
420     auto blas_transpose_a = trans[adj_x];
421     auto blas_transpose_b = trans[adj_y];
422 
423     auto* stream = context->op_device_context()->stream();
424     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
425 
426     typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
427     std::vector<DeviceMemoryType> a_device_memory;
428     std::vector<DeviceMemoryType> b_device_memory;
429     std::vector<DeviceMemoryType> c_device_memory;
430     std::vector<DeviceMemoryType*> a_ptrs;
431     std::vector<DeviceMemoryType*> b_ptrs;
432     std::vector<DeviceMemoryType*> c_ptrs;
433     a_device_memory.reserve(batch_size);
434     b_device_memory.reserve(batch_size);
435     c_device_memory.reserve(batch_size);
436     a_ptrs.reserve(batch_size);
437     b_ptrs.reserve(batch_size);
438     c_ptrs.reserve(batch_size);
439     auto* a_base_ptr = in_x.template flat<Scalar>().data();
440     auto* b_base_ptr = in_y.template flat<Scalar>().data();
441     auto* c_base_ptr = out->template flat<Scalar>().data();
442     for (int64 i = 0; i < batch_size; ++i) {
443       a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
444       b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
445       c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
446       a_ptrs.push_back(&a_device_memory.back());
447       b_ptrs.push_back(&b_device_memory.back());
448       c_ptrs.push_back(&c_device_memory.back());
449     }
450 
451     typedef float Coefficient;
452 
453     // Cublas does
454     // C = A x B
455     // where A, B and C are assumed to be in column major.
456     // We want the output to be in row-major, so we can compute
457     // C' = B' x A', where ' stands for transpose (not adjoint).
458     // TODO(yangzihao): Choose the best of the three strategies using autotune.
459     if (batch_size == 1) {
460       // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
461       // overhead of the scratch allocator and the batch interface.
462       // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
463       bool blas_launch_status =
464           stream
465               ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
466                              static_cast<Coefficient>(1.0), *(b_ptrs[0]),
467                              adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
468                              static_cast<Coefficient>(0.0), c_ptrs[0], n)
469               .ok();
470       if (!blas_launch_status) {
471         context->SetStatus(errors::Internal(
472             "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
473             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
474             ", k=", k));
475       }
476     } else {
477       CublasScratchAllocator scratch_allocator(context);
478       bool blas_launch_status =
479           stream
480               ->ThenBlasGemmBatchedWithScratch(
481                   blas_transpose_b, blas_transpose_a, n, m, k,
482                   static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
483                   adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
484                   batch_size, &scratch_allocator)
485               .ok();
486       if (!blas_launch_status) {
487         context->SetStatus(
488             errors::Internal("Blas xGEMMBatched launch failed : a.shape=",
489                              in_x.shape().DebugString(), ", b.shape=",
490                              in_y.shape().DebugString(), ", m=", m, ", n=", n,
491                              ", k=", k, ", batch_size=", batch_size));
492       }
493     }
494   }
495 };
496 
497 #endif  // GOOGLE_CUDA
498 
499 #ifdef TENSORFLOW_USE_SYCL
500 template <typename Scalar>
501 struct ParallelMatMulKernelSYCL {
502   static void Run(const OpKernelContext* context, const Tensor& in_x,
503                   const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out,
504                   int start, int limit) {
505     auto Tx = in_x.tensor<Scalar, 3>();
506     auto Ty = in_y.tensor<Scalar, 3>();
507     auto Tz = out->tensor<Scalar, 3>();
508     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
509     contract_pairs[0] = ContractionDims(adj_x, adj_y);
510     auto d = context->eigen_sycl_device();
511     for (int i = start; i < limit; ++i) {
512       auto x = Tx.template chip<0>(i);
513       auto y = Ty.template chip<0>(i);
514       auto z = Tz.template chip<0>(i);
515       z.device(d) = x.contract(y, contract_pairs);
516     }
517   }
518 };
519 
520 template <typename Scalar>
521 struct LaunchBatchMatMul<SYCLDevice, Scalar> {
522   static void Launch(OpKernelContext* context, const Tensor& in_x,
523                      const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
524     // Number of matrix multiplies i.e. size of the batch.
525     const int64 batch_size = in_x.dim_size(0);
526     ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y,
527                                           out, 0, batch_size);
528   }
529 };
530 #endif  // TENSORFLOW_USE_SYCL
531 
532 template <typename Device, typename Scalar>
533 class BatchMatMul : public OpKernel {
534  public:
535   explicit BatchMatMul(OpKernelConstruction* context) : OpKernel(context) {
536     OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
537     OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
538   }
539 
540   virtual ~BatchMatMul() {}
541 
542   void Compute(OpKernelContext* ctx) override {
543     const Tensor& in0 = ctx->input(0);
544     const Tensor& in1 = ctx->input(1);
545     OP_REQUIRES(ctx, in0.dims() == in1.dims(),
546                 errors::InvalidArgument("In[0] and In[1] has different ndims: ",
547                                         in0.shape().DebugString(), " vs. ",
548                                         in1.shape().DebugString()));
549     const int ndims = in0.dims();
550     OP_REQUIRES(
551         ctx, ndims >= 2,
552         errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
553     TensorShape out_shape;
554     for (int i = 0; i < ndims - 2; ++i) {
555       OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
556                   errors::InvalidArgument(
557                       "In[0].dim(", i, ") and In[1].dim(", i,
558                       ") must be the same: ", in0.shape().DebugString(), " vs ",
559                       in1.shape().DebugString()));
560       out_shape.AddDim(in0.dim_size(i));
561     }
562     auto n = (ndims == 2) ? 1 : out_shape.num_elements();
563     auto d0 = in0.dim_size(ndims - 2);
564     auto d1 = in0.dim_size(ndims - 1);
565     Tensor in0_reshaped;
566     CHECK(in0_reshaped.CopyFrom(in0, TensorShape({n, d0, d1})));
567     auto d2 = in1.dim_size(ndims - 2);
568     auto d3 = in1.dim_size(ndims - 1);
569     Tensor in1_reshaped;
570     CHECK(in1_reshaped.CopyFrom(in1, TensorShape({n, d2, d3})));
571     if (adj_x_) std::swap(d0, d1);
572     if (adj_y_) std::swap(d2, d3);
573     OP_REQUIRES(ctx, d1 == d2,
574                 errors::InvalidArgument(
575                     "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
576                     in0.shape().DebugString(), " ", in1.shape().DebugString(),
577                     " ", adj_x_, " ", adj_y_));
578     out_shape.AddDim(d0);
579     out_shape.AddDim(d3);
580     Tensor* out = nullptr;
581     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
582     if (out->NumElements() == 0) {
583       return;
584     }
585     if (in0.NumElements() == 0 || in1.NumElements() == 0) {
586       functor::SetZeroFunctor<Device, Scalar> f;
587       f(ctx->eigen_device<Device>(), out->flat<Scalar>());
588       return;
589     }
590     Tensor out_reshaped;
591     CHECK(out_reshaped.CopyFrom(*out, TensorShape({n, d0, d3})));
592     LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
593                                               adj_x_, adj_y_, &out_reshaped);
594   }
595 
596  private:
597   bool adj_x_;
598   bool adj_y_;
599 };
600 
601 #define REGISTER_BATCH_MATMUL_CPU(TYPE)                                 \
602   REGISTER_KERNEL_BUILDER(                                              \
603       Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
604       BatchMatMul<CPUDevice, TYPE>)
605 
606 #define REGISTER_BATCH_MATMUL_GPU(TYPE)                                 \
607   REGISTER_KERNEL_BUILDER(                                              \
608       Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
609       BatchMatMul<GPUDevice, TYPE>)
610 
611 #ifdef TENSORFLOW_USE_SYCL
612 #define REGISTER_BATCH_MATMUL_SYCL(TYPE)                                 \
613   REGISTER_KERNEL_BUILDER(                                               \
614       Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
615       BatchMatMul<SYCLDevice, TYPE>)
616 #endif  // TENSORFLOW_USE_SYCL
617 }  // end namespace tensorflow
618 
619 #endif  // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
620