1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 #include <iostream>
10 #include "common.h"
11
EIGEN_BLAS_FUNC(gemm)12 int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha,
13 const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
14 {
15 // std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
16 typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*);
17 static const functype func[12] = {
18 // array index: NOTR | (NOTR << 2)
19 (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,ColMajor,false,ColMajor>::run),
20 // array index: TR | (NOTR << 2)
21 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,false,ColMajor>::run),
22 // array index: ADJ | (NOTR << 2)
23 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor>::run),
24 0,
25 // array index: NOTR | (TR << 2)
26 (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,false,ColMajor>::run),
27 // array index: TR | (TR << 2)
28 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,false,ColMajor>::run),
29 // array index: ADJ | (TR << 2)
30 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,false,ColMajor>::run),
31 0,
32 // array index: NOTR | (ADJ << 2)
33 (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor>::run),
34 // array index: TR | (ADJ << 2)
35 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,Conj, ColMajor>::run),
36 // array index: ADJ | (ADJ << 2)
37 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,Conj, ColMajor>::run),
38 0
39 };
40
41 const Scalar* a = reinterpret_cast<const Scalar*>(pa);
42 const Scalar* b = reinterpret_cast<const Scalar*>(pb);
43 Scalar* c = reinterpret_cast<Scalar*>(pc);
44 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
45 Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
46
47 int info = 0;
48 if(OP(*opa)==INVALID) info = 1;
49 else if(OP(*opb)==INVALID) info = 2;
50 else if(*m<0) info = 3;
51 else if(*n<0) info = 4;
52 else if(*k<0) info = 5;
53 else if(*lda<std::max(1,(OP(*opa)==NOTR)?*m:*k)) info = 8;
54 else if(*ldb<std::max(1,(OP(*opb)==NOTR)?*k:*n)) info = 10;
55 else if(*ldc<std::max(1,*m)) info = 13;
56 if(info)
57 return xerbla_(SCALAR_SUFFIX_UP"GEMM ",&info,6);
58
59 if (*m == 0 || *n == 0)
60 return 0;
61
62 if(beta!=Scalar(1))
63 {
64 if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
65 else matrix(c, *m, *n, *ldc) *= beta;
66 }
67
68 if(*k == 0)
69 return 0;
70
71 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k,1,true);
72
73 int code = OP(*opa) | (OP(*opb) << 2);
74 func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0);
75 return 0;
76 }
77
EIGEN_BLAS_FUNC(trsm)78 int EIGEN_BLAS_FUNC(trsm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
79 const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
80 {
81 // std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n";
82 typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, internal::level3_blocking<Scalar,Scalar>&);
83 static const functype func[32] = {
84 // array index: NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
85 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,ColMajor,ColMajor>::run),
86 // array index: TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
87 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,RowMajor,ColMajor>::run),
88 // array index: ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
89 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, Conj, RowMajor,ColMajor>::run),\
90 0,
91 // array index: NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
92 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,ColMajor,ColMajor>::run),
93 // array index: TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
94 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,RowMajor,ColMajor>::run),
95 // array index: ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
96 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, Conj, RowMajor,ColMajor>::run),
97 0,
98 // array index: NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
99 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,ColMajor,ColMajor>::run),
100 // array index: TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
101 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,RowMajor,ColMajor>::run),
102 // array index: ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
103 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, Conj, RowMajor,ColMajor>::run),
104 0,
105 // array index: NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
106 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,ColMajor,ColMajor>::run),
107 // array index: TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
108 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,RowMajor,ColMajor>::run),
109 // array index: ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
110 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, Conj, RowMajor,ColMajor>::run),
111 0,
112 // array index: NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
113 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,ColMajor,ColMajor>::run),
114 // array index: TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
115 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,RowMajor,ColMajor>::run),
116 // array index: ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)
117 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,Conj, RowMajor,ColMajor>::run),
118 0,
119 // array index: NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
120 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,ColMajor,ColMajor>::run),
121 // array index: TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
122 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,RowMajor,ColMajor>::run),
123 // array index: ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
124 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,Conj, RowMajor,ColMajor>::run),
125 0,
126 // array index: NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
127 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,ColMajor,ColMajor>::run),
128 // array index: TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
129 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,RowMajor,ColMajor>::run),
130 // array index: ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)
131 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,Conj, RowMajor,ColMajor>::run),
132 0,
133 // array index: NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
134 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,ColMajor,ColMajor>::run),
135 // array index: TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
136 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,RowMajor,ColMajor>::run),
137 // array index: ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
138 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,Conj, RowMajor,ColMajor>::run),
139 0
140 };
141
142 const Scalar* a = reinterpret_cast<const Scalar*>(pa);
143 Scalar* b = reinterpret_cast<Scalar*>(pb);
144 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
145
146 int info = 0;
147 if(SIDE(*side)==INVALID) info = 1;
148 else if(UPLO(*uplo)==INVALID) info = 2;
149 else if(OP(*opa)==INVALID) info = 3;
150 else if(DIAG(*diag)==INVALID) info = 4;
151 else if(*m<0) info = 5;
152 else if(*n<0) info = 6;
153 else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
154 else if(*ldb<std::max(1,*m)) info = 11;
155 if(info)
156 return xerbla_(SCALAR_SUFFIX_UP"TRSM ",&info,6);
157
158 if(*m==0 || *n==0)
159 return 0;
160
161 int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
162
163 if(SIDE(*side)==LEFT)
164 {
165 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m,1,false);
166 func[code](*m, *n, a, *lda, b, *ldb, blocking);
167 }
168 else
169 {
170 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n,1,false);
171 func[code](*n, *m, a, *lda, b, *ldb, blocking);
172 }
173
174 if(alpha!=Scalar(1))
175 matrix(b,*m,*n,*ldb) *= alpha;
176
177 return 0;
178 }
179
180
181 // b = alpha*op(a)*b for side = 'L'or'l'
182 // b = alpha*b*op(a) for side = 'R'or'r'
EIGEN_BLAS_FUNC(trmm)183 int EIGEN_BLAS_FUNC(trmm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
184 const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
185 {
186 // std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n";
187 typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
188 static const functype func[32] = {
189 // array index: NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
190 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, ColMajor,false,ColMajor,false,ColMajor>::run),
191 // array index: TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
192 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,false,ColMajor,false,ColMajor>::run),
193 // array index: ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
194 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run),
195 0,
196 // array index: NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
197 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,ColMajor,false,ColMajor>::run),
198 // array index: TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
199 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,false,ColMajor>::run),
200 // array index: ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
201 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run),
202 0,
203 // array index: NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
204 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, ColMajor,false,ColMajor,false,ColMajor>::run),
205 // array index: TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
206 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,false,ColMajor,false,ColMajor>::run),
207 // array index: ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
208 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run),
209 0,
210 // array index: NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
211 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,ColMajor,false,ColMajor>::run),
212 // array index: TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
213 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,false,ColMajor>::run),
214 // array index: ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
215 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run),
216 0,
217 // array index: NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
218 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run),
219 // array index: TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
220 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run),
221 // array index: ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)
222 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run),
223 0,
224 // array index: NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
225 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run),
226 // array index: TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
227 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run),
228 // array index: ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
229 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run),
230 0,
231 // array index: NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
232 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run),
233 // array index: TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
234 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run),
235 // array index: ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)
236 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run),
237 0,
238 // array index: NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
239 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run),
240 // array index: TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
241 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run),
242 // array index: ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
243 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run),
244 0
245 };
246
247 const Scalar* a = reinterpret_cast<const Scalar*>(pa);
248 Scalar* b = reinterpret_cast<Scalar*>(pb);
249 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
250
251 int info = 0;
252 if(SIDE(*side)==INVALID) info = 1;
253 else if(UPLO(*uplo)==INVALID) info = 2;
254 else if(OP(*opa)==INVALID) info = 3;
255 else if(DIAG(*diag)==INVALID) info = 4;
256 else if(*m<0) info = 5;
257 else if(*n<0) info = 6;
258 else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
259 else if(*ldb<std::max(1,*m)) info = 11;
260 if(info)
261 return xerbla_(SCALAR_SUFFIX_UP"TRMM ",&info,6);
262
263 int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
264
265 if(*m==0 || *n==0)
266 return 1;
267
268 // FIXME find a way to avoid this copy
269 Matrix<Scalar,Dynamic,Dynamic,ColMajor> tmp = matrix(b,*m,*n,*ldb);
270 matrix(b,*m,*n,*ldb).setZero();
271
272 if(SIDE(*side)==LEFT)
273 {
274 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m,1,false);
275 func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha, blocking);
276 }
277 else
278 {
279 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n,1,false);
280 func[code](*m, *n, *n, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha, blocking);
281 }
282 return 1;
283 }
284
285 // c = alpha*a*b + beta*c for side = 'L'or'l'
286 // c = alpha*b*a + beta*c for side = 'R'or'r
EIGEN_BLAS_FUNC(symm)287 int EIGEN_BLAS_FUNC(symm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha,
288 const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
289 {
290 // std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n";
291 const Scalar* a = reinterpret_cast<const Scalar*>(pa);
292 const Scalar* b = reinterpret_cast<const Scalar*>(pb);
293 Scalar* c = reinterpret_cast<Scalar*>(pc);
294 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
295 Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
296
297 int info = 0;
298 if(SIDE(*side)==INVALID) info = 1;
299 else if(UPLO(*uplo)==INVALID) info = 2;
300 else if(*m<0) info = 3;
301 else if(*n<0) info = 4;
302 else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
303 else if(*ldb<std::max(1,*m)) info = 9;
304 else if(*ldc<std::max(1,*m)) info = 12;
305 if(info)
306 return xerbla_(SCALAR_SUFFIX_UP"SYMM ",&info,6);
307
308 if(beta!=Scalar(1))
309 {
310 if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
311 else matrix(c, *m, *n, *ldc) *= beta;
312 }
313
314 if(*m==0 || *n==0)
315 {
316 return 1;
317 }
318
319 int size = (SIDE(*side)==LEFT) ? (*m) : (*n);
320 #if ISCOMPLEX
321 // FIXME add support for symmetric complex matrix
322 Matrix<Scalar,Dynamic,Dynamic,ColMajor> matA(size,size);
323 if(UPLO(*uplo)==UP)
324 {
325 matA.triangularView<Upper>() = matrix(a,size,size,*lda);
326 matA.triangularView<Lower>() = matrix(a,size,size,*lda).transpose();
327 }
328 else if(UPLO(*uplo)==LO)
329 {
330 matA.triangularView<Lower>() = matrix(a,size,size,*lda);
331 matA.triangularView<Upper>() = matrix(a,size,size,*lda).transpose();
332 }
333 if(SIDE(*side)==LEFT)
334 matrix(c, *m, *n, *ldc) += alpha * matA * matrix(b, *m, *n, *ldb);
335 else if(SIDE(*side)==RIGHT)
336 matrix(c, *m, *n, *ldc) += alpha * matrix(b, *m, *n, *ldb) * matA;
337 #else
338 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,size,1,false);
339
340 if(SIDE(*side)==LEFT)
341 if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, RowMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking);
342 else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking);
343 else return 0;
344 else if(SIDE(*side)==RIGHT)
345 if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, RowMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking);
346 else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, ColMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking);
347 else return 0;
348 else
349 return 0;
350 #endif
351
352 return 0;
353 }
354
355 // c = alpha*a*a' + beta*c for op = 'N'or'n'
356 // c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c'
EIGEN_BLAS_FUNC(syrk)357 int EIGEN_BLAS_FUNC(syrk)(const char *uplo, const char *op, const int *n, const int *k,
358 const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
359 {
360 // std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
361 #if !ISCOMPLEX
362 typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
363 static const functype func[8] = {
364 // array index: NOTR | (UP << 2)
365 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, Upper>::run),
366 // array index: TR | (UP << 2)
367 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, Upper>::run),
368 // array index: ADJ | (UP << 2)
369 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,Upper>::run),
370 0,
371 // array index: NOTR | (LO << 2)
372 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, Lower>::run),
373 // array index: TR | (LO << 2)
374 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, Lower>::run),
375 // array index: ADJ | (LO << 2)
376 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,Lower>::run),
377 0
378 };
379 #endif
380
381 const Scalar* a = reinterpret_cast<const Scalar*>(pa);
382 Scalar* c = reinterpret_cast<Scalar*>(pc);
383 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
384 Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
385
386 int info = 0;
387 if(UPLO(*uplo)==INVALID) info = 1;
388 else if(OP(*op)==INVALID || (ISCOMPLEX && OP(*op)==ADJ) ) info = 2;
389 else if(*n<0) info = 3;
390 else if(*k<0) info = 4;
391 else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
392 else if(*ldc<std::max(1,*n)) info = 10;
393 if(info)
394 return xerbla_(SCALAR_SUFFIX_UP"SYRK ",&info,6);
395
396 if(beta!=Scalar(1))
397 {
398 if(UPLO(*uplo)==UP)
399 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
400 else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
401 else
402 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
403 else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
404 }
405
406 if(*n==0 || *k==0)
407 return 0;
408
409 #if ISCOMPLEX
410 // FIXME add support for symmetric complex matrix
411 if(UPLO(*uplo)==UP)
412 {
413 if(OP(*op)==NOTR)
414 matrix(c, *n, *n, *ldc).triangularView<Upper>() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose();
415 else
416 matrix(c, *n, *n, *ldc).triangularView<Upper>() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda);
417 }
418 else
419 {
420 if(OP(*op)==NOTR)
421 matrix(c, *n, *n, *ldc).triangularView<Lower>() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose();
422 else
423 matrix(c, *n, *n, *ldc).triangularView<Lower>() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda);
424 }
425 #else
426 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*n,*n,*k,1,false);
427
428 int code = OP(*op) | (UPLO(*uplo) << 2);
429 func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha, blocking);
430 #endif
431
432 return 0;
433 }
434
435 // c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n'
436 // c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't'
EIGEN_BLAS_FUNC(syr2k)437 int EIGEN_BLAS_FUNC(syr2k)(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha,
438 const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
439 {
440 const Scalar* a = reinterpret_cast<const Scalar*>(pa);
441 const Scalar* b = reinterpret_cast<const Scalar*>(pb);
442 Scalar* c = reinterpret_cast<Scalar*>(pc);
443 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
444 Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
445
446 // std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";
447
448 int info = 0;
449 if(UPLO(*uplo)==INVALID) info = 1;
450 else if(OP(*op)==INVALID || (ISCOMPLEX && OP(*op)==ADJ) ) info = 2;
451 else if(*n<0) info = 3;
452 else if(*k<0) info = 4;
453 else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
454 else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
455 else if(*ldc<std::max(1,*n)) info = 12;
456 if(info)
457 return xerbla_(SCALAR_SUFFIX_UP"SYR2K",&info,6);
458
459 if(beta!=Scalar(1))
460 {
461 if(UPLO(*uplo)==UP)
462 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
463 else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
464 else
465 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
466 else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
467 }
468
469 if(*k==0)
470 return 1;
471
472 if(OP(*op)==NOTR)
473 {
474 if(UPLO(*uplo)==UP)
475 {
476 matrix(c, *n, *n, *ldc).triangularView<Upper>()
477 += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose()
478 + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose();
479 }
480 else if(UPLO(*uplo)==LO)
481 matrix(c, *n, *n, *ldc).triangularView<Lower>()
482 += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose()
483 + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose();
484 }
485 else if(OP(*op)==TR || OP(*op)==ADJ)
486 {
487 if(UPLO(*uplo)==UP)
488 matrix(c, *n, *n, *ldc).triangularView<Upper>()
489 += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb)
490 + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda);
491 else if(UPLO(*uplo)==LO)
492 matrix(c, *n, *n, *ldc).triangularView<Lower>()
493 += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb)
494 + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda);
495 }
496
497 return 0;
498 }
499
500
501 #if ISCOMPLEX
502
503 // c = alpha*a*b + beta*c for side = 'L'or'l'
504 // c = alpha*b*a + beta*c for side = 'R'or'r
EIGEN_BLAS_FUNC(hemm)505 int EIGEN_BLAS_FUNC(hemm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha,
506 const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
507 {
508 const Scalar* a = reinterpret_cast<const Scalar*>(pa);
509 const Scalar* b = reinterpret_cast<const Scalar*>(pb);
510 Scalar* c = reinterpret_cast<Scalar*>(pc);
511 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
512 Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
513
514 // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
515
516 int info = 0;
517 if(SIDE(*side)==INVALID) info = 1;
518 else if(UPLO(*uplo)==INVALID) info = 2;
519 else if(*m<0) info = 3;
520 else if(*n<0) info = 4;
521 else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
522 else if(*ldb<std::max(1,*m)) info = 9;
523 else if(*ldc<std::max(1,*m)) info = 12;
524 if(info)
525 return xerbla_(SCALAR_SUFFIX_UP"HEMM ",&info,6);
526
527 if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
528 else if(beta!=Scalar(1)) matrix(c, *m, *n, *ldc) *= beta;
529
530 if(*m==0 || *n==0)
531 {
532 return 1;
533 }
534
535 int size = (SIDE(*side)==LEFT) ? (*m) : (*n);
536 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,size,1,false);
537
538 if(SIDE(*side)==LEFT)
539 {
540 if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar,DenseIndex,RowMajor,true,Conj, ColMajor,false,false, ColMajor>
541 ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking);
542 else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,true,false, ColMajor,false,false, ColMajor>
543 ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking);
544 else return 0;
545 }
546 else if(SIDE(*side)==RIGHT)
547 {
548 if(UPLO(*uplo)==UP) matrix(c,*m,*n,*ldc) += alpha * matrix(b,*m,*n,*ldb) * matrix(a,*n,*n,*lda).selfadjointView<Upper>();/*internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, RowMajor,true,Conj, ColMajor>
549 ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking);*/
550 else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, ColMajor,true,false, ColMajor>
551 ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking);
552 else return 0;
553 }
554 else
555 {
556 return 0;
557 }
558
559 return 0;
560 }
561
562 // c = alpha*a*conj(a') + beta*c for op = 'N'or'n'
563 // c = alpha*conj(a')*a + beta*c for op = 'C'or'c'
EIGEN_BLAS_FUNC(herk)564 int EIGEN_BLAS_FUNC(herk)(const char *uplo, const char *op, const int *n, const int *k,
565 const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
566 {
567 // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
568
569 typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
570 static const functype func[8] = {
571 // array index: NOTR | (UP << 2)
572 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,Upper>::run),
573 0,
574 // array index: ADJ | (UP << 2)
575 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,Upper>::run),
576 0,
577 // array index: NOTR | (LO << 2)
578 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,Lower>::run),
579 0,
580 // array index: ADJ | (LO << 2)
581 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,Lower>::run),
582 0
583 };
584
585 const Scalar* a = reinterpret_cast<const Scalar*>(pa);
586 Scalar* c = reinterpret_cast<Scalar*>(pc);
587 RealScalar alpha = *palpha;
588 RealScalar beta = *pbeta;
589
590 // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
591
592 int info = 0;
593 if(UPLO(*uplo)==INVALID) info = 1;
594 else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
595 else if(*n<0) info = 3;
596 else if(*k<0) info = 4;
597 else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
598 else if(*ldc<std::max(1,*n)) info = 10;
599 if(info)
600 return xerbla_(SCALAR_SUFFIX_UP"HERK ",&info,6);
601
602 int code = OP(*op) | (UPLO(*uplo) << 2);
603
604 if(beta!=RealScalar(1))
605 {
606 if(UPLO(*uplo)==UP)
607 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
608 else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
609 else
610 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
611 else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
612
613 if(beta!=Scalar(0))
614 {
615 matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
616 matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
617 }
618 }
619
620 if(*k>0 && alpha!=RealScalar(0))
621 {
622 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*n,*n,*k,1,false);
623 func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha, blocking);
624 matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
625 }
626 return 0;
627 }
628
629 // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n'
630 // c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c'
EIGEN_BLAS_FUNC(her2k)631 int EIGEN_BLAS_FUNC(her2k)(const char *uplo, const char *op, const int *n, const int *k,
632 const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
633 {
634 const Scalar* a = reinterpret_cast<const Scalar*>(pa);
635 const Scalar* b = reinterpret_cast<const Scalar*>(pb);
636 Scalar* c = reinterpret_cast<Scalar*>(pc);
637 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
638 RealScalar beta = *pbeta;
639
640 // std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";
641
642 int info = 0;
643 if(UPLO(*uplo)==INVALID) info = 1;
644 else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
645 else if(*n<0) info = 3;
646 else if(*k<0) info = 4;
647 else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
648 else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
649 else if(*ldc<std::max(1,*n)) info = 12;
650 if(info)
651 return xerbla_(SCALAR_SUFFIX_UP"HER2K",&info,6);
652
653 if(beta!=RealScalar(1))
654 {
655 if(UPLO(*uplo)==UP)
656 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
657 else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
658 else
659 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
660 else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
661
662 if(beta!=Scalar(0))
663 {
664 matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
665 matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
666 }
667 }
668 else if(*k>0 && alpha!=Scalar(0))
669 matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
670
671 if(*k==0)
672 return 1;
673
674 if(OP(*op)==NOTR)
675 {
676 if(UPLO(*uplo)==UP)
677 {
678 matrix(c, *n, *n, *ldc).triangularView<Upper>()
679 += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint()
680 + numext::conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint();
681 }
682 else if(UPLO(*uplo)==LO)
683 matrix(c, *n, *n, *ldc).triangularView<Lower>()
684 += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint()
685 + numext::conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint();
686 }
687 else if(OP(*op)==ADJ)
688 {
689 if(UPLO(*uplo)==UP)
690 matrix(c, *n, *n, *ldc).triangularView<Upper>()
691 += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb)
692 + numext::conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda);
693 else if(UPLO(*uplo)==LO)
694 matrix(c, *n, *n, *ldc).triangularView<Lower>()
695 += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb)
696 + numext::conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda);
697 }
698
699 return 1;
700 }
701
702 #endif // ISCOMPLEX
703