• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
26 #define ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
27 
28 #include "src/core/AccessWindowStatic.h"
29 #include "src/core/NEON/NEFixedPoint.h"
30 #include "src/core/NEON/wrapper/wrapper.h"
31 #include "support/Requires.h"
32 
33 #include <arm_neon.h>
34 
35 namespace arm_compute
36 {
37 namespace detail
38 {
39 /** Loads a 3x3 matrix as a row  (float).
40  *
41  * @param[in] ptr            Pointer to a float 3x3 matrix.
42  * @param[in] weights_offset (Optional) Weights quantization offset.
43  *
44  * @return The loaded matrix.
45  */
46 inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0)
47 {
48     ARM_COMPUTE_UNUSED(weights_offset);
49     const float32x4x3_t r =
50     {
51         {
52             vld1q_dup_f32(ptr),
53             vld1q_dup_f32(1 + ptr),
54             vld1q_dup_f32(2 + ptr)
55         }
56     };
57     return r;
58 }
59 
60 /** Loads a 3x3 matrix as a row (uint8_t/int8_t).
61  *
62  * @param[in] ptr            Pointer to a uint8_t/int8_t 3x3 matrix.
63  * @param[in] weights_offset (Optional) Weights quantization offset.
64  *
65  * @return The loaded matrix.
66  */
67 template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
68 inline int32x4x3_t load_matrix_row(const T *ptr, int weights_offset = 0)
69 {
70     const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
71 
72     /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
73        r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
74     int32x4x3_t r =
75     {
76         {
77             vaddq_s32(v_weights_offset, vdupq_n_s32(*ptr)),
78             vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 1))),
79             vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 2)))
80         }
81     };
82     return r;
83 }
84 
85 /** Stores a float32x4x2_t array into a memory location.
86  *
87  * @param[in] buffer Pointer to the memory location where the values will be stored.
88  * @param[in] values Values that will be stored.
89  *
90  */
91 template <unsigned int stridex>
92 void store_results(float *buffer, const float32x4x2_t &values);
93 
94 template <>
95 inline void store_results<1>(float *buffer, const float32x4x2_t &values)
96 {
97     vst1q_f32(buffer, values.val[0]);
98     vst1q_f32(buffer + 4, values.val[1]);
99 }
100 
101 template <>
102 inline void store_results<2>(float *buffer, const float32x4x2_t &values)
103 {
104     vst1q_f32(buffer, values.val[0]);
105 }
106 
107 template <>
108 inline void store_results<3>(float *buffer, const float32x4x2_t &values)
109 {
110     vst1_f32(buffer, vget_low_f32(values.val[0]));
111 }
112 
113 /** Stores a uint32_t array into a memory location.
114  *
115  * @param[in] buffer Pointer to the memory location where the values will be stored.
116  * @param[in] values Values that will be stored.
117  *
118  */
119 template <unsigned int stridex>
120 void store_results(int32_t *buffer, const int32x4x2_t &values);
121 
122 template <>
123 inline void store_results<1>(int32_t *buffer, const int32x4x2_t &values)
124 {
125     vst1q_s32(buffer, values.val[0]);
126     vst1q_s32(buffer + 4, values.val[1]);
127 }
128 
129 template <>
130 inline void store_results<2>(int32_t *buffer, const int32x4x2_t &values)
131 {
132     vst1q_s32(buffer, values.val[0]);
133 }
134 
135 template <>
136 inline void store_results<3>(int32_t *buffer, const int32x4x2_t &values)
137 {
138     vst1_s32(buffer, vget_low_s32(values.val[0]));
139 }
140 
141 template <unsigned int stridex>
142 inline void accumulate_results(float *buffer, const float32x4x2_t &values);
143 
144 template <>
145 inline void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
146 {
147     vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
148     vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
149 }
150 
151 template <>
152 inline void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
153 {
154     vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
155 }
156 
157 template <>
158 inline void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
159 {
160     vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
161 }
162 
163 template <unsigned int stridex>
164 void accumulate_results(int32_t *buffer, const int32x4x2_t &values);
165 
166 template <>
167 inline void accumulate_results<1>(int32_t *buffer, const int32x4x2_t &values)
168 {
169     vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
170     vst1q_s32(buffer + 4, vaddq_s32(vld1q_s32(buffer + 4), values.val[1]));
171 }
172 
173 template <>
174 inline void accumulate_results<2>(int32_t *buffer, const int32x4x2_t &values)
175 {
176     vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
177 }
178 
179 template <>
180 inline void accumulate_results<3>(int32_t *buffer, const int32x4x2_t &values)
181 {
182     vst1_s32(buffer, vadd_s32(vld1_s32(buffer), vget_low_s32(values.val[0])));
183 }
184 
185 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
186 /** Stores a float16x8x2_t array into a memory location.
187  *
188  * @param[in] buffer Pointer to the memory location where the values will be stored.
189  * @param[in] values Values that will be stored.
190  *
191  */
192 template <unsigned int stridex>
193 void store_results(float16_t *buffer, const float16x8x2_t &values);
194 
195 template <>
196 inline void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
197 {
198     vst1q_f16(buffer, values.val[0]);
199     vst1q_f16(buffer + 8, values.val[1]);
200 }
201 
202 template <>
203 inline void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
204 {
205     vst1q_f16(buffer, values.val[0]);
206 }
207 
208 template <>
209 inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
210 {
211     vst1_f16(buffer, vget_low_f16(values.val[0]));
212 }
213 
214 template <unsigned int stridex>
215 inline void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
216 
217 template <>
218 inline void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
219 {
220     vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
221     vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
222 }
223 
224 template <>
225 inline void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
226 {
227     vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
228 }
229 
230 template <>
231 inline void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
232 {
233     vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
234 }
235 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
236 
237 /** Perform a 3x3 convolution for 4 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
238  *
239  * @param[in] in_top       Pointer to the first row of the input.
240  * @param[in] in_mid       Pointer to the second row of the input.
241  * @param[in] in_low       Pointer to the third row of the input.
242  * @param[in] m0           First row of the filter.
243  * @param[in] m1           Second row of the filter.
244  * @param[in] m2           Third row of the filter.
245  * @param[in] dilation_x   Dilation, in elements across x.
246  * @param[in] input_offset (Optional) Input quantization offset.
247  *
248  */
single_convolve_3x3_dilation(const float * in_top,const float * in_mid,const float * in_low,const float32x4x3_t & m0,const float32x4x3_t & m1,const float32x4x3_t & m2,const size_t dilation_x,int input_offset)249 inline float32x4_t single_convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
250                                                 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
251                                                 const size_t dilation_x, int input_offset)
252 {
253     ARM_COMPUTE_UNUSED(input_offset);
254 
255     const float32x4x3_t vtop =
256     {
257         {
258             vld1q_f32(in_top),
259             vld1q_f32(in_top + dilation_x),
260             vld1q_f32(in_top + 2 * dilation_x)
261         }
262     };
263     const float32x4x3_t vmid =
264     {
265         {
266             vld1q_f32(in_mid),
267             vld1q_f32(in_mid + dilation_x),
268             vld1q_f32(in_mid + 2 * dilation_x)
269         }
270     };
271     const float32x4x3_t vlow =
272     {
273         {
274             vld1q_f32(in_low),
275             vld1q_f32(in_low + dilation_x),
276             vld1q_f32(in_low + 2 * dilation_x)
277         }
278     };
279     float32x4_t out = vmulq_f32(vtop.val[0], m0.val[0]);
280     out             = vmlaq_f32(out, vtop.val[1], m0.val[1]);
281     out             = vmlaq_f32(out, vtop.val[2], m0.val[2]);
282 
283     out = vmlaq_f32(out, vmid.val[0], m1.val[0]);
284     out = vmlaq_f32(out, vmid.val[1], m1.val[1]);
285     out = vmlaq_f32(out, vmid.val[2], m1.val[2]);
286 
287     out = vmlaq_f32(out, vlow.val[0], m2.val[0]);
288     out = vmlaq_f32(out, vlow.val[1], m2.val[1]);
289     out = vmlaq_f32(out, vlow.val[2], m2.val[2]);
290 
291     return out;
292 }
293 
294 /** Perform a 3x3 convolution for 8 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
295  *
296  * @param[in] in_top       Pointer to the first row of the input.
297  * @param[in] in_mid       Pointer to the second row of the input.
298  * @param[in] in_low       Pointer to the third row of the input.
299  * @param[in] m0           First row of the filter.
300  * @param[in] m1           Second row of the filter.
301  * @param[in] m2           Third row of the filter.
302  * @param[in] dilation_x   Dilation, in elements across x.
303  * @param[in] stridex      Stride value in elements across x.
304  * @param[in] input_offset (Optional) Input quantization offset.
305  *
306  */
307 inline float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
308                                            const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
309                                            const size_t dilation_x, unsigned int stridex, int input_offset = 0)
310 {
311     ARM_COMPUTE_ERROR_ON(stridex > 3);
312     float32x4x2_t out =
313     {
314         {
315             single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
316             single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
317         }
318     };
319 
320     if(stridex == 2)
321     {
322         out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
323         out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
324         out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
325     }
326     else if(stridex == 3)
327     {
328         out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
329     }
330 
331     return out;
332 }
333 
334 /** Perform a convolve3x3 on float32.
335  *
336  * @param[in]  in_top       Pointer to the first row of the input.
337  * @param[in]  in_mid       Pointer to the second row of the input.
338  * @param[in]  in_low       Pointer to the third row of the input.
339  * @param[out] out_ptr      Pointer to the output.
340  * @param[in]  m0           First row of the filter.
341  * @param[in]  m1           Second row of the filter.
342  * @param[in]  m2           Third row of the filter.
343  * @param[in]  stridex      Stride value in elements across x.
344  * @param[in]  input_offset (Optional) Input quantization offset.
345  *
346  */
347 template <bool accumulate>
348 void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
349                   const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
350                   unsigned int stridex, int input_offset = 0);
351 
352 template <bool accumulate>
convolve_3x3(const float * in_top,const float * in_mid,const float * in_low,float * out_ptr,const float32x4x3_t & m0,const float32x4x3_t & m1,const float32x4x3_t & m2,unsigned int stridex,int input_offset)353 inline void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
354                          const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
355                          unsigned int stridex, int input_offset)
356 {
357     ARM_COMPUTE_UNUSED(input_offset);
358     ARM_COMPUTE_ERROR_ON(stridex > 3);
359 
360     float32x4x2_t out =
361     {
362         {
363             vdupq_n_f32(0.f),
364             vdupq_n_f32(0.f)
365         }
366     };
367     if(stridex == 2)
368     {
369         const float32x4x2_t vtop     = vld2q_f32(in_top);
370         const float32x4x2_t vmid     = vld2q_f32(in_mid);
371         const float32x4x2_t vlow     = vld2q_f32(in_low);
372         const float32x4_t   vtop_end = vld1q_f32(in_top + 8);
373         const float32x4_t   vmid_end = vld1q_f32(in_mid + 8);
374         const float32x4_t   vlow_end = vld1q_f32(in_low + 8);
375 
376         out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
377 
378         out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]);
379         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]);
380 
381         out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
382         out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]);
383         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]);
384 
385         out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
386         out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]);
387         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]);
388 
389         accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
390     }
391     else
392     {
393         const float32x4x3_t vtop =
394         {
395             {
396                 vld1q_f32(in_top),
397                 vld1q_f32(in_top + 4),
398                 vld1q_f32(in_top + 8)
399             }
400         };
401         const float32x4x3_t vmid =
402         {
403             {
404                 vld1q_f32(in_mid),
405                 vld1q_f32(in_mid + 4),
406                 vld1q_f32(in_mid + 8)
407             }
408         };
409         const float32x4x3_t vlow =
410         {
411             {
412                 vld1q_f32(in_low),
413                 vld1q_f32(in_low + 4),
414                 vld1q_f32(in_low + 8)
415             }
416         };
417         out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
418         out.val[1] = vmulq_f32(vtop.val[1], m0.val[0]);
419 
420         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
421         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
422 
423         out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
424         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
425         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
426 
427         out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
428         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
429         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
430 
431         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
432         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
433 
434         out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
435         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
436         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
437 
438         out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
439         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
440         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
441 
442         if(stridex == 3)
443         {
444             out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
445             accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
446         }
447         else
448         {
449             accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
450         }
451     }
452 }
453 
454 /** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
455  *
456  * @param[in] in_top       Pointer to the first row of the input.
457  * @param[in] in_mid       Pointer to the second row of the input.
458  * @param[in] in_low       Pointer to the third row of the input.
459  * @param[in] m0           First row of the filter.
460  * @param[in] m1           Second row of the filter.
461  * @param[in] m2           Third row of the filter.
462  * @param[in] dilation_x   Dilation, in elements across x.
463  * @param[in] input_offset Input quantization offset.
464  *
465  */
466 template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
single_convolve_3x3_dilation(const T * in_top,const T * in_mid,const T * in_low,const int32x4x3_t & m0,const int32x4x3_t & m1,const int32x4x3_t & m2,size_t dilation_x,int32_t input_offset)467 inline int32x4_t single_convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low,
468                                               const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
469                                               size_t dilation_x, int32_t input_offset)
470 {
471     using VectorType    = typename std::conditional<std::is_same<T, uint8_t>::value, uint8x8x3_t, int8x8x3_t>::type;
472     using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
473 
474     const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
475 
476     const VectorType vtop =
477     {
478         {
479             wrapper::vload(in_top),
480             wrapper::vload(in_top + dilation_x),
481             wrapper::vload(in_top + 2 * dilation_x)
482         }
483     };
484     const VectorType vmid =
485     {
486         {
487             wrapper::vload(in_mid),
488             wrapper::vload(in_mid + dilation_x),
489             wrapper::vload(in_mid + 2 * dilation_x)
490         }
491     };
492     const VectorType vlow =
493     {
494         {
495             wrapper::vload(in_low),
496             wrapper::vload(in_low + dilation_x),
497             wrapper::vload(in_low + 2 * dilation_x)
498         }
499     };
500 
501     const int32x4x3_t vtop_s32 =
502     {
503         {
504             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
505             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
506             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[2])))),
507         }
508     };
509     const int32x4x3_t vmid_s32 =
510     {
511         {
512             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
513             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
514             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[2])))),
515         }
516     };
517     const int32x4x3_t vlow_s32 =
518     {
519         {
520             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
521             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
522             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[2])))),
523         }
524     };
525 
526     int32x4_t out = wrapper::vmul(vtop_s32.val[0], m0.val[0]);
527     out           = wrapper::vmla(out, vtop_s32.val[1], m0.val[1]);
528     out           = wrapper::vmla(out, vtop_s32.val[2], m0.val[2]);
529 
530     out = wrapper::vmla(out, vmid_s32.val[0], m1.val[0]);
531     out = wrapper::vmla(out, vmid_s32.val[1], m1.val[1]);
532     out = wrapper::vmla(out, vmid_s32.val[2], m1.val[2]);
533 
534     out = wrapper::vmla(out, vlow_s32.val[0], m2.val[0]);
535     out = wrapper::vmla(out, vlow_s32.val[1], m2.val[1]);
536     out = wrapper::vmla(out, vlow_s32.val[2], m2.val[2]);
537 
538     return out;
539 }
540 
541 /** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
542  *
543  * @param[in] in_top       Pointer to the first row of the input.
544  * @param[in] in_mid       Pointer to the second row of the input.
545  * @param[in] in_low       Pointer to the third row of the input.
546  * @param[in] m0           First row of the filter.
547  * @param[in] m1           Second row of the filter.
548  * @param[in] m2           Third row of the filter.
549  * @param[in] dilation_x   Dilation, in elements across x.
550  * @param[in] stridex      Stride value in elements across x.
551  * @param[in] input_offset Input quantization offset.
552  *
553  */
554 template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
convolve_3x3_dilation(const T * in_top,const T * in_mid,const T * in_low,const int32x4x3_t & m0,const int32x4x3_t & m1,const int32x4x3_t & m2,const size_t dilation_x,unsigned int stridex,int input_offset)555 inline int32x4x2_t convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
556                                          const size_t dilation_x, unsigned int stridex, int input_offset)
557 {
558     ARM_COMPUTE_ERROR_ON(stridex > 3);
559     int32x4x2_t out =
560     {
561         {
562             single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
563             single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
564         }
565     };
566 
567     if(stridex == 2)
568     {
569         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
570         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
571         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
572     }
573     else if(stridex == 3)
574     {
575         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
576     }
577     return out;
578 }
579 
580 /** Perform a convolve3x3 on 8-bit elements
581  *
582  * @param[in]  in_top       Pointer to the first row of the input.
583  * @param[in]  in_mid       Pointer to the second row of the input.
584  * @param[in]  in_low       Pointer to the third row of the input.
585  * @param[out] out_ptr      Pointer to the output.
586  * @param[in]  m0           First row of the filter.
587  * @param[in]  m1           Second row of the filter.
588  * @param[in]  m2           Third row of the filter.
589  * @param[in]  stridex      Stride value in elements across x.
590  * @param[in]  input_offset Input quantization offset.
591  *
592  */
593 template < bool accumulate, typename T1, typename T2, REQUIRES_TA(std::is_same<T1, uint8_t>::value || std::is_same<T1, int8_t>::value) >
convolve_3x3(const T1 * in_top,const T1 * in_mid,const T1 * in_low,T2 * out_ptr,const int32x4x3_t & m0,const int32x4x3_t & m1,const int32x4x3_t & m2,unsigned int stridex,int32_t input_offset)594 void convolve_3x3(const T1 *in_top, const T1 *in_mid, const T1 *in_low, T2 *out_ptr,
595                   const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
596                   unsigned int stridex, int32_t input_offset)
597 {
598     ARM_COMPUTE_ERROR_ON(stridex > 3);
599     using VectorType    = typename std::conditional<std::is_same<T1, uint8_t>::value, uint8x8x2_t, int8x8x2_t>::type;
600     using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
601 
602     const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
603 
604     const VectorType vtop =
605     {
606         {
607             wrapper::vload(in_top),
608             wrapper::vload(in_top + 8)
609         }
610     };
611     const VectorType vmid =
612     {
613         {
614             wrapper::vload(in_mid),
615             wrapper::vload(in_mid + 8)
616         }
617     };
618     const VectorType vlow =
619     {
620         {
621             wrapper::vload(in_low),
622             wrapper::vload(in_low + 8)
623         }
624     };
625 
626     const int32x4x3_t vtop_s32 =
627     {
628         {
629             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
630             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vtop.val[0])))),
631             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
632         }
633     };
634     const int32x4x3_t vmid_s32 =
635     {
636         {
637             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
638             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vmid.val[0])))),
639             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
640         }
641     };
642     const int32x4x3_t vlow_s32 =
643     {
644         {
645             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
646             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vlow.val[0])))),
647             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
648         }
649     };
650 
651     int32x4x2_t out
652     {
653         {
654             wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
655             wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
656         }
657     };
658 
659     // 0
660     out.val[0] = wrapper::vmla(out.val[0], vtop_s32.val[0], m0.val[0]);
661     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vtop_s32.val[0], vtop_s32.val[1]), m0.val[1]);
662     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vtop_s32.val[0], vtop_s32.val[1]), m0.val[2]);
663 
664     out.val[0] = wrapper::vmla(out.val[0], vmid_s32.val[0], m1.val[0]);
665     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vmid_s32.val[0], vmid_s32.val[1]), m1.val[1]);
666     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vmid_s32.val[0], vmid_s32.val[1]), m1.val[2]);
667 
668     out.val[0] = wrapper::vmla(out.val[0], vlow_s32.val[0], m2.val[0]);
669     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vlow_s32.val[0], vlow_s32.val[1]), m2.val[1]);
670     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vlow_s32.val[0], vlow_s32.val[1]), m2.val[2]);
671 
672     // 1
673     out.val[1] = wrapper::vmla(out.val[1], vtop_s32.val[1], m0.val[0]);
674     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vtop_s32.val[1], vtop_s32.val[2]), m0.val[1]);
675     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vtop_s32.val[1], vtop_s32.val[2]), m0.val[2]);
676 
677     out.val[1] = wrapper::vmla(out.val[1], vmid_s32.val[1], m1.val[0]);
678     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vmid_s32.val[1], vmid_s32.val[2]), m1.val[1]);
679     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vmid_s32.val[1], vmid_s32.val[2]), m1.val[2]);
680 
681     out.val[1] = wrapper::vmla(out.val[1], vlow_s32.val[1], m2.val[0]);
682     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vlow_s32.val[1], vlow_s32.val[2]), m2.val[1]);
683     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vlow_s32.val[1], vlow_s32.val[2]), m2.val[2]);
684 
685     if(stridex == 1)
686     {
687         accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
688     }
689     else if(stridex == 2)
690     {
691         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
692         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
693         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
694 
695         accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
696     }
697     else if(stridex == 3)
698     {
699         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
700         accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
701     }
702 }
703 
704 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
705 /** Loads a 3x3 matrix as a row (float16_t).
706  *
707  * @param[in] ptr Pointer to a float 3x3 matrix.
708  *
709  * @return The loaded matrix.
710  */
711 inline float16x8x3_t load_matrix_row(const float16_t *ptr, int weights_offset = 0)
712 {
713     ARM_COMPUTE_UNUSED(weights_offset);
714     /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
715        r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
716     const float16x8x3_t r =
717     {
718         {
719             vld1q_dup_f16(ptr),
720             vld1q_dup_f16(1 + ptr),
721             vld1q_dup_f16(2 + ptr)
722         }
723     };
724     return r;
725 }
726 
727 /** Perform a 3x3 convolution for 8 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
728  *
729  * @param[in] in_top       Pointer to the first row of the input.
730  * @param[in] in_mid       Pointer to the second row of the input.
731  * @param[in] in_low       Pointer to the third row of the input.
732  * @param[in] m0           First row of the filter.
733  * @param[in] m1           Second row of the filter.
734  * @param[in] m2           Third row of the filter.
735  * @param[in] dilation_x   Dilation, in elements across x.
736  * @param[in] input_offset (Optional)Input quantization offset.
737  *
738  */
739 inline float16x8_t single_convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
740                                                 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
741                                                 const size_t dilation_x, int input_offset = 0)
742 {
743     ARM_COMPUTE_UNUSED(input_offset);
744     const float16x8x3_t vtop =
745     {
746         {
747             vld1q_f16(in_top),
748             vld1q_f16(in_top + dilation_x),
749             vld1q_f16(in_top + 2 * dilation_x)
750         }
751     };
752     const float16x8x3_t vmid =
753     {
754         {
755             vld1q_f16(in_mid),
756             vld1q_f16(in_mid + dilation_x),
757             vld1q_f16(in_mid + 2 * dilation_x)
758         }
759     };
760     const float16x8x3_t vlow =
761     {
762         {
763             vld1q_f16(in_low),
764             vld1q_f16(in_low + dilation_x),
765             vld1q_f16(in_low + 2 * dilation_x)
766         }
767     };
768     float16x8_t out = vmulq_f16(vtop.val[0], m0.val[0]);
769     out             = vaddq_f16(out, vmulq_f16(vtop.val[1], m0.val[1]));
770     out             = vaddq_f16(out, vmulq_f16(vtop.val[2], m0.val[2]));
771 
772     out = vaddq_f16(out, vmulq_f16(vmid.val[0], m1.val[0]));
773     out = vaddq_f16(out, vmulq_f16(vmid.val[1], m1.val[1]));
774     out = vaddq_f16(out, vmulq_f16(vmid.val[2], m1.val[2]));
775 
776     out = vaddq_f16(out, vmulq_f16(vlow.val[0], m2.val[0]));
777     out = vaddq_f16(out, vmulq_f16(vlow.val[1], m2.val[1]));
778     out = vaddq_f16(out, vmulq_f16(vlow.val[2], m2.val[2]));
779 
780     return out;
781 }
782 
783 /** Perform a 3x3 convolution for 16 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
784  *
785  * @param[in] in_top       Pointer to the first row of the input.
786  * @param[in] in_mid       Pointer to the second row of the input.
787  * @param[in] in_low       Pointer to the third row of the input.
788  * @param[in] m0           First row of the filter.
789  * @param[in] m1           Second row of the filter.
790  * @param[in] m2           Third row of the filter.
791  * @param[in] dilation_x   Dilation, in elements across x.
792  * @param[in] stridex      Stride value in elements across x.
793  * @param[in] input_offset (Optional) Input quantization offset.
794  *
795  */
796 inline float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
797                                            const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
798                                            const size_t dilation_x, unsigned int stridex, int input_offset = 0)
799 {
800     float16x8x2_t out =
801     {
802         {
803             single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
804             single_convolve_3x3_dilation(in_top + 8, in_mid + 8, in_low + 8, m0, m1, m2, dilation_x, input_offset)
805         }
806     };
807 
808     if(stridex == 2)
809     {
810         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
811         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
812         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3);
813         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4);
814         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5);
815         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6);
816         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7);
817     }
818     else if(stridex == 3)
819     {
820         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
821         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
822         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
823     }
824 
825     return out;
826 }
827 
828 /** Perform a convolve3x3 on float16.
829  *
830  * @param[in]  in_top       Pointer to the first row of the input.
831  * @param[in]  in_mid       Pointer to the second row of the input.
832  * @param[in]  in_low       Pointer to the third row of the input.
833  * @param[out] out_ptr      Pointer to the output.
834  * @param[in]  m0           First row of the filter.
835  * @param[in]  m1           Second row of the filter.
836  * @param[in]  m2           Third row of the filter.
837  * @param[in]  stridex      Stride value in elements across x.
838  * @param[in]  input_offset (Optional) Input quantization offset.
839  *
840  */
841 template <bool accumulate>
842 inline void convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, float16_t *out_ptr,
843                          const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
844                          unsigned int stridex, int input_offset = 0)
845 {
846     ARM_COMPUTE_UNUSED(input_offset);
847 
848     float16x8x2_t out =
849     {
850         {
851             vdupq_n_f16(0),
852             vdupq_n_f16(0)
853         }
854     };
855     if(stridex == 2)
856     {
857         const float16x8x2_t vtop     = vld2q_f16(in_top);
858         const float16x8x2_t vmid     = vld2q_f16(in_mid);
859         const float16x8x2_t vlow     = vld2q_f16(in_low);
860         const float16x8_t   vtop_end = vld1q_f16(in_top + 16);
861         const float16x8_t   vmid_end = vld1q_f16(in_mid + 16);
862         const float16x8_t   vlow_end = vld1q_f16(in_low + 16);
863 
864         out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
865 
866         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1]));
867         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2]));
868 
869         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
870         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1]));
871         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2]));
872 
873         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
874         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1]));
875         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2]));
876 
877         accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
878     }
879     else
880     {
881         const float16x8x3_t vtop =
882         {
883             {
884                 vld1q_f16(in_top),
885                 vld1q_f16(in_top + 8),
886                 vld1q_f16(in_top + 16)
887             }
888         };
889         const float16x8x3_t vmid =
890         {
891             {
892                 vld1q_f16(in_mid),
893                 vld1q_f16(in_mid + 8),
894                 vld1q_f16(in_mid + 16)
895             }
896         };
897         const float16x8x3_t vlow =
898         {
899             {
900                 vld1q_f16(in_low),
901                 vld1q_f16(in_low + 8),
902                 vld1q_f16(in_low + 16)
903             }
904         };
905         out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
906         out.val[1] = vmulq_f16(vtop.val[1], m0.val[0]);
907 
908         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
909         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
910         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
911         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
912         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
913         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
914         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
915         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
916         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
917         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
918         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
919         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
920         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
921         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
922         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
923         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
924 
925         if(stridex == 3)
926         {
927             out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
928             out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
929             out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
930 
931             accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
932         }
933         else
934         {
935             accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
936         }
937     }
938 }
939 #endif /** __ARM_FEATURE_FP16_VECTOR_ARITHMETIC **/
940 
941 /** Get the number of elements processed on 3x3 convolution.
942  *
943  * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution.
944  * @param[in] stridex                         Stride value in elements across x.
945  *
946  * @return The number of elements processed.
947  */
get_input_num_elems_processed(unsigned int num_elems_written_per_iteration,unsigned int stridex)948 inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex)
949 {
950     switch(stridex)
951     {
952         case 1:
953             return num_elems_written_per_iteration;
954         case 2:
955             return num_elems_written_per_iteration << 1;
956         case 3:
957             return num_elems_written_per_iteration * 3;
958         default:
959             ARM_COMPUTE_ERROR("stridex not supported");
960             return 0;
961     }
962 }
963 }
964 } // namespace arm_compute
965 #endif /* ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H */
966