• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <vector>
21 #include "fl/server/kernel/round/update_model_kernel.h"
22 
23 namespace mindspore {
24 namespace fl {
25 namespace server {
26 namespace kernel {
InitKernel(size_t threshold_count)27 void UpdateModelKernel::InitKernel(size_t threshold_count) {
28   if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
29     iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
30   }
31 
32   executor_ = &Executor::GetInstance();
33   MS_EXCEPTION_IF_NULL(executor_);
34   if (!executor_->initialized()) {
35     MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
36     return;
37   }
38 
39   PBMetadata client_list;
40   DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list);
41   LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count);
42   LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, kInitialDataSizeSum);
43 }
44 
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> & outputs)45 bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
46                                const std::vector<AddressPtr> &outputs) {
47   MS_LOG(INFO) << "Launching UpdateModelKernel kernel.";
48   if (inputs.size() != 1 || outputs.size() != 1) {
49     std::string reason = "inputs or outputs size is invalid.";
50     MS_LOG(ERROR) << reason;
51     GenerateOutput(outputs, reason.c_str(), reason.size());
52     return true;
53   }
54 
55   void *req_data = inputs[0]->addr;
56   std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
57   if (fbb == nullptr || req_data == nullptr) {
58     std::string reason = "FBBuilder builder or req_data is nullptr.";
59     MS_LOG(ERROR) << reason;
60     GenerateOutput(outputs, reason.c_str(), reason.size());
61     return true;
62   }
63 
64   flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
65   if (!verifier.VerifyBuffer<schema::RequestUpdateModel>()) {
66     std::string reason = "The schema of RequestUpdateModel is invalid.";
67     BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
68     MS_LOG(ERROR) << reason;
69     GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
70     return true;
71   }
72 
73   ResultCode result_code = ReachThresholdForUpdateModel(fbb);
74   if (result_code != ResultCode::kSuccess) {
75     GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
76     return ConvertResultCode(result_code);
77   }
78 
79   const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data);
80   if (update_model_req == nullptr) {
81     std::string reason = "Building flatbuffers schema failed for RequestUpdateModel.";
82     BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
83     MS_LOG(ERROR) << reason;
84     GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
85     return true;
86   }
87 
88   result_code = UpdateModel(update_model_req, fbb);
89   if (result_code != ResultCode::kSuccess) {
90     MS_LOG(ERROR) << "Updating model failed.";
91     GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
92     return ConvertResultCode(result_code);
93   }
94 
95   result_code = CountForUpdateModel(fbb, update_model_req);
96   if (result_code != ResultCode::kSuccess) {
97     GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
98     return ConvertResultCode(result_code);
99   }
100   GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
101   return true;
102 }
103 
Reset()104 bool UpdateModelKernel::Reset() {
105   MS_LOG(INFO) << "Update model kernel reset!";
106   StopTimer();
107   DistributedCountService::GetInstance().ResetCounter(name_);
108   executor_->ResetAggregationStatus();
109   DistributedMetadataStore::GetInstance().ResetMetadata(kCtxUpdateModelClientList);
110   size_t &total_data_size = LocalMetaStore::GetInstance().mutable_value<size_t>(kCtxFedAvgTotalDataSize);
111   total_data_size = 0;
112   return true;
113 }
114 
OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &)115 void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
116   if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) {
117     while (!executor_->IsAllWeightAggregationDone()) {
118       std::this_thread::sleep_for(std::chrono::milliseconds(5));
119     }
120 
121     size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
122     MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
123                  << total_data_size;
124     if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
125       FinishIteration();
126     }
127   }
128 }
129 
ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> & fbb)130 ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
131   if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
132     std::string reason = "Current amount for updateModel is enough. Please retry later.";
133     BuildUpdateModelRsp(
134       fbb, schema::ResponseCode_OutOfTime, reason,
135       std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
136     MS_LOG(WARNING) << reason;
137     return ResultCode::kSuccessAndReturn;
138   }
139   return ResultCode::kSuccess;
140 }
141 
UpdateModel(const schema::RequestUpdateModel * update_model_req,const std::shared_ptr<FBBuilder> & fbb)142 ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
143                                           const std::shared_ptr<FBBuilder> &fbb) {
144   MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
145   size_t iteration = IntToSize(update_model_req->iteration());
146   if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
147     std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
148                          ", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) +
149                          ". Retry later.";
150     BuildUpdateModelRsp(
151       fbb, schema::ResponseCode_OutOfTime, reason,
152       std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
153     MS_LOG(WARNING) << reason;
154     return ResultCode::kSuccessAndReturn;
155   }
156 
157   PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
158   FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
159   std::string update_model_fl_id = update_model_req->fl_id()->str();
160   MS_LOG(INFO) << "UpdateModel for fl id " << update_model_fl_id;
161   if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
162     if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
163       std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
164       BuildUpdateModelRsp(
165         fbb, schema::ResponseCode_OutOfTime, reason,
166         std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
167       MS_LOG(ERROR) << reason;
168       return ResultCode::kSuccessAndReturn;
169     }
170   } else {
171     std::vector<std::string> get_secrets_clients;
172 #ifdef ENABLE_ARMOUR
173     mindspore::armour::CipherMetaStorage cipher_meta_storage;
174     cipher_meta_storage.GetClientListFromServer(fl::server::kCtxGetSecretsClientList, &get_secrets_clients);
175 #endif
176     if (find(get_secrets_clients.begin(), get_secrets_clients.end(), update_model_fl_id) ==
177         get_secrets_clients.end()) {  // the client not in get_secrets_clients
178       std::string reason = "fl_id: " + update_model_fl_id + " is not in get_secrets_clients. Please retry later.";
179       BuildUpdateModelRsp(
180         fbb, schema::ResponseCode_OutOfTime, reason,
181         std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
182       MS_LOG(ERROR) << reason;
183       return ResultCode::kSuccessAndReturn;
184     }
185   }
186 
187   size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
188   auto feature_map = ParseFeatureMap(update_model_req);
189   if (feature_map.empty()) {
190     std::string reason = "Feature map is empty.";
191     BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
192     MS_LOG(ERROR) << reason;
193     return ResultCode::kSuccessAndReturn;
194   }
195 
196   for (auto weight : feature_map) {
197     weight.second[kNewDataSize].addr = &data_size;
198     weight.second[kNewDataSize].size = sizeof(size_t);
199     if (!executor_->HandleModelUpdate(weight.first, weight.second)) {
200       std::string reason = "Updating weight " + weight.first + " failed.";
201       BuildUpdateModelRsp(
202         fbb, schema::ResponseCode_OutOfTime, reason,
203         std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
204       MS_LOG(ERROR) << reason;
205       return ResultCode::kFail;
206     }
207   }
208 
209   FLId fl_id;
210   fl_id.set_fl_id(update_model_fl_id);
211   PBMetadata comm_value;
212   *comm_value.mutable_fl_id() = fl_id;
213   std::string update_reason = "";
214   if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value, &update_reason)) {
215     std::string reason = "Updating metadata of UpdateModelClientList failed. " + update_reason;
216     BuildUpdateModelRsp(
217       fbb, schema::ResponseCode_OutOfTime, reason,
218       std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
219     MS_LOG(ERROR) << reason;
220     return update_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
221   }
222 
223   BuildUpdateModelRsp(fbb, schema::ResponseCode_SUCCEED, "success not ready",
224                       std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
225   return ResultCode::kSuccess;
226 }
227 
ParseFeatureMap(const schema::RequestUpdateModel * update_model_req)228 std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
229   const schema::RequestUpdateModel *update_model_req) {
230   MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, {});
231   std::map<std::string, UploadData> feature_map;
232   auto fbs_feature_map = update_model_req->feature_map();
233   MS_ERROR_IF_NULL_W_RET_VAL(fbs_feature_map, feature_map);
234   for (uint32_t i = 0; i < fbs_feature_map->size(); i++) {
235     std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
236     float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
237     size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
238     UploadData upload_data;
239     upload_data[kNewWeight].addr = weight_data;
240     upload_data[kNewWeight].size = weight_size;
241     feature_map[weight_full_name] = upload_data;
242   }
243   return feature_map;
244 }
245 
CountForUpdateModel(const std::shared_ptr<FBBuilder> & fbb,const schema::RequestUpdateModel * update_model_req)246 ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
247                                                   const schema::RequestUpdateModel *update_model_req) {
248   MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);
249   MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
250   std::string count_reason = "";
251   if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) {
252     std::string reason = "Counting for update model request failed. Please retry later. " + count_reason;
253     BuildUpdateModelRsp(
254       fbb, schema::ResponseCode_OutOfTime, reason,
255       std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
256     MS_LOG(ERROR) << reason;
257     return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
258   }
259   return ResultCode::kSuccess;
260 }
261 
BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> & fbb,const schema::ResponseCode retcode,const std::string & reason,const std::string & next_req_time)262 void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
263                                             const std::string &reason, const std::string &next_req_time) {
264   if (fbb == nullptr) {
265     MS_LOG(ERROR) << "Input fbb is nullptr.";
266     return;
267   }
268   auto fbs_reason = fbb->CreateString(reason);
269   auto fbs_next_req_time = fbb->CreateString(next_req_time);
270 
271   schema::ResponseUpdateModelBuilder rsp_update_model_builder(*(fbb.get()));
272   rsp_update_model_builder.add_retcode(static_cast<int>(retcode));
273   rsp_update_model_builder.add_reason(fbs_reason);
274   rsp_update_model_builder.add_next_req_time(fbs_next_req_time);
275   auto rsp_update_model = rsp_update_model_builder.Finish();
276   fbb->Finish(rsp_update_model);
277   return;
278 }
279 
280 REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
281 }  // namespace kernel
282 }  // namespace server
283 }  // namespace fl
284 }  // namespace mindspore
285