1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
20
21 #define EIGEN_USE_THREADS
22
23 #include <vector>
24
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/type_traits.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/kernels/fill_functor.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/inlined_vector.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/matmul_bcast.h"
39 #include "tensorflow/core/util/work_sharder.h"
40
41 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
42 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
43 #endif
44
45 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
46 #include "tensorflow/core/platform/stream_executor.h"
47 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48
49 namespace tensorflow {
50
51 typedef Eigen::ThreadPoolDevice CPUDevice;
52 typedef Eigen::GpuDevice GPUDevice;
53
54 namespace {
55
56 // Returns the pair of dimensions along which to perform Tensor contraction to
57 // emulate matrix multiplication.
58 // For matrix multiplication of 2D Tensors X and Y, X is contracted along
59 // second dimension and Y is contracted along the first dimension (if neither X
60 // nor Y is adjointed). The dimension to contract along is switched when any
61 // operand is adjointed.
62 // See http://en.wikipedia.org/wiki/Tensor_contraction
ContractionDims(bool adj_x,bool adj_y)63 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
64 return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
65 }
66
67 // Parallel batch matmul kernel based on the multi-threaded tensor contraction
68 // in Eigen.
69 template <typename Scalar, bool IsComplex = true>
70 struct ParallelMatMulKernel {
ConjugateParallelMatMulKernel71 static void Conjugate(const OpKernelContext* context, Tensor* out) {
72 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
73 auto z = out->tensor<Scalar, 3>();
74 z.device(d) = z.conjugate();
75 }
76
RunParallelMatMulKernel77 static void Run(const OpKernelContext* context, const Tensor& in_x,
78 const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
79 bool trans_y, const MatMulBCast& bcast, Tensor* out,
80 int batch_size) {
81 static_assert(IsComplex, "Complex type expected.");
82 auto Tx = in_x.tensor<Scalar, 3>();
83 auto Ty = in_y.tensor<Scalar, 3>();
84 auto Tz = out->tensor<Scalar, 3>();
85 // We use the identities
86 // conj(a) * conj(b) = conj(a * b)
87 // conj(a) * b = conj(a * conj(b))
88 // to halve the number of cases. The final conjugation of the result is
89 // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
90 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
91 contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
92 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
93
94 const bool should_bcast = bcast.IsBroadcastingRequired();
95 const auto& x_batch_indices = bcast.x_batch_indices();
96 const auto& y_batch_indices = bcast.y_batch_indices();
97 // TODO(rmlarsen): Consider launching these contractions asynchronously.
98 for (int64_t i = 0; i < batch_size; ++i) {
99 const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
100 const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
101
102 auto x = Tx.template chip<0>(x_batch_index);
103 auto z = Tz.template chip<0>(i);
104 if (adj_x != adj_y) {
105 auto y = Ty.template chip<0>(y_batch_index).conjugate();
106 z.device(d) = x.contract(y, contract_pairs);
107 } else {
108 auto y = Ty.template chip<0>(y_batch_index);
109 z.device(d) = x.contract(y, contract_pairs);
110 }
111 }
112 }
113 };
114
115 // The Eigen contraction kernel used here is very large and slow to compile,
116 // so we partially specialize ParallelMatMulKernel for real types to avoid all
117 // but one of the instantiations.
118 template <typename Scalar>
119 struct ParallelMatMulKernel<Scalar, false> {
120 static void Conjugate(const OpKernelContext* context, Tensor* out) {}
121
122 static void Run(const OpKernelContext* context, const Tensor& in_x,
123 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
124 bool trans_y, const MatMulBCast& bcast, Tensor* out,
125 int batch_size) {
126 const bool should_bcast = bcast.IsBroadcastingRequired();
127 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
128 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
129 contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
130 if (batch_size == 1 && !should_bcast) {
131 auto Tx = in_x.flat_inner_dims<Scalar, 2>();
132 auto Ty = in_y.flat_inner_dims<Scalar, 2>();
133 auto Tz = out->flat_inner_dims<Scalar, 2>();
134 Tz.device(d) = Tx.contract(Ty, contract_pairs);
135 } else {
136 auto Tx = in_x.tensor<Scalar, 3>();
137 auto Ty = in_y.tensor<Scalar, 3>();
138 auto Tz = out->tensor<Scalar, 3>();
139 const auto& x_batch_indices = bcast.x_batch_indices();
140 const auto& y_batch_indices = bcast.y_batch_indices();
141 // TODO(rmlarsen): Consider launching these contractions asynchronously.
142 for (int64_t i = 0; i < batch_size; ++i) {
143 const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
144 const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
145 auto x = Tx.template chip<0>(x_batch_index);
146 auto y = Ty.template chip<0>(y_batch_index);
147 auto z = Tz.template chip<0>(i);
148
149 z.device(d) = x.contract(y, contract_pairs);
150 }
151 }
152 }
153 };
154
155 // Sequential batch matmul kernel that calls the regular Eigen matmul.
156 // We prefer this over the tensor contraction because it performs
157 // better on vector-matrix and matrix-vector products.
158 template <typename Scalar>
159 struct SequentialMatMulKernel {
160 using Matrix =
161 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
162 using ConstMatrixMap = Eigen::Map<const Matrix>;
163 using MatrixMap = Eigen::Map<Matrix>;
164
165 static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
166 int slice) {
167 return ConstMatrixMap(
168 t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
169 t.dim_size(1), t.dim_size(2));
170 }
171
172 static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
173 return MatrixMap(
174 t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
175 t->dim_size(1), t->dim_size(2));
176 }
177
178 static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
179 bool adj_y, bool trans_x, bool trans_y,
180 const MatMulBCast& bcast, Tensor* out, int start, int limit) {
181 const bool should_bcast = bcast.IsBroadcastingRequired();
182 const auto& x_batch_indices = bcast.x_batch_indices();
183 const auto& y_batch_indices = bcast.y_batch_indices();
184 for (int64_t i = start; i < limit; ++i) {
185 const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
186 const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
187 auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
188 auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
189 auto z = TensorSliceToEigenMatrix(out, i);
190 // Assume at most one of adj_x or trans_x is true. Similarly, for adj_y
191 // and trans_y.
192 if (!adj_x && !trans_x) {
193 if (!adj_y && !trans_y) {
194 z.noalias() = x * y;
195 } else if (adj_y) {
196 z.noalias() = x * y.adjoint();
197 } else { // trans_y == true
198 z.noalias() = x * y.transpose();
199 }
200 } else if (adj_x) {
201 if (!adj_y && !trans_y) {
202 z.noalias() = x.adjoint() * y;
203 } else if (adj_y) {
204 z.noalias() = x.adjoint() * y.adjoint();
205 } else { // trans_y == true
206 z.noalias() = x.adjoint() * y.transpose();
207 }
208 } else { // trans_x == true
209 if (!adj_y && !trans_y) {
210 z.noalias() = x.transpose() * y;
211 } else if (adj_y) {
212 z.noalias() = x.transpose() * y.adjoint();
213 } else { // trans_y == true
214 z.noalias() = x.transpose() * y.transpose();
215 }
216 }
217 }
218 }
219 };
220
221 } // namespace
222
223 template <typename Device, typename Scalar>
224 struct LaunchBatchMatMul;
225
226 template <typename Scalar>
227 struct LaunchBatchMatMul<CPUDevice, Scalar> {
228 static void Launch(OpKernelContext* context, const Tensor& in_x,
229 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
230 bool trans_y, const MatMulBCast& bcast, Tensor* out) {
231 typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
232 ParallelMatMulKernel;
233 bool conjugate_result = false;
234
235 // Number of matrix multiplies i.e. size of the batch.
236 const int64_t batch_size = bcast.output_batch_size();
237 const int64_t cost_per_unit =
238 in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
239 const int64_t small_dim = std::min(
240 std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
241 // NOTE(nikhilsarda): This heuristic is optimal in benchmarks as of
242 // Jan 21, 2020.
243 const int64_t kMaxCostOuterParallelism = 128 * 128; // heuristic.
244 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
245 // TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous
246 // evaluation in Eigen Tensor.
247 if (small_dim > 1 &&
248 (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
249 // Parallelize over inner dims.
250 // For large matrix products it is counter-productive to parallelize
251 // over the batch dimension.
252 ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
253 trans_y, bcast, out, batch_size);
254 conjugate_result = adj_x;
255 } else {
256 // Parallelize over outer dims. For small matrices and large batches, it
257 // is counter-productive to parallelize the inner matrix multiplies.
258 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
259 cost_per_unit,
260 [&in_x, &in_y, adj_x, adj_y, trans_x, trans_y, &bcast, out](
261 int start, int limit) {
262 SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
263 trans_x, trans_y, bcast, out,
264 start, limit);
265 });
266 }
267 if (conjugate_result) {
268 // We used one of the identities
269 // conj(a) * conj(b) = conj(a * b)
270 // conj(a) * b = conj(a * conj(b))
271 // above, we need to conjugate the final output. This is a
272 // no-op for non-complex types.
273 ParallelMatMulKernel::Conjugate(context, out);
274 }
275 }
276 };
277
278 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
279
280 namespace {
281 template <typename T>
282 se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
283 se::DeviceMemoryBase wrapped(const_cast<T*>(gpu_memory));
284 se::DeviceMemory<T> typed(wrapped);
285 return typed;
286 }
287
288 class BlasScratchAllocator : public se::ScratchAllocator {
289 public:
290 using Stream = se::Stream;
291 using DeviceMemoryBytes = se::DeviceMemory<uint8>;
292
293 BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
294
295 int64 GetMemoryLimitInBytes() override { return -1; }
296
297 se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
298 int64_t byte_size) override {
299 Tensor temporary_memory;
300
301 Status allocation_status(context_->allocate_temp(
302 DT_UINT8, TensorShape({byte_size}), &temporary_memory));
303 if (!allocation_status.ok()) {
304 return se::port::StatusOr<DeviceMemoryBytes>(
305 DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
306 }
307 // Hold the reference of the allocated tensors until the end of the
308 // allocator.
309 allocated_tensors_.push_back(temporary_memory);
310 return se::port::StatusOr<DeviceMemoryBytes>(
311 DeviceMemoryBytes::MakeFromByteSize(
312 temporary_memory.flat<uint8>().data(),
313 temporary_memory.flat<uint8>().size()));
314 }
315
316 private:
317 OpKernelContext* context_;
318 std::vector<Tensor> allocated_tensors_;
319 };
320 } // namespace
321
322 template <typename Scalar>
323 struct LaunchBatchMatMul<GPUDevice, Scalar> {
324 static void Launch(OpKernelContext* context, const Tensor& in_x,
325 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
326 bool trans_y, const MatMulBCast& bcast, Tensor* out) {
327 se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
328 se::blas::Transpose::kTranspose,
329 se::blas::Transpose::kConjugateTranspose};
330 const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
331 const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
332 const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
333 const int64_t batch_size = bcast.output_batch_size();
334 auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
335 auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
336
337 auto* stream = context->op_device_context()->stream();
338 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
339
340 typedef se::DeviceMemory<Scalar> DeviceMemoryType;
341 std::vector<DeviceMemoryType> a_device_memory;
342 std::vector<DeviceMemoryType> b_device_memory;
343 std::vector<DeviceMemoryType> c_device_memory;
344 std::vector<DeviceMemoryType*> a_ptrs;
345 std::vector<DeviceMemoryType*> b_ptrs;
346 std::vector<DeviceMemoryType*> c_ptrs;
347 a_device_memory.reserve(bcast.x_batch_size());
348 b_device_memory.reserve(bcast.y_batch_size());
349 c_device_memory.reserve(batch_size);
350 a_ptrs.reserve(batch_size);
351 b_ptrs.reserve(batch_size);
352 c_ptrs.reserve(batch_size);
353 auto* a_base_ptr = in_x.template flat<Scalar>().data();
354 auto* b_base_ptr = in_y.template flat<Scalar>().data();
355 auto* c_base_ptr = out->template flat<Scalar>().data();
356 uint64 a_stride;
357 uint64 b_stride;
358 uint64 c_stride;
359
360 bool is_full_broadcast =
361 std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
362 bool use_strided_batched =
363 (!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
364 batch_size > 1;
365 if (use_strided_batched) {
366 a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
367 b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
368 c_stride = m * n;
369 a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
370 b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
371 c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
372 a_ptrs.push_back(&a_device_memory.back());
373 b_ptrs.push_back(&b_device_memory.back());
374 c_ptrs.push_back(&c_device_memory.back());
375 } else if (!bcast.IsBroadcastingRequired()) {
376 for (int64_t i = 0; i < batch_size; ++i) {
377 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
378 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
379 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
380 a_ptrs.push_back(&a_device_memory.back());
381 b_ptrs.push_back(&b_device_memory.back());
382 c_ptrs.push_back(&c_device_memory.back());
383 }
384 } else {
385 const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
386 const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
387 for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
388 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
389 }
390 for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
391 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
392 }
393 for (int64_t i = 0; i < batch_size; ++i) {
394 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
395 a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
396 b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
397 c_ptrs.push_back(&c_device_memory.back());
398 }
399 }
400
401 typedef Scalar Coefficient;
402
403 // Blas does
404 // C = A x B
405 // where A, B and C are assumed to be in column major.
406 // We want the output to be in row-major, so we can compute
407 // C' = B' x A', where ' stands for transpose (not adjoint).
408 // TODO(yangzihao): Choose the best of the three strategies using autotune.
409 if (batch_size == 1) {
410 // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
411 // overhead of the scratch allocator and the batch interface.
412 if (n == 1 &&
413 blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
414 blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
415 // This is a matrix*vector multiply so use GEMV to compute A * b.
416 // Here we are multiplying in the natural order, so we have to flip
417 // the transposition flag to compensate for the tensor being stored
418 // row-major. Since GEMV doesn't provide a way to just conjugate an
419 // argument, we have to defer those cases to GEMM below.
420 auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
421 ? se::blas::Transpose::kNoTranspose
422 : se::blas::Transpose::kTranspose;
423 bool blas_launch_status =
424 stream
425 ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
426 adj_x || trans_x ? k : m,
427 static_cast<Coefficient>(1.0), *(a_ptrs[0]),
428 adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
429 static_cast<Coefficient>(0.0), c_ptrs[0], 1)
430 .ok();
431 if (!blas_launch_status) {
432 context->SetStatus(errors::Internal(
433 "Blas xGEMV launch failed : a.shape=", in_x.shape().DebugString(),
434 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
435 ", k=", k));
436 }
437 } else {
438 OP_REQUIRES_OK(context,
439 stream->ThenBlasGemm(
440 blas_transpose_b, blas_transpose_a, n, m, k,
441 *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]),
442 adj_x || trans_x ? m : k, c_ptrs[0], n));
443 }
444 } else if (use_strided_batched) {
445 OP_REQUIRES_OK(context, stream->ThenBlasGemmStridedBatched(
446 blas_transpose_b, blas_transpose_a, n, m, k,
447 static_cast<Coefficient>(1.0), *b_ptrs[0],
448 adj_y || trans_y ? k : n, b_stride,
449 *a_ptrs[0], adj_x || trans_x ? m : k,
450 a_stride, static_cast<Coefficient>(0.0),
451 c_ptrs[0], n, c_stride, batch_size));
452 } else {
453 BlasScratchAllocator scratch_allocator(context);
454 bool blas_launch_status =
455 stream
456 ->ThenBlasGemmBatchedWithScratch(
457 blas_transpose_b, blas_transpose_a, n, m, k,
458 static_cast<Coefficient>(1.0), b_ptrs,
459 adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
460 static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
461 &scratch_allocator)
462 .ok();
463 if (!blas_launch_status) {
464 context->SetStatus(errors::Internal(
465 "Blas xGEMMBatched launch failed : a.shape=",
466 in_x.shape().DebugString(),
467 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
468 ", k=", k, ", batch_size=", batch_size));
469 }
470 }
471 }
472 };
473
474 template <>
475 struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
476 static void Launch(OpKernelContext* context, const Tensor& in_x,
477 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
478 bool trans_y, const MatMulBCast& bcast, Tensor* out) {
479 typedef Eigen::half Scalar;
480 se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
481 se::blas::Transpose::kTranspose,
482 se::blas::Transpose::kConjugateTranspose};
483 const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
484 const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
485 const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
486 const uint64 batch_size = bcast.output_batch_size();
487 auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
488 auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
489
490 auto* stream = context->op_device_context()->stream();
491 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
492
493 typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
494 std::vector<DeviceMemoryType> a_device_memory;
495 std::vector<DeviceMemoryType> b_device_memory;
496 std::vector<DeviceMemoryType> c_device_memory;
497 std::vector<DeviceMemoryType*> a_ptrs;
498 std::vector<DeviceMemoryType*> b_ptrs;
499 std::vector<DeviceMemoryType*> c_ptrs;
500 a_device_memory.reserve(bcast.x_batch_size());
501 b_device_memory.reserve(bcast.y_batch_size());
502 c_device_memory.reserve(batch_size);
503 a_ptrs.reserve(batch_size);
504 b_ptrs.reserve(batch_size);
505 c_ptrs.reserve(batch_size);
506 auto* a_base_ptr = in_x.template flat<Scalar>().data();
507 auto* b_base_ptr = in_y.template flat<Scalar>().data();
508 auto* c_base_ptr = out->template flat<Scalar>().data();
509
510 uint64 a_stride;
511 uint64 b_stride;
512 uint64 c_stride;
513
514 bool is_full_broadcast =
515 std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
516 bool use_strided_batched =
517 (!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
518 batch_size > 1;
519 if (use_strided_batched) {
520 a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
521 b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
522 c_stride = m * n;
523 a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
524 b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
525 c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
526 a_ptrs.push_back(&a_device_memory.back());
527 b_ptrs.push_back(&b_device_memory.back());
528 c_ptrs.push_back(&c_device_memory.back());
529 } else if (!bcast.IsBroadcastingRequired()) {
530 for (int64_t i = 0; i < batch_size; ++i) {
531 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
532 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
533 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
534 a_ptrs.push_back(&a_device_memory.back());
535 b_ptrs.push_back(&b_device_memory.back());
536 c_ptrs.push_back(&c_device_memory.back());
537 }
538 } else {
539 const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
540 const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
541 for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
542 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
543 }
544 for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
545 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
546 }
547 for (int64_t i = 0; i < batch_size; ++i) {
548 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
549 a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
550 b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
551 c_ptrs.push_back(&c_device_memory.back());
552 }
553 }
554
555 typedef float Coefficient;
556
557 // Blas does
558 // C = A x B
559 // where A, B and C are assumed to be in column major.
560 // We want the output to be in row-major, so we can compute
561 // C' = B' x A', where ' stands for transpose (not adjoint).
562 // TODO(yangzihao): Choose the best of the three strategies using autotune.
563 if (batch_size == 1) {
564 // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
565 // overhead of the scratch allocator and the batch interface.
566 // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
567 OP_REQUIRES_OK(context,
568 stream->ThenBlasGemm(
569 blas_transpose_b, blas_transpose_a, n, m, k,
570 *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]),
571 adj_x || trans_x ? m : k, c_ptrs[0], n));
572 } else if (use_strided_batched) {
573 bool blas_launch_status =
574 stream
575 ->ThenBlasGemmStridedBatched(
576 blas_transpose_b, blas_transpose_a, n, m, k,
577 static_cast<Coefficient>(1.0), *b_ptrs[0],
578 adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
579 adj_x || trans_x ? m : k, a_stride,
580 static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
581 batch_size)
582 .ok();
583 if (!blas_launch_status) {
584 context->SetStatus(errors::Internal(
585 "Blas xGEMMStridedBatched launch failed : a.shape=",
586 in_x.shape().DebugString(),
587 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
588 ", k=", k, ", batch_size=", batch_size));
589 }
590 } else {
591 BlasScratchAllocator scratch_allocator(context);
592 bool blas_launch_status =
593 stream
594 ->ThenBlasGemmBatchedWithScratch(
595 blas_transpose_b, blas_transpose_a, n, m, k,
596 static_cast<Coefficient>(1.0), b_ptrs,
597 adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
598 static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
599 &scratch_allocator)
600 .ok();
601 if (!blas_launch_status) {
602 context->SetStatus(errors::Internal(
603 "Blas xGEMMBatched launch failed : a.shape=",
604 in_x.shape().DebugString(),
605 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
606 ", k=", k, ", batch_size=", batch_size));
607 }
608 }
609 }
610 };
611
612 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
613
614 template <typename Device, typename Ta, typename Tb, typename Tout>
615 class BaseBatchMatMulOp : public OpKernel {
616 public:
617 explicit BaseBatchMatMulOp(OpKernelConstruction* context,
618 bool is_legacy_matmul)
619 : OpKernel(context) {
620 if (is_legacy_matmul) {
621 // The old MatMul kernel has "transpose_a/transpose_b" attributes.
622 OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
623 OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
624 adj_x_ = false;
625 adj_y_ = false;
626 } else {
627 OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
628 OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
629 trans_x_ = false;
630 trans_y_ = false;
631 }
632 }
633
634 ~BaseBatchMatMulOp() override {}
635
636 void Compute(OpKernelContext* ctx) override {
637 const Tensor& in0 = ctx->input(0);
638 const Tensor& in1 = ctx->input(1);
639
640 const Status s = ValidateInputTensors(ctx, in0, in1);
641 if (!s.ok()) {
642 ctx->SetStatus(s);
643 return;
644 }
645
646 MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
647 OP_REQUIRES(
648 ctx, bcast.IsValid(),
649 errors::InvalidArgument(
650 "In[0] and In[1] must have compatible batch dimensions: ",
651 in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
652
653 TensorShape out_shape = bcast.output_batch_shape();
654 auto batch_size = bcast.output_batch_size();
655 auto d0 = in0.dim_size(in0.dims() - 2);
656 auto d1 = in0.dim_size(in0.dims() - 1);
657 Tensor in0_reshaped;
658 OP_REQUIRES(
659 ctx,
660 in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
661 errors::Internal("Failed to reshape In[0] from ",
662 in0.shape().DebugString()));
663 auto d2 = in1.dim_size(in1.dims() - 2);
664 auto d3 = in1.dim_size(in1.dims() - 1);
665 Tensor in1_reshaped;
666 OP_REQUIRES(
667 ctx,
668 in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
669 errors::Internal("Failed to reshape In[1] from ",
670 in1.shape().DebugString()));
671 if (adj_x_ || trans_x_) std::swap(d0, d1);
672 if (adj_y_ || trans_y_) std::swap(d2, d3);
673 OP_REQUIRES(ctx, d1 == d2,
674 errors::InvalidArgument(
675 "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
676 in0.shape().DebugString(), " ", in1.shape().DebugString(),
677 " ", adj_x_, " ", adj_y_));
678 out_shape.AddDim(d0);
679 out_shape.AddDim(d3);
680 Tensor* out = nullptr;
681 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
682 if (out->NumElements() == 0) {
683 return;
684 }
685 if (in0.NumElements() == 0 || in1.NumElements() == 0) {
686 functor::SetZeroFunctor<Device, Tout> f;
687 f(ctx->eigen_device<Device>(), out->flat<Tout>());
688 return;
689 }
690 Tensor out_reshaped;
691 OP_REQUIRES(ctx,
692 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
693 errors::Internal("Failed to reshape output from ",
694 out->shape().DebugString()));
695 if (std::is_same<Ta, bfloat16>::value &&
696 std::is_same<Tb, bfloat16>::value) {
697 bool is_cpu = std::is_same<Device, CPUDevice>::value;
698 OP_REQUIRES(ctx, is_cpu,
699 errors::Internal("bfloat16 matmul is not supported by GPU"));
700 Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
701 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
702 &in0_reshaped_float));
703 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
704 &in1_reshaped_float));
705 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
706 &out_reshaped_float));
707
708 // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
709 BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
710 in0_reshaped_float.flat<float>().data(),
711 in0_reshaped.NumElements());
712 BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
713 in1_reshaped_float.flat<float>().data(),
714 in1_reshaped.NumElements());
715
716 LaunchBatchMatMul<Device, float>::Launch(
717 ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
718 trans_y_, bcast, &out_reshaped_float);
719 FloatToBFloat16(out_reshaped_float.flat<float>().data(),
720 out_reshaped.flat<bfloat16>().data(), out->NumElements());
721 } else {
722 // Cast tensor to desired type to reuse Eigen.
723 // TODO(b/178749687): remove this cast if Eigen supports this natively.
724 if (!std::is_same<Ta, Tout>::value) {
725 in0_reshaped = CastTensor<Ta, Tout>(in0_reshaped);
726 }
727 if (!std::is_same<Tb, Tout>::value) {
728 in1_reshaped = CastTensor<Tb, Tout>(in1_reshaped);
729 }
730 LaunchBatchMatMul<Device, Tout>::Launch(ctx, in0_reshaped, in1_reshaped,
731 adj_x_, adj_y_, trans_x_,
732 trans_y_, bcast, &out_reshaped);
733 }
734 }
735
736 protected:
737 virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
738 const Tensor& in1) = 0;
739
740 private:
741 // TODO(171979567) Make the ops take both adj and transpose attributes.
742 bool adj_x_ = false;
743 bool adj_y_ = false;
744 bool trans_x_ = false;
745 bool trans_y_ = false;
746
747 // Cast `t` from `SrcT` to `DstT`.
748 template <typename SrcT, typename DstT>
749 Tensor CastTensor(const Tensor& t) {
750 Tensor res = Tensor(DataTypeToEnum<DstT>::v(), t.shape());
751 res.flat<DstT>() = t.flat<SrcT>().template cast<DstT>();
752 return res;
753 }
754 };
755
756 // BatchMatMul Op implementation which disallows broadcasting.
757 template <typename Device, typename Ta, typename Tb, typename Tout,
758 bool is_legacy_matmul = false>
759 class BatchMatMulOp : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> {
760 public:
761 explicit BatchMatMulOp(OpKernelConstruction* context)
762 : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context, is_legacy_matmul) {}
763
764 ~BatchMatMulOp() override {}
765
766 private:
767 Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
768 const Tensor& in1) override {
769 // Disallow broadcasting support. Ensure that all batch dimensions of the
770 // input tensors match.
771 if (in0.dims() != in1.dims()) {
772 return errors::InvalidArgument(
773 "In[0] and In[1] has different ndims: ", in0.shape().DebugString(),
774 " vs. ", in1.shape().DebugString());
775 }
776 const int ndims = in0.dims();
777 if (is_legacy_matmul) {
778 if (ndims != 2) {
779 return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: ",
780 ndims);
781 }
782 } else {
783 if (ndims < 2) {
784 return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ",
785 ndims);
786 }
787 for (int i = 0; i < ndims - 2; ++i) {
788 if (in0.dim_size(i) != in1.dim_size(i)) {
789 return errors::InvalidArgument(
790 "In[0].dim(", i, ") and In[1].dim(", i,
791 ") must be the same: ", in0.shape().DebugString(), " vs ",
792 in1.shape().DebugString());
793 }
794 }
795 }
796 return Status::OK();
797 }
798 };
799
800 // BatchMatMul Op implementation with broadcasting support.
801 template <typename Device, typename Ta, typename Tb, typename Tout>
802 class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> {
803 public:
804 explicit BatchMatMulV2Op(OpKernelConstruction* context)
805 : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context,
806 /* is_legacy_matmul= */ false) {
807 }
808
809 ~BatchMatMulV2Op() override {}
810
811 private:
812 Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
813 const Tensor& in1) override {
814 // Enable broadcasting support. Validity of broadcasting is checked in
815 // BaseBatchMatMulOp.
816 if (in0.dims() < 2) {
817 return errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims());
818 }
819 if (in1.dims() < 2) {
820 return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims());
821 }
822 return Status::OK();
823 }
824 };
825
826 // Register for MatMul, BatchMatMul, BatchMatMulv2 where Tin = Tout.
827 #define REGISTER_BATCH_MATMUL_CPU(TYPE) \
828 REGISTER_KERNEL_BUILDER( \
829 Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
830 BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE>); \
831 REGISTER_KERNEL_BUILDER( \
832 Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
833 BatchMatMulV2Op<CPUDevice, TYPE, TYPE, TYPE>); \
834 REGISTER_KERNEL_BUILDER( \
835 Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
836 BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>)
837
838 #define REGISTER_BATCH_MATMUL_GPU(TYPE) \
839 REGISTER_KERNEL_BUILDER( \
840 Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
841 BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE>); \
842 REGISTER_KERNEL_BUILDER( \
843 Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
844 BatchMatMulV2Op<GPUDevice, TYPE, TYPE, TYPE>); \
845 REGISTER_KERNEL_BUILDER( \
846 Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
847 BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>)
848
849 // Register for BatchMatMulv3 where Ta, Tb and Tout are not the same.
850 #define REGISTER_BATCH_MATMUL_TOUT_CPU(Ta, Tb, Tout) \
851 REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3") \
852 .Device(DEVICE_CPU) \
853 .TypeConstraint<Ta>("Ta") \
854 .TypeConstraint<Tb>("Tb") \
855 .TypeConstraint<Tout>("Tout"), \
856 BatchMatMulV2Op<CPUDevice, Ta, Tb, Tout>)
857
858 #define REGISTER_BATCH_MATMUL_TOUT_GPU(Ta, Tb, Tout) \
859 REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3") \
860 .Device(DEVICE_GPU) \
861 .TypeConstraint<Ta>("Ta") \
862 .TypeConstraint<Tb>("Tb") \
863 .TypeConstraint<Tout>("Tout"), \
864 BatchMatMulV2Op<GPUDevice, Ta, Tb, Tout>)
865
866 } // namespace tensorflow
867
868 #endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
869