• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/prelu.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/tensor_c_utils.h"
20 #include "nnacl/fp32/prelu_fp32.h"
21 #ifdef ENABLE_FP16
22 #include "nnacl/fp16/prelu_fp16.h"
23 #endif
24 
PReluRun(void * cdata,int task_id,float l,float r)25 int PReluRun(void *cdata, int task_id, float l, float r) {
26   PReluStruct *prelu = (PReluStruct *)cdata;
27   NNACL_CHECK_NULL_RETURN_ERR(prelu);
28 
29   int thread_num = prelu->base_.thread_nr_;
30   int num = prelu->channel_shared_ ? prelu->input_num_ : prelu->input_num_ / prelu->channel_num_;
31   int step = UP_DIV(num, thread_num);
32   int start = task_id * step;
33   int end = MSMIN(start + step, num);
34 
35   void *in_data = prelu->base_.in_[FIRST_INPUT]->data_;
36   void *out_data = prelu->base_.out_[OUTPUT_INDEX]->data_;
37   void *slope_data = prelu->base_.in_[SECOND_INPUT]->data_;
38 
39   if (prelu->data_type_ == kNumberTypeFloat16) {
40 #ifdef ENABLE_FP16
41     if (prelu->channel_shared_) {
42       PReluShareChannelFp16((float16_t *)in_data, (float16_t *)out_data, ((float16_t *)slope_data)[0], start, end);
43     } else {
44       PReluFp16((float16_t *)in_data, (float16_t *)out_data, (float16_t *)slope_data, start, end, prelu->channel_num_);
45     }
46 #endif
47   } else {
48     if (prelu->channel_shared_) {
49       PReluShareChannel((float *)in_data, (float *)out_data, ((float *)slope_data)[0], start, end);
50     } else {
51       PRelu((float *)in_data, (float *)out_data, (float *)slope_data, start, end, prelu->channel_num_);
52     }
53   }
54   return NNACL_OK;
55 }
56 
PReluPrepare(KernelBase * self)57 int PReluPrepare(KernelBase *self) {
58   NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR);
59   NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR);
60   PReluStruct *prelu = (PReluStruct *)self;
61   NNACL_CHECK_NULL_RETURN_ERR(prelu);
62   TensorC *input = self->in_[FIRST_INPUT];
63   NNACL_CHECK_NULL_RETURN_ERR(input);
64   TensorC *slope = self->in_[SECOND_INPUT];
65   NNACL_CHECK_NULL_RETURN_ERR(slope);
66 
67   int slope_num = GetElementNum(slope);
68   if (slope_num == Num1) {
69     prelu->channel_shared_ = true;
70   } else if (slope_num == GetChannel(input)) {
71     prelu->channel_shared_ = false;
72   }
73   if (!CheckInferShapeDone(self->in_, TWO_TENSOR, NULL, 0)) {
74     return NNACL_OK;
75   }
76   return NNACL_OK;
77 }
78 
PReluResize(KernelBase * self)79 int PReluResize(KernelBase *self) {
80   PReluStruct *prelu = (PReluStruct *)self;
81   NNACL_CHECK_NULL_RETURN_ERR(prelu);
82   TensorC *input = self->in_[FIRST_INPUT];
83   NNACL_CHECK_NULL_RETURN_ERR(input);
84   prelu->input_num_ = GetElementNum(input);
85   prelu->channel_num_ = GetChannel(input);
86   return NNACL_OK;
87 }
88 
PReluCompute(KernelBase * self)89 int PReluCompute(KernelBase *self) {
90   NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]);
91   NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_);
92   NNACL_CHECK_NULL_RETURN_ERR(self->in_[SECOND_INPUT]);
93   NNACL_CHECK_NULL_RETURN_ERR(self->in_[SECOND_INPUT]->data_);
94   NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]);
95   NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_);
96   return self->env_->ParallelLaunch(self->env_->thread_pool_, PReluRun, self, self->thread_nr_);
97 }
98 
CreatePRelu(OpParameter * param,int data_type)99 KernelBase *CreatePRelu(OpParameter *param, int data_type) {
100   PReluStruct *prelu = (PReluStruct *)malloc(sizeof(PReluStruct));
101   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(prelu);
102   memset(prelu, 0, sizeof(PReluStruct));
103   prelu->data_type_ = data_type;
104   prelu->base_.Prepare = PReluPrepare;
105   prelu->base_.Resize = PReluResize;
106   prelu->base_.Compute = PReluCompute;
107   prelu->base_.Release = DefaultRelease;
108   return (KernelBase *)prelu;
109 }
110 
111 REG_KERNEL_CREATOR(PrimType_PReLUFusion, kNumberTypeFloat16, CreatePRelu)
112 REG_KERNEL_CREATOR(PrimType_PReLUFusion, kNumberTypeFloat32, CreatePRelu)
113