• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This header file defines common utils used by TFLite transformation
17 // passes to work with op attributes.
18 
19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
20 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
21 
22 #include <string>
23 #include <unordered_map>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/Twine.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
33 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
34 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
35 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
36 #include "mlir/IR/Attributes.h"  // from @llvm-project
37 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
41 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
42 #include "mlir/IR/Matchers.h"  // from @llvm-project
43 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
44 #include "mlir/Support/LLVM.h"  // from @llvm-project
45 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
46 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
47 
48 namespace mlir {
49 namespace quant {
50 
51 // A unit attribute can be attached to the quantize/dequantize ops which are
52 // added by the quantization passes. These ops can be removed erased without
53 // losing accuracy.
54 constexpr char kVolatileOpAttrName[] = "volatile";
55 
56 enum QuantizationTrait { FullyQuantizable, NotQuantizable };
57 extern const char kQuantTraitAttr[];
58 extern const absl::string_view QuantTraitValues[];
59 
60 using QuantParams = quant::QuantizedType;
61 using SignedInteger = std::pair<unsigned, unsigned>;  // bitwidth and sign
62 using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
63 using AccumulatorScaleFunc =
64     std::function<QuantParams(const std::vector<QuantParams>&, bool)>;
65 using StringSet = absl::flat_hash_set<std::string>;
66 
67 // Quantization spec of an op, driving the quantization algorithm.
68 struct OpQuantSpec {
69   // Maps the operand index of a bias input to its quantization specifications,
70   // including the non-bias operand indexes and the method retrieving
71   // quantization parameters from list of parameters of the non-bias operands.
72   // This map is empty if the op doesn't have a bias operand.
73   std::unordered_map<int, std::pair<std::vector<int>, AccumulatorScaleFunc>>
74       biases_params;
75 
76   // Quantization parameters for value restricted outputs. This is the
77   // "hard-coded" parameters and should be used unconditionally for the
78   // quantized op. This vector is empty if the op doesn't have value restricted
79   // outputs.
80   llvm::DenseMap<SignedInteger, QuantParamsForResults> restricted_output_params;
81 
82   // Coefficient operand index and whether supporting per-channel quantization.
83   // For QAT, this information is carried by the FakeQuant*/QDQ ops, but
84   // post-training quantization, the quantization parameters need to be inferred
85   // from the tensor content and op property. A "-1" value indicates the
86   // operand doesn't support per-channel quantization.
87   llvm::DenseMap<int, int> coeff_op_quant_dim;
88 };
89 
90 // A function signature for getting the particular OpQuantSpec for the provided
91 // op.
92 typedef std::unique_ptr<OpQuantSpec> (*OpQuantSpecGetter)(Operation* op);
93 
94 // Re-calculates scales again in float instead of simply downcasting existing
95 // scales.
96 QuantizedType DownCastScale(QuantizedType type,
97                             const SmallVectorImpl<double>& mins,
98                             const SmallVectorImpl<double>& maxs, Location loc);
99 
100 QuantizedType DownCastScale(QuantizedType type, double min, double max,
101                             Location loc);
102 
103 bool IsOpNotQuantizable(Operation* op);
104 
105 // Specialized version of location to string for flatbuffer exported locations.
GetTensorNameFromLoc(Location loc)106 inline std::string GetTensorNameFromLoc(Location loc) {
107   if (auto name_loc = loc.dyn_cast<NameLoc>()) {
108     return name_loc.getName().str();
109   }
110   return "";
111 }
112 
113 template <typename Q, typename DQ>
114 struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
ConvertStatsToQDQsConvertStatsToQDQs115   ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed,
116                      bool legacy_float_scale, MLIRContext* context)
117       : OpRewritePattern<quant::StatisticsOp>(context),
118         num_bits(num_bits),
119         narrow_range(narrow_range),
120         is_signed(is_signed),
121         legacy_float_scale(legacy_float_scale) {}
122 
matchAndRewriteConvertStatsToQDQs123   LogicalResult matchAndRewrite(quant::StatisticsOp op,
124                                 PatternRewriter& rewriter) const override {
125     Type expressed = op.getType().cast<ShapedType>().getElementType();
126     quant::QuantizedType quant_type;
127     SmallVector<double, 4> mins, maxs;
128 
129     if (op.axisStats().hasValue()) {
130       int stats_num = op.axisStats()->getNumElements();
131       if (stats_num == 0 || stats_num % 2 != 0) return failure();
132       auto stats = op.axisStats()->dyn_cast<DenseFPElementsAttr>();
133       if (!stats) return failure();
134 
135       for (auto it = stats.begin(), e = stats.end(); it != e; ++it) {
136         double rmin = FloatAttr::getValueAsDouble(*it++);
137         double rmax = FloatAttr::getValueAsDouble(*it);
138         // The default nudging implementation of mlir quant library might cause
139         // clamping during inference if the calibration range isn't wide enough.
140         // So here we adjust the range to include 0.0.
141         rmin = std::min(rmin, 0.0);
142         rmax = std::max(rmax, 0.0);
143         TensorRangeSanityCheck(op, rmin, rmax);
144         mins.push_back(rmin);
145         maxs.push_back(rmax);
146       }
147       quant_type =
148           quant::fakeQuantAttrsToType(op.getLoc(), num_bits, *op.axis(), mins,
149                                       maxs, narrow_range, expressed, is_signed);
150       if (legacy_float_scale) {
151         quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc());
152       }
153     } else if (auto stats = op.layerStats().dyn_cast<DenseFPElementsAttr>()) {
154       double rmin = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
155       double rmax = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
156       // The default nudging implementation of mlir quant library might cause
157       // clamping during inference if the calibration range isn't wide enough.
158       // So here we adjust the range to include 0.0.
159       rmin = std::min(rmin, 0.0);
160       rmax = std::max(rmax, 0.0);
161       TensorRangeSanityCheck(op, rmin, rmax);
162       quant_type =
163           quant::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax,
164                                       narrow_range, expressed, is_signed);
165       if (legacy_float_scale) {
166         quant_type = DownCastScale(quant_type, rmin, rmax, op->getLoc());
167       }
168     } else {
169       return failure();
170     }
171 
172     rewriter.setInsertionPointAfter(op.getOperation());
173     Type result_type = quant_type.castFromExpressedType(op.getType());
174     auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg());
175     q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr());
176 
177     auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
178     op.getResult().replaceAllUsesWith(dq);
179     q.getOperation()->replaceUsesOfWith(dq, op.arg());
180     op.erase();
181 
182     return success();
183   }
184 
185  private:
186   int num_bits;
187   bool narrow_range;
188   bool is_signed;
189   bool legacy_float_scale;
190 
191   // Emits an op warning message if the calibrated range is larger than 10.0 and
192   // the storage type is less than or equal to 8 bits.
TensorRangeSanityCheckConvertStatsToQDQs193   void TensorRangeSanityCheck(quant::StatisticsOp op, double min,
194                               double max) const {
195     double range = std::fabs(max - min);
196     if (num_bits <= 8 && range >= 10.0) {
197       op.emitWarning(
198           "Tensor range is too wide to be quantized. Use tf.clip_by_value or "
199           "tf.relu6 to narrow the tensor range. Range: " +
200           std::to_string(range) + ", bit width: " + std::to_string(num_bits));
201     }
202   }
203 };
204 
205 // A base rewrite pattern which matches any N-in-M-out operations with
206 // quantization parameters propagated to at least one of its operands. The
207 // quantization parameters are annotated by the Q/DQ op pairs. Each
208 // matched pattern are rewritten by its quantized alternatives.
209 //
210 // The concrete pattern, extends from this base pattern, can specify whether it
211 // allows "hybrid" operands or results. These "hybrid" operands and results
212 // don't have quantization parameters propagated to, so will be in float in the
213 // quantized results. The concrete pattern should define the following two
214 // functions:
215 //
216 //   bool AllowHybridOperand() const
217 //   bool AllowHybridResult() const
218 //
219 // Full integer quantization disallows "hybrid" operands or results.
220 // Weight quantization allows "hybrid" operands and results.
221 template <typename ConcretTy, typename Q, typename DQ, typename VERIFIER,
222           typename RootOp = DQ>
223 struct QuantizationPattern : public RewritePattern {
224   using BaseType = QuantizationPattern<ConcretTy, Q, DQ, VERIFIER, RootOp>;
225 
226   explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
227                                float error_tolerance, bool whole_model_verify,
228                                bool log_if_failed = false,
229                                const StringSet& ops_blocklist = {},
230                                const StringSet& nodes_blocklist = {})
231       // Set the score to a large number so it is always preferred.
232       : RewritePattern(RootOp::getOperationName(), 300, context),
233         enable_verify(enable_verify),
234         error_tolerance(error_tolerance),
235         whole_model_verify(whole_model_verify),
236         log_if_failed(log_if_failed),
237         ops_blocklist(ops_blocklist),
238         nodes_blocklist(nodes_blocklist) {}
239 
matchAndRewriteQuantizationPattern240   LogicalResult matchAndRewrite(Operation* op,
241                                 PatternRewriter& rewriter) const override {
242     llvm::SmallVector<Operation*, 4> quantized_ops;
243 
244     // Collect all the quantized ops as the user / def of the root op.
245     if (std::is_same<RootOp, DQ>::value) {
246       if (op->getNumResults() != 1) {
247         return failure();
248       }
249       auto users = op->getResult(0).getUsers();
250       quantized_ops.append(users.begin(), users.end());
251     } else if (std::is_same<RootOp, Q>::value) {
252       if (op->getNumOperands() != 1) {
253         return failure();
254       }
255       Value quantize_operand = op->getOperand(0);
256       if (QuantizedType::getQuantizedElementType(quantize_operand.getType())) {
257         // The input of this Q op has been quantized, i.e. rescale.
258         return failure();
259       }
260       DenseFPElementsAttr attr;
261       if (matchPattern(quantize_operand, m_Constant(&attr))) {
262         // Const->Q pattern will be handled seperately.
263         return failure();
264       }
265       if (Operation* quantized_op = quantize_operand.getDefiningOp()) {
266         quantized_ops.push_back(quantized_op);
267       }
268     }
269 
270     // Rewrite the quantized ops from floating-point to quantized version.
271     for (Operation* quantized_op : quantized_ops) {
272       // If it is requantize op, we shouldn't rewrite this op.
273       if (llvm::isa<Q, DQ>(quantized_op)) {
274         return failure();
275       }
276 
277       // If it is terminator or not quantizable or any ops form the mlir quant
278       // ops dialect, we shouldn't rewrite.
279       if (IsOpNotQuantizable(quantized_op)) {
280         return failure();
281       }
282 
283       if (!ops_blocklist.empty() &&
284           (ops_blocklist.find(quantized_op->getName().getStringRef().str()) !=
285            ops_blocklist.end())) {
286         return failure();
287       }
288 
289       if (!nodes_blocklist.empty()) {
290         if (auto name_loc = quantized_op->getLoc().dyn_cast<NameLoc>()) {
291           std::string sloc = name_loc.getName().str();
292           if (!sloc.empty() &&
293               (nodes_blocklist.find(sloc) != nodes_blocklist.end())) {
294             return failure();
295           }
296         }
297       }
298 
299       // An op with float inputs and outputs are expected when it's used by a
300       // NumericVerify op. Skip this op and look at next users.
301       if (enable_verify) {
302         bool used_by_verifier = false;
303         for (auto result : quantized_op->getResults()) {
304           if (used_by_verifier) break;
305           for (auto user : result.getUsers()) {
306             if (llvm::isa<VERIFIER>(user)) {
307               used_by_verifier = true;
308               break;
309             }
310           }
311         }
312         if (used_by_verifier) continue;
313       }
314 
315       // Collect all the quantized inputs and "clone" the matched op by these
316       // inputs.
317       SmallVector<Value, 4> inputs;
318       inputs.reserve(quantized_op->getNumOperands());
319       for (auto operand : quantized_op->getOperands()) {
320         Type operand_type = operand.getType();
321         if (operand_type.isa<NoneType>()) {
322           inputs.push_back(operand);
323           continue;
324         }
325 
326         auto ele_type = operand.getType().cast<TensorType>().getElementType();
327         if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
328           inputs.push_back(op_inst.input());
329         } else if (!ele_type.isF32()) {
330           // If the operand is an integer tensor, then it doesn't require the
331           // DQ op in the pattern.
332           inputs.push_back(operand);
333         } else if (static_cast<const ConcretTy*>(this)->AllowHybridOperand()) {
334           inputs.push_back(operand);
335         } else {
336           return failure();
337         }
338       }
339 
340       // Collect all the quantized outputs and replace them by the results of
341       // the new quantized op.
342       llvm::SmallDenseMap<Value, int> outputs_replaced;
343       SmallVector<Type, 4> output_types;
344       output_types.reserve(quantized_op->getNumResults());
345       for (auto enumerated_result :
346            llvm::enumerate(quantized_op->getResults())) {
347         Value result = enumerated_result.value();
348         Type result_type = result.getType();
349         // Add this to the test coverage once we create test ops with none type
350         // results.
351         if (result_type.isa<NoneType>()) {
352           outputs_replaced.insert({result, enumerated_result.index()});
353           output_types.push_back(result_type);
354           continue;
355         }
356         Type result_ele_type =
357             result.getType().cast<TensorType>().getElementType();
358         // If the user is the Quantize op, it must be the only user.
359         if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) {
360           auto user = llvm::cast<Q>(*result.user_begin());
361           outputs_replaced.insert({user.output(), enumerated_result.index()});
362           output_types.push_back(user.getType());
363         } else if (!result_ele_type.isF32()) {
364           // If the result is an integer tensor, then it doesn't require the
365           // D op in the pattern.
366           outputs_replaced.insert({result, enumerated_result.index()});
367           output_types.push_back(result.getType());
368         } else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) {
369           outputs_replaced.insert({result, enumerated_result.index()});
370           output_types.push_back(result.getType());
371         } else {
372           return failure();
373         }
374       }
375 
376       rewriter.setInsertionPointAfter(quantized_op);
377       OperationState new_state(quantized_op->getLoc(),
378                                quantized_op->getName().getStringRef(), inputs,
379                                output_types, quantized_op->getAttrs());
380       for (int i = 0; i < quantized_op->getNumRegions(); ++i) {
381         new_state.addRegion();
382       }
383       Operation* new_op = rewriter.createOperation(new_state);
384       if (quantized_op->getNumRegions() != 0) {
385         for (auto indexed_regions :
386              llvm::enumerate(quantized_op->getRegions())) {
387           Region& target_region = new_op->getRegion(indexed_regions.index());
388           BlockAndValueMapping mapping;
389           indexed_regions.value().cloneInto(&target_region, mapping);
390         }
391       }
392       for (auto output : outputs_replaced) {
393         output.getFirst().replaceAllUsesWith(
394             new_op->getResult(output.getSecond()));
395       }
396 
397       // To verify the numericals, the original floating-point ops are
398       // preserved in the graph. The result of these floating-point ops are sent
399       // to a numeric verifier op as the reference.
400       if (enable_verify) {
401         // For constant operands, the floating-point constant is duplicated in
402         // case it is quantized.
403         for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) {
404           auto def = new_op->getOperand(i).getDefiningOp();
405           if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
406             DenseFPElementsAttr attr;
407             if (!matchPattern(q.input(), m_Constant(&attr))) {
408               continue;
409             }
410             auto cst = rewriter.create<ConstantOp>(new_op->getLoc(), attr);
411             quantized_op->setOperand(i, cst.getResult());
412           }
413         }
414 
415         for (int i = 0, e = new_op->getNumResults(); i != e; ++i) {
416           if (!quantized_op->getResult(i)
417                    .getType()
418                    .cast<ShapedType>()
419                    .getElementType()
420                    .isa<FloatType>()) {
421             continue;
422           }
423           rewriter.setInsertionPointAfter(new_op);
424           FloatAttr tolerance = rewriter.getF32FloatAttr(error_tolerance);
425           BoolAttr log = rewriter.getBoolAttr(log_if_failed);
426           // Verify the quantized value by sending the result to the verifier.
427           rewriter.create<VERIFIER>(
428               quantized_op->getLoc(), new_op->getResult(i).getType(),
429               new_op->getResult(i), quantized_op->getResult(i), tolerance, log);
430 
431           if (!whole_model_verify) continue;
432 
433           // Find the Dequantize/Dequantize users of the new op results, and
434           // replace the usage. Then all the floating-point ops are connected.
435           // N.B. the return op will use this floating-point result.
436           for (auto user : new_op->getResult(i).getUsers()) {
437             // Skip the Requantize op, and we know it has a single user.
438             if (llvm::isa<Q>(user)) {
439               user = *user->getResult(0).getUsers().begin();
440             }
441             if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
442               dequantize.getResult().replaceAllUsesWith(
443                   quantized_op->getResult(i));
444             }
445           }
446         }
447       }
448     }
449     return success();
450   }
451 
452   bool enable_verify;
453   float error_tolerance;
454   bool whole_model_verify;
455   bool log_if_failed;
456   const StringSet ops_blocklist;
457   const StringSet nodes_blocklist;
458 };
459 
460 // Converts quantized tensor type with signed integer type to quantized tensor
461 // type with unsigned integer type.
462 Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc);
463 
464 // Converts quantize ops with unsigned quantized types to these with signed
465 // quantized types and preserves the scales.
466 template <typename Q>
467 struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
468   using BaseType = ConvertUnsignedToSigned<Q>;
469   using QType = quant::QuantizedType;
470 
ConvertUnsignedToSignedConvertUnsignedToSigned471   explicit ConvertUnsignedToSigned(MLIRContext* context)
472       : OpRewritePattern<Q>(context, 1) {}
473 
matchAndRewriteConvertUnsignedToSigned474   LogicalResult matchAndRewrite(Q op,
475                                 PatternRewriter& rewriter) const override {
476     Type output_type = op.getResult().getType();
477     auto qtype = QType::getQuantizedElementType(output_type);
478     if (!qtype || qtype.isSigned()) return failure();
479 
480     int num_bits = qtype.getStorageTypeIntegralWidth();
481     // This is a positive value, and will be applied on zero points and fixed
482     // point ranges.
483     int64_t offset =
484         QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits) -
485         QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits);
486 
487     auto flags = quant::QuantizationFlags::Signed;
488     QType new_qtype;
489     if (auto uqtype = qtype.template dyn_cast<quant::UniformQuantizedType>()) {
490       new_qtype = quant::UniformQuantizedType::getChecked(
491           op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(),
492           uqtype.getScale(), uqtype.getZeroPoint() - offset,
493           uqtype.getStorageTypeMin() - offset,
494           uqtype.getStorageTypeMax() - offset);
495     } else if (auto aqtype = qtype.template dyn_cast<
496                              quant::UniformQuantizedPerAxisType>()) {
497       auto zero_points = aqtype.getZeroPoints();
498       llvm::SmallVector<int64_t, 4> new_zero_points(zero_points.begin(),
499                                                     zero_points.end());
500       for (int i = 0, e = new_zero_points.size(); i != e; ++i) {
501         new_zero_points[i] -= offset;
502       }
503       new_qtype = quant::UniformQuantizedPerAxisType::getChecked(
504           op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(),
505           aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(),
506           aqtype.getStorageTypeMin() - offset,
507           aqtype.getStorageTypeMax() - offset);
508     } else {
509       return failure();
510     }
511 
512     if (!new_qtype) return failure();
513     Type new_output_type = new_qtype.castFromExpressedType(
514         QType::castToExpressedType(output_type));
515     rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg());
516     return success();
517   }
518 };
519 
520 // Fold Extra Requantize ops if the preceding ops has free scale requirement.
521 template <typename RQ>
522 struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
FoldTrivalRequantizeOpFoldTrivalRequantizeOp523   explicit FoldTrivalRequantizeOp(MLIRContext* context)
524       : OpRewritePattern<RQ>(context, 1) {}
525 
matchAndRewriteFoldTrivalRequantizeOp526   LogicalResult matchAndRewrite(RQ op,
527                                 PatternRewriter& rewriter) const override {
528     Value pre_quantized = op.input();
529     auto pre_quantized_type =
530         quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
531     if (!pre_quantized_type) return failure();
532 
533     Operation* def = pre_quantized.getDefiningOp();
534     if (!def) return failure();
535     if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) ||
536         def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
537       return failure();
538     }
539 
540     op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
541 
542     llvm::SmallVector<Type, 4> new_output_types;
543     for (auto result : def->getResults()) {
544       if (result.hasOneUse() && *result.getUsers().begin() == op) {
545         new_output_types.push_back(op.qtype());
546       } else {
547         new_output_types.push_back(result.getType());
548       }
549     }
550 
551     // Remove this rescale op.
552     rewriter.replaceOp(op, {pre_quantized});
553 
554     // Replace the output scale of the preceding op.
555     rewriter.setInsertionPointAfter(def);
556     OperationState new_state(def->getLoc(), def->getName().getStringRef(),
557                              def->getOperands(), new_output_types,
558                              def->getAttrs());
559     Operation* new_op = rewriter.createOperation(new_state);
560 
561     rewriter.replaceOp(def, new_op->getResults());
562     return success();
563   }
564 };
565 
566 // Given a quantized type `input`, magnifying its scales by the factor stored in
567 // `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
568 // dimension size of `input` or isn't floating-point, nullptr will be returned.
569 TypeAttr RescaleQuantizedType(Type input, Attribute factor);
570 
571 // Converts the min/max/num_bits/narrow_range information to a
572 // QuantizedType, and then returns the attribute containing the QuantizedType.
573 // The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and
574 // returns UniformQuantizedType or UniformQuantizedPerAxisType respectively.
575 // `narrow_range` is set to true for weights and `is_signed` is set to true
576 // if it is using signed int symmetric quantization.
577 //
578 // Note that this method may broadcast min and max to match the dimension length
579 // of `input_type`, if the `quant_dim` is valid. On the other hand, the
580 // symmetry of min and max is not adjusted by this method. The QAT workflow
581 // should set min/max correctly (and use `narrow_range`=true, `is_signed`=true)
582 // if symmetric quantization is required.
583 TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
584                               Attribute max, int quant_dim,
585                               IntegerAttr num_bits, BoolAttr narrow_range,
586                               bool is_signed, bool legacy_float_scale = false);
587 
588 // Casts the `target` type to a quantized type by using the quantization
589 // parameters from the type in the `source` type attribute.
590 // Examples:
591 //   f32 -> !quant.uniform<i8:f32, 1.0>
592 //   tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
593 // The result is wrapped by a type attribute. Returns nullptr if the cast
594 // isn't valid.
595 //
596 // `axis` is to specify the quantization dimension in the `target` and only
597 // used if the element type of `source` is a per-channel quantized type. During
598 // the casting, the quantization dimension of the result type needs to be set
599 // this new `axis` value.
600 TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
601                                                 TypeAttr source, Type target,
602                                                 int axis);
603 
604 // Quantizes the elements in the attribute `real_value` by the quantization
605 // parameters in `tensor_type`. Returns empty Attribute if the
606 // `tensor_type` is not a QuantizedType or the quantization fails.
607 ElementsAttr Quantize(Attribute real_value, Type tensor_type);
608 
609 // Quantizes the elements in "legacy mode", where it calls TOCO's methods to
610 // to quantize values with float scale.
611 ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type);
612 
613 // Returns the quantized type for an element attribute. The quantization
614 // parameters in this type is based on the min and max element of the
615 // attribute. When the elements in the `attr` are not in floating-point, or
616 // the value range isn't straddling zero, an empty type is returned. The min/max
617 // are adjusted to be symmetric if `symmetric` flag is set to True. And
618 // `symmetric` can only be set to true when it is signed and narrow_range.
619 Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,
620                                       unsigned num_bits, bool is_signed,
621                                       bool narrow_range,
622                                       bool legacy_float_scale = false);
623 
624 // Returns the per channel quantized type for an element attribute.
625 // `quant_dim` defines the quantization axis. The channel min/max are adjusted
626 // to be symmetric if `symmetric` flag is set to True. And `symmetric` can only
627 // be set to true when it is signed and narrow_range.
628 Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim,
629                                              bool symmetric, unsigned num_bits,
630                                              bool is_signed, bool narrow_range,
631                                              bool legacy_float_scale = false);
632 
633 // Returns the quantized type of a bias input, given the quantized types of
634 // other operands which are multiply-accumulated (the bias is added to the
635 // accumulated value).
636 quant::QuantizedType GetUniformQuantizedTypeForBias(
637     const std::vector<quant::QuantizedType>& op_types,
638     bool legacy_float_scale = false);
639 
640 // Propagates quantization parameters across ops in this function and satisfy
641 // the quantization specification of the ops. This methods assumes the initial
642 // quantization parameters are stored as adjacent quantize and dequantize ops
643 // and the propagation results are materialized by inserting pairs of quantize
644 // and dequantize ops to this function. Set `disable_per_channel` to true to not
645 // use per channel quantization even the op supports it.
646 // Setting `infer_tensor_range` to true, to infer quantization parameters from
647 // the activation ops and weight constants. This is only used for post-training
648 // quantization.
649 void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
650                                         bool disable_per_channel,
651                                         OpQuantSpecGetter op_quant_spec_getter,
652                                         bool infer_tensor_ranges,
653                                         bool legacy_float_scale = false);
654 
655 // The function might contain more stats ops than required, and it will
656 // introduce requantize if the calibration stats have conflicts. This method
657 // tries to remove all the redundant stats ops.
658 bool RemoveRedundantStatsOps(mlir::FuncOp func,
659                              OpQuantSpecGetter op_quant_spec_getter);
660 
661 // Given quantization parameters for int8, compute the quantization parameters
662 // for uint if it is required, and wrap the result in an UniformQuantizedType.
663 quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
664                                                 Type tensor_type, double scale,
665                                                 int64_t zero_point,
666                                                 int64_t storage_min = -128,
667                                                 int64_t storage_max = 127);
668 
669 // Extrace min and max values from the DenseFPElementsAttr, and stores them into
670 // `mins` and `maxs`. When mins and maxs are extracted per-channel, `dim_size`
671 // is number of channels and `slice_size` is the size of slice per each channel.
672 // When `symmetric` is true, the range is expanded to [-M, M].
673 void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
674                            int slice_size, bool symmetric,
675                            SmallVectorImpl<double>& mins,
676                            SmallVectorImpl<double>& maxs);
677 
678 // Returns the quantized type for the
679 // input_type/min/max/storag_type_width/narrow_range.
680 Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min,
681                       ArrayRef<double> max, int quant_dim,
682                       int storage_type_width, bool narrow_range, bool is_signed,
683                       bool legacy_float_scale = false);
684 }  // namespace quant
685 }  // namespace mlir
686 
687 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
688