1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Types.hpp> 9 10 #include <cmath> 11 #include <algorithm> 12 13 namespace armnn 14 { 15 16 using OffsetScalePair = std::pair<float, int>; 17 18 struct IQuantizationScheme 19 { 20 virtual OffsetScalePair ComputeScheme(double min, double max) const = 0; 21 22 virtual int NumBits() const = 0; 23 24 virtual DataType GetDataType() const = 0; 25 ~IQuantizationSchemearmnn::IQuantizationScheme26 virtual ~IQuantizationScheme() {} 27 }; 28 29 struct QAsymmU8QuantizationScheme : IQuantizationScheme 30 { ComputeSchemearmnn::QAsymmU8QuantizationScheme31 OffsetScalePair ComputeScheme(double min, double max) const override 32 { 33 if (min > max) 34 { 35 throw InvalidArgumentException("min > max will result in invalid quantization."); 36 } 37 38 double highest = (1 << NumBits()) - 1; 39 40 min = std::min(0.0, min); // min <= 0.0 41 max = std::max(0.0, max); // max >= 0.0 42 43 // To avoid dividing by zero when quantizing a zero filled tensor 44 if (min == 0.0 && max == 0.0) 45 { 46 max = 1.0; 47 } 48 49 // Assumes quantization range [0-highest] 50 double scale = (max-min) / highest; 51 double offset = -min / scale; 52 53 // Clamp offset [0-highest] 54 offset = std::max(0.0, std::min(highest, offset)); 55 56 return std::make_pair(static_cast<float>(scale), static_cast<int>(std::round(offset))); 57 } 58 NumBitsarmnn::QAsymmU8QuantizationScheme59 int NumBits() const override { return 8; } 60 GetDataTypearmnn::QAsymmU8QuantizationScheme61 DataType GetDataType() const override { return DataType::QAsymmU8; } 62 }; 63 64 struct QAsymmS8QuantizationScheme : IQuantizationScheme 65 { ComputeSchemearmnn::QAsymmS8QuantizationScheme66 OffsetScalePair ComputeScheme(double min, double max) const override 67 { 68 if (min > max) 69 { 70 throw InvalidArgumentException("min > max will result in invalid quantization."); 71 } 72 73 double highest = (1 << NumBits()) - 1; 74 75 min = std::min(0.0, min); // min <= 0.0 76 max = std::max(0.0, max); // max >= 0.0 77 78 // To avoid dividing by zero when quantizing a zero filled tensor 79 if (min == 0.0 && max == 0.0) 80 { 81 max = 1.0; 82 } 83 84 // Assumes quantization range [0-255] 85 double scale = (max-min) / highest ; 86 double offset = - min / scale; 87 88 //Clamp 0 to Highest 89 offset = std::max(0.0, std::min(highest, offset)); 90 91 //-128 on offset to cast to signed range 92 return std::make_pair(static_cast<float>(scale), static_cast<int>(std::round(offset)-128)); 93 } 94 NumBitsarmnn::QAsymmS8QuantizationScheme95 int NumBits() const override { return 8; } 96 GetDataTypearmnn::QAsymmS8QuantizationScheme97 DataType GetDataType() const override { return DataType::QAsymmS8; } 98 }; 99 100 struct QSymmS8QuantizationScheme : IQuantizationScheme 101 { ComputeSchemearmnn::QSymmS8QuantizationScheme102 OffsetScalePair ComputeScheme(double min, double max) const override 103 { 104 if (min > max) 105 { 106 throw InvalidArgumentException("min > max will result in invalid quantization."); 107 } 108 109 // To avoid dividing by zero when quantizing a zero filled tensor 110 if (min == 0.0 && max == 0.0) 111 { 112 max = 1.0; 113 } 114 115 double highest = (1 << (NumBits()-1)) - 1; // (numbits-1) accounts for the sign bit 116 117 double extent = std::max(std::abs(min), std::abs(max)); 118 double scale = extent / highest; 119 120 return std::make_pair(static_cast<float>(scale), 0); 121 } 122 NumBitsarmnn::QSymmS8QuantizationScheme123 int NumBits() const override { return 8; } 124 GetDataTypearmnn::QSymmS8QuantizationScheme125 DataType GetDataType() const override { return DataType::QSymmS8; } 126 }; 127 128 struct QSymm16QuantizationScheme : IQuantizationScheme 129 { ComputeSchemearmnn::QSymm16QuantizationScheme130 OffsetScalePair ComputeScheme(double min, double max) const override 131 { 132 if (min > max) 133 { 134 throw InvalidArgumentException("min > max will result in invalid quantization."); 135 } 136 137 // To avoid dividing by zero when quantizing a zero filled tensor 138 if (min == 0.0 && max == 0.0) 139 { 140 max = 1.0; 141 } 142 143 double highest = (1 << (NumBits()-1)) - 1; // (numbits-1) accounts for the sign bit 144 145 double extent = std::max(std::abs(min), std::abs(max)); 146 double scale = extent / highest; 147 148 return std::make_pair(static_cast<float>(scale), 0); 149 150 } 151 NumBitsarmnn::QSymm16QuantizationScheme152 int NumBits() const override { return 16; } 153 GetDataTypearmnn::QSymm16QuantizationScheme154 DataType GetDataType() const override { return DataType::QSymmS16; } 155 }; 156 157 } // namespace armnn 158