1 /*
2 * Copyright (c) 2017-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 #ifndef ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H
26 #define ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H
27
28 #include <arm_neon.h>
29
30 namespace arm_compute
31 {
32 namespace detail
33 {
load_matrix_row(const float * ptr)34 inline float32x4x3_t load_matrix_row(const float *ptr)
35 {
36 const float32x4x3_t r =
37 {
38 {
39 vld1q_dup_f32(ptr),
40 vld1q_dup_f32(1 + ptr),
41 vld1q_dup_f32(2 + ptr)
42 }
43 };
44 return r;
45 }
46
47 template <unsigned int stridex>
48 float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2);
49
50 template <>
51 inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2)
52 {
53 const float32x4x3_t vtop =
54 {
55 {
56 vld1q_f32(in_top),
57 vld1q_f32(in_top + 4),
58 vld1q_f32(in_top + 8)
59 }
60 };
61 const float32x4x3_t vmid =
62 {
63 {
64 vld1q_f32(in_mid),
65 vld1q_f32(in_mid + 4),
66 vld1q_f32(in_mid + 8)
67 }
68 };
69 const float32x4x3_t vlow =
70 {
71 {
72 vld1q_f32(in_low),
73 vld1q_f32(in_low + 4),
74 vld1q_f32(in_low + 8)
75 }
76 };
77 float32x4x2_t out =
78 {
79 {
80 vmulq_f32(vtop.val[0], m0.val[0]),
81 vmulq_f32(vtop.val[1], m0.val[0])
82 }
83 };
84 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
85 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
86
87 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
88 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
89 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
90
91 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
92 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
93 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
94
95 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
96 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
97
98 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
99 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
100 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
101
102 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
103 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
104 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
105 return out;
106 }
107
108 template <>
109 inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2)
110 {
111 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
112 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
113 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
114 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
115 return out;
116 }
117
118 template <>
119 inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2)
120 {
121 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
122 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
123 return out;
124 }
125
126 template <unsigned int stridex>
127 void store_results(float *buffer, const float32x4x2_t &values);
128
129 template <>
130 void store_results<1>(float *buffer, const float32x4x2_t &values)
131 {
132 vst1q_f32(buffer, values.val[0]);
133 vst1q_f32(buffer + 4, values.val[1]);
134 }
135
136 template <>
137 void store_results<2>(float *buffer, const float32x4x2_t &values)
138 {
139 vst1q_f32(buffer, values.val[0]);
140 }
141
142 template <>
143 void store_results<3>(float *buffer, const float32x4x2_t &values)
144 {
145 vst1_f32(buffer, vget_low_f32(values.val[0]));
146 }
147
148 template <unsigned int stridex>
149 int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
150
151 template <>
152 int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
153 {
154 return num_elems_written_per_iteration;
155 }
156
157 template <>
158 int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
159 {
160 return num_elems_written_per_iteration << 1;
161 }
162
163 template <>
164 int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
165 {
166 return num_elems_written_per_iteration * 3;
167 }
168 }
169 } // namespace arm_compute
170 #endif /* ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H */