• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
17 
18 #include <type_traits>
19 
20 #include "ruy/profiler/instrumentation.h"  // from @ruy
21 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
22 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h"
23 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25 
26 namespace tflite {
27 namespace optimized_ops {
28 namespace depthwise_conv {
29 
30 // Implementation of quantized DepthwiseConv
31 
32 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
33 struct QuantizedDepthwiseConvKernel {};
34 
35 #ifdef USE_NEON
36 template <>
37 struct QuantizedDepthwiseConvKernel<true, 8, 2> {
38   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
39                   const uint8* input_ptr, int16 input_offset,
40                   int input_ptr_increment, const uint8* filter_ptr,
41                   int16 filter_offset, int32* acc_buffer_ptr) {
42     // Load the filters, add filter_offset.
43     uint8x8x2_t filter_u8;
44     filter_u8.val[0] = vld1_u8(filter_ptr);
45     filter_u8.val[1] = vld1_u8(filter_ptr + 8);
46     int16x8_t filter[2];
47     for (int i = 0; i < 2; i++) {
48       filter[i] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])),
49                             vdupq_n_s16(filter_offset));
50     }
51     // Handle one output pixel at a time.
52     for (int outp = 0; outp < num_output_pixels; outp++) {
53       // Load the accumulators from acc_buffer
54       int32x4x2_t acc[2];
55       for (int i = 0; i < 2; i++) {
56         acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
57         acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
58       }
59       // Load the inputs, add input_offset.
60       const uint8x8_t input_u8 = vld1_u8(input_ptr);
61       input_ptr += input_ptr_increment;
62       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
63       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
64       // Duplicate the input values, 2-fold
65       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
66       // Multiply-accumulate
67       for (int i = 0; i < 2; i++) {
68         acc[0].val[i] = vmlal_s16(acc[0].val[i], vget_low_s16(filter[i]),
69                                   vget_low_s16(input_dup2.val[i]));
70         acc[1].val[i] = vmlal_s16(acc[1].val[i], vget_high_s16(filter[i]),
71                                   vget_high_s16(input_dup2.val[i]));
72       }
73       // Store the accumulators back to acc_buffer
74       for (int i = 0; i < 2; i++) {
75         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
76         vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
77       }
78       acc_buffer_ptr += 16;
79     }
80   }
81 };
82 
83 template <>
84 struct QuantizedDepthwiseConvKernel<false, 8, 1> {
85   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
86                   const uint8* input_ptr, int16 input_offset,
87                   int input_ptr_increment, const uint8* filter_ptr,
88                   int16 filter_offset, int32* acc_buffer_ptr) {
89     // Load the filters, add filter_offset.
90     const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
91     const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
92     const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
93 
94     int outp = 0;
95     // Handle 2 output pixels at a time.
96     for (; outp <= num_output_pixels - 2; outp += 2) {
97       // Load the accumulators from acc_buffer.
98       int32x4_t acc[4];
99       for (int i = 0; i < 4; i++) {
100         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
101       }
102       // Load the inputs, add input_offset.
103       uint8x8_t input_u8[2];
104       for (int i = 0; i < 2; i++) {
105         input_u8[i] = vld1_u8(input_ptr + 8 * i);
106       }
107       input_ptr += 16;
108       int16x8_t input[2];
109       for (int i = 0; i < 2; i++) {
110         input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
111       }
112       for (int i = 0; i < 2; i++) {
113         input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
114       }
115       // Multiply-accumulate.
116       acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input[0]));
117       acc[1] =
118           vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input[0]));
119       acc[2] = vmlal_s16(acc[2], vget_low_s16(filter), vget_low_s16(input[1]));
120       acc[3] =
121           vmlal_s16(acc[3], vget_high_s16(filter), vget_high_s16(input[1]));
122       // Store the accumulators back to acc_buffer
123       for (int i = 0; i < 4; i++) {
124         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
125       }
126       acc_buffer_ptr += 16;
127     }
128     // Handle 1 output pixel at a time.
129     for (; outp < num_output_pixels; outp++) {
130       // Load the accumulators from acc_buffer.
131       int32x4_t acc[2];
132       acc[0] = vld1q_s32(acc_buffer_ptr);
133       acc[1] = vld1q_s32(acc_buffer_ptr + 4);
134 
135       // Load the inputs, add input_offset.
136       const uint8x8_t input_u8 = vld1_u8(input_ptr);
137       input_ptr += 8;
138       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
139       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
140       // Multiply-accumulate.
141       acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input));
142       acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input));
143       // Store the accumulators back to acc_buffer
144       vst1q_s32(acc_buffer_ptr, acc[0]);
145       vst1q_s32(acc_buffer_ptr + 4, acc[1]);
146       acc_buffer_ptr += 8;
147     }
148   }
149 };
150 
151 template <>
152 struct QuantizedDepthwiseConvKernel<false, 4, 2> {
153   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
154                   const uint8* input_ptr, int16 input_offset,
155                   int input_ptr_increment, const uint8* filter_ptr,
156                   int16 filter_offset, int32* acc_buffer_ptr) {
157     // Load the filters, add filter_offset.
158     const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
159     const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
160     const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
161 
162     int outp = 0;
163     // Handle 2 output pixels at a time.
164     for (; outp <= num_output_pixels - 2; outp += 2) {
165       // Load the accumulators from acc_buffer
166       int32x4_t acc[4];
167       for (int i = 0; i < 4; i++) {
168         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
169       }
170       // Load the inputs, add input_offset.
171       const uint8x8_t input_u8 = vld1_u8(input_ptr);
172       input_ptr += 8;
173       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
174       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
175       // Duplicate the input values, 2-fold
176       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
177       // Multiply-accumulate
178       for (int i = 0; i < 2; i++) {
179         acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(filter),
180                                    vget_low_s16(input_dup2.val[i]));
181         acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(filter),
182                                    vget_high_s16(input_dup2.val[i]));
183       }
184       // Store the accumulators back to acc_buffer
185       for (int i = 0; i < 4; i++) {
186         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
187       }
188       acc_buffer_ptr += 16;
189     }
190     // Handle one output pixel at a time.
191     for (; outp < num_output_pixels; outp++) {
192       // Load the accumulators from acc_buffer
193       int32x4_t acc[2];
194       for (int i = 0; i < 2; i++) {
195         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
196       }
197       // Load the inputs, add input_offset.
198       uint8x8_t input_u8 = vdup_n_u8(0);
199       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
200       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
201       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
202       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
203       input_ptr += 4;
204       const int16x4_t input_s16 =
205           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
206       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
207       // Duplicate the input values, 2-fold
208       const int16x4x2_t input_dup2 = vzip_s16(input, input);
209       // Multiply-accumulate
210       acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), input_dup2.val[0]);
211       acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), input_dup2.val[1]);
212       // Store the accumulators back to acc_buffer
213       for (int i = 0; i < 2; i++) {
214         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
215       }
216       acc_buffer_ptr += 8;
217     }
218   }
219 };
220 
221 template <>
222 struct QuantizedDepthwiseConvKernel<false, 2, 8> {
223   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
224                   const uint8* input_ptr, int16 input_offset,
225                   int input_ptr_increment, const uint8* filter_ptr,
226                   int16 filter_offset, int32* acc_buffer_ptr) {
227     // Load the filters, add filter_offset.
228     int16x8_t filter[2];
229     for (int i = 0; i < 2; i++) {
230       const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i);
231       const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
232       filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
233     }
234     int outp = 0;
235     // Handle two output pixels at a time.
236     for (; outp <= num_output_pixels - 2; outp += 2) {
237       // Load the accumulators from acc_buffer.
238       int32x4_t acc[8];
239       for (int i = 0; i < 8; i++) {
240         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
241       }
242       // Load the inputs, add input_offset.
243       uint8x8_t input_u8 = vdup_n_u8(0);
244       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
245       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
246       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
247       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
248       input_ptr += 4;
249       const int16x4_t input_s16 =
250           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
251       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
252       // Multiply-accumulate.
253       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
254       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
255       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
256       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
257       acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), input, 2);
258       acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), input, 2);
259       acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), input, 3);
260       acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), input, 3);
261       // Store the accumulators back to acc_buffer.
262       for (int i = 0; i < 8; i++) {
263         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
264       }
265       acc_buffer_ptr += 32;
266     }
267     // Handle one output pixel at a time.
268     for (; outp < num_output_pixels; outp++) {
269       // Load the accumulators from acc_buffer.
270       int32x4_t acc[4];
271       for (int i = 0; i < 4; i++) {
272         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
273       }
274       // Load the inputs, add input_offset.
275       uint8x8_t input_u8 = vdup_n_u8(0);
276       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
277       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
278       input_ptr += 2;
279       const int16x4_t input_s16 =
280           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
281       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
282 
283       // Multiply-accumulate.
284       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
285       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
286       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
287       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
288 
289       // Store the accumulators back to acc_buffer.
290       for (int i = 0; i < 4; i++) {
291         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
292       }
293       acc_buffer_ptr += 16;
294     }
295   }
296 };
297 
298 template <>
299 struct QuantizedDepthwiseConvKernel<false, 2, 2> {
300   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
301                   const uint8* input_ptr, int16 input_offset,
302                   int input_ptr_increment, const uint8* filter_ptr,
303                   int16 filter_offset, int32* acc_buffer_ptr) {
304     // Load the filters, add filter_offset.
305     uint8x8_t filter_u8 = vdup_n_u8(0);
306     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
307     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
308     filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
309     filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
310     const int16x4_t filter_s16 =
311         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
312     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
313 
314     int outp = 0;
315     // Handle 4 output pixels at a time.
316     for (; outp <= num_output_pixels - 4; outp += 4) {
317       // Load the accumulators from acc_buffer
318       int32x4_t acc[4];
319       for (int i = 0; i < 4; i++) {
320         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
321       }
322 
323       // Load the inputs, add input_offset.
324       const uint8x8_t input_u8 = vld1_u8(input_ptr);
325       input_ptr += 8;
326       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
327       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
328       // Duplicate the input values, 2-fold
329       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
330       // Multiply-accumulate
331       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
332       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
333       acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
334       acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
335       // Store the accumulators back to acc_buffer
336       for (int i = 0; i < 4; i++) {
337         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
338       }
339       acc_buffer_ptr += 16;
340     }
341     // Handle one output pixel at a time.
342     for (; outp < num_output_pixels; outp++) {
343       // Load the accumulators from acc_buffer
344       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
345 
346       uint8x8_t input_u8 = vdup_n_u8(0);
347       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
348       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
349       input_ptr += 2;
350       const int16x4_t input_s16 =
351           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
352       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
353       // Duplicate the input values, 2-fold
354       const int16x4_t input_dup2 = vzip_s16(input, input).val[0];
355       // Multiply-accumulate
356       acc = vmlal_s16(acc, filter, input_dup2);
357       // Store the accumulators back to acc_buffer
358       vst1q_s32(acc_buffer_ptr, acc);
359       acc_buffer_ptr += 4;
360     }
361   }
362 };
363 
364 template <>
365 struct QuantizedDepthwiseConvKernel<false, 2, 1> {
366   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
367                   const uint8* input_ptr, int16 input_offset,
368                   int input_ptr_increment, const uint8* filter_ptr,
369                   int16 filter_offset, int32* acc_buffer_ptr) {
370     // Load the filters, add filter_offset.
371     uint8x8_t filter_u8 = vdup_n_u8(0);
372     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
373     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
374     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
375     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
376     const int16x4_t filter_s16 =
377         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
378     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
379 
380     int outp = 0;
381     // Handle 8 output pixels at a time.
382     for (; outp <= num_output_pixels - 8; outp += 8) {
383       // Load the accumulators from acc_buffer.
384       int32x4_t acc[4];
385       for (int i = 0; i < 4; i++) {
386         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
387       }
388       // Load the inputs, add input_offset.
389       uint8x8_t input_u8[2];
390       for (int i = 0; i < 2; i++) {
391         input_u8[i] = vld1_u8(input_ptr + 8 * i);
392       }
393       input_ptr += 16;
394       int16x8_t input[2];
395       for (int i = 0; i < 2; i++) {
396         input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
397       }
398       for (int i = 0; i < 2; i++) {
399         input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
400       }
401 
402       // Multiply-accumulate.
403       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input[0]));
404       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input[0]));
405       acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input[1]));
406       acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input[1]));
407       // Store the accumulators back to acc_buffer.
408       for (int i = 0; i < 4; i++) {
409         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
410       }
411       acc_buffer_ptr += 16;
412     }
413     // Handle 4 output pixels at a time.
414     for (; outp <= num_output_pixels - 4; outp += 4) {
415       // Load the accumulators from acc_buffer.
416       int32x4_t acc[2];
417       for (int i = 0; i < 2; i++) {
418         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
419       }
420       // Load the inputs, add input_offset.
421       const uint8x8_t input_u8 = vld1_u8(input_ptr);
422       input_ptr += 8;
423       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
424       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
425 
426       // Multiply-accumulate.
427       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input));
428       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input));
429       // Store the accumulators back to acc_buffer.
430       for (int i = 0; i < 2; i++) {
431         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
432       }
433       acc_buffer_ptr += 8;
434     }
435     // Handle 2 output pixels at a time.
436     for (; outp <= num_output_pixels - 2; outp += 2) {
437       // Load the accumulators from acc_buffer.
438       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
439       // Load the inputs, add input_offset.
440       uint8x8_t input_u8 = vdup_n_u8(0);
441       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
442       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
443       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
444       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
445       input_ptr += 4;
446       const int16x4_t input_s16 =
447           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
448       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
449 
450       // Multiply-accumulate.
451       acc = vmlal_s16(acc, filter, input);
452       // Store the accumulators back to acc_buffer.
453       vst1q_s32(acc_buffer_ptr, acc);
454       acc_buffer_ptr += 4;
455     }
456     // Handle 1 output pixel at a time.
457     for (; outp < num_output_pixels; outp++) {
458       // Load the accumulators from acc_buffer.
459       int32x2_t acc = vld1_s32(acc_buffer_ptr);
460       // Load the inputs, add input_offset.
461       uint8x8_t input_u8 = vdup_n_u8(0);
462       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
463       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
464       input_ptr += 2;
465       const int16x4_t input_s16 =
466           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
467       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
468 
469       // Multiply-accumulate.
470       acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
471       // Store the accumulators back to acc_buffer.
472       vst1_s32(acc_buffer_ptr, acc);
473       acc_buffer_ptr += 2;
474     }
475   }
476 };
477 
478 template <>
479 struct QuantizedDepthwiseConvKernel<false, 1, 2> {
480   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
481                   const uint8* input_ptr, int16 input_offset,
482                   int input_ptr_increment, const uint8* filter_ptr,
483                   int16 filter_offset, int32* acc_buffer_ptr) {
484     // Load the filters, add filter_offset.
485     uint8x8_t filter_u8 = vdup_n_u8(0);
486     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
487     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
488     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
489     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
490     const int16x4_t filter_s16 =
491         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
492     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
493 
494     int outp = 0;
495     // Handle 8 output pixels at a time.
496     for (; outp <= num_output_pixels - 8; outp += 8) {
497       // Load the accumulators from acc_buffer
498       int32x4_t acc[4];
499       for (int i = 0; i < 4; i++) {
500         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
501       }
502 
503       // Load the inputs, add input_offset.
504       const uint8x8_t input_u8 = vld1_u8(input_ptr);
505       input_ptr += 8;
506       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
507       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
508       // Duplicate the input values, 2-fold
509       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
510       // Multiply-accumulate
511       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
512       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
513       acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
514       acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
515       // Store the accumulators back to acc_buffer
516       for (int i = 0; i < 4; i++) {
517         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
518       }
519       acc_buffer_ptr += 16;
520     }
521     // Handle one output pixel at a time.
522     for (; outp < num_output_pixels; outp++) {
523       // Load the accumulators from acc_buffer
524       int32x2_t acc = vld1_s32(acc_buffer_ptr);
525 
526       // Load the inputs, add input_offset.
527       const uint32 input = *input_ptr++ + input_offset;
528 
529       // Multiply-accumulate
530       acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input));
531       // Store the accumulators back to acc_buffer
532       vst1_s32(acc_buffer_ptr, acc);
533       acc_buffer_ptr += 2;
534     }
535   }
536 };
537 
538 template <>
539 struct QuantizedDepthwiseConvKernel<false, 1, 4> {
540   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
541                   const uint8* input_ptr, int16 input_offset,
542                   int input_ptr_increment, const uint8* filter_ptr,
543                   int16 filter_offset, int32* acc_buffer_ptr) {
544     // Load the filters, add filter_offset.
545     uint8x8_t filter_u8 = vdup_n_u8(0);
546     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
547     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
548     filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
549     filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
550     const int16x4_t filter_s16 =
551         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
552     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
553 
554     int outp = 0;
555     // Handle 8 output pixels at a time.
556     for (; outp <= num_output_pixels - 8; outp += 8) {
557       // Load the accumulators from acc_buffer
558       int32x4_t acc[8];
559       for (int i = 0; i < 8; i++) {
560         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
561       }
562 
563       // Load the inputs, add input_offset.
564       uint8x8_t input_u8 = vld1_u8(input_ptr);
565       input_ptr += 8;
566       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
567       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
568 
569       // Multiply-accumulate
570       acc[0] = vmlal_lane_s16(acc[0], filter, vget_low_s16(input), 0);
571       acc[1] = vmlal_lane_s16(acc[1], filter, vget_low_s16(input), 1);
572       acc[2] = vmlal_lane_s16(acc[2], filter, vget_low_s16(input), 2);
573       acc[3] = vmlal_lane_s16(acc[3], filter, vget_low_s16(input), 3);
574       acc[4] = vmlal_lane_s16(acc[4], filter, vget_high_s16(input), 0);
575       acc[5] = vmlal_lane_s16(acc[5], filter, vget_high_s16(input), 1);
576       acc[6] = vmlal_lane_s16(acc[6], filter, vget_high_s16(input), 2);
577       acc[7] = vmlal_lane_s16(acc[7], filter, vget_high_s16(input), 3);
578 
579       // Store the accumulators back to acc_buffer
580       for (int i = 0; i < 8; i++) {
581         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
582       }
583       acc_buffer_ptr += 32;
584     }
585     // Handle 4 output pixels at a time.
586     for (; outp <= num_output_pixels - 4; outp += 4) {
587       // Load the accumulators from acc_buffer
588       int32x4_t acc[4];
589       for (int i = 0; i < 4; i++) {
590         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
591       }
592 
593       // Load the inputs, add input_offset.
594       uint8x8_t input_u8 = vdup_n_u8(0);
595       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
596       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
597       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
598       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
599       input_ptr += 4;
600       const int16x4_t input_s16 =
601           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
602       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
603 
604       // Multiply-accumulate
605       acc[0] = vmlal_lane_s16(acc[0], filter, input, 0);
606       acc[1] = vmlal_lane_s16(acc[1], filter, input, 1);
607       acc[2] = vmlal_lane_s16(acc[2], filter, input, 2);
608       acc[3] = vmlal_lane_s16(acc[3], filter, input, 3);
609 
610       // Store the accumulators back to acc_buffer
611       for (int i = 0; i < 4; i++) {
612         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
613       }
614       acc_buffer_ptr += 16;
615     }
616     // Handle one output pixel at a time.
617     for (; outp < num_output_pixels; outp++) {
618       // Load the accumulators from acc_buffer
619       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
620 
621       // Load the inputs, add input_offset.
622       const uint32 input = *input_ptr++ + input_offset;
623 
624       // Multiply-accumulate
625       acc = vmlal_n_s16(acc, filter, input);
626       // Store the accumulators back to acc_buffer
627       vst1q_s32(acc_buffer_ptr, acc);
628       acc_buffer_ptr += 4;
629     }
630   }
631 };
632 
633 template <>
634 struct QuantizedDepthwiseConvKernel<false, 4, 1> {
635   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
636                   const uint8* input_ptr, int16 input_offset,
637                   int input_ptr_increment, const uint8* filter_ptr,
638                   int16 filter_offset, int32* acc_buffer_ptr) {
639     // Load the filters, add filter_offset.
640     uint8x8_t filter_u8 = vdup_n_u8(0);
641     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
642     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
643     filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
644     filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
645     const int16x4_t filter_s16 =
646         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
647     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
648 
649     int outp = 0;
650     // Handle 4 output pixels at a time.
651     for (; outp <= num_output_pixels - 4; outp += 4) {
652       // Load the accumulators from acc_buffer
653       int32x4_t acc[4];
654       for (int i = 0; i < 4; i++) {
655         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
656       }
657       // Load the inputs, add input_offset.
658       int16x8_t input[2];
659       for (int i = 0; i < 2; i++) {
660         const uint8x8_t input_u8 = vld1_u8(input_ptr + 8 * i);
661         const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
662         input[i] = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
663       }
664       input_ptr += 16;
665       // Multiply-accumulate
666       for (int i = 0; i < 2; i++) {
667         acc[2 * i + 0] =
668             vmlal_s16(acc[2 * i + 0], filter, vget_low_s16(input[i]));
669         acc[2 * i + 1] =
670             vmlal_s16(acc[2 * i + 1], filter, vget_high_s16(input[i]));
671       }
672       // Store the accumulators back to acc_buffer
673       for (int i = 0; i < 4; i++) {
674         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
675       }
676       acc_buffer_ptr += 16;
677     }
678     // Handle one output pixel at a time.
679     for (; outp < num_output_pixels; outp++) {
680       // Load the accumulators from acc_buffer
681       int32x4_t acc;
682       acc = vld1q_s32(acc_buffer_ptr);
683 
684       // Load the inputs, add input_offset.
685       uint8x8_t input_u8 = vdup_n_u8(0);
686       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
687       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
688       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
689       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
690       input_ptr += 4;
691       const int16x4_t input_s16 =
692           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
693       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
694       // Multiply-accumulate
695       acc = vmlal_s16(acc, filter, input);
696       // Store the accumulators back to acc_buffer
697       vst1q_s32(acc_buffer_ptr, acc);
698       acc_buffer_ptr += 4;
699     }
700   }
701 };
702 
703 template <>
704 struct QuantizedDepthwiseConvKernel<false, 4, 4> {
705   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
706                   const uint8* input_ptr, int16 input_offset,
707                   int input_ptr_increment, const uint8* filter_ptr,
708                   int16 filter_offset, int32* acc_buffer_ptr) {
709     // Load the filters, add filter_offset.
710     int16x8_t filter[2];
711     for (int i = 0; i < 2; i++) {
712       const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i);
713       const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
714       filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
715     }
716 
717     int outp = 0;
718     // Handle 2 output pixels at a time.
719     for (; outp <= num_output_pixels - 2; outp += 2) {
720       // Load the accumulators from acc_buffer
721       int32x4_t acc[8];
722       for (int i = 0; i < 8; i++) {
723         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
724       }
725 
726       // Load the inputs, add input_offset.
727       uint8x8_t input_u8 = vld1_u8(input_ptr);
728       input_ptr += 8;
729       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
730       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
731 
732       // Multiply-accumulate
733       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]),
734                               vget_low_s16(input), 0);
735       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]),
736                               vget_low_s16(input), 1);
737       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]),
738                               vget_low_s16(input), 2);
739       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]),
740                               vget_low_s16(input), 3);
741       acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]),
742                               vget_high_s16(input), 0);
743       acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]),
744                               vget_high_s16(input), 1);
745       acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]),
746                               vget_high_s16(input), 2);
747       acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]),
748                               vget_high_s16(input), 3);
749       // Store the accumulators back to acc_buffer
750       for (int i = 0; i < 8; i++) {
751         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
752       }
753       acc_buffer_ptr += 32;
754     }
755     // Handle one output pixel at a time.
756     for (; outp < num_output_pixels; outp++) {
757       // Load the accumulators from acc_buffer
758       int32x4_t acc[4];
759       for (int i = 0; i < 4; i++) {
760         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
761       }
762 
763       // Load the inputs, add input_offset.
764       uint8x8_t input_u8 = vdup_n_u8(0);
765       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
766       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
767       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
768       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
769       input_ptr += 4;
770       const int16x4_t input_s16 =
771           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
772       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
773 
774       // Multiply-accumulate
775       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
776       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 1);
777       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 2);
778       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 3);
779       // Store the accumulators back to acc_buffer
780       for (int i = 0; i < 4; i++) {
781         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
782       }
783       acc_buffer_ptr += 16;
784     }
785   }
786 };
787 
788 template <>
789 struct QuantizedDepthwiseConvKernel<true, 0, 3> {
790   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
791                   const uint8* input_ptr, int16 input_offset,
792                   int input_ptr_increment, const uint8* filter_ptr,
793                   int16 filter_offset, int32* acc_buffer_ptr) {
794     // We will have to duplicate bytes in a NEON register, 3-fold.
795     // We will do that by register-level table-look-up using VTBL instructions.
796     // Here we prepare the registers containing the table-lookup indices.
797     static const uint8 dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2},
798                                                    {2, 3, 3, 3, 4, 4, 4, 5},
799                                                    {5, 5, 6, 6, 6, 7, 7, 7}};
800     uint8x8_t dup3_indices[3];
801     for (int i = 0; i < 3; i++) {
802       dup3_indices[i] = vld1_u8(dup3_indices_array[i]);
803     }
804 
805     // Handle one output pixel at a time.
806     for (int outp = 0; outp < num_output_pixels; outp++) {
807       const uint8* local_filter_ptr = filter_ptr;
808       const uint8* local_input_ptr = input_ptr;
809       int ic = 0;
810       // Handle 8 input channels at a time.
811       for (; ic <= input_depth - 8; ic += 8) {
812         // Load the filters, add filter_offset.
813         int16x8_t filter[3];
814         uint8x8x3_t filter_u8;
815         filter_u8.val[0] = vld1_u8(local_filter_ptr);
816         filter_u8.val[1] = vld1_u8(local_filter_ptr + 8);
817         filter_u8.val[2] = vld1_u8(local_filter_ptr + 16);
818         local_filter_ptr += 24;
819         for (int i = 0; i < 3; i++) {
820           const int16x8_t filter_s16 =
821               vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i]));
822           filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
823         }
824         // Load the inputs, duplicate 3-fold, add input_offset.
825         const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
826         local_input_ptr += 8;
827 
828         uint8x8_t input_u8_dup3[3];
829         for (int i = 0; i < 3; i++) {
830           input_u8_dup3[i] = vtbl1_u8(input_u8, dup3_indices[i]);
831         }
832         int16x8_t input_dup3[3];
833         for (int i = 0; i < 3; i++) {
834           const int16x8_t input_s16_dup3 =
835               vreinterpretq_s16_u16(vmovl_u8(input_u8_dup3[i]));
836           input_dup3[i] = vaddq_s16(input_s16_dup3, vdupq_n_s16(input_offset));
837         }
838         // Load the accumulators from acc_buffer
839         int32x4x3_t acc[2];
840         for (int i = 0; i < 2; i++) {
841           acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
842           acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
843           acc[i].val[2] = vld1q_s32(acc_buffer_ptr + 4 * i + 16);
844         }
845         // Multiply-accumulate
846         for (int j = 0; j < 3; j++) {
847           acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(input_dup3[j]),
848                                     vget_low_s16(filter[j]));
849           acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(input_dup3[j]),
850                                     vget_high_s16(filter[j]));
851         }
852         // Store the accumulators back to acc_buffer
853         for (int i = 0; i < 2; i++) {
854           vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
855           vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
856           vst1q_s32(acc_buffer_ptr + 4 * i + 16, acc[i].val[2]);
857         }
858         acc_buffer_ptr += 24;
859       }
860       // Handle one input channel at a time.
861       for (; ic < input_depth; ic++) {
862         const int16 input_val = *local_input_ptr++ + input_offset;
863         for (int i = 0; i < 3; i++) {
864           const int16 filter_val = local_filter_ptr[i] + filter_offset;
865           *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
866         }
867         local_filter_ptr += 3;
868       }
869       input_ptr += input_ptr_increment;
870     }
871   }
872 };
873 
874 template <>
875 struct QuantizedDepthwiseConvKernel<true, 0, 2> {
876   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
877                   const uint8* input_ptr, int16 input_offset,
878                   int input_ptr_increment, const uint8* filter_ptr,
879                   int16 filter_offset, int32* acc_buffer_ptr) {
880     // Handle one output pixel at a time.
881     for (int outp = 0; outp < num_output_pixels; outp++) {
882       const uint8* local_filter_ptr = filter_ptr;
883       const uint8* local_input_ptr = input_ptr;
884       int ic = 0;
885       // Handle 8 input channels at a time.
886       for (; ic <= input_depth - 8; ic += 8) {
887         // Load the filters, add filter_offset.
888         int16x8_t filter[2];
889         uint8x8x2_t filter_u8;
890         filter_u8.val[0] = vld1_u8(local_filter_ptr);
891         filter_u8.val[1] = vld1_u8(local_filter_ptr + 8);
892         local_filter_ptr += 16;
893         for (int i = 0; i < 2; i++) {
894           const int16x8_t filter_s16 =
895               vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i]));
896           filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
897         }
898         // Load the inputs, add input_offset, duplicate 2-fold.
899         const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
900         local_input_ptr += 8;
901         const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
902         const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
903         const int16x8x2_t input_dup2 = vzipq_s16(input, input);
904         // Load the accumulators from acc_buffer.
905         int32x4x2_t acc[2];
906         for (int i = 0; i < 2; i++) {
907           acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
908           acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
909         }
910         // Multiply-accumulate.
911         for (int j = 0; j < 2; j++) {
912           acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(filter[j]),
913                                     vget_low_s16(input_dup2.val[j]));
914           acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(filter[j]),
915                                     vget_high_s16(input_dup2.val[j]));
916         }
917         // Store the accumulators back to acc_buffer.
918         for (int i = 0; i < 2; i++) {
919           vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
920           vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
921         }
922         acc_buffer_ptr += 16;
923       }
924       // Handle one input channel at a time.
925       for (; ic < input_depth; ic++) {
926         // Load the inputs.
927         const int16 input_val = *local_input_ptr++ + input_offset;
928         for (int i = 0; i < 2; i++) {
929           const int16 filter_val = local_filter_ptr[i] + filter_offset;
930           *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
931         }
932         local_filter_ptr += 2;
933       }
934       input_ptr += input_ptr_increment;
935     }
936   }
937 };
938 
939 template <>
940 struct QuantizedDepthwiseConvKernel<true, 0, 1> {
941   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
942                   const uint8* input_ptr, int16 input_offset,
943                   int input_ptr_increment, const uint8* filter_ptr,
944                   int16 filter_offset, int32* acc_buffer_ptr) {
945     // Handle one output pixel at a time.
946     for (int outp = 0; outp < num_output_pixels; outp++) {
947       const uint8* local_filter_ptr = filter_ptr;
948       const uint8* local_input_ptr = input_ptr;
949       int ic = 0;
950       // Handle 16 input channels at a time.
951       for (; ic <= input_depth - 16; ic += 16) {
952         // Load the filters, add filter_offset.
953         uint8x8_t filter_u8_0 = vld1_u8(local_filter_ptr + 8 * 0);
954         uint8x8_t filter_u8_1 = vld1_u8(local_filter_ptr + 8 * 1);
955         local_filter_ptr += 16;
956         int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
957         int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
958         filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
959         filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
960         // Load the inputs, add input_offset.
961         uint8x8_t input_u8_0 = vld1_u8(local_input_ptr + 8 * 0);
962         uint8x8_t input_u8_1 = vld1_u8(local_input_ptr + 8 * 1);
963         local_input_ptr += 16;
964         int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0));
965         int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1));
966         input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
967         input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
968         // Load the accumulators from acc_buffer
969         int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
970         int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
971         int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
972         int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
973         acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), vget_low_s16(filter_0));
974         acc_1 =
975             vmlal_s16(acc_1, vget_high_s16(input_0), vget_high_s16(filter_0));
976         acc_2 = vmlal_s16(acc_2, vget_low_s16(input_1), vget_low_s16(filter_1));
977         acc_3 =
978             vmlal_s16(acc_3, vget_high_s16(input_1), vget_high_s16(filter_1));
979         // Store the accumulators back to acc_buffer
980         vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
981         vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
982         vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
983         vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
984         acc_buffer_ptr += 16;
985       }
986       // Handle 8 input channels at a time.
987       for (; ic <= input_depth - 8; ic += 8) {
988         // Load the filters, add filter_offset.
989         const uint8x8_t filter_u8 = vld1_u8(local_filter_ptr);
990         local_filter_ptr += 8;
991         const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
992         const int16x8_t filter =
993             vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
994         // Load the inputs, add input_offset.
995         const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
996         local_input_ptr += 8;
997         const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
998         const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
999         // Load the accumulators from acc_buffer
1000         int32x4_t acc[2];
1001         for (int i = 0; i < 2; i++) {
1002           acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1003         }
1004         // Multiply-accumulate
1005         acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
1006         acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
1007         // Store the accumulators back to acc_buffer
1008         for (int i = 0; i < 2; i++) {
1009           vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1010         }
1011         acc_buffer_ptr += 8;
1012       }
1013       // Handle one input channel at a time.
1014       for (; ic < input_depth; ic++) {
1015         const int16 input_val = *local_input_ptr++ + input_offset;
1016         const int16 filter_val = *local_filter_ptr++ + filter_offset;
1017         *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
1018       }
1019       input_ptr += input_ptr_increment;
1020     }
1021   }
1022 };
1023 
1024 template <>
1025 struct QuantizedDepthwiseConvKernel<true, 16, 1> {
1026   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1027                   const uint8* input_ptr, int16 input_offset,
1028                   int input_ptr_increment, const uint8* filter_ptr,
1029                   int16 filter_offset, int32* acc_buffer_ptr) {
1030     // Load the filters, add filter_offset.
1031     uint8x8_t filter_u8[2];
1032     for (int i = 0; i < 2; i++) {
1033       filter_u8[i] = vld1_u8(filter_ptr + 8 * i);
1034     }
1035     int16x8_t filter[2];
1036     for (int i = 0; i < 2; i++) {
1037       filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i]));
1038     }
1039     for (int i = 0; i < 2; i++) {
1040       filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset));
1041     }
1042     // Handle one output pixel at a time.
1043     for (int outp = 0; outp < num_output_pixels; outp++) {
1044       // Load the inputs, add input_offset.
1045       uint8x8_t input_u8[2];
1046       for (int i = 0; i < 2; i++) {
1047         input_u8[i] = vld1_u8(input_ptr + 8 * i);
1048       }
1049       input_ptr += input_ptr_increment;
1050       int16x8_t input[2];
1051       for (int i = 0; i < 2; i++) {
1052         input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
1053       }
1054       for (int i = 0; i < 2; i++) {
1055         input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
1056       }
1057       // Load the accumulators from acc_buffer
1058       int32x4_t acc[4];
1059       for (int i = 0; i < 4; i++) {
1060         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1061       }
1062       // Multiply-accumulate
1063       for (int i = 0; i < 2; i++) {
1064         acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(input[i]),
1065                                    vget_low_s16(filter[i]));
1066         acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(input[i]),
1067                                    vget_high_s16(filter[i]));
1068       }
1069       // Store the accumulators back to acc_buffer
1070       for (int i = 0; i < 4; i++) {
1071         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1072       }
1073       acc_buffer_ptr += 16;
1074     }
1075   }
1076 };
1077 
1078 template <>
1079 struct QuantizedDepthwiseConvKernel<true, 8, 1> {
1080   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1081                   const uint8* input_ptr, int16 input_offset,
1082                   int input_ptr_increment, const uint8* filter_ptr,
1083                   int16 filter_offset, int32* acc_buffer_ptr) {
1084     // Load the filters, add filter_offset.
1085     const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
1086     const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
1087     const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
1088     // Handle one output pixel at a time.
1089     for (int outp = 0; outp < num_output_pixels; outp++) {
1090       // Load the inputs, add input_offset.
1091       const uint8x8_t input_u8 = vld1_u8(input_ptr);
1092       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
1093       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
1094       // Load the accumulators from acc_buffer
1095       int32x4_t acc[2];
1096       for (int i = 0; i < 2; i++) {
1097         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1098       }
1099       // Multiply-accumulate
1100       acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
1101       acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
1102       // Store the accumulators back to acc_buffer
1103       for (int i = 0; i < 2; i++) {
1104         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1105       }
1106       acc_buffer_ptr += 8;
1107       input_ptr += input_ptr_increment;
1108     }
1109   }
1110 };
1111 
1112 template <>
1113 struct QuantizedDepthwiseConvKernel<true, 1, 16> {
1114   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1115                   const uint8* input_ptr, int16 input_offset,
1116                   int input_ptr_increment, const uint8* filter_ptr,
1117                   int16 filter_offset, int32* acc_buffer_ptr) {
1118     // Load the filters, add filter_offset.
1119     uint8x8_t filter_u8[2];
1120     for (int i = 0; i < 2; i++) {
1121       filter_u8[i] = vld1_u8(filter_ptr + 8 * i);
1122     }
1123     int16x8_t filter[2];
1124     for (int i = 0; i < 2; i++) {
1125       filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i]));
1126     }
1127     for (int i = 0; i < 2; i++) {
1128       filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset));
1129     }
1130     // Handle one output pixel at a time.
1131     for (int outp = 0; outp < num_output_pixels; outp++) {
1132       uint8 input_u8 = *input_ptr;
1133       input_ptr += input_ptr_increment;
1134       int16 input = static_cast<int16>(input_u8 + input_offset);
1135       // Load the accumulators from acc_buffer
1136       int32x4_t acc[4];
1137       for (int i = 0; i < 4; i++) {
1138         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1139       }
1140       // Multiply-accumulate
1141       for (int i = 0; i < 2; i++) {
1142         acc[2 * i + 0] =
1143             vmlal_n_s16(acc[2 * i + 0], vget_low_s16(filter[i]), input);
1144         acc[2 * i + 1] =
1145             vmlal_n_s16(acc[2 * i + 1], vget_high_s16(filter[i]), input);
1146       }
1147       // Store the accumulators back to acc_buffer
1148       for (int i = 0; i < 4; i++) {
1149         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1150       }
1151       acc_buffer_ptr += 16;
1152     }
1153   }
1154 };
1155 
1156 template <>
1157 struct QuantizedDepthwiseConvKernel<true, 1, 32> {
1158   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1159                   const uint8* input_ptr, int16 input_offset,
1160                   int input_ptr_increment, const uint8* filter_ptr,
1161                   int16 filter_offset, int32* acc_buffer_ptr) {
1162     // Load the filters, add filter_offset.
1163     uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0);
1164     uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1);
1165     uint8x8_t filter_u8_2 = vld1_u8(filter_ptr + 8 * 2);
1166     uint8x8_t filter_u8_3 = vld1_u8(filter_ptr + 8 * 3);
1167     int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
1168     int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
1169     int16x8_t filter_2 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_2));
1170     int16x8_t filter_3 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_3));
1171     filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
1172     filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
1173     filter_2 = vaddq_s16(filter_2, vdupq_n_s16(filter_offset));
1174     filter_3 = vaddq_s16(filter_3, vdupq_n_s16(filter_offset));
1175     // Handle one output pixel at a time.
1176     for (int outp = 0; outp < num_output_pixels; outp++) {
1177       uint8 input_u8 = *input_ptr;
1178       input_ptr += input_ptr_increment;
1179       int16 input = static_cast<int16>(input_u8 + input_offset);
1180       // Load the accumulators from acc_buffer
1181       int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1182       int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1183       int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1184       int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
1185       int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4);
1186       int32x4_t acc_5 = vld1q_s32(acc_buffer_ptr + 4 * 5);
1187       int32x4_t acc_6 = vld1q_s32(acc_buffer_ptr + 4 * 6);
1188       int32x4_t acc_7 = vld1q_s32(acc_buffer_ptr + 4 * 7);
1189       // Multiply-accumulate
1190       acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input);
1191       acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input);
1192       acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input);
1193       acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input);
1194       acc_4 = vmlal_n_s16(acc_4, vget_low_s16(filter_2), input);
1195       acc_5 = vmlal_n_s16(acc_5, vget_high_s16(filter_2), input);
1196       acc_6 = vmlal_n_s16(acc_6, vget_low_s16(filter_3), input);
1197       acc_7 = vmlal_n_s16(acc_7, vget_high_s16(filter_3), input);
1198       // Store the accumulators back to acc_buffer
1199       vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1200       vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1201       vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1202       vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
1203       vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4);
1204       vst1q_s32(acc_buffer_ptr + 4 * 5, acc_5);
1205       vst1q_s32(acc_buffer_ptr + 4 * 6, acc_6);
1206       vst1q_s32(acc_buffer_ptr + 4 * 7, acc_7);
1207       acc_buffer_ptr += 32;
1208     }
1209   }
1210 };
1211 
1212 template <>
1213 struct QuantizedDepthwiseConvKernel<true, 1, 20> {
1214   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1215                   const uint8* input_ptr, int16 input_offset,
1216                   int input_ptr_increment, const uint8* filter_ptr,
1217                   int16 filter_offset, int32* acc_buffer_ptr) {
1218     // Load the filters, add filter_offset.
1219     // NEON wants to load 8 bytes at a time, but 20 is not divisible by 8.
1220     // We load the first 16 bytes into filter_u8_{0,1} as usual.
1221     // Then we load the 8 last bytes into filter_u8_x  (x for 'extra').
1222     // This is redundant: the first 4 bytes of filter_u8_x are the same
1223     // as the last 4 bytes of filter_u8_x.
1224     uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0);
1225     uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1);
1226     uint8x8_t filter_u8_x = vld1_u8(filter_ptr + 8 * 1 + 4);
1227     int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
1228     int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
1229     int16x8_t filter_x = vreinterpretq_s16_u16(vmovl_u8(filter_u8_x));
1230     filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
1231     filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
1232     filter_x = vaddq_s16(filter_x, vdupq_n_s16(filter_offset));
1233     // Handle one output pixel at a time.
1234     for (int outp = 0; outp < num_output_pixels; outp++) {
1235       uint8 input_u8 = *input_ptr;
1236       input_ptr += input_ptr_increment;
1237       int16 input = static_cast<int16>(input_u8 + input_offset);
1238       // Load the accumulators from acc_buffer
1239       int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1240       int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1241       int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1242       int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
1243       int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4);
1244       // Multiply-accumulate
1245       acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input);
1246       acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input);
1247       acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input);
1248       acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input);
1249       acc_4 = vmlal_n_s16(acc_4, vget_high_s16(filter_x), input);
1250       // Store the accumulators back to acc_buffer
1251       vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1252       vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1253       vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1254       vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
1255       vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4);
1256       acc_buffer_ptr += 20;
1257     }
1258   }
1259 };
1260 
1261 template <>
1262 struct QuantizedDepthwiseConvKernel<true, 1, 8> {
1263   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1264                   const uint8* input_ptr, int16 input_offset,
1265                   int input_ptr_increment, const uint8* filter_ptr,
1266                   int16 filter_offset, int32* acc_buffer_ptr) {
1267     // Load the filters, add filter_offset.
1268     const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
1269     const int16x8_t filter = vaddq_s16(
1270         vreinterpretq_s16_u16(vmovl_u8(filter_u8)), vdupq_n_s16(filter_offset));
1271     // Handle one output pixel at a time.
1272     for (int outp = 0; outp < num_output_pixels; outp++) {
1273       uint8 input_u8 = *input_ptr;
1274       input_ptr += input_ptr_increment;
1275       int16 input = static_cast<int16>(input_u8 + input_offset);
1276       // Load the accumulators from acc_buffer
1277       int32x4_t acc[2];
1278       for (int i = 0; i < 2; i++) {
1279         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1280       }
1281       // Multiply-accumulate
1282       acc[0] = vmlal_n_s16(acc[0], vget_low_s16(filter), input);
1283       acc[1] = vmlal_n_s16(acc[1], vget_high_s16(filter), input);
1284       // Store the accumulators back to acc_buffer
1285       for (int i = 0; i < 2; i++) {
1286         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1287       }
1288       acc_buffer_ptr += 8;
1289     }
1290   }
1291 };
1292 
1293 template <>
1294 struct QuantizedDepthwiseConvKernel<true, 2, 1> {
1295   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1296                   const uint8* input_ptr, int16 input_offset,
1297                   int input_ptr_increment, const uint8* filter_ptr,
1298                   int16 filter_offset, int32* acc_buffer_ptr) {
1299     // Load the filters, add filter_offset.
1300     uint8x8_t filter_u8 = vdup_n_u8(0);
1301     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
1302     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
1303     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
1304     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
1305     const int16x4_t filter_s16 =
1306         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
1307     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
1308 
1309     int outp = 0;
1310 
1311     // Handle 2 output pixels at a time.
1312     for (; outp <= num_output_pixels - 2; outp += 2) {
1313       // Load the accumulators from acc_buffer.
1314       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
1315       // Load the inputs, add input_offset.
1316       uint16x4_t input_u16 = vdup_n_u16(0);
1317       input_u16 = vset_lane_u16((reinterpret_cast<const uint16*>(input_ptr))[0],
1318                                 input_u16, 0);
1319       input_ptr += input_ptr_increment;
1320       input_u16 = vset_lane_u16((reinterpret_cast<const uint16*>(input_ptr))[0],
1321                                 input_u16, 1);
1322       input_ptr += input_ptr_increment;
1323       const int16x4_t input_s16 = vreinterpret_s16_u16(
1324           vget_low_u16(vmovl_u8(vreinterpret_u8_u16(input_u16))));
1325       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1326 
1327       // Multiply-accumulate.
1328       acc = vmlal_s16(acc, filter, input);
1329       // Store the accumulators back to acc_buffer.
1330       vst1q_s32(acc_buffer_ptr, acc);
1331       acc_buffer_ptr += 4;
1332     }
1333 
1334     // Handle 1 output pixel at a time.
1335     for (; outp < num_output_pixels; outp++) {
1336       // Load the accumulators from acc_buffer.
1337       int32x2_t acc = vld1_s32(acc_buffer_ptr);
1338       // Load the inputs, add input_offset.
1339       uint8x8_t input_u8 = vdup_n_u8(0);
1340       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
1341       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
1342       input_ptr += input_ptr_increment;
1343       const int16x4_t input_s16 =
1344           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
1345       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1346 
1347       // Multiply-accumulate.
1348       acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
1349       // Store the accumulators back to acc_buffer.
1350       vst1_s32(acc_buffer_ptr, acc);
1351       acc_buffer_ptr += 2;
1352     }
1353   }
1354 };
1355 
1356 template <>
1357 struct QuantizedDepthwiseConvKernel<true, 4, 1> {
1358   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1359                   const uint8* input_ptr, int16 input_offset,
1360                   int input_ptr_increment, const uint8* filter_ptr,
1361                   int16 filter_offset, int32* acc_buffer_ptr) {
1362     if (num_output_pixels <= 0) {
1363       return;
1364     }
1365 
1366     // Load the filters, add filter_offset.
1367     uint8x8_t filter_u8 = vdup_n_u8(0);
1368     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
1369     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
1370     filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
1371     filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
1372     const int16x4_t filter_s16 =
1373         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
1374     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
1375 
1376     int outp = 0;
1377 
1378     // Handle one output pixel at a time until second to the last pixel. Second
1379     // to the last because we read eight input pixels while only processing
1380     // four.
1381     for (; outp < num_output_pixels - 1; outp++) {
1382       // Load the accumulators from acc_buffer
1383       int32x4_t acc;
1384       acc = vld1q_s32(acc_buffer_ptr);
1385 
1386       // Load the inputs, add input_offset.
1387       uint8x8_t input_u8 = vld1_u8(input_ptr);
1388       input_ptr += input_ptr_increment;
1389       const int16x4_t input_s16 =
1390           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
1391       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1392       // Multiply-accumulate
1393       acc = vmlal_s16(acc, filter, input);
1394       // Store the accumulators back to acc_buffer
1395       vst1q_s32(acc_buffer_ptr, acc);
1396       acc_buffer_ptr += 4;
1397     }
1398 
1399     // Handle the last output pixel.
1400     // Load the accumulators from acc_buffer
1401     int32x4_t acc;
1402     acc = vld1q_s32(acc_buffer_ptr);
1403 
1404     // Load the inputs, add input_offset.
1405     uint8x8_t input_u8 = vdup_n_u8(0);
1406     input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
1407     input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
1408     input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
1409     input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
1410     const int16x4_t input_s16 =
1411         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
1412     const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1413     // Multiply-accumulate
1414     acc = vmlal_s16(acc, filter, input);
1415     // Store the accumulators back to acc_buffer
1416     vst1q_s32(acc_buffer_ptr, acc);
1417   }
1418 };
1419 
1420 template <>
1421 struct QuantizedDepthwiseConvKernel<false, 12, 1> {
1422   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1423                   const uint8* input_ptr, int16 input_offset,
1424                   int input_ptr_increment, const uint8* filter_ptr,
1425                   int16 filter_offset, int32* acc_buffer_ptr) {
1426     // Load the filters, add filter_offset.
1427     uint8x8_t filter_u8_0 = vld1_u8(filter_ptr);
1428     uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 4);
1429     int16x8_t filter_s16_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
1430     int16x8_t filter_s16_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
1431     filter_s16_0 = vaddq_s16(filter_s16_0, vdupq_n_s16(filter_offset));
1432     filter_s16_1 = vaddq_s16(filter_s16_1, vdupq_n_s16(filter_offset));
1433     int16x4_t filter_0 = vget_low_s16(filter_s16_0);
1434     int16x4_t filter_1 = vget_high_s16(filter_s16_0);
1435     int16x4_t filter_2 = vget_high_s16(filter_s16_1);
1436 
1437     // Handle one output pixel at a time.
1438     for (int outp = 0; outp < num_output_pixels; outp++) {
1439       // Load the inputs, add input_offset.
1440       uint8x8_t input_u8_0 = vld1_u8(input_ptr);
1441       uint8x8_t input_u8_1 = vld1_u8(input_ptr + 4);
1442       input_ptr += input_ptr_increment;
1443       int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0));
1444       int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1));
1445       input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
1446       input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
1447 
1448       // Load the accumulators from acc_buffer
1449       int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1450       int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1451       int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1452 
1453       // Multiply-accumulate
1454       acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), filter_0);
1455       acc_1 = vmlal_s16(acc_1, vget_high_s16(input_0), filter_1);
1456       acc_2 = vmlal_s16(acc_2, vget_high_s16(input_1), filter_2);
1457 
1458       // Store the accumulators back to acc_buffer
1459       vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1460       vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1461       vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1462 
1463       acc_buffer_ptr += 12;
1464     }
1465   }
1466 };
1467 #endif
1468 
1469 // Accumulates the effect of one row of the filter, on a segment of one row
1470 // of the output, accessing the corresponding one row of the input.
1471 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
1472 void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor,
1473                                     int input_depth, int input_width,
1474                                     const uint8* input_data, int16 input_offset,
1475                                     int pad_width, int depth_multiplier,
1476                                     int filter_width, const uint8* filter_data,
1477                                     int16 filter_offset, int out_x_buffer_start,
1478                                     int out_x_buffer_end, int output_depth,
1479                                     int32* acc_buffer) {
1480   ruy::profiler::ScopeLabel label(__PRETTY_FUNCTION__);
1481   // Consistency check parameters. This is important in particular to ensure
1482   // that we keep the number of template instantiations minimal, so we don't
1483   // increase binary size unnecessarily.
1484   static_assert(kFixedDepthMultiplier || !kFixedInputDepth, "");
1485   static_assert(kFixedInputDepth || kAllowStrided, "");
1486   TFLITE_DCHECK(stride == 1 || kAllowStrided);
1487   if (kFixedInputDepth) {
1488     TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth);
1489   }
1490   if (kFixedDepthMultiplier) {
1491     TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier);
1492   }
1493   TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
1494   const int input_ptr_increment = stride * input_depth;
1495   const uint8* filter_base_ptr = filter_data;
1496   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
1497     // For the current (filter_x, filter_y) point in the filter,
1498     // compute the boundaries of the corresponding output row segment.
1499     int out_x_loop_start_unclamped = 0;
1500     int out_x_loop_end_unclamped = 0;
1501     if (kAllowStrided) {
1502       if (stride == 2) {
1503         out_x_loop_start_unclamped =
1504             (pad_width - dilation_factor * filter_x + 1) / 2;
1505         out_x_loop_end_unclamped =
1506             (pad_width + input_width - dilation_factor * filter_x + 1) / 2;
1507       } else if (stride == 4) {
1508         out_x_loop_start_unclamped =
1509             (pad_width - dilation_factor * filter_x + 3) / 4;
1510         out_x_loop_end_unclamped =
1511             (pad_width + input_width - dilation_factor * filter_x + 3) / 4;
1512       } else {
1513         out_x_loop_start_unclamped =
1514             (pad_width - dilation_factor * filter_x + stride - 1) / stride;
1515         out_x_loop_end_unclamped = (pad_width + input_width -
1516                                     dilation_factor * filter_x + stride - 1) /
1517                                    stride;
1518       }
1519     } else {
1520       out_x_loop_start_unclamped = pad_width - dilation_factor * filter_x;
1521       out_x_loop_end_unclamped =
1522           pad_width + input_width - dilation_factor * filter_x;
1523     }
1524     // The kernel will have to iterate on the segment of the
1525     // output row that starts at out_x_loop_start and out_x_loop_end.
1526     const int out_x_loop_start =
1527         std::max(out_x_buffer_start, out_x_loop_start_unclamped);
1528     const int out_x_loop_end =
1529         std::min(out_x_buffer_end, out_x_loop_end_unclamped);
1530 
1531     int32* acc_buffer_ptr =
1532         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
1533     const int in_x_origin =
1534         (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
1535     const uint8* input_ptr = input_data + in_x_origin * input_depth;
1536     const int num_output_pixels = out_x_loop_end - out_x_loop_start;
1537     QuantizedDepthwiseConvKernel<
1538         kAllowStrided, kFixedInputDepth,
1539         kFixedDepthMultiplier>::Run(num_output_pixels, input_depth,
1540                                     depth_multiplier, input_ptr, input_offset,
1541                                     input_ptr_increment, filter_base_ptr,
1542                                     filter_offset, acc_buffer_ptr);
1543     filter_base_ptr += output_depth;
1544   }
1545 }
1546 
1547 // generic fallback of DepthwiseConvAccumRow, portable, non-templatized.
1548 inline void QuantizedDepthwiseConvAccumRowGeneric(
1549     int stride, int dilation_factor, int input_depth, int input_width,
1550     const uint8* input_data, int16 input_offset, int pad_width,
1551     int depth_multiplier, int filter_width, const uint8* filter_data,
1552     int16 filter_offset, int out_x_buffer_start, int out_x_buffer_end,
1553     int output_depth, int32* acc_buffer) {
1554   ruy::profiler::ScopeLabel label("DepthwiseConvAccumRowGeneric (slow)");
1555   const uint8* filter_base_ptr = filter_data;
1556   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
1557     const int out_x_loop_start = std::max(
1558         out_x_buffer_start,
1559         (pad_width - dilation_factor * filter_x + stride - 1) / stride);
1560     const int out_x_loop_end = std::min(
1561         out_x_buffer_end,
1562         (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
1563             stride);
1564 
1565     int32* acc_buffer_ptr =
1566         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
1567     const int in_x_origin =
1568         (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
1569     const uint8* input_ptr = input_data + in_x_origin * input_depth;
1570     const int input_ptr_increment = (stride - 1) * input_depth;
1571     for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
1572       const uint8* filter_ptr = filter_base_ptr;
1573       for (int ic = 0; ic < input_depth; ++ic) {
1574         const int16 input_val = *input_ptr++ + input_offset;
1575         for (int m = 0; m < depth_multiplier; m++) {
1576           const int16 filter_val = *filter_ptr++ + filter_offset;
1577           *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
1578         }
1579       }
1580       input_ptr += input_ptr_increment;
1581     }
1582     filter_base_ptr += output_depth;
1583   }
1584 }
1585 
1586 // Initializes the accumulator buffer with bias values.
1587 inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
1588                                        const int32* bias_data,
1589                                        int32* acc_buffer) {
1590   int i = 0;
1591 #ifdef USE_NEON
1592   if (output_depth == 1) {
1593     const int32x4_t b = vdupq_n_s32(bias_data[0]);
1594     for (; i <= num_output_pixels - 16; i += 16) {
1595       vst1q_s32(acc_buffer + i + 0, b);
1596       vst1q_s32(acc_buffer + i + 4, b);
1597       vst1q_s32(acc_buffer + i + 8, b);
1598       vst1q_s32(acc_buffer + i + 12, b);
1599     }
1600     for (; i <= num_output_pixels - 4; i += 4) {
1601       vst1q_s32(acc_buffer + i, b);
1602     }
1603   } else if (output_depth == 2) {
1604     int32x4_t b = vdupq_n_s32(bias_data[0]);
1605     b = vsetq_lane_s32(bias_data[1], b, 1);
1606     b = vsetq_lane_s32(bias_data[1], b, 3);
1607     for (; i <= num_output_pixels - 8; i += 8) {
1608       vst1q_s32(acc_buffer + 2 * i + 0, b);
1609       vst1q_s32(acc_buffer + 2 * i + 4, b);
1610       vst1q_s32(acc_buffer + 2 * i + 8, b);
1611       vst1q_s32(acc_buffer + 2 * i + 12, b);
1612     }
1613     for (; i <= num_output_pixels - 2; i += 2) {
1614       vst1q_s32(acc_buffer + 2 * i, b);
1615     }
1616   } else if (output_depth == 4) {
1617     const int32x4_t b = vld1q_s32(bias_data);
1618     for (; i <= num_output_pixels - 4; i += 4) {
1619       vst1q_s32(acc_buffer + 4 * i + 0, b);
1620       vst1q_s32(acc_buffer + 4 * i + 4, b);
1621       vst1q_s32(acc_buffer + 4 * i + 8, b);
1622       vst1q_s32(acc_buffer + 4 * i + 12, b);
1623     }
1624     for (; i < num_output_pixels; i++) {
1625       vst1q_s32(acc_buffer + 4 * i, b);
1626     }
1627   } else if (output_depth == 8) {
1628     const int32x4_t b0 = vld1q_s32(bias_data);
1629     const int32x4_t b1 = vld1q_s32(bias_data + 4);
1630     for (; i <= num_output_pixels - 2; i += 2) {
1631       vst1q_s32(acc_buffer + 8 * i + 0, b0);
1632       vst1q_s32(acc_buffer + 8 * i + 4, b1);
1633       vst1q_s32(acc_buffer + 8 * i + 8, b0);
1634       vst1q_s32(acc_buffer + 8 * i + 12, b1);
1635     }
1636     for (; i < num_output_pixels; i++) {
1637       vst1q_s32(acc_buffer + 8 * i + 0, b0);
1638       vst1q_s32(acc_buffer + 8 * i + 4, b1);
1639     }
1640   } else if (output_depth == 16) {
1641     const int32x4_t b0 = vld1q_s32(bias_data);
1642     const int32x4_t b1 = vld1q_s32(bias_data + 4);
1643     const int32x4_t b2 = vld1q_s32(bias_data + 8);
1644     const int32x4_t b3 = vld1q_s32(bias_data + 12);
1645     for (; i < num_output_pixels; i++) {
1646       vst1q_s32(acc_buffer + 16 * i + 0, b0);
1647       vst1q_s32(acc_buffer + 16 * i + 4, b1);
1648       vst1q_s32(acc_buffer + 16 * i + 8, b2);
1649       vst1q_s32(acc_buffer + 16 * i + 12, b3);
1650     }
1651   }
1652 #endif
1653   for (; i < num_output_pixels; i++) {
1654     memcpy(acc_buffer + i * output_depth, bias_data,
1655            sizeof(acc_buffer[0]) * output_depth);
1656   }
1657 }
1658 
1659 inline void DepthwiseConvGeneral(
1660     const DepthwiseParams& params, const RuntimeShape& input_shape,
1661     const uint8* input_data, const RuntimeShape& filter_shape,
1662     const uint8* filter_data, const RuntimeShape& bias_shape,
1663     const int32* bias_data, const RuntimeShape& output_shape,
1664     uint8* output_data, int thread_start, int thread_end, int thread_dim) {
1665   const int stride_width = params.stride_width;
1666   const int stride_height = params.stride_height;
1667   const int pad_width = params.padding_values.width;
1668   const int pad_height = params.padding_values.height;
1669   const int depth_multiplier = params.depth_multiplier;
1670   const int32 output_activation_min = params.quantized_activation_min;
1671   const int32 output_activation_max = params.quantized_activation_max;
1672   const int32 input_offset = params.input_offset;
1673   const int32 filter_offset = params.weights_offset;
1674   const int32 output_offset = params.output_offset;
1675   const int32 output_multiplier = params.output_multiplier;
1676   const int output_shift = params.output_shift;
1677   const int dilation_width_factor = params.dilation_width_factor;
1678   const int dilation_height_factor = params.dilation_height_factor;
1679   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
1680   const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
1681   const int input_height = input_shape.Dims(1);
1682   const int input_width = input_shape.Dims(2);
1683   const int input_depth = input_shape.Dims(3);
1684   const int filter_height = filter_shape.Dims(1);
1685   const int filter_width = filter_shape.Dims(2);
1686   const int output_height = output_shape.Dims(1);
1687   const int output_width = output_shape.Dims(2);
1688 #ifdef USE_NEON
1689   const bool shift_left = (output_shift > 0);
1690   const int32 multiplier_power_of_two = shift_left ? (1 << output_shift) : 1;
1691 #endif
1692 
1693   // The default Accbuffer size is 2048, will allocate a bigger memory if it's
1694   // not enough.
1695   // TODO(b/136089667): If output_depth > 2048 happens a lot, we should just use
1696   // a scratch tensor.
1697   static const int kStackAccBufferSize = 2048;
1698   int acc_buffer_size = kStackAccBufferSize;
1699   int32 stack_acc_buffer[kStackAccBufferSize];
1700   int32* acc_buffer = stack_acc_buffer;
1701   std::unique_ptr<int32[]> heap_acc_buffer;
1702   if (kStackAccBufferSize < output_depth) {
1703     heap_acc_buffer.reset(new int32[output_depth]);
1704     acc_buffer = heap_acc_buffer.get();
1705     acc_buffer_size = output_depth;
1706   }
1707   const int kOutputPixelsInAccBuffer = acc_buffer_size / output_depth;
1708   const int acc_buffer_size_actually_used =
1709       kOutputPixelsInAccBuffer * output_depth;
1710   TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth,
1711                    acc_buffer_size_actually_used);
1712   TFLITE_DCHECK_LE(acc_buffer_size_actually_used, acc_buffer_size);
1713   TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1);
1714   TFLITE_DCHECK(thread_dim == 0 || thread_dim == 1);
1715 
1716   // row_accum_func will point to the core accumulation function to be used
1717   // for this DepthwiseConv op.
1718   using row_accum_func_t = decltype(&QuantizedDepthwiseConvAccumRowGeneric);
1719   row_accum_func_t row_accum_func = nullptr;
1720 
1721 #define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
1722                                         FIXED_DEPTH_MULTIPLIER)           \
1723   if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) &&          \
1724       (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) &&     \
1725       depth_multiplier == FIXED_DEPTH_MULTIPLIER) {                       \
1726     row_accum_func =                                                      \
1727         QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH,  \
1728                                        FIXED_DEPTH_MULTIPLIER>;           \
1729   }
1730 
1731 #ifdef USE_NEON
1732   // We go over our list of kernels by decreasing order of preference
1733   // for the cases where multiple kernels could apply.
1734 
1735   // Start with the fastest kernels: AllowStrided=false, fixed input depth.
1736 
1737   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 2)
1738   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 2)
1739   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 2)
1740   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 4)
1741   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 1)
1742   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 4)
1743   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1)
1744   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 8)
1745   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1)
1746   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 12, 1)
1747 
1748   // Next come the strided kernels: AllowStrided=true, fixed input depth.
1749   // They are a bit less efficient, but allow stride!=1.
1750 
1751   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2)
1752   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1)
1753   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16)
1754   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 20)
1755   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32)
1756   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8)
1757   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1)
1758   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1)
1759   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1)
1760 
1761   // Finally, the kernels allowing a variable input depth,
1762   // these are the least efficient but most general kernels.
1763 
1764   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1)
1765   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2)
1766   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 3)
1767 #endif  // USE_NEON
1768 
1769   // No matching fast kernel found, use slow fallback.
1770   if (!row_accum_func) {
1771     row_accum_func = QuantizedDepthwiseConvAccumRowGeneric;
1772   }
1773 
1774 #undef TFMINI_USE_DEPTHWISECONV_KERNEL
1775 
1776   const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
1777   const int input_batch_stride = input_height_stride * input_shape.Dims(1);
1778   const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
1779 
1780   // Now that we have determined row_accum_func, we can start work.
1781   int batch_start = 0;
1782   int batch_end = batches;
1783   int row_start = 0;
1784   int row_end = output_height;
1785   int output_ptr_offset = 0;
1786 
1787   switch (thread_dim) {
1788     case 0:
1789       // Multithread along with the batch axis
1790       TFLITE_DCHECK_GE(thread_start, 0);
1791       TFLITE_DCHECK_LE(thread_end, batches);
1792       batch_start = thread_start;
1793       batch_end = thread_end;
1794       output_ptr_offset = batch_start * FlatSizeSkipDim(output_shape, 0);
1795       break;
1796     case 1:
1797       // Multithread along with the row axis
1798       TFLITE_DCHECK_GE(thread_start, 0);
1799       TFLITE_DCHECK_LE(thread_end, output_height);
1800       row_start = thread_start;
1801       row_end = thread_end;
1802       output_ptr_offset = row_start * output_width * output_depth;
1803       break;
1804   }
1805 
1806   uint8* output_ptr = output_data + output_ptr_offset;
1807   int batch_step =
1808       (output_height + row_start - row_end) * output_width * output_depth;
1809   for (int b = batch_start; b < batch_end; ++b) {
1810     for (int out_y = row_start; out_y < row_end; ++out_y) {
1811       const int in_y_origin = (out_y * stride_height) - pad_height;
1812       const int filter_y_start =
1813           std::max(0, (-in_y_origin + dilation_height_factor - 1) /
1814                           dilation_height_factor);
1815       const int filter_y_end =
1816           std::min(filter_height,
1817                    (input_height - in_y_origin + dilation_height_factor - 1) /
1818                        dilation_height_factor);
1819       for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
1820            out_x_buffer_start += kOutputPixelsInAccBuffer) {
1821         const int out_x_buffer_end = std::min(
1822             output_width, out_x_buffer_start + kOutputPixelsInAccBuffer);
1823         // We call a 'pixel' a group of activation that share all but the
1824         // 'depth'/'channel' coordinate. num_output_pixels is the number of
1825         // output pixels that we will accumulate in this loop iteration.
1826         const int num_output_pixels = out_x_buffer_end - out_x_buffer_start;
1827         // Initialize our local accumulator with the bias values, so we don't
1828         // have to add them later.
1829         DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data,
1830                                    acc_buffer);
1831         // Accumulation loop. Most of the time should be spent in here.
1832         for (int filter_y = filter_y_start; filter_y < filter_y_end;
1833              ++filter_y) {
1834           const int in_y = in_y_origin + dilation_height_factor * filter_y;
1835           row_accum_func(
1836               stride_width, dilation_width_factor, input_depth, input_width,
1837               input_data + in_y * input_height_stride + b * input_batch_stride,
1838               input_offset, pad_width, depth_multiplier, filter_width,
1839               filter_data + filter_y * filter_height_stride, filter_offset,
1840               out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer);
1841         }
1842         // Finished accumulating int32 values. Now need to convert them to
1843         // the final 8bit form and store them.
1844         ruy::profiler::ScopeLabel label("downquantize+store");
1845         const int num_output_values = output_depth * num_output_pixels;
1846         int i = 0;
1847 #ifdef USE_NEON
1848         using gemmlowp::RoundingDivideByPOT;
1849         const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
1850         const int32x4_t output_activation_min_vec =
1851             vdupq_n_s32(output_activation_min);
1852         const int32x4_t output_activation_max_vec =
1853             vdupq_n_s32(output_activation_max);
1854         // Handle 16 values at once.
1855         // This allows us to issue 4 mutually independent int32
1856         // multiplications (vqrdmulh), which should alleviate most of their
1857         // high latency.
1858         for (; i <= num_output_values - 16; i += 16) {
1859           int32x4_t acc[4];
1860           for (int j = 0; j < 4; j++) {
1861             acc[j] = vld1q_s32(acc_buffer + i + 4 * j);
1862           }
1863 
1864           if (!shift_left) {
1865             // Fixed-point multiplication.
1866             for (int j = 0; j < 4; j++) {
1867               acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
1868             }
1869             for (int j = 0; j < 4; j++) {
1870               acc[j] = RoundingDivideByPOT(acc[j], -output_shift);
1871             }
1872           } else {
1873             // Fixed-point multiplication.
1874             for (int j = 0; j < 4; j++) {
1875               acc[j] = vmulq_n_s32(acc[j], multiplier_power_of_two);
1876               acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
1877             }
1878           }
1879           // Add the output offset.
1880           for (int j = 0; j < 4; j++) {
1881             acc[j] = vaddq_s32(acc[j], output_offset_vec);
1882           }
1883           // Apply the activation function.
1884           for (int j = 0; j < 4; j++) {
1885             acc[j] = vmaxq_s32(acc[j], output_activation_min_vec);
1886           }
1887           for (int j = 0; j < 4; j++) {
1888             acc[j] = vminq_s32(acc[j], output_activation_max_vec);
1889           }
1890           // Saturating cast to uint8 and store to destination.
1891           int16x4_t acc_s16[4];
1892           for (int j = 0; j < 4; j++) {
1893             acc_s16[j] = vqmovn_s32(acc[j]);
1894           }
1895           const int16x8_t res_s16_0 = vcombine_s16(acc_s16[0], acc_s16[1]);
1896           const int16x8_t res_s16_1 = vcombine_s16(acc_s16[2], acc_s16[3]);
1897           const uint8x8_t res_u8_0 = vqmovun_s16(res_s16_0);
1898           const uint8x8_t res_u8_1 = vqmovun_s16(res_s16_1);
1899           vst1q_u8(output_ptr, vcombine_u8(res_u8_0, res_u8_1));
1900           output_ptr += 16;
1901         }
1902         // Handle 8 values at once.
1903         // Not as good as 16 (now we're only issuing 2 mutually independent
1904         // vqrdmulh instructions, so we're probably paying for their high
1905         // latency).
1906         for (; i <= num_output_values - 8; i += 8) {
1907           int32x4_t acc0 = vld1q_s32(acc_buffer + i);
1908           int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4);
1909           if (!shift_left) {
1910             // Fixed-point multiplication.
1911             acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
1912             acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
1913             // Rounding right shift.
1914             acc0 = RoundingDivideByPOT(acc0, -output_shift);
1915             acc1 = RoundingDivideByPOT(acc1, -output_shift);
1916           } else {
1917             // Fixed-point multiplication.
1918             acc0 = vmulq_n_s32(acc0, multiplier_power_of_two);
1919             acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
1920 
1921             acc1 = vmulq_n_s32(acc1, multiplier_power_of_two);
1922             acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
1923           }
1924           // Add the output offset.
1925           acc0 = vaddq_s32(acc0, output_offset_vec);
1926           acc1 = vaddq_s32(acc1, output_offset_vec);
1927           // Apply the activation function.
1928           acc0 = vmaxq_s32(acc0, output_activation_min_vec);
1929           acc1 = vmaxq_s32(acc1, output_activation_min_vec);
1930           acc0 = vminq_s32(acc0, output_activation_max_vec);
1931           acc1 = vminq_s32(acc1, output_activation_max_vec);
1932           // Saturating cast to uint8 and store to destination.
1933           const int16x4_t acc0_s16 = vqmovn_s32(acc0);
1934           const int16x4_t acc1_s16 = vqmovn_s32(acc1);
1935           const int16x8_t res_s16 = vcombine_s16(acc0_s16, acc1_s16);
1936           const uint8x8_t res_u8 = vqmovun_s16(res_s16);
1937           vst1_u8(output_ptr, res_u8);
1938           output_ptr += 8;
1939         }
1940         // Handle 4 values at once. Now we're paying the full price of the
1941         // high latency of vqrdmulh. Also, storing only 4 bytes at the end
1942         // (without any alignment) can only be done 1 byte at a time.
1943         // Yet, that is still worth doing to minimize the amount of leftover
1944         // that will have to go through the very slow scalar code.
1945         for (; i <= num_output_values - 4; i += 4) {
1946           int32x4_t acc = vld1q_s32(acc_buffer + i);
1947           if (!shift_left) {
1948             // Fixed-point multiplication.
1949             acc = vqrdmulhq_n_s32(acc, output_multiplier);
1950             // Rounding right shift.
1951             acc = RoundingDivideByPOT(acc, -output_shift);
1952           } else {
1953             // Fixed-point multiplication.
1954             acc = vmulq_n_s32(acc, multiplier_power_of_two);
1955             acc = vqrdmulhq_n_s32(acc, output_multiplier);
1956           }
1957           // Add the output offset.
1958           acc = vaddq_s32(acc, output_offset_vec);
1959           // Apply the activation function.
1960           acc = vmaxq_s32(acc, output_activation_min_vec);
1961           acc = vminq_s32(acc, output_activation_max_vec);
1962           // Saturating cast to uint8 and store to destination.
1963           const int16x4_t acc_s16 = vqmovn_s32(acc);
1964           const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16);
1965           const uint8x8_t res_u8 = vqmovun_s16(res_s16);
1966           vst1_lane_u8(output_ptr + 0, res_u8, 0);
1967           vst1_lane_u8(output_ptr + 1, res_u8, 1);
1968           vst1_lane_u8(output_ptr + 2, res_u8, 2);
1969           vst1_lane_u8(output_ptr + 3, res_u8, 3);
1970           output_ptr += 4;
1971         }
1972 #endif  // USE_NEON
1973 
1974         // Handle leftover values, one by one. This is very slow.
1975         for (; i < num_output_values; i++) {
1976           int32 acc = acc_buffer[i];
1977           acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
1978                                               output_shift);
1979           acc += output_offset;
1980           acc = std::max(acc, output_activation_min);
1981           acc = std::min(acc, output_activation_max);
1982           *output_ptr++ = static_cast<uint8>(acc);
1983         }
1984       }
1985     }
1986     output_ptr += batch_step;
1987   }
1988 }
1989 
1990 }  // namespace depthwise_conv
1991 
1992 template <DepthwiseConvOutputRounding kOutputRounding>
1993 inline void DepthwiseConvWithRounding(
1994     const DepthwiseParams& params, const RuntimeShape& input_shape,
1995     const uint8* input_data, const RuntimeShape& filter_shape,
1996     const uint8* filter_data, const RuntimeShape& bias_shape,
1997     const int32* bias_data, const RuntimeShape& output_shape,
1998     uint8* output_data, const CpuFlags& cpu_flags, int thread_start,
1999     int thread_end, int thread_dim) {
2000   ruy::profiler::ScopeLabel label("DepthwiseConv/8bit");
2001   const int depth_multiplier = params.depth_multiplier;
2002   const int32 output_activation_min = params.quantized_activation_min;
2003   const int32 output_activation_max = params.quantized_activation_max;
2004   const int dilation_width_factor = params.dilation_width_factor;
2005   const int dilation_height_factor = params.dilation_height_factor;
2006   TFLITE_DCHECK_GE(dilation_width_factor, 1);
2007   TFLITE_DCHECK_GE(dilation_height_factor, 1);
2008   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2009   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2010   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2011   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
2012   const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
2013   const int input_depth = input_shape.Dims(3);
2014   TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
2015   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
2016 
2017 // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
2018 // Jetson TX-2. This compiler does not support the offsetof() macro.
2019 #if defined(__aarch64__) && !defined(GOOGLE_L4T)
2020 #if defined(__ANDROID__) && defined(__clang__)
2021   // Dispatch to dot-product 3x3 kernels when supported.
2022   if (cpu_flags.neon_dotprod) {
2023     using optimized_ops::depthwise_conv::DotProduct3x3KernelType;
2024     DotProduct3x3KernelType kernel_type =
2025         optimized_ops::depthwise_conv::CategorizeDotProductKernel(
2026             input_shape, filter_shape, output_shape, params);
2027     if (kernel_type != DotProduct3x3KernelType::kNone) {
2028       ruy::profiler::ScopeLabel specialized_label(
2029           "DepthwiseConv/8bit/3x3XDotProduct");
2030       optimized_ops::depthwise_conv::DepthwiseConvDotProduct3x3<
2031           DepthwiseConvImplementation::kUseNeon3x3DotProduct>(
2032           params, input_shape, input_data, filter_shape, filter_data,
2033           bias_shape, bias_data, output_shape, output_data, thread_start,
2034           thread_end, thread_dim);
2035       return;
2036     }
2037   }
2038 
2039 #endif
2040   // Dispatch to non-dot-product 3x3 kernels when supported.
2041 
2042   const int stride_width = params.stride_width;
2043   const int stride_height = params.stride_height;
2044   const int pad_width = params.padding_values.width;
2045   const int pad_height = params.padding_values.height;
2046   const int output_shift = params.output_shift;
2047 
2048   // Call kernel optimized for depthwise convolutions using 3x3 filters if
2049   // parameters are supported.
2050   if (depthwise_conv::Fast3x3FilterKernelSupported(
2051           input_shape, filter_shape, stride_width, stride_height,
2052           dilation_width_factor, dilation_height_factor, pad_width, pad_height,
2053           depth_multiplier, output_shape, output_shift)) {
2054     ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/3x3");
2055     depthwise_conv::DepthwiseConv3x3Filter<kOutputRounding>(
2056         params, input_shape, input_data, filter_shape, filter_data, bias_shape,
2057         bias_data, output_shape, output_data, thread_start, thread_end,
2058         thread_dim);
2059     return;
2060   }
2061 #endif
2062 
2063   ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/General");
2064   depthwise_conv::DepthwiseConvGeneral(params, input_shape, input_data,
2065                                        filter_shape, filter_data, bias_shape,
2066                                        bias_data, output_shape, output_data,
2067                                        thread_start, thread_end, thread_dim);
2068 }
2069 
2070 inline void DepthwiseConvImpl(
2071     const DepthwiseParams& params, const RuntimeShape& input_shape,
2072     const uint8* input_data, const RuntimeShape& filter_shape,
2073     const uint8* filter_data, const RuntimeShape& bias_shape,
2074     const int32* bias_data, const RuntimeShape& output_shape,
2075     uint8* output_data, const CpuFlags& cpu_flags, int thread_start,
2076     int thread_end, int thread_dim) {
2077   return DepthwiseConvWithRounding<DepthwiseConvOutputRounding::kUpward>(
2078       params, input_shape, input_data, filter_shape, filter_data, bias_shape,
2079       bias_data, output_shape, output_data, cpu_flags, thread_start, thread_end,
2080       thread_dim);
2081 }
2082 
2083 }  // namespace optimized_ops
2084 }  // namespace tflite
2085 
2086 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
2087