• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/f16/arithmetic_compare_f16.h"
18 #include "nnacl/kernel/f16/arithmetic_f16.h"
19 #include "nnacl/fp16/arithmetic_fp16.h"
20 
21 typedef struct ArithmeticCompareF16Funcions {
22   int primitive_type_;
23   int activation_type_;
24   int (*compute_)(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size);
25   int (*optimzie_)(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
26                    bool first_scalar);
27 } ArithmeticCompareF16Funcions;
28 
29 typedef struct ArithmeticCompareF16Struct {
30   ArithmeticF16Struct arithmetic_f16_;
31   ArithmeticCompareF16Funcions functions_;
32 } ArithmeticCompareF16Struct;
33 
InitArithmeticCompareF16RunFunction(KernelBase * base)34 void InitArithmeticCompareF16RunFunction(KernelBase *base) {
35   ArithmeticCompareF16Struct *arithmetic_compare_f16 = (ArithmeticCompareF16Struct *)base;
36   ArithmeticParameter *arithmetic_param = (ArithmeticParameter *)base->param_;
37 
38   ArithmeticCompareF16Funcions arithmetic_cp_fun_table_fp16[] = {
39     {PrimType_NotEqual, ActType_No, ElementNotEqualFp16, ElementOptNotEqualFp16},
40     {PrimType_Equal, ActType_No, ElementEqualFp16, ElementOptEqualFp16},
41     {PrimType_Less, ActType_No, ElementLessFp16, ElementOptLessFp16},
42     {PrimType_LessEqual, ActType_No, ElementLessEqualFp16, ElementOptLessEqualFp16},
43     {PrimType_Greater, ActType_No, ElementGreaterFp16, ElementOptGreaterFp16},
44     {PrimType_GreaterEqual, ActType_No, ElementGreaterEqualFp16, ElementOptGreaterEqualFp16}};
45 
46   size_t length = sizeof(arithmetic_cp_fun_table_fp16) / sizeof(ArithmeticCompareF16Funcions);
47   for (size_t i = 0; i < length; i++) {
48     if (arithmetic_cp_fun_table_fp16[i].primitive_type_ ==
49           arithmetic_compare_f16->arithmetic_f16_.arithmetic_.primitive_type_ &&
50         arithmetic_cp_fun_table_fp16[i].activation_type_ == arithmetic_param->activation_type_) {
51       arithmetic_compare_f16->functions_ = arithmetic_cp_fun_table_fp16[i];
52       return;
53     }
54   }
55 }
56 
ArithmeticCompareF16DoExecute(KernelBase * base,const void * input0,const void * input1,void * output,int64_t size)57 int ArithmeticCompareF16DoExecute(KernelBase *base, const void *input0, const void *input1, void *output,
58                                   int64_t size) {
59   ArithmeticCompareF16Struct *arithmetic_compare_f16 = (ArithmeticCompareF16Struct *)base;
60 
61   if (arithmetic_compare_f16->arithmetic_f16_.arithmetic_.scalar_opt_) {
62     bool first_scalar = arithmetic_compare_f16->arithmetic_f16_.arithmetic_.in_elements_num0_ == 1;
63     NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare_f16->functions_.optimzie_);
64     return arithmetic_compare_f16->functions_.optimzie_((const float16_t *)input0, (const float16_t *)input1,
65                                                         (uint8_t *)output, size, first_scalar);
66   }
67 
68   NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare_f16->functions_.compute_);
69   return arithmetic_compare_f16->functions_.compute_((const float16_t *)input0, (const float16_t *)input1,
70                                                      (uint8_t *)output, size);
71 }
ArithmeticCompareF16Compute(KernelBase * self)72 int ArithmeticCompareF16Compute(KernelBase *self) {
73   ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
74   NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
75   arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_);
76   arithmetic->out_data_size_ = DataTypeCSize(self->out_[OUTPUT_INDEX]->data_type_);
77   return ArithmeticF16Compute(self);
78 }
79 
CreateArithmeticCompareF16(OpParameter * param,int data_type)80 KernelBase *CreateArithmeticCompareF16(OpParameter *param, int data_type) {
81   ArithmeticCompareF16Struct *arithmetic_compare_f16 =
82     (ArithmeticCompareF16Struct *)malloc(sizeof(ArithmeticCompareF16Struct));
83   NNACL_CHECK_NULL_RETURN_NULL(arithmetic_compare_f16);
84   memset(arithmetic_compare_f16, 0, sizeof(ArithmeticF16Struct));
85 
86   ArithmeticStruct *arithmetic = &arithmetic_compare_f16->arithmetic_f16_.arithmetic_;
87   arithmetic->block_boundary_infos_size_ = 0;
88   arithmetic->a_matrix_.batch_post_sum_ = NULL;
89   arithmetic->b_matrix_.batch_post_sum_ = NULL;
90   arithmetic->c_matrix_.batch_post_sum_ = NULL;
91   arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL;
92   arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL;
93   arithmetic->base_.Prepare = ArithmeticPrepare;
94   arithmetic->base_.Resize = ArithmeticF16Resize;
95   arithmetic->base_.Release = ArithmeticRelease;
96   arithmetic->base_.Compute = ArithmeticCompareF16Compute;
97 
98   arithmetic->execute_ = ArithmeticCompareF16DoExecute;
99   arithmetic->tile_function_ = TileOneDimensionFp16;
100   arithmetic->init_function_ = InitArithmeticCompareF16RunFunction;
101 
102   return (KernelBase *)arithmetic_compare_f16;
103 }
104 
105 REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeFloat16, CreateArithmeticCompareF16)
106 REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeFloat16, CreateArithmeticCompareF16)
107 REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeFloat16, CreateArithmeticCompareF16)
108 REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeFloat16, CreateArithmeticCompareF16)
109 REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeFloat16, CreateArithmeticCompareF16)
110 REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeFloat16, CreateArithmeticCompareF16)
111