1// RUN: mlir-hlo-opt %s -input-inline-fusion -split-input-file | FileCheck %s 2 3// CHECK-LABEL: @inline_fusion_fusion_order 4// CHECK-SAME: (%[[INPUT1:.*]]: memref<?xf32>, %[[INPUT2:.*]]: memref<3xi32>, %[[INPUT3:.*]]: memref<?x?x?xf32>, %[[INPUT4:.*]]: memref<?x?x?xf32>, %[[TMP_BUF1:.*]]: memref<?x?x?xf32>, %[[TMP_BUF2:.*]]: memref<?x?x?xf32>, %[[OUT:.*]]: memref<?x?x?xf32>) -> memref<?x?x?xf32> 5func @inline_fusion_fusion_order(%arg0: memref<?xf32>, %arg1: memref<3xi32>, %arg2: memref<?x?x?xf32>, %arg3: memref<?x?x?xf32>, %arg4: memref<?x?x?xf32>, %arg5: memref<?x?x?xf32>, %arg6: memref<?x?x?xf32>) -> memref<?x?x?xf32> { 6 %c2 = constant 2 : index 7 %c1 = constant 1 : index 8 %c0 = constant 0 : index 9 // CHECK: "lmhlo.fusion"() ( { 10 "lmhlo.fusion"() ( { 11 // CHECK-NOT: lmhlo.dynamic_broadcast_in_dim 12 // CHECK-NOT: lmhlo.add 13 "lmhlo.dynamic_broadcast_in_dim"(%arg0, %arg1, %arg4) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (memref<?xf32>, memref<3xi32>, memref<?x?x?xf32>) -> () 14 "lmhlo.add"(%arg2, %arg4, %arg5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> () 15 %0 = memref.dim %arg6, %c0 : memref<?x?x?xf32> 16 %1 = memref.dim %arg6, %c1 : memref<?x?x?xf32> 17 %2 = muli %0, %1 : index 18 %3 = memref.dim %arg6, %c2 : memref<?x?x?xf32> 19 %4 = muli %2, %3 : index 20 // CHECK: scf.parallel 21 scf.parallel (%arg7) = (%c0) to (%4) step (%c1) { 22 %5 = memref.dim %arg3, %c1 : memref<?x?x?xf32> 23 %6 = memref.dim %arg3, %c2 : memref<?x?x?xf32> 24 %7 = muli %6, %5 : index 25 %8 = divi_unsigned %arg7, %7 : index 26 %9 = remi_unsigned %arg7, %7 : index 27 %10 = divi_unsigned %9, %6 : index 28 %11 = remi_unsigned %9, %6 : index 29 %12 = memref.load %arg3[%8, %10, %11] : memref<?x?x?xf32> 30 %13 = memref.load %arg5[%8, %10, %11] : memref<?x?x?xf32> 31 %14 = mulf %12, %13 : f32 32 %15 = memref.reinterpret_cast %arg6 to offset: [%c0], sizes: [%4], strides: [%c1] : memref<?x?x?xf32> to memref<?xf32> 33 memref.store %14, %15[%arg7] : memref<?xf32> 34 scf.yield 35 } 36 // CHECK: "lmhlo.terminator"() : () -> () 37 "lmhlo.terminator"() : () -> () 38 }) : () -> () 39 // CHECK: return %[[OUT]] : memref<?x?x?xf32> 40 return %arg6 : memref<?x?x?xf32> 41} 42 43// CHECK-LABEL: @multioutput_loop_fusion_with_dependency 44// CHECK-SAME: (%[[INPUT1:.*]]: memref<?xf32>, %[[INPUT2:.*]]: memref<3xi32>, %[[INPUT3:.*]]: memref<?x?x?xf32>, %[[TMP_BUF:.*]]: memref<?x?x?xf32>, %[[OUT1:.*]]: memref<?x?x?xf32>, %[[OUT2:.*]]: memref<?x?x?xf32>) -> (memref<?x?x?xf32>, memref<?x?x?xf32>) 45func @multioutput_loop_fusion_with_dependency(%arg0: memref<?xf32>, %arg1: memref<3xi32>, %arg2: memref<?x?x?xf32>, %arg3: memref<?x?x?xf32>, %arg4: memref<?x?x?xf32>, %arg5: memref<?x?x?xf32>) -> (memref<?x?x?xf32>, memref<?x?x?xf32>) { 46 %c2 = constant 2 : index 47 %c1 = constant 1 : index 48 %c0 = constant 0 : index 49 // CHECK: "lmhlo.fusion"() ( { 50 "lmhlo.fusion"() ( { 51 // CHECK-NOT: lmhlo.dynamic_broadcast_in_dim 52 // CHECK-NOT: lmhlo.add 53 "lmhlo.dynamic_broadcast_in_dim"(%arg0, %arg1, %arg3) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (memref<?xf32>, memref<3xi32>, memref<?x?x?xf32>) -> () 54 "lmhlo.add"(%arg2, %arg3, %arg4) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> () 55 %0 = memref.dim %arg5, %c0 : memref<?x?x?xf32> 56 %1 = memref.dim %arg5, %c1 : memref<?x?x?xf32> 57 %2 = muli %0, %1 : index 58 %3 = memref.dim %arg5, %c2 : memref<?x?x?xf32> 59 %4 = muli %2, %3 : index 60 // CHECK: scf.parallel 61 scf.parallel (%arg6) = (%c0) to (%4) step (%c1) { 62 %5 = memref.dim %arg2, %c1 : memref<?x?x?xf32> 63 %6 = memref.dim %arg2, %c2 : memref<?x?x?xf32> 64 %7 = muli %6, %5 : index 65 %8 = divi_unsigned %arg6, %7 : index 66 %9 = remi_unsigned %arg6, %7 : index 67 %10 = divi_unsigned %9, %6 : index 68 %11 = remi_unsigned %9, %6 : index 69 %12 = memref.load %arg2[%8, %10, %11] : memref<?x?x?xf32> 70 %13 = memref.load %arg3[%8, %10, %11] : memref<?x?x?xf32> 71 %14 = addf %12, %13 : f32 72 %15 = memref.dim %arg4, %c0 : memref<?x?x?xf32> 73 %16 = memref.dim %arg4, %c1 : memref<?x?x?xf32> 74 %17 = muli %15, %16 : index 75 %18 = memref.dim %arg4, %c2 : memref<?x?x?xf32> 76 %19 = muli %17, %18 : index 77 %20 = memref.reinterpret_cast %arg4 to offset: [%c0], sizes: [%19], strides: [%c1] : memref<?x?x?xf32> to memref<?xf32> 78 memref.store %14, %20[%arg6] : memref<?xf32> 79 %21 = memref.load %arg2[%8, %10, %11] : memref<?x?x?xf32> 80 %22 = memref.load %arg4[%8, %10, %11] : memref<?x?x?xf32> 81 %23 = mulf %21, %22 : f32 82 %24 = memref.reinterpret_cast %arg5 to offset: [%c0], sizes: [%4], strides: [%c1] : memref<?x?x?xf32> to memref<?xf32> 83 memref.store %23, %24[%arg6] : memref<?xf32> 84 scf.yield 85 } 86 // CHECK: "lmhlo.terminator"() : () -> () 87 "lmhlo.terminator"() : () -> () 88 }) : () -> () 89 // CHECK: return %[[OUT1]], %[[OUT2]] : memref<?x?x?xf32>, memref<?x?x?xf32> 90 return %arg4, %arg5 : memref<?x?x?xf32>, memref<?x?x?xf32> 91} 92