1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/server/kernel/round/get_model_kernel.h"
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include "fl/server/iteration.h"
23 #include "fl/server/model_store.h"
24
25 namespace mindspore {
26 namespace fl {
27 namespace server {
28 namespace kernel {
InitKernel(size_t)29 void GetModelKernel::InitKernel(size_t) {
30 if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
31 iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
32 }
33
34 executor_ = &Executor::GetInstance();
35 MS_EXCEPTION_IF_NULL(executor_);
36 if (!executor_->initialized()) {
37 MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
38 return;
39 }
40 }
41
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> & outputs)42 bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
43 const std::vector<AddressPtr> &outputs) {
44 if (inputs.size() != 1 || outputs.size() != 1) {
45 std::string reason = "inputs or outputs size is invalid.";
46 MS_LOG(ERROR) << reason;
47 GenerateOutput(outputs, reason.c_str(), reason.size());
48 return true;
49 }
50
51 void *req_data = inputs[0]->addr;
52 std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
53 if (fbb == nullptr || req_data == nullptr) {
54 std::string reason = "FBBuilder builder or req_data is nullptr.";
55 MS_LOG(ERROR) << reason;
56 GenerateOutput(outputs, reason.c_str(), reason.size());
57 return true;
58 }
59
60 flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
61 if (!verifier.VerifyBuffer<schema::RequestGetModel>()) {
62 std::string reason = "The schema of RequestGetModel is invalid.";
63 BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num(), {},
64 "");
65 MS_LOG(ERROR) << reason;
66 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
67 return true;
68 }
69
70 (void)++retry_count_;
71 if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
72 MS_LOG(INFO) << "Launching GetModelKernel retry count is " << retry_count_.load();
73 }
74
75 const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data);
76 if (get_model_req == nullptr) {
77 std::string reason = "Building flatbuffers schema failed for RequestGetModel.";
78 MS_LOG(ERROR) << reason;
79 GenerateOutput(outputs, reason.c_str(), reason.size());
80 return true;
81 }
82 GetModel(get_model_req, fbb);
83 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
84 return true;
85 }
86
Reset()87 bool GetModelKernel::Reset() {
88 MS_LOG(INFO) << "Get model kernel reset!";
89 StopTimer();
90 retry_count_ = 0;
91 return true;
92 }
93
GetModel(const schema::RequestGetModel * get_model_req,const std::shared_ptr<FBBuilder> & fbb)94 void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb) {
95 auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
96 std::map<std::string, AddressPtr> feature_maps;
97 size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
98 size_t get_model_iter = IntToSize(get_model_req->iteration());
99 const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
100 size_t latest_iter_num = iter_to_model.rbegin()->first;
101 // If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
102 if ((current_iter == get_model_iter && latest_iter_num != current_iter)) {
103 std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) +
104 ". Maybe this is because\n" + "1.Client doesn't send enough update model requests.\n" +
105 "2. Worker has not push all the weights to servers.";
106 BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
107 std::to_string(next_req_time));
108 if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
109 MS_LOG(WARNING) << reason;
110 }
111 return;
112 }
113
114 if (iter_to_model.count(get_model_iter) == 0) {
115 // If the model of get_model_iter is not stored, return the latest version of model and current iteration number.
116 MS_LOG(WARNING) << "The iteration of GetModel request " << std::to_string(get_model_iter)
117 << " is invalid. Current iteration is " << std::to_string(current_iter);
118 feature_maps = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
119 } else {
120 feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
121 }
122
123 MS_LOG(INFO) << "GetModel last iteration is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
124 << ", next request time is " << next_req_time << ", current iteration is " << current_iter;
125 BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
126 current_iter, feature_maps, std::to_string(next_req_time));
127 return;
128 }
129
BuildGetModelRsp(const std::shared_ptr<FBBuilder> & fbb,const schema::ResponseCode retcode,const std::string & reason,const size_t iter,const std::map<std::string,AddressPtr> & feature_maps,const std::string & timestamp)130 void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
131 const std::string &reason, const size_t iter,
132 const std::map<std::string, AddressPtr> &feature_maps,
133 const std::string ×tamp) {
134 if (fbb == nullptr) {
135 MS_LOG(ERROR) << "Input fbb is nullptr.";
136 return;
137 }
138 auto fbs_reason = fbb->CreateString(reason);
139 auto fbs_timestamp = fbb->CreateString(timestamp);
140 std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
141 for (const auto &feature_map : feature_maps) {
142 auto fbs_weight_fullname = fbb->CreateString(feature_map.first);
143 auto fbs_weight_data =
144 fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float));
145 auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data);
146 fbs_feature_maps.push_back(fbs_feature_map);
147 }
148 auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
149
150 schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get()));
151 rsp_get_model_builder.add_retcode(static_cast<int>(retcode));
152 rsp_get_model_builder.add_reason(fbs_reason);
153 rsp_get_model_builder.add_iteration(static_cast<int>(iter));
154 rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector);
155 rsp_get_model_builder.add_timestamp(fbs_timestamp);
156 auto rsp_get_model = rsp_get_model_builder.Finish();
157 fbb->Finish(rsp_get_model);
158 return;
159 }
160
161 REG_ROUND_KERNEL(getModel, GetModelKernel)
162 } // namespace kernel
163 } // namespace server
164 } // namespace fl
165 } // namespace mindspore
166