• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_AVX
18 #include "nnacl/kernel/convolution_sw_1x1.h"
19 #include "nnacl/kernel/matmul_base.h"
20 #include "nnacl/kernel/matmul_create.h"
21 
MatmulConv1x1Prelare(ConvolutionSW1x1Struct * sw_1x1)22 int MatmulConv1x1Prelare(ConvolutionSW1x1Struct *sw_1x1) {
23   sw_1x1->matmul_->batch_ = 1;
24   sw_1x1->matmul_->a_batch_ = 1;
25   sw_1x1->matmul_->b_batch_ = 1;
26 
27   sw_1x1->matmul_->compute_.deep_ = sw_1x1->conv_.compute_.in_c_;
28   sw_1x1->matmul_->compute_.col_ = sw_1x1->conv_.compute_.out_c_;
29   sw_1x1->matmul_->compute_.row_ = sw_1x1->conv_.compute_.in_hw_ * sw_1x1->conv_.compute_.in_n_;
30 
31   return sw_1x1->matmul_->base_.Prepare(&sw_1x1->matmul_->base_);
32 }
33 
MatmulConv1x1Resize(ConvolutionSW1x1Struct * sw_1x1)34 int MatmulConv1x1Resize(ConvolutionSW1x1Struct *sw_1x1) {
35   sw_1x1->matmul_->compute_.deep_ = sw_1x1->conv_.compute_.in_c_;
36   sw_1x1->matmul_->compute_.col_ = sw_1x1->conv_.compute_.out_c_;
37   sw_1x1->matmul_->compute_.row_ = sw_1x1->conv_.compute_.in_hw_ * sw_1x1->conv_.compute_.in_n_;
38 
39   MatmulBaseFreeBatchOffset(sw_1x1->matmul_);
40   int ret = MatmulBaseMallocBatchOffset(sw_1x1->matmul_);
41   if (ret != NNACL_OK) {
42     return ret;
43   }
44 
45   return sw_1x1->matmul_->base_.Resize(&sw_1x1->matmul_->base_);
46 }
47 
UpdateTensorInfo(KernelBase * self,ConvolutionSW1x1Struct * sw_1x1)48 void UpdateTensorInfo(KernelBase *self, ConvolutionSW1x1Struct *sw_1x1) {
49   sw_1x1->matmul_->base_.in_ = self->in_;
50   sw_1x1->matmul_->base_.in_size_ = self->in_size_;
51   sw_1x1->matmul_->base_.out_ = self->out_;
52   sw_1x1->matmul_->base_.out_size_ = self->out_size_;
53   sw_1x1->matmul_->base_.workspace_ = self->workspace_;
54 }
55 
ConvolutionSW1x1Compute(KernelBase * self)56 int ConvolutionSW1x1Compute(KernelBase *self) {
57   ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self;
58   NNACL_CHECK_NULL_RETURN_ERR(sw_1x1);
59   NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_);
60 
61   UpdateTensorInfo(self, sw_1x1);
62   return sw_1x1->matmul_->base_.Compute(&sw_1x1->matmul_->base_);
63 }
64 
ConvolutionSW1x1Resize(KernelBase * self)65 int ConvolutionSW1x1Resize(KernelBase *self) {
66   ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self;
67   NNACL_CHECK_NULL_RETURN_ERR(sw_1x1);
68   NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_);
69 
70   UpdateTensorInfo(self, sw_1x1);
71   return MatmulConv1x1Resize(sw_1x1);
72 }
73 
ConvolutionSW1x1Prepare(KernelBase * self)74 int ConvolutionSW1x1Prepare(KernelBase *self) {
75   ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self;
76   NNACL_CHECK_NULL_RETURN_ERR(sw_1x1);
77   NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_);
78 
79   sw_1x1->matmul_->matrix_b_.origin_ptr_ = sw_1x1->conv_.origin_weight_;
80   sw_1x1->matmul_->matrix_b_.origin_need_free_ = false;
81   sw_1x1->matmul_->matrix_c_.origin_ptr_ = sw_1x1->conv_.origin_bias_;
82   sw_1x1->matmul_->matrix_c_.origin_need_free_ = false;
83 
84   sw_1x1->matmul_->infer_shape_ = sw_1x1->conv_.infershape_done_;
85   sw_1x1->matmul_->base_.train_session_ = self->train_session_;
86   sw_1x1->matmul_->base_.thread_nr_ = self->thread_nr_;
87   sw_1x1->matmul_->base_.env_ = self->env_;
88 
89   UpdateTensorInfo(self, sw_1x1);
90   return MatmulConv1x1Prelare(sw_1x1);
91 }
92 
ConvolutionSW1x1Release(KernelBase * self)93 int ConvolutionSW1x1Release(KernelBase *self) {
94   ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self;
95   NNACL_CHECK_NULL_RETURN_ERR(sw_1x1);
96 
97   if (sw_1x1->matmul_ != NULL) {
98     sw_1x1->matmul_->matrix_b_.origin_ptr_ = NULL;
99     sw_1x1->matmul_->matrix_c_.origin_ptr_ = NULL;
100 
101     (void)sw_1x1->matmul_->base_.Release(&sw_1x1->matmul_->base_);
102 
103     if (sw_1x1->matmul_->base_.param_ != NULL) {
104       free(sw_1x1->matmul_->base_.param_);
105       sw_1x1->matmul_->base_.param_ = NULL;
106     }
107 
108     free(sw_1x1->matmul_);
109     sw_1x1->matmul_ = NULL;
110   }
111 
112   ConvBaseRelease(&sw_1x1->conv_);
113   return NNACL_OK;
114 }
115 
CreateConvolutionSW1x1(ConvParameter * conv_param,bool input_const,bool weight_const)116 ConvolutionBaseStruct *CreateConvolutionSW1x1(ConvParameter *conv_param, bool input_const, bool weight_const) {
117   ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)malloc(sizeof(ConvolutionSW1x1Struct));
118   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(sw_1x1);
119   memset(sw_1x1, 0, sizeof(ConvolutionSW1x1Struct));
120 
121   sw_1x1->conv_.is_sharing_pack_ = false;
122   sw_1x1->conv_.base_.Compute = ConvolutionSW1x1Compute;
123   sw_1x1->conv_.base_.Resize = ConvolutionSW1x1Resize;
124   sw_1x1->conv_.base_.Prepare = ConvolutionSW1x1Prepare;
125   sw_1x1->conv_.base_.Release = ConvolutionSW1x1Release;
126 
127   OpParameter *param = (OpParameter *)malloc(sizeof(MatMulParameter));
128   if (param == NULL) {
129     free(sw_1x1);
130     return NULL;
131   }
132   MatMulParameter *matmul_param = (MatMulParameter *)param;
133   matmul_param->op_parameter_ = conv_param->op_parameter_;
134   matmul_param->act_type_ = conv_param->act_type_;
135   matmul_param->a_transpose_ = false;
136   matmul_param->b_transpose_ = true;
137 
138   KernelBase *matmul = CreateMatmulKernel();
139   if (matmul == NULL) {
140     free(sw_1x1);
141     free(param);
142     return NULL;
143   }
144 
145   ((MatmulStruct *)matmul)->is_sharing_pack_ = false;
146   ((MatmulStruct *)matmul)->a_const_ = input_const;
147   ((MatmulStruct *)matmul)->b_const_ = weight_const;
148   ((MatmulStruct *)matmul)->base_.param_ = param;
149   sw_1x1->matmul_ = (MatmulStruct *)matmul;
150   return (ConvolutionBaseStruct *)sw_1x1;
151 }
152 #endif
153