1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/unique.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/fp32/unique_fp32.h"
20 #include "nnacl/tensor_c_utils.h"
21 #ifdef ENABLE_FP16
22 #include "nnacl/fp16/unique_fp16.h"
23 #endif
24
UniqueCompute(KernelBase * self)25 int UniqueCompute(KernelBase *self) {
26 TensorC *input = self->in_[FIRST_INPUT];
27 NNACL_CHECK_NULL_RETURN_ERR(input);
28 TensorC *output0 = self->out_[Index0];
29 NNACL_CHECK_NULL_RETURN_ERR(output0);
30 TensorC *output1 = self->out_[Index1];
31 NNACL_CHECK_NULL_RETURN_ERR(output1);
32
33 int num = GetElementNum(input);
34 int output0_len = 0;
35
36 #ifdef ENABLE_FP16
37 if (input->data_type_ == kNumberTypeFloat16) {
38 UniqueFp16((float16_t *)input->data_, num, (float16_t *)output0->data_, &output0_len, (int *)output1->data_);
39 }
40 #endif
41 if (input->data_type_ == kNumberTypeInt32) {
42 UniqueInt((int *)input->data_, num, (int *)output0->data_, &output0_len, (int *)output1->data_);
43 }
44 if (input->data_type_ == kNumberTypeFloat32) {
45 Unique((float *)input->data_, num, (float *)output0->data_, &output0_len, (int *)output1->data_);
46 }
47
48 output0->shape_changed_ = (output0->shape_[output0->shape_size_ - 1] != output0_len);
49 output0->shape_[output0->shape_size_ - 1] = output0_len;
50 return NNACL_OK;
51 }
52
CreateUnique(OpParameter * param,int data_type)53 KernelBase *CreateUnique(OpParameter *param, int data_type) {
54 UniqueStruct *unique = (UniqueStruct *)malloc(sizeof(UniqueStruct));
55 NNACL_CHECK_NULL_RETURN_NULL(unique);
56 unique->data_type_ = data_type;
57 unique->base_.Release = DefaultRelease;
58 unique->base_.Prepare = DefaultPrepare1In2Out;
59 unique->base_.Resize = DefaultResize;
60 unique->base_.Compute = UniqueCompute;
61 return (KernelBase *)unique;
62 }
63
64 REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeInt32, CreateUnique)
65 REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeFloat32, CreateUnique)
66 REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeFloat16, CreateUnique)
67