1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/int8/arithmetic_int8.h"
18 #ifdef ENABLE_NEON
19 #include <arm_neon.h>
20 #endif
21 #include "nnacl/errorcode.h"
22
TileOneDimensionInt8(const int8_t * inData,int8_t * outData,int dim,size_t ndim,const int * inShape,const int * inStrides,const int * outStrides,const int * multiple)23 void TileOneDimensionInt8(const int8_t *inData, int8_t *outData, int dim, size_t ndim, const int *inShape,
24 const int *inStrides, const int *outStrides, const int *multiple) {
25 int srcDimSize = inShape[dim];
26 if (dim == ndim - 1) {
27 for (int i = 0; i < multiple[dim]; i++) {
28 memcpy(outData, inData, srcDimSize * sizeof(int8_t));
29 outData += srcDimSize;
30 }
31 return;
32 }
33 for (size_t i = 0; i < srcDimSize; i++) {
34 for (size_t j = 0; j < multiple[dim]; j++) {
35 TileOneDimensionInt8(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim,
36 inShape, inStrides, outStrides, multiple);
37 }
38 }
39 }
40
TileDimensionsInt8(const int8_t * data0,const int8_t * data1,int8_t * tile_data0,int8_t * tile_data1,ArithmeticParameter * param)41 void TileDimensionsInt8(const int8_t *data0, const int8_t *data1, int8_t *tile_data0, int8_t *tile_data1,
42 ArithmeticParameter *param) {
43 CalcMultiplesAndStrides(param);
44 TileOneDimensionInt8(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_,
45 param->multiples0_);
46 TileOneDimensionInt8(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_,
47 param->multiples1_);
48 }
49
50 #define ACCURACY_DATA 0.00000001
51
ElementNotEqualInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)52 int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
53 ArithmeticQuantArg *quant_arg) {
54 float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
55 float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
56
57 for (int index = 0; index < element_size; ++index) {
58 float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
59 float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
60 float minus_inputs = in0_real - in1_real;
61 bool out_real = true;
62 if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) {
63 out_real = false;
64 }
65 output[index] = (uint8_t)out_real;
66 }
67 return NNACL_OK;
68 }
69
ElementEqualInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)70 int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) {
71 float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
72 float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
73 for (int index = 0; index < element_size; ++index) {
74 float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
75 float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
76 float minus_inputs = in0_real - in1_real;
77 bool out_real = false;
78 if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) {
79 out_real = true;
80 }
81 output[index] = (uint8_t)out_real;
82 }
83 return NNACL_OK;
84 }
85
ElementLessInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)86 int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) {
87 float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
88 float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
89 for (int index = 0; index < element_size; ++index) {
90 float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
91 float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
92 bool out_real = in0_real < in1_real;
93 output[index] = (uint8_t)out_real;
94 }
95 return NNACL_OK;
96 }
97
ElementLessEqualInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)98 int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
99 ArithmeticQuantArg *quant_arg) {
100 float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
101 float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
102 for (int index = 0; index < element_size; ++index) {
103 float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
104 float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
105 bool out_real = in0_real <= in1_real;
106 output[index] = (uint8_t)out_real;
107 }
108 return NNACL_OK;
109 }
110
ElementGreaterInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)111 int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
112 ArithmeticQuantArg *quant_arg) {
113 float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
114 float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
115 for (int index = 0; index < element_size; ++index) {
116 float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
117 float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
118 bool out_real = in0_real > in1_real;
119 output[index] = (uint8_t)out_real;
120 }
121 return NNACL_OK;
122 }
123
ElementGreaterEqualInt8(int8_t * input0,int8_t * input1,uint8_t * output,int element_size,ArithmeticQuantArg * quant_arg)124 int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size,
125 ArithmeticQuantArg *quant_arg) {
126 float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
127 float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
128 for (int index = 0; index < element_size; ++index) {
129 float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
130 float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
131 bool out_real = in0_real >= in1_real;
132 output[index] = (uint8_t)out_real;
133 }
134 return NNACL_OK;
135 }
136
137 #undef ACCURACY_DATA
138