• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <iterator>
21 #include <limits>
22 #include <numeric>
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
28 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
29 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
30 #include "mlir/Dialect/Quant/QuantizeUtils.h"  // from @llvm-project
31 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
32 #include "mlir/IR/Attributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
36 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
40 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
41 #include "tensorflow/lite/tools/optimize/quantization_utils.h"
42 
43 namespace mlir {
44 
45 // This includes the interface class definition. It couldn't be in a namespace
46 // because the table gen doesn't emit the namespace when it is used.
47 #include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.cc.inc"
48 
49 namespace quant {
50 
51 constexpr double kNearZeroTolerance = 1.0e-6;
52 constexpr double kSmallestHalfRange = kNearZeroTolerance / 2;
53 
54 // This method expands the range to be larger than or equal to 1.0e-6, if it is
55 // very small (< 1.0e-6). This is to prevent very large quantized value by this
56 // range.
ExpandVerySmallRange(ArrayRef<double> mins,ArrayRef<double> maxs,SmallVectorImpl<double> * effective_mins,SmallVectorImpl<double> * effective_maxs)57 static void ExpandVerySmallRange(ArrayRef<double> mins, ArrayRef<double> maxs,
58                                  SmallVectorImpl<double>* effective_mins,
59                                  SmallVectorImpl<double>* effective_maxs) {
60   for (auto arg : llvm::zip(mins, maxs)) {
61     double min = std::get<0>(arg);
62     double max = std::get<1>(arg);
63     // The range is wide, then use the same min/max.
64     if ((max - min) > kNearZeroTolerance) {
65       effective_mins->push_back(min);
66       effective_maxs->push_back(max);
67       continue;
68     }
69 
70     // The range is small. Expands the range to stride 0.0 and also at least
71     // 1.0e-6.
72     effective_mins->push_back(std::min(min, -kSmallestHalfRange));
73     effective_maxs->push_back(std::max(max, kSmallestHalfRange));
74   }
75 }
76 
77 // Returns the quantized type for the
78 // input_type/min/max/storag_type_width/narrow_range.
79 // This is entry point to the Quant dialect and used for both quantizing
80 // activations and weights.
GetQuantizedType(Builder builder,Type input_type,ArrayRef<double> min,ArrayRef<double> max,int quant_dim,int storage_type_width,bool narrow_range,bool is_signed,bool legacy_float_scale)81 Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min,
82                       ArrayRef<double> max, int quant_dim,
83                       int storage_type_width, bool narrow_range, bool is_signed,
84                       bool legacy_float_scale) {
85   auto converter =
86       quant::ExpressedToQuantizedConverter::forInputType(input_type);
87 
88   // Expand the range to prevent extremely small scales and large quantized
89   // integers which can cause overflow. This leads to scale
90   // 7.843137254901961e-9 with 8 bits.
91   SmallVector<double, 4> effective_mins, effective_maxs;
92   ExpandVerySmallRange(min, max, &effective_mins, &effective_maxs);
93 
94   quant::QuantizedType quantizedEleType;
95   if (min.size() == 1 && max.size() == 1 && quant_dim == -1) {
96     quantizedEleType = quant::fakeQuantAttrsToType(
97         builder.getUnknownLoc(), storage_type_width, effective_mins[0],
98         effective_maxs[0], narrow_range, converter.expressedType, is_signed);
99     if (legacy_float_scale) {
100       quantizedEleType =
101           DownCastScale(quantizedEleType, effective_mins[0], effective_maxs[0],
102                         builder.getUnknownLoc());
103     }
104   } else if (min.size() == max.size()) {
105     auto shape = input_type.dyn_cast<ShapedType>();
106     if (!shape || shape.getRank() <= quant_dim ||
107         static_cast<int64_t>(min.size()) != shape.getDimSize(quant_dim)) {
108       return {};
109     }
110     // TODO(b/141508873): the quantization dim is set to the last dimension.
111     quantizedEleType = quant::fakeQuantAttrsToType(
112         builder.getUnknownLoc(), storage_type_width, quant_dim, effective_mins,
113         effective_maxs, narrow_range, converter.expressedType, is_signed);
114     if (legacy_float_scale) {
115       quantizedEleType = DownCastScale(quantizedEleType, effective_mins,
116                                        effective_maxs, builder.getUnknownLoc());
117     }
118   }
119   if (!quantizedEleType) return {};
120   return converter.convert(quantizedEleType);
121 }
122 
123 // TODO(fengliuai): promote this utility method to mlir QuantOps.
RescaleQuantizedType(Type input,Attribute factor)124 TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
125   auto factor_values = factor.dyn_cast_or_null<DenseFPElementsAttr>();
126   if (!factor_values) return {};
127   auto ele_type = quant::QuantizedType::getQuantizedElementType(input);
128   if (!ele_type) return {};
129   if (auto qtype = ele_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
130     ArrayRef<double> scales = qtype.getScales();
131     // Broadcasting hasn't been implemented yet.
132     if (static_cast<int64_t>(scales.size()) != factor_values.getNumElements())
133       return {};
134     SmallVector<double, 4> new_scales;
135     new_scales.reserve(scales.size());
136     auto scales_iter = scales.begin();
137     for (const auto& f : factor_values) {
138       new_scales.push_back(*(scales_iter++) *
139                            std::fabs(FloatAttr::getValueAsDouble(f)));
140     }
141     // We are assuming symmetric quantization.
142     auto new_ele_type = quant::UniformQuantizedPerAxisType::get(
143         qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
144         new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(),
145         qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
146     if (auto new_type = new_ele_type.castFromExpressedType(
147             quant::QuantizedType::castToExpressedType(input))) {
148       return TypeAttr::get(new_type);
149     }
150   }
151   // Currently, we only support per-axis quantized type.
152   return {};
153 }
154 
GetQuantizedTypeAttr(Builder builder,Type input_type,Attribute min,Attribute max,int quant_dim,IntegerAttr num_bits,BoolAttr narrow_range,bool is_signed,bool legacy_float_scale)155 TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
156                               Attribute max, int quant_dim,
157                               IntegerAttr num_bits, BoolAttr narrow_range,
158                               bool is_signed, bool legacy_float_scale) {
159   SmallVector<double, 4> min_value, max_value;
160   auto mins = min.dyn_cast<DenseFPElementsAttr>();
161   auto maxs = max.dyn_cast<DenseFPElementsAttr>();
162   if (mins && maxs) {
163     min_value.reserve(mins.getNumElements());
164     max_value.reserve(maxs.getNumElements());
165     for (auto it = mins.begin(), e = mins.end(); it != e; ++it) {
166       min_value.push_back(FloatAttr::getValueAsDouble(*it));
167     }
168     for (auto it = maxs.begin(), e = maxs.end(); it != e; ++it) {
169       max_value.push_back(FloatAttr::getValueAsDouble(*it));
170     }
171   } else {
172     auto fmin = min.dyn_cast<FloatAttr>();
173     auto fmax = max.dyn_cast<FloatAttr>();
174     if (fmin && fmax) {
175       min_value.push_back(fmin.getValueAsDouble());
176       max_value.push_back(fmax.getValueAsDouble());
177     } else {
178       return {};
179     }
180   }
181   Type final_type = GetQuantizedType(
182       builder, input_type, min_value, max_value, quant_dim, num_bits.getInt(),
183       narrow_range.getValue(), is_signed, legacy_float_scale);
184   if (!final_type) return {};
185   return TypeAttr::get(final_type);
186 }
187 
188 // Repeats the content of `data` multiple times to resize to `target_size`.
189 // Note that this only broadcast across one dimension.
190 template <typename T>
BroadcastVector(int target_size,SmallVectorImpl<T> & data)191 static bool BroadcastVector(int target_size, SmallVectorImpl<T>& data) {
192   int size = data.size();
193   if (size != target_size) {
194     if (target_size % size != 0) return true;
195     data.reserve(target_size);
196     for (int i = 1, e = target_size / size; i != e; ++i) {
197       data.insert(data.end(), data.begin(), data.begin() + size);
198     }
199   }
200   return false;
201 }
202 
203 // Changes the axis of the input per-channel quantized type to match the
204 // dimension of the target type. Returns nullptr if it fails.
ResetAxisAndBroadcast(ArrayRef<int64_t> shape,quant::UniformQuantizedPerAxisType qtype,Type target,int quant_dim)205 static quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast(
206     ArrayRef<int64_t> shape, quant::UniformQuantizedPerAxisType qtype,
207     Type target, int quant_dim) {
208   auto shaped = target.dyn_cast<RankedTensorType>();
209   if (!shaped) return {};
210   ArrayRef<int64_t> new_shape = shaped.getShape();
211 
212   SmallVector<double, 4> scales(qtype.getScales().begin(),
213                                 qtype.getScales().end());
214   SmallVector<int64_t, 4> zero_points(qtype.getZeroPoints().begin(),
215                                       qtype.getZeroPoints().end());
216 
217   if (new_shape.size() == shape.size()) {  // same rank
218     // Broadcast the scales and zero points to match the target size, which is
219     // usually the axis-th dimension of the target type. Currently, it covers
220     // two cases:
221     // - for Transpose, the data layout is changed so the `dim[axis]` still
222     // equals to the `scales_size`. The broadcast skips;
223     // - for Reshape, the data layout isn't changed but the innermost dimension
224     // is expand to cover the last two original dimensions. Thus we just need to
225     // be repeated the `scales` dim[2] times to covers the new dim length.
226     //
227     // TODO(b/141709944): after the fix, the `scales` can be for dim[2], thus we
228     // have to repeat each elements in the `scales` locally dim[3] times.
229     if (BroadcastVector<double>(shaped.getDimSize(quant_dim), scales) ||
230         BroadcastVector<int64_t>(shaped.getDimSize(quant_dim), zero_points)) {
231       return {};
232     }
233   } else if ((new_shape.size() == shape.size() + 1) && new_shape.back() == 1) {
234     // This is a trivial shift left, then we shift the quant_dim as well.
235     if (std::equal(shape.begin(), shape.end(), new_shape.begin()) &&
236         quant_dim == -1) {
237       quant_dim = shape.size() + quant_dim;
238     } else {
239       return {};
240     }
241   } else {
242     return {};
243   }
244 
245   return quant::UniformQuantizedPerAxisType::get(
246       qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
247       scales, zero_points, quant_dim, qtype.getStorageTypeMin(),
248       qtype.getStorageTypeMax());
249 }
250 
CastQuantizedTypeAttrFromExpressedType(Builder builder,TypeAttr source,Type target,int axis)251 TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
252                                                 TypeAttr source, Type target,
253                                                 int axis) {
254   auto source_type = source.getValue().dyn_cast_or_null<ShapedType>();
255   if (!source_type) return {};
256   auto src_ele_type = source_type.getElementType();
257   auto qtype = src_ele_type.dyn_cast<quant::QuantizedType>();
258 
259   // Reset the quantization dimensions if it is per-axis.
260   if (auto per_axis =
261           qtype.dyn_cast_or_null<quant::UniformQuantizedPerAxisType>()) {
262     qtype =
263         ResetAxisAndBroadcast(source_type.getShape(), per_axis, target, axis);
264   }
265   if (!qtype) return {};
266   Type final_type = qtype.castFromExpressedType(target);
267   if (!final_type) return {};
268   return TypeAttr::get(final_type);
269 }
270 
ExtractMinMaxFromAttr(DenseFPElementsAttr values,int dim_size,int slice_size,bool symmetric,SmallVectorImpl<double> & mins,SmallVectorImpl<double> & maxs)271 void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
272                            int slice_size, bool symmetric,
273                            SmallVectorImpl<double>& mins,
274                            SmallVectorImpl<double>& maxs) {
275   // If all the element values are same we don't need to scan the content.
276   if (values.isSplat()) {
277     double single_value =
278         FloatAttr::getValueAsDouble(values.getSplatValue<llvm::APFloat>());
279 
280     // When the single value isn't 0.0, we expand it to a range to include
281     // this single value and 0.0. This will give us a scale and zero point
282     // works for both this value and 0.0.
283     if (single_value < 0.0) {
284       mins[0] = single_value;
285       maxs[0] = symmetric ? -single_value : 0.0;
286     } else if (single_value > 0.0) {
287       mins[0] = symmetric ? -single_value : 0.0;
288       maxs[0] = single_value;
289     } else {
290       mins[0] = maxs[0] = single_value;
291     }
292     for (int i = 1; i < dim_size; ++i) {
293       mins[i] = mins[0];
294       maxs[i] = maxs[0];
295     }
296   } else {
297     int64_t flatten_index = 0;
298     for (auto it = values.begin(), e = values.end(); it != e;
299          ++it, ++flatten_index) {
300       double ele_value = FloatAttr::getValueAsDouble(*it);
301       int slice_index = flatten_index / slice_size;
302       int channel_index = slice_index % dim_size;
303       mins[channel_index] = std::min(mins[channel_index], ele_value);
304       maxs[channel_index] = std::max(maxs[channel_index], ele_value);
305     }
306     // Expand range to include 0.
307     for (int i = 0; i < dim_size; ++i) {
308       maxs[i] = std::max(maxs[i], 0.0);
309       mins[i] = std::min(mins[i], 0.0);
310     }
311     if (symmetric) {
312       for (int i = 0; i < dim_size; ++i) {
313         maxs[i] = std::max(std::abs(mins[i]), std::abs(maxs[i]));
314         mins[i] = -maxs[i];
315       }
316     }
317   }
318 }
319 
GetUniformQuantizedTypeForWeight(ElementsAttr attr,bool symmetric,unsigned num_bits,bool is_signed,bool narrow_range,bool legacy_float_scale)320 Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,
321                                       unsigned num_bits, bool is_signed,
322                                       bool narrow_range,
323                                       bool legacy_float_scale) {
324   Builder builder(attr.getContext());
325   // `symmetric` can only be used when it is `signed` and `narrow_range`.
326   if (symmetric && (!is_signed || !narrow_range)) return {};
327 
328   SmallVector<double, 4> mins(1, std::numeric_limits<double>::max());
329   SmallVector<double, 4> maxs(1, std::numeric_limits<double>::min());
330   auto fp = attr.dyn_cast<DenseFPElementsAttr>();
331   if (!fp) return {};
332 
333   // Computes the effective min/max values of the attribute values.
334   ExtractMinMaxFromAttr(fp, /*dim_size=*/1, /*slice_size=*/1, symmetric, mins,
335                         maxs);
336 
337   auto type = GetQuantizedType(builder, attr.getType(), mins[0], maxs[0],
338                                /*quant_dim=*/-1, num_bits, narrow_range,
339                                is_signed, legacy_float_scale);
340   if (auto ele_type = type.dyn_cast_or_null<TensorType>())
341     return ele_type.getElementType();
342 
343   return {};
344 }
345 
GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr,int quant_dim,bool symmetric,unsigned num_bits,bool is_signed,bool narrow_range,bool legacy_float_scale)346 Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim,
347                                              bool symmetric, unsigned num_bits,
348                                              bool is_signed, bool narrow_range,
349                                              bool legacy_float_scale) {
350   Builder builder(attr.getContext());
351   auto shape = attr.getType().cast<ShapedType>().getShape();
352   if (static_cast<int>(shape.size()) <= quant_dim) return {};
353   // `symmetric` can only be used when it is `signed` and `narrow_range`.
354   if (symmetric && (!is_signed || !narrow_range)) return {};
355 
356   int dim_size = shape[quant_dim];
357   int slice_size = std::accumulate(std::next(shape.begin(), quant_dim + 1),
358                                    shape.end(), 1, std::multiplies<int64_t>());
359   SmallVector<double, 4> mins(dim_size, std::numeric_limits<double>::max());
360   SmallVector<double, 4> maxs(dim_size, std::numeric_limits<double>::min());
361   auto fp = attr.dyn_cast<DenseFPElementsAttr>();
362   if (!fp) return {};
363 
364   // Computes the effective min/max values of the attribute values.
365   ExtractMinMaxFromAttr(fp, dim_size, slice_size, symmetric, mins, maxs);
366 
367   auto type =
368       GetQuantizedType(builder, attr.getType(), mins, maxs, quant_dim, num_bits,
369                        narrow_range, is_signed, legacy_float_scale);
370   if (auto ele_type = type.dyn_cast_or_null<TensorType>())
371     return ele_type.getElementType();
372 
373   return {};
374 }
375 
GetUniformQuantizedTypeForBias(const std::vector<quant::QuantizedType> & op_types,bool legacy_float_scale)376 quant::QuantizedType GetUniformQuantizedTypeForBias(
377     const std::vector<quant::QuantizedType>& op_types,
378     bool legacy_float_scale) {
379   if (op_types.empty()) return {};
380 
381   size_t axis_size = 1;
382   int32_t quant_dim = -1;
383   Type expressed_type;
384   // Requires all the op types are valid UniformQuantizedTypes or
385   // UniformQuantizedPerAxisTypes and also have same expressed type. For all
386   // the UniformQuantizedPerAxisTypes, the quantization dimension index and
387   // dimension sizes are same.
388   for (auto op_type : op_types) {
389     if (!op_type) return {};
390     if (expressed_type && expressed_type != op_type.getExpressedType()) {
391       return {};
392     }
393     expressed_type = op_type.getExpressedType();
394 
395     if (auto type = op_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
396       if ((axis_size != 1 && axis_size != type.getScales().size())) return {};
397       if (quant_dim != -1 && quant_dim != type.getQuantizedDimension())
398         return {};
399       axis_size = type.getScales().size();
400       quant_dim = type.getQuantizedDimension();
401     } else if (!op_type.isa<quant::UniformQuantizedType>()) {
402       return {};
403     }
404   }
405 
406   // The scale from the UniformQuantizedTypes is broadcasted if there are
407   // UniformQuantizedPerAxisTypes.
408   llvm::SmallVector<double, 4> scales(axis_size, 1.0);
409   for (auto op_type : op_types) {
410     if (auto type = op_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
411       for (auto index_scale : llvm::enumerate(type.getScales())) {
412         scales[index_scale.index()] *= index_scale.value();
413       }
414     } else if (auto type = op_type.dyn_cast<quant::UniformQuantizedType>()) {
415       for (int index = 0, e = axis_size; index != e; ++index) {
416         scales[index] *= type.getScale();
417       }
418     }
419   }
420   if (legacy_float_scale) {
421     for (int i = 0; i < scales.size(); ++i) {
422       scales[i] = static_cast<float>(scales[i]);
423     }
424   }
425 
426   // Builds the result quantized type, which has signed 32 bits storage type.
427   Builder builder(expressed_type.getContext());
428   IntegerType storage_type = builder.getIntegerType(32);
429   int64_t storage_type_min =
430       quant::QuantizedType::getDefaultMinimumForInteger(/*isSigned=*/true, 32);
431   int64_t storage_type_max =
432       quant::QuantizedType::getDefaultMaximumForInteger(/*isSigned=*/true, 32);
433   if (axis_size == 1) {
434     return quant::UniformQuantizedType::getChecked(
435         /*flags=*/true, storage_type, expressed_type, scales[0],
436         /*zeroPoint=*/0, storage_type_min, storage_type_max,
437         builder.getUnknownLoc());
438   } else {
439     llvm::SmallVector<int64_t, 4> zero_points(axis_size, 0);
440     // TODO(b/141508873): Assume the bias is a 1-D tensor, and set the
441     // quantization dim to the last dimension, which is 0. If the bias rank is
442     // larger than 1, this returned quantized type couldn't be used to
443     // quantize the bias.
444     return quant::UniformQuantizedPerAxisType::getChecked(
445         /*flags=*/true, storage_type, expressed_type, scales, zero_points,
446         /*quantizedDimension=*/0, storage_type_min, storage_type_max,
447         builder.getUnknownLoc());
448   }
449 }
450 
QuantizeLegacy(Attribute real_value,Type tensor_type)451 ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type) {
452   if (!real_value.isa<DenseFPElementsAttr>() ||
453       !quant::QuantizedType::getQuantizedElementType(tensor_type)) {
454     return {};
455   }
456   auto real_values_attr = real_value.cast<DenseFPElementsAttr>();
457   auto q_type = quant::QuantizedType::getQuantizedElementType(tensor_type);
458   std::vector<float> real_values;
459   llvm::SmallVector<APInt, 8> quantized_attr;
460   real_values.reserve(real_values_attr.getNumElements());
461   quantized_attr.reserve(real_values_attr.getNumElements());
462   std::transform(real_values_attr.begin(), real_values_attr.end(),
463                  std::back_inserter(real_values), [&](APFloat value) -> float {
464                    return value.convertToFloat();
465                  });
466   ShapedType new_dense_type =
467       q_type.castExpressedToStorageType(real_values_attr.getType())
468           .dyn_cast_or_null<ShapedType>();
469   int width = q_type.getStorageType().dyn_cast<mlir::IntegerType>().getWidth();
470 
471   if (width == 8 && q_type.getStorageTypeMax() == 127 &&
472       q_type.getStorageTypeMin() == -127) {
473     std::vector<int8_t> quantized_values(real_values_attr.getNumElements());
474     if (q_type.isa<UniformQuantizedType>()) {
475       float min, max, scale;
476       tflite::tensor_utils::SymmetricQuantizeFloats(
477           real_values.data(), real_values.size(), quantized_values.data(), &min,
478           &max, &scale);
479     } else if (auto uniform_type =
480                    q_type.dyn_cast<UniformQuantizedPerAxisType>()) {
481       std::vector<float> scales_inv;
482       std::vector<int32_t> dimension;
483       dimension.insert(dimension.end(), new_dense_type.getShape().begin(),
484                        new_dense_type.getShape().end());
485       std::transform(uniform_type.getScales().begin(),
486                      uniform_type.getScales().end(),
487                      std::back_inserter(scales_inv),
488                      [](float scale) { return 1.0 / scale; });
489 
490       tflite::optimize::utils::SymmetricPerChannelQuantizeValues(
491           real_values.data(), scales_inv, dimension,
492           uniform_type.getQuantizedDimension(), &quantized_values);
493     } else {
494       return {};
495     }
496     std::transform(quantized_values.begin(), quantized_values.end(),
497                    std::back_inserter(quantized_attr),
498                    [&](int8_t value) -> APInt {
499                      return APInt(8, value, /*isSigned=*/true);
500                    });
501     return DenseElementsAttr::get(new_dense_type, quantized_attr);
502   } else if (width == 16) {
503     if (auto uniform_type = q_type.dyn_cast<UniformQuantizedType>()) {
504       auto quantized_values =
505           tflite::optimize::utils::SymmetricQuantizeFloatsToInt16(
506               real_values.data(), real_values.size(), uniform_type.getScale());
507       std::transform(quantized_values.begin(), quantized_values.end(),
508                      std::back_inserter(quantized_attr),
509                      [&](int16_t value) -> APInt {
510                        return APInt(16, value, /*isSigned=*/true);
511                      });
512       return DenseElementsAttr::get(new_dense_type, quantized_attr);
513     }
514   } else if (width == 32) {
515     std::vector<float> scales;
516     if (auto uniform_type = q_type.dyn_cast<UniformQuantizedType>()) {
517       scales.push_back(uniform_type.getScale());
518     } else if (auto uniform_type =
519                    q_type.dyn_cast<UniformQuantizedPerAxisType>()) {
520       scales.insert(scales.end(), uniform_type.getScales().begin(),
521                     uniform_type.getScales().end());
522     } else {
523       return {};
524     }
525     auto quantized_bias =
526         tflite::optimize::utils::SymmetricBiasQuantize<std::int32_t>(
527             real_values.data(), real_values.size(), scales);
528     std::transform(quantized_bias.begin(), quantized_bias.end(),
529                    std::back_inserter(quantized_attr),
530                    [&](int32_t value) -> APInt {
531                      return APInt(32, value, /*isSigned=*/true);
532                    });
533     return DenseElementsAttr::get(new_dense_type, quantized_attr);
534   }
535   return {};
536 }
537 
Quantize(Attribute real_value,Type tensor_type)538 ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
539   if (auto q_type =
540           quant::QuantizedType::getQuantizedElementType(tensor_type)) {
541     Type converted_type;
542     return quant::quantizeAttr(real_value, q_type, converted_type)
543         .dyn_cast<ElementsAttr>();
544   }
545   return {};
546 }
547 
DownCastScale(QuantizedType type,double min,double max,Location loc)548 QuantizedType DownCastScale(QuantizedType type, double min, double max,
549                             Location loc) {
550   SmallVector<double, 1> mins = {min};
551   SmallVector<double, 1> maxs = {max};
552   return DownCastScale(type, mins, maxs, loc);
553 }
554 
DownCastScale(QuantizedType type,const SmallVectorImpl<double> & mins,const SmallVectorImpl<double> & maxs,Location loc)555 QuantizedType DownCastScale(QuantizedType type,
556                             const SmallVectorImpl<double>& mins,
557                             const SmallVectorImpl<double>& maxs, Location loc) {
558   SmallVector<double, 4> scales(mins.size());
559   for (int i = 0; i < mins.size(); ++i) {
560     scales[i] = (static_cast<float>(maxs[i]) - static_cast<float>(mins[i])) /
561                 (type.getStorageTypeMax() - type.getStorageTypeMin());
562     if (scales[i] < kNearZeroTolerance &&
563         type.getStorageTypeIntegralWidth() == 8) {
564       emitWarning(loc) << "The scale " << scales[i] << " is too small, and "
565                        << "might cause overflow for bias. Forcing to use scale "
566                        << kNearZeroTolerance;
567       scales[i] = kNearZeroTolerance;
568     }
569   }
570   if (auto q_type = type.dyn_cast<UniformQuantizedType>()) {
571     return UniformQuantizedType::get(
572         q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(),
573         scales[0], q_type.getZeroPoint(), q_type.getStorageTypeMin(),
574         q_type.getStorageTypeMax());
575   } else if (auto q_type = type.dyn_cast<UniformQuantizedPerAxisType>()) {
576     return UniformQuantizedPerAxisType::get(
577         q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(),
578         scales, q_type.getZeroPoints(), q_type.getQuantizedDimension(),
579         q_type.getStorageTypeMin(), q_type.getStorageTypeMax());
580   }
581   return type;
582 }
583 
584 // A heuristic to determine whether the scales needs to be from operands or
585 // from results for the ops with the `SameOperandsAndResultsScale` property.
586 // The current implementation is based on the number of operands.
PreferResultScale(Operation * op)587 static bool PreferResultScale(Operation* op) {
588   int float_operands = 0;
589   for (auto operand : op->getOperands()) {
590     if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
591       if (operand_type.getElementType().isa<FloatType>()) {
592         if (++float_operands > 1) return true;
593       }
594     }
595   }
596   return false;
597 }
598 
599 // The stats op of some of the ops can be redundant. The current implementation
600 // only considers the ops with restricted output params.
IsStatsRedundant(Operation * op,OpQuantSpecGetter op_quant_spec_getter)601 static bool IsStatsRedundant(Operation* op,
602                              OpQuantSpecGetter op_quant_spec_getter) {
603   return llvm::isa<FixedOutputRangeInterface>(op);
604 }
605 
RemoveRedundantStatsOps(mlir::FuncOp func,OpQuantSpecGetter op_quant_spec_getter)606 bool RemoveRedundantStatsOps(mlir::FuncOp func,
607                              OpQuantSpecGetter op_quant_spec_getter) {
608   llvm::SmallVector<quant::StatisticsOp, 16> all_stats_ops;
609   llvm::DenseSet<Operation*> redundant_stats_ops;
610 
611   // Step 0: remove the quant::StatisticsOp which are used by the tfl.quantize
612   // op in case it overrides the information from training FakeQuant ops.
613   func.walk([&](quant::QuantizeCastOp q) {
614     auto input_op = q.arg().getDefiningOp();
615     if (auto stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(input_op)) {
616       q.setOperand(stats.arg());
617       if (stats.use_empty()) stats.erase();
618     }
619   });
620 
621   // Step 1: forward pass: propagate any value scales which are not produces
622   // by `SameOperandsAndResultsScale`. Additionally, remove the value scales
623   // which are produced by the ops with the `FixedOutputRangeInterface`.
624   // Note that we don't propagate across the multiple-operands
625   // `SameOperandsAndResultsScale` ops like `concatenation`.
626   func.walk(
627       [&](quant::StatisticsOp stats_op) { all_stats_ops.push_back(stats_op); });
628 
629   while (!all_stats_ops.empty()) {
630     quant::StatisticsOp stats_op = all_stats_ops.back();
631     all_stats_ops.pop_back();
632 
633     if (auto def = stats_op.arg().getDefiningOp()) {
634       if (IsStatsRedundant(def, op_quant_spec_getter)) {
635         redundant_stats_ops.insert(stats_op);
636       }
637     }
638 
639     for (auto user : stats_op.getResult().getUsers()) {
640       // We don't propagate this parameter down if it has multiple operands.
641       // We want to use the result parameter scales instead.
642 
643       if (llvm::dyn_cast<SameScalesOpInterface>(user) &&
644           !PreferResultScale(user)) {
645         for (Value res : user->getResults()) {
646           if (res.hasOneUse()) {
647             if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>(
648                     *res.getUsers().begin())) {
649               // quantization parameters can be propagated to next_stats
650               redundant_stats_ops.insert(next_stats);
651               // add next_stats to the work list so propagation can
652               // continue.
653               all_stats_ops.push_back(next_stats);
654             }
655           }
656         }
657       }
658     }
659   }
660 
661   // Step 2: backward pass: For the ops skiped in the forward pass, propagate
662   // its results scale backwards as far as possible.
663   func.walk([&](quant::StatisticsOp stats_op) {
664     if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) {
665       all_stats_ops.push_back(stats_op);
666     }
667   });
668 
669   while (!all_stats_ops.empty()) {
670     quant::StatisticsOp stats_op = all_stats_ops.back();
671     all_stats_ops.pop_back();
672 
673     if (auto def = stats_op.arg().getDefiningOp()) {
674       if (llvm::dyn_cast<SameScalesOpInterface>(def)) {
675         for (auto input : def->getOperands()) {
676           if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
677                   input.getDefiningOp())) {
678             redundant_stats_ops.insert(next_stats);
679             all_stats_ops.push_back(next_stats);
680           }
681         }
682       }
683     }
684   }
685 
686   // Step3: Remove all the redundant stats ops
687   for (auto it : redundant_stats_ops) {
688     if (!llvm::isa<quant::StatisticsOp>(it)) return true;
689     auto stats_op = llvm::cast<quant::StatisticsOp>(it);
690     stats_op.getResult().replaceAllUsesWith(stats_op.arg());
691     stats_op.erase();
692   }
693 
694   // Returns false if the steps finish without errors.
695   return false;
696 }
697 
VerifySameScales(Operation * op)698 LogicalResult VerifySameScales(Operation* op) {
699   auto same_scale_op = llvm::cast<SameScalesOpInterface>(op);
700 
701   llvm::SmallVector<QuantizedType, 4> collected_quant_params;
702   for (auto input : op->getOperands()) {
703     auto quant_params =
704         UniformQuantizedType::getQuantizedElementType(input.getType());
705     // Skip non-quantizable operands.
706     if (quant_params) {
707       collected_quant_params.push_back(quant_params);
708     }
709   }
710 
711   for (auto output : op->getResults()) {
712     auto quant_params =
713         UniformQuantizedType::getQuantizedElementType(output.getType());
714     // Skip non-quantizable results.
715     if (quant_params) {
716       collected_quant_params.push_back(quant_params);
717     }
718   }
719 
720   if (collected_quant_params.size() <= 1) return success();
721   for (int i = 1; i < collected_quant_params.size(); i++) {
722     auto expected_params = collected_quant_params[0];
723     auto compared_paras = collected_quant_params[i];
724     // Same quantization parameters are always ok.
725     if (expected_params == compared_paras) continue;
726     // If the quantization parameters are not the same, as long as it has the
727     // same storage type and the op interface doesn't require same scale
728     // constraint for this storage type, it is still ok.
729     if ((expected_params.isSigned() == compared_paras.isSigned() &&
730          expected_params.getStorageTypeIntegralWidth() ==
731              compared_paras.getStorageTypeIntegralWidth()) &&
732         !same_scale_op.RequiredSameOperandsAndResultsScale(
733             expected_params.isSigned(),
734             expected_params.getStorageTypeIntegralWidth()))
735       continue;
736 
737     std::string err_msg =
738         "quantization parameters violate the same scale constraint: ";
739     llvm::raw_string_ostream os(err_msg);
740     collected_quant_params[0].print(os);
741     os << " vs. ";
742     collected_quant_params[i].print(os);
743     os.flush();
744     return op->emitOpError(err_msg);
745   }
746   return success();
747 }
748 
GetFixedOutputRange(bool is_signed,int bit_width,Type tensor_type,double scale,int64_t zero_point,int64_t storage_min,int64_t storage_max)749 quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
750                                                 Type tensor_type, double scale,
751                                                 int64_t zero_point,
752                                                 int64_t storage_min,
753                                                 int64_t storage_max) {
754   auto result_type = tensor_type.cast<ShapedType>();
755   if (!result_type.getElementType().isa<FloatType>()) return {};
756   Builder builder(result_type.getContext());
757 
758   // Only support 8-bits
759   if (bit_width != 8) return {};
760   IntegerType storage_type = builder.getIntegerType(bit_width);
761   if (!is_signed) {
762     zero_point += 128;
763     storage_min += 128;
764     storage_max += 128;
765   }
766   return quant::UniformQuantizedType::getChecked(
767       is_signed, storage_type, result_type.getElementType(), scale, zero_point,
768       storage_min, storage_max, builder.getUnknownLoc());
769 }
770 }  // namespace quant
771 }  // namespace mlir
772