• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/fp16/arg_min_max_fp16.h"
18 
ArgCompareAscFp16(const void * a,const void * b)19 int ArgCompareAscFp16(const void *a, const void *b) {
20   float16_t a_value = ((ArgElement *)a)->data_.f16_data_;
21   float16_t b_value = ((ArgElement *)b)->data_.f16_data_;
22   if (b_value > a_value) {
23     return -1;
24   }
25   if (b_value < a_value) {
26     return 1;
27   }
28 
29   return 0;
30 }
31 
ArgCompareDescFp16(const void * a,const void * b)32 int ArgCompareDescFp16(const void *a, const void *b) {
33   float16_t b_value = ((ArgElement *)b)->data_.f16_data_;
34   float16_t a_value = ((ArgElement *)a)->data_.f16_data_;
35   if (b_value > a_value) {
36     return 1;
37   }
38   if (b_value < a_value) {
39     return -1;
40   }
41 
42   return 0;
43 }
44 
ArgMaxTopK1Fp16(const float16_t * input,void * output,float16_t * output_value,const ArgMinMaxComputeParam * param,int pre_axis_count,int axis_count,int after_axis_count)45 void ArgMaxTopK1Fp16(const float16_t *input, void *output, float16_t *output_value, const ArgMinMaxComputeParam *param,
46                      int pre_axis_count, int axis_count, int after_axis_count) {
47   bool out_value = param->out_value_;
48   float16_t *outputfp16 = (float16_t *)output;
49   int *outputint = (int *)output;
50   for (int i = 0; i < pre_axis_count; ++i) {
51     size_t output_offset = i * after_axis_count;
52     size_t input_offset = output_offset * axis_count;
53     for (int j = 0; j < after_axis_count; ++j) {
54       float16_t value = -FLT_MAX;
55       int index = 0;
56       for (int k = 0; k < axis_count; ++k) {
57         float16_t value_tmp = input[input_offset + k * after_axis_count + j];
58         if (value_tmp > value) {
59           value = value_tmp;
60           index = k;
61         }
62       }
63       if (out_value) {
64         outputfp16[output_offset + j] = value;
65       } else {
66         outputint[output_offset + j] = index;
67       }
68       if (output_value != NULL) {
69         output_value[output_offset + j] = value;
70       }
71     }
72   }
73 }
74 
ArgMinTopK1Fp16(const float16_t * input,void * output,float16_t * output_value,const ArgMinMaxComputeParam * param,int pre_axis_count,int axis_count,int after_axis_count)75 void ArgMinTopK1Fp16(const float16_t *input, void *output, float16_t *output_value, const ArgMinMaxComputeParam *param,
76                      int pre_axis_count, int axis_count, int after_axis_count) {
77   bool out_value = param->out_value_;
78   float16_t *outputfp16 = (float16_t *)output;
79   int *outputint = (int *)output;
80   for (int i = 0; i < pre_axis_count; ++i) {
81     size_t output_offset = i * after_axis_count;
82     size_t input_offset = output_offset * axis_count;
83     for (int j = 0; j < after_axis_count; ++j) {
84       float16_t value = FLT_MAX;
85       int index = 0;
86       for (int k = 0; k < axis_count; ++k) {
87         float16_t value_tmp = input[input_offset + k * after_axis_count + j];
88         if (value_tmp < value) {
89           value = value_tmp;
90           index = k;
91         }
92       }
93       if (out_value) {
94         outputfp16[output_offset + j] = value;
95       } else {
96         outputint[output_offset + j] = index;
97       }
98       if (output_value != NULL) {
99         output_value[output_offset + j] = value;
100       }
101     }
102   }
103 }
104 
ArgMinMaxDim0Fp16(const float16_t * input,void * output,float16_t * output_value,const int * in_shape,const ArgMinMaxComputeParam * param,COMPARE_FUNCTION compare_func)105 void ArgMinMaxDim0Fp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape,
106                        const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) {
107   float16_t *outputfp16 = (float16_t *)output;
108   int *outputint = (int *)output;
109   for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
110     for (int j = 0; j < in_shape[0]; ++j) {
111       size_t offset = param->in_strides_[0] * j + i;
112       param->arg_elements_[j].index_ = j;
113       param->arg_elements_[j].data_.f16_data_ = input[offset];
114     }
115     qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), *compare_func);
116     for (int j = 0; j < param->topk_; ++j) {
117       size_t out_offset = j * param->out_strides_[0] + i;
118       if (param->out_value_) {
119         outputfp16[out_offset] = param->arg_elements_[j].data_.f16_data_;
120       } else {
121         outputint[out_offset] = param->arg_elements_[j].index_;
122       }
123       if (output_value != NULL) {
124         output_value[out_offset] = param->arg_elements_[j].data_.f16_data_;
125       }
126     }
127   }
128   return;
129 }
130 
ArgMinMaxDim1Fp16(const float16_t * input,void * output,float16_t * output_value,const int * in_shape,const ArgMinMaxComputeParam * param,COMPARE_FUNCTION compare_func)131 void ArgMinMaxDim1Fp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape,
132                        const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) {
133   int in_shape1 = in_shape[1];
134   float16_t *outputfp16 = (float16_t *)output;
135   int *outputint = (int *)output;
136   for (int i = 0; i < in_shape[0]; ++i) {
137     size_t in_dim0_offset = i * param->in_strides_[0];
138     size_t out_dim0_offset = i * param->out_strides_[0];
139     for (int j = 0; j < param->in_strides_[1]; ++j) {
140       for (int k = 0; k < in_shape1; ++k) {
141         size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
142         param->arg_elements_[k].index_ = k;
143         param->arg_elements_[k].data_.f16_data_ = input[offset];
144       }
145       qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), *compare_func);
146       for (int k = 0; k < param->topk_; ++k) {
147         size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
148         if (param->out_value_) {
149           outputfp16[out_offset] = param->arg_elements_[k].data_.f16_data_;
150         } else {
151           outputint[out_offset] = param->arg_elements_[k].index_;
152         }
153         if (output_value != NULL) {
154           output_value[out_offset] = param->arg_elements_[k].data_.f16_data_;
155         }
156       }
157     }
158   }
159   return;
160 }
161 
ArgMinMaxDim2Fp16(const float16_t * input,float16_t * output,float16_t * output_value,const int * in_shape,const ArgMinMaxComputeParam * param,COMPARE_FUNCTION compare_func)162 void ArgMinMaxDim2Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
163                        const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) {
164   int in_shape1 = in_shape[1];
165   int in_shape2 = in_shape[2];
166   float *outputfp16 = (float *)output;
167   int *outputint = (int *)output;
168   for (int i = 0; i < in_shape[0]; ++i) {
169     size_t in_dim0_offset = i * param->in_strides_[0];
170     size_t out_dim0_offset = i * param->out_strides_[0];
171     for (int j = 0; j < in_shape1; ++j) {
172       size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
173       size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
174       for (int k = 0; k < param->in_strides_[2]; ++k) {
175         for (int l = 0; l < in_shape2; ++l) {
176           size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
177           param->arg_elements_[l].index_ = l;
178           param->arg_elements_[l].data_.f16_data_ = input[offset];
179         }
180         qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), *compare_func);
181         for (int l = 0; l < param->topk_; ++l) {
182           size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
183           if (param->out_value_) {
184             outputfp16[out_offset] = param->arg_elements_[l].data_.f16_data_;
185           } else {
186             outputint[out_offset] = param->arg_elements_[l].index_;
187           }
188           if (output_value != NULL) {
189             output_value[out_offset] = param->arg_elements_[l].data_.f16_data_;
190           }
191         }
192       }
193     }
194   }
195 }
196 
ArgMinMaxDim3Fp16(const float16_t * input,float16_t * output,float16_t * output_value,const int * in_shape,const ArgMinMaxComputeParam * param,COMPARE_FUNCTION compare_func)197 void ArgMinMaxDim3Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
198                        const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) {
199   int in_shape1 = in_shape[1];
200   int in_shape2 = in_shape[2];
201   int in_shape3 = in_shape[3];
202   float *outputfp16 = (float *)output;
203   int *outputint = (int *)output;
204   for (int i = 0; i < in_shape[0]; ++i) {
205     size_t in_dim0_offset = i * param->in_strides_[0];
206     size_t out_dim0_offset = i * param->out_strides_[0];
207     for (int j = 0; j < in_shape1; ++j) {
208       size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
209       size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
210       for (int k = 0; k < in_shape2; ++k) {
211         size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
212         size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
213         for (int l = 0; l < in_shape3; ++l) {
214           size_t offset = l + in_dim2_offset;
215           param->arg_elements_[l].index_ = l;
216           param->arg_elements_[l].data_.f16_data_ = input[offset];
217         }
218         qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), *compare_func);
219         for (int l = 0; l < param->topk_; ++l) {
220           size_t out_offset = out_dim2_offset + l;
221           if (param->out_value_) {
222             outputfp16[out_offset] = param->arg_elements_[l].data_.f16_data_;
223           } else {
224             outputint[out_offset] = param->arg_elements_[l].index_;
225           }
226           if (output_value != NULL) {
227             output_value[out_offset] = param->arg_elements_[l].data_.f16_data_;
228           }
229         }
230       }
231     }
232   }
233 }
234 
ArgMinMaxFp16(const float16_t * input,void * output,float16_t * output_value,const int * in_shape,const ArgMinMaxComputeParam * param)235 void ArgMinMaxFp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape,
236                    const ArgMinMaxComputeParam *param) {
237   if (param->topk_ == 1) {
238     int pre_axis_count = 1;
239     int axis_count = 1;
240     int after_axis_count = 1;
241     ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count);
242 
243     if (param->get_max_) {
244       ArgMaxTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count);
245     } else {
246       ArgMinTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count);
247     }
248     return;
249   }
250 
251   COMPARE_FUNCTION compare_function = NULL;
252   if (param->get_max_) {
253     compare_function = ArgCompareDescFp16;
254   } else {
255     compare_function = ArgCompareAscFp16;
256   }
257 
258   switch (param->axis_) {
259     case 0:
260       ArgMinMaxDim0Fp16(input, output, output_value, in_shape, param, compare_function);
261       break;
262     case 1:
263       ArgMinMaxDim1Fp16(input, output, output_value, in_shape, param, compare_function);
264       break;
265     case 2:
266       ArgMinMaxDim2Fp16(input, output, output_value, in_shape, param, compare_function);
267       break;
268     case 3:
269       ArgMinMaxDim3Fp16(input, output, output_value, in_shape, param, compare_function);
270       break;
271   }
272   return;
273 }
274