1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
20
21 #define EIGEN_USE_THREADS
22
23 #include <vector>
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/type_traits.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/work_sharder.h"
36
37 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
38 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
39 #endif
40
41 #if GOOGLE_CUDA
42 #include "tensorflow/core/platform/stream_executor.h"
43 #endif // GOOGLE_CUDA
44
45 namespace tensorflow {
46
47 typedef Eigen::ThreadPoolDevice CPUDevice;
48 typedef Eigen::GpuDevice GPUDevice;
49 #ifdef TENSORFLOW_USE_SYCL
50 typedef Eigen::SyclDevice SYCLDevice;
51 #endif // TENSORFLOW_USE_SYCL
52
53 namespace {
54
55 // Returns the pair of dimensions along which to perform Tensor contraction to
56 // emulate matrix multiplication.
57 // For matrix multiplication of 2D Tensors X and Y, X is contracted along
58 // second dimension and Y is contracted along the first dimension (if neither X
59 // nor Y is adjointed). The dimension to contract along is switched when any
60 // operand is adjointed.
61 // See http://en.wikipedia.org/wiki/Tensor_contraction
ContractionDims(bool adj_x,bool adj_y)62 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
63 return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
64 }
65
66 // Parallel batch matmul kernel based on the multi-threaded tensor contraction
67 // in Eigen.
68 template <typename Scalar, bool IsComplex = true>
69 struct ParallelMatMulKernel {
ConjugateParallelMatMulKernel70 static void Conjugate(const OpKernelContext* context, Tensor* out) {
71 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
72 auto z = out->tensor<Scalar, 3>();
73 z.device(d) = z.conjugate();
74 }
75
RunParallelMatMulKernel76 static void Run(const OpKernelContext* context, const Tensor& in_x,
77 const Tensor in_y, bool adj_x, bool adj_y, Tensor* out,
78 int start, int limit) {
79 static_assert(IsComplex, "Complex type expected.");
80 auto Tx = in_x.tensor<Scalar, 3>();
81 auto Ty = in_y.tensor<Scalar, 3>();
82 auto Tz = out->tensor<Scalar, 3>();
83 // We use the identities
84 // conj(a) * conj(b) = conj(a * b)
85 // conj(a) * b = conj(a * conj(b))
86 // to halve the number of cases. The final conjugation of the result is
87 // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
88 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
89 contract_pairs[0] = ContractionDims(adj_x, adj_y);
90 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
91 for (int i = start; i < limit; ++i) {
92 auto x = Tx.template chip<0>(i);
93 auto z = Tz.template chip<0>(i);
94 if (adj_x != adj_y) {
95 auto y = Ty.template chip<0>(i).conjugate();
96 z.device(d) = x.contract(y, contract_pairs);
97 } else {
98 auto y = Ty.template chip<0>(i);
99 z.device(d) = x.contract(y, contract_pairs);
100 }
101 }
102 }
103 };
104
105 // The Eigen contraction kernel used here is very large and slow to compile,
106 // so we partially specialize ParallelMatMulKernel for real types to avoid all
107 // but one of the instantiations.
108 template <typename Scalar>
109 struct ParallelMatMulKernel<Scalar, false> {
110 static void Conjugate(const OpKernelContext* context, Tensor* out) {}
111
112 static void Run(const OpKernelContext* context, const Tensor& in_x,
113 const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out,
114 int start, int limit) {
115 auto Tx = in_x.tensor<Scalar, 3>();
116 auto Ty = in_y.tensor<Scalar, 3>();
117 auto Tz = out->tensor<Scalar, 3>();
118 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
119 contract_pairs[0] = ContractionDims(adj_x, adj_y);
120 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
121 for (int i = start; i < limit; ++i) {
122 auto x = Tx.template chip<0>(i);
123 auto y = Ty.template chip<0>(i);
124 auto z = Tz.template chip<0>(i);
125 z.device(d) = x.contract(y, contract_pairs);
126 }
127 }
128 };
129
130 // TODO(rmlarsen): Get rid of this when we have upstreamed improvements
131 // for matrix*vector and vector*matrix to Eigen's general matrix product.
132 template <typename Tx, typename Ty, typename Tz>
133 static void Multiply(bool adj_x, bool adj_y, Tx x, Ty y, Tz z) {
134 if (!adj_x) {
135 if (!adj_y) {
136 z.noalias() = x * y;
137 } else {
138 z.noalias() = x * y.adjoint();
139 }
140 } else {
141 if (!adj_y) {
142 z.noalias() = x.adjoint() * y;
143 } else {
144 z.noalias() = x.adjoint() * y.adjoint();
145 }
146 }
147 }
148
149 // Sequential batch matmul kernel that calls the regular Eigen matmul.
150 // We prefer this over the tensor contraction because it performs
151 // better on vector-matrix and matrix-vector products.
152 template <typename Scalar>
153 struct SequentialMatMulKernel {
154 using Matrix =
155 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
156 using ConstMatrixMap = Eigen::Map<const Matrix>;
157 using MatrixMap = Eigen::Map<Matrix>;
158
159 static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
160 int slice) {
161 return ConstMatrixMap(
162 t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
163 t.dim_size(1), t.dim_size(2));
164 }
165
166 static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
167 return MatrixMap(
168 t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
169 t->dim_size(1), t->dim_size(2));
170 }
171
172 static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
173 bool adj_y, Tensor* out, int start, int limit) {
174 for (int i = start; i < limit; ++i) {
175 auto x = ConstTensorSliceToEigenMatrix(in_x, i);
176 auto y = ConstTensorSliceToEigenMatrix(in_y, i);
177 auto z = TensorSliceToEigenMatrix(out, i);
178 // TODO(rmlarsen): Get rid of the special casing here when we have
179 // upstreamed improvements for matrix*vector and vector*matrix to
180 // Eigen's general matrix product.
181 if (!adj_x && x.rows() == 1) {
182 Multiply(adj_x, adj_y, x.row(0), y, z);
183 } else if (adj_x && x.cols() == 1) {
184 Multiply(adj_x, adj_y, x.col(0), y, z);
185 } else if (!adj_y && y.cols() == 1) {
186 Multiply(adj_x, adj_y, x, y.col(0), z);
187 } else if (adj_y && y.rows() == 1) {
188 Multiply(adj_x, adj_y, x, y.row(0), z);
189 } else {
190 Multiply(adj_x, adj_y, x, y, z);
191 }
192 }
193 }
194 };
195
196 } // namespace
197
198 template <typename Device, typename Scalar>
199 struct LaunchBatchMatMul;
200
201 template <typename Scalar>
202 struct LaunchBatchMatMul<CPUDevice, Scalar> {
203 static void Launch(OpKernelContext* context, const Tensor& in_x,
204 const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
205 typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
206 ParallelMatMulKernel;
207 bool conjugate_result = false;
208
209 // Number of matrix multiplies i.e. size of the batch.
210 const int64 batch_size = in_x.dim_size(0);
211 const int64 cost_per_unit =
212 in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
213 const int64 small_dim = std::min(
214 std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
215 const int64 kMaxCostOuterParallelism = 128 * 128 * 256; // heuristic.
216 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
217 if (small_dim > 1 &&
218 (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
219 // Parallelize over inner dims.
220 // For large matrix products it is counter-productive to parallelize
221 // over the batch dimension.
222 ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, out, 0,
223 batch_size);
224 conjugate_result = adj_x;
225 } else {
226 // Parallelize over outer dims. For small matrices and large batches, it
227 // is counter-productive to parallelize the inner matrix multiplies.
228 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
229 cost_per_unit,
230 [&in_x, &in_y, adj_x, adj_y, out](int start, int limit) {
231 SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y, out,
232 start, limit);
233 });
234 }
235 if (conjugate_result) {
236 // We used one of the identities
237 // conj(a) * conj(b) = conj(a * b)
238 // conj(a) * b = conj(a * conj(b))
239 // above, we need to conjugate the final output. This is a
240 // no-op for non-complex types.
241 ParallelMatMulKernel::Conjugate(context, out);
242 }
243 }
244 };
245
246 #if GOOGLE_CUDA
247
248 namespace {
249 template <typename T>
250 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
251 se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
252 se::DeviceMemory<T> typed(wrapped);
253 return typed;
254 }
255
256 class CublasScratchAllocator : public se::ScratchAllocator {
257 public:
258 using Stream = se::Stream;
259 using DeviceMemoryBytes = se::DeviceMemory<uint8>;
260
261 CublasScratchAllocator(OpKernelContext* context) : context_(context) {}
262
263 int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; }
264
265 se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
266 Stream* stream, int64 byte_size) override {
267 Tensor temporary_memory;
268
269 Status allocation_status(context_->allocate_temp(
270 DT_UINT8, TensorShape({byte_size}), &temporary_memory));
271 if (!allocation_status.ok()) {
272 return se::port::StatusOr<DeviceMemoryBytes>(
273 DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
274 }
275 // Hold the reference of the allocated tensors until the end of the
276 // allocator.
277 allocated_tensors_.push_back(temporary_memory);
278 return se::port::StatusOr<DeviceMemoryBytes>(
279 DeviceMemoryBytes::MakeFromByteSize(
280 temporary_memory.flat<uint8>().data(),
281 temporary_memory.flat<uint8>().size()));
282 }
283
284 private:
285 OpKernelContext* context_;
286 std::vector<Tensor> allocated_tensors_;
287 };
288 } // namespace
289
290 template <typename Scalar>
291 struct LaunchBatchMatMul<GPUDevice, Scalar> {
292 static void Launch(OpKernelContext* context, const Tensor& in_x,
293 const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
294 constexpr se::blas::Transpose kTranspose =
295 is_complex<Scalar>::value ? se::blas::Transpose::kConjugateTranspose
296 : se::blas::Transpose::kTranspose;
297 se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
298 kTranspose};
299 const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
300 const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
301 const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
302 const uint64 batch_size = in_x.dim_size(0);
303 auto blas_transpose_a = trans[adj_x];
304 auto blas_transpose_b = trans[adj_y];
305
306 auto* stream = context->op_device_context()->stream();
307 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
308
309 typedef se::DeviceMemory<Scalar> DeviceMemoryType;
310 std::vector<DeviceMemoryType> a_device_memory;
311 std::vector<DeviceMemoryType> b_device_memory;
312 std::vector<DeviceMemoryType> c_device_memory;
313 std::vector<DeviceMemoryType*> a_ptrs;
314 std::vector<DeviceMemoryType*> b_ptrs;
315 std::vector<DeviceMemoryType*> c_ptrs;
316 a_device_memory.reserve(batch_size);
317 b_device_memory.reserve(batch_size);
318 c_device_memory.reserve(batch_size);
319 a_ptrs.reserve(batch_size);
320 b_ptrs.reserve(batch_size);
321 c_ptrs.reserve(batch_size);
322 auto* a_base_ptr = in_x.template flat<Scalar>().data();
323 auto* b_base_ptr = in_y.template flat<Scalar>().data();
324 auto* c_base_ptr = out->template flat<Scalar>().data();
325 for (int64 i = 0; i < batch_size; ++i) {
326 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
327 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
328 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
329 a_ptrs.push_back(&a_device_memory.back());
330 b_ptrs.push_back(&b_device_memory.back());
331 c_ptrs.push_back(&c_device_memory.back());
332 }
333
334 typedef Scalar Coefficient;
335
336 // Cublas does
337 // C = A x B
338 // where A, B and C are assumed to be in column major.
339 // We want the output to be in row-major, so we can compute
340 // C' = B' x A', where ' stands for transpose (not adjoint).
341 // TODO(yangzihao): Choose the best of the three strategies using autotune.
342 if (batch_size == 1) {
343 // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
344 // overhead of the scratch allocator and the batch interface.
345 if (n == 1 &&
346 blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
347 blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
348 // This is a matrix*vector multiply so use GEMV to compute A * b.
349 // Here we are multiplying in the natural order, so we have to flip
350 // the transposition flag to compensate for the tensor being stored
351 // row-major. Since GEMV doesn't provide a way to just conjugate an
352 // argument, we have to defer those cases to GEMM below.
353 auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
354 ? se::blas::Transpose::kNoTranspose
355 : se::blas::Transpose::kTranspose;
356 bool blas_launch_status =
357 stream
358 ->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
359 static_cast<Coefficient>(1.0), *(a_ptrs[0]),
360 adj_x ? m : k, *(b_ptrs[0]), 1,
361 static_cast<Coefficient>(0.0), c_ptrs[0], 1)
362 .ok();
363 if (!blas_launch_status) {
364 context->SetStatus(errors::Internal(
365 "Blas xGEMV launch failed : a.shape=", in_x.shape().DebugString(),
366 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
367 ", k=", k));
368 }
369 } else {
370 bool blas_launch_status =
371 stream
372 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
373 static_cast<Coefficient>(1.0), *(b_ptrs[0]),
374 adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
375 static_cast<Coefficient>(0.0), c_ptrs[0], n)
376 .ok();
377 if (!blas_launch_status) {
378 context->SetStatus(errors::Internal(
379 "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
380 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
381 ", k=", k));
382 }
383 }
384 } else {
385 CublasScratchAllocator scratch_allocator(context);
386 bool blas_launch_status =
387 stream
388 ->ThenBlasGemmBatchedWithScratch(
389 blas_transpose_b, blas_transpose_a, n, m, k,
390 static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
391 adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
392 batch_size, &scratch_allocator)
393 .ok();
394 if (!blas_launch_status) {
395 context->SetStatus(errors::Internal(
396 "Blas xGEMMBatched launch failed : a.shape=",
397 in_x.shape().DebugString(),
398 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
399 ", k=", k, ", batch_size=", batch_size));
400 }
401 }
402 }
403 };
404
405 template <>
406 struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
407 static void Launch(OpKernelContext* context, const Tensor& in_x,
408 const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
409 typedef Eigen::half Scalar;
410 constexpr perftools::gputools::blas::Transpose kTranspose =
411 is_complex<Scalar>::value
412 ? perftools::gputools::blas::Transpose::kConjugateTranspose
413 : perftools::gputools::blas::Transpose::kTranspose;
414 perftools::gputools::blas::Transpose trans[] = {
415 perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
416 const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
417 const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
418 const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
419 const uint64 batch_size = in_x.dim_size(0);
420 auto blas_transpose_a = trans[adj_x];
421 auto blas_transpose_b = trans[adj_y];
422
423 auto* stream = context->op_device_context()->stream();
424 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
425
426 typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
427 std::vector<DeviceMemoryType> a_device_memory;
428 std::vector<DeviceMemoryType> b_device_memory;
429 std::vector<DeviceMemoryType> c_device_memory;
430 std::vector<DeviceMemoryType*> a_ptrs;
431 std::vector<DeviceMemoryType*> b_ptrs;
432 std::vector<DeviceMemoryType*> c_ptrs;
433 a_device_memory.reserve(batch_size);
434 b_device_memory.reserve(batch_size);
435 c_device_memory.reserve(batch_size);
436 a_ptrs.reserve(batch_size);
437 b_ptrs.reserve(batch_size);
438 c_ptrs.reserve(batch_size);
439 auto* a_base_ptr = in_x.template flat<Scalar>().data();
440 auto* b_base_ptr = in_y.template flat<Scalar>().data();
441 auto* c_base_ptr = out->template flat<Scalar>().data();
442 for (int64 i = 0; i < batch_size; ++i) {
443 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
444 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
445 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
446 a_ptrs.push_back(&a_device_memory.back());
447 b_ptrs.push_back(&b_device_memory.back());
448 c_ptrs.push_back(&c_device_memory.back());
449 }
450
451 typedef float Coefficient;
452
453 // Cublas does
454 // C = A x B
455 // where A, B and C are assumed to be in column major.
456 // We want the output to be in row-major, so we can compute
457 // C' = B' x A', where ' stands for transpose (not adjoint).
458 // TODO(yangzihao): Choose the best of the three strategies using autotune.
459 if (batch_size == 1) {
460 // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
461 // overhead of the scratch allocator and the batch interface.
462 // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
463 bool blas_launch_status =
464 stream
465 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
466 static_cast<Coefficient>(1.0), *(b_ptrs[0]),
467 adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
468 static_cast<Coefficient>(0.0), c_ptrs[0], n)
469 .ok();
470 if (!blas_launch_status) {
471 context->SetStatus(errors::Internal(
472 "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
473 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
474 ", k=", k));
475 }
476 } else {
477 CublasScratchAllocator scratch_allocator(context);
478 bool blas_launch_status =
479 stream
480 ->ThenBlasGemmBatchedWithScratch(
481 blas_transpose_b, blas_transpose_a, n, m, k,
482 static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
483 adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
484 batch_size, &scratch_allocator)
485 .ok();
486 if (!blas_launch_status) {
487 context->SetStatus(
488 errors::Internal("Blas xGEMMBatched launch failed : a.shape=",
489 in_x.shape().DebugString(), ", b.shape=",
490 in_y.shape().DebugString(), ", m=", m, ", n=", n,
491 ", k=", k, ", batch_size=", batch_size));
492 }
493 }
494 }
495 };
496
497 #endif // GOOGLE_CUDA
498
499 #ifdef TENSORFLOW_USE_SYCL
500 template <typename Scalar>
501 struct ParallelMatMulKernelSYCL {
502 static void Run(const OpKernelContext* context, const Tensor& in_x,
503 const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out,
504 int start, int limit) {
505 auto Tx = in_x.tensor<Scalar, 3>();
506 auto Ty = in_y.tensor<Scalar, 3>();
507 auto Tz = out->tensor<Scalar, 3>();
508 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
509 contract_pairs[0] = ContractionDims(adj_x, adj_y);
510 auto d = context->eigen_sycl_device();
511 for (int i = start; i < limit; ++i) {
512 auto x = Tx.template chip<0>(i);
513 auto y = Ty.template chip<0>(i);
514 auto z = Tz.template chip<0>(i);
515 z.device(d) = x.contract(y, contract_pairs);
516 }
517 }
518 };
519
520 template <typename Scalar>
521 struct LaunchBatchMatMul<SYCLDevice, Scalar> {
522 static void Launch(OpKernelContext* context, const Tensor& in_x,
523 const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
524 // Number of matrix multiplies i.e. size of the batch.
525 const int64 batch_size = in_x.dim_size(0);
526 ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y,
527 out, 0, batch_size);
528 }
529 };
530 #endif // TENSORFLOW_USE_SYCL
531
532 template <typename Device, typename Scalar>
533 class BatchMatMul : public OpKernel {
534 public:
535 explicit BatchMatMul(OpKernelConstruction* context) : OpKernel(context) {
536 OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
537 OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
538 }
539
540 virtual ~BatchMatMul() {}
541
542 void Compute(OpKernelContext* ctx) override {
543 const Tensor& in0 = ctx->input(0);
544 const Tensor& in1 = ctx->input(1);
545 OP_REQUIRES(ctx, in0.dims() == in1.dims(),
546 errors::InvalidArgument("In[0] and In[1] has different ndims: ",
547 in0.shape().DebugString(), " vs. ",
548 in1.shape().DebugString()));
549 const int ndims = in0.dims();
550 OP_REQUIRES(
551 ctx, ndims >= 2,
552 errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
553 TensorShape out_shape;
554 for (int i = 0; i < ndims - 2; ++i) {
555 OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
556 errors::InvalidArgument(
557 "In[0].dim(", i, ") and In[1].dim(", i,
558 ") must be the same: ", in0.shape().DebugString(), " vs ",
559 in1.shape().DebugString()));
560 out_shape.AddDim(in0.dim_size(i));
561 }
562 auto n = (ndims == 2) ? 1 : out_shape.num_elements();
563 auto d0 = in0.dim_size(ndims - 2);
564 auto d1 = in0.dim_size(ndims - 1);
565 Tensor in0_reshaped;
566 CHECK(in0_reshaped.CopyFrom(in0, TensorShape({n, d0, d1})));
567 auto d2 = in1.dim_size(ndims - 2);
568 auto d3 = in1.dim_size(ndims - 1);
569 Tensor in1_reshaped;
570 CHECK(in1_reshaped.CopyFrom(in1, TensorShape({n, d2, d3})));
571 if (adj_x_) std::swap(d0, d1);
572 if (adj_y_) std::swap(d2, d3);
573 OP_REQUIRES(ctx, d1 == d2,
574 errors::InvalidArgument(
575 "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
576 in0.shape().DebugString(), " ", in1.shape().DebugString(),
577 " ", adj_x_, " ", adj_y_));
578 out_shape.AddDim(d0);
579 out_shape.AddDim(d3);
580 Tensor* out = nullptr;
581 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
582 if (out->NumElements() == 0) {
583 return;
584 }
585 if (in0.NumElements() == 0 || in1.NumElements() == 0) {
586 functor::SetZeroFunctor<Device, Scalar> f;
587 f(ctx->eigen_device<Device>(), out->flat<Scalar>());
588 return;
589 }
590 Tensor out_reshaped;
591 CHECK(out_reshaped.CopyFrom(*out, TensorShape({n, d0, d3})));
592 LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
593 adj_x_, adj_y_, &out_reshaped);
594 }
595
596 private:
597 bool adj_x_;
598 bool adj_y_;
599 };
600
601 #define REGISTER_BATCH_MATMUL_CPU(TYPE) \
602 REGISTER_KERNEL_BUILDER( \
603 Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
604 BatchMatMul<CPUDevice, TYPE>)
605
606 #define REGISTER_BATCH_MATMUL_GPU(TYPE) \
607 REGISTER_KERNEL_BUILDER( \
608 Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
609 BatchMatMul<GPUDevice, TYPE>)
610
611 #ifdef TENSORFLOW_USE_SYCL
612 #define REGISTER_BATCH_MATMUL_SYCL(TYPE) \
613 REGISTER_KERNEL_BUILDER( \
614 Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
615 BatchMatMul<SYCLDevice, TYPE>)
616 #endif // TENSORFLOW_USE_SYCL
617 } // end namespace tensorflow
618
619 #endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
620