1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This header file defines common utils used by TFLite transformation 17 // passes to work with op attributes. 18 19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_ 20 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_ 21 22 #include <string> 23 #include <unordered_map> 24 25 #include "llvm/ADT/SmallVector.h" 26 #include "llvm/ADT/Twine.h" 27 #include "llvm/Support/Casting.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project 30 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project 31 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project 32 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project 33 #include "mlir/IR/Attributes.h" // from @llvm-project 34 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project 35 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 36 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 37 #include "mlir/IR/MLIRContext.h" // from @llvm-project 38 #include "mlir/IR/Matchers.h" // from @llvm-project 39 #include "mlir/IR/PatternMatch.h" // from @llvm-project 40 #include "mlir/Support/LLVM.h" // from @llvm-project 41 #include "mlir/Support/LogicalResult.h" // from @llvm-project 42 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" 43 44 namespace mlir { 45 namespace quant { 46 47 // A unit attribute can be attached to the quantize/dequantize ops which are 48 // added by the quantization passes. These ops can be removed erased without 49 // losing accuracy. 50 constexpr char kVolatileOpAttrName[] = "volatile"; 51 52 using QuantParams = quant::QuantizedType; 53 using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign 54 using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>; 55 using AccumulatorScaleFunc = 56 std::function<QuantParams(const std::vector<QuantParams>&, bool)>; 57 58 // Quantization spec of an op, driving the quantization algorithm. 59 struct OpQuantSpec { 60 // Maps the operand index of a bias input to its quantization specifications, 61 // including the non-bias operand indexes and the method retrieving 62 // quantization parameters from list of parameters of the non-bias operands. 63 // This map is empty if the op doesn't have a bias operand. 64 std::unordered_map<int, std::pair<std::vector<int>, AccumulatorScaleFunc>> 65 biases_params; 66 67 // Quantization parameters for value restricted outputs. This is the 68 // "hard-coded" parameters and should be used unconditionally for the 69 // quantized op. This vector is empty if the op doesn't have value restricted 70 // outputs. 71 llvm::DenseMap<SignedInteger, QuantParamsForResults> restricted_output_params; 72 73 // Coefficient operand index and whether supporting per-channel quantization. 74 // For QAT, this information is carried by the FakeQuant*/QDQ ops, but 75 // post-training quantization, the quantization parameters need to be inferred 76 // from the tensor content and op property. A "-1" value indicates the 77 // operand doesn't support per-channel quantization. 78 llvm::DenseMap<int, int> coeff_op_quant_dim; 79 }; 80 81 // A function signature for getting the particular OpQuantSpec for the provided 82 // op. 83 typedef std::unique_ptr<OpQuantSpec> (*OpQuantSpecGetter)(Operation* op); 84 85 // Re-calculates scales again in float instead of simply downcasting existing 86 // scales. 87 QuantizedType DownCastScale(QuantizedType type, 88 const SmallVectorImpl<double>& mins, 89 const SmallVectorImpl<double>& maxs, Location loc); 90 91 QuantizedType DownCastScale(QuantizedType type, double min, double max, 92 Location loc); 93 94 template <typename Q, typename DQ> 95 struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> { ConvertStatsToQDQsConvertStatsToQDQs96 ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed, 97 bool legacy_float_scale, MLIRContext* context) 98 : OpRewritePattern<quant::StatisticsOp>(context), 99 num_bits(num_bits), 100 narrow_range(narrow_range), 101 is_signed(is_signed), 102 legacy_float_scale(legacy_float_scale) {} 103 matchAndRewriteConvertStatsToQDQs104 LogicalResult matchAndRewrite(quant::StatisticsOp op, 105 PatternRewriter& rewriter) const override { 106 Type expressed = op.getType().cast<ShapedType>().getElementType(); 107 quant::QuantizedType quant_type; 108 SmallVector<double, 4> mins, maxs; 109 110 if (op.axisStats().hasValue()) { 111 int stats_num = op.axisStats()->getNumElements(); 112 if (stats_num == 0 || stats_num % 2 != 0) return failure(); 113 auto stats = op.axisStats()->dyn_cast<DenseFPElementsAttr>(); 114 if (!stats) return failure(); 115 116 for (auto it = stats.begin(), e = stats.end(); it != e; ++it) { 117 double rmin = FloatAttr::getValueAsDouble(*it++); 118 double rmax = FloatAttr::getValueAsDouble(*it); 119 // The default nudging implementation of mlir quant library might cause 120 // clamping during inference if the calibration range isn't wide enough. 121 // So here we adjust the range to include 0.0. 122 rmin = std::min(rmin, 0.0); 123 rmax = std::max(rmax, 0.0); 124 TensorRangeSanityCheck(op, rmin, rmax); 125 mins.push_back(rmin); 126 maxs.push_back(rmax); 127 } 128 quant_type = 129 quant::fakeQuantAttrsToType(op.getLoc(), num_bits, *op.axis(), mins, 130 maxs, narrow_range, expressed, is_signed); 131 if (legacy_float_scale) { 132 quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc()); 133 } 134 } else if (auto stats = op.layerStats().dyn_cast<DenseFPElementsAttr>()) { 135 double rmin = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0})); 136 double rmax = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1})); 137 // The default nudging implementation of mlir quant library might cause 138 // clamping during inference if the calibration range isn't wide enough. 139 // So here we adjust the range to include 0.0. 140 rmin = std::min(rmin, 0.0); 141 rmax = std::max(rmax, 0.0); 142 TensorRangeSanityCheck(op, rmin, rmax); 143 quant_type = 144 quant::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax, 145 narrow_range, expressed, is_signed); 146 if (legacy_float_scale) { 147 quant_type = DownCastScale(quant_type, rmin, rmax, op->getLoc()); 148 } 149 } else { 150 return failure(); 151 } 152 153 rewriter.setInsertionPointAfter(op.getOperation()); 154 Type result_type = quant_type.castFromExpressedType(op.getType()); 155 auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg()); 156 q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr()); 157 158 auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q); 159 op.getResult().replaceAllUsesWith(dq); 160 q.getOperation()->replaceUsesOfWith(dq, op.arg()); 161 op.erase(); 162 163 return success(); 164 } 165 166 private: 167 int num_bits; 168 bool narrow_range; 169 bool is_signed; 170 bool legacy_float_scale; 171 172 // Emits an op warning message if the calibrated range is larger than 10.0 and 173 // the storage type is less than or equal to 8 bits. TensorRangeSanityCheckConvertStatsToQDQs174 void TensorRangeSanityCheck(quant::StatisticsOp op, double min, 175 double max) const { 176 double range = std::fabs(max - min); 177 if (num_bits <= 8 && range >= 10.0) { 178 op.emitWarning( 179 "Tensor range is too wide to be quantized. Use tf.clip_by_value or " 180 "tf.relu6 to narrow the tensor range. Range: " + 181 std::to_string(range) + ", bit width: " + std::to_string(num_bits)); 182 } 183 } 184 }; 185 186 // A base rewrite pattern which matches any N-in-M-out operations with 187 // quantization parameters propagated to at least one of its operands. The 188 // quantization parameters are annotated by the Q/DQ op pairs. Each 189 // matched pattern are rewritten by its quantized alternatives. 190 // 191 // The concrete pattern, extends from this base pattern, can specify whether it 192 // allows "hybrid" operands or results. These "hybrid" operands and results 193 // don't have quantization parameters propagated to, so will be in float in the 194 // quantized results. The concrete pattern should define the following two 195 // functions: 196 // 197 // bool AllowHybridOperand() const 198 // bool AllowHybridResult() const 199 // 200 // Full integer quantization disallows "hybrid" operands or results. 201 // Weight quantization allows "hybrid" operands and results. 202 template <typename ConcretTy, typename Q, typename DQ, typename VERIFIER> 203 struct QuantizationPattern : public RewritePattern { 204 using BaseType = QuantizationPattern<ConcretTy, Q, DQ, VERIFIER>; 205 206 explicit QuantizationPattern(MLIRContext* context, bool enable_verify, 207 float error_tolerance, bool single_layer_verify, 208 bool log_if_failed = false) 209 // Set the score to a large number so it is always preferred. 210 : RewritePattern(DQ::getOperationName(), 300, context), 211 enable_verify(enable_verify), 212 error_tolerance(error_tolerance), 213 single_layer_verify(single_layer_verify), 214 log_if_failed(log_if_failed) {} 215 matchAndRewriteQuantizationPattern216 LogicalResult matchAndRewrite(Operation* op, 217 PatternRewriter& rewriter) const override { 218 if (op->getNumResults() != 1) { 219 return failure(); 220 } 221 Value quantized_value = op->getResult(0); 222 for (Operation* quantized_op : quantized_value.getUsers()) { 223 // If it is requantize op, we shouldn't rewrite this op. 224 if (llvm::isa<Q, DQ>(quantized_op)) { 225 return failure(); 226 } 227 228 // If it is terminator or not quantizable or any ops form the mlir quant 229 // ops dialect, we shouldn't rewrite. 230 if (quantized_op->hasTrait<OpTrait::IsTerminator>() || 231 quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() || 232 llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>( 233 quantized_op)) { 234 return failure(); 235 } 236 237 // Collect all the quantized inputs and "clone" the matched op by these 238 // inputs. 239 SmallVector<Value, 4> inputs; 240 inputs.reserve(quantized_op->getNumOperands()); 241 for (auto operand : quantized_op->getOperands()) { 242 Type operand_type = operand.getType(); 243 if (operand_type.isa<NoneType>()) { 244 inputs.push_back(operand); 245 continue; 246 } 247 248 auto ele_type = operand.getType().cast<TensorType>().getElementType(); 249 if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) { 250 inputs.push_back(op_inst.input()); 251 } else if (ele_type.isSignlessInteger()) { 252 // If the operand is an integer tensor, then it doesn't require the 253 // DQ op in the pattern. 254 inputs.push_back(operand); 255 } else if (static_cast<const ConcretTy*>(this)->AllowHybridOperand()) { 256 inputs.push_back(operand); 257 } else { 258 return failure(); 259 } 260 } 261 262 // Collect all the quantized outputs and replace them by the results of 263 // the new quantized op. 264 llvm::SmallDenseMap<Value, int> outputs_replaced; 265 SmallVector<Type, 4> output_types; 266 output_types.reserve(quantized_op->getNumResults()); 267 for (auto enumerated_result : 268 llvm::enumerate(quantized_op->getResults())) { 269 Value result = enumerated_result.value(); 270 Type result_type = result.getType(); 271 // Add this to the test coverage once we create test ops with none type 272 // results. 273 if (result_type.isa<NoneType>()) { 274 outputs_replaced.insert({result, enumerated_result.index()}); 275 output_types.push_back(result_type); 276 continue; 277 } 278 Type result_ele_type = 279 result.getType().cast<TensorType>().getElementType(); 280 // If the user is the Quantize op, it must be the only user. 281 if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) { 282 auto user = llvm::cast<Q>(*result.user_begin()); 283 outputs_replaced.insert({user.output(), enumerated_result.index()}); 284 output_types.push_back(user.getType()); 285 } else if (result_ele_type.isSignlessInteger()) { 286 // If the result is an integer tensor, then it doesn't require the 287 // D op in the pattern. 288 outputs_replaced.insert({result, enumerated_result.index()}); 289 output_types.push_back(result.getType()); 290 } else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) { 291 outputs_replaced.insert({result, enumerated_result.index()}); 292 output_types.push_back(result.getType()); 293 } else { 294 return failure(); 295 } 296 } 297 298 rewriter.setInsertionPointAfter(quantized_op); 299 OperationState new_state(quantized_op->getLoc(), 300 quantized_op->getName().getStringRef(), inputs, 301 output_types, quantized_op->getAttrs()); 302 for (int i = 0; i < quantized_op->getNumRegions(); ++i) { 303 new_state.addRegion(); 304 } 305 Operation* new_op = rewriter.createOperation(new_state); 306 if (quantized_op->getNumRegions() != 0) { 307 for (auto indexed_regions : 308 llvm::enumerate(quantized_op->getRegions())) { 309 new_op->getRegion(indexed_regions.index()) 310 .takeBody(indexed_regions.value()); 311 } 312 } 313 for (auto output : outputs_replaced) { 314 output.getFirst().replaceAllUsesWith( 315 new_op->getResult(output.getSecond())); 316 } 317 318 // To verify the numericals, the original floating-point ops are 319 // preserved in the graph. The result of these floating-point ops are sent 320 // to a numeric verifier op as the reference. 321 if (enable_verify) { 322 // For constant operands, the floating-point constant is duplicated in 323 // case it is quantized. 324 for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) { 325 auto def = new_op->getOperand(i).getDefiningOp(); 326 if (auto q = llvm::dyn_cast_or_null<Q>(def)) { 327 DenseFPElementsAttr attr; 328 if (!matchPattern(q.input(), m_Constant(&attr))) { 329 continue; 330 } 331 auto cst = rewriter.create<ConstantOp>(new_op->getLoc(), attr); 332 quantized_op->setOperand(i, cst.getResult()); 333 } 334 } 335 336 for (int i = 0, e = new_op->getNumResults(); i != e; ++i) { 337 if (!quantized_op->getResult(i) 338 .getType() 339 .cast<ShapedType>() 340 .getElementType() 341 .isa<FloatType>()) { 342 continue; 343 } 344 rewriter.setInsertionPointAfter(new_op); 345 FloatAttr tolerance = rewriter.getF32FloatAttr(error_tolerance); 346 BoolAttr log = rewriter.getBoolAttr(log_if_failed); 347 // Verify the quantized value by sending the result to the verifier. 348 rewriter.create<VERIFIER>( 349 quantized_op->getLoc(), new_op->getResult(i).getType(), 350 new_op->getResult(i), quantized_op->getResult(i), tolerance, log); 351 352 if (single_layer_verify) continue; 353 354 // Find the Dequantize/Dequantize users of the new op results, and 355 // replace the usage. Then all the floating-point ops are connected. 356 // N.B. the return op will use this floating-point result. 357 for (auto user : new_op->getResult(i).getUsers()) { 358 // Skip the Requantize op, and we know it has a single user. 359 if (llvm::isa<Q>(user)) { 360 user = *user->getResult(0).getUsers().begin(); 361 } 362 if (auto dequantize = llvm::dyn_cast<DQ>(user)) { 363 dequantize.getResult().replaceAllUsesWith( 364 quantized_op->getResult(i)); 365 } 366 } 367 } 368 } 369 } 370 return success(); 371 } 372 373 bool enable_verify; 374 float error_tolerance; 375 bool single_layer_verify; 376 bool log_if_failed; 377 }; 378 379 // Converts quantize ops with unsigned quantized types to these with signed 380 // quantized types and preserves the scales. 381 template <typename Q> 382 struct ConvertUnsignedToSigned : public OpRewritePattern<Q> { 383 using BaseType = ConvertUnsignedToSigned<Q>; 384 using QType = quant::QuantizedType; 385 ConvertUnsignedToSignedConvertUnsignedToSigned386 explicit ConvertUnsignedToSigned(MLIRContext* context) 387 : OpRewritePattern<Q>(context, 1) {} 388 matchAndRewriteConvertUnsignedToSigned389 LogicalResult matchAndRewrite(Q op, 390 PatternRewriter& rewriter) const override { 391 Type output_type = op.getResult().getType(); 392 auto qtype = QType::getQuantizedElementType(output_type); 393 if (!qtype || qtype.isSigned()) return failure(); 394 395 int num_bits = qtype.getStorageTypeIntegralWidth(); 396 // This is a positive value, and will be applied on zero points and fixed 397 // point ranges. 398 int64_t offset = 399 QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits) - 400 QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits); 401 402 auto flags = quant::QuantizationFlags::Signed; 403 QType new_qtype; 404 if (auto uqtype = qtype.template dyn_cast<quant::UniformQuantizedType>()) { 405 new_qtype = quant::UniformQuantizedType::getChecked( 406 flags, qtype.getStorageType(), qtype.getExpressedType(), 407 uqtype.getScale(), uqtype.getZeroPoint() - offset, 408 uqtype.getStorageTypeMin() - offset, 409 uqtype.getStorageTypeMax() - offset, op.getLoc()); 410 } else if (auto aqtype = qtype.template dyn_cast< 411 quant::UniformQuantizedPerAxisType>()) { 412 auto zero_points = aqtype.getZeroPoints(); 413 llvm::SmallVector<int64_t, 4> new_zero_points(zero_points.begin(), 414 zero_points.end()); 415 for (int i = 0, e = new_zero_points.size(); i != e; ++i) { 416 new_zero_points[i] -= offset; 417 } 418 new_qtype = quant::UniformQuantizedPerAxisType::getChecked( 419 flags, qtype.getStorageType(), qtype.getExpressedType(), 420 aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(), 421 aqtype.getStorageTypeMin() - offset, 422 aqtype.getStorageTypeMax() - offset, op.getLoc()); 423 } else { 424 return failure(); 425 } 426 427 if (!new_qtype) return failure(); 428 Type new_output_type = new_qtype.castFromExpressedType( 429 QType::castToExpressedType(output_type)); 430 rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg()); 431 return success(); 432 } 433 }; 434 435 // Fold Extra Requantize ops if the preceding ops has free scale requirement. 436 template <typename RQ> 437 struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> { FoldTrivalRequantizeOpFoldTrivalRequantizeOp438 explicit FoldTrivalRequantizeOp(MLIRContext* context) 439 : OpRewritePattern<RQ>(context, 1) {} 440 matchAndRewriteFoldTrivalRequantizeOp441 LogicalResult matchAndRewrite(RQ op, 442 PatternRewriter& rewriter) const override { 443 Value pre_quantized = op.input(); 444 auto pre_quantized_type = 445 quant::QuantizedType::getQuantizedElementType(pre_quantized.getType()); 446 if (!pre_quantized_type) return failure(); 447 448 Operation* def = pre_quantized.getDefiningOp(); 449 if (!def) return failure(); 450 if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) || 451 def->hasTrait<OpTrait::quant::NoQuantizableResult>()) { 452 return failure(); 453 } 454 455 op.emitWarning("Remove trivial `rescale` op. Please fix the source graph."); 456 457 llvm::SmallVector<Type, 4> new_output_types; 458 for (auto result : def->getResults()) { 459 if (result.hasOneUse() && *result.getUsers().begin() == op) { 460 new_output_types.push_back(op.qtype()); 461 } else { 462 new_output_types.push_back(result.getType()); 463 } 464 } 465 466 // Remove this rescale op. 467 rewriter.replaceOp(op, {pre_quantized}); 468 469 // Replace the output scale of the preceding op. 470 rewriter.setInsertionPointAfter(def); 471 OperationState new_state(def->getLoc(), def->getName().getStringRef(), 472 def->getOperands(), new_output_types, 473 def->getAttrs()); 474 Operation* new_op = rewriter.createOperation(new_state); 475 476 rewriter.replaceOp(def, new_op->getResults()); 477 return success(); 478 } 479 }; 480 481 // Given a quantized type `input`, magnifying its scales by the factor stored in 482 // `factor`. If `input` isn't a quantized type or the `factor` doesn't match the 483 // dimension size of `input` or isn't floating-point, nullptr will be returned. 484 TypeAttr RescaleQuantizedType(Type input, Attribute factor); 485 486 // Converts the min/max/num_bits/narrow_range information to a 487 // QuantizedType, and then returns the attribute containing the QuantizedType. 488 // The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and 489 // returns UniformQuantizedType or UniformQuantizedPerAxisType respectively. 490 // `narrow_range` is set to true for weights and `is_signed` is set to true 491 // if it is using signed int symmetric quantization. 492 // 493 // Note that this method may broadcast min and max to match the dimension length 494 // of `input_type`, if the `quant_dim` is valid. On the other hand, the 495 // symmetry of min and max is not adjusted by this method. The QAT workflow 496 // should set min/max correctly (and use `narrow_range`=true, `is_signed`=true) 497 // if symmetric quantization is required. 498 TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min, 499 Attribute max, int quant_dim, 500 IntegerAttr num_bits, BoolAttr narrow_range, 501 bool is_signed, bool legacy_float_scale = false); 502 503 // Casts the `target` type to a quantized type by using the quantization 504 // parameters from the type in the `source` type attribute. 505 // Examples: 506 // f32 -> !quant.uniform<i8:f32, 1.0> 507 // tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>> 508 // The result is wrapped by a type attribute. Returns nullptr if the cast 509 // isn't valid. 510 // 511 // `axis` is to specify the quantization dimension in the `target` and only 512 // used if the element type of `source` is a per-channel quantized type. During 513 // the casting, the quantization dimension of the result type needs to be set 514 // this new `axis` value. 515 TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder, 516 TypeAttr source, Type target, 517 int axis); 518 519 // Quantizes the elements in the attribute `real_value` by the quantization 520 // parameters in `tensor_type`. Returns empty Attribute if the 521 // `tensor_type` is not a QuantizedType or the quantization fails. 522 ElementsAttr Quantize(Attribute real_value, Type tensor_type); 523 524 // Quantizes the elements in "legacy mode", where it calls TOCO's methods to 525 // to quantize values with float scale. 526 ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type); 527 528 // Returns the quantized type for an element attribute. The quantization 529 // parameters in this type is based on the min and max element of the 530 // attribute. When the elements in the `attr` are not in floating-point, or 531 // the value range isn't straddling zero, an empty type is returned. The min/max 532 // are adjusted to be symmetric if `symmetric` flag is set to True. And 533 // `symmetric` can only be set to true when it is signed and narrow_range. 534 Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric, 535 unsigned num_bits, bool is_signed, 536 bool narrow_range, 537 bool legacy_float_scale = false); 538 539 // Returns the per channel quantized type for an element attribute. 540 // `quant_dim` defines the quantization axis. The channel min/max are adjusted 541 // to be symmetric if `symmetric` flag is set to True. And `symmetric` can only 542 // be set to true when it is signed and narrow_range. 543 Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim, 544 bool symmetric, unsigned num_bits, 545 bool is_signed, bool narrow_range, 546 bool legacy_float_scale = false); 547 548 // Returns the quantized type of a bias input, given the quantized types of 549 // other operands which are multiply-accumulated (the bias is added to the 550 // accumulated value). 551 quant::QuantizedType GetUniformQuantizedTypeForBias( 552 const std::vector<quant::QuantizedType>& op_types, 553 bool legacy_float_scale = false); 554 555 // Propagates quantization parameters across ops in this function and satisfy 556 // the quantization specification of the ops. This methods assumes the initial 557 // quantization parameters are stored as adjacent quantize and dequantize ops 558 // and the propagation results are materialized by inserting pairs of quantize 559 // and dequantize ops to this function. Set `disable_per_channel` to true to not 560 // use per channel quantization even the op supports it. 561 // Setting `infer_tensor_range` to true, to infer quantization parameters from 562 // the activation ops and weight constants. This is only used for post-training 563 // quantization. 564 void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed, 565 bool disable_per_channel, 566 OpQuantSpecGetter op_quant_spec_getter, 567 bool infer_tensor_ranges, 568 bool legacy_float_scale = false); 569 570 // The function might contain more stats ops than required, and it will 571 // introduce requantize if the calibration stats have conflicts. This method 572 // tries to remove all the redundant stats ops. 573 bool RemoveRedundantStatsOps(mlir::FuncOp func, 574 OpQuantSpecGetter op_quant_spec_getter); 575 576 // Given quantization parameters for int8, compute the quantization parameters 577 // for uint if it is required, and wrap the result in an UniformQuantizedType. 578 quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, 579 Type tensor_type, double scale, 580 int64_t zero_point, 581 int64_t storage_min = -128, 582 int64_t storage_max = 127); 583 584 // Extrace min and max values from the DenseFPElementsAttr, and stores them into 585 // `mins` and `maxs`. When mins and maxs are extracted per-channel, `dim_size` 586 // is number of channels and `slice_size` is the size of slice per each channel. 587 // When `symmetric` is true, the range is expanded to [-M, M]. 588 void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size, 589 int slice_size, bool symmetric, 590 SmallVectorImpl<double>& mins, 591 SmallVectorImpl<double>& maxs); 592 593 // Returns the quantized type for the 594 // input_type/min/max/storag_type_width/narrow_range. 595 Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min, 596 ArrayRef<double> max, int quant_dim, 597 int storage_type_width, bool narrow_range, bool is_signed, 598 bool legacy_float_scale = false); 599 } // namespace quant 600 } // namespace mlir 601 602 #endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_ 603