1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/SparseCsrTensorUtils.h>
5 #include <ATen/native/Resize.h>
6 #include <ATen/native/sparse/cuda/SparseBlasImpl.h>
7 #include <ATen/native/sparse/SparseBlas.h>
8 #include <ATen/native/sparse/SparseCsrTensorMath.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/addmm_native.h>
15 #include <ATen/ops/addmv_native.h>
16 #include <ATen/ops/copy_native.h>
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/mul.h>
19 #include <ATen/ops/resize_as_sparse_native.h>
20 #include <ATen/ops/scalar_tensor_native.h>
21 #include <ATen/ops/sparse_sampled_addmm_native.h>
22 #include <ATen/ops/triangular_solve_native.h>
23 #endif
24
25 #include <c10/util/MaybeOwned.h>
26
27 namespace at::native {
28
29 /*
30 Computes `result` <- α*(A @ B) * spy(C) + β*C, where spy(C) is the sparsity pattern matrix of C.
31
32 Args:
33 * `mat1` - [in] dense Tensor A of size m × k.
34 * `mat2` - [in] dense Tensor B of size k × n.
35 * `self` - [in] sparse Tensor C of size m × n.
36 * `result` - [out] sparse Tensor of size m × n.
37 */
sparse_sampled_addmm_out_sparse_csr_cuda(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)38 Tensor& sparse_sampled_addmm_out_sparse_csr_cuda(
39 const Tensor& self,
40 const Tensor& mat1,
41 const Tensor& mat2,
42 const Scalar& beta,
43 const Scalar& alpha,
44 Tensor& result) {
45 at::native::sparse::sparse_sampled_addmm_check_inputs(
46 self, mat1, mat2, beta, alpha, result);
47
48 if (&result != &self) {
49 // We allow self to be a single matrix when mat1 and mat2 are batched
50 auto result_sizes = DimVector(mat1.sizes().slice(0, mat1.dim() - 2));
51 result_sizes.push_back(self.size(-2));
52 result_sizes.push_back(self.size(-1));
53 at::sparse_csr::get_sparse_csr_impl(result)->resize_(self._nnz(), result_sizes);
54 result.copy_(self);
55 }
56
57 // there's a segfault when calling cuSPARSE on 0-sized matrices
58 if (mat1.numel() == 0 || mat2.numel() == 0) {
59 result.mul_(beta);
60 return result;
61 }
62
63 sparse::impl::cuda::sampled_addmm_out_sparse_csr(mat1, mat2, beta, alpha, result);
64 return result;
65 }
66
sparse_sampled_addmm_sparse_csr_cuda(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha)67 Tensor sparse_sampled_addmm_sparse_csr_cuda(
68 const Tensor& self,
69 const Tensor& mat1,
70 const Tensor& mat2,
71 const Scalar& beta,
72 const Scalar& alpha) {
73 auto result = at::empty({0, 0}, self.options());
74 at::native::sparse_sampled_addmm_out_sparse_csr_cuda(self, mat1, mat2, beta, alpha, result);
75 return result;
76 }
77
78 // result = beta * self + alpha * (mat1 @ mat2)
addmm_out_sparse_compressed_cuda(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)79 Tensor& addmm_out_sparse_compressed_cuda(
80 const Tensor& self,
81 const Tensor& mat1,
82 const Tensor& mat2,
83 const Scalar& beta,
84 const Scalar& alpha,
85 Tensor& result) {
86 sparse::impl::_check_is_cuda(self, "self");
87 sparse::impl::_check_is_cuda(mat1, "mat1");
88 sparse::impl::_check_is_cuda(mat2, "mat2");
89 sparse::impl::_check_is_cuda(result, "result");
90
91 // Same checks as in TORCH_META_FUNC(addmm) at
92 // aten/src/ATen/native/LinearAlgebra.cpp
93 sparse::impl::_check_dim(mat1, 2, "mat1");
94 sparse::impl::_check_dim(mat2, 2, "mat2");
95
96 TORCH_CHECK(
97 mat1.size(1) == mat2.size(0), "mat1 and mat2 shapes cannot be multiplied (",
98 mat1.size(0), "x", mat1.size(1), " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
99
100 // From addmm_out_cuda_impl at ATen/native/cuda/Blas.cpp
101 // TODO: remove code duplication and unify code
102 // There were undefined symbol problems,
103 // when using the same function for CUDA and SparseCsrCUDA dispatch keys
104 // Also structured kernels do not support sparse output
105 c10::MaybeOwned<at::Tensor> self_;
106 // Don't expand self if this is an in-place operation
107 if (&result == &self) {
108 self_ = c10::MaybeOwned<Tensor>::borrowed(self);
109 } else {
110 self_ = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm");
111 }
112
113 sparse::impl::_check_dim(*self_, 2, "self");
114 TORCH_CHECK(((self_->dim() == 2) &&
115 (self_->size(0) == mat1.size(0)) &&
116 (self_->size(1) == mat2.size(1))),
117 "The input tensor must be a matrix with size ",
118 mat1.size(0),
119 "x",
120 mat2.size(1),
121 ", but got a ",
122 self_->dim(),
123 "-D tensor with size ",
124 self_->size(0),
125 "x",
126 self_->size(1));
127
128 if (!result.is_same(self)) {
129 if (result.layout() == kStrided) {
130 at::native::resize_output(result, self_->sizes());
131 } else {
132 result.resize_as_sparse_(*self_);
133 }
134 }
135
136 if (result.numel() == 0) {
137 return result;
138 }
139
140 if (sparse::impl::_is_sparse_and_zero(mat1) || sparse::impl::_is_sparse_and_zero(mat2)) {
141 // According to docs, when beta==0 values in self should be ignored.
142 // nans and infs should not propagate
143 const auto beta_val = beta.toComplexDouble();
144 if (beta_val == 0.) {
145 result.zero_();
146 } else {
147 if (!result.is_same(self)) {
148 result.copy_(*self_);
149 }
150 if (beta_val != 1.) {
151 result.mul_(beta);
152 }
153 }
154 return result;
155 }
156
157 sparse::impl::cuda::addmm_out_sparse_csr(*self_, mat1, mat2, beta, alpha, result);
158 return result;
159 }
160
baddbmm_out_sparse_csr_cuda(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)161 Tensor& baddbmm_out_sparse_csr_cuda(
162 const Tensor& self,
163 const Tensor& mat1,
164 const Tensor& mat2,
165 const Scalar& beta,
166 const Scalar& alpha,
167 Tensor& result) {
168 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.is_sparse_csr());
169
170 TORCH_CHECK(self.layout() == kStrided, "torch.baddbmm: Expected self to be strided, but got layout ", self.layout());
171 TORCH_CHECK(mat2.layout() == kStrided, "torch.baddbmm: Expect mat2 to be strided, but got ", mat2.layout());
172 TORCH_CHECK(result.layout() == kStrided, "torch.baddbmm: Expect result to be strided, but got ", result.layout());
173
174 if (!result.is_same(self)) {
175 at::native::resize_output(result, self.sizes());
176 }
177
178 if (mat1._nnz() == 0) {
179 // According to docs, when beta==0 values in self should be ignored
180 // nans and infs should not propagate
181 if (beta.toComplexDouble() == 0.) {
182 result.zero_();
183 } else {
184 if (!result.is_same(self)) {
185 result.copy_(self);
186 }
187 if (beta.toComplexDouble() != 1.) {
188 result.mul_(beta);
189 }
190 }
191 return result;
192 }
193
194 sparse::impl::cuda::addmm_out_sparse_csr(self, mat1, mat2, beta, alpha, result);
195 return result;
196 }
197
bmm_out_sparse_csr_cuda(const Tensor & mat1,const Tensor & mat2,Tensor & result)198 Tensor& bmm_out_sparse_csr_cuda(
199 const Tensor& mat1,
200 const Tensor& mat2,
201 Tensor& result) {
202 Scalar beta(0.0);
203 Scalar alpha(1.0);
204 return at::native::baddbmm_out_sparse_csr_cuda(result, mat1, mat2, beta, alpha, result);
205 }
206
addmv_out_sparse_compressed_cuda(const Tensor & self,const Tensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,Tensor & result)207 Tensor& addmv_out_sparse_compressed_cuda(
208 const Tensor& self,
209 const Tensor& mat,
210 const Tensor& vec,
211 const Scalar& beta,
212 const Scalar& alpha,
213 Tensor& result) {
214
215 if (mat.layout() == kSparseCsc) {
216 return addmv_out_sparse_compressed_cuda(self, mat.to_sparse_csr(), vec,
217 beta, alpha, result);
218 }
219 TORCH_CHECK(mat.layout() != kSparseBsc, "addmm_out_sparse_csr_cuda currently does not support layout SparseBsc for input mat.");
220
221 TORCH_CHECK(mat.dim() == 2, "addmv: Expected mat to be 2-D");
222 TORCH_CHECK(vec.dim() == 1, "addmv: Expected vec to be 1-D");
223
224 // Preprocessing code is copied from TORCH_IMPL_FUNC(addmv_out_cuda) at
225 // aten/src/ATen/native/cuda/Blas.cpp
226 // It would be nice to have it unified but there were undefined symbol
227 // problems, when using the same function for CUDA and SparseCsrCUDA dispatch
228 // keys and structured kernel
229 c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
230 auto betaval = beta.toComplexDouble();
231
232 if (&result != &self) {
233 at::native::resize_output(result, self_->sizes());
234 if (betaval != 0.0) {
235 at::native::copy_(result, *self_);
236 }
237 }
238
239 if (mat._nnz() == 0) {
240 // shortcut for an empty matrix
241 // By definition, when beta==0, values in self should be ignored. nans and
242 // infs should not propagate
243 if (betaval == 0.0) {
244 return result.zero_();
245 } else {
246 return at::mul_out(
247 const_cast<Tensor&>(result),
248 self,
249 at::native::scalar_tensor(
250 beta,
251 self.scalar_type(),
252 std::nullopt /* layout */,
253 at::kCPU,
254 std::nullopt /* pin_memory */));
255 }
256 }
257
258 sparse::impl::cuda::addmv_out_sparse_csr(mat, vec, beta, alpha, result);
259 return result;
260 }
261
262 /*
263 Solves a system of linear equations whose coefficients are represented in a sparse triangular matrix A:
264 op(A) X = B.
265
266 Args:
267 * `B` - dense Tensor of size m × nrhs.
268 * `A` - sparse Tensor of size m × m.
269 * `upper` - controls whether upper or lower triangular part of A is considered in computations.
270 * `transpose` - if true then op(A) = A^T.
271 * `unitriangular` - if true then the diagonal elements of A are assumed to be one.
272 * `X` - dense Tensor of size m × nrhs.
273 * `clone_A` - cloned matrix A, required only for compatibility with strided layout interface.
274 */
triangular_solve_out_sparse_csr_cuda(const Tensor & B,const Tensor & A,bool upper,bool transpose,bool unitriangular,Tensor & X,Tensor & clone_A)275 std::tuple<Tensor&, Tensor&> triangular_solve_out_sparse_csr_cuda(
276 const Tensor& B,
277 const Tensor& A,
278 bool upper,
279 bool transpose,
280 bool unitriangular,
281 Tensor& X,
282 Tensor& clone_A) {
283 sparse::impl::cuda::triangular_solve_out_sparse_csr(A, B, X, upper, transpose, unitriangular);
284 return std::tuple<Tensor&, Tensor&>(X, clone_A);
285 }
286
287 } // namespace at::native
288