• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <stdbool.h>
13 #include <assert.h>
14 #include <arm_neon.h>
15 
16 #include "config/av1_rtcd.h"
17 #include "av1/encoder/ml.h"
18 
nn_activate8(float32x4_t * out_h,float32x4_t * out_l,const float32x4_t * zero)19 static void nn_activate8(float32x4_t *out_h, float32x4_t *out_l,
20                          const float32x4_t *zero) {
21   *out_h = vmaxq_f32(*out_h, *zero);
22   *out_l = vmaxq_f32(*out_l, *zero);
23 }
24 
nn_activate4(float32x4_t * x,const float32x4_t * zero)25 static void nn_activate4(float32x4_t *x, const float32x4_t *zero) {
26   *x = vmaxq_f32(*x, *zero);
27 }
28 
29 #define CLAMP_0(x) (x = x > 0 ? x : 0)
30 
nn_propagate_8to1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)31 static void nn_propagate_8to1(int num_inputs, const float *const inputs,
32                               const float *const weights,
33                               const float *layer_bias,
34                               float *const output_nodes, bool output_layer) {
35   const float32x4_t zero = vdupq_n_f32(0);
36   float32x4_t vadd = zero;
37   float total = *layer_bias;
38 
39   for (int in = 0; in < num_inputs; in += 8) {
40     const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
41     const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
42 
43     const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
44     const float32x4_t weights_l = vld1q_f32(&weights[in]);
45 
46     vadd = vmlaq_f32(vadd, inputs_h, weights_h);
47     vadd = vmlaq_f32(vadd, inputs_l, weights_l);
48   }
49 #if defined(__aarch64__)
50   total += vaddvq_f32(vadd);
51 #else
52   float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
53   vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
54   total += vget_lane_f32(vadd_lo, 0);
55 #endif
56 
57   if (!output_layer) CLAMP_0(total);
58   *output_nodes = total;
59 }
60 
nn_propagate_xto1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes)61 static void nn_propagate_xto1(int num_inputs, const float *const inputs,
62                               const float *const weights,
63                               const float *layer_bias,
64                               float *const output_nodes) {
65   float32x4_t vadd = vdupq_n_f32(0);
66 
67   float total = *layer_bias;
68   int j = num_inputs;
69   int in = 0;
70   while (j > 7) {
71     const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
72     const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
73 
74     const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
75     const float32x4_t weights_l = vld1q_f32(&weights[in]);
76 
77     vadd = vmlaq_f32(vadd, inputs_h, weights_h);
78     vadd = vmlaq_f32(vadd, inputs_l, weights_l);
79     in += 8;
80     j -= 8;
81   }
82 
83 #if defined(__aarch64__)
84   total += vaddvq_f32(vadd);
85 
86 #else
87   float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
88   vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
89   total += vget_lane_f32(vadd_lo, 0);
90 #endif
91   for (; in < num_inputs; in++) total += weights[in] * inputs[in];
92 
93   *output_nodes = CLAMP_0(total);
94 }
95 
nn_propagate_xsto1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes)96 static void nn_propagate_xsto1(int num_inputs, const float *const inputs,
97                                const float *const weights,
98                                const float *layer_bias,
99                                float *const output_nodes) {
100   float total = *layer_bias;
101 #if defined(__aarch64__)
102   const float32x4_t v_inputs = vld1q_f32(inputs);
103   const float32x4_t v_weights = vld1q_f32(weights);
104   const float32x4_t vadd = vmulq_f32(v_inputs, v_weights);
105   total += vaddvq_f32(vadd);
106   int in = 4;
107 #else
108   int in = 0;
109 #endif
110   for (; in < num_inputs; in++) total += weights[in] * inputs[in];
111 
112   *output_nodes = CLAMP_0(total);
113 }
114 
nn_propagate_4to1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)115 static void nn_propagate_4to1(int num_inputs, const float *const inputs,
116                               const float *const weights,
117                               const float *layer_bias,
118                               float *const output_nodes, bool output_layer) {
119   const float32x4_t zero = vdupq_n_f32(0);
120   float32x4_t vadd = zero;
121   float total = *layer_bias;
122 
123   for (int in = 0; in < num_inputs; in += 4) {
124     const float32x4_t v_inputs = vld1q_f32(&inputs[in]);
125     const float32x4_t v_weights = vld1q_f32(&weights[in]);
126     vadd = vmlaq_f32(vadd, v_inputs, v_weights);
127   }
128 
129 #if defined(__aarch64__)
130   total += vaddvq_f32(vadd);
131 #else
132   float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
133   vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
134   total += vget_lane_f32(vadd_lo, 0);
135 #endif
136 
137   if (!output_layer) CLAMP_0(total);
138   *output_nodes = total;
139 }
140 
nn_propagate_4to4(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)141 static void nn_propagate_4to4(int num_inputs, const float *const inputs,
142                               const float *const weights,
143                               const float *layer_bias,
144                               float *const output_nodes, bool output_layer) {
145   float32x4_t outputs = vld1q_f32(layer_bias);
146   const float32x4_t zero = vdupq_n_f32(0);
147 
148   float32x4_t mul0[2] = { zero, zero };
149   float32x4_t mul1[2] = { zero, zero };
150   for (int in = 0; in < num_inputs; in += 4) {
151     const float32x4_t v_input = vld1q_f32(&inputs[in]);
152 
153     for (int i = 0; i < 2; i++) {
154       const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
155       mul0[i] = vmlaq_f32(mul0[i], weight0, v_input);
156       const float32x4_t weight1 =
157           vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
158       mul1[i] = vmlaq_f32(mul1[i], weight1, v_input);
159     }
160   }
161   for (int i = 0; i < 2; i++)
162 #if defined(__aarch64__)
163     mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
164   const float32x4_t hh = vpaddq_f32(mul0[0], mul0[1]);
165 #else
166     mul0[i] =
167         vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
168                      vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
169   const float32x4_t hh =
170       vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
171                    vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
172 #endif
173 
174   outputs = vaddq_f32(outputs, hh);
175   if (!output_layer) nn_activate4(&outputs, &zero);
176   vst1q_f32(output_nodes, outputs);
177 }
178 
nn_propagate_4to8(const int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)179 static void nn_propagate_4to8(const int num_inputs, const float *const inputs,
180                               const float *const weights,
181                               const float *layer_bias,
182                               float *const output_nodes, bool output_layer) {
183   float32x4_t out_h = vld1q_f32(&layer_bias[4]);
184   float32x4_t out_l = vld1q_f32(layer_bias);
185   const float32x4_t zero = vdupq_n_f32(0);
186   float32x4_t mul0[4] = { zero, zero, zero, zero };
187   float32x4_t mul1[4] = { zero, zero, zero, zero };
188 
189   for (int in = 0; in < num_inputs; in += 4) {
190     const float32x4_t v_input = vld1q_f32(&inputs[in]);
191     for (int i = 0; i < 4; i++) {
192       const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
193       const float32x4_t weight1 =
194           vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
195       mul0[i] = vmlaq_f32(mul0[i], v_input, weight0);
196       mul1[i] = vmlaq_f32(mul1[i], v_input, weight1);
197     }
198   }
199   for (int i = 0; i < 4; i++)
200 #if defined(__aarch64__)
201     mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
202   const float32x4_t hh0 = vpaddq_f32(mul0[0], mul0[1]);
203   const float32x4_t hh1 = vpaddq_f32(mul0[2], mul0[3]);
204 #else
205     mul0[i] =
206         vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
207                      vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
208   const float32x4_t hh0 =
209       vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
210                    vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
211   const float32x4_t hh1 =
212       vcombine_f32(vpadd_f32(vget_low_f32(mul0[2]), vget_high_f32(mul0[2])),
213                    vpadd_f32(vget_low_f32(mul0[3]), vget_high_f32(mul0[3])));
214 #endif
215 
216   out_h = vaddq_f32(out_h, hh1);
217   out_l = vaddq_f32(out_l, hh0);
218 
219   if (!output_layer) nn_activate8(&out_h, &out_l, &zero);
220   vst1q_f32(&output_nodes[4], out_h);
221   vst1q_f32(output_nodes, out_l);
222 }
223 
nn_propagate_8to4(const int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)224 static void nn_propagate_8to4(const int num_inputs, const float *const inputs,
225                               const float *const weights,
226                               const float *layer_bias,
227                               float *const output_nodes, bool output_layer) {
228   float32x4_t outputs = vld1q_f32(layer_bias);
229   const float32x4_t zero = vdupq_n_f32(0);
230   float32x4_t add[4] = { zero, zero, zero, zero };
231   for (int in = 0; in < num_inputs; in += 8) {
232     const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
233     const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
234 
235     for (int i = 0; i < 4; i++) {
236       const float32x4_t weight_l = vld1q_f32(&weights[in + i * num_inputs]);
237       const float32x4_t weight_h = vld1q_f32(&weights[in + i * num_inputs + 4]);
238       add[i] = vmlaq_f32(add[i], inputs_l, weight_l);
239       add[i] = vmlaq_f32(add[i], inputs_h, weight_h);
240     }
241   }
242 #if defined(__aarch64__)
243   const float32x4_t hadd_h = vpaddq_f32(add[2], add[3]);
244   const float32x4_t hadd_l = vpaddq_f32(add[0], add[1]);
245   const float32x4_t haddhadd = vpaddq_f32(hadd_l, hadd_h);
246 #else
247   const float32x4_t hadd_h =
248       vcombine_f32(vpadd_f32(vget_low_f32(add[2]), vget_high_f32(add[2])),
249                    vpadd_f32(vget_low_f32(add[3]), vget_high_f32(add[3])));
250   const float32x4_t hadd_l =
251       vcombine_f32(vpadd_f32(vget_low_f32(add[0]), vget_high_f32(add[0])),
252                    vpadd_f32(vget_low_f32(add[1]), vget_high_f32(add[1])));
253   const float32x4_t haddhadd =
254       vcombine_f32(vpadd_f32(vget_low_f32(hadd_l), vget_high_f32(hadd_l)),
255                    vpadd_f32(vget_low_f32(hadd_h), vget_high_f32(hadd_h)));
256 #endif
257 
258   outputs = vaddq_f32(outputs, haddhadd);
259   if (!output_layer) nn_activate4(&outputs, &zero);
260   vst1q_f32(output_nodes, outputs);
261 }
262 
263 // Calculate prediction based on the given input features and neural net config.
264 // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden
265 // layer.
av1_nn_predict_neon(const float * input_nodes,const NN_CONFIG * const nn_config,int reduce_prec,float * const output)266 void av1_nn_predict_neon(const float *input_nodes,
267                          const NN_CONFIG *const nn_config, int reduce_prec,
268                          float *const output) {
269   float buf[2][NN_MAX_NODES_PER_LAYER];
270   int buf_index = 0;
271   int num_inputs = nn_config->num_inputs;
272   // Hidden layers, except the final iteration is the output layer.
273   for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
274     const float *layer_weights = nn_config->weights[layer];
275     const float *layer_bias = nn_config->bias[layer];
276     bool output_layer = (layer == nn_config->num_hidden_layers);
277     float *const output_nodes = output_layer ? output : buf[buf_index];
278     const int num_outputs = output_layer ? nn_config->num_outputs
279                                          : nn_config->num_hidden_nodes[layer];
280 
281     if (num_inputs % 4 == 0 && num_outputs % 8 == 0) {
282       for (int out = 0; out < num_outputs; out += 8) {
283         nn_propagate_4to8(num_inputs, input_nodes,
284                           &layer_weights[out * num_inputs], &layer_bias[out],
285                           &output_nodes[out], output_layer);
286       }
287     } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) {
288       for (int out = 0; out < num_outputs; out += 4) {
289         nn_propagate_8to4(num_inputs, input_nodes,
290                           &layer_weights[out * num_inputs], &layer_bias[out],
291                           &output_nodes[out], output_layer);
292       }
293     } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) {
294       for (int out = 0; out < num_outputs; out += 4) {
295         nn_propagate_4to4(num_inputs, input_nodes,
296                           &layer_weights[out * num_inputs], &layer_bias[out],
297                           &output_nodes[out], output_layer);
298       }
299     } else if (num_inputs % 8 == 0) {
300       for (int out = 0; out < num_outputs; out++) {
301         nn_propagate_8to1(num_inputs, input_nodes,
302                           &layer_weights[out * num_inputs], &layer_bias[out],
303                           &output_nodes[out], output_layer);
304       }
305     } else if (num_inputs % 4 == 0) {
306       for (int out = 0; out < num_outputs; out++) {
307         nn_propagate_4to1(num_inputs, input_nodes,
308                           &layer_weights[out * num_inputs], &layer_bias[out],
309                           &output_nodes[out], output_layer);
310       }
311     } else if (num_inputs > 8) {
312       for (int out = 0; out < num_outputs; out++) {
313         nn_propagate_xto1(num_inputs, input_nodes,
314                           &layer_weights[out * num_inputs], &layer_bias[out],
315                           &output_nodes[out]);
316       }
317     } else if (num_inputs >= 4) {
318       for (int out = 0; out < num_outputs; out++) {
319         nn_propagate_xsto1(num_inputs, input_nodes,
320                            &layer_weights[out * num_inputs], &layer_bias[out],
321                            &output_nodes[out]);
322       }
323     } else {
324       for (int node = 0; node < num_outputs; ++node) {
325         float val = layer_bias[node];
326         for (int i = 0; i < num_inputs; ++i)
327           val += layer_weights[node * num_inputs + i] * input_nodes[i];
328         // ReLU as activation function.
329         val = val > 0.0f ? val : 0.0f;  // Could use AOMMAX().
330         output_nodes[node] = val;
331       }
332     }
333     input_nodes = output_nodes;
334     num_inputs = num_outputs;
335     buf_index = 1 - buf_index;
336   }
337   if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
338 }
339