1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/fullconnection.h"
18 #include "nnacl/kernel/matmul_base.h"
19 #include "nnacl/kernel/matmul_create.h"
20
FullConnectionPrepare(KernelBase * self)21 int FullConnectionPrepare(KernelBase *self) {
22 MatmulStruct *matmul = (MatmulStruct *)self;
23
24 NNACL_CHECK_FALSE(self->in_size_ < C2NUM, NNACL_ERR);
25 NNACL_CHECK_FALSE(self->out_size_ < C1NUM, NNACL_ERR);
26
27 if (matmul->a_const_ || matmul->infer_shape_) {
28 int *a_shape = self->in_[FIRST_INPUT]->shape_;
29 matmul->compute_.row_ = a_shape[0];
30 matmul->compute_.deep_ = a_shape[1];
31 }
32
33 if (matmul->b_const_ || matmul->infer_shape_) {
34 int *b_shape = self->in_[SECOND_INPUT]->shape_;
35 matmul->compute_.col_ = b_shape[0];
36 matmul->compute_.deep_ = b_shape[1];
37 }
38
39 matmul->batch_ = 1;
40 matmul->a_batch_ = 1;
41 matmul->b_batch_ = 1;
42
43 MatMulParameter *param = (MatMulParameter *)matmul->base_.param_;
44 param->a_transpose_ = false;
45 param->b_transpose_ = true;
46
47 int ret = MatmulBaseMallocBatchOffset(matmul);
48 if (ret != NNACL_OK) {
49 return ret;
50 }
51
52 return MatmulBasePrepare(self);
53 }
54
FullConnectionResize(KernelBase * self)55 int FullConnectionResize(KernelBase *self) {
56 MatmulStruct *matmul = (MatmulStruct *)self;
57 NNACL_CHECK_TRUE_RET(self->out_[0]->shape_size_ > 0, NNACL_ERR);
58
59 int row = 1;
60 for (size_t i = 0; i < self->out_[0]->shape_size_ - 1; ++i) {
61 row *= (self->out_[OUTPUT_INDEX]->shape_)[i];
62 }
63 matmul->compute_.row_ = row;
64 matmul->compute_.col_ = (self->out_[OUTPUT_INDEX]->shape_)[self->out_[0]->shape_size_ - 1];
65 matmul->compute_.deep_ = self->in_[SECOND_INPUT]->shape_[SECOND_INPUT];
66
67 return MatmulBaseResize(self);
68 }
69
CreateFullconnection(OpParameter * param,int data_type)70 KernelBase *CreateFullconnection(OpParameter *param, int data_type) {
71 KernelBase *kernel = NULL;
72 if (data_type == kNumberTypeFloat32) {
73 kernel = CreateMatmulKernel();
74 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(kernel);
75 kernel->Prepare = FullConnectionPrepare;
76 kernel->Resize = FullConnectionResize;
77 }
78 return kernel;
79 }
80
81 REG_KERNEL_CREATOR(PrimType_FullConnection, kNumberTypeFloat32, CreateFullconnection);
82