• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/shape.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 
ShapeCompute(struct KernelBase * self)20 int ShapeCompute(struct KernelBase *self) {
21   ShapeStruct *shape = (ShapeStruct *)self;
22   memcpy(self->out_[OUTPUT_INDEX]->data_, self->in_[FIRST_INPUT]->shape_, shape->shape_size_);
23   return NNACL_OK;
24 }
25 
ShapeResize(KernelBase * self)26 int ShapeResize(KernelBase *self) {
27   NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]);
28   NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]);
29   ShapeStruct *shape = (ShapeStruct *)self;
30   NNACL_CHECK_NULL_RETURN_ERR(shape);
31   shape->shape_size_ = self->in_[FIRST_INPUT]->shape_size_ * sizeof(int);
32   return NNACL_OK;
33 }
34 
CreateShape(OpParameter * param,int data_type)35 KernelBase *CreateShape(OpParameter *param, int data_type) {
36   ShapeStruct *shape = (ShapeStruct *)malloc(sizeof(ShapeStruct));
37   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(shape);
38   shape->base_.Release = DefaultRelease;
39   shape->base_.Prepare = DefaultPrepare1In1Out;
40   shape->base_.Resize = ShapeResize;
41   shape->base_.Compute = ShapeCompute;
42   return (KernelBase *)shape;
43 }
44 
45 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt32, CreateShape)
46 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeBool, CreateShape)
47 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeFloat16, CreateShape)
48 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeFloat32, CreateShape)
49 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt8, CreateShape)
50 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeUInt8, CreateShape)
51 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt64, CreateShape)
52