• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This header file defines common utils used by TFLite transformation
17 // passes to work with op attributes.
18 
19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
20 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
21 
22 #include <string>
23 #include <unordered_map>
24 
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/Twine.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
30 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
31 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
32 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
33 #include "mlir/IR/Attributes.h"  // from @llvm-project
34 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
37 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
38 #include "mlir/IR/Matchers.h"  // from @llvm-project
39 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
40 #include "mlir/Support/LLVM.h"  // from @llvm-project
41 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
43 
44 namespace mlir {
45 namespace quant {
46 
47 // A unit attribute can be attached to the quantize/dequantize ops which are
48 // added by the quantization passes. These ops can be removed erased without
49 // losing accuracy.
50 constexpr char kVolatileOpAttrName[] = "volatile";
51 
52 using QuantParams = quant::QuantizedType;
53 using SignedInteger = std::pair<unsigned, unsigned>;  // bitwidth and sign
54 using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
55 using AccumulatorScaleFunc =
56     std::function<QuantParams(const std::vector<QuantParams>&, bool)>;
57 
58 // Quantization spec of an op, driving the quantization algorithm.
59 struct OpQuantSpec {
60   // Maps the operand index of a bias input to its quantization specifications,
61   // including the non-bias operand indexes and the method retrieving
62   // quantization parameters from list of parameters of the non-bias operands.
63   // This map is empty if the op doesn't have a bias operand.
64   std::unordered_map<int, std::pair<std::vector<int>, AccumulatorScaleFunc>>
65       biases_params;
66 
67   // Quantization parameters for value restricted outputs. This is the
68   // "hard-coded" parameters and should be used unconditionally for the
69   // quantized op. This vector is empty if the op doesn't have value restricted
70   // outputs.
71   llvm::DenseMap<SignedInteger, QuantParamsForResults> restricted_output_params;
72 
73   // Coefficient operand index and whether supporting per-channel quantization.
74   // For QAT, this information is carried by the FakeQuant*/QDQ ops, but
75   // post-training quantization, the quantization parameters need to be inferred
76   // from the tensor content and op property. A "-1" value indicates the
77   // operand doesn't support per-channel quantization.
78   llvm::DenseMap<int, int> coeff_op_quant_dim;
79 };
80 
81 // A function signature for getting the particular OpQuantSpec for the provided
82 // op.
83 typedef std::unique_ptr<OpQuantSpec> (*OpQuantSpecGetter)(Operation* op);
84 
85 // Re-calculates scales again in float instead of simply downcasting existing
86 // scales.
87 QuantizedType DownCastScale(QuantizedType type,
88                             const SmallVectorImpl<double>& mins,
89                             const SmallVectorImpl<double>& maxs, Location loc);
90 
91 QuantizedType DownCastScale(QuantizedType type, double min, double max,
92                             Location loc);
93 
94 template <typename Q, typename DQ>
95 struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
ConvertStatsToQDQsConvertStatsToQDQs96   ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed,
97                      bool legacy_float_scale, MLIRContext* context)
98       : OpRewritePattern<quant::StatisticsOp>(context),
99         num_bits(num_bits),
100         narrow_range(narrow_range),
101         is_signed(is_signed),
102         legacy_float_scale(legacy_float_scale) {}
103 
matchAndRewriteConvertStatsToQDQs104   LogicalResult matchAndRewrite(quant::StatisticsOp op,
105                                 PatternRewriter& rewriter) const override {
106     Type expressed = op.getType().cast<ShapedType>().getElementType();
107     quant::QuantizedType quant_type;
108     SmallVector<double, 4> mins, maxs;
109 
110     if (op.axisStats().hasValue()) {
111       int stats_num = op.axisStats()->getNumElements();
112       if (stats_num == 0 || stats_num % 2 != 0) return failure();
113       auto stats = op.axisStats()->dyn_cast<DenseFPElementsAttr>();
114       if (!stats) return failure();
115 
116       for (auto it = stats.begin(), e = stats.end(); it != e; ++it) {
117         double rmin = FloatAttr::getValueAsDouble(*it++);
118         double rmax = FloatAttr::getValueAsDouble(*it);
119         // The default nudging implementation of mlir quant library might cause
120         // clamping during inference if the calibration range isn't wide enough.
121         // So here we adjust the range to include 0.0.
122         rmin = std::min(rmin, 0.0);
123         rmax = std::max(rmax, 0.0);
124         TensorRangeSanityCheck(op, rmin, rmax);
125         mins.push_back(rmin);
126         maxs.push_back(rmax);
127       }
128       quant_type =
129           quant::fakeQuantAttrsToType(op.getLoc(), num_bits, *op.axis(), mins,
130                                       maxs, narrow_range, expressed, is_signed);
131       if (legacy_float_scale) {
132         quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc());
133       }
134     } else if (auto stats = op.layerStats().dyn_cast<DenseFPElementsAttr>()) {
135       double rmin = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
136       double rmax = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
137       // The default nudging implementation of mlir quant library might cause
138       // clamping during inference if the calibration range isn't wide enough.
139       // So here we adjust the range to include 0.0.
140       rmin = std::min(rmin, 0.0);
141       rmax = std::max(rmax, 0.0);
142       TensorRangeSanityCheck(op, rmin, rmax);
143       quant_type =
144           quant::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax,
145                                       narrow_range, expressed, is_signed);
146       if (legacy_float_scale) {
147         quant_type = DownCastScale(quant_type, rmin, rmax, op->getLoc());
148       }
149     } else {
150       return failure();
151     }
152 
153     rewriter.setInsertionPointAfter(op.getOperation());
154     Type result_type = quant_type.castFromExpressedType(op.getType());
155     auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg());
156     q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr());
157 
158     auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
159     op.getResult().replaceAllUsesWith(dq);
160     q.getOperation()->replaceUsesOfWith(dq, op.arg());
161     op.erase();
162 
163     return success();
164   }
165 
166  private:
167   int num_bits;
168   bool narrow_range;
169   bool is_signed;
170   bool legacy_float_scale;
171 
172   // Emits an op warning message if the calibrated range is larger than 10.0 and
173   // the storage type is less than or equal to 8 bits.
TensorRangeSanityCheckConvertStatsToQDQs174   void TensorRangeSanityCheck(quant::StatisticsOp op, double min,
175                               double max) const {
176     double range = std::fabs(max - min);
177     if (num_bits <= 8 && range >= 10.0) {
178       op.emitWarning(
179           "Tensor range is too wide to be quantized. Use tf.clip_by_value or "
180           "tf.relu6 to narrow the tensor range. Range: " +
181           std::to_string(range) + ", bit width: " + std::to_string(num_bits));
182     }
183   }
184 };
185 
186 // A base rewrite pattern which matches any N-in-M-out operations with
187 // quantization parameters propagated to at least one of its operands. The
188 // quantization parameters are annotated by the Q/DQ op pairs. Each
189 // matched pattern are rewritten by its quantized alternatives.
190 //
191 // The concrete pattern, extends from this base pattern, can specify whether it
192 // allows "hybrid" operands or results. These "hybrid" operands and results
193 // don't have quantization parameters propagated to, so will be in float in the
194 // quantized results. The concrete pattern should define the following two
195 // functions:
196 //
197 //   bool AllowHybridOperand() const
198 //   bool AllowHybridResult() const
199 //
200 // Full integer quantization disallows "hybrid" operands or results.
201 // Weight quantization allows "hybrid" operands and results.
202 template <typename ConcretTy, typename Q, typename DQ, typename VERIFIER>
203 struct QuantizationPattern : public RewritePattern {
204   using BaseType = QuantizationPattern<ConcretTy, Q, DQ, VERIFIER>;
205 
206   explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
207                                float error_tolerance, bool single_layer_verify,
208                                bool log_if_failed = false)
209       // Set the score to a large number so it is always preferred.
210       : RewritePattern(DQ::getOperationName(), 300, context),
211         enable_verify(enable_verify),
212         error_tolerance(error_tolerance),
213         single_layer_verify(single_layer_verify),
214         log_if_failed(log_if_failed) {}
215 
matchAndRewriteQuantizationPattern216   LogicalResult matchAndRewrite(Operation* op,
217                                 PatternRewriter& rewriter) const override {
218     if (op->getNumResults() != 1) {
219       return failure();
220     }
221     Value quantized_value = op->getResult(0);
222     for (Operation* quantized_op : quantized_value.getUsers()) {
223       // If it is requantize op, we shouldn't rewrite this op.
224       if (llvm::isa<Q, DQ>(quantized_op)) {
225         return failure();
226       }
227 
228       // If it is terminator or not quantizable or any ops form the mlir quant
229       // ops dialect, we shouldn't rewrite.
230       if (quantized_op->hasTrait<OpTrait::IsTerminator>() ||
231           quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
232           llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(
233               quantized_op)) {
234         return failure();
235       }
236 
237       // Collect all the quantized inputs and "clone" the matched op by these
238       // inputs.
239       SmallVector<Value, 4> inputs;
240       inputs.reserve(quantized_op->getNumOperands());
241       for (auto operand : quantized_op->getOperands()) {
242         Type operand_type = operand.getType();
243         if (operand_type.isa<NoneType>()) {
244           inputs.push_back(operand);
245           continue;
246         }
247 
248         auto ele_type = operand.getType().cast<TensorType>().getElementType();
249         if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
250           inputs.push_back(op_inst.input());
251         } else if (ele_type.isSignlessInteger()) {
252           // If the operand is an integer tensor, then it doesn't require the
253           // DQ op in the pattern.
254           inputs.push_back(operand);
255         } else if (static_cast<const ConcretTy*>(this)->AllowHybridOperand()) {
256           inputs.push_back(operand);
257         } else {
258           return failure();
259         }
260       }
261 
262       // Collect all the quantized outputs and replace them by the results of
263       // the new quantized op.
264       llvm::SmallDenseMap<Value, int> outputs_replaced;
265       SmallVector<Type, 4> output_types;
266       output_types.reserve(quantized_op->getNumResults());
267       for (auto enumerated_result :
268            llvm::enumerate(quantized_op->getResults())) {
269         Value result = enumerated_result.value();
270         Type result_type = result.getType();
271         // Add this to the test coverage once we create test ops with none type
272         // results.
273         if (result_type.isa<NoneType>()) {
274           outputs_replaced.insert({result, enumerated_result.index()});
275           output_types.push_back(result_type);
276           continue;
277         }
278         Type result_ele_type =
279             result.getType().cast<TensorType>().getElementType();
280         // If the user is the Quantize op, it must be the only user.
281         if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) {
282           auto user = llvm::cast<Q>(*result.user_begin());
283           outputs_replaced.insert({user.output(), enumerated_result.index()});
284           output_types.push_back(user.getType());
285         } else if (result_ele_type.isSignlessInteger()) {
286           // If the result is an integer tensor, then it doesn't require the
287           // D op in the pattern.
288           outputs_replaced.insert({result, enumerated_result.index()});
289           output_types.push_back(result.getType());
290         } else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) {
291           outputs_replaced.insert({result, enumerated_result.index()});
292           output_types.push_back(result.getType());
293         } else {
294           return failure();
295         }
296       }
297 
298       rewriter.setInsertionPointAfter(quantized_op);
299       OperationState new_state(quantized_op->getLoc(),
300                                quantized_op->getName().getStringRef(), inputs,
301                                output_types, quantized_op->getAttrs());
302       for (int i = 0; i < quantized_op->getNumRegions(); ++i) {
303         new_state.addRegion();
304       }
305       Operation* new_op = rewriter.createOperation(new_state);
306       if (quantized_op->getNumRegions() != 0) {
307         for (auto indexed_regions :
308              llvm::enumerate(quantized_op->getRegions())) {
309           new_op->getRegion(indexed_regions.index())
310               .takeBody(indexed_regions.value());
311         }
312       }
313       for (auto output : outputs_replaced) {
314         output.getFirst().replaceAllUsesWith(
315             new_op->getResult(output.getSecond()));
316       }
317 
318       // To verify the numericals, the original floating-point ops are
319       // preserved in the graph. The result of these floating-point ops are sent
320       // to a numeric verifier op as the reference.
321       if (enable_verify) {
322         // For constant operands, the floating-point constant is duplicated in
323         // case it is quantized.
324         for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) {
325           auto def = new_op->getOperand(i).getDefiningOp();
326           if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
327             DenseFPElementsAttr attr;
328             if (!matchPattern(q.input(), m_Constant(&attr))) {
329               continue;
330             }
331             auto cst = rewriter.create<ConstantOp>(new_op->getLoc(), attr);
332             quantized_op->setOperand(i, cst.getResult());
333           }
334         }
335 
336         for (int i = 0, e = new_op->getNumResults(); i != e; ++i) {
337           if (!quantized_op->getResult(i)
338                    .getType()
339                    .cast<ShapedType>()
340                    .getElementType()
341                    .isa<FloatType>()) {
342             continue;
343           }
344           rewriter.setInsertionPointAfter(new_op);
345           FloatAttr tolerance = rewriter.getF32FloatAttr(error_tolerance);
346           BoolAttr log = rewriter.getBoolAttr(log_if_failed);
347           // Verify the quantized value by sending the result to the verifier.
348           rewriter.create<VERIFIER>(
349               quantized_op->getLoc(), new_op->getResult(i).getType(),
350               new_op->getResult(i), quantized_op->getResult(i), tolerance, log);
351 
352           if (single_layer_verify) continue;
353 
354           // Find the Dequantize/Dequantize users of the new op results, and
355           // replace the usage. Then all the floating-point ops are connected.
356           // N.B. the return op will use this floating-point result.
357           for (auto user : new_op->getResult(i).getUsers()) {
358             // Skip the Requantize op, and we know it has a single user.
359             if (llvm::isa<Q>(user)) {
360               user = *user->getResult(0).getUsers().begin();
361             }
362             if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
363               dequantize.getResult().replaceAllUsesWith(
364                   quantized_op->getResult(i));
365             }
366           }
367         }
368       }
369     }
370     return success();
371   }
372 
373   bool enable_verify;
374   float error_tolerance;
375   bool single_layer_verify;
376   bool log_if_failed;
377 };
378 
379 // Converts quantize ops with unsigned quantized types to these with signed
380 // quantized types and preserves the scales.
381 template <typename Q>
382 struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
383   using BaseType = ConvertUnsignedToSigned<Q>;
384   using QType = quant::QuantizedType;
385 
ConvertUnsignedToSignedConvertUnsignedToSigned386   explicit ConvertUnsignedToSigned(MLIRContext* context)
387       : OpRewritePattern<Q>(context, 1) {}
388 
matchAndRewriteConvertUnsignedToSigned389   LogicalResult matchAndRewrite(Q op,
390                                 PatternRewriter& rewriter) const override {
391     Type output_type = op.getResult().getType();
392     auto qtype = QType::getQuantizedElementType(output_type);
393     if (!qtype || qtype.isSigned()) return failure();
394 
395     int num_bits = qtype.getStorageTypeIntegralWidth();
396     // This is a positive value, and will be applied on zero points and fixed
397     // point ranges.
398     int64_t offset =
399         QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits) -
400         QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits);
401 
402     auto flags = quant::QuantizationFlags::Signed;
403     QType new_qtype;
404     if (auto uqtype = qtype.template dyn_cast<quant::UniformQuantizedType>()) {
405       new_qtype = quant::UniformQuantizedType::getChecked(
406           flags, qtype.getStorageType(), qtype.getExpressedType(),
407           uqtype.getScale(), uqtype.getZeroPoint() - offset,
408           uqtype.getStorageTypeMin() - offset,
409           uqtype.getStorageTypeMax() - offset, op.getLoc());
410     } else if (auto aqtype = qtype.template dyn_cast<
411                              quant::UniformQuantizedPerAxisType>()) {
412       auto zero_points = aqtype.getZeroPoints();
413       llvm::SmallVector<int64_t, 4> new_zero_points(zero_points.begin(),
414                                                     zero_points.end());
415       for (int i = 0, e = new_zero_points.size(); i != e; ++i) {
416         new_zero_points[i] -= offset;
417       }
418       new_qtype = quant::UniformQuantizedPerAxisType::getChecked(
419           flags, qtype.getStorageType(), qtype.getExpressedType(),
420           aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(),
421           aqtype.getStorageTypeMin() - offset,
422           aqtype.getStorageTypeMax() - offset, op.getLoc());
423     } else {
424       return failure();
425     }
426 
427     if (!new_qtype) return failure();
428     Type new_output_type = new_qtype.castFromExpressedType(
429         QType::castToExpressedType(output_type));
430     rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg());
431     return success();
432   }
433 };
434 
435 // Fold Extra Requantize ops if the preceding ops has free scale requirement.
436 template <typename RQ>
437 struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
FoldTrivalRequantizeOpFoldTrivalRequantizeOp438   explicit FoldTrivalRequantizeOp(MLIRContext* context)
439       : OpRewritePattern<RQ>(context, 1) {}
440 
matchAndRewriteFoldTrivalRequantizeOp441   LogicalResult matchAndRewrite(RQ op,
442                                 PatternRewriter& rewriter) const override {
443     Value pre_quantized = op.input();
444     auto pre_quantized_type =
445         quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
446     if (!pre_quantized_type) return failure();
447 
448     Operation* def = pre_quantized.getDefiningOp();
449     if (!def) return failure();
450     if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) ||
451         def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
452       return failure();
453     }
454 
455     op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
456 
457     llvm::SmallVector<Type, 4> new_output_types;
458     for (auto result : def->getResults()) {
459       if (result.hasOneUse() && *result.getUsers().begin() == op) {
460         new_output_types.push_back(op.qtype());
461       } else {
462         new_output_types.push_back(result.getType());
463       }
464     }
465 
466     // Remove this rescale op.
467     rewriter.replaceOp(op, {pre_quantized});
468 
469     // Replace the output scale of the preceding op.
470     rewriter.setInsertionPointAfter(def);
471     OperationState new_state(def->getLoc(), def->getName().getStringRef(),
472                              def->getOperands(), new_output_types,
473                              def->getAttrs());
474     Operation* new_op = rewriter.createOperation(new_state);
475 
476     rewriter.replaceOp(def, new_op->getResults());
477     return success();
478   }
479 };
480 
481 // Given a quantized type `input`, magnifying its scales by the factor stored in
482 // `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
483 // dimension size of `input` or isn't floating-point, nullptr will be returned.
484 TypeAttr RescaleQuantizedType(Type input, Attribute factor);
485 
486 // Converts the min/max/num_bits/narrow_range information to a
487 // QuantizedType, and then returns the attribute containing the QuantizedType.
488 // The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and
489 // returns UniformQuantizedType or UniformQuantizedPerAxisType respectively.
490 // `narrow_range` is set to true for weights and `is_signed` is set to true
491 // if it is using signed int symmetric quantization.
492 //
493 // Note that this method may broadcast min and max to match the dimension length
494 // of `input_type`, if the `quant_dim` is valid. On the other hand, the
495 // symmetry of min and max is not adjusted by this method. The QAT workflow
496 // should set min/max correctly (and use `narrow_range`=true, `is_signed`=true)
497 // if symmetric quantization is required.
498 TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
499                               Attribute max, int quant_dim,
500                               IntegerAttr num_bits, BoolAttr narrow_range,
501                               bool is_signed, bool legacy_float_scale = false);
502 
503 // Casts the `target` type to a quantized type by using the quantization
504 // parameters from the type in the `source` type attribute.
505 // Examples:
506 //   f32 -> !quant.uniform<i8:f32, 1.0>
507 //   tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
508 // The result is wrapped by a type attribute. Returns nullptr if the cast
509 // isn't valid.
510 //
511 // `axis` is to specify the quantization dimension in the `target` and only
512 // used if the element type of `source` is a per-channel quantized type. During
513 // the casting, the quantization dimension of the result type needs to be set
514 // this new `axis` value.
515 TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
516                                                 TypeAttr source, Type target,
517                                                 int axis);
518 
519 // Quantizes the elements in the attribute `real_value` by the quantization
520 // parameters in `tensor_type`. Returns empty Attribute if the
521 // `tensor_type` is not a QuantizedType or the quantization fails.
522 ElementsAttr Quantize(Attribute real_value, Type tensor_type);
523 
524 // Quantizes the elements in "legacy mode", where it calls TOCO's methods to
525 // to quantize values with float scale.
526 ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type);
527 
528 // Returns the quantized type for an element attribute. The quantization
529 // parameters in this type is based on the min and max element of the
530 // attribute. When the elements in the `attr` are not in floating-point, or
531 // the value range isn't straddling zero, an empty type is returned. The min/max
532 // are adjusted to be symmetric if `symmetric` flag is set to True. And
533 // `symmetric` can only be set to true when it is signed and narrow_range.
534 Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,
535                                       unsigned num_bits, bool is_signed,
536                                       bool narrow_range,
537                                       bool legacy_float_scale = false);
538 
539 // Returns the per channel quantized type for an element attribute.
540 // `quant_dim` defines the quantization axis. The channel min/max are adjusted
541 // to be symmetric if `symmetric` flag is set to True. And `symmetric` can only
542 // be set to true when it is signed and narrow_range.
543 Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim,
544                                              bool symmetric, unsigned num_bits,
545                                              bool is_signed, bool narrow_range,
546                                              bool legacy_float_scale = false);
547 
548 // Returns the quantized type of a bias input, given the quantized types of
549 // other operands which are multiply-accumulated (the bias is added to the
550 // accumulated value).
551 quant::QuantizedType GetUniformQuantizedTypeForBias(
552     const std::vector<quant::QuantizedType>& op_types,
553     bool legacy_float_scale = false);
554 
555 // Propagates quantization parameters across ops in this function and satisfy
556 // the quantization specification of the ops. This methods assumes the initial
557 // quantization parameters are stored as adjacent quantize and dequantize ops
558 // and the propagation results are materialized by inserting pairs of quantize
559 // and dequantize ops to this function. Set `disable_per_channel` to true to not
560 // use per channel quantization even the op supports it.
561 // Setting `infer_tensor_range` to true, to infer quantization parameters from
562 // the activation ops and weight constants. This is only used for post-training
563 // quantization.
564 void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
565                                         bool disable_per_channel,
566                                         OpQuantSpecGetter op_quant_spec_getter,
567                                         bool infer_tensor_ranges,
568                                         bool legacy_float_scale = false);
569 
570 // The function might contain more stats ops than required, and it will
571 // introduce requantize if the calibration stats have conflicts. This method
572 // tries to remove all the redundant stats ops.
573 bool RemoveRedundantStatsOps(mlir::FuncOp func,
574                              OpQuantSpecGetter op_quant_spec_getter);
575 
576 // Given quantization parameters for int8, compute the quantization parameters
577 // for uint if it is required, and wrap the result in an UniformQuantizedType.
578 quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
579                                                 Type tensor_type, double scale,
580                                                 int64_t zero_point,
581                                                 int64_t storage_min = -128,
582                                                 int64_t storage_max = 127);
583 
584 // Extrace min and max values from the DenseFPElementsAttr, and stores them into
585 // `mins` and `maxs`. When mins and maxs are extracted per-channel, `dim_size`
586 // is number of channels and `slice_size` is the size of slice per each channel.
587 // When `symmetric` is true, the range is expanded to [-M, M].
588 void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
589                            int slice_size, bool symmetric,
590                            SmallVectorImpl<double>& mins,
591                            SmallVectorImpl<double>& maxs);
592 
593 // Returns the quantized type for the
594 // input_type/min/max/storag_type_width/narrow_range.
595 Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min,
596                       ArrayRef<double> max, int quant_dim,
597                       int storage_type_width, bool narrow_range, bool is_signed,
598                       bool legacy_float_scale = false);
599 }  // namespace quant
600 }  // namespace mlir
601 
602 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
603