1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
17
18 #include <algorithm>
19
20 #include "absl/types/optional.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
26 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/Support/LogicalResult.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
30
31 namespace mlir {
32 namespace quant {
33
34 constexpr int k8Bits = 8;
35 constexpr int k32Bits = 32;
36 constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
37
DeviceTarget(MLIRContext * ctx)38 DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
39 f32_ = FloatType::getF32(ctx_);
40 i8_ = IntegerType::get(ctx_, k8Bits);
41 i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
42 i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
43 i32_ = IntegerType::get(ctx_, k32Bits);
44 i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
45 i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
46 any_ = AnyQuantizedType();
47 qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_);
48 qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_);
49 qi32_ = AnyQuantizedType::get(kSigned, i32_, f32_, i32_min_, i32_max_);
50 assert(qi8n_ == qi8n_);
51 }
52
GetKernelSpec(llvm::StringRef kernel,const KernelSpecs::Signature & signature) const53 Optional<KernelSpec> DeviceTarget::GetKernelSpec(
54 llvm::StringRef kernel, const KernelSpecs::Signature& signature) const {
55 auto kernel_specs_it = specs_.find(kernel);
56 if (kernel_specs_it == specs_.end()) return llvm::None;
57 return kernel_specs_it->getValue().Find(signature);
58 }
59
GetDecomposeFn(QuantizeRegionOp op) const60 ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const {
61 auto kernel_specs_it = specs_.find(op.logical_kernel());
62 if (kernel_specs_it == specs_.end()) return ScaleDecomposeFn(nullptr);
63 return kernel_specs_it->second.GetDecomposeFn();
64 }
65
AppendToSignature(Type spec,KernelSpecs::Signature * signature)66 void DeviceTarget::AppendToSignature(Type spec,
67 KernelSpecs::Signature* signature) {
68 if (auto quant = spec.dyn_cast_or_null<UniformQuantizedType>()) {
69 signature->push_back(AnyQuantizedType::get(
70 quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
71 quant.getStorageTypeMin(), quant.getStorageTypeMax()));
72 } else if (auto any = spec.dyn_cast_or_null<AnyQuantizedType>()) {
73 signature->push_back(any);
74 } else { // float
75 signature->push_back(AnyQuantizedType());
76 }
77 }
78
RegisterKernel(llvm::StringRef kernel,const KernelSpecs::Signature & signature,const ScaleFn & fn,const ScaleDecomposeFn & dfn)79 LogicalResult DeviceTarget::RegisterKernel(
80 llvm::StringRef kernel, const KernelSpecs::Signature& signature,
81 const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
82 return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
83 }
84
85 namespace ph = std::placeholders;
86
RegisterKernel(llvm::StringRef kernel,const KernelSpecs::Signature & signature,const ScaleConstraintType constraint)87 LogicalResult DeviceTarget::RegisterKernel(
88 llvm::StringRef kernel, const KernelSpecs::Signature& signature,
89 const ScaleConstraintType constraint) {
90 if (failed(specs_[kernel].Add(signature, {constraint, {}}))) return failure();
91 switch (constraint) {
92 case ScaleConstraintType::OutputInputSameScale:
93 specs_[kernel].WithImpl(std::bind(&DeviceTarget::DecomposeSameScale,
94 ph::_1, ph::_2, ph::_3, ph::_4));
95 return success();
96 default:
97 return failure();
98 }
99 }
100
DecomposeMultiplyAccumulateScale(Operation * op,quant::QuantizedMultipliers * input_multipliers,quant::QuantizedMultipliers * output_multipliers,quant::QuantizedRanges * output_ranges)101 LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
102 Operation* op, quant::QuantizedMultipliers* input_multipliers,
103 quant::QuantizedMultipliers* output_multipliers,
104 quant::QuantizedRanges* output_ranges) {
105 auto rop = llvm::dyn_cast<quant::QuantizeRegionOp>(op);
106 if (!rop) return failure();
107
108 llvm::SmallVector<Type, 4> input_specs, out_specs;
109 for (auto spec : rop.input_specs()) {
110 input_specs.push_back(spec.cast<TypeAttr>().getValue());
111 }
112 for (auto spec : rop.output_specs()) {
113 out_specs.push_back(spec.cast<TypeAttr>().getValue());
114 }
115
116 auto in_spec = input_specs[0].dyn_cast<quant::UniformQuantizedType>();
117 // TODO(fengliuai): handles the PerAxis QuantizedType.
118 auto w_spec = input_specs[1].dyn_cast<quant::UniformQuantizedType>();
119 auto b_spec = input_specs[2].dyn_cast<quant::UniformQuantizedType>();
120 auto o_spec = out_specs[0].dyn_cast<quant::UniformQuantizedType>();
121 if (!in_spec || !w_spec || !b_spec || !o_spec) return failure();
122
123 double scale_product = in_spec.getScale() * w_spec.getScale();
124 if (fabs(scale_product - b_spec.getScale()) >= 1e-6) return failure();
125
126 // input multipliers
127 input_multipliers->append(3, kUnitQuantizedMultiplier);
128
129 // output multipliers
130 double real_multiplier = scale_product / o_spec.getScale();
131 output_multipliers->push_back(quant::QuantizeMultiplier(real_multiplier));
132
133 // output ranges
134 auto min = rop->getAttrOfType<FloatAttr>("min");
135 auto max = rop->getAttrOfType<FloatAttr>("max");
136 output_ranges->push_back(quant::CalculateQuantizedRange(
137 o_spec.getScale(), o_spec.getZeroPoint(),
138 (min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
139 (max ? absl::optional<double>(max.getValueAsDouble()) : absl::nullopt),
140 o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
141
142 return success();
143 }
144
DecomposeSameScale(Operation * op,quant::QuantizedMultipliers * input_multipliers,quant::QuantizedMultipliers * output_multipliers,quant::QuantizedRanges * output_ranges)145 LogicalResult DeviceTarget::DecomposeSameScale(
146 Operation* op, quant::QuantizedMultipliers* input_multipliers,
147 quant::QuantizedMultipliers* output_multipliers,
148 quant::QuantizedRanges* output_ranges) {
149 auto rop = llvm::dyn_cast<quant::QuantizeRegionOp>(op);
150 if (!rop) return failure();
151
152 // input multipliers
153 for (int i = 0; i < op->getNumOperands(); ++i) {
154 input_multipliers->push_back(kUnitQuantizedMultiplier);
155 }
156
157 // output multipliers
158 for (int i = 0; i < op->getNumResults(); ++i) {
159 output_multipliers->push_back(kUnitQuantizedMultiplier);
160 }
161
162 auto o_spec = rop.output_specs()[0]
163 .cast<TypeAttr>()
164 .getValue()
165 .dyn_cast<quant::UniformQuantizedType>();
166 if (!o_spec) return failure();
167
168 // output ranges
169 auto min = rop->getAttrOfType<FloatAttr>("min");
170 auto max = rop->getAttrOfType<FloatAttr>("max");
171 output_ranges->push_back(quant::CalculateQuantizedRange(
172 o_spec.getScale(), o_spec.getZeroPoint(),
173 (min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
174 (max ? absl::optional<double>(max.getValueAsDouble()) : absl::nullopt),
175 o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
176
177 return success();
178 }
179
180 } // namespace quant
181 } // namespace mlir
182