1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Transform pass for LSTMs.
17
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
19 #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
20
21 #include <algorithm>
22 #include <cmath>
23 #include <string>
24 #include <vector>
25
26 #include "absl/container/flat_hash_set.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/MathExtras.h"
30 #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
31 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
32 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
33 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
34 #include "mlir/IR/Attributes.h" // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
36 #include "mlir/IR/OpDefinition.h" // from @llvm-project
37 #include "mlir/IR/PatternMatch.h" // from @llvm-project
38 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
39 #include "mlir/IR/Value.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
43 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
44 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
45 #include "tensorflow/core/framework/types.pb.h"
46 #include "tensorflow/lite/schema/schema_generated.h"
47 #include "tensorflow/lite/tools/optimize/operator_property.h"
48
49 //===----------------------------------------------------------------------===//
50 // The prepare-quantize Pass for LSTM.
51 //
52 namespace mlir {
53 namespace TFL {
54
55 constexpr double power_of_two_scale = 32768.0;
56
57 // Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
58 constexpr const char* intermediate_attributes[] = {
59 "input_to_input_intermediate", "input_to_forget_intermediate",
60 "input_to_cell_intermediate", "input_to_output_intermediate",
61 "effective_hidden_scale_intermediate"};
62
63 // Calculates the minimum power of two that is not less than the value.
PowerOfTwoBound(double value)64 inline double PowerOfTwoBound(double value) {
65 return std::pow(2, std::ceil(std::log2(value)));
66 }
67
68 // Returns the element type of LSTM's intermediate tensor designated by the
69 // index.
70 template <typename LstmOp>
GetIntermediateElementType(LstmOp op,int tensor_index)71 inline QuantizedType GetIntermediateElementType(LstmOp op, int tensor_index) {
72 if (tensor_index < 0 || tensor_index > 4) return nullptr;
73 TypeAttr attr = op->template getAttrOfType<TypeAttr>(
74 intermediate_attributes[tensor_index]);
75 if (!attr) {
76 return nullptr;
77 }
78 return QuantizedType::getQuantizedElementType(attr.getValue());
79 }
80
81 namespace operator_property = ::tflite::optimize::operator_property;
82 using Q = quant::QuantizeCastOp;
83 using DQ = quant::DequantizeCastOp;
84
85 template <typename LstmOp>
GetLstmProperty(LstmOp op,operator_property::OpVariant * lstm_variant,operator_property::OperatorProperty * op_property)86 LogicalResult GetLstmProperty(
87 LstmOp op, operator_property::OpVariant* lstm_variant,
88 operator_property::OperatorProperty* op_property) {
89 if (llvm::isa<TFL::LSTMOp>(op.getOperation())) {
90 lstm_variant->op_code = tflite::BuiltinOperator_LSTM;
91 } else if (llvm::isa<TFL::UnidirectionalSequenceLSTMOp>(op.getOperation())) {
92 lstm_variant->op_code =
93 tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM;
94 } else {
95 op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs.");
96 return failure();
97 }
98 lstm_variant->use_projection =
99 !op.projection_weights().getType().template isa<NoneType>();
100 lstm_variant->use_peephole =
101 !op.cell_to_output_weights().getType().template isa<NoneType>();
102 lstm_variant->use_layer_norm =
103 !op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
104
105 *op_property = operator_property::GetOperatorProperty(*lstm_variant);
106
107 // TODO(b/176258587) move this to operator_property.cc if this is needed in
108 // other components, too.
109 bool use_cifg =
110 op.input_to_input_weights().getType().template isa<NoneType>();
111 if (use_cifg) {
112 const absl::flat_hash_set<int> cifg_non_inputs = {1, 5, 9, 12, 20};
113 const int cifg_non_intermediate = 0;
114 op_property->inputs.erase(
115 std::remove_if(
116 op_property->inputs.begin(), op_property->inputs.end(),
117 [&](std::pair<int, operator_property::TensorProperty> input) {
118 return cifg_non_inputs.find(input.first) != cifg_non_inputs.end();
119 }),
120 op_property->inputs.end());
121 op_property->intermediates.erase(
122 std::remove_if(op_property->intermediates.begin(),
123 op_property->intermediates.end(),
124 [&](std::pair<int, operator_property::TensorProperty>
125 intermediate) {
126 return intermediate.first == cifg_non_intermediate;
127 }),
128 op_property->intermediates.end());
129 }
130 return success();
131 }
132
133 template <typename SourceOp>
134 struct PrepareLstmOutputScale : public OpRewritePattern<SourceOp> {
135 public:
PrepareLstmOutputScalePrepareLstmOutputScale136 explicit PrepareLstmOutputScale(MLIRContext* context)
137 : OpRewritePattern<SourceOp>(context) {}
matchAndRewritePrepareLstmOutputScale138 LogicalResult matchAndRewrite(SourceOp op,
139 PatternRewriter& rewriter) const override {
140 operator_property::OpVariant lstm_variant;
141 operator_property::OperatorProperty lstm_property;
142
143 if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
144 return failure();
145 }
146 if (lstm_property.restrict_scale.size() != 1) {
147 op.emitError() << "The LSTM's operator property expects exactly one "
148 << "restrict scale requirement. Got "
149 << lstm_property.restrict_scale.size()
150 << " restrict scale requirements.";
151 return failure();
152 }
153
154 // Use same scale for input and output specified in restrict_scale.
155 const std::vector<int>& tensors = lstm_property.restrict_scale[0];
156 if (tensors.size() != 2) {
157 op.emitError(
158 "Unexpected restricted_scale from operator property."
159 " Should only have a pair of indices.");
160 return failure();
161 }
162 return processRestrictScale(op, tensors[0], tensors[1], rewriter);
163 }
164
165 private:
166 // For LSTM's recurrent input activation and output, they are quantized with
167 // the collective range of both tensors, because theoretically the input
168 // activation value for the very first inference is not reflected in the
169 // output and the input activation is not captured.
processRestrictScalePrepareLstmOutputScale170 LogicalResult processRestrictScale(SourceOp op, int input_index,
171 int output_index,
172 PatternRewriter& rewriter) const {
173 assert(output_index == 0);
174 if (!op.getResult().hasOneUse()) {
175 op.emitError()
176 << "output " << output_index
177 << " should have only one use, which should be quant.stats.";
178 return failure();
179 }
180
181 llvm::SmallVector<quant::StatisticsOp, 2> stats_ops = {
182 llvm::dyn_cast_or_null<quant::StatisticsOp>(
183 op.getOperand(input_index).getDefiningOp()),
184 llvm::dyn_cast_or_null<quant::StatisticsOp>(
185 *op.getResult().getUsers().begin()),
186 };
187
188 if (!stats_ops[0] || !stats_ops[1]) {
189 return failure(); // Already converted to Q-DQ pair.
190 }
191
192 llvm::SmallVector<llvm::APFloat, 4> min_max_values;
193
194 for (auto& stats_op : stats_ops) {
195 auto values = stats_op.layerStats()
196 .dyn_cast<DenseFPElementsAttr>()
197 .getValues<llvm::APFloat>();
198 min_max_values.insert(min_max_values.end(), values.begin(), values.end());
199 }
200
201 // min and max values of two stats are already the same.
202 if (min_max_values[0] == min_max_values[2] &&
203 min_max_values[1] == min_max_values[3]) {
204 return failure();
205 }
206
207 mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
208 mlir::RankedTensorType::get({2}, rewriter.getF32Type()),
209 {llvm::minimum(min_max_values[0], min_max_values[2]),
210 llvm::maximum(min_max_values[1], min_max_values[3])});
211 mlir::ElementsAttr axis_stats;
212 mlir::IntegerAttr axis;
213 for (auto& stats_op : stats_ops) {
214 rewriter.setInsertionPointAfter(stats_op);
215 rewriter.replaceOpWithNewOp<quant::StatisticsOp>(
216 stats_op, stats_op.arg(), layer_stats, axis_stats, axis);
217 }
218 return success();
219 }
220 };
221
222 template <typename SourceOp>
223 struct ConvertOpStatsToQDQs : public OpRewritePattern<SourceOp> {
224 public:
225 explicit ConvertOpStatsToQDQs(MLIRContext* context,
226 const QuantizationSpecs& quant_specs,
227 PatternBenefit benefit = 1)
228 : OpRewritePattern<SourceOp>(context, benefit),
229 quant_specs(quant_specs) {}
230
231 protected:
232 QuantizationSpecs quant_specs;
233
processInputsConvertOpStatsToQDQs234 LogicalResult processInputs(
235 SourceOp op, const operator_property::OpVariant& op_variant,
236 const operator_property::OperatorProperty& op_property,
237 PatternRewriter& rewriter) const {
238 for (auto& enumerated_inputs : op_property.inputs) {
239 int index = enumerated_inputs.first;
240 auto& tensor_property = enumerated_inputs.second;
241
242 Value input = op.getOperand(index);
243
244 if (input.getDefiningOp() == nullptr) continue;
245
246 // TODO(b/172517537): make this work with non-PTQ case.
247 if (llvm::isa<ConstantOp, TFL::ConstOp>(input.getDefiningOp())) {
248 // Tensors with derived scale are biases, and handled in propagation.
249 if (tensor_property.use_derived_scale) continue;
250 // For weights, use quantization scale inferred from the values.
251 if (failed(processConstantOp(op, input.getDefiningOp(), index,
252 tensor_property, rewriter))) {
253 return failure();
254 }
255 } else {
256 if (auto stats_op =
257 llvm::dyn_cast<quant::StatisticsOp>(input.getDefiningOp())) {
258 if (failed(replaceStatsOp(op, stats_op, index, tensor_property,
259 rewriter))) {
260 return failure();
261 }
262 } else if (!llvm::isa<DQ>(input.getDefiningOp()) &&
263 !llvm::isa<SameScalesOpInterface>(input.getDefiningOp())) {
264 // Continue if StatisticsOp is already converted to Q-DQ pair, or
265 // stats op is not immediately available to the input because it's
266 // connected to ops with same scale requirements.
267 // TODO(b/172517537): make this work with non-PTQ case.
268 op.emitError() << "Input " << index
269 << " should be from DequantizeCast, Statistics, "
270 << ", or ops with same scale requirement.";
271 input.getDefiningOp()->emitError();
272 return failure();
273 }
274 }
275 }
276 return success();
277 }
278
processConstantOpConvertOpStatsToQDQs279 LogicalResult processConstantOp(
280 SourceOp op, Operation* const_op, int input_index,
281 const operator_property::TensorProperty& tensor_property,
282 PatternRewriter& rewriter) const {
283 // Non-float tensors are neither weights nor require quantization.
284 auto type = const_op->getResult(0).getType().dyn_cast<ShapedType>();
285 if (!type || !type.getElementType().isa<FloatType>()) return success();
286
287 DenseFPElementsAttr attr;
288 if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) {
289 const_op->emitError("Not a constant op.");
290 return failure();
291 }
292
293 UniformQuantizedType quant_type = nullptr;
294 // When the number of bits is 10 (instead of 16), quantize the tensor to
295 // [-512, 512], instead of [-32767, 32767].
296 // For now this behavior is specific for SVDF, where 6 bits are reserved for
297 // the reduce operation after element-wise multiplication between state and
298 // time weights.
299 if (tensor_property.number_of_bits == 10) {
300 SmallVector<double, 4> mins(1, std::numeric_limits<double>::max());
301 SmallVector<double, 4> maxs(1, std::numeric_limits<double>::min());
302 // Computes the effective min/max values of the attribute values.
303 quant::ExtractMinMaxFromAttr(attr, /*dim_size=*/1, /*slice_size=*/1,
304 /*symmetric=*/true, mins, maxs);
305 double scale = maxs[0] / -llvm::minIntN(tensor_property.number_of_bits);
306 quant_type = UniformQuantizedType::getChecked(
307 quant::QuantizationFlags::Signed, rewriter.getIntegerType(16),
308 attr.getType().getElementType(), scale, /*zeroPoint=*/0,
309 llvm::minIntN(10), -llvm::minIntN(10), const_op->getLoc());
310 } else {
311 quant_type =
312 quant::GetUniformQuantizedTypeForWeight(
313 attr, /*symmetric=*/true,
314 /*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true,
315 /*narrow_range=*/true, quant_specs.legacy_float_scale)
316 .template dyn_cast<quant::UniformQuantizedType>();
317 }
318 if (!quant_type) {
319 const_op->emitError("Failed to get quantized type");
320 return failure();
321 }
322
323 // TODO(b/172517537): duplicate the constant when the bias is shared.
324 Type expressed_type = const_op->getResult(0).getType();
325 Type cast_type = quant_type.castFromExpressedType(expressed_type);
326 rewriter.setInsertionPointAfter(const_op);
327 auto q = rewriter.create<Q>(const_op->getLoc(), cast_type,
328 const_op->getResult(0));
329 auto dq = rewriter.create<DQ>(const_op->getLoc(), expressed_type, q);
330 op.setOperand(input_index, dq.getResult());
331 return success();
332 }
333
replaceStatsOpConvertOpStatsToQDQs334 LogicalResult replaceStatsOp(
335 SourceOp op, quant::StatisticsOp stats_op, int input_index,
336 const operator_property::TensorProperty& tensor_property,
337 PatternRewriter& rewriter) const {
338 if (tensor_property.state_tensor && !stats_op.getResult().hasOneUse()) {
339 // TODO(b/172517537): check if other tensors should go through this
340 // check too.
341 op.emitError() << "Input tensor [" << input_index
342 << "] is a state tensor, but has more than one use.";
343 return failure();
344 }
345 auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
346 if (!stats || stats.getNumElements() != 2) {
347 stats_op.emitError("Stats should have 2 values.");
348 return failure();
349 }
350 quant::QuantizedType quant_type;
351 double min = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
352 double max = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
353 // Make sure the range includes zero.
354 min = std::min(min, 0.0);
355 max = std::max(max, 0.0);
356 Type expressed = getElementTypeOrSelf(stats_op.getType());
357
358 if (tensor_property.extend_to_power_of_two) {
359 if (tensor_property.number_of_bits != 16) {
360 op.emitError(
361 "extended power of 2 scale is only supported for 16-bit"
362 " quantization.");
363 return failure();
364 }
365
366 double bound = PowerOfTwoBound(std::max(std::abs(min), std::abs(max)));
367 // Set flags to 1 for signed type.
368 quant_type = UniformQuantizedType::getChecked(
369 quant::QuantizationFlags::Signed,
370 rewriter.getIntegerType(tensor_property.number_of_bits), expressed,
371 /*scale=*/bound / -llvm::minIntN(tensor_property.number_of_bits),
372 /*zeroPoint=*/0, llvm::minIntN(tensor_property.number_of_bits),
373 llvm::maxIntN(tensor_property.number_of_bits), op.getLoc());
374 } else {
375 // int16 uses range [-32767, 32767]
376 if (tensor_property.number_of_bits == 16) {
377 max = std::max(std::abs(min), std::abs(max));
378 min = -max;
379 quant_type = quant::fakeQuantAttrsToType(
380 op.getLoc(), tensor_property.number_of_bits, min, max,
381 /*narrowRange=*/true, expressed,
382 /*isSigned=*/true);
383 } else {
384 quant_type = quant::fakeQuantAttrsToType(
385 op.getLoc(), tensor_property.number_of_bits, min, max,
386 /*narrowRange=*/false, expressed,
387 /*isSigned=*/true);
388 }
389 if (quant_specs.legacy_float_scale) {
390 quant_type = quant::DownCastScale(quant_type, min, max, op.getLoc());
391 }
392 }
393 rewriter.setInsertionPointAfter(stats_op);
394 Type result_type = quant_type.castFromExpressedType(stats_op.getType());
395 auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
396 rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
397 return success();
398 }
399 };
400
401 // Quantize LSTM according to its quantization recipe.
402 template <typename SourceOp>
403 struct ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs<SourceOp> {
404 public:
ConvertLstmStatsToQDQsConvertLstmStatsToQDQs405 ConvertLstmStatsToQDQs(MLIRContext* context,
406 const QuantizationSpecs& quant_specs)
407
408 : ConvertOpStatsToQDQs<SourceOp>(context, quant_specs) {}
matchAndRewriteConvertLstmStatsToQDQs409 LogicalResult matchAndRewrite(SourceOp op,
410 PatternRewriter& rewriter) const override {
411 operator_property::OpVariant lstm_variant;
412 operator_property::OperatorProperty lstm_property;
413 if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
414 return failure();
415 }
416
417 if (failed(processIntermediates(op, lstm_variant, lstm_property)) ||
418 failed(ConvertOpStatsToQDQs<SourceOp>::processInputs(
419 op, lstm_variant, lstm_property, rewriter))) {
420 return failure();
421 }
422
423 return success();
424 }
425
426 private:
processIntermediatesConvertLstmStatsToQDQs427 LogicalResult processIntermediates(
428 SourceOp op, const operator_property::OpVariant& lstm_variant,
429 const operator_property::OperatorProperty& lstm_property) const {
430 for (auto& enumerated_intermediates : lstm_property.intermediates) {
431 int index = enumerated_intermediates.first;
432 auto& tensor_property = enumerated_intermediates.second;
433 // intermediate tensors 0, 1, 2, 3 are only used with layer normalization.
434 if (!lstm_variant.use_layer_norm && index != 4) {
435 continue;
436 }
437
438 TypeAttr attr =
439 op->template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
440 auto quant_type = GetIntermediateElementType<SourceOp>(op, index);
441 if (!quant_type) {
442 // intermediate tensor 4 is optional, unless the LSTM uses projection.
443 if (index == 4 && !lstm_variant.use_projection) {
444 return success();
445 }
446 op.emitError() << intermediate_attributes[index]
447 << " is not quantized.";
448 return failure();
449 }
450 auto calibrated_type =
451 quant_type.template dyn_cast<quant::CalibratedQuantizedType>();
452 if (!calibrated_type) {
453 int num_storage_bits = quant_type.getStorageTypeIntegralWidth();
454 if (tensor_property.number_of_bits != num_storage_bits) {
455 op.emitError() << intermediate_attributes[index]
456 << " is expected to be quantized with "
457 << tensor_property.number_of_bits << " bits, but got "
458 << num_storage_bits << " bits instead.";
459 return failure();
460 }
461 continue; // skip if it is already quantized.
462 }
463 quant::UniformQuantizedType qtype;
464 if (tensor_property.number_of_bits == 8) {
465 qtype = quant::fakeQuantAttrsToType(
466 op.getLoc(), tensor_property.number_of_bits,
467 calibrated_type.getMin(), calibrated_type.getMax(),
468 /*narrowRange=*/false, calibrated_type.getExpressedType(),
469 /*isSigned=*/this->quant_specs.IsSignedInferenceType());
470 if (this->quant_specs.legacy_float_scale) {
471 qtype = quant::DownCastScale(qtype, calibrated_type.getMin(),
472 calibrated_type.getMax(), op.getLoc())
473 .template cast<UniformQuantizedType>();
474 }
475 } else if (tensor_property.number_of_bits == 16) {
476 double max = std::max(std::abs(calibrated_type.getMin()),
477 std::abs(calibrated_type.getMax()));
478 qtype = quant::fakeQuantAttrsToType(
479 op.getLoc(), tensor_property.number_of_bits, -max, max,
480 /*narrowRange=*/true, calibrated_type.getExpressedType(),
481 /*isSigned=*/true);
482 } else {
483 op.emitError() << "Unsupported quantization bits: "
484 << tensor_property.number_of_bits;
485 return failure();
486 }
487 op->setAttr(intermediate_attributes[index],
488 TypeAttr::get(qtype.castFromExpressedType(
489 qtype.castToExpressedType(attr.getValue()))));
490 }
491 return success();
492 }
493 };
494
495 // Returns a function that returns the quantized type of a bias input.
496 // The scale of bias is a multiplication of given scale and scales from the
497 // quantization type of other operands.
GetUniformQuantizedTypeForBiasWithScale(double scale)498 inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale(
499 double scale) {
500 return [=](const std::vector<quant::QuantParams>& quant_params,
501 bool legacy_float_scale) -> quant::QuantParams {
502 if (auto qtype =
503 GetUniformQuantizedTypeForBias(quant_params, legacy_float_scale)
504 .dyn_cast_or_null<UniformQuantizedType>()) {
505 return quant::UniformQuantizedType::get(
506 qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
507 qtype.getScale() * scale, qtype.getZeroPoint(),
508 qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
509 }
510 return {};
511 };
512 }
513
514 // Returns quantization spec for LSTMs based on their operator properties.
515 template <typename LstmOp>
GetLstmOpQuantSpec(LstmOp op)516 std::unique_ptr<quant::OpQuantSpec> GetLstmOpQuantSpec(LstmOp op) {
517 operator_property::OpVariant lstm_variant;
518 operator_property::OperatorProperty lstm_property;
519 if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
520 return nullptr;
521 }
522
523 auto spec = absl::make_unique<quant::OpQuantSpec>();
524
525 for (const auto& enumerated_inputs : lstm_property.inputs) {
526 int index = enumerated_inputs.first;
527 auto& tensor_property = enumerated_inputs.second;
528 if (tensor_property.use_derived_scale) {
529 double scale = 1.0;
530 for (int tensor_index :
531 tensor_property.derived_scale.intermediate_tensors) {
532 auto quant_type = GetIntermediateElementType<LstmOp>(op, tensor_index);
533 if (!quant_type ||
534 !quant_type.template isa<quant::UniformQuantizedType>()) {
535 op->emitError() << "While processing derived scale, intermediate "
536 << intermediate_attributes[tensor_index]
537 << " is not quantized.";
538 return nullptr;
539 }
540 scale *= quant_type.template dyn_cast<quant::UniformQuantizedType>()
541 .getScale();
542 }
543 for (float factor : tensor_property.derived_scale.factors) {
544 scale *= factor;
545 }
546 spec->biases_params.emplace(
547 index,
548 std::make_pair(tensor_property.derived_scale.input_tensors,
549 GetUniformQuantizedTypeForBiasWithScale(scale)));
550 }
551 }
552 return spec;
553 }
554
555 struct ConvertSvdfStatsToQDQs : public ConvertOpStatsToQDQs<TFL::SVDFOp> {
556 public:
ConvertSvdfStatsToQDQsConvertSvdfStatsToQDQs557 explicit ConvertSvdfStatsToQDQs(MLIRContext* context,
558 const QuantizationSpecs& quant_specs_param)
559 : ConvertOpStatsToQDQs<TFL::SVDFOp>(context, quant_specs_param) {}
matchAndRewriteConvertSvdfStatsToQDQs560 LogicalResult matchAndRewrite(TFL::SVDFOp op,
561 PatternRewriter& rewriter) const override {
562 operator_property::OpVariant op_variant;
563 op_variant.op_code = tflite::BuiltinOperator_SVDF;
564 auto op_property = operator_property::GetOperatorProperty(op_variant);
565 return ConvertOpStatsToQDQs<TFL::SVDFOp>::processInputs(
566 op, op_variant, op_property, rewriter);
567 }
568 };
569
570 } // namespace TFL
571 } // namespace mlir
572
573 #endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
574