1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/pack_weight.h"
18 #include "src/extendrt/dynamic_mem_allocator.h"
19 namespace mindspore::lite {
InitPackWeight(const void * model_buf,size_t model_size,std::string id,int numa_id,bool need_copy_buf)20 STATUS PackWeight::InitPackWeight(const void *model_buf, size_t model_size, std::string id, int numa_id,
21 bool need_copy_buf) {
22 std::lock_guard<std::mutex> lock(mtx_weight_);
23 if (model_buf == nullptr || model_weights_.size() != shared_bufs_.size()) {
24 MS_LOG(ERROR) << "model buf is nullptr in pack weight manager.";
25 return RET_ERROR;
26 }
27 if (model_weights_.find(id) != model_weights_.end() && model_weights_[id].find(numa_id) != model_weights_[id].end()) {
28 MS_LOG(INFO) << "same numa id, use same model buf.";
29 return RET_OK;
30 }
31 std::shared_ptr<Allocator> allocator = nullptr;
32 #ifdef BFC_MEMORY
33 allocator = std::make_shared<DynamicMemAllocator>(numa_id);
34 #else
35 allocator = std::make_shared<DefaultAllocator>();
36 #endif
37 if (allocator == nullptr) {
38 MS_LOG(ERROR) << "allocator is nullptr in pack weight manager.";
39 return RET_ERROR;
40 }
41 auto *model_const_weight = new (std::nothrow) ModelConstWeight();
42 if (model_const_weight == nullptr) {
43 MS_LOG(ERROR) << "model const weight is nullptr.";
44 return RET_ERROR;
45 }
46 void *new_model_buf = const_cast<void *>(model_buf);
47 if (need_copy_buf) {
48 new_model_buf = allocator->Malloc(model_size);
49 if (new_model_buf == nullptr) {
50 MS_LOG(ERROR) << "new model buf is nullptr in pack weight manager.";
51 return RET_ERROR;
52 }
53 memcpy(new_model_buf, model_buf, model_size);
54 model_const_weight->copy_buf = need_copy_buf;
55 }
56 model_const_weight->allocator = allocator;
57 model_const_weight->numa_id = numa_id;
58 if (model_weights_.find(id) != model_weights_.end()) {
59 model_weights_[id][numa_id] = model_const_weight;
60 shared_bufs_[id][numa_id] = new_model_buf;
61 } else {
62 std::unordered_map<int, ModelConstWeight *> numa_model_weight;
63 numa_model_weight[numa_id] = model_const_weight;
64 model_weights_[id] = numa_model_weight;
65 std::unordered_map<int, void *> numa_model_buf;
66 numa_model_buf[numa_id] = new_model_buf;
67 shared_bufs_[id] = numa_model_buf;
68 }
69 return RET_OK;
70 }
71
GetSharedModelBuf(std::string id,int numa_id)72 char *PackWeight::GetSharedModelBuf(std::string id, int numa_id) {
73 std::lock_guard<std::mutex> lock(mtx_weight_);
74 if (shared_bufs_.find(id) == shared_bufs_.end() || shared_bufs_[id].find(numa_id) == shared_bufs_[id].end()) {
75 MS_LOG(ERROR) << "can not find numa id in saved model buf, id: " << id << ", numa id: " << numa_id;
76 return nullptr;
77 }
78 return static_cast<char *>(shared_bufs_[id][numa_id]);
79 }
80
StoreOriginTensorData(const void * model_buf,const void * origin_tensor_data)81 STATUS PackWeight::StoreOriginTensorData(const void *model_buf, const void *origin_tensor_data) {
82 std::lock_guard<std::mutex> lock(mtx_weight_);
83 for (auto &item : shared_bufs_) {
84 for (auto &numa_item : item.second) {
85 if (numa_item.second == model_buf) {
86 std::string id = item.first;
87 int numa_id = numa_item.first;
88 auto &model_weight = model_weights_[id][numa_id];
89 auto &packed_pair = model_weight->origin_and_packed_pair;
90 if (packed_pair.find(origin_tensor_data) != packed_pair.end()) {
91 MS_LOG(DEBUG) << "origin tensor data already store by other model.";
92 return RET_OK;
93 }
94 packed_pair.insert(std::make_pair(origin_tensor_data, nullptr));
95 return RET_OK;
96 }
97 }
98 }
99 MS_LOG(ERROR) << "can not find model buf in store origin Tensor";
100 return RET_ERROR;
101 }
102
ReplaceFp16Data(void * origin_fp16_data,size_t size)103 void *PackWeight::ReplaceFp16Data(void *origin_fp16_data, size_t size) {
104 std::lock_guard<std::mutex> lock(mtx_weight_);
105 if (fp16_fp32_data_pair_.find(origin_fp16_data) != fp16_fp32_data_pair_.end()) {
106 return fp16_fp32_data_pair_[origin_fp16_data];
107 } else {
108 for (auto &numa_item : model_weights_) {
109 for (auto &item : numa_item.second) {
110 if (item.second->origin_and_packed_pair.find(origin_fp16_data) != item.second->origin_and_packed_pair.end()) {
111 auto &model_weight = item.second;
112 auto &origin_and_packed_pair = model_weight->origin_and_packed_pair;
113 if (origin_and_packed_pair.find(origin_fp16_data) == origin_and_packed_pair.end()) {
114 MS_LOG(ERROR) << "origin fp16 data not find.";
115 return nullptr;
116 }
117 auto allocator = model_weight->allocator;
118 void *data = allocator->Malloc(size);
119 if (data == nullptr) {
120 MS_LOG(ERROR) << "malloc failed.";
121 return nullptr;
122 }
123 origin_and_packed_pair.insert(std::make_pair(data, nullptr));
124 model_weight->fp16_fp32_data.insert(data);
125 origin_and_packed_pair.erase(origin_fp16_data);
126 fp16_fp32_data_pair_.insert(std::make_pair(origin_fp16_data, data));
127 return data;
128 }
129 }
130 }
131 }
132 MS_LOG(ERROR) << "ReplaceFp16Data failed.";
133 return nullptr;
134 }
135
ReplaceOriginTensorData(const void * model_buf,std::vector<Tensor * > * tensors,int tensor_index)136 STATUS PackWeight::ReplaceOriginTensorData(const void *model_buf, std::vector<Tensor *> *tensors, int tensor_index) {
137 std::lock_guard<std::mutex> lock(mtx_weight_);
138 for (auto &item : shared_bufs_) {
139 for (auto &numa_item : item.second) {
140 if (numa_item.second == model_buf) {
141 std::string id = item.first;
142 int numa_id = numa_item.first;
143 auto &tensor = tensors->at(tensor_index);
144 auto &model_weight = model_weights_[id][numa_id];
145 if (model_weight->tensors_data.find(tensor_index) == model_weight->tensors_data.end()) {
146 auto allocator = model_weight->allocator;
147 void *new_data = allocator->Malloc(tensor->Size());
148 if (new_data == nullptr) {
149 MS_LOG(ERROR) << "allocator malloc data failed.";
150 return RET_ERROR;
151 }
152 memcpy(new_data, tensor->data(), tensor->Size());
153 MS_CHECK_TRUE_MSG(tensor->own_data(), RET_ERROR, "tensor data is not own data.");
154 tensor->FreeData();
155 tensor->set_data(new_data);
156 tensor->set_own_data(false);
157 model_weight->tensors_data.insert(std::make_pair(tensor_index, new_data));
158 } else {
159 auto new_data = model_weight->tensors_data[tensor_index];
160 tensor->FreeData();
161 tensor->set_data(new_data);
162 tensor->set_own_data(false);
163 }
164 return RET_OK;
165 }
166 }
167 }
168 MS_LOG(ERROR) << "can not find model buf in store origin Tensor";
169 return RET_ERROR;
170 }
171
GetPackData(const void * tensor_data,const size_t size,bool * is_packed)172 void *PackWeight::GetPackData(const void *tensor_data, const size_t size, bool *is_packed) {
173 std::lock_guard<std::mutex> lock(mtx_weight_);
174 MS_CHECK_TRUE_RET(tensor_data != nullptr, nullptr);
175 for (auto &numa_item : model_weights_) {
176 for (auto &item : numa_item.second) {
177 auto &model_weight = item.second;
178 auto &origin_packed_weight = model_weight->origin_and_packed_pair;
179 if (origin_packed_weight.find(tensor_data) == origin_packed_weight.end()) {
180 continue;
181 }
182 auto packed_tensor_data = origin_packed_weight[tensor_data];
183 if (packed_tensor_data != nullptr) {
184 *is_packed = true;
185 return packed_tensor_data;
186 } else {
187 auto weight_allocator = model_weight->allocator;
188 packed_tensor_data = weight_allocator->Malloc(size);
189 if (packed_tensor_data == nullptr) {
190 MS_LOG(ERROR) << "malloc failed.";
191 return nullptr;
192 }
193 origin_packed_weight[tensor_data] = packed_tensor_data;
194 *is_packed = false;
195 return packed_tensor_data;
196 }
197 }
198 }
199 *is_packed = false;
200 MS_LOG(ERROR) << "can not find tensor data in origin tensor data.";
201 return nullptr;
202 }
203
FreePackedWeight(ModelConstWeight * weight)204 void PackWeight::FreePackedWeight(ModelConstWeight *weight) {
205 MS_CHECK_TRUE_RET_VOID(weight != nullptr);
206 for (auto &origin_and_packed_pair : weight->origin_and_packed_pair) {
207 auto &packed_data = origin_and_packed_pair.second;
208 auto allocator = weight->allocator;
209 MS_CHECK_TRUE_RET_VOID(allocator != nullptr);
210 if (packed_data != nullptr) {
211 allocator->Free(packed_data);
212 packed_data = nullptr;
213 }
214 }
215 weight->origin_and_packed_pair.clear();
216 }
217
FreeTensorData(ModelConstWeight * weight)218 void PackWeight::FreeTensorData(ModelConstWeight *weight) {
219 MS_CHECK_TRUE_RET_VOID(weight != nullptr);
220 for (auto &tensor_data : weight->tensors_data) {
221 auto &data = tensor_data.second;
222 auto allocator = weight->allocator;
223 MS_CHECK_TRUE_RET_VOID(allocator != nullptr);
224 if (data != nullptr) {
225 allocator->Free(data);
226 data = nullptr;
227 }
228 }
229 weight->tensors_data.clear();
230 }
231
FreeFp16ToFp32Data(ModelConstWeight * weight)232 void PackWeight::FreeFp16ToFp32Data(ModelConstWeight *weight) {
233 MS_CHECK_TRUE_RET_VOID(weight != nullptr);
234 for (auto &data : weight->fp16_fp32_data) {
235 auto allocator = weight->allocator;
236 MS_CHECK_TRUE_RET_VOID(allocator != nullptr);
237 if (data != nullptr) {
238 allocator->Free(data);
239 }
240 }
241 weight->fp16_fp32_data.clear();
242 }
243
FreePackWeight(std::string id,bool free_all)244 void PackWeight::FreePackWeight(std::string id, bool free_all) {
245 std::lock_guard<std::mutex> lock(mtx_weight_);
246 MS_LOG(INFO) << "model weight size: " << model_weights_.size() << " | shared buf size: " << shared_bufs_.size();
247 if (model_weights_.find(id) == model_weights_.end() || shared_bufs_.find(id) == shared_bufs_.end()) {
248 MS_LOG(INFO) << "can not find id in shared bufs or model weights.";
249 return;
250 }
251 for (auto &item : model_weights_[id]) {
252 auto numa_id = item.first;
253 ModelConstWeight *model_weight = model_weights_[id][numa_id];
254 void *model_buf = shared_bufs_[id][numa_id];
255 if (model_buf == nullptr || model_weight == nullptr) {
256 MS_LOG(ERROR) << "model buf or model weight is nullptr.";
257 return;
258 }
259 FreePackedWeight(model_weight);
260 FreeFp16ToFp32Data(model_weight);
261 FreeTensorData(model_weight);
262 if (model_weight->copy_buf) {
263 auto &allocator = model_weight->allocator;
264 allocator->Free(model_buf);
265 model_buf = nullptr;
266 }
267 delete model_weight;
268 model_weight = nullptr;
269 }
270 if (!free_all) {
271 model_weights_.erase(id);
272 }
273 shared_bufs_.erase(id);
274 MS_LOG(INFO) << "FreePackWeight done.";
275 }
276
~PackWeight()277 PackWeight::~PackWeight() {
278 MS_LOG(INFO) << "~PackWeight() begin";
279 if (model_weights_.empty()) {
280 MS_LOG(INFO) << "~PackWeight() empty end";
281 return;
282 }
283 for (auto &numa_item : model_weights_) {
284 std::string id = numa_item.first;
285 FreePackWeight(id, true);
286 }
287 model_weights_.clear();
288 MS_LOG(INFO) << "~PackWeight() end";
289 }
290 } // namespace mindspore::lite
291