• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/server/distributed_metadata_store.h"
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 namespace mindspore {
23 namespace fl {
24 namespace server {
Initialize(const std::shared_ptr<ps::core::ServerNode> & server_node)25 void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
26   MS_EXCEPTION_IF_NULL(server_node);
27   server_node_ = server_node;
28   local_rank_ = server_node_->rank_id();
29   server_num_ = ps::PSContext::instance()->initial_server_num();
30   InitHashRing();
31   return;
32 }
33 
RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> & communicator)34 void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
35   MS_EXCEPTION_IF_NULL(communicator);
36   communicator_ = communicator;
37   communicator_->RegisterMsgCallBack(
38     "updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1));
39   communicator_->RegisterMsgCallBack(
40     "getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1));
41   return;
42 }
43 
RegisterMetadata(const std::string & name,const PBMetadata & meta)44 void DistributedMetadataStore::RegisterMetadata(const std::string &name, const PBMetadata &meta) {
45   if (router_ == nullptr) {
46     MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
47     return;
48   }
49 
50   uint32_t stored_rank = router_->Find(name);
51   if (local_rank_ == stored_rank) {
52     if (metadata_.count(name) != 0) {
53       MS_LOG(WARNING) << "The metadata for " << name << " is already registered.";
54       return;
55     }
56 
57     MS_LOG(INFO) << "Rank " << local_rank_ << " register storage for metadata " << name;
58     metadata_[name] = meta;
59     mutex_[name];
60   }
61   return;
62 }
63 
ResetMetadata(const std::string & name)64 void DistributedMetadataStore::ResetMetadata(const std::string &name) {
65   if (router_ == nullptr) {
66     MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
67     return;
68   }
69 
70   uint32_t stored_rank = router_->Find(name);
71   if (local_rank_ == stored_rank) {
72     if (metadata_.count(name) == 0) {
73       MS_LOG(ERROR) << "The metadata for " << name << " is not registered.";
74       return;
75     }
76 
77     MS_LOG(INFO) << "Rank " << local_rank_ << " reset metadata for " << name;
78     std::unique_lock<std::mutex> lock(mutex_[name]);
79     PBMetadata empty_meta;
80     metadata_[name] = empty_meta;
81   }
82   return;
83 }
84 
UpdateMetadata(const std::string & name,const PBMetadata & meta,std::string * reason)85 bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason) {
86   if (router_ == nullptr) {
87     MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
88     return false;
89   }
90 
91   uint32_t stored_rank = router_->Find(name);
92   MS_LOG(INFO) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank;
93   if (local_rank_ == stored_rank) {
94     if (!DoUpdateMetadata(name, meta)) {
95       MS_LOG(ERROR) << "Updating meta data failed.";
96       return false;
97     }
98   } else {
99     PBMetadataWithName metadata_with_name;
100     metadata_with_name.set_name(name);
101     *metadata_with_name.mutable_metadata() = meta;
102     std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
103     if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata,
104                                       &update_meta_rsp_msg)) {
105       MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
106       if (reason != nullptr) {
107         *reason = kNetworkError;
108       }
109       return false;
110     }
111 
112     MS_ERROR_IF_NULL_W_RET_VAL(update_meta_rsp_msg, false);
113     std::string update_meta_rsp =
114       std::string(reinterpret_cast<char *>(update_meta_rsp_msg->data()), update_meta_rsp_msg->size());
115     if (update_meta_rsp != kSuccess) {
116       MS_LOG(ERROR) << "Updating metadata in server " << stored_rank << " failed. " << update_meta_rsp;
117       return false;
118     }
119   }
120   return true;
121 }
122 
GetMetadata(const std::string & name)123 PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
124   if (router_ == nullptr) {
125     MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
126     return {};
127   }
128 
129   uint32_t stored_rank = router_->Find(name);
130   MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
131   if (local_rank_ == stored_rank) {
132     std::unique_lock<std::mutex> lock(mutex_[name]);
133     return metadata_[name];
134   } else {
135     GetMetadataRequest get_metadata_req;
136     get_metadata_req.set_name(name);
137     PBMetadata get_metadata_rsp;
138 
139     std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
140     if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, ps::core::TcpUserCommand::kGetMetadata,
141                                       &get_meta_rsp_msg)) {
142       MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
143       return get_metadata_rsp;
144     }
145 
146     MS_ERROR_IF_NULL_W_RET_VAL(get_meta_rsp_msg, get_metadata_rsp);
147     (void)get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size()));
148     return get_metadata_rsp;
149   }
150 }
151 
ReInitForScaling()152 bool DistributedMetadataStore::ReInitForScaling() {
153   // If DistributedMetadataStore is not initialized yet but the scaling event is triggered, do not throw exception.
154   if (server_node_ == nullptr) {
155     return true;
156   }
157 
158   MS_LOG(INFO) << "Cluster scaling completed. Reinitialize for distributed metadata store.";
159   local_rank_ = server_node_->rank_id();
160   server_num_ = IntToUint(server_node_->server_num());
161   MS_LOG(INFO) << "After scheduler scaling, this server's rank is " << local_rank_ << ", server number is "
162                << server_num_;
163   InitHashRing();
164 
165   // Clear old metadata.
166   metadata_.clear();
167   return true;
168 }
169 
InitHashRing()170 void DistributedMetadataStore::InitHashRing() {
171   router_ = std::make_shared<ConsistentHashRing>(kDefaultVirtualNodeNum);
172   MS_EXCEPTION_IF_NULL(router_);
173   for (uint32_t i = 0; i < server_num_; i++) {
174     bool ret = router_->Insert(i);
175     if (!ret) {
176       MS_LOG(EXCEPTION) << "Add node " << i << " to router of meta storage failed.";
177       return;
178     }
179   }
180   return;
181 }
182 
HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> & message)183 void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
184   MS_ERROR_IF_NULL_WO_RET_VAL(message);
185   PBMetadataWithName meta_with_name;
186   (void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len()));
187   const std::string &name = meta_with_name.name();
188   MS_LOG(INFO) << "Update metadata for " << name;
189 
190   std::string update_meta_rsp_msg;
191   if (!DoUpdateMetadata(name, meta_with_name.metadata())) {
192     update_meta_rsp_msg = "Updating meta data failed.";
193     MS_LOG(ERROR) << update_meta_rsp_msg;
194   } else {
195     update_meta_rsp_msg = "Success";
196   }
197   if (!communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message)) {
198     MS_LOG(ERROR) << "Sending response failed.";
199     return;
200   }
201   return;
202 }
203 
HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> & message)204 void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
205   MS_ERROR_IF_NULL_WO_RET_VAL(message);
206   GetMetadataRequest get_metadata_req;
207   (void)get_metadata_req.ParseFromArray(message->data(), SizeToInt(message->len()));
208   const std::string &name = get_metadata_req.name();
209   MS_LOG(INFO) << "Getting metadata for " << name;
210 
211   std::unique_lock<std::mutex> lock(mutex_[name]);
212   if (metadata_.count(name) == 0) {
213     MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
214     return;
215   }
216   PBMetadata stored_meta = metadata_[name];
217   std::string getting_meta_rsp_msg = stored_meta.SerializeAsString();
218   if (!communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message)) {
219     MS_LOG(ERROR) << "Sending response failed.";
220     return;
221   }
222   return;
223 }
224 
DoUpdateMetadata(const std::string & name,const PBMetadata & meta)225 bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
226   std::unique_lock<std::mutex> lock(mutex_[name]);
227   if (metadata_.count(name) == 0) {
228     MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
229     return false;
230   }
231   if (meta.has_device_meta()) {
232     auto &fl_id_to_meta_map = *metadata_[name].mutable_device_metas()->mutable_fl_id_to_meta();
233     auto &device_meta_fl_id = meta.device_meta().fl_id();
234     if (fl_id_to_meta_map.count(device_meta_fl_id) != 0) {
235       MS_LOG(WARNING) << "The fl id " << device_meta_fl_id << " already exists.";
236       return false;
237     }
238     auto &device_meta = meta.device_meta();
239     fl_id_to_meta_map[device_meta_fl_id] = device_meta;
240   } else if (meta.has_fl_id()) {
241     auto client_list = metadata_[name].mutable_client_list();
242     auto &fl_id = meta.fl_id().fl_id();
243     // Check whether the new item already exists.
244     bool add_flag = true;
245     for (int i = 0; i < client_list->fl_id_size(); i++) {
246       if (fl_id == client_list->fl_id(i)) {
247         add_flag = false;
248         break;
249       }
250     }
251     if (add_flag) {
252       client_list->add_fl_id(fl_id);
253     }
254   } else if (meta.has_update_model_threshold()) {
255     auto update_model_threshold = metadata_[name].mutable_update_model_threshold();
256     *update_model_threshold = meta.update_model_threshold();
257   } else if (meta.has_prime()) {
258     metadata_[name] = meta;
259   } else if (meta.has_pair_client_keys()) {
260     auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
261     auto &fl_id = meta.pair_client_keys().fl_id();
262     auto &client_keys = meta.pair_client_keys().client_keys();
263     // Check whether the new item already exists.
264     bool add_flag = true;
265     for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); ++iter) {
266       if (fl_id == iter->first) {
267         add_flag = false;
268         MS_LOG(ERROR) << "Leader server updating value for " << name
269                       << " failed: The Protobuffer of this value already exists.";
270         break;
271       }
272     }
273     if (add_flag) {
274       client_keys_map[fl_id] = client_keys;
275     } else {
276       return false;
277     }
278   } else if (meta.has_pair_client_shares()) {
279     auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
280     auto &fl_id = meta.pair_client_shares().fl_id();
281     auto &client_shares = meta.pair_client_shares().client_shares();
282     // google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter;
283     // Check whether the new item already exists.
284     bool add_flag = true;
285     for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); ++iter) {
286       if (fl_id == iter->first) {
287         add_flag = false;
288         MS_LOG(ERROR) << "Leader server updating value for " << name
289                       << " failed: The Protobuffer of this value already exists.";
290         break;
291       }
292     }
293     if (add_flag) {
294       client_shares_map[fl_id] = client_shares;
295     } else {
296       return false;
297     }
298   } else if (meta.has_one_client_noises()) {
299     auto &client_noises = *metadata_[name].mutable_client_noises();
300     if (client_noises.has_one_client_noises()) {
301       MS_LOG(WARNING) << "Leader server updating value for " << name
302                       << " failed: The Protobuffer of this value already exists.";
303       client_noises.Clear();
304     }
305     client_noises.mutable_one_client_noises()->MergeFrom(meta.one_client_noises());
306   } else {
307     MS_LOG(ERROR) << "Leader server updating value for " << name
308                   << " failed: The Protobuffer of this value is not defined.";
309     return false;
310   }
311   return true;
312 }
313 }  // namespace server
314 }  // namespace fl
315 }  // namespace mindspore
316