1 /* Copyright 2019 Google LLC. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef RUY_RUY_MUL_PARAMS_H_ 17 #define RUY_RUY_MUL_PARAMS_H_ 18 19 #include <cstdint> 20 #include <limits> 21 #include <type_traits> 22 23 #include "ruy/check_macros.h" 24 #include "ruy/size_util.h" 25 26 namespace ruy { 27 28 // Enumeration to designate which dimension is the 'channels', for MulParams 29 // features that are 'per-channel', namely the bias-vector and the quantized 30 // multiplier. 31 enum class ChannelDimension : std::int8_t { 32 // kRow means that 'per-channel' means 'per row of the destination matrix' 33 kRow, 34 // kCol means that 'per-channel' means 'per column of the destination matrix' 35 kCol 36 }; 37 38 namespace detail { 39 template <typename tAccumScalar, typename tDstScalar> 40 struct MulParamsStorage; 41 } 42 43 // MulParams describes all about a matrix multiplication that 44 // isn't encoded in the LHS, RHS and destination matrices. Some of that 45 // information is encoded as compile-time constants and types (for instance, the 46 // choice of accumulator type, AccumScalar). Some of that information is encoded 47 // as runtime values (for instance, the optional bias vector). 48 // 49 // Template parameters: 50 // AccumScalar: Accumulator type. The type of accumulators used to compute the 51 // dot-products before being ultimately casted to the destination type. 52 // DstScalar: The destination scalar type. 53 // 54 // Constraints on these template parameters (see also the ruy::Mul comment): 55 // * If DstScalar is floating-point then AccumScalar must also be. 56 // * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover 57 // in that integral case, there is a mode switch: 58 // - If DstScalar is std::int32_t then the multiplier_* fields are all 59 // disabled, and ruy::Mul will just return raw (unscaled) accumulators. 60 // - If DstScalar is not std::int32_t then the multiplier_* fields are 61 // enabled, and ruy::Mul will use them to scale internal std::int32_t 62 // accumulators before casting them to the DstScalar type. The default 63 // values are such that the effective multiplier is 1 (no scaling). 64 // 65 // For the latter case (DstScalar integral and narrower than std::int32_t), 66 // reference code can be found in the implementation of ruy::ApplyMultiplier. 67 // If you look there, you'll find warnings like this: 68 // 69 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 70 // Warning: this code is not meant to be bit-exact-normative. 71 // Please refer to the class comment of ruy::MulParams, in mul_params.h. 72 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 73 // 74 // The explanation of this warning is that as of early 2021, we still don't know 75 // whether it is advisable to let this code as-is have normative value, or 76 // whether that would become advisable after some specific final change. 77 // 78 // Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform 79 // bit-exactly to this reference, but we also know that x86 could be faster if 80 // it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't 81 // know that this particular reference code is inherently better than other 82 // forms that could perform better on these architectures --- in fact, the 83 // alternative that was proposed in [2] as better performing on ARM Cortex-M 84 // is also inherently more accurate thanks to rounding only once, but it would 85 // perform worse on both ARM NEON, and x86. 86 // 87 // In fact, if we look at other hardware architectures beyond current Ruy 88 // targets, namely "hardware accelerators", it becomes clear that there is no 89 // hope for any form of this to be efficiently implementable simultaneously on 90 // all current relevant hardware. Indeed, some accelerators prefer to perform 91 // the multiplication in IEEE float32, others in IEEE float16, others in 92 // bfloat16, others in 16-bit fixed-point... 93 // 94 // See: 95 // [1] https://github.com/google/ruy/pull/227 96 // [2] https://github.com/tensorflow/tensorflow/issues/25087 97 template <typename tAccumScalar, typename tDstScalar> 98 class MulParams final { 99 public: 100 using AccumScalar = tAccumScalar; 101 using DstScalar = tDstScalar; 102 103 // The bias vector data, if not null. bias()104 const AccumScalar* bias() const { return storage_.bias; } set_bias(const AccumScalar * ptr)105 void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; } 106 // Only for non-floating-point cases. The fixed-point part of the multiplier 107 // by which accumulators are multiplied before being casted to the destination 108 // type. This is a fixed-point quantity with 0 integer bits. Since 109 // (as explained in the class comment) AccumScalar must be std::int32_t, 110 // that means that the fixed-point format is Q0.31. For example, 111 // a multiplier_fixedpoint value of 2^30 has the effect of multiplying 112 // by one half (1/2). More generally, the effect is to multiply by 113 // (multiplier_fixedpoint / (2^31)). multiplier_fixedpoint()114 AccumScalar multiplier_fixedpoint() const { 115 return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint; 116 } set_multiplier_fixedpoint(const AccumScalar value)117 void set_multiplier_fixedpoint(const AccumScalar value) { 118 set_perchannel(false); 119 storage_.multiplier_fixedpoint = value; 120 } 121 // Only for non-floating-point cases. The exponent part of the aforementioned 122 // multiplier. multiplier_exponent()123 int multiplier_exponent() const { 124 return storage_.perchannel ? 0 : storage_.multiplier_exponent; 125 } set_multiplier_exponent(const int value)126 void set_multiplier_exponent(const int value) { 127 set_perchannel(false); 128 storage_.multiplier_exponent = value; 129 } 130 // Per-channel variant of multiplier_fixedpoint. Setting this switches 131 // to per-channel mode, where `multiplier_fixedpoint` and 132 // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel` 133 // and `multiplier_exponent_perchannel` are used instead. 134 // 135 // This must point to a buffer of as many values as there are rows or columns 136 // in the destination matrix, whichever is the channels dimension. Each 137 // channel of the destination matrix will use the corresponding buffer element 138 // instead of multiplier_fixedpoint. multiplier_fixedpoint_perchannel()139 const AccumScalar* multiplier_fixedpoint_perchannel() const { 140 return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel 141 : nullptr; 142 } set_multiplier_fixedpoint_perchannel(const AccumScalar * ptr)143 void set_multiplier_fixedpoint_perchannel(const AccumScalar* ptr) { 144 set_perchannel(true); 145 storage_.multiplier_fixedpoint_perchannel = ptr; 146 } 147 // Per-channel variant of multiplier_exponent. Same comments as for 148 // multiplier_fixedpoint_perchannel. multiplier_exponent_perchannel()149 const int* multiplier_exponent_perchannel() const { 150 return storage_.perchannel ? storage_.multiplier_exponent_perchannel 151 : nullptr; 152 } set_multiplier_exponent_perchannel(const int * ptr)153 void set_multiplier_exponent_perchannel(const int* ptr) { 154 set_perchannel(true); 155 storage_.multiplier_exponent_perchannel = ptr; 156 } 157 // min clamp bound of destination values. clamp_min()158 DstScalar clamp_min() const { return storage_.clamp_min; } set_clamp_min(const DstScalar value)159 void set_clamp_min(const DstScalar value) { storage_.clamp_min = value; } 160 // max clamp bound of destination values. clamp_max()161 DstScalar clamp_max() const { return storage_.clamp_max; } set_clamp_max(const DstScalar value)162 void set_clamp_max(const DstScalar value) { storage_.clamp_max = value; } 163 // Designates which dimension is the 'channels', for per-channel features 164 // such as bias-addition and per-channel quantization multipliers. channel_dimension()165 ChannelDimension channel_dimension() const { 166 return storage_.channel_dimension; 167 } set_channel_dimension(ChannelDimension value)168 void set_channel_dimension(ChannelDimension value) { 169 storage_.channel_dimension = value; 170 } 171 // Specifies the upward rounding of the allocated capacity of per-channel 172 // buffers such as bias vectors and per-channel quantization multipliers. 173 // The unit is matrix entries, not bytes. 174 // 175 // This value must be a power of two. 176 // 177 // The default value, 1, means no upward rounding, meaning that the buffers 178 // are not required to have a capacity greater than the size of the 179 // corresponding matrix dimension, i.e. the number of rows (respectively 180 // columns) of the destination matrix if `channel_dimension()` is kRow 181 // (respectively kCol). 182 // 183 // Higher values allow the implementation to assume that it is OK to access 184 // these buffers a little past this boundary, which is useful in SIMD 185 // optimized kernels. In practice, when this value is lower than what the 186 // kernel requires, ruy has to internally reallocate and copy per-channel 187 // buffers. When this value is high enough, this reallocation and copy is 188 // avoided. 189 // 190 // When a value greater than 1 is specified, the tail region of the buffer 191 // (past the end of the values actually corresponding to channels) is required 192 // to be zero-initialized. 193 // 194 // As of 2020, values as high as 16 may be useful on some CPU architectures 195 // (corresponding to the widest kernels used on any CPU architecture). perchannel_buffers_capacity_rounding()196 int perchannel_buffers_capacity_rounding() const { 197 return 1 << storage_.perchannel_buffers_capacity_rounding_log2; 198 } set_perchannel_buffers_capacity_rounding(int value)199 void set_perchannel_buffers_capacity_rounding(int value) { 200 // Note: pot_log2 asserts (debug-only) that its argument is a power-of-two. 201 storage_.perchannel_buffers_capacity_rounding_log2 = pot_log2(value); 202 } 203 204 private: 205 detail::MulParamsStorage<AccumScalar, DstScalar> storage_; 206 set_perchannel(bool perchannel)207 void set_perchannel(bool perchannel) { 208 storage_.perchannel = perchannel; 209 } 210 }; 211 212 namespace detail { 213 214 // Floating-point case. 215 template <typename AccumScalar, typename DstScalar> 216 struct MulParamsStorage final { 217 static_assert(std::is_floating_point<AccumScalar>::value, ""); 218 static_assert(std::is_floating_point<DstScalar>::value, ""); 219 static_assert(sizeof(DstScalar) <= sizeof(AccumScalar), ""); 220 221 const AccumScalar* bias = nullptr; 222 DstScalar clamp_min = -std::numeric_limits<DstScalar>::infinity(); 223 DstScalar clamp_max = std::numeric_limits<DstScalar>::infinity(); 224 ChannelDimension channel_dimension = ChannelDimension::kRow; 225 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; 226 227 // Data members that are disabled in this case are left as `static constexpr` 228 // so that one can write some generic code. 229 static constexpr const AccumScalar* multiplier_fixedpoint_perchannel = 230 nullptr; 231 static constexpr const int* multiplier_exponent_perchannel = nullptr; 232 static constexpr AccumScalar multiplier_fixedpoint = 0; 233 static constexpr int multiplier_exponent = 0; 234 static constexpr bool perchannel = false; 235 }; 236 237 // Specialization for the integer-quantized type, with down-quantization of 238 // int32 accumulators to a narrower destination scalar type. 239 template <typename DstScalar> 240 struct MulParamsStorage<std::int32_t, DstScalar> final { 241 using AccumScalar = std::int32_t; 242 static_assert(std::is_integral<DstScalar>::value, ""); 243 static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, ""); 244 245 const AccumScalar* bias = nullptr; 246 union { 247 const AccumScalar* multiplier_fixedpoint_perchannel; 248 // Let the default multiplier be effecively a multiplication by 1, so that 249 // the matmul behaves as a (saturating) plain integer matmul. Unfortunately 250 // 1 is not exactly representable in fixedpoint with 0 integer bits, but 251 // using the highest representable value is a sufficiently good 252 // approximation: since this specialization of MulParams is for the case 253 // where DstScalar is at least 2x narrower than MulScalar, the values 254 // for which there would be a difference will get saturated anyway. 255 AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max(); 256 }; 257 union { 258 const int* multiplier_exponent_perchannel; 259 // See the above comment about the default value of multiplier_fixedpoint. 260 int multiplier_exponent = 0; 261 }; 262 DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest(); 263 DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); 264 ChannelDimension channel_dimension = ChannelDimension::kRow; 265 bool perchannel = false; 266 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; 267 }; 268 269 // Specialization used in the integer case when outputting raw int32 270 // accumulators, without down-quantization to a narrower destination scalar 271 // type. In this case, the feature of clamping destination values is not 272 // available. 273 template <> 274 struct MulParamsStorage<std::int32_t, std::int32_t> final { 275 using AccumScalar = std::int32_t; 276 using DstScalar = std::int32_t; 277 278 const AccumScalar* bias = nullptr; 279 ChannelDimension channel_dimension = ChannelDimension::kRow; 280 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; 281 282 // Data members that are disabled in this case are left as `static constexpr` 283 // so that one can write some generic code. 284 static constexpr const AccumScalar* multiplier_fixedpoint_perchannel = 285 nullptr; 286 static constexpr const int* multiplier_exponent_perchannel = nullptr; 287 static constexpr AccumScalar multiplier_fixedpoint = 0; 288 static constexpr int multiplier_exponent = 0; 289 static constexpr DstScalar clamp_min = 290 std::numeric_limits<DstScalar>::lowest(); 291 static constexpr DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); 292 static constexpr bool perchannel = false; 293 }; 294 295 } // namespace detail 296 297 } // namespace ruy 298 299 #endif // RUY_RUY_MUL_PARAMS_H_ 300