• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/fp32/arithmetic_compare_fp32.h"
18 
19 inline bool EqualFp32(float x, float y);
20 inline bool EqualBool(bool x, bool y);
21 inline bool NotEqualFp32(float x, float y);
22 inline bool LessFp32(float x, float y);
23 inline bool LessEqualFp32(float x, float y);
24 inline bool GreaterFp32(float x, float y);
25 inline bool GreaterEqualFp32(float x, float y);
26 
27 inline bool EqualInt32(int x, int y);
28 inline bool NotEqualInt32(int x, int y);
29 inline bool NotEqualInt64(int64_t x, int64_t y);
30 inline bool LessInt32(int x, int y);
31 inline bool LessEqualInt32(int x, int y);
32 inline bool GreaterInt32(int x, int y);
33 inline bool GreaterEqualInt32(int x, int y);
34 
EqualFp32(float x,float y)35 bool EqualFp32(float x, float y) { return x == y; }
EqualBool(bool x,bool y)36 bool EqualBool(bool x, bool y) { return x == y; }
NotEqualFp32(float x,float y)37 bool NotEqualFp32(float x, float y) { return x != y; }
LessFp32(float x,float y)38 bool LessFp32(float x, float y) { return x < y; }
LessEqualFp32(float x,float y)39 bool LessEqualFp32(float x, float y) { return x <= y; }
GreaterFp32(float x,float y)40 bool GreaterFp32(float x, float y) { return x > y; }
GreaterEqualFp32(float x,float y)41 bool GreaterEqualFp32(float x, float y) { return x >= y; }
42 
EqualInt32(int x,int y)43 bool EqualInt32(int x, int y) { return x == y; }
NotEqualInt32(int x,int y)44 bool NotEqualInt32(int x, int y) { return x != y; }
NotEqualInt64(int64_t x,int64_t y)45 bool NotEqualInt64(int64_t x, int64_t y) { return x != y; }
LessInt32(int x,int y)46 bool LessInt32(int x, int y) { return x < y; }
LessEqualInt32(int x,int y)47 bool LessEqualInt32(int x, int y) { return x <= y; }
GreaterInt32(int x,int y)48 bool GreaterInt32(int x, int y) { return x > y; }
GreaterEqualInt32(int x,int y)49 bool GreaterEqualInt32(int x, int y) { return x >= y; }
50 
51 #define ELEMENT_COMPARE(input0, input1, output, element_size, compare_func) \
52   do {                                                                      \
53     for (int i = 0; i < element_size; i++) {                                \
54       output[i] = compare_func(input0[i], input1[i]);                       \
55     }                                                                       \
56     return NNACL_OK;                                                        \
57   } while (0)
58 
59 #define ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, compare_func) \
60   do {                                                                                        \
61     int i = 0;                                                                                \
62     if (first_scalar) {                                                                       \
63       for (; i < element_size; i++) {                                                         \
64         output[i] = compare_func(input0[0], input1[i]);                                       \
65       }                                                                                       \
66     } else {                                                                                  \
67       for (; i < element_size; i++) {                                                         \
68         output[i] = compare_func(input0[i], input1[0]);                                       \
69       }                                                                                       \
70     }                                                                                         \
71     return NNACL_OK;                                                                          \
72   } while (0)
73 
74 // equal:
ElementEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size)75 int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
76   ELEMENT_COMPARE(input0, input1, output, element_size, EqualFp32);
77 }
78 
ElementEqualBool(const bool * input0,const bool * input1,uint8_t * output,int element_size)79 int ElementEqualBool(const bool *input0, const bool *input1, uint8_t *output, int element_size) {
80   ELEMENT_COMPARE(input0, input1, output, element_size, EqualBool);
81 }
82 
ElementOptEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)83 int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
84                         bool first_scalar) {
85   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, EqualFp32);
86 }
87 
ElementEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)88 int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
89   ELEMENT_COMPARE(input0, input1, output, element_size, EqualInt32);
90 }
91 
ElementOptEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)92 int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
93                          bool first_scalar) {
94   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, EqualInt32);
95 }
96 
97 // not equal
ElementNotEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size)98 int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
99   ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualFp32);
100 }
101 
ElementOptNotEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)102 int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
103                            bool first_scalar) {
104   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualFp32);
105 }
106 
ElementNotEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)107 int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
108   ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualInt32);
109 }
110 
ElementOptNotEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)111 int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
112                             bool first_scalar) {
113   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualInt32);
114 }
115 
ElementNotEqualInt64(const int64_t * input0,const int64_t * input1,uint8_t * output,int element_size)116 int ElementNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size) {
117   ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualInt64);
118 }
119 
ElementOptNotEqualInt64(const int64_t * input0,const int64_t * input1,uint8_t * output,int element_size,bool first_scalar)120 int ElementOptNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size,
121                             bool first_scalar) {
122   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualInt64);
123 }
124 
125 // less
ElementLessFp32(const float * input0,const float * input1,uint8_t * output,int element_size)126 int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
127   ELEMENT_COMPARE(input0, input1, output, element_size, LessFp32);
128 }
129 
ElementOptLessFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)130 int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar) {
131   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessFp32);
132 }
133 
ElementLessInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)134 int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
135   ELEMENT_COMPARE(input0, input1, output, element_size, LessInt32);
136 }
137 
ElementOptLessInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)138 int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
139                         bool first_scalar) {
140   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessInt32);
141 }
142 
143 // less equal
ElementLessEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size)144 int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
145   ELEMENT_COMPARE(input0, input1, output, element_size, LessEqualFp32);
146 }
147 
ElementOptLessEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)148 int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
149                             bool first_scalar) {
150   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessEqualFp32);
151 }
152 
ElementLessEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)153 int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
154   ELEMENT_COMPARE(input0, input1, output, element_size, LessEqualInt32);
155 }
156 
ElementOptLessEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)157 int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
158                              bool first_scalar) {
159   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessEqualInt32);
160 }
161 
162 // greater
ElementGreaterFp32(const float * input0,const float * input1,uint8_t * output,int element_size)163 int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
164   ELEMENT_COMPARE(input0, input1, output, element_size, GreaterFp32);
165 }
166 
ElementOptGreaterFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)167 int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
168                           bool first_scalar) {
169   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterFp32);
170 }
171 
ElementGreaterInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)172 int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
173   ELEMENT_COMPARE(input0, input1, output, element_size, GreaterInt32);
174 }
175 
ElementOptGreaterInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)176 int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
177                            bool first_scalar) {
178   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterInt32);
179 }
180 
181 // greater equal
ElementGreaterEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size)182 int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
183   ELEMENT_COMPARE(input0, input1, output, element_size, GreaterEqualFp32);
184 }
185 
ElementOptGreaterEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)186 int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
187                                bool first_scalar) {
188   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterEqualFp32);
189 }
190 
ElementGreaterEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)191 int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
192   ELEMENT_COMPARE(input0, input1, output, element_size, GreaterEqualInt32);
193 }
194 
ElementOptGreaterEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)195 int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
196                                 bool first_scalar) {
197   ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterEqualInt32);
198 }
199