• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s
2
3#matvec_accesses = [
4  affine_map<(i, j) -> (i, j)>,
5  affine_map<(i, j) -> (j)>,
6  affine_map<(i, j) -> (i)>
7]
8#matvec_trait = {
9  indexing_maps = #matvec_accesses,
10  iterator_types = ["parallel", "reduction"]
11}
12
13#mattransvec_accesses = [
14  affine_map<(i, j) -> (j, i)>,
15  affine_map<(i, j) -> (j)>,
16  affine_map<(i, j) -> (i)>
17]
18#mattransvec_trait = {
19  indexing_maps = #mattransvec_accesses,
20  iterator_types = ["parallel", "reduction"]
21}
22
23#vecmat_accesses = [
24  affine_map<(i, j) -> (j)>,
25  affine_map<(i, j) -> (i, j)>,
26  affine_map<(i, j) -> (i)>
27]
28#vecmat_trait = {
29  indexing_maps = #vecmat_accesses,
30  iterator_types = ["parallel", "reduction"]
31}
32
33#vecmattrans_accesses = [
34  affine_map<(i, j) -> (j)>,
35  affine_map<(i, j) -> (j, i)>,
36  affine_map<(i, j) -> (i)>
37]
38#vecmattrans_trait = {
39  indexing_maps = #vecmattrans_accesses,
40  iterator_types = ["parallel", "reduction"]
41}
42
43// CHECK-LABEL: func @matvec2x2
44// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
45// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
46// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
47// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
48// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
49// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
50// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
51// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
52// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
53// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32
54// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
55// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
56// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32
57// CHECK: store %[[T9]], %[[C]][] : memref<vector<2xf32>>
58// CHECK: return
59func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
60                                                %arg2: memref<vector<2xf32>>) {
61  %A = load %arg0[] : memref<vector<2x2xf32>>
62  %x = load %arg1[] : memref<vector<2xf32>>
63  %b = load %arg2[] : memref<vector<2xf32>>
64  %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
65  store %0, %arg2[] : memref<vector<2xf32>>
66  return
67}
68
69// CHECK-LABEL: func @mattransvec2x2
70// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
71// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
72// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
73// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
74// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
75// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
76// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
77// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
78// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32
79// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
80// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
81// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32
82// CHECK: store %[[T8]], %[[C]][] : memref<vector<2xf32>>
83// CHECK: return
84func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
85                                                     %arg2: memref<vector<2xf32>>) {
86  %A = load %arg0[] : memref<vector<2x2xf32>>
87  %x = load %arg1[] : memref<vector<2xf32>>
88  %b = load %arg2[] : memref<vector<2xf32>>
89  %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
90  store %0, %arg2[] : memref<vector<2xf32>>
91  return
92}
93
94// CHECK-LABEL: func @vecmat2x2
95// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
96// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
97// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
98// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
99// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
100// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
101// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
102// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
103// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
104// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32
105// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
106// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
107// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32
108// CHECK: store %[[T9]], %[[C]][] : memref<vector<2xf32>>
109// CHECK: return
110func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
111                                                %arg2: memref<vector<2xf32>>) {
112  %A = load %arg0[] : memref<vector<2x2xf32>>
113  %x = load %arg1[] : memref<vector<2xf32>>
114  %b = load %arg2[] : memref<vector<2xf32>>
115  %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
116  store %0, %arg2[] : memref<vector<2xf32>>
117  return
118}
119
120// CHECK-LABEL: func @vecmattrans2x2
121// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
122// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
123// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
124// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
125// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
126// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
127// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
128// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
129// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32
130// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
131// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
132// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32
133// CHECK: store %[[T8]], %[[C]][] : memref<vector<2xf32>>
134// CHECK: return
135func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
136                                                     %arg2: memref<vector<2xf32>>) {
137  %A = load %arg0[] : memref<vector<2x2xf32>>
138  %x = load %arg1[] : memref<vector<2xf32>>
139  %b = load %arg2[] : memref<vector<2xf32>>
140  %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
141  store %0, %arg2[] : memref<vector<2xf32>>
142  return
143}
144