1 Computation with less than 8 bits in gemmlowp 2 ********************************************* 3 4 5Introduction 6============ 7 8We assume familiarity with gemmlowp's low-precision uint8 computation 9paradigm, which is described in doc/low-precision.txt. 10 11This document is about the possibility of further reducing precision 12below 8 bits. 13 14That allows to get higher arithmetic throughput on some architectures, 15at the cost of decreased accuracy. 16 17 18Public interface 19================ 20 21 22The BitDepthSetting parameter in the EightBitIntGemm interface 23-------------------------------------------------------------- 24 25Accessing less-than-8-bit computation via the EightBitIntGemm is very 26simple: EightBitIntGemm takes a BitDepthSetting enum 27which allows to choose among a fixed set of supported bit-depth 28combinations. 29 30 31The BitDepthParams parameter in the public/gemmlowp.h interface 32--------------------------------------------------------------- 33 34The public/gemmlowp.h interface exposes more extensive control over 35quantization, by means of a BitDepthParams template parameter, 36which is a type parameter, carrying information about: 37 1. The LHS and RHS bit depth, which can be set arbitrarily and 38 independently; 39 2. The 'RoundingStrategy', which is the heuristic used to choose 40 a rounding mode, based on the accumulation size (a.k.a. the 41 "depth" dimension of the Gemm). 42Details can be seen in public/bit_depth.h. 43 44 45How does BitDepth{Setting,Params} affect input/output uint8 matrix data? 46------------------------------------------------------------------- 47 48Input/output matrix data is all uint8's, ranging from 0 to 255, regardless of 49the BitDepth{Setting,Params}. 50 51So the BitDepth{Setting,Params} is only an internal detail. It only means to 52allow gemmlowp to use lower precision internally, but the input/output data 53format is unaffected. 54 55As far as the API contract goes, the only thing that the 56BitDepth{Setting,Params} does is to relax the accuracy requirement. 57With standard 8bit/8bit computation, gemmlowp is required to return the exact 58result as specified in doc/low-precision.txt. With lower bit depths, gemmlowp 59is no longer required to return an exact result. 60 61 62Implementation 63============== 64 65Here we refer to the 3 stages of computation as described in doc/design.txt, 66namely: packing, computation kernel, unpacking. 67 68The general idea is that at the packing stage, we requantize input (Lhs/Rhs) 69data to less-than-8-bit depths by scaling them, thus shrinking the range of 70the packed matrix entries; for instance, if the Rhs bit depth is to be 5 bits 71then packed Rhs matrix entries will be in the range [0 ... 31]. This then 72allows the GEMM kernel to use narrower accumulators without risking overflow, 73thus achieving higher arithmetic throughput. Finally, at the unpacking stage, 74it only remains to scale the result values to compensate for the scalings 75applied earlier. 76 77Let us go into more detail for each of those stages: 78 79 80Packing stage 81------------- 82 83The packing stage is where most of the work specific to the BitDepthParams 84takes place. 85 86Here, we have to scale input matrix values from their original range of 87[0 ... 255] to the range specified by the BitDepthParams, which is 88[0 ... (2^N)-1] where N is the number of bits for the matrix at hand 89(Lhs or Rhs). For example, for a bit depth of 5 bits, we need to scale 90down to [0 ... 31]. 91 92This scaling is what we call "requantization". The pedantic name matches 93the fact that this is actually quite nontrivial to do correctly i.e. 94in such a way that the result accuracy will be good enough for real-world 95applications. See the section below on requantization details. 96 97Concretely, this work happens in PackingRegisterBlock::Pack(), which calls 98Requantize(). This is in internal/pack.h. This code can be overridden for 99specific architectures, see internal/pack_neon.h. 100 101This requantization work is costly and makes packing slower. This means 102that, at least in our approach, less-than-8-bit computation is only 103interesting for large-enough, square-enough GEMMs where packing is only 104a small fraction of the overall cost. In cases where packing overhead 105is more prevalent (highly rectangular cases), less-than-8-bit is probably 106a waste of time as long as we treat it as an internal computation detail. 107What might help there, might be if we shrunk the input/output data format 108to lower memory bandwidth usage. 109 110 111Computation kernel stage 112------------------------ 113 114In principle, the computation kernel stage simply doesn't have to care 115about the bit depth at all. In fact, on architectures where we do not have 116specific optimized kernels for less-than-8-bit cases, we simply use our 117standard kernel there, and that's correct! 118 119However, while the kernel doesn't have to know about the fact that the 120operands are on less than 8 bits, it can use that information to make 121special optimizations that would be incorrect in the general 8-bit case 122and become correct here thanks to the more restricted range of inputs. 123That's the whole point of this less-than-8-bit computation idea. 124 125With Lhs entries guaranteed to be smaller than 2^N, and Rhs entries 126guaranteed to be smaller than 2^M, each product is thus guaranteed to be 127smaller than 2^(M+N). Thus, one may accumulate 2^(16-(M+N)) such products 128and still be guaranteed that such an accumulator will be smaller than 2^16, 129and thus can be stored as a uint16 without risking overflow. 130 131For example, in the L7R5 case, the Lhs enties are on 7 bits (N=7) and the 132Rhs entries are on 5 bits (M=5), so each product fits in 12 bits and one can 133thus accumulate 16 ( = 2^(16-12)) such products into uint16 accumulators 134with no risk of overflow. 135 136This means that a computation kernel may use uint16 accumulators for 137several loop iterations (16 in the above example), provided that it is 138allowed to assume that inputs are in such restricted range. 139 140After this fixed number of loop iterations, the kernel must accumulate 141the local uint16 accumulators back into global uint32 accumulators. 142 143On SIMD architectures with suitable uint16 arithmetic, this in principle 144allows to multiply arithmetic throughput by up to 2x, since twice more 145accumulators now fit in each SIMD vector register. This is partially offset 146by the cost of accumulating back into global uint32 accumulators every 147several loop iterations, but our experience on ARM NEON has been that 148we still get quite close to a 2x speedup. See internal/kernel_neon.h, 149specifically NEON32Kernel12x4Depth2Assuming12BitProducts. 150 151 152Unpacking stage 153--------------- 154 155At the unpacking stage, it only remains to scale the result values 156to compensate for the scaling of the inputs. This is easier because 157now we are expanding the range instead of shrinking it, so we don't 158need to worry about ways to minimize a loss of accuracy. We simply 159need to multiply result values by a constant fraction, rounding to nearest. 160 161Since the inputs were scaled by factors of (2^lhs_bits - 1)/255 and 162(2^rhs_bits - 1)/255 respectively, the scaling of the outputs needs to be 163by the following factor: 164 165 255 * 255 166 ----------------------------------- 167 (2^lhs_bits - 1) * (2^rhs_bits - 1) 168 169This is done by a MultiplyByConstantFraction function, see internal/unpack.h 170 171 172Requantization details 173====================== 174 175Here we go into more detail on the Requantize() function used at the packing 176stage to requantize input matrix data. See this function in internal/pack.h. 177 178It depends on the bit depth and on a rounding mode, and requantizes an input 179value in [0 ... 255] to the range [0 ... (2^N)-1] specified by the bit depth N. 180 181 182Naive, bad rounding, that's plainly biased 183------------------------------------------ 184 185Naive and inaccurate ways to achieve this requantization include: 186 1. By shifting bits rights by (8-N) bits; 187 2. By multiplying by ((2^N) - 1) and dividing by 255. 188 189Both of those are biased in some way: 1. has the wrong "derivative", since it 190approximates (((2^N) - 1) / 255) by ((2^N) / 256) ; 2. has bias since it 191effectively implements rounding towards 0. 192 193In practice, both of the above requantization functions give results that are 194too inaccurate in practice for the neural network that we tried (GoogLeNet). 195 196Round-to-nearest rounding: unbiased in principle but not in practice 197-------------------------------------------------------------------- 198 199The simplest fix is to avoid the bias in 2. by rounding-to-nearest instead 200of rounding towards 0. This can be achieved by doing 201 202 dst = (src * maxval + rounding_offset) / 255; 203 204Where maxval = ((2^N) - 1) is the highest requantized value, and the 205rounding_offset can be set to 206 207 rounding_offset = 127 208 209to achieve rounding-to-nearest (while the above rounding towards 0 210corresponded to rounding_offset = 0). 211 212In principle, rounding-to-nearest is unbiased and optimal in various ways. 213 214In practice though, our input data is not random real numbers, but 215already-quantized 8-bit values. That means that even in the best case, there 216would be at most 255 different possible input values; in practice, we generally 217see the input values distributed non-uniformly in that range, so that a majority 218of input values tend to be in a much smaller range. See test/test_data.cc. 219 220Having a large part of the input values in a very small finite set, means that 221the corresponding rounding errors are also in a very small finite set, which 222can be small enough that the mean of these rounding errors is significantly 223different from 0. That rounding-to-nearest is "unbiased" only means that over 224a sufficiently large set of input values, the bias would become arbitrarily 225close to 0; here, the set of input values is effectively small enough that the 226resulting bias is significant. 227 228This leads to biasing the matrix product entries, resulting in an error that 229grows linearly with the depth dimension of the GEMM. 230 231 232Probabilistic rounding: unbiased even on small finite input distributions 233------------------------------------------------------------------------- 234 235To address that, we can instead use probabilistic rounding. The idea is that 236for instance if we have to round the value 3.8 to the nearest integer, we can 237round it to 3 with 20% probability and to 4 with probability 80%. If that value 2383.8 occurs many times, the mean requantized value will thus tend to 3.8. 239 240This amounts to keeping the above requantization formula, 241 242 dst = (src * maxval + rounding_offset) / 255; 243 244but now the rounding_offset is a random value in [0 .. 254]. 245 246This guarantees zero bias no matter how small the distribution of input values 247is. 248 249On the other hand, the variance of the error term here is higher than with 250rounding-to-nearest --- one can check that it is 2x higher. 251 252So the error term coming from the Central Limit Theorem, which grows with 253the square root of the accumulator depth i.e. the GEMM depth, 254will be 2x higher here. 255 256Still, for large enough GEMM depth, that is better than rounding-to-nearest 257which has an error term growing linearly with the GEMM depth. 258 259 260Switching between rounding-to-nearest and probabilistic rounding 261---------------------------------------------------------------- 262 263Thus, for fixed input values and bit depths, we expect that probabilistic 264rounding will give more accurate results for large enough GEMM depths, while 265rounding-to-nearest will be more accurate for smaller GEMM depths. 266 267That is why use switch between these rounding modes based on GEMM depth, 268see ChooseRoundingMode in internal/bit_depth_util.h. 269 270It is based on a constant, kProbabilisticRoundingThreshold, defined 271in internal/common.h and empirically determined. See the comment there. 272It would be nice to better understand the statistics here and come up 273with better heuristics for this switching. 274 275 276Choice of pseudorandom number generator 277--------------------------------------- 278We provide two PRNGs. The first is an 8-bit Xorshift. 279It is fast, naturally produces values ranging 280over an interval of width 255, which is what we need here (as opposed 281to an interval of width 256), and turns out, from empirical tests, 282to produce better results than a linear congruential generator (LCG). 283That's unfortunate, as a 8-bit LCG performs better (we confirmed that 284on various ARM devices) but we need as perfect un-biased-ness as we can 285get. 286 287The second is an "add-mod" sequence generator, which generates 288non-random values in the sequence x = (x + 97) % 255. This 289generates a low-discrepancy sequence that minimizes the "clumpiness" 290of the random offsets (Thus, for example, quantizing a 3x3 matrix will 291have a maximum additive error of about 200 from the random offsets). 292While not random, this sequence performs well empirically for many 293quantizations. (For information about why 97 is a good value, see 294https://en.wikipedia.org/wiki/Low-discrepancy_sequence#Additive_recurrence 295and http://mollwollfumble.blogspot.com/2011/03/subrandom-numbers.html 29697/255 = 0.38; 0.382 is the best choice. For discrete numbers, 297the choice must be relatively prime to the modulus. 97 is prime, 298so it is safely relatively prime to 255. 107 is another near-optimal 299choice. 300 301The low-discrepancy sequence generator is the default. 302 303More details and results are given in a comment on the default 304PRNG in internal/pack.h. Interested users can change the 305PRNG used by setting DefaultRoundingGenerator in bit_depth_util.h. 306