1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server_r0_advertise_keys_state.h"
18
19 #include <memory>
20 #include <vector>
21
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
25 #include "fcp/secagg/server/secagg_server_enums.pb.h"
26 #include "fcp/secagg/server/secagg_server_state.h"
27 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
28 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
29 #include "fcp/secagg/shared/compute_session_id.h"
30 #include "fcp/secagg/shared/ecdh_keys.h"
31 #include "fcp/secagg/shared/input_vector_specification.h"
32 #include "fcp/secagg/shared/secagg_messages.pb.h"
33 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
34 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
35 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
36 #include "fcp/testing/testing.h"
37 #include "fcp/tracing/test_tracing_recorder.h"
38
39 namespace fcp {
40 namespace secagg {
41 namespace {
42
43 using ::testing::_;
44 using ::testing::Eq;
45 using ::testing::Ge;
46 using ::testing::IsFalse;
47 using ::testing::IsTrue;
48
CreateAesSecAggServerProtocolImpl(MockSendToClientsInterface * sender,MockSecAggServerMetricsListener * metrics_listener=nullptr)49 std::unique_ptr<AesSecAggServerProtocolImpl> CreateAesSecAggServerProtocolImpl(
50 MockSendToClientsInterface* sender,
51 MockSecAggServerMetricsListener* metrics_listener = nullptr) {
52 auto input_vector_specs = std::vector<InputVectorSpecification>();
53 input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32));
54 SecretSharingGraphFactory factory;
55
56 return std::make_unique<AesSecAggServerProtocolImpl>(
57 factory.CreateCompleteGraph(4, 3), // total number of clients is 4
58 3, // minimum_number_of_clients_to_proceed
59 input_vector_specs,
60 std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
61 std::make_unique<AesCtrPrngFactory>(), sender,
62 nullptr, // prng_runner
63 std::vector<ClientStatus>(4, ClientStatus::READY_TO_START),
64 ServerVariant::NATIVE_V1);
65 }
66
TEST(SecaggServerR0AdvertiseKeysStateTest,IsAbortedReturnsFalse)67 TEST(SecaggServerR0AdvertiseKeysStateTest, IsAbortedReturnsFalse) {
68 auto sender = std::make_unique<MockSendToClientsInterface>();
69
70 SecAggServerR0AdvertiseKeysState state(
71 CreateAesSecAggServerProtocolImpl(sender.get()));
72
73 EXPECT_THAT(state.IsAborted(), IsFalse());
74 }
75
TEST(SecaggServerR0AdvertiseKeysStateTest,IsCompletedSuccessfullyReturnsFalse)76 TEST(SecaggServerR0AdvertiseKeysStateTest,
77 IsCompletedSuccessfullyReturnsFalse) {
78 auto sender = std::make_unique<MockSendToClientsInterface>();
79
80 SecAggServerR0AdvertiseKeysState state(
81 CreateAesSecAggServerProtocolImpl(sender.get()));
82
83 EXPECT_THAT(state.IsCompletedSuccessfully(), IsFalse());
84 }
85
TEST(SecaggServerR0AdvertiseKeysStateTest,ErrorMessageRaisesErrorStatus)86 TEST(SecaggServerR0AdvertiseKeysStateTest, ErrorMessageRaisesErrorStatus) {
87 auto sender = std::make_unique<MockSendToClientsInterface>();
88
89 SecAggServerR0AdvertiseKeysState state(
90 CreateAesSecAggServerProtocolImpl(sender.get()));
91
92 EXPECT_THAT(state.ErrorMessage().ok(), IsFalse());
93 }
94
TEST(SecaggServerR0AdvertiseKeysStateTest,ResultRaisesErrorStatus)95 TEST(SecaggServerR0AdvertiseKeysStateTest, ResultRaisesErrorStatus) {
96 auto sender = std::make_unique<MockSendToClientsInterface>();
97
98 SecAggServerR0AdvertiseKeysState state(
99 CreateAesSecAggServerProtocolImpl(sender.get()));
100
101 EXPECT_THAT(state.Result().ok(), IsFalse());
102 }
103
TEST(SecaggServerR0AdvertiseKeysStateTest,AbortReturnsValidStateAndNotifiesClients)104 TEST(SecaggServerR0AdvertiseKeysStateTest,
105 AbortReturnsValidStateAndNotifiesClients) {
106 TestTracingRecorder tracing_recorder;
107 MockSecAggServerMetricsListener* metrics =
108 new MockSecAggServerMetricsListener();
109 auto sender = std::make_unique<MockSendToClientsInterface>();
110
111 SecAggServerR0AdvertiseKeysState state(
112 CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
113
114 ServerToClientWrapperMessage abort_message;
115 abort_message.mutable_abort()->set_early_success(false);
116 abort_message.mutable_abort()->set_diagnostic_info("test abort reason");
117
118 EXPECT_CALL(*metrics,
119 ProtocolOutcomes(Eq(SecAggServerOutcome::EXTERNAL_REQUEST)));
120 EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message)));
121 auto next_state =
122 state.Abort("test abort reason", SecAggServerOutcome::EXTERNAL_REQUEST);
123
124 ASSERT_THAT(next_state->State(), Eq(SecAggServerStateKind::ABORTED));
125 ASSERT_THAT(next_state->ErrorMessage(), IsOk());
126 EXPECT_THAT(next_state->ErrorMessage().value(), Eq("test abort reason"));
127 EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
128 ElementsAre(IsEvent<BroadcastMessageSent>(
129 Eq(ServerToClientMessageType_Abort),
130 Eq(abort_message.ByteSizeLong()))));
131 }
132
TEST(SecaggServerR0AdvertiseKeysStateTest,StateProceedsCorrectlyWithAllClientsValid)133 TEST(SecaggServerR0AdvertiseKeysStateTest,
134 StateProceedsCorrectlyWithAllClientsValid) {
135 // In this test, all clients send two valid ECDH public keys apiece, and then
136 // the server proceeds to the next state.
137 TestTracingRecorder tracing_recorder;
138 auto sender = std::make_unique<MockSendToClientsInterface>();
139
140 SecAggServerR0AdvertiseKeysState state(
141 CreateAesSecAggServerProtocolImpl(sender.get()));
142
143 EcdhPregeneratedTestKeys ecdh_keys;
144 auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
145 std::vector<ClientToServerWrapperMessage> client_messages(4);
146 ServerToClientWrapperMessage expected_server_message;
147 for (int i = 0; i < 4; ++i) {
148 PairOfPublicKeys* public_keys =
149 expected_server_message.mutable_share_keys_request()
150 ->add_pairs_of_public_keys();
151 client_messages[i]
152 .mutable_advertise_keys()
153 ->mutable_pair_of_public_keys()
154 ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
155 client_messages[i]
156 .mutable_advertise_keys()
157 ->mutable_pair_of_public_keys()
158 ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
159 public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
160 public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
161 }
162 expected_server_message.mutable_share_keys_request()->set_session_id(
163 ComputeSessionId(expected_server_message.share_keys_request()).data);
164
165 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
166 for (int i = 0; i < 4; ++i) {
167 EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
168 .Times(1);
169 }
170
171 for (int i = 0; i < 5; ++i) {
172 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
173 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
174 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
175 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
176 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
177 if (i < 3) {
178 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
179 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
180 } else {
181 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
182 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
183 }
184 if (i < 4) {
185 ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
186 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
187 }
188 }
189
190 auto next_state = state.ProceedToNextRound();
191 ASSERT_THAT(next_state, IsOk());
192 EXPECT_THAT(next_state.value()->State(),
193 Eq(SecAggServerStateKind::R1_SHARE_KEYS));
194 EXPECT_THAT(
195 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
196 Eq(0));
197 EXPECT_THAT(
198 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
199 Eq(0));
200 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
201 Eq(0));
202 EXPECT_THAT(tracing_recorder.FindAllEvents<IndividualMessageSent>(),
203 ElementsAre(IsEvent<IndividualMessageSent>(
204 0, Eq(ServerToClientMessageType_ShareKeysRequest),
205 Eq(expected_server_message.ByteSizeLong())),
206 IsEvent<IndividualMessageSent>(
207 1, Eq(ServerToClientMessageType_ShareKeysRequest),
208 Eq(expected_server_message.ByteSizeLong())),
209 IsEvent<IndividualMessageSent>(
210 2, Eq(ServerToClientMessageType_ShareKeysRequest),
211 Eq(expected_server_message.ByteSizeLong())),
212 IsEvent<IndividualMessageSent>(
213 3, Eq(ServerToClientMessageType_ShareKeysRequest),
214 Eq(expected_server_message.ByteSizeLong()))));
215 }
216
TEST(SecaggServerR0AdvertiseKeysStateTest,StateProceedsCorrectlyWithInvalidKeysFromOneClient)217 TEST(SecaggServerR0AdvertiseKeysStateTest,
218 StateProceedsCorrectlyWithInvalidKeysFromOneClient) {
219 // In this test, client 3 sends invalid public keys, so it should be forced to
220 // abort. But this should not stop the rest of the state proceeding normally.
221 auto sender = std::make_unique<MockSendToClientsInterface>();
222
223 SecAggServerR0AdvertiseKeysState state(
224 CreateAesSecAggServerProtocolImpl(sender.get()));
225
226 EcdhPregeneratedTestKeys ecdh_keys;
227 auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
228 std::vector<ClientToServerWrapperMessage> client_messages(4);
229 ServerToClientWrapperMessage expected_server_message;
230 for (int i = 0; i < 3; ++i) {
231 PairOfPublicKeys* public_keys =
232 expected_server_message.mutable_share_keys_request()
233 ->add_pairs_of_public_keys();
234 client_messages[i]
235 .mutable_advertise_keys()
236 ->mutable_pair_of_public_keys()
237 ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
238 client_messages[i]
239 .mutable_advertise_keys()
240 ->mutable_pair_of_public_keys()
241 ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
242 public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
243 public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
244 }
245 client_messages[3]
246 .mutable_advertise_keys()
247 ->mutable_pair_of_public_keys()
248 ->set_enc_pk(ecdh_keys.GetPublicKeyString(3));
249 client_messages[3]
250 .mutable_advertise_keys()
251 ->mutable_pair_of_public_keys()
252 ->set_noise_pk("This is too long to be a valid key.");
253 expected_server_message.mutable_share_keys_request()
254 ->add_pairs_of_public_keys(); // this one will be empty
255
256 expected_server_message.mutable_share_keys_request()->set_session_id(
257 ComputeSessionId(expected_server_message.share_keys_request()).data);
258
259 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
260 for (int i = 0; i < 3; ++i) {
261 EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
262 .Times(1);
263 }
264
265 ServerToClientWrapperMessage abort_message;
266 abort_message.mutable_abort()->set_early_success(false);
267 abort_message.mutable_abort()->set_diagnostic_info(
268 "A public key sent by the client was not the correct size.");
269
270 EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message)));
271
272 for (int i = 0; i < 4; ++i) {
273 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
274 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
275 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
276 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
277 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
278 if (i < 3) {
279 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
280 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
281 } else {
282 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
283 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
284 }
285 ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
286 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
287 }
288 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
289 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
290 EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
291 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
292 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(3));
293 EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
294 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
295
296 auto next_state = state.ProceedToNextRound();
297 ASSERT_THAT(next_state, IsOk());
298 EXPECT_THAT(next_state.value()->State(),
299 Eq(SecAggServerStateKind::R1_SHARE_KEYS));
300 EXPECT_THAT(
301 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
302 Eq(0));
303 EXPECT_THAT(
304 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
305 Eq(1));
306 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
307 Eq(0));
308 }
309
TEST(SecaggServerR0AdvertiseKeysStateTest,StateProceedsCorrectlyWithNoMessageFromOneClient)310 TEST(SecaggServerR0AdvertiseKeysStateTest,
311 StateProceedsCorrectlyWithNoMessageFromOneClient) {
312 // In this test, we proceed to the next state before client 3 sends any
313 // message, so it should be forced to abort. But this should not stop the rest
314 // of the state proceeding normally.
315 auto sender = std::make_unique<MockSendToClientsInterface>();
316
317 SecAggServerR0AdvertiseKeysState state(
318 CreateAesSecAggServerProtocolImpl(sender.get()));
319
320 EcdhPregeneratedTestKeys ecdh_keys;
321 auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
322 std::vector<ClientToServerWrapperMessage> client_messages(3);
323 ServerToClientWrapperMessage expected_server_message;
324 for (int i = 0; i < 3; ++i) {
325 PairOfPublicKeys* public_keys =
326 expected_server_message.mutable_share_keys_request()
327 ->add_pairs_of_public_keys();
328 client_messages[i]
329 .mutable_advertise_keys()
330 ->mutable_pair_of_public_keys()
331 ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
332 client_messages[i]
333 .mutable_advertise_keys()
334 ->mutable_pair_of_public_keys()
335 ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
336 public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
337 public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
338 }
339 expected_server_message.mutable_share_keys_request()
340 ->add_pairs_of_public_keys(); // this one will be empty
341
342 expected_server_message.mutable_share_keys_request()->set_session_id(
343 ComputeSessionId(expected_server_message.share_keys_request()).data);
344
345 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
346 for (int i = 0; i < 3; ++i) {
347 EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
348 .Times(1);
349 }
350 ServerToClientWrapperMessage abort_message;
351 abort_message.mutable_abort()->set_early_success(false);
352 abort_message.mutable_abort()->set_diagnostic_info(
353 "Client did not send AdvertiseKeys message before round transition.");
354
355 EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message)));
356
357 for (int i = 0; i < 4; ++i) {
358 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
359 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
360 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
361 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
362 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
363 if (i < 3) {
364 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
365 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
366 } else {
367 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
368 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
369 }
370 if (i < 3) {
371 ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
372 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
373 }
374 }
375
376 auto next_state = state.ProceedToNextRound();
377 ASSERT_THAT(next_state, IsOk());
378 EXPECT_THAT(next_state.value()->State(),
379 Eq(SecAggServerStateKind::R1_SHARE_KEYS));
380 EXPECT_THAT(
381 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
382 Eq(0));
383 EXPECT_THAT(
384 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
385 Eq(1));
386 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
387 Eq(0));
388 }
389
TEST(SecaggServerR0AdvertiseKeysStateTest,StateNeedsToAbortIfTooManyClientsAbort)390 TEST(SecaggServerR0AdvertiseKeysStateTest,
391 StateNeedsToAbortIfTooManyClientsAbort) {
392 // In this test, the first two clients send abort messages, so the server
393 // should register that it needs to abort.
394 TestTracingRecorder tracing_recorder;
395 auto sender = std::make_unique<MockSendToClientsInterface>();
396
397 SecAggServerR0AdvertiseKeysState state(
398 CreateAesSecAggServerProtocolImpl(sender.get()));
399
400 for (int i = 0; i < 3; ++i) {
401 EXPECT_THAT(state.NeedsToAbort(), Eq(i >= 2));
402 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4 - i));
403 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(0));
404 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(0));
405 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
406 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3));
407 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
408 if (i < 2) {
409 // Have client abort
410 ClientToServerWrapperMessage abort_message;
411 abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
412 ASSERT_THAT(state.HandleMessage(i, abort_message), IsOk());
413 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 1));
414 }
415 }
416
417 ServerToClientWrapperMessage server_message;
418 server_message.mutable_abort()->set_early_success(false);
419 server_message.mutable_abort()->set_diagnostic_info(
420 "Too many clients aborted.");
421 EXPECT_CALL(*sender, SendBroadcast(EqualsProto(server_message))).Times(1);
422 EXPECT_CALL(*sender, Send(_, _)).Times(0);
423
424 auto next_state = state.ProceedToNextRound();
425 ASSERT_THAT(next_state, IsOk());
426 EXPECT_THAT(next_state.value()->State(), Eq(SecAggServerStateKind::ABORTED));
427 ASSERT_THAT(next_state.value()->ErrorMessage(), IsOk());
428 EXPECT_THAT(next_state.value()->ErrorMessage().value(),
429 Eq("Too many clients aborted."));
430 EXPECT_THAT(tracing_recorder.FindAllEvents<BroadcastMessageSent>(),
431 ElementsAre(IsEvent<BroadcastMessageSent>(
432 Eq(ServerToClientMessageType_Abort),
433 Eq(server_message.ByteSizeLong()))));
434 }
435
TEST(SecaggServerR0AdvertiseKeysStateTest,StateProceedsCorrectlyWithAllUncompressedClientMessages)436 TEST(SecaggServerR0AdvertiseKeysStateTest,
437 StateProceedsCorrectlyWithAllUncompressedClientMessages) {
438 // In this test, all clients send two valid ECDH public keys apiece, and then
439 // the server proceeds to the next state.
440 auto sender = std::make_unique<MockSendToClientsInterface>();
441
442 SecAggServerR0AdvertiseKeysState state(
443 CreateAesSecAggServerProtocolImpl(sender.get()));
444
445 EcdhPregeneratedTestKeys ecdh_keys;
446 auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
447 std::vector<ClientToServerWrapperMessage> client_messages(4);
448 ServerToClientWrapperMessage expected_server_message;
449 for (int i = 0; i < 4; ++i) {
450 PairOfPublicKeys* public_keys =
451 expected_server_message.mutable_share_keys_request()
452 ->add_pairs_of_public_keys();
453 client_messages[i]
454 .mutable_advertise_keys()
455 ->mutable_pair_of_public_keys()
456 ->set_enc_pk(ecdh_keys.GetUncompressedPublicKeyString(i));
457 client_messages[i]
458 .mutable_advertise_keys()
459 ->mutable_pair_of_public_keys()
460 ->set_noise_pk(ecdh_keys.GetUncompressedPublicKeyString(i + 4));
461 public_keys->set_enc_pk(ecdh_keys.GetUncompressedPublicKeyString(i));
462 public_keys->set_noise_pk(ecdh_keys.GetUncompressedPublicKeyString(i + 4));
463 }
464
465 expected_server_message.mutable_share_keys_request()->set_session_id(
466 ComputeSessionId(expected_server_message.share_keys_request()).data);
467
468 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
469 for (int i = 0; i < 4; ++i) {
470 EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
471 .Times(1);
472 }
473
474 for (int i = 0; i < 5; ++i) {
475 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
476 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
477 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
478 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
479 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
480 if (i < 3) {
481 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
482 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
483 } else {
484 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
485 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
486 }
487 if (i < 4) {
488 ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
489 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
490 }
491 }
492
493 auto next_state = state.ProceedToNextRound();
494 ASSERT_THAT(next_state, IsOk());
495 EXPECT_THAT(next_state.value()->State(),
496 Eq(SecAggServerStateKind::R1_SHARE_KEYS));
497 EXPECT_THAT(
498 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
499 Eq(0));
500 EXPECT_THAT(
501 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
502 Eq(0));
503 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
504 Eq(0));
505 }
506
TEST(SecaggServerR0AdvertiseKeysStateTest,MetricsRecordsStart)507 TEST(SecaggServerR0AdvertiseKeysStateTest, MetricsRecordsStart) {
508 MockSecAggServerMetricsListener* metrics =
509 new MockSecAggServerMetricsListener();
510 auto sender = std::make_unique<MockSendToClientsInterface>();
511
512 EXPECT_CALL(*metrics, ProtocolStarts(_));
513
514 SecAggServerR0AdvertiseKeysState state(
515 CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
516
517 EXPECT_THAT(state.Result().ok(), IsFalse());
518 }
519
TEST(SecaggServerR0AdvertiseKeysStateTest,MetricsRecordsMessageSizes)520 TEST(SecaggServerR0AdvertiseKeysStateTest, MetricsRecordsMessageSizes) {
521 // In this test, client 3 sends invalid public keys, so it should be forced to
522 // abort. But this should not stop the rest of the state proceeding normally.
523 TestTracingRecorder tracing_recorder;
524 MockSecAggServerMetricsListener* metrics =
525 new MockSecAggServerMetricsListener();
526 auto sender = std::make_unique<MockSendToClientsInterface>();
527
528 EXPECT_CALL(*metrics, ProtocolStarts(_));
529
530 SecAggServerR0AdvertiseKeysState state(
531 CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
532
533 EcdhPregeneratedTestKeys ecdh_keys;
534 auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
535 std::vector<ClientToServerWrapperMessage> client_messages(4);
536 ServerToClientWrapperMessage expected_server_message;
537 for (int i = 0; i < 3; ++i) {
538 PairOfPublicKeys* public_keys =
539 expected_server_message.mutable_share_keys_request()
540 ->add_pairs_of_public_keys();
541 client_messages[i]
542 .mutable_advertise_keys()
543 ->mutable_pair_of_public_keys()
544 ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
545 client_messages[i]
546 .mutable_advertise_keys()
547 ->mutable_pair_of_public_keys()
548 ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
549 public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
550 public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
551 }
552 client_messages[3]
553 .mutable_advertise_keys()
554 ->mutable_pair_of_public_keys()
555 ->set_enc_pk(ecdh_keys.GetPublicKeyString(3));
556 client_messages[3]
557 .mutable_advertise_keys()
558 ->mutable_pair_of_public_keys()
559 ->set_noise_pk("This is too long to be a valid key.");
560 expected_server_message.mutable_share_keys_request()
561 ->add_pairs_of_public_keys(); // this one will be empty
562
563 expected_server_message.mutable_share_keys_request()->set_session_id(
564 ComputeSessionId(expected_server_message.share_keys_request()).data);
565
566 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
567 for (int i = 0; i < 3; ++i) {
568 EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
569 .Times(1);
570 }
571 ServerToClientWrapperMessage abort_message;
572 abort_message.mutable_abort()->set_early_success(false);
573 abort_message.mutable_abort()->set_diagnostic_info(
574 "A public key sent by the client was not the correct size.");
575 EXPECT_CALL(*sender, Send(3, EqualsProto(abort_message)));
576
577 EXPECT_CALL(*metrics, IndividualMessageSizes(
578 Eq(ServerToClientWrapperMessage::
579 MessageContentCase::kShareKeysRequest),
580 Eq(expected_server_message.ByteSizeLong())))
581 .Times(3);
582 EXPECT_CALL(*metrics,
583 IndividualMessageSizes(
584 Eq(ServerToClientWrapperMessage::MessageContentCase::kAbort),
585 Eq(abort_message.ByteSizeLong())));
586 EXPECT_CALL(
587 *metrics,
588 MessageReceivedSizes(
589 Eq(ClientToServerWrapperMessage::MessageContentCase::kAdvertiseKeys),
590 Eq(true), Eq(client_messages[0].ByteSizeLong())))
591 .Times(3);
592 EXPECT_CALL(
593 *metrics,
594 MessageReceivedSizes(
595 Eq(ClientToServerWrapperMessage::MessageContentCase::kAdvertiseKeys),
596 Eq(true), Eq(client_messages[3].ByteSizeLong())));
597
598 for (int i = 0; i < 4; ++i) {
599 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
600 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
601 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
602 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
603 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
604 if (i < 3) {
605 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
606 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
607 } else {
608 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
609 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
610 }
611 ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
612 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
613 EXPECT_THAT(tracing_recorder.root()[i],
614 IsEvent<ClientMessageReceived>(
615 Eq(ClientToServerMessageType_AdvertiseKeys),
616 Eq(client_messages[i].ByteSizeLong()), Eq(true), Ge(0)));
617 }
618 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
619 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
620 EXPECT_THAT(state.NumberOfAliveClients(), Eq(3));
621 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(3));
622 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(3));
623 EXPECT_THAT(state.NumberOfPendingClients(), Eq(0));
624 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
625
626 auto next_state = state.ProceedToNextRound();
627 ASSERT_THAT(next_state, IsOk());
628 EXPECT_THAT(next_state.value()->State(),
629 Eq(SecAggServerStateKind::R1_SHARE_KEYS));
630 EXPECT_THAT(
631 next_state.value()->NumberOfClientsFailedAfterSendingMaskedInput(),
632 Eq(0));
633 EXPECT_THAT(
634 next_state.value()->NumberOfClientsFailedBeforeSendingMaskedInput(),
635 Eq(1));
636 EXPECT_THAT(next_state.value()->NumberOfClientsTerminatedWithoutUnmasking(),
637 Eq(0));
638 }
639
TEST(SecaggServerR0AdvertiseKeysStateTest,ServerAndClientAbortsAreRecordedCorrectly)640 TEST(SecaggServerR0AdvertiseKeysStateTest,
641 ServerAndClientAbortsAreRecordedCorrectly) {
642 TestTracingRecorder tracing_recorder;
643 // In this test clients abort for a variety of reasons, and then ultimately
644 // the server aborts. Metrics should record all of these events.
645 auto sender = std::make_unique<MockSendToClientsInterface>();
646 MockSecAggServerMetricsListener* metrics =
647 new MockSecAggServerMetricsListener();
648 EcdhPregeneratedTestKeys ecdh_keys;
649
650 SecAggServerR0AdvertiseKeysState state(
651 CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
652
653 EXPECT_CALL(*metrics,
654 ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
655 Eq(ClientDropReason::SENT_ABORT_MESSAGE)));
656 EXPECT_CALL(*metrics,
657 ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
658 Eq(ClientDropReason::ADVERTISE_KEYS_UNEXPECTED)));
659 EXPECT_CALL(*metrics,
660 ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
661 Eq(ClientDropReason::UNEXPECTED_MESSAGE_TYPE)));
662 EXPECT_CALL(*metrics,
663 ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING),
664 Eq(ClientDropReason::INVALID_PUBLIC_KEY)));
665 EXPECT_CALL(
666 *metrics,
667 ProtocolOutcomes(Eq(SecAggServerOutcome::NOT_ENOUGH_CLIENTS_REMAINING)));
668
669 ClientToServerWrapperMessage abort_message;
670 abort_message.mutable_abort()->set_diagnostic_info("Aborting for test");
671 ClientToServerWrapperMessage valid_message;
672 valid_message.mutable_advertise_keys()
673 ->mutable_pair_of_public_keys()
674 ->set_enc_pk(ecdh_keys.GetPublicKeyString(0));
675 valid_message.mutable_advertise_keys()
676 ->mutable_pair_of_public_keys()
677 ->set_noise_pk(ecdh_keys.GetPublicKeyString(4));
678 ClientToServerWrapperMessage invalid_message;
679 invalid_message.mutable_advertise_keys()
680 ->mutable_pair_of_public_keys()
681 ->set_enc_pk(ecdh_keys.GetPublicKeyString(3));
682 invalid_message.mutable_advertise_keys()
683 ->mutable_pair_of_public_keys()
684 ->set_noise_pk("This is too long to be a valid key.");
685 ClientToServerWrapperMessage wrong_message;
686 wrong_message.mutable_share_keys_response(); // wrong type of message
687
688 state.HandleMessage(0, abort_message).IgnoreError();
689 state.HandleMessage(1, valid_message).IgnoreError();
690 state.HandleMessage(1, valid_message).IgnoreError();
691 state.HandleMessage(2, invalid_message).IgnoreError();
692 state.HandleMessage(3, wrong_message).IgnoreError();
693 state.ProceedToNextRound().IgnoreError(); // causes server abort
694
695 EXPECT_THAT(tracing_recorder.FindAllEvents<SecAggProtocolOutcome>(),
696 ElementsAre(IsEvent<SecAggProtocolOutcome>(
697 Eq(TracingSecAggServerOutcome_NotEnoughClientsRemaining))));
698 EXPECT_THAT(
699 tracing_recorder.FindAllEvents<ClientsDropped>(),
700 ElementsAre(IsEvent<ClientsDropped>(
701 Eq(TracingClientStatus_DeadBeforeSendingAnything),
702 Eq(TracingClientDropReason_SentAbortMessage)),
703 IsEvent<ClientsDropped>(
704 Eq(TracingClientStatus_DeadBeforeSendingAnything),
705 Eq(TracingClientDropReason_AdvertiseKeysUnexpected)),
706 IsEvent<ClientsDropped>(
707 Eq(TracingClientStatus_DeadBeforeSendingAnything),
708 Eq(TracingClientDropReason_InvalidPublicKey)),
709 IsEvent<ClientsDropped>(
710 Eq(TracingClientStatus_DeadBeforeSendingAnything),
711 Eq(TracingClientDropReason_UnexpectedMessageType))));
712 }
713
TEST(SecaggServerR0AdvertiseKeysStateTest,MetricsAreRecorded)714 TEST(SecaggServerR0AdvertiseKeysStateTest, MetricsAreRecorded) {
715 // In this test, all clients send two valid ECDH public keys apiece, and then
716 // the server proceeds to the next state.
717 TestTracingRecorder tracing_recorder;
718 MockSecAggServerMetricsListener* metrics =
719 new MockSecAggServerMetricsListener();
720 auto sender = std::make_unique<MockSendToClientsInterface>();
721 SecAggServerR0AdvertiseKeysState state(
722 CreateAesSecAggServerProtocolImpl(sender.get(), metrics));
723
724 EcdhPregeneratedTestKeys ecdh_keys;
725 auto pairwise_public_keys = std::make_unique<std::vector<EcdhPublicKey>>();
726 std::vector<ClientToServerWrapperMessage> client_messages(4);
727 ServerToClientWrapperMessage expected_server_message;
728 for (int i = 0; i < 4; ++i) {
729 PairOfPublicKeys* public_keys =
730 expected_server_message.mutable_share_keys_request()
731 ->add_pairs_of_public_keys();
732 client_messages[i]
733 .mutable_advertise_keys()
734 ->mutable_pair_of_public_keys()
735 ->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
736 client_messages[i]
737 .mutable_advertise_keys()
738 ->mutable_pair_of_public_keys()
739 ->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
740 public_keys->set_enc_pk(ecdh_keys.GetPublicKeyString(i));
741 public_keys->set_noise_pk(ecdh_keys.GetPublicKeyString(i + 4));
742 }
743
744 expected_server_message.mutable_share_keys_request()->set_session_id(
745 ComputeSessionId(expected_server_message.share_keys_request()).data);
746
747 EXPECT_CALL(*sender, SendBroadcast(_)).Times(0);
748 for (int i = 0; i < 4; ++i) {
749 EXPECT_CALL(*sender, Send(i, EqualsProto(expected_server_message)))
750 .Times(1);
751 }
752 EXPECT_CALL(*metrics, RoundTimes(Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS),
753 Eq(true), Ge(0)));
754 EXPECT_CALL(*metrics,
755 RoundSurvivingClients(
756 Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS), Eq(4)));
757 EXPECT_CALL(
758 *metrics,
759 ClientResponseTimes(
760 Eq(ClientToServerWrapperMessage::MessageContentCase::kAdvertiseKeys),
761 Ge(0)))
762 .Times(4);
763
764 for (int i = 0; i < 5; ++i) {
765 EXPECT_THAT(state.NeedsToAbort(), IsFalse());
766 EXPECT_THAT(state.NumberOfAliveClients(), Eq(4));
767 EXPECT_THAT(state.NumberOfClientsReadyForNextRound(), Eq(i));
768 EXPECT_THAT(state.NumberOfMessagesReceivedInThisRound(), Eq(i));
769 EXPECT_THAT(state.NumberOfPendingClients(), Eq(4 - i));
770 if (i < 3) {
771 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(3 - i));
772 EXPECT_THAT(state.ReadyForNextRound(), IsFalse());
773 } else {
774 EXPECT_THAT(state.MinimumMessagesNeededForNextRound(), Eq(0));
775 EXPECT_THAT(state.ReadyForNextRound(), IsTrue());
776 }
777 if (i < 4) {
778 ASSERT_THAT(state.HandleMessage(i, client_messages[i]), IsOk());
779 EXPECT_THAT(state.ReadyForNextRound(), Eq(i >= 2));
780 }
781 }
782
783 auto next_state = state.ProceedToNextRound();
784 ASSERT_THAT(next_state, IsOk());
785 EXPECT_THAT(next_state.value()->State(),
786 Eq(SecAggServerStateKind::R1_SHARE_KEYS));
787 EXPECT_THAT(
788 tracing_recorder.FindAllEvents<StateCompletion>(),
789 ElementsAre(IsEvent<StateCompletion>(
790 Eq(SecAggServerTraceState_R0AdvertiseKeys), Eq(true), Ge(0), Eq(4))));
791 }
792
793 } // namespace
794 } // namespace secagg
795 } // namespace fcp
796