1// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s 2 3// CHECK-LABEL: func @const_fold_collapse_to_scalar 4func.func @const_fold_collapse_to_scalar() -> tensor<i32> { 5 // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32> 6 %cst = mhlo.constant dense<42> : tensor<1x1xi32> 7 %0 = "mhlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32> 8 // CHECK-NEXT: return [[CST]] 9 func.return %0 : tensor<i32> 10} 11 12// ----- 13 14// CHECK-LABEL: func @const_fold_collapse_to_tensor 15func.func @const_fold_collapse_to_tensor() -> tensor<2xi32> { 16 // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<2xi32> 17 %cst = mhlo.constant dense<42> : tensor<1x2xi32> 18 %0 = "mhlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32> 19 // CHECK-NEXT: return [[CST]] 20 func.return %0 : tensor<2xi32> 21} 22 23// ----- 24 25// CHECK-LABEL: func @const_fold_expand 26func.func @const_fold_expand() -> tensor<1xi32> { 27 // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<1xi32> 28 %cst = mhlo.constant dense<42> : tensor<i32> 29 %0 = "mhlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32> 30 // CHECK-NEXT: return [[CST]] 31 func.return %0 : tensor<1xi32> 32} 33 34// ----- 35 36// CHECK-LABEL: func @const_fold_nontrivial 37func.func @const_fold_nontrivial() -> tensor<16xi64> { 38 // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64> 39 %cst = mhlo.constant dense<42> : tensor<4x4xi64> 40 %0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> 41 // CHECK-NEXT: return [[CST]] 42 func.return %0 : tensor<16xi64> 43} 44 45// ----- 46 47// CHECK-LABEL: func @const_fold_flatten 48func.func @const_fold_flatten() -> tensor<16xi64> { 49 // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64> 50 %cst = mhlo.constant dense<42> : tensor<4x4xi64> 51 %0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> 52 // CHECK-NEXT: return [[CST]] 53 func.return %0 : tensor<16xi64> 54} 55 56// ----- 57 58// CHECK-LABEL: func @const_fold_6 59func.func @const_fold_6() -> tensor<6xi32> { 60 // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> 61 %cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> 62 %0 = "mhlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32> 63 // CHECK-NEXT: return [[CST]] 64 func.return %0 : tensor<6xi32> 65} 66 67// ----- 68 69// CHECK-LABEL: func @const_fold_same_shape 70func.func @const_fold_same_shape() -> tensor<2x3xi32> { 71 // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[ 72 // CHECK-SAME: [1, 2, 3], [4, 5, 6] 73 // CHECK-SAME: ]> : tensor<2x3xi32> 74 %cst = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> 75 %0 = "mhlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32> 76 // CHECK-NEXT: return [[CST]] 77 func.return %0 : tensor<2x3xi32> 78} 79 80// ----- 81 82// CHECK-LABEL: func @const_fold_float 83func.func @const_fold_float() -> tensor<16xf64> { 84 // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64> 85 %cst = mhlo.constant dense<4.2> : tensor<4x4xf64> 86 %0 = "mhlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64> 87 // CHECK-NEXT: return [[CST]] 88 func.return %0 : tensor<16xf64> 89} 90 91// ----- 92 93// CHECK-LABEL: func @non_const_same_shape 94// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 95func.func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { 96 // CHECK-NEXT: return [[ARG]] 97 %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32> 98 func.return %0 : tensor<2x3xi32> 99} 100 101// ----- 102 103// CHECK-LABEL: func @non_const_chained_reshape 104// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 105func.func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) { 106 // CHECK-NEXT: mhlo.reshape [[ARG]] : (tensor<2x3xi32>) -> tensor<3x2xi32> 107 // CHECK-NEXT: mhlo.reshape [[ARG]] : (tensor<2x3xi32>) -> tensor<6xi32> 108 %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> 109 %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> 110 func.return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed 111} 112 113// ----- 114 115// CHECK-LABEL: func @non_const_chained_reshape_unused_parent 116// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 117func.func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> { 118 // CHECK-NEXT: [[RES:%.+]] = mhlo.reshape [[ARG]] : (tensor<2x3xi32>) -> tensor<6xi32> 119 %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> 120 %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> 121 // CHECK-NEXT: return [[RES]] 122 func.return %1 : tensor<6xi32> 123} 124 125// ----- 126 127// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop 128// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 129func.func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { 130 %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> 131 %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32> 132 // CHECK-NEXT: return [[ARG]] 133 func.return %1 : tensor<2x3xi32> 134} 135 136// ----- 137 138// CHECK-LABEL: func @non_const_many_chained_reshapes 139// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] 140func.func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> { 141 // CHECK-NEXT: [[RES:%.+]] = mhlo.reshape [[ARG]] : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> 142 %0 = "mhlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32> 143 %1 = "mhlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32> 144 %2 = "mhlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32> 145 %3 = "mhlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32> 146 %4 = "mhlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32> 147 // CHECK-NEXT: return [[RES]] 148 func.return %4 : tensor<1x2x4x3xi32> 149} 150