• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_MUL_PARAMS_H_
17 #define RUY_RUY_MUL_PARAMS_H_
18 
19 #include <cstdint>
20 #include <limits>
21 #include <type_traits>
22 
23 #include "ruy/check_macros.h"
24 #include "ruy/size_util.h"
25 
26 namespace ruy {
27 
28 // Enumeration to designate which dimension is the 'channels', for MulParams
29 // features that are 'per-channel', namely the bias-vector and the quantized
30 // multiplier.
31 enum class ChannelDimension : std::int8_t {
32   // kRow means that 'per-channel' means 'per row of the destination matrix'
33   kRow,
34   // kCol means that 'per-channel' means 'per column of the destination matrix'
35   kCol
36 };
37 
38 namespace detail {
39 template <typename tAccumScalar, typename tDstScalar>
40 struct MulParamsStorage;
41 }
42 
43 // MulParams describes all about a matrix multiplication that
44 // isn't encoded in the LHS, RHS and destination matrices. Some of that
45 // information is encoded as compile-time constants and types (for instance, the
46 // choice of accumulator type, AccumScalar). Some of that information is encoded
47 // as runtime values (for instance, the optional bias vector).
48 //
49 // Template parameters:
50 // AccumScalar: Accumulator type. The type of accumulators used to compute the
51 // dot-products before being ultimately casted to the destination type.
52 // DstScalar: The destination scalar type.
53 //
54 // Constraints on these template parameters (see also the ruy::Mul comment):
55 // * If DstScalar is floating-point then AccumScalar must also be.
56 // * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover
57 //   in that integral case, there is a mode switch:
58 //   - If DstScalar is std::int32_t then the multiplier_* fields are all
59 //     disabled, and ruy::Mul will just return raw (unscaled) accumulators.
60 //   - If DstScalar is not std::int32_t then the multiplier_* fields are
61 //     enabled, and ruy::Mul will use them to scale internal std::int32_t
62 //     accumulators before casting them to the DstScalar type. The default
63 //     values are such that the effective multiplier is 1 (no scaling).
64 //
65 // For the latter case (DstScalar integral and narrower than std::int32_t),
66 // reference code can be found in the implementation of ruy::ApplyMultiplier.
67 // If you look there, you'll find warnings like this:
68 //
69 //   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
70 //   Warning: this code is not meant to be bit-exact-normative.
71 //   Please refer to the class comment of ruy::MulParams, in mul_params.h.
72 //   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
73 //
74 // The explanation of this warning is that as of early 2021, we still don't know
75 // whether it is advisable to let this code as-is have normative value, or
76 // whether that would become advisable after some specific final change.
77 //
78 // Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform
79 // bit-exactly to this reference, but we also know that x86 could be faster if
80 // it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't
81 // know that this particular reference code is inherently better than other
82 // forms that could perform better on these architectures --- in fact, the
83 // alternative that was proposed in [2] as better performing on ARM Cortex-M
84 // is also inherently more accurate thanks to rounding only once, but it would
85 // perform worse on both ARM NEON, and x86.
86 //
87 // In fact, if we look at other hardware architectures beyond current Ruy
88 // targets, namely "hardware accelerators", it becomes clear that there is no
89 // hope for any form of this to be efficiently implementable simultaneously on
90 // all current relevant hardware. Indeed, some accelerators prefer to perform
91 // the multiplication in IEEE float32, others in IEEE float16, others in
92 // bfloat16, others in 16-bit fixed-point...
93 //
94 // See:
95 //   [1] https://github.com/google/ruy/pull/227
96 //   [2] https://github.com/tensorflow/tensorflow/issues/25087
97 template <typename tAccumScalar, typename tDstScalar>
98 class MulParams final {
99  public:
100   using AccumScalar = tAccumScalar;
101   using DstScalar = tDstScalar;
102 
103   // The bias vector data, if not null.
bias()104   const AccumScalar* bias() const { return storage_.bias; }
set_bias(const AccumScalar * ptr)105   void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; }
106   // Only for non-floating-point cases. The fixed-point part of the multiplier
107   // by which accumulators are multiplied before being casted to the destination
108   // type. This is a fixed-point quantity with 0 integer bits. Since
109   // (as explained in the class comment) AccumScalar must be std::int32_t,
110   // that means that the fixed-point format is Q0.31. For example,
111   // a multiplier_fixedpoint value of 2^30 has the effect of multiplying
112   // by one half (1/2). More generally, the effect is to multiply by
113   // (multiplier_fixedpoint / (2^31)).
multiplier_fixedpoint()114   AccumScalar multiplier_fixedpoint() const {
115     return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint;
116   }
set_multiplier_fixedpoint(const AccumScalar value)117   void set_multiplier_fixedpoint(const AccumScalar value) {
118     set_perchannel(false);
119     storage_.multiplier_fixedpoint = value;
120   }
121   // Only for non-floating-point cases. The exponent part of the aforementioned
122   // multiplier.
multiplier_exponent()123   int multiplier_exponent() const {
124     return storage_.perchannel ? 0 : storage_.multiplier_exponent;
125   }
set_multiplier_exponent(const int value)126   void set_multiplier_exponent(const int value) {
127     set_perchannel(false);
128     storage_.multiplier_exponent = value;
129   }
130   // Per-channel variant of multiplier_fixedpoint. Setting this switches
131   // to per-channel mode, where `multiplier_fixedpoint` and
132   // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel`
133   // and `multiplier_exponent_perchannel` are used instead.
134   //
135   // This must point to a buffer of as many values as there are rows or columns
136   // in the destination matrix, whichever is the channels dimension. Each
137   // channel of the destination matrix will use the corresponding buffer element
138   // instead of multiplier_fixedpoint.
multiplier_fixedpoint_perchannel()139   const AccumScalar* multiplier_fixedpoint_perchannel() const {
140     return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel
141                                : nullptr;
142   }
set_multiplier_fixedpoint_perchannel(const AccumScalar * ptr)143   void set_multiplier_fixedpoint_perchannel(const AccumScalar* ptr) {
144     set_perchannel(true);
145     storage_.multiplier_fixedpoint_perchannel = ptr;
146   }
147   // Per-channel variant of multiplier_exponent. Same comments as for
148   // multiplier_fixedpoint_perchannel.
multiplier_exponent_perchannel()149   const int* multiplier_exponent_perchannel() const {
150     return storage_.perchannel ? storage_.multiplier_exponent_perchannel
151                                : nullptr;
152   }
set_multiplier_exponent_perchannel(const int * ptr)153   void set_multiplier_exponent_perchannel(const int* ptr) {
154     set_perchannel(true);
155     storage_.multiplier_exponent_perchannel = ptr;
156   }
157   // min clamp bound of destination values.
clamp_min()158   DstScalar clamp_min() const { return storage_.clamp_min; }
set_clamp_min(const DstScalar value)159   void set_clamp_min(const DstScalar value) { storage_.clamp_min = value; }
160   // max clamp bound of destination values.
clamp_max()161   DstScalar clamp_max() const { return storage_.clamp_max; }
set_clamp_max(const DstScalar value)162   void set_clamp_max(const DstScalar value) { storage_.clamp_max = value; }
163   // Designates which dimension is the 'channels', for per-channel features
164   // such as bias-addition and per-channel quantization multipliers.
channel_dimension()165   ChannelDimension channel_dimension() const {
166     return storage_.channel_dimension;
167   }
set_channel_dimension(ChannelDimension value)168   void set_channel_dimension(ChannelDimension value) {
169     storage_.channel_dimension = value;
170   }
171   // Specifies the upward rounding of the allocated capacity of per-channel
172   // buffers such as bias vectors and per-channel quantization multipliers.
173   // The unit is matrix entries, not bytes.
174   //
175   // This value must be a power of two.
176   //
177   // The default value, 1, means no upward rounding, meaning that the buffers
178   // are not required to have a capacity greater than the size of the
179   // corresponding matrix dimension, i.e. the number of rows (respectively
180   // columns) of the destination matrix if `channel_dimension()` is kRow
181   // (respectively kCol).
182   //
183   // Higher values allow the implementation to assume that it is OK to access
184   // these buffers a little past this boundary, which is useful in SIMD
185   // optimized kernels. In practice, when this value is lower than what the
186   // kernel requires, ruy has to internally reallocate and copy per-channel
187   // buffers. When this value is high enough, this reallocation and copy is
188   // avoided.
189   //
190   // When a value greater than 1 is specified, the tail region of the buffer
191   // (past the end of the values actually corresponding to channels) is required
192   // to be zero-initialized.
193   //
194   // As of 2020, values as high as 16 may be useful on some CPU architectures
195   // (corresponding to the widest kernels used on any CPU architecture).
perchannel_buffers_capacity_rounding()196   int perchannel_buffers_capacity_rounding() const {
197     return 1 << storage_.perchannel_buffers_capacity_rounding_log2;
198   }
set_perchannel_buffers_capacity_rounding(int value)199   void set_perchannel_buffers_capacity_rounding(int value) {
200     // Note: pot_log2 asserts (debug-only) that its argument is a power-of-two.
201     storage_.perchannel_buffers_capacity_rounding_log2 = pot_log2(value);
202   }
203 
204  private:
205   detail::MulParamsStorage<AccumScalar, DstScalar> storage_;
206 
set_perchannel(bool perchannel)207   void set_perchannel(bool perchannel) {
208     storage_.perchannel = perchannel;
209   }
210 };
211 
212 namespace detail {
213 
214 // Floating-point case.
215 template <typename AccumScalar, typename DstScalar>
216 struct MulParamsStorage final {
217   static_assert(std::is_floating_point<AccumScalar>::value, "");
218   static_assert(std::is_floating_point<DstScalar>::value, "");
219   static_assert(sizeof(DstScalar) <= sizeof(AccumScalar), "");
220 
221   const AccumScalar* bias = nullptr;
222   DstScalar clamp_min = -std::numeric_limits<DstScalar>::infinity();
223   DstScalar clamp_max = std::numeric_limits<DstScalar>::infinity();
224   ChannelDimension channel_dimension = ChannelDimension::kRow;
225   std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
226 
227   // Data members that are disabled in this case are left as `static constexpr`
228   // so that one can write some generic code.
229   static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
230       nullptr;
231   static constexpr const int* multiplier_exponent_perchannel = nullptr;
232   static constexpr AccumScalar multiplier_fixedpoint = 0;
233   static constexpr int multiplier_exponent = 0;
234   static constexpr bool perchannel = false;
235 };
236 
237 // Specialization for the integer-quantized type, with down-quantization of
238 // int32 accumulators to a narrower destination scalar type.
239 template <typename DstScalar>
240 struct MulParamsStorage<std::int32_t, DstScalar> final {
241   using AccumScalar = std::int32_t;
242   static_assert(std::is_integral<DstScalar>::value, "");
243   static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, "");
244 
245   const AccumScalar* bias = nullptr;
246   union {
247     const AccumScalar* multiplier_fixedpoint_perchannel;
248     // Let the default multiplier be effecively a multiplication by 1, so that
249     // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
250     // 1 is not exactly representable in fixedpoint with 0 integer bits, but
251     // using the highest representable value is a sufficiently good
252     // approximation: since this specialization of MulParams is for the case
253     // where DstScalar is at least 2x narrower than MulScalar, the values
254     // for which there would be a difference will get saturated anyway.
255     AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max();
256   };
257   union {
258     const int* multiplier_exponent_perchannel;
259     // See the above comment about the default value of multiplier_fixedpoint.
260     int multiplier_exponent = 0;
261   };
262   DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
263   DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
264   ChannelDimension channel_dimension = ChannelDimension::kRow;
265   bool perchannel = false;
266   std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
267 };
268 
269 // Specialization used in the integer case when outputting raw int32
270 // accumulators, without down-quantization to a narrower destination scalar
271 // type. In this case, the feature of clamping destination values is not
272 // available.
273 template <>
274 struct MulParamsStorage<std::int32_t, std::int32_t> final {
275   using AccumScalar = std::int32_t;
276   using DstScalar = std::int32_t;
277 
278   const AccumScalar* bias = nullptr;
279   ChannelDimension channel_dimension = ChannelDimension::kRow;
280   std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
281 
282   // Data members that are disabled in this case are left as `static constexpr`
283   // so that one can write some generic code.
284   static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
285       nullptr;
286   static constexpr const int* multiplier_exponent_perchannel = nullptr;
287   static constexpr AccumScalar multiplier_fixedpoint = 0;
288   static constexpr int multiplier_exponent = 0;
289   static constexpr DstScalar clamp_min =
290       std::numeric_limits<DstScalar>::lowest();
291   static constexpr DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
292   static constexpr bool perchannel = false;
293 };
294 
295 }  // namespace detail
296 
297 }  // namespace ruy
298 
299 #endif  // RUY_RUY_MUL_PARAMS_H_
300