• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
17 
18 #include <algorithm>
19 
20 #include "absl/types/optional.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
26 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h"
30 
31 namespace mlir {
32 namespace quant {
33 
34 constexpr int k8Bits = 8;
35 constexpr int k32Bits = 32;
36 constexpr unsigned kSigned = quant::QuantizationFlags::Signed;
37 
DeviceTarget(MLIRContext * ctx)38 DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) {
39   f32_ = FloatType::getF32(ctx_);
40   i8_ = IntegerType::get(ctx_, k8Bits);
41   i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits);
42   i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits);
43   i32_ = IntegerType::get(ctx_, k32Bits);
44   i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits);
45   i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits);
46   any_ = AnyQuantizedType();
47   qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_);
48   qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_);
49   qi32_ = AnyQuantizedType::get(kSigned, i32_, f32_, i32_min_, i32_max_);
50   assert(qi8n_ == qi8n_);
51 }
52 
GetKernelSpec(llvm::StringRef kernel,const KernelSpecs::Signature & signature) const53 Optional<KernelSpec> DeviceTarget::GetKernelSpec(
54     llvm::StringRef kernel, const KernelSpecs::Signature& signature) const {
55   auto kernel_specs_it = specs_.find(kernel);
56   if (kernel_specs_it == specs_.end()) return llvm::None;
57   return kernel_specs_it->getValue().Find(signature);
58 }
59 
GetDecomposeFn(QuantizeRegionOp op) const60 ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const {
61   auto kernel_specs_it = specs_.find(op.logical_kernel());
62   if (kernel_specs_it == specs_.end()) return ScaleDecomposeFn(nullptr);
63   return kernel_specs_it->second.GetDecomposeFn();
64 }
65 
AppendToSignature(Type spec,KernelSpecs::Signature * signature)66 void DeviceTarget::AppendToSignature(Type spec,
67                                      KernelSpecs::Signature* signature) {
68   if (auto quant = spec.dyn_cast_or_null<UniformQuantizedType>()) {
69     signature->push_back(AnyQuantizedType::get(
70         quant.getFlags(), quant.getStorageType(), quant.getExpressedType(),
71         quant.getStorageTypeMin(), quant.getStorageTypeMax()));
72   } else if (auto any = spec.dyn_cast_or_null<AnyQuantizedType>()) {
73     signature->push_back(any);
74   } else {  // float
75     signature->push_back(AnyQuantizedType());
76   }
77 }
78 
RegisterKernel(llvm::StringRef kernel,const KernelSpecs::Signature & signature,const ScaleFn & fn,const ScaleDecomposeFn & dfn)79 LogicalResult DeviceTarget::RegisterKernel(
80     llvm::StringRef kernel, const KernelSpecs::Signature& signature,
81     const ScaleFn& fn, const ScaleDecomposeFn& dfn) {
82   return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn});
83 }
84 
85 namespace ph = std::placeholders;
86 
RegisterKernel(llvm::StringRef kernel,const KernelSpecs::Signature & signature,const ScaleConstraintType constraint)87 LogicalResult DeviceTarget::RegisterKernel(
88     llvm::StringRef kernel, const KernelSpecs::Signature& signature,
89     const ScaleConstraintType constraint) {
90   if (failed(specs_[kernel].Add(signature, {constraint, {}}))) return failure();
91   switch (constraint) {
92     case ScaleConstraintType::OutputInputSameScale:
93       specs_[kernel].WithImpl(std::bind(&DeviceTarget::DecomposeSameScale,
94                                         ph::_1, ph::_2, ph::_3, ph::_4));
95       return success();
96     default:
97       return failure();
98   }
99 }
100 
DecomposeMultiplyAccumulateScale(Operation * op,quant::QuantizedMultipliers * input_multipliers,quant::QuantizedMultipliers * output_multipliers,quant::QuantizedRanges * output_ranges)101 LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale(
102     Operation* op, quant::QuantizedMultipliers* input_multipliers,
103     quant::QuantizedMultipliers* output_multipliers,
104     quant::QuantizedRanges* output_ranges) {
105   auto rop = llvm::dyn_cast<quant::QuantizeRegionOp>(op);
106   if (!rop) return failure();
107 
108   llvm::SmallVector<Type, 4> input_specs, out_specs;
109   for (auto spec : rop.input_specs()) {
110     input_specs.push_back(spec.cast<TypeAttr>().getValue());
111   }
112   for (auto spec : rop.output_specs()) {
113     out_specs.push_back(spec.cast<TypeAttr>().getValue());
114   }
115 
116   auto in_spec = input_specs[0].dyn_cast<quant::UniformQuantizedType>();
117   // TODO(fengliuai): handles the PerAxis QuantizedType.
118   auto w_spec = input_specs[1].dyn_cast<quant::UniformQuantizedType>();
119   auto b_spec = input_specs[2].dyn_cast<quant::UniformQuantizedType>();
120   auto o_spec = out_specs[0].dyn_cast<quant::UniformQuantizedType>();
121   if (!in_spec || !w_spec || !b_spec || !o_spec) return failure();
122 
123   double scale_product = in_spec.getScale() * w_spec.getScale();
124   if (fabs(scale_product - b_spec.getScale()) >= 1e-6) return failure();
125 
126   // input multipliers
127   input_multipliers->append(3, kUnitQuantizedMultiplier);
128 
129   // output multipliers
130   double real_multiplier = scale_product / o_spec.getScale();
131   output_multipliers->push_back(quant::QuantizeMultiplier(real_multiplier));
132 
133   // output ranges
134   auto min = rop->getAttrOfType<FloatAttr>("min");
135   auto max = rop->getAttrOfType<FloatAttr>("max");
136   output_ranges->push_back(quant::CalculateQuantizedRange(
137       o_spec.getScale(), o_spec.getZeroPoint(),
138       (min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
139       (max ? absl::optional<double>(max.getValueAsDouble()) : absl::nullopt),
140       o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
141 
142   return success();
143 }
144 
DecomposeSameScale(Operation * op,quant::QuantizedMultipliers * input_multipliers,quant::QuantizedMultipliers * output_multipliers,quant::QuantizedRanges * output_ranges)145 LogicalResult DeviceTarget::DecomposeSameScale(
146     Operation* op, quant::QuantizedMultipliers* input_multipliers,
147     quant::QuantizedMultipliers* output_multipliers,
148     quant::QuantizedRanges* output_ranges) {
149   auto rop = llvm::dyn_cast<quant::QuantizeRegionOp>(op);
150   if (!rop) return failure();
151 
152   // input multipliers
153   for (int i = 0; i < op->getNumOperands(); ++i) {
154     input_multipliers->push_back(kUnitQuantizedMultiplier);
155   }
156 
157   // output multipliers
158   for (int i = 0; i < op->getNumResults(); ++i) {
159     output_multipliers->push_back(kUnitQuantizedMultiplier);
160   }
161 
162   auto o_spec = rop.output_specs()[0]
163                     .cast<TypeAttr>()
164                     .getValue()
165                     .dyn_cast<quant::UniformQuantizedType>();
166   if (!o_spec) return failure();
167 
168   // output ranges
169   auto min = rop->getAttrOfType<FloatAttr>("min");
170   auto max = rop->getAttrOfType<FloatAttr>("max");
171   output_ranges->push_back(quant::CalculateQuantizedRange(
172       o_spec.getScale(), o_spec.getZeroPoint(),
173       (min ? absl::optional<double>(min.getValueAsDouble()) : absl::nullopt),
174       (max ? absl::optional<double>(max.getValueAsDouble()) : absl::nullopt),
175       o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax()));
176 
177   return success();
178 }
179 
180 }  // namespace quant
181 }  // namespace mlir
182