1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/arg_min_max.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/arg_min_max_parameter.h"
20 #include "nnacl/nnacl_common.h"
21 #include "nnacl/fp32/arg_min_max_fp32.h"
22 #ifdef ENABLE_FP16
23 #include "nnacl/fp16/arg_min_max_fp16.h"
24 #endif
25
ArgMinMaxPrepare(KernelBase * self)26 int ArgMinMaxPrepare(KernelBase *self) {
27 ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self;
28 NNACL_CHECK_NULL_RETURN_ERR(arg_min_max);
29 ArgMinMaxParameter *param = (ArgMinMaxParameter *)self->param_;
30 NNACL_CHECK_NULL_RETURN_ERR(param);
31
32 arg_min_max->arg_elements_alloc_ = param->topk_ > Num1 || param->keep_dims_;
33 arg_min_max->compute_.topk_ = param->topk_;
34 arg_min_max->compute_.axis_ = param->axis_;
35 arg_min_max->compute_.keep_dims_ = param->keep_dims_;
36 arg_min_max->compute_.out_value_ = param->out_value_;
37 arg_min_max->compute_.get_max_ = self->param_->type_ == PrimType_ArgMinFusion ? false : true;
38 return NNACL_OK;
39 }
40
ArgMinMaxResize(KernelBase * self)41 int ArgMinMaxResize(KernelBase *self) {
42 ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self;
43 NNACL_CHECK_NULL_RETURN_ERR(arg_min_max);
44 ArgMinMaxComputeParam *compute = &arg_min_max->compute_;
45
46 TensorC *input_tensor = self->in_[FIRST_INPUT];
47 NNACL_CHECK_NULL_RETURN_ERR(input_tensor);
48 ComputeStrides(input_tensor->shape_, compute->in_strides_, input_tensor->shape_size_);
49
50 TensorC *output_tensor = self->out_[OUTPUT_INDEX];
51 NNACL_CHECK_NULL_RETURN_ERR(output_tensor);
52 ComputeStrides(output_tensor->shape_, compute->out_strides_, output_tensor->shape_size_);
53
54 compute->dims_size_ = (int)input_tensor->shape_size_;
55 compute->axis_ = compute->axis_ < 0 ? compute->axis_ + compute->dims_size_ : compute->axis_;
56 NNACL_CHECK_FALSE(compute->topk_ <= 0, NNACL_ARG_MIN_MAX_AXIS_INVALID);
57 NNACL_CHECK_FALSE(compute->topk_ > input_tensor->shape_[compute->axis_], NNACL_ARG_MIN_MAX_AXIS_INVALID);
58 return NNACL_OK;
59 }
60
ArgMinMaxCompute(KernelBase * self)61 int ArgMinMaxCompute(KernelBase *self) {
62 ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self;
63 NNACL_CHECK_NULL_RETURN_ERR(arg_min_max);
64 ArgMinMaxParameter *param = (ArgMinMaxParameter *)self->param_;
65 NNACL_CHECK_NULL_RETURN_ERR(param);
66 TensorC *in_tensor = self->in_[FIRST_INPUT];
67 NNACL_CHECK_NULL_RETURN_ERR(in_tensor);
68 void *in_data = in_tensor->data_;
69 NNACL_CHECK_NULL_RETURN_ERR(in_data);
70 TensorC *out_tensor = self->out_[OUTPUT_INDEX];
71 NNACL_CHECK_NULL_RETURN_ERR(out_tensor);
72 void *out_data = out_tensor->data_;
73 NNACL_CHECK_NULL_RETURN_ERR(out_data);
74
75 void *out_value = NULL;
76 if (self->out_size_ == TWO_TENSOR) {
77 out_value = self->out_[Index1]->data_;
78 NNACL_CHECK_NULL_RETURN_ERR(out_value);
79 }
80
81 if (arg_min_max->arg_elements_alloc_) {
82 int arg_size = in_tensor->shape_[arg_min_max->compute_.axis_] * (int)sizeof(ArgElement);
83 NNACL_CHECK_MALLOC_SIZE(arg_size);
84 arg_min_max->compute_.arg_elements_ = (ArgElement *)self->env_->Alloc(self->env_->allocator_, arg_size);
85 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arg_min_max->compute_.arg_elements_);
86 }
87
88 int ret = NNACL_OK;
89 int *in_shape = in_tensor->shape_;
90 if (in_tensor->data_type_ == kNumberTypeFloat32) {
91 ArgMinMaxFp32((float *)in_data, out_data, (float *)out_value, in_shape, &arg_min_max->compute_);
92 #ifdef ENABLE_FP16
93 } else if (in_tensor->data_type_ == kNumberTypeFloat16) {
94 ArgMinMaxFp16((float16_t *)in_data, out_data, (float16_t *)out_value, in_shape, &arg_min_max->compute_);
95 #endif
96 } else if (in_tensor->data_type_ == kNumberTypeInt32) {
97 ArgMinMaxInt32((int32_t *)in_data, out_data, (int32_t *)out_value, in_shape, &arg_min_max->compute_);
98 } else {
99 ret = NNACL_UNSUPPORTED_DATA_TYPE;
100 }
101
102 if (arg_min_max->arg_elements_alloc_) {
103 self->env_->Free(self->env_->allocator_, arg_min_max->compute_.arg_elements_);
104 arg_min_max->compute_.arg_elements_ = NULL;
105 }
106 return ret;
107 }
108
CreateArgMinMax(OpParameter * param,int data_type)109 KernelBase *CreateArgMinMax(OpParameter *param, int data_type) {
110 ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)malloc(sizeof(ArgMinMaxStruct));
111 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arg_min_max);
112 memset(arg_min_max, 0, sizeof(ArgMinMaxStruct));
113
114 arg_min_max->base_.Prepare = ArgMinMaxPrepare;
115 arg_min_max->base_.Resize = ArgMinMaxResize;
116 arg_min_max->base_.Release = DefaultRelease;
117 arg_min_max->base_.Compute = ArgMinMaxCompute;
118 return (KernelBase *)arg_min_max;
119 }
120
121 REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeInt32, CreateArgMinMax)
122 REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeFloat16, CreateArgMinMax)
123 REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeFloat32, CreateArgMinMax)
124
125 REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeInt32, CreateArgMinMax)
126 REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeFloat16, CreateArgMinMax)
127 REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeFloat32, CreateArgMinMax)
128