• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/pow.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/tensor_c_utils.h"
20 #include "nnacl/nnacl_common.h"
21 #include "nnacl/fp32/power_fp32.h"
22 #ifdef ENABLE_FP16
23 #include "nnacl/fp16/power_fp16.h"
24 #endif
25 
PowImpl(void * cdata,int task_id,float l,float r)26 int PowImpl(void *cdata, int task_id, float l, float r) {
27   PowStruct *pow = (PowStruct *)cdata;
28   TensorC *input0 = pow->base_.in_[FIRST_INPUT];
29   TensorC *input1 = pow->base_.in_[SECOND_INPUT];
30   TensorC *output = pow->base_.out_[OUTPUT_INDEX];
31 
32   int size = GetElementNum(input0);
33   int stride = UP_DIV(size, pow->base_.thread_nr_);
34   int len = MSMIN(stride, size - stride * task_id);
35   if (len <= 0) {
36     return NNACL_OK;
37   }
38   bool broadcast = !ShapeEqual(input0->shape_, input0->shape_size_, input1->shape_, input1->shape_size_);
39   float scale = ((PowParameter *)pow->base_.param_)->scale_;
40   float shift = ((PowParameter *)pow->base_.param_)->shift_;
41   int task_stride = stride * task_id;
42 
43   uint8_t *exp_addr = (uint8_t *)input1->data_;
44   void *cur_exp = NULL;
45   if (broadcast) {
46     cur_exp = exp_addr;
47   } else {
48     cur_exp = exp_addr + task_stride * DataTypeCSize(pow->data_type_);
49   }
50 
51   if (pow->data_type_ == kNumberTypeFloat16) {
52 #ifdef ENABLE_FP16
53     return PowerFp16((float16_t *)input0->data_ + task_stride, (float16_t *)cur_exp,
54                      (float16_t *)output->data_ + task_stride, len, scale, shift, broadcast);
55 #endif
56   } else if (pow->data_type_ == kNumberTypeFloat32) {
57     return Power((float *)input0->data_ + task_stride, (float *)cur_exp, (float *)output->data_ + task_stride, len,
58                  scale, shift, broadcast);
59   }
60   return NNACL_POW_INVALID_DATA_TYPE;
61 }
62 
PowCompute(KernelBase * self)63 int PowCompute(KernelBase *self) {
64   return self->env_->ParallelLaunch(self->env_->thread_pool_, PowImpl, self, self->thread_nr_);
65 }
66 
CreatePow(OpParameter * param,int data_type)67 KernelBase *CreatePow(OpParameter *param, int data_type) {
68   PowStruct *pow = (PowStruct *)malloc(sizeof(PowStruct));
69   NNACL_CHECK_NULL_RETURN_NULL(pow);
70   pow->data_type_ = data_type;
71   pow->base_.Release = DefaultRelease;
72   pow->base_.Prepare = DefaultPrepare2In1Out;
73   pow->base_.Resize = DefaultResize;
74   pow->base_.Compute = PowCompute;
75   return (KernelBase *)pow;
76 }
77 
78 REG_KERNEL_CREATOR(PrimType_PowFusion, kNumberTypeFloat32, CreatePow)
79 REG_KERNEL_CREATOR(PrimType_PowFusion, kNumberTypeFloat16, CreatePow)
80