1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/f16/arithmetic_compare_f16.h"
18 #include "nnacl/kernel/f16/arithmetic_f16.h"
19 #include "nnacl/fp16/arithmetic_fp16.h"
20
21 typedef struct ArithmeticCompareF16Funcions {
22 int primitive_type_;
23 int activation_type_;
24 int (*compute_)(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size);
25 int (*optimzie_)(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
26 bool first_scalar);
27 } ArithmeticCompareF16Funcions;
28
29 typedef struct ArithmeticCompareF16Struct {
30 ArithmeticF16Struct arithmetic_f16_;
31 ArithmeticCompareF16Funcions functions_;
32 } ArithmeticCompareF16Struct;
33
InitArithmeticCompareF16RunFunction(KernelBase * base)34 void InitArithmeticCompareF16RunFunction(KernelBase *base) {
35 ArithmeticCompareF16Struct *arithmetic_compare_f16 = (ArithmeticCompareF16Struct *)base;
36 ArithmeticParameter *arithmetic_param = (ArithmeticParameter *)base->param_;
37
38 ArithmeticCompareF16Funcions arithmetic_cp_fun_table_fp16[] = {
39 {PrimType_NotEqual, ActType_No, ElementNotEqualFp16, ElementOptNotEqualFp16},
40 {PrimType_Equal, ActType_No, ElementEqualFp16, ElementOptEqualFp16},
41 {PrimType_Less, ActType_No, ElementLessFp16, ElementOptLessFp16},
42 {PrimType_LessEqual, ActType_No, ElementLessEqualFp16, ElementOptLessEqualFp16},
43 {PrimType_Greater, ActType_No, ElementGreaterFp16, ElementOptGreaterFp16},
44 {PrimType_GreaterEqual, ActType_No, ElementGreaterEqualFp16, ElementOptGreaterEqualFp16}};
45
46 size_t length = sizeof(arithmetic_cp_fun_table_fp16) / sizeof(ArithmeticCompareF16Funcions);
47 for (size_t i = 0; i < length; i++) {
48 if (arithmetic_cp_fun_table_fp16[i].primitive_type_ ==
49 arithmetic_compare_f16->arithmetic_f16_.arithmetic_.primitive_type_ &&
50 arithmetic_cp_fun_table_fp16[i].activation_type_ == arithmetic_param->activation_type_) {
51 arithmetic_compare_f16->functions_ = arithmetic_cp_fun_table_fp16[i];
52 return;
53 }
54 }
55 }
56
ArithmeticCompareF16DoExecute(KernelBase * base,const void * input0,const void * input1,void * output,int64_t size)57 int ArithmeticCompareF16DoExecute(KernelBase *base, const void *input0, const void *input1, void *output,
58 int64_t size) {
59 ArithmeticCompareF16Struct *arithmetic_compare_f16 = (ArithmeticCompareF16Struct *)base;
60
61 if (arithmetic_compare_f16->arithmetic_f16_.arithmetic_.scalar_opt_) {
62 bool first_scalar = arithmetic_compare_f16->arithmetic_f16_.arithmetic_.in_elements_num0_ == 1;
63 NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare_f16->functions_.optimzie_);
64 return arithmetic_compare_f16->functions_.optimzie_((const float16_t *)input0, (const float16_t *)input1,
65 (uint8_t *)output, size, first_scalar);
66 }
67
68 NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare_f16->functions_.compute_);
69 return arithmetic_compare_f16->functions_.compute_((const float16_t *)input0, (const float16_t *)input1,
70 (uint8_t *)output, size);
71 }
ArithmeticCompareF16Compute(KernelBase * self)72 int ArithmeticCompareF16Compute(KernelBase *self) {
73 ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
74 NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
75 arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_);
76 arithmetic->out_data_size_ = DataTypeCSize(self->out_[OUTPUT_INDEX]->data_type_);
77 return ArithmeticF16Compute(self);
78 }
79
CreateArithmeticCompareF16(OpParameter * param,int data_type)80 KernelBase *CreateArithmeticCompareF16(OpParameter *param, int data_type) {
81 ArithmeticCompareF16Struct *arithmetic_compare_f16 =
82 (ArithmeticCompareF16Struct *)malloc(sizeof(ArithmeticCompareF16Struct));
83 NNACL_CHECK_NULL_RETURN_NULL(arithmetic_compare_f16);
84 memset(arithmetic_compare_f16, 0, sizeof(ArithmeticF16Struct));
85
86 ArithmeticStruct *arithmetic = &arithmetic_compare_f16->arithmetic_f16_.arithmetic_;
87 arithmetic->block_boundary_infos_size_ = 0;
88 arithmetic->a_matrix_.batch_post_sum_ = NULL;
89 arithmetic->b_matrix_.batch_post_sum_ = NULL;
90 arithmetic->c_matrix_.batch_post_sum_ = NULL;
91 arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL;
92 arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL;
93 arithmetic->base_.Prepare = ArithmeticPrepare;
94 arithmetic->base_.Resize = ArithmeticF16Resize;
95 arithmetic->base_.Release = ArithmeticRelease;
96 arithmetic->base_.Compute = ArithmeticCompareF16Compute;
97
98 arithmetic->execute_ = ArithmeticCompareF16DoExecute;
99 arithmetic->tile_function_ = TileOneDimensionFp16;
100 arithmetic->init_function_ = InitArithmeticCompareF16RunFunction;
101
102 return (KernelBase *)arithmetic_compare_f16;
103 }
104
105 REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeFloat16, CreateArithmeticCompareF16)
106 REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeFloat16, CreateArithmeticCompareF16)
107 REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeFloat16, CreateArithmeticCompareF16)
108 REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeFloat16, CreateArithmeticCompareF16)
109 REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeFloat16, CreateArithmeticCompareF16)
110 REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeFloat16, CreateArithmeticCompareF16)
111