• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/unique.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/fp32/unique_fp32.h"
20 #include "nnacl/tensor_c_utils.h"
21 #ifdef ENABLE_FP16
22 #include "nnacl/fp16/unique_fp16.h"
23 #endif
24 
UniqueCompute(KernelBase * self)25 int UniqueCompute(KernelBase *self) {
26   TensorC *input = self->in_[FIRST_INPUT];
27   NNACL_CHECK_NULL_RETURN_ERR(input);
28   TensorC *output0 = self->out_[Index0];
29   NNACL_CHECK_NULL_RETURN_ERR(output0);
30   TensorC *output1 = self->out_[Index1];
31   NNACL_CHECK_NULL_RETURN_ERR(output1);
32 
33   int num = GetElementNum(input);
34   int output0_len = 0;
35 
36 #ifdef ENABLE_FP16
37   if (input->data_type_ == kNumberTypeFloat16) {
38     UniqueFp16((float16_t *)input->data_, num, (float16_t *)output0->data_, &output0_len, (int *)output1->data_);
39   }
40 #endif
41   if (input->data_type_ == kNumberTypeInt32) {
42     UniqueInt((int *)input->data_, num, (int *)output0->data_, &output0_len, (int *)output1->data_);
43   }
44   if (input->data_type_ == kNumberTypeFloat32) {
45     Unique((float *)input->data_, num, (float *)output0->data_, &output0_len, (int *)output1->data_);
46   }
47 
48   output0->shape_changed_ = (output0->shape_[output0->shape_size_ - 1] != output0_len);
49   output0->shape_[output0->shape_size_ - 1] = output0_len;
50   return NNACL_OK;
51 }
52 
CreateUnique(OpParameter * param,int data_type)53 KernelBase *CreateUnique(OpParameter *param, int data_type) {
54   UniqueStruct *unique = (UniqueStruct *)malloc(sizeof(UniqueStruct));
55   NNACL_CHECK_NULL_RETURN_NULL(unique);
56   unique->data_type_ = data_type;
57   unique->base_.Release = DefaultRelease;
58   unique->base_.Prepare = DefaultPrepare1In2Out;
59   unique->base_.Resize = DefaultResize;
60   unique->base_.Compute = UniqueCompute;
61   return (KernelBase *)unique;
62 }
63 
64 REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeInt32, CreateUnique)
65 REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeFloat32, CreateUnique)
66 REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeFloat16, CreateUnique)
67