• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_not_set_state.h"
18 
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/node_hash_map.h"
26 #include "fcp/base/monitoring.h"
27 #include "fcp/secagg/client/other_client_state.h"
28 #include "fcp/secagg/client/secagg_client_aborted_state.h"
29 #include "fcp/secagg/client/secagg_client_completed_state.h"
30 #include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
31 #include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h"
32 #include "fcp/secagg/client/secagg_client_r2_masked_input_coll_waiting_for_input_state.h"
33 #include "fcp/secagg/client/secagg_client_state.h"
34 #include "fcp/secagg/client/send_to_server_interface.h"
35 #include "fcp/secagg/client/state_transition_listener_interface.h"
36 #include "fcp/secagg/shared/aes_key.h"
37 #include "fcp/secagg/shared/aes_prng_factory.h"
38 #include "fcp/secagg/shared/compute_session_id.h"
39 #include "fcp/secagg/shared/input_vector_specification.h"
40 #include "fcp/secagg/shared/secagg_messages.pb.h"
41 #include "fcp/secagg/shared/secagg_vector.h"
42 #include "fcp/secagg/shared/shamir_secret_sharing.h"
43 
44 namespace fcp {
45 namespace secagg {
46 
47 SecAggClientR2MaskedInputCollInputNotSetState::
SecAggClientR2MaskedInputCollInputNotSetState(uint32_t client_id,uint32_t minimum_surviving_neighbors_for_reconstruction,uint32_t number_of_alive_neighbors,uint32_t number_of_neighbors,std::unique_ptr<std::vector<InputVectorSpecification>> input_vector_specs,std::unique_ptr<std::vector<OtherClientState>> other_client_states,std::unique_ptr<std::vector<AesKey>> other_client_enc_keys,std::unique_ptr<std::vector<AesKey>> other_client_prng_keys,std::unique_ptr<ShamirShare> own_self_key_share,std::unique_ptr<AesKey> self_prng_key,std::unique_ptr<SendToServerInterface> sender,std::unique_ptr<StateTransitionListenerInterface> transition_listener,std::unique_ptr<SessionId> session_id,std::unique_ptr<AesPrngFactory> prng_factory,AsyncAbort * async_abort)48     SecAggClientR2MaskedInputCollInputNotSetState(
49         uint32_t client_id,
50         uint32_t minimum_surviving_neighbors_for_reconstruction,
51         uint32_t number_of_alive_neighbors, uint32_t number_of_neighbors,
52         std::unique_ptr<std::vector<InputVectorSpecification> >
53             input_vector_specs,
54         std::unique_ptr<std::vector<OtherClientState> > other_client_states,
55         std::unique_ptr<std::vector<AesKey> > other_client_enc_keys,
56         std::unique_ptr<std::vector<AesKey> > other_client_prng_keys,
57         std::unique_ptr<ShamirShare> own_self_key_share,
58         std::unique_ptr<AesKey> self_prng_key,
59         std::unique_ptr<SendToServerInterface> sender,
60         std::unique_ptr<StateTransitionListenerInterface> transition_listener,
61 
62         std::unique_ptr<SessionId> session_id,
63         std::unique_ptr<AesPrngFactory> prng_factory, AsyncAbort* async_abort)
64     : SecAggClientR2MaskedInputCollBaseState(
65           std::move(sender), std::move(transition_listener), async_abort),
66       client_id_(client_id),
67       minimum_surviving_neighbors_for_reconstruction_(
68           minimum_surviving_neighbors_for_reconstruction),
69       number_of_alive_neighbors_(number_of_alive_neighbors),
70       number_of_neighbors_(number_of_neighbors),
71       input_vector_specs_(std::move(input_vector_specs)),
72       other_client_states_(std::move(other_client_states)),
73       other_client_enc_keys_(std::move(other_client_enc_keys)),
74       other_client_prng_keys_(std::move(other_client_prng_keys)),
75       own_self_key_share_(std::move(own_self_key_share)),
76       self_prng_key_(std::move(self_prng_key)),
77       session_id_(std::move(session_id)),
78       prng_factory_(std::move(prng_factory)) {
79   FCP_CHECK(client_id_ >= 0)
80       << "Client id must not be negative but was " << client_id_;
81 }
82 
83 SecAggClientR2MaskedInputCollInputNotSetState::
84     ~SecAggClientR2MaskedInputCollInputNotSetState() = default;
85 
86 StatusOr<std::unique_ptr<SecAggClientState> >
HandleMessage(const ServerToClientWrapperMessage & message)87 SecAggClientR2MaskedInputCollInputNotSetState::HandleMessage(
88     const ServerToClientWrapperMessage& message) {
89   // Handle abort messages or masked input requests only.
90   if (message.has_abort()) {
91     if (message.abort().early_success()) {
92       return {std::make_unique<SecAggClientCompletedState>(
93           std::move(sender_), std::move(transition_listener_))};
94     } else {
95       return {std::make_unique<SecAggClientAbortedState>(
96           "Aborting because of abort message from the server.",
97           std::move(sender_), std::move(transition_listener_))};
98     }
99   } else if (!message.has_masked_input_request()) {
100     // Returns an error indicating that the message is of invalid type.
101     return SecAggClientState::HandleMessage(message);
102   }
103 
104   const MaskedInputCollectionRequest& request = message.masked_input_request();
105   std::string error_message;
106   auto pairwise_key_shares = std::make_unique<std::vector<ShamirShare> >();
107   auto self_key_shares = std::make_unique<std::vector<ShamirShare> >();
108 
109   std::unique_ptr<SecAggVectorMap> map_of_masks =
110       HandleMaskedInputCollectionRequest(
111           request, client_id_, *input_vector_specs_,
112           minimum_surviving_neighbors_for_reconstruction_, number_of_neighbors_,
113           *other_client_enc_keys_, *other_client_prng_keys_,
114           *own_self_key_share_, *self_prng_key_, *session_id_, *prng_factory_,
115           &number_of_alive_neighbors_, other_client_states_.get(),
116           pairwise_key_shares.get(), self_key_shares.get(), &error_message);
117 
118   if (!map_of_masks) {
119     return AbortAndNotifyServer(error_message);
120   }
121 
122   return {std::make_unique<SecAggClientR2MaskedInputCollWaitingForInputState>(
123       client_id_, minimum_surviving_neighbors_for_reconstruction_,
124       number_of_alive_neighbors_, number_of_neighbors_,
125       std::move(input_vector_specs_), std::move(map_of_masks),
126       std::move(other_client_states_), std::move(pairwise_key_shares),
127       std::move(self_key_shares), std::move(sender_),
128       std::move(transition_listener_), async_abort_)};
129 }
130 
131 StatusOr<std::unique_ptr<SecAggClientState> >
SetInput(std::unique_ptr<SecAggVectorMap> input_map)132 SecAggClientR2MaskedInputCollInputNotSetState::SetInput(
133     std::unique_ptr<SecAggVectorMap> input_map) {
134   if (!ValidateInput(*input_map, *input_vector_specs_)) {
135     return FCP_STATUS(INVALID_ARGUMENT)
136            << "The input to SetInput does not match the "
137               "InputVectorSpecification.";
138   }
139 
140   return {std::make_unique<SecAggClientR2MaskedInputCollInputSetState>(
141       client_id_, minimum_surviving_neighbors_for_reconstruction_,
142       number_of_alive_neighbors_, number_of_neighbors_, std::move(input_map),
143       std::move(input_vector_specs_), std::move(other_client_states_),
144       std::move(other_client_enc_keys_), std::move(other_client_prng_keys_),
145       std::move(own_self_key_share_), std::move(self_prng_key_),
146       std::move(sender_), std::move(transition_listener_),
147       std::move(session_id_), std::move(prng_factory_), async_abort_)};
148 }
149 
StateName() const150 std::string SecAggClientR2MaskedInputCollInputNotSetState::StateName() const {
151   return "R2_MASKED_INPUT_COLL_INPUT_NOT_SET";
152 }
153 
154 }  // namespace secagg
155 }  // namespace fcp
156