• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/fullconnection.h"
18 #include "nnacl/kernel/matmul_base.h"
19 #include "nnacl/kernel/matmul_create.h"
20 
FullConnectionPrepare(KernelBase * self)21 int FullConnectionPrepare(KernelBase *self) {
22   MatmulStruct *matmul = (MatmulStruct *)self;
23 
24   NNACL_CHECK_FALSE(self->in_size_ < C2NUM, NNACL_ERR);
25   NNACL_CHECK_FALSE(self->out_size_ < C1NUM, NNACL_ERR);
26 
27   if (matmul->a_const_ || matmul->infer_shape_) {
28     int *a_shape = self->in_[FIRST_INPUT]->shape_;
29     matmul->compute_.row_ = a_shape[0];
30     matmul->compute_.deep_ = a_shape[1];
31   }
32 
33   if (matmul->b_const_ || matmul->infer_shape_) {
34     int *b_shape = self->in_[SECOND_INPUT]->shape_;
35     matmul->compute_.col_ = b_shape[0];
36     matmul->compute_.deep_ = b_shape[1];
37   }
38 
39   matmul->batch_ = 1;
40   matmul->a_batch_ = 1;
41   matmul->b_batch_ = 1;
42 
43   MatMulParameter *param = (MatMulParameter *)matmul->base_.param_;
44   param->a_transpose_ = false;
45   param->b_transpose_ = true;
46 
47   int ret = MatmulBaseMallocBatchOffset(matmul);
48   if (ret != NNACL_OK) {
49     return ret;
50   }
51 
52   return MatmulBasePrepare(self);
53 }
54 
FullConnectionResize(KernelBase * self)55 int FullConnectionResize(KernelBase *self) {
56   MatmulStruct *matmul = (MatmulStruct *)self;
57   NNACL_CHECK_TRUE_RET(self->out_[0]->shape_size_ > 0, NNACL_ERR);
58 
59   int row = 1;
60   for (size_t i = 0; i < self->out_[0]->shape_size_ - 1; ++i) {
61     row *= (self->out_[OUTPUT_INDEX]->shape_)[i];
62   }
63   matmul->compute_.row_ = row;
64   matmul->compute_.col_ = (self->out_[OUTPUT_INDEX]->shape_)[self->out_[0]->shape_size_ - 1];
65   matmul->compute_.deep_ = self->in_[SECOND_INPUT]->shape_[SECOND_INPUT];
66 
67   return MatmulBaseResize(self);
68 }
69 
CreateFullconnection(OpParameter * param,int data_type)70 KernelBase *CreateFullconnection(OpParameter *param, int data_type) {
71   KernelBase *kernel = NULL;
72   if (data_type == kNumberTypeFloat32) {
73     kernel = CreateMatmulKernel();
74     NNACL_MALLOC_CHECK_NULL_RETURN_NULL(kernel);
75     kernel->Prepare = FullConnectionPrepare;
76     kernel->Resize = FullConnectionResize;
77   }
78   return kernel;
79 }
80 
81 REG_KERNEL_CREATOR(PrimType_FullConnection, kNumberTypeFloat32, CreateFullconnection);
82