• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
18 #define ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
19 
20 #include "gemmlowp.h"
21 #include "../common.h"
22 #include "../types.h"
23 
24 namespace android {
25 namespace nn {
26 namespace optimized_ops {
27 
28 // Implementation of float DepthwiseConv
29 
30 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
31 struct FloatDepthwiseConvKernel {};
32 
33 #ifdef USE_NEON
34 
35 template <>
36 struct FloatDepthwiseConvKernel<false, 8, 1> {
37   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
38                   const float* input_ptr, int input_ptr_increment,
39                   const float* filter_ptr, float* acc_buffer_ptr) {
40     // Load the filters
41     float32x4_t filter[2];
42     for (int i = 0; i < 2; i++) {
43       filter[i] = vld1q_f32(filter_ptr + 4 * i);
44     }
45     int outp = 0;
46     // Handle 2 output pixels at a time.
47     for (; outp <= num_output_pixels - 2; outp += 2) {
48       // Load the inputs
49       float32x4_t input[4];
50       for (int i = 0; i < 4; i++) {
51         input[i] = vld1q_f32(input_ptr + 4 * i);
52       }
53       input_ptr += 16;
54       // Load the accumulators from acc_buffer
55       float32x4_t acc[4];
56       for (int i = 0; i < 4; i++) {
57         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
58       }
59       // Multiply-accumulate
60       acc[0] = vmlaq_f32(acc[0], input[0], filter[0]);
61       acc[1] = vmlaq_f32(acc[1], input[1], filter[1]);
62       acc[2] = vmlaq_f32(acc[2], input[2], filter[0]);
63       acc[3] = vmlaq_f32(acc[3], input[3], filter[1]);
64       // Store the accumulators back to acc_buffer
65       for (int i = 0; i < 4; i++) {
66         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
67       }
68       acc_buffer_ptr += 16;
69     }
70     // Handle one output pixel at a time.
71     for (; outp < num_output_pixels; outp++) {
72       // Load the inputs
73       float32x4_t input[2];
74       for (int i = 0; i < 2; i++) {
75         input[i] = vld1q_f32(input_ptr + 4 * i);
76       }
77       input_ptr += 8;
78       // Load the accumulators from acc_buffer
79       float32x4_t acc[2];
80       for (int i = 0; i < 2; i++) {
81         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
82       }
83       // Multiply-accumulate
84       for (int i = 0; i < 2; i++) {
85         acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
86       }
87       // Store the accumulators back to acc_buffer
88       for (int i = 0; i < 2; i++) {
89         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
90       }
91       acc_buffer_ptr += 8;
92     }
93   }
94 };
95 
96 template <>
97 struct FloatDepthwiseConvKernel<false, 2, 1> {
98   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
99                   const float* input_ptr, int input_ptr_increment,
100                   const float* filter_ptr, float* acc_buffer_ptr) {
101     const float32x2_t filters = vld1_f32(filter_ptr);
102     const float32x4_t filters_dup2 = vcombine_f32(filters, filters);
103     int outp = 0;
104     // Handle 8 output pixels at a time.
105     for (; outp <= num_output_pixels - 8; outp += 8) {
106       // Load the inputs
107       float32x4_t input[4];
108       for (int i = 0; i < 4; i++) {
109         input[i] = vld1q_f32(input_ptr + 4 * i);
110       }
111       input_ptr += 16;
112       // Load the accumulators from acc_buffer
113       float32x4_t acc[4];
114       for (int i = 0; i < 4; i++) {
115         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
116       }
117       // Multiply-accumulate
118       for (int i = 0; i < 4; i++) {
119         acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
120       }
121       // Store the accumulators back to acc_buffer
122       for (int i = 0; i < 4; i++) {
123         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
124       }
125       acc_buffer_ptr += 16;
126     }
127     // Handle 4 output pixels at a time.
128     for (; outp <= num_output_pixels - 4; outp += 4) {
129       // Load the inputs
130       float32x4_t input[2];
131       for (int i = 0; i < 2; i++) {
132         input[i] = vld1q_f32(input_ptr + 4 * i);
133       }
134       input_ptr += 8;
135       // Load the accumulators from acc_buffer
136       float32x4_t acc[2];
137       for (int i = 0; i < 2; i++) {
138         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
139       }
140       // Multiply-accumulate
141       for (int i = 0; i < 2; i++) {
142         acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
143       }
144       // Store the accumulators back to acc_buffer
145       for (int i = 0; i < 2; i++) {
146         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
147       }
148       acc_buffer_ptr += 8;
149     }
150     // Handle 2 output pixels at a time.
151     for (; outp <= num_output_pixels - 2; outp += 2) {
152       // Load the inputs
153       const float32x4_t input = vld1q_f32(input_ptr);
154       input_ptr += 4;
155       // Load the accumulators from acc_buffer
156       float32x4_t acc = vld1q_f32(acc_buffer_ptr);
157       // Multiply-accumulate
158       acc = vmlaq_f32(acc, input, filters_dup2);
159       // Store the accumulators back to acc_buffer
160       vst1q_f32(acc_buffer_ptr, acc);
161       acc_buffer_ptr += 4;
162     }
163     // Handle 1 output pixel at a time
164     for (; outp < num_output_pixels; outp++) {
165       // Load the inputs
166       const float32x2_t input = vld1_f32(input_ptr);
167       input_ptr += 2;
168       // Load the accumulators from acc_buffer
169       float32x2_t acc = vld1_f32(acc_buffer_ptr);
170       // Multiply-accumulate
171       acc = vmla_f32(acc, input, filters);
172       // Store the accumulators back to acc_buffer
173       vst1_f32(acc_buffer_ptr, acc);
174       acc_buffer_ptr += 2;
175     }
176   }
177 };
178 
179 template <>
180 struct FloatDepthwiseConvKernel<true, 0, 1> {
181   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
182                   const float* input_ptr, int input_ptr_increment,
183                   const float* filter_ptr, float* acc_buffer_ptr) {
184     // Handle one output pixel at a time.
185     for (int outp = 0; outp < num_output_pixels; outp++) {
186       const float* local_filter_ptr = filter_ptr;
187       const float* local_input_ptr = input_ptr;
188       int ic = 0;
189       // Handle 16 input channels at a time.
190       for (; ic <= input_depth - 16; ic += 16) {
191         // Load the filters
192         float32x4_t filter[4];
193         for (int i = 0; i < 4; i++) {
194           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
195         }
196         local_filter_ptr += 16;
197         // Load the inputs
198         float32x4_t input[4];
199         for (int i = 0; i < 4; i++) {
200           input[i] = vld1q_f32(local_input_ptr + 4 * i);
201         }
202         local_input_ptr += 16;
203         // Load the accumulators from acc_buffer
204         float32x4_t acc[4];
205         for (int i = 0; i < 4; i++) {
206           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
207         }
208         // Multiply-accumulate
209         for (int i = 0; i < 4; i++) {
210           acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
211         }
212         // Store the accumulators back to acc_buffer
213         for (int i = 0; i < 4; i++) {
214           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
215         }
216         acc_buffer_ptr += 16;
217       }
218       // Handle 4 input channels at a time.
219       for (; ic <= input_depth - 4; ic += 4) {
220         // Load the filters
221         float32x4_t filter;
222         filter = vld1q_f32(local_filter_ptr);
223         local_filter_ptr += 4;
224         // Load the inputs
225         float32x4_t input;
226         input = vld1q_f32(local_input_ptr);
227         local_input_ptr += 4;
228         // Load the accumulators from acc_buffer
229         float32x4_t acc;
230         acc = vld1q_f32(acc_buffer_ptr);
231         // Multiply-accumulate
232         acc = vmlaq_f32(acc, input, filter);
233         // Store the accumulators back to acc_buffer
234         vst1q_f32(acc_buffer_ptr, acc);
235         acc_buffer_ptr += 4;
236       }
237       // Handle one input channel at a time.
238       for (; ic < input_depth; ic++) {
239         const float input_val = *local_input_ptr++;
240         const float filter_val = *local_filter_ptr++;
241         *acc_buffer_ptr++ += filter_val * input_val;
242       }
243       input_ptr += input_ptr_increment;
244     }
245   }
246 };
247 
248 template <>
249 struct FloatDepthwiseConvKernel<true, 0, 8> {
250   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
251                   const float* input_ptr, int input_ptr_increment,
252                   const float* filter_ptr, float* acc_buffer_ptr) {
253     // Handle one output pixel at a time.
254     for (int outp = 0; outp < num_output_pixels; outp++) {
255       const float* local_filter_ptr = filter_ptr;
256       const float* local_input_ptr = input_ptr;
257       int ic = 0;
258       // Handle 2 input channels at a time.
259       for (; ic <= input_depth - 2; ic += 2) {
260         // Load the filters
261         float32x4_t filter[4];
262         for (int i = 0; i < 4; i++) {
263           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
264         }
265         local_filter_ptr += 16;
266         // Load the inputs
267         const float32x2_t input = vld1_f32(local_input_ptr);
268         local_input_ptr += 2;
269         // Load the accumulators from acc_buffer
270         float32x4_t acc[4];
271         for (int i = 0; i < 4; i++) {
272           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
273         }
274         // Multiply-accumulate
275         acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0);
276         acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0);
277         acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1);
278         acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1);
279         // Store the accumulators back to acc_buffer
280         for (int i = 0; i < 4; i++) {
281           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
282         }
283         acc_buffer_ptr += 16;
284       }
285       // Handle one input channel at a time.
286       for (; ic < input_depth; ic++) {
287         // Load the filters
288         float32x4_t filter[2];
289         for (int i = 0; i < 2; i++) {
290           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
291         }
292         local_filter_ptr += 8;
293         // Load the inputs
294         const float input_val = *local_input_ptr++;
295         // Load the accumulators from acc_buffer
296         float32x4_t acc[2];
297         for (int i = 0; i < 2; i++) {
298           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
299         }
300         // Multiply-accumulate
301         for (int i = 0; i < 2; i++) {
302           acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
303         }
304         // Store the accumulators back to acc_buffer
305         for (int i = 0; i < 2; i++) {
306           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
307         }
308         acc_buffer_ptr += 8;
309       }
310       input_ptr += input_ptr_increment;
311     }
312   }
313 };
314 
315 template <>
316 struct FloatDepthwiseConvKernel<true, 0, 2> {
317   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
318                   const float* input_ptr, int input_ptr_increment,
319                   const float* filter_ptr, float* acc_buffer_ptr) {
320     // Handle one output pixel at a time.
321     for (int outp = 0; outp < num_output_pixels; outp++) {
322       const float* local_filter_ptr = filter_ptr;
323       const float* local_input_ptr = input_ptr;
324       int ic = 0;
325       // Handle 8 input channels at a time.
326       for (; ic <= input_depth - 8; ic += 8) {
327         // Load the filters
328         float32x4_t filter[4];
329         for (int i = 0; i < 4; i++) {
330           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
331         }
332         local_filter_ptr += 16;
333         // Load the inputs
334         float32x4x2_t input_dup2[2];
335         for (int i = 0; i < 2; i++) {
336           const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i);
337           input_dup2[i] = vzipq_f32(input, input);
338         }
339         local_input_ptr += 8;
340         // Load the accumulators from acc_buffer
341         float32x4_t acc[4];
342         for (int i = 0; i < 4; i++) {
343           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
344         }
345         // Multiply-accumulate
346         acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]);
347         acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]);
348         acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]);
349         acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]);
350         // Store the accumulators back to acc_buffer
351         for (int i = 0; i < 4; i++) {
352           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
353         }
354         acc_buffer_ptr += 16;
355       }
356       // Handle 4 input channels at a time.
357       for (; ic <= input_depth - 4; ic += 4) {
358         // Load the filters
359         float32x2_t filter[4];
360         for (int i = 0; i < 4; i++) {
361           filter[i] = vld1_f32(local_filter_ptr + 2 * i);
362         }
363         local_filter_ptr += 8;
364         // Load the inputs
365         const float32x4_t input = vld1q_f32(local_input_ptr);
366         local_input_ptr += 4;
367         // Load the accumulators from acc_buffer
368         float32x2_t acc[4];
369         for (int i = 0; i < 4; i++) {
370           acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
371         }
372         // Multiply-accumulate
373         acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0);
374         acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1);
375         acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0);
376         acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1);
377         // Store the accumulators back to acc_buffer
378         for (int i = 0; i < 4; i++) {
379           vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
380         }
381         acc_buffer_ptr += 8;
382       }
383       // Handle 2 input channels at a time.
384       for (; ic <= input_depth - 2; ic += 2) {
385         // Load the filters
386         const float32x4_t filter = vld1q_f32(local_filter_ptr);
387         local_filter_ptr += 4;
388         // Load the inputs
389         const float32x2_t input = vld1_f32(local_input_ptr);
390         local_input_ptr += 2;
391         // Load the accumulators from acc_buffer
392         float32x2_t acc[2];
393         for (int i = 0; i < 2; i++) {
394           acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
395         }
396         // Multiply-accumulate
397         acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0);
398         acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1);
399         // Store the accumulators back to acc_buffer
400         for (int i = 0; i < 2; i++) {
401           vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
402         }
403         acc_buffer_ptr += 4;
404       }
405       // Handle one input channel at a time.
406       for (; ic < input_depth; ic++) {
407         // Load the inputs
408         const float input_val = *local_input_ptr++;
409         // Multiply-accumulate
410         for (int i = 0; i < 2; i++) {
411           acc_buffer_ptr[i] += local_filter_ptr[i] * input_val;
412         }
413         local_filter_ptr += 2;
414         acc_buffer_ptr += 2;
415       }
416       input_ptr += input_ptr_increment;
417     }
418   }
419 };
420 
421 template <>
422 struct FloatDepthwiseConvKernel<true, 1, 8> {
423   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
424                   const float* input_ptr, int input_ptr_increment,
425                   const float* filter_ptr, float* acc_buffer_ptr) {
426     // Handle one output pixel at a time.
427     for (int outp = 0; outp < num_output_pixels; outp++) {
428       // Load the filters
429       float32x4_t filter[2];
430       for (int i = 0; i < 2; i++) {
431         filter[i] = vld1q_f32(filter_ptr + 4 * i);
432       }
433       // Load the inputs
434       const float input_val = *input_ptr;
435       input_ptr += input_ptr_increment;
436       // Load the accumulators from acc_buffer
437       float32x4_t acc[2];
438       for (int i = 0; i < 2; i++) {
439         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
440       }
441       // Multiply-accumulate
442       for (int i = 0; i < 2; i++) {
443         acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
444       }
445       // Store the accumulators back to acc_buffer
446       for (int i = 0; i < 2; i++) {
447         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
448       }
449       acc_buffer_ptr += 8;
450     }
451   }
452 };
453 
454 template <>
455 struct FloatDepthwiseConvKernel<true, 0, 16> {
456   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
457                   const float* input_ptr, int input_ptr_increment,
458                   const float* filter_ptr, float* acc_buffer_ptr) {
459     // Handle one output pixel at a time.
460     for (int outp = 0; outp < num_output_pixels; outp++) {
461       const float* local_filter_ptr = filter_ptr;
462       const float* local_input_ptr = input_ptr;
463       for (int ic = 0; ic < input_depth; ic++) {
464         // Load the filters
465         float32x4_t filter[4];
466         for (int i = 0; i < 4; i++) {
467           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
468         }
469         local_filter_ptr += 16;
470         // Load the inputs
471         const float input_val = *local_input_ptr++;
472         // Load the accumulators from acc_buffer
473         float32x4_t acc[4];
474         for (int i = 0; i < 4; i++) {
475           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
476         }
477         // Multiply-accumulate
478         for (int i = 0; i < 4; i++) {
479           acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
480         }
481         // Store the accumulators back to acc_buffer
482         for (int i = 0; i < 4; i++) {
483           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
484         }
485         acc_buffer_ptr += 16;
486       }
487       input_ptr += input_ptr_increment;
488     }
489   }
490 };
491 #endif
492 
493 // Accumulates the effect of one row of the filter, on a segment of one row
494 // of the output, accessing the corresponding one row of the input.
495 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
496 void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
497                                 const float* input_data, int pad_width,
498                                 int depth_multiplier, int filter_width,
499                                 const float* filter_data,
500                                 int out_x_buffer_start, int out_x_buffer_end,
501                                 int output_depth, float* acc_buffer) {
502 #ifdef GEMMLOWP_PROFILING
503   gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
504 #endif
505   // Sanity check parameters. This is important in particular to ensure
506   // that we keep the number of template instantiations minimal, so we don't
507   // increase binary size unnecessarily.
508   static_assert(kFixedDepthMultiplier || !kFixedInputDepth, "");
509   static_assert(kFixedInputDepth || kAllowStrided, "");
510   DCHECK(stride == 1 || kAllowStrided);
511   if (kFixedInputDepth) {
512     DCHECK_EQ(input_depth, kFixedInputDepth);
513   }
514   if (kFixedDepthMultiplier) {
515     DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier);
516   }
517   DCHECK_EQ(output_depth, input_depth * depth_multiplier);
518   const int input_ptr_increment = stride * input_depth;
519   const float* filter_base_ptr = filter_data;
520   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
521     // For the current (filter_x, filter_y) point in the filter,
522     // compute the boundaries of the corresponding output row segment.
523     int out_x_loop_start_unclampled = 0;
524     int out_x_loop_end_unclampled = 0;
525     if (kAllowStrided) {
526       if (stride == 2) {
527         out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2;
528         out_x_loop_end_unclampled =
529             (pad_width + input_width - filter_x + 1) / 2;
530       } else if (stride == 4) {
531         out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4;
532         out_x_loop_end_unclampled =
533             (pad_width + input_width - filter_x + 3) / 4;
534       } else {
535         out_x_loop_start_unclampled =
536             (pad_width - filter_x + stride - 1) / stride;
537         out_x_loop_end_unclampled =
538             (pad_width + input_width - filter_x + stride - 1) / stride;
539       }
540     } else {
541       out_x_loop_start_unclampled = pad_width - filter_x;
542       out_x_loop_end_unclampled = pad_width + input_width - filter_x;
543     }
544     // The kernel will have to iterate on the segment of the
545     // output row that starts at out_x_loop_start and out_x_loop_end.
546     const int out_x_loop_start =
547         std::max(out_x_buffer_start, out_x_loop_start_unclampled);
548     const int out_x_loop_end =
549         std::min(out_x_buffer_end, out_x_loop_end_unclampled);
550 
551     float* acc_buffer_ptr =
552         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
553     const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
554     const float* input_ptr = input_data + in_x_origin * input_depth;
555     const int num_output_pixels = out_x_loop_end - out_x_loop_start;
556     FloatDepthwiseConvKernel<kAllowStrided, kFixedInputDepth,
557                              kFixedDepthMultiplier>::Run(num_output_pixels,
558                                                          input_depth,
559                                                          depth_multiplier,
560                                                          input_ptr,
561                                                          input_ptr_increment,
562                                                          filter_base_ptr,
563                                                          acc_buffer_ptr);
564     filter_base_ptr += output_depth;
565   }
566 }
567 
568 // generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
569 inline void FloatDepthwiseConvAccumRowGeneric(
570     int stride, int input_depth, int input_width, const float* input_data,
571     int pad_width, int depth_multiplier, int filter_width,
572     const float* filter_data, int out_x_buffer_start, int out_x_buffer_end,
573     int output_depth, float* acc_buffer) {
574   gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
575   const float* filter_base_ptr = filter_data;
576   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
577     const int out_x_loop_start = std::max(
578         out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
579     const int out_x_loop_end =
580         std::min(out_x_buffer_end,
581                  (pad_width + input_width - filter_x + stride - 1) / stride);
582 
583     float* acc_buffer_ptr =
584         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
585     const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
586     const float* input_ptr = input_data + in_x_origin * input_depth;
587     const int input_ptr_increment = (stride - 1) * input_depth;
588     for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
589       const float* filter_ptr = filter_base_ptr;
590       for (int ic = 0; ic < input_depth; ++ic) {
591         const float input_val = *input_ptr++;
592         for (int m = 0; m < depth_multiplier; m++) {
593           const float filter_val = *filter_ptr++;
594           *acc_buffer_ptr++ += filter_val * input_val;
595         }
596       }
597       input_ptr += input_ptr_increment;
598     }
599     filter_base_ptr += output_depth;
600   }
601 }
602 
603 // Initializes the accumulator buffer with bias values.
604 inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
605                                        const float* bias_data,
606                                        float* acc_buffer) {
607   for (int i = 0; i < num_output_pixels; i++) {
608     memcpy(acc_buffer + i * output_depth, bias_data,
609            sizeof(acc_buffer[0]) * output_depth);
610   }
611 }
612 
613 template <FusedActivationFunctionType Ac>
614 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
615                    const float* filter_data, const Dims<4>& filter_dims,
616                    const float* bias_data, const Dims<4>& bias_dims,
617                    int stride_width, int stride_height,
618                    int pad_width, int pad_height, int depth_multiplier,
619                    float* output_data, const Dims<4>& output_dims) {
620   gemmlowp::ScopedProfilingLabel label("DepthwiseConv");
621   static_assert(Ac == FusedActivationFunctionType::kNone ||
622                     Ac == FusedActivationFunctionType::kRelu ||
623                     Ac == FusedActivationFunctionType::kRelu6 ||
624                     Ac == FusedActivationFunctionType::kRelu1,
625                 "");
626   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
627   const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
628   const int input_height = ArraySize(input_dims, 2);
629   const int input_width = ArraySize(input_dims, 1);
630   const int input_depth = ArraySize(input_dims, 0);
631   const int filter_height = ArraySize(filter_dims, 2);
632   const int filter_width = ArraySize(filter_dims, 1);
633   const int output_height = ArraySize(output_dims, 2);
634   const int output_width = ArraySize(output_dims, 1);
635   DCHECK(output_depth == input_depth * depth_multiplier);
636 
637   static const int kAccBufferMaxSize = 1024;
638   float acc_buffer[kAccBufferMaxSize];
639   DCHECK_GE(kAccBufferMaxSize, output_depth);
640   const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth;
641   const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth;
642   DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, kAccBufferActualSize);
643   DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize);
644   DCHECK_GE(kOutputPixelsInAccBuffer, 1);
645 
646   // row_accum_func will point to the core accumulation function to be used
647   // for this DepthwiseConv op.
648   auto* row_accum_func = FloatDepthwiseConvAccumRowGeneric;
649 
650   const int kMaxFixedDepthMultiplier = 16;
651   int fixed_depth_multiplier = 0;
652   if (depth_multiplier <= kMaxFixedDepthMultiplier) {
653     fixed_depth_multiplier = depth_multiplier;
654   }
655   // kMaxUnrolling is the max number of output values that we aim to handle
656   // in one unrolled iteration of the inner loop. For practical performance
657   // reasons, it is limited by the number of available registers. We could
658   // fine-tune it depending on the architecture, but that's not worth doing
659   // since this whole code is not very optimized to begin with. The
660   // present value reflects what's realistic on ARM 32bit NEON with 16 128-bit
661   // vector registers.
662   const int kMaxUnrolling = 8;
663   int fixed_input_depth = 0;
664   if (fixed_depth_multiplier &&
665       input_depth * fixed_depth_multiplier <= kMaxUnrolling) {
666     fixed_input_depth = input_depth;
667   }
668 #define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
669                                         FIXED_DEPTH_MULTIPLIER)           \
670   if ((stride_width == 1 || ALLOW_STRIDED) &&                             \
671       fixed_input_depth == FIXED_INPUT_DEPTH &&                           \
672       fixed_depth_multiplier == FIXED_DEPTH_MULTIPLIER) {                 \
673     row_accum_func =                                                      \
674         FloatDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH,      \
675                                    FIXED_DEPTH_MULTIPLIER>;               \
676   }
677 
678 #ifdef USE_NEON
679   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1)
680   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 8)
681   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2)
682   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1)
683   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1)
684   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 16)
685   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8)
686 #endif  // USE_NEON
687 
688 #undef TFMINI_USE_DEPTHWISECONV_KERNEL
689 
690   // Now that we have determined row_accum_func, we can start work.
691   float* output_ptr = output_data;
692   for (int b = 0; b < batches; ++b) {
693     for (int out_y = 0; out_y < output_height; ++out_y) {
694       const int in_y_origin = (out_y * stride_height) - pad_height;
695       const int filter_y_start = std::max(0, -in_y_origin);
696       const int filter_y_end =
697           std::min(filter_height, input_height - in_y_origin);
698       for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
699            out_x_buffer_start += kOutputPixelsInAccBuffer) {
700         const int out_x_buffer_end = std::min(
701             output_width, out_x_buffer_start + kOutputPixelsInAccBuffer);
702         // We call a 'pixel' a group of activation that share all but the
703         // 'depth'/'channel' coordinate. num_output_pixels is the number of
704         // output pixels that we will accumulate in this loop iteration.
705         const int num_output_pixels = out_x_buffer_end - out_x_buffer_start;
706         // Initialize our local accumulator with the bias values, so we don't
707         // have to add them later.
708         DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data,
709                                    acc_buffer);
710         // Accumulation loop. Most of the time should be spent in here.
711         for (int filter_y = filter_y_start; filter_y < filter_y_end;
712              ++filter_y) {
713           const int in_y = in_y_origin + filter_y;
714           row_accum_func(stride_width, input_depth, input_width,
715                          input_data + in_y * input_dims.strides[2] +
716                              b * input_dims.strides[3],
717                          pad_width, depth_multiplier, filter_width,
718                          filter_data + filter_y * filter_dims.strides[2],
719                          out_x_buffer_start, out_x_buffer_end, output_depth,
720                          acc_buffer);
721         }
722         // Finished accumulating. Now store to destination.
723         const int num_output_values = output_depth * num_output_pixels;
724         int i = 0;
725 #ifdef USE_NEON
726         // Handle 16 values at a time
727         for (; i <= num_output_values - 16; i += 16) {
728           float32x4_t acc[4];
729           for (int k = 0; k < 4; k++) {
730             acc[k] = vld1q_f32(acc_buffer + i + 4 * k);
731           }
732           if (Ac == FusedActivationFunctionType::kRelu) {
733             for (int k = 0; k < 4; k++) {
734               acc[k] = vmaxq_f32(vdupq_n_f32(0.f), acc[k]);
735             }
736           } else if (Ac == FusedActivationFunctionType::kRelu6) {
737             for (int k = 0; k < 4; k++) {
738               acc[k] = vmaxq_f32(vdupq_n_f32(0.f),
739                                  vminq_f32(vdupq_n_f32(6.f), acc[k]));
740             }
741           } else if (Ac == FusedActivationFunctionType::kRelu1) {
742             for (int k = 0; k < 4; k++) {
743               acc[k] = vmaxq_f32(vdupq_n_f32(-1.f),
744                                  vminq_f32(vdupq_n_f32(1.f), acc[k]));
745             }
746           }
747           for (int k = 0; k < 4; k++) {
748             vst1q_f32(output_ptr + 4 * k, acc[k]);
749           }
750           output_ptr += 16;
751         }
752         // Handle 4 values at a time
753         for (; i <= num_output_values - 4; i += 4) {
754           float32x4_t acc = vld1q_f32(acc_buffer + i);
755           if (Ac == FusedActivationFunctionType::kRelu) {
756             acc = vmaxq_f32(vdupq_n_f32(0.f), acc);
757           } else if (Ac == FusedActivationFunctionType::kRelu6) {
758             acc = vmaxq_f32(vdupq_n_f32(0.f), vminq_f32(vdupq_n_f32(6.f), acc));
759           } else if (Ac == FusedActivationFunctionType::kRelu1) {
760             acc =
761                 vmaxq_f32(vdupq_n_f32(-1.f), vminq_f32(vdupq_n_f32(1.f), acc));
762           }
763           vst1q_f32(output_ptr, acc);
764           output_ptr += 4;
765         }
766 #endif
767         // Handle leftover values, one by one. This is very slow.
768         for (; i < num_output_values; i++) {
769           float acc = acc_buffer[i];
770           if (Ac == FusedActivationFunctionType::kRelu) {
771             acc = std::max(0.f, acc);
772           } else if (Ac == FusedActivationFunctionType::kRelu6) {
773             acc = std::max(0.f, std::min(6.f, acc));
774           } else if (Ac == FusedActivationFunctionType::kRelu1) {
775             acc = std::max(-1.f, std::min(1.f, acc));
776           }
777           *output_ptr++ = acc;
778         }
779       }
780     }
781   }
782 }
783 
784 }  // namespace optimized_ops
785 }  // namespace nn
786 }  // namespace android
787 
788 
789 #endif  // ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
790