1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server_r2_masked_input_coll_state.h"
18
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "absl/container/node_hash_set.h"
28 #include "absl/strings/str_cat.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
31 #include "fcp/secagg/server/experiments_interface.h"
32 #include "fcp/secagg/server/experiments_names.h"
33 #include "fcp/secagg/server/secagg_server_state.h"
34 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
35 #include "fcp/secagg/server/send_to_clients_interface.h"
36 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
37 #include "fcp/secagg/shared/ecdh_key_agreement.h"
38 #include "fcp/secagg/shared/ecdh_keys.h"
39 #include "fcp/secagg/shared/input_vector_specification.h"
40 #include "fcp/secagg/shared/secagg_messages.pb.h"
41 #include "fcp/secagg/shared/shamir_secret_sharing.h"
42 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
43 #include "fcp/secagg/testing/fake_prng.h"
44 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
45 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
46 #include "fcp/secagg/testing/server/test_secagg_experiments.h"
47 #include "fcp/testing/testing.h"
48 #include "fcp/tracing/test_tracing_recorder.h"
49
50 namespace fcp {
51 namespace secagg {
52 namespace {
53
54 using ::testing::_;
55 using ::testing::Eq;
56 using ::testing::Ge;
57 using ::testing::IsFalse;
58 using ::testing::IsTrue;
59
60 class FakeScheduler : public Scheduler {
61 public:
Schedule(std::function<void ()> job)62 void Schedule(std::function<void()> job) override { jobs_.push_back(job); }
63
WaitUntilIdle()64 void WaitUntilIdle() override {}
65
Run()66 void Run() {
67 for (auto& job : jobs_) {
68 job();
69 }
70 jobs_.clear();
71 }
72
73 private:
74 std::vector<std::function<void()>> jobs_;
75 };
76
77 // Default test session_id.
78 SessionId session_id = {"session id number, 32 bytes long"};
79
80 struct SecAggR2StateTestParams {
81 const std::string test_name;
82 // Enables asymchronous processing of round 2 messages by the server.
83 bool enable_async_r2;
84 };
85
86 class SecaggServerR2MaskedInputCollStateTest
87 : public ::testing::TestWithParam<SecAggR2StateTestParams> {
88 protected:
CreateSecAggServerProtocolImpl(int minimum_number_of_clients_to_proceed,int total_number_of_clients,MockSendToClientsInterface * sender,MockSecAggServerMetricsListener * metrics_listener=nullptr,bool enable_async_r2=true)89 std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
90 int minimum_number_of_clients_to_proceed, int total_number_of_clients,
91 MockSendToClientsInterface* sender,
92 MockSecAggServerMetricsListener* metrics_listener = nullptr,
93 bool enable_async_r2 = true) {
94 auto input_vector_specs = std::vector<InputVectorSpecification>();
95 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
96 SecretSharingGraphFactory factory;
97 auto impl = std::make_unique<AesSecAggServerProtocolImpl>(
98 factory.CreateCompleteGraph(total_number_of_clients,
99 minimum_number_of_clients_to_proceed),
100 minimum_number_of_clients_to_proceed, input_vector_specs,
101 std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
102 std::make_unique<AesCtrPrngFactory>(), sender,
103 std::make_unique<SecAggScheduler>(¶llel_scheduler_,
104 &sequential_scheduler_),
105 std::vector<ClientStatus>(total_number_of_clients,
106 ClientStatus::SHARE_KEYS_RECEIVED),
107 ServerVariant::NATIVE_V1,
108 enable_async_r2
109 ? std::make_unique<TestSecAggExperiment>(
110 TestSecAggExperiment(kSecAggAsyncRound2Experiment))
111 : std::make_unique<TestSecAggExperiment>(TestSecAggExperiment()));
112 impl->set_session_id(std::make_unique<SessionId>(session_id));
113 EcdhPregeneratedTestKeys ecdh_keys;
114 for (int i = 0; i < total_number_of_clients; ++i) {
115 impl->SetPairwisePublicKeys(i, ecdh_keys.GetPublicKey(i));
116 }
117
118 return impl;
119 }
120
RunSchedulers()121 void RunSchedulers() {
122 parallel_scheduler_.Run();
123 sequential_scheduler_.Run();
124 }
125
126 private:
127 FakeScheduler parallel_scheduler_;
128 FakeScheduler sequential_scheduler_;
129 };
130
TEST_P(SecaggServerR2MaskedInputCollStateTest,IsAbortedReturnsFalse)131 TEST_P(SecaggServerR2MaskedInputCollStateTest, IsAbortedReturnsFalse) {
132 auto sender = std::make_unique<MockSendToClientsInterface>();
133
134 SecAggServerR2MaskedInputCollState state(
135 CreateSecAggServerProtocolImpl(3, 4, sender.get(),
136 nullptr /* metrics_listener */,
137 GetParam().enable_async_r2),
138 0, // number_of_clients_failed_after_sending_masked_input
139 0, // number_of_clients_failed_before_sending_masked_input
140 0 // number_of_clients_terminated_without_unmasking
141 );
142
143 EXPECT_THAT(state.IsAborted(), IsFalse());
144 }
145
TEST_P(SecaggServerR2MaskedInputCollStateTest,IsCompletedSuccessfullyReturnsFalse)146 TEST_P(SecaggServerR2MaskedInputCollStateTest,
147 IsCompletedSuccessfullyReturnsFalse) {
148 auto sender = std::make_unique<MockSendToClientsInterface>();
149
150 SecAggServerR2MaskedInputCollState state(
151 CreateSecAggServerProtocolImpl(3, 4, sender.get(),
152 nullptr /* metrics_listener */,
153 GetParam().enable_async_r2),
154 0, // number_of_clients_failed_after_sending_masked_input
155 0, // number_of_clients_failed_before_sending_masked_input
156 0 // number_of_clients_terminated_without_unmasking
157 );
158
159 EXPECT_THAT(state.IsCompletedSuccessfully(), IsFalse());
160 }
161
TEST_P(SecaggServerR2MaskedInputCollStateTest,ErrorMessageRaisesErrorStatus)162 TEST_P(SecaggServerR2MaskedInputCollStateTest, ErrorMessageRaisesErrorStatus) {
163 auto sender = std::make_unique<MockSendToClientsInterface>();
164
165 SecAggServerR2MaskedInputCollState state(
166 CreateSecAggServerProtocolImpl(3, 4, sender.get(),
167 nullptr /* metrics_listener */,
168 GetParam().enable_async_r2),
169 0, // number_of_clients_failed_after_sending_masked_input
170 0, // number_of_clients_failed_before_sending_masked_input
171 0 // number_of_clients_terminated_without_unmasking
172 );
173
174 EXPECT_THAT(state.ErrorMessage().ok(), IsFalse());
175 }
176
TEST_P(SecaggServerR2MaskedInputCollStateTest,ResultRaisesErrorStatus)177 TEST_P(SecaggServerR2MaskedInputCollStateTest, ResultRaisesErrorStatus) {
178 auto sender = std::make_unique<MockSendToClientsInterface>();
179
180 SecAggServerR2MaskedInputCollState state(
181 CreateSecAggServerProtocolImpl(3, 4, sender.get(),
182 nullptr /* metrics_listener */,
183 GetParam().enable_async_r2),
184 0, // number_of_clients_failed_after_sending_masked_input
185 0, // number_of_clients_failed_before_sending_masked_input
186 0 // number_of_clients_terminated_without_unmasking
187 );
188
189 EXPECT_THAT(state.Result().ok(), IsFalse());
190 }
191
TEST_P(SecaggServerR2MaskedInputCollStateTest,AbortReturnsValidStateAndNotifiesClients)192 TEST_P(SecaggServerR2MaskedInputCollStateTest,
193 AbortReturnsValidStateAndNotifiesClients) {
194 TestTracingRecorder tracing_recorder;
195 MockSecAggServerMetricsListener* metrics =
196 new MockSecAggServerMetricsListener();
197 auto sender = std::make_unique<MockSendToClientsInterface>();
198
199 SecAggServerR2MaskedInputCollState state(
200 CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics,
201 GetParam().enable_async_r2),
202 0, // number_of_clients_failed_after_sending_masked_input
203 0, // number_of_clients_failed_before_sending_masked_input
204 0 // number_of_clients_terminated_without_unmasking
205 );
206
207 ServerToClientWrapperMessage abort_message;
208 abort_message.mutable_abort()->set_early_success(false);
209 abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
210
211 EXPECT_CALL(*metrics,
212 ProtocolOutcomes(Eq(SecAggServerOutcome::EXTERNAL_REQUEST)));
213 EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
214 auto next_state =
215 state.Abort("test abort reason", SecAggServerOutcome::EXTERNAL_REQUEST);
216
217 ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
218 ASSERT_THAT(next_state->ErrorMessage(), IsOk());
219 EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
220 EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
221 ElementsAre(IsEvent<BroadcastMessageSent>(
222 Eq(ServerToClientMessageType_Abort),
223 Eq(abort_message.ByteSizeLong()))));
224 }
225
TEST_P(SecaggServerR2MaskedInputCollStateTest,StateProceedsCorrectlyWithAllClientsValid)226 TEST_P(SecaggServerR2MaskedInputCollStateTest,
227 StateProceedsCorrectlyWithAllClientsValid) {
228 // In this test, all clients send in their valid masked inputs, and then the
229 // server proceeds to the next state.
230 TestTracingRecorder tracing_recorder;
231 auto sender = std::make_unique<MockSendToClientsInterface>();
232
233 SecAggServerR2MaskedInputCollState state(
234 CreateSecAggServerProtocolImpl(3, 4, sender.get(),
235 nullptr /* metrics_listener */,
236 GetParam().enable_async_r2),
237 0, // number_of_clients_failed_after_sending_masked_input
238 0, // number_of_clients_failed_before_sending_masked_input
239 0 // number_of_clients_terminated_without_unmasking
240 );
241
242 for (int i = 0; i < 5; ++i) {
243 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
244 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
245 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
246 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
247 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
248 if (i < 3) {
249 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
250 } else {
251 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
252 }
253 if (GetParam().enable_async_r2) {
254 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
255 } else {
256 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
257 }
258
259 if (i < 4) {
260 // Have client send a vector of the correct size to the server
261 auto client_message = std::make_unique<ClientToServerWrapperMessage>();
262 MaskedInputVector encoded_vector;
263 SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
264 encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
265 (*client_message->mutable_masked_input_response()
266 ->mutable_vectors())["foobar"] = encoded_vector;
267 ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
268 if (GetParam().enable_async_r2) {
269 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
270 } else {
271 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
272 }
273 }
274 }
275
276 if (GetParam().enable_async_r2) {
277 RunSchedulers();
278 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
279 }
280
281 ServerToClientWrapperMessage server_message;
282 server_message.mutable_unmasking_request()
283 ->mutable_dead_3_client_ids()
284 ->Clear(); // Just to set it to an empty vector
285 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
286 for (int i = 0; i < 4; ++i) {
287 EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
288 }
289
290 auto next_state = state.ProceedToNextRound();
291 ASSERT_THAT(next_state, IsOk());
292 EXPECT_THAT(next_state.value()->State(),
293 Eq(SecAggServerStateKind::R3_UNMASKING));
294 EXPECT_THAT(
295 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
296 Eq(0));
297 EXPECT_THAT(
298 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
299 Eq(0));
300 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
301 Eq(0));
302 EXPECT_THAT(tracing_recorder.FindAllEvents<IndividualMessageSent>(),
303 ElementsAre(IsEvent<IndividualMessageSent>(
304 0, Eq(ServerToClientMessageType_UnmaskingRequest),
305 Eq(server_message.ByteSizeLong())),
306 IsEvent<IndividualMessageSent>(
307 1, Eq(ServerToClientMessageType_UnmaskingRequest),
308 Eq(server_message.ByteSizeLong())),
309 IsEvent<IndividualMessageSent>(
310 2, Eq(ServerToClientMessageType_UnmaskingRequest),
311 Eq(server_message.ByteSizeLong())),
312 IsEvent<IndividualMessageSent>(
313 3, Eq(ServerToClientMessageType_UnmaskingRequest),
314 Eq(server_message.ByteSizeLong()))));
315 }
316
TEST_P(SecaggServerR2MaskedInputCollStateTest,StateProceedsCorrectlyWithoutAllClients)317 TEST_P(SecaggServerR2MaskedInputCollStateTest,
318 StateProceedsCorrectlyWithoutAllClients) {
319 // In this test, clients 0 through 2 send in valid masked inputs, and then we
320 // proceed to the next step even without client 3.
321 auto sender = std::make_unique<MockSendToClientsInterface>();
322
323 SecAggServerR2MaskedInputCollState state(
324 CreateSecAggServerProtocolImpl(3, 4, sender.get(),
325 nullptr /* metrics_listener */,
326 GetParam().enable_async_r2),
327 0, // number_of_clients_failed_after_sending_masked_input
328 0, // number_of_clients_failed_before_sending_masked_input
329 0 // number_of_clients_terminated_without_unmasking
330 );
331
332 for (int i = 0; i < 4; ++i) {
333 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
334 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
335 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
336 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
337 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
338 if (i < 3) {
339 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
340 } else {
341 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
342 }
343 if (GetParam().enable_async_r2) {
344 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
345 } else {
346 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
347 }
348
349 if (i < 3) {
350 // Have client send a vector of the correct size to the server
351 auto client_message = std::make_unique<ClientToServerWrapperMessage>();
352 MaskedInputVector encoded_vector;
353 SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
354 encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
355 (*client_message->mutable_masked_input_response()
356 ->mutable_vectors())["foobar"] = encoded_vector;
357 ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
358 if (GetParam().enable_async_r2) {
359 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
360 } else {
361 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
362 }
363 }
364 }
365
366 if (GetParam().enable_async_r2) {
367 RunSchedulers();
368 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
369 }
370
371 ServerToClientWrapperMessage server_message;
372 // TODO(team): 4 -> 3 below, once backwards compatibility not needed.
373 server_message.mutable_unmasking_request()->add_dead_3_client_ids(4);
374 ServerToClientWrapperMessage abort_message;
375 abort_message.mutable_abort()->set_early_success(false);
376 abort_message.mutable_abort()->set_diagnostic_info(
377 "Client did not send MaskedInputCollectionResponse before round "
378 "transition.");
379 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
380 for (int i = 0; i < 3; ++i) {
381 EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
382 }
383 EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message))).Times(1);
384
385 auto next_state = state.ProceedToNextRound();
386 ASSERT_THAT(next_state, IsOk());
387 EXPECT_THAT(next_state.value()->State(),
388 Eq(SecAggServerStateKind::R3_UNMASKING));
389 EXPECT_THAT(
390 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
391 Eq(0));
392 EXPECT_THAT(
393 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
394 Eq(1));
395 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
396 Eq(0));
397 }
398
TEST_P(SecaggServerR2MaskedInputCollStateTest,StateProceedsCorrectlyWithOneClientSendingInvalidInput)399 TEST_P(SecaggServerR2MaskedInputCollStateTest,
400 StateProceedsCorrectlyWithOneClientSendingInvalidInput) {
401 // In this test, client 0 sends an invalid masked input, so it is aborted. The
402 // rest of the round goes normally.
403 auto sender = std::make_unique<MockSendToClientsInterface>();
404
405 SecAggServerR2MaskedInputCollState state(
406 CreateSecAggServerProtocolImpl(3, 4, sender.get(),
407 nullptr /* metrics_listener */,
408 GetParam().enable_async_r2),
409 0, // number_of_clients_failed_after_sending_masked_input
410 0, // number_of_clients_failed_before_sending_masked_input
411 0 // number_of_clients_terminated_without_unmasking
412 );
413
414 ServerToClientWrapperMessage server_message;
415 // TODO(team): 1 -> 0 below, once backwards compatibility not needed.
416 server_message.mutable_unmasking_request()->add_dead_3_client_ids(1);
417 ServerToClientWrapperMessage abort_message;
418 abort_message.mutable_abort()->set_early_success(false);
419 abort_message.mutable_abort()->set_diagnostic_info(
420 "Masked input does not match input vector specification - vector is "
421 "wrong size.");
422
423 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
424 for (int i = 1; i < 4; ++i) {
425 EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
426 }
427 EXPECT_CALL(*sender, Send(0, EqualsProto(abort_message))).Times(1);
428
429 // Have client 0 send an invalid message.
430 auto invalid_message = std::make_unique<ClientToServerWrapperMessage>();
431 MaskedInputVector encoded_vector;
432 encoded_vector.set_encoded_vector("not a real masked input vector - invalid");
433 (*invalid_message->mutable_masked_input_response()
434 ->mutable_vectors())["foobar"] = encoded_vector;
435 ASSERT_THAT(state.HandleMessage(0, std::move(invalid_message)), IsOk());
436 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
437 for (int i = 1; i < 5; ++i) {
438 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
439 EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
440 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i - 1));
441 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i - 1));
442 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
443 if (i < 4) {
444 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(4 - i));
445 } else {
446 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
447 }
448 if (GetParam().enable_async_r2) {
449 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
450 } else {
451 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 4));
452 }
453
454 if (i < 4) {
455 // Have client send a vector of the correct size to the server
456 auto client_message = std::make_unique<ClientToServerWrapperMessage>();
457 MaskedInputVector encoded_vector;
458 SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
459 encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
460 (*client_message->mutable_masked_input_response()
461 ->mutable_vectors())["foobar"] = encoded_vector;
462 ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
463 if (GetParam().enable_async_r2) {
464 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
465 } else {
466 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
467 }
468 }
469 }
470
471 if (GetParam().enable_async_r2) {
472 RunSchedulers();
473 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
474 }
475
476 auto next_state = state.ProceedToNextRound();
477 ASSERT_THAT(next_state, IsOk());
478 EXPECT_THAT(next_state.value()->State(),
479 Eq(SecAggServerStateKind::R3_UNMASKING));
480 EXPECT_THAT(
481 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
482 Eq(0));
483 EXPECT_THAT(
484 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
485 Eq(1));
486 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
487 Eq(0));
488 }
489
TEST_P(SecaggServerR2MaskedInputCollStateTest,StateProceedsCorrectlyWithOneClientAbortingAfterSendingInput)490 TEST_P(SecaggServerR2MaskedInputCollStateTest,
491 StateProceedsCorrectlyWithOneClientAbortingAfterSendingInput) {
492 // In this test, all clients send in their valid masked inputs, but then
493 // client 2 aborts before the server proceeds to the next state.
494 auto sender = std::make_unique<MockSendToClientsInterface>();
495
496 SecAggServerR2MaskedInputCollState state(
497 CreateSecAggServerProtocolImpl(3, 4, sender.get(),
498 nullptr /* metrics_listener */,
499 GetParam().enable_async_r2),
500 0, // number_of_clients_failed_after_sending_masked_input
501 0, // number_of_clients_failed_before_sending_masked_input
502 0 // number_of_clients_terminated_without_unmasking
503 );
504
505 for (int i = 0; i < 5; ++i) {
506 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
507 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
508 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
509 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
510 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
511 if (i < 3) {
512 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
513 } else {
514 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
515 }
516 if (GetParam().enable_async_r2) {
517 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
518 } else {
519 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
520 }
521 if (i < 4) {
522 // Have client send a vector of the correct size to the server
523 auto client_message = std::make_unique<ClientToServerWrapperMessage>();
524 MaskedInputVector encoded_vector;
525 SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
526 encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
527 (*client_message->mutable_masked_input_response()
528 ->mutable_vectors())["foobar"] = encoded_vector;
529 ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
530 if (GetParam().enable_async_r2) {
531 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
532 } else {
533 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
534 }
535 }
536 }
537
538 if (GetParam().enable_async_r2) {
539 RunSchedulers();
540 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
541 }
542
543 auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
544 abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
545 ASSERT_THAT(state.HandleMessage(2, std::move(abort_message)), IsOk());
546 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
547 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
548 EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
549 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
550 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(4));
551 EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
552 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
553 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
554
555 ServerToClientWrapperMessage server_message;
556 server_message.mutable_unmasking_request()
557 ->mutable_dead_3_client_ids()
558 ->Clear(); // Just to set it to an empty vector
559 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
560 for (int i = 0; i < 4; ++i) {
561 if (i != 2) {
562 EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
563 }
564 }
565
566 auto next_state = state.ProceedToNextRound();
567 ASSERT_THAT(next_state, IsOk());
568 EXPECT_THAT(next_state.value()->State(),
569 Eq(SecAggServerStateKind::R3_UNMASKING));
570 EXPECT_THAT(
571 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
572 Eq(1));
573 EXPECT_THAT(
574 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
575 Eq(0));
576 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
577 Eq(0));
578 }
579
TEST_P(SecaggServerR2MaskedInputCollStateTest,StateForcesAbortIfTooManyClientsAbort)580 TEST_P(SecaggServerR2MaskedInputCollStateTest,
581 StateForcesAbortIfTooManyClientsAbort) {
582 // In this test, clients 0 and 1 abort, so the state aborts.
583 TestTracingRecorder tracing_recorder;
584 auto sender = std::make_unique<MockSendToClientsInterface>();
585
586 SecAggServerR2MaskedInputCollState state(
587 CreateSecAggServerProtocolImpl(3, 4, sender.get(),
588 nullptr /* metrics_listener */,
589 GetParam().enable_async_r2),
590 0, // number_of_clients_failed_after_sending_masked_input
591 0, // number_of_clients_failed_before_sending_masked_input
592 0 // number_of_clients_terminated_without_unmasking
593 );
594
595 for (int i = 0; i < 3; ++i) {
596 EXPECT_THAT(state.NeedsToAbort(), Eq(i >= 2));
597 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4 - i));
598 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
599 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
600 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
601 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3));
602 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
603 if (i < 2) {
604 // Have client abort
605 auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
606 abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
607 ASSERT_THAT(state.HandleMessage(i, std::move(abort_message)), IsOk());
608 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 1));
609 }
610 }
611
612 ServerToClientWrapperMessage server_message;
613 server_message.mutable_abort()->set_early_success(false);
614 server_message.mutable_abort()->set_diagnostic_info(
615 "Too many clients aborted.");
616 EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(1);
617 EXPECT_CALL(*sender, Send(_, _)).Times(0);
618
619 auto next_state = state.ProceedToNextRound();
620 ASSERT_THAT(next_state, IsOk());
621 EXPECT_THAT(next_state.value()->State(), Eq(SecAggServerStateKind::ABORTED));
622 ASSERT_THAT(next_state.value()->ErrorMessage(), IsOk());
623 EXPECT_THAT(next_state.value()->ErrorMessage().value(),
624 Eq("Too many clients aborted."));
625 EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
626 ElementsAre(IsEvent<BroadcastMessageSent>(
627 Eq(ServerToClientMessageType_Abort),
628 Eq(server_message.ByteSizeLong()))));
629 }
630
TEST_P(SecaggServerR2MaskedInputCollStateTest,MetricsRecordsMessageSizes)631 TEST_P(SecaggServerR2MaskedInputCollStateTest, MetricsRecordsMessageSizes) {
632 // In this test, all clients send in their valid masked inputs, but then
633 // client 2 aborts before the server proceeds to the next state.
634 TestTracingRecorder tracing_recorder;
635 MockSecAggServerMetricsListener* metrics =
636 new MockSecAggServerMetricsListener();
637 auto sender = std::make_unique<MockSendToClientsInterface>();
638
639 SecAggServerR2MaskedInputCollState state(
640 CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics,
641 GetParam().enable_async_r2),
642 0, // number_of_clients_failed_after_sending_masked_input
643 0, // number_of_clients_failed_before_sending_masked_input
644 0 // number_of_clients_terminated_without_unmasking
645 );
646
647 for (int i = 0; i < 5; ++i) {
648 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
649 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
650 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
651 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
652 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
653 if (i < 3) {
654 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
655 } else {
656 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
657 }
658 if (GetParam().enable_async_r2) {
659 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
660 } else {
661 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
662 }
663 if (i < 4) {
664 // Have client send a vector of the correct size to the server
665 auto client_message = std::make_unique<ClientToServerWrapperMessage>();
666 MaskedInputVector encoded_vector;
667 SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
668 encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
669 (*client_message->mutable_masked_input_response()
670 ->mutable_vectors())["foobar"] = encoded_vector;
671 EXPECT_CALL(
672 *metrics,
673 MessageReceivedSizes(Eq(ClientToServerWrapperMessage::
674 MessageContentCase::kMaskedInputResponse),
675 Eq(true), Eq(client_message->ByteSizeLong())));
676 ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
677 if (GetParam().enable_async_r2) {
678 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
679 } else {
680 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
681 }
682 }
683 }
684
685 if (GetParam().enable_async_r2) {
686 RunSchedulers();
687 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
688 }
689
690 auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
691 abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
692 EXPECT_CALL(*metrics,
693 MessageReceivedSizes(
694 Eq(ClientToServerWrapperMessage::MessageContentCase::kAbort),
695 Eq(false), Eq(abort_message->ByteSizeLong())));
696
697 size_t abort_message_size = abort_message->ByteSizeLong();
698 ASSERT_THAT(state.HandleMessage(2, std::move(abort_message)), IsOk());
699 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
700 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
701 EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
702 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
703 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(4));
704 EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
705 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
706 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
707 EXPECT_THAT(tracing_recorder.root(),
708 Contains(IsEvent<ClientMessageReceived>(
709 Eq(ClientToServerMessageType_Abort), Eq(abort_message_size),
710 Eq(false), Ge(0))));
711
712 ServerToClientWrapperMessage server_message;
713 server_message.mutable_unmasking_request()
714 ->mutable_dead_3_client_ids()
715 ->Clear(); // Just to set it to an empty vector
716 EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(0);
717 for (int i = 0; i < 4; ++i) {
718 if (i != 2) {
719 EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
720 }
721 }
722 EXPECT_CALL(*metrics, BroadcastMessageSizes(_, _)).Times(0);
723 EXPECT_CALL(*metrics, IndividualMessageSizes(
724 Eq(ServerToClientWrapperMessage::
725 MessageContentCase::kUnmaskingRequest),
726 Eq(server_message.ByteSizeLong())))
727 .Times(3);
728
729 auto next_state = state.ProceedToNextRound();
730 ASSERT_THAT(next_state, IsOk());
731 EXPECT_THAT(next_state.value()->State(),
732 Eq(SecAggServerStateKind::R3_UNMASKING));
733 EXPECT_THAT(
734 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
735 Eq(1));
736 EXPECT_THAT(
737 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
738 Eq(0));
739 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
740 Eq(0));
741 }
742
TEST_P(SecaggServerR2MaskedInputCollStateTest,ServerAndClientAbortsAreRecordedCorrectly)743 TEST_P(SecaggServerR2MaskedInputCollStateTest,
744 ServerAndClientAbortsAreRecordedCorrectly) {
745 // In this test clients abort for a variety of reasons, and then ultimately
746 // the server aborts. Metrics should record all of these events.
747 MockSecAggServerMetricsListener* metrics =
748 new MockSecAggServerMetricsListener();
749 auto sender = std::make_unique<MockSendToClientsInterface>();
750
751 SecAggServerR2MaskedInputCollState state(
752 CreateSecAggServerProtocolImpl(2, 7, sender.get(), metrics,
753 GetParam().enable_async_r2),
754 0, // number_of_clients_failed_after_sending_masked_input
755 0, // number_of_clients_failed_before_sending_masked_input
756 0 // number_of_clients_terminated_without_unmasking
757 );
758
759 EXPECT_CALL(*metrics,
760 ClientsDropped(Eq(ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED),
761 Eq(ClientDropReason::SENT_ABORT_MESSAGE)));
762 EXPECT_CALL(*metrics,
763 ClientsDropped(
764 Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
765 Eq(ClientDropReason::MASKED_INPUT_UNEXPECTED)));
766 EXPECT_CALL(*metrics,
767 ClientsDropped(Eq(ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED),
768 Eq(ClientDropReason::UNEXPECTED_MESSAGE_TYPE)));
769 EXPECT_CALL(*metrics,
770 ClientsDropped(Eq(ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED),
771 Eq(ClientDropReason::INVALID_MASKED_INPUT)))
772 .Times(3);
773 EXPECT_CALL(
774 *metrics,
775 ProtocolOutcomes(Eq(SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING)));
776
777 auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
778 abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
779
780 ClientToServerWrapperMessage valid_message;
781 MaskedInputVector encoded_vector;
782 SecAggVector masked_vector(std::vector<uint64_t>(4, 9), 32);
783 encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
784 (*valid_message.mutable_masked_input_response()
785 ->mutable_vectors())["foobar"] = encoded_vector;
786
787 auto invalid_message_too_many_vectors =
788 std::make_unique<ClientToServerWrapperMessage>();
789 (*invalid_message_too_many_vectors->mutable_masked_input_response()
790 ->mutable_vectors())["extra"] = encoded_vector;
791
792 auto invalid_message_wrong_name =
793 std::make_unique<ClientToServerWrapperMessage>();
794 (*invalid_message_wrong_name->mutable_masked_input_response()
795 ->mutable_vectors())["wrong"] = encoded_vector;
796
797 auto invalid_message_wrong_size =
798 std::make_unique<ClientToServerWrapperMessage>();
799 MaskedInputVector large_encoded_vector;
800 SecAggVector large_masked_vector(std::vector<uint64_t>(7, 9), 32);
801 large_encoded_vector.set_encoded_vector(
802 large_masked_vector.GetAsPackedBytes());
803 (*invalid_message_wrong_size->mutable_masked_input_response()
804 ->mutable_vectors())["foobar"] = large_encoded_vector;
805
806 auto wrong_message = std::make_unique<ClientToServerWrapperMessage>();
807 wrong_message->mutable_advertise_keys(); // wrong type of message
808
809 state.HandleMessage(0, std::move(abort_message)).IgnoreError();
810 state
811 .HandleMessage(
812 1, std::make_unique<ClientToServerWrapperMessage>(valid_message))
813 .IgnoreError();
814 state
815 .HandleMessage(
816 1, std::make_unique<ClientToServerWrapperMessage>(valid_message))
817 .IgnoreError();
818 state.HandleMessage(2, std::move(invalid_message_too_many_vectors))
819 .IgnoreError();
820 state.HandleMessage(3, std::move(invalid_message_wrong_name)).IgnoreError();
821 state.HandleMessage(4, std::move(invalid_message_wrong_size)).IgnoreError();
822 state.HandleMessage(5, std::move(wrong_message)).IgnoreError();
823
824 if (GetParam().enable_async_r2) {
825 RunSchedulers();
826 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
827 }
828
829 state.ProceedToNextRound().IgnoreError(); // causes server abort
830 }
831
TEST_P(SecaggServerR2MaskedInputCollStateTest,MetricsAreRecorded)832 TEST_P(SecaggServerR2MaskedInputCollStateTest, MetricsAreRecorded) {
833 // In this test, clients 0 through 2 send in valid masked inputs, and then we
834 // proceed to the next step even without client 3.
835 MockSecAggServerMetricsListener* metrics =
836 new MockSecAggServerMetricsListener();
837 auto sender = std::make_unique<MockSendToClientsInterface>();
838
839 SecAggServerR2MaskedInputCollState state(
840 CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics,
841 GetParam().enable_async_r2),
842 0, // number_of_clients_failed_after_sending_masked_input
843 0, // number_of_clients_failed_before_sending_masked_input
844 0 // number_of_clients_terminated_without_unmasking
845 );
846
847 EXPECT_CALL(*metrics, ClientResponseTimes(
848 Eq(ClientToServerWrapperMessage::
849 MessageContentCase::kMaskedInputResponse),
850 Ge(0)))
851 .Times(3);
852
853 for (int i = 0; i < 4; ++i) {
854 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
855 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
856 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
857 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
858 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
859 if (i < 3) {
860 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
861 } else {
862 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
863 }
864 if (GetParam().enable_async_r2) {
865 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
866 } else {
867 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
868 }
869 if (i < 3) {
870 // Have client send a vector of the correct size to the server
871 auto client_message = std::make_unique<ClientToServerWrapperMessage>();
872 MaskedInputVector encoded_vector;
873 SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
874 encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
875 (*client_message->mutable_masked_input_response()
876 ->mutable_vectors())["foobar"] = encoded_vector;
877 ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
878 if (GetParam().enable_async_r2) {
879 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
880 } else {
881 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
882 }
883 }
884 }
885
886 if (GetParam().enable_async_r2) {
887 RunSchedulers();
888 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
889 }
890
891 ServerToClientWrapperMessage server_message;
892 // TODO(team): 4 -> 3 below, once backwards compatibility not needed.
893 server_message.mutable_unmasking_request()->add_dead_3_client_ids(4);
894 ServerToClientWrapperMessage abort_message;
895 abort_message.mutable_abort()->set_early_success(false);
896 abort_message.mutable_abort()->set_diagnostic_info(
897 "Client did not send MaskedInputCollectionResponse before round "
898 "transition.");
899 EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(0);
900 for (int i = 0; i < 3; ++i) {
901 EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
902 }
903 EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message))).Times(1);
904 EXPECT_CALL(*metrics,
905 RoundTimes(Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION),
906 Eq(true), Ge(0)));
907 EXPECT_CALL(
908 *metrics,
909 RoundSurvivingClients(
910 Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION), Eq(3)));
911
912 auto next_state = state.ProceedToNextRound();
913 ASSERT_THAT(next_state, IsOk());
914 EXPECT_THAT(next_state.value()->State(),
915 Eq(SecAggServerStateKind::R3_UNMASKING));
916 }
917
918 INSTANTIATE_TEST_SUITE_P(
919 SecaggServerR2MaskedInputCollStateTests,
920 SecaggServerR2MaskedInputCollStateTest,
921 ::testing::ValuesIn<SecAggR2StateTestParams>(
922 {{"r2_async_processing_enabled", true},
923 {"r2_async_processing_disabled", false}}),
924 [](const ::testing::TestParamInfo<
__anona314698d0202(const ::testing::TestParamInfo< SecaggServerR2MaskedInputCollStateTest::ParamType>& info) 925 SecaggServerR2MaskedInputCollStateTest::ParamType>& info) {
926 return info.param.test_name;
927 });
928
929 } // namespace
930 } // namespace secagg
931 } // namespace fcp
932