• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
18 
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/container/node_hash_map.h"
25 #include "absl/strings/str_cat.h"
26 #include "fcp/base/monitoring.h"
27 #include "fcp/secagg/client/other_client_state.h"
28 #include "fcp/secagg/client/secagg_client_aborted_state.h"
29 #include "fcp/secagg/client/secagg_client_alive_base_state.h"
30 #include "fcp/secagg/client/secagg_client_completed_state.h"
31 #include "fcp/secagg/client/secagg_client_r3_unmasking_state.h"
32 #include "fcp/secagg/client/secagg_client_state.h"
33 #include "fcp/secagg/client/send_to_server_interface.h"
34 #include "fcp/secagg/shared/aes_gcm_encryption.h"
35 #include "fcp/secagg/shared/aes_key.h"
36 #include "fcp/secagg/shared/aes_prng_factory.h"
37 #include "fcp/secagg/shared/compute_session_id.h"
38 #include "fcp/secagg/shared/input_vector_specification.h"
39 #include "fcp/secagg/shared/map_of_masks.h"
40 #include "fcp/secagg/shared/secagg_messages.pb.h"
41 #include "fcp/secagg/shared/secagg_vector.h"
42 #include "fcp/secagg/shared/shamir_secret_sharing.h"
43 
44 namespace fcp {
45 namespace secagg {
46 
SecAggClientR2MaskedInputCollBaseState(std::unique_ptr<SendToServerInterface> sender,std::unique_ptr<StateTransitionListenerInterface> transition_listener,AsyncAbort * async_abort)47 SecAggClientR2MaskedInputCollBaseState::SecAggClientR2MaskedInputCollBaseState(
48     std::unique_ptr<SendToServerInterface> sender,
49     std::unique_ptr<StateTransitionListenerInterface> transition_listener,
50     AsyncAbort* async_abort)
51     : SecAggClientAliveBaseState(std::move(sender),
52                                  std::move(transition_listener),
53                                  ClientState::R2_MASKED_INPUT, async_abort) {}
54 
55 SecAggClientR2MaskedInputCollBaseState::
56     ~SecAggClientR2MaskedInputCollBaseState() = default;
57 
58 std::unique_ptr<SecAggVectorMap>
HandleMaskedInputCollectionRequest(const MaskedInputCollectionRequest & request,uint32_t client_id,const std::vector<InputVectorSpecification> & input_vector_specs,uint32_t minimum_surviving_neighbors_for_reconstruction,uint32_t number_of_clients,const std::vector<AesKey> & other_client_enc_keys,const std::vector<AesKey> & other_client_prng_keys,const ShamirShare & own_self_key_share,const AesKey & self_prng_key,const SessionId & session_id,const AesPrngFactory & prng_factory,uint32_t * number_of_alive_clients,std::vector<OtherClientState> * other_client_states,std::vector<ShamirShare> * pairwise_key_shares,std::vector<ShamirShare> * self_key_shares,std::string * error_message)59 SecAggClientR2MaskedInputCollBaseState::HandleMaskedInputCollectionRequest(
60     const MaskedInputCollectionRequest& request, uint32_t client_id,
61     const std::vector<InputVectorSpecification>& input_vector_specs,
62     uint32_t minimum_surviving_neighbors_for_reconstruction,
63     uint32_t number_of_clients,
64     const std::vector<AesKey>& other_client_enc_keys,
65     const std::vector<AesKey>& other_client_prng_keys,
66     const ShamirShare& own_self_key_share, const AesKey& self_prng_key,
67     const SessionId& session_id, const AesPrngFactory& prng_factory,
68     uint32_t* number_of_alive_clients,
69     std::vector<OtherClientState>* other_client_states,
70     std::vector<ShamirShare>* pairwise_key_shares,
71     std::vector<ShamirShare>* self_key_shares, std::string* error_message) {
72   if (request.encrypted_key_shares_size() !=
73       static_cast<int>(number_of_clients)) {
74     *error_message =
75         "The number of encrypted shares sent by the server does not match "
76         "the number of clients.";
77     return nullptr;
78   }
79 
80   // Parse the request, decrypt and store the key shares from other clients.
81   AesGcmEncryption decryptor;
82   std::string plaintext;
83 
84   for (int i = 0; i < static_cast<int>(number_of_clients); ++i) {
85     if (async_abort_ && async_abort_->Signalled()) {
86       *error_message = async_abort_->Message();
87       return nullptr;
88     }
89     if (i == static_cast<int>(client_id)) {
90       // this client
91       pairwise_key_shares->push_back({""});  // this will never be needed
92       self_key_shares->push_back(own_self_key_share);
93     } else if ((*other_client_states)[i] != OtherClientState::kAlive) {
94       if (request.encrypted_key_shares(i).length() > 0) {
95         // A client who was considered aborted sent key shares.
96         *error_message =
97             "Received encrypted key shares from an aborted client.";
98         return nullptr;
99       } else {
100         pairwise_key_shares->push_back({""});
101         self_key_shares->push_back({""});
102       }
103     } else if (request.encrypted_key_shares(i).length() == 0) {
104       // A client who was considered alive dropped out. Mark it as dead.
105       (*other_client_states)[i] = OtherClientState::kDeadAtRound2;
106       pairwise_key_shares->push_back({""});
107       self_key_shares->push_back({""});
108       --(*number_of_alive_clients);
109     } else {
110       // A living client sent encrypted key shares, so we decrypt and store
111       // them.
112       auto decrypted = decryptor.Decrypt(other_client_enc_keys[i],
113                                          request.encrypted_key_shares(i));
114       if (!decrypted.ok()) {
115         *error_message = "Authentication of encrypted data failed.";
116         return nullptr;
117       } else {
118         plaintext = decrypted.value();
119       }
120 
121       PairOfKeyShares pairwise_and_self_key_shares;
122       if (!pairwise_and_self_key_shares.ParseFromString(plaintext)) {
123         *error_message = "Unable to parse decrypted pair of key shares.";
124         return nullptr;
125       }
126       pairwise_key_shares->push_back(
127           {pairwise_and_self_key_shares.noise_sk_share()});
128       self_key_shares->push_back({pairwise_and_self_key_shares.prf_sk_share()});
129     }
130   }
131 
132   if (*number_of_alive_clients <
133       minimum_surviving_neighbors_for_reconstruction) {
134     *error_message =
135         "There are not enough clients to complete this protocol session. "
136         "Aborting.";
137     return nullptr;
138   }
139 
140   // Compute the map of masks using the other clients' keys.
141   std::vector<AesKey> prng_keys_to_add;
142   std::vector<AesKey> prng_keys_to_subtract;
143 
144   prng_keys_to_add.push_back(self_prng_key);
145 
146   for (int i = 0; i < static_cast<int>(number_of_clients); ++i) {
147     if (async_abort_ && async_abort_->Signalled()) {
148       *error_message = async_abort_->Message();
149       return nullptr;
150     }
151     if (i == static_cast<int>(client_id) ||
152         (*other_client_states)[i] != OtherClientState::kAlive) {
153       continue;
154     } else if (i < static_cast<int>(client_id)) {
155       prng_keys_to_add.push_back(other_client_prng_keys[i]);
156     } else {
157       prng_keys_to_subtract.push_back(other_client_prng_keys[i]);
158     }
159   }
160 
161   std::unique_ptr<SecAggVectorMap> map =
162       MapOfMasks(prng_keys_to_add, prng_keys_to_subtract, input_vector_specs,
163                  session_id, prng_factory, async_abort_);
164   if (!map) {
165     *error_message = async_abort_->Message();
166     return nullptr;
167   }
168   return map;
169 }
170 
171 // TODO(team): Add two SecAggVector values more efficiently, without
172 // having to unpack both vectors and convert the result back into the
173 // packed form.
AddSecAggVectors(SecAggVector v1,SecAggVector v2)174 SecAggVector AddSecAggVectors(SecAggVector v1, SecAggVector v2) {
175   FCP_CHECK(v1.modulus() == v2.modulus());
176   uint64_t modulus = v1.modulus();
177 
178   // The code below moves v1 and v2 to temp instances to "consume" and destroy
179   // the original vectors as soon as possible in order to minimize the number of
180   // concurrent copies of the data in memory.
181   std::vector<uint64_t> vec1 = SecAggVector(std::move(v1)).GetAsUint64Vector();
182 
183   {
184     // Keep vec2 scoped so that it is destroyed as soon as it is no longer used
185     // and before creating the SecAggVector instance below.
186     std::vector<uint64_t> vec2 =
187         SecAggVector(std::move(v2)).GetAsUint64Vector();
188 
189     // Add the two vectors in place assigning the values back into vec1.
190     FCP_CHECK(vec1.size() == vec2.size());
191     for (int i = 0; i < static_cast<int>(vec1.size()); ++i) {
192       vec1[i] = ((vec1[i] + vec2[i]) % modulus);
193     }
194   }
195 
196   return SecAggVector(vec1, modulus);
197 }
198 
SendMaskedInput(std::unique_ptr<SecAggVectorMap> input_map,std::unique_ptr<SecAggVectorMap> map_of_masks)199 void SecAggClientR2MaskedInputCollBaseState::SendMaskedInput(
200     std::unique_ptr<SecAggVectorMap> input_map,
201     std::unique_ptr<SecAggVectorMap> map_of_masks) {
202   ClientToServerWrapperMessage to_send;
203   for (auto& pair : *input_map) {
204     // SetInput should already have guaranteed these
205     FCP_CHECK(map_of_masks->find(pair.first) != map_of_masks->end());
206     SecAggVector& mask = map_of_masks->at(pair.first);
207     SecAggVector sum =
208         AddSecAggVectors(std::move(pair.second), std::move(mask));
209     MaskedInputVector sum_vec_proto;
210     sum_vec_proto.set_encoded_vector(std::move(sum).TakePackedBytes());
211     (*to_send.mutable_masked_input_response()->mutable_vectors())[pair.first] =
212         std::move(sum_vec_proto);
213   }
214   sender_->Send(&to_send);
215 }
216 
217 }  // namespace secagg
218 }  // namespace fcp
219