• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #ifndef ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H
26 #define ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H
27 
28 #include <arm_neon.h>
29 
30 namespace arm_compute
31 {
32 namespace detail
33 {
load_matrix_row(const float * ptr)34 inline float32x4x3_t load_matrix_row(const float *ptr)
35 {
36     const float32x4x3_t r =
37     {
38         {
39             vld1q_dup_f32(ptr),
40             vld1q_dup_f32(1 + ptr),
41             vld1q_dup_f32(2 + ptr)
42         }
43     };
44     return r;
45 }
46 
47 template <unsigned int stridex>
48 float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2);
49 
50 template <>
51 inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2)
52 {
53     const float32x4x3_t vtop =
54     {
55         {
56             vld1q_f32(in_top),
57             vld1q_f32(in_top + 4),
58             vld1q_f32(in_top + 8)
59         }
60     };
61     const float32x4x3_t vmid =
62     {
63         {
64             vld1q_f32(in_mid),
65             vld1q_f32(in_mid + 4),
66             vld1q_f32(in_mid + 8)
67         }
68     };
69     const float32x4x3_t vlow =
70     {
71         {
72             vld1q_f32(in_low),
73             vld1q_f32(in_low + 4),
74             vld1q_f32(in_low + 8)
75         }
76     };
77     float32x4x2_t out =
78     {
79         {
80             vmulq_f32(vtop.val[0], m0.val[0]),
81             vmulq_f32(vtop.val[1], m0.val[0])
82         }
83     };
84     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
85     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
86 
87     out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
88     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
89     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
90 
91     out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
92     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
93     out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
94 
95     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
96     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
97 
98     out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
99     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
100     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
101 
102     out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
103     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
104     out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
105     return out;
106 }
107 
108 template <>
109 inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2)
110 {
111     float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
112     out.val[0]        = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
113     out.val[0]        = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
114     out.val[0]        = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
115     return out;
116 }
117 
118 template <>
119 inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2)
120 {
121     float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
122     out.val[0]        = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
123     return out;
124 }
125 
126 template <unsigned int stridex>
127 void store_results(float *buffer, const float32x4x2_t &values);
128 
129 template <>
130 void store_results<1>(float *buffer, const float32x4x2_t &values)
131 {
132     vst1q_f32(buffer, values.val[0]);
133     vst1q_f32(buffer + 4, values.val[1]);
134 }
135 
136 template <>
137 void store_results<2>(float *buffer, const float32x4x2_t &values)
138 {
139     vst1q_f32(buffer, values.val[0]);
140 }
141 
142 template <>
143 void store_results<3>(float *buffer, const float32x4x2_t &values)
144 {
145     vst1_f32(buffer, vget_low_f32(values.val[0]));
146 }
147 
148 template <unsigned int stridex>
149 int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
150 
151 template <>
152 int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
153 {
154     return num_elems_written_per_iteration;
155 }
156 
157 template <>
158 int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
159 {
160     return num_elems_written_per_iteration << 1;
161 }
162 
163 template <>
164 int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
165 {
166     return num_elems_written_per_iteration * 3;
167 }
168 }
169 } // namespace arm_compute
170 #endif /* ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H */