• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2012 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 
18 #include "rsCpuIntrinsic.h"
19 #include "rsCpuIntrinsicInlines.h"
20 #include "rsCpuBLASDispatch.h"
21 #include "eight_bit_int_gemm.h"
22 
23 using namespace android;
24 using namespace android::renderscript;
25 
26 namespace android {
27 namespace renderscript {
28 
29 
30 class RsdCpuScriptIntrinsicBLAS : public RsdCpuScriptIntrinsic {
31 public:
32     void invokeForEach(uint32_t slot,
33                        const Allocation ** ain,
34                        uint32_t inLen,
35                        Allocation * aout,
36                        const void * usr,
37                        uint32_t usrLen,
38                        const RsScriptCall *sc) override;
39 
40     void populateScript(Script *) override;
41     ~RsdCpuScriptIntrinsicBLAS() override;
42     RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s);
43 
44 protected:
45 
46     uint8_t a_offset = 0;
47     uint8_t b_offset = 0;
48     uint8_t c_offset = 0;
49 
50 #ifdef RS_COMPATIBILITY_LIB
51     bool isBlasLibInitialized = false;
52 #endif
53     static void kernelBNNM(size_t m, size_t n, size_t k,
54                            const uint8_t* a, uint8_t a_offset, size_t lda,
55                            const uint8_t* b, uint8_t b_offset, size_t ldb,
56                            uint8_t* c, int32_t c_offset, size_t ldc,
57                            int32_t c_mult_int);
58 
59 
60 
61 };
62 
63 }
64 }
65 
populateScript(Script * s)66 void RsdCpuScriptIntrinsicBLAS::populateScript(Script *s) {
67     s->mHal.info.exportedVariableCount = 0;
68 }
69 
initABC(const Allocation ** ain,size_t size,void ** A,void ** B,void ** C,int * lda,int * ldb,int * ldc)70 static void initABC(const Allocation ** ain,
71                     size_t size,
72                     void** A,
73                     void** B,
74                     void** C,
75                     int* lda,
76                     int* ldb,
77                     int* ldc)
78 {
79     if (ain[0]) {
80         *A = ain[0]->mHal.drvState.lod[0].mallocPtr;
81         *lda = (int)(ain[0]->mHal.drvState.lod[0].stride/size);
82     }
83     if (ain[1]) {
84         *B = ain[1]->mHal.drvState.lod[0].mallocPtr;
85         *ldb = (int)(ain[1]->mHal.drvState.lod[0].stride/size);
86     }
87     if (ain[2]) {
88         *C = ain[2]->mHal.drvState.lod[0].mallocPtr;
89         *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size);
90     }
91 
92 
93 }
94 
invokeForEach(uint32_t slot,const Allocation ** ain,uint32_t inLen,Allocation * aout,const void * usr,uint32_t usrLen,const RsScriptCall * sc)95 void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
96                                               const Allocation ** ain,
97                                               uint32_t inLen,
98                                               Allocation * aout,
99                                               const void * usr,
100                                               uint32_t usrLen,
101                                               const RsScriptCall *sc) {
102     RsBlasCall* call = (RsBlasCall*) usr;
103     // setup BLAS enum args
104     enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA;
105     enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB;
106     enum CBLAS_UPLO Uplo = (enum CBLAS_UPLO)call->uplo;
107     enum CBLAS_DIAG Diag = (enum CBLAS_DIAG)call->diag;
108     enum CBLAS_SIDE Side = (enum CBLAS_SIDE)call->side;
109 
110     void *A = nullptr;
111     void *B = nullptr;
112     void *C = nullptr;
113     void *X = nullptr;
114     void *Y = nullptr;
115 
116     int lda = 0, ldb = 0, ldc = 0;
117 
118 #ifdef RS_COMPATIBILITY_LIB
119     // Allow BNNM even without libblas
120     if (call->func != RsBlas_bnnm && !isBlasLibInitialized) {
121         if (!loadBLASLib()) {
122             ALOGE("Failed to load the BLAS lib, IntrinsicBLAS NOT supported!\n");
123             return;
124         }
125         isBlasLibInitialized = true;
126     }
127 #endif
128 
129     switch (call->func) {
130 
131     // Level 1 BLAS: returns into a 1D Allocation
132 
133 
134     // Level 2 BLAS
135     case (RsBlas_sgemv):
136         initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
137         cblas_sgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.f, (float*)A,
138                     lda, (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
139         break;
140     case (RsBlas_sgbmv):
141         initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
142         cblas_sgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
143                     call->alpha.f, (float*)A, lda, (float*)X, call->incX,
144                     call->beta.f, (float*)Y, call->incY);
145         break;
146     case (RsBlas_strmv):
147         initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
148         cblas_strmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
149                     lda, (float*)X, call->incX);
150         break;
151     case (RsBlas_stbmv):
152         initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
153         cblas_stbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
154                     lda, (float*)X, call->incX);
155         break;
156     // stpmv takes a packed 1D Allocation only
157     case (RsBlas_stpmv):
158         initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
159         cblas_stpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
160                     (float*)X, call->incX);
161         break;
162     case (RsBlas_strsv):
163         initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
164         cblas_strsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, lda,
165                     (float*)X, call->incX);
166         break;
167     case (RsBlas_stbsv):
168         initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
169         cblas_stbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A,
170                     lda, (float*)X, call->incX);
171         break;
172     case (RsBlas_stpsv):
173         initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr);
174         cblas_stpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A,
175                     (float*)X, call->incX);
176         break;
177     case (RsBlas_dgemv):
178         initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
179         cblas_dgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.d, (double*)A,
180                     lda, (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
181         break;
182     case (RsBlas_dgbmv):
183         initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
184         cblas_dgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
185                     call->alpha.d, (double*)A, lda, (double*)X, call->incX,
186                     call->beta.d, (double*)Y, call->incY);
187         break;
188     case (RsBlas_dtrmv):
189         initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
190         cblas_dtrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
191                     lda, (double*)X, call->incX);
192         break;
193     case (RsBlas_dtbmv):
194         initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
195         cblas_dtbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
196                     lda, (double*)X, call->incX);
197         break;
198     // stpmv takes a packed 1D Allocation only
199     case (RsBlas_dtpmv):
200         initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
201         cblas_dtpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
202                     (double*)X, call->incX);
203         break;
204     case (RsBlas_dtrsv):
205         initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
206         cblas_dtrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, lda,
207                     (double*)X, call->incX);
208         break;
209     case (RsBlas_dtbsv):
210         initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
211         cblas_dtbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A,
212                     lda, (double*)X, call->incX);
213         break;
214     case (RsBlas_dtpsv):
215         initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr);
216         cblas_dtpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A,
217                     (double*)X, call->incX);
218         break;
219     case (RsBlas_cgemv):
220         initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
221         cblas_cgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.c, (void*)A,
222                     lda, (void*)X, call->incX, (void*)&call->beta.c, (void*)Y, call->incY);
223         break;
224     case (RsBlas_cgbmv):
225         initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
226         cblas_cgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
227                     (void*)&call->alpha.c, (void*)A, lda, (void*)X, call->incX,
228                     (void*)&call->beta.c, (void*)Y, call->incY);
229         break;
230     case (RsBlas_ctrmv):
231         initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
232         cblas_ctrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
233                     lda, (void*)X, call->incX);
234         break;
235     case (RsBlas_ctbmv):
236         initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
237         cblas_ctbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
238                     lda, (void*)X, call->incX);
239         break;
240     // stpmv takes a packed 1D Allocation only
241     case (RsBlas_ctpmv):
242         initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
243         cblas_ctpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
244                     (void*)X, call->incX);
245         break;
246     case (RsBlas_ctrsv):
247         initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
248         cblas_ctrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
249                     (void*)X, call->incX);
250         break;
251     case (RsBlas_ctbsv):
252         initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
253         cblas_ctbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
254                     lda, (void*)X, call->incX);
255         break;
256     case (RsBlas_ctpsv):
257         initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
258         cblas_ctpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
259                     (void*)X, call->incX);
260         break;
261     case (RsBlas_zgemv):
262         initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
263         cblas_zgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.z, (void*)A,
264                     lda, (void*)X, call->incX, (void*)&call->beta.z, (void*)Y, call->incY);
265         break;
266     case (RsBlas_zgbmv):
267         initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
268         cblas_zgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU,
269                     (void*)&call->alpha.z, (void*)A, lda, (void*)X, call->incX,
270                     (void*)&call->beta.z, (void*)Y, call->incY);
271         break;
272     case (RsBlas_ztrmv):
273         initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
274         cblas_ztrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
275                     lda, (void*)X, call->incX);
276         break;
277     case (RsBlas_ztbmv):
278         initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
279         cblas_ztbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
280                     lda, (void*)X, call->incX);
281         break;
282     // stpmv takes a packed 1D Allocation only
283     case (RsBlas_ztpmv):
284         initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
285         cblas_ztpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
286                     (void*)X, call->incX);
287         break;
288     case (RsBlas_ztrsv):
289         initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
290         cblas_ztrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda,
291                     (void*)X, call->incX);
292         break;
293     case (RsBlas_ztbsv):
294         initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
295         cblas_ztbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A,
296                     lda, (void*)X, call->incX);
297         break;
298     case (RsBlas_ztpsv):
299         initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr);
300         cblas_ztpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A,
301                     (void*)X, call->incX);
302         break;
303 
304 
305     // S and D only
306     case (RsBlas_ssymv):
307         initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
308         cblas_ssymv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, lda,
309                     (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
310         break;
311     case (RsBlas_ssbmv):
312         initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
313         cblas_ssbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.f,
314                     (float*)A, lda, (float*)X, call->incX, call->beta.f,
315                     (float*)Y, call->incY);
316         break;
317     //sspmv requires a packed 1D Allocation
318     case (RsBlas_sspmv):
319         initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc);
320         cblas_sspmv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A,
321                     (float*)X, call->incX, call->beta.f, (float*)Y, call->incY);
322         break;
323     // following calls have init reordered because A is output matrix
324     case (RsBlas_sger):
325         initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
326         cblas_sger(CblasRowMajor, call->M, call->N, call->alpha.f, (float*)X,
327                    call->incX, (float*)Y, call->incY, (float*)A, lda);
328         break;
329     case (RsBlas_ssyr):
330         initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
331         cblas_ssyr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
332                    (float*)A, lda);
333         break;
334     // sspr is packed 1D Allocation A only
335     case (RsBlas_sspr):
336         initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr);
337         cblas_sspr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
338                    (float*)A);
339         break;
340     case (RsBlas_ssyr2):
341         initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
342         cblas_ssyr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
343                     (float*)Y, call->incY, (float*)A, lda);
344         break;
345     // sspr2 is packed 1D Allocation A only
346     case (RsBlas_sspr2):
347         initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda);
348         cblas_sspr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX,
349                     (float*)Y, call->incY, (float*)A);
350         break;
351     case (RsBlas_dsymv):
352         initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
353         cblas_dsymv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, lda,
354                     (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
355         break;
356     case (RsBlas_dsbmv):
357         initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
358         cblas_dsbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.d,
359                     (double*)A, lda, (double*)X, call->incX, call->beta.d,
360                     (double*)Y, call->incY);
361         break;
362     // dspmv requires a packed 1D Allocation
363     case (RsBlas_dspmv):
364         initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc);
365         cblas_dspmv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A,
366                     (double*)X, call->incX, call->beta.d, (double*)Y, call->incY);
367         break;
368     // following calls have init reordered because A is output matrix
369     case (RsBlas_dger):
370         initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
371         cblas_dger(CblasRowMajor, call->M, call->N, call->alpha.d, (double*)X,
372                    call->incX, (double*)Y, call->incY, (double*)A, lda);
373         break;
374     case (RsBlas_dsyr):
375         initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
376         cblas_dsyr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
377                    (double*)A, lda);
378         break;
379     // dspr is packed 1D Allocation A only
380     case (RsBlas_dspr):
381         initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr);
382         cblas_dspr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
383                    (double*)A);
384         break;
385     case (RsBlas_dsyr2):
386         initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
387         cblas_dsyr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
388                     (double*)Y, call->incY, (double*)A, lda);
389         break;
390     // dspr2 is packed 1D Allocation A only
391     case (RsBlas_dspr2):
392         initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda);
393         cblas_dspr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX,
394                     (double*)Y, call->incY, (double*)A);
395         break;
396 
397     // C and Z only
398     case (RsBlas_chemv):
399         initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
400         cblas_chemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, lda,
401                     X, call->incX, (void*)&call->beta.c, Y, call->incY);
402         break;
403     case (RsBlas_chbmv):
404         initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
405         cblas_chbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.c,
406                     A, lda, X, call->incX, (void*)&call->beta.c, Y, call->incY);
407         break;
408     case (RsBlas_chpmv):
409         initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc);
410         cblas_chpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A,
411                     X, call->incX, (void*)&call->beta.c, Y, call->incY);
412         break;
413     case (RsBlas_cgeru):
414         initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
415         cblas_cgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
416                     X, call->incX, Y, call->incY, A, lda);
417         break;
418     case (RsBlas_cgerc):
419         initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
420         cblas_cgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c,
421                     X, call->incX, Y, call->incY, A, lda);
422         break;
423     case (RsBlas_cher):
424         initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
425         cblas_cher(CblasRowMajor, Uplo, call->N, call->alpha.f,
426                    X, call->incX, A, lda);
427         break;
428     // packed 1D Allocations only
429     case (RsBlas_chpr):
430         initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
431         cblas_chpr(CblasRowMajor, Uplo, call->N, call->alpha.f, X,
432                    call->incX, A);
433         break;
434     case (RsBlas_cher2):
435         initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
436         cblas_cher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c,
437                    X, call->incX, Y, call->incY, A, lda);
438         break;
439     // packed 1D Allocations only
440     case (RsBlas_chpr2):
441         initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda);
442         cblas_chpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, X,
443                    call->incX, Y, call->incY, A);
444         break;
445     case (RsBlas_zhemv):
446         initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
447         cblas_zhemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, lda,
448                     X, call->incX, (void*)&call->beta.z, Y, call->incY);
449         break;
450     case (RsBlas_zhbmv):
451         initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
452         cblas_zhbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.z,
453                     A, lda, X, call->incX, (void*)&call->beta.z, Y, call->incY);
454         break;
455     case (RsBlas_zhpmv):
456         initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc);
457         cblas_zhpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A,
458                     X, call->incX, (void*)&call->beta.z, Y, call->incY);
459         break;
460     case (RsBlas_zgeru):
461         initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
462         cblas_zgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
463                     X, call->incX, Y, call->incY, A, lda);
464         break;
465     case (RsBlas_zgerc):
466         initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
467         cblas_zgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z,
468                     X, call->incX, Y, call->incY, A, lda);
469         break;
470     case (RsBlas_zher):
471         initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
472         cblas_zher(CblasRowMajor, Uplo, call->N, call->alpha.d,
473                    X, call->incX, A, lda);
474         break;
475     // packed 1D Allocations only
476     case (RsBlas_zhpr):
477         initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda);
478         cblas_zhpr(CblasRowMajor, Uplo, call->N, call->alpha.d, X,
479                    call->incX, A);
480         break;
481     case (RsBlas_zher2):
482         initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
483         cblas_zher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z,
484                    X, call->incX, Y, call->incY, A, lda);
485         break;
486     // packed 1D Allocations only
487     case (RsBlas_zhpr2):
488         initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda);
489         cblas_zhpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, X,
490                    call->incX, Y, call->incY, A);
491         break;
492 
493     // Level 3 BLAS
494     case (RsBlas_sgemm):
495         initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
496         cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f,
497                     (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
498         break;
499     case (RsBlas_ssymm):
500         initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
501         cblas_ssymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.f, (float*)A,
502                     lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
503         break;
504     case (RsBlas_ssyrk):
505         initABC(ain, sizeof(float), &A, nullptr, &C, &lda, nullptr, &ldc);
506         cblas_ssyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
507                     lda, call->beta.f, (float*)C, ldc);
508         break;
509     case (RsBlas_ssyr2k):
510         initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
511         cblas_ssyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A,
512                      lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
513         break;
514     case (RsBlas_strmm):
515         initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
516         cblas_strmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
517                     (float*)A, lda, (float*)B, ldb);
518         break;
519     case (RsBlas_strsm):
520         initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr);
521         cblas_strsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f,
522                     (float*)A, lda, (float*)B, ldb);
523         break;
524 
525 
526     case (RsBlas_dgemm):
527         initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
528         cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d,
529                     (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
530         break;
531     case (RsBlas_dsymm):
532         initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
533         cblas_dsymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.d, (double*)A,
534                     lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
535         break;
536     case (RsBlas_dsyrk):
537         initABC(ain, sizeof(double), &A, nullptr, &C, &lda, nullptr, &ldc);
538         cblas_dsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
539                     lda, call->beta.d, (double*)C, ldc);
540         break;
541     case (RsBlas_dsyr2k):
542         initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
543         cblas_dsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A,
544                      lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
545         break;
546     case (RsBlas_dtrmm):
547         initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
548         cblas_dtrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
549                     (double*)A, lda, (double*)B, ldb);
550         break;
551     case (RsBlas_dtrsm):
552         initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr);
553         cblas_dtrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d,
554                     (double*)A, lda, (double*)B, ldb);
555         break;
556 
557     case (RsBlas_cgemm):
558         initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
559         cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c,
560                     A, lda, B, ldb, (void*)&call->beta.c, C, ldc);
561         break;
562     case (RsBlas_csymm):
563         initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
564         cblas_csymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A,
565                     lda, B, ldb, (void*)&call->beta.c, C, ldc);
566         break;
567     case (RsBlas_csyrk):
568         initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
569         cblas_csyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
570                     lda, (void*)&call->beta.c, C, ldc);
571         break;
572     case (RsBlas_csyr2k):
573         initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
574         cblas_csyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A,
575                      lda, B, ldb, (void*)&call->beta.c, C, ldc);
576         break;
577     case (RsBlas_ctrmm):
578         initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
579         cblas_ctrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
580                     A, lda, B, ldb);
581         break;
582     case (RsBlas_ctrsm):
583         initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
584         cblas_ctrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c,
585                     A, lda, B, ldb);
586         break;
587 
588     case (RsBlas_zgemm):
589         initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
590         cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z,
591                     A, lda, B, ldb, (void*)&call->beta.z, C, ldc);
592         break;
593     case (RsBlas_zsymm):
594         initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
595         cblas_zsymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A,
596                     lda, B, ldb, (void*)&call->beta.z, C, ldc);
597         break;
598     case (RsBlas_zsyrk):
599         initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
600         cblas_zsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
601                     lda, (void*)&call->beta.z, C, ldc);
602         break;
603     case (RsBlas_zsyr2k):
604         initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
605         cblas_zsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A,
606                      lda, B, ldb, (void*)&call->beta.z, C, ldc);
607         break;
608     case (RsBlas_ztrmm):
609         initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
610         cblas_ztrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
611                     A, lda, B, ldb);
612         break;
613     case (RsBlas_ztrsm):
614         initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr);
615         cblas_ztrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z,
616                     A, lda, B, ldb);
617         break;
618 
619     // Level 3 C and Z only
620     case (RsBlas_chemm):
621         initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
622         cblas_chemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, lda,
623                     B, ldb, (void*)&call->beta.c, C, ldc);
624         break;
625     case (RsBlas_cherk):
626         initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
627         cblas_cherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, A, lda,
628                     call->beta.f, C, ldc);
629         break;
630     case (RsBlas_cher2k):
631         initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
632         cblas_cher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, lda,
633                      B, ldb, call->beta.f, C, ldc);
634         break;
635 
636     case (RsBlas_zhemm):
637         initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
638         cblas_zhemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, lda,
639                     B, ldb, (void*)&call->beta.z, C, ldc);
640         break;
641     case (RsBlas_zherk):
642         initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc);
643         cblas_zherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, A, lda,
644                     call->beta.d, C, ldc);
645         break;
646     case (RsBlas_zher2k):
647         initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
648         cblas_zher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, lda,
649                      B, ldb, call->beta.d, C, ldc);
650         break;
651 
652 
653     case (RsBlas_bnnm):
654         initABC(ain, sizeof(uint8_t), &A, &B, &C, &lda, &ldb, &ldc);
655         kernelBNNM(call->M, call->N, call->K,
656                     (const uint8_t*)A, call->a_offset, lda,
657                     (const uint8_t*)B, call->b_offset, ldb,
658                     (uint8_t*)C, call->c_offset, ldc,
659                     call->c_mult_int);
660 
661         break;
662 
663     default:
664         ALOGE("unimplemented\n");
665     }
666 
667 
668 }
669 
kernelBNNM(size_t m,size_t n,size_t k,const uint8_t * a,uint8_t a_offset,size_t lda,const uint8_t * b,uint8_t b_offset,size_t ldb,uint8_t * c,int32_t c_offset,size_t ldc,int32_t c_mult_int)670 void RsdCpuScriptIntrinsicBLAS::kernelBNNM(size_t m, size_t n, size_t k,
671                                            const uint8_t* a, uint8_t a_offset, size_t lda,
672                                            const uint8_t* b, uint8_t b_offset, size_t ldb,
673                                            uint8_t* c, int32_t c_offset, size_t ldc,
674                                            int32_t c_mult_int) {
675     const int c_shift = 21;
676 #if defined(ARCH_ARM_HAVE_VFP) || defined(ARCH_ARM_USE_INTRINSICS)
677     // Non-optimized path for ARMv7 devices without SIMD instructions.
678     if (!gArchUseSIMD) {
679         /*
680          * Calculations are done in 1.10.21 fixed-point format for the final output,
681          * just before there's a shift down to drop the fractional parts. The output
682          * values are gated to 0 to 255 to fit in a byte, but the 10-bit format
683          * gives some headroom to avoid wrapping around on small overflows.
684          */
685         size_t i = 0, j = 0, l = 0;
686         for (j = 0; j < n; j++) {
687             for (i = 0; i < m; i++) {
688                 int32_t total = 0;
689                 for (l = 0; l < k; l++) {
690                     const int a_index = ((i * lda) + l);
691                     const uint8_t a_as_byte = a[a_index];
692                     const int32_t a_as_int = (((int32_t)(a_as_byte)) - a_offset);
693                     const int b_index = ((j * ldb) + l);
694                     const uint8_t b_as_byte = b[b_index];
695                     const int32_t b_as_int = (((int32_t)(b_as_byte)) - b_offset);
696                     const int32_t mult_as_int = (a_as_int * b_as_int);
697                     total += mult_as_int;
698                 }
699                 const int c_index = ((ldc * i) + j);
700                 int32_t output =
701                     ((((total + c_offset) * c_mult_int) + (1 << (c_shift - 1)))
702                      >> c_shift);
703                 if (output > 255) {
704                     output = 255;
705                 }
706                 if (output < 0) {
707                     output = 0;
708                 }
709                 c[c_index] = (uint8_t)(output);
710             }
711         }
712         return;
713     }
714 #endif
715 
716     // Using gemmlowp to calculate the low precision 8 bit GEMM.
717     bool transpose_a = true;
718     bool transpose_b = false;
719     bool transpose_c = true;
720     gemmlowp::eight_bit_int_gemm::EightBitIntGemm(transpose_a, transpose_b, transpose_c,
721                                                   m, n, k, a, -a_offset, lda,
722                                                   b, -b_offset, ldb, c, c_offset,
723                                                   c_mult_int, c_shift, ldc,
724                                                   gemmlowp::eight_bit_int_gemm::BitDepthSetting::A8B8);
725 
726 }
727 
728 
729 
730 
731 
RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl * ctx,const Script * s)732 RsdCpuScriptIntrinsicBLAS::RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx,
733                                                    const Script *s)
734             : RsdCpuScriptIntrinsic(ctx, s, nullptr, RS_SCRIPT_INTRINSIC_ID_BLAS) {
735 
736 
737 }
738 
~RsdCpuScriptIntrinsicBLAS()739 RsdCpuScriptIntrinsicBLAS::~RsdCpuScriptIntrinsicBLAS() {
740 }
741 
742 
743 
744 
745 
rsdIntrinsic_BLAS(RsdCpuReferenceImpl * ctx,const Script * s,const Element * e)746 RsdCpuScriptImpl * rsdIntrinsic_BLAS(RsdCpuReferenceImpl *ctx,
747                                     const Script *s, const Element *e) {
748 
749     return new RsdCpuScriptIntrinsicBLAS(ctx, s);
750 }
751