1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/splice.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/splice_parameter.h"
20 #include "nnacl/fp32/splice_fp32.h"
21 #ifdef ENABLE_FP16
22 #include "nnacl/fp16/splice_fp16.h"
23 #endif
24
SpliceCompute(struct KernelBase * self)25 int SpliceCompute(struct KernelBase *self) {
26 TensorC *input = self->in_[FIRST_INPUT];
27 NNACL_CHECK_NULL_RETURN_ERR(input);
28 TensorC *output = self->out_[OUTPUT_INDEX];
29 NNACL_CHECK_NULL_RETURN_ERR(output);
30
31 NNACL_CHECK_FALSE(input->shape_size_ != output->shape_size_, NNACL_SPLICE_SHAPE_INVALID);
32 NNACL_CHECK_FALSE(input->shape_size_ != DIMENSION_3D, NNACL_SPLICE_SHAPE_INVALID);
33 NNACL_CHECK_FALSE(output->shape_size_ != DIMENSION_3D, NNACL_SPLICE_SHAPE_INVALID);
34
35 SpliceParameter *param = (SpliceParameter *)self->param_;
36 NNACL_CHECK_NULL_RETURN_ERR(param);
37
38 int src_row = input->shape_[Index1];
39 int src_col = input->shape_[Index2];
40 int dst_row = output->shape_[Index1];
41 int dst_col = output->shape_[Index2];
42
43 NNACL_CHECK_FALSE(src_col * param->context_dim_ != dst_col, NNACL_SPLICE_SHAPE_INVALID);
44 NNACL_CHECK_FALSE(param->context_dim_ * dst_row != param->forward_indexes_dim_, NNACL_SPLICE_SHAPE_INVALID);
45
46 for (int i = 0; i < param->forward_indexes_dim_; ++i) {
47 if (param->forward_indexes_[i] >= src_row) {
48 return NNACL_SPLICE_SHAPE_INVALID;
49 }
50 }
51
52 void *input_data = input->data_;
53 NNACL_CHECK_NULL_RETURN_ERR(input_data);
54 void *output_data = output->data_;
55 NNACL_CHECK_NULL_RETURN_ERR(output_data);
56
57 #ifdef ENABLE_FP16
58 if (input->data_type_ == kNumberTypeFloat16) {
59 SpliceFp16((float16_t *)input_data, src_row, src_col, param, (float16_t *)output_data, dst_row, dst_col);
60 return NNACL_OK;
61 }
62 #endif
63
64 SpliceFp32((float *)input_data, src_row, src_col, param, (float *)output_data, dst_row, dst_col);
65 return NNACL_OK;
66 }
67
CreateSplice(OpParameter * param,int data_type)68 KernelBase *CreateSplice(OpParameter *param, int data_type) {
69 SpliceStruct *splice = (SpliceStruct *)malloc(sizeof(SpliceStruct));
70 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(splice);
71 splice->base_.Release = DefaultRelease;
72 splice->base_.Prepare = DefaultPrepare1In1Out;
73 splice->base_.Resize = DefaultResize;
74 splice->base_.Compute = SpliceCompute;
75 return (KernelBase *)splice;
76 }
77
78 REG_KERNEL_CREATOR(PrimType_Splice, kNumberTypeFloat32, CreateSplice)
79 REG_KERNEL_CREATOR(PrimType_Splice, kNumberTypeFloat16, CreateSplice)
80