• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2023 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
25 
26 #include "arm_compute/runtime/NEON/NEScheduler.h"
27 #include "src/core/CPP/Validate.h"
28 #include "src/core/NEON/kernels/arm_gemm/utils.hpp"
29 #include "src/core/helpers/MemoryHelpers.h"
30 #include "src/core/utils/AssemblyUtils.h"
31 #include "src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h"
32 #include "src/cpu/kernels/assembly/arm_gemm.hpp"
33 #include "src/cpu/utils/CpuAuxTensorHandler.h"
34 
35 #include <arm_neon.h>
36 
37 namespace arm_compute
38 {
39 namespace cpu
40 {
41 using namespace arm_compute::experimental;
42 
43 namespace
44 {
45 struct free_delete
46 {
operator ()arm_compute::cpu::__anon91ef70ce0111::free_delete47     void operator()(void *x)
48     {
49         free(x);
50     }
51 };
52 
53 struct Params
54 {
55     unsigned int M;
56     unsigned int N;
57     unsigned int K;
58     unsigned int batches;
59     unsigned int multis;
60     unsigned int sections;
61     bool         indirect;
62 };
63 
extract_parameters(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * d,const AsmGemmInfo & info)64 Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info)
65 {
66     ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
67     Params p;
68     p.M        = d->tensor_shape().y();
69     p.K        = a->tensor_shape().x();
70     p.N        = d->tensor_shape().x();
71     p.batches  = 1;
72     p.multis   = 1;
73     p.sections = 1;
74     p.indirect = false;
75 
76     if(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)
77     {
78         p.indirect = true;
79         p.sections = b->tensor_shape()[2] * b->tensor_shape()[3];
80     }
81     else
82     {
83         p.multis  = b->tensor_shape().z();
84         p.batches = d->tensor_shape().total_size_upper(2) / p.multis;
85     }
86 
87     // Update M in case of GEMM3D for output
88     if(info.depth_output_gemm3d != 0)
89     {
90         p.M       = d->tensor_shape().y() * d->tensor_shape().z();
91         p.batches = d->tensor_shape().total_size_upper(3) / p.multis;
92     }
93 
94     return p;
95 }
96 
scheduling_hint_heuristic(arm_gemm::GemmMethod method,DataType data_type)97 IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataType data_type)
98 {
99     // Schedule assembly kernel
100     const int         granule_threshold = 200;
101     IScheduler::Hints scheduling_hint   = IScheduler::Hints(Window::DimX);
102     if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && data_type == DataType::F32)
103     {
104         scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold);
105     }
106     else if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && (data_type == DataType::F32 || data_type == DataType::F16 || data_type == DataType::U8 || data_type == DataType::S8))
107     {
108         //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions
109         scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
110     }
111     else if(method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && (data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED))
112     {
113         //special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case
114         scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
115     }
116 
117     return scheduling_hint;
118 }
119 
120 /** Fallback in case ACL doesn't have a function */
121 template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
122 class Fallback : public CpuGemmAssemblyDispatch::IFallback
123 {
124 public:
125     /** Destructor */
126     ~Fallback() = default;
127 
128     /** Initialise the functions's input and output.
129      *
130      * @param[in]  a         Input tensor containing the Matrix A.
131      * @param[in]  b         Input tensor containing the Matrix B.
132      * @param[in]  c         Input tensor containing the Matrix C.
133      * @param[out] d         Output tensor to store the result of matrix multiplication.
134      * @param[in]  args      Matrix multiplication information.
135      * @param[in]  gemm_info GEMM meta-data
136      * @param[in]  os        Output stage meta-data.
137      */
138     void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
139                    arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
140                    const OutputStage &os = {});
141 
142     /** Set requantization shifts to be used
143      *
144      * @param[in] shifts Requantization shifts
145      *
146      * @return Pointer to the shift data
147      */
148     /** Set requantization data to be used
149       *
150       *
151       * @param shifts       Requantization shifts
152       * @param multipliers  Requantization multipliers
153       *
154       * @return A tuple with the pointers to the shift and multiplier data respectively
155       */
156     std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> set_requantize_data(const std::vector<int32_t> &shifts,
157                                                                                             const std::vector<int32_t> &multipliers);
158 
159     // Inherited methods overridden:
160     void run(ITensorPack &tensors) override;
161     void prepare(ITensorPack &tensors) override;
162     bool                             is_configured() const override;
163     experimental::MemoryRequirements workspace() const override;
isVarWeightsKernel() const164     bool                             isVarWeightsKernel() const override
165     {
166         if(!_gemm_kernel_asm)
167             return false;
168         const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
169         return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY;
170     }
171 
172 private:
173     enum AuxTensorIdx
174     {
175         AsmGemmWorkspace = 0,
176         Pretranspose,
177         Count
178     };
179 
180     /** Configure the indirect buffer
181      *
182      * @param[in]  a    Input tensor containing the Matrix A.
183      * @param[in]  b    Input tensor containing the Matrix B.
184      * @param[out] d    Output tensor to store the result of matrix multiplication.
185      * @param[in]  info GEMM meta-data
186      */
187     void configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info);
188     /** Prepare the indirect buffer */
189     void prepare_indirect_buffer(ITensorPack &tensors);
190 
191     /** Assembly Gemm kernel */
192     std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
193     /** Optimised Arm® Neon™ kernel */
194     std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
195     /** Assembly GEMM workspace tensor info */
196     TensorInfo _workspace_info{};
197     /** Pre-transpose tensor info */
198     TensorInfo _pretranspose_info{};
199     /** Prepared flag */
200     bool _is_prepared{ false };
201     /** GEMM meta-data */
202     AsmGemmInfo _gemm_info{};
203     /** GEMM kernel description */
204     arm_gemm::KernelDescription _kernel_info{};
205     /** Per channel quantization shifts */
206     std::vector<int32_t> _shifts{};
207     std::vector<int32_t> right_shifts{};
208     std::vector<int32_t> left_shifts{};
209     /** Per channel quantization multipliers */
210     std::vector<int32_t> _multipliers{};
211     /** Indirect buffer */
212     std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
213     std::unique_ptr<const TypeInput *, free_delete>        _indirect_buf{};
214     std::vector<TypeInput>           _indirect_pad{};
215     arm_gemm::ConvolutionParameters  _cp{};
216     experimental::MemoryRequirements _aux_mem{ Count };
217     bool                             _B_pretranspose_required{ false };
218     bool                             _is_b_constant{ true };
219     bool                             _is_c_constant{ true };
220 };
221 
222 template <typename TypeInput, typename TypeOutput, class OutputStage>
223 std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
set_requantize_data(const std::vector<int32_t> & shifts,const std::vector<int32_t> & multipliers)224 Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, const std::vector<int32_t> &multipliers)
225 {
226     _multipliers   = multipliers;
227     _shifts        = shifts;
228     bool need_left = false;
229     for(const auto s : _shifts)
230     {
231         left_shifts.push_back(std::max(-s, int32_t(0)));
232         right_shifts.push_back(std::min(-s, int32_t(0)));
233         if(s < 0 && !need_left)
234         {
235             need_left = true;
236         }
237     }
238     return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data());
239 }
240 
241 template <typename TypeInput, typename TypeOutput, class OutputStage>
prepare_indirect_buffer(ITensorPack & tensors)242 void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors)
243 {
244     auto             a              = tensors.get_const_tensor(TensorType::ACL_SRC_0);
245     const TypeInput *A_ptr          = reinterpret_cast<TypeInput *>(a->buffer());
246     const int        multis         = 1;
247     const int        batches        = a->info()->tensor_shape().total_size_upper(3);
248     const size_t     stride_A       = a->info()->strides_in_bytes().y() / sizeof(TypeInput);
249     const size_t     batch_stride_A = a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
250     const size_t     multi_stride_A = a->info()->strides_in_bytes()[4] / sizeof(TypeInput);
251 
252     const size_t output_hw    = _cp.output_height * _cp.output_width;
253     const int    batch_size   = _cp.kernel_height * _cp.kernel_width * output_hw * sizeof(TypeInput);
254     const size_t batch_stride = batch_size / sizeof(TypeInput);
255     const int    multi_size   = batch_size * batches;
256     const size_t multi_stride = multi_size / sizeof(TypeInput);
257 
258     for(int64_t m = 0; m < multis; m++)
259     {
260         for(int64_t b = 0; b < batches; b++)
261         {
262             for(int64_t output_y = 0; output_y < _cp.output_height; output_y++)
263             {
264                 for(int64_t output_x = 0; output_x < _cp.output_width; output_x++)
265                 {
266                     int64_t output_xy = (output_y * _cp.output_width) + output_x;
267 
268                     for(int64_t kernel_y = 0; kernel_y < _cp.kernel_height; kernel_y++)
269                     {
270                         for(int64_t kernel_x = 0; kernel_x < _cp.kernel_width; kernel_x++)
271                         {
272                             int64_t input_x   = (output_x * _cp.output_stride_w) + kernel_x - _cp.padding_left;
273                             int64_t input_y   = (output_y * _cp.output_stride_h) + kernel_y - _cp.padding_top;
274                             int64_t kernel_xy = (kernel_y * _cp.kernel_width) + kernel_x;
275                             int64_t input_xy  = (input_y * _cp.input_width) + input_x;
276 
277                             if(input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height)
278                             {
279                                 _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = _indirect_pad.data();
280                             }
281                             else
282                             {
283                                 _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
284                                     A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A);
285                             }
286                         }
287                     }
288                 }
289             }
290         }
291     }
292 }
293 
294 template <typename TypeInput, typename TypeOutput, class OutputStage>
configure_indirect(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * d,const AsmGemmInfo & info)295 void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info)
296 {
297     ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect));
298 
299     float zeropad = 0.f;
300     if(is_data_type_quantized(a->data_type()))
301     {
302         zeropad = a->quantization_info().uniform().offset;
303     }
304 
305     const int64_t input_width    = static_cast<int64_t>(a->tensor_shape()[1]);
306     const int64_t input_height   = static_cast<int64_t>(a->tensor_shape()[2]);
307     const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]);
308     const int64_t kernel_width   = static_cast<int64_t>(b->tensor_shape()[2]);
309     const int64_t kernel_height  = static_cast<int64_t>(b->tensor_shape()[3]);
310     const int64_t output_width   = static_cast<int64_t>(d->tensor_shape()[1]);
311     const int64_t output_height  = static_cast<int64_t>(d->tensor_shape()[2]);
312 
313     _cp = { input_width, input_height, input_channels, kernel_width, kernel_height, output_width, output_height,
314             info.ps_info.stride().first, info.ps_info.stride().second, info.padding_top, info.padding_left, zeropad
315           };
316 
317     if(info.method == AsmConvMethod::Conv)
318     {
319         _gemm_kernel_asm->set_convolution_parameters(_cp);
320     }
321 
322     if(info.method == AsmConvMethod::Indirect)
323     {
324         const unsigned int multis    = 1;
325         const unsigned int batches   = a->tensor_shape().total_size_upper(3);
326         const unsigned int kernel_hw = _cp.kernel_width * _cp.kernel_height;
327         const unsigned int output_hw = _cp.output_width * _cp.output_height;
328 
329         using TypeInputPtr        = TypeInput *;
330         const int    batch_size   = kernel_hw * output_hw * sizeof(TypeInputPtr);
331         const size_t batch_stride = batch_size / sizeof(TypeInputPtr);
332         const int    multi_size   = batch_size * batches;
333         const size_t multi_stride = multi_size / sizeof(TypeInputPtr);
334 
335         _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(reinterpret_cast<const TypeInput **>(malloc(multi_size * multis)));
336         _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches)));
337         _indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad));
338 
339         // Set indirect argument
340         int64_t pos = 0;
341         for(int64_t m = 0; m < multis; m++)
342         {
343             for(int64_t b = 0; b < batches; b++)
344             {
345                 for(int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++)
346                 {
347                     (_indirect_arg.get())[pos++] = _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw;
348                 }
349             }
350         }
351 
352         _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get());
353     }
354 }
355 
356 template <typename TypeInput, typename TypeOutput, class OutputStage>
configure(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,ITensorInfo * d,arm_gemm::GemmArgs args,const AsmGemmInfo & gemm_info,const OutputStage & os)357 void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
358                                                              arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
359                                                              const OutputStage &os)
360 {
361     ARM_COMPUTE_UNUSED(c);
362 
363     _is_b_constant = b->are_values_constant();
364     _is_c_constant = c ? c->are_values_constant() : true;
365 
366     _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os);
367     if(_gemm_kernel_asm == nullptr)
368     {
369         //configuration not supported: Leave function unconfigured:
370         return;
371     }
372 
373     arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config();
374 
375     // arm_compute wrapper for the Gemm object (see above)
376     auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>();
377     ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
378     acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
379     const size_t       workspace_size = _gemm_kernel_asm->get_working_size();
380     const unsigned int alignment      = 4096;
381     _workspace_info                   = TensorInfo(TensorShape(workspace_size), 1, DataType::U8);
382     _aux_mem[AsmGemmWorkspace]        = MemoryInfo(offset_int_vec(AsmGemmWorkspace), MemoryLifetime::Temporary, workspace_size, alignment);
383 
384     //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
385     //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
386     {
387         const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size();
388         if(window_size < static_cast<unsigned int>(args._maxthreads))
389         {
390             _gemm_kernel_asm->set_nthreads(window_size);
391         }
392     }
393 
394     _optimised_kernel = std::move(acl_gemm_wrapper);
395     _gemm_info        = gemm_info;
396     // Check for pre-transposed support
397     if(_gemm_kernel_asm->B_pretranspose_required())
398     {
399         // Forcing 128-byte alignment (required by 32-bit kernels)
400         const unsigned int alignment           = 128;
401         const size_t       B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
402         _pretranspose_info                     = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8);
403         _aux_mem[Pretranspose]                 = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment);
404         _B_pretranspose_required               = true;
405     }
406 
407     // Handle indirect GEMM convolution
408     if(gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect)
409     {
410         configure_indirect(a, b, d, gemm_info);
411     }
412 }
413 
414 template <typename TypeInput, typename TypeOutput, class OutputStage>
prepare(ITensorPack & tensors)415 void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
416 {
417     if(!_is_prepared)
418     {
419         auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
420         auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
421 
422         // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
423         if(c && c->info()->data_type() == DataType::S32)
424         {
425             _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
426         }
427 
428         // Pretranspose B if required
429         if(_gemm_kernel_asm->B_pretranspose_required())
430         {
431             // Fixed format kernels need no pretranspose.
432             ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
433             const int  ldb            = b->info()->strides_in_bytes().y() / b->info()->element_size();
434             const auto in1_ptr        = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
435             const int  multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
436 
437             CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
438             ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
439             _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), in1_ptr, ldb, multi_stride_b);
440 
441             b->mark_as_unused();
442         }
443 
444         if(_gemm_info.method == AsmConvMethod::Indirect)
445         {
446             prepare_indirect_buffer(tensors);
447         }
448 
449         _is_prepared = true;
450     }
451 }
452 
453 template <typename TypeInput, typename TypeOutput, class OutputStage>
is_configured() const454 bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
455 {
456     return _optimised_kernel != nullptr;
457 }
458 
459 template <typename TypeInput, typename TypeOutput, class OutputStage>
workspace() const460 experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const
461 {
462     return _aux_mem;
463 }
464 
465 template <typename TypeInput, typename TypeOutput, class OutputStage>
run(ITensorPack & tensors)466 void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
467 {
468     auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
469     auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
470     auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
471     auto d = tensors.get_tensor(TensorType::ACL_DST);
472 
473     int       lda = a->info()->strides_in_bytes().y() / a->info()->element_size();
474     int       ldb = 0;
475     const int ldd = d->info()->strides_in_bytes().y() / d->info()->element_size();
476 
477     const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2;
478     const size_t a_multi_idx = a_batch_idx + 1;
479     const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2;
480     const size_t d_multi_idx = d_batch_idx + 1;
481 
482     int       batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / a->info()->element_size();
483     const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / d->info()->element_size();
484 
485     int       multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / a->info()->element_size();
486     int       multi_stride_b = 0;
487     const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size();
488 
489     auto             in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes());
490     const TypeInput *in1_ptr = nullptr;
491     auto             out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
492 
493     // Check if B is pre-tranposed and de-reference if not
494     if(!_gemm_kernel_asm->B_is_pretransposed())
495     {
496         ldb                                = b->info()->strides_in_bytes().y() / b->info()->element_size();
497         multi_stride_b                     = b->info()->strides_in_bytes().z() / b->info()->element_size();
498         in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
499     }
500 
501     // If necessary, run pretranspose every time if either weights or biases are non-constant
502     if((b && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32))
503     {
504         if(c && c->info()->data_type() == DataType::S32)
505         {
506             _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
507         }
508 
509         // Pretranspose B if required
510         if(_B_pretranspose_required)
511         {
512             const int  ldb            = b->info()->strides_in_bytes().y() / b->info()->element_size();
513             const auto b_ptr          = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
514             const int  multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
515 
516             CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
517             ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
518 
519             if(_is_b_constant)
520             {
521                 _gemm_kernel_asm->requantize_bias(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
522             }
523             else
524             {
525                 _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
526             }
527         }
528     }
529 
530     const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type());
531 
532     // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
533     CpuAuxTensorHandler workspace(offset_int_vec(AsmGemmWorkspace), _workspace_info, tensors, false);
534     if(workspace.get()->buffer() != nullptr)
535     {
536         _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(workspace.get()->buffer()));
537         const unsigned int split_dim   = scheduling_hint.split_dimension();
538         const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size();
539         unsigned int       num_threads = NEScheduler::get().num_threads();
540         if(window_size < num_threads)
541         {
542             num_threads = window_size;
543         }
544         if(split_dim != IScheduler::split_dimensions_all)
545         {
546             // Make sure the kernel does not expect more threads than we can actually spawn
547             const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim);
548             num_threads                       = std::min(num_iterations, num_threads);
549         }
550         _gemm_kernel_asm->set_nthreads(num_threads);
551     }
552 
553     // Prepare assembly kernel
554     prepare(tensors);
555 
556     // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
557     TypeOutput *bias = nullptr;
558     if(c && c->info()->data_type() != DataType::S32)
559     {
560         bias = reinterpret_cast<TypeOutput *>(c->buffer() + c->info()->offset_first_element_in_bytes());
561     }
562 
563     if(_gemm_info.method == AsmConvMethod::Indirect)
564     {
565         in0_ptr        = nullptr;
566         lda            = 0;
567         batch_stride_a = 0;
568         multi_stride_a = 0;
569     }
570 
571     // Set gemm parameters
572     _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a,
573                                  in1_ptr, ldb, multi_stride_b,
574                                  out_ptr, ldd, batch_stride_d, multi_stride_d,
575                                  bias, 0);
576     // Schedule
577     NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
578 }
579 
580 template <typename TypeInput, typename TypeOutput>
create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> & arm_gemm,const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,ITensorInfo * d,arm_gemm::Activation activation,const AsmGemmInfo & info)581 void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
582                      const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
583                      arm_gemm::Activation activation, const AsmGemmInfo &info)
584 {
585     Params         p           = extract_parameters(a, b, d, info);
586     const CPUInfo &ci          = NEScheduler::get().cpu_info();
587     unsigned int   num_threads = NEScheduler::get().num_threads();
588 
589     arm_gemm::GemmConfig cfg;
590     cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
591     arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
592 
593     // Create arm_gemm fallback
594     auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
595     fallback->configure(a, b, c, d, args, info);
596     arm_gemm = std::move(fallback);
597 }
598 
599 template <typename TypeInput, typename TypeOutput>
create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> & arm_gemm,const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,ITensorInfo * d,arm_gemm::Activation activation,const AsmGemmInfo & info)600 void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
601                            const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
602                            arm_gemm::Activation activation, const AsmGemmInfo &info)
603 {
604     ARM_COMPUTE_UNUSED(activation);
605     Params             p           = extract_parameters(a, b, d, info);
606     const CPUInfo     &ci          = NEScheduler::get().cpu_info();
607     const unsigned int num_threads = NEScheduler::get().num_threads();
608 
609     arm_gemm::GemmConfig cfg;
610     cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
611     arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
612 
613     // Create arm_gemm fallback
614     auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
615 
616     // Configure requantization info
617     const int32_t                 negation = info.negated_offsets ? 1 : -1;
618     const int32_t                 a_offset = -a->quantization_info().uniform().offset * negation;
619     const int32_t                 b_offset = -b->quantization_info().uniform().offset * negation;
620     const GEMMLowpOutputStageInfo os_info  = info.output_stage;
621 
622     arm_gemm::Requantize32 gemm_requant_info{};
623     if(os_info.gemmlowp_shifts.size() > 1)
624     {
625         const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers);
626         gemm_requant_info          = arm_gemm::Requantize32(nullptr, 0,
627                                                             a_offset, b_offset, os_info.gemmlowp_offset,
628                                                             (std::get<0>(requantize_data)) ? std::get<1>(requantize_data) : nullptr,
629                                                             std::get<2>(requantize_data),
630                                                             std::get<3>(requantize_data),
631                                                             os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
632     }
633     else
634     {
635         gemm_requant_info = arm_gemm::Requantize32(nullptr, 0,
636                                                    a_offset, b_offset, os_info.gemmlowp_offset,
637                                                    -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier,
638                                                    os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
639     }
640 
641     // Configure fallback
642     fallback->configure(a, b, c, d, args, info, gemm_requant_info);
643     arm_gemm = std::move(fallback);
644 }
645 } //namespace
646 
CpuGemmAssemblyDispatch()647 CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch()
648     : _arm_gemm(nullptr)
649 {
650 }
651 
has_opt_impl(arm_compute::WeightFormat & expected_weight_format,const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * d,const AsmGemmInfo & info)652 Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
653                                              const AsmGemmInfo &info)
654 {
655     ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
656     ARM_COMPUTE_UNUSED(c);
657     arm_gemm::Activation act         = assembly_utils::map_to_arm_gemm_activation(info.activation_info);
658     Params               p           = extract_parameters(a, b, d, info);
659     const CPUInfo       &ci          = NEScheduler::get().cpu_info();
660     unsigned int         num_threads = NEScheduler::get().num_threads();
661     arm_gemm::GemmConfig cfg;
662     cfg.weight_format                           = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
663     arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
664     arm_gemm::GemmArgs     args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
665     switch(a->data_type())
666     {
667         case DataType::F32:
668             ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
669                                             "We could not find an optimized kernel for F32 input");
670             break;
671 #ifdef __aarch64__
672         case DataType::U8:
673         case DataType::QASYMM8:
674             if(d->data_type() == DataType::S32)
675             {
676                 ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
677                                                 "We could not find an optimized kernel for U8/QASYMM8 input and U32 output");
678             }
679             else
680             {
681                 ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
682                                                 "We could not find an optimized kernel for U8 input and U8 output");
683             }
684             break;
685         case DataType::S8:
686         case DataType::QASYMM8_SIGNED:
687             if(d->data_type() == DataType::S32)
688             {
689                 ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
690                                                 "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
691             }
692             else
693             {
694                 ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
695                                                 "We could not find an optimized kernel for S8 input and S8 output");
696             }
697             break;
698 #endif /* __aarch64__ */
699 #if defined(ARM_COMPUTE_ENABLE_BF16)
700         case DataType::BFLOAT16:
701         {
702             ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
703                                             "We could not find an optimized kernel for BFLOAT16 input and F32 output");
704             break;
705         }
706 #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
707 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
708         case DataType::F16:
709             ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
710                                             "We could not find an optimized kernel for F16 input and F16 output");
711             break;
712 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
713         default:
714             ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel");
715             break;
716     }
717     expected_weight_format = assembly_utils::map_to_arm_compute_weight_format(arm_gemm_expected_wf);
718 
719     return Status{};
720 }
721 
validate(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * d,const AsmGemmInfo & info)722 Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
723 {
724     ARM_COMPUTE_UNUSED(c, info);
725     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
726     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
727     ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
728     ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(info.reshape_b_only_on_first_run), "Assembly kernel will not be executed when reshape_b_only_on_first_run is false");
729 
730 #ifndef __aarch64__
731     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64");
732 #endif /* __aarch64__ */
733     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8,
734                                                          DataType::BFLOAT16, DataType::F16, DataType::F32);
735     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(b, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::S8,
736                                                          DataType::BFLOAT16, DataType::F16, DataType::F32);
737     if(is_data_type_quantized_per_channel(b->data_type()))
738     {
739         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8_SIGNED, DataType::S8);
740     }
741     else if(is_fixed_format_fast_math(info.weight_format))
742     {
743         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
744         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
745     }
746     else
747     {
748         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
749     }
750     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input");
751     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input");
752     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32, "Only F32 output supported for BFLOAT16 input");
753     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
754     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
755     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
756     arm_compute::WeightFormat expected_weight_format;
757     const Status              ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
758     if((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
759     {
760         // Correctness check: if the format expected by the kernel is
761         // not "any", make sure that the one found matches the format
762         // intended by the caller.
763         ARM_COMPUTE_RETURN_ERROR_ON_MSG((expected_weight_format != info.weight_format),
764                                         "The format expected by the kernel does not correspond with the one requested by the user.");
765     }
766     return ret;
767 }
768 
is_activation_supported(const ActivationLayerInfo & activation)769 bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
770 {
771     arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(activation);
772     return act.type != arm_gemm::Activation::Type::None;
773 }
774 
configure(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,ITensorInfo * d,const AsmGemmInfo & info)775 void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info)
776 {
777     ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
778     arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info);
779 
780     //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
781     if(!CpuGemmAssemblyDispatch::validate(a, b, c, d, info))
782     {
783         return;
784     }
785 
786     switch(a->data_type())
787     {
788         case DataType::F32:
789             create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info);
790             break;
791 #ifdef __aarch64__
792         case DataType::U8:
793         case DataType::QASYMM8:
794             if(d->data_type() == DataType::S32)
795             {
796                 create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
797             }
798             else
799             {
800                 create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
801             }
802             break;
803         case DataType::S8:
804         case DataType::QASYMM8_SIGNED:
805             if(d->data_type() == DataType::S32)
806             {
807                 create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
808             }
809             else
810             {
811                 create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
812             }
813             break;
814 #endif /* __aarch64__ */
815 #if defined(ARM_COMPUTE_ENABLE_BF16)
816         case DataType::BFLOAT16:
817             create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
818             break;
819 #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
820 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
821         case DataType::F16:
822             create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
823             break;
824 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
825         default:
826             break;
827     }
828 }
829 
prepare(ITensorPack & tensors)830 void CpuGemmAssemblyDispatch::prepare(ITensorPack &tensors)
831 {
832     ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
833     _arm_gemm->prepare(tensors);
834 }
835 
is_configured() const836 bool CpuGemmAssemblyDispatch::is_configured() const
837 {
838     return _arm_gemm && _arm_gemm->is_configured();
839 }
840 
run(ITensorPack & tensors)841 void CpuGemmAssemblyDispatch::run(ITensorPack &tensors)
842 {
843     ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
844     _arm_gemm->run(tensors);
845 }
846 
workspace() const847 experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const
848 {
849     ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
850     return _arm_gemm->workspace();
851 }
852 } // namespace cpu
853 } // namespace arm_compute
854