1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/runtime/kernel/arm/base/group_convolution_creator.h"
18
19 namespace mindspore::kernel {
CopyTensorQuantParam(lite::Tensor * dst,lite::Tensor * src)20 void CopyTensorQuantParam(lite::Tensor *dst, lite::Tensor *src) {
21 for (size_t i = 0; i < src->quant_params().size(); i++) {
22 dst->AddQuantParam(src->quant_params().at(i));
23 }
24 }
25
CreateNewConvParameter(ConvParameter * parameter)26 ConvParameter *CreateNewConvParameter(ConvParameter *parameter) {
27 auto conv_parameter = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
28 if (conv_parameter == nullptr) {
29 MS_LOG(ERROR) << "Malloc new conv parameter failed.";
30 return nullptr;
31 }
32 memcpy(conv_parameter, parameter, sizeof(ConvParameter));
33 return conv_parameter;
34 }
35
FreeCurrentConv(ConvParameter * conv_param,std::vector<lite::Tensor * > * new_inputs,std::vector<lite::Tensor * > * new_outputs)36 void FreeCurrentConv(ConvParameter *conv_param, std::vector<lite::Tensor *> *new_inputs,
37 std::vector<lite::Tensor *> *new_outputs) {
38 if (conv_param != nullptr) {
39 free(conv_param);
40 }
41 if (new_inputs != nullptr) {
42 for (auto &in_tensor : *new_inputs) {
43 delete in_tensor;
44 in_tensor = nullptr;
45 }
46 }
47 if (new_outputs != nullptr) {
48 for (auto &out_tensor : *new_outputs) {
49 delete out_tensor;
50 out_tensor = nullptr;
51 }
52 }
53 }
54
TensorMalloc(lite::Tensor * tensor)55 static inline lite::Tensor *TensorMalloc(lite::Tensor *tensor) {
56 if (tensor->MallocData() != lite::RET_OK) {
57 delete tensor;
58 MS_LOG(ERROR) << "malloc tensor data failed.";
59 return nullptr;
60 }
61 return tensor;
62 }
63
CreateConstTensor(lite::Tensor * tensor,const std::vector<int> & shape,const int index)64 lite::Tensor *CreateConstTensor(lite::Tensor *tensor, const std::vector<int> &shape, const int index) {
65 auto new_tensor =
66 new (std::nothrow) lite::Tensor(tensor->data_type(), shape, mindspore::NHWC, lite::Tensor::Category::CONST_TENSOR);
67 if (new_tensor == nullptr) {
68 MS_LOG(ERROR) << "Create new_tensor failed.";
69 return nullptr;
70 }
71 auto ret = new_tensor->MallocData();
72 if (ret != lite::RET_OK) {
73 delete new_tensor;
74 MS_LOG(ERROR) << "Malloc new_tensor failed.";
75 return nullptr;
76 }
77
78 uint8_t *new_tensor_data = reinterpret_cast<uint8_t *>(tensor->data()) + index * new_tensor->Size();
79 memcpy(new_tensor->data(), new_tensor_data, new_tensor->Size());
80 return new_tensor;
81 }
82
CreateVarTensor(const TensorInfo & tensor_info,bool inferred)83 lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info, bool inferred) {
84 auto tensor = new (std::nothrow) lite::Tensor();
85 if (!tensor) {
86 MS_LOG(ERROR) << "new tensor failed.";
87 return nullptr;
88 }
89 tensor->set_data_type(tensor_info.data_type_);
90 tensor->set_format(tensor_info.format_);
91 tensor->set_category(tensor_info.tensor_type_);
92 if (tensor_info.is_in_) {
93 tensor->set_shape(tensor_info.shape_);
94 }
95
96 if (inferred) {
97 // set shape of out tensor
98 if (!tensor_info.is_in_) {
99 tensor->set_shape(tensor_info.shape_);
100 }
101 return TensorMalloc(tensor);
102 }
103 return tensor;
104 }
105
106 /* Class GroupConv Creator Implement Part */
CopyQuantParam(std::vector<lite::Tensor * > * tensors)107 void GroupConvCreator::CopyQuantParam(std::vector<lite::Tensor *> *tensors) {
108 for (size_t j = 0; j < origin_inputs_.size(); ++j) {
109 CopyTensorQuantParam(tensors->at(j), origin_inputs_.at(j));
110 }
111 }
112
FreeGroupConvs()113 void GroupConvCreator::FreeGroupConvs() {
114 for (auto &sub_conv : group_convs_) {
115 for (auto &in_tensor : sub_conv->in_tensors()) {
116 delete in_tensor;
117 }
118 for (auto &out_tensor : sub_conv->out_tensors()) {
119 delete out_tensor;
120 }
121 delete sub_conv;
122 sub_conv = nullptr;
123 }
124 group_convs_.clear();
125 }
126
NewInputTensor(std::vector<lite::Tensor * > * tensors)127 int GroupConvCreator::NewInputTensor(std::vector<lite::Tensor *> *tensors) {
128 auto in_tensor =
129 CreateVarTensor({input_shape_, mindspore::NHWC, data_type_, lite::Tensor::Category::VAR, true}, infered_);
130 if (in_tensor == nullptr) {
131 return lite::RET_ERROR;
132 }
133 tensors->emplace_back(in_tensor);
134 return lite::RET_OK;
135 }
136
NewOutputTensor(std::vector<lite::Tensor * > * tensors,lite::Tensor * output)137 int GroupConvCreator::NewOutputTensor(std::vector<lite::Tensor *> *tensors, lite::Tensor *output) {
138 auto out_tensor = CreateVarTensor({output_shape_, output->format(), data_type_, output->category(), false}, infered_);
139 if (out_tensor == nullptr) {
140 return lite::RET_ERROR;
141 }
142 if (is_quant_) {
143 CopyTensorQuantParam(out_tensor, output);
144 }
145 tensors->emplace_back(out_tensor);
146 return lite::RET_OK;
147 }
148
NewConstTensor(std::vector<lite::Tensor * > * tensors,int group_id)149 int GroupConvCreator::NewConstTensor(std::vector<lite::Tensor *> *tensors, int group_id) {
150 std::vector<std::pair<int, std::vector<int>>> const_tensor_list{std::make_pair(kWeightIndex, filter_shape_)};
151 if (origin_inputs_.size() == kInputSize2) {
152 const_tensor_list.emplace_back(std::make_pair(kBiasIndex, bias_shape_));
153 }
154 for (auto &info : const_tensor_list) {
155 auto const_tensor = CreateConstTensor(origin_inputs_.at(info.first), info.second, group_id);
156 if (const_tensor == nullptr) {
157 return lite::RET_ERROR;
158 }
159 tensors->emplace_back(const_tensor);
160 }
161 return lite::RET_OK;
162 }
163
SetShapeOfTensors()164 void GroupConvCreator::SetShapeOfTensors() {
165 int new_in_channel = origin_inputs_.at(kWeightIndex)->Channel();
166 int new_out_channel;
167 if (conv_param_->group_ == 0) {
168 MS_LOG(ERROR) << "Divisor 'group' cannot be 0.";
169 return;
170 } else {
171 new_out_channel = origin_inputs_.at(kWeightIndex)->Batch() / conv_param_->group_;
172 }
173
174 /* set shape */
175 set_filter_shape({new_out_channel, conv_param_->kernel_h_, conv_param_->kernel_w_, new_in_channel});
176 set_bias_shape({new_out_channel});
177 if (infered_) {
178 conv_param_->input_channel_ = new_in_channel;
179 conv_param_->output_channel_ = new_out_channel;
180 set_input_shape({origin_inputs_.front()->Batch(), origin_inputs_.front()->Height(), origin_inputs_.front()->Width(),
181 new_in_channel});
182 set_output_shape({origin_inputs_.front()->Batch(), origin_outputs_.front()->Height(),
183 origin_outputs_.front()->Width(), new_out_channel});
184 }
185 }
186
GetSingleConvParam(ConvParameter * conv_param,std::vector<lite::Tensor * > * new_inputs,std::vector<lite::Tensor * > * new_outputs,int group_id)187 int GroupConvCreator::GetSingleConvParam(ConvParameter *conv_param, std::vector<lite::Tensor *> *new_inputs,
188 std::vector<lite::Tensor *> *new_outputs, int group_id) {
189 if (conv_param == nullptr) {
190 FreeGroupConvs();
191 return lite::RET_ERROR;
192 }
193 // create new input for each group
194 if (NewInputTensor(new_inputs) != lite::RET_OK) {
195 MS_LOG(ERROR) << "new input tensor failed.";
196 FreeGroupConvs();
197 FreeCurrentConv(conv_param, new_inputs, {});
198 return lite::RET_ERROR;
199 }
200 // const tensor
201 if (NewConstTensor(new_inputs, group_id) != lite::RET_OK) {
202 MS_LOG(ERROR) << "new const tensor failed.";
203 FreeGroupConvs();
204 FreeCurrentConv(conv_param, new_inputs, {});
205 return lite::RET_ERROR;
206 }
207 // create new output tensor
208 for (auto &output : origin_outputs_) {
209 if (NewOutputTensor(new_outputs, output) != lite::RET_OK) {
210 MS_LOG(ERROR) << "new output tensor failed.";
211 FreeGroupConvs();
212 FreeCurrentConv(conv_param, new_inputs, new_outputs);
213 return lite::RET_ERROR;
214 }
215 }
216 return lite::RET_OK;
217 }
218 } // namespace mindspore::kernel
219