1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/server/distributed_metadata_store.h"
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 namespace mindspore {
23 namespace fl {
24 namespace server {
Initialize(const std::shared_ptr<ps::core::ServerNode> & server_node)25 void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
26 MS_EXCEPTION_IF_NULL(server_node);
27 server_node_ = server_node;
28 local_rank_ = server_node_->rank_id();
29 server_num_ = ps::PSContext::instance()->initial_server_num();
30 InitHashRing();
31 return;
32 }
33
RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> & communicator)34 void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
35 MS_EXCEPTION_IF_NULL(communicator);
36 communicator_ = communicator;
37 communicator_->RegisterMsgCallBack(
38 "updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1));
39 communicator_->RegisterMsgCallBack(
40 "getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1));
41 return;
42 }
43
RegisterMetadata(const std::string & name,const PBMetadata & meta)44 void DistributedMetadataStore::RegisterMetadata(const std::string &name, const PBMetadata &meta) {
45 if (router_ == nullptr) {
46 MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
47 return;
48 }
49
50 uint32_t stored_rank = router_->Find(name);
51 if (local_rank_ == stored_rank) {
52 if (metadata_.count(name) != 0) {
53 MS_LOG(WARNING) << "The metadata for " << name << " is already registered.";
54 return;
55 }
56
57 MS_LOG(INFO) << "Rank " << local_rank_ << " register storage for metadata " << name;
58 metadata_[name] = meta;
59 mutex_[name];
60 }
61 return;
62 }
63
ResetMetadata(const std::string & name)64 void DistributedMetadataStore::ResetMetadata(const std::string &name) {
65 if (router_ == nullptr) {
66 MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
67 return;
68 }
69
70 uint32_t stored_rank = router_->Find(name);
71 if (local_rank_ == stored_rank) {
72 if (metadata_.count(name) == 0) {
73 MS_LOG(ERROR) << "The metadata for " << name << " is not registered.";
74 return;
75 }
76
77 MS_LOG(INFO) << "Rank " << local_rank_ << " reset metadata for " << name;
78 std::unique_lock<std::mutex> lock(mutex_[name]);
79 PBMetadata empty_meta;
80 metadata_[name] = empty_meta;
81 }
82 return;
83 }
84
UpdateMetadata(const std::string & name,const PBMetadata & meta,std::string * reason)85 bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason) {
86 if (router_ == nullptr) {
87 MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
88 return false;
89 }
90
91 uint32_t stored_rank = router_->Find(name);
92 MS_LOG(INFO) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank;
93 if (local_rank_ == stored_rank) {
94 if (!DoUpdateMetadata(name, meta)) {
95 MS_LOG(ERROR) << "Updating meta data failed.";
96 return false;
97 }
98 } else {
99 PBMetadataWithName metadata_with_name;
100 metadata_with_name.set_name(name);
101 *metadata_with_name.mutable_metadata() = meta;
102 std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
103 if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata,
104 &update_meta_rsp_msg)) {
105 MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
106 if (reason != nullptr) {
107 *reason = kNetworkError;
108 }
109 return false;
110 }
111
112 MS_ERROR_IF_NULL_W_RET_VAL(update_meta_rsp_msg, false);
113 std::string update_meta_rsp =
114 std::string(reinterpret_cast<char *>(update_meta_rsp_msg->data()), update_meta_rsp_msg->size());
115 if (update_meta_rsp != kSuccess) {
116 MS_LOG(ERROR) << "Updating metadata in server " << stored_rank << " failed. " << update_meta_rsp;
117 return false;
118 }
119 }
120 return true;
121 }
122
GetMetadata(const std::string & name)123 PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
124 if (router_ == nullptr) {
125 MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
126 return {};
127 }
128
129 uint32_t stored_rank = router_->Find(name);
130 MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
131 if (local_rank_ == stored_rank) {
132 std::unique_lock<std::mutex> lock(mutex_[name]);
133 return metadata_[name];
134 } else {
135 GetMetadataRequest get_metadata_req;
136 get_metadata_req.set_name(name);
137 PBMetadata get_metadata_rsp;
138
139 std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
140 if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, ps::core::TcpUserCommand::kGetMetadata,
141 &get_meta_rsp_msg)) {
142 MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
143 return get_metadata_rsp;
144 }
145
146 MS_ERROR_IF_NULL_W_RET_VAL(get_meta_rsp_msg, get_metadata_rsp);
147 (void)get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size()));
148 return get_metadata_rsp;
149 }
150 }
151
ReInitForScaling()152 bool DistributedMetadataStore::ReInitForScaling() {
153 // If DistributedMetadataStore is not initialized yet but the scaling event is triggered, do not throw exception.
154 if (server_node_ == nullptr) {
155 return true;
156 }
157
158 MS_LOG(INFO) << "Cluster scaling completed. Reinitialize for distributed metadata store.";
159 local_rank_ = server_node_->rank_id();
160 server_num_ = IntToUint(server_node_->server_num());
161 MS_LOG(INFO) << "After scheduler scaling, this server's rank is " << local_rank_ << ", server number is "
162 << server_num_;
163 InitHashRing();
164
165 // Clear old metadata.
166 metadata_.clear();
167 return true;
168 }
169
InitHashRing()170 void DistributedMetadataStore::InitHashRing() {
171 router_ = std::make_shared<ConsistentHashRing>(kDefaultVirtualNodeNum);
172 MS_EXCEPTION_IF_NULL(router_);
173 for (uint32_t i = 0; i < server_num_; i++) {
174 bool ret = router_->Insert(i);
175 if (!ret) {
176 MS_LOG(EXCEPTION) << "Add node " << i << " to router of meta storage failed.";
177 return;
178 }
179 }
180 return;
181 }
182
HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> & message)183 void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
184 MS_ERROR_IF_NULL_WO_RET_VAL(message);
185 PBMetadataWithName meta_with_name;
186 (void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len()));
187 const std::string &name = meta_with_name.name();
188 MS_LOG(INFO) << "Update metadata for " << name;
189
190 std::string update_meta_rsp_msg;
191 if (!DoUpdateMetadata(name, meta_with_name.metadata())) {
192 update_meta_rsp_msg = "Updating meta data failed.";
193 MS_LOG(ERROR) << update_meta_rsp_msg;
194 } else {
195 update_meta_rsp_msg = "Success";
196 }
197 if (!communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message)) {
198 MS_LOG(ERROR) << "Sending response failed.";
199 return;
200 }
201 return;
202 }
203
HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> & message)204 void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
205 MS_ERROR_IF_NULL_WO_RET_VAL(message);
206 GetMetadataRequest get_metadata_req;
207 (void)get_metadata_req.ParseFromArray(message->data(), SizeToInt(message->len()));
208 const std::string &name = get_metadata_req.name();
209 MS_LOG(INFO) << "Getting metadata for " << name;
210
211 std::unique_lock<std::mutex> lock(mutex_[name]);
212 if (metadata_.count(name) == 0) {
213 MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
214 return;
215 }
216 PBMetadata stored_meta = metadata_[name];
217 std::string getting_meta_rsp_msg = stored_meta.SerializeAsString();
218 if (!communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message)) {
219 MS_LOG(ERROR) << "Sending response failed.";
220 return;
221 }
222 return;
223 }
224
DoUpdateMetadata(const std::string & name,const PBMetadata & meta)225 bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
226 std::unique_lock<std::mutex> lock(mutex_[name]);
227 if (metadata_.count(name) == 0) {
228 MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
229 return false;
230 }
231 if (meta.has_device_meta()) {
232 auto &fl_id_to_meta_map = *metadata_[name].mutable_device_metas()->mutable_fl_id_to_meta();
233 auto &device_meta_fl_id = meta.device_meta().fl_id();
234 if (fl_id_to_meta_map.count(device_meta_fl_id) != 0) {
235 MS_LOG(WARNING) << "The fl id " << device_meta_fl_id << " already exists.";
236 return false;
237 }
238 auto &device_meta = meta.device_meta();
239 fl_id_to_meta_map[device_meta_fl_id] = device_meta;
240 } else if (meta.has_fl_id()) {
241 auto client_list = metadata_[name].mutable_client_list();
242 auto &fl_id = meta.fl_id().fl_id();
243 // Check whether the new item already exists.
244 bool add_flag = true;
245 for (int i = 0; i < client_list->fl_id_size(); i++) {
246 if (fl_id == client_list->fl_id(i)) {
247 add_flag = false;
248 break;
249 }
250 }
251 if (add_flag) {
252 client_list->add_fl_id(fl_id);
253 }
254 } else if (meta.has_update_model_threshold()) {
255 auto update_model_threshold = metadata_[name].mutable_update_model_threshold();
256 *update_model_threshold = meta.update_model_threshold();
257 } else if (meta.has_prime()) {
258 metadata_[name] = meta;
259 } else if (meta.has_pair_client_keys()) {
260 auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
261 auto &fl_id = meta.pair_client_keys().fl_id();
262 auto &client_keys = meta.pair_client_keys().client_keys();
263 // Check whether the new item already exists.
264 bool add_flag = true;
265 for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); ++iter) {
266 if (fl_id == iter->first) {
267 add_flag = false;
268 MS_LOG(ERROR) << "Leader server updating value for " << name
269 << " failed: The Protobuffer of this value already exists.";
270 break;
271 }
272 }
273 if (add_flag) {
274 client_keys_map[fl_id] = client_keys;
275 } else {
276 return false;
277 }
278 } else if (meta.has_pair_client_shares()) {
279 auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
280 auto &fl_id = meta.pair_client_shares().fl_id();
281 auto &client_shares = meta.pair_client_shares().client_shares();
282 // google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter;
283 // Check whether the new item already exists.
284 bool add_flag = true;
285 for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); ++iter) {
286 if (fl_id == iter->first) {
287 add_flag = false;
288 MS_LOG(ERROR) << "Leader server updating value for " << name
289 << " failed: The Protobuffer of this value already exists.";
290 break;
291 }
292 }
293 if (add_flag) {
294 client_shares_map[fl_id] = client_shares;
295 } else {
296 return false;
297 }
298 } else if (meta.has_one_client_noises()) {
299 auto &client_noises = *metadata_[name].mutable_client_noises();
300 if (client_noises.has_one_client_noises()) {
301 MS_LOG(WARNING) << "Leader server updating value for " << name
302 << " failed: The Protobuffer of this value already exists.";
303 client_noises.Clear();
304 }
305 client_noises.mutable_one_client_noises()->MergeFrom(meta.one_client_noises());
306 } else {
307 MS_LOG(ERROR) << "Leader server updating value for " << name
308 << " failed: The Protobuffer of this value is not defined.";
309 return false;
310 }
311 return true;
312 }
313 } // namespace server
314 } // namespace fl
315 } // namespace mindspore
316