1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_ 18 #define MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_ 19 20 #include <map> 21 #include <set> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 #include <mutex> 26 #include <condition_variable> 27 #include "fl/server/common.h" 28 #include "fl/server/parameter_aggregator.h" 29 #ifdef ENABLE_ARMOUR 30 #include "fl/armour/cipher/cipher_unmask.h" 31 #endif 32 33 namespace mindspore { 34 namespace fl { 35 namespace server { 36 // Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles 37 // logics relevant to kernel launching. 38 class Executor { 39 public: GetInstance()40 static Executor &GetInstance() { 41 static Executor instance; 42 return instance; 43 } 44 45 // FuncGraphPtr func_graph is the graph compiled by the frontend. aggregation_count is the number which will 46 // be used for aggregators. 47 // As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the 48 // optimizer cnode's input. So we need to initialize server executor using func_graph. 49 void Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count); 50 51 // Reinitialize parameter aggregators after scaling operations are done. 52 bool ReInitForScaling(); 53 54 // After hyper-parameters are updated, some parameter aggregators should be reinitialized. 55 bool ReInitForUpdatingHyperParams(size_t aggr_threshold); 56 57 // Called in parameter server training mode to do Push operation. 58 // For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered 59 // as completed. 60 bool HandlePush(const std::string ¶m_name, const UploadData &upload_data); 61 62 // Called in parameter server training mode to do Pull operation. 63 // Returns the value of parameter param_name. 64 // HandlePull method must be called the same times as HandlePush is called before it's considered as 65 // completed. 66 AddressPtr HandlePull(const std::string ¶m_name); 67 68 // Called in federated learning training mode. Update value for parameter param_name. 69 bool HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data); 70 71 // Called in asynchronous federated learning training mode. Update current model with the new feature map 72 // asynchronously. 73 bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map); 74 75 // Overwrite the weights in server using pushed feature map. 76 bool HandlePushWeight(const std::map<std::string, Address> &feature_map); 77 78 // Returns multiple trainable parameters passed by weight_names. 79 std::map<std::string, AddressPtr> HandlePullWeight(const std::vector<std::string> ¶m_names); 80 81 // Reset the aggregation status for all aggregation kernels in the server. 82 void ResetAggregationStatus(); 83 84 // Judge whether aggregation processes for all weights/gradients are completed. 85 bool IsAllWeightAggregationDone(); 86 87 // Judge whether the aggregation processes for the given param_names are completed. 88 bool IsWeightAggrDone(const std::vector<std::string> ¶m_names); 89 90 // Returns whole model in key-value where key refers to the parameter name. 91 std::map<std::string, AddressPtr> GetModel(); 92 93 // Returns whether the executor singleton is already initialized. 94 bool initialized() const; 95 96 const std::vector<std::string> ¶m_names() const; 97 98 // The unmasking method for pairwise encrypt algorithm. 99 bool Unmask(); 100 101 // The setter and getter for unmasked flag to judge whether the unmasking is completed. 102 void set_unmasked(bool unmasked); 103 bool unmasked() const; 104 105 private: Executor()106 Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}), unmasked_(false) {} 107 ~Executor() = default; 108 Executor(const Executor &) = delete; 109 Executor &operator=(const Executor &) = delete; 110 111 // Returns the trainable parameter name parsed from this cnode. 112 std::string GetTrainableParamName(const CNodePtr &cnode); 113 114 // Server's graph is basically the same as Worker's graph, so we can get all information from func_graph for later 115 // computations. Including forward and backward propagation, aggregation, optimizing, etc. 116 bool InitParamAggregator(const FuncGraphPtr &func_graph); 117 118 bool initialized_; 119 size_t aggregation_count_; 120 std::vector<std::string> param_names_; 121 122 // The map for trainable parameter names and its ParameterAggregator, as noted in the header file 123 // parameter_aggregator.h 124 std::map<std::string, std::shared_ptr<ParameterAggregator>> param_aggrs_; 125 126 // The mutex ensures that the operation on whole model is threadsafe. 127 // The whole model is constructed by all trainable parameters. 128 std::mutex model_mutex_; 129 130 // Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can 131 // acquire lock before calling its method. 132 std::map<std::string, std::mutex> parameter_mutex_; 133 134 #ifdef ENABLE_ARMOUR 135 armour::CipherUnmask cipher_unmask_; 136 #endif 137 138 // The flag represents the unmasking status. 139 std::atomic<bool> unmasked_; 140 }; 141 } // namespace server 142 } // namespace fl 143 } // namespace mindspore 144 #endif // MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_ 145