1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/fp16/arg_min_max_fp16.h"
18
ArgCompareAscFp16(const void * a,const void * b)19 int ArgCompareAscFp16(const void *a, const void *b) {
20 float16_t a_value = ((ArgElement *)a)->data_.f16_data_;
21 float16_t b_value = ((ArgElement *)b)->data_.f16_data_;
22 if (b_value > a_value) {
23 return -1;
24 }
25 if (b_value < a_value) {
26 return 1;
27 }
28
29 return 0;
30 }
31
ArgCompareDescFp16(const void * a,const void * b)32 int ArgCompareDescFp16(const void *a, const void *b) {
33 float16_t b_value = ((ArgElement *)b)->data_.f16_data_;
34 float16_t a_value = ((ArgElement *)a)->data_.f16_data_;
35 if (b_value > a_value) {
36 return 1;
37 }
38 if (b_value < a_value) {
39 return -1;
40 }
41
42 return 0;
43 }
44
ArgMaxTopK1Fp16(const float16_t * input,void * output,float16_t * output_value,const ArgMinMaxComputeParam * param,int pre_axis_count,int axis_count,int after_axis_count)45 void ArgMaxTopK1Fp16(const float16_t *input, void *output, float16_t *output_value, const ArgMinMaxComputeParam *param,
46 int pre_axis_count, int axis_count, int after_axis_count) {
47 bool out_value = param->out_value_;
48 float16_t *outputfp16 = (float16_t *)output;
49 int *outputint = (int *)output;
50 for (int i = 0; i < pre_axis_count; ++i) {
51 size_t output_offset = i * after_axis_count;
52 size_t input_offset = output_offset * axis_count;
53 for (int j = 0; j < after_axis_count; ++j) {
54 float16_t value = -FLT_MAX;
55 int index = 0;
56 for (int k = 0; k < axis_count; ++k) {
57 float16_t value_tmp = input[input_offset + k * after_axis_count + j];
58 if (value_tmp > value) {
59 value = value_tmp;
60 index = k;
61 }
62 }
63 if (out_value) {
64 outputfp16[output_offset + j] = value;
65 } else {
66 outputint[output_offset + j] = index;
67 }
68 if (output_value != NULL) {
69 output_value[output_offset + j] = value;
70 }
71 }
72 }
73 }
74
ArgMinTopK1Fp16(const float16_t * input,void * output,float16_t * output_value,const ArgMinMaxComputeParam * param,int pre_axis_count,int axis_count,int after_axis_count)75 void ArgMinTopK1Fp16(const float16_t *input, void *output, float16_t *output_value, const ArgMinMaxComputeParam *param,
76 int pre_axis_count, int axis_count, int after_axis_count) {
77 bool out_value = param->out_value_;
78 float16_t *outputfp16 = (float16_t *)output;
79 int *outputint = (int *)output;
80 for (int i = 0; i < pre_axis_count; ++i) {
81 size_t output_offset = i * after_axis_count;
82 size_t input_offset = output_offset * axis_count;
83 for (int j = 0; j < after_axis_count; ++j) {
84 float16_t value = FLT_MAX;
85 int index = 0;
86 for (int k = 0; k < axis_count; ++k) {
87 float16_t value_tmp = input[input_offset + k * after_axis_count + j];
88 if (value_tmp < value) {
89 value = value_tmp;
90 index = k;
91 }
92 }
93 if (out_value) {
94 outputfp16[output_offset + j] = value;
95 } else {
96 outputint[output_offset + j] = index;
97 }
98 if (output_value != NULL) {
99 output_value[output_offset + j] = value;
100 }
101 }
102 }
103 }
104
ArgMinMaxDim0Fp16(const float16_t * input,void * output,float16_t * output_value,const int * in_shape,const ArgMinMaxComputeParam * param,COMPARE_FUNCTION compare_func)105 void ArgMinMaxDim0Fp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape,
106 const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) {
107 float16_t *outputfp16 = (float16_t *)output;
108 int *outputint = (int *)output;
109 for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
110 for (int j = 0; j < in_shape[0]; ++j) {
111 size_t offset = param->in_strides_[0] * j + i;
112 param->arg_elements_[j].index_ = j;
113 param->arg_elements_[j].data_.f16_data_ = input[offset];
114 }
115 qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), *compare_func);
116 for (int j = 0; j < param->topk_; ++j) {
117 size_t out_offset = j * param->out_strides_[0] + i;
118 if (param->out_value_) {
119 outputfp16[out_offset] = param->arg_elements_[j].data_.f16_data_;
120 } else {
121 outputint[out_offset] = param->arg_elements_[j].index_;
122 }
123 if (output_value != NULL) {
124 output_value[out_offset] = param->arg_elements_[j].data_.f16_data_;
125 }
126 }
127 }
128 return;
129 }
130
ArgMinMaxDim1Fp16(const float16_t * input,void * output,float16_t * output_value,const int * in_shape,const ArgMinMaxComputeParam * param,COMPARE_FUNCTION compare_func)131 void ArgMinMaxDim1Fp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape,
132 const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) {
133 int in_shape1 = in_shape[1];
134 float16_t *outputfp16 = (float16_t *)output;
135 int *outputint = (int *)output;
136 for (int i = 0; i < in_shape[0]; ++i) {
137 size_t in_dim0_offset = i * param->in_strides_[0];
138 size_t out_dim0_offset = i * param->out_strides_[0];
139 for (int j = 0; j < param->in_strides_[1]; ++j) {
140 for (int k = 0; k < in_shape1; ++k) {
141 size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
142 param->arg_elements_[k].index_ = k;
143 param->arg_elements_[k].data_.f16_data_ = input[offset];
144 }
145 qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), *compare_func);
146 for (int k = 0; k < param->topk_; ++k) {
147 size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
148 if (param->out_value_) {
149 outputfp16[out_offset] = param->arg_elements_[k].data_.f16_data_;
150 } else {
151 outputint[out_offset] = param->arg_elements_[k].index_;
152 }
153 if (output_value != NULL) {
154 output_value[out_offset] = param->arg_elements_[k].data_.f16_data_;
155 }
156 }
157 }
158 }
159 return;
160 }
161
ArgMinMaxDim2Fp16(const float16_t * input,float16_t * output,float16_t * output_value,const int * in_shape,const ArgMinMaxComputeParam * param,COMPARE_FUNCTION compare_func)162 void ArgMinMaxDim2Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
163 const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) {
164 int in_shape1 = in_shape[1];
165 int in_shape2 = in_shape[2];
166 float *outputfp16 = (float *)output;
167 int *outputint = (int *)output;
168 for (int i = 0; i < in_shape[0]; ++i) {
169 size_t in_dim0_offset = i * param->in_strides_[0];
170 size_t out_dim0_offset = i * param->out_strides_[0];
171 for (int j = 0; j < in_shape1; ++j) {
172 size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
173 size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
174 for (int k = 0; k < param->in_strides_[2]; ++k) {
175 for (int l = 0; l < in_shape2; ++l) {
176 size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
177 param->arg_elements_[l].index_ = l;
178 param->arg_elements_[l].data_.f16_data_ = input[offset];
179 }
180 qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), *compare_func);
181 for (int l = 0; l < param->topk_; ++l) {
182 size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
183 if (param->out_value_) {
184 outputfp16[out_offset] = param->arg_elements_[l].data_.f16_data_;
185 } else {
186 outputint[out_offset] = param->arg_elements_[l].index_;
187 }
188 if (output_value != NULL) {
189 output_value[out_offset] = param->arg_elements_[l].data_.f16_data_;
190 }
191 }
192 }
193 }
194 }
195 }
196
ArgMinMaxDim3Fp16(const float16_t * input,float16_t * output,float16_t * output_value,const int * in_shape,const ArgMinMaxComputeParam * param,COMPARE_FUNCTION compare_func)197 void ArgMinMaxDim3Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
198 const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) {
199 int in_shape1 = in_shape[1];
200 int in_shape2 = in_shape[2];
201 int in_shape3 = in_shape[3];
202 float *outputfp16 = (float *)output;
203 int *outputint = (int *)output;
204 for (int i = 0; i < in_shape[0]; ++i) {
205 size_t in_dim0_offset = i * param->in_strides_[0];
206 size_t out_dim0_offset = i * param->out_strides_[0];
207 for (int j = 0; j < in_shape1; ++j) {
208 size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
209 size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
210 for (int k = 0; k < in_shape2; ++k) {
211 size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
212 size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
213 for (int l = 0; l < in_shape3; ++l) {
214 size_t offset = l + in_dim2_offset;
215 param->arg_elements_[l].index_ = l;
216 param->arg_elements_[l].data_.f16_data_ = input[offset];
217 }
218 qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), *compare_func);
219 for (int l = 0; l < param->topk_; ++l) {
220 size_t out_offset = out_dim2_offset + l;
221 if (param->out_value_) {
222 outputfp16[out_offset] = param->arg_elements_[l].data_.f16_data_;
223 } else {
224 outputint[out_offset] = param->arg_elements_[l].index_;
225 }
226 if (output_value != NULL) {
227 output_value[out_offset] = param->arg_elements_[l].data_.f16_data_;
228 }
229 }
230 }
231 }
232 }
233 }
234
ArgMinMaxFp16(const float16_t * input,void * output,float16_t * output_value,const int * in_shape,const ArgMinMaxComputeParam * param)235 void ArgMinMaxFp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape,
236 const ArgMinMaxComputeParam *param) {
237 if (param->topk_ == 1) {
238 int pre_axis_count = 1;
239 int axis_count = 1;
240 int after_axis_count = 1;
241 ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count);
242
243 if (param->get_max_) {
244 ArgMaxTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count);
245 } else {
246 ArgMinTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count);
247 }
248 return;
249 }
250
251 COMPARE_FUNCTION compare_function = NULL;
252 if (param->get_max_) {
253 compare_function = ArgCompareDescFp16;
254 } else {
255 compare_function = ArgCompareAscFp16;
256 }
257
258 switch (param->axis_) {
259 case 0:
260 ArgMinMaxDim0Fp16(input, output, output_value, in_shape, param, compare_function);
261 break;
262 case 1:
263 ArgMinMaxDim1Fp16(input, output, output_value, in_shape, param, compare_function);
264 break;
265 case 2:
266 ArgMinMaxDim2Fp16(input, output, output_value, in_shape, param, compare_function);
267 break;
268 case 3:
269 ArgMinMaxDim3Fp16(input, output, output_value, in_shape, param, compare_function);
270 break;
271 }
272 return;
273 }
274