1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines the op traits used in the MLIR TensorFlow Lite dialect. 17 18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ 19 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ 20 21 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project 22 #include "mlir/Support/LLVM.h" // from @llvm-project 23 #include "mlir/Support/LogicalResult.h" // from @llvm-project 24 25 using QuantizedType = mlir::quant::QuantizedType; 26 using UniformQuantizedType = mlir::quant::UniformQuantizedType; 27 28 namespace mlir { 29 namespace quant { 30 // Verify that the op satisfies the same operands and results scales 31 // constraints. Note that this constraint can only be applied on some 32 // storage types of the op. 33 LogicalResult VerifySameScales(Operation* op); 34 } // namespace quant 35 36 // This includes the interface class definition. It couldn't be in a namespace 37 // because the table gen doesn't emit the namespace when it is used. 38 #include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.h.inc" 39 40 namespace OpTrait { 41 namespace quant { 42 43 // The base class that all the quantization related OpTrait implements. 44 template <typename ConcreteType, template <typename> class TraitType> 45 struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> { IsBiasQuantizationSpecTraitBase46 static bool IsBias(int index) { return false; } IsQuantizableQuantizationSpecTraitBase47 static bool IsQuantizable() { return true; } 48 }; 49 50 // This class provides the API for TFL ops that has a fixed output value range. 51 // This is used as a trait like this: 52 // 53 // class SoftmaxOp 54 // : public Op<SoftmaxOp, 55 // OpTrait::quant::FixedResultUniformScale< 56 // 8, -128, 390625, -8, 0, 255, false>::Impl> { 57 // 58 // TODO(fengliuai): create a better way to express floating point scale in the 59 // template argument list. 60 template <unsigned BitWidth, int ZeroPoint, int ScaleMantissa, int ScaleExp, 61 int64_t StorageTypeMin, int64_t StorageTypeMax, bool Sign> 62 class FixedResultUniformScale { 63 public: 64 template <typename ConcreteType> 65 class Impl 66 : public QuantizationSpecTraitBase< 67 ConcreteType, FixedResultUniformScale< 68 BitWidth, ZeroPoint, ScaleMantissa, ScaleExp, 69 StorageTypeMin, StorageTypeMax, Sign>::Impl> { 70 public: GetResultQuantizedType(int index)71 QuantizedType GetResultQuantizedType(int index) { 72 auto op = this->getOperation(); 73 auto result_type = 74 op->getResult(index).getType().template cast<ShapedType>(); 75 if (!result_type.getElementType().template isa<FloatType>()) return {}; 76 Builder builder(op->getContext()); 77 IntegerType storage_type = builder.getIntegerType(BitWidth); 78 const double scale = static_cast<double>(ScaleMantissa) * 79 ::pow(10.0, static_cast<double>(ScaleExp)); 80 return UniformQuantizedType::getChecked( 81 Sign, storage_type, result_type.getElementType(), scale, ZeroPoint, 82 StorageTypeMin, StorageTypeMax, builder.getUnknownLoc()); 83 } 84 }; 85 }; 86 87 // This class provides the API for TFL ops that has input as bias. This is used 88 // as a trait like this: 89 // 90 // class Conv2DOp 91 // : public Op<Conv2DOp, OpTrait::quant::AccumulatorScale<2, 0, 1>::Impl> 92 // 93 // TODO(fengliuai): supports a configurable accumulator bit width. 94 template <int Bias, int... Operands> 95 class AccumulatorUniformScale { 96 public: 97 template <typename ConcreteType> 98 class Impl 99 : public QuantizationSpecTraitBase< 100 ConcreteType, AccumulatorUniformScale<Bias, Operands...>::Impl> { 101 public: 102 // Whether the index-th operand is a bias. IsBias(int index)103 static bool IsBias(int index) { return index == Bias; } 104 105 // Returns the indexes of all the non-bias operands. GetAllNonBiasOperands()106 static std::vector<int> GetAllNonBiasOperands() { 107 return std::vector<int>({Operands...}); 108 } 109 }; 110 }; 111 112 // The trait to specify the operand index of the coefficient for an affine op 113 // and also the quantization dimension if per-axis quantization is support. 114 // If the quantization dimension is -1, per-axis quantization isn't supported. 115 // 116 // class Conv2DOp 117 // : public Op<Conv2DOp, OpTrait::quant::AffineOpCoefficient<0>::Impl> 118 // 119 template <int QuantDim, int OperandIndex = 1> 120 class AffineOpCoefficient { 121 public: 122 template <typename ConcreteType> 123 class Impl 124 : public TraitBase<ConcreteType, 125 AffineOpCoefficient<QuantDim, OperandIndex>::Impl> { 126 public: GetCoefficientOperandIndex()127 static int GetCoefficientOperandIndex() { return OperandIndex; } GetQuantizationDim()128 static int GetQuantizationDim() { return QuantDim; } 129 }; 130 }; 131 132 // This class provides the API for TFL ops that shouldn't be quantized. This is 133 // used as a trait like this: 134 // 135 // class LessOp : public Op<LessOp, OpTrait::quant::NoQuantizableResult> { 136 // 137 template <typename ConcreteType> 138 class NoQuantizableResult 139 : public QuantizationSpecTraitBase<ConcreteType, NoQuantizableResult> {}; 140 141 } // namespace quant 142 } // namespace OpTrait 143 } // namespace mlir 144 145 #endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ 146