• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s
2
3// CHECK-LABEL: func @const_fold_collapse_to_scalar
4func.func @const_fold_collapse_to_scalar() -> tensor<i32> {
5  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
6  %cst = mhlo.constant dense<42> : tensor<1x1xi32>
7  %0 = "mhlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
8  // CHECK-NEXT: return [[CST]]
9  func.return %0 : tensor<i32>
10}
11
12// -----
13
14// CHECK-LABEL: func @const_fold_collapse_to_tensor
15func.func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
16  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<2xi32>
17  %cst = mhlo.constant dense<42> : tensor<1x2xi32>
18  %0 = "mhlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
19  // CHECK-NEXT: return [[CST]]
20  func.return %0 : tensor<2xi32>
21}
22
23// -----
24
25// CHECK-LABEL: func @const_fold_expand
26func.func @const_fold_expand() -> tensor<1xi32> {
27  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<1xi32>
28  %cst = mhlo.constant dense<42> : tensor<i32>
29  %0 = "mhlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
30  // CHECK-NEXT: return [[CST]]
31  func.return %0 : tensor<1xi32>
32}
33
34// -----
35
36// CHECK-LABEL: func @const_fold_nontrivial
37func.func @const_fold_nontrivial() -> tensor<16xi64> {
38  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
39  %cst = mhlo.constant dense<42> : tensor<4x4xi64>
40  %0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
41  // CHECK-NEXT: return [[CST]]
42  func.return %0 : tensor<16xi64>
43}
44
45// -----
46
47// CHECK-LABEL: func @const_fold_flatten
48func.func @const_fold_flatten() -> tensor<16xi64> {
49  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
50  %cst = mhlo.constant dense<42> : tensor<4x4xi64>
51  %0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
52  // CHECK-NEXT: return [[CST]]
53  func.return %0 : tensor<16xi64>
54}
55
56// -----
57
58// CHECK-LABEL: func @const_fold_6
59func.func @const_fold_6() -> tensor<6xi32> {
60  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
61  %cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
62  %0 = "mhlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
63  // CHECK-NEXT: return [[CST]]
64  func.return %0 : tensor<6xi32>
65}
66
67// -----
68
69// CHECK-LABEL: func @const_fold_same_shape
70func.func @const_fold_same_shape() -> tensor<2x3xi32> {
71  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
72  // CHECK-SAME:   [1, 2, 3], [4, 5, 6]
73  // CHECK-SAME: ]> : tensor<2x3xi32>
74  %cst = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
75  %0 = "mhlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
76  // CHECK-NEXT: return [[CST]]
77  func.return %0 : tensor<2x3xi32>
78}
79
80// -----
81
82// CHECK-LABEL: func @const_fold_float
83func.func @const_fold_float() -> tensor<16xf64> {
84  // CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
85  %cst = mhlo.constant dense<4.2> : tensor<4x4xf64>
86  %0 = "mhlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
87  // CHECK-NEXT: return [[CST]]
88  func.return %0 : tensor<16xf64>
89}
90
91// -----
92
93// CHECK-LABEL: func @non_const_same_shape
94// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
95func.func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
96  // CHECK-NEXT: return [[ARG]]
97  %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
98  func.return %0 : tensor<2x3xi32>
99}
100
101// -----
102
103// CHECK-LABEL: func @non_const_chained_reshape
104// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
105func.func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) {
106  // CHECK-NEXT: mhlo.reshape [[ARG]] : (tensor<2x3xi32>) -> tensor<3x2xi32>
107  // CHECK-NEXT: mhlo.reshape [[ARG]] : (tensor<2x3xi32>) -> tensor<6xi32>
108  %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
109  %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
110  func.return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed
111}
112
113// -----
114
115// CHECK-LABEL: func @non_const_chained_reshape_unused_parent
116// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
117func.func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> {
118  // CHECK-NEXT: [[RES:%.+]] = mhlo.reshape [[ARG]] : (tensor<2x3xi32>) -> tensor<6xi32>
119  %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
120  %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
121  // CHECK-NEXT: return [[RES]]
122  func.return %1 : tensor<6xi32>
123}
124
125// -----
126
127// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop
128// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
129func.func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
130  %0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
131  %1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
132  // CHECK-NEXT: return [[ARG]]
133  func.return %1 : tensor<2x3xi32>
134}
135
136// -----
137
138// CHECK-LABEL: func @non_const_many_chained_reshapes
139// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
140func.func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> {
141  // CHECK-NEXT: [[RES:%.+]] = mhlo.reshape [[ARG]] : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
142  %0 = "mhlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
143  %1 = "mhlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
144  %2 = "mhlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
145  %3 = "mhlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
146  %4 = "mhlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
147  // CHECK-NEXT: return [[RES]]
148  func.return %4 : tensor<1x2x4x3xi32>
149}
150