/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_ #define MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_ #include #include #include #include #include #include #include #include "fl/server/common.h" #include "fl/server/parameter_aggregator.h" #ifdef ENABLE_ARMOUR #include "fl/armour/cipher/cipher_unmask.h" #endif namespace mindspore { namespace fl { namespace server { // Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles // logics relevant to kernel launching. class Executor { public: static Executor &GetInstance() { static Executor instance; return instance; } // FuncGraphPtr func_graph is the graph compiled by the frontend. aggregation_count is the number which will // be used for aggregators. // As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the // optimizer cnode's input. So we need to initialize server executor using func_graph. void Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count); // Reinitialize parameter aggregators after scaling operations are done. bool ReInitForScaling(); // After hyper-parameters are updated, some parameter aggregators should be reinitialized. bool ReInitForUpdatingHyperParams(size_t aggr_threshold); // Called in parameter server training mode to do Push operation. // For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered // as completed. bool HandlePush(const std::string ¶m_name, const UploadData &upload_data); // Called in parameter server training mode to do Pull operation. // Returns the value of parameter param_name. // HandlePull method must be called the same times as HandlePush is called before it's considered as // completed. AddressPtr HandlePull(const std::string ¶m_name); // Called in federated learning training mode. Update value for parameter param_name. bool HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data); // Called in asynchronous federated learning training mode. Update current model with the new feature map // asynchronously. bool HandleModelUpdateAsync(const std::map &feature_map); // Overwrite the weights in server using pushed feature map. bool HandlePushWeight(const std::map &feature_map); // Returns multiple trainable parameters passed by weight_names. std::map HandlePullWeight(const std::vector ¶m_names); // Reset the aggregation status for all aggregation kernels in the server. void ResetAggregationStatus(); // Judge whether aggregation processes for all weights/gradients are completed. bool IsAllWeightAggregationDone(); // Judge whether the aggregation processes for the given param_names are completed. bool IsWeightAggrDone(const std::vector ¶m_names); // Returns whole model in key-value where key refers to the parameter name. std::map GetModel(); // Returns whether the executor singleton is already initialized. bool initialized() const; const std::vector ¶m_names() const; // The unmasking method for pairwise encrypt algorithm. bool Unmask(); // The setter and getter for unmasked flag to judge whether the unmasking is completed. void set_unmasked(bool unmasked); bool unmasked() const; private: Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}), unmasked_(false) {} ~Executor() = default; Executor(const Executor &) = delete; Executor &operator=(const Executor &) = delete; // Returns the trainable parameter name parsed from this cnode. std::string GetTrainableParamName(const CNodePtr &cnode); // Server's graph is basically the same as Worker's graph, so we can get all information from func_graph for later // computations. Including forward and backward propagation, aggregation, optimizing, etc. bool InitParamAggregator(const FuncGraphPtr &func_graph); bool initialized_; size_t aggregation_count_; std::vector param_names_; // The map for trainable parameter names and its ParameterAggregator, as noted in the header file // parameter_aggregator.h std::map> param_aggrs_; // The mutex ensures that the operation on whole model is threadsafe. // The whole model is constructed by all trainable parameters. std::mutex model_mutex_; // Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can // acquire lock before calling its method. std::map parameter_mutex_; #ifdef ENABLE_ARMOUR armour::CipherUnmask cipher_unmask_; #endif // The flag represents the unmasking status. std::atomic unmasked_; }; } // namespace server } // namespace fl } // namespace mindspore #endif // MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_