1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_ 18 #define MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include "proto/comm.pb.h" 24 #include "schema/fl_job_generated.h" 25 #include "schema/cipher_generated.h" 26 #include "ps/ps_context.h" 27 #include "ps/core/worker_node.h" 28 #include "ps/core/cluster_metadata.h" 29 #include "ps/core/communicator/tcp_communicator.h" 30 31 namespace mindspore { 32 namespace fl { 33 using FBBuilder = flatbuffers::FlatBufferBuilder; 34 35 // The step number for worker to judge whether to communicate with server. 36 constexpr uint32_t kTrainBeginStepNum = 1; 37 constexpr uint32_t kTrainEndStepNum = 0; 38 constexpr uint32_t kOneStepPerIteration = 1; 39 40 // The sleeping time of the worker thread before the networking is completed. 41 constexpr uint32_t kWorkerSleepTimeForNetworking = 1000; 42 43 // The time duration between retrying when server is in safemode. 44 constexpr uint32_t kWorkerRetryDurationForSafeMode = 500; 45 46 // The rank of the leader server. 47 constexpr uint32_t kLeaderServerRank = 0; 48 49 // The timeout for worker sending message to server in case of network jitter. 50 constexpr uint32_t kWorkerTimeout = 30; 51 52 enum class IterationState { 53 // This iteration is still in process. 54 kRunning, 55 // This iteration is completed and the next iteration is not started yet. 56 kCompleted 57 }; 58 59 namespace worker { 60 // This class is used for hybrid training mode for now. In later version, parameter server mode will also use this class 61 // as worker. 62 class FLWorker { 63 public: GetInstance()64 static FLWorker &GetInstance() { 65 static FLWorker instance; 66 return instance; 67 } 68 void Run(); 69 void Finalize(); 70 bool SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command, 71 std::shared_ptr<std::vector<unsigned char>> *output = nullptr); 72 73 uint32_t server_num() const; 74 uint32_t worker_num() const; 75 uint32_t rank_id() const; 76 uint64_t worker_step_num_per_iteration() const; 77 78 // Check whether worker has exited. 79 bool running() const; 80 81 // These methods set the worker's iteration state. 82 void SetIterationRunning(); 83 void SetIterationCompleted(); 84 85 void set_fl_iteration_num(uint64_t iteration_num); 86 uint64_t fl_iteration_num() const; 87 88 void set_data_size(int data_size); 89 int data_size() const; 90 91 std::string fl_name() const; 92 std::string fl_id() const; 93 94 private: FLWorker()95 FLWorker() 96 : running_(false), 97 server_num_(0), 98 worker_num_(0), 99 scheduler_ip_(""), 100 scheduler_port_(0), 101 worker_node_(nullptr), 102 rank_id_(UINT32_MAX), 103 iteration_num_(0), 104 data_size_(0), 105 worker_step_num_per_iteration_(1), 106 server_iteration_state_(IterationState::kCompleted), 107 worker_iteration_state_(IterationState::kCompleted), 108 safemode_(false) {} 109 ~FLWorker() = default; 110 FLWorker(const FLWorker &) = delete; 111 FLWorker &operator=(const FLWorker &) = delete; 112 113 // Initialize the scaler for worker 114 void InitializeFollowerScaler(); 115 116 // The handlers for the iteration state events. 117 void HandleIterationRunningEvent(); 118 void HandleIterationCompletedEvent(); 119 120 // The barriers before scaling operations. 121 void ProcessBeforeScalingOut(); 122 void ProcessBeforeScalingIn(); 123 124 // The handlers after scheduler's scaling operations are done. 125 void ProcessAfterScalingOut(); 126 void ProcessAfterScalingIn(); 127 128 std::atomic_bool running_; 129 uint32_t server_num_; 130 uint32_t worker_num_; 131 std::string scheduler_ip_; 132 uint16_t scheduler_port_; 133 std::shared_ptr<ps::core::WorkerNode> worker_node_; 134 uint32_t rank_id_; 135 136 // The federated learning iteration number. 137 std::atomic<uint64_t> iteration_num_; 138 139 // Data size for this federated learning job. 140 int data_size_; 141 142 // The worker standalone training step number before communicating with server. This used in hybrid training mode. 143 uint64_t worker_step_num_per_iteration_; 144 145 // The iteration state is either running or completed. 146 // This variable represents the server iteration state and should be changed by events 147 // kIterationRunning/kIterationCompleted. triggered by server. 148 std::atomic<IterationState> server_iteration_state_; 149 150 // This variable represents the worker iteration state and should be changed by worker training process. 151 std::atomic<IterationState> worker_iteration_state_; 152 153 // The flag that represents whether worker is in safemode, which is decided by both worker and server iteration state. 154 std::atomic_bool safemode_; 155 }; 156 } // namespace worker 157 } // namespace fl 158 } // namespace mindspore 159 #endif // MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_ 160