1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This header file defines common utils used by TFLite transformation
17 // passes to work with op attributes.
18
19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
20 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
21
22 #include <string>
23 #include <unordered_map>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/Twine.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
33 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
34 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
35 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
36 #include "mlir/IR/Attributes.h" // from @llvm-project
37 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
38 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
39 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
40 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
41 #include "mlir/IR/MLIRContext.h" // from @llvm-project
42 #include "mlir/IR/Matchers.h" // from @llvm-project
43 #include "mlir/IR/PatternMatch.h" // from @llvm-project
44 #include "mlir/Support/LLVM.h" // from @llvm-project
45 #include "mlir/Support/LogicalResult.h" // from @llvm-project
46 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
47
48 namespace mlir {
49 namespace quant {
50
51 // A unit attribute can be attached to the quantize/dequantize ops which are
52 // added by the quantization passes. These ops can be removed erased without
53 // losing accuracy.
54 constexpr char kVolatileOpAttrName[] = "volatile";
55
56 enum QuantizationTrait { FullyQuantizable, NotQuantizable };
57 extern const char kQuantTraitAttr[];
58 extern const absl::string_view QuantTraitValues[];
59
60 using QuantParams = quant::QuantizedType;
61 using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
62 using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
63 using AccumulatorScaleFunc =
64 std::function<QuantParams(const std::vector<QuantParams>&, bool)>;
65 using StringSet = absl::flat_hash_set<std::string>;
66
67 // Quantization spec of an op, driving the quantization algorithm.
68 struct OpQuantSpec {
69 // Maps the operand index of a bias input to its quantization specifications,
70 // including the non-bias operand indexes and the method retrieving
71 // quantization parameters from list of parameters of the non-bias operands.
72 // This map is empty if the op doesn't have a bias operand.
73 std::unordered_map<int, std::pair<std::vector<int>, AccumulatorScaleFunc>>
74 biases_params;
75
76 // Quantization parameters for value restricted outputs. This is the
77 // "hard-coded" parameters and should be used unconditionally for the
78 // quantized op. This vector is empty if the op doesn't have value restricted
79 // outputs.
80 llvm::DenseMap<SignedInteger, QuantParamsForResults> restricted_output_params;
81
82 // Coefficient operand index and whether supporting per-channel quantization.
83 // For QAT, this information is carried by the FakeQuant*/QDQ ops, but
84 // post-training quantization, the quantization parameters need to be inferred
85 // from the tensor content and op property. A "-1" value indicates the
86 // operand doesn't support per-channel quantization.
87 llvm::DenseMap<int, int> coeff_op_quant_dim;
88 };
89
90 // A function signature for getting the particular OpQuantSpec for the provided
91 // op.
92 typedef std::unique_ptr<OpQuantSpec> (*OpQuantSpecGetter)(Operation* op);
93
94 // Re-calculates scales again in float instead of simply downcasting existing
95 // scales.
96 QuantizedType DownCastScale(QuantizedType type,
97 const SmallVectorImpl<double>& mins,
98 const SmallVectorImpl<double>& maxs, Location loc);
99
100 QuantizedType DownCastScale(QuantizedType type, double min, double max,
101 Location loc);
102
103 bool IsOpNotQuantizable(Operation* op);
104
105 // Specialized version of location to string for flatbuffer exported locations.
GetTensorNameFromLoc(Location loc)106 inline std::string GetTensorNameFromLoc(Location loc) {
107 if (auto name_loc = loc.dyn_cast<NameLoc>()) {
108 return name_loc.getName().str();
109 }
110 return "";
111 }
112
113 template <typename Q, typename DQ>
114 struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
ConvertStatsToQDQsConvertStatsToQDQs115 ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed,
116 bool legacy_float_scale, MLIRContext* context)
117 : OpRewritePattern<quant::StatisticsOp>(context),
118 num_bits(num_bits),
119 narrow_range(narrow_range),
120 is_signed(is_signed),
121 legacy_float_scale(legacy_float_scale) {}
122
matchAndRewriteConvertStatsToQDQs123 LogicalResult matchAndRewrite(quant::StatisticsOp op,
124 PatternRewriter& rewriter) const override {
125 Type expressed = op.getType().cast<ShapedType>().getElementType();
126 quant::QuantizedType quant_type;
127 SmallVector<double, 4> mins, maxs;
128
129 if (op.axisStats().hasValue()) {
130 int stats_num = op.axisStats()->getNumElements();
131 if (stats_num == 0 || stats_num % 2 != 0) return failure();
132 auto stats = op.axisStats()->dyn_cast<DenseFPElementsAttr>();
133 if (!stats) return failure();
134
135 for (auto it = stats.begin(), e = stats.end(); it != e; ++it) {
136 double rmin = FloatAttr::getValueAsDouble(*it++);
137 double rmax = FloatAttr::getValueAsDouble(*it);
138 // The default nudging implementation of mlir quant library might cause
139 // clamping during inference if the calibration range isn't wide enough.
140 // So here we adjust the range to include 0.0.
141 rmin = std::min(rmin, 0.0);
142 rmax = std::max(rmax, 0.0);
143 TensorRangeSanityCheck(op, rmin, rmax);
144 mins.push_back(rmin);
145 maxs.push_back(rmax);
146 }
147 quant_type =
148 quant::fakeQuantAttrsToType(op.getLoc(), num_bits, *op.axis(), mins,
149 maxs, narrow_range, expressed, is_signed);
150 if (legacy_float_scale) {
151 quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc());
152 }
153 } else if (auto stats = op.layerStats().dyn_cast<DenseFPElementsAttr>()) {
154 double rmin = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
155 double rmax = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
156 // The default nudging implementation of mlir quant library might cause
157 // clamping during inference if the calibration range isn't wide enough.
158 // So here we adjust the range to include 0.0.
159 rmin = std::min(rmin, 0.0);
160 rmax = std::max(rmax, 0.0);
161 TensorRangeSanityCheck(op, rmin, rmax);
162 quant_type =
163 quant::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax,
164 narrow_range, expressed, is_signed);
165 if (legacy_float_scale) {
166 quant_type = DownCastScale(quant_type, rmin, rmax, op->getLoc());
167 }
168 } else {
169 return failure();
170 }
171
172 rewriter.setInsertionPointAfter(op.getOperation());
173 Type result_type = quant_type.castFromExpressedType(op.getType());
174 auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg());
175 q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr());
176
177 auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
178 op.getResult().replaceAllUsesWith(dq);
179 q.getOperation()->replaceUsesOfWith(dq, op.arg());
180 op.erase();
181
182 return success();
183 }
184
185 private:
186 int num_bits;
187 bool narrow_range;
188 bool is_signed;
189 bool legacy_float_scale;
190
191 // Emits an op warning message if the calibrated range is larger than 10.0 and
192 // the storage type is less than or equal to 8 bits.
TensorRangeSanityCheckConvertStatsToQDQs193 void TensorRangeSanityCheck(quant::StatisticsOp op, double min,
194 double max) const {
195 double range = std::fabs(max - min);
196 if (num_bits <= 8 && range >= 10.0) {
197 op.emitWarning(
198 "Tensor range is too wide to be quantized. Use tf.clip_by_value or "
199 "tf.relu6 to narrow the tensor range. Range: " +
200 std::to_string(range) + ", bit width: " + std::to_string(num_bits));
201 }
202 }
203 };
204
205 // A base rewrite pattern which matches any N-in-M-out operations with
206 // quantization parameters propagated to at least one of its operands. The
207 // quantization parameters are annotated by the Q/DQ op pairs. Each
208 // matched pattern are rewritten by its quantized alternatives.
209 //
210 // The concrete pattern, extends from this base pattern, can specify whether it
211 // allows "hybrid" operands or results. These "hybrid" operands and results
212 // don't have quantization parameters propagated to, so will be in float in the
213 // quantized results. The concrete pattern should define the following two
214 // functions:
215 //
216 // bool AllowHybridOperand() const
217 // bool AllowHybridResult() const
218 //
219 // Full integer quantization disallows "hybrid" operands or results.
220 // Weight quantization allows "hybrid" operands and results.
221 template <typename ConcretTy, typename Q, typename DQ, typename VERIFIER,
222 typename RootOp = DQ>
223 struct QuantizationPattern : public RewritePattern {
224 using BaseType = QuantizationPattern<ConcretTy, Q, DQ, VERIFIER, RootOp>;
225
226 explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
227 float error_tolerance, bool whole_model_verify,
228 bool log_if_failed = false,
229 const StringSet& ops_blocklist = {},
230 const StringSet& nodes_blocklist = {})
231 // Set the score to a large number so it is always preferred.
232 : RewritePattern(RootOp::getOperationName(), 300, context),
233 enable_verify(enable_verify),
234 error_tolerance(error_tolerance),
235 whole_model_verify(whole_model_verify),
236 log_if_failed(log_if_failed),
237 ops_blocklist(ops_blocklist),
238 nodes_blocklist(nodes_blocklist) {}
239
matchAndRewriteQuantizationPattern240 LogicalResult matchAndRewrite(Operation* op,
241 PatternRewriter& rewriter) const override {
242 llvm::SmallVector<Operation*, 4> quantized_ops;
243
244 // Collect all the quantized ops as the user / def of the root op.
245 if (std::is_same<RootOp, DQ>::value) {
246 if (op->getNumResults() != 1) {
247 return failure();
248 }
249 auto users = op->getResult(0).getUsers();
250 quantized_ops.append(users.begin(), users.end());
251 } else if (std::is_same<RootOp, Q>::value) {
252 if (op->getNumOperands() != 1) {
253 return failure();
254 }
255 Value quantize_operand = op->getOperand(0);
256 if (QuantizedType::getQuantizedElementType(quantize_operand.getType())) {
257 // The input of this Q op has been quantized, i.e. rescale.
258 return failure();
259 }
260 DenseFPElementsAttr attr;
261 if (matchPattern(quantize_operand, m_Constant(&attr))) {
262 // Const->Q pattern will be handled seperately.
263 return failure();
264 }
265 if (Operation* quantized_op = quantize_operand.getDefiningOp()) {
266 quantized_ops.push_back(quantized_op);
267 }
268 }
269
270 // Rewrite the quantized ops from floating-point to quantized version.
271 for (Operation* quantized_op : quantized_ops) {
272 // If it is requantize op, we shouldn't rewrite this op.
273 if (llvm::isa<Q, DQ>(quantized_op)) {
274 return failure();
275 }
276
277 // If it is terminator or not quantizable or any ops form the mlir quant
278 // ops dialect, we shouldn't rewrite.
279 if (IsOpNotQuantizable(quantized_op)) {
280 return failure();
281 }
282
283 if (!ops_blocklist.empty() &&
284 (ops_blocklist.find(quantized_op->getName().getStringRef().str()) !=
285 ops_blocklist.end())) {
286 return failure();
287 }
288
289 if (!nodes_blocklist.empty()) {
290 if (auto name_loc = quantized_op->getLoc().dyn_cast<NameLoc>()) {
291 std::string sloc = name_loc.getName().str();
292 if (!sloc.empty() &&
293 (nodes_blocklist.find(sloc) != nodes_blocklist.end())) {
294 return failure();
295 }
296 }
297 }
298
299 // An op with float inputs and outputs are expected when it's used by a
300 // NumericVerify op. Skip this op and look at next users.
301 if (enable_verify) {
302 bool used_by_verifier = false;
303 for (auto result : quantized_op->getResults()) {
304 if (used_by_verifier) break;
305 for (auto user : result.getUsers()) {
306 if (llvm::isa<VERIFIER>(user)) {
307 used_by_verifier = true;
308 break;
309 }
310 }
311 }
312 if (used_by_verifier) continue;
313 }
314
315 // Collect all the quantized inputs and "clone" the matched op by these
316 // inputs.
317 SmallVector<Value, 4> inputs;
318 inputs.reserve(quantized_op->getNumOperands());
319 for (auto operand : quantized_op->getOperands()) {
320 Type operand_type = operand.getType();
321 if (operand_type.isa<NoneType>()) {
322 inputs.push_back(operand);
323 continue;
324 }
325
326 auto ele_type = operand.getType().cast<TensorType>().getElementType();
327 if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
328 inputs.push_back(op_inst.input());
329 } else if (!ele_type.isF32()) {
330 // If the operand is an integer tensor, then it doesn't require the
331 // DQ op in the pattern.
332 inputs.push_back(operand);
333 } else if (static_cast<const ConcretTy*>(this)->AllowHybridOperand()) {
334 inputs.push_back(operand);
335 } else {
336 return failure();
337 }
338 }
339
340 // Collect all the quantized outputs and replace them by the results of
341 // the new quantized op.
342 llvm::SmallDenseMap<Value, int> outputs_replaced;
343 SmallVector<Type, 4> output_types;
344 output_types.reserve(quantized_op->getNumResults());
345 for (auto enumerated_result :
346 llvm::enumerate(quantized_op->getResults())) {
347 Value result = enumerated_result.value();
348 Type result_type = result.getType();
349 // Add this to the test coverage once we create test ops with none type
350 // results.
351 if (result_type.isa<NoneType>()) {
352 outputs_replaced.insert({result, enumerated_result.index()});
353 output_types.push_back(result_type);
354 continue;
355 }
356 Type result_ele_type =
357 result.getType().cast<TensorType>().getElementType();
358 // If the user is the Quantize op, it must be the only user.
359 if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) {
360 auto user = llvm::cast<Q>(*result.user_begin());
361 outputs_replaced.insert({user.output(), enumerated_result.index()});
362 output_types.push_back(user.getType());
363 } else if (!result_ele_type.isF32()) {
364 // If the result is an integer tensor, then it doesn't require the
365 // D op in the pattern.
366 outputs_replaced.insert({result, enumerated_result.index()});
367 output_types.push_back(result.getType());
368 } else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) {
369 outputs_replaced.insert({result, enumerated_result.index()});
370 output_types.push_back(result.getType());
371 } else {
372 return failure();
373 }
374 }
375
376 rewriter.setInsertionPointAfter(quantized_op);
377 OperationState new_state(quantized_op->getLoc(),
378 quantized_op->getName().getStringRef(), inputs,
379 output_types, quantized_op->getAttrs());
380 for (int i = 0; i < quantized_op->getNumRegions(); ++i) {
381 new_state.addRegion();
382 }
383 Operation* new_op = rewriter.createOperation(new_state);
384 if (quantized_op->getNumRegions() != 0) {
385 for (auto indexed_regions :
386 llvm::enumerate(quantized_op->getRegions())) {
387 Region& target_region = new_op->getRegion(indexed_regions.index());
388 BlockAndValueMapping mapping;
389 indexed_regions.value().cloneInto(&target_region, mapping);
390 }
391 }
392 for (auto output : outputs_replaced) {
393 output.getFirst().replaceAllUsesWith(
394 new_op->getResult(output.getSecond()));
395 }
396
397 // To verify the numericals, the original floating-point ops are
398 // preserved in the graph. The result of these floating-point ops are sent
399 // to a numeric verifier op as the reference.
400 if (enable_verify) {
401 // For constant operands, the floating-point constant is duplicated in
402 // case it is quantized.
403 for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) {
404 auto def = new_op->getOperand(i).getDefiningOp();
405 if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
406 DenseFPElementsAttr attr;
407 if (!matchPattern(q.input(), m_Constant(&attr))) {
408 continue;
409 }
410 auto cst = rewriter.create<ConstantOp>(new_op->getLoc(), attr);
411 quantized_op->setOperand(i, cst.getResult());
412 }
413 }
414
415 for (int i = 0, e = new_op->getNumResults(); i != e; ++i) {
416 if (!quantized_op->getResult(i)
417 .getType()
418 .cast<ShapedType>()
419 .getElementType()
420 .isa<FloatType>()) {
421 continue;
422 }
423 rewriter.setInsertionPointAfter(new_op);
424 FloatAttr tolerance = rewriter.getF32FloatAttr(error_tolerance);
425 BoolAttr log = rewriter.getBoolAttr(log_if_failed);
426 // Verify the quantized value by sending the result to the verifier.
427 rewriter.create<VERIFIER>(
428 quantized_op->getLoc(), new_op->getResult(i).getType(),
429 new_op->getResult(i), quantized_op->getResult(i), tolerance, log);
430
431 if (!whole_model_verify) continue;
432
433 // Find the Dequantize/Dequantize users of the new op results, and
434 // replace the usage. Then all the floating-point ops are connected.
435 // N.B. the return op will use this floating-point result.
436 for (auto user : new_op->getResult(i).getUsers()) {
437 // Skip the Requantize op, and we know it has a single user.
438 if (llvm::isa<Q>(user)) {
439 user = *user->getResult(0).getUsers().begin();
440 }
441 if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
442 dequantize.getResult().replaceAllUsesWith(
443 quantized_op->getResult(i));
444 }
445 }
446 }
447 }
448 }
449 return success();
450 }
451
452 bool enable_verify;
453 float error_tolerance;
454 bool whole_model_verify;
455 bool log_if_failed;
456 const StringSet ops_blocklist;
457 const StringSet nodes_blocklist;
458 };
459
460 // Converts quantized tensor type with signed integer type to quantized tensor
461 // type with unsigned integer type.
462 Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc);
463
464 // Converts quantize ops with unsigned quantized types to these with signed
465 // quantized types and preserves the scales.
466 template <typename Q>
467 struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
468 using BaseType = ConvertUnsignedToSigned<Q>;
469 using QType = quant::QuantizedType;
470
ConvertUnsignedToSignedConvertUnsignedToSigned471 explicit ConvertUnsignedToSigned(MLIRContext* context)
472 : OpRewritePattern<Q>(context, 1) {}
473
matchAndRewriteConvertUnsignedToSigned474 LogicalResult matchAndRewrite(Q op,
475 PatternRewriter& rewriter) const override {
476 Type output_type = op.getResult().getType();
477 auto qtype = QType::getQuantizedElementType(output_type);
478 if (!qtype || qtype.isSigned()) return failure();
479
480 int num_bits = qtype.getStorageTypeIntegralWidth();
481 // This is a positive value, and will be applied on zero points and fixed
482 // point ranges.
483 int64_t offset =
484 QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits) -
485 QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits);
486
487 auto flags = quant::QuantizationFlags::Signed;
488 QType new_qtype;
489 if (auto uqtype = qtype.template dyn_cast<quant::UniformQuantizedType>()) {
490 new_qtype = quant::UniformQuantizedType::getChecked(
491 op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(),
492 uqtype.getScale(), uqtype.getZeroPoint() - offset,
493 uqtype.getStorageTypeMin() - offset,
494 uqtype.getStorageTypeMax() - offset);
495 } else if (auto aqtype = qtype.template dyn_cast<
496 quant::UniformQuantizedPerAxisType>()) {
497 auto zero_points = aqtype.getZeroPoints();
498 llvm::SmallVector<int64_t, 4> new_zero_points(zero_points.begin(),
499 zero_points.end());
500 for (int i = 0, e = new_zero_points.size(); i != e; ++i) {
501 new_zero_points[i] -= offset;
502 }
503 new_qtype = quant::UniformQuantizedPerAxisType::getChecked(
504 op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(),
505 aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(),
506 aqtype.getStorageTypeMin() - offset,
507 aqtype.getStorageTypeMax() - offset);
508 } else {
509 return failure();
510 }
511
512 if (!new_qtype) return failure();
513 Type new_output_type = new_qtype.castFromExpressedType(
514 QType::castToExpressedType(output_type));
515 rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg());
516 return success();
517 }
518 };
519
520 // Fold Extra Requantize ops if the preceding ops has free scale requirement.
521 template <typename RQ>
522 struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
FoldTrivalRequantizeOpFoldTrivalRequantizeOp523 explicit FoldTrivalRequantizeOp(MLIRContext* context)
524 : OpRewritePattern<RQ>(context, 1) {}
525
matchAndRewriteFoldTrivalRequantizeOp526 LogicalResult matchAndRewrite(RQ op,
527 PatternRewriter& rewriter) const override {
528 Value pre_quantized = op.input();
529 auto pre_quantized_type =
530 quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
531 if (!pre_quantized_type) return failure();
532
533 Operation* def = pre_quantized.getDefiningOp();
534 if (!def) return failure();
535 if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) ||
536 def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
537 return failure();
538 }
539
540 op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
541
542 llvm::SmallVector<Type, 4> new_output_types;
543 for (auto result : def->getResults()) {
544 if (result.hasOneUse() && *result.getUsers().begin() == op) {
545 new_output_types.push_back(op.qtype());
546 } else {
547 new_output_types.push_back(result.getType());
548 }
549 }
550
551 // Remove this rescale op.
552 rewriter.replaceOp(op, {pre_quantized});
553
554 // Replace the output scale of the preceding op.
555 rewriter.setInsertionPointAfter(def);
556 OperationState new_state(def->getLoc(), def->getName().getStringRef(),
557 def->getOperands(), new_output_types,
558 def->getAttrs());
559 Operation* new_op = rewriter.createOperation(new_state);
560
561 rewriter.replaceOp(def, new_op->getResults());
562 return success();
563 }
564 };
565
566 // Given a quantized type `input`, magnifying its scales by the factor stored in
567 // `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
568 // dimension size of `input` or isn't floating-point, nullptr will be returned.
569 TypeAttr RescaleQuantizedType(Type input, Attribute factor);
570
571 // Converts the min/max/num_bits/narrow_range information to a
572 // QuantizedType, and then returns the attribute containing the QuantizedType.
573 // The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and
574 // returns UniformQuantizedType or UniformQuantizedPerAxisType respectively.
575 // `narrow_range` is set to true for weights and `is_signed` is set to true
576 // if it is using signed int symmetric quantization.
577 //
578 // Note that this method may broadcast min and max to match the dimension length
579 // of `input_type`, if the `quant_dim` is valid. On the other hand, the
580 // symmetry of min and max is not adjusted by this method. The QAT workflow
581 // should set min/max correctly (and use `narrow_range`=true, `is_signed`=true)
582 // if symmetric quantization is required.
583 TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
584 Attribute max, int quant_dim,
585 IntegerAttr num_bits, BoolAttr narrow_range,
586 bool is_signed, bool legacy_float_scale = false);
587
588 // Casts the `target` type to a quantized type by using the quantization
589 // parameters from the type in the `source` type attribute.
590 // Examples:
591 // f32 -> !quant.uniform<i8:f32, 1.0>
592 // tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
593 // The result is wrapped by a type attribute. Returns nullptr if the cast
594 // isn't valid.
595 //
596 // `axis` is to specify the quantization dimension in the `target` and only
597 // used if the element type of `source` is a per-channel quantized type. During
598 // the casting, the quantization dimension of the result type needs to be set
599 // this new `axis` value.
600 TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
601 TypeAttr source, Type target,
602 int axis);
603
604 // Quantizes the elements in the attribute `real_value` by the quantization
605 // parameters in `tensor_type`. Returns empty Attribute if the
606 // `tensor_type` is not a QuantizedType or the quantization fails.
607 ElementsAttr Quantize(Attribute real_value, Type tensor_type);
608
609 // Quantizes the elements in "legacy mode", where it calls TOCO's methods to
610 // to quantize values with float scale.
611 ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type);
612
613 // Returns the quantized type for an element attribute. The quantization
614 // parameters in this type is based on the min and max element of the
615 // attribute. When the elements in the `attr` are not in floating-point, or
616 // the value range isn't straddling zero, an empty type is returned. The min/max
617 // are adjusted to be symmetric if `symmetric` flag is set to True. And
618 // `symmetric` can only be set to true when it is signed and narrow_range.
619 Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,
620 unsigned num_bits, bool is_signed,
621 bool narrow_range,
622 bool legacy_float_scale = false);
623
624 // Returns the per channel quantized type for an element attribute.
625 // `quant_dim` defines the quantization axis. The channel min/max are adjusted
626 // to be symmetric if `symmetric` flag is set to True. And `symmetric` can only
627 // be set to true when it is signed and narrow_range.
628 Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim,
629 bool symmetric, unsigned num_bits,
630 bool is_signed, bool narrow_range,
631 bool legacy_float_scale = false);
632
633 // Returns the quantized type of a bias input, given the quantized types of
634 // other operands which are multiply-accumulated (the bias is added to the
635 // accumulated value).
636 quant::QuantizedType GetUniformQuantizedTypeForBias(
637 const std::vector<quant::QuantizedType>& op_types,
638 bool legacy_float_scale = false);
639
640 // Propagates quantization parameters across ops in this function and satisfy
641 // the quantization specification of the ops. This methods assumes the initial
642 // quantization parameters are stored as adjacent quantize and dequantize ops
643 // and the propagation results are materialized by inserting pairs of quantize
644 // and dequantize ops to this function. Set `disable_per_channel` to true to not
645 // use per channel quantization even the op supports it.
646 // Setting `infer_tensor_range` to true, to infer quantization parameters from
647 // the activation ops and weight constants. This is only used for post-training
648 // quantization.
649 void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
650 bool disable_per_channel,
651 OpQuantSpecGetter op_quant_spec_getter,
652 bool infer_tensor_ranges,
653 bool legacy_float_scale = false);
654
655 // The function might contain more stats ops than required, and it will
656 // introduce requantize if the calibration stats have conflicts. This method
657 // tries to remove all the redundant stats ops.
658 bool RemoveRedundantStatsOps(mlir::FuncOp func,
659 OpQuantSpecGetter op_quant_spec_getter);
660
661 // Given quantization parameters for int8, compute the quantization parameters
662 // for uint if it is required, and wrap the result in an UniformQuantizedType.
663 quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
664 Type tensor_type, double scale,
665 int64_t zero_point,
666 int64_t storage_min = -128,
667 int64_t storage_max = 127);
668
669 // Extrace min and max values from the DenseFPElementsAttr, and stores them into
670 // `mins` and `maxs`. When mins and maxs are extracted per-channel, `dim_size`
671 // is number of channels and `slice_size` is the size of slice per each channel.
672 // When `symmetric` is true, the range is expanded to [-M, M].
673 void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
674 int slice_size, bool symmetric,
675 SmallVectorImpl<double>& mins,
676 SmallVectorImpl<double>& maxs);
677
678 // Returns the quantized type for the
679 // input_type/min/max/storag_type_width/narrow_range.
680 Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min,
681 ArrayRef<double> max, int quant_dim,
682 int storage_type_width, bool narrow_range, bool is_signed,
683 bool legacy_float_scale = false);
684 } // namespace quant
685 } // namespace mlir
686
687 #endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
688