• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/int8/arithmetic_int8.h"
18 #ifdef ENABLE_NEON
19 #include <arm_neon.h>
20 #endif
21 #include "nnacl/errorcode.h"
22 
TileOneDimensionInt8(const int8_t * inData,int8_t * outData,int dim,size_t ndim,const int * inShape,const int * inStrides,const int * outStrides,const int * multiple)23 void TileOneDimensionInt8(const int8_t *inData, int8_t *outData, int dim, size_t ndim, const int *inShape,
24                           const int *inStrides, const int *outStrides, const int *multiple) {
25   int srcDimSize = inShape[dim];
26   if (dim == ndim - 1) {
27     for (int i = 0; i < multiple[dim]; i++) {
28       memcpy(outData, inData, srcDimSize * sizeof(int8_t));
29       outData += srcDimSize;
30     }
31     return;
32   }
33   for (size_t i = 0; i < srcDimSize; i++) {
34     for (size_t j = 0; j < multiple[dim]; j++) {
35       TileOneDimensionInt8(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim,
36                            inShape, inStrides, outStrides, multiple);
37     }
38   }
39 }
40 
TileDimensionsInt8(const int8_t * data0,const int8_t * data1,int8_t * tile_data0,int8_t * tile_data1,ArithmeticParameter * param)41 void TileDimensionsInt8(const int8_t *data0, const int8_t *data1, int8_t *tile_data0, int8_t *tile_data1,
42                         ArithmeticParameter *param) {
43   CalcMultiplesAndStrides(param);
44   TileOneDimensionInt8(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_,
45                        param->multiples0_);
46   TileOneDimensionInt8(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_,
47                        param->multiples1_);
48 }
49 
50 #define ACCURACY_DATA 0.00000001
51 
ElementNotEqualInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)52 int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
53                         ArithmeticQuantArg *quant_arg) {
54   float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
55   float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
56 
57   for (int index = 0; index < element_size; ++index) {
58     float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
59     float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
60     float minus_inputs = in0_real - in1_real;
61     bool out_real = true;
62     if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) {
63       out_real = false;
64     }
65     output[index] = (uint8_t)out_real;
66   }
67   return NNACL_OK;
68 }
69 
ElementEqualInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)70 int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) {
71   float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
72   float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
73   for (int index = 0; index < element_size; ++index) {
74     float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
75     float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
76     float minus_inputs = in0_real - in1_real;
77     bool out_real = false;
78     if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) {
79       out_real = true;
80     }
81     output[index] = (uint8_t)out_real;
82   }
83   return NNACL_OK;
84 }
85 
ElementLessInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)86 int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) {
87   float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
88   float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
89   for (int index = 0; index < element_size; ++index) {
90     float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
91     float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
92     bool out_real = in0_real < in1_real;
93     output[index] = (uint8_t)out_real;
94   }
95   return NNACL_OK;
96 }
97 
ElementLessEqualInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)98 int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
99                          ArithmeticQuantArg *quant_arg) {
100   float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
101   float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
102   for (int index = 0; index < element_size; ++index) {
103     float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
104     float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
105     bool out_real = in0_real <= in1_real;
106     output[index] = (uint8_t)out_real;
107   }
108   return NNACL_OK;
109 }
110 
ElementGreaterInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)111 int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
112                        ArithmeticQuantArg *quant_arg) {
113   float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
114   float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
115   for (int index = 0; index < element_size; ++index) {
116     float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
117     float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
118     bool out_real = in0_real > in1_real;
119     output[index] = (uint8_t)out_real;
120   }
121   return NNACL_OK;
122 }
123 
ElementGreaterEqualInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)124 int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
125                             ArithmeticQuantArg *quant_arg) {
126   float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
127   float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
128   for (int index = 0; index < element_size; ++index) {
129     float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
130     float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
131     bool out_real = in0_real >= in1_real;
132     output[index] = (uint8_t)out_real;
133   }
134   return NNACL_OK;
135 }
136 
137 #undef ACCURACY_DATA
138