• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_KERNEL_COMMON_H_
17 #define RUY_RUY_KERNEL_COMMON_H_
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <type_traits>
22 
23 #include "ruy/apply_multiplier.h"
24 #include "ruy/check_macros.h"
25 #include "ruy/mat.h"
26 #include "ruy/matrix.h"
27 #include "ruy/mul_params.h"
28 #include "ruy/opt_set.h"
29 #include "ruy/path.h"
30 #include "ruy/platform.h"
31 #include "ruy/profiler/instrumentation.h"
32 #include "ruy/side_pair.h"
33 #include "ruy/size_util.h"
34 #include "ruy/tune.h"
35 
36 namespace ruy {
37 
38 template <Path ThePath, typename LhsScalar, typename RhsScalar,
39           typename AccumScalar, typename DstScalar>
40 struct Kernel;
41 
42 #define RUY_INHERIT_KERNEL(PARENT, CHILD)                               \
43   template <typename LhsScalar, typename RhsScalar, typename DstScalar, \
44             typename AccumScalar>                                       \
45   struct Kernel<CHILD, LhsScalar, RhsScalar, AccumScalar, DstScalar>    \
46       : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar> {  \
47     explicit Kernel(Tuning tuning)                                      \
48         : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar>( \
49               tuning) {}                                                \
50   };
51 
52 // KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code.
53 //
54 // In other cases, we still define (empty) versions, so that dummy kernels
55 // can use the classes in function signatures.
56 #if ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)) || \
57     RUY_PLATFORM_X86
58 
59 #define RUY_ASM_FLAG_HAS_BIAS 0x1
60 #define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2
61 #define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4
62 #define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8
63 #define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10
64 #define RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL 0x20
65 
66 #define RUY_ASM_TYPE_ID_UINT8 1
67 #define RUY_ASM_TYPE_ID_INT8 2
68 #define RUY_ASM_TYPE_ID_INT16 3
69 #define RUY_ASM_TYPE_ID_INT32 4
70 
71 template <typename DstScalar>
72 struct DstTypeId {};
73 
74 template <>
75 struct DstTypeId<std::uint8_t> {
76   static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8;
77 };
78 
79 template <>
80 struct DstTypeId<std::int8_t> {
81   static constexpr int kValue = RUY_ASM_TYPE_ID_INT8;
82 };
83 
84 template <>
85 struct DstTypeId<std::int16_t> {
86   static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
87 };
88 
89 template <>
90 struct DstTypeId<std::int32_t> {
91   static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
92 };
93 
94 template <int LhsCols, int RhsCols>
95 struct KernelParams8bit {
96   static constexpr int kMaxDstTypeSize = 4;
97 
98   const std::int32_t* bias;
99   const std::int32_t* lhs_sums;
100   const std::int32_t* rhs_sums;
101   const std::int8_t* lhs_base_ptr;
102   const std::int32_t* multiplier_fixedpoint;
103   const std::int32_t* multiplier_exponent;
104   const std::int8_t* rhs_base_ptr;
105   void* dst_base_ptr;
106   std::int32_t lhs_zero_point;
107   std::int32_t rhs_zero_point;
108   std::int32_t dst_zero_point;
109   std::int32_t prod_zp_depth;
110   std::int32_t start_row;
111   std::int32_t start_col;
112   std::int32_t last_row;
113   std::int32_t last_col;
114   std::int32_t dst_rows;
115   std::int32_t dst_cols;
116   std::int32_t lhs_stride;
117   std::int32_t rhs_stride;
118   std::int32_t dst_stride;
119   std::int32_t depth;
120   std::int32_t clamp_min;
121   std::int32_t clamp_max;
122   std::uint8_t flags;
123   std::uint8_t dst_type_id;
124   const std::int32_t zero_data[LhsCols] = {0};
125   std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize];
126   std::int32_t multiplier_fixedpoint_buf[LhsCols];
127   std::int32_t multiplier_exponent_buf[LhsCols];
128 };
129 
130 template <typename DstScalar, int LhsCols, int RhsCols>
131 void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
132                           const PMat<std::int8_t>& rhs,
133                           const MulParams<std::int32_t, DstScalar>& mul_params,
134                           int start_row, int start_col, int end_row,
135                           int end_col, Mat<DstScalar>* dst,
136                           KernelParams8bit<LhsCols, RhsCols>* params) {
137   using Params = KernelParams8bit<LhsCols, RhsCols>;
138 
139   static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, "");
140 
141   const int depth = lhs.layout.rows;
142   RUY_DCHECK_EQ(start_row % LhsCols, 0);
143   RUY_DCHECK_EQ(start_col % RhsCols, 0);
144   RUY_DCHECK_EQ(end_row % LhsCols, 0);
145   RUY_DCHECK_EQ(end_col % RhsCols, 0);
146 
147   params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
148   params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
149   params->flags = 0;
150   params->bias = params->zero_data;
151   if (mul_params.bias()) {
152     params->bias = mul_params.bias();
153     params->flags |= RUY_ASM_FLAG_HAS_BIAS;
154   }
155   if (lhs.sums) {
156     params->lhs_sums = lhs.sums;
157     params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS;
158   }
159   if (rhs.sums) {
160     params->rhs_sums = rhs.sums;
161     params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS;
162   }
163   if (mul_params.channel_dimension() == ChannelDimension::kCol) {
164     params->flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
165   }
166   params->start_row = start_row;
167   params->start_col = start_col;
168   params->last_row = end_row - LhsCols;
169   params->last_col = end_col - RhsCols;
170   params->lhs_stride = lhs.layout.stride;
171   params->rhs_stride = rhs.layout.stride;
172   params->dst_stride = sizeof(DstScalar) * dst->layout.stride;
173   params->lhs_zero_point = lhs.zero_point;
174   params->rhs_zero_point = rhs.zero_point;
175   params->dst_zero_point = dst->zero_point;
176   params->depth = depth;
177   params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
178   params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
179   if (mul_params.multiplier_fixedpoint_perchannel()) {
180     // Temporary release-assert to debug some crashes in an application.
181     RUY_CHECK(mul_params.multiplier_exponent_perchannel());
182     params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
183     params->multiplier_fixedpoint =
184         mul_params.multiplier_fixedpoint_perchannel();
185     params->multiplier_exponent = mul_params.multiplier_exponent_perchannel();
186   } else {
187     params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf;
188     params->multiplier_exponent = params->multiplier_exponent_buf;
189     for (int i = 0; i < LhsCols; i++) {
190       params->multiplier_fixedpoint_buf[i] = mul_params.multiplier_fixedpoint();
191       params->multiplier_exponent_buf[i] = mul_params.multiplier_exponent();
192     }
193   }
194   params->clamp_min = mul_params.clamp_min();
195   params->clamp_max = mul_params.clamp_max();
196   params->dst_rows = dst->layout.rows;
197   params->dst_cols = dst->layout.cols;
198 
199   RUY_DCHECK_LT(params->last_row, params->dst_rows);
200   RUY_DCHECK_LT(params->last_col, params->dst_cols);
201 
202   params->dst_type_id = DstTypeId<DstScalar>::kValue;
203   params->dst_base_ptr =
204       dst->data.get() + start_col * dst->layout.stride + start_row;
205 
206   // Temporary release-asserts to debug some crashes in an application.
207   RUY_CHECK(params->multiplier_fixedpoint);
208   RUY_CHECK(params->multiplier_exponent);
209   RUY_CHECK(params->bias);
210 }
211 
212 template <int LhsCols, int RhsCols>
213 struct KernelParamsFloat {
214   const float* lhs_base_ptr;
215   const float* rhs_base_ptr;
216   float* dst_base_ptr;
217   const float* bias;
218   std::int32_t start_row;
219   std::int32_t start_col;
220   std::int32_t last_row;
221   std::int32_t last_col;
222   std::int32_t dst_rows;
223   std::int32_t dst_cols;
224   std::int32_t lhs_stride;
225   std::int32_t rhs_stride;
226   std::int32_t dst_stride;
227   std::int32_t depth;
228   float clamp_min;
229   float clamp_max;
230   std::uint8_t flags;
231   const float zero_data[LhsCols] = {0};
232   float dst_tmp_buf[LhsCols * RhsCols];
233 };
234 
235 template <int LhsCols, int RhsCols>
236 inline void MakeKernelParamsFloat(const PMat<float>& lhs,
237                                   const PMat<float>& rhs,
238                                   const MulParams<float, float>& mul_params,
239                                   int start_row, int start_col, int end_row,
240                                   int end_col, Mat<float>* dst,
241                                   KernelParamsFloat<LhsCols, RhsCols>* params) {
242   const int depth = lhs.layout.rows;
243   RUY_DCHECK_EQ(start_row % LhsCols, 0);
244   RUY_DCHECK_EQ(start_col % RhsCols, 0);
245   RUY_DCHECK_EQ(end_row % LhsCols, 0);
246   RUY_DCHECK_EQ(end_col % RhsCols, 0);
247 
248   params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
249   params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
250   params->dst_base_ptr =
251       dst->data.get() + start_col * dst->layout.stride + start_row;
252 
253   std::uint8_t flags = 0;
254   params->bias = params->zero_data;
255   if (mul_params.bias()) {
256     params->bias = mul_params.bias();
257     flags |= RUY_ASM_FLAG_HAS_BIAS;
258   }
259   if (mul_params.channel_dimension() == ChannelDimension::kCol) {
260     flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
261   }
262   params->flags = flags;
263   params->start_row = start_row;
264   params->start_col = start_col;
265   params->last_row = end_row - LhsCols;
266   params->last_col = end_col - RhsCols;
267   params->lhs_stride = sizeof(float) * lhs.layout.stride;
268   params->rhs_stride = sizeof(float) * rhs.layout.stride;
269   params->dst_stride = sizeof(float) * dst->layout.stride;
270   params->depth = depth;
271   params->clamp_min = mul_params.clamp_min();
272   params->clamp_max = mul_params.clamp_max();
273   params->dst_rows = dst->layout.rows;
274   params->dst_cols = dst->layout.cols;
275 
276   RUY_DCHECK_LT(params->last_row, params->dst_rows);
277   RUY_DCHECK_LT(params->last_col, params->dst_cols);
278 }
279 
280 #else  // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) &&
281        // RUY_OPT(ASM)) || RUY_PLATFORM_X86
282 
283 template <int LhsCols, int RhsCols>
284 struct KernelParams8bit {};
285 
286 template <int LhsCols, int RhsCols>
287 struct KernelParamsFloat {};
288 
289 #endif  // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) &&
290         //  RUY_OPT(ASM)) || RUY_PLATFORM_X86
291 
292 }  // namespace ruy
293 
294 #endif  // RUY_RUY_KERNEL_COMMON_H_
295