• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/server/kernel/round/push_weight_kernel.h"
18 
19 namespace mindspore {
20 namespace fl {
21 namespace server {
22 namespace kernel {
InitKernel(size_t)23 void PushWeightKernel::InitKernel(size_t) {
24   executor_ = &Executor::GetInstance();
25   MS_EXCEPTION_IF_NULL(executor_);
26   if (!executor_->initialized()) {
27     MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
28     return;
29   }
30   local_rank_ = DistributedCountService::GetInstance().local_rank();
31 }
32 
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> & outputs)33 bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
34                               const std::vector<AddressPtr> &outputs) {
35   MS_LOG(INFO) << "Launching PushWeightKernel kernel.";
36   void *req_data = inputs[0]->addr;
37   std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
38   if (fbb == nullptr || req_data == nullptr) {
39     std::string reason = "FBBuilder builder or req_data is nullptr.";
40     MS_LOG(ERROR) << reason;
41     GenerateOutput(outputs, reason.c_str(), reason.size());
42     return true;
43   }
44 
45   flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
46   if (!verifier.VerifyBuffer<schema::RequestPushWeight>()) {
47     std::string reason = "The schema of RequestPushWeight is invalid.";
48     BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num());
49     MS_LOG(ERROR) << reason;
50     GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
51     return true;
52   }
53 
54   const schema::RequestPushWeight *push_weight_req = flatbuffers::GetRoot<schema::RequestPushWeight>(req_data);
55   if (push_weight_req == nullptr) {
56     std::string reason = "Building flatbuffers schema failed for RequestPushWeight";
57     BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num());
58     GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
59     return false;
60   }
61 
62   ResultCode result_code = PushWeight(fbb, push_weight_req);
63   GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
64   return ConvertResultCode(result_code);
65 }
66 
Reset()67 bool PushWeightKernel::Reset() {
68   MS_LOG(INFO) << "PushWeightKernel reset!";
69   StopTimer();
70   DistributedCountService::GetInstance().ResetCounter(name_);
71   return true;
72 }
73 
OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &)74 void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
75   if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushWeight) {
76     FinishIteration();
77   }
78   return;
79 }
80 
PushWeight(const std::shared_ptr<FBBuilder> & fbb,const schema::RequestPushWeight * push_weight_req)81 ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb,
82                                         const schema::RequestPushWeight *push_weight_req) {
83   MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);
84   MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, ResultCode::kSuccessAndReturn);
85   size_t iteration = IntToSize(push_weight_req->iteration());
86   size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
87   if (iteration != current_iter) {
88     std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) +
89                          ", current iteration:" + std::to_string(current_iter);
90     BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter);
91     MS_LOG(WARNING) << reason;
92     return ResultCode::kSuccessAndReturn;
93   }
94 
95   std::map<std::string, Address> upload_feature_map = ParseFeatureMap(push_weight_req);
96   if (upload_feature_map.empty()) {
97     std::string reason = "PushWeight feature_map is empty.";
98     BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter);
99     MS_LOG(ERROR) << reason;
100     return ResultCode::kSuccessAndReturn;
101   }
102 
103   if (!executor_->HandlePushWeight(upload_feature_map)) {
104     std::string reason = "Pushing weight failed.";
105     BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter);
106     MS_LOG(ERROR) << reason;
107     return ResultCode::kSuccessAndReturn;
108   }
109   MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds.";
110 
111   std::string count_reason = "";
112   if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) {
113     std::string reason = "Count for push weight request failed.";
114     BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
115     MS_LOG(ERROR) << reason;
116     return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
117   }
118   BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.", current_iter);
119   return ResultCode::kSuccess;
120 }
121 
ParseFeatureMap(const schema::RequestPushWeight * push_weight_req)122 std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) {
123   MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, {});
124   std::map<std::string, Address> upload_feature_map;
125   auto fbs_feature_map = push_weight_req->feature_map();
126   MS_ERROR_IF_NULL_W_RET_VAL(fbs_feature_map, upload_feature_map);
127   for (uint32_t i = 0; i < fbs_feature_map->size(); i++) {
128     std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
129     float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
130     size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
131     upload_feature_map[weight_full_name] = {weight_data, weight_size};
132   }
133   return upload_feature_map;
134 }
135 
BuildPushWeightRsp(const std::shared_ptr<FBBuilder> & fbb,const schema::ResponseCode retcode,const std::string & reason,size_t iteration)136 void PushWeightKernel::BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
137                                           const std::string &reason, size_t iteration) {
138   if (fbb == nullptr) {
139     MS_LOG(ERROR) << "Input fbb is nullptr.";
140     return;
141   }
142   auto fbs_reason = fbb->CreateString(reason);
143   schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get()));
144   rsp_push_weight_builder.add_retcode(static_cast<int>(retcode));
145   rsp_push_weight_builder.add_reason(fbs_reason);
146   rsp_push_weight_builder.add_iteration(SizeToInt(iteration));
147   auto rsp_push_weight = rsp_push_weight_builder.Finish();
148   fbb->Finish(rsp_push_weight);
149   return;
150 }
151 
152 REG_ROUND_KERNEL(pushWeight, PushWeightKernel)
153 }  // namespace kernel
154 }  // namespace server
155 }  // namespace fl
156 }  // namespace mindspore
157