1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/rank.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19
RankCompute(KernelBase * self)20 int RankCompute(KernelBase *self) {
21 size_t rank = self->in_[FIRST_INPUT]->shape_size_;
22 void *output_data = self->out_[OUTPUT_INDEX]->data_;
23 if (self->in_[FIRST_INPUT]->data_type_ == kNumberTypeFloat16) {
24 #ifdef ENABLE_FP16
25 *(float16_t *)output_data = (float16_t)rank;
26 #endif
27 } else {
28 *(float *)output_data = (float)rank;
29 }
30 return NNACL_OK;
31 }
32
CreateRank(OpParameter * param,int data_type)33 KernelBase *CreateRank(OpParameter *param, int data_type) {
34 RankStruct *rank = (RankStruct *)malloc(sizeof(RankStruct));
35 NNACL_CHECK_NULL_RETURN_NULL(rank);
36 rank->base_.Release = DefaultRelease;
37 rank->base_.Prepare = DefaultPrepare1In1Out;
38 rank->base_.Resize = DefaultResize;
39 rank->base_.Compute = RankCompute;
40 return (KernelBase *)rank;
41 }
42
43 REG_KERNEL_CREATOR(PrimType_Rank, kNumberTypeFloat32, CreateRank)
44 REG_KERNEL_CREATOR(PrimType_Rank, kNumberTypeFloat16, CreateRank)
45