1// RUN: mlir-hlo-opt -lhlo-legalize-tensor-load-op %s -o - | FileCheck %s 2 3// test: `memref -> memref.tensor_load -> tensor.extract` -> `memref -> memref.load` 4// CHECK-LABEL: forward_extract_op 5// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<3xindex>) 6func @forward_extract_op(%arg0: memref<?x?xf32>, %arg1: memref<3xindex>) -> memref<?x?x?xf32> { 7 %c0 = constant 0 : index 8 %c1 = constant 1 : index 9 %c2 = constant 2 : index 10 // CHECK-NOT: memref.tensor_load 11 // CHECK-NOT: tensor.extract 12 // CHECK: %[[DIM0:.*]] = memref.load %[[ARG1]][%c0] 13 // CHECK: %[[DIM1:.*]] = memref.load %[[ARG1]][%c1] 14 // CHECK: %[[DIM2:.*]] = memref.load %[[ARG1]][%c2] 15 // CHECK: memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]]) 16 %0 = memref.tensor_load %arg1 : memref<3xindex> 17 %1 = tensor.extract %0[%c0] : tensor<3xindex> 18 %2 = tensor.extract %0[%c1] : tensor<3xindex> 19 %3 = tensor.extract %0[%c2] : tensor<3xindex> 20 %4 = memref.alloc(%1, %2, %3) : memref<?x?x?xf32> 21 "lmhlo.dynamic_broadcast_in_dim"(%arg0, %arg1, %4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<?x?xf32>, memref<3xindex>, memref<?x?x?xf32>) -> () 22 return %4 : memref<?x?x?xf32> 23} 24 25// ----- 26 27// test: `memref -> memref.tensor_load -> shape.shape_of` -> `memref -> shape.shape_of` 28// CHECK-LABEL: forward_shape_of_op 29// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>) 30func @forward_shape_of_op(%arg0: memref<?x?xf32>) -> tensor<2xindex> { 31 // CHECK-NOT: memref.tensor_load 32 // CHECK: shape.shape_of %[[ARG]] : memref<?x?xf32> -> tensor<2xindex> 33 %0 = memref.tensor_load %arg0 : memref<?x?xf32> 34 %1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex> 35 return %1 : tensor<2xindex> 36} 37