• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: mlir-hlo-opt -lhlo-legalize-tensor-load-op %s -o - | FileCheck %s
2
3// test: `memref -> memref.tensor_load -> tensor.extract` -> `memref -> memref.load`
4// CHECK-LABEL: forward_extract_op
5// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<3xindex>)
6func @forward_extract_op(%arg0: memref<?x?xf32>, %arg1: memref<3xindex>) -> memref<?x?x?xf32> {
7  %c0 = constant 0 : index
8  %c1 = constant 1 : index
9  %c2 = constant 2 : index
10  // CHECK-NOT: memref.tensor_load
11  // CHECK-NOT: tensor.extract
12  // CHECK: %[[DIM0:.*]] = memref.load %[[ARG1]][%c0]
13  // CHECK: %[[DIM1:.*]] = memref.load %[[ARG1]][%c1]
14  // CHECK: %[[DIM2:.*]] = memref.load %[[ARG1]][%c2]
15  // CHECK: memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
16  %0 = memref.tensor_load %arg1 : memref<3xindex>
17  %1 = tensor.extract %0[%c0] : tensor<3xindex>
18  %2 = tensor.extract %0[%c1] : tensor<3xindex>
19  %3 = tensor.extract %0[%c2] : tensor<3xindex>
20  %4 = memref.alloc(%1, %2, %3) : memref<?x?x?xf32>
21  "lmhlo.dynamic_broadcast_in_dim"(%arg0, %arg1, %4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<?x?xf32>, memref<3xindex>, memref<?x?x?xf32>) -> ()
22  return %4 : memref<?x?x?xf32>
23}
24
25// -----
26
27// test: `memref -> memref.tensor_load -> shape.shape_of` -> `memref -> shape.shape_of`
28// CHECK-LABEL: forward_shape_of_op
29// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>)
30func @forward_shape_of_op(%arg0: memref<?x?xf32>) -> tensor<2xindex> {
31  // CHECK-NOT: memref.tensor_load
32  // CHECK: shape.shape_of %[[ARG]] : memref<?x?xf32> -> tensor<2xindex>
33  %0 = memref.tensor_load %arg0 : memref<?x?xf32>
34  %1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
35  return %1 : tensor<2xindex>
36}
37