• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/arg_min_max.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/arg_min_max_parameter.h"
20 #include "nnacl/nnacl_common.h"
21 #include "nnacl/fp32/arg_min_max_fp32.h"
22 #ifdef ENABLE_FP16
23 #include "nnacl/fp16/arg_min_max_fp16.h"
24 #endif
25 
ArgMinMaxPrepare(KernelBase * self)26 int ArgMinMaxPrepare(KernelBase *self) {
27   ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self;
28   NNACL_CHECK_NULL_RETURN_ERR(arg_min_max);
29   ArgMinMaxParameter *param = (ArgMinMaxParameter *)self->param_;
30   NNACL_CHECK_NULL_RETURN_ERR(param);
31 
32   arg_min_max->arg_elements_alloc_ = param->topk_ > Num1 || param->keep_dims_;
33   arg_min_max->compute_.topk_ = param->topk_;
34   arg_min_max->compute_.axis_ = param->axis_;
35   arg_min_max->compute_.keep_dims_ = param->keep_dims_;
36   arg_min_max->compute_.out_value_ = param->out_value_;
37   arg_min_max->compute_.get_max_ = self->param_->type_ == PrimType_ArgMinFusion ? false : true;
38   return NNACL_OK;
39 }
40 
ArgMinMaxResize(KernelBase * self)41 int ArgMinMaxResize(KernelBase *self) {
42   ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self;
43   NNACL_CHECK_NULL_RETURN_ERR(arg_min_max);
44   ArgMinMaxComputeParam *compute = &arg_min_max->compute_;
45 
46   TensorC *input_tensor = self->in_[FIRST_INPUT];
47   NNACL_CHECK_NULL_RETURN_ERR(input_tensor);
48   ComputeStrides(input_tensor->shape_, compute->in_strides_, input_tensor->shape_size_);
49 
50   TensorC *output_tensor = self->out_[OUTPUT_INDEX];
51   NNACL_CHECK_NULL_RETURN_ERR(output_tensor);
52   ComputeStrides(output_tensor->shape_, compute->out_strides_, output_tensor->shape_size_);
53 
54   compute->dims_size_ = (int)input_tensor->shape_size_;
55   compute->axis_ = compute->axis_ < 0 ? compute->axis_ + compute->dims_size_ : compute->axis_;
56   NNACL_CHECK_FALSE(compute->topk_ <= 0, NNACL_ARG_MIN_MAX_AXIS_INVALID);
57   NNACL_CHECK_FALSE(compute->topk_ > input_tensor->shape_[compute->axis_], NNACL_ARG_MIN_MAX_AXIS_INVALID);
58   return NNACL_OK;
59 }
60 
ArgMinMaxCompute(KernelBase * self)61 int ArgMinMaxCompute(KernelBase *self) {
62   ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self;
63   NNACL_CHECK_NULL_RETURN_ERR(arg_min_max);
64   ArgMinMaxParameter *param = (ArgMinMaxParameter *)self->param_;
65   NNACL_CHECK_NULL_RETURN_ERR(param);
66   TensorC *in_tensor = self->in_[FIRST_INPUT];
67   NNACL_CHECK_NULL_RETURN_ERR(in_tensor);
68   void *in_data = in_tensor->data_;
69   NNACL_CHECK_NULL_RETURN_ERR(in_data);
70   TensorC *out_tensor = self->out_[OUTPUT_INDEX];
71   NNACL_CHECK_NULL_RETURN_ERR(out_tensor);
72   void *out_data = out_tensor->data_;
73   NNACL_CHECK_NULL_RETURN_ERR(out_data);
74 
75   void *out_value = NULL;
76   if (self->out_size_ == TWO_TENSOR) {
77     out_value = self->out_[Index1]->data_;
78     NNACL_CHECK_NULL_RETURN_ERR(out_value);
79   }
80 
81   if (arg_min_max->arg_elements_alloc_) {
82     int arg_size = in_tensor->shape_[arg_min_max->compute_.axis_] * (int)sizeof(ArgElement);
83     NNACL_CHECK_MALLOC_SIZE(arg_size);
84     arg_min_max->compute_.arg_elements_ = (ArgElement *)self->env_->Alloc(self->env_->allocator_, arg_size);
85     NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arg_min_max->compute_.arg_elements_);
86   }
87 
88   int ret = NNACL_OK;
89   int *in_shape = in_tensor->shape_;
90   if (in_tensor->data_type_ == kNumberTypeFloat32) {
91     ArgMinMaxFp32((float *)in_data, out_data, (float *)out_value, in_shape, &arg_min_max->compute_);
92 #ifdef ENABLE_FP16
93   } else if (in_tensor->data_type_ == kNumberTypeFloat16) {
94     ArgMinMaxFp16((float16_t *)in_data, out_data, (float16_t *)out_value, in_shape, &arg_min_max->compute_);
95 #endif
96   } else if (in_tensor->data_type_ == kNumberTypeInt32) {
97     ArgMinMaxInt32((int32_t *)in_data, out_data, (int32_t *)out_value, in_shape, &arg_min_max->compute_);
98   } else {
99     ret = NNACL_UNSUPPORTED_DATA_TYPE;
100   }
101 
102   if (arg_min_max->arg_elements_alloc_) {
103     self->env_->Free(self->env_->allocator_, arg_min_max->compute_.arg_elements_);
104     arg_min_max->compute_.arg_elements_ = NULL;
105   }
106   return ret;
107 }
108 
CreateArgMinMax(OpParameter * param,int data_type)109 KernelBase *CreateArgMinMax(OpParameter *param, int data_type) {
110   ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)malloc(sizeof(ArgMinMaxStruct));
111   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arg_min_max);
112   memset(arg_min_max, 0, sizeof(ArgMinMaxStruct));
113 
114   arg_min_max->base_.Prepare = ArgMinMaxPrepare;
115   arg_min_max->base_.Resize = ArgMinMaxResize;
116   arg_min_max->base_.Release = DefaultRelease;
117   arg_min_max->base_.Compute = ArgMinMaxCompute;
118   return (KernelBase *)arg_min_max;
119 }
120 
121 REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeInt32, CreateArgMinMax)
122 REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeFloat16, CreateArgMinMax)
123 REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeFloat32, CreateArgMinMax)
124 
125 REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeInt32, CreateArgMinMax)
126 REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeFloat16, CreateArgMinMax)
127 REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeFloat32, CreateArgMinMax)
128