• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: mlir-hlo-opt --lhlo-fusion -split-input-file %s -o - | FileCheck %s
2
3// CHECK-LABEL: @simple_kloop_fusion
4// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>) -> memref<?x?xf32>
5func @simple_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
6                          %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> memref<?x?xf32> {
7  // CHECK: "lmhlo.fusion"() ( {
8  // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
9  // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
10  // CHECK: }) : () -> ()
11  // CHECK: return %[[ARG3]] : memref<?x?xf32>
12  "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
13  "lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
14  return %arg3 : memref<?x?xf32>
15}
16
17// -----
18
19// CHECK-LABEL: @simple_multi_output_kloop_fusion
20// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>)
21func @simple_multi_output_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
22                          %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) {
23  // CHECK: "lmhlo.fusion"() ( {
24  // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
25  // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
26  // CHECK: }) : () -> ()
27  // CHECK: return %[[ARG1]], %[[ARG3]] : memref<?x?xf32>, memref<?x?xf32>
28  "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
29  "lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
30  return %arg1, %arg3 : memref<?x?xf32>, memref<?x?xf32>
31}
32
33// -----
34
35// CHECK-LABEL: @simple_multi_output_kloop_fusion_with_reorder
36// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>, %[[ARG4:.*]]: memref<2xindex>, %[[ARG5:.*]]: memref<?x?xf32>)
37func @simple_multi_output_kloop_fusion_with_reorder(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
38                          %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
39                          %arg4: memref<2xindex>, %arg5:  memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) {
40  // CHECK: "lmhlo.fusion"() ( {
41  // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
42  // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
43  // CHECK: }) : () -> ()
44  // CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[ARG4]], %[[ARG5]])
45  // CHECK: return %[[ARG1]], %[[ARG3]], %[[ARG5]] : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
46  "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
47  "lmhlo.dynamic_broadcast_in_dim"(%arg1, %arg4, %arg5) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (memref<?x?xf32>, memref<2xindex>, memref<?x?xf32>) -> ()
48  "lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
49  return %arg1, %arg3, %arg5 : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
50}
51
52// -----
53
54// CHECK-LABEL: @same_num_elements_multi_output_kloop_fusion
55// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<2xi64>, %[[ARG3:.*]]: memref<?x?x?xf32>, %[[ARG4:.*]]: memref<?x?x?xf32>, %[[ARG5:.*]]: memref<?x?x?xf32>)
56func @same_num_elements_multi_output_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
57                          %arg2: memref<2xi64>, %arg3: memref<?x?x?xf32>,
58                          %arg4: memref<?x?x?xf32>, %arg5:  memref<?x?x?xf32>) -> (memref<?x?xf32>, memref<?x?x?xf32>) {
59  // CHECK: "lmhlo.fusion"() ( {
60  // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
61  // CHECK: "lmhlo.dynamic_reshape"(%[[ARG1]], %[[ARG2]], %[[ARG3]])
62  // CHECK: "lmhlo.add"(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
63  // CHECK: }) : () -> ()
64  // CHECK: return %[[ARG1]], %[[ARG5]] : memref<?x?xf32>, memref<?x?x?xf32>
65  "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
66  "lmhlo.dynamic_reshape"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<2xi64>, memref<?x?x?xf32>) -> ()
67  "lmhlo.add"(%arg3, %arg4, %arg5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
68  return %arg1, %arg5 : memref<?x?xf32>, memref<?x?x?xf32>
69}
70
71// -----
72
73// CHECK-LABEL: @check_not_kloop_fusion
74func @check_not_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) {
75  // CHECK-NOT: "lmhlo.fusion"
76  "lmhlo.add"(%arg0, %arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
77  "lmhlo.subtract"(%arg2, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
78  return %arg1, %arg3: memref<?x?xf32>, memref<?x?xf32>
79}
80
81// -----
82
83// CHECK-LABEL: @kloop_fusion_with_dealloc
84// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>)
85func @kloop_fusion_with_dealloc(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) {
86  // CHECK: %[[TMP3:.*]] = memref.alloc
87  // CHECK: %[[TMP5:.*]] = memref.alloc
88  // CHECK: %[[TMP9:.*]] = memref.alloc
89  // CHECK: %[[TMP13:.*]] = memref.alloc
90  // CHECK: %[[TMP16:.*]] = memref.alloc
91  // CHECK: "lmhlo.fusion"() ( {
92  // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
93  // CHECK: "lmhlo.multiply"(%[[ARG0]], %[[ARG1]], %[[TMP5]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
94  // CHECK: "lmhlo.abs"(%[[TMP3]], %[[TMP9]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
95  // CHECK: "lmhlo.abs"(%[[TMP5]], %[[TMP13]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
96  // CHECK: "lmhlo.multiply"(%[[TMP9]], %[[TMP13]], %[[TMP16]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
97  // CHECK: }) : () -> ()
98  // CHECK: memref.dealloc %[[TMP3]] : memref<?x?xf32>
99  // CHECK: memref.dealloc %[[TMP5]] : memref<?x?xf32>
100  // CHECK: memref.dealloc %[[TMP13]] : memref<?x?xf32>
101  // CHECK: return %[[TMP9]], %[[TMP16]] : memref<?x?xf32>, memref<?x?xf32>
102  %c0 = constant 0 : index
103  %c1 = constant 1 : index
104  %0 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
105  %1 = tensor.extract %0[%c0] : tensor<2xindex>
106  %2 = tensor.extract %0[%c1] : tensor<2xindex>
107  %3 = memref.alloc(%1, %2) : memref<?x?xf32>
108  "lmhlo.add"(%arg0, %arg1, %3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
109  %4 = memref.alloc(%1, %2) : memref<?x?xf32>
110  "lmhlo.multiply"(%arg0, %arg1, %4) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
111  %5 = shape.shape_of %3 : memref<?x?xf32> -> tensor<2xindex>
112  %6 = tensor.extract %5[%c0] : tensor<2xindex>
113  %7 = tensor.extract %5[%c1] : tensor<2xindex>
114  %8 = memref.alloc(%6, %7) : memref<?x?xf32>
115  "lmhlo.abs"(%3, %8) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
116  memref.dealloc %3 : memref<?x?xf32>
117  %9 = shape.shape_of %4 : memref<?x?xf32> -> tensor<2xindex>
118  %10 = tensor.extract %9[%c0] : tensor<2xindex>
119  %11 = tensor.extract %9[%c1] : tensor<2xindex>
120  %12 = memref.alloc(%10, %11) : memref<?x?xf32>
121  "lmhlo.abs"(%4, %12) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
122  memref.dealloc %4 : memref<?x?xf32>
123  %13 = shape.shape_of %8 : memref<?x?xf32> -> tensor<2xindex>
124  %14 = tensor.extract %13[%c0] : tensor<2xindex>
125  %15 = tensor.extract %13[%c1] : tensor<2xindex>
126  %16 = memref.alloc(%14, %15) : memref<?x?xf32>
127  "lmhlo.multiply"(%8, %12, %16) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
128  memref.dealloc %12 : memref<?x?xf32>
129  return %8, %16 : memref<?x?xf32>, memref<?x?xf32>
130}
131
132// -----
133
134// CHECK-LABEL: @simple_kinput
135// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<f32>
136func @simple_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?xf32>, %init: memref<f32>) -> memref<?xf32> {
137  // CHECK: "lmhlo.fusion"() ( {
138  // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
139  // CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( {
140  // CHECK: }) : () -> ()
141  // CHECK: return %[[ARG2]] : memref<?xf32>
142  "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
143  "lmhlo.reduce"(%arg1, %init, %arg2) ( {
144  ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
145    "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
146    "lmhlo.terminator"() : () -> ()
147  } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
148  return %arg2: memref<?xf32>
149}
150
151// -----
152
153// CHECK-LABEL: @multi_output_kinput
154// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<f32>
155func @multi_output_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?xf32>, %init: memref<f32>) -> (memref<?x?xf32>, memref<?xf32>) {
156  // CHECK: "lmhlo.fusion"() ( {
157  // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
158  // CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( {
159  // CHECK: }) : () -> ()
160  // CHECK: return %[[ARG1]], %[[ARG2]] : memref<?x?xf32>, memref<?xf32>
161  "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
162  "lmhlo.reduce"(%arg1, %init, %arg2) ( {
163  ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
164    "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
165    "lmhlo.terminator"() : () -> ()
166  } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
167  return %arg1, %arg2: memref<?x?xf32>, memref<?xf32>
168}
169
170// -----
171
172// CHECK-LABEL: @row_red_and_row_red_kinput
173// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?xf32>, %[[ARG4:.*]]: memref<?xf32>, %[[ARG5:.*]]: memref<?x?xf32>, %[[ARG6:.*]]: memref<f32>
174func @row_red_and_row_red_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?xf32>, %arg4: memref<?xf32>, %arg5: memref<?x?xf32>, %init: memref<f32>) -> (memref<?xf32>, memref<?xf32>) {
175  // CHECK: "lmhlo.fusion"() ( {
176  // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
177  // CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
178  // CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( {
179  // CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( {
180  // CHECK: }) : () -> ()
181  // CHECK: return %[[ARG3]], %[[ARG4]] : memref<?xf32>, memref<?xf32>
182  "lmhlo.add"(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
183  "lmhlo.abs"(%arg2, %arg5) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
184  "lmhlo.reduce"(%arg5, %init, %arg3) ( {
185  ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
186    "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
187    "lmhlo.terminator"() : () -> ()
188  } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
189  "lmhlo.reduce"(%arg2, %init, %arg4) ( {
190  ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
191    "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
192    "lmhlo.terminator"() : () -> ()
193  } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
194  return %arg3, %arg4: memref<?xf32>, memref<?xf32>
195}
196
197// -----
198
199// CHECK-LABEL: @row_red_and_col_red_kinput
200// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?xf32>, %[[ARG4:.*]]: memref<?xf32>, %[[ARG5:.*]]: memref<?x?xf32>, %[[ARG6:.*]]: memref<f32>
201func @row_red_and_col_red_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?xf32>, %arg4: memref<?xf32>, %arg5: memref<?x?xf32>, %init: memref<f32>) -> (memref<?xf32>, memref<?xf32>) {
202  // CHECK: "lmhlo.fusion"() ( {
203  // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
204  // CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
205  // CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( {
206  // CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( {
207  // CHECK: }) : () -> ()
208  // CHECK: return %[[ARG3]], %[[ARG4]] : memref<?xf32>, memref<?xf32>
209  "lmhlo.add"(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
210  "lmhlo.abs"(%arg2, %arg5) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
211  "lmhlo.reduce"(%arg5, %init, %arg3) ( {
212  ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
213    "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
214    "lmhlo.terminator"() : () -> ()
215  } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
216  "lmhlo.reduce"(%arg2, %init, %arg4) ( {
217  ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>):
218    "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> ()
219    "lmhlo.terminator"() : () -> ()
220  } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
221  return %arg3, %arg4: memref<?xf32>, memref<?xf32>
222}
223
224// -----
225
226// CHECK-LABEL: @reduce_should_not_have_consumer_in_the_fusion
227// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>
228func @reduce_should_not_have_consumer_in_the_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>)
229-> (memref<?x?xf32>, memref<?xf32>) {
230  // CHECK: %[[TMP4:.*]] = memref.alloc
231  // CHECK: %[[TMP7:.*]] = memref.alloc
232  // CHECK: %[[TMP8:.*]] = memref.alloc
233  // CHECK: %[[TMP9:.*]] = memref.alloc
234  // CHECK: "lmhlo.fusion"() ( {
235  // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP4]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
236  // CHECK: "lmhlo.subtract"(%[[ARG0]], %[[TMP4]], %[[TMP7]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
237  // CHECK: "lmhlo.constant"(%[[TMP8]]) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32>) -> ()
238  // CHECK: "lmhlo.reduce"(%[[TMP7]], %[[TMP8]], %[[TMP9]]) ( {
239  // CHECK: }) : () -> ()
240  // CHECK: memref.dealloc %[[TMP4]] : memref<?x?xf32>
241  // CHECK: memref.dealloc %[[TMP8]] : memref<f32>
242  // CHECK: %[[TMP12:.*]] = memref.alloc
243  // CHECK: "lmhlo.add"(%[[TMP9]], %[[TMP9]], %[[TMP12]]) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
244  // CHECK: memref.dealloc %[[TMP9]] : memref<?xf32>
245  // CHECK: return %[[TMP7]], %[[TMP12]] : memref<?x?xf32>, memref<?xf32>
246  %c1 = constant 1 : index
247  %c0 = constant 0 : index
248  %0 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
249  %1 = tensor.extract %0[%c0] : tensor<2xindex>
250  %2 = tensor.extract %0[%c1] : tensor<2xindex>
251  %3 = memref.alloc(%1, %2) : memref<?x?xf32>
252  "lmhlo.add"(%arg0, %arg1, %3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
253  %4 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
254  %5 = tensor.extract %4[%c0] : tensor<2xindex>
255  %6 = tensor.extract %4[%c1] : tensor<2xindex>
256  %7 = memref.alloc(%5, %6) : memref<?x?xf32>
257  "lmhlo.subtract"(%arg0, %3, %7) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
258  memref.dealloc %3 : memref<?x?xf32>
259  %8 = memref.alloc() : memref<f32>
260  "lmhlo.constant"(%8) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32>) -> ()
261  %9 = memref.alloc(%5) : memref<?xf32>
262  "lmhlo.reduce"(%7, %8, %9) ( {
263  ^bb0(%arg2: memref<f32>, %arg3: memref<f32>, %arg4: memref<f32>):  // no predecessors
264    "lmhlo.add"(%arg2, %arg3, %arg4) : (memref<f32>, memref<f32>, memref<f32>) -> ()
265    "lmhlo.terminator"() : () -> ()
266  }) {dimensions = dense<1> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> ()
267  memref.dealloc %8 : memref<f32>
268  %10 = shape.shape_of %9 : memref<?xf32> -> tensor<1xindex>
269  %11 = tensor.extract %10[%c0] : tensor<1xindex>
270  %12 = memref.alloc(%11) : memref<?xf32>
271  "lmhlo.add"(%9, %9, %12) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
272  memref.dealloc %9 : memref<?xf32>
273  return %7, %12 : memref<?x?xf32>, memref<?xf32>
274}
275
276// -----
277
278// CHECK-LABEL: @const_should_not_be_output
279func @const_should_not_be_output(%arg0: memref<f32>) -> (memref<f32>, memref<f32>) {
280  // CHECK-NOT: lmhlo.fusion
281  %0 = memref.alloc() : memref<f32>
282  "lmhlo.constant"(%0) {value = dense<1.000000e+00> : tensor<f32>} : (memref<f32>) -> ()
283  %1 = memref.alloc() : memref<f32>
284  "lmhlo.add"(%arg0, %0, %1) : (memref<f32>, memref<f32>, memref<f32>) -> ()
285  return %0, %1 : memref<f32>, memref<f32>
286}
287