• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/server/kernel/round/get_model_kernel.h"
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include "fl/server/iteration.h"
23 #include "fl/server/model_store.h"
24 
25 namespace mindspore {
26 namespace fl {
27 namespace server {
28 namespace kernel {
InitKernel(size_t)29 void GetModelKernel::InitKernel(size_t) {
30   if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
31     iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
32   }
33 
34   executor_ = &Executor::GetInstance();
35   MS_EXCEPTION_IF_NULL(executor_);
36   if (!executor_->initialized()) {
37     MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
38     return;
39   }
40 }
41 
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> & outputs)42 bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
43                             const std::vector<AddressPtr> &outputs) {
44   if (inputs.size() != 1 || outputs.size() != 1) {
45     std::string reason = "inputs or outputs size is invalid.";
46     MS_LOG(ERROR) << reason;
47     GenerateOutput(outputs, reason.c_str(), reason.size());
48     return true;
49   }
50 
51   void *req_data = inputs[0]->addr;
52   std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
53   if (fbb == nullptr || req_data == nullptr) {
54     std::string reason = "FBBuilder builder or req_data is nullptr.";
55     MS_LOG(ERROR) << reason;
56     GenerateOutput(outputs, reason.c_str(), reason.size());
57     return true;
58   }
59 
60   flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
61   if (!verifier.VerifyBuffer<schema::RequestGetModel>()) {
62     std::string reason = "The schema of RequestGetModel is invalid.";
63     BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num(), {},
64                      "");
65     MS_LOG(ERROR) << reason;
66     GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
67     return true;
68   }
69 
70   (void)++retry_count_;
71   if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
72     MS_LOG(INFO) << "Launching GetModelKernel retry count is " << retry_count_.load();
73   }
74 
75   const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data);
76   if (get_model_req == nullptr) {
77     std::string reason = "Building flatbuffers schema failed for RequestGetModel.";
78     MS_LOG(ERROR) << reason;
79     GenerateOutput(outputs, reason.c_str(), reason.size());
80     return true;
81   }
82   GetModel(get_model_req, fbb);
83   GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
84   return true;
85 }
86 
Reset()87 bool GetModelKernel::Reset() {
88   MS_LOG(INFO) << "Get model kernel reset!";
89   StopTimer();
90   retry_count_ = 0;
91   return true;
92 }
93 
GetModel(const schema::RequestGetModel * get_model_req,const std::shared_ptr<FBBuilder> & fbb)94 void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb) {
95   auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
96   std::map<std::string, AddressPtr> feature_maps;
97   size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
98   size_t get_model_iter = IntToSize(get_model_req->iteration());
99   const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
100   size_t latest_iter_num = iter_to_model.rbegin()->first;
101   // If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
102   if ((current_iter == get_model_iter && latest_iter_num != current_iter)) {
103     std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) +
104                          ". Maybe this is because\n" + "1.Client doesn't send enough update model requests.\n" +
105                          "2. Worker has not push all the weights to servers.";
106     BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
107                      std::to_string(next_req_time));
108     if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
109       MS_LOG(WARNING) << reason;
110     }
111     return;
112   }
113 
114   if (iter_to_model.count(get_model_iter) == 0) {
115     // If the model of get_model_iter is not stored, return the latest version of model and current iteration number.
116     MS_LOG(WARNING) << "The iteration of GetModel request " << std::to_string(get_model_iter)
117                     << " is invalid. Current iteration is " << std::to_string(current_iter);
118     feature_maps = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
119   } else {
120     feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
121   }
122 
123   MS_LOG(INFO) << "GetModel last iteration is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
124                << ", next request time is " << next_req_time << ", current iteration is " << current_iter;
125   BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
126                    current_iter, feature_maps, std::to_string(next_req_time));
127   return;
128 }
129 
BuildGetModelRsp(const std::shared_ptr<FBBuilder> & fbb,const schema::ResponseCode retcode,const std::string & reason,const size_t iter,const std::map<std::string,AddressPtr> & feature_maps,const std::string & timestamp)130 void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
131                                       const std::string &reason, const size_t iter,
132                                       const std::map<std::string, AddressPtr> &feature_maps,
133                                       const std::string &timestamp) {
134   if (fbb == nullptr) {
135     MS_LOG(ERROR) << "Input fbb is nullptr.";
136     return;
137   }
138   auto fbs_reason = fbb->CreateString(reason);
139   auto fbs_timestamp = fbb->CreateString(timestamp);
140   std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
141   for (const auto &feature_map : feature_maps) {
142     auto fbs_weight_fullname = fbb->CreateString(feature_map.first);
143     auto fbs_weight_data =
144       fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float));
145     auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data);
146     fbs_feature_maps.push_back(fbs_feature_map);
147   }
148   auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
149 
150   schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get()));
151   rsp_get_model_builder.add_retcode(static_cast<int>(retcode));
152   rsp_get_model_builder.add_reason(fbs_reason);
153   rsp_get_model_builder.add_iteration(static_cast<int>(iter));
154   rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector);
155   rsp_get_model_builder.add_timestamp(fbs_timestamp);
156   auto rsp_get_model = rsp_get_model_builder.Finish();
157   fbb->Finish(rsp_get_model);
158   return;
159 }
160 
161 REG_ROUND_KERNEL(getModel, GetModelKernel)
162 }  // namespace kernel
163 }  // namespace server
164 }  // namespace fl
165 }  // namespace mindspore
166