1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
20
21 #define EIGEN_USE_THREADS
22
23 #include <vector>
24
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/type_traits.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/kernels/fill_functor.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/inlined_vector.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/util/matmul_bcast.h"
39 #include "tensorflow/core/util/work_sharder.h"
40
41 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
42 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
43 #endif
44
45 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
46 #include "tensorflow/core/platform/stream_executor.h"
47 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48
49 namespace tensorflow {
50
51 typedef Eigen::ThreadPoolDevice CPUDevice;
52 typedef Eigen::GpuDevice GPUDevice;
53
54 namespace {
55
56 // Returns the pair of dimensions along which to perform Tensor contraction to
57 // emulate matrix multiplication.
58 // For matrix multiplication of 2D Tensors X and Y, X is contracted along
59 // second dimension and Y is contracted along the first dimension (if neither X
60 // nor Y is adjointed). The dimension to contract along is switched when any
61 // operand is adjointed.
62 // See http://en.wikipedia.org/wiki/Tensor_contraction
ContractionDims(bool adj_x,bool adj_y)63 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
64 return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
65 }
66
67 // Parallel batch matmul kernel based on the multi-threaded tensor contraction
68 // in Eigen.
69 template <typename Scalar, bool IsComplex = true>
70 struct ParallelMatMulKernel {
ConjugateParallelMatMulKernel71 static void Conjugate(const OpKernelContext* context, Tensor* out) {
72 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
73 auto z = out->tensor<Scalar, 3>();
74 z.device(d) = z.conjugate();
75 }
76
RunParallelMatMulKernel77 static void Run(const OpKernelContext* context, const Tensor& in_x,
78 const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
79 bool trans_y, const MatMulBCast& bcast, Tensor* out,
80 int batch_size) {
81 static_assert(IsComplex, "Complex type expected.");
82 auto Tx = in_x.tensor<Scalar, 3>();
83 auto Ty = in_y.tensor<Scalar, 3>();
84 auto Tz = out->tensor<Scalar, 3>();
85 // We use the identities
86 // conj(a) * conj(b) = conj(a * b)
87 // conj(a) * b = conj(a * conj(b))
88 // to halve the number of cases. The final conjugation of the result is
89 // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
90 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
91 contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
92 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
93
94 const bool should_bcast = bcast.IsBroadcastingRequired();
95 const auto& x_batch_indices = bcast.x_batch_indices();
96 const auto& y_batch_indices = bcast.y_batch_indices();
97 // TODO(rmlarsen): Consider launching these contractions asynchronously.
98 for (int64 i = 0; i < batch_size; ++i) {
99 const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
100 const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
101
102 auto x = Tx.template chip<0>(x_batch_index);
103 auto z = Tz.template chip<0>(i);
104 if (adj_x != adj_y) {
105 auto y = Ty.template chip<0>(y_batch_index).conjugate();
106 z.device(d) = x.contract(y, contract_pairs);
107 } else {
108 auto y = Ty.template chip<0>(y_batch_index);
109 z.device(d) = x.contract(y, contract_pairs);
110 }
111 }
112 }
113 };
114
115 // The Eigen contraction kernel used here is very large and slow to compile,
116 // so we partially specialize ParallelMatMulKernel for real types to avoid all
117 // but one of the instantiations.
118 template <typename Scalar>
119 struct ParallelMatMulKernel<Scalar, false> {
120 static void Conjugate(const OpKernelContext* context, Tensor* out) {}
121
122 static void Run(const OpKernelContext* context, const Tensor& in_x,
123 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
124 bool trans_y, const MatMulBCast& bcast, Tensor* out,
125 int batch_size) {
126 const bool should_bcast = bcast.IsBroadcastingRequired();
127 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
128 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
129 contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
130 if (batch_size == 1 && !should_bcast) {
131 auto Tx = in_x.flat_inner_dims<Scalar, 2>();
132 auto Ty = in_y.flat_inner_dims<Scalar, 2>();
133 auto Tz = out->flat_inner_dims<Scalar, 2>();
134 Tz.device(d) = Tx.contract(Ty, contract_pairs);
135 } else {
136 auto Tx = in_x.tensor<Scalar, 3>();
137 auto Ty = in_y.tensor<Scalar, 3>();
138 auto Tz = out->tensor<Scalar, 3>();
139 const auto& x_batch_indices = bcast.x_batch_indices();
140 const auto& y_batch_indices = bcast.y_batch_indices();
141 // TODO(rmlarsen): Consider launching these contractions asynchronously.
142 for (int64 i = 0; i < batch_size; ++i) {
143 const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
144 const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
145 auto x = Tx.template chip<0>(x_batch_index);
146 auto y = Ty.template chip<0>(y_batch_index);
147 auto z = Tz.template chip<0>(i);
148
149 z.device(d) = x.contract(y, contract_pairs);
150 }
151 }
152 }
153 };
154
155 // Sequential batch matmul kernel that calls the regular Eigen matmul.
156 // We prefer this over the tensor contraction because it performs
157 // better on vector-matrix and matrix-vector products.
158 template <typename Scalar>
159 struct SequentialMatMulKernel {
160 using Matrix =
161 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
162 using ConstMatrixMap = Eigen::Map<const Matrix>;
163 using MatrixMap = Eigen::Map<Matrix>;
164
165 static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
166 int slice) {
167 return ConstMatrixMap(
168 t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
169 t.dim_size(1), t.dim_size(2));
170 }
171
172 static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
173 return MatrixMap(
174 t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
175 t->dim_size(1), t->dim_size(2));
176 }
177
178 static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
179 bool adj_y, bool trans_x, bool trans_y,
180 const MatMulBCast& bcast, Tensor* out, int start, int limit) {
181 const bool should_bcast = bcast.IsBroadcastingRequired();
182 const auto& x_batch_indices = bcast.x_batch_indices();
183 const auto& y_batch_indices = bcast.y_batch_indices();
184 for (int64 i = start; i < limit; ++i) {
185 const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
186 const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
187 auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
188 auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
189 auto z = TensorSliceToEigenMatrix(out, i);
190 // Assume at most one of adj_x or trans_x is true. Similarly, for adj_y
191 // and trans_y.
192 if (!adj_x && !trans_x) {
193 if (!adj_y && !trans_y) {
194 z.noalias() = x * y;
195 } else if (adj_y) {
196 z.noalias() = x * y.adjoint();
197 } else { // trans_y == true
198 z.noalias() = x * y.transpose();
199 }
200 } else if (adj_x) {
201 if (!adj_y && !trans_y) {
202 z.noalias() = x.adjoint() * y;
203 } else if (adj_y) {
204 z.noalias() = x.adjoint() * y.adjoint();
205 } else { // trans_y == true
206 z.noalias() = x.adjoint() * y.transpose();
207 }
208 } else { // trans_x == true
209 if (!adj_y && !trans_y) {
210 z.noalias() = x.transpose() * y;
211 } else if (adj_y) {
212 z.noalias() = x.transpose() * y.adjoint();
213 } else { // trans_y == true
214 z.noalias() = x.transpose() * y.transpose();
215 }
216 }
217 }
218 }
219 };
220
221 } // namespace
222
223 template <typename Device, typename Scalar>
224 struct LaunchBatchMatMul;
225
226 template <typename Scalar>
227 struct LaunchBatchMatMul<CPUDevice, Scalar> {
228 static void Launch(OpKernelContext* context, const Tensor& in_x,
229 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
230 bool trans_y, const MatMulBCast& bcast, Tensor* out) {
231 typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
232 ParallelMatMulKernel;
233 bool conjugate_result = false;
234
235 // Number of matrix multiplies i.e. size of the batch.
236 const int64 batch_size = bcast.output_batch_size();
237 const int64 cost_per_unit =
238 in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
239 const int64 small_dim = std::min(
240 std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
241 // NOTE(nikhilsarda): This heuristic is optimal in benchmarks as of
242 // Jan 21, 2020.
243 const int64 kMaxCostOuterParallelism = 128 * 128; // heuristic.
244 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
245 // TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous
246 // evaluation in Eigen Tensor.
247 if (small_dim > 1 &&
248 (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
249 // Parallelize over inner dims.
250 // For large matrix products it is counter-productive to parallelize
251 // over the batch dimension.
252 ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
253 trans_y, bcast, out, batch_size);
254 conjugate_result = adj_x;
255 } else {
256 // Parallelize over outer dims. For small matrices and large batches, it
257 // is counter-productive to parallelize the inner matrix multiplies.
258 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
259 cost_per_unit,
260 [&in_x, &in_y, adj_x, adj_y, trans_x, trans_y, &bcast, out](
261 int start, int limit) {
262 SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
263 trans_x, trans_y, bcast, out,
264 start, limit);
265 });
266 }
267 if (conjugate_result) {
268 // We used one of the identities
269 // conj(a) * conj(b) = conj(a * b)
270 // conj(a) * b = conj(a * conj(b))
271 // above, we need to conjugate the final output. This is a
272 // no-op for non-complex types.
273 ParallelMatMulKernel::Conjugate(context, out);
274 }
275 }
276 };
277
278 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
279
280 namespace {
281 template <typename T>
282 se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
283 se::DeviceMemoryBase wrapped(const_cast<T*>(gpu_memory));
284 se::DeviceMemory<T> typed(wrapped);
285 return typed;
286 }
287
288 class BlasScratchAllocator : public se::ScratchAllocator {
289 public:
290 using Stream = se::Stream;
291 using DeviceMemoryBytes = se::DeviceMemory<uint8>;
292
293 BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
294
295 int64 GetMemoryLimitInBytes() override { return -1; }
296
297 se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
298 int64 byte_size) override {
299 Tensor temporary_memory;
300
301 Status allocation_status(context_->allocate_temp(
302 DT_UINT8, TensorShape({byte_size}), &temporary_memory));
303 if (!allocation_status.ok()) {
304 return se::port::StatusOr<DeviceMemoryBytes>(
305 DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
306 }
307 // Hold the reference of the allocated tensors until the end of the
308 // allocator.
309 allocated_tensors_.push_back(temporary_memory);
310 return se::port::StatusOr<DeviceMemoryBytes>(
311 DeviceMemoryBytes::MakeFromByteSize(
312 temporary_memory.flat<uint8>().data(),
313 temporary_memory.flat<uint8>().size()));
314 }
315
316 private:
317 OpKernelContext* context_;
318 std::vector<Tensor> allocated_tensors_;
319 };
320 } // namespace
321
322 template <typename Scalar>
323 struct LaunchBatchMatMul<GPUDevice, Scalar> {
324 static void Launch(OpKernelContext* context, const Tensor& in_x,
325 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
326 bool trans_y, const MatMulBCast& bcast, Tensor* out) {
327 se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
328 se::blas::Transpose::kTranspose,
329 se::blas::Transpose::kConjugateTranspose};
330 const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
331 const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
332 const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
333 const int64 batch_size = bcast.output_batch_size();
334 auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
335 auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
336
337 auto* stream = context->op_device_context()->stream();
338 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
339
340 typedef se::DeviceMemory<Scalar> DeviceMemoryType;
341 std::vector<DeviceMemoryType> a_device_memory;
342 std::vector<DeviceMemoryType> b_device_memory;
343 std::vector<DeviceMemoryType> c_device_memory;
344 std::vector<DeviceMemoryType*> a_ptrs;
345 std::vector<DeviceMemoryType*> b_ptrs;
346 std::vector<DeviceMemoryType*> c_ptrs;
347 a_device_memory.reserve(bcast.x_batch_size());
348 b_device_memory.reserve(bcast.y_batch_size());
349 c_device_memory.reserve(batch_size);
350 a_ptrs.reserve(batch_size);
351 b_ptrs.reserve(batch_size);
352 c_ptrs.reserve(batch_size);
353 auto* a_base_ptr = in_x.template flat<Scalar>().data();
354 auto* b_base_ptr = in_y.template flat<Scalar>().data();
355 auto* c_base_ptr = out->template flat<Scalar>().data();
356 uint64 a_stride;
357 uint64 b_stride;
358 uint64 c_stride;
359
360 bool is_full_broadcast =
361 std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
362 bool use_strided_batched =
363 (!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
364 batch_size > 1;
365 if (use_strided_batched) {
366 a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
367 b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
368 c_stride = m * n;
369 a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
370 b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
371 c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
372 a_ptrs.push_back(&a_device_memory.back());
373 b_ptrs.push_back(&b_device_memory.back());
374 c_ptrs.push_back(&c_device_memory.back());
375 } else if (!bcast.IsBroadcastingRequired()) {
376 for (int64 i = 0; i < batch_size; ++i) {
377 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
378 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
379 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
380 a_ptrs.push_back(&a_device_memory.back());
381 b_ptrs.push_back(&b_device_memory.back());
382 c_ptrs.push_back(&c_device_memory.back());
383 }
384 } else {
385 const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
386 const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
387 for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
388 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
389 }
390 for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
391 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
392 }
393 for (int64 i = 0; i < batch_size; ++i) {
394 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
395 a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
396 b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
397 c_ptrs.push_back(&c_device_memory.back());
398 }
399 }
400
401 typedef Scalar Coefficient;
402
403 // Blas does
404 // C = A x B
405 // where A, B and C are assumed to be in column major.
406 // We want the output to be in row-major, so we can compute
407 // C' = B' x A', where ' stands for transpose (not adjoint).
408 // TODO(yangzihao): Choose the best of the three strategies using autotune.
409 if (batch_size == 1) {
410 // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
411 // overhead of the scratch allocator and the batch interface.
412 if (n == 1 &&
413 blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
414 blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
415 // This is a matrix*vector multiply so use GEMV to compute A * b.
416 // Here we are multiplying in the natural order, so we have to flip
417 // the transposition flag to compensate for the tensor being stored
418 // row-major. Since GEMV doesn't provide a way to just conjugate an
419 // argument, we have to defer those cases to GEMM below.
420 auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
421 ? se::blas::Transpose::kNoTranspose
422 : se::blas::Transpose::kTranspose;
423 bool blas_launch_status =
424 stream
425 ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
426 adj_x || trans_x ? k : m,
427 static_cast<Coefficient>(1.0), *(a_ptrs[0]),
428 adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
429 static_cast<Coefficient>(0.0), c_ptrs[0], 1)
430 .ok();
431 if (!blas_launch_status) {
432 context->SetStatus(errors::Internal(
433 "Blas xGEMV launch failed : a.shape=", in_x.shape().DebugString(),
434 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
435 ", k=", k));
436 }
437 } else {
438 bool blas_launch_status =
439 stream
440 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
441 static_cast<Coefficient>(1.0), *(b_ptrs[0]),
442 adj_y || trans_y ? k : n, *(a_ptrs[0]),
443 adj_x || trans_x ? m : k,
444 static_cast<Coefficient>(0.0), c_ptrs[0], n)
445 .ok();
446 if (!blas_launch_status) {
447 context->SetStatus(errors::Internal(
448 "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
449 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
450 ", k=", k));
451 }
452 }
453 } else if (use_strided_batched) {
454 bool blas_launch_status =
455 stream
456 ->ThenBlasGemmStridedBatched(
457 blas_transpose_b, blas_transpose_a, n, m, k,
458 static_cast<Coefficient>(1.0), *b_ptrs[0],
459 adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
460 adj_x || trans_x ? m : k, a_stride,
461 static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
462 batch_size)
463 .ok();
464 if (!blas_launch_status) {
465 context->SetStatus(errors::Internal(
466 "Blas xGEMMStridedBatched launch failed : a.shape=",
467 in_x.shape().DebugString(),
468 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
469 ", k=", k, ", batch_size=", batch_size));
470 }
471 } else {
472 BlasScratchAllocator scratch_allocator(context);
473 bool blas_launch_status =
474 stream
475 ->ThenBlasGemmBatchedWithScratch(
476 blas_transpose_b, blas_transpose_a, n, m, k,
477 static_cast<Coefficient>(1.0), b_ptrs,
478 adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
479 static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
480 &scratch_allocator)
481 .ok();
482 if (!blas_launch_status) {
483 context->SetStatus(errors::Internal(
484 "Blas xGEMMBatched launch failed : a.shape=",
485 in_x.shape().DebugString(),
486 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
487 ", k=", k, ", batch_size=", batch_size));
488 }
489 }
490 }
491 };
492
493 template <>
494 struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
495 static void Launch(OpKernelContext* context, const Tensor& in_x,
496 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
497 bool trans_y, const MatMulBCast& bcast, Tensor* out) {
498 typedef Eigen::half Scalar;
499 se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
500 se::blas::Transpose::kTranspose,
501 se::blas::Transpose::kConjugateTranspose};
502 const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
503 const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
504 const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
505 const uint64 batch_size = bcast.output_batch_size();
506 auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
507 auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
508
509 auto* stream = context->op_device_context()->stream();
510 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
511
512 typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
513 std::vector<DeviceMemoryType> a_device_memory;
514 std::vector<DeviceMemoryType> b_device_memory;
515 std::vector<DeviceMemoryType> c_device_memory;
516 std::vector<DeviceMemoryType*> a_ptrs;
517 std::vector<DeviceMemoryType*> b_ptrs;
518 std::vector<DeviceMemoryType*> c_ptrs;
519 a_device_memory.reserve(bcast.x_batch_size());
520 b_device_memory.reserve(bcast.y_batch_size());
521 c_device_memory.reserve(batch_size);
522 a_ptrs.reserve(batch_size);
523 b_ptrs.reserve(batch_size);
524 c_ptrs.reserve(batch_size);
525 auto* a_base_ptr = in_x.template flat<Scalar>().data();
526 auto* b_base_ptr = in_y.template flat<Scalar>().data();
527 auto* c_base_ptr = out->template flat<Scalar>().data();
528
529 uint64 a_stride;
530 uint64 b_stride;
531 uint64 c_stride;
532
533 bool is_full_broadcast =
534 std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
535 bool use_strided_batched =
536 (!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
537 batch_size > 1;
538 if (use_strided_batched) {
539 a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
540 b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
541 c_stride = m * n;
542 a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
543 b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
544 c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
545 a_ptrs.push_back(&a_device_memory.back());
546 b_ptrs.push_back(&b_device_memory.back());
547 c_ptrs.push_back(&c_device_memory.back());
548 } else if (!bcast.IsBroadcastingRequired()) {
549 for (int64 i = 0; i < batch_size; ++i) {
550 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
551 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
552 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
553 a_ptrs.push_back(&a_device_memory.back());
554 b_ptrs.push_back(&b_device_memory.back());
555 c_ptrs.push_back(&c_device_memory.back());
556 }
557 } else {
558 const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
559 const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
560 for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
561 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
562 }
563 for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
564 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
565 }
566 for (int64 i = 0; i < batch_size; ++i) {
567 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
568 a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
569 b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
570 c_ptrs.push_back(&c_device_memory.back());
571 }
572 }
573
574 typedef float Coefficient;
575
576 // Blas does
577 // C = A x B
578 // where A, B and C are assumed to be in column major.
579 // We want the output to be in row-major, so we can compute
580 // C' = B' x A', where ' stands for transpose (not adjoint).
581 // TODO(yangzihao): Choose the best of the three strategies using autotune.
582 if (batch_size == 1) {
583 // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
584 // overhead of the scratch allocator and the batch interface.
585 // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
586 bool blas_launch_status =
587 stream
588 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
589 static_cast<Coefficient>(1.0), *(b_ptrs[0]),
590 adj_y || trans_y ? k : n, *(a_ptrs[0]),
591 adj_x || trans_x ? m : k,
592 static_cast<Coefficient>(0.0), c_ptrs[0], n)
593 .ok();
594 if (!blas_launch_status) {
595 context->SetStatus(errors::Internal(
596 "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
597 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
598 ", k=", k));
599 }
600 } else if (use_strided_batched) {
601 bool blas_launch_status =
602 stream
603 ->ThenBlasGemmStridedBatched(
604 blas_transpose_b, blas_transpose_a, n, m, k,
605 static_cast<Coefficient>(1.0), *b_ptrs[0],
606 adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
607 adj_x || trans_x ? m : k, a_stride,
608 static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
609 batch_size)
610 .ok();
611 if (!blas_launch_status) {
612 context->SetStatus(errors::Internal(
613 "Blas xGEMMStridedBatched launch failed : a.shape=",
614 in_x.shape().DebugString(),
615 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
616 ", k=", k, ", batch_size=", batch_size));
617 }
618 } else {
619 BlasScratchAllocator scratch_allocator(context);
620 bool blas_launch_status =
621 stream
622 ->ThenBlasGemmBatchedWithScratch(
623 blas_transpose_b, blas_transpose_a, n, m, k,
624 static_cast<Coefficient>(1.0), b_ptrs,
625 adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
626 static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
627 &scratch_allocator)
628 .ok();
629 if (!blas_launch_status) {
630 context->SetStatus(errors::Internal(
631 "Blas xGEMMBatched launch failed : a.shape=",
632 in_x.shape().DebugString(),
633 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
634 ", k=", k, ", batch_size=", batch_size));
635 }
636 }
637 }
638 };
639
640 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
641
642
643 template <typename Device, typename Scalar>
644 class BaseBatchMatMulOp : public OpKernel {
645 public:
646 explicit BaseBatchMatMulOp(OpKernelConstruction* context,
647 bool is_legacy_matmul)
648 : OpKernel(context) {
649 if (is_legacy_matmul) {
650 // The old MatMul kernel has "transpose_a/transpose_b" attributes.
651 OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
652 OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
653 adj_x_ = false;
654 adj_y_ = false;
655 } else {
656 OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
657 OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
658 trans_x_ = false;
659 trans_y_ = false;
660 }
661 }
662
663 ~BaseBatchMatMulOp() override {}
664
665 void Compute(OpKernelContext* ctx) override {
666 const Tensor& in0 = ctx->input(0);
667 const Tensor& in1 = ctx->input(1);
668
669 const Status s = ValidateInputTensors(ctx, in0, in1);
670 if (!s.ok()) {
671 ctx->SetStatus(s);
672 return;
673 }
674
675 MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
676 OP_REQUIRES(
677 ctx, bcast.IsValid(),
678 errors::InvalidArgument(
679 "In[0] and In[1] must have compatible batch dimensions: ",
680 in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
681
682 TensorShape out_shape = bcast.output_batch_shape();
683 auto batch_size = bcast.output_batch_size();
684 auto d0 = in0.dim_size(in0.dims() - 2);
685 auto d1 = in0.dim_size(in0.dims() - 1);
686 Tensor in0_reshaped;
687 OP_REQUIRES(
688 ctx,
689 in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
690 errors::Internal("Failed to reshape In[0] from ",
691 in0.shape().DebugString()));
692 auto d2 = in1.dim_size(in1.dims() - 2);
693 auto d3 = in1.dim_size(in1.dims() - 1);
694 Tensor in1_reshaped;
695 OP_REQUIRES(
696 ctx,
697 in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
698 errors::Internal("Failed to reshape In[1] from ",
699 in1.shape().DebugString()));
700 if (adj_x_ || trans_x_) std::swap(d0, d1);
701 if (adj_y_ || trans_y_) std::swap(d2, d3);
702 OP_REQUIRES(ctx, d1 == d2,
703 errors::InvalidArgument(
704 "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
705 in0.shape().DebugString(), " ", in1.shape().DebugString(),
706 " ", adj_x_, " ", adj_y_));
707 out_shape.AddDim(d0);
708 out_shape.AddDim(d3);
709 Tensor* out = nullptr;
710 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
711 if (out->NumElements() == 0) {
712 return;
713 }
714 if (in0.NumElements() == 0 || in1.NumElements() == 0) {
715 functor::SetZeroFunctor<Device, Scalar> f;
716 f(ctx->eigen_device<Device>(), out->flat<Scalar>());
717 return;
718 }
719 Tensor out_reshaped;
720 OP_REQUIRES(ctx,
721 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
722 errors::Internal("Failed to reshape output from ",
723 out->shape().DebugString()));
724 if (std::is_same<Scalar, bfloat16>::value) {
725 bool is_cpu = std::is_same<Device, CPUDevice>::value;
726 OP_REQUIRES(ctx, is_cpu,
727 errors::Internal("bfloat16 matmul is not supported by GPU"));
728 Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
729 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
730 &in0_reshaped_float));
731 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
732 &in1_reshaped_float));
733 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
734 &out_reshaped_float));
735
736 // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
737 BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
738 in0_reshaped_float.flat<float>().data(),
739 in0_reshaped.NumElements());
740 BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
741 in1_reshaped_float.flat<float>().data(),
742 in1_reshaped.NumElements());
743
744 LaunchBatchMatMul<Device, float>::Launch(
745 ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
746 trans_y_, bcast, &out_reshaped_float);
747 FloatToBFloat16(out_reshaped_float.flat<float>().data(),
748 out_reshaped.flat<bfloat16>().data(), out->NumElements());
749 } else {
750 LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
751 adj_x_, adj_y_, trans_x_,
752 trans_y_, bcast, &out_reshaped);
753 }
754 }
755
756 protected:
757 virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
758 const Tensor& in1) = 0;
759
760 private:
761 // TODO(171979567) Make the ops take both adj and transpose attributes.
762 bool adj_x_;
763 bool adj_y_;
764 bool trans_x_;
765 bool trans_y_;
766 };
767
768 // BatchMatMul Op implementation which disallows broadcasting.
769 template <typename Device, typename Scalar, bool is_legacy_matmul = false>
770 class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
771 public:
772 explicit BatchMatMulOp(OpKernelConstruction* context)
773 : BaseBatchMatMulOp<Device, Scalar>(context, is_legacy_matmul) {}
774
775 ~BatchMatMulOp() override {}
776
777 private:
778 Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
779 const Tensor& in1) override {
780 // Disallow broadcasting support. Ensure that all batch dimensions of the
781 // input tensors match.
782 if (in0.dims() != in1.dims()) {
783 return errors::InvalidArgument(
784 "In[0] and In[1] has different ndims: ", in0.shape().DebugString(),
785 " vs. ", in1.shape().DebugString());
786 }
787 const int ndims = in0.dims();
788 if (is_legacy_matmul) {
789 if (ndims != 2) {
790 return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: ",
791 ndims);
792 }
793 } else {
794 if (ndims < 2) {
795 return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ",
796 ndims);
797 }
798 for (int i = 0; i < ndims - 2; ++i) {
799 if (in0.dim_size(i) != in1.dim_size(i)) {
800 return errors::InvalidArgument(
801 "In[0].dim(", i, ") and In[1].dim(", i,
802 ") must be the same: ", in0.shape().DebugString(), " vs ",
803 in1.shape().DebugString());
804 }
805 }
806 }
807 return Status::OK();
808 }
809 };
810
811 // BatchMatMul Op implementation with broadcasting support.
812 template <typename Device, typename Scalar>
813 class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
814 public:
815 explicit BatchMatMulV2Op(OpKernelConstruction* context)
816 : BaseBatchMatMulOp<Device, Scalar>(context,
817 /* is_legacy_matmul= */ false) {}
818
819 ~BatchMatMulV2Op() override {}
820
821 private:
822 Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
823 const Tensor& in1) override {
824 // Enable broadcasting support. Validity of broadcasting is checked in
825 // BaseBatchMatMulOp.
826 if (in0.dims() < 2) {
827 return errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims());
828 }
829 if (in1.dims() < 2) {
830 return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims());
831 }
832 return Status::OK();
833 }
834 };
835
836 #define REGISTER_BATCH_MATMUL_CPU(TYPE) \
837 REGISTER_KERNEL_BUILDER( \
838 Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
839 BatchMatMulOp<CPUDevice, TYPE>); \
840 REGISTER_KERNEL_BUILDER( \
841 Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
842 BatchMatMulV2Op<CPUDevice, TYPE>); \
843 REGISTER_KERNEL_BUILDER( \
844 Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
845 BatchMatMulOp<CPUDevice, TYPE, /* is_legacy_matmul=*/true>)
846
847 #define REGISTER_BATCH_MATMUL_GPU(TYPE) \
848 REGISTER_KERNEL_BUILDER( \
849 Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
850 BatchMatMulOp<GPUDevice, TYPE>); \
851 REGISTER_KERNEL_BUILDER( \
852 Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
853 BatchMatMulV2Op<GPUDevice, TYPE>); \
854 REGISTER_KERNEL_BUILDER( \
855 Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
856 BatchMatMulOp<GPUDevice, TYPE, /* is_legacy_matmul=*/true>)
857
858 } // namespace tensorflow
859
860 #endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
861