• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: mlir-hlo-opt %s -input-inline-fusion -split-input-file | FileCheck %s
2
3// CHECK-LABEL: @inline_fusion_fusion_order
4// CHECK-SAME: (%[[INPUT1:.*]]: memref<?xf32>, %[[INPUT2:.*]]: memref<3xi32>, %[[INPUT3:.*]]: memref<?x?x?xf32>, %[[INPUT4:.*]]: memref<?x?x?xf32>, %[[TMP_BUF1:.*]]: memref<?x?x?xf32>, %[[TMP_BUF2:.*]]: memref<?x?x?xf32>, %[[OUT:.*]]: memref<?x?x?xf32>) -> memref<?x?x?xf32>
5func @inline_fusion_fusion_order(%arg0: memref<?xf32>, %arg1: memref<3xi32>, %arg2: memref<?x?x?xf32>, %arg3: memref<?x?x?xf32>, %arg4: memref<?x?x?xf32>, %arg5: memref<?x?x?xf32>, %arg6: memref<?x?x?xf32>) -> memref<?x?x?xf32> {
6  %c2 = constant 2 : index
7  %c1 = constant 1 : index
8  %c0 = constant 0 : index
9  // CHECK: "lmhlo.fusion"() ( {
10  "lmhlo.fusion"() ( {
11    // CHECK-NOT: lmhlo.dynamic_broadcast_in_dim
12    // CHECK-NOT: lmhlo.add
13    "lmhlo.dynamic_broadcast_in_dim"(%arg0, %arg1, %arg4) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (memref<?xf32>, memref<3xi32>, memref<?x?x?xf32>) -> ()
14    "lmhlo.add"(%arg2, %arg4, %arg5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
15    %0 = memref.dim %arg6, %c0 : memref<?x?x?xf32>
16    %1 = memref.dim %arg6, %c1 : memref<?x?x?xf32>
17    %2 = muli %0, %1 : index
18    %3 = memref.dim %arg6, %c2 : memref<?x?x?xf32>
19    %4 = muli %2, %3 : index
20    // CHECK: scf.parallel
21    scf.parallel (%arg7) = (%c0) to (%4) step (%c1) {
22      %5 = memref.dim %arg3, %c1 : memref<?x?x?xf32>
23      %6 = memref.dim %arg3, %c2 : memref<?x?x?xf32>
24      %7 = muli %6, %5 : index
25      %8 = divi_unsigned %arg7, %7 : index
26      %9 = remi_unsigned %arg7, %7 : index
27      %10 = divi_unsigned %9, %6 : index
28      %11 = remi_unsigned %9, %6 : index
29      %12 = memref.load %arg3[%8, %10, %11] : memref<?x?x?xf32>
30      %13 = memref.load %arg5[%8, %10, %11] : memref<?x?x?xf32>
31      %14 = mulf %12, %13 : f32
32      %15 = memref.reinterpret_cast %arg6 to offset: [%c0], sizes: [%4], strides: [%c1] : memref<?x?x?xf32> to memref<?xf32>
33      memref.store %14, %15[%arg7] : memref<?xf32>
34      scf.yield
35    }
36    // CHECK: "lmhlo.terminator"() : () -> ()
37    "lmhlo.terminator"() : () -> ()
38  }) : () -> ()
39  // CHECK: return %[[OUT]] : memref<?x?x?xf32>
40  return %arg6 : memref<?x?x?xf32>
41}
42
43// CHECK-LABEL: @multioutput_loop_fusion_with_dependency
44// CHECK-SAME: (%[[INPUT1:.*]]: memref<?xf32>, %[[INPUT2:.*]]: memref<3xi32>, %[[INPUT3:.*]]: memref<?x?x?xf32>, %[[TMP_BUF:.*]]: memref<?x?x?xf32>, %[[OUT1:.*]]: memref<?x?x?xf32>, %[[OUT2:.*]]: memref<?x?x?xf32>) -> (memref<?x?x?xf32>, memref<?x?x?xf32>)
45func @multioutput_loop_fusion_with_dependency(%arg0: memref<?xf32>, %arg1: memref<3xi32>, %arg2: memref<?x?x?xf32>, %arg3: memref<?x?x?xf32>, %arg4: memref<?x?x?xf32>, %arg5: memref<?x?x?xf32>) -> (memref<?x?x?xf32>, memref<?x?x?xf32>) {
46  %c2 = constant 2 : index
47  %c1 = constant 1 : index
48  %c0 = constant 0 : index
49  // CHECK: "lmhlo.fusion"() ( {
50  "lmhlo.fusion"() ( {
51    // CHECK-NOT: lmhlo.dynamic_broadcast_in_dim
52    // CHECK-NOT: lmhlo.add
53    "lmhlo.dynamic_broadcast_in_dim"(%arg0, %arg1, %arg3) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (memref<?xf32>, memref<3xi32>, memref<?x?x?xf32>) -> ()
54    "lmhlo.add"(%arg2, %arg3, %arg4) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
55    %0 = memref.dim %arg5, %c0 : memref<?x?x?xf32>
56    %1 = memref.dim %arg5, %c1 : memref<?x?x?xf32>
57    %2 = muli %0, %1 : index
58    %3 = memref.dim %arg5, %c2 : memref<?x?x?xf32>
59    %4 = muli %2, %3 : index
60    // CHECK: scf.parallel
61    scf.parallel (%arg6) = (%c0) to (%4) step (%c1) {
62      %5 = memref.dim %arg2, %c1 : memref<?x?x?xf32>
63      %6 = memref.dim %arg2, %c2 : memref<?x?x?xf32>
64      %7 = muli %6, %5 : index
65      %8 = divi_unsigned %arg6, %7 : index
66      %9 = remi_unsigned %arg6, %7 : index
67      %10 = divi_unsigned %9, %6 : index
68      %11 = remi_unsigned %9, %6 : index
69      %12 = memref.load %arg2[%8, %10, %11] : memref<?x?x?xf32>
70      %13 = memref.load %arg3[%8, %10, %11] : memref<?x?x?xf32>
71      %14 = addf %12, %13 : f32
72      %15 = memref.dim %arg4, %c0 : memref<?x?x?xf32>
73      %16 = memref.dim %arg4, %c1 : memref<?x?x?xf32>
74      %17 = muli %15, %16 : index
75      %18 = memref.dim %arg4, %c2 : memref<?x?x?xf32>
76      %19 = muli %17, %18 : index
77      %20 = memref.reinterpret_cast %arg4 to offset: [%c0], sizes: [%19], strides: [%c1] : memref<?x?x?xf32> to memref<?xf32>
78      memref.store %14, %20[%arg6] : memref<?xf32>
79      %21 = memref.load %arg2[%8, %10, %11] : memref<?x?x?xf32>
80      %22 = memref.load %arg4[%8, %10, %11] : memref<?x?x?xf32>
81      %23 = mulf %21, %22 : f32
82      %24 = memref.reinterpret_cast %arg5 to offset: [%c0], sizes: [%4], strides: [%c1] : memref<?x?x?xf32> to memref<?xf32>
83      memref.store %23, %24[%arg6] : memref<?xf32>
84      scf.yield
85    }
86    // CHECK: "lmhlo.terminator"() : () -> ()
87    "lmhlo.terminator"() : () -> ()
88  }) : () -> ()
89  // CHECK: return %[[OUT1]], %[[OUT2]] : memref<?x?x?xf32>, memref<?x?x?xf32>
90  return %arg4, %arg5 : memref<?x?x?xf32>, memref<?x?x?xf32>
91}
92