• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdint.h>
18 #include "nnacl/int8/reduce_int8.h"
19 #include "nnacl/errorcode.h"
20 #include "nnacl/int8/fixed_point.h"
21 #include "nnacl/common_func.h"
22 
ReduceMeanN(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)23 int ReduceMeanN(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
24   return NNACL_OK;
25 }
ReduceMeanH(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)26 int ReduceMeanH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
27   return NNACL_OK;
28 }
ReduceMeanW(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)29 int ReduceMeanW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
30   return NNACL_OK;
31 }
ReduceMeanC(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)32 int ReduceMeanC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
33   return NNACL_OK;
34 }
ReduceMeanNH(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)35 int ReduceMeanNH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
36   return NNACL_OK;
37 }
ReduceMeanNW(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)38 int ReduceMeanNW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
39   return NNACL_OK;
40 }
ReduceMeanNC(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)41 int ReduceMeanNC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
42   return NNACL_OK;
43 }
ReduceMeanHW(int n,int plane,int count,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg,int32_t bias)44 int ReduceMeanHW(int n, int plane, int count, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg,
45                  int32_t bias) {
46   int stride = plane * UP_ROUND(c, C4NUM);
47   for (int batch = 0; batch < n; ++batch) {
48     int8_t *in_ptr = in_data + batch * stride;
49     int8_t *out_ptr = out_data + batch * c;
50     for (int i = 0; i < count; ++i) {
51       int32_t sum_array = 0;
52       int j = 0;
53 #ifdef ENABLE_ARM64
54       for (; j < plane; j += 16) {
55         int8x16_t in_data_vec = vld1q_s8(in_ptr);
56         sum_array += vaddlvq_s8(in_data_vec);
57         in_ptr += 16;
58       }
59       for (; j < plane; j += 8) {
60         int8x8_t in_data_vec = vld1_s8(in_ptr);
61         sum_array += vaddlv_s8(in_data_vec);
62         in_ptr += 8;
63       }
64       for (; j < plane; j += 4) {
65         int32x4_t in_data_vec;
66         in_data_vec[0] = in_ptr[0];
67         in_data_vec[1] = in_ptr[1];
68         in_data_vec[2] = in_ptr[2];
69         in_data_vec[3] = in_ptr[3];
70         sum_array += vaddvq_s32(in_data_vec);
71         in_ptr += 4;
72       }
73 #elif ENABLE_ARM32
74       int32x4_t accum = vmovq_n_s32(0);
75       for (; j < plane; j += 16) {
76         int32x4_t in_data_vec1;
77         int32x4_t in_data_vec2;
78         int32x4_t in_data_vec3;
79         int32x4_t in_data_vec4;
80         in_data_vec1[0] = in_ptr[0];
81         in_data_vec1[1] = in_ptr[1];
82         in_data_vec1[2] = in_ptr[2];
83         in_data_vec1[3] = in_ptr[3];
84         in_data_vec2[0] = in_ptr[4];
85         in_data_vec2[1] = in_ptr[5];
86         in_data_vec2[2] = in_ptr[6];
87         in_data_vec2[3] = in_ptr[7];
88         in_data_vec3[0] = in_ptr[8];
89         in_data_vec3[1] = in_ptr[9];
90         in_data_vec3[2] = in_ptr[10];
91         in_data_vec3[3] = in_ptr[11];
92         in_data_vec4[0] = in_ptr[12];
93         in_data_vec4[1] = in_ptr[13];
94         in_data_vec4[2] = in_ptr[14];
95         in_data_vec4[3] = in_ptr[15];
96         accum = vaddq_s32(accum, in_data_vec1);
97         accum = vaddq_s32(accum, in_data_vec2);
98         accum = vaddq_s32(accum, in_data_vec3);
99         accum = vaddq_s32(accum, in_data_vec4);
100         in_ptr += 16;
101       }
102       for (; j < plane; j += 8) {
103         int32x4_t in_data_vec1;
104         int32x4_t in_data_vec2;
105         in_data_vec1[0] = in_ptr[0];
106         in_data_vec1[1] = in_ptr[1];
107         in_data_vec1[2] = in_ptr[2];
108         in_data_vec1[3] = in_ptr[3];
109         in_data_vec2[0] = in_ptr[4];
110         in_data_vec2[1] = in_ptr[5];
111         in_data_vec2[2] = in_ptr[6];
112         in_data_vec2[3] = in_ptr[7];
113         accum = vaddq_s32(accum, in_data_vec1);
114         accum = vaddq_s32(accum, in_data_vec2);
115         in_ptr += 8;
116       }
117       for (; j < plane; j += 4) {
118         int32x4_t in_data_vec;
119         in_data_vec[0] = in_ptr[0];
120         in_data_vec[1] = in_ptr[1];
121         in_data_vec[2] = in_ptr[2];
122         in_data_vec[3] = in_ptr[3];
123         accum = vaddq_s32(accum, in_data_vec);
124         in_ptr += 4;
125       }
126       sum_array += accum[0];
127       sum_array += accum[1];
128       sum_array += accum[2];
129       sum_array += accum[3];
130 #endif
131       for (; j < plane; j++) {
132         sum_array += in_ptr[0];
133         in_ptr++;
134       }
135       int32_t mean =
136         RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum_array * (1 << (unsigned int)quant_arg.left_shift_),
137                                                               quant_arg.multiplier_),
138                             quant_arg.right_shift_);
139       mean += bias;
140       *out_ptr++ = MSMAX(MSMIN(mean, INT8_MAX), INT8_MIN);
141     }
142   }
143   return NNACL_OK;
144 }
145 
ReduceMeanHC(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)146 int ReduceMeanHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
147   return NNACL_OK;
148 }
149 
ReduceMeanWC(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)150 int ReduceMeanWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
151   return NNACL_OK;
152 }
153 
ReduceMeanNHW(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)154 int ReduceMeanNHW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
155   return NNACL_OK;
156 }
157 
ReduceMeanNHC(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)158 int ReduceMeanNHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
159   return NNACL_OK;
160 }
161 
ReduceMeanNWC(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)162 int ReduceMeanNWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
163   return NNACL_OK;
164 }
165 
ReduceMeanHWC(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)166 int ReduceMeanHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
167   return NNACL_OK;
168 }
169 
ReduceMeanNHWC(int n,int h,int w,int c,int8_t * in_data,int8_t * out_data,QuantMulArg quant_arg)170 int ReduceMeanNHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
171   return NNACL_OK;
172 }
173 
174 // Get x such that (x-zp_in) * scale_in = mean
175 // Assuming reduce n axes, this works for first n-1 reduce. One call for one reduce.
ReduceMeanInt8(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int32_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)176 int ReduceMeanInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
177                    int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
178   if (src_data == NULL || dst_data == NULL) {
179     return NNACL_NULL_PTR;
180   }
181   int i, j, k;
182   for (j = tid; j < outer_size; j += thread_num) {
183     const int32_t *outer_src = src_data + j * axis_size * inner_size;
184     int32_t *outer_dst = dst_data + j * inner_size;
185     for (k = 0; k < inner_size; k++) {
186       const int32_t *inner_src = outer_src + k;
187       int32_t *inner_dst = outer_dst + k;
188       int32_t sum = 0;
189       for (i = 0; i < axis_size; i++) {
190         int32_t tmp = inner_src[i * inner_size] - quant->in_zp_;
191         if (isAddOverflow(sum, tmp)) {
192           return NNACL_ERRCODE_ADD_OVERFLOW;
193         }
194         sum += tmp;
195       }
196       int32_t mean = RoundingDivideByPOT(
197         SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->mean_left_shift_), quant->mean_multiplier_),
198         quant->mean_right_shift_);
199       if (isAddOverflow(mean, quant->in_zp_)) {
200         return NNACL_ERRCODE_ADD_OVERFLOW;
201       }
202       *inner_dst = mean + quant->in_zp_;
203     }
204   }
205   return NNACL_OK;
206 }
207 
208 // suppose reduce n axes, this works for last reduce axis.
209 // get y such that (y-zp_out) * scale_out = mean(x-zp_in)*scale_in
ReduceMeanLastAxis(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int8_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)210 int ReduceMeanLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
211                        int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
212   if (src_data == NULL || dst_data == NULL) {
213     return NNACL_NULL_PTR;
214   }
215   int i, j, k;
216   for (j = tid; j < outer_size; j += thread_num) {
217     const int32_t *outer_src = src_data + j * axis_size * inner_size;
218     int8_t *outer_dst = dst_data + j * inner_size;
219     for (k = 0; k < inner_size; k++) {
220       const int32_t *inner_src = outer_src + k;
221       int8_t *inner_dst = outer_dst + k;
222       int32_t sum = 0;
223       for (i = 0; i < axis_size; i++) {
224         int32_t tmp = inner_src[i * inner_size] - quant->in_zp_;
225         if (isAddOverflow(tmp, sum)) {
226           return NNACL_ERRCODE_ADD_OVERFLOW;
227         }
228         sum += tmp;
229       }
230       int32_t mean = RoundingDivideByPOT(
231         SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->mean_left_shift_), quant->mean_multiplier_),
232         quant->mean_right_shift_);
233       // trans to output scale
234       int32_t mean_scaled =
235         RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(mean * (1 << (unsigned int)quant->in_out_left_shift_),
236                                                               quant->in_out_multiplier_),
237                             quant->in_out_right_shift_);
238       if (isAddOverflow(mean_scaled, quant->out_zp_)) {
239         return NNACL_ERRCODE_ADD_OVERFLOW;
240       }
241       mean = mean_scaled + quant->out_zp_;
242 
243       *inner_dst = MSMAX(MSMIN(mean, INT8_MAX), INT8_MIN);
244     }
245   }
246   return NNACL_OK;
247 }
248 
249 // Get x such that (x-zp_in) * scale_in = sum(item-zp_in)*scale_in
250 // Assuming reduce n axes, this works for first n-1 reduce. One call for one reduce.
ReduceSumInt8(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int32_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)251 int ReduceSumInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
252                   int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
253   if (src_data == NULL || dst_data == NULL) {
254     return NNACL_NULL_PTR;
255   }
256   int i, j, k;
257   for (j = tid; j < outer_size; j += thread_num) {
258     const int32_t *outer_src = src_data + j * axis_size * inner_size;
259     int32_t *outer_dst = dst_data + j * inner_size;
260     for (k = 0; k < inner_size; k++) {
261       const int32_t *inner_src = outer_src + k;
262       int32_t *inner_dst = outer_dst + k;
263       int32_t sum = 0;
264       for (i = 0; i < axis_size; i++) {
265         int32_t tmp = inner_src[i * inner_size] - quant->in_zp_;
266         if (isAddOverflow(tmp, sum)) {
267           return NNACL_ERRCODE_ADD_OVERFLOW;
268         }
269         sum += tmp;
270       }
271 
272       if (isAddOverflow(quant->in_zp_, sum)) {
273         return NNACL_ERRCODE_ADD_OVERFLOW;
274       }
275       *inner_dst = sum + quant->in_zp_;
276     }
277   }
278   return NNACL_OK;
279 }
280 
281 // suppose reduce n axes, this works for last reduce axis.
282 // get y such that (y-zp_out) * scale_out = sum(item-zp_in)*scale_in
ReduceSumLastAxis(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int8_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)283 int ReduceSumLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
284                       int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
285   if (src_data == NULL || dst_data == NULL) {
286     return NNACL_NULL_PTR;
287   }
288   int i, j, k;
289   for (j = tid; j < outer_size; j += thread_num) {
290     const int32_t *outer_src = src_data + j * axis_size * inner_size;
291     int8_t *outer_dst = dst_data + j * inner_size;
292     for (k = 0; k < inner_size; k++) {
293       const int32_t *inner_src = outer_src + k;
294       int8_t *inner_dst = outer_dst + k;
295       int32_t sum = 0;
296       for (i = 0; i < axis_size; i++) {
297         int32_t tmp = inner_src[i * inner_size] - quant->in_zp_;
298         if (isAddOverflow(tmp, sum)) {
299           return NNACL_ERRCODE_ADD_OVERFLOW;
300         }
301         sum += tmp;
302       }
303       int32_t sum_scaled =
304         RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->in_out_left_shift_),
305                                                               quant->in_out_multiplier_),
306                             quant->in_out_right_shift_);
307       if (isAddOverflow(sum_scaled, quant->out_zp_)) {
308         return NNACL_ERRCODE_ADD_OVERFLOW;
309       }
310       sum = sum_scaled + quant->out_zp_;
311       if (sum > INT8_MAX) {
312         *inner_dst = INT8_MAX;
313       } else if (sum < INT8_MIN) {
314         *inner_dst = INT8_MIN;
315       } else {
316         *inner_dst = (int8_t)sum;
317       }
318     }
319   }
320   return NNACL_OK;
321 }
322 
ReduceMaxLastAxis(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int8_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)323 int ReduceMaxLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
324                       int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
325   if (src_data == NULL || dst_data == NULL) {
326     return NNACL_NULL_PTR;
327   }
328   int i, j, k;
329   for (j = tid; j < outer_size; j += thread_num) {
330     const int32_t *outer_src = src_data + j * axis_size * inner_size;
331     int8_t *outer_dst = dst_data + j * inner_size;
332     for (k = 0; k < inner_size; k++) {
333       const int32_t *inner_src = outer_src + k;
334       int8_t *inner_dst = outer_dst + k;
335       int32_t tmp = INT8_MIN;
336       for (i = 0; i < axis_size; i++) {
337         tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size];
338       }
339       int32_t tmp_scaled = RoundingDivideByPOT(
340         SaturatingRoundingDoublingHighMul((tmp - quant->in_zp_) * (1 << (unsigned int)quant->in_out_left_shift_),
341                                           quant->in_out_multiplier_),
342         quant->in_out_right_shift_);
343       if (isAddOverflow(tmp_scaled, quant->out_zp_)) {
344         return NNACL_ERRCODE_ADD_OVERFLOW;
345       }
346       tmp = tmp_scaled + quant->out_zp_;
347       if (tmp > INT8_MAX) {
348         *inner_dst = INT8_MAX;
349       } else if (tmp < INT8_MIN) {
350         *inner_dst = INT8_MIN;
351       } else {
352         *inner_dst = (int8_t)tmp;
353       }
354     }
355   }
356   return NNACL_OK;
357 }
358 
ReduceMaxInt8(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int32_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)359 int ReduceMaxInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
360                   int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
361   if (src_data == NULL || dst_data == NULL) {
362     return NNACL_NULL_PTR;
363   }
364   int i, j, k;
365   for (j = tid; j < outer_size; j += thread_num) {
366     const int32_t *outer_src = src_data + j * axis_size * inner_size;
367     int32_t *outer_dst = dst_data + j * inner_size;
368     for (k = 0; k < inner_size; k++) {
369       const int32_t *inner_src = outer_src + k;
370       int32_t *inner_dst = outer_dst + k;
371       int32_t tmp = INT8_MIN;
372       for (i = 0; i < axis_size; i++) {
373         tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size];
374       }
375 
376       *inner_dst = tmp;
377     }
378   }
379   return NNACL_OK;
380 }
381 
ReduceMinLastAxis(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int8_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)382 int ReduceMinLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
383                       int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
384   if (src_data == NULL || dst_data == NULL) {
385     return NNACL_NULL_PTR;
386   }
387   int i, j, k;
388   const int base_offset = 20;
389   for (j = tid; j < outer_size; j += thread_num) {
390     const int32_t *outer_src = src_data + j * axis_size * inner_size;
391     int8_t *outer_dst = dst_data + j * inner_size;
392     for (k = 0; k < inner_size; k++) {
393       const int32_t *inner_src = outer_src + k;
394       int8_t *inner_dst = outer_dst + k;
395       int32_t tmp = INT8_MAX;
396       for (i = 0; i < axis_size; i++) {
397         tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size];
398       }
399       int32_t tmp_scaled =
400         RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
401                               (tmp - quant->in_zp_) * (1 << ((unsigned int)quant->in_out_left_shift_ + base_offset)),
402                               quant->in_out_multiplier_),
403                             quant->in_out_right_shift_ + base_offset);
404       if (isAddOverflow(tmp_scaled, quant->out_zp_)) {
405         return NNACL_ERRCODE_ADD_OVERFLOW;
406       }
407       tmp = tmp_scaled + quant->out_zp_;
408       if (tmp > INT8_MAX) {
409         *inner_dst = INT8_MAX;
410       } else if (tmp < INT8_MIN) {
411         *inner_dst = INT8_MIN;
412       } else {
413         *inner_dst = (int8_t)tmp;
414       }
415     }
416   }
417   return NNACL_OK;
418 }
419 
ReduceMinInt8(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int32_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)420 int ReduceMinInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
421                   int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
422   if (src_data == NULL || dst_data == NULL) {
423     return NNACL_NULL_PTR;
424   }
425   int i, j, k;
426   for (j = tid; j < outer_size; j += thread_num) {
427     const int32_t *outer_src = src_data + j * axis_size * inner_size;
428     int32_t *outer_dst = dst_data + j * inner_size;
429     for (k = 0; k < inner_size; k++) {
430       const int32_t *inner_src = outer_src + k;
431       int32_t *inner_dst = outer_dst + k;
432       int32_t tmp = INT8_MAX;
433       for (i = 0; i < axis_size; i++) {
434         tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size];
435       }
436       *inner_dst = tmp;
437     }
438   }
439   return NNACL_OK;
440 }
441 
ReduceProdLastAxis(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int8_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)442 int ReduceProdLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
443                        int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
444   if (src_data == NULL || dst_data == NULL) {
445     return NNACL_NULL_PTR;
446   }
447   int i, j, k;
448   for (j = tid; j < outer_size; j += thread_num) {
449     const int32_t *outer_src = src_data + j * axis_size * inner_size;
450     int8_t *outer_dst = dst_data + j * inner_size;
451     for (k = 0; k < inner_size; k++) {
452       const int32_t *inner_src = outer_src + k;
453       int8_t *inner_dst = outer_dst + k;
454       int32_t prod = 1;
455       for (i = 0; i < axis_size; i++) {
456         int32_t tmp = inner_src[i * inner_size] - quant->in_zp_;
457         if (isMulOverflow(prod, tmp)) {
458           return NNACL_ERRCODE_MUL_OVERFLOW;
459         }
460         prod *= tmp;
461       }
462       prod = RoundingDivideByPOT(
463         SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->prod_left_shift_), quant->prod_multiplier_),
464         quant->prod_right_shift_);
465       int32_t prod_scaled =
466         RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->in_out_left_shift_),
467                                                               quant->in_out_multiplier_),
468                             quant->in_out_right_shift_);
469       if (isAddOverflow(prod_scaled, quant->out_zp_)) {
470         return NNACL_ERRCODE_ADD_OVERFLOW;
471       }
472       prod = prod_scaled + quant->out_zp_;
473       if (prod > INT8_MAX) {
474         *inner_dst = INT8_MAX;
475       } else if (prod < INT8_MIN) {
476         *inner_dst = INT8_MIN;
477       } else {
478         *inner_dst = (int8_t)prod;
479       }
480     }
481   }
482   return NNACL_OK;
483 }
484 
ReduceProdInt8(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int32_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)485 int ReduceProdInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
486                    int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
487   if (src_data == NULL || dst_data == NULL) {
488     return NNACL_NULL_PTR;
489   }
490   int i, j, k;
491   for (j = tid; j < outer_size; j += thread_num) {
492     const int32_t *outer_src = src_data + j * axis_size * inner_size;
493     int32_t *outer_dst = dst_data + j * inner_size;
494     for (k = 0; k < inner_size; k++) {
495       const int32_t *inner_src = outer_src + k;
496       int32_t *inner_dst = outer_dst + k;
497       int32_t prod = 1;
498       for (i = 0; i < axis_size; i++) {
499         int32_t tmp = inner_src[i * inner_size] - quant->in_zp_;
500         if (isMulOverflow(prod, tmp)) {
501           return NNACL_ERRCODE_MUL_OVERFLOW;
502         }
503         prod *= tmp;
504       }
505       prod = RoundingDivideByPOT(
506         SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->prod_left_shift_), quant->prod_multiplier_),
507         quant->prod_right_shift_);
508       if (isAddOverflow(prod, quant->in_zp_)) {
509         return NNACL_ERRCODE_ADD_OVERFLOW;
510       }
511       *inner_dst = prod + quant->in_zp_;
512     }
513   }
514   return NNACL_OK;
515 }
516 
ReduceSumSquareLastAxis(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int8_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)517 int ReduceSumSquareLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
518                             int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
519   if (src_data == NULL || dst_data == NULL) {
520     return NNACL_NULL_PTR;
521   }
522   int i, j, k;
523   for (j = tid; j < outer_size; j += thread_num) {
524     const int32_t *outer_src = src_data + j * axis_size * inner_size;
525     int8_t *outer_dst = dst_data + j * inner_size;
526     for (k = 0; k < inner_size; k++) {
527       const int32_t *inner_src = outer_src + k;
528       int8_t *inner_dst = outer_dst + k;
529       int32_t sum = 0;
530       for (i = 0; i < axis_size; i++) {
531         int32_t tmp;
532         if (isMulOverflow(inner_src[i * inner_size] - quant->in_zp_, inner_src[i * inner_size] - quant->in_zp_)) {
533           return NNACL_ERRCODE_MUL_OVERFLOW;
534         }
535         tmp = (inner_src[i * inner_size] - quant->in_zp_) * (inner_src[i * inner_size] - quant->in_zp_);
536         if (isAddOverflow(sum, tmp)) {
537           return NNACL_ERRCODE_ADD_OVERFLOW;
538         }
539         sum += tmp;
540       }
541       int32_t sum_scaled =
542         RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->sum_square_left_shift_),
543                                                               quant->sum_square_multiplier_),
544                             quant->sum_square_right_shift_);
545       if (isAddOverflow(sum_scaled, quant->out_zp_)) {
546         return NNACL_ERRCODE_ADD_OVERFLOW;
547       }
548       sum = sum_scaled + quant->out_zp_;
549 
550       if (sum > INT8_MAX) {
551         *inner_dst = INT8_MAX;
552       } else if (sum < INT8_MIN) {
553         *inner_dst = INT8_MIN;
554       } else {
555         *inner_dst = (int8_t)sum;
556       }
557     }
558   }
559   return NNACL_OK;
560 }
561 
ReduceSumSquareInt8(const int outer_size,const int inner_size,const int axis_size,const int32_t * src_data,int32_t * dst_data,const ReduceQuantArg * quant,const int tid,const int thread_num)562 int ReduceSumSquareInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data,
563                         int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) {
564   if (src_data == NULL || dst_data == NULL) {
565     return NNACL_NULL_PTR;
566   }
567   int i, j, k;
568   for (j = tid; j < outer_size; j += thread_num) {
569     const int32_t *outer_src = src_data + j * axis_size * inner_size;
570     int32_t *outer_dst = dst_data + j * inner_size;
571     for (k = 0; k < inner_size; k++) {
572       const int32_t *inner_src = outer_src + k;
573       int32_t *inner_dst = outer_dst + k;
574       int32_t sum = 0;
575       for (i = 0; i < axis_size; i++) {
576         int32_t tmp;
577         if (isMulOverflow(inner_src[i * inner_size] - quant->in_zp_, inner_src[i * inner_size] - quant->in_zp_)) {
578           return NNACL_ERRCODE_MUL_OVERFLOW;
579         }
580         tmp = (inner_src[i * inner_size] - quant->in_zp_) * (inner_src[i * inner_size] - quant->in_zp_);
581         if (isAddOverflow(sum, tmp)) {
582           return NNACL_ERRCODE_ADD_OVERFLOW;
583         }
584         sum += tmp;
585       }
586       sum =
587         RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->sum_square_left_shift_),
588                                                               quant->sum_square_multiplier_),
589                             quant->sum_square_right_shift_);
590       if (isAddOverflow(sum, quant->in_zp_)) {
591         return NNACL_ERRCODE_ADD_OVERFLOW;
592       }
593       *inner_dst = sum + quant->in_zp_;
594     }
595   }
596   return NNACL_OK;
597 }
598