1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/server/model_store.h"
18 #include <map>
19 #include <string>
20 #include <memory>
21 #include "fl/server/executor.h"
22
23 namespace mindspore {
24 namespace fl {
25 namespace server {
Initialize(uint32_t max_count)26 void ModelStore::Initialize(uint32_t max_count) {
27 if (!Executor::GetInstance().initialized()) {
28 MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage.";
29 return;
30 }
31
32 max_model_count_ = max_count;
33 initial_model_ = AssignNewModelMemory();
34 iteration_to_model_[kInitIterationNum] = initial_model_;
35 model_size_ = ComputeModelSize();
36 }
37
StoreModelByIterNum(size_t iteration,const std::map<std::string,AddressPtr> & new_model)38 void ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
39 std::unique_lock<std::mutex> lock(model_mtx_);
40 if (iteration_to_model_.count(iteration) != 0) {
41 MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored";
42 return;
43 }
44 if (new_model.empty()) {
45 MS_LOG(ERROR) << "Model feature map is empty.";
46 return;
47 }
48
49 std::shared_ptr<MemoryRegister> memory_register = nullptr;
50 if (iteration_to_model_.size() < max_model_count_) {
51 // If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model.
52 memory_register = AssignNewModelMemory();
53 MS_ERROR_IF_NULL_WO_RET_VAL(memory_register);
54 iteration_to_model_[iteration] = memory_register;
55 } else {
56 // If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model.
57 memory_register = iteration_to_model_.begin()->second;
58 MS_ERROR_IF_NULL_WO_RET_VAL(memory_register);
59 (void)iteration_to_model_.erase(iteration_to_model_.begin());
60 }
61
62 // Copy new model data to the the stored model.
63 auto &stored_model = memory_register->addresses();
64 for (const auto &weight : new_model) {
65 const std::string &weight_name = weight.first;
66 if (stored_model.count(weight_name) == 0) {
67 MS_LOG(ERROR) << "The stored model has no weight " << weight_name;
68 continue;
69 }
70
71 MS_ERROR_IF_NULL_WO_RET_VAL(stored_model[weight_name]);
72 MS_ERROR_IF_NULL_WO_RET_VAL(stored_model[weight_name]->addr);
73 MS_ERROR_IF_NULL_WO_RET_VAL(weight.second);
74 MS_ERROR_IF_NULL_WO_RET_VAL(weight.second->addr);
75 void *dst_addr = stored_model[weight_name]->addr;
76 size_t dst_size = stored_model[weight_name]->size;
77 void *src_addr = weight.second->addr;
78 size_t src_size = weight.second->size;
79 int ret = memcpy_s(dst_addr, dst_size, src_addr, src_size);
80 if (ret != 0) {
81 MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
82 return;
83 }
84 }
85 iteration_to_model_[iteration] = memory_register;
86 return;
87 }
88
GetModelByIterNum(size_t iteration)89 std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration) {
90 std::unique_lock<std::mutex> lock(model_mtx_);
91 std::map<std::string, AddressPtr> model = {};
92 if (iteration_to_model_.count(iteration) == 0) {
93 MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored.";
94 return model;
95 }
96 model = iteration_to_model_[iteration]->addresses();
97 return model;
98 }
99
Reset()100 void ModelStore::Reset() {
101 std::unique_lock<std::mutex> lock(model_mtx_);
102 initial_model_ = iteration_to_model_.rbegin()->second;
103 iteration_to_model_.clear();
104 iteration_to_model_[kInitIterationNum] = initial_model_;
105 }
106
iteration_to_model()107 const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() {
108 std::unique_lock<std::mutex> lock(model_mtx_);
109 return iteration_to_model_;
110 }
111
model_size() const112 size_t ModelStore::model_size() const { return model_size_; }
113
AssignNewModelMemory()114 std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
115 std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel();
116 if (model.empty()) {
117 MS_LOG(EXCEPTION) << "Model feature map is empty.";
118 return nullptr;
119 }
120
121 // Assign new memory for the model.
122 std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>();
123 MS_ERROR_IF_NULL_W_RET_VAL(memory_register, nullptr);
124 for (const auto &weight : model) {
125 const std::string weight_name = weight.first;
126 size_t weight_size = weight.second->size;
127 auto weight_data = std::make_unique<char[]>(weight_size);
128 MS_ERROR_IF_NULL_W_RET_VAL(weight_data, nullptr);
129 MS_ERROR_IF_NULL_W_RET_VAL(weight.second, nullptr);
130 MS_ERROR_IF_NULL_W_RET_VAL(weight.second->addr, nullptr);
131
132 auto src_data_size = weight_size;
133 auto dst_data_size = weight_size;
134 int ret = memcpy_s(weight_data.get(), dst_data_size, weight.second->addr, src_data_size);
135 if (ret != 0) {
136 MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
137 return nullptr;
138 }
139 memory_register->RegisterArray(weight_name, &weight_data, weight_size);
140 }
141 return memory_register;
142 }
143
ComputeModelSize()144 size_t ModelStore::ComputeModelSize() {
145 std::unique_lock<std::mutex> lock(model_mtx_);
146 if (iteration_to_model_.empty()) {
147 MS_LOG(EXCEPTION) << "Calculating model size failed: model for iteration 0 is not stored yet. ";
148 return 0;
149 }
150
151 const auto &model = iteration_to_model_[kInitIterationNum];
152 MS_EXCEPTION_IF_NULL(model);
153 size_t model_size = std::accumulate(model->addresses().begin(), model->addresses().end(), static_cast<size_t>(0),
154 [](size_t s, const auto &weight) { return s + weight.second->size; });
155 MS_LOG(INFO) << "Model size in byte is " << model_size;
156 return model_size;
157 }
158 } // namespace server
159 } // namespace fl
160 } // namespace mindspore
161