• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- mlir_test_cblas_interface.cpp - Simple Blas subset interface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Simple Blas subset interface implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "include/mlir_test_cblas_interface.h"
14 #include "include/mlir_test_cblas.h"
15 #include <assert.h>
16 #include <iostream>
17 
18 extern "C" void
_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float,0> * X,float f)19 _mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f) {
20   X->data[X->offset] = f;
21 }
22 
23 extern "C" void
_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float,1> * X,float f)24 _mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
25                                        float f) {
26   for (unsigned i = 0; i < X->sizes[0]; ++i)
27     *(X->data + X->offset + i * X->strides[0]) = f;
28 }
29 
30 extern "C" void
_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float,2> * X,float f)31 _mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
32                                          float f) {
33   for (unsigned i = 0; i < X->sizes[0]; ++i)
34     for (unsigned j = 0; j < X->sizes[1]; ++j)
35       *(X->data + X->offset + i * X->strides[0] + j * X->strides[1]) = f;
36 }
37 
38 extern "C" void
_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float,0> * I,StridedMemRefType<float,0> * O)39 _mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
40                                          StridedMemRefType<float, 0> *O) {
41   O->data[O->offset] = I->data[I->offset];
42 }
43 
44 extern "C" void
_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float,1> * I,StridedMemRefType<float,1> * O)45 _mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
46                                              StridedMemRefType<float, 1> *O) {
47   if (I->sizes[0] != O->sizes[0]) {
48     std::cerr << "Incompatible strided memrefs\n";
49     printMemRefMetaData(std::cerr, *I);
50     printMemRefMetaData(std::cerr, *O);
51     return;
52   }
53   for (unsigned i = 0; i < I->sizes[0]; ++i)
54     O->data[O->offset + i * O->strides[0]] =
55         I->data[I->offset + i * I->strides[0]];
56 }
57 
_mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(StridedMemRefType<float,2> * I,StridedMemRefType<float,2> * O)58 extern "C" void _mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
59     StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O) {
60   if (I->sizes[0] != O->sizes[0] || I->sizes[1] != O->sizes[1]) {
61     std::cerr << "Incompatible strided memrefs\n";
62     printMemRefMetaData(std::cerr, *I);
63     printMemRefMetaData(std::cerr, *O);
64     return;
65   }
66   auto so0 = O->strides[0], so1 = O->strides[1];
67   auto si0 = I->strides[0], si1 = I->strides[1];
68   for (unsigned i = 0; i < I->sizes[0]; ++i)
69     for (unsigned j = 0; j < I->sizes[1]; ++j)
70       O->data[O->offset + i * so0 + j * so1] =
71           I->data[I->offset + i * si0 + j * si1];
72 }
73 
_mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(StridedMemRefType<float,1> * X,StridedMemRefType<float,1> * Y,StridedMemRefType<float,0> * Z)74 extern "C" void _mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
75     StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
76     StridedMemRefType<float, 0> *Z) {
77   if (X->strides[0] != 1 || Y->strides[0] != 1 || X->sizes[0] != Y->sizes[0]) {
78     std::cerr << "Incompatible strided memrefs\n";
79     printMemRefMetaData(std::cerr, *X);
80     printMemRefMetaData(std::cerr, *Y);
81     printMemRefMetaData(std::cerr, *Z);
82     return;
83   }
84   Z->data[Z->offset] +=
85       mlir_test_cblas_sdot(X->sizes[0], X->data + X->offset, X->strides[0],
86                            Y->data + Y->offset, Y->strides[0]);
87 }
88 
_mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(StridedMemRefType<float,2> * A,StridedMemRefType<float,2> * B,StridedMemRefType<float,2> * C)89 extern "C" void _mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
90     StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
91     StridedMemRefType<float, 2> *C) {
92   if (A->strides[1] != B->strides[1] || A->strides[1] != C->strides[1] ||
93       A->strides[1] != 1 || A->sizes[0] < A->strides[1] ||
94       B->sizes[0] < B->strides[1] || C->sizes[0] < C->strides[1] ||
95       C->sizes[0] != A->sizes[0] || C->sizes[1] != B->sizes[1] ||
96       A->sizes[1] != B->sizes[0]) {
97     printMemRefMetaData(std::cerr, *A);
98     printMemRefMetaData(std::cerr, *B);
99     printMemRefMetaData(std::cerr, *C);
100     return;
101   }
102   mlir_test_cblas_sgemm(
103       CBLAS_ORDER::CblasRowMajor, CBLAS_TRANSPOSE::CblasNoTrans,
104       CBLAS_TRANSPOSE::CblasNoTrans, C->sizes[0], C->sizes[1], A->sizes[1],
105       1.0f, A->data + A->offset, A->strides[0], B->data + B->offset,
106       B->strides[0], 1.0f, C->data + C->offset, C->strides[0]);
107 }
108