• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_
16 #define TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_
17 
18 #include "public/gemmlowp.h"
19 #include "tensorflow/core/kernels/neon/types.h"
20 
21 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
22 #define USE_NEON
23 #include <arm_neon.h>
24 #endif
25 
26 namespace tensorflow {
27 namespace neon {
28 
29 // Implementation of float DepthwiseConv
30 
31 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
32 struct FloatDepthwiseConvKernel {};
33 
34 #ifdef USE_NEON
35 
36 template <>
37 struct FloatDepthwiseConvKernel<false, 8, 1> {
38   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
39                   const float* input_ptr, int input_ptr_increment,
40                   const float* filter_ptr, float* acc_buffer_ptr) {
41     // Load the filters
42     float32x4_t filter[2];
43     for (int i = 0; i < 2; i++) {
44       filter[i] = vld1q_f32(filter_ptr + 4 * i);
45     }
46     int outp = 0;
47     // Handle 2 output pixels at a time.
48     for (; outp <= num_output_pixels - 2; outp += 2) {
49       // Load the inputs
50       float32x4_t input[4];
51       for (int i = 0; i < 4; i++) {
52         input[i] = vld1q_f32(input_ptr + 4 * i);
53       }
54       input_ptr += 16;
55       // Load the accumulators from acc_buffer
56       float32x4_t acc[4];
57       for (int i = 0; i < 4; i++) {
58         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
59       }
60       // Multiply-accumulate
61       acc[0] = vmlaq_f32(acc[0], input[0], filter[0]);
62       acc[1] = vmlaq_f32(acc[1], input[1], filter[1]);
63       acc[2] = vmlaq_f32(acc[2], input[2], filter[0]);
64       acc[3] = vmlaq_f32(acc[3], input[3], filter[1]);
65       // Store the accumulators back to acc_buffer
66       for (int i = 0; i < 4; i++) {
67         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
68       }
69       acc_buffer_ptr += 16;
70     }
71     // Handle one output pixel at a time.
72     for (; outp < num_output_pixels; outp++) {
73       // Load the inputs
74       float32x4_t input[2];
75       for (int i = 0; i < 2; i++) {
76         input[i] = vld1q_f32(input_ptr + 4 * i);
77       }
78       input_ptr += 8;
79       // Load the accumulators from acc_buffer
80       float32x4_t acc[2];
81       for (int i = 0; i < 2; i++) {
82         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
83       }
84       // Multiply-accumulate
85       for (int i = 0; i < 2; i++) {
86         acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
87       }
88       // Store the accumulators back to acc_buffer
89       for (int i = 0; i < 2; i++) {
90         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
91       }
92       acc_buffer_ptr += 8;
93     }
94   }
95 };
96 
97 template <>
98 struct FloatDepthwiseConvKernel<false, 2, 1> {
99   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
100                   const float* input_ptr, int input_ptr_increment,
101                   const float* filter_ptr, float* acc_buffer_ptr) {
102     const float32x2_t filters = vld1_f32(filter_ptr);
103     const float32x4_t filters_dup2 = vcombine_f32(filters, filters);
104     int outp = 0;
105     // Handle 8 output pixels at a time.
106     for (; outp <= num_output_pixels - 8; outp += 8) {
107       // Load the inputs
108       float32x4_t input[4];
109       for (int i = 0; i < 4; i++) {
110         input[i] = vld1q_f32(input_ptr + 4 * i);
111       }
112       input_ptr += 16;
113       // Load the accumulators from acc_buffer
114       float32x4_t acc[4];
115       for (int i = 0; i < 4; i++) {
116         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
117       }
118       // Multiply-accumulate
119       for (int i = 0; i < 4; i++) {
120         acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
121       }
122       // Store the accumulators back to acc_buffer
123       for (int i = 0; i < 4; i++) {
124         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
125       }
126       acc_buffer_ptr += 16;
127     }
128     // Handle 4 output pixels at a time.
129     for (; outp <= num_output_pixels - 4; outp += 4) {
130       // Load the inputs
131       float32x4_t input[2];
132       for (int i = 0; i < 2; i++) {
133         input[i] = vld1q_f32(input_ptr + 4 * i);
134       }
135       input_ptr += 8;
136       // Load the accumulators from acc_buffer
137       float32x4_t acc[2];
138       for (int i = 0; i < 2; i++) {
139         acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
140       }
141       // Multiply-accumulate
142       for (int i = 0; i < 2; i++) {
143         acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2);
144       }
145       // Store the accumulators back to acc_buffer
146       for (int i = 0; i < 2; i++) {
147         vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
148       }
149       acc_buffer_ptr += 8;
150     }
151     // Handle 2 output pixels at a time.
152     for (; outp <= num_output_pixels - 2; outp += 2) {
153       // Load the inputs
154       const float32x4_t input = vld1q_f32(input_ptr);
155       input_ptr += 4;
156       // Load the accumulators from acc_buffer
157       float32x4_t acc = vld1q_f32(acc_buffer_ptr);
158       // Multiply-accumulate
159       acc = vmlaq_f32(acc, input, filters_dup2);
160       // Store the accumulators back to acc_buffer
161       vst1q_f32(acc_buffer_ptr, acc);
162       acc_buffer_ptr += 4;
163     }
164     // Handle 1 output pixel at a time
165     for (; outp < num_output_pixels; outp++) {
166       // Load the inputs
167       const float32x2_t input = vld1_f32(input_ptr);
168       input_ptr += 2;
169       // Load the accumulators from acc_buffer
170       float32x2_t acc = vld1_f32(acc_buffer_ptr);
171       // Multiply-accumulate
172       acc = vmla_f32(acc, input, filters);
173       // Store the accumulators back to acc_buffer
174       vst1_f32(acc_buffer_ptr, acc);
175       acc_buffer_ptr += 2;
176     }
177   }
178 };
179 
180 template <>
181 struct FloatDepthwiseConvKernel<true, 0, 1> {
182   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
183                   const float* input_ptr, int input_ptr_increment,
184                   const float* filter_ptr, float* acc_buffer_ptr) {
185     // Handle one output pixel at a time.
186     for (int outp = 0; outp < num_output_pixels; outp++) {
187       const float* local_filter_ptr = filter_ptr;
188       const float* local_input_ptr = input_ptr;
189       int ic = 0;
190       // Handle 16 input channels at a time.
191       for (; ic <= input_depth - 16; ic += 16) {
192         // Load the filters
193         float32x4_t filter[4];
194         for (int i = 0; i < 4; i++) {
195           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
196         }
197         local_filter_ptr += 16;
198         // Load the inputs
199         float32x4_t input[4];
200         for (int i = 0; i < 4; i++) {
201           input[i] = vld1q_f32(local_input_ptr + 4 * i);
202         }
203         local_input_ptr += 16;
204         // Load the accumulators from acc_buffer
205         float32x4_t acc[4];
206         for (int i = 0; i < 4; i++) {
207           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
208         }
209         // Multiply-accumulate
210         for (int i = 0; i < 4; i++) {
211           acc[i] = vmlaq_f32(acc[i], input[i], filter[i]);
212         }
213         // Store the accumulators back to acc_buffer
214         for (int i = 0; i < 4; i++) {
215           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
216         }
217         acc_buffer_ptr += 16;
218       }
219       // Handle 4 input channels at a time.
220       for (; ic <= input_depth - 4; ic += 4) {
221         // Load the filters
222         float32x4_t filter;
223         filter = vld1q_f32(local_filter_ptr);
224         local_filter_ptr += 4;
225         // Load the inputs
226         float32x4_t input;
227         input = vld1q_f32(local_input_ptr);
228         local_input_ptr += 4;
229         // Load the accumulators from acc_buffer
230         float32x4_t acc;
231         acc = vld1q_f32(acc_buffer_ptr);
232         // Multiply-accumulate
233         acc = vmlaq_f32(acc, input, filter);
234         // Store the accumulators back to acc_buffer
235         vst1q_f32(acc_buffer_ptr, acc);
236         acc_buffer_ptr += 4;
237       }
238       // Handle one input channel at a time.
239       for (; ic < input_depth; ic++) {
240         const float input_val = *local_input_ptr++;
241         const float filter_val = *local_filter_ptr++;
242         *acc_buffer_ptr++ += filter_val * input_val;
243       }
244       input_ptr += input_ptr_increment;
245     }
246   }
247 };
248 
249 template <>
250 struct FloatDepthwiseConvKernel<true, 0, 8> {
251   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
252                   const float* input_ptr, int input_ptr_increment,
253                   const float* filter_ptr, float* acc_buffer_ptr) {
254     // Handle one output pixel at a time.
255     for (int outp = 0; outp < num_output_pixels; outp++) {
256       const float* local_filter_ptr = filter_ptr;
257       const float* local_input_ptr = input_ptr;
258       int ic = 0;
259       // Handle 2 input channels at a time.
260       for (; ic <= input_depth - 2; ic += 2) {
261         // Load the filters
262         float32x4_t filter[4];
263         for (int i = 0; i < 4; i++) {
264           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
265         }
266         local_filter_ptr += 16;
267         // Load the inputs
268         const float32x2_t input = vld1_f32(local_input_ptr);
269         local_input_ptr += 2;
270         // Load the accumulators from acc_buffer
271         float32x4_t acc[4];
272         for (int i = 0; i < 4; i++) {
273           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
274         }
275         // Multiply-accumulate
276         acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0);
277         acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0);
278         acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1);
279         acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1);
280         // Store the accumulators back to acc_buffer
281         for (int i = 0; i < 4; i++) {
282           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
283         }
284         acc_buffer_ptr += 16;
285       }
286       // Handle one input channel at a time.
287       for (; ic < input_depth; ic++) {
288         // Load the filters
289         float32x4_t filter[2];
290         for (int i = 0; i < 2; i++) {
291           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
292         }
293         local_filter_ptr += 8;
294         // Load the inputs
295         const float input_val = *local_input_ptr++;
296         // Load the accumulators from acc_buffer
297         float32x4_t acc[2];
298         for (int i = 0; i < 2; i++) {
299           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
300         }
301         // Multiply-accumulate
302         for (int i = 0; i < 2; i++) {
303           acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val);
304         }
305         // Store the accumulators back to acc_buffer
306         for (int i = 0; i < 2; i++) {
307           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
308         }
309         acc_buffer_ptr += 8;
310       }
311       input_ptr += input_ptr_increment;
312     }
313   }
314 };
315 
316 template <>
317 struct FloatDepthwiseConvKernel<true, 0, 2> {
318   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
319                   const float* input_ptr, int input_ptr_increment,
320                   const float* filter_ptr, float* acc_buffer_ptr) {
321     // Handle one output pixel at a time.
322     for (int outp = 0; outp < num_output_pixels; outp++) {
323       const float* local_filter_ptr = filter_ptr;
324       const float* local_input_ptr = input_ptr;
325       int ic = 0;
326       // Handle 8 input channels at a time.
327       for (; ic <= input_depth - 8; ic += 8) {
328         // Load the filters
329         float32x4_t filter[4];
330         for (int i = 0; i < 4; i++) {
331           filter[i] = vld1q_f32(local_filter_ptr + 4 * i);
332         }
333         local_filter_ptr += 16;
334         // Load the inputs
335         float32x4x2_t input_dup2[2];
336         for (int i = 0; i < 2; i++) {
337           const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i);
338           input_dup2[i] = vzipq_f32(input, input);
339         }
340         local_input_ptr += 8;
341         // Load the accumulators from acc_buffer
342         float32x4_t acc[4];
343         for (int i = 0; i < 4; i++) {
344           acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i);
345         }
346         // Multiply-accumulate
347         acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]);
348         acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]);
349         acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]);
350         acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]);
351         // Store the accumulators back to acc_buffer
352         for (int i = 0; i < 4; i++) {
353           vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]);
354         }
355         acc_buffer_ptr += 16;
356       }
357       // Handle 4 input channels at a time.
358       for (; ic <= input_depth - 4; ic += 4) {
359         // Load the filters
360         float32x2_t filter[4];
361         for (int i = 0; i < 4; i++) {
362           filter[i] = vld1_f32(local_filter_ptr + 2 * i);
363         }
364         local_filter_ptr += 8;
365         // Load the inputs
366         const float32x4_t input = vld1q_f32(local_input_ptr);
367         local_input_ptr += 4;
368         // Load the accumulators from acc_buffer
369         float32x2_t acc[4];
370         for (int i = 0; i < 4; i++) {
371           acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
372         }
373         // Multiply-accumulate
374         acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0);
375         acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1);
376         acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0);
377         acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1);
378         // Store the accumulators back to acc_buffer
379         for (int i = 0; i < 4; i++) {
380           vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
381         }
382         acc_buffer_ptr += 8;
383       }
384       // Handle 2 input channels at a time.
385       for (; ic <= input_depth - 2; ic += 2) {
386         // Load the filters
387         const float32x4_t filter = vld1q_f32(local_filter_ptr);
388         local_filter_ptr += 4;
389         // Load the inputs
390         const float32x2_t input = vld1_f32(local_input_ptr);
391         local_input_ptr += 2;
392         // Load the accumulators from acc_buffer
393         float32x2_t acc[2];
394         for (int i = 0; i < 2; i++) {
395           acc[i] = vld1_f32(acc_buffer_ptr + 2 * i);
396         }
397         // Multiply-accumulate
398         acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0);
399         acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1);
400         // Store the accumulators back to acc_buffer
401         for (int i = 0; i < 2; i++) {
402           vst1_f32(acc_buffer_ptr + 2 * i, acc[i]);
403         }
404         acc_buffer_ptr += 4;
405       }
406       // Handle one input channel at a time.
407       for (; ic < input_depth; ic++) {
408         // Load the inputs
409         const float input_val = *local_input_ptr++;
410         // Multiply-accumulate
411         for (int i = 0; i < 2; i++) {
412           acc_buffer_ptr[i] += local_filter_ptr[i] * input_val;
413         }
414         local_filter_ptr += 2;
415         acc_buffer_ptr += 2;
416       }
417       input_ptr += input_ptr_increment;
418     }
419   }
420 };
421 #endif
422 
423 // Accumulates the effect of one row of the filter, on a segment of one row
424 // of the output, accessing the corresponding one row of the input.
425 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
426 void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
427                                 const float* input_data, int pad_width,
428                                 int depth_multiplier, int filter_width,
429                                 const float* filter_data,
430                                 int out_x_buffer_start, int out_x_buffer_end,
431                                 int output_depth, float* acc_buffer) {
432 #ifdef GEMMLOWP_PROFILING
433   gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
434 #endif
435   // Sanity check parameters. This is important in particular to ensure
436   // that we keep the number of template instantiations minimal, so we don't
437   // increase binary size unnecessarily.
438   static_assert(kFixedDepthMultiplier || !kFixedInputDepth, "");
439   static_assert(kFixedInputDepth || kAllowStrided, "");
440   DCHECK(stride == 1 || kAllowStrided);
441   if (kFixedInputDepth) {
442     DCHECK_EQ(input_depth, kFixedInputDepth);
443   }
444   if (kFixedDepthMultiplier) {
445     DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier);
446   }
447   DCHECK_EQ(output_depth, input_depth * depth_multiplier);
448   const int input_ptr_increment = stride * input_depth;
449   const float* filter_base_ptr = filter_data;
450   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
451     // For the current (filter_x, filter_y) point in the filter,
452     // compute the boundaries of the corresponding output row segment.
453     int out_x_loop_start_unclamped = 0;
454     int out_x_loop_end_unclamped = 0;
455     if (kAllowStrided) {
456       if (stride == 2) {
457         out_x_loop_start_unclamped = (pad_width - filter_x + 1) / 2;
458         out_x_loop_end_unclamped = (pad_width + input_width - filter_x + 1) / 2;
459       } else if (stride == 4) {
460         out_x_loop_start_unclamped = (pad_width - filter_x + 3) / 4;
461         out_x_loop_end_unclamped = (pad_width + input_width - filter_x + 3) / 4;
462       } else {
463         out_x_loop_start_unclamped =
464             (pad_width - filter_x + stride - 1) / stride;
465         out_x_loop_end_unclamped =
466             (pad_width + input_width - filter_x + stride - 1) / stride;
467       }
468     } else {
469       out_x_loop_start_unclamped = pad_width - filter_x;
470       out_x_loop_end_unclamped = pad_width + input_width - filter_x;
471     }
472     // The kernel will have to iterate on the segment of the
473     // output row that starts at out_x_loop_start and out_x_loop_end.
474     const int out_x_loop_start =
475         std::max(out_x_buffer_start, out_x_loop_start_unclamped);
476     const int out_x_loop_end =
477         std::min(out_x_buffer_end, out_x_loop_end_unclamped);
478 
479     float* acc_buffer_ptr =
480         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
481     const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
482     const float* input_ptr = input_data + in_x_origin * input_depth;
483     const int num_output_pixels = out_x_loop_end - out_x_loop_start;
484     FloatDepthwiseConvKernel<kAllowStrided, kFixedInputDepth,
485                              kFixedDepthMultiplier>::Run(num_output_pixels,
486                                                          input_depth,
487                                                          depth_multiplier,
488                                                          input_ptr,
489                                                          input_ptr_increment,
490                                                          filter_base_ptr,
491                                                          acc_buffer_ptr);
492     filter_base_ptr += output_depth;
493   }
494 }
495 
496 // generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
497 inline void FloatDepthwiseConvAccumRowGeneric(
498     int stride, int input_depth, int input_width, const float* input_data,
499     int pad_width, int depth_multiplier, int filter_width,
500     const float* filter_data, int out_x_buffer_start, int out_x_buffer_end,
501     int output_depth, float* acc_buffer) {
502   gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
503 
504   VLOG(1) << "DepthwiseConv2d using slow path with "
505           << "stride = " << stride << ", "
506           << "input_depth = " << input_depth << ", "
507           << "depth_multiplier = " << depth_multiplier << ".";
508 
509   const float* filter_base_ptr = filter_data;
510   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
511     const int out_x_loop_start = std::max(
512         out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
513     const int out_x_loop_end =
514         std::min(out_x_buffer_end,
515                  (pad_width + input_width - filter_x + stride - 1) / stride);
516 
517     float* acc_buffer_ptr =
518         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
519     const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
520     const float* input_ptr = input_data + in_x_origin * input_depth;
521     const int input_ptr_increment = (stride - 1) * input_depth;
522     for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
523       const float* filter_ptr = filter_base_ptr;
524       for (int ic = 0; ic < input_depth; ++ic) {
525         const float input_val = *input_ptr++;
526         for (int m = 0; m < depth_multiplier; m++) {
527           const float filter_val = *filter_ptr++;
528           *acc_buffer_ptr++ += filter_val * input_val;
529         }
530       }
531       input_ptr += input_ptr_increment;
532     }
533     filter_base_ptr += output_depth;
534   }
535 }
536 
537 // Initializes the accumulator buffer with bias values.
538 inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
539                                        const float* bias_data,
540                                        float* acc_buffer) {
541   // TODO(benoitjacob): This might need optimized specializations
542   // for small output_depth values, if that ever becomes an important
543   // case (like it was for some quantized DepthwiseConv cases).
544   for (int i = 0; i < num_output_pixels; i++) {
545     memcpy(acc_buffer + i * output_depth, bias_data,
546            sizeof(acc_buffer[0]) * output_depth);
547   }
548 }
549 
550 template <FusedActivationFunctionType Ac>
551 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
552                    const float* filter_data, const Dims<4>& filter_dims,
553                    const float* bias_data, const Dims<4>& bias_dims, int stride,
554                    int pad_width, int pad_height, int depth_multiplier,
555                    float* output_data, const Dims<4>& output_dims) {
556   gemmlowp::ScopedProfilingLabel label("DepthwiseConv");
557   static_assert(Ac == FusedActivationFunctionType::kNone ||
558                     Ac == FusedActivationFunctionType::kRelu ||
559                     Ac == FusedActivationFunctionType::kRelu6 ||
560                     Ac == FusedActivationFunctionType::kRelu1,
561                 "");
562   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
563   const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
564   const int input_height = ArraySize(input_dims, 2);
565   const int input_width = ArraySize(input_dims, 1);
566   const int input_depth = ArraySize(input_dims, 0);
567   const int filter_height = ArraySize(filter_dims, 2);
568   const int filter_width = ArraySize(filter_dims, 1);
569   const int output_height = ArraySize(output_dims, 2);
570   const int output_width = ArraySize(output_dims, 1);
571   DCHECK(output_depth == input_depth * depth_multiplier);
572 
573   static const int kAccBufferMaxSize = 1024;
574   float acc_buffer[kAccBufferMaxSize];
575   DCHECK_GE(kAccBufferMaxSize, output_depth)
576       << "Too small kAccBufferMaxSize for this model!";
577   const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth;
578   const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth;
579   DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, kAccBufferActualSize);
580   DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize);
581   DCHECK_GE(kOutputPixelsInAccBuffer, 1);
582 
583   // row_accum_func will point to the core accumulation function to be used
584   // for this DepthwiseConv op.
585   auto* row_accum_func = FloatDepthwiseConvAccumRowGeneric;
586 
587   const int kMaxFixedDepthMultiplier = 8;
588   int fixed_depth_multiplier = 0;
589   if (depth_multiplier <= kMaxFixedDepthMultiplier) {
590     fixed_depth_multiplier = depth_multiplier;
591   }
592   // kMaxUnrolling is the max number of output values that we aim to handle
593   // in one unrolled iteration of the inner loop. For practical performance
594   // reasons, it is limited by the number of available registers. We could
595   // fine-tune it depending on the architecture, but that's not worth doing
596   // since this whole code is not very optimized to begin with. The
597   // present value reflects what's realistic on ARM 32bit NEON with 16 128-bit
598   // vector registers.
599   const int kMaxUnrolling = 8;
600   int fixed_input_depth = 0;
601   if (fixed_depth_multiplier &&
602       input_depth * fixed_depth_multiplier <= kMaxUnrolling) {
603     fixed_input_depth = input_depth;
604   }
605 #define TF_NEON_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
606                                          FIXED_DEPTH_MULTIPLIER)           \
607   if ((stride == 1 || ALLOW_STRIDED) &&                                    \
608       fixed_input_depth == FIXED_INPUT_DEPTH &&                            \
609       fixed_depth_multiplier == FIXED_DEPTH_MULTIPLIER) {                  \
610     row_accum_func =                                                       \
611         FloatDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH,       \
612                                    FIXED_DEPTH_MULTIPLIER>;                \
613   }
614 
615 #ifdef USE_NEON
616   TF_NEON_USE_DEPTHWISECONV_KERNEL(true, 0, 1)
617   TF_NEON_USE_DEPTHWISECONV_KERNEL(true, 0, 8)
618   TF_NEON_USE_DEPTHWISECONV_KERNEL(true, 0, 2)
619   TF_NEON_USE_DEPTHWISECONV_KERNEL(false, 8, 1)
620   TF_NEON_USE_DEPTHWISECONV_KERNEL(false, 2, 1)
621 #endif  // USE_NEON
622 
623 #undef TF_NEON_USE_DEPTHWISECONV_KERNEL
624 
625   // Now that we have determined row_accum_func, we can start work.
626   float* output_ptr = output_data;
627   for (int b = 0; b < batches; ++b) {
628     for (int out_y = 0; out_y < output_height; ++out_y) {
629       const int in_y_origin = (out_y * stride) - pad_height;
630       const int filter_y_start = std::max(0, -in_y_origin);
631       const int filter_y_end =
632           std::min(filter_height, input_height - in_y_origin);
633       for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
634            out_x_buffer_start += kOutputPixelsInAccBuffer) {
635         const int out_x_buffer_end = std::min(
636             output_width, out_x_buffer_start + kOutputPixelsInAccBuffer);
637         // We call a 'pixel' a group of activation that share all but the
638         // 'depth'/'channel' coordinate. num_output_pixels is the number of
639         // output pixels that we will accumulate in this loop iteration.
640         const int num_output_pixels = out_x_buffer_end - out_x_buffer_start;
641         // Initialize our local accumulator with the bias values, so we don't
642         // have to add them later.
643         DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data,
644                                    acc_buffer);
645         // Accumulation loop. Most of the time should be spent in here.
646         for (int filter_y = filter_y_start; filter_y < filter_y_end;
647              ++filter_y) {
648           const int in_y = in_y_origin + filter_y;
649           row_accum_func(stride, input_depth, input_width,
650                          input_data + in_y * input_dims.strides[2] +
651                              b * input_dims.strides[3],
652                          pad_width, depth_multiplier, filter_width,
653                          filter_data + filter_y * filter_dims.strides[2],
654                          out_x_buffer_start, out_x_buffer_end, output_depth,
655                          acc_buffer);
656         }
657         // Finished accumulating. Now store to destination.
658         const int num_output_values = output_depth * num_output_pixels;
659         int i = 0;
660 // TODO(benoitjacob) optimized code goes here
661 #ifdef USE_NEON
662         // Handle 16 values at a time
663         for (; i <= num_output_values - 16; i += 16) {
664           float32x4_t acc[4];
665           for (int k = 0; k < 4; k++) {
666             acc[k] = vld1q_f32(acc_buffer + i + 4 * k);
667           }
668           if (Ac == FusedActivationFunctionType::kRelu) {
669             for (int k = 0; k < 4; k++) {
670               acc[k] = vmaxq_f32(vdupq_n_f32(0.f), acc[k]);
671             }
672           } else if (Ac == FusedActivationFunctionType::kRelu6) {
673             for (int k = 0; k < 4; k++) {
674               acc[k] = vmaxq_f32(vdupq_n_f32(0.f),
675                                  vminq_f32(vdupq_n_f32(6.f), acc[k]));
676             }
677           } else if (Ac == FusedActivationFunctionType::kRelu1) {
678             for (int k = 0; k < 4; k++) {
679               acc[k] = vmaxq_f32(vdupq_n_f32(-1.f),
680                                  vminq_f32(vdupq_n_f32(1.f), acc[k]));
681             }
682           }
683           for (int k = 0; k < 4; k++) {
684             vst1q_f32(output_ptr + 4 * k, acc[k]);
685           }
686           output_ptr += 16;
687         }
688         // Handle 4 values at a time
689         for (; i <= num_output_values - 4; i += 4) {
690           float32x4_t acc = vld1q_f32(acc_buffer + i);
691           if (Ac == FusedActivationFunctionType::kRelu) {
692             acc = vmaxq_f32(vdupq_n_f32(0.f), acc);
693           } else if (Ac == FusedActivationFunctionType::kRelu6) {
694             acc = vmaxq_f32(vdupq_n_f32(0.f), vminq_f32(vdupq_n_f32(6.f), acc));
695           } else if (Ac == FusedActivationFunctionType::kRelu1) {
696             acc =
697                 vmaxq_f32(vdupq_n_f32(-1.f), vminq_f32(vdupq_n_f32(1.f), acc));
698           }
699           vst1q_f32(output_ptr, acc);
700           output_ptr += 4;
701         }
702 #endif
703         // Handle leftover values, one by one. This is very slow.
704         for (; i < num_output_values; i++) {
705           float acc = acc_buffer[i];
706           if (Ac == FusedActivationFunctionType::kRelu) {
707             acc = std::max(0.f, acc);
708           } else if (Ac == FusedActivationFunctionType::kRelu6) {
709             acc = std::max(0.f, std::min(6.f, acc));
710           } else if (Ac == FusedActivationFunctionType::kRelu1) {
711             acc = std::max(-1.f, std::min(1.f, acc));
712           }
713           *output_ptr++ = acc;
714         }
715       }
716     }
717   }
718 }
719 
720 }  // end namespace neon
721 }  // end namespace tensorflow
722 
723 #endif  // TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_
724