• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/f16/arithmetic_f16.h"
18 #include "nnacl/fp16/cast_fp16.h"
19 #include "nnacl/fp16/arithmetic_fp16.h"
20 #include "nnacl/fp16/utils_fp16.h"
21 #include "nnacl/tensor_c_utils.h"
22 
InitArithmeticF16RunFunction(KernelBase * base)23 void InitArithmeticF16RunFunction(KernelBase *base) {
24   ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)base;
25 
26   ArithmeticF16Funcions f16_fun_table[] = {
27     {PrimType_MulFusion, ActType_Relu, ElementMulReluFp16, ElementOptMulReluFp16},
28     {PrimType_MulFusion, ActType_Relu6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16},
29     {PrimType_MulFusion, ActType_No, ElementMulFp16, ElementOptMulFp16},
30     {PrimType_AddFusion, ActType_Relu, ElementAddReluFp16, ElementOptAddReluFp16},
31     {PrimType_AddFusion, ActType_Relu6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16},
32     {PrimType_AddFusion, ActType_No, ElementAddFp16, ElementOptAddFp16},
33     {PrimType_SubFusion, ActType_Relu, ElementSubReluFp16, ElementOptSubReluFp16},
34     {PrimType_SubFusion, ActType_Relu6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16},
35     {PrimType_SubFusion, ActType_No, ElementSubFp16, ElementOptSubFp16},
36     {PrimType_DivFusion, ActType_Relu, ElementDivReluFp16, ElementOptDivReluFp16},
37     {PrimType_DivFusion, ActType_Relu6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16},
38     {PrimType_DivFusion, ActType_No, ElementDivFp16, ElementOptDivFp16},
39     {PrimType_RealDiv, ActType_Relu, ElementDivReluFp16, ElementOptDivReluFp16},
40     {PrimType_RealDiv, ActType_Relu6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16},
41     {PrimType_RealDiv, ActType_No, ElementDivFp16, ElementOptDivFp16},
42     {PrimType_FloorMod, ActType_No, ElementFloorModFp16, ElementOptFloorModFp16},
43     {PrimType_FloorDiv, ActType_No, ElementFloorDivFp16, ElementOptFloorDivFp16},
44     {PrimType_LogicalAnd, ActType_No, ElementLogicalAndFp16, ElementOptLogicalAndFp16},
45     {PrimType_LogicalOr, ActType_No, ElementLogicalOrFp16, ElementOptLogicalOrFp16},
46     {PrimType_SquaredDifference, ActType_No, ElementSquaredDifferenceFp16, ElementOptSquaredDifferenceFp16},
47     {PrimType_Maximum, ActType_No, ElementMaximumFp16, ElementOptMaximumFp16},
48     {PrimType_Minimum, ActType_No, ElementMinimumFp16, ElementOptMinimumFp16}};
49 
50   size_t length = sizeof(f16_fun_table) / sizeof(ArithmeticF16Funcions);
51   for (size_t i = 0; i < length; i++) {
52     if (f16_fun_table[i].primitive_type_ == arithmetic_f16->arithmetic_.primitive_type_ &&
53         f16_fun_table[i].activation_type_ ==
54           ((ArithmeticParameter *)(arithmetic_f16->arithmetic_.base_.param_))->activation_type_) {
55       arithmetic_f16->functions_ = f16_fun_table[i];
56       return;
57     }
58   }
59 }
60 
ArithmeticF16DoExecute(KernelBase * base,const void * input0,const void * input1,void * output,int64_t size)61 int ArithmeticF16DoExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) {
62   ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)base;
63 
64   if (arithmetic_f16->arithmetic_.scalar_opt_) {
65     NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->functions_.optimzie_);
66     return arithmetic_f16->functions_.optimzie_((const float16_t *)input0, (const float16_t *)input1,
67                                                 (float16_t *)output, size,
68                                                 arithmetic_f16->arithmetic_.in_elements_num0_ == 1);
69   }
70 
71   NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->functions_.compute_);
72   return arithmetic_f16->functions_.compute_((const float16_t *)input0, (const float16_t *)input1, (float16_t *)output,
73                                              size);
74 }
75 
ArithmeticF16Resize(KernelBase * self)76 int ArithmeticF16Resize(KernelBase *self) {
77   ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)self;
78   NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16);
79   ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
80 
81   arithmetic->in_data_size_ = sizeof(float16_t);
82   arithmetic->out_data_size_ = sizeof(float16_t);
83   if (arithmetic->in_elements_num1_ != 1 && arithmetic->in_elements_num0_ != 1) {
84     if (arithmetic->a_matrix_.is_const_ && self->in_[FIRST_INPUT]->data_type_ == kNumberTypeFloat32) {
85       TensorC *t = self->in_[FIRST_INPUT];
86       NNACL_CHECK_NULL_RETURN_ERR(t->data_);
87       void *f32_data = t->data_;
88       t->data_type_ = kNumberTypeFloat16;
89       t->data_ = self->env_->Alloc(self->env_->allocator_, GetSize(t));
90       NNACL_MALLOC_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_);
91       Float32ToFloat16((float *)(f32_data), (float16_t *)(t->data_), GetElementNum(t));
92       self->env_->Free(self->env_->allocator_, f32_data);
93     }
94     if (arithmetic->b_matrix_.is_const_ && self->in_[SECOND_INPUT]->data_type_ == kNumberTypeFloat32) {
95       TensorC *t = self->in_[SECOND_INPUT];
96       NNACL_CHECK_NULL_RETURN_ERR(t->data_);
97       void *f32_data = t->data_;
98       t->data_type_ = kNumberTypeFloat16;
99       t->data_ = self->env_->Alloc(self->env_->allocator_, GetSize(t));
100       NNACL_MALLOC_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_);
101       Float32ToFloat16((float *)(f32_data), (float16_t *)(t->data_), GetElementNum(t));
102       self->env_->Free(self->env_->allocator_, f32_data);
103     }
104   }
105   return ArithmeticResize(self);
106 }
107 
FreeArithmeticF16Buffers(ArithmeticF16Struct * arithmetic_f16)108 void FreeArithmeticF16Buffers(ArithmeticF16Struct *arithmetic_f16) {
109   for (int i = 0; i < THREE_TENSOR; i++) {
110     if (arithmetic_f16->tmp_buffer_[i] != NULL) {
111       arithmetic_f16->arithmetic_.base_.env_->Free(arithmetic_f16->arithmetic_.base_.env_->allocator_,
112                                                    arithmetic_f16->tmp_buffer_[i]);
113       arithmetic_f16->tmp_buffer_[i] = NULL;
114     }
115   }
116 }
117 
ArithmeticF16Compute(KernelBase * self)118 int ArithmeticF16Compute(KernelBase *self) {
119   ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)self;
120   NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16);
121 
122   int in0_data_type = self->in_[FIRST_INPUT]->data_type_;
123   int in1_data_type = self->in_[SECOND_INPUT]->data_type_;
124   int out_data_type = self->out_[OUTPUT_INDEX]->data_type_;
125 
126   NNACL_CHECK_FALSE(in0_data_type != kNumberTypeFloat32 && in0_data_type != kNumberTypeFloat16,
127                     NNACL_UNSUPPORTED_DATA_TYPE);
128   NNACL_CHECK_FALSE(in1_data_type != kNumberTypeFloat16 && in1_data_type != kNumberTypeFloat32,
129                     NNACL_UNSUPPORTED_DATA_TYPE);
130 
131   if (!arithmetic_f16->arithmetic_.a_matrix_.is_valid_) {
132     arithmetic_f16->arithmetic_.a_matrix_.data_ = GetOrAllocFp16Data(self->in_[FIRST_INPUT], self->env_, true);
133     arithmetic_f16->tmp_buffer_[FIRST_INPUT] =
134       in0_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.a_matrix_.data_;
135   }
136 
137   if (!arithmetic_f16->arithmetic_.b_matrix_.is_valid_) {
138     arithmetic_f16->arithmetic_.b_matrix_.data_ = GetOrAllocFp16Data(self->in_[SECOND_INPUT], self->env_, true);
139     arithmetic_f16->tmp_buffer_[SECOND_INPUT] =
140       in1_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.b_matrix_.data_;
141   }
142 
143   arithmetic_f16->arithmetic_.c_matrix_.data_ = GetOrAllocFp16Data(self->out_[OUTPUT_INDEX], self->env_, false);
144   arithmetic_f16->tmp_buffer_[THIRD_INPUT] =
145     out_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.c_matrix_.data_;
146 
147   int ret = ArithmeticCompute(self);
148   if (ret == NNACL_OK && out_data_type == kNumberTypeFloat32) {
149     NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->arithmetic_.c_matrix_.data_);
150     NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_);
151     Float16ToFloat32((float16_t *)(arithmetic_f16->arithmetic_.c_matrix_.data_),
152                      (float *)(self->out_[OUTPUT_INDEX]->data_), GetElementNum(self->out_[OUTPUT_INDEX]));
153   }
154 
155   FreeArithmeticF16Buffers(arithmetic_f16);
156   return NNACL_OK;
157 }
158 
CreateArithmeticF16(OpParameter * param,int data_type)159 KernelBase *CreateArithmeticF16(OpParameter *param, int data_type) {
160   ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)malloc(sizeof(ArithmeticF16Struct));
161   NNACL_CHECK_NULL_RETURN_NULL(arithmetic_f16);
162   memset(arithmetic_f16, 0, sizeof(ArithmeticF16Struct));
163 
164   ArithmeticStruct *arithmetic = &arithmetic_f16->arithmetic_;
165   arithmetic->block_boundary_infos_size_ = 0;
166   arithmetic->a_matrix_.batch_post_sum_ = NULL;
167   arithmetic->b_matrix_.batch_post_sum_ = NULL;
168   arithmetic->c_matrix_.batch_post_sum_ = NULL;
169   arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL;
170   arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL;
171   arithmetic->base_.Prepare = ArithmeticPrepare;
172   arithmetic->base_.Resize = ArithmeticF16Resize;
173   arithmetic->base_.Release = ArithmeticRelease;
174   arithmetic->base_.Compute = ArithmeticF16Compute;
175 
176   arithmetic->execute_ = ArithmeticF16DoExecute;
177   arithmetic->tile_function_ = TileOneDimensionFp16;
178   arithmetic->init_function_ = InitArithmeticF16RunFunction;
179 
180   return (KernelBase *)arithmetic_f16;
181 }
182 
183 REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeFloat16, CreateArithmeticF16)
184 REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeFloat16, CreateArithmeticF16)
185 REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeFloat16, CreateArithmeticF16)
186 REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeFloat16, CreateArithmeticF16)
187 REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeFloat16, CreateArithmeticF16)
188 REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeFloat16, CreateArithmeticF16)
189 REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeFloat16, CreateArithmeticF16)
190 REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeFloat16, CreateArithmeticF16)
191 REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeFloat16, CreateArithmeticF16)
192 REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeFloat16, CreateArithmeticF16)
193 REG_KERNEL_CREATOR(PrimType_Eltwise, kNumberTypeFloat16, CreateArithmeticF16)
194 REG_KERNEL_CREATOR(PrimType_RealDiv, kNumberTypeFloat16, CreateArithmeticF16)
195 REG_KERNEL_CREATOR(PrimType_SquaredDifference, kNumberTypeFloat16, CreateArithmeticF16)
196