• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
25 
26 #include "src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
27 #include "src/core/NEON/wrapper/wrapper.h"
28 
29 #include "arm_compute/core/Error.h"
30 #include "arm_compute/core/Helpers.h"
31 #include "arm_compute/core/IAccessWindow.h"
32 #include "arm_compute/core/ITensor.h"
33 #include "arm_compute/core/Types.h"
34 #include "arm_compute/core/Utils.h"
35 #include "arm_compute/core/Validate.h"
36 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
37 #include "src/core/AccessWindowStatic.h"
38 #include "src/core/CPP/Validate.h"
39 #include "src/core/NEON/NEFixedPoint.h"
40 #include "src/core/helpers/AutoConfiguration.h"
41 #include "src/core/helpers/WindowHelpers.h"
42 
43 #include <algorithm>
44 
45 using namespace arm_compute::detail;
46 
47 namespace arm_compute
48 {
49 namespace
50 {
51 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
52 template <unsigned int stridex>
53 float16x8_t internal_vld1q(const float16_t *in);
54 
55 template <>
internal_vld1q(const float16_t * in)56 float16x8_t internal_vld1q<1>(const float16_t *in)
57 {
58     return vld1q_f16(in);
59 }
60 
61 template <>
internal_vld1q(const float16_t * in)62 float16x8_t internal_vld1q<2>(const float16_t *in)
63 {
64     const float16x8x2_t tmp = vld2q_f16(in);
65     return tmp.val[0];
66 }
67 
68 template <>
internal_vld1q(const float16_t * in)69 float16x8_t internal_vld1q<3>(const float16_t *in)
70 {
71     const float16x8x3_t tmp = vld3q_f16(in);
72     return tmp.val[0];
73 }
74 
internal_vdupq_n(float16_t v)75 inline float16x8_t internal_vdupq_n(float16_t v)
76 {
77     return vdupq_n_f16(v);
78 }
79 
internal_vst1q(float16_t * p,const float16x8_t & v)80 inline void internal_vst1q(float16_t *p, const float16x8_t &v)
81 {
82     vst1q_f16(p, v);
83 }
84 
internal_vmull(const float16x8_t & x,const float16x8_t & y)85 float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y)
86 {
87     return vmulq_f16(x, y);
88 }
89 
internal_vmlal(const float16x8_t & x,const float16x8_t & y,const float16x8_t & z)90 inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z)
91 {
92     return vaddq_f16(x, vmulq_f16(y, z));
93 }
94 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
95 
96 template <unsigned int stridex>
97 float32x4_t internal_vld1q(const float *in);
98 
99 template <>
internal_vld1q(const float * in)100 float32x4_t internal_vld1q<1>(const float *in)
101 {
102     return vld1q_f32(in);
103 }
104 
105 template <>
internal_vld1q(const float * in)106 float32x4_t internal_vld1q<2>(const float *in)
107 {
108     const float32x4x2_t tmp = vld2q_f32(in);
109     return tmp.val[0];
110 }
111 
112 template <>
internal_vld1q(const float * in)113 float32x4_t internal_vld1q<3>(const float *in)
114 {
115     const float32x4x3_t tmp = vld3q_f32(in);
116     return tmp.val[0];
117 }
118 
internal_vdupq_n(float v)119 inline float32x4_t internal_vdupq_n(float v)
120 {
121     return vdupq_n_f32(v);
122 }
123 
internal_vst1q(float * p,const float32x4_t & v)124 inline void internal_vst1q(float *p, const float32x4_t &v)
125 {
126     vst1q_f32(p, v);
127 }
128 
internal_vmull(const float32x4_t & x,const float32x4_t & y)129 float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y)
130 {
131     return vmulq_f32(x, y);
132 }
133 
internal_vmlal(const float32x4_t & x,const float32x4_t & y,const float32x4_t & z)134 inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z)
135 {
136     return vmlaq_f32(x, y, z);
137 }
138 
139 constexpr int small_tensor_size_optim = 8;
run_optim_small_tensor_info(const ITensorInfo * t)140 inline bool run_optim_small_tensor_info(const ITensorInfo *t)
141 {
142     return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim;
143 }
144 
run_optim_small_tensor(const ITensor * t)145 inline bool run_optim_small_tensor(const ITensor *t)
146 {
147     return run_optim_small_tensor_info(t->info());
148 }
149 
150 // Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
151 // For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
152 // store intermidiate results in memory. Temporary results are stored in NEON registers directly and then written to the output buffer.
153 template <unsigned int stridex>
154 class convolver_w1x1_i8x8_f32
155 {
156 public:
convolve(const Window & window,const ITensor * input,const ITensor * weights,ITensor * output,const PadStrideInfo & conv_info)157     static void convolve(const Window &window, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
158     {
159         ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimX) > small_tensor_size_optim);
160         ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimY) > small_tensor_size_optim);
161 
162         const int          input_stride_x  = input->info()->strides_in_bytes().x();
163         const int          input_stride_y  = input->info()->strides_in_bytes().y();
164         const int          input_stride_z  = input->info()->strides_in_bytes().z();
165         const int          output_stride_y = output->info()->strides_in_bytes().y();
166         const int          output_stride_z = output->info()->strides_in_bytes().z();
167         const int          kernel_stride_z = weights->info()->strides_in_bytes().z();
168         const int          kernel_stride_w = weights->info()->strides_in_bytes()[3];
169         const int          output_h        = output->info()->dimension(1);
170         const int          range_z         = window.z().end() - window.z().start();
171         const int          kernel_depth    = weights->info()->dimension(Window::DimZ);
172         const unsigned int conv_stride_y   = std::get<1>(conv_info.stride());
173         const unsigned int conv_pad_left   = conv_info.pad_left();
174         const unsigned int conv_pad_top    = conv_info.pad_top();
175 
176         // setup output window for the iterator
177         Window window_out = window;
178         window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
179         window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
180         window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
181 
182         // setup input window for the iterator
183         Window window_in = window;
184         // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
185         window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
186         window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
187         window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
188 
189         Window   window_k = calculate_max_window(*weights->info(), Steps(1u));
190         Iterator out(output, window_out);
191         Iterator in(input, window_in);
192         Iterator k(weights, window_k);
193 
194         const uint8_t *k_ptr = k.ptr();
195 
196         execute_window_loop(window_out, [&](const Coordinates & id)
197         {
198             const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
199             uint8_t       *out_ptr   = out.ptr();
200             int            ih        = 0;
201             int            oh        = 0;
202             std::array<float32x4_t, 8> accum0 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
203             std::array<float32x4_t, 8> accum1 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
204             for(int oz = 0; oz < range_z; ++oz)
205             {
206                 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
207                 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
208                 auto p_out_base                                                                               = out_ptr + oz * output_stride_z;
209                 for(int p = 0; p < kernel_depth; ++p)
210                 {
211                     const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
212                     const auto vk0   = internal_vdupq_n(*k_val);
213                     for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
214                     {
215                         const int offset_xy = ih * input_stride_y;
216                         auto      in_val    = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
217                         auto      v_in0     = internal_vld1q<stridex>(in_val);
218                         auto      v_in1     = internal_vld1q<stridex>(in_val + 4);
219                         accum0[oh]          = vmlaq_f32(accum0[oh], vk0, v_in0);
220                         accum1[oh]          = vmlaq_f32(accum1[oh], vk0, v_in1);
221                     }
222                 }
223                 for(oh = 0; oh < output_h; ++oh)
224                 {
225                     auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
226                     vst1q_f32(p_out, accum0[oh]);
227                     vst1q_f32(p_out + 4, accum1[oh]);
228                 }
229             }
230         },
231         in, out);
232     }
233 };
234 
235 template <typename T1, typename T2, unsigned int stridex>
236 class convolver_1x1
237 {
238 public:
convolve(const Window & window,unsigned int num_elems_read_per_iteration,unsigned int num_elems_written_per_iteration,const ITensor * input,const ITensor * weights,ITensor * output,const PadStrideInfo & conv_info)239     static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
240                          const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
241     {
242         const int          input_stride_x  = input->info()->strides_in_bytes().x();
243         const int          input_stride_y  = input->info()->strides_in_bytes().y();
244         const int          input_stride_z  = input->info()->strides_in_bytes().z();
245         const int          output_stride_y = output->info()->strides_in_bytes().y();
246         const int          output_stride_z = output->info()->strides_in_bytes().z();
247         const int          kernel_stride_z = weights->info()->strides_in_bytes().z();
248         const int          kernel_stride_w = weights->info()->strides_in_bytes()[3];
249         const int          output_w        = output->info()->dimension(0);
250         const int          output_h        = output->info()->dimension(1);
251         const int          range_z         = window.z().end() - window.z().start();
252         const int          kernel_depth    = weights->info()->dimension(Window::DimZ);
253         const unsigned int conv_stride_y   = std::get<1>(conv_info.stride());
254         const unsigned int conv_pad_left   = conv_info.pad_left();
255         const unsigned int conv_pad_top    = conv_info.pad_top();
256 
257         // setup output window for the iterator
258         Window window_out = window;
259         window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
260         window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
261         window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
262 
263         // setup input window for the iterator
264         Window window_in = window;
265         // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
266         window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
267         window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
268         window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
269 
270         Window   window_k = calculate_max_window(*weights->info(), Steps(1u));
271         Iterator out(output, window_out);
272         Iterator in(input, window_in);
273         Iterator k(weights, window_k);
274 
275         const uint8_t *k_ptr = k.ptr();
276 
277         execute_window_loop(window_out, [&](const Coordinates & id)
278         {
279             /*
280                 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
281             */
282             const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
283             uint8_t       *out_ptr   = out.ptr();
284             int            ih        = 0;
285             int            oh        = 0;
286             for(int oz = 0; oz < range_z; ++oz)
287             {
288                 auto p_out_base = out_ptr + oz * output_stride_z;
289                 // Step 1
290                 {
291                     const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
292                     const auto vk    = internal_vdupq_n(*k_val);
293                     for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
294                     {
295                         const int offset_xy = ih * input_stride_y;
296                         auto      in_val    = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
297                         auto      p_out     = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
298                         for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
299                         {
300                             internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val)));
301                         }
302                     }
303                 }
304 
305                 // Step 2
306                 for(int p = 1; p < kernel_depth; ++p)
307                 {
308                     const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
309                     const auto vk    = internal_vdupq_n(*k_val);
310                     for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
311                     {
312                         const int offset_xy = ih * input_stride_y;
313                         auto      in_val    = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
314                         auto      p_out     = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
315                         for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
316                         {
317                             internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val)));
318                         }
319                     }
320                 }
321             }
322         },
323         in, out);
324     }
325 };
326 
327 template <unsigned int stridex>
328 float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
329                            const float *m0, const float *m1, const float *m2, const float *m3, const float *m4);
330 
load_matrix_hi(const float * const m0,const float * const m1,const float * const m2)331 inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
332 {
333     const float32x4x3_t m00 =
334     {
335         {
336             vld1q_dup_f32(m0),
337             vld1q_dup_f32(m1),
338             vld1q_dup_f32(m2)
339         }
340     };
341     return m00;
342 }
343 
load_matrix_lo(const float * const m3,const float * const m4)344 inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
345 {
346     const float32x4x2_t m00 =
347     {
348         {
349             vld1q_dup_f32(m3),
350             vld1q_dup_f32(m4)
351         }
352     };
353     return m00;
354 }
355 
load_input(const float * const in)356 inline float32x4x3_t load_input(const float *const in)
357 {
358     const float32x4x3_t vin =
359     {
360         {
361             vld1q_f32(in),
362             vld1q_f32(in + 4),
363             vld1q_f32(in + 8)
364         }
365     };
366     return vin;
367 }
368 
369 template <>
convolve_5x5(const float * in_0,const float * in_1,const float * in_2,const float * in_3,const float * in_4,const float * m0,const float * m1,const float * m2,const float * m3,const float * m4)370 inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
371                                      const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
372 {
373     const float32x4x3_t vin0 = load_input(in_0);
374     const float32x4x3_t vin1 = load_input(in_1);
375     const float32x4x3_t vin2 = load_input(in_2);
376     const float32x4x3_t vin3 = load_input(in_3);
377     const float32x4x3_t vin4 = load_input(in_4);
378     const float32x4x3_t m00  = load_matrix_hi(m0, 1 + m0, 2 + m0);
379     const float32x4x2_t m01  = load_matrix_lo(3 + m0, 4 + m0);
380     const float32x4x3_t m10  = load_matrix_hi(m1, 1 + m1, 2 + m1);
381     const float32x4x2_t m11  = load_matrix_lo(3 + m1, 4 + m1);
382     const float32x4x3_t m20  = load_matrix_hi(m2, 1 + m2, 2 + m2);
383     const float32x4x2_t m21  = load_matrix_lo(3 + m2, 4 + m2);
384     const float32x4x3_t m30  = load_matrix_hi(m3, 1 + m3, 2 + m3);
385     const float32x4x2_t m31  = load_matrix_lo(3 + m3, 4 + m3);
386     const float32x4x3_t m40  = load_matrix_hi(m4, 1 + m4, 2 + m4);
387     const float32x4x2_t m41  = load_matrix_lo(3 + m4, 4 + m4);
388 
389     float32x4x2_t out =
390     {
391         {
392             vmulq_f32(vin0.val[0], m00.val[0]),
393             vmulq_f32(vin0.val[1], m00.val[0])
394         }
395     };
396 
397     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
398     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
399     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
400     out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
401 
402     out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
403     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
404     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
405     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
406     out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
407 
408     out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
409     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
410     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
411     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
412     out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
413 
414     out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
415     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
416     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
417     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
418     out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
419 
420     out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
421     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
422     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
423     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
424     out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
425 
426     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
427     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
428     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
429     out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
430 
431     out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
432     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
433     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
434     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
435     out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
436 
437     out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
438     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
439     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
440     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
441     out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
442 
443     out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
444     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
445     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
446     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
447     out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
448 
449     out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
450     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
451     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
452     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
453     out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
454 
455     return out;
456 }
457 
458 template <>
convolve_5x5(const float * in_0,const float * in_1,const float * in_2,const float * in_3,const float * in_4,const float * m0,const float * m1,const float * m2,const float * m3,const float * m4)459 inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
460                                      const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
461 {
462     float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
463     out.val[0]        = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
464     out.val[0]        = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
465     out.val[0]        = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
466     return out;
467 }
468 
469 template <>
convolve_5x5(const float * in_0,const float * in_1,const float * in_2,const float * in_3,const float * in_4,const float * m0,const float * m1,const float * m2,const float * m3,const float * m4)470 inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
471                                      const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
472 {
473     float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
474     out.val[0]        = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
475     return out;
476 }
477 
478 template <typename T1, typename T2, unsigned int stridex>
479 class convolver_3x3
480 {
481 public:
convolve(const Window & window,unsigned int num_elems_read_per_iteration,unsigned int num_elems_written_per_iteration,const ITensor * input,const ITensor * weights,ITensor * output,const PadStrideInfo & conv_info)482     static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
483                          const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
484     {
485         ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
486         const int          input_stride_x  = input->info()->strides_in_bytes().x();
487         const int          input_stride_y  = input->info()->strides_in_bytes().y();
488         const int          input_stride_z  = input->info()->strides_in_bytes().z();
489         const int          output_stride_y = output->info()->strides_in_bytes().y();
490         const int          output_stride_z = output->info()->strides_in_bytes().z();
491         const int          kernel_stride_x = weights->info()->strides_in_bytes().x();
492         const int          kernel_stride_y = weights->info()->strides_in_bytes().y();
493         const int          kernel_stride_z = weights->info()->strides_in_bytes().z();
494         const int          kernel_stride_w = weights->info()->strides_in_bytes()[3];
495         const int          output_w        = output->info()->dimension(0);
496         const int          output_h        = output->info()->dimension(1);
497         const int          num_planes_z    = window.z().end() - window.z().start();
498         const int          delta_input     = get_input_num_elems_processed(num_elems_written_per_iteration, stridex);
499         const int          kernel_depth    = weights->info()->dimension(Window::DimZ);
500         const unsigned int conv_stride_y   = std::get<1>(conv_info.stride());
501         const unsigned int conv_pad_left   = conv_info.pad_left();
502         const unsigned int conv_pad_top    = conv_info.pad_top();
503 
504         // setup output window for the iterator
505         Window window_out = window;
506         window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
507         window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
508         window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
509 
510         // setup input window for the iterator
511         Window window_in = window;
512         // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
513         window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
514         window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
515         window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
516 
517         Window window_k = calculate_max_window(*weights->info(), Steps(1u));
518 
519         Iterator out(output, window_out);
520         Iterator in(input, window_in);
521         Iterator k(weights, window_k);
522 
523         const uint8_t *k_ptr = k.ptr();
524 
525         execute_window_loop(window_out, [&](const Coordinates & id)
526         {
527             const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
528             uint8_t       *out_ptr   = out.ptr();
529             int            ih        = 0;
530             int            oh        = 0;
531             /*
532                     Each thread executing this kernel computes one or more output's volume planes.
533 
534                     Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
535                     the third thread [16,24] and the fourth thread [25,31].
536 
537                     The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
538                     is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
539 
540                     The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
541                         1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
542                         2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
543             */
544             for(int oz = 0; oz < num_planes_z; ++oz)
545             {
546                 const int zoffset    = id.z() + oz;
547                 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
548                 // Step 1
549                 {
550                     const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
551                     const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
552                     const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
553                     const auto vk_r0    = load_matrix_row(ptr_k_r0);
554                     const auto vk_r1    = load_matrix_row(ptr_k_r1);
555                     const auto vk_r2    = load_matrix_row(ptr_k_r2);
556                     for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
557                     {
558                         auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
559                         auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
560                         auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
561                         auto p_out  = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
562                         for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
563                             in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
564                         {
565                             convolve_3x3<false>(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex);
566                         }
567                     }
568                 }
569                 // Step 2
570                 for(int p = 1; p < kernel_depth; ++p)
571                 {
572                     const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
573                     const uint8_t *input_base = input_ptr + p * input_stride_z;
574                     const auto     ptr_k_r0   = reinterpret_cast<const T1 *>(ptr_k_base);
575                     const auto     ptr_k_r1   = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
576                     const auto     ptr_k_r2   = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
577                     const auto     vk_r0      = load_matrix_row(ptr_k_r0);
578                     const auto     vk_r1      = load_matrix_row(ptr_k_r1);
579                     const auto     vk_r2      = load_matrix_row(ptr_k_r2);
580                     for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
581                     {
582                         auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
583                         auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
584                         auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
585                         auto p_out  = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
586                         for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
587                             in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
588                         {
589                             convolve_3x3<true>(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex);
590                         }
591                     }
592                 }
593             }
594         },
595         in, out);
596     }
597 };
598 
599 template <typename T1, typename T2, unsigned int stridex>
600 class convolver_5x5
601 {
602 public:
convolve(const Window & window,unsigned int num_elems_read_per_iteration,unsigned int num_elems_written_per_iteration,const ITensor * input,const ITensor * weights,ITensor * output,const PadStrideInfo & conv_info)603     static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
604                          const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
605     {
606         ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
607         const int          input_stride_x  = input->info()->strides_in_bytes().x();
608         const int          input_stride_y  = input->info()->strides_in_bytes().y();
609         const int          input_stride_z  = input->info()->strides_in_bytes().z();
610         const int          output_stride_y = output->info()->strides_in_bytes().y();
611         const int          output_stride_z = output->info()->strides_in_bytes().z();
612         const int          kernel_stride_x = weights->info()->strides_in_bytes().x();
613         const int          kernel_stride_y = weights->info()->strides_in_bytes().y();
614         const int          kernel_stride_z = weights->info()->strides_in_bytes().z();
615         const int          kernel_stride_w = weights->info()->strides_in_bytes()[3];
616         const int          output_w        = output->info()->dimension(0);
617         const int          output_h        = output->info()->dimension(1);
618         const int          num_planes_z    = window.z().end() - window.z().start();
619         const int          delta_input     = get_input_num_elems_processed(num_elems_written_per_iteration, stridex);
620         const int          kernel_depth    = weights->info()->dimension(Window::DimZ);
621         const unsigned int conv_stride_y   = std::get<1>(conv_info.stride());
622         const unsigned int conv_pad_left   = conv_info.pad_left();
623         const unsigned int conv_pad_top    = conv_info.pad_top();
624 
625         // setup output window for the iterator
626         Window window_out = window;
627         window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
628         window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
629         window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
630 
631         // setup input window for the iterator
632         Window window_in = window;
633         // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
634         window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
635         window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
636         window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
637 
638         Window window_k = calculate_max_window(*weights->info(), Steps(1u));
639 
640         Iterator out(output, window_out);
641         Iterator in(input, window_in);
642         Iterator k(weights, window_k);
643 
644         const uint8_t *k_ptr = k.ptr();
645 
646         execute_window_loop(window_out, [&](const Coordinates & id)
647         {
648             const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
649             uint8_t       *out_ptr   = out.ptr();
650             int            ih        = 0;
651             int            oh        = 0;
652             for(int oz = 0; oz < num_planes_z; ++oz)
653             {
654                 const int zoffset    = id.z() + oz;
655                 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
656                 // Step 1
657                 {
658                     const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
659                     const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
660                     const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
661                     const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
662                     const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
663                     for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
664                     {
665                         auto in_0  = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
666                         auto in_1  = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
667                         auto in_2  = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
668                         auto in_3  = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
669                         auto in_4  = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
670                         auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
671                         for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
672                             in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
673                         {
674                             auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
675                             store_results<stridex>(p_out, vres);
676                         }
677                     }
678                 }
679                 // Step 2
680                 for(int p = 1; p < kernel_depth; ++p)
681                 {
682                     const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
683                     const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
684                     const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
685                     const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
686                     const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
687 
688                     for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
689                     {
690                         auto in_0  = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
691                         auto in_1  = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
692                         auto in_2  = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
693                         auto in_3  = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
694                         auto in_4  = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
695                         auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
696                         for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
697                             in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
698                         {
699                             auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
700                             accumulate_results<stridex>(p_out, vres);
701                         }
702                     }
703                 }
704             }
705         },
706         in, out);
707     }
708 };
709 
vreduce(const float32x4_t & v)710 float vreduce(const float32x4_t &v)
711 {
712     auto v0    = wrapper::vgethigh(v);
713     auto v1    = wrapper::vgetlow(v);
714     auto v_out = wrapper::vadd(v0, v1);
715 
716     float a = wrapper::vgetlane(v_out, 0);
717     float b = wrapper::vgetlane(v_out, 1);
718     return a + b;
719 }
720 
721 template <typename T1, typename T2>
convolve_1x1(const Window & window,unsigned int num_elems_read_per_iteration,unsigned int num_elems_written_per_iteration,const ITensor * input,const ITensor * weights,ITensor * output,const PadStrideInfo & conv_info)722 inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
723                          const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
724 {
725     const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
726     switch(conv_stride_x)
727     {
728         case 1:
729             convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
730             break;
731         case 2:
732             convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
733             break;
734         case 3:
735             convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
736             break;
737         default:
738             ARM_COMPUTE_ERROR("Not implemented");
739     }
740 }
741 
742 template <>
convolve_1x1(const Window & window,unsigned int num_elems_read_per_iteration,unsigned int num_elems_written_per_iteration,const ITensor * input,const ITensor * weights,ITensor * output,const PadStrideInfo & conv_info)743 inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
744                                        const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
745 {
746     const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
747     if(run_optim_small_tensor(input))
748     {
749         switch(conv_stride_x)
750         {
751             case 1:
752                 convolver_w1x1_i8x8_f32<1>::convolve(window, input, weights, output, conv_info);
753                 break;
754             case 2:
755                 convolver_w1x1_i8x8_f32<2>::convolve(window, input, weights, output, conv_info);
756                 break;
757             case 3:
758                 convolver_w1x1_i8x8_f32<3>::convolve(window, input, weights, output, conv_info);
759                 break;
760             default:
761                 ARM_COMPUTE_ERROR("Not implemented");
762         }
763     }
764     else
765     {
766         switch(conv_stride_x)
767         {
768             case 1:
769                 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
770                 break;
771             case 2:
772                 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
773                 break;
774             case 3:
775                 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
776                 break;
777             default:
778                 ARM_COMPUTE_ERROR("Not implemented");
779         }
780     }
781 }
782 
783 template <typename T1, typename T2>
convolve_3x3(const Window & window,unsigned int num_elems_read_per_iteration,unsigned int num_elems_written_per_iteration,const ITensor * input,const ITensor * weights,ITensor * output,const PadStrideInfo & conv_info)784 inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
785                          const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
786 {
787     const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
788     switch(conv_stride_x)
789     {
790         case 1:
791             convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
792             break;
793         case 2:
794             convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
795             break;
796         case 3:
797             convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
798             break;
799         default:
800             ARM_COMPUTE_ERROR("Not implemented");
801     }
802 }
803 
804 template <typename T1, typename T2>
convolve_5x5(const Window & window,unsigned int num_elems_read_per_iteration,unsigned int num_elems_written_per_iteration,const ITensor * input,const ITensor * weights,ITensor * output,const PadStrideInfo & conv_info)805 inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
806                          const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
807 {
808     const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
809     switch(conv_stride_x)
810     {
811         case 1:
812             convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
813             break;
814         case 2:
815             convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
816             break;
817         case 3:
818             convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
819             break;
820         default:
821             ARM_COMPUTE_ERROR("Not implemented");
822     }
823 }
824 
validate_arguments(const ITensorInfo * input,const ITensorInfo * weights,const ITensorInfo * output,const PadStrideInfo & conv_info)825 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
826 {
827     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
828     ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
829     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
830     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
831     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
832 
833     const DataLayout data_layout = input->data_layout();
834     const int        width_idx   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
835     const int        height_idx  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
836     const int        channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
837 
838     ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
839     ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != input->dimension(channel_idx));
840     ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
841     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
842     ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32);
843     ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (input->data_type() == DataType::F16));
844 
845     // Checks performed when output is configured
846     if(output->total_size() != 0)
847     {
848         TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
849 
850         DataType data_type = input->data_type();
851 
852         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
853         ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != data_type);
854     }
855 
856     return Status{};
857 }
858 
validate_and_configure_window(ITensorInfo * input,ITensorInfo * weights,ITensorInfo * output,const PadStrideInfo & conv_info,unsigned int & num_weight_elems_read_per_row,unsigned int & num_elems_read_per_iteration,unsigned int & num_elems_written_per_iteration,BorderSize & border_size)859 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
860                                                         unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
861 {
862     ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
863 
864     const DataLayout data_layout = input->data_layout();
865     const int        width_idx   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
866 
867     // Calculate right and bottom border
868     unsigned int kernel_size   = weights->dimension(width_idx);
869     const int    conv_stride_x = std::get<0>(conv_info.stride());
870     const int    conv_stride_y = std::get<1>(conv_info.stride());
871     const int    input_width   = input->dimension(width_idx);
872 
873     Window win{};
874     bool   window_changed = false;
875 
876     if(data_layout == DataLayout::NCHW)
877     {
878         switch(kernel_size)
879         {
880             case 1:
881             {
882                 switch(input->data_type())
883                 {
884 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
885                     case DataType::F16:
886                         num_elems_written_per_iteration = 8;
887                         break;
888 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
889                     case DataType::F32:
890                         if(run_optim_small_tensor_info(input))
891                         {
892                             num_elems_written_per_iteration = 8;
893                         }
894                         else
895                         {
896                             num_elems_written_per_iteration = 4;
897                         }
898                         break;
899                     default:
900                         ARM_COMPUTE_ERROR("Data type not supported.");
901                         break;
902                 }
903                 num_weight_elems_read_per_row = kernel_size;
904                 num_elems_read_per_iteration  = conv_stride_x * num_elems_written_per_iteration;
905                 break;
906             }
907             case 3:
908                 switch(input->data_type())
909                 {
910                     case DataType::F32:
911                         num_weight_elems_read_per_row   = 4 + kernel_size - 1;
912                         num_elems_read_per_iteration    = 12;
913                         num_elems_written_per_iteration = 16 >> conv_stride_x;
914                         break;
915 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
916                     case DataType::F16:
917                         num_weight_elems_read_per_row   = 8 + kernel_size - 1;
918                         num_elems_read_per_iteration    = 24;
919                         num_elems_written_per_iteration = 32 >> conv_stride_x;
920                         break;
921 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
922                     default:
923                         ARM_COMPUTE_ERROR("Data type not supported.");
924                         break;
925                 }
926                 break;
927             case 5:
928             {
929                 switch(input->data_type())
930                 {
931                     case DataType::F32:
932                         num_weight_elems_read_per_row   = 4 + kernel_size - 1;
933                         num_elems_read_per_iteration    = 12;
934                         num_elems_written_per_iteration = 16 >> conv_stride_x;
935                         break;
936                     default:
937                         ARM_COMPUTE_ERROR("Data type not supported.");
938                         break;
939                 }
940             }
941             break;
942             default:
943             {
944                 ARM_COMPUTE_ERROR("Not implemented");
945                 break;
946             }
947         }
948 
949         // Calculate right pad
950         int start_x       = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
951         int end_x         = ceil_to_multiple(static_cast<int>(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
952         int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
953 
954         // Calculate border
955         const unsigned int conv_pad_left   = conv_info.pad_left();
956         const unsigned int conv_pad_top    = conv_info.pad_top();
957         const unsigned int conv_pad_right  = std::max(upper_bound_w, 0);
958         const unsigned int conv_pad_bottom = conv_info.pad_bottom();
959 
960         border_size.left   = conv_pad_left;
961         border_size.top    = conv_pad_top;
962         border_size.right  = conv_pad_right;
963         border_size.bottom = conv_pad_bottom;
964 
965         // Configure window
966         win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
967 
968         AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top,
969                                            num_elems_read_per_iteration, kernel_size,
970                                            conv_stride_x, conv_stride_y);
971         AccessWindowStatic     weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
972         AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
973         window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
974         output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
975     }
976     else
977     {
978         // Configure window NHWC without any padding
979         win = calculate_max_window(*output, Steps());
980         Coordinates coord;
981         coord.set_num_dimensions(output->num_dimensions());
982         output->set_valid_region(ValidRegion(coord, output->tensor_shape()));
983     }
984 
985     Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
986     return std::make_pair(err, win);
987 }
988 
have_zero_x_internal_padding(ITensorInfo * input,ITensorInfo * weights)989 bool have_zero_x_internal_padding(ITensorInfo *input, ITensorInfo *weights)
990 {
991     return (input->padding().left == 0 && weights->padding().left == 0 && input->padding().right == 0 && weights->padding().right == 0);
992 }
993 
994 } // namespace
995 
996 template <typename T>
convolve_nhwc_optimized(const Window & window)997 void NEDirectConvolutionLayerKernel::convolve_nhwc_optimized(const Window &window)
998 {
999     // This function assumes that input and weights have not padding in channel
1000 
1001     // Declare useful types
1002     using vtype       = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
1003     using vector_type = typename vtype::type;
1004     using tag_type    = typename vtype::tag_type;
1005 
1006     // Scalar quantities
1007     const int element_size   = _input->info()->element_size();
1008     const int input_stride_w = _input->info()->strides_in_bytes().y() / element_size;
1009     const int input_stride_h = _input->info()->strides_in_bytes().z() / element_size;
1010     const int input_stride_n = _input->info()->strides_in_bytes()[3] / element_size;
1011     const int input_dim_w    = _input->info()->dimension(1);
1012     const int input_dim_h    = _input->info()->dimension(2);
1013 
1014     const int output_stride_c = _output->info()->strides_in_bytes().x();
1015 
1016     const unsigned int kernel_stride_w = _weights->info()->strides_in_bytes().y() / element_size;
1017     const unsigned int kernel_stride_h = _weights->info()->strides_in_bytes().z() / element_size;
1018     const int          kernel_dim_w    = _weights->info()->dimension(1);
1019     const int          kernel_dim_h    = _weights->info()->dimension(2);
1020 
1021     const int conv_pad_top  = _conv_info.pad_top();
1022     const int conv_pad_left = _conv_info.pad_left();
1023     const int conv_stride_w = std::get<0>(_conv_info.stride());
1024     const int conv_stride_h = std::get<1>(_conv_info.stride());
1025 
1026     // Setup input window for the output iterator
1027     Window window_out = window;
1028     window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
1029 
1030     // Setup input window for the weights iterator
1031     Window window_w = calculate_max_window(*_weights->info(), Steps());
1032     window_w.set(Window::DimX, Window::Dimension(0, 1, 1));
1033     window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
1034     window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
1035 
1036     Iterator out(_output, window_out);
1037     Iterator wei(_weights, window_w);
1038 
1039     constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
1040     /*
1041      * This implementation parallelize the full WC plane of input and weights by
1042      * treating them as series of elements. So for example, a 3x3 weights and
1043      * floating point vector operations of 4 elements per time, the first 3
1044      * channel elements of the first row would be taken and additionally the first
1045      * element of the second row. The 9 elements in each single WC weight plane
1046      * would require 2 4-element vector operations and a last single element operation.
1047      *
1048      * This works since when we create the input vector to multiply with the weights,
1049      * the exact required elements are loaded in the same order. Therefore the
1050      * multiplication works on the correct input/weight elements.
1051      */
1052     execute_window_loop(window_out, [&](const Coordinates & id)
1053     {
1054         /*
1055          * In here we create theoretical indexes which then we validate for both
1056          * inputs and weights.
1057          * As a reminder, this loop take each output point in NHW, C is treated
1058          * in the weights loop.
1059          */
1060         // We are computing the theoretical starting input starting points
1061         const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
1062         const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
1063         const int in_w_end_t   = in_w_start_t + kernel_dim_w;
1064         const int in_h_end_t   = in_h_start_t + kernel_dim_h;
1065 
1066         // We are computing the valid initial and ending input points by checking the borders
1067         const int in_w_start = std::max(in_w_start_t, 0);
1068         const int in_h_start = std::max(in_h_start_t, 0);
1069         const int in_w_end   = std::min(in_w_end_t, input_dim_w);
1070         const int in_h_end   = std::min(in_h_end_t, input_dim_h);
1071 
1072         // We use the input points to select the valid weight points to use
1073         const int index_wc_start = (in_w_start - in_w_start_t) * kernel_stride_w;
1074         const int index_h_start  = in_h_start - in_h_start_t;
1075         const int index_wc_end   = (kernel_dim_w - (in_w_end_t - in_w_end)) * kernel_stride_w;
1076         const int index_h_end    = kernel_dim_h - (in_h_end_t - in_h_end);
1077 
1078         execute_window_loop(window_w, [&](const Coordinates & id_w)
1079         {
1080             /*
1081              * This is the loop in the weights, and it goes along N (the batches)
1082              * As a reminder, the batches of the weights are translated into the
1083              * channels of the output
1084              */
1085             const T *in_ptr_row = reinterpret_cast<const T *>(_input->buffer() + _input->info()->offset_first_element_in_bytes())
1086                                   + id[3] * input_stride_n + in_w_start * input_stride_w + in_h_start * input_stride_h;
1087             const T *weights_ptr_row = reinterpret_cast<const T *>(wei.ptr()) + index_h_start * kernel_stride_h;
1088             uint8_t *out_ptr         = out.ptr() + id_w[3] * output_stride_c;
1089 
1090             T out_temp = static_cast<T>(0);
1091             for(int index_h = index_h_start; index_h < index_h_end; ++index_h, in_ptr_row += input_stride_h, weights_ptr_row += kernel_stride_h)
1092             {
1093                 const T    *in_ptr_mover = in_ptr_row;
1094                 int         index_wc     = index_wc_start;
1095                 vector_type out_temp_vec = wrapper::vdup_n(static_cast<T>(0), tag_type());
1096                 for(; index_wc <= index_wc_end - num_elems_read_per_iteration; index_wc += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration)
1097                 {
1098                     const auto src_vec = wrapper::vloadq(in_ptr_mover);
1099                     const auto w_vec   = wrapper::vloadq(weights_ptr_row + index_wc);
1100                     out_temp_vec       = wrapper::vmla(out_temp_vec, w_vec, src_vec);
1101                 }
1102                 out_temp += vreduce(out_temp_vec);
1103                 for(; index_wc < index_wc_end; ++index_wc, ++in_ptr_mover)
1104                 {
1105                     const auto src_val = *(in_ptr_mover);
1106                     const auto w_val   = *(weights_ptr_row + index_wc);
1107                     out_temp += src_val * w_val;
1108                 }
1109             }
1110             *(reinterpret_cast<T *>(out_ptr)) = out_temp;
1111         },
1112         wei);
1113     },
1114     out);
1115 }
1116 
1117 template <typename T>
convolve_nhwc(const Window & window)1118 void NEDirectConvolutionLayerKernel::convolve_nhwc(const Window &window)
1119 {
1120     // Declare useful types
1121     using vtype       = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
1122     using vector_type = typename vtype::type;
1123     using tag_type    = typename vtype::tag_type;
1124 
1125     // Scalar quantities
1126     const int element_size   = _input->info()->element_size();
1127     const int input_stride_w = _input->info()->strides_in_bytes().y() / element_size;
1128     const int input_stride_h = _input->info()->strides_in_bytes().z() / element_size;
1129     const int input_stride_n = _input->info()->strides_in_bytes()[3] / element_size;
1130     const int input_dim_w    = _input->info()->dimension(1);
1131     const int input_dim_h    = _input->info()->dimension(2);
1132 
1133     const int output_stride_c = _output->info()->strides_in_bytes().x();
1134 
1135     const unsigned int kernel_stride_w = _weights->info()->strides_in_bytes().y() / element_size;
1136     const unsigned int kernel_stride_h = _weights->info()->strides_in_bytes().z() / element_size;
1137     const int          kernel_dim_w    = _weights->info()->dimension(1);
1138     const int          kernel_dim_h    = _weights->info()->dimension(2);
1139 
1140     const int conv_pad_top  = _conv_info.pad_top();
1141     const int conv_pad_left = _conv_info.pad_left();
1142     const int conv_stride_w = std::get<0>(_conv_info.stride());
1143     const int conv_stride_h = std::get<1>(_conv_info.stride());
1144 
1145     // Setup input window for the output iterator
1146     Window window_out = window;
1147     window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
1148 
1149     // Setup input window for the weights iterator
1150     Window window_w = calculate_max_window(*_weights->info(), Steps());
1151     window_w.set(Window::DimX, Window::Dimension(0, 1, 1));
1152     window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
1153     window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
1154 
1155     Iterator out(_output, window_out);
1156     Iterator wei(_weights, window_w);
1157 
1158     constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
1159 
1160     execute_window_loop(window_out, [&](const Coordinates & id)
1161     {
1162         // We are computing the theoretical starting input starting points
1163         const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
1164         const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
1165         const int in_w_end_t   = in_w_start_t + kernel_dim_w;
1166         const int in_h_end_t   = in_h_start_t + kernel_dim_h;
1167 
1168         // We are computing the valid initial and ending input points by checking the borders
1169         const int in_w_start = std::max(in_w_start_t, 0);
1170         const int in_h_start = std::max(in_h_start_t, 0);
1171         const int in_w_end   = std::min(in_w_end_t, input_dim_w);
1172         const int in_h_end   = std::min(in_h_end_t, input_dim_h);
1173 
1174         // We use the input points to select the valid weight points to use
1175         const int wei_w_start = in_w_start - in_w_start_t;
1176         const int wei_h_start = in_h_start - in_h_start_t;
1177         const int wei_w_end   = kernel_dim_w - (in_w_end_t - in_w_end);
1178         const int wei_h_end   = kernel_dim_h - (in_h_end_t - in_h_end);
1179 
1180         const int      index_c_end  = _weights->info()->dimension(0);
1181         const T *const in_ptr_start = reinterpret_cast<const T *>(_input->buffer() + _input->info()->offset_first_element_in_bytes()) + id[3] * input_stride_n;
1182 
1183         execute_window_loop(window_w, [&](const Coordinates & id_w)
1184         {
1185             const T *const weights_ptr_start = reinterpret_cast<const T *>(wei.ptr());
1186             uint8_t       *out_ptr           = out.ptr() + id_w[3] * output_stride_c;
1187 
1188             T out_temp = static_cast<T>(0);
1189             for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h)
1190             {
1191                 const T *const in_ptr_row      = in_ptr_start + index_in_h * input_stride_h;
1192                 const T *const weights_ptr_row = weights_ptr_start + index_wei_h * kernel_stride_h;
1193                 for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w)
1194                 {
1195                     const T    *in_ptr_mover      = in_ptr_row + index_in_w * input_stride_w;
1196                     const T    *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w;
1197                     int         index_c           = 0;
1198                     vector_type out_temp_vec      = wrapper::vdup_n(static_cast<T>(0), tag_type());
1199                     for(; index_c <= index_c_end - num_elems_read_per_iteration; index_c += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration, weights_ptr_mover += num_elems_read_per_iteration)
1200                     {
1201                         const auto src_vec = wrapper::vloadq(in_ptr_mover);
1202                         const auto w_vec   = wrapper::vloadq(weights_ptr_mover);
1203                         out_temp_vec       = wrapper::vmla(out_temp_vec, w_vec, src_vec);
1204                     }
1205                     out_temp += vreduce(out_temp_vec);
1206                     for(; index_c < index_c_end; ++index_c, ++in_ptr_mover, ++weights_ptr_mover)
1207                     {
1208                         const auto src_val = *(in_ptr_mover);
1209                         const auto w_val   = *(weights_ptr_mover);
1210                         out_temp += src_val * w_val;
1211                     }
1212                 }
1213             }
1214             *(reinterpret_cast<T *>(out_ptr)) = out_temp;
1215         },
1216         wei);
1217     },
1218     out);
1219 }
1220 
NEDirectConvolutionLayerKernel()1221 NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
1222     : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
1223       _num_elems_written_per_iteration(0)
1224 {
1225 }
1226 
border_size() const1227 BorderSize NEDirectConvolutionLayerKernel::border_size() const
1228 {
1229     return _border_size;
1230 }
1231 
configure(const ITensor * input,const ITensor * weights,ITensor * output,const PadStrideInfo & conv_info)1232 void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1233 {
1234     ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
1235 
1236     _input       = input;
1237     _weights     = weights;
1238     _output      = output;
1239     _conv_info   = conv_info;
1240     _kernel_size = weights->info()->dimension(get_data_layout_dimension_index(weights->info()->data_layout(), DataLayoutDimension::WIDTH));
1241 
1242     const unsigned int conv_pad_left   = conv_info.pad_left();
1243     const unsigned int conv_pad_top    = conv_info.pad_top();
1244     const unsigned int conv_pad_right  = conv_info.pad_right();
1245     const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1246     if(_input->info()->data_layout() == DataLayout::NCHW)
1247     {
1248         _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
1249     }
1250     else
1251     {
1252         _border_size = BorderSize(0);
1253     }
1254 
1255     // Get convolved dimensions
1256     TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info);
1257 
1258     DataType data_type = input->info()->data_type();
1259 
1260     // Output auto inizialitation if not yet initialized
1261     auto_init_if_empty(*output->info(), output_shape, 1, data_type);
1262 
1263     // Perform validation step
1264     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), output->info(), conv_info));
1265 
1266     // Configure kernel window
1267     auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, _num_weight_elems_read_per_row,
1268                                                     _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size);
1269     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1270     INEKernel::configure(win_config.second);
1271 }
1272 
validate(const ITensorInfo * input,const ITensorInfo * weights,const ITensorInfo * output,const PadStrideInfo & conv_info)1273 Status NEDirectConvolutionLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1274 {
1275     unsigned int num_weight_elems_read_per_row   = 0;
1276     unsigned int num_elems_read_per_iteration    = 0;
1277     unsigned int num_elems_written_per_iteration = 0;
1278     BorderSize   border_size                     = {};
1279     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, output, conv_info));
1280     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1281                                                               weights->clone().get(),
1282                                                               output->clone().get(),
1283                                                               conv_info,
1284                                                               num_weight_elems_read_per_row,
1285                                                               num_elems_read_per_iteration,
1286                                                               num_elems_written_per_iteration,
1287                                                               border_size)
1288                                 .first);
1289 
1290     return Status{};
1291 }
1292 
run(const Window & window,const ThreadInfo & info)1293 void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info)
1294 {
1295     ARM_COMPUTE_UNUSED(info);
1296     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1297     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1298     ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1299 
1300     const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
1301 
1302     if(_input->info()->data_layout() == DataLayout::NCHW)
1303     {
1304         switch(kernel_size)
1305         {
1306             case 1:
1307             {
1308                 switch(_input->info()->data_type())
1309                 {
1310                     case DataType::F32:
1311                         convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1312                         break;
1313 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1314                     case DataType::F16:
1315                         convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1316                         break;
1317 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1318                     default:
1319                         ARM_COMPUTE_ERROR("Data type not supported");
1320                         break;
1321                 }
1322                 break;
1323             }
1324             case 3:
1325             {
1326                 switch(_input->info()->data_type())
1327                 {
1328                     case DataType::F32:
1329                         convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1330                         break;
1331 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1332                     case DataType::F16:
1333                         convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1334                         break;
1335 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1336                     default:
1337                         ARM_COMPUTE_ERROR("Data type not supported");
1338                         break;
1339                 }
1340                 break;
1341             }
1342             case 5:
1343             {
1344                 switch(_input->info()->data_type())
1345                 {
1346                     case DataType::F32:
1347                         convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1348                         break;
1349                     default:
1350                         ARM_COMPUTE_ERROR("Data type not supported");
1351                         break;
1352                 }
1353                 break;
1354             }
1355             default:
1356             {
1357                 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
1358                 break;
1359             }
1360         }
1361     }
1362     else
1363     {
1364         switch(_input->info()->data_type())
1365         {
1366             case DataType::F32:
1367             {
1368                 if(have_zero_x_internal_padding(_input->info(), _weights->info()))
1369                 {
1370                     convolve_nhwc_optimized<float>(window);
1371                 }
1372                 else
1373                 {
1374                     convolve_nhwc<float>(window);
1375                 }
1376                 break;
1377             }
1378             default:
1379                 ARM_COMPUTE_ERROR("Data type not supported");
1380                 break;
1381         }
1382     }
1383 }
1384 } // namespace arm_compute