• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h"
18 
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "fcp/base/monitoring.h"
26 #include "fcp/secagg/client/other_client_state.h"
27 #include "fcp/secagg/client/secagg_client_aborted_state.h"
28 #include "fcp/secagg/client/secagg_client_completed_state.h"
29 #include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
30 #include "fcp/secagg/client/secagg_client_r3_unmasking_state.h"
31 #include "fcp/secagg/client/secagg_client_state.h"
32 #include "fcp/secagg/client/send_to_server_interface.h"
33 #include "fcp/secagg/client/state_transition_listener_interface.h"
34 #include "fcp/secagg/shared/aes_key.h"
35 #include "fcp/secagg/shared/aes_prng_factory.h"
36 #include "fcp/secagg/shared/compute_session_id.h"
37 #include "fcp/secagg/shared/input_vector_specification.h"
38 #include "fcp/secagg/shared/secagg_messages.pb.h"
39 #include "fcp/secagg/shared/secagg_vector.h"
40 #include "fcp/secagg/shared/shamir_secret_sharing.h"
41 
42 namespace fcp {
43 namespace secagg {
44 
45 SecAggClientR2MaskedInputCollInputSetState::
SecAggClientR2MaskedInputCollInputSetState(uint32_t client_id,uint32_t minimum_surviving_neighbors_for_reconstruction,uint32_t number_of_alive_neighbors,uint32_t number_of_neighbors,std::unique_ptr<SecAggVectorMap> input_map,std::unique_ptr<std::vector<InputVectorSpecification>> input_vector_specs,std::unique_ptr<std::vector<OtherClientState>> other_client_states,std::unique_ptr<std::vector<AesKey>> other_client_enc_keys,std::unique_ptr<std::vector<AesKey>> other_client_prng_keys,std::unique_ptr<ShamirShare> own_self_key_share,std::unique_ptr<AesKey> self_prng_key,std::unique_ptr<SendToServerInterface> sender,std::unique_ptr<StateTransitionListenerInterface> transition_listener,std::unique_ptr<SessionId> session_id,std::unique_ptr<AesPrngFactory> prng_factory,AsyncAbort * async_abort)46     SecAggClientR2MaskedInputCollInputSetState(
47         uint32_t client_id,
48         uint32_t minimum_surviving_neighbors_for_reconstruction,
49         uint32_t number_of_alive_neighbors, uint32_t number_of_neighbors,
50         std::unique_ptr<SecAggVectorMap> input_map,
51         std::unique_ptr<std::vector<InputVectorSpecification> >
52             input_vector_specs,
53         std::unique_ptr<std::vector<OtherClientState> > other_client_states,
54         std::unique_ptr<std::vector<AesKey> > other_client_enc_keys,
55         std::unique_ptr<std::vector<AesKey> > other_client_prng_keys,
56         std::unique_ptr<ShamirShare> own_self_key_share,
57         std::unique_ptr<AesKey> self_prng_key,
58         std::unique_ptr<SendToServerInterface> sender,
59         std::unique_ptr<StateTransitionListenerInterface> transition_listener,
60 
61         std::unique_ptr<SessionId> session_id,
62         std::unique_ptr<AesPrngFactory> prng_factory, AsyncAbort* async_abort)
63     : SecAggClientR2MaskedInputCollBaseState(
64           std::move(sender), std::move(transition_listener), async_abort),
65       client_id_(client_id),
66       minimum_surviving_neighbors_for_reconstruction_(
67           minimum_surviving_neighbors_for_reconstruction),
68       number_of_alive_neighbors_(number_of_alive_neighbors),
69       number_of_neighbors_(number_of_neighbors),
70       input_map_(std::move(input_map)),
71       input_vector_specs_(std::move(input_vector_specs)),
72       other_client_states_(std::move(other_client_states)),
73       other_client_enc_keys_(std::move(other_client_enc_keys)),
74       other_client_prng_keys_(std::move(other_client_prng_keys)),
75       own_self_key_share_(std::move(own_self_key_share)),
76       self_prng_key_(std::move(self_prng_key)),
77       session_id_(std::move(session_id)),
78       prng_factory_(std::move(prng_factory)) {
79   FCP_CHECK(client_id_ >= 0)
80       << "Client id must not be negative but was " << client_id_;
81 }
82 
83 SecAggClientR2MaskedInputCollInputSetState::
84     ~SecAggClientR2MaskedInputCollInputSetState() = default;
85 
86 StatusOr<std::unique_ptr<SecAggClientState> >
HandleMessage(const ServerToClientWrapperMessage & message)87 SecAggClientR2MaskedInputCollInputSetState::HandleMessage(
88     const ServerToClientWrapperMessage& message) {
89   // Handle abort messages or masked input requests only.
90   if (message.has_abort()) {
91     if (message.abort().early_success()) {
92       return {std::make_unique<SecAggClientCompletedState>(
93           std::move(sender_), std::move(transition_listener_))};
94     } else {
95       return {std::make_unique<SecAggClientAbortedState>(
96           "Aborting because of abort message from the server.",
97           std::move(sender_), std::move(transition_listener_))};
98     }
99   } else if (!message.has_masked_input_request()) {
100     // Returns an error indicating that the message is of invalid type.
101     return SecAggClientState::HandleMessage(message);
102   }
103 
104   const MaskedInputCollectionRequest& request = message.masked_input_request();
105   std::string error_message;
106   auto pairwise_key_shares = std::make_unique<std::vector<ShamirShare> >();
107   auto self_key_shares = std::make_unique<std::vector<ShamirShare> >();
108 
109   std::unique_ptr<SecAggVectorMap> map_of_masks =
110       HandleMaskedInputCollectionRequest(
111           request, client_id_, *input_vector_specs_,
112           minimum_surviving_neighbors_for_reconstruction_, number_of_neighbors_,
113           *other_client_enc_keys_, *other_client_prng_keys_,
114           *own_self_key_share_, *self_prng_key_, *session_id_, *prng_factory_,
115           &number_of_alive_neighbors_, other_client_states_.get(),
116           pairwise_key_shares.get(), self_key_shares.get(), &error_message);
117 
118   if (!map_of_masks) {
119     return AbortAndNotifyServer(error_message);
120   }
121 
122   SendMaskedInput(std::move(input_map_), std::move(map_of_masks));
123 
124   return {std::make_unique<SecAggClientR3UnmaskingState>(
125       client_id_, number_of_alive_neighbors_,
126       minimum_surviving_neighbors_for_reconstruction_, number_of_neighbors_,
127       std::move(other_client_states_), std::move(pairwise_key_shares),
128       std::move(self_key_shares), std::move(sender_),
129       std::move(transition_listener_), async_abort_)};
130 }
131 
StateName() const132 std::string SecAggClientR2MaskedInputCollInputSetState::StateName() const {
133   return "R2_MASKED_INPUT_COLL_INPUT_SET";
134 }
135 
136 }  // namespace secagg
137 }  // namespace fcp
138