• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include "ps/core/communicator/communicator_base.h"
24 #include "ps/core/communicator/tcp_communicator.h"
25 #include "ps/core/communicator/task_executor.h"
26 #include "ps/core/file_configuration.h"
27 #include "fl/server/common.h"
28 #include "fl/server/executor.h"
29 #include "fl/server/iteration.h"
30 #ifdef ENABLE_ARMOUR
31 #include "fl/armour/cipher/cipher_init.h"
32 #endif
33 
34 namespace mindspore {
35 namespace fl {
36 namespace server {
37 // The sleeping time of the server thread before the networking is completed.
38 constexpr uint32_t kServerSleepTimeForNetworking = 1000;
39 
40 // Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
41 class Server {
42  public:
GetInstance()43   static Server &GetInstance() {
44     static Server instance;
45     return instance;
46   }
47 
48   void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
49                   const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold);
50 
51   // According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
52   // blocked until the server is finalized.
53   // func_graph is the frontend graph which will be parse in server's exector and aggregator.
54   void Run();
55 
56   void SwitchToSafeMode();
57   void CancelSafeMode();
58   bool IsSafeMode() const;
59   void WaitExitSafeMode() const;
60 
61   // Whether the training job of the server is enabled.
62   InstanceState instance_state() const;
63 
64  private:
Server()65   Server()
66       : server_node_(nullptr),
67         task_executor_(nullptr),
68         use_tcp_(false),
69         use_http_(false),
70         http_port_(0),
71         func_graph_(nullptr),
72         executor_threshold_(0),
73         communicator_with_server_(nullptr),
74         communicators_with_worker_({}),
75         iteration_(nullptr),
76         safemode_(true),
77         scheduler_ip_(""),
78         scheduler_port_(0),
79         server_num_(0),
80         worker_num_(0),
81         fl_server_port_(0),
82         cipher_initial_client_cnt_(0),
83         cipher_exchange_keys_cnt_(0),
84         cipher_get_keys_cnt_(0),
85         cipher_share_secrets_cnt_(0),
86         cipher_get_secrets_cnt_(0),
87         cipher_get_clientlist_cnt_(0),
88         cipher_reconstruct_secrets_up_cnt_(0),
89         cipher_reconstruct_secrets_down_cnt_(0),
90         cipher_time_window_(0) {}
91   ~Server() = default;
92   Server(const Server &) = delete;
93   Server &operator=(const Server &) = delete;
94 
95   // Load variables which is set by ps_context.
96   void InitServerContext();
97 
98   // Try to recover server config from persistent storage.
99   void Recovery();
100 
101   // Initialize the server cluster, server node and communicators.
102   void InitCluster();
103   bool InitCommunicatorWithServer();
104   bool InitCommunicatorWithWorker();
105 
106   // Initialize iteration with rounds. Which rounds to use could be set by ps_context as well.
107   void InitIteration();
108 
109   // Register all message and event callbacks for communicators(TCP and HTTP). This method must be called before
110   // communicators are started.
111   void RegisterCommCallbacks();
112 
113   // Register cluster exception callbacks. This method is called in RegisterCommCallbacks.
114   void RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
115 
116   // Register message callbacks. These messages are mainly from scheduler.
117   void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
118 
119   // Initialize executor according to the server mode.
120   void InitExecutor();
121 
122   // Initialize cipher according to the public param.
123   void InitCipher();
124 
125   // Create round kernels and bind these kernels with corresponding Round.
126   void RegisterRoundKernel();
127 
128   void InitMetrics();
129 
130   // The communicators should be started after all initializations are completed.
131   void StartCommunicator();
132 
133   // The barriers before scaling operations.
134   void ProcessBeforeScalingOut();
135   void ProcessBeforeScalingIn();
136 
137   // The handlers after scheduler's scaling operations are done.
138   void ProcessAfterScalingOut();
139   void ProcessAfterScalingIn();
140 
141   // Handlers for enableFLS/disableFLS requests from the scheduler.
142   void HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
143   void HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
144 
145   // Finish current instance and start a new one. FLPlan could be changed in this method.
146   void HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
147 
148   // Query current instance information.
149   void HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
150 
151   // The server node is initialized in Server.
152   std::shared_ptr<ps::core::ServerNode> server_node_;
153 
154   // The task executor of the communicators. This helps server to handle network message concurrently. The tasks
155   // submitted to this task executor is asynchronous.
156   std::shared_ptr<ps::core::TaskExecutor> task_executor_;
157 
158   // Which protocol should communicators use.
159   bool use_tcp_;
160   bool use_http_;
161   uint16_t http_port_;
162 
163   // The configure of all rounds.
164   std::vector<RoundConfig> rounds_config_;
165   CipherConfig cipher_config_;
166 
167   // The graph passed by the frontend without backend optimizing.
168   FuncGraphPtr func_graph_;
169 
170   // The threshold count for executor to do aggregation or optimizing.
171   size_t executor_threshold_;
172 
173   // Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
174   // operations, etc.
175   std::shared_ptr<ps::core::CommunicatorBase> communicator_with_server_;
176 
177   // The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
178   // In some cases, both types should be supported in one distributed training job. So here we may have multiple
179   // communicators.
180   std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_;
181 
182   // Mutex for scaling operations. We must wait server's initialization done before handle scaling events.
183   std::mutex scaling_mtx_;
184 
185   // Iteration consists of multiple kinds of rounds.
186   Iteration *iteration_;
187 
188   // The flag that represents whether server is in safemode.
189   // If true, the server is not available to workers and clients.
190   std::atomic_bool safemode_;
191 
192   // Variables set by ps context.
193 #ifdef ENABLE_ARMOUR
194   armour::CipherInit *cipher_init_{nullptr};
195 #endif
196   std::string scheduler_ip_;
197   uint16_t scheduler_port_;
198   uint32_t server_num_;
199   uint32_t worker_num_;
200   uint16_t fl_server_port_;
201   size_t cipher_initial_client_cnt_;
202   size_t cipher_exchange_keys_cnt_;
203   size_t cipher_get_keys_cnt_;
204   size_t cipher_share_secrets_cnt_;
205   size_t cipher_get_secrets_cnt_;
206   size_t cipher_get_clientlist_cnt_;
207   size_t cipher_reconstruct_secrets_up_cnt_;
208   size_t cipher_reconstruct_secrets_down_cnt_;
209   uint64_t cipher_time_window_;
210 };
211 }  // namespace server
212 }  // namespace fl
213 }  // namespace mindspore
214 #endif  // MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
215