• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_UPDATE_MODEL_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_UPDATE_MODEL_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <functional>
24 #include "backend/kernel_compiler/cpu/cpu_kernel.h"
25 #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
26 #include "fl/worker/fl_worker.h"
27 
28 namespace mindspore {
29 namespace kernel {
30 class UpdateModelKernel : public CPUKernel {
31  public:
32   UpdateModelKernel() = default;
33   ~UpdateModelKernel() override = default;
34 
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> &)35   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &) {
36     MS_LOG(INFO) << "Launching client UpdateModelKernel";
37     if (inputs.size() != weight_full_names_.size()) {
38       MS_LOG(EXCEPTION) << "Input number of UpdateModelKernel should be " << weight_full_names_.size() << ", but got "
39                         << inputs.size();
40       return false;
41     }
42 
43     if (!WeightingData(inputs)) {
44       MS_LOG(EXCEPTION) << "Weighting data with data_size failed.";
45       return false;
46     }
47 
48     if (!BuildUpdateModelReq(fbb_, inputs)) {
49       MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
50       return false;
51     }
52 
53     std::shared_ptr<std::vector<unsigned char>> update_model_rsp_msg = nullptr;
54     if (!fl::worker::FLWorker::GetInstance().SendToServer(target_server_rank_, fbb_->GetBufferPointer(),
55                                                           fbb_->GetSize(), ps::core::TcpUserCommand::kUpdateModel,
56                                                           &update_model_rsp_msg)) {
57       MS_LOG(EXCEPTION) << "Sending request for UpdateModel to server " << target_server_rank_ << " failed.";
58       return false;
59     }
60     flatbuffers::Verifier verifier(update_model_rsp_msg->data(), update_model_rsp_msg->size());
61     if (!verifier.VerifyBuffer<schema::ResponseUpdateModel>()) {
62       MS_LOG(EXCEPTION) << "The schema of ResponseUpdateModel is invalid.";
63       return false;
64     }
65 
66     const schema::ResponseFLJob *update_model_rsp =
67       flatbuffers::GetRoot<schema::ResponseFLJob>(update_model_rsp_msg->data());
68     MS_EXCEPTION_IF_NULL(update_model_rsp);
69     auto response_code = update_model_rsp->retcode();
70     switch (response_code) {
71       case schema::ResponseCode_SUCCEED:
72       case schema::ResponseCode_OutOfTime:
73         break;
74       default:
75         MS_LOG(EXCEPTION) << "Launching start fl job for worker failed. Reason: " << update_model_rsp->reason();
76     }
77     return true;
78   }
79 
Init(const CNodePtr & kernel_node)80   void Init(const CNodePtr &kernel_node) {
81     MS_LOG(INFO) << "Initializing UpdateModel kernel";
82     fbb_ = std::make_shared<fl::FBBuilder>();
83     MS_EXCEPTION_IF_NULL(fbb_);
84 
85     MS_EXCEPTION_IF_NULL(kernel_node);
86     server_num_ = fl::worker::FLWorker::GetInstance().server_num();
87     rank_id_ = fl::worker::FLWorker::GetInstance().rank_id();
88     if (rank_id_ == UINT32_MAX) {
89       MS_LOG(EXCEPTION) << "Federated worker is not initialized yet.";
90       return;
91     }
92     target_server_rank_ = rank_id_ % server_num_;
93     fl_name_ = fl::worker::FLWorker::GetInstance().fl_name();
94     fl_id_ = fl::worker::FLWorker::GetInstance().fl_id();
95     MS_LOG(INFO) << "Initializing StartFLJob kernel. fl_name: " << fl_name_ << ", fl_id: " << fl_id_
96                  << ". Request will be sent to server " << target_server_rank_;
97 
98     size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
99     for (size_t i = 0; i < input_num; i++) {
100       auto input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, i), 0).first;
101       MS_EXCEPTION_IF_NULL(input_node);
102       auto weight_node = input_node->cast<ParameterPtr>();
103       MS_EXCEPTION_IF_NULL(weight_node);
104       std::string weight_name = weight_node->fullname_with_scope();
105       MS_LOG(INFO) << "Parameter name is " << weight_name;
106       weight_full_names_.push_back(weight_name);
107 
108       auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
109       size_t weight_size_ =
110         std::accumulate(weight_shape.begin(), weight_shape.end(), sizeof(float), std::multiplies<float>());
111       input_size_list_.push_back(weight_size_);
112     }
113     output_size_list_.push_back(sizeof(float));
114   }
115 
InitKernel(const CNodePtr & kernel_node)116   void InitKernel(const CNodePtr &kernel_node) { return; }
117 
118  protected:
InitSizeLists()119   void InitSizeLists() { return; }
120 
121  private:
BuildUpdateModelReq(const std::shared_ptr<fl::FBBuilder> & fbb,const std::vector<AddressPtr> & weights)122   bool BuildUpdateModelReq(const std::shared_ptr<fl::FBBuilder> &fbb, const std::vector<AddressPtr> &weights) {
123     MS_EXCEPTION_IF_NULL(fbb_);
124     auto fbs_fl_name = fbb->CreateString(fl_name_);
125     auto fbs_fl_id = fbb->CreateString(fl_id_);
126     std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
127     for (size_t i = 0; i < weight_full_names_.size(); i++) {
128       const std::string &weight_name = weight_full_names_[i];
129       auto fbs_weight_fullname = fbb->CreateString(weight_name);
130       auto fbs_weight_data =
131         fbb->CreateVector(reinterpret_cast<const float *>(weights[i]->addr), weights[i]->size / sizeof(float));
132       auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data);
133       fbs_feature_maps.push_back(fbs_feature_map);
134     }
135     auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
136 
137     schema::RequestUpdateModelBuilder req_update_model_builder(*(fbb.get()));
138     req_update_model_builder.add_fl_name(fbs_fl_name);
139     req_update_model_builder.add_fl_id(fbs_fl_id);
140     iteration_ = fl::worker::FLWorker::GetInstance().fl_iteration_num();
141     req_update_model_builder.add_iteration(SizeToInt(iteration_));
142     req_update_model_builder.add_feature_map(fbs_feature_maps_vector);
143     auto req_update_model = req_update_model_builder.Finish();
144     fbb->Finish(req_update_model);
145     return true;
146   }
147 
WeightingData(const std::vector<AddressPtr> & inputs)148   bool WeightingData(const std::vector<AddressPtr> &inputs) {
149     data_size_ = fl::worker::FLWorker::GetInstance().data_size();
150     for (auto &input : inputs) {
151       float *data = reinterpret_cast<float *>(input->addr);
152       for (size_t i = 0; i < input->size / sizeof(float); i++) {
153         data[i] *= data_size_;
154       }
155     }
156     return true;
157   }
158 
159   std::shared_ptr<fl::FBBuilder> fbb_;
160   uint32_t rank_id_;
161   uint32_t server_num_;
162   uint32_t target_server_rank_;
163   std::string fl_name_;
164   std::string fl_id_;
165   int data_size_;
166   uint64_t iteration_;
167   std::vector<std::string> weight_full_names_;
168 };
169 }  // namespace kernel
170 }  // namespace mindspore
171 
172 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_UPDATE_MODEL_H_
173