• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Types.hpp>
9 
10 #include <cmath>
11 #include <algorithm>
12 
13 namespace armnn
14 {
15 
16 using OffsetScalePair = std::pair<float, int>;
17 
18 struct IQuantizationScheme
19 {
20     virtual OffsetScalePair ComputeScheme(double min, double max) const = 0;
21 
22     virtual int NumBits() const = 0;
23 
24     virtual DataType GetDataType() const = 0;
25 
~IQuantizationSchemearmnn::IQuantizationScheme26     virtual ~IQuantizationScheme() {}
27 };
28 
29 struct QAsymmU8QuantizationScheme : IQuantizationScheme
30 {
ComputeSchemearmnn::QAsymmU8QuantizationScheme31     OffsetScalePair ComputeScheme(double min, double max) const override
32     {
33         if (min > max)
34         {
35             throw InvalidArgumentException("min > max will result in invalid quantization.");
36         }
37 
38         double highest = (1 << NumBits()) - 1;
39 
40         min = std::min(0.0, min); // min <= 0.0
41         max = std::max(0.0, max); // max >= 0.0
42 
43         // To avoid dividing by zero when quantizing a zero filled tensor
44         if (min == 0.0 && max == 0.0)
45         {
46             max = 1.0;
47         }
48 
49         // Assumes quantization range [0-highest]
50         double scale = (max-min) / highest;
51         double offset = -min / scale;
52 
53         // Clamp offset [0-highest]
54         offset = std::max(0.0, std::min(highest, offset));
55 
56         return std::make_pair(static_cast<float>(scale), static_cast<int>(std::round(offset)));
57     }
58 
NumBitsarmnn::QAsymmU8QuantizationScheme59     int NumBits() const override { return 8; }
60 
GetDataTypearmnn::QAsymmU8QuantizationScheme61     DataType GetDataType() const override { return DataType::QAsymmU8; }
62 };
63 
64 struct QAsymmS8QuantizationScheme : IQuantizationScheme
65 {
ComputeSchemearmnn::QAsymmS8QuantizationScheme66     OffsetScalePair ComputeScheme(double min, double max) const override
67     {
68         if (min > max)
69         {
70             throw InvalidArgumentException("min > max will result in invalid quantization.");
71         }
72 
73         double highest = (1 << NumBits()) - 1;
74 
75         min = std::min(0.0, min); // min <= 0.0
76         max = std::max(0.0, max); // max >= 0.0
77 
78         // To avoid dividing by zero when quantizing a zero filled tensor
79         if (min == 0.0 && max == 0.0)
80         {
81             max = 1.0;
82         }
83 
84         // Assumes quantization range [0-255]
85         double scale = (max-min) / highest ;
86         double offset = - min / scale;
87 
88         //Clamp 0 to Highest
89         offset = std::max(0.0, std::min(highest, offset));
90 
91         //-128 on offset to cast to signed range
92         return std::make_pair(static_cast<float>(scale), static_cast<int>(std::round(offset)-128));
93     }
94 
NumBitsarmnn::QAsymmS8QuantizationScheme95     int NumBits() const override { return 8; }
96 
GetDataTypearmnn::QAsymmS8QuantizationScheme97     DataType GetDataType() const override { return DataType::QAsymmS8; }
98 };
99 
100 struct QSymmS8QuantizationScheme : IQuantizationScheme
101 {
ComputeSchemearmnn::QSymmS8QuantizationScheme102     OffsetScalePair ComputeScheme(double min, double max) const override
103     {
104         if (min > max)
105         {
106             throw InvalidArgumentException("min > max will result in invalid quantization.");
107         }
108 
109         // To avoid dividing by zero when quantizing a zero filled tensor
110         if (min == 0.0 && max == 0.0)
111         {
112             max = 1.0;
113         }
114 
115         double highest = (1 << (NumBits()-1)) - 1; // (numbits-1) accounts for the sign bit
116 
117         double extent = std::max(std::abs(min), std::abs(max));
118         double scale = extent / highest;
119 
120         return std::make_pair(static_cast<float>(scale), 0);
121     }
122 
NumBitsarmnn::QSymmS8QuantizationScheme123     int NumBits() const override { return 8; }
124 
GetDataTypearmnn::QSymmS8QuantizationScheme125     DataType GetDataType() const override { return DataType::QSymmS8; }
126 };
127 
128 struct QSymm16QuantizationScheme : IQuantizationScheme
129 {
ComputeSchemearmnn::QSymm16QuantizationScheme130     OffsetScalePair ComputeScheme(double min, double max) const override
131     {
132         if (min > max)
133         {
134             throw InvalidArgumentException("min > max will result in invalid quantization.");
135         }
136 
137         // To avoid dividing by zero when quantizing a zero filled tensor
138         if (min == 0.0 && max == 0.0)
139         {
140             max = 1.0;
141         }
142 
143         double highest = (1 << (NumBits()-1)) - 1; // (numbits-1) accounts for the sign bit
144 
145         double extent = std::max(std::abs(min), std::abs(max));
146         double scale = extent / highest;
147 
148         return std::make_pair(static_cast<float>(scale), 0);
149 
150     }
151 
NumBitsarmnn::QSymm16QuantizationScheme152     int NumBits() const override { return 16; }
153 
GetDataTypearmnn::QSymm16QuantizationScheme154     DataType GetDataType() const override { return DataType::QSymmS16; }
155 };
156 
157 } // namespace armnn
158