1// RUN: mlir-opt %s -split-input-file -quant-convert-simulated-quantization | FileCheck %s 2 3// ----- 4// Verifies a quint8 single point. 5// CHECK-LABEL: fakeQuantArgs_Quint8_0 6func @fakeQuantArgs_Quint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 7^bb0(%arg0: tensor<8x4x3xf32>): 8 // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 9 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8:f32, 1.000000e+00>> 10 // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<u8:f32, 1.000000e+00>>) 11 // CHECK-SAME: -> tensor<8x4x3xf32> 12 %0 = "quant.const_fake_quant"(%arg0) { 13 min = 0.0 : f32, max = 0.0 : f32, num_bits = 8 14 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 15 return %0 : tensor<8x4x3xf32> 16} 17 18// ----- 19// Verifies a quint8 single point (with narrow_range = true). 20// CHECK-LABEL: fakeQuantArgs_Quint8_0_NarrowRange 21func @fakeQuantArgs_Quint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 22^bb0(%arg0: tensor<8x4x3xf32>): 23 // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 24 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>> 25 // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>) 26 // CHECK-SAME: -> tensor<8x4x3xf32> 27 %0 = "quant.const_fake_quant"(%arg0) { 28 min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, narrow_range = true 29 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 30 return %0 : tensor<8x4x3xf32> 31} 32 33// ----- 34// Verifies a quint8 asymmetric 0..1 range. 35// CHECK-LABEL: fakeQuantArgs_Quint8_0_1 36func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 37^bb0(%arg0: tensor<8x4x3xf32>): 38 // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 39 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8:f32, 0.0039215686274509803>> 40 // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<u8:f32, 0.0039215686274509803>>) 41 // CHECK-SAME: -> tensor<8x4x3xf32> 42 %0 = "quant.const_fake_quant"(%arg0) { 43 min = 0.0 : f32, max = 1.0 : f32, num_bits = 8 44 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 45 return %0 : tensor<8x4x3xf32> 46} 47 48// ----- 49// Verifies a quint8 asymmetric 0..1 range (with narrow_range = true). 50// CHECK-LABEL: fakeQuantArgs_Quint8_NarrowRange 51func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 52^bb0(%arg0: tensor<8x4x3xf32>): 53 // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 54 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 0.003937007874015748:1>> 55 // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 0.003937007874015748:1>>) 56 // CHECK-SAME: -> tensor<8x4x3xf32> 57 %0 = "quant.const_fake_quant"(%arg0) { 58 min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true 59 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 60 return %0 : tensor<8x4x3xf32> 61} 62 63// ----- 64// Verifies a quint8 symmetric range of -1..127/128. 65// CHECK-LABEL: fakeQuantArgs_Quint8_SymmetricRange 66func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 67^bb0(%arg0: tensor<8x4x3xf32>): 68 // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 69 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8:f32, 7.812500e-03:128>> 70 // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) 71 // CHECK-SAME: -> tensor<8x4x3xf32> 72 %0 = "quant.const_fake_quant"(%arg0) { 73 min = -1.0 : f32, max = 0.9921875 : f32, num_bits = 8, narrow_range = false 74 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 75 return %0 : tensor<8x4x3xf32> 76} 77 78// ----- 79// Verifies a qint8 single point. 80// CHECK-LABEL: fakeQuantArgs_Qint8_0 81func @fakeQuantArgs_Qint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 82^bb0(%arg0: tensor<8x4x3xf32>): 83 // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 84 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 1.000000e+00:-128>> 85 // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 1.000000e+00:-128>>) 86 // CHECK-SAME: -> tensor<8x4x3xf32> 87 %0 = "quant.const_fake_quant"(%arg0) { 88 min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, is_signed = true 89 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 90 return %0 : tensor<8x4x3xf32> 91} 92 93// ----- 94// Verifies a qint8 single point (with narrow_range = true). 95// CHECK-LABEL: fakeQuantArgs_Qint8_0_NarrowRange 96func @fakeQuantArgs_Qint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 97^bb0(%arg0: tensor<8x4x3xf32>): 98 // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 99 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00:-127>> 100 // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00:-127>>) 101 // CHECK-SAME: -> tensor<8x4x3xf32> 102 %0 = "quant.const_fake_quant"(%arg0) { 103 min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, narrow_range = true, is_signed = true 104 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 105 return %0 : tensor<8x4x3xf32> 106} 107 108// ----- 109// Verifies a qint8 asymmetric 0..1 range. 110// CHECK-LABEL: fakeQuantArgs_Qint8_0_1 111func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 112^bb0(%arg0: tensor<8x4x3xf32>): 113 // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 114 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>> 115 // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>) 116 // CHECK-SAME: -> tensor<8x4x3xf32> 117 %0 = "quant.const_fake_quant"(%arg0) { 118 min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, is_signed = true 119 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 120 return %0 : tensor<8x4x3xf32> 121} 122 123// ----- 124// Verifies a qint8 asymmetric 0..1 range (with narrow_range = true). 125// CHECK-LABEL: fakeQuantArgs_Qint8_NarrowRange 126func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 127^bb0(%arg0: tensor<8x4x3xf32>): 128 // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 129 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 0.003937007874015748:-127>> 130 // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 0.003937007874015748:-127>>) 131 // CHECK-SAME: -> tensor<8x4x3xf32> 132 %0 = "quant.const_fake_quant"(%arg0) { 133 min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true, is_signed = true 134 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 135 return %0 : tensor<8x4x3xf32> 136} 137 138// ----- 139// Verifies a qint8 symmetric range of -1..127/128. 140// CHECK-LABEL: fakeQuantArgs_Qint8_SymmetricRange 141func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 142^bb0(%arg0: tensor<8x4x3xf32>): 143 // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 144 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 7.812500e-03>> 145 // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i8:f32, 7.812500e-03>>) 146 // CHECK-SAME: -> tensor<8x4x3xf32> 147 %0 = "quant.const_fake_quant"(%arg0) { 148 min = -1.0 : f32, max = 0.9921875 : f32, num_bits = 8, narrow_range = false, is_signed = true 149 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 150 return %0 : tensor<8x4x3xf32> 151} 152 153// ----- 154// Verifies a commonly used -1..1 symmetric 16bit range with a zero point of 155// 0 and range -1.0 .. 32767/32768. 156// CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric 157func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 158^bb0(%arg0: tensor<8x4x3xf32>): 159 // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 160 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i16:f32, 3.0517578125E-5>> 161 // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i16:f32, 3.0517578125E-5>>) 162 // CHECK-SAME: -> tensor<8x4x3xf32> 163 %0 = "quant.const_fake_quant"(%arg0) { 164 min = -1.0 : f32, max = 0.999969482 : f32, num_bits = 16, is_signed = true 165 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 166 return %0 : tensor<8x4x3xf32> 167} 168 169// ----- 170// Verify that lowering to barriers of unranked tensors functions. 171// CHECK-LABEL: fakeQuantArgs_UnrankedTensor 172func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> { 173^bb0(%arg0: tensor<f32>): 174 // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<f32>) 175 // CHECK-SAME: -> tensor<!quant.uniform<u8:f32, 0.0039215686274509803>> 176 // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<!quant.uniform<u8:f32, 0.0039215686274509803>>) 177 // CHECK-SAME: -> tensor<f32> 178 %0 = "quant.const_fake_quant"(%arg0) { 179 min = 0.0 : f32, max = 1.0 : f32, num_bits = 8 180 } : (tensor<f32>) -> tensor<f32> 181 return %0 : tensor<f32> 182} 183 184// ----- 185// CHECK-LABEL: fakeQuantArgs_all_positive 186func @fakeQuantArgs_all_positive(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 187^bb0(%arg0: tensor<8x4x3xf32>): 188 189 // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 190 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>> 191 // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>) 192 // CHECK-SAME: -> tensor<8x4x3xf32> 193 194 %0 = "quant.const_fake_quant"(%arg0) { 195 min = 0.5 : f32, max = 1.5 : f32, num_bits = 8, narrow_range = false, is_signed = true 196 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 197 return %0 : tensor<8x4x3xf32> 198} 199 200// ----- 201// CHECK-LABEL: fakeQuantArgs_all_negative 202func @fakeQuantArgs_all_negative(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 203^bb0(%arg0: tensor<8x4x3xf32>): 204 205 // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 206 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>> 207 // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>>) 208 // CHECK-SAME: -> tensor<8x4x3xf32> 209 210 %0 = "quant.const_fake_quant"(%arg0) { 211 min = -1.5 : f32, max = -0.5 : f32, num_bits = 8, narrow_range = false, is_signed = true 212 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 213 return %0 : tensor<8x4x3xf32> 214} 215 216// ----- 217// Verifies a qint8 per axis 218// CHECK-LABEL: fakeQuantPerAxis 219func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { 220^bb0(%arg0: tensor<8x4x3xf32>): 221 222 // CHECK: %[[q:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) 223 // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>> 224 // CHECK: %[[d:.*]] = "quant.dcast"(%[[q]]) 225 // CHECK-SAME: (tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>>) 226 227 %0 = "quant.const_fake_quant_per_axis"(%arg0) { 228 min = [-1.0 : f32, 0.0 : f32, 0.0 : f32], 229 max = [0.9921875 : f32, 0.0: f32, 1.0 : f32], 230 num_bits = 8, narrow_range = false, is_signed = true, axis = 2 231 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> 232 return %0 : tensor<8x4x3xf32> 233} 234