• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/arithmetic_self.h"
18 #include "nnacl/fp32/arithmetic_self_fp32.h"
19 #include "nnacl/kernel/default_kernel_base.h"
20 #include "nnacl/tensor_c_utils.h"
21 #ifdef ENABLE_FP16
22 #include "nnacl/fp16/arithmetic_self_fp16.h"
23 #endif
24 
ArithmeticSelfGetArithmeticSelfFunction(ArithmeticSelfStruct * arithmetic_self,int primitive_type)25 void ArithmeticSelfGetArithmeticSelfFunction(ArithmeticSelfStruct *arithmetic_self, int primitive_type) {
26   ArithmeticSelfFunction type_func_table[] = {
27     {PrimType_Abs, ElementAbs, NULL, ElementAbsInt, NULL},
28     {PrimType_Cos, ElementCos, NULL, NULL, NULL},
29     {PrimType_Log, ElementLog, NULL, NULL, NULL},
30     {PrimType_Log1p, ElementLog1p, NULL, NULL, NULL},
31     {PrimType_Square, ElementSquare, NULL, NULL, NULL},
32     {PrimType_Sqrt, ElementSqrt, NULL, NULL, NULL},
33     {PrimType_Rsqrt, ElementRsqrt, NULL, NULL, NULL},
34     {PrimType_Sin, ElementSin, NULL, NULL, NULL},
35     {PrimType_LogicalNot, ElementLogicalNot, ElementLogicalNotBool, NULL, NULL},
36     {PrimType_Floor, ElementFloor, NULL, NULL, NULL},
37     {PrimType_Ceil, ElementCeil, NULL, NULL, NULL},
38     {PrimType_Round, ElementRound, NULL, NULL, NULL},
39     {PrimType_Neg, ElementNegative, NULL, ElementNegativeInt, NULL},
40     {PrimType_Reciprocal, ElementReciprocal, NULL, NULL, NULL},
41     {PrimType_Erf, ElementErf, NULL, NULL, NULL},
42     {PrimType_IsFinite, NULL, NULL, NULL, ElementIsFinite}};
43   for (size_t i = 0; i < sizeof(type_func_table) / sizeof(ArithmeticSelfFunction); i++) {
44     if (type_func_table[i].primitive_type_ == primitive_type) {
45       arithmetic_self->function_ = type_func_table[i];
46       return;
47     }
48   }
49 }
50 
ArithmeticSelfGetArithmeticSelfF16Function(ArithmeticSelfStruct * arithmetic_self,int primitive_type)51 void ArithmeticSelfGetArithmeticSelfF16Function(ArithmeticSelfStruct *arithmetic_self, int primitive_type) {
52 #ifdef ENABLE_FP16
53   ArithmeticSelfF16Function type_func_table[] = {{PrimType_Abs, ElementAbsFp16},
54                                                  {PrimType_Cos, ElementCosFp16},
55                                                  {PrimType_Log, ElementLogFp16},
56                                                  {PrimType_Square, ElementSquareFp16},
57                                                  {PrimType_Sqrt, ElementSqrtFp16},
58                                                  {PrimType_Rsqrt, ElementRsqrtFp16},
59                                                  {PrimType_Sin, ElementSinFp16},
60                                                  {PrimType_LogicalNot, ElementLogicalNotFp16},
61                                                  {PrimType_Floor, ElementFloorFp16},
62                                                  {PrimType_Ceil, ElementCeilFp16},
63                                                  {PrimType_Round, ElementRoundFp16},
64                                                  {PrimType_Neg, ElementNegativeFp16},
65                                                  {PrimType_Reciprocal, ElementReciprocalFp16},
66                                                  {PrimType_Erf, ElementErfFp16}};
67   for (size_t i = 0; i < sizeof(type_func_table) / sizeof(ArithmeticSelfF16Function); i++) {
68     if (type_func_table[i].primitive_type_ == primitive_type) {
69       arithmetic_self->f16_function_ = type_func_table[i];
70       return;
71     }
72   }
73 #endif
74   arithmetic_self->f16_function_.primitive_type_ = primitive_type;
75   return;
76 }
77 
ArithmeticSelfExecute(ArithmeticSelfStruct * arithmetic_self,int task_id)78 int ArithmeticSelfExecute(ArithmeticSelfStruct *arithmetic_self, int task_id) {
79   int elements_num = GetElementNum(arithmetic_self->base_.in_[FIRST_INPUT]);
80   NNACL_CHECK_TRUE_RET(arithmetic_self->base_.thread_nr_, NNACL_ERR);
81   int stride = UP_DIV(elements_num, arithmetic_self->base_.thread_nr_);
82   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, stride, NNACL_ERR);
83   int offset = task_id * stride;
84   int count = NNACL_MIN(stride, elements_num - offset);
85   if (count <= 0) {
86     return NNACL_OK;
87   }
88 
89   void *in_data = arithmetic_self->base_.in_[FIRST_INPUT]->data_;
90   NNACL_CHECK_NULL_RETURN_ERR(in_data);
91   void *out_data = arithmetic_self->base_.out_[OUTPUT_INDEX]->data_;
92   NNACL_CHECK_NULL_RETURN_ERR(out_data);
93   int in_data_type = arithmetic_self->base_.in_[FIRST_INPUT]->data_type_;
94   int out_data_type = arithmetic_self->base_.out_[OUTPUT_INDEX]->data_type_;
95 
96   if (in_data_type == kNumberTypeFloat32 && out_data_type == kNumberTypeBool) {
97     NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_float_bool_);
98     return arithmetic_self->function_.func_float_bool_((float *)in_data + offset, (bool *)out_data + offset, count);
99   }
100 
101   if (in_data_type == kNumberTypeFloat32) {
102     NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_);
103     return arithmetic_self->function_.func_((float *)in_data + offset, (float *)out_data + offset, count);
104   }
105 
106   if (in_data_type == kNumberTypeBool) {
107     NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_bool_);
108     return arithmetic_self->function_.func_bool_((bool *)in_data + offset, (bool *)out_data + offset, count);
109   }
110 
111   if (in_data_type == kNumberTypeInt32) {
112     NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_int_);
113     return arithmetic_self->function_.func_int_((int32_t *)in_data + offset, (int32_t *)out_data + offset, count);
114   }
115 
116 #ifdef ENABLE_FP16
117   if (in_data_type == kNumberTypeFloat16) {
118     NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->f16_function_.func_);
119     return arithmetic_self->f16_function_.func_((float16_t *)in_data + offset, (float16_t *)out_data + offset, count);
120   }
121 #endif
122   return NNACL_ARITHMETIC_SELF_DATA_TYPE_UNSUPPORT;
123 }
124 
ArithmeticSelfRun(void * cdata,int task_id,float l,float r)125 int ArithmeticSelfRun(void *cdata, int task_id, float l, float r) {
126   ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)cdata;
127   NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self);
128   return ArithmeticSelfExecute(arithmetic_self, task_id);
129 }
130 
ArithmeticSelfResize(KernelBase * self)131 int ArithmeticSelfResize(KernelBase *self) {
132   ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)self;
133   NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self);
134   self->thread_nr_ = arithmetic_self->base_.UpdateThread(TC_PTYPE(arithmetic_self->op_type_), 1, 1,
135                                                          GetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_);
136   return NNACL_OK;
137 }
138 
ArithmeticSelfCompute(KernelBase * self)139 int ArithmeticSelfCompute(KernelBase *self) {
140   return self->env_->ParallelLaunch(self->env_->thread_pool_, ArithmeticSelfRun, self, self->thread_nr_);
141 }
142 
ArithmeticSelfPrepare(KernelBase * self)143 int ArithmeticSelfPrepare(KernelBase *self) {
144   NNACL_CHECK_FALSE(self->in_size_ != ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR);
145   NNACL_CHECK_FALSE(self->out_size_ != ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR);
146   NNACL_CHECK_FALSE(self->out_[OUTPUT_INDEX]->category_ == ConstTensor, NNACL_OUTPUT_TENSOR_ERROR);
147   NNACL_CHECK_FALSE(self->out_[OUTPUT_INDEX]->category_ == ConstScalar, NNACL_OUTPUT_TENSOR_ERROR);
148   return NNACL_OK;
149 }
150 
CreateArithmeticSelf(OpParameter * param,int data_type)151 KernelBase *CreateArithmeticSelf(OpParameter *param, int data_type) {
152   ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)malloc(sizeof(ArithmeticSelfStruct));
153   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic_self);
154   ArithmeticSelfGetArithmeticSelfFunction(arithmetic_self, param->type_);
155   ArithmeticSelfGetArithmeticSelfF16Function(arithmetic_self, param->type_);
156   arithmetic_self->op_type_ = param->type_;
157   arithmetic_self->base_.Prepare = ArithmeticSelfPrepare;
158   arithmetic_self->base_.Resize = ArithmeticSelfResize;
159   arithmetic_self->base_.Release = DefaultRelease;
160   arithmetic_self->base_.Compute = ArithmeticSelfCompute;
161   return (KernelBase *)arithmetic_self;
162 }
163 
164 REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeBool, CreateArithmeticSelf)
165 
166 REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeInt32, CreateArithmeticSelf)
167 REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeInt32, CreateArithmeticSelf)
168 
169 REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeFloat32, CreateArithmeticSelf)
170 REG_KERNEL_CREATOR(PrimType_Ceil, kNumberTypeFloat32, CreateArithmeticSelf)
171 REG_KERNEL_CREATOR(PrimType_Cos, kNumberTypeFloat32, CreateArithmeticSelf)
172 REG_KERNEL_CREATOR(PrimType_Erf, kNumberTypeFloat32, CreateArithmeticSelf)
173 REG_KERNEL_CREATOR(PrimType_Floor, kNumberTypeFloat32, CreateArithmeticSelf)
174 REG_KERNEL_CREATOR(PrimType_Log, kNumberTypeFloat32, CreateArithmeticSelf)
175 REG_KERNEL_CREATOR(PrimType_Log1p, kNumberTypeFloat32, CreateArithmeticSelf)
176 REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeFloat32, CreateArithmeticSelf)
177 REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeFloat32, CreateArithmeticSelf)
178 REG_KERNEL_CREATOR(PrimType_Reciprocal, kNumberTypeFloat32, CreateArithmeticSelf)
179 REG_KERNEL_CREATOR(PrimType_Round, kNumberTypeFloat32, CreateArithmeticSelf)
180 REG_KERNEL_CREATOR(PrimType_Rsqrt, kNumberTypeFloat32, CreateArithmeticSelf)
181 REG_KERNEL_CREATOR(PrimType_Square, kNumberTypeFloat32, CreateArithmeticSelf)
182 REG_KERNEL_CREATOR(PrimType_Sqrt, kNumberTypeFloat32, CreateArithmeticSelf)
183 REG_KERNEL_CREATOR(PrimType_Sin, kNumberTypeFloat32, CreateArithmeticSelf)
184 REG_KERNEL_CREATOR(PrimType_IsFinite, kNumberTypeFloat32, CreateArithmeticSelf)
185 
186 REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeFloat16, CreateArithmeticSelf)
187 REG_KERNEL_CREATOR(PrimType_Cos, kNumberTypeFloat16, CreateArithmeticSelf)
188 REG_KERNEL_CREATOR(PrimType_Log, kNumberTypeFloat16, CreateArithmeticSelf)
189 REG_KERNEL_CREATOR(PrimType_Square, kNumberTypeFloat16, CreateArithmeticSelf)
190 REG_KERNEL_CREATOR(PrimType_Sqrt, kNumberTypeFloat16, CreateArithmeticSelf)
191 REG_KERNEL_CREATOR(PrimType_Rsqrt, kNumberTypeFloat16, CreateArithmeticSelf)
192 REG_KERNEL_CREATOR(PrimType_Sin, kNumberTypeFloat16, CreateArithmeticSelf)
193 REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeFloat16, CreateArithmeticSelf)
194 REG_KERNEL_CREATOR(PrimType_Floor, kNumberTypeFloat16, CreateArithmeticSelf)
195 REG_KERNEL_CREATOR(PrimType_Ceil, kNumberTypeFloat16, CreateArithmeticSelf)
196 REG_KERNEL_CREATOR(PrimType_Round, kNumberTypeFloat16, CreateArithmeticSelf)
197 REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeFloat16, CreateArithmeticSelf)
198 REG_KERNEL_CREATOR(PrimType_Reciprocal, kNumberTypeFloat16, CreateArithmeticSelf)
199 REG_KERNEL_CREATOR(PrimType_Erf, kNumberTypeFloat16, CreateArithmeticSelf)
200