• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
19 
20 #include <map>
21 #include <string>
22 #include <numeric>
23 #include <climits>
24 #include <memory>
25 #include <functional>
26 #include "proto/ps.pb.h"
27 #include "proto/fl.pb.h"
28 #include "ir/anf.h"
29 #include "utils/utils.h"
30 #include "ir/dtype/type_id.h"
31 #include "backend/kernel_compiler/cpu/cpu_kernel.h"
32 #include "schema/fl_job_generated.h"
33 #include "schema/cipher_generated.h"
34 #include "ps/ps_context.h"
35 #include "ps/core/communicator/http_message_handler.h"
36 #include "ps/core/communicator/tcp_server.h"
37 #include "ps/core/communicator/message_handler.h"
38 
39 namespace mindspore {
40 namespace fl {
41 namespace server {
42 // Definitions for the server framework.
43 enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
44 enum CommType { HTTP = 0, TCP };
45 enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
46 
47 struct RoundConfig {
48   // The name of round. Please refer to round kernel *.cc files.
49   std::string name;
50   // Whether this round has the time window limit.
51   bool check_timeout = false;
52   // The length of the time window. Only used when check_timeout is set to true.
53   size_t time_window = 3000;
54   // Whether this round has to check the request count has reached the threshold.
55   bool check_count = false;
56   // This round's request threshold count. Only used when check_count is set to true.
57   size_t threshold_count = 0;
58   // Whether this round uses the server as threshold count. This is vital for some rounds in elastic scaling scenario.
59   bool server_num_as_threshold = false;
60 };
61 
62 struct CipherConfig {
63   float share_secrets_ratio = 1.0;
64   uint64_t cipher_time_window = 300000;
65   size_t exchange_keys_threshold = 0;
66   size_t get_keys_threshold = 0;
67   size_t share_secrets_threshold = 0;
68   size_t get_secrets_threshold = 0;
69   size_t client_list_threshold = 0;
70   size_t reconstruct_secrets_threshold = 0;
71 };
72 
73 // Every instance is one training loop that runs fl_iteration_num iterations of federated learning.
74 // During every instance, server's training process could be controlled by scheduler, which will change the state of
75 // this instance.
76 enum class InstanceState {
77   // If this instance is in kRunning state, server could communicate with client/worker and the traning process moves
78   // on.
79   kRunning = 0,
80   // The server is not available for client/worker if in kDisable state.
81   kDisable,
82   // The server is not available for client/worker if in kDisable state. And this state means one instance has finished.
83   // In other words, fl_iteration_num iterations are completed.
84   kFinish
85 };
86 
87 using mindspore::kernel::Address;
88 using mindspore::kernel::AddressPtr;
89 using mindspore::kernel::CPUKernel;
90 using FBBuilder = flatbuffers::FlatBufferBuilder;
91 using TimeOutCb = std::function<void(bool, const std::string &)>;
92 using StopTimerCb = std::function<void(void)>;
93 using FinishIterCb = std::function<void(bool, const std::string &)>;
94 using FinalizeCb = std::function<void(void)>;
95 using MessageCallback = std::function<void(const std::shared_ptr<ps::core::MessageHandler> &)>;
96 
97 // Information about whether server kernel will reuse kernel node memory from the front end.
98 // Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
99 // Value refers to the kernel node's parameter index.
100 using ReuseKernelNodeInfo = std::map<std::string, size_t>;
101 
102 // UploadData refers to the data which is uploaded by workers.
103 // Key refers to the data name. For example: "weights", "grad", "learning_rate", etc. This will be set by the worker.
104 // Value refers to the data of the key.
105 
106 // We use Address instead of AddressPtr because:
107 // 1. Address doesn't need to call make_shared<T> so it has better performance.
108 // 2. The data uploaded by worker is normally parsed from FlatterBuffers or ProtoBuffer. For example: learning rate, new
109 // weights, etc. Address is enough to store these data.
110 
111 // Pay attention that Address only stores the void* pointer of the data, so the data must not be released before the
112 // related logic is done.
113 using UploadData = std::map<std::string, Address>;
114 
115 constexpr auto kWeight = "weight";
116 constexpr auto kNewWeight = "new_weight";
117 constexpr auto kAccumulation = "accum";
118 constexpr auto kLearningRate = "lr";
119 constexpr auto kGradient = "grad";
120 constexpr auto kNewGradient = "new_grad";
121 constexpr auto kMomentum = "momentum";
122 constexpr auto kIndices = "indices";
123 constexpr auto kAdamM = "m";
124 constexpr auto kAdamV = "v";
125 constexpr auto kAdamBeta1Power = "beta1_power";
126 constexpr auto kAdamBeta2Power = "beta2_power";
127 constexpr auto kAdamBeta1 = "beta1";
128 constexpr auto kAdamBeta2 = "beta2";
129 constexpr auto kAdamEps = "eps";
130 constexpr auto kFtrlLinear = "linear";
131 constexpr auto kDataSize = "data_size";
132 constexpr auto kNewDataSize = "new_data_size";
133 constexpr auto kStat = "stat";
134 
135 // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
136 // launched.
137 using OptimParamNameToIndex = std::map<std::string, std::map<std::string, size_t>>;
138 const OptimParamNameToIndex kMomentumNameToIdx = {
139   {"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kLearningRate, 2}, {kGradient, 3}, {kMomentum, 4}}}, {"outputs", {}}};
140 const OptimParamNameToIndex kAdamNameToIdx = {{"inputs",
141                                                {{kWeight, 0},
142                                                 {kAdamM, 1},
143                                                 {kAdamV, 2},
144                                                 {kAdamBeta1Power, 3},
145                                                 {kAdamBeta2Power, 4},
146                                                 {kLearningRate, 5},
147                                                 {kAdamBeta1, 6},
148                                                 {kAdamBeta2, 7},
149                                                 {kAdamEps, 8},
150                                                 {kGradient, 9}}},
151                                               {"outputs", {}}};
152 const OptimParamNameToIndex kSparseAdamNameToIdx = {{"inputs",
153                                                      {{kWeight, 0},
154                                                       {kAdamM, 1},
155                                                       {kAdamV, 2},
156                                                       {kAdamBeta1Power, 3},
157                                                       {kAdamBeta2Power, 4},
158                                                       {kLearningRate, 5},
159                                                       {kAdamBeta1, 6},
160                                                       {kAdamBeta1, 7},
161                                                       {kAdamEps, 8},
162                                                       {kGradient, 9},
163                                                       {kIndices, 10}}},
164                                                     {"outputs", {}}};
165 const OptimParamNameToIndex kSparseFtrlNameToIdx = {
166   {"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kFtrlLinear, 2}, {kGradient, 3}, {kIndices, 4}}}, {"outputs", {}}};
167 const OptimParamNameToIndex kAdamWeightDecayNameToIdx = {{"inputs",
168                                                           {{"weight", 0},
169                                                            {"m", 1},
170                                                            {"v", 2},
171                                                            {"lr", 3},
172                                                            {"beta1", 4},
173                                                            {"beta2", 5},
174                                                            {"eps", 6},
175                                                            {"weight_decay", 7},
176                                                            {"grad", 8}}},
177                                                          {"outputs", {}}};
178 const OptimParamNameToIndex kSGDNameToIdx = {
179   {"inputs", {{kWeight, 0}, {kGradient, 1}, {kLearningRate, 2}, {kAccumulation, 3}, {kMomentum, 4}, {kStat, 5}}},
180   {"outputs", {}}};
181 
182 const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {
183   {kApplyMomentumOpName, kMomentumNameToIdx},     {kFusedSparseAdamName, kSparseAdamNameToIdx},
184   {kSparseApplyFtrlOpName, kSparseFtrlNameToIdx}, {kApplyAdamOpName, kAdamNameToIdx},
185   {"AdamWeightDecay", kAdamWeightDecayNameToIdx}, {kSGDName, kSGDNameToIdx}};
186 
187 constexpr uint32_t kLeaderServerRank = 0;
188 constexpr size_t kWorkerMgrThreadPoolSize = 32;
189 constexpr size_t kWorkerMgrMaxTaskNum = 64;
190 constexpr size_t kCipherMgrThreadPoolSize = 32;
191 constexpr size_t kCipherMgrMaxTaskNum = 64;
192 constexpr size_t kExecutorThreadPoolSize = 32;
193 constexpr size_t kExecutorMaxTaskNum = 32;
194 constexpr int kHttpSuccess = 200;
195 constexpr uint32_t kThreadSleepTime = 50;
196 constexpr auto kPBProtocol = "PB";
197 constexpr auto kFBSProtocol = "FBS";
198 constexpr auto kSuccess = "Success";
199 constexpr auto kFedAvg = "FedAvg";
200 constexpr auto kAggregationKernelType = "Aggregation";
201 constexpr auto kOptimizerKernelType = "Optimizer";
202 constexpr auto kCtxFuncGraph = "FuncGraph";
203 constexpr auto kCtxIterNum = "iteration";
204 constexpr auto kCtxDeviceMetas = "device_metas";
205 constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
206 constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp";
207 constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
208 constexpr auto kCtxUpdateModelThld = "update_model_threshold";
209 constexpr auto kCtxClientsKeys = "clients_keys";
210 constexpr auto kCtxClientNoises = "clients_noises";
211 constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares";
212 constexpr auto kCtxClientsReconstructShares = "clients_restruct_shares";
213 constexpr auto kCtxShareSecretsClientList = "share_secrets_client_list";
214 constexpr auto kCtxGetSecretsClientList = "get_secrets_client_list";
215 constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
216 constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
217 constexpr auto kCtxGetUpdateModelClientList = "get_update_model_client_list";
218 constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
219 constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
220 constexpr auto kCtxCipherPrimer = "cipher_primer";
221 
222 // This macro the current timestamp in milliseconds.
223 #define CURRENT_TIME_MILLI \
224   std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
225 
226 // This method returns the size in bytes of the given TypeId.
GetTypeIdByte(const TypeId & type)227 inline size_t GetTypeIdByte(const TypeId &type) {
228   switch (type) {
229     case kNumberTypeFloat16:
230       return 2;
231     case kNumberTypeUInt32:
232     case kNumberTypeFloat32:
233       return 4;
234     case kNumberTypeUInt64:
235       return 8;
236     default:
237       MS_LOG(EXCEPTION) << "TypeId " << type << " not supported.";
238       return 0;
239   }
240 }
241 
GenerateParameterNodeAddrPtr(const CNodePtr & kernel_node,size_t param_idx)242 inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size_t param_idx) {
243   MS_ERROR_IF_NULL_W_RET_VAL(kernel_node, nullptr);
244   auto param_node =
245     AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, param_idx), 0).first->cast<ParameterPtr>();
246   MS_ERROR_IF_NULL_W_RET_VAL(param_node, nullptr);
247   auto param_tensor = param_node->default_param()->cast<tensor::TensorPtr>();
248   MS_ERROR_IF_NULL_W_RET_VAL(param_tensor, nullptr);
249   AddressPtr addr = std::make_shared<kernel::Address>();
250   addr->addr = param_tensor->data_c();
251   addr->size = param_tensor->data().nbytes();
252   return addr;
253 }
254 
255 // Definitions for Federated Learning.
256 
257 constexpr auto kNetworkError = "Cluster networking failed.";
258 
259 // The result code used for round kernels.
260 enum class ResultCode {
261   // If the method is successfully called and round kernel's residual methods should be called, return kSuccess.
262   kSuccess = 0,
263   // If there's error happened in the method and residual methods should not be called but this iteration continues,
264   // return kSuccessAndReturn so that framework won't drop this iteration.
265   kSuccessAndReturn,
266   // If there's error happened and this iteration should be dropped, return kFail.
267   kFail
268 };
269 
ConvertResultCode(ResultCode result_code)270 bool inline ConvertResultCode(ResultCode result_code) {
271   switch (result_code) {
272     case ResultCode::kSuccess:
273       return true;
274     case ResultCode::kSuccessAndReturn:
275       return true;
276     case ResultCode::kFail:
277       return false;
278     default:
279       return true;
280   }
281 }
282 
283 // Definitions for Parameter Server.
284 
285 }  // namespace server
286 }  // namespace fl
287 }  // namespace mindspore
288 #endif  // MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
289