1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_UPDATE_MODEL_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_UPDATE_MODEL_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <functional> 24 #include "backend/kernel_compiler/cpu/cpu_kernel.h" 25 #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" 26 #include "fl/worker/fl_worker.h" 27 28 namespace mindspore { 29 namespace kernel { 30 class UpdateModelKernel : public CPUKernel { 31 public: 32 UpdateModelKernel() = default; 33 ~UpdateModelKernel() override = default; 34 Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> &)35 bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &) { 36 MS_LOG(INFO) << "Launching client UpdateModelKernel"; 37 if (inputs.size() != weight_full_names_.size()) { 38 MS_LOG(EXCEPTION) << "Input number of UpdateModelKernel should be " << weight_full_names_.size() << ", but got " 39 << inputs.size(); 40 return false; 41 } 42 43 if (!WeightingData(inputs)) { 44 MS_LOG(EXCEPTION) << "Weighting data with data_size failed."; 45 return false; 46 } 47 48 if (!BuildUpdateModelReq(fbb_, inputs)) { 49 MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed."; 50 return false; 51 } 52 53 std::shared_ptr<std::vector<unsigned char>> update_model_rsp_msg = nullptr; 54 if (!fl::worker::FLWorker::GetInstance().SendToServer(target_server_rank_, fbb_->GetBufferPointer(), 55 fbb_->GetSize(), ps::core::TcpUserCommand::kUpdateModel, 56 &update_model_rsp_msg)) { 57 MS_LOG(EXCEPTION) << "Sending request for UpdateModel to server " << target_server_rank_ << " failed."; 58 return false; 59 } 60 flatbuffers::Verifier verifier(update_model_rsp_msg->data(), update_model_rsp_msg->size()); 61 if (!verifier.VerifyBuffer<schema::ResponseUpdateModel>()) { 62 MS_LOG(EXCEPTION) << "The schema of ResponseUpdateModel is invalid."; 63 return false; 64 } 65 66 const schema::ResponseFLJob *update_model_rsp = 67 flatbuffers::GetRoot<schema::ResponseFLJob>(update_model_rsp_msg->data()); 68 MS_EXCEPTION_IF_NULL(update_model_rsp); 69 auto response_code = update_model_rsp->retcode(); 70 switch (response_code) { 71 case schema::ResponseCode_SUCCEED: 72 case schema::ResponseCode_OutOfTime: 73 break; 74 default: 75 MS_LOG(EXCEPTION) << "Launching start fl job for worker failed. Reason: " << update_model_rsp->reason(); 76 } 77 return true; 78 } 79 Init(const CNodePtr & kernel_node)80 void Init(const CNodePtr &kernel_node) { 81 MS_LOG(INFO) << "Initializing UpdateModel kernel"; 82 fbb_ = std::make_shared<fl::FBBuilder>(); 83 MS_EXCEPTION_IF_NULL(fbb_); 84 85 MS_EXCEPTION_IF_NULL(kernel_node); 86 server_num_ = fl::worker::FLWorker::GetInstance().server_num(); 87 rank_id_ = fl::worker::FLWorker::GetInstance().rank_id(); 88 if (rank_id_ == UINT32_MAX) { 89 MS_LOG(EXCEPTION) << "Federated worker is not initialized yet."; 90 return; 91 } 92 target_server_rank_ = rank_id_ % server_num_; 93 fl_name_ = fl::worker::FLWorker::GetInstance().fl_name(); 94 fl_id_ = fl::worker::FLWorker::GetInstance().fl_id(); 95 MS_LOG(INFO) << "Initializing StartFLJob kernel. fl_name: " << fl_name_ << ", fl_id: " << fl_id_ 96 << ". Request will be sent to server " << target_server_rank_; 97 98 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); 99 for (size_t i = 0; i < input_num; i++) { 100 auto input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, i), 0).first; 101 MS_EXCEPTION_IF_NULL(input_node); 102 auto weight_node = input_node->cast<ParameterPtr>(); 103 MS_EXCEPTION_IF_NULL(weight_node); 104 std::string weight_name = weight_node->fullname_with_scope(); 105 MS_LOG(INFO) << "Parameter name is " << weight_name; 106 weight_full_names_.push_back(weight_name); 107 108 auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); 109 size_t weight_size_ = 110 std::accumulate(weight_shape.begin(), weight_shape.end(), sizeof(float), std::multiplies<float>()); 111 input_size_list_.push_back(weight_size_); 112 } 113 output_size_list_.push_back(sizeof(float)); 114 } 115 InitKernel(const CNodePtr & kernel_node)116 void InitKernel(const CNodePtr &kernel_node) { return; } 117 118 protected: InitSizeLists()119 void InitSizeLists() { return; } 120 121 private: BuildUpdateModelReq(const std::shared_ptr<fl::FBBuilder> & fbb,const std::vector<AddressPtr> & weights)122 bool BuildUpdateModelReq(const std::shared_ptr<fl::FBBuilder> &fbb, const std::vector<AddressPtr> &weights) { 123 MS_EXCEPTION_IF_NULL(fbb_); 124 auto fbs_fl_name = fbb->CreateString(fl_name_); 125 auto fbs_fl_id = fbb->CreateString(fl_id_); 126 std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps; 127 for (size_t i = 0; i < weight_full_names_.size(); i++) { 128 const std::string &weight_name = weight_full_names_[i]; 129 auto fbs_weight_fullname = fbb->CreateString(weight_name); 130 auto fbs_weight_data = 131 fbb->CreateVector(reinterpret_cast<const float *>(weights[i]->addr), weights[i]->size / sizeof(float)); 132 auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data); 133 fbs_feature_maps.push_back(fbs_feature_map); 134 } 135 auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); 136 137 schema::RequestUpdateModelBuilder req_update_model_builder(*(fbb.get())); 138 req_update_model_builder.add_fl_name(fbs_fl_name); 139 req_update_model_builder.add_fl_id(fbs_fl_id); 140 iteration_ = fl::worker::FLWorker::GetInstance().fl_iteration_num(); 141 req_update_model_builder.add_iteration(SizeToInt(iteration_)); 142 req_update_model_builder.add_feature_map(fbs_feature_maps_vector); 143 auto req_update_model = req_update_model_builder.Finish(); 144 fbb->Finish(req_update_model); 145 return true; 146 } 147 WeightingData(const std::vector<AddressPtr> & inputs)148 bool WeightingData(const std::vector<AddressPtr> &inputs) { 149 data_size_ = fl::worker::FLWorker::GetInstance().data_size(); 150 for (auto &input : inputs) { 151 float *data = reinterpret_cast<float *>(input->addr); 152 for (size_t i = 0; i < input->size / sizeof(float); i++) { 153 data[i] *= data_size_; 154 } 155 } 156 return true; 157 } 158 159 std::shared_ptr<fl::FBBuilder> fbb_; 160 uint32_t rank_id_; 161 uint32_t server_num_; 162 uint32_t target_server_rank_; 163 std::string fl_name_; 164 std::string fl_id_; 165 int data_size_; 166 uint64_t iteration_; 167 std::vector<std::string> weight_full_names_; 168 }; 169 } // namespace kernel 170 } // namespace mindspore 171 172 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_UPDATE_MODEL_H_ 173