1 /*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18 #include "rsCpuIntrinsic.h"
19 #include "rsCpuIntrinsicInlines.h"
20 #include "rsCpuBLASDispatch.h"
21
22 using namespace android;
23 using namespace android::renderscript;
24
25 namespace android {
26 namespace renderscript {
27
28
29 class RsdCpuScriptIntrinsicBLAS : public RsdCpuScriptIntrinsic {
30 public:
31 void invokeForEach(uint32_t slot,
32 const Allocation ** ain,
33 uint32_t inLen,
34 Allocation * aout,
35 const void * usr,
36 uint32_t usrLen,
37 const RsScriptCall *sc) override;
38
39 void populateScript(Script *) override;
40 ~RsdCpuScriptIntrinsicBLAS() override;
41 RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s);
42
43 protected:
44
45 uint8_t a_offset = 0;
46 uint8_t b_offset = 0;
47 uint8_t c_offset = 0;
48
49 #ifdef RS_COMPATIBILITY_LIB
50 bool isBlasLibInitialized = false;
51 #endif
52 static void kernelBNNM(size_t m, size_t n, size_t k,
53 const uint8_t* a, uint8_t a_offset, size_t lda,
54 const uint8_t* b, uint8_t b_offset, size_t ldb,
55 uint8_t* c, int32_t c_offset, size_t ldc,
56 int32_t c_mult_int);
57
58
59
60 };
61
62 }
63 }
64
populateScript(Script * s)65 void RsdCpuScriptIntrinsicBLAS::populateScript(Script *s) {
66 s->mHal.info.exportedVariableCount = 0;
67 }
68
initABC(const Allocation ** ain,size_t size,void ** A,void ** B,void ** C,int * lda,int * ldb,int * ldc)69 static void initABC(const Allocation ** ain,
70 size_t size,
71 void** A,
72 void** B,
73 void** C,
74 int* lda,
75 int* ldb,
76 int* ldc)
77 {
78 if (ain[0]) {
79 *A = ain[0]->mHal.drvState.lod[0].mallocPtr;
80 *lda = (int)(ain[0]->mHal.drvState.lod[0].stride/size);
81 }
82 if (ain[1]) {
83 *B = ain[1]->mHal.drvState.lod[0].mallocPtr;
84 *ldb = (int)(ain[1]->mHal.drvState.lod[0].stride/size);
85 }
86 if (ain[2]) {
87 *C = ain[2]->mHal.drvState.lod[0].mallocPtr;
88 *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size);
89 }
90
91
92 }
93
invokeForEach(uint32_t slot,const Allocation ** ain,uint32_t inLen,Allocation * aout,const void * usr,uint32_t usrLen,const RsScriptCall * sc)94 void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
95 const Allocation ** ain,
96 uint32_t inLen,
97 Allocation * aout,
98 const void * usr,
99 uint32_t usrLen,
100 const RsScriptCall *sc) {
101 RsBlasCall* call = (RsBlasCall*) usr;
102 // setup BLAS enum args
103 enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA;
104 enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB;
105 enum CBLAS_UPLO Uplo = (enum CBLAS_UPLO)call->uplo;
106 enum CBLAS_DIAG Diag = (enum CBLAS_DIAG)call->diag;
107 enum CBLAS_SIDE Side = (enum CBLAS_SIDE)call->side;
108
109 void *A = nullptr;
110 void *B = nullptr;
111 void *C = nullptr;
112 void *X = nullptr;
113 void *Y = nullptr;
114
115 int lda = 0, ldb = 0, ldc = 0;
116
117 #ifdef RS_COMPATIBILITY_LIB
118 // Allow BNNM even without libblas
119 if (call->func != RsBlas_bnnm && !isBlasLibInitialized) {
120 if (!loadBLASLib()) {
121 ALOGE("Failed to load the BLAS lib, IntrinsicBLAS NOT supported!\n");
122 return;
123 }
124 isBlasLibInitialized = true;
125 }
126 #endif
127
128 switch (call->func) {
129
130 // Level 1 BLAS: returns into a 1D Allocation
131
132
133 // Level 2 BLAS
134 case (RsBlas_sgemv):
135 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
136 cblas_sgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.f, (float*)A,
137 lda, (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
138 break;
139 case (RsBlas_sgbmv):
140 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
141 cblas_sgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
142 call->alpha.f, (float*)A, lda, (float*)X, call->incX,
143 call->beta.f, (float*)Y, call->incY);
144 break;
145 case (RsBlas_strmv):
146 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
147 cblas_strmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
148 lda, (float*)X, call->incX);
149 break;
150 case (RsBlas_stbmv):
151 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
152 cblas_stbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
153 lda, (float*)X, call->incX);
154 break;
155 // stpmv takes a packed 1D Allocation only
156 case (RsBlas_stpmv):
157 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
158 cblas_stpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
159 (float*)X, call->incX);
160 break;
161 case (RsBlas_strsv):
162 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
163 cblas_strsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, lda,
164 (float*)X, call->incX);
165 break;
166 case (RsBlas_stbsv):
167 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
168 cblas_stbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
169 lda, (float*)X, call->incX);
170 break;
171 case (RsBlas_stpsv):
172 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
173 cblas_stpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
174 (float*)X, call->incX);
175 break;
176 case (RsBlas_dgemv):
177 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
178 cblas_dgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.d, (double*)A,
179 lda, (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
180 break;
181 case (RsBlas_dgbmv):
182 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
183 cblas_dgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
184 call->alpha.d, (double*)A, lda, (double*)X, call->incX,
185 call->beta.d, (double*)Y, call->incY);
186 break;
187 case (RsBlas_dtrmv):
188 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
189 cblas_dtrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
190 lda, (double*)X, call->incX);
191 break;
192 case (RsBlas_dtbmv):
193 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
194 cblas_dtbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
195 lda, (double*)X, call->incX);
196 break;
197 // stpmv takes a packed 1D Allocation only
198 case (RsBlas_dtpmv):
199 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
200 cblas_dtpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
201 (double*)X, call->incX);
202 break;
203 case (RsBlas_dtrsv):
204 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
205 cblas_dtrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, lda,
206 (double*)X, call->incX);
207 break;
208 case (RsBlas_dtbsv):
209 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
210 cblas_dtbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
211 lda, (double*)X, call->incX);
212 break;
213 case (RsBlas_dtpsv):
214 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
215 cblas_dtpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
216 (double*)X, call->incX);
217 break;
218 case (RsBlas_cgemv):
219 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
220 cblas_cgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.c, (void*)A,
221 lda, (void*)X, call->incX, (void*)&call->beta.c, (void*)Y, call->incY);
222 break;
223 case (RsBlas_cgbmv):
224 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
225 cblas_cgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
226 (void*)&call->alpha.c, (void*)A, lda, (void*)X, call->incX,
227 (void*)&call->beta.c, (void*)Y, call->incY);
228 break;
229 case (RsBlas_ctrmv):
230 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
231 cblas_ctrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
232 lda, (void*)X, call->incX);
233 break;
234 case (RsBlas_ctbmv):
235 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
236 cblas_ctbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
237 lda, (void*)X, call->incX);
238 break;
239 // stpmv takes a packed 1D Allocation only
240 case (RsBlas_ctpmv):
241 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
242 cblas_ctpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
243 (void*)X, call->incX);
244 break;
245 case (RsBlas_ctrsv):
246 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
247 cblas_ctrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
248 (void*)X, call->incX);
249 break;
250 case (RsBlas_ctbsv):
251 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
252 cblas_ctbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
253 lda, (void*)X, call->incX);
254 break;
255 case (RsBlas_ctpsv):
256 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
257 cblas_ctpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
258 (void*)X, call->incX);
259 break;
260 case (RsBlas_zgemv):
261 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
262 cblas_zgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.z, (void*)A,
263 lda, (void*)X, call->incX, (void*)&call->beta.z, (void*)Y, call->incY);
264 break;
265 case (RsBlas_zgbmv):
266 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
267 cblas_zgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
268 (void*)&call->alpha.z, (void*)A, lda, (void*)X, call->incX,
269 (void*)&call->beta.z, (void*)Y, call->incY);
270 break;
271 case (RsBlas_ztrmv):
272 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
273 cblas_ztrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
274 lda, (void*)X, call->incX);
275 break;
276 case (RsBlas_ztbmv):
277 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
278 cblas_ztbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
279 lda, (void*)X, call->incX);
280 break;
281 // stpmv takes a packed 1D Allocation only
282 case (RsBlas_ztpmv):
283 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
284 cblas_ztpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
285 (void*)X, call->incX);
286 break;
287 case (RsBlas_ztrsv):
288 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
289 cblas_ztrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
290 (void*)X, call->incX);
291 break;
292 case (RsBlas_ztbsv):
293 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
294 cblas_ztbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
295 lda, (void*)X, call->incX);
296 break;
297 case (RsBlas_ztpsv):
298 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
299 cblas_ztpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
300 (void*)X, call->incX);
301 break;
302
303
304 // S and D only
305 case (RsBlas_ssymv):
306 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
307 cblas_ssymv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, lda,
308 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
309 break;
310 case (RsBlas_ssbmv):
311 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
312 cblas_ssbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.f,
313 (float*)A, lda, (float*)X, call->incX, call->beta.f,
314 (float*)Y, call->incY);
315 break;
316 //sspmv requires a packed 1D Allocation
317 case (RsBlas_sspmv):
318 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
319 cblas_sspmv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A,
320 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
321 break;
322 // following calls have init reordered because A is output matrix
323 case (RsBlas_sger):
324 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
325 cblas_sger(CblasRowMajor, call->M, call->N, call->alpha.f, (float*)X,
326 call->incX, (float*)Y, call->incY, (float*)A, lda);
327 break;
328 case (RsBlas_ssyr):
329 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
330 cblas_ssyr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
331 (float*)A, lda);
332 break;
333 // sspr is packed 1D Allocation A only
334 case (RsBlas_sspr):
335 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
336 cblas_sspr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
337 (float*)A);
338 break;
339 case (RsBlas_ssyr2):
340 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
341 cblas_ssyr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
342 (float*)Y, call->incY, (float*)A, lda);
343 break;
344 // sspr2 is packed 1D Allocation A only
345 case (RsBlas_sspr2):
346 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
347 cblas_sspr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
348 (float*)Y, call->incY, (float*)A);
349 break;
350 case (RsBlas_dsymv):
351 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
352 cblas_dsymv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, lda,
353 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
354 break;
355 case (RsBlas_dsbmv):
356 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
357 cblas_dsbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.d,
358 (double*)A, lda, (double*)X, call->incX, call->beta.d,
359 (double*)Y, call->incY);
360 break;
361 // dspmv requires a packed 1D Allocation
362 case (RsBlas_dspmv):
363 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
364 cblas_dspmv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A,
365 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
366 break;
367 // following calls have init reordered because A is output matrix
368 case (RsBlas_dger):
369 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
370 cblas_dger(CblasRowMajor, call->M, call->N, call->alpha.d, (double*)X,
371 call->incX, (double*)Y, call->incY, (double*)A, lda);
372 break;
373 case (RsBlas_dsyr):
374 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
375 cblas_dsyr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
376 (double*)A, lda);
377 break;
378 // dspr is packed 1D Allocation A only
379 case (RsBlas_dspr):
380 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
381 cblas_dspr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
382 (double*)A);
383 break;
384 case (RsBlas_dsyr2):
385 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
386 cblas_dsyr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
387 (double*)Y, call->incY, (double*)A, lda);
388 break;
389 // dspr2 is packed 1D Allocation A only
390 case (RsBlas_dspr2):
391 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
392 cblas_dspr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
393 (double*)Y, call->incY, (double*)A);
394 break;
395
396 // C and Z only
397 case (RsBlas_chemv):
398 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
399 cblas_chemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, lda,
400 X, call->incX, (void*)&call->beta.c, Y, call->incY);
401 break;
402 case (RsBlas_chbmv):
403 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
404 cblas_chbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.c,
405 A, lda, X, call->incX, (void*)&call->beta.c, Y, call->incY);
406 break;
407 case (RsBlas_chpmv):
408 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
409 cblas_chpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A,
410 X, call->incX, (void*)&call->beta.c, Y, call->incY);
411 break;
412 case (RsBlas_cgeru):
413 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
414 cblas_cgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
415 X, call->incX, Y, call->incY, A, lda);
416 break;
417 case (RsBlas_cgerc):
418 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
419 cblas_cgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
420 X, call->incX, Y, call->incY, A, lda);
421 break;
422 case (RsBlas_cher):
423 initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
424 cblas_cher(CblasRowMajor, Uplo, call->N, call->alpha.f,
425 X, call->incX, A, lda);
426 break;
427 // packed 1D Allocations only
428 case (RsBlas_chpr):
429 initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
430 cblas_chpr(CblasRowMajor, Uplo, call->N, call->alpha.f, X,
431 call->incX, A);
432 break;
433 case (RsBlas_cher2):
434 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
435 cblas_cher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c,
436 X, call->incX, Y, call->incY, A, lda);
437 break;
438 // packed 1D Allocations only
439 case (RsBlas_chpr2):
440 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
441 cblas_chpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, X,
442 call->incX, Y, call->incY, A);
443 break;
444 case (RsBlas_zhemv):
445 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
446 cblas_zhemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, lda,
447 X, call->incX, (void*)&call->beta.z, Y, call->incY);
448 break;
449 case (RsBlas_zhbmv):
450 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
451 cblas_zhbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.z,
452 A, lda, X, call->incX, (void*)&call->beta.z, Y, call->incY);
453 break;
454 case (RsBlas_zhpmv):
455 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
456 cblas_zhpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A,
457 X, call->incX, (void*)&call->beta.z, Y, call->incY);
458 break;
459 case (RsBlas_zgeru):
460 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
461 cblas_zgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
462 X, call->incX, Y, call->incY, A, lda);
463 break;
464 case (RsBlas_zgerc):
465 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
466 cblas_zgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
467 X, call->incX, Y, call->incY, A, lda);
468 break;
469 case (RsBlas_zher):
470 initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
471 cblas_zher(CblasRowMajor, Uplo, call->N, call->alpha.d,
472 X, call->incX, A, lda);
473 break;
474 // packed 1D Allocations only
475 case (RsBlas_zhpr):
476 initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
477 cblas_zhpr(CblasRowMajor, Uplo, call->N, call->alpha.d, X,
478 call->incX, A);
479 break;
480 case (RsBlas_zher2):
481 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
482 cblas_zher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z,
483 X, call->incX, Y, call->incY, A, lda);
484 break;
485 // packed 1D Allocations only
486 case (RsBlas_zhpr2):
487 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
488 cblas_zhpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, X,
489 call->incX, Y, call->incY, A);
490 break;
491
492 // Level 3 BLAS
493 case (RsBlas_sgemm):
494 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
495 cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f,
496 (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
497 break;
498 case (RsBlas_ssymm):
499 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
500 cblas_ssymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.f, (float*)A,
501 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
502 break;
503 case (RsBlas_ssyrk):
504 initABC(ain, sizeof(float), &A, nullptr, &C, &lda, nullptr, &ldc);
505 cblas_ssyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
506 lda, call->beta.f, (float*)C, ldc);
507 break;
508 case (RsBlas_ssyr2k):
509 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
510 cblas_ssyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
511 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
512 break;
513 case (RsBlas_strmm):
514 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
515 cblas_strmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
516 (float*)A, lda, (float*)B, ldb);
517 break;
518 case (RsBlas_strsm):
519 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
520 cblas_strsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
521 (float*)A, lda, (float*)B, ldb);
522 break;
523
524
525 case (RsBlas_dgemm):
526 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
527 cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d,
528 (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
529 break;
530 case (RsBlas_dsymm):
531 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
532 cblas_dsymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.d, (double*)A,
533 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
534 break;
535 case (RsBlas_dsyrk):
536 initABC(ain, sizeof(double), &A, nullptr, &C, &lda, nullptr, &ldc);
537 cblas_dsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
538 lda, call->beta.d, (double*)C, ldc);
539 break;
540 case (RsBlas_dsyr2k):
541 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
542 cblas_dsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
543 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
544 break;
545 case (RsBlas_dtrmm):
546 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
547 cblas_dtrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
548 (double*)A, lda, (double*)B, ldb);
549 break;
550 case (RsBlas_dtrsm):
551 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
552 cblas_dtrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
553 (double*)A, lda, (double*)B, ldb);
554 break;
555
556 case (RsBlas_cgemm):
557 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
558 cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c,
559 A, lda, B, ldb, (void*)&call->beta.c, C, ldc);
560 break;
561 case (RsBlas_csymm):
562 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
563 cblas_csymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A,
564 lda, B, ldb, (void*)&call->beta.c, C, ldc);
565 break;
566 case (RsBlas_csyrk):
567 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
568 cblas_csyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
569 lda, (void*)&call->beta.c, C, ldc);
570 break;
571 case (RsBlas_csyr2k):
572 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
573 cblas_csyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
574 lda, B, ldb, (void*)&call->beta.c, C, ldc);
575 break;
576 case (RsBlas_ctrmm):
577 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
578 cblas_ctrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
579 A, lda, B, ldb);
580 break;
581 case (RsBlas_ctrsm):
582 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
583 cblas_ctrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
584 A, lda, B, ldb);
585 break;
586
587 case (RsBlas_zgemm):
588 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
589 cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z,
590 A, lda, B, ldb, (void*)&call->beta.z, C, ldc);
591 break;
592 case (RsBlas_zsymm):
593 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
594 cblas_zsymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A,
595 lda, B, ldb, (void*)&call->beta.z, C, ldc);
596 break;
597 case (RsBlas_zsyrk):
598 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
599 cblas_zsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
600 lda, (void*)&call->beta.z, C, ldc);
601 break;
602 case (RsBlas_zsyr2k):
603 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
604 cblas_zsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
605 lda, B, ldb, (void*)&call->beta.z, C, ldc);
606 break;
607 case (RsBlas_ztrmm):
608 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
609 cblas_ztrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
610 A, lda, B, ldb);
611 break;
612 case (RsBlas_ztrsm):
613 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
614 cblas_ztrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
615 A, lda, B, ldb);
616 break;
617
618 // Level 3 C and Z only
619 case (RsBlas_chemm):
620 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
621 cblas_chemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, lda,
622 B, ldb, (void*)&call->beta.c, C, ldc);
623 break;
624 case (RsBlas_cherk):
625 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
626 cblas_cherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, A, lda,
627 call->beta.f, C, ldc);
628 break;
629 case (RsBlas_cher2k):
630 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
631 cblas_cher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, lda,
632 B, ldb, call->beta.f, C, ldc);
633 break;
634
635 case (RsBlas_zhemm):
636 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
637 cblas_zhemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, lda,
638 B, ldb, (void*)&call->beta.z, C, ldc);
639 break;
640 case (RsBlas_zherk):
641 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
642 cblas_zherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, A, lda,
643 call->beta.d, C, ldc);
644 break;
645 case (RsBlas_zher2k):
646 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
647 cblas_zher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, lda,
648 B, ldb, call->beta.d, C, ldc);
649 break;
650
651
652 case (RsBlas_bnnm):
653 initABC(ain, sizeof(uint8_t), &A, &B, &C, &lda, &ldb, &ldc);
654 kernelBNNM(call->M, call->N, call->K,
655 (const uint8_t*)A, call->a_offset, lda,
656 (const uint8_t*)B, call->b_offset, ldb,
657 (uint8_t*)C, call->c_offset, ldc,
658 call->c_mult_int);
659
660 break;
661
662 default:
663 ALOGE("unimplemented\n");
664 }
665
666
667 }
668
kernelBNNM(size_t m,size_t n,size_t k,const uint8_t * a,uint8_t a_offset,size_t lda,const uint8_t * b,uint8_t b_offset,size_t ldb,uint8_t * c,int32_t c_offset,size_t ldc,int32_t c_mult_int)669 void RsdCpuScriptIntrinsicBLAS::kernelBNNM(size_t m, size_t n, size_t k,
670 const uint8_t* a, uint8_t a_offset, size_t lda,
671 const uint8_t* b, uint8_t b_offset, size_t ldb,
672 uint8_t* c, int32_t c_offset, size_t ldc,
673 int32_t c_mult_int) {
674 // Calculations are done in 1.10.21 fixed-point format for the final output,
675 // just before there's a shift down to drop the fractional parts. The output
676 // values are gated to 0 to 255 to fit in a byte, but the 10-bit format
677 // gives some headroom to avoid wrapping around on small overflows.
678 const int c_shift = 21;
679 size_t i = 0, j = 0, l = 0;
680 for (j = 0; j < n; j++) {
681 for (i = 0; i < m; i++) {
682 int32_t total = 0;
683 for (l = 0; l < k; l++) {
684 const int a_index = ((i * lda) + l);
685 const uint8_t a_as_byte = a[a_index];
686 const int32_t a_as_int = (((int32_t)(a_as_byte)) - a_offset);
687 const int b_index = ((j * ldb) + l);
688 const uint8_t b_as_byte = b[b_index];
689 const int32_t b_as_int = (((int32_t)(b_as_byte)) - b_offset);
690 const int32_t mult_as_int = (a_as_int * b_as_int);
691 total += mult_as_int;
692 }
693 const int c_index = ((ldc * i) + j);
694 int32_t output =
695 ((((total + c_offset) * c_mult_int) + (1 << (c_shift - 1)))
696 >> c_shift);
697 if (output > 255) {
698 output = 255;
699 }
700 if (output < 0) {
701 output = 0;
702 }
703 c[c_index] = (uint8_t)(output);
704 }
705 }
706 }
707
708
709
710
711
RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl * ctx,const Script * s)712 RsdCpuScriptIntrinsicBLAS::RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx,
713 const Script *s)
714 : RsdCpuScriptIntrinsic(ctx, s, nullptr, RS_SCRIPT_INTRINSIC_ID_BLAS) {
715
716
717 }
718
~RsdCpuScriptIntrinsicBLAS()719 RsdCpuScriptIntrinsicBLAS::~RsdCpuScriptIntrinsicBLAS() {
720 }
721
722
723
724
725
rsdIntrinsic_BLAS(RsdCpuReferenceImpl * ctx,const Script * s,const Element * e)726 RsdCpuScriptImpl * rsdIntrinsic_BLAS(RsdCpuReferenceImpl *ctx,
727 const Script *s, const Element *e) {
728
729 return new RsdCpuScriptIntrinsicBLAS(ctx, s);
730 }
731