• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_MODEL_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_MODEL_H_
19 
20 #include <map>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <utility>
25 #include <functional>
26 #include "backend/kernel_compiler/cpu/cpu_kernel.h"
27 #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
28 #include "fl/worker/fl_worker.h"
29 
30 namespace mindspore {
31 namespace kernel {
32 class GetModelKernel : public CPUKernel {
33  public:
34   GetModelKernel() = default;
35   ~GetModelKernel() override = default;
36 
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> &)37   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &) {
38     MS_LOG(INFO) << "Launching client GetModelKernel";
39     if (!BuildGetModelReq(fbb_, inputs)) {
40       MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
41       return false;
42     }
43 
44     const schema::ResponseGetModel *get_model_rsp = nullptr;
45     std::shared_ptr<std::vector<unsigned char>> get_model_rsp_msg = nullptr;
46     int response_code = schema::ResponseCode_SucNotReady;
47     while (response_code == schema::ResponseCode_SucNotReady) {
48       if (!fl::worker::FLWorker::GetInstance().SendToServer(target_server_rank_, fbb_->GetBufferPointer(),
49                                                             fbb_->GetSize(), ps::core::TcpUserCommand::kGetModel,
50                                                             &get_model_rsp_msg)) {
51         MS_LOG(EXCEPTION) << "Sending request for GetModel to server " << target_server_rank_ << " failed.";
52         return false;
53       }
54       flatbuffers::Verifier verifier(get_model_rsp_msg->data(), get_model_rsp_msg->size());
55       if (!verifier.VerifyBuffer<schema::ResponseGetModel>()) {
56         MS_LOG(EXCEPTION) << "The schema of ResponseGetModel is invalid.";
57         return false;
58       }
59 
60       get_model_rsp = flatbuffers::GetRoot<schema::ResponseGetModel>(get_model_rsp_msg->data());
61       MS_EXCEPTION_IF_NULL(get_model_rsp);
62       response_code = get_model_rsp->retcode();
63       if (response_code == schema::ResponseCode_SUCCEED) {
64         break;
65       } else if (response_code == schema::ResponseCode_SucNotReady) {
66         std::this_thread::sleep_for(std::chrono::milliseconds(200));
67         continue;
68       } else {
69         MS_LOG(EXCEPTION) << "Launching get model for worker failed. Reason: " << get_model_rsp->reason();
70       }
71     }
72 
73     auto feature_map = get_model_rsp->feature_map();
74     MS_EXCEPTION_IF_NULL(feature_map);
75     if (feature_map->size() == 0) {
76       MS_LOG(EXCEPTION) << "Feature map after GetModel is empty.";
77       return false;
78     }
79     for (size_t i = 0; i < feature_map->size(); i++) {
80       std::string weight_full_name = feature_map->Get(i)->weight_fullname()->str();
81       float *weight_data = const_cast<float *>(feature_map->Get(i)->data()->data());
82       size_t weight_size = feature_map->Get(i)->data()->size() * sizeof(float);
83       if (weight_name_to_input_idx_.count(weight_full_name) == 0) {
84         MS_LOG(EXCEPTION) << "Weight " << weight_full_name << " doesn't exist in FL worker.";
85         return false;
86       }
87       MS_LOG(INFO) << "Cover weight " << weight_full_name << " by the model in server.";
88       size_t index = weight_name_to_input_idx_[weight_full_name];
89       int ret = memcpy_s(inputs[index]->addr, inputs[index]->size, weight_data, weight_size);
90       if (ret != 0) {
91         MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
92         return false;
93       }
94     }
95     return true;
96   }
97 
Init(const CNodePtr & kernel_node)98   void Init(const CNodePtr &kernel_node) {
99     MS_LOG(INFO) << "Initializing GetModel kernel";
100     fbb_ = std::make_shared<fl::FBBuilder>();
101     MS_EXCEPTION_IF_NULL(fbb_);
102 
103     MS_EXCEPTION_IF_NULL(kernel_node);
104     server_num_ = fl::worker::FLWorker::GetInstance().server_num();
105     rank_id_ = fl::worker::FLWorker::GetInstance().rank_id();
106     if (rank_id_ == UINT32_MAX) {
107       MS_LOG(EXCEPTION) << "Federated worker is not initialized yet.";
108       return;
109     }
110     target_server_rank_ = rank_id_ % server_num_;
111     fl_name_ = fl::worker::FLWorker::GetInstance().fl_name();
112     MS_LOG(INFO) << "Initializing GetModel kernel. fl_name: " << fl_name_ << ". Request will be sent to server "
113                  << target_server_rank_;
114 
115     size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
116     for (size_t i = 0; i < input_num; i++) {
117       auto input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, i), 0).first;
118       MS_EXCEPTION_IF_NULL(input_node);
119       auto weight_node = input_node->cast<ParameterPtr>();
120       MS_EXCEPTION_IF_NULL(weight_node);
121       std::string weight_name = weight_node->fullname_with_scope();
122       MS_LOG(INFO) << "Parameter name is " << weight_name;
123       weight_name_to_input_idx_.insert(std::make_pair(weight_name, i));
124 
125       auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
126       size_t weight_size_ =
127         std::accumulate(weight_shape.begin(), weight_shape.end(), sizeof(float), std::multiplies<float>());
128       input_size_list_.push_back(weight_size_);
129     }
130     output_size_list_.push_back(sizeof(float));
131   }
132 
InitKernel(const CNodePtr & kernel_node)133   void InitKernel(const CNodePtr &kernel_node) { return; }
134 
135  protected:
InitSizeLists()136   void InitSizeLists() { return; }
137 
138  private:
BuildGetModelReq(const std::shared_ptr<fl::FBBuilder> & fbb,const std::vector<AddressPtr> & weights)139   bool BuildGetModelReq(const std::shared_ptr<fl::FBBuilder> &fbb, const std::vector<AddressPtr> &weights) {
140     MS_EXCEPTION_IF_NULL(fbb_);
141     auto fbs_fl_name = fbb->CreateString(fl_name_);
142     schema::RequestGetModelBuilder req_get_model_builder(*(fbb.get()));
143     req_get_model_builder.add_fl_name(fbs_fl_name);
144     iteration_ = fl::worker::FLWorker::GetInstance().fl_iteration_num();
145     req_get_model_builder.add_iteration(SizeToInt(iteration_));
146     auto req_get_model = req_get_model_builder.Finish();
147     fbb->Finish(req_get_model);
148     return true;
149   }
150 
151   std::shared_ptr<fl::FBBuilder> fbb_;
152   uint32_t rank_id_;
153   uint32_t server_num_;
154   uint32_t target_server_rank_;
155   std::string fl_name_;
156   uint64_t iteration_;
157   std::map<std::string, size_t> weight_name_to_input_idx_;
158 };
159 }  // namespace kernel
160 }  // namespace mindspore
161 
162 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_MODEL_H_
163