1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/int8/pooling_int8.h"
18 #include "nnacl/common_func.h"
19 #include "nnacl/errorcode.h"
20
AvgPoolingInt8(const int8_t * input_ptr,int8_t * output_ptr,const PoolingParameter * pooling_param,int task_id)21 int AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, int task_id) {
22 int stride_w = pooling_param->stride_w_;
23 int stride_h = pooling_param->stride_h_;
24 int pad_w = pooling_param->pad_l_;
25 int pad_h = pooling_param->pad_u_;
26 int win_w = pooling_param->window_w_;
27 int win_h = pooling_param->window_h_;
28 int channel = pooling_param->input_channel_;
29 int in_w = pooling_param->input_w_;
30 int in_h = pooling_param->input_h_;
31 int output_w = pooling_param->output_w_;
32 int output_h = pooling_param->output_h_;
33 int output_batch = pooling_param->output_batch_;
34 int out_plane = output_w * output_h;
35 float input_scale = pooling_param->quant_args_[0][0].scale_;
36 int input_zp = pooling_param->quant_args_[0][0].zp_;
37 float output_scale = pooling_param->quant_args_[1][0].scale_;
38 int output_zp = pooling_param->quant_args_[1][0].zp_;
39 double real_multiplier = input_scale / output_scale;
40 const int8_t out_min = INT8_MIN;
41 const int8_t out_max = INT8_MAX;
42
43 for (int batch = 0; batch < output_batch; batch++) {
44 int in_batch_offset = batch * in_h * in_w * channel;
45 int out_batch_offset = batch * output_h * output_w * channel;
46 for (int i = 0; i < out_plane; i++) {
47 int out_w_index = i % output_w;
48 int out_h_index = i / output_w;
49 int in_w_index = out_w_index * stride_w - pad_w;
50 int in_h_index = out_h_index * stride_h - pad_h;
51 int out_plane_offset = out_batch_offset + i * channel;
52 for (int j = 0; j < channel; j++) {
53 int in_channel_offset = in_batch_offset + j;
54 int out_channel_offset = out_plane_offset + j;
55 int16_t tmp_avg = 0;
56 int real_count = 0;
57 for (int h = 0; h < win_h; h++) {
58 for (int w = 0; w < win_w; w++) {
59 if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) {
60 continue;
61 } else {
62 int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
63 tmp_avg += *(input_ptr + in_offset);
64 ++real_count;
65 }
66 } // win_w loop
67 } // win_h loop
68 if (real_count == 0) {
69 return NNACL_ERR;
70 }
71 int16_t tmp_out = round((float)tmp_avg / (float)real_count);
72 tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp);
73 int8_t real_out = tmp_out < out_min ? out_min : tmp_out;
74 real_out = real_out > out_max ? out_max : real_out;
75 *(output_ptr + out_channel_offset) = real_out;
76 } // in_channel loop
77 } // out_plane loop
78 } // out_batch loop
79 return NNACL_OK;
80 }
81
AvgPoolingOptInt8(const int8_t * input_ptr,int8_t * output_ptr,const PoolingParameter * pooling_param,int task_id)82 int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, int task_id) {
83 int win_w = pooling_param->window_w_;
84 int win_h = pooling_param->window_h_;
85 int channel = pooling_param->input_channel_;
86 int c16 = channel / C16NUM;
87 int in_w = pooling_param->input_w_;
88 int output_w = pooling_param->output_w_;
89 int out_plane = output_w * pooling_param->output_h_;
90 int out_tile_count = UP_DIV(out_plane, TILE_NUM);
91 int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
92 int input_zp = pooling_param->quant_args_[0][0].zp_;
93 int output_zp = pooling_param->quant_args_[1][0].zp_;
94 double real_multiplier = pooling_param->quant_args_[0][0].scale_ / pooling_param->quant_args_[1][0].scale_;
95 const int8_t out_min = INT8_MIN;
96 const int8_t out_max = INT8_MAX;
97 NNACL_CHECK_ZERO_RETURN_ERR(output_w);
98 for (int batch = 0; batch < pooling_param->output_batch_; batch++) {
99 int in_batch_offset = batch * pooling_param->input_h_ * in_w * channel;
100 int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel;
101 for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
102 int cal_start_index = thread_id * TILE_NUM;
103 int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
104 for (int i = 0; i < real_cal_num; i++) {
105 int index = cal_start_index + i;
106 int out_w_index = index % output_w;
107 int out_h_index = index / output_w;
108 int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_;
109 int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_;
110 int out_plane_offset = out_batch_offset + index * channel;
111 int input_stride = (in_h_index * in_w + in_w_index) * channel;
112 int kw_s = MSMAX(0, -in_w_index);
113 int kw_e = MSMIN(win_w, in_w - in_w_index);
114 int kh_s = MSMAX(0, -in_h_index);
115 int kh_e = MSMIN(win_h, pooling_param->input_h_ - in_h_index);
116 int real_count = (kw_e - kw_s) * (kh_e - kh_s);
117 if (real_count == 0) {
118 return NNACL_ERR;
119 }
120
121 // 16 channels
122 for (int j = 0; j < c16; j++) {
123 #ifdef ENABLE_NEON
124 int16x8_t tmp_avg[2];
125 tmp_avg[0] = vmovq_n_s16(0);
126 tmp_avg[1] = vmovq_n_s16(0);
127 #else
128 int16_t tmp_avg[16];
129 int16_t real_out[16];
130 for (int m = 0; m < C16NUM; ++m) {
131 tmp_avg[m] = 0;
132 }
133 #endif
134 int in_channel_offset = in_batch_offset + j * C16NUM;
135 int out_channel_offset = out_plane_offset + j * C16NUM;
136
137 for (int h = kh_s; h < kh_e; h++) {
138 for (int w = kw_s; w < kw_e; w++) {
139 int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
140 #ifdef ENABLE_NEON
141 int8x16_t in_ptr = vld1q_s8(input_ptr + in_offset);
142 int8x8_t in_data1 = vget_low_s8(in_ptr);
143 int8x8_t in_data2 = vget_high_s8(in_ptr);
144 int16x8_t data1 = vmovl_s8(in_data1);
145 int16x8_t data2 = vmovl_s8(in_data2);
146 tmp_avg[0] = vaddq_s16(tmp_avg[0], data1);
147 tmp_avg[1] = vaddq_s16(tmp_avg[1], data2);
148 #else
149 for (int k = 0; k < C16NUM; ++k) {
150 tmp_avg[k] += input_ptr[in_offset + k];
151 }
152 #endif
153 } // win_w loop
154 } // win_h loop
155 #ifdef ENABLE_NEON
156 int16_t tmp_data[8];
157 int16_t tmp_out[8];
158 int16_t tmp_data1[8];
159 int16_t tmp_out1[8];
160 for (int l = 0; l < C8NUM; l++) {
161 tmp_data[l] = tmp_avg[0][l] + 128 * real_count;
162 tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count;
163 tmp_out[l] -= 128;
164 tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp;
165 }
166 for (int l = 0; l < C8NUM; l++) {
167 tmp_data1[l] = tmp_avg[1][l] + 128 * real_count;
168 tmp_out1[l] = (tmp_data1[l] + real_count / 2) / real_count;
169 tmp_out1[l] -= 128;
170 tmp_out1[l] = round((tmp_out1[l] - input_zp) * real_multiplier) + output_zp;
171 }
172 int8x8_t real_out[2];
173 int8x8_t output_min = vdup_n_s8(out_min);
174 int8x8_t output_max = vdup_n_s8(out_max);
175 real_out[0] = vqmovn_s16(vld1q_s16(tmp_out));
176 real_out[0] = vmin_s8(real_out[0], output_max);
177 real_out[0] = vmax_s8(real_out[0], output_min);
178 vst1_s8(output_ptr + out_channel_offset, real_out[0]);
179 real_out[1] = vqmovn_s16(vld1q_s16(tmp_out1));
180 real_out[1] = vmin_s8(real_out[1], output_max);
181 real_out[1] = vmax_s8(real_out[1], output_min);
182 vst1_s8(output_ptr + out_channel_offset + 8, real_out[1]);
183 #else
184 for (int l = 0; l < C16NUM; ++l) {
185 int16_t tmp_data = tmp_avg[l] + 128 * real_count;
186 real_out[l] = (tmp_data + real_count / 2) / real_count - 128;
187 real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp);
188 real_out[l] = real_out[l] < out_min ? out_min : real_out[l];
189 real_out[l] = real_out[l] > out_max ? out_max : real_out[l];
190 *(output_ptr + out_channel_offset + l) = (int8_t)real_out[l];
191 }
192 #endif
193 }
194
195 // 8 channels
196 int channel_16_res = channel - c16 * C16NUM;
197 int c8 = channel_16_res / C8NUM;
198 int in_c16_offset = in_batch_offset + c16 * C16NUM;
199 int out_c16_offset = out_plane_offset + c16 * C16NUM;
200 for (int j = 0; j < c8; j++) {
201 #ifdef ENABLE_NEON
202 int16x8_t tmp_avg = vmovq_n_s16(0);
203 #else
204 int16_t tmp_avg[8] = {0, 0, 0, 0, 0, 0, 0, 0};
205 int16_t real_out[8];
206 #endif
207 int in_channel_offset = in_c16_offset + j * C8NUM;
208 int out_channel_offset = out_c16_offset + j * C8NUM;
209 for (int h = kh_s; h < kh_e; h++) {
210 for (int w = kw_s; w < kw_e; w++) {
211 int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
212 #ifdef ENABLE_NEON
213 int8x8_t in_ptr = vld1_s8(input_ptr + in_offset);
214 int16x8_t data = vmovl_s8(in_ptr);
215 tmp_avg = vaddq_s16(tmp_avg, data);
216 #else
217 for (int k = 0; k < C8NUM; ++k) {
218 tmp_avg[k] += input_ptr[in_offset + k];
219 }
220 #endif
221 } // win_w loop
222 } // win_h loop
223 #ifdef ENABLE_NEON
224 int16_t tmp_data[8];
225 int16_t tmp_out[8];
226 for (int l = 0; l < C8NUM; l++) {
227 tmp_data[l] = tmp_avg[l] + 128 * real_count;
228 tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count;
229 tmp_out[l] -= 128;
230 tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp;
231 }
232 int8x8_t real_out;
233 int8x8_t output_min = vdup_n_s8(out_min);
234 int8x8_t output_max = vdup_n_s8(out_max);
235 real_out = vqmovn_s16(vld1q_s16(tmp_out));
236 real_out = vmin_s8(real_out, output_max);
237 real_out = vmax_s8(real_out, output_min);
238 vst1_s8(output_ptr + out_channel_offset, real_out);
239 #else
240 for (int l = 0; l < C8NUM; ++l) {
241 int16_t tmp_data = tmp_avg[l] + 128 * real_count;
242 real_out[l] = (tmp_data + real_count / 2) / real_count - 128;
243 real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp);
244 real_out[l] = real_out[l] < out_min ? out_min : real_out[l];
245 real_out[l] = real_out[l] > out_max ? out_max : real_out[l];
246 *(output_ptr + out_channel_offset + l) = (int8_t)real_out[l];
247 }
248 #endif
249 }
250
251 // less than 8 channel
252 int channel_8_res = channel_16_res - c8 * C8NUM;
253 int in_c8_offset = in_c16_offset + c8 * C8NUM;
254 int out_c8_offset = out_c16_offset + c8 * C8NUM;
255 for (int k = 0; k < channel_8_res; k++) {
256 int in_channel_offset = in_c8_offset + k;
257 int out_channel_offset = out_c8_offset + k;
258 int16_t tmp_avg = 0;
259 for (int h = kh_s; h < kh_e; h++) {
260 for (int w = kw_s; w < kw_e; w++) {
261 int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
262 tmp_avg += input_ptr[in_offset];
263 } // win_w loop
264 } // win_h loop
265 int16_t tmp_out = round((float)tmp_avg / (float)real_count + 128) - 128;
266 tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp);
267 int16_t real_out = tmp_out < out_min ? out_min : tmp_out;
268 real_out = real_out > out_max ? out_max : real_out;
269 *(output_ptr + out_channel_offset) = (int8_t)real_out;
270 } // channel_res loop
271 } // out_plane loop
272 } // out_batch loop
273 }
274 return NNACL_OK;
275 }
276
MaxPoolingInt8(const int8_t * input_ptr,int8_t * output_ptr,PoolingParameter * pooling_param,int task_id)277 void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) {
278 int stride_w = pooling_param->stride_w_;
279 int stride_h = pooling_param->stride_h_;
280 int pad_w = pooling_param->pad_l_;
281 int pad_h = pooling_param->pad_u_;
282 int win_w = pooling_param->window_w_;
283 int win_h = pooling_param->window_h_;
284 int channel = pooling_param->input_channel_;
285 int in_w = pooling_param->input_w_;
286 int in_h = pooling_param->input_h_;
287 int output_w = pooling_param->output_w_;
288 int output_h = pooling_param->output_h_;
289 int output_batch = pooling_param->output_batch_;
290 int out_plane = output_w * output_h;
291 // input channel is equal to output channel
292 float input_scale = pooling_param->quant_args_[0][0].scale_;
293 int input_zp = pooling_param->quant_args_[0][0].zp_;
294 float output_scale = pooling_param->quant_args_[1][0].scale_;
295 int output_zp = pooling_param->quant_args_[1][0].zp_;
296 double real_multiplier = input_scale / output_scale;
297
298 for (int batch = 0; batch < output_batch; batch++) {
299 int in_batch_offset = batch * in_h * in_w * channel;
300 int out_batch_offset = batch * output_h * output_w * channel;
301 for (int i = 0; i < out_plane; i++) {
302 int out_w_index = i % output_w;
303 int out_h_index = i / output_w;
304 int in_w_index = out_w_index * stride_w - pad_w;
305 int in_h_index = out_h_index * stride_h - pad_h;
306 int out_plane_offset = out_batch_offset + i * channel;
307 for (int j = 0; j < channel; j++) {
308 int in_channel_offset = in_batch_offset + j;
309 int out_channel_offset = out_plane_offset + j;
310 int8_t tmp_max = INT8_MIN;
311 for (int h = 0; h < win_h; h++) {
312 for (int w = 0; w < win_w; w++) {
313 if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) {
314 continue;
315 } else {
316 int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
317 tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset));
318 }
319 } // win_w loop
320 } // win_h loop
321 *(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp);
322 } // in_channel loop
323 } // out_plane loop
324 } // out_batch loop
325 }
326
MaxPoolingWithQuantInt8(const int8_t * input_ptr,int8_t * output_ptr,PoolingParameter * pooling_param,int task_id)327 void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param,
328 int task_id) {
329 int channel = pooling_param->input_channel_;
330 int in_w = pooling_param->input_w_;
331 int in_h = pooling_param->input_h_;
332 int output_w = pooling_param->output_w_;
333 int out_plane = output_w * pooling_param->output_h_;
334 int out_tile_count = UP_DIV(out_plane, TILE_NUM);
335 int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
336 int c16 = UP_DIV(channel, 16);
337 // input channel is equal to output channel
338 float input_scale = pooling_param->quant_args_[0][0].scale_;
339 int input_zp = pooling_param->quant_args_[0][0].zp_;
340 float output_scale = pooling_param->quant_args_[1][0].scale_;
341 int output_zp = pooling_param->quant_args_[1][0].zp_;
342 double real_multiplier = input_scale / output_scale;
343
344 NNACL_CHECK_ZERO_RETURN(output_w);
345 for (int batch = 0; batch < pooling_param->output_batch_; batch++) {
346 int in_batch_offset = batch * in_h * in_w * channel;
347 int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel;
348 for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
349 int cal_start_index = thread_id * TILE_NUM;
350 int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
351 for (int i = 0; i < real_cal_num; i++) {
352 int index = cal_start_index + i;
353 int out_w_index = index % output_w;
354 int out_h_index = index / output_w;
355 int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_;
356 int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_;
357 int out_plane_offset = out_batch_offset + index * channel;
358 for (int j = 0; j < c16 - 1; j++) {
359 int in_channel_offset = in_batch_offset + j * 16;
360 int out_channel_offset = out_plane_offset + j * 16;
361 #ifdef ENABLE_NEON
362 int8x16_t tmp_max = vdupq_n_s8(INT8_MIN);
363 #else
364 int8_t tmp_max[16];
365 for (int m = 0; m < C16NUM; ++m) {
366 tmp_max[m] = INT8_MIN;
367 }
368 #endif
369 for (int h = 0; h < pooling_param->window_h_; h++) {
370 for (int w = 0; w < pooling_param->window_w_; w++) {
371 if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
372 (in_w_index + w) >= in_w) {
373 continue;
374 } else {
375 int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
376 #ifdef ENABLE_NEON
377 tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset));
378 #else
379 for (int k = 0; k < C16NUM; ++k) {
380 tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k));
381 }
382 #endif
383 }
384 } // win_w loop
385 } // win_h loop
386 #ifdef ENABLE_NEON
387 for (int l = 0; l < C16NUM; ++l) {
388 tmp_max[l] = (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp);
389 }
390 vst1q_s8(output_ptr + out_channel_offset, tmp_max);
391 #else
392 for (int l = 0; l < C16NUM; ++l) {
393 *(output_ptr + out_channel_offset + l) =
394 (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp);
395 }
396 #endif
397 } // in_channel loop
398
399 // res channel
400 int channel_s = (c16 - 1) * 16;
401 for (int k = channel_s; k < channel; k++) {
402 int in_channel_offset = in_batch_offset + k;
403 int out_channel_offset = out_plane_offset + k;
404 int8_t tmp_max = INT8_MIN;
405 for (int h = 0; h < pooling_param->window_h_; h++) {
406 for (int w = 0; w < pooling_param->window_w_; w++) {
407 if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
408 (in_w_index + w) >= in_w) {
409 continue;
410 } else {
411 int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
412 tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset));
413 }
414 } // win_w loop
415 } // win_h loop
416 *(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp);
417 } // channel_res loop
418 } // out_plane loop
419 } // out_batch loop
420 }
421 }
422
MaxPoolingOptInt8(const int8_t * input_ptr,int8_t * output_ptr,PoolingParameter * pooling_param,int task_id)423 void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) {
424 int channel = pooling_param->input_channel_;
425 int in_w = pooling_param->input_w_;
426 int output_w = pooling_param->output_w_;
427 int out_plane = output_w * pooling_param->output_h_;
428 int out_tile_count = UP_DIV(out_plane, TILE_NUM);
429 int thread_num = MSMIN(out_tile_count, pooling_param->thread_num_);
430 int8_t out_array[MAX_MAXPOOL_SIZE];
431
432 NNACL_CHECK_ZERO_RETURN(output_w);
433 for (int batch = 0; batch < pooling_param->output_batch_; batch++) {
434 int in_batch_offset = batch * pooling_param->input_h_ * in_w * channel;
435 int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel;
436 for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
437 int cal_start_index = thread_id * TILE_NUM;
438 int real_cal_num = out_plane - cal_start_index;
439 real_cal_num = MSMIN(real_cal_num, TILE_NUM);
440 for (int i = 0; i < real_cal_num; i++) {
441 int index = cal_start_index + i;
442 int out_w_index = index % output_w;
443 int out_h_index = index / output_w;
444 int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_;
445 int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_;
446 const int ky_s = 0 > (-in_h_index) ? 0 : (-in_h_index);
447 int ky_e = MSMIN(pooling_param->window_h_, pooling_param->input_h_ - in_h_index);
448 const int kx_s = 0 > (-in_w_index) ? 0 : (-in_w_index);
449 int kx_e = MSMIN(pooling_param->window_w_, in_w - in_w_index);
450 int input_stride = (in_h_index * in_w + in_w_index) * channel + in_batch_offset;
451 int out_plane_offset = out_batch_offset + index * channel;
452
453 int c = 0;
454 for (; c < channel; c += MAX_MAXPOOL_SIZE) {
455 int real_channel = channel - c;
456 real_channel = MSMIN(real_channel, MAX_MAXPOOL_SIZE);
457 memset(out_array, INT8_MIN, real_channel);
458 int8_t *out_data = output_ptr + out_plane_offset + c;
459 for (int h = ky_s; h < ky_e; ++h) {
460 int in_h_offset = input_stride + h * in_w * channel + c;
461 for (int w = kx_s; w < kx_e; ++w) {
462 const int8_t *in_data = input_ptr + in_h_offset + w * channel;
463 int j = 0;
464 #ifdef ENABLE_NEON
465 const int8_t *tmp_in_data = in_data;
466 int c16 = real_channel / 16 * 16;
467 int c8 = real_channel / 8 * 8;
468 for (; j < c16; j += 16) {
469 int8x16_t ori_in = vld1q_s8(tmp_in_data);
470 int8x16_t out_array16 = vld1q_s8(out_array + j);
471 tmp_in_data += 16;
472 out_array16 = vmaxq_s8(ori_in, out_array16);
473 vst1q_s8(out_array + j, out_array16);
474 } // 16 channel loop
475
476 for (; j < c8; j += 8) {
477 int8x8_t ori_in = vld1_s8(tmp_in_data);
478 int8x8_t out_array8 = vld1_s8(out_array + j);
479 tmp_in_data += 8;
480 out_array8 = vmax_s8(ori_in, out_array8);
481 vst1_s8(out_array + j, out_array8);
482 } // 8 channel loop
483 #endif
484 for (; j < real_channel; ++j) {
485 out_array[j] = out_array[j] > in_data[j] ? out_array[j] : in_data[j];
486 }
487 } // kw loop
488 } // kh loop
489
490 int j = 0;
491 #ifdef ENABLE_NEON
492 int c16 = real_channel / 16 * 16;
493 int c8 = real_channel / 8 * 8;
494 int8_t *tmp_out_data = out_data;
495 for (; j < c16; j += 16) {
496 vst1q_s8(tmp_out_data, vld1q_s8(out_array + j));
497 tmp_out_data += 16;
498 } // 16 channel loop
499
500 for (; j < c8; j += 8) {
501 vst1_s8(tmp_out_data, vld1_s8(out_array + j));
502 tmp_out_data += 8;
503 } // 8 channel loop
504 #endif
505 for (; j < real_channel; ++j) {
506 out_data[j] = out_array[j];
507 }
508 } // 256 channel loop
509 } // out_plane loop
510 } // out_batch loop
511 }
512 }
513