• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/splice.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/splice_parameter.h"
20 #include "nnacl/fp32/splice_fp32.h"
21 #ifdef ENABLE_FP16
22 #include "nnacl/fp16/splice_fp16.h"
23 #endif
24 
SpliceCompute(struct KernelBase * self)25 int SpliceCompute(struct KernelBase *self) {
26   TensorC *input = self->in_[FIRST_INPUT];
27   NNACL_CHECK_NULL_RETURN_ERR(input);
28   TensorC *output = self->out_[OUTPUT_INDEX];
29   NNACL_CHECK_NULL_RETURN_ERR(output);
30 
31   NNACL_CHECK_FALSE(input->shape_size_ != output->shape_size_, NNACL_SPLICE_SHAPE_INVALID);
32   NNACL_CHECK_FALSE(input->shape_size_ != DIMENSION_3D, NNACL_SPLICE_SHAPE_INVALID);
33   NNACL_CHECK_FALSE(output->shape_size_ != DIMENSION_3D, NNACL_SPLICE_SHAPE_INVALID);
34 
35   SpliceParameter *param = (SpliceParameter *)self->param_;
36   NNACL_CHECK_NULL_RETURN_ERR(param);
37 
38   int src_row = input->shape_[Index1];
39   int src_col = input->shape_[Index2];
40   int dst_row = output->shape_[Index1];
41   int dst_col = output->shape_[Index2];
42 
43   NNACL_CHECK_FALSE(src_col * param->context_dim_ != dst_col, NNACL_SPLICE_SHAPE_INVALID);
44   NNACL_CHECK_FALSE(param->context_dim_ * dst_row != param->forward_indexes_dim_, NNACL_SPLICE_SHAPE_INVALID);
45 
46   for (int i = 0; i < param->forward_indexes_dim_; ++i) {
47     if (param->forward_indexes_[i] >= src_row) {
48       return NNACL_SPLICE_SHAPE_INVALID;
49     }
50   }
51 
52   void *input_data = input->data_;
53   NNACL_CHECK_NULL_RETURN_ERR(input_data);
54   void *output_data = output->data_;
55   NNACL_CHECK_NULL_RETURN_ERR(output_data);
56 
57 #ifdef ENABLE_FP16
58   if (input->data_type_ == kNumberTypeFloat16) {
59     SpliceFp16((float16_t *)input_data, src_row, src_col, param, (float16_t *)output_data, dst_row, dst_col);
60     return NNACL_OK;
61   }
62 #endif
63 
64   SpliceFp32((float *)input_data, src_row, src_col, param, (float *)output_data, dst_row, dst_col);
65   return NNACL_OK;
66 }
67 
CreateSplice(OpParameter * param,int data_type)68 KernelBase *CreateSplice(OpParameter *param, int data_type) {
69   SpliceStruct *splice = (SpliceStruct *)malloc(sizeof(SpliceStruct));
70   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(splice);
71   splice->base_.Release = DefaultRelease;
72   splice->base_.Prepare = DefaultPrepare1In1Out;
73   splice->base_.Resize = DefaultResize;
74   splice->base_.Compute = SpliceCompute;
75   return (KernelBase *)splice;
76 }
77 
78 REG_KERNEL_CREATOR(PrimType_Splice, kNumberTypeFloat32, CreateSplice)
79 REG_KERNEL_CREATOR(PrimType_Splice, kNumberTypeFloat16, CreateSplice)
80