1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/fp32/arithmetic_compare_fp32.h"
18
19 inline bool EqualFp32(float x, float y);
20 inline bool EqualBool(bool x, bool y);
21 inline bool NotEqualFp32(float x, float y);
22 inline bool LessFp32(float x, float y);
23 inline bool LessEqualFp32(float x, float y);
24 inline bool GreaterFp32(float x, float y);
25 inline bool GreaterEqualFp32(float x, float y);
26
27 inline bool EqualInt32(int x, int y);
28 inline bool NotEqualInt32(int x, int y);
29 inline bool NotEqualInt64(int64_t x, int64_t y);
30 inline bool LessInt32(int x, int y);
31 inline bool LessEqualInt32(int x, int y);
32 inline bool GreaterInt32(int x, int y);
33 inline bool GreaterEqualInt32(int x, int y);
34
EqualFp32(float x,float y)35 bool EqualFp32(float x, float y) { return x == y; }
EqualBool(bool x,bool y)36 bool EqualBool(bool x, bool y) { return x == y; }
NotEqualFp32(float x,float y)37 bool NotEqualFp32(float x, float y) { return x != y; }
LessFp32(float x,float y)38 bool LessFp32(float x, float y) { return x < y; }
LessEqualFp32(float x,float y)39 bool LessEqualFp32(float x, float y) { return x <= y; }
GreaterFp32(float x,float y)40 bool GreaterFp32(float x, float y) { return x > y; }
GreaterEqualFp32(float x,float y)41 bool GreaterEqualFp32(float x, float y) { return x >= y; }
42
EqualInt32(int x,int y)43 bool EqualInt32(int x, int y) { return x == y; }
NotEqualInt32(int x,int y)44 bool NotEqualInt32(int x, int y) { return x != y; }
NotEqualInt64(int64_t x,int64_t y)45 bool NotEqualInt64(int64_t x, int64_t y) { return x != y; }
LessInt32(int x,int y)46 bool LessInt32(int x, int y) { return x < y; }
LessEqualInt32(int x,int y)47 bool LessEqualInt32(int x, int y) { return x <= y; }
GreaterInt32(int x,int y)48 bool GreaterInt32(int x, int y) { return x > y; }
GreaterEqualInt32(int x,int y)49 bool GreaterEqualInt32(int x, int y) { return x >= y; }
50
51 #define ELEMENT_COMPARE(input0, input1, output, element_size, compare_func) \
52 do { \
53 for (int i = 0; i < element_size; i++) { \
54 output[i] = compare_func(input0[i], input1[i]); \
55 } \
56 return NNACL_OK; \
57 } while (0)
58
59 #define ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, compare_func) \
60 do { \
61 int i = 0; \
62 if (first_scalar) { \
63 for (; i < element_size; i++) { \
64 output[i] = compare_func(input0[0], input1[i]); \
65 } \
66 } else { \
67 for (; i < element_size; i++) { \
68 output[i] = compare_func(input0[i], input1[0]); \
69 } \
70 } \
71 return NNACL_OK; \
72 } while (0)
73
74 // equal:
ElementEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size)75 int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
76 ELEMENT_COMPARE(input0, input1, output, element_size, EqualFp32);
77 }
78
ElementEqualBool(const bool * input0,const bool * input1,uint8_t * output,int element_size)79 int ElementEqualBool(const bool *input0, const bool *input1, uint8_t *output, int element_size) {
80 ELEMENT_COMPARE(input0, input1, output, element_size, EqualBool);
81 }
82
ElementOptEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)83 int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
84 bool first_scalar) {
85 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, EqualFp32);
86 }
87
ElementEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)88 int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
89 ELEMENT_COMPARE(input0, input1, output, element_size, EqualInt32);
90 }
91
ElementOptEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)92 int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
93 bool first_scalar) {
94 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, EqualInt32);
95 }
96
97 // not equal
ElementNotEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size)98 int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
99 ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualFp32);
100 }
101
ElementOptNotEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)102 int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
103 bool first_scalar) {
104 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualFp32);
105 }
106
ElementNotEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)107 int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
108 ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualInt32);
109 }
110
ElementOptNotEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)111 int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
112 bool first_scalar) {
113 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualInt32);
114 }
115
ElementNotEqualInt64(const int64_t * input0,const int64_t * input1,uint8_t * output,int element_size)116 int ElementNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size) {
117 ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualInt64);
118 }
119
ElementOptNotEqualInt64(const int64_t * input0,const int64_t * input1,uint8_t * output,int element_size,bool first_scalar)120 int ElementOptNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size,
121 bool first_scalar) {
122 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualInt64);
123 }
124
125 // less
ElementLessFp32(const float * input0,const float * input1,uint8_t * output,int element_size)126 int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
127 ELEMENT_COMPARE(input0, input1, output, element_size, LessFp32);
128 }
129
ElementOptLessFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)130 int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar) {
131 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessFp32);
132 }
133
ElementLessInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)134 int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
135 ELEMENT_COMPARE(input0, input1, output, element_size, LessInt32);
136 }
137
ElementOptLessInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)138 int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
139 bool first_scalar) {
140 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessInt32);
141 }
142
143 // less equal
ElementLessEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size)144 int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
145 ELEMENT_COMPARE(input0, input1, output, element_size, LessEqualFp32);
146 }
147
ElementOptLessEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)148 int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
149 bool first_scalar) {
150 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessEqualFp32);
151 }
152
ElementLessEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)153 int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
154 ELEMENT_COMPARE(input0, input1, output, element_size, LessEqualInt32);
155 }
156
ElementOptLessEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)157 int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
158 bool first_scalar) {
159 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessEqualInt32);
160 }
161
162 // greater
ElementGreaterFp32(const float * input0,const float * input1,uint8_t * output,int element_size)163 int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
164 ELEMENT_COMPARE(input0, input1, output, element_size, GreaterFp32);
165 }
166
ElementOptGreaterFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)167 int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
168 bool first_scalar) {
169 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterFp32);
170 }
171
ElementGreaterInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)172 int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
173 ELEMENT_COMPARE(input0, input1, output, element_size, GreaterInt32);
174 }
175
ElementOptGreaterInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)176 int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
177 bool first_scalar) {
178 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterInt32);
179 }
180
181 // greater equal
ElementGreaterEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size)182 int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
183 ELEMENT_COMPARE(input0, input1, output, element_size, GreaterEqualFp32);
184 }
185
ElementOptGreaterEqualFp32(const float * input0,const float * input1,uint8_t * output,int element_size,bool first_scalar)186 int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
187 bool first_scalar) {
188 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterEqualFp32);
189 }
190
ElementGreaterEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size)191 int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
192 ELEMENT_COMPARE(input0, input1, output, element_size, GreaterEqualInt32);
193 }
194
ElementOptGreaterEqualInt32(const int32_t * input0,const int32_t * input1,uint8_t * output,int element_size,bool first_scalar)195 int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
196 bool first_scalar) {
197 ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterEqualInt32);
198 }
199