1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_
17 #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_
18
19 #include <algorithm>
20 #include <cstdint>
21 #include <type_traits>
22
23 #include "tensorflow/lite/experimental/ruy/check_macros.h"
24 #include "tensorflow/lite/experimental/ruy/common.h"
25 #include "tensorflow/lite/experimental/ruy/internal_matrix.h"
26 #include "tensorflow/lite/experimental/ruy/matrix.h"
27 #include "tensorflow/lite/experimental/ruy/opt_set.h"
28 #include "tensorflow/lite/experimental/ruy/path.h"
29 #include "tensorflow/lite/experimental/ruy/platform.h"
30 #include "tensorflow/lite/experimental/ruy/profiler/instrumentation.h"
31 #include "tensorflow/lite/experimental/ruy/side_pair.h"
32 #include "tensorflow/lite/experimental/ruy/size_util.h"
33 #include "tensorflow/lite/experimental/ruy/spec.h"
34 #include "tensorflow/lite/experimental/ruy/tune.h"
35
36 namespace ruy {
37
38 template <Path ThePath, typename LhsScalar, typename RhsScalar,
39 typename DstScalar, typename Spec>
40 struct Kernel {};
41
42 template <Path ThePath, typename LhsScalar, typename RhsScalar,
43 typename DstScalar, typename Spec>
RunKernelTyped(Tuning tuning,const PackedMatrix<LhsScalar> & lhs,const PackedMatrix<RhsScalar> & rhs,const Spec & spec,int start_row,int start_col,int end_row,int end_col,Matrix<DstScalar> * dst)44 void RunKernelTyped(Tuning tuning, const PackedMatrix<LhsScalar>& lhs,
45 const PackedMatrix<RhsScalar>& rhs, const Spec& spec,
46 int start_row, int start_col, int end_row, int end_col,
47 Matrix<DstScalar>* dst) {
48 using Kernel = Kernel<ThePath, LhsScalar, RhsScalar, DstScalar, Spec>;
49 Kernel kernel(tuning);
50 #if !defined(NDEBUG) || !RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL)
51 using LhsLayout = typename Kernel::LhsLayout;
52 using RhsLayout = typename Kernel::RhsLayout;
53 #endif
54 // end_row and end_col may be larger than dst dimensions.
55 // that is because kernels write directly to the destination matrix, whose
56 // dimensions may not be a multiple of the kernel dimensions, and we try to
57 // keep this annoyance localized as an implementation detail in kernels,
58 // by allowing to pass rounded-up values down as far as possible.
59 // These assertions encode the contract.
60 RUY_DCHECK_LE(0, start_row);
61 RUY_DCHECK_LE(start_row, end_row);
62 RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols);
63 RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0);
64 RUY_DCHECK_LE(0, start_col);
65 RUY_DCHECK_LE(start_col, end_col);
66 RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols);
67 RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0);
68 #if RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL)
69 kernel.Run(lhs, rhs, spec, start_row, start_col, end_row, end_col, dst);
70 #else
71 for (int col = start_col; col < end_col; col += RhsLayout::kCols) {
72 int block_end_col = std::min(col + RhsLayout::kCols, end_col);
73 for (int row = start_row; row < end_row; row += LhsLayout::kCols) {
74 int block_end_row = std::min(row + LhsLayout::kCols, end_row);
75 kernel.Run(lhs, rhs, spec, row, col, block_end_row, block_end_col, dst);
76 }
77 }
78 #endif
79 }
80
81 // Main entry point for kernels.
82 template <Path ThePath, typename LhsScalar, typename RhsScalar,
83 typename DstScalar, typename Spec>
RunKernel(Tuning tuning,const SidePair<PMatrix> & src,void * spec,const SidePair<int> & start,const SidePair<int> & end,DMatrix * dst)84 void RunKernel(Tuning tuning, const SidePair<PMatrix>& src, void* spec,
85 const SidePair<int>& start, const SidePair<int>& end,
86 DMatrix* dst) {
87 Matrix<DstScalar> mdst = ToMatrix<DstScalar>(*dst);
88 RunKernelTyped<ThePath, LhsScalar, RhsScalar, DstScalar, Spec>(
89 tuning, ToPackedMatrix<LhsScalar>(src[Side::kLhs]),
90 ToPackedMatrix<RhsScalar>(src[Side::kRhs]),
91 *static_cast<const Spec*>(spec), start[Side::kLhs], start[Side::kRhs],
92 end[Side::kLhs], end[Side::kRhs], &mdst);
93 }
94
95 // Copied from gemmlowp/fixedpoint.
SaturatingRoundingDoublingHighMul(std::int32_t a,std::int32_t b)96 inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
97 std::int32_t b) {
98 bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
99 std::int64_t a_64(a);
100 std::int64_t b_64(b);
101 std::int64_t ab_64 = a_64 * b_64;
102 std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
103 std::int32_t ab_x2_high32 =
104 static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
105 return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
106 }
107
RoundingDivideByPOT(std::int32_t numerator,int exponent)108 inline std::int32_t RoundingDivideByPOT(std::int32_t numerator, int exponent) {
109 std::int32_t sign = numerator >= 0 ? 1 : -1;
110 std::int32_t abs_numerator = std::abs(numerator);
111 std::int32_t mask = (1LL << exponent) - 1;
112 std::int32_t remainder = abs_numerator & mask;
113 std::int32_t threshold = mask >> 1;
114 std::int32_t abs_result =
115 (abs_numerator >> exponent) + (remainder > threshold ? 1 : 0);
116 return sign * abs_result;
117 }
118
119 // Copied from TF Lite code.
MultiplyByQuantizedMultiplier(std::int32_t x,std::int32_t quantized_multiplier,int shift)120 inline std::int32_t MultiplyByQuantizedMultiplier(
121 std::int32_t x, std::int32_t quantized_multiplier, int shift) {
122 int left_shift = shift > 0 ? shift : 0;
123 int right_shift = shift > 0 ? 0 : -shift;
124 return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
125 x * (1 << left_shift), quantized_multiplier),
126 right_shift);
127 }
128
129 // Helper to apply a fixed-point multiplier. Only 'applicable' if AccumScalar
130 // is int32 (i.e. in all cases except floating-point) and if the destination is
131 // not int32 (i.e. unless the user wants to get raw accumulators).
132 template <typename Spec,
133 bool IsApplicable =
134 std::is_same<typename Spec::AccumScalar, std::int32_t>::value &&
135 !std::is_same<typename Spec::DstScalar, std::int32_t>::value>
136 struct ApplyMultiplierImpl {};
137
138 // Specialization in non-applicable case: do nothing, just check that values
139 // are default.
140 template <typename Spec>
141 struct ApplyMultiplierImpl<Spec, false> {
142 using AccumScalar = typename Spec::AccumScalar;
143 using DstScalar = typename Spec::DstScalar;
144 static void Run(const Spec& spec, int row, AccumScalar* accum) {
145 RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0);
146 RUY_DCHECK_EQ(spec.multiplier_exponent, 0);
147 }
148 };
149
150 template <typename Spec>
151 struct ApplyMultiplierImpl<Spec, true> {
152 using AccumScalar = typename Spec::AccumScalar;
153 using DstScalar = typename Spec::DstScalar;
154 static void Run(const Spec& spec, int row, AccumScalar* accum) {
155 AccumScalar m = spec.multiplier_fixedpoint_perchannel
156 ? spec.multiplier_fixedpoint_perchannel[row]
157 : spec.multiplier_fixedpoint;
158 int e = spec.multiplier_exponent_perchannel
159 ? spec.multiplier_exponent_perchannel[row]
160 : spec.multiplier_exponent;
161 *accum = MultiplyByQuantizedMultiplier(*accum, m, e);
162 }
163 };
164
165 template <typename Spec>
166 void ApplyMultiplier(const Spec& spec, int row,
167 typename Spec::AccumScalar* accum) {
168 ApplyMultiplierImpl<Spec>::Run(spec, row, accum);
169 }
170
171 template <typename LhsScalar, typename RhsScalar, typename DstScalar,
172 typename Spec>
173 struct Kernel<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar, Spec> {
174 using AccumScalar = typename Spec::AccumScalar;
175 using LhsLayout = typename Spec::StandardCppKernelLhsLayout;
176 using RhsLayout = typename Spec::StandardCppKernelRhsLayout;
177 explicit Kernel(Tuning) {}
178 void Run(const PackedMatrix<LhsScalar>& lhs,
179 const PackedMatrix<RhsScalar>& rhs, const Spec& spec, int start_row,
180 int start_col, int end_row, int end_col,
181 Matrix<DstScalar>* dst) const {
182 // See the comment in RunKernelTyped. end_row may be larger than
183 // dst->layout.rows. It's the responsibility of the kernel to avoid
184 // overrunning dst boundaries, which we do here by computing
185 // clamped_end_row.
186 int clamped_end_row = std::min(end_row, dst->layout.rows);
187 int clamped_end_col = std::min(end_col, dst->layout.cols);
188 RUY_DCHECK_LE(0, start_row);
189 RUY_DCHECK_LE(start_row, clamped_end_row);
190 RUY_DCHECK_LE(clamped_end_row, dst->layout.rows);
191 RUY_DCHECK_LE(clamped_end_row, end_row);
192 RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols);
193 RUY_DCHECK_LE(0, start_col);
194 RUY_DCHECK_LE(start_col, clamped_end_col);
195 RUY_DCHECK_LE(clamped_end_col, dst->layout.cols);
196 RUY_DCHECK_LE(clamped_end_col, end_col);
197 RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols);
198 profiler::ScopeLabel label("Kernel (Standard Cpp)");
199 const int depth = lhs.layout.rows;
200 for (int i = start_row; i < clamped_end_row; i++) {
201 for (int j = start_col; j < clamped_end_col; j++) {
202 using AccumScalar = typename Spec::AccumScalar;
203 AccumScalar accum = 0;
204 for (int k = 0; k < depth; k++) {
205 AccumScalar lhs_val = Element(lhs, k, i);
206 AccumScalar rhs_val = Element(rhs, k, j);
207 accum += lhs_val * rhs_val;
208 }
209 if (spec.bias) {
210 accum += spec.bias[i];
211 }
212 if (lhs.zero_point) {
213 accum -= lhs.zero_point * rhs.sums[j];
214 }
215 if (rhs.zero_point) {
216 accum -= rhs.zero_point * lhs.sums[i];
217 }
218 if (lhs.zero_point && rhs.zero_point) {
219 accum += lhs.zero_point * rhs.zero_point * depth;
220 }
221 ApplyMultiplier(spec, i, &accum);
222 accum += dst->zero_point;
223 accum = std::min<AccumScalar>(accum, spec.clamp_max);
224 accum = std::max<AccumScalar>(accum, spec.clamp_min);
225 *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum);
226 }
227 }
228 }
229 };
230
231 #define RUY_INHERIT_KERNEL(PARENT, CHILD) \
232 template <typename LhsScalar, typename RhsScalar, typename DstScalar, \
233 typename Spec> \
234 struct Kernel<CHILD, LhsScalar, RhsScalar, DstScalar, Spec> \
235 : Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec> { \
236 explicit Kernel(Tuning tuning) \
237 : Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec>(tuning) {} \
238 };
239
240 #if RUY_PLATFORM(NEON)
241 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon)
242 RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
243 #elif RUY_PLATFORM(X86)
244 RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kSse42)
245 RUY_INHERIT_KERNEL(Path::kSse42, Path::kAvx2)
246 RUY_INHERIT_KERNEL(Path::kAvx2, Path::kAvx512)
247 RUY_INHERIT_KERNEL(Path::kAvx512, Path::kAvxVnni)
248 #endif
249
250 // KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code.
251 //
252 // In other cases, we still define (empty) versions, so that dummy kernels
253 // can use the classes in function signatures.
254 #if ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \
255 RUY_OPT_ENABLED(RUY_OPT_ASM)) || \
256 RUY_PLATFORM(X86)
257
258 #define RUY_ASM_FLAG_HAS_BIAS 0x1
259 #define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2
260 #define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4
261 #define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8
262 #define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10
263
264 #define RUY_ASM_TYPE_ID_UINT8 1
265 #define RUY_ASM_TYPE_ID_INT8 2
266 #define RUY_ASM_TYPE_ID_INT16 3
267 #define RUY_ASM_TYPE_ID_INT32 4
268
269 template <typename DstScalar>
270 struct DstTypeId {};
271
272 template <>
273 struct DstTypeId<std::uint8_t> {
274 static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8;
275 };
276
277 template <>
278 struct DstTypeId<std::int8_t> {
279 static constexpr int kValue = RUY_ASM_TYPE_ID_INT8;
280 };
281
282 template <>
283 struct DstTypeId<std::int16_t> {
284 static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
285 };
286
287 template <>
288 struct DstTypeId<std::int32_t> {
289 static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
290 };
291
292 template <int LhsCols, int RhsCols>
293 struct KernelParams8bit {
294 static constexpr int kMaxDstTypeSize = 4;
295
296 const std::int32_t* bias;
297 const std::int32_t* lhs_sums;
298 const std::int32_t* rhs_sums;
299 const std::int8_t* lhs_base_ptr;
300 const std::int32_t* multiplier_fixedpoint;
301 const std::int32_t* multiplier_exponent;
302 const std::int8_t* rhs_base_ptr;
303 void* dst_base_ptr;
304 std::int32_t lhs_zero_point;
305 std::int32_t rhs_zero_point;
306 std::int32_t dst_zero_point;
307 std::int32_t prod_zp_depth;
308 std::int32_t start_row;
309 std::int32_t start_col;
310 std::int32_t last_row;
311 std::int32_t last_col;
312 std::int32_t dst_rows;
313 std::int32_t dst_cols;
314 std::int32_t lhs_stride;
315 std::int32_t rhs_stride;
316 std::int32_t dst_stride;
317 std::int32_t depth;
318 std::int32_t clamp_min;
319 std::int32_t clamp_max;
320 std::uint8_t flags;
321 std::uint8_t dst_type_id;
322 const std::int32_t zero_data[LhsCols] = {0};
323 std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize];
324 std::int32_t multiplier_fixedpoint_buf[LhsCols];
325 std::int32_t multiplier_exponent_buf[LhsCols];
326 };
327
328 template <typename DstScalar, int LhsCols, int RhsCols>
329 void MakeKernelParams8bit(const PackedMatrix<std::int8_t>& lhs,
330 const PackedMatrix<std::int8_t>& rhs,
331 const BasicSpec<std::int32_t, DstScalar>& spec,
332 int start_row, int start_col, int end_row,
333 int end_col, Matrix<DstScalar>* dst,
334 KernelParams8bit<LhsCols, RhsCols>* params) {
335 using Params = KernelParams8bit<LhsCols, RhsCols>;
336
337 static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, "");
338
339 const int depth = lhs.layout.rows;
340 RUY_DCHECK_EQ(start_row % LhsCols, 0);
341 RUY_DCHECK_EQ(start_col % RhsCols, 0);
342 RUY_DCHECK_EQ(end_row % LhsCols, 0);
343 RUY_DCHECK_EQ(end_col % RhsCols, 0);
344
345 params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
346 params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
347 params->flags = 0;
348 params->bias = params->zero_data;
349 if (spec.bias) {
350 params->bias = spec.bias;
351 params->flags |= RUY_ASM_FLAG_HAS_BIAS;
352 }
353 if (lhs.sums) {
354 params->lhs_sums = lhs.sums;
355 params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS;
356 }
357 if (rhs.sums) {
358 params->rhs_sums = rhs.sums;
359 params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS;
360 }
361 params->start_row = start_row;
362 params->start_col = start_col;
363 params->last_row = end_row - LhsCols;
364 params->last_col = end_col - RhsCols;
365 params->lhs_stride = lhs.layout.stride;
366 params->rhs_stride = rhs.layout.stride;
367 params->dst_stride = sizeof(DstScalar) * dst->layout.stride;
368 params->lhs_zero_point = lhs.zero_point;
369 params->rhs_zero_point = rhs.zero_point;
370 params->dst_zero_point = dst->zero_point;
371 params->depth = depth;
372 params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
373 if (spec.multiplier_fixedpoint_perchannel) {
374 params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
375 params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
376 params->multiplier_fixedpoint = spec.multiplier_fixedpoint_perchannel;
377 params->multiplier_exponent = spec.multiplier_exponent_perchannel;
378 } else {
379 if (spec.multiplier_exponent > 0) {
380 params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
381 }
382 params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf;
383 params->multiplier_exponent = params->multiplier_exponent_buf;
384 for (int i = 0; i < LhsCols; i++) {
385 params->multiplier_fixedpoint_buf[i] = spec.multiplier_fixedpoint;
386 params->multiplier_exponent_buf[i] = spec.multiplier_exponent;
387 }
388 }
389 params->clamp_min = spec.clamp_min;
390 params->clamp_max = spec.clamp_max;
391 params->dst_rows = dst->layout.rows;
392 params->dst_cols = dst->layout.cols;
393
394 RUY_DCHECK_LT(params->last_row, params->dst_rows);
395 RUY_DCHECK_LT(params->last_col, params->dst_cols);
396
397 params->dst_type_id = DstTypeId<DstScalar>::kValue;
398 params->dst_base_ptr =
399 dst->data.get() + start_col * dst->layout.stride + start_row;
400 }
401
402 template <int LhsCols, int RhsCols>
403 struct KernelParamsFloat {
404 const float* lhs_base_ptr;
405 const float* rhs_base_ptr;
406 float* dst_base_ptr;
407 const float* bias;
408 std::int32_t start_row;
409 std::int32_t start_col;
410 std::int32_t last_row;
411 std::int32_t last_col;
412 std::int32_t dst_rows;
413 std::int32_t dst_cols;
414 std::int32_t lhs_stride;
415 std::int32_t rhs_stride;
416 std::int32_t dst_stride;
417 std::int32_t depth;
418 float clamp_min;
419 float clamp_max;
420 std::uint8_t flags;
421 const float zero_data[LhsCols] = {0};
422 float dst_tmp_buf[LhsCols * RhsCols];
423 };
424
425 template <int LhsCols, int RhsCols>
426 inline void MakeKernelParamsFloat(const PackedMatrix<float>& lhs,
427 const PackedMatrix<float>& rhs,
428 const BasicSpec<float, float>& spec,
429 int start_row, int start_col, int end_row,
430 int end_col, Matrix<float>* dst,
431 KernelParamsFloat<LhsCols, RhsCols>* params) {
432 const int depth = lhs.layout.rows;
433 RUY_DCHECK_EQ(start_row % LhsCols, 0);
434 RUY_DCHECK_EQ(start_col % RhsCols, 0);
435 RUY_DCHECK_EQ(end_row % LhsCols, 0);
436 RUY_DCHECK_EQ(end_col % RhsCols, 0);
437
438 params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
439 params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
440 params->dst_base_ptr =
441 dst->data.get() + start_col * dst->layout.stride + start_row;
442
443 std::uint8_t flags = 0;
444 params->bias = params->zero_data;
445 if (spec.bias) {
446 params->bias = spec.bias;
447 flags |= RUY_ASM_FLAG_HAS_BIAS;
448 }
449 params->flags = flags;
450 params->start_row = start_row;
451 params->start_col = start_col;
452 params->last_row = end_row - LhsCols;
453 params->last_col = end_col - RhsCols;
454 params->lhs_stride = sizeof(float) * lhs.layout.stride;
455 params->rhs_stride = sizeof(float) * rhs.layout.stride;
456 params->dst_stride = sizeof(float) * dst->layout.stride;
457 params->depth = depth;
458 params->clamp_min = spec.clamp_min;
459 params->clamp_max = spec.clamp_max;
460 params->dst_rows = dst->layout.rows;
461 params->dst_cols = dst->layout.cols;
462
463 RUY_DCHECK_LT(params->last_row, params->dst_rows);
464 RUY_DCHECK_LT(params->last_col, params->dst_cols);
465 }
466
467 #else // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) &&
468 // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86)
469
470 template <int LhsCols, int RhsCols>
471 struct KernelParams8bit {};
472
473 template <int LhsCols, int RhsCols>
474 struct KernelParamsFloat {};
475
476 #endif // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) &&
477 // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86)
478
479 } // namespace ruy
480
481 #endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_
482