• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_r0_advertise_keys_state.h"
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
25 #include "fcp/secagg/server/secagg_server_enums.pb.h"
26 #include "fcp/secagg/server/secagg_server_state.h"
27 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
28 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
29 #include "fcp/secagg/shared/compute_session_id.h"
30 #include "fcp/secagg/shared/ecdh_keys.h"
31 #include "fcp/secagg/shared/input_vector_specification.h"
32 #include "fcp/secagg/shared/secagg_messages.pb.h"
33 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
34 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
35 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
36 #include "fcp/testing/testing.h"
37 #include "fcp/tracing/test_tracing_recorder.h"
38 
39 namespace fcp {
40 namespace secagg {
41 namespace {
42 
43 using ::testing::_;
44 using ::testing::Eq;
45 using ::testing::Ge;
46 using ::testing::IsFalse;
47 using ::testing::IsTrue;
48 
CreateAesSecAggServerProtocolImpl(MockSendToClientsInterface * sender,MockSecAggServerMetricsListener * metrics_listener=nullptr)49 std::unique_ptr<AesSecAggServerProtocolImpl> CreateAesSecAggServerProtocolImpl(
50     MockSendToClientsInterface* sender,
51     MockSecAggServerMetricsListener* metrics_listener = nullptr) {
52   auto input_vector_specs = std::vector<InputVectorSpecification>();
53   input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
54   SecretSharingGraphFactory factory;
55 
56   return std::make_unique<AesSecAggServerProtocolImpl>(
57       factory.CreateCompleteGraph(4, 3),  // total number of clients is 4
58       3,  // minimum_number_of_clients_to_proceed
59       input_vector_specs,
60       std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
61       std::make_unique<AesCtrPrngFactory>(), sender,
62       nullptr,  // prng_runner
63       std::vector<ClientStatus>(4, ClientStatus::READY_TO_START),
64       ServerVariant::NATIVE_V1);
65 }
66 
TEST(SecaggServerR0AdvertiseKeysStateTest,IsAbortedReturnsFalse)67 TEST(SecaggServerR0AdvertiseKeysStateTest, IsAbortedReturnsFalse) {
68   auto sender = std::make_unique<MockSendToClientsInterface>();
69 
70   SecAggServerR0AdvertiseKeysState state(
71       CreateAesSecAggServerProtocolImpl(sender.get()));
72 
73   EXPECT_THAT(state.IsAborted(), IsFalse());
74 }
75 
TEST(SecaggServerR0AdvertiseKeysStateTest,IsCompletedSuccessfullyReturnsFalse)76 TEST(SecaggServerR0AdvertiseKeysStateTest,
77      IsCompletedSuccessfullyReturnsFalse) {
78   auto sender = std::make_unique<MockSendToClientsInterface>();
79 
80   SecAggServerR0AdvertiseKeysState state(
81       CreateAesSecAggServerProtocolImpl(sender.get()));
82 
83   EXPECT_THAT(state.IsCompletedSuccessfully(), IsFalse());
84 }
85 
TEST(SecaggServerR0AdvertiseKeysStateTest,ErrorMessageRaisesErrorStatus)86 TEST(SecaggServerR0AdvertiseKeysStateTest, ErrorMessageRaisesErrorStatus) {
87   auto sender = std::make_unique<MockSendToClientsInterface>();
88 
89   SecAggServerR0AdvertiseKeysState state(
90       CreateAesSecAggServerProtocolImpl(sender.get()));
91 
92   EXPECT_THAT(state.ErrorMessage().ok(), IsFalse());
93 }
94 
TEST(SecaggServerR0AdvertiseKeysStateTest,ResultRaisesErrorStatus)95 TEST(SecaggServerR0AdvertiseKeysStateTest, ResultRaisesErrorStatus) {
96   auto sender = std::make_unique<MockSendToClientsInterface>();
97 
98   SecAggServerR0AdvertiseKeysState state(
99       CreateAesSecAggServerProtocolImpl(sender.get()));
100 
101   EXPECT_THAT(state.Result().ok(), IsFalse());
102 }
103 
TEST(SecaggServerR0AdvertiseKeysStateTest,AbortReturnsValidStateAndNotifiesClients)104 TEST(SecaggServerR0AdvertiseKeysStateTest,
105      AbortReturnsValidStateAndNotifiesClients) {
106   TestTracingRecorder tracing_recorder;
107   MockSecAggServerMetricsListener* metrics =
108       new MockSecAggServerMetricsListener();
109   auto sender = std::make_unique<MockSendToClientsInterface>();
110 
111   SecAggServerR0AdvertiseKeysState state(
112       CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
113 
114   ServerToClientWrapperMessage abort_message;
115   abort_message.mutable_abort()->set_early_success(false);
116   abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
117 
118   EXPECT_CALL(*metrics,
119               ProtocolOutcomes(Eq(SecAggServerOutcome::EXTERNAL_REQUEST)));
120   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
121   auto next_state =
122       state.Abort("test abort reason", SecAggServerOutcome::EXTERNAL_REQUEST);
123 
124   ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
125   ASSERT_THAT(next_state->ErrorMessage(), IsOk());
126   EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
127   EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
128               ElementsAre(IsEvent<BroadcastMessageSent>(
129                   Eq(ServerToClientMessageType_Abort),
130                   Eq(abort_message.ByteSizeLong()))));
131 }
132 
TEST(SecaggServerR0AdvertiseKeysStateTest,StateProceedsCorrectlyWithAllClientsValid)133 TEST(SecaggServerR0AdvertiseKeysStateTest,
134      StateProceedsCorrectlyWithAllClientsValid) {
135   // In this test, all clients send two valid ECDH public keys apiece, and then
136   // the server proceeds to the next state.
137   TestTracingRecorder tracing_recorder;
138   auto sender = std::make_unique<MockSendToClientsInterface>();
139 
140   SecAggServerR0AdvertiseKeysState state(
141       CreateAesSecAggServerProtocolImpl(sender.get()));
142 
143   EcdhPregeneratedTestKeys ecdh_keys;
144   auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
145   std::vector<ClientToServerWrapperMessage> client_messages(4);
146   ServerToClientWrapperMessage expected_server_message;
147   for (int i = 0; i < 4; ++i) {
148     PairOfPublicKeys* public_keys =
149         expected_server_message.mutable_share_keys_request()
150             ->add_pairs_of_public_keys();
151     client_messages[i]
152         .mutable_advertise_keys()
153         ->mutable_pair_of_public_keys()
154         ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
155     client_messages[i]
156         .mutable_advertise_keys()
157         ->mutable_pair_of_public_keys()
158         ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
159     public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
160     public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
161   }
162   expected_server_message.mutable_share_keys_request()->set_session_id(
163       ComputeSessionId(expected_server_message.share_keys_request()).data);
164 
165   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
166   for (int i = 0; i < 4; ++i) {
167     EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
168         .Times(1);
169   }
170 
171   for (int i = 0; i < 5; ++i) {
172     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
173     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
174     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
175     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
176     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
177     if (i < 3) {
178       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
179       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
180     } else {
181       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
182       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
183     }
184     if (i < 4) {
185       ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
186       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
187     }
188   }
189 
190   auto next_state = state.ProceedToNextRound();
191   ASSERT_THAT(next_state, IsOk());
192   EXPECT_THAT(next_state.value()->State(),
193               Eq(SecAggServerStateKind::R1_SHARE_KEYS));
194   EXPECT_THAT(
195       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
196       Eq(0));
197   EXPECT_THAT(
198       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
199       Eq(0));
200   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
201               Eq(0));
202   EXPECT_THAT(tracing_recorder.FindAllEvents<IndividualMessageSent>(),
203               ElementsAre(IsEvent<IndividualMessageSent>(
204                               0, Eq(ServerToClientMessageType_ShareKeysRequest),
205                               Eq(expected_server_message.ByteSizeLong())),
206                           IsEvent<IndividualMessageSent>(
207                               1, Eq(ServerToClientMessageType_ShareKeysRequest),
208                               Eq(expected_server_message.ByteSizeLong())),
209                           IsEvent<IndividualMessageSent>(
210                               2, Eq(ServerToClientMessageType_ShareKeysRequest),
211                               Eq(expected_server_message.ByteSizeLong())),
212                           IsEvent<IndividualMessageSent>(
213                               3, Eq(ServerToClientMessageType_ShareKeysRequest),
214                               Eq(expected_server_message.ByteSizeLong()))));
215 }
216 
TEST(SecaggServerR0AdvertiseKeysStateTest,StateProceedsCorrectlyWithInvalidKeysFromOneClient)217 TEST(SecaggServerR0AdvertiseKeysStateTest,
218      StateProceedsCorrectlyWithInvalidKeysFromOneClient) {
219   // In this test, client 3 sends invalid public keys, so it should be forced to
220   // abort. But this should not stop the rest of the state proceeding normally.
221   auto sender = std::make_unique<MockSendToClientsInterface>();
222 
223   SecAggServerR0AdvertiseKeysState state(
224       CreateAesSecAggServerProtocolImpl(sender.get()));
225 
226   EcdhPregeneratedTestKeys ecdh_keys;
227   auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
228   std::vector<ClientToServerWrapperMessage> client_messages(4);
229   ServerToClientWrapperMessage expected_server_message;
230   for (int i = 0; i < 3; ++i) {
231     PairOfPublicKeys* public_keys =
232         expected_server_message.mutable_share_keys_request()
233             ->add_pairs_of_public_keys();
234     client_messages[i]
235         .mutable_advertise_keys()
236         ->mutable_pair_of_public_keys()
237         ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
238     client_messages[i]
239         .mutable_advertise_keys()
240         ->mutable_pair_of_public_keys()
241         ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
242     public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
243     public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
244   }
245   client_messages[3]
246       .mutable_advertise_keys()
247       ->mutable_pair_of_public_keys()
248       ->set_enc_pk(ecdh_keys.GetPublicKeyString(3));
249   client_messages[3]
250       .mutable_advertise_keys()
251       ->mutable_pair_of_public_keys()
252       ->set_noise_pk("This is too long to be a valid key.");
253   expected_server_message.mutable_share_keys_request()
254       ->add_pairs_of_public_keys();  // this one will be empty
255 
256   expected_server_message.mutable_share_keys_request()->set_session_id(
257       ComputeSessionId(expected_server_message.share_keys_request()).data);
258 
259   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
260   for (int i = 0; i < 3; ++i) {
261     EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
262         .Times(1);
263   }
264 
265   ServerToClientWrapperMessage abort_message;
266   abort_message.mutable_abort()->set_early_success(false);
267   abort_message.mutable_abort()->set_diagnostic_info(
268       "A public key sent by the client was not the correct size.");
269 
270   EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message)));
271 
272   for (int i = 0; i < 4; ++i) {
273     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
274     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
275     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
276     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
277     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
278     if (i < 3) {
279       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
280       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
281     } else {
282       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
283       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
284     }
285     ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
286     EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
287   }
288   EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
289   EXPECT_THAT(state.NeedsToAbort(), IsFalse());
290   EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
291   EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
292   EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(3));
293   EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
294   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
295 
296   auto next_state = state.ProceedToNextRound();
297   ASSERT_THAT(next_state, IsOk());
298   EXPECT_THAT(next_state.value()->State(),
299               Eq(SecAggServerStateKind::R1_SHARE_KEYS));
300   EXPECT_THAT(
301       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
302       Eq(0));
303   EXPECT_THAT(
304       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
305       Eq(1));
306   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
307               Eq(0));
308 }
309 
TEST(SecaggServerR0AdvertiseKeysStateTest,StateProceedsCorrectlyWithNoMessageFromOneClient)310 TEST(SecaggServerR0AdvertiseKeysStateTest,
311      StateProceedsCorrectlyWithNoMessageFromOneClient) {
312   // In this test, we proceed to the next state before client 3 sends any
313   // message, so it should be forced to abort. But this should not stop the rest
314   // of the state proceeding normally.
315   auto sender = std::make_unique<MockSendToClientsInterface>();
316 
317   SecAggServerR0AdvertiseKeysState state(
318       CreateAesSecAggServerProtocolImpl(sender.get()));
319 
320   EcdhPregeneratedTestKeys ecdh_keys;
321   auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
322   std::vector<ClientToServerWrapperMessage> client_messages(3);
323   ServerToClientWrapperMessage expected_server_message;
324   for (int i = 0; i < 3; ++i) {
325     PairOfPublicKeys* public_keys =
326         expected_server_message.mutable_share_keys_request()
327             ->add_pairs_of_public_keys();
328     client_messages[i]
329         .mutable_advertise_keys()
330         ->mutable_pair_of_public_keys()
331         ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
332     client_messages[i]
333         .mutable_advertise_keys()
334         ->mutable_pair_of_public_keys()
335         ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
336     public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
337     public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
338   }
339   expected_server_message.mutable_share_keys_request()
340       ->add_pairs_of_public_keys();  // this one will be empty
341 
342   expected_server_message.mutable_share_keys_request()->set_session_id(
343       ComputeSessionId(expected_server_message.share_keys_request()).data);
344 
345   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
346   for (int i = 0; i < 3; ++i) {
347     EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
348         .Times(1);
349   }
350   ServerToClientWrapperMessage abort_message;
351   abort_message.mutable_abort()->set_early_success(false);
352   abort_message.mutable_abort()->set_diagnostic_info(
353       "Client did not send AdvertiseKeys message before round transition.");
354 
355   EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message)));
356 
357   for (int i = 0; i < 4; ++i) {
358     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
359     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
360     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
361     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
362     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
363     if (i < 3) {
364       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
365       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
366     } else {
367       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
368       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
369     }
370     if (i < 3) {
371       ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
372       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
373     }
374   }
375 
376   auto next_state = state.ProceedToNextRound();
377   ASSERT_THAT(next_state, IsOk());
378   EXPECT_THAT(next_state.value()->State(),
379               Eq(SecAggServerStateKind::R1_SHARE_KEYS));
380   EXPECT_THAT(
381       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
382       Eq(0));
383   EXPECT_THAT(
384       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
385       Eq(1));
386   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
387               Eq(0));
388 }
389 
TEST(SecaggServerR0AdvertiseKeysStateTest,StateNeedsToAbortIfTooManyClientsAbort)390 TEST(SecaggServerR0AdvertiseKeysStateTest,
391      StateNeedsToAbortIfTooManyClientsAbort) {
392   // In this test, the first two clients send abort messages, so the server
393   // should register that it needs to abort.
394   TestTracingRecorder tracing_recorder;
395   auto sender = std::make_unique<MockSendToClientsInterface>();
396 
397   SecAggServerR0AdvertiseKeysState state(
398       CreateAesSecAggServerProtocolImpl(sender.get()));
399 
400   for (int i = 0; i < 3; ++i) {
401     EXPECT_THAT(state.NeedsToAbort(), Eq(i >= 2));
402     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4 - i));
403     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
404     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
405     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
406     EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3));
407     EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
408     if (i < 2) {
409       // Have client abort
410       ClientToServerWrapperMessage abort_message;
411       abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
412       ASSERT_THAT(state.HandleMessage(i, abort_message), IsOk());
413       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 1));
414     }
415   }
416 
417   ServerToClientWrapperMessage server_message;
418   server_message.mutable_abort()->set_early_success(false);
419   server_message.mutable_abort()->set_diagnostic_info(
420       "Too many clients aborted.");
421   EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(1);
422   EXPECT_CALL(*sender, Send(_, _)).Times(0);
423 
424   auto next_state = state.ProceedToNextRound();
425   ASSERT_THAT(next_state, IsOk());
426   EXPECT_THAT(next_state.value()->State(), Eq(SecAggServerStateKind::ABORTED));
427   ASSERT_THAT(next_state.value()->ErrorMessage(), IsOk());
428   EXPECT_THAT(next_state.value()->ErrorMessage().value(),
429               Eq("Too many clients aborted."));
430   EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
431               ElementsAre(IsEvent<BroadcastMessageSent>(
432                   Eq(ServerToClientMessageType_Abort),
433                   Eq(server_message.ByteSizeLong()))));
434 }
435 
TEST(SecaggServerR0AdvertiseKeysStateTest,StateProceedsCorrectlyWithAllUncompressedClientMessages)436 TEST(SecaggServerR0AdvertiseKeysStateTest,
437      StateProceedsCorrectlyWithAllUncompressedClientMessages) {
438   // In this test, all clients send two valid ECDH public keys apiece, and then
439   // the server proceeds to the next state.
440   auto sender = std::make_unique<MockSendToClientsInterface>();
441 
442   SecAggServerR0AdvertiseKeysState state(
443       CreateAesSecAggServerProtocolImpl(sender.get()));
444 
445   EcdhPregeneratedTestKeys ecdh_keys;
446   auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
447   std::vector<ClientToServerWrapperMessage> client_messages(4);
448   ServerToClientWrapperMessage expected_server_message;
449   for (int i = 0; i < 4; ++i) {
450     PairOfPublicKeys* public_keys =
451         expected_server_message.mutable_share_keys_request()
452             ->add_pairs_of_public_keys();
453     client_messages[i]
454         .mutable_advertise_keys()
455         ->mutable_pair_of_public_keys()
456         ->set_enc_pk(ecdh_keys.GetUncompressedPublicKeyString(i));
457     client_messages[i]
458         .mutable_advertise_keys()
459         ->mutable_pair_of_public_keys()
460         ->set_noise_pk(ecdh_keys.GetUncompressedPublicKeyString(i + 4));
461     public_keys->set_enc_pk(ecdh_keys.GetUncompressedPublicKeyString(i));
462     public_keys->set_noise_pk(ecdh_keys.GetUncompressedPublicKeyString(i + 4));
463   }
464 
465   expected_server_message.mutable_share_keys_request()->set_session_id(
466       ComputeSessionId(expected_server_message.share_keys_request()).data);
467 
468   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
469   for (int i = 0; i < 4; ++i) {
470     EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
471         .Times(1);
472   }
473 
474   for (int i = 0; i < 5; ++i) {
475     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
476     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
477     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
478     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
479     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
480     if (i < 3) {
481       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
482       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
483     } else {
484       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
485       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
486     }
487     if (i < 4) {
488       ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
489       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
490     }
491   }
492 
493   auto next_state = state.ProceedToNextRound();
494   ASSERT_THAT(next_state, IsOk());
495   EXPECT_THAT(next_state.value()->State(),
496               Eq(SecAggServerStateKind::R1_SHARE_KEYS));
497   EXPECT_THAT(
498       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
499       Eq(0));
500   EXPECT_THAT(
501       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
502       Eq(0));
503   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
504               Eq(0));
505 }
506 
TEST(SecaggServerR0AdvertiseKeysStateTest,MetricsRecordsStart)507 TEST(SecaggServerR0AdvertiseKeysStateTest, MetricsRecordsStart) {
508   MockSecAggServerMetricsListener* metrics =
509       new MockSecAggServerMetricsListener();
510   auto sender = std::make_unique<MockSendToClientsInterface>();
511 
512   EXPECT_CALL(*metrics, ProtocolStarts(_));
513 
514   SecAggServerR0AdvertiseKeysState state(
515       CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
516 
517   EXPECT_THAT(state.Result().ok(), IsFalse());
518 }
519 
TEST(SecaggServerR0AdvertiseKeysStateTest,MetricsRecordsMessageSizes)520 TEST(SecaggServerR0AdvertiseKeysStateTest, MetricsRecordsMessageSizes) {
521   // In this test, client 3 sends invalid public keys, so it should be forced to
522   // abort. But this should not stop the rest of the state proceeding normally.
523   TestTracingRecorder tracing_recorder;
524   MockSecAggServerMetricsListener* metrics =
525       new MockSecAggServerMetricsListener();
526   auto sender = std::make_unique<MockSendToClientsInterface>();
527 
528   EXPECT_CALL(*metrics, ProtocolStarts(_));
529 
530   SecAggServerR0AdvertiseKeysState state(
531       CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
532 
533   EcdhPregeneratedTestKeys ecdh_keys;
534   auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
535   std::vector<ClientToServerWrapperMessage> client_messages(4);
536   ServerToClientWrapperMessage expected_server_message;
537   for (int i = 0; i < 3; ++i) {
538     PairOfPublicKeys* public_keys =
539         expected_server_message.mutable_share_keys_request()
540             ->add_pairs_of_public_keys();
541     client_messages[i]
542         .mutable_advertise_keys()
543         ->mutable_pair_of_public_keys()
544         ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
545     client_messages[i]
546         .mutable_advertise_keys()
547         ->mutable_pair_of_public_keys()
548         ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
549     public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
550     public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
551   }
552   client_messages[3]
553       .mutable_advertise_keys()
554       ->mutable_pair_of_public_keys()
555       ->set_enc_pk(ecdh_keys.GetPublicKeyString(3));
556   client_messages[3]
557       .mutable_advertise_keys()
558       ->mutable_pair_of_public_keys()
559       ->set_noise_pk("This is too long to be a valid key.");
560   expected_server_message.mutable_share_keys_request()
561       ->add_pairs_of_public_keys();  // this one will be empty
562 
563   expected_server_message.mutable_share_keys_request()->set_session_id(
564       ComputeSessionId(expected_server_message.share_keys_request()).data);
565 
566   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
567   for (int i = 0; i < 3; ++i) {
568     EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
569         .Times(1);
570   }
571   ServerToClientWrapperMessage abort_message;
572   abort_message.mutable_abort()->set_early_success(false);
573   abort_message.mutable_abort()->set_diagnostic_info(
574       "A public key sent by the client was not the correct size.");
575   EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message)));
576 
577   EXPECT_CALL(*metrics, IndividualMessageSizes(
578                             Eq(ServerToClientWrapperMessage::
579                                    MessageContentCase::kShareKeysRequest),
580                             Eq(expected_server_message.ByteSizeLong())))
581       .Times(3);
582   EXPECT_CALL(*metrics,
583               IndividualMessageSizes(
584                   Eq(ServerToClientWrapperMessage::MessageContentCase::kAbort),
585                   Eq(abort_message.ByteSizeLong())));
586   EXPECT_CALL(
587       *metrics,
588       MessageReceivedSizes(
589           Eq(ClientToServerWrapperMessage::MessageContentCase::kAdvertiseKeys),
590           Eq(true), Eq(client_messages[0].ByteSizeLong())))
591       .Times(3);
592   EXPECT_CALL(
593       *metrics,
594       MessageReceivedSizes(
595           Eq(ClientToServerWrapperMessage::MessageContentCase::kAdvertiseKeys),
596           Eq(true), Eq(client_messages[3].ByteSizeLong())));
597 
598   for (int i = 0; i < 4; ++i) {
599     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
600     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
601     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
602     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
603     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
604     if (i < 3) {
605       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
606       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
607     } else {
608       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
609       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
610     }
611     ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
612     EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
613     EXPECT_THAT(tracing_recorder.root()[i],
614                 IsEvent<ClientMessageReceived>(
615                     Eq(ClientToServerMessageType_AdvertiseKeys),
616                     Eq(client_messages[i].ByteSizeLong()), Eq(true), Ge(0)));
617   }
618   EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
619   EXPECT_THAT(state.NeedsToAbort(), IsFalse());
620   EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
621   EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
622   EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(3));
623   EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
624   EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
625 
626   auto next_state = state.ProceedToNextRound();
627   ASSERT_THAT(next_state, IsOk());
628   EXPECT_THAT(next_state.value()->State(),
629               Eq(SecAggServerStateKind::R1_SHARE_KEYS));
630   EXPECT_THAT(
631       next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
632       Eq(0));
633   EXPECT_THAT(
634       next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
635       Eq(1));
636   EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
637               Eq(0));
638 }
639 
TEST(SecaggServerR0AdvertiseKeysStateTest,ServerAndClientAbortsAreRecordedCorrectly)640 TEST(SecaggServerR0AdvertiseKeysStateTest,
641      ServerAndClientAbortsAreRecordedCorrectly) {
642   TestTracingRecorder tracing_recorder;
643   // In this test clients abort for a variety of reasons, and then ultimately
644   // the server aborts. Metrics should record all of these events.
645   auto sender = std::make_unique<MockSendToClientsInterface>();
646   MockSecAggServerMetricsListener* metrics =
647       new MockSecAggServerMetricsListener();
648   EcdhPregeneratedTestKeys ecdh_keys;
649 
650   SecAggServerR0AdvertiseKeysState state(
651       CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
652 
653   EXPECT_CALL(*metrics,
654               ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
655                              Eq(ClientDropReason::SENT_ABORT_MESSAGE)));
656   EXPECT_CALL(*metrics,
657               ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
658                              Eq(ClientDropReason::ADVERTISE_KEYS_UNEXPECTED)));
659   EXPECT_CALL(*metrics,
660               ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
661                              Eq(ClientDropReason::UNEXPECTED_MESSAGE_TYPE)));
662   EXPECT_CALL(*metrics,
663               ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
664                              Eq(ClientDropReason::INVALID_PUBLIC_KEY)));
665   EXPECT_CALL(
666       *metrics,
667       ProtocolOutcomes(Eq(SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING)));
668 
669   ClientToServerWrapperMessage abort_message;
670   abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
671   ClientToServerWrapperMessage valid_message;
672   valid_message.mutable_advertise_keys()
673       ->mutable_pair_of_public_keys()
674       ->set_enc_pk(ecdh_keys.GetPublicKeyString(0));
675   valid_message.mutable_advertise_keys()
676       ->mutable_pair_of_public_keys()
677       ->set_noise_pk(ecdh_keys.GetPublicKeyString(4));
678   ClientToServerWrapperMessage invalid_message;
679   invalid_message.mutable_advertise_keys()
680       ->mutable_pair_of_public_keys()
681       ->set_enc_pk(ecdh_keys.GetPublicKeyString(3));
682   invalid_message.mutable_advertise_keys()
683       ->mutable_pair_of_public_keys()
684       ->set_noise_pk("This is too long to be a valid key.");
685   ClientToServerWrapperMessage wrong_message;
686   wrong_message.mutable_share_keys_response();  // wrong type of message
687 
688   state.HandleMessage(0, abort_message).IgnoreError();
689   state.HandleMessage(1, valid_message).IgnoreError();
690   state.HandleMessage(1, valid_message).IgnoreError();
691   state.HandleMessage(2, invalid_message).IgnoreError();
692   state.HandleMessage(3, wrong_message).IgnoreError();
693   state.ProceedToNextRound().IgnoreError();  // causes server abort
694 
695   EXPECT_THAT(tracing_recorder.FindAllEvents<SecAggProtocolOutcome>(),
696               ElementsAre(IsEvent<SecAggProtocolOutcome>(
697                   Eq(TracingSecAggServerOutcome_NotEnoughClientsRemaining))));
698   EXPECT_THAT(
699       tracing_recorder.FindAllEvents<ClientsDropped>(),
700       ElementsAre(IsEvent<ClientsDropped>(
701                       Eq(TracingClientStatus_DeadBeforeSendingAnything),
702                       Eq(TracingClientDropReason_SentAbortMessage)),
703                   IsEvent<ClientsDropped>(
704                       Eq(TracingClientStatus_DeadBeforeSendingAnything),
705                       Eq(TracingClientDropReason_AdvertiseKeysUnexpected)),
706                   IsEvent<ClientsDropped>(
707                       Eq(TracingClientStatus_DeadBeforeSendingAnything),
708                       Eq(TracingClientDropReason_InvalidPublicKey)),
709                   IsEvent<ClientsDropped>(
710                       Eq(TracingClientStatus_DeadBeforeSendingAnything),
711                       Eq(TracingClientDropReason_UnexpectedMessageType))));
712 }
713 
TEST(SecaggServerR0AdvertiseKeysStateTest,MetricsAreRecorded)714 TEST(SecaggServerR0AdvertiseKeysStateTest, MetricsAreRecorded) {
715   // In this test, all clients send two valid ECDH public keys apiece, and then
716   // the server proceeds to the next state.
717   TestTracingRecorder tracing_recorder;
718   MockSecAggServerMetricsListener* metrics =
719       new MockSecAggServerMetricsListener();
720   auto sender = std::make_unique<MockSendToClientsInterface>();
721   SecAggServerR0AdvertiseKeysState state(
722       CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
723 
724   EcdhPregeneratedTestKeys ecdh_keys;
725   auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
726   std::vector<ClientToServerWrapperMessage> client_messages(4);
727   ServerToClientWrapperMessage expected_server_message;
728   for (int i = 0; i < 4; ++i) {
729     PairOfPublicKeys* public_keys =
730         expected_server_message.mutable_share_keys_request()
731             ->add_pairs_of_public_keys();
732     client_messages[i]
733         .mutable_advertise_keys()
734         ->mutable_pair_of_public_keys()
735         ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
736     client_messages[i]
737         .mutable_advertise_keys()
738         ->mutable_pair_of_public_keys()
739         ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
740     public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
741     public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
742   }
743 
744   expected_server_message.mutable_share_keys_request()->set_session_id(
745       ComputeSessionId(expected_server_message.share_keys_request()).data);
746 
747   EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
748   for (int i = 0; i < 4; ++i) {
749     EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
750         .Times(1);
751   }
752   EXPECT_CALL(*metrics, RoundTimes(Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS),
753                                    Eq(true), Ge(0)));
754   EXPECT_CALL(*metrics,
755               RoundSurvivingClients(
756                   Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS), Eq(4)));
757   EXPECT_CALL(
758       *metrics,
759       ClientResponseTimes(
760           Eq(ClientToServerWrapperMessage::MessageContentCase::kAdvertiseKeys),
761           Ge(0)))
762       .Times(4);
763 
764   for (int i = 0; i < 5; ++i) {
765     EXPECT_THAT(state.NeedsToAbort(), IsFalse());
766     EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
767     EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
768     EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
769     EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
770     if (i < 3) {
771       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
772       EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
773     } else {
774       EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
775       EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
776     }
777     if (i < 4) {
778       ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
779       EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
780     }
781   }
782 
783   auto next_state = state.ProceedToNextRound();
784   ASSERT_THAT(next_state, IsOk());
785   EXPECT_THAT(next_state.value()->State(),
786               Eq(SecAggServerStateKind::R1_SHARE_KEYS));
787   EXPECT_THAT(
788       tracing_recorder.FindAllEvents<StateCompletion>(),
789       ElementsAre(IsEvent<StateCompletion>(
790           Eq(SecAggServerTraceState_R0AdvertiseKeys), Eq(true), Ge(0), Eq(4))));
791 }
792 
793 }  // namespace
794 }  // namespace secagg
795 }  // namespace fcp
796