• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the op traits used in the MLIR TensorFlow Lite dialect.
17 
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
19 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
20 
21 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
22 #include "mlir/Support/LLVM.h"  // from @llvm-project
23 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
24 
25 using QuantizedType = mlir::quant::QuantizedType;
26 using UniformQuantizedType = mlir::quant::UniformQuantizedType;
27 
28 namespace mlir {
29 namespace quant {
30 // Verify that the op satisfies the same operands and results scales
31 // constraints. Note that this constraint can only be applied on some
32 // storage types of the op.
33 LogicalResult VerifySameScales(Operation* op);
34 }  // namespace quant
35 
36 // This includes the interface class definition. It couldn't be in a namespace
37 // because the table gen doesn't emit the namespace when it is used.
38 #include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.h.inc"
39 
40 namespace OpTrait {
41 namespace quant {
42 
43 // The base class that all the quantization related OpTrait implements.
44 template <typename ConcreteType, template <typename> class TraitType>
45 struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {
IsBiasQuantizationSpecTraitBase46   static bool IsBias(int index) { return false; }
IsQuantizableQuantizationSpecTraitBase47   static bool IsQuantizable() { return true; }
48 };
49 
50 // This class provides the API for TFL ops that has a fixed output value range.
51 // This is used as a trait like this:
52 //
53 //   class SoftmaxOp
54 //       : public Op<SoftmaxOp,
55 //           OpTrait::quant::FixedResultUniformScale<
56 //               8, -128, 390625, -8, 0, 255, false>::Impl> {
57 //
58 // TODO(fengliuai): create a better way to express floating point scale in the
59 // template argument list.
60 template <unsigned BitWidth, int ZeroPoint, int ScaleMantissa, int ScaleExp,
61           int64_t StorageTypeMin, int64_t StorageTypeMax, bool Sign>
62 class FixedResultUniformScale {
63  public:
64   template <typename ConcreteType>
65   class Impl
66       : public QuantizationSpecTraitBase<
67             ConcreteType, FixedResultUniformScale<
68                               BitWidth, ZeroPoint, ScaleMantissa, ScaleExp,
69                               StorageTypeMin, StorageTypeMax, Sign>::Impl> {
70    public:
GetResultQuantizedType(int index)71     QuantizedType GetResultQuantizedType(int index) {
72       auto op = this->getOperation();
73       auto result_type =
74           op->getResult(index).getType().template cast<ShapedType>();
75       if (!result_type.getElementType().template isa<FloatType>()) return {};
76       Builder builder(op->getContext());
77       IntegerType storage_type = builder.getIntegerType(BitWidth);
78       const double scale = static_cast<double>(ScaleMantissa) *
79                            ::pow(10.0, static_cast<double>(ScaleExp));
80       return UniformQuantizedType::getChecked(
81           Sign, storage_type, result_type.getElementType(), scale, ZeroPoint,
82           StorageTypeMin, StorageTypeMax, builder.getUnknownLoc());
83     }
84   };
85 };
86 
87 // This class provides the API for TFL ops that has input as bias. This is used
88 // as a trait like this:
89 //
90 //   class Conv2DOp
91 //       : public Op<Conv2DOp, OpTrait::quant::AccumulatorScale<2, 0, 1>::Impl>
92 //
93 // TODO(fengliuai): supports a configurable accumulator bit width.
94 template <int Bias, int... Operands>
95 class AccumulatorUniformScale {
96  public:
97   template <typename ConcreteType>
98   class Impl
99       : public QuantizationSpecTraitBase<
100             ConcreteType, AccumulatorUniformScale<Bias, Operands...>::Impl> {
101    public:
102     // Whether the index-th operand is a bias.
IsBias(int index)103     static bool IsBias(int index) { return index == Bias; }
104 
105     // Returns the indexes of all the non-bias operands.
GetAllNonBiasOperands()106     static std::vector<int> GetAllNonBiasOperands() {
107       return std::vector<int>({Operands...});
108     }
109   };
110 };
111 
112 // The trait to specify the operand index of the coefficient for an affine op
113 // and also the quantization dimension if per-axis quantization is support.
114 // If the quantization dimension is -1, per-axis quantization isn't supported.
115 //
116 //   class Conv2DOp
117 //       : public Op<Conv2DOp, OpTrait::quant::AffineOpCoefficient<0>::Impl>
118 //
119 template <int QuantDim, int OperandIndex = 1>
120 class AffineOpCoefficient {
121  public:
122   template <typename ConcreteType>
123   class Impl
124       : public TraitBase<ConcreteType,
125                          AffineOpCoefficient<QuantDim, OperandIndex>::Impl> {
126    public:
GetCoefficientOperandIndex()127     static int GetCoefficientOperandIndex() { return OperandIndex; }
GetQuantizationDim()128     static int GetQuantizationDim() { return QuantDim; }
129   };
130 };
131 
132 // This class provides the API for TFL ops that shouldn't be quantized. This is
133 // used as a trait like this:
134 //
135 //   class LessOp : public Op<LessOp, OpTrait::quant::NoQuantizableResult> {
136 //
137 template <typename ConcreteType>
138 class NoQuantizableResult
139     : public QuantizationSpecTraitBase<ConcreteType, NoQuantizableResult> {};
140 
141 }  // namespace quant
142 }  // namespace OpTrait
143 }  // namespace mlir
144 
145 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
146