1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/shape.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19
ShapeCompute(struct KernelBase * self)20 int ShapeCompute(struct KernelBase *self) {
21 ShapeStruct *shape = (ShapeStruct *)self;
22 memcpy(self->out_[OUTPUT_INDEX]->data_, self->in_[FIRST_INPUT]->shape_, shape->shape_size_);
23 return NNACL_OK;
24 }
25
ShapeResize(KernelBase * self)26 int ShapeResize(KernelBase *self) {
27 NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]);
28 NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]);
29 ShapeStruct *shape = (ShapeStruct *)self;
30 NNACL_CHECK_NULL_RETURN_ERR(shape);
31 shape->shape_size_ = self->in_[FIRST_INPUT]->shape_size_ * sizeof(int);
32 return NNACL_OK;
33 }
34
CreateShape(OpParameter * param,int data_type)35 KernelBase *CreateShape(OpParameter *param, int data_type) {
36 ShapeStruct *shape = (ShapeStruct *)malloc(sizeof(ShapeStruct));
37 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(shape);
38 shape->base_.Release = DefaultRelease;
39 shape->base_.Prepare = DefaultPrepare1In1Out;
40 shape->base_.Resize = ShapeResize;
41 shape->base_.Compute = ShapeCompute;
42 return (KernelBase *)shape;
43 }
44
45 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt32, CreateShape)
46 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeBool, CreateShape)
47 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeFloat16, CreateShape)
48 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeFloat32, CreateShape)
49 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt8, CreateShape)
50 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeUInt8, CreateShape)
51 REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt64, CreateShape)
52