• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_KERNEL_COMMON_H_
17 #define RUY_RUY_KERNEL_COMMON_H_
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <type_traits>
22 
23 #include "ruy/apply_multiplier.h"
24 #include "ruy/check_macros.h"
25 #include "ruy/mat.h"
26 #include "ruy/matrix.h"
27 #include "ruy/mul_params.h"
28 #include "ruy/opt_set.h"
29 #include "ruy/path.h"
30 #include "ruy/platform.h"
31 #include "ruy/profiler/instrumentation.h"
32 #include "ruy/side_pair.h"
33 #include "ruy/size_util.h"
34 #include "ruy/tune.h"
35 
36 namespace ruy {
37 
38 template <Path ThePath, typename LhsScalar, typename RhsScalar,
39           typename AccumScalar, typename DstScalar>
40 struct Kernel;
41 
42 #define RUY_INHERIT_KERNEL(PARENT, CHILD)                               \
43   template <typename LhsScalar, typename RhsScalar, typename DstScalar, \
44             typename AccumScalar>                                       \
45   struct Kernel<CHILD, LhsScalar, RhsScalar, AccumScalar, DstScalar>    \
46       : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar> {  \
47     explicit Kernel(Tuning tuning)                                      \
48         : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar>( \
49               tuning) {}                                                \
50   };
51 
52 // KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code.
53 //
54 // In other cases, we still define (empty) versions, so that dummy kernels
55 // can use the classes in function signatures.
56 #if ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)) || \
57     RUY_PLATFORM_X86
58 
59 #define RUY_ASM_FLAG_HAS_BIAS 0x1
60 #define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2
61 #define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4
62 #define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8
63 #define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10
64 #define RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL 0x20
65 
66 #define RUY_ASM_TYPE_ID_UINT8 1
67 #define RUY_ASM_TYPE_ID_INT8 2
68 #define RUY_ASM_TYPE_ID_INT16 3
69 #define RUY_ASM_TYPE_ID_INT32 4
70 
71 template <typename DstScalar>
72 struct DstTypeId {};
73 
74 template <>
75 struct DstTypeId<std::uint8_t> {
76   static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8;
77 };
78 
79 template <>
80 struct DstTypeId<std::int8_t> {
81   static constexpr int kValue = RUY_ASM_TYPE_ID_INT8;
82 };
83 
84 template <>
85 struct DstTypeId<std::int16_t> {
86   static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
87 };
88 
89 template <>
90 struct DstTypeId<std::int32_t> {
91   static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
92 };
93 
94 template <int LhsCols, int RhsCols>
95 struct KernelParams8bit {
96   static constexpr int kMaxDstTypeSize = 4;
97 
98   const std::int32_t* bias;
99   const std::int32_t* lhs_sums;
100   const std::int32_t* rhs_sums;
101   const std::int8_t* lhs_base_ptr;
102   const std::int32_t* multiplier_fixedpoint;
103   const std::int32_t* multiplier_exponent;
104   const std::int8_t* rhs_base_ptr;
105   void* dst_base_ptr;
106   std::int32_t lhs_zero_point;
107   std::int32_t rhs_zero_point;
108   std::int32_t dst_zero_point;
109   std::int32_t prod_zp_depth;
110   std::int32_t start_row;
111   std::int32_t start_col;
112   std::int32_t last_row;
113   std::int32_t last_col;
114   std::int32_t dst_rows;
115   std::int32_t dst_cols;
116   std::int32_t lhs_stride;
117   std::int32_t rhs_stride;
118   std::int32_t dst_stride;
119   std::int32_t depth;
120   std::int32_t clamp_min;
121   std::int32_t clamp_max;
122   std::uint8_t flags;
123   std::uint8_t dst_type_id;
124   const std::int32_t zero_data[LhsCols] = {0};
125   std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize];
126   std::int32_t multiplier_fixedpoint_buf[LhsCols];
127   std::int32_t multiplier_exponent_buf[LhsCols];
128 };
129 
130 template <typename DstScalar, int LhsCols, int RhsCols>
131 void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
132                           const PMat<std::int8_t>& rhs,
133                           const MulParams<std::int32_t, DstScalar>& mul_params,
134                           int start_row, int start_col, int end_row,
135                           int end_col, Mat<DstScalar>* dst,
136                           KernelParams8bit<LhsCols, RhsCols>* params) {
137   using Params = KernelParams8bit<LhsCols, RhsCols>;
138 
139   static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, "");
140 
141   const int depth = lhs.layout.rows;
142   RUY_DCHECK_EQ(start_row % LhsCols, 0);
143   RUY_DCHECK_EQ(start_col % RhsCols, 0);
144   RUY_DCHECK_EQ(end_row % LhsCols, 0);
145   RUY_DCHECK_EQ(end_col % RhsCols, 0);
146 
147   params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
148   params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
149   params->flags = 0;
150   params->bias = params->zero_data;
151   if (mul_params.bias()) {
152     params->bias = mul_params.bias();
153     params->flags |= RUY_ASM_FLAG_HAS_BIAS;
154   }
155   if (lhs.sums) {
156     params->lhs_sums = lhs.sums;
157     params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS;
158   }
159   if (rhs.sums) {
160     params->rhs_sums = rhs.sums;
161     params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS;
162   }
163   if (mul_params.channel_dimension() == ChannelDimension::kCol) {
164     params->flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
165   }
166   params->start_row = start_row;
167   params->start_col = start_col;
168   params->last_row = end_row - LhsCols;
169   params->last_col = end_col - RhsCols;
170   params->lhs_stride = lhs.layout.stride;
171   params->rhs_stride = rhs.layout.stride;
172   params->dst_stride = sizeof(DstScalar) * dst->layout.stride;
173   params->lhs_zero_point = lhs.zero_point;
174   params->rhs_zero_point = rhs.zero_point;
175   params->dst_zero_point = dst->zero_point;
176   params->depth = depth;
177   params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
178   params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
179   if (mul_params.multiplier_fixedpoint_perchannel()) {
180     params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
181     params->multiplier_fixedpoint =
182         mul_params.multiplier_fixedpoint_perchannel();
183     params->multiplier_exponent = mul_params.multiplier_exponent_perchannel();
184   } else {
185     params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf;
186     params->multiplier_exponent = params->multiplier_exponent_buf;
187     for (int i = 0; i < LhsCols; i++) {
188       params->multiplier_fixedpoint_buf[i] = mul_params.multiplier_fixedpoint();
189       params->multiplier_exponent_buf[i] = mul_params.multiplier_exponent();
190     }
191   }
192   params->clamp_min = mul_params.clamp_min();
193   params->clamp_max = mul_params.clamp_max();
194   params->dst_rows = dst->layout.rows;
195   params->dst_cols = dst->layout.cols;
196 
197   RUY_DCHECK_LT(params->last_row, params->dst_rows);
198   RUY_DCHECK_LT(params->last_col, params->dst_cols);
199 
200   params->dst_type_id = DstTypeId<DstScalar>::kValue;
201   params->dst_base_ptr =
202       dst->data.get() + start_col * dst->layout.stride + start_row;
203 }
204 
205 template <int LhsCols, int RhsCols>
206 struct KernelParamsFloat {
207   const float* lhs_base_ptr;
208   const float* rhs_base_ptr;
209   float* dst_base_ptr;
210   const float* bias;
211   std::int32_t start_row;
212   std::int32_t start_col;
213   std::int32_t last_row;
214   std::int32_t last_col;
215   std::int32_t dst_rows;
216   std::int32_t dst_cols;
217   std::int32_t lhs_stride;
218   std::int32_t rhs_stride;
219   std::int32_t dst_stride;
220   std::int32_t depth;
221   float clamp_min;
222   float clamp_max;
223   std::uint8_t flags;
224   const float zero_data[LhsCols] = {0};
225   float dst_tmp_buf[LhsCols * RhsCols];
226 };
227 
228 template <int LhsCols, int RhsCols>
229 inline void MakeKernelParamsFloat(const PMat<float>& lhs,
230                                   const PMat<float>& rhs,
231                                   const MulParams<float, float>& mul_params,
232                                   int start_row, int start_col, int end_row,
233                                   int end_col, Mat<float>* dst,
234                                   KernelParamsFloat<LhsCols, RhsCols>* params) {
235   const int depth = lhs.layout.rows;
236   RUY_DCHECK_EQ(start_row % LhsCols, 0);
237   RUY_DCHECK_EQ(start_col % RhsCols, 0);
238   RUY_DCHECK_EQ(end_row % LhsCols, 0);
239   RUY_DCHECK_EQ(end_col % RhsCols, 0);
240 
241   params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
242   params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
243   params->dst_base_ptr =
244       dst->data.get() + start_col * dst->layout.stride + start_row;
245 
246   std::uint8_t flags = 0;
247   params->bias = params->zero_data;
248   if (mul_params.bias()) {
249     params->bias = mul_params.bias();
250     flags |= RUY_ASM_FLAG_HAS_BIAS;
251   }
252   if (mul_params.channel_dimension() == ChannelDimension::kCol) {
253     flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
254   }
255   params->flags = flags;
256   params->start_row = start_row;
257   params->start_col = start_col;
258   params->last_row = end_row - LhsCols;
259   params->last_col = end_col - RhsCols;
260   params->lhs_stride = sizeof(float) * lhs.layout.stride;
261   params->rhs_stride = sizeof(float) * rhs.layout.stride;
262   params->dst_stride = sizeof(float) * dst->layout.stride;
263   params->depth = depth;
264   params->clamp_min = mul_params.clamp_min();
265   params->clamp_max = mul_params.clamp_max();
266   params->dst_rows = dst->layout.rows;
267   params->dst_cols = dst->layout.cols;
268 
269   RUY_DCHECK_LT(params->last_row, params->dst_rows);
270   RUY_DCHECK_LT(params->last_col, params->dst_cols);
271 }
272 
273 #else  // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) &&
274        // RUY_OPT(ASM)) || RUY_PLATFORM_X86
275 
276 template <int LhsCols, int RhsCols>
277 struct KernelParams8bit {};
278 
279 template <int LhsCols, int RhsCols>
280 struct KernelParamsFloat {};
281 
282 #endif  // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) &&
283         //  RUY_OPT(ASM)) || RUY_PLATFORM_X86
284 
285 }  // namespace ruy
286 
287 #endif  // RUY_RUY_KERNEL_COMMON_H_
288