• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/pack_weight.h"
18 #include "src/extendrt/dynamic_mem_allocator.h"
19 namespace mindspore::lite {
InitPackWeight(const void * model_buf,size_t model_size,std::string id,int numa_id,bool need_copy_buf)20 STATUS PackWeight::InitPackWeight(const void *model_buf, size_t model_size, std::string id, int numa_id,
21                                   bool need_copy_buf) {
22   std::lock_guard<std::mutex> lock(mtx_weight_);
23   if (model_buf == nullptr || model_weights_.size() != shared_bufs_.size()) {
24     MS_LOG(ERROR) << "model buf is nullptr in pack weight manager.";
25     return RET_ERROR;
26   }
27   if (model_weights_.find(id) != model_weights_.end() && model_weights_[id].find(numa_id) != model_weights_[id].end()) {
28     MS_LOG(INFO) << "same numa id, use same model buf.";
29     return RET_OK;
30   }
31   std::shared_ptr<Allocator> allocator = nullptr;
32 #ifdef BFC_MEMORY
33   allocator = std::make_shared<DynamicMemAllocator>(numa_id);
34 #else
35   allocator = std::make_shared<DefaultAllocator>();
36 #endif
37   if (allocator == nullptr) {
38     MS_LOG(ERROR) << "allocator is nullptr in pack weight manager.";
39     return RET_ERROR;
40   }
41   auto *model_const_weight = new (std::nothrow) ModelConstWeight();
42   if (model_const_weight == nullptr) {
43     MS_LOG(ERROR) << "model const weight is nullptr.";
44     return RET_ERROR;
45   }
46   void *new_model_buf = const_cast<void *>(model_buf);
47   if (need_copy_buf) {
48     new_model_buf = allocator->Malloc(model_size);
49     if (new_model_buf == nullptr) {
50       MS_LOG(ERROR) << "new model buf is nullptr in pack weight manager.";
51       return RET_ERROR;
52     }
53     memcpy(new_model_buf, model_buf, model_size);
54     model_const_weight->copy_buf = need_copy_buf;
55   }
56   model_const_weight->allocator = allocator;
57   model_const_weight->numa_id = numa_id;
58   if (model_weights_.find(id) != model_weights_.end()) {
59     model_weights_[id][numa_id] = model_const_weight;
60     shared_bufs_[id][numa_id] = new_model_buf;
61   } else {
62     std::unordered_map<int, ModelConstWeight *> numa_model_weight;
63     numa_model_weight[numa_id] = model_const_weight;
64     model_weights_[id] = numa_model_weight;
65     std::unordered_map<int, void *> numa_model_buf;
66     numa_model_buf[numa_id] = new_model_buf;
67     shared_bufs_[id] = numa_model_buf;
68   }
69   return RET_OK;
70 }
71 
GetSharedModelBuf(std::string id,int numa_id)72 char *PackWeight::GetSharedModelBuf(std::string id, int numa_id) {
73   std::lock_guard<std::mutex> lock(mtx_weight_);
74   if (shared_bufs_.find(id) == shared_bufs_.end() || shared_bufs_[id].find(numa_id) == shared_bufs_[id].end()) {
75     MS_LOG(ERROR) << "can not find numa id in saved model buf, id: " << id << ", numa id: " << numa_id;
76     return nullptr;
77   }
78   return static_cast<char *>(shared_bufs_[id][numa_id]);
79 }
80 
StoreOriginTensorData(const void * model_buf,const void * origin_tensor_data)81 STATUS PackWeight::StoreOriginTensorData(const void *model_buf, const void *origin_tensor_data) {
82   std::lock_guard<std::mutex> lock(mtx_weight_);
83   for (auto &item : shared_bufs_) {
84     for (auto &numa_item : item.second) {
85       if (numa_item.second == model_buf) {
86         std::string id = item.first;
87         int numa_id = numa_item.first;
88         auto &model_weight = model_weights_[id][numa_id];
89         auto &packed_pair = model_weight->origin_and_packed_pair;
90         if (packed_pair.find(origin_tensor_data) != packed_pair.end()) {
91           MS_LOG(DEBUG) << "origin tensor data already store by other model.";
92           return RET_OK;
93         }
94         packed_pair.insert(std::make_pair(origin_tensor_data, nullptr));
95         return RET_OK;
96       }
97     }
98   }
99   MS_LOG(ERROR) << "can not find model buf in store origin Tensor";
100   return RET_ERROR;
101 }
102 
ReplaceFp16Data(void * origin_fp16_data,size_t size)103 void *PackWeight::ReplaceFp16Data(void *origin_fp16_data, size_t size) {
104   std::lock_guard<std::mutex> lock(mtx_weight_);
105   if (fp16_fp32_data_pair_.find(origin_fp16_data) != fp16_fp32_data_pair_.end()) {
106     return fp16_fp32_data_pair_[origin_fp16_data];
107   } else {
108     for (auto &numa_item : model_weights_) {
109       for (auto &item : numa_item.second) {
110         if (item.second->origin_and_packed_pair.find(origin_fp16_data) != item.second->origin_and_packed_pair.end()) {
111           auto &model_weight = item.second;
112           auto &origin_and_packed_pair = model_weight->origin_and_packed_pair;
113           if (origin_and_packed_pair.find(origin_fp16_data) == origin_and_packed_pair.end()) {
114             MS_LOG(ERROR) << "origin fp16 data not find.";
115             return nullptr;
116           }
117           auto allocator = model_weight->allocator;
118           void *data = allocator->Malloc(size);
119           if (data == nullptr) {
120             MS_LOG(ERROR) << "malloc failed.";
121             return nullptr;
122           }
123           origin_and_packed_pair.insert(std::make_pair(data, nullptr));
124           model_weight->fp16_fp32_data.insert(data);
125           origin_and_packed_pair.erase(origin_fp16_data);
126           fp16_fp32_data_pair_.insert(std::make_pair(origin_fp16_data, data));
127           return data;
128         }
129       }
130     }
131   }
132   MS_LOG(ERROR) << "ReplaceFp16Data failed.";
133   return nullptr;
134 }
135 
ReplaceOriginTensorData(const void * model_buf,std::vector<Tensor * > * tensors,int tensor_index)136 STATUS PackWeight::ReplaceOriginTensorData(const void *model_buf, std::vector<Tensor *> *tensors, int tensor_index) {
137   std::lock_guard<std::mutex> lock(mtx_weight_);
138   for (auto &item : shared_bufs_) {
139     for (auto &numa_item : item.second) {
140       if (numa_item.second == model_buf) {
141         std::string id = item.first;
142         int numa_id = numa_item.first;
143         auto &tensor = tensors->at(tensor_index);
144         auto &model_weight = model_weights_[id][numa_id];
145         if (model_weight->tensors_data.find(tensor_index) == model_weight->tensors_data.end()) {
146           auto allocator = model_weight->allocator;
147           void *new_data = allocator->Malloc(tensor->Size());
148           if (new_data == nullptr) {
149             MS_LOG(ERROR) << "allocator malloc data failed.";
150             return RET_ERROR;
151           }
152           memcpy(new_data, tensor->data(), tensor->Size());
153           MS_CHECK_TRUE_MSG(tensor->own_data(), RET_ERROR, "tensor data is not own data.");
154           tensor->FreeData();
155           tensor->set_data(new_data);
156           tensor->set_own_data(false);
157           model_weight->tensors_data.insert(std::make_pair(tensor_index, new_data));
158         } else {
159           auto new_data = model_weight->tensors_data[tensor_index];
160           tensor->FreeData();
161           tensor->set_data(new_data);
162           tensor->set_own_data(false);
163         }
164         return RET_OK;
165       }
166     }
167   }
168   MS_LOG(ERROR) << "can not find model buf in store origin Tensor";
169   return RET_ERROR;
170 }
171 
GetPackData(const void * tensor_data,const size_t size,bool * is_packed)172 void *PackWeight::GetPackData(const void *tensor_data, const size_t size, bool *is_packed) {
173   std::lock_guard<std::mutex> lock(mtx_weight_);
174   MS_CHECK_TRUE_RET(tensor_data != nullptr, nullptr);
175   for (auto &numa_item : model_weights_) {
176     for (auto &item : numa_item.second) {
177       auto &model_weight = item.second;
178       auto &origin_packed_weight = model_weight->origin_and_packed_pair;
179       if (origin_packed_weight.find(tensor_data) == origin_packed_weight.end()) {
180         continue;
181       }
182       auto packed_tensor_data = origin_packed_weight[tensor_data];
183       if (packed_tensor_data != nullptr) {
184         *is_packed = true;
185         return packed_tensor_data;
186       } else {
187         auto weight_allocator = model_weight->allocator;
188         packed_tensor_data = weight_allocator->Malloc(size);
189         if (packed_tensor_data == nullptr) {
190           MS_LOG(ERROR) << "malloc failed.";
191           return nullptr;
192         }
193         origin_packed_weight[tensor_data] = packed_tensor_data;
194         *is_packed = false;
195         return packed_tensor_data;
196       }
197     }
198   }
199   *is_packed = false;
200   MS_LOG(ERROR) << "can not find tensor data in origin tensor data.";
201   return nullptr;
202 }
203 
FreePackedWeight(ModelConstWeight * weight)204 void PackWeight::FreePackedWeight(ModelConstWeight *weight) {
205   MS_CHECK_TRUE_RET_VOID(weight != nullptr);
206   for (auto &origin_and_packed_pair : weight->origin_and_packed_pair) {
207     auto &packed_data = origin_and_packed_pair.second;
208     auto allocator = weight->allocator;
209     MS_CHECK_TRUE_RET_VOID(allocator != nullptr);
210     if (packed_data != nullptr) {
211       allocator->Free(packed_data);
212       packed_data = nullptr;
213     }
214   }
215   weight->origin_and_packed_pair.clear();
216 }
217 
FreeTensorData(ModelConstWeight * weight)218 void PackWeight::FreeTensorData(ModelConstWeight *weight) {
219   MS_CHECK_TRUE_RET_VOID(weight != nullptr);
220   for (auto &tensor_data : weight->tensors_data) {
221     auto &data = tensor_data.second;
222     auto allocator = weight->allocator;
223     MS_CHECK_TRUE_RET_VOID(allocator != nullptr);
224     if (data != nullptr) {
225       allocator->Free(data);
226       data = nullptr;
227     }
228   }
229   weight->tensors_data.clear();
230 }
231 
FreeFp16ToFp32Data(ModelConstWeight * weight)232 void PackWeight::FreeFp16ToFp32Data(ModelConstWeight *weight) {
233   MS_CHECK_TRUE_RET_VOID(weight != nullptr);
234   for (auto &data : weight->fp16_fp32_data) {
235     auto allocator = weight->allocator;
236     MS_CHECK_TRUE_RET_VOID(allocator != nullptr);
237     if (data != nullptr) {
238       allocator->Free(data);
239     }
240   }
241   weight->fp16_fp32_data.clear();
242 }
243 
FreePackWeight(std::string id,bool free_all)244 void PackWeight::FreePackWeight(std::string id, bool free_all) {
245   std::lock_guard<std::mutex> lock(mtx_weight_);
246   MS_LOG(INFO) << "model weight size: " << model_weights_.size() << " | shared buf size: " << shared_bufs_.size();
247   if (model_weights_.find(id) == model_weights_.end() || shared_bufs_.find(id) == shared_bufs_.end()) {
248     MS_LOG(INFO) << "can not find id in shared bufs or model weights.";
249     return;
250   }
251   for (auto &item : model_weights_[id]) {
252     auto numa_id = item.first;
253     ModelConstWeight *model_weight = model_weights_[id][numa_id];
254     void *model_buf = shared_bufs_[id][numa_id];
255     if (model_buf == nullptr || model_weight == nullptr) {
256       MS_LOG(ERROR) << "model buf or model weight is nullptr.";
257       return;
258     }
259     FreePackedWeight(model_weight);
260     FreeFp16ToFp32Data(model_weight);
261     FreeTensorData(model_weight);
262     if (model_weight->copy_buf) {
263       auto &allocator = model_weight->allocator;
264       allocator->Free(model_buf);
265       model_buf = nullptr;
266     }
267     delete model_weight;
268     model_weight = nullptr;
269   }
270   if (!free_all) {
271     model_weights_.erase(id);
272   }
273   shared_bufs_.erase(id);
274   MS_LOG(INFO) << "FreePackWeight done.";
275 }
276 
~PackWeight()277 PackWeight::~PackWeight() {
278   MS_LOG(INFO) << "~PackWeight() begin";
279   if (model_weights_.empty()) {
280     MS_LOG(INFO) << "~PackWeight() empty end";
281     return;
282   }
283   for (auto &numa_item : model_weights_) {
284     std::string id = numa_item.first;
285     FreePackWeight(id, true);
286   }
287   model_weights_.clear();
288   MS_LOG(INFO) << "~PackWeight() end";
289 }
290 }  // namespace mindspore::lite
291