1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_MODEL_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_MODEL_H_ 19 20 #include <map> 21 #include <vector> 22 #include <string> 23 #include <memory> 24 #include <utility> 25 #include <functional> 26 #include "backend/kernel_compiler/cpu/cpu_kernel.h" 27 #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" 28 #include "fl/worker/fl_worker.h" 29 30 namespace mindspore { 31 namespace kernel { 32 class GetModelKernel : public CPUKernel { 33 public: 34 GetModelKernel() = default; 35 ~GetModelKernel() override = default; 36 Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> &)37 bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &) { 38 MS_LOG(INFO) << "Launching client GetModelKernel"; 39 if (!BuildGetModelReq(fbb_, inputs)) { 40 MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed."; 41 return false; 42 } 43 44 const schema::ResponseGetModel *get_model_rsp = nullptr; 45 std::shared_ptr<std::vector<unsigned char>> get_model_rsp_msg = nullptr; 46 int response_code = schema::ResponseCode_SucNotReady; 47 while (response_code == schema::ResponseCode_SucNotReady) { 48 if (!fl::worker::FLWorker::GetInstance().SendToServer(target_server_rank_, fbb_->GetBufferPointer(), 49 fbb_->GetSize(), ps::core::TcpUserCommand::kGetModel, 50 &get_model_rsp_msg)) { 51 MS_LOG(EXCEPTION) << "Sending request for GetModel to server " << target_server_rank_ << " failed."; 52 return false; 53 } 54 flatbuffers::Verifier verifier(get_model_rsp_msg->data(), get_model_rsp_msg->size()); 55 if (!verifier.VerifyBuffer<schema::ResponseGetModel>()) { 56 MS_LOG(EXCEPTION) << "The schema of ResponseGetModel is invalid."; 57 return false; 58 } 59 60 get_model_rsp = flatbuffers::GetRoot<schema::ResponseGetModel>(get_model_rsp_msg->data()); 61 MS_EXCEPTION_IF_NULL(get_model_rsp); 62 response_code = get_model_rsp->retcode(); 63 if (response_code == schema::ResponseCode_SUCCEED) { 64 break; 65 } else if (response_code == schema::ResponseCode_SucNotReady) { 66 std::this_thread::sleep_for(std::chrono::milliseconds(200)); 67 continue; 68 } else { 69 MS_LOG(EXCEPTION) << "Launching get model for worker failed. Reason: " << get_model_rsp->reason(); 70 } 71 } 72 73 auto feature_map = get_model_rsp->feature_map(); 74 MS_EXCEPTION_IF_NULL(feature_map); 75 if (feature_map->size() == 0) { 76 MS_LOG(EXCEPTION) << "Feature map after GetModel is empty."; 77 return false; 78 } 79 for (size_t i = 0; i < feature_map->size(); i++) { 80 std::string weight_full_name = feature_map->Get(i)->weight_fullname()->str(); 81 float *weight_data = const_cast<float *>(feature_map->Get(i)->data()->data()); 82 size_t weight_size = feature_map->Get(i)->data()->size() * sizeof(float); 83 if (weight_name_to_input_idx_.count(weight_full_name) == 0) { 84 MS_LOG(EXCEPTION) << "Weight " << weight_full_name << " doesn't exist in FL worker."; 85 return false; 86 } 87 MS_LOG(INFO) << "Cover weight " << weight_full_name << " by the model in server."; 88 size_t index = weight_name_to_input_idx_[weight_full_name]; 89 int ret = memcpy_s(inputs[index]->addr, inputs[index]->size, weight_data, weight_size); 90 if (ret != 0) { 91 MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; 92 return false; 93 } 94 } 95 return true; 96 } 97 Init(const CNodePtr & kernel_node)98 void Init(const CNodePtr &kernel_node) { 99 MS_LOG(INFO) << "Initializing GetModel kernel"; 100 fbb_ = std::make_shared<fl::FBBuilder>(); 101 MS_EXCEPTION_IF_NULL(fbb_); 102 103 MS_EXCEPTION_IF_NULL(kernel_node); 104 server_num_ = fl::worker::FLWorker::GetInstance().server_num(); 105 rank_id_ = fl::worker::FLWorker::GetInstance().rank_id(); 106 if (rank_id_ == UINT32_MAX) { 107 MS_LOG(EXCEPTION) << "Federated worker is not initialized yet."; 108 return; 109 } 110 target_server_rank_ = rank_id_ % server_num_; 111 fl_name_ = fl::worker::FLWorker::GetInstance().fl_name(); 112 MS_LOG(INFO) << "Initializing GetModel kernel. fl_name: " << fl_name_ << ". Request will be sent to server " 113 << target_server_rank_; 114 115 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); 116 for (size_t i = 0; i < input_num; i++) { 117 auto input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, i), 0).first; 118 MS_EXCEPTION_IF_NULL(input_node); 119 auto weight_node = input_node->cast<ParameterPtr>(); 120 MS_EXCEPTION_IF_NULL(weight_node); 121 std::string weight_name = weight_node->fullname_with_scope(); 122 MS_LOG(INFO) << "Parameter name is " << weight_name; 123 weight_name_to_input_idx_.insert(std::make_pair(weight_name, i)); 124 125 auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); 126 size_t weight_size_ = 127 std::accumulate(weight_shape.begin(), weight_shape.end(), sizeof(float), std::multiplies<float>()); 128 input_size_list_.push_back(weight_size_); 129 } 130 output_size_list_.push_back(sizeof(float)); 131 } 132 InitKernel(const CNodePtr & kernel_node)133 void InitKernel(const CNodePtr &kernel_node) { return; } 134 135 protected: InitSizeLists()136 void InitSizeLists() { return; } 137 138 private: BuildGetModelReq(const std::shared_ptr<fl::FBBuilder> & fbb,const std::vector<AddressPtr> & weights)139 bool BuildGetModelReq(const std::shared_ptr<fl::FBBuilder> &fbb, const std::vector<AddressPtr> &weights) { 140 MS_EXCEPTION_IF_NULL(fbb_); 141 auto fbs_fl_name = fbb->CreateString(fl_name_); 142 schema::RequestGetModelBuilder req_get_model_builder(*(fbb.get())); 143 req_get_model_builder.add_fl_name(fbs_fl_name); 144 iteration_ = fl::worker::FLWorker::GetInstance().fl_iteration_num(); 145 req_get_model_builder.add_iteration(SizeToInt(iteration_)); 146 auto req_get_model = req_get_model_builder.Finish(); 147 fbb->Finish(req_get_model); 148 return true; 149 } 150 151 std::shared_ptr<fl::FBBuilder> fbb_; 152 uint32_t rank_id_; 153 uint32_t server_num_; 154 uint32_t target_server_rank_; 155 std::string fl_name_; 156 uint64_t iteration_; 157 std::map<std::string, size_t> weight_name_to_input_idx_; 158 }; 159 } // namespace kernel 160 } // namespace mindspore 161 162 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_MODEL_H_ 163