• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <iterator>
21 #include <limits>
22 #include <numeric>
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
28 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
29 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
30 #include "mlir/Dialect/Quant/QuantizeUtils.h"  // from @llvm-project
31 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
32 #include "mlir/IR/Attributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
36 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
40 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
41 #include "tensorflow/lite/tools/optimize/quantization_utils.h"
42 
43 namespace mlir {
44 
45 // This includes the interface class definition. It couldn't be in a namespace
46 // because the table gen doesn't emit the namespace when it is used.
47 #include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.cc.inc"
48 
49 namespace quant {
50 
51 constexpr double kNearZeroTolerance = 1.0e-6;
52 constexpr double kSmallestHalfRange = kNearZeroTolerance / 2;
53 using QType = quant::QuantizedType;
54 
55 const char kQuantTraitAttr[] = "_tfl_quant_trait";
56 const absl::string_view QuantTraitValues[] = {"fully_quantizable",
57                                               "not_quantizable"};
58 
IsOpNotQuantizable(Operation * op)59 bool IsOpNotQuantizable(Operation* op) {
60   // If it is terminator or not quantizable or any ops form the mlir quant
61   // ops dialect, we shouldn't rewrite.
62   bool attr_enforced_quantizable =
63       op->hasAttrOfType<StringAttr>(kQuantTraitAttr) &&
64       op->getAttrOfType<StringAttr>(kQuantTraitAttr).getValue().str() ==
65           QuantTraitValues[QuantizationTrait::FullyQuantizable];
66   bool prop_enforced_no_quantizable =
67       op->hasTrait<OpTrait::quant::NoQuantizableResult>();
68 
69   return op->hasTrait<OpTrait::IsTerminator>() ||
70          llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(op) ||
71          (!attr_enforced_quantizable && prop_enforced_no_quantizable);
72 }
73 
74 // This method expands the range to be larger than or equal to 1.0e-6, if it is
75 // very small (< 1.0e-6). This is to prevent very large quantized value by this
76 // range.
ExpandVerySmallRange(ArrayRef<double> mins,ArrayRef<double> maxs,SmallVectorImpl<double> * effective_mins,SmallVectorImpl<double> * effective_maxs)77 static void ExpandVerySmallRange(ArrayRef<double> mins, ArrayRef<double> maxs,
78                                  SmallVectorImpl<double>* effective_mins,
79                                  SmallVectorImpl<double>* effective_maxs) {
80   for (auto arg : llvm::zip(mins, maxs)) {
81     double min = std::get<0>(arg);
82     double max = std::get<1>(arg);
83     // The range is wide, then use the same min/max.
84     if ((max - min) > kNearZeroTolerance) {
85       effective_mins->push_back(min);
86       effective_maxs->push_back(max);
87       continue;
88     }
89 
90     // The range is small. Expands the range to stride 0.0 and also at least
91     // 1.0e-6.
92     effective_mins->push_back(std::min(min, -kSmallestHalfRange));
93     effective_maxs->push_back(std::max(max, kSmallestHalfRange));
94   }
95 }
96 
97 // Returns the quantized type for the
98 // input_type/min/max/storag_type_width/narrow_range.
99 // This is entry point to the Quant dialect and used for both quantizing
100 // activations and weights.
GetQuantizedType(Builder builder,Type input_type,ArrayRef<double> min,ArrayRef<double> max,int quant_dim,int storage_type_width,bool narrow_range,bool is_signed,bool legacy_float_scale)101 Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min,
102                       ArrayRef<double> max, int quant_dim,
103                       int storage_type_width, bool narrow_range, bool is_signed,
104                       bool legacy_float_scale) {
105   auto converter =
106       quant::ExpressedToQuantizedConverter::forInputType(input_type);
107 
108   // Expand the range to prevent extremely small scales and large quantized
109   // integers which can cause overflow. This leads to scale
110   // 7.843137254901961e-9 with 8 bits.
111   SmallVector<double, 4> effective_mins, effective_maxs;
112   ExpandVerySmallRange(min, max, &effective_mins, &effective_maxs);
113 
114   quant::QuantizedType quantizedEleType;
115   if (min.size() == 1 && max.size() == 1 && quant_dim == -1) {
116     quantizedEleType = quant::fakeQuantAttrsToType(
117         builder.getUnknownLoc(), storage_type_width, effective_mins[0],
118         effective_maxs[0], narrow_range, converter.expressedType, is_signed);
119     if (legacy_float_scale) {
120       quantizedEleType =
121           DownCastScale(quantizedEleType, effective_mins[0], effective_maxs[0],
122                         builder.getUnknownLoc());
123     }
124   } else if (min.size() == max.size()) {
125     auto shape = input_type.dyn_cast<ShapedType>();
126     if (!shape || shape.getRank() <= quant_dim ||
127         static_cast<int64_t>(min.size()) != shape.getDimSize(quant_dim)) {
128       return {};
129     }
130     // The quantization dim is set to the last dimension.
131     quantizedEleType = quant::fakeQuantAttrsToType(
132         builder.getUnknownLoc(), storage_type_width, quant_dim, effective_mins,
133         effective_maxs, narrow_range, converter.expressedType, is_signed);
134     if (legacy_float_scale) {
135       quantizedEleType = DownCastScale(quantizedEleType, effective_mins,
136                                        effective_maxs, builder.getUnknownLoc());
137     }
138   }
139   if (!quantizedEleType) return {};
140   return converter.convert(quantizedEleType);
141 }
142 
143 // TODO(fengliuai): promote this utility method to mlir QuantOps.
RescaleQuantizedType(Type input,Attribute factor)144 TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
145   auto factor_values = factor.dyn_cast_or_null<DenseFPElementsAttr>();
146   if (!factor_values) return {};
147   auto ele_type = quant::QuantizedType::getQuantizedElementType(input);
148   if (!ele_type) return {};
149   if (auto qtype = ele_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
150     ArrayRef<double> scales = qtype.getScales();
151     // Broadcasting hasn't been implemented yet.
152     if (static_cast<int64_t>(scales.size()) != factor_values.getNumElements())
153       return {};
154     SmallVector<double, 4> new_scales;
155     new_scales.reserve(scales.size());
156     auto scales_iter = scales.begin();
157     for (const auto& f : factor_values) {
158       new_scales.push_back(*(scales_iter++) *
159                            std::fabs(FloatAttr::getValueAsDouble(f)));
160     }
161     // We are assuming symmetric quantization.
162     auto new_ele_type = quant::UniformQuantizedPerAxisType::get(
163         qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
164         new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(),
165         qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
166     if (auto new_type = new_ele_type.castFromExpressedType(
167             quant::QuantizedType::castToExpressedType(input))) {
168       return TypeAttr::get(new_type);
169     }
170   }
171   // Currently, we only support per-axis quantized type.
172   return {};
173 }
174 
GetQuantizedTypeAttr(Builder builder,Type input_type,Attribute min,Attribute max,int quant_dim,IntegerAttr num_bits,BoolAttr narrow_range,bool is_signed,bool legacy_float_scale)175 TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
176                               Attribute max, int quant_dim,
177                               IntegerAttr num_bits, BoolAttr narrow_range,
178                               bool is_signed, bool legacy_float_scale) {
179   SmallVector<double, 4> min_value, max_value;
180   auto mins = min.dyn_cast<DenseFPElementsAttr>();
181   auto maxs = max.dyn_cast<DenseFPElementsAttr>();
182   if (mins && maxs) {
183     min_value.reserve(mins.getNumElements());
184     max_value.reserve(maxs.getNumElements());
185     for (auto it = mins.begin(), e = mins.end(); it != e; ++it) {
186       min_value.push_back(FloatAttr::getValueAsDouble(*it));
187     }
188     for (auto it = maxs.begin(), e = maxs.end(); it != e; ++it) {
189       max_value.push_back(FloatAttr::getValueAsDouble(*it));
190     }
191   } else {
192     auto fmin = min.dyn_cast<FloatAttr>();
193     auto fmax = max.dyn_cast<FloatAttr>();
194     if (fmin && fmax) {
195       min_value.push_back(fmin.getValueAsDouble());
196       max_value.push_back(fmax.getValueAsDouble());
197     } else {
198       return {};
199     }
200   }
201   Type final_type = GetQuantizedType(
202       builder, input_type, min_value, max_value, quant_dim, num_bits.getInt(),
203       narrow_range.getValue(), is_signed, legacy_float_scale);
204   if (!final_type) return {};
205   return TypeAttr::get(final_type);
206 }
207 
208 // Repeats the content of `data` multiple times to resize to `target_size`.
209 // Note that this only broadcast across one dimension.
210 template <typename T>
BroadcastVector(int target_size,SmallVectorImpl<T> & data)211 static bool BroadcastVector(int target_size, SmallVectorImpl<T>& data) {
212   int size = data.size();
213   if (size != target_size) {
214     if (target_size % size != 0) return true;
215     data.reserve(target_size);
216     for (int i = 1, e = target_size / size; i != e; ++i) {
217       data.insert(data.end(), data.begin(), data.begin() + size);
218     }
219   }
220   return false;
221 }
222 
223 // Changes the axis of the input per-channel quantized type to match the
224 // dimension of the target type. Returns nullptr if it fails.
ResetAxisAndBroadcast(ArrayRef<int64_t> shape,quant::UniformQuantizedPerAxisType qtype,Type target,int quant_dim)225 static quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast(
226     ArrayRef<int64_t> shape, quant::UniformQuantizedPerAxisType qtype,
227     Type target, int quant_dim) {
228   auto shaped = target.dyn_cast<RankedTensorType>();
229   if (!shaped) return {};
230   ArrayRef<int64_t> new_shape = shaped.getShape();
231 
232   SmallVector<double, 4> scales(qtype.getScales().begin(),
233                                 qtype.getScales().end());
234   SmallVector<int64_t, 4> zero_points(qtype.getZeroPoints().begin(),
235                                       qtype.getZeroPoints().end());
236 
237   if (new_shape.size() == shape.size()) {  // same rank
238     // Broadcast the scales and zero points to match the target size, which is
239     // usually the axis-th dimension of the target type. Currently, it covers
240     // two cases:
241     // - for Transpose, the data layout is changed so the `dim[axis]` still
242     // equals to the `scales_size`. The broadcast skips;
243     // - for Reshape, the data layout isn't changed but the innermost dimension
244     // is expand to cover the last two original dimensions. Thus we just need to
245     // be repeated the `scales` dim[2] times to covers the new dim length.
246     //
247     // TODO(b/141709944): after the fix, the `scales` can be for dim[2], thus we
248     // have to repeat each elements in the `scales` locally dim[3] times.
249     if (BroadcastVector<double>(shaped.getDimSize(quant_dim), scales) ||
250         BroadcastVector<int64_t>(shaped.getDimSize(quant_dim), zero_points)) {
251       return {};
252     }
253   } else if ((new_shape.size() == shape.size() + 1) && new_shape.back() == 1) {
254     // This is a trivial shift left, then we shift the quant_dim as well.
255     if (std::equal(shape.begin(), shape.end(), new_shape.begin()) &&
256         quant_dim == -1) {
257       quant_dim = shape.size() + quant_dim;
258     } else {
259       return {};
260     }
261   } else {
262     return {};
263   }
264 
265   return quant::UniformQuantizedPerAxisType::get(
266       qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
267       scales, zero_points, quant_dim, qtype.getStorageTypeMin(),
268       qtype.getStorageTypeMax());
269 }
270 
CastQuantizedTypeAttrFromExpressedType(Builder builder,TypeAttr source,Type target,int axis)271 TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
272                                                 TypeAttr source, Type target,
273                                                 int axis) {
274   auto source_type = source.getValue().dyn_cast_or_null<ShapedType>();
275   if (!source_type) return {};
276   auto src_ele_type = source_type.getElementType();
277   auto qtype = src_ele_type.dyn_cast<quant::QuantizedType>();
278 
279   // Reset the quantization dimensions if it is per-axis.
280   if (auto per_axis =
281           qtype.dyn_cast_or_null<quant::UniformQuantizedPerAxisType>()) {
282     qtype =
283         ResetAxisAndBroadcast(source_type.getShape(), per_axis, target, axis);
284   }
285   if (!qtype) return {};
286   Type final_type = qtype.castFromExpressedType(target);
287   if (!final_type) return {};
288   return TypeAttr::get(final_type);
289 }
290 
ExtractMinMaxFromAttr(DenseFPElementsAttr values,int dim_size,int slice_size,bool symmetric,SmallVectorImpl<double> & mins,SmallVectorImpl<double> & maxs)291 void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
292                            int slice_size, bool symmetric,
293                            SmallVectorImpl<double>& mins,
294                            SmallVectorImpl<double>& maxs) {
295   // If all the element values are same we don't need to scan the content.
296   if (values.isSplat()) {
297     double single_value =
298         FloatAttr::getValueAsDouble(values.getSplatValue<llvm::APFloat>());
299 
300     // When the single value isn't 0.0, we expand it to a range to include
301     // this single value and 0.0. This will give us a scale and zero point
302     // works for both this value and 0.0.
303     if (single_value < 0.0) {
304       mins[0] = single_value;
305       maxs[0] = symmetric ? -single_value : 0.0;
306     } else if (single_value > 0.0) {
307       mins[0] = symmetric ? -single_value : 0.0;
308       maxs[0] = single_value;
309     } else {
310       mins[0] = maxs[0] = single_value;
311     }
312     for (int i = 1; i < dim_size; ++i) {
313       mins[i] = mins[0];
314       maxs[i] = maxs[0];
315     }
316   } else {
317     int64_t flatten_index = 0;
318     for (auto it = values.begin(), e = values.end(); it != e;
319          ++it, ++flatten_index) {
320       double ele_value = FloatAttr::getValueAsDouble(*it);
321       int slice_index = flatten_index / slice_size;
322       int channel_index = slice_index % dim_size;
323       mins[channel_index] = std::min(mins[channel_index], ele_value);
324       maxs[channel_index] = std::max(maxs[channel_index], ele_value);
325     }
326     // Expand range to include 0.
327     for (int i = 0; i < dim_size; ++i) {
328       maxs[i] = std::max(maxs[i], 0.0);
329       mins[i] = std::min(mins[i], 0.0);
330     }
331     if (symmetric) {
332       for (int i = 0; i < dim_size; ++i) {
333         maxs[i] = std::max(std::abs(mins[i]), std::abs(maxs[i]));
334         mins[i] = -maxs[i];
335       }
336     }
337   }
338 }
339 
GetUniformQuantizedTypeForWeight(ElementsAttr attr,bool symmetric,unsigned num_bits,bool is_signed,bool narrow_range,bool legacy_float_scale)340 Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,
341                                       unsigned num_bits, bool is_signed,
342                                       bool narrow_range,
343                                       bool legacy_float_scale) {
344   Builder builder(attr.getContext());
345   // `symmetric` can only be used when it is `signed` and `narrow_range`.
346   if (symmetric && (!is_signed || !narrow_range)) return {};
347 
348   SmallVector<double, 4> mins(1, std::numeric_limits<double>::max());
349   SmallVector<double, 4> maxs(1, std::numeric_limits<double>::min());
350   auto fp = attr.dyn_cast<DenseFPElementsAttr>();
351   if (!fp) return {};
352 
353   // Computes the effective min/max values of the attribute values.
354   ExtractMinMaxFromAttr(fp, /*dim_size=*/1, /*slice_size=*/1, symmetric, mins,
355                         maxs);
356 
357   auto type = GetQuantizedType(builder, attr.getType(), mins[0], maxs[0],
358                                /*quant_dim=*/-1, num_bits, narrow_range,
359                                is_signed, legacy_float_scale);
360   if (auto ele_type = type.dyn_cast_or_null<TensorType>())
361     return ele_type.getElementType();
362 
363   return {};
364 }
365 
GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr,int quant_dim,bool symmetric,unsigned num_bits,bool is_signed,bool narrow_range,bool legacy_float_scale)366 Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim,
367                                              bool symmetric, unsigned num_bits,
368                                              bool is_signed, bool narrow_range,
369                                              bool legacy_float_scale) {
370   Builder builder(attr.getContext());
371   auto shape = attr.getType().cast<ShapedType>().getShape();
372   if (static_cast<int>(shape.size()) <= quant_dim) return {};
373   // `symmetric` can only be used when it is `signed` and `narrow_range`.
374   if (symmetric && (!is_signed || !narrow_range)) return {};
375 
376   int dim_size = shape[quant_dim];
377   int slice_size = std::accumulate(std::next(shape.begin(), quant_dim + 1),
378                                    shape.end(), 1, std::multiplies<int64_t>());
379   SmallVector<double, 4> mins(dim_size, std::numeric_limits<double>::max());
380   SmallVector<double, 4> maxs(dim_size, std::numeric_limits<double>::min());
381   auto fp = attr.dyn_cast<DenseFPElementsAttr>();
382   if (!fp) return {};
383 
384   // Computes the effective min/max values of the attribute values.
385   ExtractMinMaxFromAttr(fp, dim_size, slice_size, symmetric, mins, maxs);
386 
387   auto type =
388       GetQuantizedType(builder, attr.getType(), mins, maxs, quant_dim, num_bits,
389                        narrow_range, is_signed, legacy_float_scale);
390   if (auto ele_type = type.dyn_cast_or_null<TensorType>())
391     return ele_type.getElementType();
392 
393   return {};
394 }
395 
GetUniformQuantizedTypeForBias(const std::vector<quant::QuantizedType> & op_types,bool legacy_float_scale)396 quant::QuantizedType GetUniformQuantizedTypeForBias(
397     const std::vector<quant::QuantizedType>& op_types,
398     bool legacy_float_scale) {
399   if (op_types.empty()) return {};
400 
401   size_t axis_size = 1;
402   int32_t quant_dim = -1;
403   Type expressed_type;
404   // Requires all the op types are valid UniformQuantizedTypes or
405   // UniformQuantizedPerAxisTypes and also have same expressed type. For all
406   // the UniformQuantizedPerAxisTypes, the quantization dimension index and
407   // dimension sizes are same.
408   for (auto op_type : op_types) {
409     if (!op_type) return {};
410     if (expressed_type && expressed_type != op_type.getExpressedType()) {
411       return {};
412     }
413     expressed_type = op_type.getExpressedType();
414 
415     if (auto type = op_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
416       if ((axis_size != 1 && axis_size != type.getScales().size())) return {};
417       if (quant_dim != -1 && quant_dim != type.getQuantizedDimension())
418         return {};
419       axis_size = type.getScales().size();
420       quant_dim = type.getQuantizedDimension();
421     } else if (!op_type.isa<quant::UniformQuantizedType>()) {
422       return {};
423     }
424   }
425 
426   // The scale from the UniformQuantizedTypes is broadcasted if there are
427   // UniformQuantizedPerAxisTypes.
428   llvm::SmallVector<double, 4> scales(axis_size, 1.0);
429   for (auto op_type : op_types) {
430     if (auto type = op_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
431       for (auto index_scale : llvm::enumerate(type.getScales())) {
432         scales[index_scale.index()] *= index_scale.value();
433       }
434     } else if (auto type = op_type.dyn_cast<quant::UniformQuantizedType>()) {
435       for (int index = 0, e = axis_size; index != e; ++index) {
436         scales[index] *= type.getScale();
437       }
438     }
439   }
440   if (legacy_float_scale) {
441     for (int i = 0; i < scales.size(); ++i) {
442       scales[i] = static_cast<float>(scales[i]);
443     }
444   }
445 
446   // Builds the result quantized type, which has signed 32 bits storage type.
447   Builder builder(expressed_type.getContext());
448   IntegerType storage_type = builder.getIntegerType(32);
449   int64_t storage_type_min =
450       quant::QuantizedType::getDefaultMinimumForInteger(/*isSigned=*/true, 32);
451   int64_t storage_type_max =
452       quant::QuantizedType::getDefaultMaximumForInteger(/*isSigned=*/true, 32);
453   if (axis_size == 1) {
454     return quant::UniformQuantizedType::getChecked(
455         builder.getUnknownLoc(),
456         /*flags=*/true, storage_type, expressed_type, scales[0],
457         /*zeroPoint=*/0, storage_type_min, storage_type_max);
458   } else {
459     llvm::SmallVector<int64_t, 4> zero_points(axis_size, 0);
460     // Assume the bias is a 1-D tensor, and set the quantization dim to the last
461     // dimension, which is 0. If the bias rank is larger than 1, this returned
462     // quantized type couldn't be used to quantize the bias.
463     return quant::UniformQuantizedPerAxisType::getChecked(
464         builder.getUnknownLoc(),
465         /*flags=*/true, storage_type, expressed_type, scales, zero_points,
466         /*quantizedDimension=*/0, storage_type_min, storage_type_max);
467   }
468 }
469 
QuantizeLegacy(Attribute real_value,Type tensor_type)470 ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type) {
471   if (!real_value.isa<DenseFPElementsAttr>() ||
472       !quant::QuantizedType::getQuantizedElementType(tensor_type)) {
473     return {};
474   }
475   auto real_values_attr = real_value.cast<DenseFPElementsAttr>();
476   auto q_type = quant::QuantizedType::getQuantizedElementType(tensor_type);
477   std::vector<float> real_values;
478   llvm::SmallVector<APInt, 8> quantized_attr;
479   real_values.reserve(real_values_attr.getNumElements());
480   quantized_attr.reserve(real_values_attr.getNumElements());
481   std::transform(real_values_attr.begin(), real_values_attr.end(),
482                  std::back_inserter(real_values), [&](APFloat value) -> float {
483                    return value.convertToFloat();
484                  });
485   ShapedType new_dense_type =
486       q_type.castExpressedToStorageType(real_values_attr.getType())
487           .dyn_cast_or_null<ShapedType>();
488   int width = q_type.getStorageType().dyn_cast<mlir::IntegerType>().getWidth();
489 
490   if (width == 8 && q_type.getStorageTypeMax() == 127 &&
491       q_type.getStorageTypeMin() == -127) {
492     std::vector<int8_t> quantized_values(real_values_attr.getNumElements());
493     if (q_type.isa<UniformQuantizedType>()) {
494       float min, max, scale;
495       tflite::tensor_utils::SymmetricQuantizeFloats(
496           real_values.data(), real_values.size(), quantized_values.data(), &min,
497           &max, &scale);
498     } else if (auto uniform_type =
499                    q_type.dyn_cast<UniformQuantizedPerAxisType>()) {
500       std::vector<float> scales_inv;
501       std::vector<int32_t> dimension;
502       dimension.insert(dimension.end(), new_dense_type.getShape().begin(),
503                        new_dense_type.getShape().end());
504       std::transform(uniform_type.getScales().begin(),
505                      uniform_type.getScales().end(),
506                      std::back_inserter(scales_inv),
507                      [](float scale) { return 1.0 / scale; });
508 
509       tflite::optimize::utils::SymmetricPerChannelQuantizeValues(
510           real_values.data(), scales_inv, dimension,
511           uniform_type.getQuantizedDimension(), &quantized_values);
512     } else {
513       return {};
514     }
515     std::transform(quantized_values.begin(), quantized_values.end(),
516                    std::back_inserter(quantized_attr),
517                    [&](int8_t value) -> APInt {
518                      return APInt(8, value, /*isSigned=*/true);
519                    });
520     return DenseElementsAttr::get(new_dense_type, quantized_attr);
521   } else if (width == 8) {
522     // This can be a state tensor, or an actual constant tensor with
523     // asymmetric range. For a state tensor, assigining correct quantization
524     // parameters is sufficient, and for constants with asymmetric range it's
525     // not correctly quantized by legacy quantizer so call the new Quantize.
526     return Quantize(real_value, tensor_type);
527   } else if (width == 16) {
528     if (auto uniform_type = q_type.dyn_cast<UniformQuantizedType>()) {
529       auto quantized_values =
530           tflite::optimize::utils::SymmetricQuantizeFloatsToInt16(
531               real_values.data(), real_values.size(), uniform_type.getScale());
532       std::transform(quantized_values.begin(), quantized_values.end(),
533                      std::back_inserter(quantized_attr),
534                      [&](int16_t value) -> APInt {
535                        return APInt(16, value, /*isSigned=*/true);
536                      });
537       return DenseElementsAttr::get(new_dense_type, quantized_attr);
538     }
539   } else if (width == 32) {
540     std::vector<float> scales;
541     if (auto uniform_type = q_type.dyn_cast<UniformQuantizedType>()) {
542       scales.push_back(uniform_type.getScale());
543     } else if (auto uniform_type =
544                    q_type.dyn_cast<UniformQuantizedPerAxisType>()) {
545       scales.insert(scales.end(), uniform_type.getScales().begin(),
546                     uniform_type.getScales().end());
547     } else {
548       return {};
549     }
550     auto quantized_bias =
551         tflite::optimize::utils::SymmetricBiasQuantize<std::int32_t>(
552             real_values.data(), real_values.size(), scales);
553     std::transform(quantized_bias.begin(), quantized_bias.end(),
554                    std::back_inserter(quantized_attr),
555                    [&](int32_t value) -> APInt {
556                      return APInt(32, value, /*isSigned=*/true);
557                    });
558     return DenseElementsAttr::get(new_dense_type, quantized_attr);
559   }
560   return {};
561 }
562 
Quantize(Attribute real_value,Type tensor_type)563 ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
564   if (auto q_type =
565           quant::QuantizedType::getQuantizedElementType(tensor_type)) {
566     Type converted_type;
567     return quant::quantizeAttr(real_value, q_type, converted_type)
568         .dyn_cast<ElementsAttr>();
569   }
570   return {};
571 }
572 
DownCastScale(QuantizedType type,double min,double max,Location loc)573 QuantizedType DownCastScale(QuantizedType type, double min, double max,
574                             Location loc) {
575   SmallVector<double, 1> mins = {min};
576   SmallVector<double, 1> maxs = {max};
577   return DownCastScale(type, mins, maxs, loc);
578 }
579 
DownCastScale(QuantizedType type,const SmallVectorImpl<double> & mins,const SmallVectorImpl<double> & maxs,Location loc)580 QuantizedType DownCastScale(QuantizedType type,
581                             const SmallVectorImpl<double>& mins,
582                             const SmallVectorImpl<double>& maxs, Location loc) {
583   SmallVector<double, 4> scales(mins.size());
584   SmallVector<int64_t, 4> zero_points(mins.size());
585   if (auto q_type = type.dyn_cast<UniformQuantizedType>()) {
586     zero_points.push_back(q_type.getZeroPoint());
587   } else if (auto q_type = type.dyn_cast<UniformQuantizedPerAxisType>()) {
588     zero_points = {q_type.getZeroPoints().begin(),
589                    q_type.getZeroPoints().end()};
590   }
591   for (int i = 0; i < mins.size(); ++i) {
592     scales[i] = (static_cast<float>(maxs[i]) - static_cast<float>(mins[i])) /
593                 (type.getStorageTypeMax() - type.getStorageTypeMin());
594     if (scales[i] < kNearZeroTolerance &&
595         type.getStorageTypeIntegralWidth() == 8) {
596       emitWarning(loc) << "The scale " << scales[i] << " is too small, and "
597                        << "might cause overflow for bias. Forcing to use scale "
598                        << kNearZeroTolerance;
599       scales[i] = kNearZeroTolerance;
600     } else if (type.getStorageTypeMax() != -type.getStorageTypeMin()) {
601       // Only applies for asymmetric quantized range with original scale.
602       float zero_point_from_min =
603           type.getStorageTypeMin() - mins[i] / scales[i];
604       if (zero_point_from_min < type.getStorageTypeMin()) {
605         zero_points[i] = static_cast<int64_t>(type.getStorageTypeMin());
606       } else if (zero_point_from_min > type.getStorageTypeMax()) {
607         zero_points[i] = static_cast<int64_t>(type.getStorageTypeMax());
608       } else {
609         zero_points[i] = static_cast<int64_t>(std::round(zero_point_from_min));
610       }
611     }
612   }
613   if (auto q_type = type.dyn_cast<UniformQuantizedType>()) {
614     return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(),
615                                      q_type.getExpressedType(), scales[0],
616                                      zero_points[0], q_type.getStorageTypeMin(),
617                                      q_type.getStorageTypeMax());
618   } else if (auto q_type = type.dyn_cast<UniformQuantizedPerAxisType>()) {
619     return UniformQuantizedPerAxisType::get(
620         q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(),
621         scales, zero_points, q_type.getQuantizedDimension(),
622         q_type.getStorageTypeMin(), q_type.getStorageTypeMax());
623   }
624   return type;
625 }
626 
627 // A heuristic to determine whether the scales needs to be from operands or
628 // from results for the ops with the `SameOperandsAndResultsScale` property.
629 // The current implementation is based on the number of operands.
PreferResultScale(Operation * op)630 static bool PreferResultScale(Operation* op) {
631   int float_operands = 0;
632   for (auto operand : op->getOperands()) {
633     if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
634       if (operand_type.getElementType().isa<FloatType>()) {
635         if (++float_operands > 1) return true;
636       }
637     }
638   }
639   return false;
640 }
641 
642 // The stats op of some of the ops can be redundant. The current implementation
643 // only considers the ops with restricted output params.
IsStatsRedundant(Operation * op,OpQuantSpecGetter op_quant_spec_getter)644 static bool IsStatsRedundant(Operation* op,
645                              OpQuantSpecGetter op_quant_spec_getter) {
646   return llvm::isa<FixedOutputRangeInterface>(op);
647 }
648 
RemoveRedundantStatsOps(mlir::FuncOp func,OpQuantSpecGetter op_quant_spec_getter)649 bool RemoveRedundantStatsOps(mlir::FuncOp func,
650                              OpQuantSpecGetter op_quant_spec_getter) {
651   llvm::SmallVector<quant::StatisticsOp, 16> all_stats_ops;
652   llvm::DenseSet<Operation*> redundant_stats_ops;
653 
654   // Step 0: remove the quant::StatisticsOp which are used by the tfl.quantize
655   // op in case it overrides the information from training FakeQuant ops.
656   func.walk([&](quant::QuantizeCastOp q) {
657     auto input_op = q.arg().getDefiningOp();
658     if (auto stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(input_op)) {
659       q.setOperand(stats.arg());
660       if (stats.use_empty()) stats.erase();
661     }
662   });
663 
664   // Step 1: forward pass: propagate any value scales which are not produces
665   // by `SameOperandsAndResultsScale`. Additionally, remove the value scales
666   // which are produced by the ops with the `FixedOutputRangeInterface`.
667   // Note that we don't propagate across the multiple-operands
668   // `SameOperandsAndResultsScale` ops like `concatenation`.
669   func.walk(
670       [&](quant::StatisticsOp stats_op) { all_stats_ops.push_back(stats_op); });
671 
672   while (!all_stats_ops.empty()) {
673     quant::StatisticsOp stats_op = all_stats_ops.back();
674     all_stats_ops.pop_back();
675 
676     if (auto def = stats_op.arg().getDefiningOp()) {
677       if (IsStatsRedundant(def, op_quant_spec_getter)) {
678         redundant_stats_ops.insert(stats_op);
679       }
680     }
681 
682     for (auto user : stats_op.getResult().getUsers()) {
683       // We don't propagate this parameter down if it has multiple operands.
684       // We want to use the result parameter scales instead.
685 
686       if (llvm::dyn_cast<SameScalesOpInterface>(user) &&
687           !PreferResultScale(user)) {
688         for (Value res : user->getResults()) {
689           if (res.hasOneUse()) {
690             if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>(
691                     *res.getUsers().begin())) {
692               // quantization parameters can be propagated to next_stats
693               redundant_stats_ops.insert(next_stats);
694               // add next_stats to the work list so propagation can
695               // continue.
696               all_stats_ops.push_back(next_stats);
697             }
698           }
699         }
700       }
701     }
702   }
703 
704   // Step 2: backward pass: For the ops skiped in the forward pass, propagate
705   // its results scale backwards as far as possible.
706   func.walk([&](quant::StatisticsOp stats_op) {
707     if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) {
708       all_stats_ops.push_back(stats_op);
709     }
710   });
711 
712   while (!all_stats_ops.empty()) {
713     quant::StatisticsOp stats_op = all_stats_ops.back();
714     all_stats_ops.pop_back();
715 
716     if (auto def = stats_op.arg().getDefiningOp()) {
717       if (llvm::dyn_cast<SameScalesOpInterface>(def)) {
718         for (auto input : def->getOperands()) {
719           if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
720                   input.getDefiningOp())) {
721             redundant_stats_ops.insert(next_stats);
722             all_stats_ops.push_back(next_stats);
723           }
724         }
725       }
726     }
727   }
728 
729   // Step3: Remove all the redundant stats ops
730   for (auto it : redundant_stats_ops) {
731     if (!llvm::isa<quant::StatisticsOp>(it)) return true;
732     auto stats_op = llvm::cast<quant::StatisticsOp>(it);
733     stats_op.getResult().replaceAllUsesWith(stats_op.arg());
734     stats_op.erase();
735   }
736 
737   // Returns false if the steps finish without errors.
738   return false;
739 }
740 
VerifySameScales(Operation * op)741 LogicalResult VerifySameScales(Operation* op) {
742   auto same_scale_op = llvm::cast<SameScalesOpInterface>(op);
743 
744   llvm::SmallVector<QuantizedType, 4> collected_quant_params;
745   for (auto input : op->getOperands()) {
746     auto quant_params =
747         UniformQuantizedType::getQuantizedElementType(input.getType());
748     // Skip non-quantizable operands.
749     if (quant_params) {
750       collected_quant_params.push_back(quant_params);
751     }
752   }
753 
754   for (auto output : op->getResults()) {
755     auto quant_params =
756         UniformQuantizedType::getQuantizedElementType(output.getType());
757     // Skip non-quantizable results.
758     if (quant_params) {
759       collected_quant_params.push_back(quant_params);
760     }
761   }
762 
763   if (collected_quant_params.size() <= 1) return success();
764   for (int i = 1; i < collected_quant_params.size(); i++) {
765     auto expected_params = collected_quant_params[0];
766     auto compared_paras = collected_quant_params[i];
767     // Same quantization parameters are always ok.
768     if (expected_params == compared_paras) continue;
769     // If the quantization parameters are not the same, as long as it has the
770     // same storage type and the op interface doesn't require same scale
771     // constraint for this storage type, it is still ok.
772     if ((expected_params.isSigned() == compared_paras.isSigned() &&
773          expected_params.getStorageTypeIntegralWidth() ==
774              compared_paras.getStorageTypeIntegralWidth()) &&
775         !same_scale_op.RequiredSameOperandsAndResultsScale(
776             expected_params.isSigned(),
777             expected_params.getStorageTypeIntegralWidth()))
778       continue;
779 
780     std::string err_msg =
781         "quantization parameters violate the same scale constraint: ";
782     llvm::raw_string_ostream os(err_msg);
783     collected_quant_params[0].print(os);
784     os << " vs. ";
785     collected_quant_params[i].print(os);
786     os.flush();
787     return op->emitOpError(err_msg);
788   }
789   return success();
790 }
791 
GetFixedOutputRange(bool is_signed,int bit_width,Type tensor_type,double scale,int64_t zero_point,int64_t storage_min,int64_t storage_max)792 quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
793                                                 Type tensor_type, double scale,
794                                                 int64_t zero_point,
795                                                 int64_t storage_min,
796                                                 int64_t storage_max) {
797   auto result_type = tensor_type.cast<ShapedType>();
798   if (!result_type.getElementType().isa<FloatType>()) return {};
799   Builder builder(result_type.getContext());
800 
801   // Only support 8-bits
802   if (bit_width != 8) return {};
803   IntegerType storage_type = builder.getIntegerType(bit_width);
804   if (!is_signed) {
805     zero_point += 128;
806     storage_min += 128;
807     storage_max += 128;
808   }
809   return quant::UniformQuantizedType::getChecked(
810       builder.getUnknownLoc(), is_signed, storage_type,
811       result_type.getElementType(), scale, zero_point, storage_min,
812       storage_max);
813 }
814 
ConvertSignedQuantizedToUnsigned(Type signed_tensor_type,Location loc)815 Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc) {
816   auto qtype = QType::getQuantizedElementType(signed_tensor_type);
817   if (!qtype || !qtype.isSigned()) return {};
818 
819   int num_bits = qtype.getStorageTypeIntegralWidth();
820   // This is a negative value, and will be applied on zero points and fixed
821   // point ranges.
822   int64_t offset =
823       QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits) -
824       QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits);
825 
826   auto flags = !quant::QuantizationFlags::Signed;
827   QType new_qtype;
828   if (auto uqtype = qtype.dyn_cast<quant::UniformQuantizedType>()) {
829     new_qtype = quant::UniformQuantizedType::getChecked(
830         loc, flags, qtype.getStorageType(), qtype.getExpressedType(),
831         uqtype.getScale(), uqtype.getZeroPoint() - offset,
832         uqtype.getStorageTypeMin() - offset,
833         uqtype.getStorageTypeMax() - offset);
834   } else if (auto aqtype =
835                  qtype.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
836     auto zero_points = aqtype.getZeroPoints();
837     llvm::SmallVector<int64_t, 4> new_zero_points(zero_points.begin(),
838                                                   zero_points.end());
839     for (int i = 0, e = new_zero_points.size(); i != e; ++i) {
840       new_zero_points[i] -= offset;
841     }
842     new_qtype = quant::UniformQuantizedPerAxisType::getChecked(
843         loc, flags, qtype.getStorageType(), qtype.getExpressedType(),
844         aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(),
845         aqtype.getStorageTypeMin() - offset,
846         aqtype.getStorageTypeMax() - offset);
847   }
848   return new_qtype.castFromExpressedType(
849       QType::castToExpressedType(signed_tensor_type));
850 }
851 
852 }  // namespace quant
853 }  // namespace mlir
854