1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/pow.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/tensor_c_utils.h"
20 #include "nnacl/nnacl_common.h"
21 #include "nnacl/fp32/power_fp32.h"
22 #ifdef ENABLE_FP16
23 #include "nnacl/fp16/power_fp16.h"
24 #endif
25
PowImpl(void * cdata,int task_id,float l,float r)26 int PowImpl(void *cdata, int task_id, float l, float r) {
27 PowStruct *pow = (PowStruct *)cdata;
28 TensorC *input0 = pow->base_.in_[FIRST_INPUT];
29 TensorC *input1 = pow->base_.in_[SECOND_INPUT];
30 TensorC *output = pow->base_.out_[OUTPUT_INDEX];
31
32 int size = GetElementNum(input0);
33 int stride = UP_DIV(size, pow->base_.thread_nr_);
34 int len = MSMIN(stride, size - stride * task_id);
35 if (len <= 0) {
36 return NNACL_OK;
37 }
38 bool broadcast = !ShapeEqual(input0->shape_, input0->shape_size_, input1->shape_, input1->shape_size_);
39 float scale = ((PowParameter *)pow->base_.param_)->scale_;
40 float shift = ((PowParameter *)pow->base_.param_)->shift_;
41 int task_stride = stride * task_id;
42
43 uint8_t *exp_addr = (uint8_t *)input1->data_;
44 void *cur_exp = NULL;
45 if (broadcast) {
46 cur_exp = exp_addr;
47 } else {
48 cur_exp = exp_addr + task_stride * DataTypeCSize(pow->data_type_);
49 }
50
51 if (pow->data_type_ == kNumberTypeFloat16) {
52 #ifdef ENABLE_FP16
53 return PowerFp16((float16_t *)input0->data_ + task_stride, (float16_t *)cur_exp,
54 (float16_t *)output->data_ + task_stride, len, scale, shift, broadcast);
55 #endif
56 } else if (pow->data_type_ == kNumberTypeFloat32) {
57 return Power((float *)input0->data_ + task_stride, (float *)cur_exp, (float *)output->data_ + task_stride, len,
58 scale, shift, broadcast);
59 }
60 return NNACL_POW_INVALID_DATA_TYPE;
61 }
62
PowCompute(KernelBase * self)63 int PowCompute(KernelBase *self) {
64 return self->env_->ParallelLaunch(self->env_->thread_pool_, PowImpl, self, self->thread_nr_);
65 }
66
CreatePow(OpParameter * param,int data_type)67 KernelBase *CreatePow(OpParameter *param, int data_type) {
68 PowStruct *pow = (PowStruct *)malloc(sizeof(PowStruct));
69 NNACL_CHECK_NULL_RETURN_NULL(pow);
70 pow->data_type_ = data_type;
71 pow->base_.Release = DefaultRelease;
72 pow->base_.Prepare = DefaultPrepare2In1Out;
73 pow->base_.Resize = DefaultResize;
74 pow->base_.Compute = PowCompute;
75 return (KernelBase *)pow;
76 }
77
78 REG_KERNEL_CREATOR(PrimType_PowFusion, kNumberTypeFloat32, CreatePow)
79 REG_KERNEL_CREATOR(PrimType_PowFusion, kNumberTypeFloat16, CreatePow)
80