• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
19 
20 #include <map>
21 #include <set>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #include <mutex>
26 #include <condition_variable>
27 #include "fl/server/common.h"
28 #include "fl/server/parameter_aggregator.h"
29 #ifdef ENABLE_ARMOUR
30 #include "fl/armour/cipher/cipher_unmask.h"
31 #endif
32 
33 namespace mindspore {
34 namespace fl {
35 namespace server {
36 // Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles
37 // logics relevant to kernel launching.
38 class Executor {
39  public:
GetInstance()40   static Executor &GetInstance() {
41     static Executor instance;
42     return instance;
43   }
44 
45   // FuncGraphPtr func_graph is the graph compiled by the frontend. aggregation_count is the number which will
46   // be used for aggregators.
47   // As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the
48   // optimizer cnode's input. So we need to initialize server executor using func_graph.
49   void Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count);
50 
51   // Reinitialize parameter aggregators after scaling operations are done.
52   bool ReInitForScaling();
53 
54   // After hyper-parameters are updated, some parameter aggregators should be reinitialized.
55   bool ReInitForUpdatingHyperParams(size_t aggr_threshold);
56 
57   // Called in parameter server training mode to do Push operation.
58   // For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered
59   // as completed.
60   bool HandlePush(const std::string &param_name, const UploadData &upload_data);
61 
62   // Called in parameter server training mode to do Pull operation.
63   // Returns the value of parameter param_name.
64   // HandlePull method must be called the same times as HandlePush is called before it's considered as
65   // completed.
66   AddressPtr HandlePull(const std::string &param_name);
67 
68   // Called in federated learning training mode. Update value for parameter param_name.
69   bool HandleModelUpdate(const std::string &param_name, const UploadData &upload_data);
70 
71   // Called in asynchronous federated learning training mode. Update current model with the new feature map
72   // asynchronously.
73   bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
74 
75   // Overwrite the weights in server using pushed feature map.
76   bool HandlePushWeight(const std::map<std::string, Address> &feature_map);
77 
78   // Returns multiple trainable parameters passed by weight_names.
79   std::map<std::string, AddressPtr> HandlePullWeight(const std::vector<std::string> &param_names);
80 
81   // Reset the aggregation status for all aggregation kernels in the server.
82   void ResetAggregationStatus();
83 
84   // Judge whether aggregation processes for all weights/gradients are completed.
85   bool IsAllWeightAggregationDone();
86 
87   // Judge whether the aggregation processes for the given param_names are completed.
88   bool IsWeightAggrDone(const std::vector<std::string> &param_names);
89 
90   // Returns whole model in key-value where key refers to the parameter name.
91   std::map<std::string, AddressPtr> GetModel();
92 
93   // Returns whether the executor singleton is already initialized.
94   bool initialized() const;
95 
96   const std::vector<std::string> &param_names() const;
97 
98   // The unmasking method for pairwise encrypt algorithm.
99   bool Unmask();
100 
101   // The setter and getter for unmasked flag to judge whether the unmasking is completed.
102   void set_unmasked(bool unmasked);
103   bool unmasked() const;
104 
105  private:
Executor()106   Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}), unmasked_(false) {}
107   ~Executor() = default;
108   Executor(const Executor &) = delete;
109   Executor &operator=(const Executor &) = delete;
110 
111   // Returns the trainable parameter name parsed from this cnode.
112   std::string GetTrainableParamName(const CNodePtr &cnode);
113 
114   // Server's graph is basically the same as Worker's graph, so we can get all information from func_graph for later
115   // computations. Including forward and backward propagation, aggregation, optimizing, etc.
116   bool InitParamAggregator(const FuncGraphPtr &func_graph);
117 
118   bool initialized_;
119   size_t aggregation_count_;
120   std::vector<std::string> param_names_;
121 
122   // The map for trainable parameter names and its ParameterAggregator, as noted in the header file
123   // parameter_aggregator.h
124   std::map<std::string, std::shared_ptr<ParameterAggregator>> param_aggrs_;
125 
126   // The mutex ensures that the operation on whole model is threadsafe.
127   // The whole model is constructed by all trainable parameters.
128   std::mutex model_mutex_;
129 
130   // Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can
131   // acquire lock before calling its method.
132   std::map<std::string, std::mutex> parameter_mutex_;
133 
134 #ifdef ENABLE_ARMOUR
135   armour::CipherUnmask cipher_unmask_;
136 #endif
137 
138   // The flag represents the unmasking status.
139   std::atomic<bool> unmasked_;
140 };
141 }  // namespace server
142 }  // namespace fl
143 }  // namespace mindspore
144 #endif  // MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
145