1 //===- mlir_test_cblas.cpp - Simple Blas subset implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Simple Blas subset implementation.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "include/mlir_test_cblas.h"
14 #include <assert.h>
15
mlir_test_cblas_sdot(const int N,const float * X,const int incX,const float * Y,const int incY)16 extern "C" float mlir_test_cblas_sdot(const int N, const float *X,
17 const int incX, const float *Y,
18 const int incY) {
19 float res = 0.0f;
20 for (int i = 0; i < N; ++i)
21 res += X[i * incX] * Y[i * incY];
22 return res;
23 }
24
mlir_test_cblas_sgemm(const enum CBLAS_ORDER Order,const enum CBLAS_TRANSPOSE TransA,const enum CBLAS_TRANSPOSE TransB,const int M,const int N,const int K,const float alpha,const float * A,const int lda,const float * B,const int ldb,const float beta,float * C,const int ldc)25 extern "C" void mlir_test_cblas_sgemm(
26 const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
27 const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
28 const float alpha, const float *A, const int lda, const float *B,
29 const int ldb, const float beta, float *C, const int ldc) {
30 assert(Order == CBLAS_ORDER::CblasRowMajor);
31 assert(TransA == CBLAS_TRANSPOSE::CblasNoTrans);
32 assert(TransB == CBLAS_TRANSPOSE::CblasNoTrans);
33 for (int m = 0; m < M; ++m) {
34 auto *pA = A + m * lda;
35 auto *pC = C + m * ldc;
36 for (int n = 0; n < N; ++n) {
37 float c = pC[n];
38 float res = 0.0f;
39 for (int k = 0; k < K; ++k) {
40 auto *pB = B + k * ldb;
41 res += pA[k] * pB[n];
42 }
43 pC[n] = alpha * c + beta * res;
44 }
45 }
46 }
47