• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/armour/cipher/cipher_meta_storage.h"
18 
19 namespace mindspore {
20 namespace armour {
GetClientSharesFromServer(const char * list_name,std::map<std::string,std::vector<clientshare_str>> * clients_shares_list)21 void CipherMetaStorage::GetClientSharesFromServer(
22   const char *list_name, std::map<std::string, std::vector<clientshare_str>> *clients_shares_list) {
23   if (clients_shares_list == nullptr) {
24     MS_LOG(ERROR) << "input clients_shares_list is nullptr";
25     return;
26   }
27   const fl::PBMetadata &clients_shares_pb_out =
28     fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
29   const fl::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares();
30   auto iter = clients_shares_pb.client_secret_shares().begin();
31   for (; iter != clients_shares_pb.client_secret_shares().end(); ++iter) {
32     std::string fl_id = iter->first;
33     const fl::SharesPb &shares_pb = iter->second;
34     std::vector<clientshare_str> encrpted_shares_new;
35     size_t client_share_num = IntToSize(shares_pb.clientsharestrs_size());
36     for (size_t index_shares = 0; index_shares < client_share_num; ++index_shares) {
37       const fl::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares);
38       clientshare_str new_clientshare;
39       new_clientshare.fl_id = client_share_str_pb.fl_id();
40       new_clientshare.index = client_share_str_pb.index();
41       new_clientshare.share.assign(client_share_str_pb.share().begin(), client_share_str_pb.share().end());
42       encrpted_shares_new.push_back(new_clientshare);
43     }
44     clients_shares_list->insert(std::pair<std::string, std::vector<clientshare_str>>(fl_id, encrpted_shares_new));
45   }
46 }
47 
GetClientListFromServer(const char * list_name,std::vector<std::string> * clients_list)48 void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list) {
49   if (clients_list == nullptr) {
50     MS_LOG(ERROR) << "input clients_list is nullptr";
51     return;
52   }
53   const fl::PBMetadata &client_list_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
54   const fl::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
55   size_t client_list_num = IntToSize(client_list_pb.fl_id_size());
56   for (size_t i = 0; i < client_list_num; ++i) {
57     std::string fl_id = client_list_pb.fl_id(SizeToInt(i));
58     clients_list->push_back(fl_id);
59   }
60 }
61 
GetClientKeysFromServer(const char * list_name,std::map<std::string,std::vector<std::vector<uint8_t>>> * clients_keys_list)62 void CipherMetaStorage::GetClientKeysFromServer(
63   const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list) {
64   if (clients_keys_list == nullptr) {
65     MS_LOG(ERROR) << "input clients_keys_list is nullptr";
66     return;
67   }
68   const fl::PBMetadata &clients_keys_pb_out =
69     fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
70   const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
71 
72   for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
73     std::string fl_id = iter->first;
74     fl::KeysPb keys_pb = iter->second;
75     std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
76     std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
77     std::vector<std::vector<uint8_t>> cur_keys;
78     cur_keys.push_back(cpk);
79     cur_keys.push_back(spk);
80     (void)clients_keys_list->emplace(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys));
81   }
82 }
83 
GetClientIVsFromServer(const char * list_name,std::map<std::string,std::vector<std::vector<uint8_t>>> * clients_ivs_list)84 void CipherMetaStorage::GetClientIVsFromServer(
85   const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list) {
86   if (clients_ivs_list == nullptr) {
87     MS_LOG(ERROR) << "input clients_ivs_list is nullptr";
88     return;
89   }
90   const fl::PBMetadata &clients_keys_pb_out =
91     fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
92   const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
93 
94   for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
95     std::string fl_id = iter->first;
96     fl::KeysPb keys_pb = iter->second;
97     std::vector<uint8_t> ind_iv(keys_pb.ind_iv().begin(), keys_pb.ind_iv().end());
98     std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
99     std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
100 
101     std::vector<std::vector<uint8_t>> cur_ivs;
102     cur_ivs.push_back(ind_iv);
103     cur_ivs.push_back(pw_iv);
104     cur_ivs.push_back(pw_salt);
105     (void)clients_ivs_list->emplace(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_ivs));
106   }
107 }
108 
GetClientNoisesFromServer(const char * list_name,std::vector<float> * cur_public_noise)109 bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise) {
110   if (cur_public_noise == nullptr) {
111     MS_LOG(ERROR) << "input cur_public_noise is nullptr";
112     return false;
113   }
114   const fl::PBMetadata &clients_noises_pb_out =
115     fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
116   const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
117   int count = 0;
118   const int count_thld = 1000;
119   while (clients_noises_pb.has_one_client_noises() == false) {
120     const int register_time = 500;
121     std::this_thread::sleep_for(std::chrono::milliseconds(register_time));
122     count++;
123     if (count >= count_thld) break;
124   }
125   if (clients_noises_pb.has_one_client_noises() == false) {
126     MS_LOG(WARNING) << "GetClientNoisesFromServer Count: " << count;
127     return false;
128   }
129   cur_public_noise->assign(clients_noises_pb.one_client_noises().noise().begin(),
130                            clients_noises_pb.one_client_noises().noise().end());
131   return true;
132 }
133 
GetPrimeFromServer(const char * prime_name,uint8_t * prime)134 bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, uint8_t *prime) {
135   if (prime == nullptr) {
136     MS_LOG(ERROR) << "input prime is nullptr";
137     return false;
138   }
139   const fl::PBMetadata &prime_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
140   fl::Prime prime_pb(prime_pb_out.prime());
141   std::string str = *(prime_pb.mutable_prime());
142   if (str.size() != PRIME_MAX_LEN) {
143     MS_LOG(ERROR) << "get prime size is :" << str.size();
144     return false;
145   } else {
146     if (memcpy_s(prime, PRIME_MAX_LEN, str.data(), str.size()) != 0) {
147       MS_LOG(ERROR) << "Memcpy_s error";
148       return false;
149     }
150     return true;
151   }
152 }
153 
UpdateClientToServer(const char * list_name,const std::string & fl_id)154 bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
155   fl::FLId fl_id_pb;
156   fl_id_pb.set_fl_id(fl_id);
157   fl::PBMetadata client_pb;
158   client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
159   bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
160   return retcode;
161 }
162 
RegisterPrime(const char * list_name,const std::string & prime)163 void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
164   MS_LOG(INFO) << "register prime: " << prime;
165   fl::Prime prime_id_pb;
166   prime_id_pb.set_prime(prime);
167   fl::PBMetadata prime_pb;
168   prime_pb.mutable_prime()->MergeFrom(prime_id_pb);
169   fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb);
170   uint32_t time = 1;
171   (void)sleep(time);
172 }
173 
UpdateClientKeyToServer(const char * list_name,const std::string & fl_id,const std::vector<std::vector<uint8_t>> & cur_public_key)174 bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
175                                                 const std::vector<std::vector<uint8_t>> &cur_public_key) {
176   const size_t correct_size = 2;
177   if (cur_public_key.size() < correct_size) {
178     MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size();
179     return false;
180   }
181   // update new item to memory server.
182   fl::KeysPb keys;
183   keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
184   keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
185   fl::PairClientKeys pair_client_keys_pb;
186   pair_client_keys_pb.set_fl_id(fl_id);
187   pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
188   fl::PBMetadata client_and_keys_pb;
189   client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
190   bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
191   return retcode;
192 }
193 
UpdateClientKeyToServer(const char * list_name,const schema::RequestExchangeKeys * exchange_keys_req)194 bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name,
195                                                 const schema::RequestExchangeKeys *exchange_keys_req) {
196   std::string fl_id = exchange_keys_req->fl_id()->str();
197   auto fbs_cpk = exchange_keys_req->c_pk();
198   auto fbs_spk = exchange_keys_req->s_pk();
199   if (fbs_cpk == nullptr || fbs_spk == nullptr) {
200     MS_LOG(ERROR) << "public key from exchange_keys_req is null";
201     return false;
202   }
203 
204   size_t spk_len = fbs_spk->size();
205   size_t cpk_len = fbs_cpk->size();
206 
207   // transform fbs (fbs_cpk & fbs_spk) to a vector: public_key
208   std::vector<std::vector<uint8_t>> cur_public_key;
209   std::vector<uint8_t> cpk(cpk_len);
210   std::vector<uint8_t> spk(spk_len);
211   bool ret_create_code_cpk = CreateArray<uint8_t>(&cpk, *fbs_cpk);
212   bool ret_create_code_spk = CreateArray<uint8_t>(&spk, *fbs_spk);
213   if (!(ret_create_code_cpk && ret_create_code_spk)) {
214     MS_LOG(ERROR) << "create array for public keys failed";
215     return false;
216   }
217   cur_public_key.push_back(cpk);
218   cur_public_key.push_back(spk);
219 
220   auto fbs_ind_iv = exchange_keys_req->ind_iv();
221   std::vector<char> ind_iv;
222   if (fbs_ind_iv == nullptr) {
223     MS_LOG(WARNING) << "ind_iv in exchange_keys_req is nullptr";
224   } else {
225     ind_iv.assign(fbs_ind_iv->begin(), fbs_ind_iv->end());
226   }
227 
228   auto fbs_pw_iv = exchange_keys_req->pw_iv();
229   std::vector<char> pw_iv;
230   if (fbs_pw_iv == nullptr) {
231     MS_LOG(WARNING) << "pw_iv in exchange_keys_req is nullptr";
232   } else {
233     pw_iv.assign(fbs_pw_iv->begin(), fbs_pw_iv->end());
234   }
235 
236   auto fbs_pw_salt = exchange_keys_req->pw_salt();
237   std::vector<char> pw_salt;
238   if (fbs_pw_salt == nullptr) {
239     MS_LOG(WARNING) << "pw_salt in exchange_keys_req is nullptr";
240   } else {
241     pw_salt.assign(fbs_pw_salt->begin(), fbs_pw_salt->end());
242   }
243 
244   // update new item to memory server.
245   fl::KeysPb keys;
246   keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
247   keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
248   keys.set_ind_iv(ind_iv.data(), ind_iv.size());
249   keys.set_pw_iv(pw_iv.data(), pw_iv.size());
250   keys.set_pw_salt(pw_salt.data(), pw_salt.size());
251   fl::PairClientKeys pair_client_keys_pb;
252   pair_client_keys_pb.set_fl_id(fl_id);
253   pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
254   fl::PBMetadata client_and_keys_pb;
255   client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
256   bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
257   return retcode;
258 }
259 
UpdateClientNoiseToServer(const char * list_name,const std::vector<float> & cur_public_noise)260 bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise) {
261   // update new item to memory server.
262   fl::OneClientNoises noises_pb;
263   *noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
264   fl::PBMetadata client_noises_pb;
265   client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
266   bool ret = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
267   return ret;
268 }
269 
UpdateClientShareToServer(const char * list_name,const std::string & fl_id,const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> * shares)270 bool CipherMetaStorage::UpdateClientShareToServer(
271   const char *list_name, const std::string &fl_id,
272   const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
273   if (shares == nullptr) {
274     return false;
275   }
276   size_t size_shares = shares->size();
277   fl::SharesPb shares_pb;
278   for (size_t index = 0; index < size_shares; ++index) {
279     // new item
280     fl::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs();
281     std::string fl_id_new = (*shares)[SizeToInt(index)]->fl_id()->str();
282     int index_new = (*shares)[SizeToInt(index)]->index();
283     auto share = (*shares)[SizeToInt(index)]->share();
284     if (share == nullptr) return false;
285     client_share_str_new_p->set_share(reinterpret_cast<const char *>(share->data()), share->size());
286     client_share_str_new_p->set_fl_id(fl_id_new);
287     client_share_str_new_p->set_index(index_new);
288   }
289   fl::PairClientShares pair_client_shares_pb;
290   pair_client_shares_pb.set_fl_id(fl_id);
291   pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
292   fl::PBMetadata client_and_shares_pb;
293   client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
294   bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
295   return retcode;
296 }
297 
RegisterClass()298 void CipherMetaStorage::RegisterClass() {
299   fl::PBMetadata exchange_keys_client_list;
300   fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
301                                                                        exchange_keys_client_list);
302   fl::PBMetadata get_keys_client_list;
303   fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetKeysClientList,
304                                                                        get_keys_client_list);
305   fl::PBMetadata clients_keys;
306   fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
307   fl::PBMetadata reconstruct_client_list;
308   fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxReconstructClientList,
309                                                                        reconstruct_client_list);
310   fl::PBMetadata clients_reconstruct_shares;
311   fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsReconstructShares,
312                                                                        clients_reconstruct_shares);
313   fl::PBMetadata share_secretes_client_list;
314   fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList,
315                                                                        share_secretes_client_list);
316   fl::PBMetadata get_secretes_client_list;
317   fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetSecretsClientList,
318                                                                        get_secretes_client_list);
319   fl::PBMetadata clients_encrypt_shares;
320   fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares,
321                                                                        clients_encrypt_shares);
322   fl::PBMetadata get_update_clients_list;
323   fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetUpdateModelClientList,
324                                                                        get_update_clients_list);
325   fl::PBMetadata client_noises;
326   fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientNoises, client_noises);
327 }
328 }  // namespace armour
329 }  // namespace mindspore
330