• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the op traits used in the MLIR TensorFlow Lite dialect.
17 
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
19 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
20 
21 #include "mlir/Dialect/QuantOps/QuantTypes.h"  // TF:llvm-project
22 #include "mlir/Support/LLVM.h"  // TF:llvm-project
23 
24 namespace mlir {
25 namespace OpTrait {
26 namespace quant {
27 
28 using QuantizedType = mlir::quant::QuantizedType;
29 using UniformQuantizedType = mlir::quant::UniformQuantizedType;
30 
31 // The base class that all the quantization related OpTrait implements.
32 template <typename ConcreteType, template <typename> class TraitType>
33 struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {
IsBiasQuantizationSpecTraitBase34   static bool IsBias(int index) { return false; }
IsQuantizableQuantizationSpecTraitBase35   static bool IsQuantizable() { return true; }
36 };
37 
38 // This class provides the API for TFL ops that requires same input and output
39 // scale as the quantization results. This is used as a trait like this:
40 //
41 //   class TransposeOp
42 //       : public Op<TransposeOp, OpTrait::TFL::SameOperandsAndResultsScale> {
43 //
44 template <typename ConcreteType>
45 class SameOperandsAndResultsScale
46     : public QuantizationSpecTraitBase<ConcreteType,
47                                        SameOperandsAndResultsScale> {};
48 
49 // This class provides the API for TFL ops that has a fixed output value range.
50 // This is used as a trait like this:
51 //
52 //   class SoftmaxOp
53 //       : public Op<SoftmaxOp,
54 //           OpTrait::quant::FixedResultUniformScale<
55 //               8, -128, 390625, -8, 0, 255, false>::Impl> {
56 //
57 // TODO(fengliuai): create a better way to express floating point scale in the
58 // template argument list.
59 template <unsigned BitWidth, int ZeroPoint, int ScaleMantissa, int ScaleExp,
60           int64_t StorageTypeMin, int64_t StorageTypeMax, bool Sign>
61 class FixedResultUniformScale {
62  public:
63   template <typename ConcreteType>
64   class Impl
65       : public QuantizationSpecTraitBase<
66             ConcreteType, FixedResultUniformScale<
67                               BitWidth, ZeroPoint, ScaleMantissa, ScaleExp,
68                               StorageTypeMin, StorageTypeMax, Sign>::Impl> {
69    public:
GetResultQuantizedType(int index)70     QuantizedType GetResultQuantizedType(int index) {
71       auto op = this->getOperation();
72       auto result_type =
73           op->getResult(index).getType().template cast<ShapedType>();
74       if (!result_type.getElementType().template isa<FloatType>()) return {};
75       Builder builder(op->getContext());
76       IntegerType storage_type = builder.getIntegerType(BitWidth);
77       const double scale = static_cast<double>(ScaleMantissa) *
78                            ::pow(10.0, static_cast<double>(ScaleExp));
79       return UniformQuantizedType::getChecked(
80           Sign, storage_type, result_type.getElementType(), scale, ZeroPoint,
81           StorageTypeMin, StorageTypeMax, builder.getUnknownLoc());
82     }
83   };
84 };
85 
86 // This class provides the API for TFL ops that has input as bias. This is used
87 // as a trait like this:
88 //
89 //   class Conv2DOp
90 //       : public Op<Conv2DOp, OpTrait::quant::AccumulatorScale<2, 0, 1>::Impl>
91 //
92 // TODO(fengliuai): supports a configurable accumulator bit width.
93 template <int Bias, int... Operands>
94 class AccumulatorUniformScale {
95  public:
96   template <typename ConcreteType>
97   class Impl
98       : public QuantizationSpecTraitBase<
99             ConcreteType, AccumulatorUniformScale<Bias, Operands...>::Impl> {
100    public:
101     // Whether the index-th operand is a bias.
IsBias(int index)102     static bool IsBias(int index) { return index == Bias; }
103 
104     // Returns the indexes of all the non-bias operands.
GetAllNonBiasOperands()105     static std::vector<int> GetAllNonBiasOperands() {
106       return std::vector<int>({Operands...});
107     }
108   };
109 };
110 
111 // The trait to specify the operand index of the coefficient for an affine op
112 // and also the quantization dimension if per-axis quantization is support.
113 // If the quantization dimension is -1, per-axis quantization isn't supported.
114 //
115 //   class Conv2DOp
116 //       : public Op<Conv2DOp, OpTrait::quant::AffineOpCoefficient<0>::Impl>
117 //
118 template <int QuantDim, int OperandIndex = 1>
119 class AffineOpCoefficient {
120  public:
121   template <typename ConcreteType>
122   class Impl
123       : public TraitBase<ConcreteType,
124                          AffineOpCoefficient<QuantDim, OperandIndex>::Impl> {
125    public:
GetCoefficientOperandIndex()126     static int GetCoefficientOperandIndex() { return OperandIndex; }
GetQuantizationDim()127     static int GetQuantizationDim() { return QuantDim; }
128   };
129 };
130 
131 // This class provides the API for TFL ops that shouldn't be quantized. This is
132 // used as a trait like this:
133 //
134 //   class LessOp : public Op<LessOp, OpTrait::quant::NoQuantizableResult> {
135 //
136 template <typename ConcreteType>
137 class NoQuantizableResult
138     : public QuantizationSpecTraitBase<ConcreteType, NoQuantizableResult> {};
139 
140 }  // namespace quant
141 }  // namespace OpTrait
142 }  // namespace mlir
143 
144 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_
145