• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/arithmetic_compare.h"
18 #include "nnacl/kernel/arithmetic.h"
19 #include "nnacl/fp32/arithmetic_fp32.h"
20 #include "nnacl/fp32/arithmetic_compare_fp32.h"
21 
22 typedef struct ArithmeticCompareFuncions {
23   int primitive_type_;
24   int (*compute_f32_)(const float *input0, const float *input1, uint8_t *output, int element_size);
25   int (*compute_i32_)(const int *input0, const int *input1, uint8_t *output, int element_size);
26   int (*optimize_f32)(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar);
27   int (*optimize_i32)(const int *input0, const int *input1, uint8_t *output, int element_size, bool first_scalar);
28   int (*compute_i64)(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size);
29   int (*optimize_i64)(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size,
30                       bool first_scalar);
31   int (*compute_bool)(const bool *input0, const bool *input1, uint8_t *output, int element_size);
32 } ArithmeticCompareFuncions;
33 
34 typedef struct ArithmeticCompareStruct {
35   ArithmeticStruct arithmetic_;
36   ArithmeticCompareFuncions functions_;
37 } ArithmeticCompareStruct;
38 
InitArithmeticCompareRunFunction(KernelBase * self)39 void InitArithmeticCompareRunFunction(KernelBase *self) {
40   ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)self;
41   NNACL_CHECK_NULL_RETURN_VOID(arithmetic_compare);
42 
43   ArithmeticCompareFuncions fun_table[] = {
44     {PrimType_Equal, ElementEqualFp32, ElementEqualInt32, ElementOptEqualFp32, ElementOptEqualInt32, NULL, NULL,
45      ElementEqualBool},
46     {PrimType_NotEqual, ElementNotEqualFp32, ElementNotEqualInt32, ElementOptNotEqualFp32, ElementOptNotEqualInt32,
47      ElementNotEqualInt64, ElementOptNotEqualInt64, NULL},
48     {PrimType_Less, ElementLessFp32, ElementLessInt32, ElementOptLessFp32, ElementOptLessInt32, NULL, NULL, NULL},
49     {PrimType_LessEqual, ElementLessEqualFp32, ElementLessEqualInt32, ElementOptLessEqualFp32, ElementOptLessEqualInt32,
50      NULL, NULL, NULL},
51     {PrimType_Greater, ElementGreaterFp32, ElementGreaterInt32, ElementOptGreaterFp32, ElementOptGreaterInt32, NULL,
52      NULL, NULL},
53     {PrimType_GreaterEqual, ElementGreaterEqualFp32, ElementGreaterEqualInt32, ElementOptGreaterEqualFp32,
54      ElementOptGreaterEqualInt32, NULL, NULL, NULL}};
55 
56   size_t length = sizeof(fun_table) / sizeof(ArithmeticCompareFuncions);
57   for (size_t i = 0; i < length; i++) {
58     if (fun_table[i].primitive_type_ == arithmetic_compare->arithmetic_.primitive_type_) {
59       arithmetic_compare->functions_ = fun_table[i];
60       return;
61     }
62   }
63 }
64 
ArithmeticCompareExecute(KernelBase * base,const void * input0,const void * input1,void * output,int64_t size)65 int ArithmeticCompareExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) {
66   ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)base;
67   NNACL_CHECK_NULL_RETURN_ERR(input0);
68   NNACL_CHECK_NULL_RETURN_ERR(input1);
69 
70   int data_type = base->in_[FIRST_INPUT]->data_type_;
71   bool first_scalar = arithmetic_compare->arithmetic_.in_elements_num0_ == 1;
72 
73   if (data_type == kNumberTypeFloat32) {
74     if (arithmetic_compare->arithmetic_.scalar_opt_) {
75       NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_f32);
76       return arithmetic_compare->functions_.optimize_f32((const float *)input0, (const float *)input1,
77                                                          (uint8_t *)output, size, first_scalar);
78     } else {
79       NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_f32_);
80       return arithmetic_compare->functions_.compute_f32_((const float *)input0, (const float *)input1,
81                                                          (uint8_t *)output, size);
82     }
83   }
84 
85   if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) {
86     if (arithmetic_compare->arithmetic_.scalar_opt_) {
87       NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_i32);
88       return arithmetic_compare->functions_.optimize_i32((const int *)input0, (const int *)input1, (uint8_t *)output,
89                                                          size, first_scalar);
90     } else {
91       NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_i32_);
92       return arithmetic_compare->functions_.compute_i32_((const int *)input0, (const int *)input1, (uint8_t *)output,
93                                                          size);
94     }
95   }
96 
97   if (data_type == kNumberTypeInt64) {
98     if (arithmetic_compare->arithmetic_.scalar_opt_) {
99       NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_i64);
100       return arithmetic_compare->functions_.optimize_i64((const int64_t *)input0, (const int64_t *)input1,
101                                                          (uint8_t *)output, size, first_scalar);
102     } else {
103       NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_i64);
104       return arithmetic_compare->functions_.compute_i64((const int64_t *)input0, (const int64_t *)input1,
105                                                         (uint8_t *)output, size);
106     }
107   }
108   if (data_type == kNumberTypeBool) {
109     if (!arithmetic_compare->arithmetic_.scalar_opt_) {
110       NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_bool);
111       return arithmetic_compare->functions_.compute_bool((const bool *)input0, (const bool *)input1, (uint8_t *)output,
112                                                          size);
113     } else {
114       return NNACL_UNSUPPORTED_DATA_TYPE;
115     }
116   }
117 
118   return NNACL_UNSUPPORTED_DATA_TYPE;
119 }
120 
ArithmeticCompareResize(KernelBase * self)121 int ArithmeticCompareResize(KernelBase *self) {
122   ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
123   NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
124   arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_);
125   arithmetic->out_data_size_ = DataTypeCSize(self->out_[OUTPUT_INDEX]->data_type_);
126   return ArithmeticResize(self);
127 }
128 
CreateArithmeticCompare(OpParameter * param,int data_type)129 KernelBase *CreateArithmeticCompare(OpParameter *param, int data_type) {
130   ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)malloc(sizeof(ArithmeticCompareStruct));
131   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic_compare);
132   memset(arithmetic_compare, 0, sizeof(ArithmeticCompareStruct));
133 
134   ArithmeticStruct *arithmetic = (ArithmeticStruct *)arithmetic_compare;
135   arithmetic->in_data_size_ = DataTypeCSize(data_type);
136   arithmetic->out_data_size_ = DataTypeCSize(data_type);
137   arithmetic->block_boundary_infos_size_ = 0;
138   arithmetic->a_matrix_.batch_post_sum_ = NULL;
139   arithmetic->b_matrix_.batch_post_sum_ = NULL;
140   arithmetic->c_matrix_.batch_post_sum_ = NULL;
141   arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL;
142   arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL;
143   arithmetic->tile_function_ = TileOneDimensionFp32;
144   arithmetic->init_function_ = InitArithmeticCompareRunFunction;
145   arithmetic->execute_ = ArithmeticCompareExecute;
146   arithmetic->base_.Prepare = ArithmeticPrepare;
147   arithmetic->base_.Resize = ArithmeticCompareResize;
148   arithmetic->base_.Release = ArithmeticRelease;
149   arithmetic->base_.Compute = ArithmeticCompute;
150   return (KernelBase *)arithmetic_compare;
151 }
152 
153 REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeFloat32, CreateArithmeticCompare)
154 REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeBool, CreateArithmeticCompare)
155 REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeInt32, CreateArithmeticCompare)
156 REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeFloat32, CreateArithmeticCompare)
157 REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeInt32, CreateArithmeticCompare)
158 REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeInt64, CreateArithmeticCompare)
159 REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeFloat32, CreateArithmeticCompare)
160 REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeInt32, CreateArithmeticCompare)
161 REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeFloat32, CreateArithmeticCompare)
162 REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeInt32, CreateArithmeticCompare)
163 REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeFloat32, CreateArithmeticCompare)
164 REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeInt32, CreateArithmeticCompare)
165 REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeFloat32, CreateArithmeticCompare)
166 REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeInt32, CreateArithmeticCompare)
167