• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/ragged_range.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/fp32/ragged_range_fp32.h"
20 #ifdef ENABLE_FP16
21 #include "nnacl/fp16/ragged_range_fp16.h"
22 #endif
23 
RaggedRangeCompute(KernelBase * self)24 int RaggedRangeCompute(KernelBase *self) {
25   RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)self;
26   NNACL_CHECK_NULL_RETURN_ERR(ragged_range);
27 
28   TensorC *input0 = self->in_[Index0];
29   TensorC *input1 = self->in_[Index1];
30   TensorC *input2 = self->in_[Index2];
31   TensorC *output0 = self->out_[Index0];
32   TensorC *output1 = self->out_[Index1];
33 
34   if (input0->data_type_ == kNumberTypeFloat32) {
35     RaggedRangeFp32((float *)input0->data_, (float *)input1->data_, (float *)input2->data_, (int *)output0->data_,
36                     (float *)output1->data_, ragged_range);
37   } else if (input0->data_type_ == kNumberTypeInt32) {
38     RaggedRangeInt((int *)input0->data_, (int *)input1->data_, (int *)input2->data_, (int *)output0->data_,
39                    (int *)output1->data_, ragged_range);
40   } else if (input0->data_type_ == kNumberTypeFloat16) {
41 #ifdef ENABLE_FP16
42     RaggedRangeFp16((float16_t *)input0->data_, (float16_t *)input1->data_, (float16_t *)input2->data_,
43                     (int *)output0->data_, (float16_t *)output1->data_, ragged_range);
44 #endif
45   } else {
46     return NNACL_UNSUPPORTED_DATA_TYPE;
47   }
48   return NNACL_OK;
49 }
50 
RaggedRangeResize(KernelBase * self)51 int RaggedRangeResize(KernelBase *self) {
52   RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)self;
53   NNACL_CHECK_NULL_RETURN_ERR(ragged_range);
54 
55   ragged_range->rows_ = self->out_[OUTPUT_INDEX]->shape_[Index0] - 1;
56   ragged_range->starts_is_scalar_ = self->in_[FIRST_INPUT]->shape_size_ == 0;
57   ragged_range->limits_is_scalar_ = self->in_[SECOND_INPUT]->shape_size_ == 0;
58   ragged_range->deltas_is_scalar_ = self->in_[THIRD_INPUT]->shape_size_ == 0;
59   return NNACL_OK;
60 }
61 
CreateRaggedRange(OpParameter * param,int data_type)62 KernelBase *CreateRaggedRange(OpParameter *param, int data_type) {
63   RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)malloc(sizeof(RaggedRangeStruct));
64   NNACL_CHECK_NULL_RETURN_NULL(ragged_range);
65   ragged_range->base_.Release = DefaultRelease;
66   ragged_range->base_.Prepare = DefaultPrepare3In2Out;
67   ragged_range->base_.Resize = RaggedRangeResize;
68   ragged_range->base_.Compute = RaggedRangeCompute;
69   return (KernelBase *)ragged_range;
70 }
71 
72 REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeInt32, CreateRaggedRange)
73 REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeFloat16, CreateRaggedRange)
74 REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeFloat32, CreateRaggedRange)
75