// RUN: tf-opt %s -tfl-prepare-quantize="quantize-allowlist=quantize_float_placeholder_only,not_reset_input" | FileCheck %s // RUN: tf-opt %s -tfl-prepare-quantize="disable-set-input-nodes-quantization-params=true" | FileCheck --check-prefix=MixedPrecision %s // CHECK-LABEL: main // Uses `main` function to match the default target function of QuantSpecs and // execute the production code path. func.func @main(%arg0: tensor<2x1xf32>, %arg1: tensor<2x3xf32>) -> (tensor<2x4xf32>) { %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x1x!quant.uniform>} : (tensor<2x1xf32>) -> tensor<2x1x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<2x1x!quant.uniform>) -> (tensor<2x1xf32>) %2 = "tfl.quantize"(%arg1) {qtype = tensor<2x3x!quant.uniform>} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> %3 = "tfl.dequantize"(%2) : (tensor<2x3x!quant.uniform>) -> (tensor<2x3xf32>) %4 = "tfl.concatenation"(%1, %3) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xf32>, tensor<2x3xf32>) -> tensor<2x4xf32> func.return %4: tensor<2x4xf32> // CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK-NEXT: %[[q_0:.*]] = "tfl.quantize"(%arg1) // CHECK-NEXT: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) // CHECK-NEXT: %[[c:.*]] = "tfl.concatenation"(%[[dq]], %[[dq_0]]) // CHECK-NEXT: %[[q_1:.*]] = "tfl.quantize"(%[[c]]) // CHECK-NEXT: %[[dq_1:.*]] = "tfl.dequantize"(%[[q_1]]) // CHECK-NEXT: return %[[dq_1:.*]] } // MixedPrecision-LABEL: paritial_quantized func.func @paritial_quantized(%arg0: tensor<2x1xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x4xf32>) -> (tensor<2x4xf32>) { %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x1x!quant.uniform>} : (tensor<2x1xf32>) -> tensor<2x1x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<2x1x!quant.uniform>) -> (tensor<2x1xf32>) %2 = "tfl.quantize"(%arg1) {qtype = tensor<2x3x!quant.uniform>} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> %3 = "tfl.dequantize"(%2) : (tensor<2x3x!quant.uniform>) -> (tensor<2x3xf32>) %4 = "tfl.concatenation"(%1, %3) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xf32>, tensor<2x3xf32>) -> tensor<2x4xf32> %5 = "tfl.add"(%4, %arg2) {fused_activation_function = "NONE"} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> func.return %5: tensor<2x4xf32> // MixedPrecision-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) // MixedPrecision-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // MixedPrecision-NEXT: %[[q_0:.*]] = "tfl.quantize"(%arg1) // MixedPrecision-NEXT: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) // MixedPrecision-NEXT: %[[c:.*]] = "tfl.concatenation"(%[[dq]], %[[dq_0]]) // MixedPrecision-NEXT: %[[q_1:.*]] = "tfl.quantize"(%[[c]]) // MixedPrecision-NEXT: %[[dq_1:.*]] = "tfl.dequantize"(%[[q_1]]) // MixedPrecision-NEXT: %[[v:.*]] = tfl.add %[[dq_1]], %arg2 // MixedPrecision-NEXT: return %[[v:.*]] } // CHECK-LABEL: quantize_float_placeholder_only func.func @quantize_float_placeholder_only(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor, tensor<2x3xi32>, tensor<2x3xf32>) { func.return %arg0, %arg1, %arg2: tensor, tensor<2x3xi32>, tensor<2x3xf32> // CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK-NEXT: %[[q_0:.*]] = "tfl.quantize"(%arg2) // CHECK-NEXT: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) // CHECK-NEXT: %[[dq]], %arg1, %[[dq_0]] } // CHECK-LABEL: not_reset_input func.func @not_reset_input(%arg0: tensor) -> (tensor>) { %0 = "tfl.quantize"(%arg0) {qtype = tensor>} : (tensor) -> tensor> func.return %0: tensor> // CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor>} // CHECK-NEXT: return %[[q]] } // CHECK-LABEL: DequantizeAndQuantize func.func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform> { %cst = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform>, value = dense<-1> : tensor<2x2xi8>} : () -> tensor<2x2x!quant.uniform> %0 = "tfl.dequantize"(%cst) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %1 : tensor<2x2x!quant.uniform> // CHECK: %0 = "tfl.pseudo_qconst"() // CHECK: %1 = "tfl.dequantize"(%0) // CHECK: %2 = "tfl.quantize"(%1) // CHECK: return %2 } // CHECK-LABEL: prepareStatistics func.func @prepareStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { %0 = "quantfork.stats"(%arg0) { layerStats = dense<[-1.0, 1.0]> : tensor<2xf32> } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> %1 = "quantfork.stats"(%0) { layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, axisStats = dense<[ [-1.0, 1.0], [-8.0, 8.0], [-0.5, 0.5] ]> : tensor<3x2xf32>, axis = 2 : i64 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> func.return %1 : tensor<8x4x3xf32> // CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>, volatile} // CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) // CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform>, volatile} // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) // CHECK: return %[[dq2]] } // CHECK-LABEL: prepareNarrowStatistics func.func @prepareNarrowStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { %0 = "quantfork.stats"(%arg0) { layerStats = dense<[-1.0e-9, 1.0e-9]> : tensor<2xf32> } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> func.return %0 : tensor<8x4x3xf32> // CHECK: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform>, volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: return %[[dq]] } // CHECK-LABEL: QuantizeConv2DPerChannel func.func @QuantizeConv2DPerChannel(%arg0: tensor<1x224x224x3x!quant.uniform>, %arg1: tensor<32x3x3x3x!quant.uniform:f32:3, {1.0,2.0,3.0}>>) -> tensor<1x112x112x32xf32> { %bias = arith.constant dense<1.0> : tensor<32xf32> %input = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> %weight = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform:f32:3, {1.0,2.0,3.0}>>) -> tensor<32x3x3x3xf32> %conv = "tfl.conv_2d"(%input, %weight, %bias) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> func.return %conv : tensor<1x112x112x32xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<32xf32> // CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]]) // CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0) // CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1) // CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[in]], %[[w]], %[[bias]]) // CHECK-NEXT: return %[[conv]] } // CHECK-LABEL: QuantizeConv2DPerChannelConst func.func @QuantizeConv2DPerChannelConst(%arg0: tensor<1x224x224x3x!quant.uniform>, %arg1: tensor<32x3x3x3x!quant.uniform:f32:3, {1.0,2.0,3.0}>>) -> tensor<1x112x112x32xf32> { %bias = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<32xf32>} : () -> tensor<32xf32> %input = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> %weight = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform:f32:3, {1.0,2.0,3.0}>>) -> tensor<32x3x3x3xf32> %conv = "tfl.conv_2d"(%input, %weight, %bias) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> func.return %conv : tensor<1x112x112x32xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<32xf32> // CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]]) // CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0) // CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1) // CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[in]], %[[w]], %[[bias]]) // CHECK-NEXT: return %[[conv]] } // CHECK-LABEL: QuantizeConv2DPerChannels func.func @QuantizeConv2DPerChannels(%arg0: tensor<1x224x224x3x!quant.uniform>, %arg1: tensor<32x3x3x3x!quant.uniform:f32:3, {1.0,2.0,3.0}>>) -> tensor<1x112x112x32xf32> { %bias = arith.constant dense<1.0> : tensor<32xf32> %input = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> %weight = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform:f32:3, {1.0,2.0,3.0}>>) -> tensor<32x3x3x3xf32> %conv = "tfl.conv_2d"(%input, %weight, %bias) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> func.return %conv : tensor<1x112x112x32xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<32xf32> // CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]]) // CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0) // CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1) // CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[in]], %[[w]], %[[bias]]) // CHECK-NEXT: return %[[conv]] } // CHECK-LABEL: QuantizeConv2D func.func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): %cst = arith.constant dense<-1.23697901> : tensor<32xf32> %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> %3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>> %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32> %5 = "tfl.conv_2d"(%2, %4, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> %6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> func.return %6 : tensor<1x112x112x32x!quant.uniform> // CHECK: %cst = arith.constant dense<-1.23697901> : tensor<32xf32> // CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) // CHECK: %3 = "tfl.pseudo_qconst"() // CHECK: %4 = "tfl.dequantize"(%3) // CHECK: %5 = "tfl.conv_2d"(%2, %4, %1) // CHECK: %6 = "tfl.quantize"(%5) // CHECK: return %6 } // CHECK-LABEL: QuantizeFullyConnected func.func @QuantizeFullyConnected(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): %cst = arith.constant dense<-1.23697901> : tensor<32xf32> %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> %3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x12xi8>} : () -> tensor<32x12x!quant.uniform:f32, 0.021826678373682216:151>> %4 = "tfl.dequantize"(%3) : (tensor<32x12x!quant.uniform:f32, 0.021826678373682216:151>>) -> tensor<32x12xf32> %5 = "tfl.fully_connected"(%2, %4, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<32x12xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> %6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> func.return %6 : tensor<1x112x112x32x!quant.uniform> // CHECK: %cst = arith.constant dense<-1.23697901> : tensor<32xf32> // CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) // CHECK: %3 = "tfl.pseudo_qconst"() // CHECK: %4 = "tfl.dequantize"(%3) // CHECK: %5 = "tfl.fully_connected"(%2, %4, %1) // CHECK: %6 = "tfl.quantize"(%5) // CHECK: return %6 } // CHECK-LABEL: QuantizeDepthwiseConv2D func.func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): %cst = arith.constant dense<-1.23697901> : tensor<32xf32> %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> %3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>> %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32> %5 = "tfl.depthwise_conv_2d"(%2, %4, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> %6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> func.return %6 : tensor<1x112x112x32x!quant.uniform> // CHECK: %cst = arith.constant dense<-1.23697901> : tensor<32xf32> // CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) // CHECK: %3 = "tfl.pseudo_qconst"() // CHECK: %4 = "tfl.dequantize"(%3) // CHECK: %5 = "tfl.depthwise_conv_2d"(%2, %4, %1) // CHECK: %6 = "tfl.quantize"(%5) // CHECK: return %6 } // CHECK-LABEL: QuantizeAveragePool2D func.func @QuantizeAveragePool2D(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x1x1x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> %1 = "tfl.average_pool_2d"(%0) { name = "avgpool", filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32 } : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> func.return %1 : tensor<1x1x1x16xf32> // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.average_pool_2d"(%0) // CHECK: %2 = "tfl.quantize"(%1) // CHECK: %3 = "tfl.dequantize"(%2) // CHECK: return %3 : tensor<1x1x1x16xf32> } // CHECK-LABEL: QuantizeMaximum func.func @QuantizeMaximum(tensor<1x6x6x16x!quant.uniform>, tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>, %arg1: tensor<1x6x6x16x!quant.uniform>): %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> %1 = "tfl.dequantize"(%arg1) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> %2 = "tfl.maximum"(%0, %1) : (tensor<1x6x6x16xf32>, tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> func.return %2 : tensor<1x6x6x16xf32> // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.dequantize"(%arg1) // CHECK: %2 = "tfl.maximum"(%0, %1) // CHECK: %3 = "tfl.quantize"(%2) // CHECK: %4 = "tfl.dequantize"(%3) // CHECK: return %4 : tensor<1x6x6x16xf32> } // CHECK-LABEL: QuantizeMinimum func.func @QuantizeMinimum(tensor<1x6x6x16x!quant.uniform>, tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>, %arg1: tensor<1x6x6x16x!quant.uniform>): %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> %1 = "tfl.dequantize"(%arg1) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> %2 = "tfl.minimum"(%0, %1) : (tensor<1x6x6x16xf32>, tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> func.return %2 : tensor<1x6x6x16xf32> // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.dequantize"(%arg1) // CHECK: %2 = "tfl.minimum"(%0, %1) // CHECK: %3 = "tfl.quantize"(%2) // CHECK: %4 = "tfl.dequantize"(%3) // CHECK: return %4 : tensor<1x6x6x16xf32> } // CHECK-LABEL: QuantizeSlice func.func @QuantizeSlice(tensor<2x3x5x!quant.uniform>, tensor<3xi32>, tensor<3xi32>) -> tensor { ^bb0(%arg0: tensor<2x3x5x!quant.uniform>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>): %0 = "tfl.dequantize"(%arg0) : (tensor<2x3x5x!quant.uniform>) -> tensor<2x3x5xf32> %1 = "tfl.slice"(%0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor func.return %1 : tensor // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.slice"(%0, %arg1, %arg2) // CHECK: %2 = "tfl.quantize"(%1) // CHECK: %3 = "tfl.dequantize"(%2) // CHECK: return %3 : tensor } // CHECK-LABEL: QuantizeStridedSlice func.func @QuantizeStridedSlice(tensor<12x2x2x5x!quant.uniform>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> { ^bb0(%arg0: tensor<12x2x2x5x!quant.uniform>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>): %0 = "tfl.dequantize"(%arg0) : (tensor<12x2x2x5x!quant.uniform>) -> tensor<12x2x2x5xf32> %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> func.return %1 : tensor<1x2x2x5xf32> // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg3) // CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x2x2x5x!quant.uniform>, volatile} // CHECK: %3 = "tfl.dequantize"(%2) // CHECK: return %3 : tensor<1x2x2x5xf32> } // CHECK-LABEL: QuantizePad func.func @QuantizePad(tensor<2x1x3x!quant.uniform>, tensor<3x2xi32>) -> tensor { ^bb0(%arg0: tensor<2x1x3x!quant.uniform>, %arg1: tensor<3x2xi32>): %0 = "tfl.dequantize"(%arg0) : (tensor<2x1x3x!quant.uniform>) -> tensor<2x1x3xf32> %1 = "tfl.pad"(%0, %arg1) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor func.return %1 : tensor // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.pad"(%0, %arg1) // CHECK: %2 = "tfl.quantize"(%1) // CHECK: %3 = "tfl.dequantize"(%2) // CHECK: return %3 : tensor } // CHECK-LABEL: QuantizePad2 // only the second tfl.pad has sufficient quantization information. func.func @QuantizePad2(tensor<2x1x3x!quant.uniform>, tensor<2x1x3xf32>, tensor<3x2xi32>) -> (tensor, tensor) { ^bb0(%arg0: tensor<2x1x3x!quant.uniform>, %arg1: tensor<2x1x3xf32>, %arg2: tensor<3x2xi32>): %0 = "tfl.dequantize"(%arg0) : (tensor<2x1x3x!quant.uniform>) -> tensor<2x1x3xf32> %1 = "tfl.pad"(%arg1, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor %2 = "tfl.pad"(%0, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor func.return %1, %2 : tensor, tensor // CHECK: %[[dq:.*]] = "tfl.dequantize"(%arg0) // CHECK: %[[pad1:.*]] = "tfl.pad"(%arg1, %arg2) // CHECK: %[[pad2:.*]] = "tfl.pad"(%[[dq]], %arg2) // CHECK: %[[q2:.*]] = "tfl.quantize"(%[[pad2]]) // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) } // CHECK-LABEL: QuantizeReshape2D func.func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x36x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): %cst = arith.constant dense<[1, 36, 16]> : tensor<3xi32> %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> %1 = "tfl.reshape"(%0, %cst) : (tensor<1x6x6x16xf32>, tensor<3xi32>) -> tensor<1x36x16xf32> func.return %1 : tensor<1x36x16xf32> // CHECK: %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) // CHECK: %1 = "tfl.reshape"(%0, %{{.*}}) : (tensor<1x6x6x16xf32>, tensor<3xi32>) -> tensor<1x36x16xf32> // CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x36x16x!quant.uniform>, volatile} // CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x36x16x!quant.uniform>) // CHECK: return %3 : tensor<1x36x16xf32> } // CHECK-LABEL: QuantizeSoftmax func.func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> func.return %1 : tensor<1x6x6x16xf32> // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> // CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform>, volatile} // CHECK: %3 = "tfl.dequantize"(%2) // CHECK: return %3 : tensor<1x6x6x16xf32> } // CHECK-LABEL: QuantizeLogistic func.func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> func.return %1 : tensor<1x6x6x16xf32> // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> // CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform>, volatile} // CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> // CHECK: return %3 : tensor<1x6x6x16xf32> } // CHECK-LABEL: NotRescaleLogistic func.func @NotRescaleLogistic(%arg0: tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16x!quant.uniform> { %0 = "tfl.logistic"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16x!quant.uniform> func.return %0 : tensor<1x6x6x16x!quant.uniform> // CHECK: %[[log:.*]] = "tfl.logistic"(%arg0) // CHECK: return %[[log]] } // CHECK-LABEL: QuantizeL2Norm func.func @QuantizeL2Norm(%arg0: tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> %1 = "tfl.l2_normalization"(%0) {fused_activation_function = "NONE"} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> func.return %1 : tensor<1x6x6x16xf32> // CHECK: %[[in:.*]] = "tfl.dequantize"(%arg0) // CHECK: %[[l2:.*]] = "tfl.l2_normalization"(%[[in]]) // CHECK: %[[q:.*]] = "tfl.quantize"(%[[l2]]) {qtype = tensor<1x6x6x16x!quant.uniform>, volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: return %[[dq]] : tensor<1x6x6x16xf32> } // CHECK-LABEL: NotQuantizeConcatConstantOperand func.func @NotQuantizeConcatConstantOperand(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { %0 = arith.constant dense<1.0> : tensor<1x2xf32> %1 = "tfl.concatenation"(%arg0, %0) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> func.return %1 : tensor<2x2xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<1x2xf32> // CHECK-NEXT: %[[cc:.*]] = "tfl.concatenation"(%arg0, %[[cst]]) // CHECK-NEXT: return %[[cc]] } // CHECK-LABEL: QuantizeConcatOperand0ToAll func.func @QuantizeConcatOperand0ToAll(tensor<1x2x!quant.uniform>, tensor<1x2xf32>) -> tensor<2x2xf32> { ^bb0(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2xf32>): %0 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %1 = "tfl.concatenation"(%0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> func.return %1 : tensor<2x2xf32> // CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> // CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform>, volatile} // CHECK: %5 = "tfl.dequantize"(%4) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> // CHECK: return %5 : tensor<2x2xf32> } // CHECK-LABEL: QuantizeConcatOperand1ToAll func.func @QuantizeConcatOperand1ToAll(tensor<1x2xf32>, tensor<1x2x!quant.uniform>) -> tensor<2x2xf32> { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2x!quant.uniform>): %0 = "tfl.dequantize"(%arg1) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %1 = "tfl.concatenation"(%arg0, %0) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> func.return %1 : tensor<2x2xf32> // CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %2 = "tfl.dequantize"(%arg1) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> // CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform>, volatile} // CHECK: %5 = "tfl.dequantize"(%4) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> // CHECK: return %5 : tensor<2x2xf32> } // CHECK-LABEL: QuantizeConcatResToAll func.func @QuantizeConcatResToAll(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>): %0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %1 : tensor<2x2x!quant.uniform> // CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %2 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %4 = "tfl.concatenation"(%3, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> // CHECK: %5 = "tfl.quantize"(%4) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %5 : tensor<2x2x!quant.uniform> } // CHECK-LABEL: QuantizeConcatResToAllNoRequantize func.func @QuantizeConcatResToAllNoRequantize(tensor<1x2x!quant.uniform>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { ^bb0(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2xf32>): %0 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %1 = "tfl.concatenation"(%0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> %2 = "tfl.quantize"(%1) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %2 : tensor<2x2x!quant.uniform> // CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> // CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %4 : tensor<2x2x!quant.uniform> } // CHECK-LABEL: QuantizeConcatResToAllRequantize func.func @QuantizeConcatResToAllRequantize(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>): %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %2 = "tfl.concatenation"(%1, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %3 : tensor<2x2x!quant.uniform> // CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %[[Q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> // CHECK: %[[RQ0:.*]] = "tfl.quantize"(%[[Q0]]) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> // CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %[[Q]] : tensor<2x2x!quant.uniform> } // CHECK-LABEL: QuantizeConcatResToAllRequantizeArg func.func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { ^bb0(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2xf32>): %1 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %2 = "tfl.concatenation"(%1, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> func.return %3 : tensor<2x2x!quant.uniform> // CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %[[RQ0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> // CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %[[Q]] : tensor<2x2x!quant.uniform> } // CHECK-LABEL: NotRequantizeAlreadyQuantizedModel func.func @NotRequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform>, %arg1: tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> { %9 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x96x!quant.uniform> %10 = "tfl.concatenation"(%arg0, %9) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform>, tensor<1x73x73x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> func.return %10 : tensor<1x73x73x160x!quant.uniform> // CHECK: %[[max:.*]] = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x96x!quant.uniform> // CHECK: %[[cat:.*]] = "tfl.concatenation"(%arg0, %[[max]]) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform>, tensor<1x73x73x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> // CHECK: return %[[cat]] : tensor<1x73x73x160x!quant.uniform> } // CHECK-LABEL: QuantizeChain func.func @QuantizeChain(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x36x16xf32> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): %cst = arith.constant dense<-1.23697901> : tensor<32xf32> %cst_0 = arith.constant dense<[1, 36, 16]> : tensor<3xi32> %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> %3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>> %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32> %5 = "tfl.average_pool_2d"(%2) { name = "avgpool", filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32 } : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xf32> %6 = "tfl.conv_2d"(%5, %4, %cst) { dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32 } : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> %7 = "tfl.quantize"(%6) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> %8 = "tfl.dequantize"(%7) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x6x6x16xf32> %9 = "tfl.reshape"(%8, %cst_0) : (tensor<1x6x6x16xf32>, tensor<3xi32>) -> tensor<1x36x16xf32> %10 = "tfl.softmax"(%9) {beta = 1.000000e+00 : f32} : (tensor<1x36x16xf32>) -> tensor<1x36x16xf32> func.return %10 : tensor<1x36x16xf32> // CHECK: %cst = arith.constant dense<-1.23697901> : tensor<32xf32> // CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) // CHECK: %3 = "tfl.pseudo_qconst"() // CHECK: %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>) // CHECK: %5 = "tfl.average_pool_2d"(%2) // CHECK: %6 = "tfl.quantize"(%5) {qtype = tensor<1x224x224x3x!quant.uniform>, volatile} // CHECK: %7 = "tfl.dequantize"(%6) : (tensor<1x224x224x3x!quant.uniform>) // CHECK: %8 = "tfl.conv_2d"(%7, %4, %1) // CHECK: %9 = "tfl.quantize"(%8) {qtype = tensor<1x112x112x32x!quant.uniform>} // CHECK: %10 = "tfl.dequantize"(%9) : (tensor<1x112x112x32x!quant.uniform>) // CHECK: %11 = "tfl.reshape"(%10, %{{.*}}) // CHECK: %12 = "tfl.quantize"(%11) {qtype = tensor<1x36x16x!quant.uniform>, volatile} // CHECK: %13 = "tfl.dequantize"(%12) : (tensor<1x36x16x!quant.uniform>) // CHECK: %14 = "tfl.softmax"(%13) // CHECK: %15 = "tfl.quantize"(%14) {qtype = tensor<1x36x16x!quant.uniform>, volatile} // CHECK: %16 = "tfl.dequantize"(%15) : (tensor<1x36x16x!quant.uniform>) // CHECK: return %16 : tensor<1x36x16xf32> } // CHECK-LABEL: QuantizeConstant func.func @QuantizeConstant() -> tensor<2x3xf32> { %cst = arith.constant dense<[[-3.0, -1.0, 0.0], [0.0, 1.0, 3.0]]> : tensor<2x3xf32> func.return %cst : tensor<2x3xf32> // CHECK: %cst = arith.constant dense{{.*}}tensor<2x3xf32> // CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) // CHECK: return %1 : tensor<2x3xf32> } // CHECK-LABEL: NotQuantizeNoneType func.func @NotQuantizeNoneType() -> none { %cst = "tfl.no_value"() {value = unit} : () -> none func.return %cst : none // CHECK-NEXT: %[[cst:.*]] = "tfl.no_value"() {value} : () -> none // CHECK-NEXT: return %[[cst]] } // CHECK-LABEL: QuantizeZeroSplat func.func @QuantizeZeroSplat() -> tensor<2x3xf32> { %cst = arith.constant dense<0.0> : tensor<2x3xf32> func.return %cst : tensor<2x3xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor<2x3xf32> // CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>, volatile} } // CHECK-LABEL: QuantizeZeroScalar func.func @QuantizeZeroScalar() -> tensor { %cst = arith.constant dense<0.0> : tensor func.return %cst : tensor // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor // CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>, volatile} } // CHECK-LABEL: QuantizePositiveSplat func.func @QuantizePositiveSplat() -> tensor<2x3xf32> { %cst = arith.constant dense<25.4> : tensor<2x3xf32> func.return %cst : tensor<2x3xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<2.540000e+01> : tensor<2x3xf32> // CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>, volatile} } // CHECK-LABEL: QuantizePositiveScalar func.func @QuantizePositiveScalar() -> tensor { %cst = arith.constant dense<2.54> : tensor func.return %cst : tensor // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<2.540000e+00> : tensor // CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>, volatile} } // CHECK-LABEL: QuantizeNegativeSplat func.func @QuantizeNegativeSplat() -> tensor<2x3xf32> { %cst = arith.constant dense<-2.54> : tensor<2x3xf32> func.return %cst : tensor<2x3xf32> // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<-2.540000e+00> : tensor<2x3xf32> // CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>, volatile} } // CHECK-LABEL: QuantizeNegativeScalar func.func @QuantizeNegativeScalar() -> tensor { %cst = arith.constant dense<-25.4> : tensor func.return %cst : tensor // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<-2.540000e+01> : tensor // CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>, volatile} } // Make sure biases are not shared. // CHECK-LABEL: QuantizeSharedBiases func.func @QuantizeSharedBiases( %arg0: tensor<1x224x224x3x!quant.uniform>, %arg1: tensor<32x3x3x3x!quant.uniform:f32, 1.0>>, %arg2: tensor<32x3x3x3x!quant.uniform:f32, 2.0>>) -> (tensor<1x56x56x32x!quant.uniform>) { %cst = arith.constant dense<1.0> : tensor<32xf32> %1 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> %2 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform:f32, 1.0>>) -> tensor<32x3x3x3xf32> %conv1 = "tfl.conv_2d"(%1, %2, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> %3 = "tfl.quantize"(%conv1) {qtype = tensor<1x112x112x32xf32>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> %4 = "tfl.dequantize"(%3) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x32xf32> %5 = "tfl.dequantize"(%arg2) : (tensor<32x3x3x3x!quant.uniform:f32, 2.0>>) -> tensor<32x3x3x3xf32> %conv2 = "tfl.conv_2d"(%4, %5, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x112x112x32xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x56x56x32xf32> %6 = "tfl.quantize"(%conv2) {qtype = tensor<1x56x56x32x!quant.uniform>} : (tensor<1x56x56x32xf32>) -> tensor<1x56x56x32x!quant.uniform> func.return %6 : tensor<1x56x56x32x!quant.uniform> // CHECK: %[[cst_0:.*]] = arith.constant dense<1.000000e+00> : tensor<32xf32> // CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) // CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) : (tensor<32x!quant.uniform>) // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<32xf32> // CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<32x!quant.uniform>) // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq_0]]) } // Make sure biases are not shared. // CHECK-LABEL: QuantizeSharedBiases2 func.func @QuantizeSharedBiases2( %arg0: tensor<32x!quant.uniform>, %arg1: tensor<1x112x112x32x!quant.uniform>, %arg2: tensor<32x3x3x3x!quant.uniform:f32, 2.0>>) -> (tensor<32x!quant.uniform>, tensor<1x56x56x32x!quant.uniform>) { %cst = arith.constant dense<0.0> : tensor<32xf32> %1 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform>) -> tensor<32xf32> %add = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> %3 = "tfl.quantize"(%add) {qtype = tensor<32xf32>} : (tensor<32xf32>) -> tensor<32x!quant.uniform> %5 = "tfl.dequantize"(%arg1) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x32xf32> %6 = "tfl.dequantize"(%arg2) : (tensor<32x3x3x3x!quant.uniform:f32, 2.0>>) -> tensor<32x3x3x3xf32> %conv2 = "tfl.conv_2d"(%5, %6, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x112x112x32xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x56x56x32xf32> %7 = "tfl.quantize"(%conv2) {qtype = tensor<1x56x56x32x!quant.uniform>} : (tensor<1x56x56x32xf32>) -> tensor<1x56x56x32x!quant.uniform> func.return %3, %7 : tensor<32x!quant.uniform>, tensor<1x56x56x32x!quant.uniform> // CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf32> // CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[cst_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf32> // CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) // CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]] // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) } // Make sure biases are not shared. // CHECK-LABEL: QuantizeSharedBiases3 func.func @QuantizeSharedBiases3( %arg0: tensor<32x!quant.uniform>, %arg1: tensor<1x112x112x32x!quant.uniform>, %arg2: tensor<32x3x3x3x!quant.uniform:f32, 2.0>>) -> (tensor<32x!quant.uniform>, tensor<1x56x56x32x!quant.uniform>) { %cst = arith.constant dense<0.0> : tensor<32xf32> %5 = "tfl.dequantize"(%arg1) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x32xf32> %6 = "tfl.dequantize"(%arg2) : (tensor<32x3x3x3x!quant.uniform:f32, 2.0>>) -> tensor<32x3x3x3xf32> %conv2 = "tfl.conv_2d"(%5, %6, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x112x112x32xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x56x56x32xf32> %7 = "tfl.quantize"(%conv2) {qtype = tensor<1x56x56x32x!quant.uniform>} : (tensor<1x56x56x32xf32>) -> tensor<1x56x56x32x!quant.uniform> %1 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform>) -> tensor<32xf32> %add = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> %3 = "tfl.quantize"(%add) {qtype = tensor<32xf32>} : (tensor<32xf32>) -> tensor<32x!quant.uniform> func.return %3, %7 : tensor<32x!quant.uniform>, tensor<1x56x56x32x!quant.uniform> // CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf32> // CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[cst_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf32> // CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) // CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) // CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]] } // Make sure constants are duplicataed for all users. // CHECK-LABEL: QuantizeSharedConstantsMultipleUsers func.func @QuantizeSharedConstantsMultipleUsers( %arg0: tensor<32x!quant.uniform>, %arg1: tensor<32x!quant.uniform>, %arg2: tensor<32x!quant.uniform>, %arg3: tensor<32x!quant.uniform>) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) { %cst = arith.constant dense<0.0> : tensor<32xf32> %0 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform>) -> tensor<32xf32> %1 = "tfl.dequantize"(%arg1) : (tensor<32x!quant.uniform>) -> tensor<32xf32> %2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform>) -> tensor<32xf32> %3 = "tfl.dequantize"(%arg3) : (tensor<32x!quant.uniform>) -> tensor<32xf32> %4 = "tfl.minimum"(%0, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> %5 = "tfl.minimum"(%1, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> %6 = "tfl.minimum"(%2, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> %7 = "tfl.minimum"(%3, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> func.return %4, %5, %6, %7 : tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32> // CHECK-DAG: %[[cst1:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform>) -> tensor<32xf32> // CHECK-DAG: %[[cst2:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform>) -> tensor<32xf32> // CHECK-DAG: %[[cst3:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform>) -> tensor<32xf32> // CHECK-DAG: %[[cst4:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform>) -> tensor<32xf32> // CHECK-NOT: BLOCK_DAG // CHECK-DAG: "tfl.minimum"(%{{.*}}, %[[cst1]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> // CHECK-DAG: "tfl.minimum"(%{{.*}}, %[[cst2]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> // CHECK-DAG: "tfl.minimum"(%{{.*}}, %[[cst3]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> // CHECK-DAG: "tfl.minimum"(%{{.*}}, %[[cst4]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> } // Make sure quantization parameters are scanned from weight, but not from bias. // CHECK-LABEL: QuantizeWeight func.func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { %w = arith.constant dense<1.0> : tensor<32x3x3x3xf32> %b = arith.constant dense<-1.0> : tensor<32xf32> %c = "tfl.conv_2d"(%arg0, %w, %b) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> func.return %c : tensor<1x112x112x32xf32> // CHECK: %[[w:.*]] = arith.constant dense<1.000000e+00> : tensor<32x3x3x3xf32> // CHECK: %[[q:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.003937007874015748:1>>, volatile} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3x!quant.uniform:f32, 0.003937007874015748:1>> // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<32x3x3x3x!quant.uniform:f32, 0.003937007874015748:1>>) -> tensor<32x3x3x3xf32> // CHECK: %[[b:.*]] = arith.constant dense<-1.000000e+00> : tensor<32xf32> // CHECK: %[[c:.*]] = "tfl.conv_2d"(%arg0, %[[dq]], %[[b]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> // CHECK: return %[[c]] : tensor<1x112x112x32xf32> } // Make sure quantization parameters are not scanned if quantize op is presented. // CHECK-LABEL: NoRedundantQuantizeWeight func.func @NoRedundantQuantizeWeight() -> tensor<1x112x112x32xf32> { %w = arith.constant dense<1.0> : tensor<1x112x112x32xf32> %q = "tfl.quantize"(%w) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> %dq = "tfl.dequantize"(%q) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x32xf32> func.return %dq : tensor<1x112x112x32xf32> // CHECK-NEXT: %[[w:.*]] = arith.constant dense<1.000000e+00> : tensor<1x112x112x32xf32> // CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<1x112x112x32x!quant.uniform>} // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK-NEXT: return %[[dq]] : tensor<1x112x112x32xf32> } // CHECK-LABEL: ReturnQuantizedResult func.func @ReturnQuantizedResult(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<32x3x3x3xf32>, %arg2: tensor<32xf32>) -> (tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) { %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> %1 = "tfl.quantize"(%0) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> %2 = "tfl.dequantize"(%1) : (tensor<1x112x112x32x!quant.uniform>) -> (tensor<1x112x112x32xf32>) func.return %0, %2 : tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32> // CHECK: %[[dw:.*]] = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) // CHECK: %[[q:.*]] = "tfl.quantize"(%[[dw]]) // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: return %[[dq]], %[[dq]] } // Series of values needing requantization -- first the args then the results // of concatenation operations. concat(concat(arg2, arg0), concat(arg1, arg0)), // concat(concat(arg2, arg0), arg3)). arg0 should be requantized twice -- // concat(arg2, arg0) should be requantized twice as well. // CHECK-LABEL: QuantizedCatsAddRequantsTest func.func @QuantizedCatsAddRequantsTest(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1xf32>, %arg2: tensor<1x1xf32>, %arg3: tensor<1x1xf32>) -> (tensor<1x4xf32>, tensor<1x3xf32>) { %0 = "quantfork.stats"(%arg0) {layerStats = dense<[-0.440728068, 0.189515018]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> %1 = "quantfork.stats"(%arg1) {layerStats = dense<[-0.154693216, 0.26483655]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> %2 = "quantfork.stats"(%arg2) {layerStats = dense<[-0.488159984, 0.16362021]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> %3 = "quantfork.stats"(%arg3) {layerStats = dense<[-0.25180456, 0.398609281]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> %6 = "tfl.concatenation"(%1, %0) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> %7 = "quantfork.stats"(%6) {layerStats = dense<[-0.440728068, 0.26483655]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> %8 = "tfl.concatenation"(%2, %0) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> %9 = "quantfork.stats"(%8) {layerStats = dense<[-0.488159984, 0.189515018]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> %10 = "tfl.concatenation"(%9, %7) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x4xf32> %11 = "quantfork.stats"(%10) {layerStats = dense<[-0.488159984, 0.26483655]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32> %13 = "tfl.concatenation"(%9, %3) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> %14 = "quantfork.stats"(%13) {layerStats = dense<[-0.488159984, 0.398609281]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> func.return %10, %14 : tensor<1x4xf32>, tensor<1x3xf32> // CHECK-NEXT: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[r0q0:.*]] = "tfl.quantize"(%[[q0]]) {qtype = tensor<1x1x!quant.uniform>} : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[r1q0:.*]] = "tfl.quantize"(%[[q0]]) {qtype = tensor<1x1x!quant.uniform>} : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[d1q0:.*]] = "tfl.dequantize"(%[[r1q0]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> // CHECK-NEXT: %[[d0q0:.*]] = "tfl.dequantize"(%[[r0q0]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> // CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[r0q1:.*]] = "tfl.quantize"(%[[q1]]) {qtype = tensor<1x1x!quant.uniform>} : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[d0q1:.*]] = "tfl.dequantize"(%[[r0q1]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> // CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%arg2) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[r0q2:.*]] = "tfl.quantize"(%[[q2]]) {qtype = tensor<1x1x!quant.uniform>} : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[d0q2:.*]] = "tfl.dequantize"(%[[r0q2]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> // CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg3) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[r0q3:.*]] = "tfl.quantize"(%[[q3]]) {qtype = tensor<1x1x!quant.uniform>} : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[d0q3:.*]] = "tfl.dequantize"(%[[r0q3]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> // CHECK-NEXT: %[[cat1_0:.*]] = "tfl.concatenation"(%[[d0q1]], %[[d1q0]]) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> // CHECK-NEXT: %[[qcat1_0:.*]] = "tfl.quantize"(%[[cat1_0]]) {qtype = tensor<1x2x!quant.uniform>, volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> // CHECK-NEXT: %[[r0qcat1_0:.*]] = "tfl.quantize"(%[[qcat1_0]]) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK-NEXT: %[[d0qcat1_0:.*]] = "tfl.dequantize"(%[[r0qcat1_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK-NEXT: %[[cat_2_0:.*]] = "tfl.concatenation"(%[[d0q2]], %[[d0q0]]) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> // CHECK-NEXT: %[[qcat_2_0:.*]] = "tfl.quantize"(%[[cat_2_0]]) {qtype = tensor<1x2x!quant.uniform>, volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> // CHECK-NEXT: %[[r0qcat_2_0:.*]] = "tfl.quantize"(%[[qcat_2_0]]) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK-NEXT: %[[d0qcat_2_0:.*]] = "tfl.dequantize"(%[[r0qcat_2_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK-NEXT: %[[dqcat_2_0:.*]] = "tfl.dequantize"(%[[qcat_2_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK-NEXT: %[[cat_2_0_1_0:.*]] = "tfl.concatenation"(%[[dqcat_2_0]], %[[d0qcat1_0]]) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x4xf32> // CHECK-NEXT: %[[qcat_2_0_1_0:.*]] = "tfl.quantize"(%[[cat_2_0_1_0]]) {qtype = tensor<1x4x!quant.uniform>, volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> // CHECK-NEXT: %[[dqcat_2_0_1_0:.*]] = "tfl.dequantize"(%[[qcat_2_0_1_0]]) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> // CHECK-NEXT: %[[cat_2_0_3:.*]] = "tfl.concatenation"(%[[d0qcat_2_0]], %[[d0q3]]) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> // CHECK-NEXT: %[[qcat_2_0_3:.*]] = "tfl.quantize"(%[[cat_2_0_3]]) {qtype = tensor<1x3x!quant.uniform>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> // CHECK-NEXT: %[[dqcat_2_0_3:.*]] = "tfl.dequantize"(%[[qcat_2_0_3]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK-NEXT: return %[[dqcat_2_0_1_0]], %[[dqcat_2_0_3]] : tensor<1x4xf32>, tensor<1x3xf32> }