1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifdef ENABLE_AVX
18 #include "nnacl/kernel/convolution_sw_1x1.h"
19 #include "nnacl/kernel/matmul_base.h"
20 #include "nnacl/kernel/matmul_create.h"
21
MatmulConv1x1Prelare(ConvolutionSW1x1Struct * sw_1x1)22 int MatmulConv1x1Prelare(ConvolutionSW1x1Struct *sw_1x1) {
23 sw_1x1->matmul_->batch_ = 1;
24 sw_1x1->matmul_->a_batch_ = 1;
25 sw_1x1->matmul_->b_batch_ = 1;
26
27 sw_1x1->matmul_->compute_.deep_ = sw_1x1->conv_.compute_.in_c_;
28 sw_1x1->matmul_->compute_.col_ = sw_1x1->conv_.compute_.out_c_;
29 sw_1x1->matmul_->compute_.row_ = sw_1x1->conv_.compute_.in_hw_ * sw_1x1->conv_.compute_.in_n_;
30
31 return sw_1x1->matmul_->base_.Prepare(&sw_1x1->matmul_->base_);
32 }
33
MatmulConv1x1Resize(ConvolutionSW1x1Struct * sw_1x1)34 int MatmulConv1x1Resize(ConvolutionSW1x1Struct *sw_1x1) {
35 sw_1x1->matmul_->compute_.deep_ = sw_1x1->conv_.compute_.in_c_;
36 sw_1x1->matmul_->compute_.col_ = sw_1x1->conv_.compute_.out_c_;
37 sw_1x1->matmul_->compute_.row_ = sw_1x1->conv_.compute_.in_hw_ * sw_1x1->conv_.compute_.in_n_;
38
39 MatmulBaseFreeBatchOffset(sw_1x1->matmul_);
40 int ret = MatmulBaseMallocBatchOffset(sw_1x1->matmul_);
41 if (ret != NNACL_OK) {
42 return ret;
43 }
44
45 return sw_1x1->matmul_->base_.Resize(&sw_1x1->matmul_->base_);
46 }
47
UpdateTensorInfo(KernelBase * self,ConvolutionSW1x1Struct * sw_1x1)48 void UpdateTensorInfo(KernelBase *self, ConvolutionSW1x1Struct *sw_1x1) {
49 sw_1x1->matmul_->base_.in_ = self->in_;
50 sw_1x1->matmul_->base_.in_size_ = self->in_size_;
51 sw_1x1->matmul_->base_.out_ = self->out_;
52 sw_1x1->matmul_->base_.out_size_ = self->out_size_;
53 sw_1x1->matmul_->base_.workspace_ = self->workspace_;
54 }
55
ConvolutionSW1x1Compute(KernelBase * self)56 int ConvolutionSW1x1Compute(KernelBase *self) {
57 ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self;
58 NNACL_CHECK_NULL_RETURN_ERR(sw_1x1);
59 NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_);
60
61 UpdateTensorInfo(self, sw_1x1);
62 return sw_1x1->matmul_->base_.Compute(&sw_1x1->matmul_->base_);
63 }
64
ConvolutionSW1x1Resize(KernelBase * self)65 int ConvolutionSW1x1Resize(KernelBase *self) {
66 ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self;
67 NNACL_CHECK_NULL_RETURN_ERR(sw_1x1);
68 NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_);
69
70 UpdateTensorInfo(self, sw_1x1);
71 return MatmulConv1x1Resize(sw_1x1);
72 }
73
ConvolutionSW1x1Prepare(KernelBase * self)74 int ConvolutionSW1x1Prepare(KernelBase *self) {
75 ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self;
76 NNACL_CHECK_NULL_RETURN_ERR(sw_1x1);
77 NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_);
78
79 sw_1x1->matmul_->matrix_b_.origin_ptr_ = sw_1x1->conv_.origin_weight_;
80 sw_1x1->matmul_->matrix_b_.origin_need_free_ = false;
81 sw_1x1->matmul_->matrix_c_.origin_ptr_ = sw_1x1->conv_.origin_bias_;
82 sw_1x1->matmul_->matrix_c_.origin_need_free_ = false;
83
84 sw_1x1->matmul_->infer_shape_ = sw_1x1->conv_.infershape_done_;
85 sw_1x1->matmul_->base_.train_session_ = self->train_session_;
86 sw_1x1->matmul_->base_.thread_nr_ = self->thread_nr_;
87 sw_1x1->matmul_->base_.env_ = self->env_;
88
89 UpdateTensorInfo(self, sw_1x1);
90 return MatmulConv1x1Prelare(sw_1x1);
91 }
92
ConvolutionSW1x1Release(KernelBase * self)93 int ConvolutionSW1x1Release(KernelBase *self) {
94 ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self;
95 NNACL_CHECK_NULL_RETURN_ERR(sw_1x1);
96
97 if (sw_1x1->matmul_ != NULL) {
98 sw_1x1->matmul_->matrix_b_.origin_ptr_ = NULL;
99 sw_1x1->matmul_->matrix_c_.origin_ptr_ = NULL;
100
101 (void)sw_1x1->matmul_->base_.Release(&sw_1x1->matmul_->base_);
102
103 if (sw_1x1->matmul_->base_.param_ != NULL) {
104 free(sw_1x1->matmul_->base_.param_);
105 sw_1x1->matmul_->base_.param_ = NULL;
106 }
107
108 free(sw_1x1->matmul_);
109 sw_1x1->matmul_ = NULL;
110 }
111
112 ConvBaseRelease(&sw_1x1->conv_);
113 return NNACL_OK;
114 }
115
CreateConvolutionSW1x1(ConvParameter * conv_param,bool input_const,bool weight_const)116 ConvolutionBaseStruct *CreateConvolutionSW1x1(ConvParameter *conv_param, bool input_const, bool weight_const) {
117 ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)malloc(sizeof(ConvolutionSW1x1Struct));
118 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(sw_1x1);
119 memset(sw_1x1, 0, sizeof(ConvolutionSW1x1Struct));
120
121 sw_1x1->conv_.is_sharing_pack_ = false;
122 sw_1x1->conv_.base_.Compute = ConvolutionSW1x1Compute;
123 sw_1x1->conv_.base_.Resize = ConvolutionSW1x1Resize;
124 sw_1x1->conv_.base_.Prepare = ConvolutionSW1x1Prepare;
125 sw_1x1->conv_.base_.Release = ConvolutionSW1x1Release;
126
127 OpParameter *param = (OpParameter *)malloc(sizeof(MatMulParameter));
128 if (param == NULL) {
129 free(sw_1x1);
130 return NULL;
131 }
132 MatMulParameter *matmul_param = (MatMulParameter *)param;
133 matmul_param->op_parameter_ = conv_param->op_parameter_;
134 matmul_param->act_type_ = conv_param->act_type_;
135 matmul_param->a_transpose_ = false;
136 matmul_param->b_transpose_ = true;
137
138 KernelBase *matmul = CreateMatmulKernel();
139 if (matmul == NULL) {
140 free(sw_1x1);
141 free(param);
142 return NULL;
143 }
144
145 ((MatmulStruct *)matmul)->is_sharing_pack_ = false;
146 ((MatmulStruct *)matmul)->a_const_ = input_const;
147 ((MatmulStruct *)matmul)->b_const_ = weight_const;
148 ((MatmulStruct *)matmul)->base_.param_ = param;
149 sw_1x1->matmul_ = (MatmulStruct *)matmul;
150 return (ConvolutionBaseStruct *)sw_1x1;
151 }
152 #endif
153