1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/f16/arithmetic_f16.h"
18 #include "nnacl/fp16/cast_fp16.h"
19 #include "nnacl/fp16/arithmetic_fp16.h"
20 #include "nnacl/fp16/utils_fp16.h"
21 #include "nnacl/tensor_c_utils.h"
22
InitArithmeticF16RunFunction(KernelBase * base)23 void InitArithmeticF16RunFunction(KernelBase *base) {
24 ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)base;
25
26 ArithmeticF16Funcions f16_fun_table[] = {
27 {PrimType_MulFusion, ActType_Relu, ElementMulReluFp16, ElementOptMulReluFp16},
28 {PrimType_MulFusion, ActType_Relu6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16},
29 {PrimType_MulFusion, ActType_No, ElementMulFp16, ElementOptMulFp16},
30 {PrimType_AddFusion, ActType_Relu, ElementAddReluFp16, ElementOptAddReluFp16},
31 {PrimType_AddFusion, ActType_Relu6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16},
32 {PrimType_AddFusion, ActType_No, ElementAddFp16, ElementOptAddFp16},
33 {PrimType_SubFusion, ActType_Relu, ElementSubReluFp16, ElementOptSubReluFp16},
34 {PrimType_SubFusion, ActType_Relu6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16},
35 {PrimType_SubFusion, ActType_No, ElementSubFp16, ElementOptSubFp16},
36 {PrimType_DivFusion, ActType_Relu, ElementDivReluFp16, ElementOptDivReluFp16},
37 {PrimType_DivFusion, ActType_Relu6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16},
38 {PrimType_DivFusion, ActType_No, ElementDivFp16, ElementOptDivFp16},
39 {PrimType_RealDiv, ActType_Relu, ElementDivReluFp16, ElementOptDivReluFp16},
40 {PrimType_RealDiv, ActType_Relu6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16},
41 {PrimType_RealDiv, ActType_No, ElementDivFp16, ElementOptDivFp16},
42 {PrimType_FloorMod, ActType_No, ElementFloorModFp16, ElementOptFloorModFp16},
43 {PrimType_FloorDiv, ActType_No, ElementFloorDivFp16, ElementOptFloorDivFp16},
44 {PrimType_LogicalAnd, ActType_No, ElementLogicalAndFp16, ElementOptLogicalAndFp16},
45 {PrimType_LogicalOr, ActType_No, ElementLogicalOrFp16, ElementOptLogicalOrFp16},
46 {PrimType_SquaredDifference, ActType_No, ElementSquaredDifferenceFp16, ElementOptSquaredDifferenceFp16},
47 {PrimType_Maximum, ActType_No, ElementMaximumFp16, ElementOptMaximumFp16},
48 {PrimType_Minimum, ActType_No, ElementMinimumFp16, ElementOptMinimumFp16}};
49
50 size_t length = sizeof(f16_fun_table) / sizeof(ArithmeticF16Funcions);
51 for (size_t i = 0; i < length; i++) {
52 if (f16_fun_table[i].primitive_type_ == arithmetic_f16->arithmetic_.primitive_type_ &&
53 f16_fun_table[i].activation_type_ ==
54 ((ArithmeticParameter *)(arithmetic_f16->arithmetic_.base_.param_))->activation_type_) {
55 arithmetic_f16->functions_ = f16_fun_table[i];
56 return;
57 }
58 }
59 }
60
ArithmeticF16DoExecute(KernelBase * base,const void * input0,const void * input1,void * output,int64_t size)61 int ArithmeticF16DoExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) {
62 ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)base;
63
64 if (arithmetic_f16->arithmetic_.scalar_opt_) {
65 NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->functions_.optimzie_);
66 return arithmetic_f16->functions_.optimzie_((const float16_t *)input0, (const float16_t *)input1,
67 (float16_t *)output, size,
68 arithmetic_f16->arithmetic_.in_elements_num0_ == 1);
69 }
70
71 NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->functions_.compute_);
72 return arithmetic_f16->functions_.compute_((const float16_t *)input0, (const float16_t *)input1, (float16_t *)output,
73 size);
74 }
75
ArithmeticF16Resize(KernelBase * self)76 int ArithmeticF16Resize(KernelBase *self) {
77 ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)self;
78 NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16);
79 ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
80
81 arithmetic->in_data_size_ = sizeof(float16_t);
82 arithmetic->out_data_size_ = sizeof(float16_t);
83 if (arithmetic->in_elements_num1_ != 1 && arithmetic->in_elements_num0_ != 1) {
84 if (arithmetic->a_matrix_.is_const_ && self->in_[FIRST_INPUT]->data_type_ == kNumberTypeFloat32) {
85 TensorC *t = self->in_[FIRST_INPUT];
86 NNACL_CHECK_NULL_RETURN_ERR(t->data_);
87 void *f32_data = t->data_;
88 t->data_type_ = kNumberTypeFloat16;
89 t->data_ = self->env_->Alloc(self->env_->allocator_, GetSize(t));
90 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_);
91 Float32ToFloat16((float *)(f32_data), (float16_t *)(t->data_), GetElementNum(t));
92 self->env_->Free(self->env_->allocator_, f32_data);
93 }
94 if (arithmetic->b_matrix_.is_const_ && self->in_[SECOND_INPUT]->data_type_ == kNumberTypeFloat32) {
95 TensorC *t = self->in_[SECOND_INPUT];
96 NNACL_CHECK_NULL_RETURN_ERR(t->data_);
97 void *f32_data = t->data_;
98 t->data_type_ = kNumberTypeFloat16;
99 t->data_ = self->env_->Alloc(self->env_->allocator_, GetSize(t));
100 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_);
101 Float32ToFloat16((float *)(f32_data), (float16_t *)(t->data_), GetElementNum(t));
102 self->env_->Free(self->env_->allocator_, f32_data);
103 }
104 }
105 return ArithmeticResize(self);
106 }
107
FreeArithmeticF16Buffers(ArithmeticF16Struct * arithmetic_f16)108 void FreeArithmeticF16Buffers(ArithmeticF16Struct *arithmetic_f16) {
109 for (int i = 0; i < THREE_TENSOR; i++) {
110 if (arithmetic_f16->tmp_buffer_[i] != NULL) {
111 arithmetic_f16->arithmetic_.base_.env_->Free(arithmetic_f16->arithmetic_.base_.env_->allocator_,
112 arithmetic_f16->tmp_buffer_[i]);
113 arithmetic_f16->tmp_buffer_[i] = NULL;
114 }
115 }
116 }
117
ArithmeticF16Compute(KernelBase * self)118 int ArithmeticF16Compute(KernelBase *self) {
119 ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)self;
120 NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16);
121
122 int in0_data_type = self->in_[FIRST_INPUT]->data_type_;
123 int in1_data_type = self->in_[SECOND_INPUT]->data_type_;
124 int out_data_type = self->out_[OUTPUT_INDEX]->data_type_;
125
126 NNACL_CHECK_FALSE(in0_data_type != kNumberTypeFloat32 && in0_data_type != kNumberTypeFloat16,
127 NNACL_UNSUPPORTED_DATA_TYPE);
128 NNACL_CHECK_FALSE(in1_data_type != kNumberTypeFloat16 && in1_data_type != kNumberTypeFloat32,
129 NNACL_UNSUPPORTED_DATA_TYPE);
130
131 if (!arithmetic_f16->arithmetic_.a_matrix_.is_valid_) {
132 arithmetic_f16->arithmetic_.a_matrix_.data_ = GetOrAllocFp16Data(self->in_[FIRST_INPUT], self->env_, true);
133 arithmetic_f16->tmp_buffer_[FIRST_INPUT] =
134 in0_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.a_matrix_.data_;
135 }
136
137 if (!arithmetic_f16->arithmetic_.b_matrix_.is_valid_) {
138 arithmetic_f16->arithmetic_.b_matrix_.data_ = GetOrAllocFp16Data(self->in_[SECOND_INPUT], self->env_, true);
139 arithmetic_f16->tmp_buffer_[SECOND_INPUT] =
140 in1_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.b_matrix_.data_;
141 }
142
143 arithmetic_f16->arithmetic_.c_matrix_.data_ = GetOrAllocFp16Data(self->out_[OUTPUT_INDEX], self->env_, false);
144 arithmetic_f16->tmp_buffer_[THIRD_INPUT] =
145 out_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.c_matrix_.data_;
146
147 int ret = ArithmeticCompute(self);
148 if (ret == NNACL_OK && out_data_type == kNumberTypeFloat32) {
149 NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->arithmetic_.c_matrix_.data_);
150 NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_);
151 Float16ToFloat32((float16_t *)(arithmetic_f16->arithmetic_.c_matrix_.data_),
152 (float *)(self->out_[OUTPUT_INDEX]->data_), GetElementNum(self->out_[OUTPUT_INDEX]));
153 }
154
155 FreeArithmeticF16Buffers(arithmetic_f16);
156 return NNACL_OK;
157 }
158
CreateArithmeticF16(OpParameter * param,int data_type)159 KernelBase *CreateArithmeticF16(OpParameter *param, int data_type) {
160 ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)malloc(sizeof(ArithmeticF16Struct));
161 NNACL_CHECK_NULL_RETURN_NULL(arithmetic_f16);
162 memset(arithmetic_f16, 0, sizeof(ArithmeticF16Struct));
163
164 ArithmeticStruct *arithmetic = &arithmetic_f16->arithmetic_;
165 arithmetic->block_boundary_infos_size_ = 0;
166 arithmetic->a_matrix_.batch_post_sum_ = NULL;
167 arithmetic->b_matrix_.batch_post_sum_ = NULL;
168 arithmetic->c_matrix_.batch_post_sum_ = NULL;
169 arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL;
170 arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL;
171 arithmetic->base_.Prepare = ArithmeticPrepare;
172 arithmetic->base_.Resize = ArithmeticF16Resize;
173 arithmetic->base_.Release = ArithmeticRelease;
174 arithmetic->base_.Compute = ArithmeticF16Compute;
175
176 arithmetic->execute_ = ArithmeticF16DoExecute;
177 arithmetic->tile_function_ = TileOneDimensionFp16;
178 arithmetic->init_function_ = InitArithmeticF16RunFunction;
179
180 return (KernelBase *)arithmetic_f16;
181 }
182
183 REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeFloat16, CreateArithmeticF16)
184 REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeFloat16, CreateArithmeticF16)
185 REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeFloat16, CreateArithmeticF16)
186 REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeFloat16, CreateArithmeticF16)
187 REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeFloat16, CreateArithmeticF16)
188 REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeFloat16, CreateArithmeticF16)
189 REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeFloat16, CreateArithmeticF16)
190 REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeFloat16, CreateArithmeticF16)
191 REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeFloat16, CreateArithmeticF16)
192 REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeFloat16, CreateArithmeticF16)
193 REG_KERNEL_CREATOR(PrimType_Eltwise, kNumberTypeFloat16, CreateArithmeticF16)
194 REG_KERNEL_CREATOR(PrimType_RealDiv, kNumberTypeFloat16, CreateArithmeticF16)
195 REG_KERNEL_CREATOR(PrimType_SquaredDifference, kNumberTypeFloat16, CreateArithmeticF16)
196