1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/ragged_range.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/fp32/ragged_range_fp32.h"
20 #ifdef ENABLE_FP16
21 #include "nnacl/fp16/ragged_range_fp16.h"
22 #endif
23
RaggedRangeCompute(KernelBase * self)24 int RaggedRangeCompute(KernelBase *self) {
25 RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)self;
26 NNACL_CHECK_NULL_RETURN_ERR(ragged_range);
27
28 TensorC *input0 = self->in_[Index0];
29 TensorC *input1 = self->in_[Index1];
30 TensorC *input2 = self->in_[Index2];
31 TensorC *output0 = self->out_[Index0];
32 TensorC *output1 = self->out_[Index1];
33
34 if (input0->data_type_ == kNumberTypeFloat32) {
35 RaggedRangeFp32((float *)input0->data_, (float *)input1->data_, (float *)input2->data_, (int *)output0->data_,
36 (float *)output1->data_, ragged_range);
37 } else if (input0->data_type_ == kNumberTypeInt32) {
38 RaggedRangeInt((int *)input0->data_, (int *)input1->data_, (int *)input2->data_, (int *)output0->data_,
39 (int *)output1->data_, ragged_range);
40 } else if (input0->data_type_ == kNumberTypeFloat16) {
41 #ifdef ENABLE_FP16
42 RaggedRangeFp16((float16_t *)input0->data_, (float16_t *)input1->data_, (float16_t *)input2->data_,
43 (int *)output0->data_, (float16_t *)output1->data_, ragged_range);
44 #endif
45 } else {
46 return NNACL_UNSUPPORTED_DATA_TYPE;
47 }
48 return NNACL_OK;
49 }
50
RaggedRangeResize(KernelBase * self)51 int RaggedRangeResize(KernelBase *self) {
52 RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)self;
53 NNACL_CHECK_NULL_RETURN_ERR(ragged_range);
54
55 ragged_range->rows_ = self->out_[OUTPUT_INDEX]->shape_[Index0] - 1;
56 ragged_range->starts_is_scalar_ = self->in_[FIRST_INPUT]->shape_size_ == 0;
57 ragged_range->limits_is_scalar_ = self->in_[SECOND_INPUT]->shape_size_ == 0;
58 ragged_range->deltas_is_scalar_ = self->in_[THIRD_INPUT]->shape_size_ == 0;
59 return NNACL_OK;
60 }
61
CreateRaggedRange(OpParameter * param,int data_type)62 KernelBase *CreateRaggedRange(OpParameter *param, int data_type) {
63 RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)malloc(sizeof(RaggedRangeStruct));
64 NNACL_CHECK_NULL_RETURN_NULL(ragged_range);
65 ragged_range->base_.Release = DefaultRelease;
66 ragged_range->base_.Prepare = DefaultPrepare3In2Out;
67 ragged_range->base_.Resize = RaggedRangeResize;
68 ragged_range->base_.Compute = RaggedRangeCompute;
69 return (KernelBase *)ragged_range;
70 }
71
72 REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeInt32, CreateRaggedRange)
73 REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeFloat16, CreateRaggedRange)
74 REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeFloat32, CreateRaggedRange)
75