1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <vector>
21 #include "fl/server/kernel/round/update_model_kernel.h"
22
23 namespace mindspore {
24 namespace fl {
25 namespace server {
26 namespace kernel {
InitKernel(size_t threshold_count)27 void UpdateModelKernel::InitKernel(size_t threshold_count) {
28 if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
29 iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
30 }
31
32 executor_ = &Executor::GetInstance();
33 MS_EXCEPTION_IF_NULL(executor_);
34 if (!executor_->initialized()) {
35 MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
36 return;
37 }
38
39 PBMetadata client_list;
40 DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list);
41 LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count);
42 LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, kInitialDataSizeSum);
43 }
44
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> & outputs)45 bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
46 const std::vector<AddressPtr> &outputs) {
47 MS_LOG(INFO) << "Launching UpdateModelKernel kernel.";
48 if (inputs.size() != 1 || outputs.size() != 1) {
49 std::string reason = "inputs or outputs size is invalid.";
50 MS_LOG(ERROR) << reason;
51 GenerateOutput(outputs, reason.c_str(), reason.size());
52 return true;
53 }
54
55 void *req_data = inputs[0]->addr;
56 std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
57 if (fbb == nullptr || req_data == nullptr) {
58 std::string reason = "FBBuilder builder or req_data is nullptr.";
59 MS_LOG(ERROR) << reason;
60 GenerateOutput(outputs, reason.c_str(), reason.size());
61 return true;
62 }
63
64 flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
65 if (!verifier.VerifyBuffer<schema::RequestUpdateModel>()) {
66 std::string reason = "The schema of RequestUpdateModel is invalid.";
67 BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
68 MS_LOG(ERROR) << reason;
69 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
70 return true;
71 }
72
73 ResultCode result_code = ReachThresholdForUpdateModel(fbb);
74 if (result_code != ResultCode::kSuccess) {
75 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
76 return ConvertResultCode(result_code);
77 }
78
79 const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data);
80 if (update_model_req == nullptr) {
81 std::string reason = "Building flatbuffers schema failed for RequestUpdateModel.";
82 BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
83 MS_LOG(ERROR) << reason;
84 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
85 return true;
86 }
87
88 result_code = UpdateModel(update_model_req, fbb);
89 if (result_code != ResultCode::kSuccess) {
90 MS_LOG(ERROR) << "Updating model failed.";
91 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
92 return ConvertResultCode(result_code);
93 }
94
95 result_code = CountForUpdateModel(fbb, update_model_req);
96 if (result_code != ResultCode::kSuccess) {
97 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
98 return ConvertResultCode(result_code);
99 }
100 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
101 return true;
102 }
103
Reset()104 bool UpdateModelKernel::Reset() {
105 MS_LOG(INFO) << "Update model kernel reset!";
106 StopTimer();
107 DistributedCountService::GetInstance().ResetCounter(name_);
108 executor_->ResetAggregationStatus();
109 DistributedMetadataStore::GetInstance().ResetMetadata(kCtxUpdateModelClientList);
110 size_t &total_data_size = LocalMetaStore::GetInstance().mutable_value<size_t>(kCtxFedAvgTotalDataSize);
111 total_data_size = 0;
112 return true;
113 }
114
OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &)115 void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
116 if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) {
117 while (!executor_->IsAllWeightAggregationDone()) {
118 std::this_thread::sleep_for(std::chrono::milliseconds(5));
119 }
120
121 size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
122 MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
123 << total_data_size;
124 if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
125 FinishIteration();
126 }
127 }
128 }
129
ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> & fbb)130 ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
131 if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
132 std::string reason = "Current amount for updateModel is enough. Please retry later.";
133 BuildUpdateModelRsp(
134 fbb, schema::ResponseCode_OutOfTime, reason,
135 std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
136 MS_LOG(WARNING) << reason;
137 return ResultCode::kSuccessAndReturn;
138 }
139 return ResultCode::kSuccess;
140 }
141
UpdateModel(const schema::RequestUpdateModel * update_model_req,const std::shared_ptr<FBBuilder> & fbb)142 ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
143 const std::shared_ptr<FBBuilder> &fbb) {
144 MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
145 size_t iteration = IntToSize(update_model_req->iteration());
146 if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
147 std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
148 ", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) +
149 ". Retry later.";
150 BuildUpdateModelRsp(
151 fbb, schema::ResponseCode_OutOfTime, reason,
152 std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
153 MS_LOG(WARNING) << reason;
154 return ResultCode::kSuccessAndReturn;
155 }
156
157 PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
158 FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
159 std::string update_model_fl_id = update_model_req->fl_id()->str();
160 MS_LOG(INFO) << "UpdateModel for fl id " << update_model_fl_id;
161 if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
162 if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
163 std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
164 BuildUpdateModelRsp(
165 fbb, schema::ResponseCode_OutOfTime, reason,
166 std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
167 MS_LOG(ERROR) << reason;
168 return ResultCode::kSuccessAndReturn;
169 }
170 } else {
171 std::vector<std::string> get_secrets_clients;
172 #ifdef ENABLE_ARMOUR
173 mindspore::armour::CipherMetaStorage cipher_meta_storage;
174 cipher_meta_storage.GetClientListFromServer(fl::server::kCtxGetSecretsClientList, &get_secrets_clients);
175 #endif
176 if (find(get_secrets_clients.begin(), get_secrets_clients.end(), update_model_fl_id) ==
177 get_secrets_clients.end()) { // the client not in get_secrets_clients
178 std::string reason = "fl_id: " + update_model_fl_id + " is not in get_secrets_clients. Please retry later.";
179 BuildUpdateModelRsp(
180 fbb, schema::ResponseCode_OutOfTime, reason,
181 std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
182 MS_LOG(ERROR) << reason;
183 return ResultCode::kSuccessAndReturn;
184 }
185 }
186
187 size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
188 auto feature_map = ParseFeatureMap(update_model_req);
189 if (feature_map.empty()) {
190 std::string reason = "Feature map is empty.";
191 BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
192 MS_LOG(ERROR) << reason;
193 return ResultCode::kSuccessAndReturn;
194 }
195
196 for (auto weight : feature_map) {
197 weight.second[kNewDataSize].addr = &data_size;
198 weight.second[kNewDataSize].size = sizeof(size_t);
199 if (!executor_->HandleModelUpdate(weight.first, weight.second)) {
200 std::string reason = "Updating weight " + weight.first + " failed.";
201 BuildUpdateModelRsp(
202 fbb, schema::ResponseCode_OutOfTime, reason,
203 std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
204 MS_LOG(ERROR) << reason;
205 return ResultCode::kFail;
206 }
207 }
208
209 FLId fl_id;
210 fl_id.set_fl_id(update_model_fl_id);
211 PBMetadata comm_value;
212 *comm_value.mutable_fl_id() = fl_id;
213 std::string update_reason = "";
214 if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value, &update_reason)) {
215 std::string reason = "Updating metadata of UpdateModelClientList failed. " + update_reason;
216 BuildUpdateModelRsp(
217 fbb, schema::ResponseCode_OutOfTime, reason,
218 std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
219 MS_LOG(ERROR) << reason;
220 return update_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
221 }
222
223 BuildUpdateModelRsp(fbb, schema::ResponseCode_SUCCEED, "success not ready",
224 std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
225 return ResultCode::kSuccess;
226 }
227
ParseFeatureMap(const schema::RequestUpdateModel * update_model_req)228 std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
229 const schema::RequestUpdateModel *update_model_req) {
230 MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, {});
231 std::map<std::string, UploadData> feature_map;
232 auto fbs_feature_map = update_model_req->feature_map();
233 MS_ERROR_IF_NULL_W_RET_VAL(fbs_feature_map, feature_map);
234 for (uint32_t i = 0; i < fbs_feature_map->size(); i++) {
235 std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
236 float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
237 size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
238 UploadData upload_data;
239 upload_data[kNewWeight].addr = weight_data;
240 upload_data[kNewWeight].size = weight_size;
241 feature_map[weight_full_name] = upload_data;
242 }
243 return feature_map;
244 }
245
CountForUpdateModel(const std::shared_ptr<FBBuilder> & fbb,const schema::RequestUpdateModel * update_model_req)246 ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
247 const schema::RequestUpdateModel *update_model_req) {
248 MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);
249 MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
250 std::string count_reason = "";
251 if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) {
252 std::string reason = "Counting for update model request failed. Please retry later. " + count_reason;
253 BuildUpdateModelRsp(
254 fbb, schema::ResponseCode_OutOfTime, reason,
255 std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
256 MS_LOG(ERROR) << reason;
257 return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
258 }
259 return ResultCode::kSuccess;
260 }
261
BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> & fbb,const schema::ResponseCode retcode,const std::string & reason,const std::string & next_req_time)262 void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
263 const std::string &reason, const std::string &next_req_time) {
264 if (fbb == nullptr) {
265 MS_LOG(ERROR) << "Input fbb is nullptr.";
266 return;
267 }
268 auto fbs_reason = fbb->CreateString(reason);
269 auto fbs_next_req_time = fbb->CreateString(next_req_time);
270
271 schema::ResponseUpdateModelBuilder rsp_update_model_builder(*(fbb.get()));
272 rsp_update_model_builder.add_retcode(static_cast<int>(retcode));
273 rsp_update_model_builder.add_reason(fbs_reason);
274 rsp_update_model_builder.add_next_req_time(fbs_next_req_time);
275 auto rsp_update_model = rsp_update_model_builder.Finish();
276 fbb->Finish(rsp_update_model);
277 return;
278 }
279
280 REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
281 } // namespace kernel
282 } // namespace server
283 } // namespace fl
284 } // namespace mindspore
285