1// RUN: mlir-hlo-opt --lhlo-fusion -split-input-file %s -o - | FileCheck %s 2 3// CHECK-LABEL: @simple_kloop_fusion 4// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>) -> memref<?x?xf32> 5func @simple_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, 6 %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> memref<?x?xf32> { 7 // CHECK: "lmhlo.fusion"() ( { 8 // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () 9 // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 10 // CHECK: }) : () -> () 11 // CHECK: return %[[ARG3]] : memref<?x?xf32> 12 "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () 13 "lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 14 return %arg3 : memref<?x?xf32> 15} 16 17// ----- 18 19// CHECK-LABEL: @simple_multi_output_kloop_fusion 20// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) 21func @simple_multi_output_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, 22 %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) { 23 // CHECK: "lmhlo.fusion"() ( { 24 // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () 25 // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 26 // CHECK: }) : () -> () 27 // CHECK: return %[[ARG1]], %[[ARG3]] : memref<?x?xf32>, memref<?x?xf32> 28 "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () 29 "lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 30 return %arg1, %arg3 : memref<?x?xf32>, memref<?x?xf32> 31} 32 33// ----- 34 35// CHECK-LABEL: @simple_multi_output_kloop_fusion_with_reorder 36// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?x?xf32>, %[[ARG4:.*]]: memref<2xindex>, %[[ARG5:.*]]: memref<?x?xf32>) 37func @simple_multi_output_kloop_fusion_with_reorder(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, 38 %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>, 39 %arg4: memref<2xindex>, %arg5: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) { 40 // CHECK: "lmhlo.fusion"() ( { 41 // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () 42 // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 43 // CHECK: }) : () -> () 44 // CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[ARG4]], %[[ARG5]]) 45 // CHECK: return %[[ARG1]], %[[ARG3]], %[[ARG5]] : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32> 46 "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () 47 "lmhlo.dynamic_broadcast_in_dim"(%arg1, %arg4, %arg5) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (memref<?x?xf32>, memref<2xindex>, memref<?x?xf32>) -> () 48 "lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 49 return %arg1, %arg3, %arg5 : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32> 50} 51 52// ----- 53 54// CHECK-LABEL: @same_num_elements_multi_output_kloop_fusion 55// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<2xi64>, %[[ARG3:.*]]: memref<?x?x?xf32>, %[[ARG4:.*]]: memref<?x?x?xf32>, %[[ARG5:.*]]: memref<?x?x?xf32>) 56func @same_num_elements_multi_output_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, 57 %arg2: memref<2xi64>, %arg3: memref<?x?x?xf32>, 58 %arg4: memref<?x?x?xf32>, %arg5: memref<?x?x?xf32>) -> (memref<?x?xf32>, memref<?x?x?xf32>) { 59 // CHECK: "lmhlo.fusion"() ( { 60 // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () 61 // CHECK: "lmhlo.dynamic_reshape"(%[[ARG1]], %[[ARG2]], %[[ARG3]]) 62 // CHECK: "lmhlo.add"(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> () 63 // CHECK: }) : () -> () 64 // CHECK: return %[[ARG1]], %[[ARG5]] : memref<?x?xf32>, memref<?x?x?xf32> 65 "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () 66 "lmhlo.dynamic_reshape"(%arg1, %arg2, %arg3) : (memref<?x?xf32>, memref<2xi64>, memref<?x?x?xf32>) -> () 67 "lmhlo.add"(%arg3, %arg4, %arg5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> () 68 return %arg1, %arg5 : memref<?x?xf32>, memref<?x?x?xf32> 69} 70 71// ----- 72 73// CHECK-LABEL: @check_not_kloop_fusion 74func @check_not_kloop_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) { 75 // CHECK-NOT: "lmhlo.fusion" 76 "lmhlo.add"(%arg0, %arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 77 "lmhlo.subtract"(%arg2, %arg2, %arg3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 78 return %arg1, %arg3: memref<?x?xf32>, memref<?x?xf32> 79} 80 81// ----- 82 83// CHECK-LABEL: @kloop_fusion_with_dealloc 84// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>) 85func @kloop_fusion_with_dealloc(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) -> (memref<?x?xf32>, memref<?x?xf32>) { 86 // CHECK: %[[TMP3:.*]] = memref.alloc 87 // CHECK: %[[TMP5:.*]] = memref.alloc 88 // CHECK: %[[TMP9:.*]] = memref.alloc 89 // CHECK: %[[TMP13:.*]] = memref.alloc 90 // CHECK: %[[TMP16:.*]] = memref.alloc 91 // CHECK: "lmhlo.fusion"() ( { 92 // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP3]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 93 // CHECK: "lmhlo.multiply"(%[[ARG0]], %[[ARG1]], %[[TMP5]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 94 // CHECK: "lmhlo.abs"(%[[TMP3]], %[[TMP9]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () 95 // CHECK: "lmhlo.abs"(%[[TMP5]], %[[TMP13]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () 96 // CHECK: "lmhlo.multiply"(%[[TMP9]], %[[TMP13]], %[[TMP16]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 97 // CHECK: }) : () -> () 98 // CHECK: memref.dealloc %[[TMP3]] : memref<?x?xf32> 99 // CHECK: memref.dealloc %[[TMP5]] : memref<?x?xf32> 100 // CHECK: memref.dealloc %[[TMP13]] : memref<?x?xf32> 101 // CHECK: return %[[TMP9]], %[[TMP16]] : memref<?x?xf32>, memref<?x?xf32> 102 %c0 = constant 0 : index 103 %c1 = constant 1 : index 104 %0 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> 105 %1 = tensor.extract %0[%c0] : tensor<2xindex> 106 %2 = tensor.extract %0[%c1] : tensor<2xindex> 107 %3 = memref.alloc(%1, %2) : memref<?x?xf32> 108 "lmhlo.add"(%arg0, %arg1, %3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 109 %4 = memref.alloc(%1, %2) : memref<?x?xf32> 110 "lmhlo.multiply"(%arg0, %arg1, %4) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 111 %5 = shape.shape_of %3 : memref<?x?xf32> -> tensor<2xindex> 112 %6 = tensor.extract %5[%c0] : tensor<2xindex> 113 %7 = tensor.extract %5[%c1] : tensor<2xindex> 114 %8 = memref.alloc(%6, %7) : memref<?x?xf32> 115 "lmhlo.abs"(%3, %8) : (memref<?x?xf32>, memref<?x?xf32>) -> () 116 memref.dealloc %3 : memref<?x?xf32> 117 %9 = shape.shape_of %4 : memref<?x?xf32> -> tensor<2xindex> 118 %10 = tensor.extract %9[%c0] : tensor<2xindex> 119 %11 = tensor.extract %9[%c1] : tensor<2xindex> 120 %12 = memref.alloc(%10, %11) : memref<?x?xf32> 121 "lmhlo.abs"(%4, %12) : (memref<?x?xf32>, memref<?x?xf32>) -> () 122 memref.dealloc %4 : memref<?x?xf32> 123 %13 = shape.shape_of %8 : memref<?x?xf32> -> tensor<2xindex> 124 %14 = tensor.extract %13[%c0] : tensor<2xindex> 125 %15 = tensor.extract %13[%c1] : tensor<2xindex> 126 %16 = memref.alloc(%14, %15) : memref<?x?xf32> 127 "lmhlo.multiply"(%8, %12, %16) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 128 memref.dealloc %12 : memref<?x?xf32> 129 return %8, %16 : memref<?x?xf32>, memref<?x?xf32> 130} 131 132// ----- 133 134// CHECK-LABEL: @simple_kinput 135// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<f32> 136func @simple_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?xf32>, %init: memref<f32>) -> memref<?xf32> { 137 // CHECK: "lmhlo.fusion"() ( { 138 // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () 139 // CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( { 140 // CHECK: }) : () -> () 141 // CHECK: return %[[ARG2]] : memref<?xf32> 142 "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () 143 "lmhlo.reduce"(%arg1, %init, %arg2) ( { 144 ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): 145 "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () 146 "lmhlo.terminator"() : () -> () 147 } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () 148 return %arg2: memref<?xf32> 149} 150 151// ----- 152 153// CHECK-LABEL: @multi_output_kinput 154// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<f32> 155func @multi_output_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?xf32>, %init: memref<f32>) -> (memref<?x?xf32>, memref<?xf32>) { 156 // CHECK: "lmhlo.fusion"() ( { 157 // CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () 158 // CHECK: "lmhlo.reduce"(%[[ARG1]], %[[ARG3]], %[[ARG2]]) ( { 159 // CHECK: }) : () -> () 160 // CHECK: return %[[ARG1]], %[[ARG2]] : memref<?x?xf32>, memref<?xf32> 161 "lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32>, memref<?x?xf32>) -> () 162 "lmhlo.reduce"(%arg1, %init, %arg2) ( { 163 ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): 164 "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () 165 "lmhlo.terminator"() : () -> () 166 } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () 167 return %arg1, %arg2: memref<?x?xf32>, memref<?xf32> 168} 169 170// ----- 171 172// CHECK-LABEL: @row_red_and_row_red_kinput 173// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?xf32>, %[[ARG4:.*]]: memref<?xf32>, %[[ARG5:.*]]: memref<?x?xf32>, %[[ARG6:.*]]: memref<f32> 174func @row_red_and_row_red_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?xf32>, %arg4: memref<?xf32>, %arg5: memref<?x?xf32>, %init: memref<f32>) -> (memref<?xf32>, memref<?xf32>) { 175 // CHECK: "lmhlo.fusion"() ( { 176 // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 177 // CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () 178 // CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( { 179 // CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( { 180 // CHECK: }) : () -> () 181 // CHECK: return %[[ARG3]], %[[ARG4]] : memref<?xf32>, memref<?xf32> 182 "lmhlo.add"(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 183 "lmhlo.abs"(%arg2, %arg5) : (memref<?x?xf32>, memref<?x?xf32>) -> () 184 "lmhlo.reduce"(%arg5, %init, %arg3) ( { 185 ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): 186 "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () 187 "lmhlo.terminator"() : () -> () 188 } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () 189 "lmhlo.reduce"(%arg2, %init, %arg4) ( { 190 ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): 191 "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () 192 "lmhlo.terminator"() : () -> () 193 } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () 194 return %arg3, %arg4: memref<?xf32>, memref<?xf32> 195} 196 197// ----- 198 199// CHECK-LABEL: @row_red_and_col_red_kinput 200// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32>, %[[ARG2:.*]]: memref<?x?xf32>, %[[ARG3:.*]]: memref<?xf32>, %[[ARG4:.*]]: memref<?xf32>, %[[ARG5:.*]]: memref<?x?xf32>, %[[ARG6:.*]]: memref<f32> 201func @row_red_and_col_red_kinput(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: memref<?xf32>, %arg4: memref<?xf32>, %arg5: memref<?x?xf32>, %init: memref<f32>) -> (memref<?xf32>, memref<?xf32>) { 202 // CHECK: "lmhlo.fusion"() ( { 203 // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 204 // CHECK: "lmhlo.abs"(%[[ARG2]], %[[ARG5]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () 205 // CHECK: "lmhlo.reduce"(%[[ARG5]], %[[ARG6]], %[[ARG3]]) ( { 206 // CHECK: "lmhlo.reduce"(%[[ARG2]], %[[ARG6]], %[[ARG4]]) ( { 207 // CHECK: }) : () -> () 208 // CHECK: return %[[ARG3]], %[[ARG4]] : memref<?xf32>, memref<?xf32> 209 "lmhlo.add"(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 210 "lmhlo.abs"(%arg2, %arg5) : (memref<?x?xf32>, memref<?x?xf32>) -> () 211 "lmhlo.reduce"(%arg5, %init, %arg3) ( { 212 ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): 213 "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () 214 "lmhlo.terminator"() : () -> () 215 } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () 216 "lmhlo.reduce"(%arg2, %init, %arg4) ( { 217 ^bb0(%targ1: memref<f32>, %targ2: memref<f32>, %tresult: memref<f32>): 218 "lmhlo.add"(%targ1, %targ2, %tresult) : (memref<f32>, memref<f32>, memref<f32>) -> () 219 "lmhlo.terminator"() : () -> () 220 } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () 221 return %arg3, %arg4: memref<?xf32>, memref<?xf32> 222} 223 224// ----- 225 226// CHECK-LABEL: @reduce_should_not_have_consumer_in_the_fusion 227// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?x?xf32> 228func @reduce_should_not_have_consumer_in_the_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) 229-> (memref<?x?xf32>, memref<?xf32>) { 230 // CHECK: %[[TMP4:.*]] = memref.alloc 231 // CHECK: %[[TMP7:.*]] = memref.alloc 232 // CHECK: %[[TMP8:.*]] = memref.alloc 233 // CHECK: %[[TMP9:.*]] = memref.alloc 234 // CHECK: "lmhlo.fusion"() ( { 235 // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[TMP4]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 236 // CHECK: "lmhlo.subtract"(%[[ARG0]], %[[TMP4]], %[[TMP7]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 237 // CHECK: "lmhlo.constant"(%[[TMP8]]) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32>) -> () 238 // CHECK: "lmhlo.reduce"(%[[TMP7]], %[[TMP8]], %[[TMP9]]) ( { 239 // CHECK: }) : () -> () 240 // CHECK: memref.dealloc %[[TMP4]] : memref<?x?xf32> 241 // CHECK: memref.dealloc %[[TMP8]] : memref<f32> 242 // CHECK: %[[TMP12:.*]] = memref.alloc 243 // CHECK: "lmhlo.add"(%[[TMP9]], %[[TMP9]], %[[TMP12]]) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> () 244 // CHECK: memref.dealloc %[[TMP9]] : memref<?xf32> 245 // CHECK: return %[[TMP7]], %[[TMP12]] : memref<?x?xf32>, memref<?xf32> 246 %c1 = constant 1 : index 247 %c0 = constant 0 : index 248 %0 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> 249 %1 = tensor.extract %0[%c0] : tensor<2xindex> 250 %2 = tensor.extract %0[%c1] : tensor<2xindex> 251 %3 = memref.alloc(%1, %2) : memref<?x?xf32> 252 "lmhlo.add"(%arg0, %arg1, %3) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 253 %4 = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> 254 %5 = tensor.extract %4[%c0] : tensor<2xindex> 255 %6 = tensor.extract %4[%c1] : tensor<2xindex> 256 %7 = memref.alloc(%5, %6) : memref<?x?xf32> 257 "lmhlo.subtract"(%arg0, %3, %7) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () 258 memref.dealloc %3 : memref<?x?xf32> 259 %8 = memref.alloc() : memref<f32> 260 "lmhlo.constant"(%8) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32>) -> () 261 %9 = memref.alloc(%5) : memref<?xf32> 262 "lmhlo.reduce"(%7, %8, %9) ( { 263 ^bb0(%arg2: memref<f32>, %arg3: memref<f32>, %arg4: memref<f32>): // no predecessors 264 "lmhlo.add"(%arg2, %arg3, %arg4) : (memref<f32>, memref<f32>, memref<f32>) -> () 265 "lmhlo.terminator"() : () -> () 266 }) {dimensions = dense<1> : tensor<1xi64>} : (memref<?x?xf32>, memref<f32>, memref<?xf32>) -> () 267 memref.dealloc %8 : memref<f32> 268 %10 = shape.shape_of %9 : memref<?xf32> -> tensor<1xindex> 269 %11 = tensor.extract %10[%c0] : tensor<1xindex> 270 %12 = memref.alloc(%11) : memref<?xf32> 271 "lmhlo.add"(%9, %9, %12) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> () 272 memref.dealloc %9 : memref<?xf32> 273 return %7, %12 : memref<?x?xf32>, memref<?xf32> 274} 275 276// ----- 277 278// CHECK-LABEL: @const_should_not_be_output 279func @const_should_not_be_output(%arg0: memref<f32>) -> (memref<f32>, memref<f32>) { 280 // CHECK-NOT: lmhlo.fusion 281 %0 = memref.alloc() : memref<f32> 282 "lmhlo.constant"(%0) {value = dense<1.000000e+00> : tensor<f32>} : (memref<f32>) -> () 283 %1 = memref.alloc() : memref<f32> 284 "lmhlo.add"(%arg0, %0, %1) : (memref<f32>, memref<f32>, memref<f32>) -> () 285 return %0, %1 : memref<f32>, memref<f32> 286} 287