1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/server/kernel/round/push_weight_kernel.h"
18
19 namespace mindspore {
20 namespace fl {
21 namespace server {
22 namespace kernel {
InitKernel(size_t)23 void PushWeightKernel::InitKernel(size_t) {
24 executor_ = &Executor::GetInstance();
25 MS_EXCEPTION_IF_NULL(executor_);
26 if (!executor_->initialized()) {
27 MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
28 return;
29 }
30 local_rank_ = DistributedCountService::GetInstance().local_rank();
31 }
32
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> & outputs)33 bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
34 const std::vector<AddressPtr> &outputs) {
35 MS_LOG(INFO) << "Launching PushWeightKernel kernel.";
36 void *req_data = inputs[0]->addr;
37 std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
38 if (fbb == nullptr || req_data == nullptr) {
39 std::string reason = "FBBuilder builder or req_data is nullptr.";
40 MS_LOG(ERROR) << reason;
41 GenerateOutput(outputs, reason.c_str(), reason.size());
42 return true;
43 }
44
45 flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
46 if (!verifier.VerifyBuffer<schema::RequestPushWeight>()) {
47 std::string reason = "The schema of RequestPushWeight is invalid.";
48 BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num());
49 MS_LOG(ERROR) << reason;
50 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
51 return true;
52 }
53
54 const schema::RequestPushWeight *push_weight_req = flatbuffers::GetRoot<schema::RequestPushWeight>(req_data);
55 if (push_weight_req == nullptr) {
56 std::string reason = "Building flatbuffers schema failed for RequestPushWeight";
57 BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num());
58 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
59 return false;
60 }
61
62 ResultCode result_code = PushWeight(fbb, push_weight_req);
63 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
64 return ConvertResultCode(result_code);
65 }
66
Reset()67 bool PushWeightKernel::Reset() {
68 MS_LOG(INFO) << "PushWeightKernel reset!";
69 StopTimer();
70 DistributedCountService::GetInstance().ResetCounter(name_);
71 return true;
72 }
73
OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &)74 void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
75 if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushWeight) {
76 FinishIteration();
77 }
78 return;
79 }
80
PushWeight(const std::shared_ptr<FBBuilder> & fbb,const schema::RequestPushWeight * push_weight_req)81 ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb,
82 const schema::RequestPushWeight *push_weight_req) {
83 MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);
84 MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, ResultCode::kSuccessAndReturn);
85 size_t iteration = IntToSize(push_weight_req->iteration());
86 size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
87 if (iteration != current_iter) {
88 std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) +
89 ", current iteration:" + std::to_string(current_iter);
90 BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter);
91 MS_LOG(WARNING) << reason;
92 return ResultCode::kSuccessAndReturn;
93 }
94
95 std::map<std::string, Address> upload_feature_map = ParseFeatureMap(push_weight_req);
96 if (upload_feature_map.empty()) {
97 std::string reason = "PushWeight feature_map is empty.";
98 BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter);
99 MS_LOG(ERROR) << reason;
100 return ResultCode::kSuccessAndReturn;
101 }
102
103 if (!executor_->HandlePushWeight(upload_feature_map)) {
104 std::string reason = "Pushing weight failed.";
105 BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter);
106 MS_LOG(ERROR) << reason;
107 return ResultCode::kSuccessAndReturn;
108 }
109 MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds.";
110
111 std::string count_reason = "";
112 if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) {
113 std::string reason = "Count for push weight request failed.";
114 BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
115 MS_LOG(ERROR) << reason;
116 return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
117 }
118 BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.", current_iter);
119 return ResultCode::kSuccess;
120 }
121
ParseFeatureMap(const schema::RequestPushWeight * push_weight_req)122 std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) {
123 MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, {});
124 std::map<std::string, Address> upload_feature_map;
125 auto fbs_feature_map = push_weight_req->feature_map();
126 MS_ERROR_IF_NULL_W_RET_VAL(fbs_feature_map, upload_feature_map);
127 for (uint32_t i = 0; i < fbs_feature_map->size(); i++) {
128 std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
129 float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
130 size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
131 upload_feature_map[weight_full_name] = {weight_data, weight_size};
132 }
133 return upload_feature_map;
134 }
135
BuildPushWeightRsp(const std::shared_ptr<FBBuilder> & fbb,const schema::ResponseCode retcode,const std::string & reason,size_t iteration)136 void PushWeightKernel::BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
137 const std::string &reason, size_t iteration) {
138 if (fbb == nullptr) {
139 MS_LOG(ERROR) << "Input fbb is nullptr.";
140 return;
141 }
142 auto fbs_reason = fbb->CreateString(reason);
143 schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get()));
144 rsp_push_weight_builder.add_retcode(static_cast<int>(retcode));
145 rsp_push_weight_builder.add_reason(fbs_reason);
146 rsp_push_weight_builder.add_iteration(SizeToInt(iteration));
147 auto rsp_push_weight = rsp_push_weight_builder.Finish();
148 fbb->Finish(rsp_push_weight);
149 return;
150 }
151
152 REG_ROUND_KERNEL(pushWeight, PushWeightKernel)
153 } // namespace kernel
154 } // namespace server
155 } // namespace fl
156 } // namespace mindspore
157