1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
20
21 #define EIGEN_USE_THREADS
22
23 #include <vector>
24
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/type_traits.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/kernels/fill_functor.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/inlined_vector.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/matmul_bcast.h"
39 #include "tensorflow/core/util/work_sharder.h"
40
41 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
42 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
43 #endif
44
45 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
46 #include "tensorflow/core/platform/stream_executor.h"
47 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48
49 namespace tensorflow {
50
51 typedef Eigen::ThreadPoolDevice CPUDevice;
52 typedef Eigen::GpuDevice GPUDevice;
53 #ifdef TENSORFLOW_USE_SYCL
54 typedef Eigen::SyclDevice SYCLDevice;
55 #endif // TENSORFLOW_USE_SYCL
56
57 namespace {
58
59 // Returns the pair of dimensions along which to perform Tensor contraction to
60 // emulate matrix multiplication.
61 // For matrix multiplication of 2D Tensors X and Y, X is contracted along
62 // second dimension and Y is contracted along the first dimension (if neither X
63 // nor Y is adjointed). The dimension to contract along is switched when any
64 // operand is adjointed.
65 // See http://en.wikipedia.org/wiki/Tensor_contraction
ContractionDims(bool adj_x,bool adj_y)66 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
67 return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
68 }
69
70 // Parallel batch matmul kernel based on the multi-threaded tensor contraction
71 // in Eigen.
72 template <typename Scalar, bool IsComplex = true>
73 struct ParallelMatMulKernel {
ConjugateParallelMatMulKernel74 static void Conjugate(const OpKernelContext* context, Tensor* out) {
75 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
76 auto z = out->tensor<Scalar, 3>();
77 z.device(d) = z.conjugate();
78 }
79
RunParallelMatMulKernel80 static void Run(const OpKernelContext* context, const Tensor& in_x,
81 const Tensor in_y, bool adj_x, bool adj_y,
82 const MatMulBCast& bcast, Tensor* out, int start, int limit) {
83 static_assert(IsComplex, "Complex type expected.");
84 auto Tx = in_x.tensor<Scalar, 3>();
85 auto Ty = in_y.tensor<Scalar, 3>();
86 auto Tz = out->tensor<Scalar, 3>();
87 // We use the identities
88 // conj(a) * conj(b) = conj(a * b)
89 // conj(a) * b = conj(a * conj(b))
90 // to halve the number of cases. The final conjugation of the result is
91 // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
92 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
93 contract_pairs[0] = ContractionDims(adj_x, adj_y);
94 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
95
96 const bool should_bcast = bcast.IsBroadcastingRequired();
97 const auto& x_batch_indices = bcast.x_batch_indices();
98 const auto& y_batch_indices = bcast.y_batch_indices();
99 for (int64 i = start; i < limit; ++i) {
100 const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
101 const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
102
103 auto x = Tx.template chip<0>(x_batch_index);
104 auto z = Tz.template chip<0>(i);
105 if (adj_x != adj_y) {
106 auto y = Ty.template chip<0>(y_batch_index).conjugate();
107 z.device(d) = x.contract(y, contract_pairs);
108 } else {
109 auto y = Ty.template chip<0>(y_batch_index);
110 z.device(d) = x.contract(y, contract_pairs);
111 }
112 }
113 }
114 };
115
116 // The Eigen contraction kernel used here is very large and slow to compile,
117 // so we partially specialize ParallelMatMulKernel for real types to avoid all
118 // but one of the instantiations.
119 template <typename Scalar>
120 struct ParallelMatMulKernel<Scalar, false> {
121 static void Conjugate(const OpKernelContext* context, Tensor* out) {}
122
123 static void Run(const OpKernelContext* context, const Tensor& in_x,
124 const Tensor& in_y, bool adj_x, bool adj_y,
125 const MatMulBCast& bcast, Tensor* out, int start, int limit) {
126 auto Tx = in_x.tensor<Scalar, 3>();
127 auto Ty = in_y.tensor<Scalar, 3>();
128 auto Tz = out->tensor<Scalar, 3>();
129 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
130 contract_pairs[0] = ContractionDims(adj_x, adj_y);
131 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
132
133 const bool should_bcast = bcast.IsBroadcastingRequired();
134 const auto& x_batch_indices = bcast.x_batch_indices();
135 const auto& y_batch_indices = bcast.y_batch_indices();
136 for (int64 i = start; i < limit; ++i) {
137 const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
138 const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
139 auto x = Tx.template chip<0>(x_batch_index);
140 auto y = Ty.template chip<0>(y_batch_index);
141 auto z = Tz.template chip<0>(i);
142
143 z.device(d) = x.contract(y, contract_pairs);
144 }
145 }
146 };
147
148 // Sequential batch matmul kernel that calls the regular Eigen matmul.
149 // We prefer this over the tensor contraction because it performs
150 // better on vector-matrix and matrix-vector products.
151 template <typename Scalar>
152 struct SequentialMatMulKernel {
153 using Matrix =
154 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
155 using ConstMatrixMap = Eigen::Map<const Matrix>;
156 using MatrixMap = Eigen::Map<Matrix>;
157
158 static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
159 int slice) {
160 return ConstMatrixMap(
161 t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
162 t.dim_size(1), t.dim_size(2));
163 }
164
165 static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
166 return MatrixMap(
167 t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
168 t->dim_size(1), t->dim_size(2));
169 }
170
171 static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
172 bool adj_y, const MatMulBCast& bcast, Tensor* out, int start,
173 int limit) {
174 const bool should_bcast = bcast.IsBroadcastingRequired();
175 const auto& x_batch_indices = bcast.x_batch_indices();
176 const auto& y_batch_indices = bcast.y_batch_indices();
177 for (int64 i = start; i < limit; ++i) {
178 const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
179 const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
180 auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
181 auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
182 auto z = TensorSliceToEigenMatrix(out, i);
183 if (!adj_x) {
184 if (!adj_y) {
185 z.noalias() = x * y;
186 } else {
187 z.noalias() = x * y.adjoint();
188 }
189 } else {
190 if (!adj_y) {
191 z.noalias() = x.adjoint() * y;
192 } else {
193 z.noalias() = x.adjoint() * y.adjoint();
194 }
195 }
196 }
197 }
198 };
199
200 } // namespace
201
202 template <typename Device, typename Scalar>
203 struct LaunchBatchMatMul;
204
205 template <typename Scalar>
206 struct LaunchBatchMatMul<CPUDevice, Scalar> {
207 static void Launch(OpKernelContext* context, const Tensor& in_x,
208 const Tensor& in_y, bool adj_x, bool adj_y,
209 const MatMulBCast& bcast, Tensor* out) {
210 typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
211 ParallelMatMulKernel;
212 bool conjugate_result = false;
213
214 // Number of matrix multiplies i.e. size of the batch.
215 const int64 batch_size = bcast.output_batch_size();
216 const int64 cost_per_unit =
217 in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
218 const int64 small_dim = std::min(
219 std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
220 // NOTE(nikhilsarda): This heuristic is optimal in benchmarks as of
221 // Jan 21, 2020.
222 const int64 kMaxCostOuterParallelism = 128 * 128; // heuristic.
223 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
224 if (small_dim > 1 &&
225 (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
226 // Parallelize over inner dims.
227 // For large matrix products it is counter-productive to parallelize
228 // over the batch dimension.
229 ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, bcast, out,
230 0, batch_size);
231 conjugate_result = adj_x;
232 } else {
233 // Parallelize over outer dims. For small matrices and large batches, it
234 // is counter-productive to parallelize the inner matrix multiplies.
235 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
236 cost_per_unit,
237 [&in_x, &in_y, adj_x, adj_y, &bcast, out](int start, int limit) {
238 SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
239 bcast, out, start, limit);
240 });
241 }
242 if (conjugate_result) {
243 // We used one of the identities
244 // conj(a) * conj(b) = conj(a * b)
245 // conj(a) * b = conj(a * conj(b))
246 // above, we need to conjugate the final output. This is a
247 // no-op for non-complex types.
248 ParallelMatMulKernel::Conjugate(context, out);
249 }
250 }
251 };
252
253 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
254
255 namespace {
256 template <typename T>
257 se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
258 se::DeviceMemoryBase wrapped(const_cast<T*>(gpu_memory));
259 se::DeviceMemory<T> typed(wrapped);
260 return typed;
261 }
262
263 class BlasScratchAllocator : public se::ScratchAllocator {
264 public:
265 using Stream = se::Stream;
266 using DeviceMemoryBytes = se::DeviceMemory<uint8>;
267
268 BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
269
270 int64 GetMemoryLimitInBytes() override { return -1; }
271
272 se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
273 int64 byte_size) override {
274 Tensor temporary_memory;
275
276 Status allocation_status(context_->allocate_temp(
277 DT_UINT8, TensorShape({byte_size}), &temporary_memory));
278 if (!allocation_status.ok()) {
279 return se::port::StatusOr<DeviceMemoryBytes>(
280 DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
281 }
282 // Hold the reference of the allocated tensors until the end of the
283 // allocator.
284 allocated_tensors_.push_back(temporary_memory);
285 return se::port::StatusOr<DeviceMemoryBytes>(
286 DeviceMemoryBytes::MakeFromByteSize(
287 temporary_memory.flat<uint8>().data(),
288 temporary_memory.flat<uint8>().size()));
289 }
290
291 private:
292 OpKernelContext* context_;
293 std::vector<Tensor> allocated_tensors_;
294 };
295 } // namespace
296
297 template <typename Scalar>
298 struct LaunchBatchMatMul<GPUDevice, Scalar> {
299 static void Launch(OpKernelContext* context, const Tensor& in_x,
300 const Tensor& in_y, bool adj_x, bool adj_y,
301 const MatMulBCast& bcast, Tensor* out) {
302 constexpr se::blas::Transpose kTranspose =
303 is_complex<Scalar>::value ? se::blas::Transpose::kConjugateTranspose
304 : se::blas::Transpose::kTranspose;
305 se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
306 kTranspose};
307 const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
308 const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
309 const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
310 const int64 batch_size = bcast.output_batch_size();
311 auto blas_transpose_a = trans[adj_x];
312 auto blas_transpose_b = trans[adj_y];
313
314 auto* stream = context->op_device_context()->stream();
315 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
316
317 typedef se::DeviceMemory<Scalar> DeviceMemoryType;
318 std::vector<DeviceMemoryType> a_device_memory;
319 std::vector<DeviceMemoryType> b_device_memory;
320 std::vector<DeviceMemoryType> c_device_memory;
321 std::vector<DeviceMemoryType*> a_ptrs;
322 std::vector<DeviceMemoryType*> b_ptrs;
323 std::vector<DeviceMemoryType*> c_ptrs;
324 a_device_memory.reserve(bcast.x_batch_size());
325 b_device_memory.reserve(bcast.y_batch_size());
326 c_device_memory.reserve(batch_size);
327 a_ptrs.reserve(batch_size);
328 b_ptrs.reserve(batch_size);
329 c_ptrs.reserve(batch_size);
330 auto* a_base_ptr = in_x.template flat<Scalar>().data();
331 auto* b_base_ptr = in_y.template flat<Scalar>().data();
332 auto* c_base_ptr = out->template flat<Scalar>().data();
333
334 if (!bcast.IsBroadcastingRequired()) {
335 for (int64 i = 0; i < batch_size; ++i) {
336 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
337 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
338 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
339 a_ptrs.push_back(&a_device_memory.back());
340 b_ptrs.push_back(&b_device_memory.back());
341 c_ptrs.push_back(&c_device_memory.back());
342 }
343 } else {
344 const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
345 const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
346 for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
347 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
348 }
349 for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
350 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
351 }
352 for (int64 i = 0; i < batch_size; ++i) {
353 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
354 a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
355 b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
356 c_ptrs.push_back(&c_device_memory.back());
357 }
358 }
359
360 typedef Scalar Coefficient;
361
362 // Blas does
363 // C = A x B
364 // where A, B and C are assumed to be in column major.
365 // We want the output to be in row-major, so we can compute
366 // C' = B' x A', where ' stands for transpose (not adjoint).
367 // TODO(yangzihao): Choose the best of the three strategies using autotune.
368 if (batch_size == 1) {
369 // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
370 // overhead of the scratch allocator and the batch interface.
371 if (n == 1 &&
372 blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
373 blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
374 // This is a matrix*vector multiply so use GEMV to compute A * b.
375 // Here we are multiplying in the natural order, so we have to flip
376 // the transposition flag to compensate for the tensor being stored
377 // row-major. Since GEMV doesn't provide a way to just conjugate an
378 // argument, we have to defer those cases to GEMM below.
379 auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
380 ? se::blas::Transpose::kNoTranspose
381 : se::blas::Transpose::kTranspose;
382 bool blas_launch_status =
383 stream
384 ->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
385 static_cast<Coefficient>(1.0), *(a_ptrs[0]),
386 adj_x ? m : k, *(b_ptrs[0]), 1,
387 static_cast<Coefficient>(0.0), c_ptrs[0], 1)
388 .ok();
389 if (!blas_launch_status) {
390 context->SetStatus(errors::Internal(
391 "Blas xGEMV launch failed : a.shape=", in_x.shape().DebugString(),
392 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
393 ", k=", k));
394 }
395 } else {
396 bool blas_launch_status =
397 stream
398 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
399 static_cast<Coefficient>(1.0), *(b_ptrs[0]),
400 adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
401 static_cast<Coefficient>(0.0), c_ptrs[0], n)
402 .ok();
403 if (!blas_launch_status) {
404 context->SetStatus(errors::Internal(
405 "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
406 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
407 ", k=", k));
408 }
409 }
410 } else {
411 BlasScratchAllocator scratch_allocator(context);
412 bool blas_launch_status =
413 stream
414 ->ThenBlasGemmBatchedWithScratch(
415 blas_transpose_b, blas_transpose_a, n, m, k,
416 static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
417 adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
418 batch_size, &scratch_allocator)
419 .ok();
420 if (!blas_launch_status) {
421 context->SetStatus(errors::Internal(
422 "Blas xGEMMBatched launch failed : a.shape=",
423 in_x.shape().DebugString(),
424 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
425 ", k=", k, ", batch_size=", batch_size));
426 }
427 }
428 }
429 };
430
431 template <>
432 struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
433 static void Launch(OpKernelContext* context, const Tensor& in_x,
434 const Tensor& in_y, bool adj_x, bool adj_y,
435 const MatMulBCast& bcast, Tensor* out) {
436 typedef Eigen::half Scalar;
437 constexpr perftools::gputools::blas::Transpose kTranspose =
438 is_complex<Scalar>::value
439 ? perftools::gputools::blas::Transpose::kConjugateTranspose
440 : perftools::gputools::blas::Transpose::kTranspose;
441 perftools::gputools::blas::Transpose trans[] = {
442 perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
443 const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
444 const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
445 const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
446 const uint64 batch_size = bcast.output_batch_size();
447 auto blas_transpose_a = trans[adj_x];
448 auto blas_transpose_b = trans[adj_y];
449
450 auto* stream = context->op_device_context()->stream();
451 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
452
453 typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
454 std::vector<DeviceMemoryType> a_device_memory;
455 std::vector<DeviceMemoryType> b_device_memory;
456 std::vector<DeviceMemoryType> c_device_memory;
457 std::vector<DeviceMemoryType*> a_ptrs;
458 std::vector<DeviceMemoryType*> b_ptrs;
459 std::vector<DeviceMemoryType*> c_ptrs;
460 a_device_memory.reserve(bcast.x_batch_size());
461 b_device_memory.reserve(bcast.y_batch_size());
462 c_device_memory.reserve(batch_size);
463 a_ptrs.reserve(batch_size);
464 b_ptrs.reserve(batch_size);
465 c_ptrs.reserve(batch_size);
466 auto* a_base_ptr = in_x.template flat<Scalar>().data();
467 auto* b_base_ptr = in_y.template flat<Scalar>().data();
468 auto* c_base_ptr = out->template flat<Scalar>().data();
469
470 if (!bcast.IsBroadcastingRequired()) {
471 for (int64 i = 0; i < batch_size; ++i) {
472 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
473 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
474 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
475 a_ptrs.push_back(&a_device_memory.back());
476 b_ptrs.push_back(&b_device_memory.back());
477 c_ptrs.push_back(&c_device_memory.back());
478 }
479 } else {
480 const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
481 const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
482 for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
483 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
484 }
485 for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
486 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
487 }
488 for (int64 i = 0; i < batch_size; ++i) {
489 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
490 a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
491 b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
492 c_ptrs.push_back(&c_device_memory.back());
493 }
494 }
495
496 typedef float Coefficient;
497
498 // Blas does
499 // C = A x B
500 // where A, B and C are assumed to be in column major.
501 // We want the output to be in row-major, so we can compute
502 // C' = B' x A', where ' stands for transpose (not adjoint).
503 // TODO(yangzihao): Choose the best of the three strategies using autotune.
504 if (batch_size == 1) {
505 // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
506 // overhead of the scratch allocator and the batch interface.
507 // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
508 bool blas_launch_status =
509 stream
510 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
511 static_cast<Coefficient>(1.0), *(b_ptrs[0]),
512 adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
513 static_cast<Coefficient>(0.0), c_ptrs[0], n)
514 .ok();
515 if (!blas_launch_status) {
516 context->SetStatus(errors::Internal(
517 "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
518 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
519 ", k=", k));
520 }
521 } else {
522 BlasScratchAllocator scratch_allocator(context);
523 bool blas_launch_status =
524 stream
525 ->ThenBlasGemmBatchedWithScratch(
526 blas_transpose_b, blas_transpose_a, n, m, k,
527 static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
528 adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
529 batch_size, &scratch_allocator)
530 .ok();
531 if (!blas_launch_status) {
532 context->SetStatus(errors::Internal(
533 "Blas xGEMMBatched launch failed : a.shape=",
534 in_x.shape().DebugString(),
535 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
536 ", k=", k, ", batch_size=", batch_size));
537 }
538 }
539 }
540 };
541
542 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
543
544 #ifdef TENSORFLOW_USE_SYCL
545 template <typename Scalar>
546 struct ParallelMatMulKernelSYCL {
547 static void Run(const OpKernelContext* context, const Tensor& in_x,
548 const Tensor& in_y, bool adj_x, bool adj_y,
549 const MatMulBCast& bcast, Tensor* out, int start, int limit) {
550 auto Tx = in_x.tensor<Scalar, 3>();
551 auto Ty = in_y.tensor<Scalar, 3>();
552 auto Tz = out->tensor<Scalar, 3>();
553 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
554 contract_pairs[0] = ContractionDims(adj_x, adj_y);
555 auto d = context->eigen_sycl_device();
556
557 const bool should_bcast = bcast.IsBroadcastingRequired();
558 const auto& x_batch_indices = bcast.x_batch_indices();
559 const auto& y_batch_indices = bcast.y_batch_indices();
560 for (int64 i = start; i < limit; ++i) {
561 const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
562 const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
563
564 auto x = Tx.template chip<0>(x_batch_index);
565 auto y = Ty.template chip<0>(y_batch_index);
566 auto z = Tz.template chip<0>(i);
567 z.device(d) = x.contract(y, contract_pairs);
568 }
569 }
570 };
571
572 template <typename Scalar>
573 struct LaunchBatchMatMul<SYCLDevice, Scalar> {
574 static void Launch(OpKernelContext* context, const Tensor& in_x,
575 const Tensor& in_y, bool adj_x, bool adj_y,
576 const MatMulBCast& bcast, Tensor* out) {
577 // Number of matrix multiplies i.e. size of the batch.
578 const int64 batch_size = bcast.output_batch_size();
579 ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y,
580 bcast, out, 0, batch_size);
581 }
582 };
583 #endif // TENSORFLOW_USE_SYCL
584
585 template <typename Device, typename Scalar>
586 class BaseBatchMatMulOp : public OpKernel {
587 public:
588 explicit BaseBatchMatMulOp(OpKernelConstruction* context)
589 : OpKernel(context) {
590 OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
591 OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
592 }
593
594 ~BaseBatchMatMulOp() override {}
595
596 void Compute(OpKernelContext* ctx) override {
597 const Tensor& in0 = ctx->input(0);
598 const Tensor& in1 = ctx->input(1);
599
600 ValidateInputTensors(ctx, in0, in1);
601
602 MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
603 OP_REQUIRES(
604 ctx, bcast.IsValid(),
605 errors::InvalidArgument(
606 "In[0] and In[1] must have compatible batch dimensions: ",
607 in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
608
609 TensorShape out_shape = bcast.output_batch_shape();
610 auto batch_size = bcast.output_batch_size();
611 auto d0 = in0.dim_size(in0.dims() - 2);
612 auto d1 = in0.dim_size(in0.dims() - 1);
613 Tensor in0_reshaped;
614 OP_REQUIRES(
615 ctx,
616 in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
617 errors::Internal("Failed to reshape In[0] from ",
618 in0.shape().DebugString()));
619 auto d2 = in1.dim_size(in1.dims() - 2);
620 auto d3 = in1.dim_size(in1.dims() - 1);
621 Tensor in1_reshaped;
622 OP_REQUIRES(
623 ctx,
624 in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
625 errors::Internal("Failed to reshape In[1] from ",
626 in1.shape().DebugString()));
627 if (adj_x_) std::swap(d0, d1);
628 if (adj_y_) std::swap(d2, d3);
629 OP_REQUIRES(ctx, d1 == d2,
630 errors::InvalidArgument(
631 "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
632 in0.shape().DebugString(), " ", in1.shape().DebugString(),
633 " ", adj_x_, " ", adj_y_));
634 out_shape.AddDim(d0);
635 out_shape.AddDim(d3);
636 Tensor* out = nullptr;
637 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
638 if (out->NumElements() == 0) {
639 return;
640 }
641 if (in0.NumElements() == 0 || in1.NumElements() == 0) {
642 functor::SetZeroFunctor<Device, Scalar> f;
643 f(ctx->eigen_device<Device>(), out->flat<Scalar>());
644 return;
645 }
646 Tensor out_reshaped;
647 OP_REQUIRES(ctx,
648 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
649 errors::Internal("Failed to reshape output from ",
650 out->shape().DebugString()));
651 LaunchBatchMatMul<Device, Scalar>::Launch(
652 ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, bcast, &out_reshaped);
653 }
654
655 protected:
656 virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
657 const Tensor& in1) = 0;
658
659 private:
660 bool adj_x_;
661 bool adj_y_;
662 };
663
664 // BatchMatMul Op implementation which disallows broadcasting.
665 template <typename Device, typename Scalar>
666 class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
667 public:
668 explicit BatchMatMulOp(OpKernelConstruction* context)
669 : BaseBatchMatMulOp<Device, Scalar>(context) {}
670
671 ~BatchMatMulOp() override {}
672
673 private:
674 void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
675 const Tensor& in1) override {
676 // Disallow broadcasting support. Ensure that all batch dimensions of the
677 // input tensors match.
678 OP_REQUIRES(ctx, in0.dims() == in1.dims(),
679 errors::InvalidArgument("In[0] and In[1] has different ndims: ",
680 in0.shape().DebugString(), " vs. ",
681 in1.shape().DebugString()));
682 const int ndims = in0.dims();
683 OP_REQUIRES(
684 ctx, ndims >= 2,
685 errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
686 for (int i = 0; i < ndims - 2; ++i) {
687 OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
688 errors::InvalidArgument(
689 "In[0].dim(", i, ") and In[1].dim(", i,
690 ") must be the same: ", in0.shape().DebugString(), " vs ",
691 in1.shape().DebugString()));
692 }
693 }
694 };
695
696 // BatchMatMul Op implementation with broadcasting support.
697 template <typename Device, typename Scalar>
698 class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
699 public:
700 explicit BatchMatMulV2Op(OpKernelConstruction* context)
701 : BaseBatchMatMulOp<Device, Scalar>(context) {}
702
703 ~BatchMatMulV2Op() override {}
704
705 private:
706 void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
707 const Tensor& in1) override {
708 // Enable broadcasting support. Validity of broadcasting is checked in
709 // BaseBatchMatMulOp.
710 OP_REQUIRES(
711 ctx, in0.dims() >= 2,
712 errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
713 OP_REQUIRES(
714 ctx, in1.dims() >= 2,
715 errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
716 }
717 };
718
719 #define REGISTER_BATCH_MATMUL_CPU(TYPE) \
720 REGISTER_KERNEL_BUILDER( \
721 Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
722 BatchMatMulOp<CPUDevice, TYPE>); \
723 REGISTER_KERNEL_BUILDER( \
724 Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
725 BatchMatMulV2Op<CPUDevice, TYPE>)
726
727 #define REGISTER_BATCH_MATMUL_GPU(TYPE) \
728 REGISTER_KERNEL_BUILDER( \
729 Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
730 BatchMatMulOp<GPUDevice, TYPE>); \
731 REGISTER_KERNEL_BUILDER( \
732 Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
733 BatchMatMulV2Op<GPUDevice, TYPE>)
734
735 #ifdef TENSORFLOW_USE_SYCL
736 #define REGISTER_BATCH_MATMUL_SYCL(TYPE) \
737 REGISTER_KERNEL_BUILDER( \
738 Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
739 BatchMatMulOp<SYCLDevice, TYPE>); \
740 REGISTER_KERNEL_BUILDER( \
741 Name("BatchMatMulV2").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
742 BatchMatMulV2Op<SYCLDevice, TYPE>)
743 #endif // TENSORFLOW_USE_SYCL
744 } // namespace tensorflow
745
746 #endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
747