• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | \
2// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
3// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
4// RUN: | FileCheck %s
5
6func private @print_memref_f32(memref<*xf32>)
7
8func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>) -> (memref<?x?xf32>) {
9  %c0 = constant 0 : index
10  %c1 = constant 1 : index
11  %f0 = constant 0.0 : f32
12  %x = dim %A, %c0 : memref<?x?xf32>
13  %y = dim %B, %c1 : memref<?x?xf32>
14  %C = alloc(%x, %y) : memref<?x?xf32>
15  linalg.fill(%C, %f0) : memref<?x?xf32>, f32
16  linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
17                outs(%C: memref<?x?xf32>)
18  return %C : memref<?x?xf32>
19}
20
21func @matvec(%A: memref<?x?xf32>, %B: memref<?x?xf32>) -> (memref<?x?xf32>) {
22  %c0 = constant 0 : index
23  %c1 = constant 1 : index
24  %f0 = constant 0.0 : f32
25  %m = dim %A, %c0 : memref<?x?xf32>
26  %x = dim %A, %c1 : memref<?x?xf32>
27  %n = dim %B, %c1 : memref<?x?xf32>
28  %C = alloc(%m, %n) : memref<?x?xf32>
29  linalg.fill(%C, %f0) : memref<?x?xf32>, f32
30  scf.for %i = %c0 to %n step %c1 {
31    %b = subview %B[0, %i][%x, 1][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [?]>
32    %c = subview %C[0, %i][%m, 1][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [?]>
33    linalg.matvec ins(%A, %b: memref<?x?xf32>, memref<?xf32, offset: ?, strides: [?]>)
34                  outs(%c: memref<?xf32, offset: ?, strides: [?]>)
35  }
36  return %C : memref<?x?xf32>
37}
38
39func @main() {
40  %c0 = constant 0 : index
41  %c1 = constant 1 : index
42  %m = constant 5 : index
43  %x = constant 3 : index
44  %n = constant 2 : index
45  %val1 = constant 13.0 : f32
46  %val2 = constant 17.0 : f32
47  %A = alloc(%m, %x) : memref<?x?xf32>
48  %B = alloc(%x, %n) : memref<?x?xf32>
49  linalg.fill(%A, %val1) : memref<?x?xf32>, f32
50  linalg.fill(%B, %val2) : memref<?x?xf32>, f32
51  store %val1, %B[%c0, %c0] : memref<?x?xf32>
52  %C1 = call @matmul(%A, %B) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>
53  %C2 = call @matvec(%A, %B) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>
54  scf.for %i = %c0 to %m step %c1 {
55    scf.for %j = %c0 to %n step %c1 {
56      %e1 = load %C1[%i, %j] : memref<?x?xf32>
57      %e2 = load %C2[%i, %j] : memref<?x?xf32>
58      %c = cmpf "oeq", %e1, %e2 : f32
59      assert %c, "Matmul does not produce same output as matvec"
60    }
61  }
62  %C2_ = memref_cast %C2 : memref<?x?xf32> to memref<*xf32>
63  call @print_memref_f32(%C2_) : (memref<*xf32>) -> ()
64  return
65}
66
67// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [5, 2] strides = [2, 1] data =
68// CHECK-NEXT:      [
69// CHECK-SAME:  [611,   663],
70// CHECK-NEXT:  [611,   663],
71// CHECK-NEXT:  [611,   663],
72// CHECK-NEXT:  [611,   663],
73// CHECK-NEXT:  [611,   663]
74// CHECK-SAME: ]
75