• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_r2_masked_input_coll_state.h"
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "absl/container/node_hash_set.h"
28 #include "absl/strings/str_cat.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
31 #include "fcp/secagg/server/experiments_interface.h"
32 #include "fcp/secagg/server/experiments_names.h"
33 #include "fcp/secagg/server/secagg_server_state.h"
34 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
35 #include "fcp/secagg/server/send_to_clients_interface.h"
36 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
37 #include "fcp/secagg/shared/ecdh_key_agreement.h"
38 #include "fcp/secagg/shared/ecdh_keys.h"
39 #include "fcp/secagg/shared/input_vector_specification.h"
40 #include "fcp/secagg/shared/secagg_messages.pb.h"
41 #include "fcp/secagg/shared/shamir_secret_sharing.h"
42 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
43 #include "fcp/secagg/testing/fake_prng.h"
44 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
45 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
46 #include "fcp/secagg/testing/server/test_secagg_experiments.h"
47 #include "fcp/testing/testing.h"
48 #include "fcp/tracing/test_tracing_recorder.h"
49 
50 namespace fcp {
51 namespace secagg {
52 namespace {
53 
54 using ::testing::_;
55 using ::testing::Eq;
56 using ::testing::Ge;
57 using ::testing::IsFalse;
58 using ::testing::IsTrue;
59 
60 class FakeScheduler : public Scheduler {
61  public:
Schedule(std::function<void ()> job)62   void Schedule(std::function<void()> job) override { jobs_.push_back(job); }
63 
WaitUntilIdle()64   void WaitUntilIdle() override {}
65 
Run()66   void Run() {
67     for (auto& job : jobs_) {
68       job();
69     }
70     jobs_.clear();
71   }
72 
73  private:
74   std::vector<std::function<void()>> jobs_;
75 };
76 
77 // Default test session_id.
78 SessionId session_id = {"session id number, 32 bytes long"};
79 
80 struct SecAggR2StateTestParams {
81   const std::string test_name;
82   // Enables asymchronous processing of round 2 messages by the server.
83   bool enable_async_r2;
84 };
85 
86 class SecaggServerR2MaskedInputCollStateTest
87     : public ::testing::TestWithParam<SecAggR2StateTestParams> {
88  protected:
CreateSecAggServerProtocolImpl(int minimum_number_of_clients_to_proceed,int total_number_of_clients,MockSendToClientsInterface * sender,MockSecAggServerMetricsListener * metrics_listener=nullptr,bool enable_async_r2=true)89   std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
90       int minimum_number_of_clients_to_proceed, int total_number_of_clients,
91       MockSendToClientsInterface* sender,
92       MockSecAggServerMetricsListener* metrics_listener = nullptr,
93       bool enable_async_r2 = true) {
94     auto input_vector_specs = std::vector<InputVectorSpecification>();
95     input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
96     SecretSharingGraphFactory factory;
97     auto impl = std::make_unique<AesSecAggServerProtocolImpl>(
98         factory.CreateCompleteGraph(total_number_of_clients,
99                                     minimum_number_of_clients_to_proceed),
100         minimum_number_of_clients_to_proceed, input_vector_specs,
101         std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
102         std::make_unique<AesCtrPrngFactory>(), sender,
103         std::make_unique<SecAggScheduler>(&parallel_scheduler_,
104                                           &sequential_scheduler_),
105         std::vector<ClientStatus>(total_number_of_clients,
106                                   ClientStatus::SHARE_KEYS_RECEIVED),
107         ServerVariant::NATIVE_V1,
108         enable_async_r2
109             ? std::make_unique<TestSecAggExperiment>(
110                   TestSecAggExperiment(kSecAggAsyncRound2Experiment))
111             : std::make_unique<TestSecAggExperiment>(TestSecAggExperiment()));
112     impl->set_session_id(std::make_unique<SessionId>(session_id));
113     EcdhPregeneratedTestKeys ecdh_keys;
114     for (int i = 0; i < total_number_of_clients; ++i) {
115       impl->SetPairwisePublicKeys(i, ecdh_keys.GetPublicKey(i));
116     }
117 
118     return impl;
119   }
120 
RunSchedulers()121   void RunSchedulers() {
122     parallel_scheduler_.Run();
123     sequential_scheduler_.Run();
124   }
125 
126  private:
127   FakeScheduler parallel_scheduler_;
128   FakeScheduler sequential_scheduler_;
129 };
130 
TEST_P(SecaggServerR2MaskedInputCollStateTest,IsAbortedReturnsFalse)131 TEST_P(SecaggServerR2MaskedInputCollStateTest, IsAbortedReturnsFalse) {
132   auto sender = std::make_unique<MockSendToClientsInterface>();
133 
134   SecAggServerR2MaskedInputCollState state(
135       CreateSecAggServerProtocolImpl(3, 4, sender.get(),
136                                      nullptr /* metrics_listener */,
137                                      GetParam().enable_async_r2),
138       0,  // number_of_clients_failed_after_sending_masked_input
139       0,  // number_of_clients_failed_before_sending_masked_input
140       0   // number_of_clients_terminated_without_unmasking
141   );
142 
143   EXPECT_THAT(state.IsAborted(), IsFalse());
144 }
145 
TEST_P(SecaggServerR2MaskedInputCollStateTest,IsCompletedSuccessfullyReturnsFalse)146 TEST_P(SecaggServerR2MaskedInputCollStateTest,
147        IsCompletedSuccessfullyReturnsFalse) {
148   auto sender = std::make_unique<MockSendToClientsInterface>();
149 
150   SecAggServerR2MaskedInputCollState state(
151       CreateSecAggServerProtocolImpl(3, 4, sender.get(),
152                                      nullptr /* metrics_listener */,
153                                      GetParam().enable_async_r2),
154       0,  // number_of_clients_failed_after_sending_masked_input
155       0,  // number_of_clients_failed_before_sending_masked_input
156       0   // number_of_clients_terminated_without_unmasking
157   );
158 
159   EXPECT_THAT(state.IsCompletedSuccessfully(), IsFalse());
160 }
161 
TEST_P(SecaggServerR2MaskedInputCollStateTest,ErrorMessageRaisesErrorStatus)162 TEST_P(SecaggServerR2MaskedInputCollStateTest, ErrorMessageRaisesErrorStatus) {
163   auto sender = std::make_unique<MockSendToClientsInterface>();
164 
165   SecAggServerR2MaskedInputCollState state(
166       CreateSecAggServerProtocolImpl(3, 4, sender.get(),
167                                      nullptr /* metrics_listener */,
168                                      GetParam().enable_async_r2),
169       0,  // number_of_clients_failed_after_sending_masked_input
170       0,  // number_of_clients_failed_before_sending_masked_input
171       0   // number_of_clients_terminated_without_unmasking
172   );
173 
174   EXPECT_THAT(state.ErrorMessage().ok(), IsFalse());
175 }
176 
TEST_P(SecaggServerR2MaskedInputCollStateTest,ResultRaisesErrorStatus)177 TEST_P(SecaggServerR2MaskedInputCollStateTest, ResultRaisesErrorStatus) {
178   auto sender = std::make_unique<MockSendToClientsInterface>();
179 
180   SecAggServerR2MaskedInputCollState state(
181       CreateSecAggServerProtocolImpl(3, 4, sender.get(),
182                                      nullptr /* metrics_listener */,
183                                      GetParam().enable_async_r2),
184       0,  // number_of_clients_failed_after_sending_masked_input
185       0,  // number_of_clients_failed_before_sending_masked_input
186       0   // number_of_clients_terminated_without_unmasking
187   );
188 
189   EXPECT_THAT(state.Result().ok(), IsFalse());
190 }
191 
TEST_P(SecaggServerR2MaskedInputCollStateTest,AbortReturnsValidStateAndNotifiesClients)192 TEST_P(SecaggServerR2MaskedInputCollStateTest,
193        AbortReturnsValidStateAndNotifiesClients) {
194   TestTracingRecorder tracing_recorder;
195   MockSecAggServerMetricsListener* metrics =
196       new MockSecAggServerMetricsListener();
197   auto sender = std::make_unique<MockSendToClientsInterface>();
198 
199   SecAggServerR2MaskedInputCollState state(
200       CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics,
201                                      GetParam().enable_async_r2),
202       0,  // number_of_clients_failed_after_sending_masked_input
203       0,  // number_of_clients_failed_before_sending_masked_input
204       0   // number_of_clients_terminated_without_unmasking
205   );
206 
207   ServerToClientWrapperMessage abort_message;
208   abort_message.mutable_abort()->set_early_success(false);
209   abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
210 
211   EXPECT_CALL(*metrics,
212               ProtocolOutcomes(Eq(SecAggServerOutcome::EXTERNAL_REQUEST)));
213   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
214   auto next_state =
215       state.Abort("test abort reason", SecAggServerOutcome::EXTERNAL_REQUEST);
216 
217   ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
218   ASSERT_THAT(next_state->ErrorMessage(), IsOk());
219   EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
220   EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
221               ElementsAre(IsEvent<BroadcastMessageSent>(
222                   Eq(ServerToClientMessageType_Abort),
223                   Eq(abort_message.ByteSizeLong()))));
224 }
225 
TEST_P(SecaggServerR2MaskedInputCollStateTest,StateProceedsCorrectlyWithAllClientsValid)226 TEST_P(SecaggServerR2MaskedInputCollStateTest,
227        StateProceedsCorrectlyWithAllClientsValid) {
228   // In this test, all clients send in their valid masked inputs, and then the
229   // server proceeds to the next state.
230   TestTracingRecorder tracing_recorder;
231   auto sender = std::make_unique<MockSendToClientsInterface>();
232 
233   SecAggServerR2MaskedInputCollState state(
234       CreateSecAggServerProtocolImpl(3, 4, sender.get(),
235                                      nullptr /* metrics_listener */,
236                                      GetParam().enable_async_r2),
237       0,  // number_of_clients_failed_after_sending_masked_input
238       0,  // number_of_clients_failed_before_sending_masked_input
239       0   // number_of_clients_terminated_without_unmasking
240   );
241 
242   for (int i = 0; i < 5; ++i) {
243     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
244     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
245     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
246     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
247     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
248     if (i < 3) {
249       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
250     } else {
251       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
252     }
253     if (GetParam().enable_async_r2) {
254       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
255     } else {
256       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
257     }
258 
259     if (i < 4) {
260       // Have client send a vector of the correct size to the server
261       auto client_message = std::make_unique<ClientToServerWrapperMessage>();
262       MaskedInputVector encoded_vector;
263       SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
264       encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
265       (*client_message->mutable_masked_input_response()
266             ->mutable_vectors())["foobar"] = encoded_vector;
267       ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
268       if (GetParam().enable_async_r2) {
269         EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
270       } else {
271         EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
272       }
273     }
274   }
275 
276   if (GetParam().enable_async_r2) {
277     RunSchedulers();
278     EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
279   }
280 
281   ServerToClientWrapperMessage server_message;
282   server_message.mutable_unmasking_request()
283       ->mutable_dead_3_client_ids()
284       ->Clear();  // Just to set it to an empty vector
285   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
286   for (int i = 0; i < 4; ++i) {
287     EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
288   }
289 
290   auto next_state = state.ProceedToNextRound();
291   ASSERT_THAT(next_state, IsOk());
292   EXPECT_THAT(next_state.value()->State(),
293               Eq(SecAggServerStateKind::R3_UNMASKING));
294   EXPECT_THAT(
295       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
296       Eq(0));
297   EXPECT_THAT(
298       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
299       Eq(0));
300   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
301               Eq(0));
302   EXPECT_THAT(tracing_recorder.FindAllEvents<IndividualMessageSent>(),
303               ElementsAre(IsEvent<IndividualMessageSent>(
304                               0, Eq(ServerToClientMessageType_UnmaskingRequest),
305                               Eq(server_message.ByteSizeLong())),
306                           IsEvent<IndividualMessageSent>(
307                               1, Eq(ServerToClientMessageType_UnmaskingRequest),
308                               Eq(server_message.ByteSizeLong())),
309                           IsEvent<IndividualMessageSent>(
310                               2, Eq(ServerToClientMessageType_UnmaskingRequest),
311                               Eq(server_message.ByteSizeLong())),
312                           IsEvent<IndividualMessageSent>(
313                               3, Eq(ServerToClientMessageType_UnmaskingRequest),
314                               Eq(server_message.ByteSizeLong()))));
315 }
316 
TEST_P(SecaggServerR2MaskedInputCollStateTest,StateProceedsCorrectlyWithoutAllClients)317 TEST_P(SecaggServerR2MaskedInputCollStateTest,
318        StateProceedsCorrectlyWithoutAllClients) {
319   // In this test, clients 0 through 2 send in valid masked inputs, and then we
320   // proceed to the next step even without client 3.
321   auto sender = std::make_unique<MockSendToClientsInterface>();
322 
323   SecAggServerR2MaskedInputCollState state(
324       CreateSecAggServerProtocolImpl(3, 4, sender.get(),
325                                      nullptr /* metrics_listener */,
326                                      GetParam().enable_async_r2),
327       0,  // number_of_clients_failed_after_sending_masked_input
328       0,  // number_of_clients_failed_before_sending_masked_input
329       0   // number_of_clients_terminated_without_unmasking
330   );
331 
332   for (int i = 0; i < 4; ++i) {
333     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
334     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
335     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
336     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
337     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
338     if (i < 3) {
339       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
340     } else {
341       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
342     }
343     if (GetParam().enable_async_r2) {
344       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
345     } else {
346       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
347     }
348 
349     if (i < 3) {
350       // Have client send a vector of the correct size to the server
351       auto client_message = std::make_unique<ClientToServerWrapperMessage>();
352       MaskedInputVector encoded_vector;
353       SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
354       encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
355       (*client_message->mutable_masked_input_response()
356             ->mutable_vectors())["foobar"] = encoded_vector;
357       ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
358       if (GetParam().enable_async_r2) {
359         EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
360       } else {
361         EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
362       }
363     }
364   }
365 
366   if (GetParam().enable_async_r2) {
367     RunSchedulers();
368     EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
369   }
370 
371   ServerToClientWrapperMessage server_message;
372   // TODO(team): 4 -> 3 below, once backwards compatibility not needed.
373   server_message.mutable_unmasking_request()->add_dead_3_client_ids(4);
374   ServerToClientWrapperMessage abort_message;
375   abort_message.mutable_abort()->set_early_success(false);
376   abort_message.mutable_abort()->set_diagnostic_info(
377       "Client did not send MaskedInputCollectionResponse before round "
378       "transition.");
379   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
380   for (int i = 0; i < 3; ++i) {
381     EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
382   }
383   EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message))).Times(1);
384 
385   auto next_state = state.ProceedToNextRound();
386   ASSERT_THAT(next_state, IsOk());
387   EXPECT_THAT(next_state.value()->State(),
388               Eq(SecAggServerStateKind::R3_UNMASKING));
389   EXPECT_THAT(
390       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
391       Eq(0));
392   EXPECT_THAT(
393       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
394       Eq(1));
395   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
396               Eq(0));
397 }
398 
TEST_P(SecaggServerR2MaskedInputCollStateTest,StateProceedsCorrectlyWithOneClientSendingInvalidInput)399 TEST_P(SecaggServerR2MaskedInputCollStateTest,
400        StateProceedsCorrectlyWithOneClientSendingInvalidInput) {
401   // In this test, client 0 sends an invalid masked input, so it is aborted. The
402   // rest of the round goes normally.
403   auto sender = std::make_unique<MockSendToClientsInterface>();
404 
405   SecAggServerR2MaskedInputCollState state(
406       CreateSecAggServerProtocolImpl(3, 4, sender.get(),
407                                      nullptr /* metrics_listener */,
408                                      GetParam().enable_async_r2),
409       0,  // number_of_clients_failed_after_sending_masked_input
410       0,  // number_of_clients_failed_before_sending_masked_input
411       0   // number_of_clients_terminated_without_unmasking
412   );
413 
414   ServerToClientWrapperMessage server_message;
415   // TODO(team): 1 -> 0 below, once backwards compatibility not needed.
416   server_message.mutable_unmasking_request()->add_dead_3_client_ids(1);
417   ServerToClientWrapperMessage abort_message;
418   abort_message.mutable_abort()->set_early_success(false);
419   abort_message.mutable_abort()->set_diagnostic_info(
420       "Masked input does not match input vector specification - vector is "
421       "wrong size.");
422 
423   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
424   for (int i = 1; i < 4; ++i) {
425     EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
426   }
427   EXPECT_CALL(*sender, Send(0, EqualsProto(abort_message))).Times(1);
428 
429   // Have client 0 send an invalid message.
430   auto invalid_message = std::make_unique<ClientToServerWrapperMessage>();
431   MaskedInputVector encoded_vector;
432   encoded_vector.set_encoded_vector("not a real masked input vector - invalid");
433   (*invalid_message->mutable_masked_input_response()
434         ->mutable_vectors())["foobar"] = encoded_vector;
435   ASSERT_THAT(state.HandleMessage(0, std::move(invalid_message)), IsOk());
436   EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
437   for (int i = 1; i < 5; ++i) {
438     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
439     EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
440     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i - 1));
441     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i - 1));
442     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
443     if (i < 4) {
444       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(4 - i));
445     } else {
446       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
447     }
448     if (GetParam().enable_async_r2) {
449       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
450     } else {
451       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 4));
452     }
453 
454     if (i < 4) {
455       // Have client send a vector of the correct size to the server
456       auto client_message = std::make_unique<ClientToServerWrapperMessage>();
457       MaskedInputVector encoded_vector;
458       SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
459       encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
460       (*client_message->mutable_masked_input_response()
461             ->mutable_vectors())["foobar"] = encoded_vector;
462       ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
463       if (GetParam().enable_async_r2) {
464         EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
465       } else {
466         EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
467       }
468     }
469   }
470 
471   if (GetParam().enable_async_r2) {
472     RunSchedulers();
473     EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
474   }
475 
476   auto next_state = state.ProceedToNextRound();
477   ASSERT_THAT(next_state, IsOk());
478   EXPECT_THAT(next_state.value()->State(),
479               Eq(SecAggServerStateKind::R3_UNMASKING));
480   EXPECT_THAT(
481       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
482       Eq(0));
483   EXPECT_THAT(
484       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
485       Eq(1));
486   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
487               Eq(0));
488 }
489 
TEST_P(SecaggServerR2MaskedInputCollStateTest,StateProceedsCorrectlyWithOneClientAbortingAfterSendingInput)490 TEST_P(SecaggServerR2MaskedInputCollStateTest,
491        StateProceedsCorrectlyWithOneClientAbortingAfterSendingInput) {
492   // In this test, all clients send in their valid masked inputs, but then
493   // client 2 aborts before the server proceeds to the next state.
494   auto sender = std::make_unique<MockSendToClientsInterface>();
495 
496   SecAggServerR2MaskedInputCollState state(
497       CreateSecAggServerProtocolImpl(3, 4, sender.get(),
498                                      nullptr /* metrics_listener */,
499                                      GetParam().enable_async_r2),
500       0,  // number_of_clients_failed_after_sending_masked_input
501       0,  // number_of_clients_failed_before_sending_masked_input
502       0   // number_of_clients_terminated_without_unmasking
503   );
504 
505   for (int i = 0; i < 5; ++i) {
506     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
507     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
508     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
509     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
510     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
511     if (i < 3) {
512       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
513     } else {
514       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
515     }
516     if (GetParam().enable_async_r2) {
517       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
518     } else {
519       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
520     }
521     if (i < 4) {
522       // Have client send a vector of the correct size to the server
523       auto client_message = std::make_unique<ClientToServerWrapperMessage>();
524       MaskedInputVector encoded_vector;
525       SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
526       encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
527       (*client_message->mutable_masked_input_response()
528             ->mutable_vectors())["foobar"] = encoded_vector;
529       ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
530       if (GetParam().enable_async_r2) {
531         EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
532       } else {
533         EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
534       }
535     }
536   }
537 
538   if (GetParam().enable_async_r2) {
539     RunSchedulers();
540     EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
541   }
542 
543   auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
544   abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
545   ASSERT_THAT(state.HandleMessage(2, std::move(abort_message)), IsOk());
546   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
547   EXPECT_THAT(state.NeedsToAbort(), IsFalse());
548   EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
549   EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
550   EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(4));
551   EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
552   EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
553   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
554 
555   ServerToClientWrapperMessage server_message;
556   server_message.mutable_unmasking_request()
557       ->mutable_dead_3_client_ids()
558       ->Clear();  // Just to set it to an empty vector
559   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
560   for (int i = 0; i < 4; ++i) {
561     if (i != 2) {
562       EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
563     }
564   }
565 
566   auto next_state = state.ProceedToNextRound();
567   ASSERT_THAT(next_state, IsOk());
568   EXPECT_THAT(next_state.value()->State(),
569               Eq(SecAggServerStateKind::R3_UNMASKING));
570   EXPECT_THAT(
571       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
572       Eq(1));
573   EXPECT_THAT(
574       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
575       Eq(0));
576   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
577               Eq(0));
578 }
579 
TEST_P(SecaggServerR2MaskedInputCollStateTest,StateForcesAbortIfTooManyClientsAbort)580 TEST_P(SecaggServerR2MaskedInputCollStateTest,
581        StateForcesAbortIfTooManyClientsAbort) {
582   // In this test, clients 0 and 1 abort, so the state aborts.
583   TestTracingRecorder tracing_recorder;
584   auto sender = std::make_unique<MockSendToClientsInterface>();
585 
586   SecAggServerR2MaskedInputCollState state(
587       CreateSecAggServerProtocolImpl(3, 4, sender.get(),
588                                      nullptr /* metrics_listener */,
589                                      GetParam().enable_async_r2),
590       0,  // number_of_clients_failed_after_sending_masked_input
591       0,  // number_of_clients_failed_before_sending_masked_input
592       0   // number_of_clients_terminated_without_unmasking
593   );
594 
595   for (int i = 0; i < 3; ++i) {
596     EXPECT_THAT(state.NeedsToAbort(), Eq(i >= 2));
597     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4 - i));
598     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
599     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
600     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
601     EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3));
602     EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
603     if (i < 2) {
604       // Have client abort
605       auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
606       abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
607       ASSERT_THAT(state.HandleMessage(i, std::move(abort_message)), IsOk());
608       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 1));
609     }
610   }
611 
612   ServerToClientWrapperMessage server_message;
613   server_message.mutable_abort()->set_early_success(false);
614   server_message.mutable_abort()->set_diagnostic_info(
615       "Too many clients aborted.");
616   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(1);
617   EXPECT_CALL(*sender, Send(_, _)).Times(0);
618 
619   auto next_state = state.ProceedToNextRound();
620   ASSERT_THAT(next_state, IsOk());
621   EXPECT_THAT(next_state.value()->State(), Eq(SecAggServerStateKind::ABORTED));
622   ASSERT_THAT(next_state.value()->ErrorMessage(), IsOk());
623   EXPECT_THAT(next_state.value()->ErrorMessage().value(),
624               Eq("Too many clients aborted."));
625   EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
626               ElementsAre(IsEvent<BroadcastMessageSent>(
627                   Eq(ServerToClientMessageType_Abort),
628                   Eq(server_message.ByteSizeLong()))));
629 }
630 
TEST_P(SecaggServerR2MaskedInputCollStateTest,MetricsRecordsMessageSizes)631 TEST_P(SecaggServerR2MaskedInputCollStateTest, MetricsRecordsMessageSizes) {
632   // In this test, all clients send in their valid masked inputs, but then
633   // client 2 aborts before the server proceeds to the next state.
634   TestTracingRecorder tracing_recorder;
635   MockSecAggServerMetricsListener* metrics =
636       new MockSecAggServerMetricsListener();
637   auto sender = std::make_unique<MockSendToClientsInterface>();
638 
639   SecAggServerR2MaskedInputCollState state(
640       CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics,
641                                      GetParam().enable_async_r2),
642       0,  // number_of_clients_failed_after_sending_masked_input
643       0,  // number_of_clients_failed_before_sending_masked_input
644       0   // number_of_clients_terminated_without_unmasking
645   );
646 
647   for (int i = 0; i < 5; ++i) {
648     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
649     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
650     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
651     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
652     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
653     if (i < 3) {
654       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
655     } else {
656       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
657     }
658     if (GetParam().enable_async_r2) {
659       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
660     } else {
661       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
662     }
663     if (i < 4) {
664       // Have client send a vector of the correct size to the server
665       auto client_message = std::make_unique<ClientToServerWrapperMessage>();
666       MaskedInputVector encoded_vector;
667       SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
668       encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
669       (*client_message->mutable_masked_input_response()
670             ->mutable_vectors())["foobar"] = encoded_vector;
671       EXPECT_CALL(
672           *metrics,
673           MessageReceivedSizes(Eq(ClientToServerWrapperMessage::
674                                       MessageContentCase::kMaskedInputResponse),
675                                Eq(true), Eq(client_message->ByteSizeLong())));
676       ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
677       if (GetParam().enable_async_r2) {
678         EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
679       } else {
680         EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
681       }
682     }
683   }
684 
685   if (GetParam().enable_async_r2) {
686     RunSchedulers();
687     EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
688   }
689 
690   auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
691   abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
692   EXPECT_CALL(*metrics,
693               MessageReceivedSizes(
694                   Eq(ClientToServerWrapperMessage::MessageContentCase::kAbort),
695                   Eq(false), Eq(abort_message->ByteSizeLong())));
696 
697   size_t abort_message_size = abort_message->ByteSizeLong();
698   ASSERT_THAT(state.HandleMessage(2, std::move(abort_message)), IsOk());
699   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
700   EXPECT_THAT(state.NeedsToAbort(), IsFalse());
701   EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
702   EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
703   EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(4));
704   EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
705   EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
706   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
707   EXPECT_THAT(tracing_recorder.root(),
708               Contains(IsEvent<ClientMessageReceived>(
709                   Eq(ClientToServerMessageType_Abort), Eq(abort_message_size),
710                   Eq(false), Ge(0))));
711 
712   ServerToClientWrapperMessage server_message;
713   server_message.mutable_unmasking_request()
714       ->mutable_dead_3_client_ids()
715       ->Clear();  // Just to set it to an empty vector
716   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(0);
717   for (int i = 0; i < 4; ++i) {
718     if (i != 2) {
719       EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
720     }
721   }
722   EXPECT_CALL(*metrics, BroadcastMessageSizes(_, _)).Times(0);
723   EXPECT_CALL(*metrics, IndividualMessageSizes(
724                             Eq(ServerToClientWrapperMessage::
725                                    MessageContentCase::kUnmaskingRequest),
726                             Eq(server_message.ByteSizeLong())))
727       .Times(3);
728 
729   auto next_state = state.ProceedToNextRound();
730   ASSERT_THAT(next_state, IsOk());
731   EXPECT_THAT(next_state.value()->State(),
732               Eq(SecAggServerStateKind::R3_UNMASKING));
733   EXPECT_THAT(
734       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
735       Eq(1));
736   EXPECT_THAT(
737       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
738       Eq(0));
739   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
740               Eq(0));
741 }
742 
TEST_P(SecaggServerR2MaskedInputCollStateTest,ServerAndClientAbortsAreRecordedCorrectly)743 TEST_P(SecaggServerR2MaskedInputCollStateTest,
744        ServerAndClientAbortsAreRecordedCorrectly) {
745   // In this test clients abort for a variety of reasons, and then ultimately
746   // the server aborts. Metrics should record all of these events.
747   MockSecAggServerMetricsListener* metrics =
748       new MockSecAggServerMetricsListener();
749   auto sender = std::make_unique<MockSendToClientsInterface>();
750 
751   SecAggServerR2MaskedInputCollState state(
752       CreateSecAggServerProtocolImpl(2, 7, sender.get(), metrics,
753                                      GetParam().enable_async_r2),
754       0,  // number_of_clients_failed_after_sending_masked_input
755       0,  // number_of_clients_failed_before_sending_masked_input
756       0   // number_of_clients_terminated_without_unmasking
757   );
758 
759   EXPECT_CALL(*metrics,
760               ClientsDropped(Eq(ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED),
761                              Eq(ClientDropReason::SENT_ABORT_MESSAGE)));
762   EXPECT_CALL(*metrics,
763               ClientsDropped(
764                   Eq(ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED),
765                   Eq(ClientDropReason::MASKED_INPUT_UNEXPECTED)));
766   EXPECT_CALL(*metrics,
767               ClientsDropped(Eq(ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED),
768                              Eq(ClientDropReason::UNEXPECTED_MESSAGE_TYPE)));
769   EXPECT_CALL(*metrics,
770               ClientsDropped(Eq(ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED),
771                              Eq(ClientDropReason::INVALID_MASKED_INPUT)))
772       .Times(3);
773   EXPECT_CALL(
774       *metrics,
775       ProtocolOutcomes(Eq(SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING)));
776 
777   auto abort_message = std::make_unique<ClientToServerWrapperMessage>();
778   abort_message->mutable_abort()->set_diagnostic_info("Aborting for test");
779 
780   ClientToServerWrapperMessage valid_message;
781   MaskedInputVector encoded_vector;
782   SecAggVector masked_vector(std::vector<uint64_t>(4, 9), 32);
783   encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
784   (*valid_message.mutable_masked_input_response()
785         ->mutable_vectors())["foobar"] = encoded_vector;
786 
787   auto invalid_message_too_many_vectors =
788       std::make_unique<ClientToServerWrapperMessage>();
789   (*invalid_message_too_many_vectors->mutable_masked_input_response()
790         ->mutable_vectors())["extra"] = encoded_vector;
791 
792   auto invalid_message_wrong_name =
793       std::make_unique<ClientToServerWrapperMessage>();
794   (*invalid_message_wrong_name->mutable_masked_input_response()
795         ->mutable_vectors())["wrong"] = encoded_vector;
796 
797   auto invalid_message_wrong_size =
798       std::make_unique<ClientToServerWrapperMessage>();
799   MaskedInputVector large_encoded_vector;
800   SecAggVector large_masked_vector(std::vector<uint64_t>(7, 9), 32);
801   large_encoded_vector.set_encoded_vector(
802       large_masked_vector.GetAsPackedBytes());
803   (*invalid_message_wrong_size->mutable_masked_input_response()
804         ->mutable_vectors())["foobar"] = large_encoded_vector;
805 
806   auto wrong_message = std::make_unique<ClientToServerWrapperMessage>();
807   wrong_message->mutable_advertise_keys();  // wrong type of message
808 
809   state.HandleMessage(0, std::move(abort_message)).IgnoreError();
810   state
811       .HandleMessage(
812           1, std::make_unique<ClientToServerWrapperMessage>(valid_message))
813       .IgnoreError();
814   state
815       .HandleMessage(
816           1, std::make_unique<ClientToServerWrapperMessage>(valid_message))
817       .IgnoreError();
818   state.HandleMessage(2, std::move(invalid_message_too_many_vectors))
819       .IgnoreError();
820   state.HandleMessage(3, std::move(invalid_message_wrong_name)).IgnoreError();
821   state.HandleMessage(4, std::move(invalid_message_wrong_size)).IgnoreError();
822   state.HandleMessage(5, std::move(wrong_message)).IgnoreError();
823 
824   if (GetParam().enable_async_r2) {
825     RunSchedulers();
826     EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
827   }
828 
829   state.ProceedToNextRound().IgnoreError();  // causes server abort
830 }
831 
TEST_P(SecaggServerR2MaskedInputCollStateTest,MetricsAreRecorded)832 TEST_P(SecaggServerR2MaskedInputCollStateTest, MetricsAreRecorded) {
833   // In this test, clients 0 through 2 send in valid masked inputs, and then we
834   // proceed to the next step even without client 3.
835   MockSecAggServerMetricsListener* metrics =
836       new MockSecAggServerMetricsListener();
837   auto sender = std::make_unique<MockSendToClientsInterface>();
838 
839   SecAggServerR2MaskedInputCollState state(
840       CreateSecAggServerProtocolImpl(3, 4, sender.get(), metrics,
841                                      GetParam().enable_async_r2),
842       0,  // number_of_clients_failed_after_sending_masked_input
843       0,  // number_of_clients_failed_before_sending_masked_input
844       0   // number_of_clients_terminated_without_unmasking
845   );
846 
847   EXPECT_CALL(*metrics, ClientResponseTimes(
848                             Eq(ClientToServerWrapperMessage::
849                                    MessageContentCase::kMaskedInputResponse),
850                             Ge(0)))
851       .Times(3);
852 
853   for (int i = 0; i < 4; ++i) {
854     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
855     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
856     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
857     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
858     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
859     if (i < 3) {
860       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
861     } else {
862       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
863     }
864     if (GetParam().enable_async_r2) {
865       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
866     } else {
867       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 3));
868     }
869     if (i < 3) {
870       // Have client send a vector of the correct size to the server
871       auto client_message = std::make_unique<ClientToServerWrapperMessage>();
872       MaskedInputVector encoded_vector;
873       SecAggVector masked_vector(std::vector<uint64_t>(4, i + 1), 32);
874       encoded_vector.set_encoded_vector(masked_vector.GetAsPackedBytes());
875       (*client_message->mutable_masked_input_response()
876             ->mutable_vectors())["foobar"] = encoded_vector;
877       ASSERT_THAT(state.HandleMessage(i, std::move(client_message)), IsOk());
878       if (GetParam().enable_async_r2) {
879         EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
880       } else {
881         EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
882       }
883     }
884   }
885 
886   if (GetParam().enable_async_r2) {
887     RunSchedulers();
888     EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
889   }
890 
891   ServerToClientWrapperMessage server_message;
892   // TODO(team): 4 -> 3 below, once backwards compatibility not needed.
893   server_message.mutable_unmasking_request()->add_dead_3_client_ids(4);
894   ServerToClientWrapperMessage abort_message;
895   abort_message.mutable_abort()->set_early_success(false);
896   abort_message.mutable_abort()->set_diagnostic_info(
897       "Client did not send MaskedInputCollectionResponse before round "
898       "transition.");
899   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(0);
900   for (int i = 0; i < 3; ++i) {
901     EXPECT_CALL(*sender, Send(i, EqualsProto(server_message))).Times(1);
902   }
903   EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message))).Times(1);
904   EXPECT_CALL(*metrics,
905               RoundTimes(Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION),
906                          Eq(true), Ge(0)));
907   EXPECT_CALL(
908       *metrics,
909       RoundSurvivingClients(
910           Eq(SecAggServerStateKind::R2_MASKED_INPUT_COLLECTION), Eq(3)));
911 
912   auto next_state = state.ProceedToNextRound();
913   ASSERT_THAT(next_state, IsOk());
914   EXPECT_THAT(next_state.value()->State(),
915               Eq(SecAggServerStateKind::R3_UNMASKING));
916 }
917 
918 INSTANTIATE_TEST_SUITE_P(
919     SecaggServerR2MaskedInputCollStateTests,
920     SecaggServerR2MaskedInputCollStateTest,
921     ::testing::ValuesIn<SecAggR2StateTestParams>(
922         {{"r2_async_processing_enabled", true},
923          {"r2_async_processing_disabled", false}}),
924     [](const ::testing::TestParamInfo<
__anona314698d0202(const ::testing::TestParamInfo< SecaggServerR2MaskedInputCollStateTest::ParamType>& info) 925         SecaggServerR2MaskedInputCollStateTest::ParamType>& info) {
926       return info.param.test_name;
927     });
928 
929 }  // namespace
930 }  // namespace secagg
931 }  // namespace fcp
932