• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- mlir_test_cblas.cpp - Simple Blas subset implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Simple Blas subset implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "include/mlir_test_cblas.h"
14 #include <assert.h>
15 
mlir_test_cblas_sdot(const int N,const float * X,const int incX,const float * Y,const int incY)16 extern "C" float mlir_test_cblas_sdot(const int N, const float *X,
17                                       const int incX, const float *Y,
18                                       const int incY) {
19   float res = 0.0f;
20   for (int i = 0; i < N; ++i)
21     res += X[i * incX] * Y[i * incY];
22   return res;
23 }
24 
mlir_test_cblas_sgemm(const enum CBLAS_ORDER Order,const enum CBLAS_TRANSPOSE TransA,const enum CBLAS_TRANSPOSE TransB,const int M,const int N,const int K,const float alpha,const float * A,const int lda,const float * B,const int ldb,const float beta,float * C,const int ldc)25 extern "C" void mlir_test_cblas_sgemm(
26     const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
27     const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
28     const float alpha, const float *A, const int lda, const float *B,
29     const int ldb, const float beta, float *C, const int ldc) {
30   assert(Order == CBLAS_ORDER::CblasRowMajor);
31   assert(TransA == CBLAS_TRANSPOSE::CblasNoTrans);
32   assert(TransB == CBLAS_TRANSPOSE::CblasNoTrans);
33   for (int m = 0; m < M; ++m) {
34     auto *pA = A + m * lda;
35     auto *pC = C + m * ldc;
36     for (int n = 0; n < N; ++n) {
37       float c = pC[n];
38       float res = 0.0f;
39       for (int k = 0; k < K; ++k) {
40         auto *pB = B + k * ldb;
41         res += pA[k] * pB[n];
42       }
43       pC[n] = alpha * c + beta * res;
44     }
45   }
46 }
47