1 /*
2 * Copyright 2018 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/client/secagg_client_r2_masked_input_coll_input_set_state.h"
18
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "fcp/base/monitoring.h"
26 #include "fcp/secagg/client/other_client_state.h"
27 #include "fcp/secagg/client/secagg_client_aborted_state.h"
28 #include "fcp/secagg/client/secagg_client_completed_state.h"
29 #include "fcp/secagg/client/secagg_client_r2_masked_input_coll_base_state.h"
30 #include "fcp/secagg/client/secagg_client_r3_unmasking_state.h"
31 #include "fcp/secagg/client/secagg_client_state.h"
32 #include "fcp/secagg/client/send_to_server_interface.h"
33 #include "fcp/secagg/client/state_transition_listener_interface.h"
34 #include "fcp/secagg/shared/aes_key.h"
35 #include "fcp/secagg/shared/aes_prng_factory.h"
36 #include "fcp/secagg/shared/compute_session_id.h"
37 #include "fcp/secagg/shared/input_vector_specification.h"
38 #include "fcp/secagg/shared/secagg_messages.pb.h"
39 #include "fcp/secagg/shared/secagg_vector.h"
40 #include "fcp/secagg/shared/shamir_secret_sharing.h"
41
42 namespace fcp {
43 namespace secagg {
44
45 SecAggClientR2MaskedInputCollInputSetState::
SecAggClientR2MaskedInputCollInputSetState(uint32_t client_id,uint32_t minimum_surviving_neighbors_for_reconstruction,uint32_t number_of_alive_neighbors,uint32_t number_of_neighbors,std::unique_ptr<SecAggVectorMap> input_map,std::unique_ptr<std::vector<InputVectorSpecification>> input_vector_specs,std::unique_ptr<std::vector<OtherClientState>> other_client_states,std::unique_ptr<std::vector<AesKey>> other_client_enc_keys,std::unique_ptr<std::vector<AesKey>> other_client_prng_keys,std::unique_ptr<ShamirShare> own_self_key_share,std::unique_ptr<AesKey> self_prng_key,std::unique_ptr<SendToServerInterface> sender,std::unique_ptr<StateTransitionListenerInterface> transition_listener,std::unique_ptr<SessionId> session_id,std::unique_ptr<AesPrngFactory> prng_factory,AsyncAbort * async_abort)46 SecAggClientR2MaskedInputCollInputSetState(
47 uint32_t client_id,
48 uint32_t minimum_surviving_neighbors_for_reconstruction,
49 uint32_t number_of_alive_neighbors, uint32_t number_of_neighbors,
50 std::unique_ptr<SecAggVectorMap> input_map,
51 std::unique_ptr<std::vector<InputVectorSpecification> >
52 input_vector_specs,
53 std::unique_ptr<std::vector<OtherClientState> > other_client_states,
54 std::unique_ptr<std::vector<AesKey> > other_client_enc_keys,
55 std::unique_ptr<std::vector<AesKey> > other_client_prng_keys,
56 std::unique_ptr<ShamirShare> own_self_key_share,
57 std::unique_ptr<AesKey> self_prng_key,
58 std::unique_ptr<SendToServerInterface> sender,
59 std::unique_ptr<StateTransitionListenerInterface> transition_listener,
60
61 std::unique_ptr<SessionId> session_id,
62 std::unique_ptr<AesPrngFactory> prng_factory, AsyncAbort* async_abort)
63 : SecAggClientR2MaskedInputCollBaseState(
64 std::move(sender), std::move(transition_listener), async_abort),
65 client_id_(client_id),
66 minimum_surviving_neighbors_for_reconstruction_(
67 minimum_surviving_neighbors_for_reconstruction),
68 number_of_alive_neighbors_(number_of_alive_neighbors),
69 number_of_neighbors_(number_of_neighbors),
70 input_map_(std::move(input_map)),
71 input_vector_specs_(std::move(input_vector_specs)),
72 other_client_states_(std::move(other_client_states)),
73 other_client_enc_keys_(std::move(other_client_enc_keys)),
74 other_client_prng_keys_(std::move(other_client_prng_keys)),
75 own_self_key_share_(std::move(own_self_key_share)),
76 self_prng_key_(std::move(self_prng_key)),
77 session_id_(std::move(session_id)),
78 prng_factory_(std::move(prng_factory)) {
79 FCP_CHECK(client_id_ >= 0)
80 << "Client id must not be negative but was " << client_id_;
81 }
82
83 SecAggClientR2MaskedInputCollInputSetState::
84 ~SecAggClientR2MaskedInputCollInputSetState() = default;
85
86 StatusOr<std::unique_ptr<SecAggClientState> >
HandleMessage(const ServerToClientWrapperMessage & message)87 SecAggClientR2MaskedInputCollInputSetState::HandleMessage(
88 const ServerToClientWrapperMessage& message) {
89 // Handle abort messages or masked input requests only.
90 if (message.has_abort()) {
91 if (message.abort().early_success()) {
92 return {std::make_unique<SecAggClientCompletedState>(
93 std::move(sender_), std::move(transition_listener_))};
94 } else {
95 return {std::make_unique<SecAggClientAbortedState>(
96 "Aborting because of abort message from the server.",
97 std::move(sender_), std::move(transition_listener_))};
98 }
99 } else if (!message.has_masked_input_request()) {
100 // Returns an error indicating that the message is of invalid type.
101 return SecAggClientState::HandleMessage(message);
102 }
103
104 const MaskedInputCollectionRequest& request = message.masked_input_request();
105 std::string error_message;
106 auto pairwise_key_shares = std::make_unique<std::vector<ShamirShare> >();
107 auto self_key_shares = std::make_unique<std::vector<ShamirShare> >();
108
109 std::unique_ptr<SecAggVectorMap> map_of_masks =
110 HandleMaskedInputCollectionRequest(
111 request, client_id_, *input_vector_specs_,
112 minimum_surviving_neighbors_for_reconstruction_, number_of_neighbors_,
113 *other_client_enc_keys_, *other_client_prng_keys_,
114 *own_self_key_share_, *self_prng_key_, *session_id_, *prng_factory_,
115 &number_of_alive_neighbors_, other_client_states_.get(),
116 pairwise_key_shares.get(), self_key_shares.get(), &error_message);
117
118 if (!map_of_masks) {
119 return AbortAndNotifyServer(error_message);
120 }
121
122 SendMaskedInput(std::move(input_map_), std::move(map_of_masks));
123
124 return {std::make_unique<SecAggClientR3UnmaskingState>(
125 client_id_, number_of_alive_neighbors_,
126 minimum_surviving_neighbors_for_reconstruction_, number_of_neighbors_,
127 std::move(other_client_states_), std::move(pairwise_key_shares),
128 std::move(self_key_shares), std::move(sender_),
129 std::move(transition_listener_), async_abort_)};
130 }
131
StateName() const132 std::string SecAggClientR2MaskedInputCollInputSetState::StateName() const {
133 return "R2_MASKED_INPUT_COLL_INPUT_SET";
134 }
135
136 } // namespace secagg
137 } // namespace fcp
138