1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/http/http_secagg_send_to_server_impl.h"
17
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22
23 #include "google/longrunning/operations.pb.h"
24 #include "google/rpc/code.pb.h"
25 #include "absl/time/time.h"
26 #include "fcp/base/simulated_clock.h"
27 #include "fcp/client/http/testing/test_helpers.h"
28 #include "fcp/client/test_helpers.h"
29 #include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
30 #include "fcp/testing/testing.h"
31
32 namespace fcp {
33 namespace client {
34 namespace http {
35 namespace {
36
37 using ::google::internal::federatedcompute::v1::AbortSecureAggregationRequest;
38 using ::google::internal::federatedcompute::v1::AbortSecureAggregationResponse;
39 using ::google::internal::federatedcompute::v1::AdvertiseKeysRequest;
40 using ::google::internal::federatedcompute::v1::AdvertiseKeysResponse;
41 using ::google::internal::federatedcompute::v1::ByteStreamResource;
42 using ::google::internal::federatedcompute::v1::ForwardingInfo;
43 using ::google::internal::federatedcompute::v1::SecureAggregandExecutionInfo;
44 using ::google::internal::federatedcompute::v1::ShareKeysRequest;
45 using ::google::internal::federatedcompute::v1::ShareKeysResponse;
46 using ::google::internal::federatedcompute::v1::
47 SubmitSecureAggregationResultRequest;
48 using ::google::internal::federatedcompute::v1::
49 SubmitSecureAggregationResultResponse;
50 using ::google::internal::federatedcompute::v1::UnmaskRequest;
51 using ::google::internal::federatedcompute::v1::UnmaskResponse;
52 using ::google::longrunning::Operation;
53 using ::testing::_;
54 using ::testing::NiceMock;
55 using ::testing::Return;
56 using ::testing::StrictMock;
57
58 constexpr absl::string_view kAggregationId = "aggregation_id";
59 constexpr absl::string_view kClientToken = "client_token";
60 constexpr absl::string_view kSecureAggregationTargetUri =
61 "https://secureaggregation.uri/";
62 constexpr absl::string_view kByteStreamTargetUri = "https://bytestream.uri/";
63 constexpr absl::string_view kMaskedResourceName = "masked_resource";
64 constexpr absl::string_view kNonmaskedResourceName = "nonmasked_resource";
65 constexpr absl::string_view kOperationName = "my_operation";
66 constexpr absl::string_view kApiKey = "API_KEY";
67 constexpr absl::Duration kDelayedInterruptibleRunnerDeadline =
68 absl::Seconds(10);
69
TEST(HttpSecAggProtocolDelegateTest,GetModulus)70 TEST(HttpSecAggProtocolDelegateTest, GetModulus) {
71 absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
72 std::string tensor_key = "tensor_1";
73 google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
74 SecureAggregandExecutionInfo secure_aggregand_execution_info;
75 secure_aggregand_execution_info.set_modulus(12345);
76 secure_aggregands[tensor_key] = secure_aggregand_execution_info;
77 HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
78 auto modulus = delegate.GetModulus(tensor_key);
79 ASSERT_OK(modulus);
80 ASSERT_EQ(*modulus, 12345);
81 }
82
TEST(HttpSecAggProtocolDelegateTest,GetModulusKeyNotFound)83 TEST(HttpSecAggProtocolDelegateTest, GetModulusKeyNotFound) {
84 absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
85 google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
86 SecureAggregandExecutionInfo secure_aggregand_execution_info;
87 secure_aggregand_execution_info.set_modulus(12345);
88 secure_aggregands["tensor_1"] = secure_aggregand_execution_info;
89 HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
90 ASSERT_THAT(delegate.GetModulus("do_not_exist"),
91 IsCode(absl::StatusCode::kInternal));
92 }
93
TEST(HttpSecAggProtocolDelegateTest,ReceiveMessageOkResponse)94 TEST(HttpSecAggProtocolDelegateTest, ReceiveMessageOkResponse) {
95 absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
96 google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
97 HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
98 secagg::ServerToClientWrapperMessage server_response;
99 server_response.mutable_masked_input_request()->add_encrypted_key_shares(
100 "encrypted_key");
101 holder = server_response;
102
103 auto server_message = delegate.ReceiveServerMessage();
104 ASSERT_OK(server_message);
105 ASSERT_THAT(*server_message, EqualsProto(server_response));
106 ASSERT_EQ(delegate.last_received_message_size(),
107 server_response.ByteSizeLong());
108 }
109
TEST(HttpSecAggProtocolDelegateTest,ReceiveMessageErrorResponse)110 TEST(HttpSecAggProtocolDelegateTest, ReceiveMessageErrorResponse) {
111 absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
112 google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
113 HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
114 holder = absl::InternalError("Something is broken.");
115
116 ASSERT_THAT(delegate.ReceiveServerMessage(),
117 IsCode(absl::StatusCode::kInternal));
118 ASSERT_EQ(delegate.last_received_message_size(), 0);
119 }
120
121 class HttpSecAggSendToServerImplTest : public ::testing::Test {
122 protected:
SetUp()123 void SetUp() override {
124 request_helper_ = std::make_unique<ProtocolRequestHelper>(
125 &http_client_, &bytes_downloaded_, &bytes_uploaded_,
126 network_stopwatch_.get(), Clock::RealClock());
127 runner_ = std::make_unique<InterruptibleRunner>(
128 &log_manager_, []() { return false; },
129 InterruptibleRunner::TimingConfig{
130 .polling_period = absl::ZeroDuration(),
131 .graceful_shutdown_period = absl::InfiniteDuration(),
132 .extended_shutdown_period = absl::InfiniteDuration()},
133 InterruptibleRunner::DiagnosticsConfig{
134 .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
135 .interrupt_timeout =
136 ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
137 .interrupted_extended = ProdDiagCode::
138 BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
139 .interrupt_timeout_extended = ProdDiagCode::
140 BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT});
141 *secagg_upload_forwarding_info_.mutable_target_uri_prefix() =
142 kSecureAggregationTargetUri;
143 *masked_result_resource_.mutable_resource_name() = kMaskedResourceName;
144 ForwardingInfo masked_resource_forwarding_info;
145 *masked_resource_forwarding_info.mutable_target_uri_prefix() =
146 kByteStreamTargetUri;
147 *masked_result_resource_.mutable_data_upload_forwarding_info() =
148 masked_resource_forwarding_info;
149 *nonmasked_result_resource_.mutable_resource_name() =
150 kNonmaskedResourceName;
151 ForwardingInfo nonmasked_resource_forwarding_info;
152 *nonmasked_resource_forwarding_info.mutable_target_uri_prefix() =
153 kByteStreamTargetUri;
154 *nonmasked_result_resource_.mutable_data_upload_forwarding_info() =
155 nonmasked_resource_forwarding_info;
156 }
157
CreateInterruptibleRunner()158 std::unique_ptr<InterruptibleRunner> CreateInterruptibleRunner() {
159 return std::make_unique<InterruptibleRunner>(
160 &log_manager_, [this]() { return interrupted_; },
161 InterruptibleRunner::TimingConfig{
162 .polling_period = absl::ZeroDuration(),
163 .graceful_shutdown_period = absl::InfiniteDuration(),
164 .extended_shutdown_period = absl::InfiniteDuration()},
165 InterruptibleRunner::DiagnosticsConfig{
166 .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
167 .interrupt_timeout =
168 ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
169 .interrupted_extended = ProdDiagCode::
170 BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
171 .interrupt_timeout_extended = ProdDiagCode::
172 BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT});
173 }
174
CreateSecAggSendToServer(std::optional<std::string> tf_checkpoint)175 std::unique_ptr<HttpSecAggSendToServerImpl> CreateSecAggSendToServer(
176 std::optional<std::string> tf_checkpoint) {
177 auto send_to_server = HttpSecAggSendToServerImpl::Create(
178 kApiKey, &clock_, request_helper_.get(), runner_.get(),
179 /* delayed_interruptible_runner_creator=*/
180 [this](absl::Time deadline) {
181 // Ensure that the HttpSecAggSendToServerImpl implementation correctly
182 // passes a deadline that matches the 'waiting period' value we
183 // provide below (with a 1s grace period to account for the delay of
184 // executing the actual test; unfortunately the underlying HTTP code
185 // currently still uses absl::Now() directly, so we're forced to deal
186 // with 'real' time...).
187 //
188 // We don't actually use the deadline value though, since that
189 // would only make testing more complicated.
190 EXPECT_GE(deadline, absl::Now() +
191 kDelayedInterruptibleRunnerDeadline -
192 absl::Seconds(1));
193 return CreateInterruptibleRunner();
194 },
195 &server_response_holder_, kAggregationId, kClientToken,
196 secagg_upload_forwarding_info_, masked_result_resource_,
197 nonmasked_result_resource_, tf_checkpoint,
198 /* disable_request_body_compression=*/true,
199 /* waiting_period_for_cancellation=*/
200 kDelayedInterruptibleRunnerDeadline);
201 FCP_CHECK(send_to_server.ok());
202 return std::move(*send_to_server);
203 }
204 bool interrupted_ = false;
205 // We set the simulated clock to "now", since a bunch of the HTTP-related FCP
206 // code currently still uses absl::Now() directly, rather than using a more
207 // testable "Clock" object. This ensures various timestamps we may encounter
208 // are more understandable.
209 SimulatedClock clock_ = SimulatedClock(absl::Now());
210 StrictMock<MockHttpClient> http_client_;
211 NiceMock<MockLogManager> log_manager_;
212 int64_t bytes_downloaded_ = 0;
213 int64_t bytes_uploaded_ = 0;
214 std::unique_ptr<WallClockStopwatch> network_stopwatch_ =
215 WallClockStopwatch::Create();
216 std::unique_ptr<ProtocolRequestHelper> request_helper_;
217 std::unique_ptr<InterruptibleRunner> runner_;
218 absl::StatusOr<secagg::ServerToClientWrapperMessage> server_response_holder_;
219 ForwardingInfo secagg_upload_forwarding_info_;
220 ByteStreamResource masked_result_resource_;
221 ByteStreamResource nonmasked_result_resource_;
222 };
223
TEST_F(HttpSecAggSendToServerImplTest,TestSendR0AdvertiseKeys)224 TEST_F(HttpSecAggSendToServerImplTest, TestSendR0AdvertiseKeys) {
225 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
226 CreateSecAggSendToServer(std::optional<std::string>());
227 secagg::ClientToServerWrapperMessage server_message;
228 auto pair_of_keys =
229 server_message.mutable_advertise_keys()->mutable_pair_of_public_keys();
230 pair_of_keys->set_enc_pk("enc_pk");
231 pair_of_keys->set_noise_pk("noise_pk");
232 // Create expected request.
233 AdvertiseKeysRequest expected_request;
234 *expected_request.mutable_advertise_keys() = server_message.advertise_keys();
235
236 EXPECT_CALL(
237 http_client_,
238 PerformSingleRequest(SimpleHttpRequestMatcher(
239 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
240 "clients/client_token:advertisekeys?%24alt=proto",
241 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
242 .WillOnce(Return(FakeHttpResponse(
243 200, HeaderList(),
244 CreatePendingOperation("operations/foo#bar").SerializeAsString())));
245 // Create expected response.
246 secagg::ShareKeysRequest share_keys_request;
247 *share_keys_request.add_pairs_of_public_keys()->mutable_noise_pk() =
248 "noise_pk";
249 AdvertiseKeysResponse advertise_keys_response;
250 *advertise_keys_response.mutable_share_keys_server_request() =
251 share_keys_request;
252 Operation complete_operation =
253 CreateDoneOperation(kOperationName, advertise_keys_response);
254 EXPECT_CALL(
255 http_client_,
256 PerformSingleRequest(SimpleHttpRequestMatcher(
257 "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
258 HttpRequest::Method::kGet, _, "")))
259 .WillOnce(Return(FakeHttpResponse(
260 200, HeaderList(), complete_operation.SerializeAsString())));
261 send_to_server->Send(&server_message);
262 ASSERT_OK(server_response_holder_);
263
264 secagg::ServerToClientWrapperMessage expected_message;
265 *expected_message.mutable_share_keys_request() = share_keys_request;
266 EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
267 }
268
TEST_F(HttpSecAggSendToServerImplTest,TestSendR0AdvertiseKeysFailedImmediately)269 TEST_F(HttpSecAggSendToServerImplTest,
270 TestSendR0AdvertiseKeysFailedImmediately) {
271 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
272 CreateSecAggSendToServer(std::optional<std::string>());
273 secagg::ClientToServerWrapperMessage server_message;
274 auto pair_of_keys =
275 server_message.mutable_advertise_keys()->mutable_pair_of_public_keys();
276 pair_of_keys->set_enc_pk("enc_pk");
277 pair_of_keys->set_noise_pk("noise_pk");
278 // Create expected request.
279 AdvertiseKeysRequest expected_request;
280 *expected_request.mutable_advertise_keys() = server_message.advertise_keys();
281
282 EXPECT_CALL(
283 http_client_,
284 PerformSingleRequest(SimpleHttpRequestMatcher(
285 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
286 "clients/client_token:advertisekeys?%24alt=proto",
287 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
288 .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
289 send_to_server->Send(&server_message);
290 EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
291 }
292
TEST_F(HttpSecAggSendToServerImplTest,TestSendR0AdvertiseKeysFailed)293 TEST_F(HttpSecAggSendToServerImplTest, TestSendR0AdvertiseKeysFailed) {
294 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
295 CreateSecAggSendToServer(std::optional<std::string>());
296 secagg::ClientToServerWrapperMessage server_message;
297 auto pair_of_keys =
298 server_message.mutable_advertise_keys()->mutable_pair_of_public_keys();
299 pair_of_keys->set_enc_pk("enc_pk");
300 pair_of_keys->set_noise_pk("noise_pk");
301 // Create expected request.
302 AdvertiseKeysRequest expected_request;
303 *expected_request.mutable_advertise_keys() = server_message.advertise_keys();
304
305 EXPECT_CALL(
306 http_client_,
307 PerformSingleRequest(SimpleHttpRequestMatcher(
308 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
309 "clients/client_token:advertisekeys?%24alt=proto",
310 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
311 .WillOnce(Return(FakeHttpResponse(
312 200, HeaderList(),
313 CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
314 "Something's wrong")
315 .SerializeAsString())));
316 send_to_server->Send(&server_message);
317 EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kInternal));
318 }
319
TEST_F(HttpSecAggSendToServerImplTest,TestSendR1ShareKeys)320 TEST_F(HttpSecAggSendToServerImplTest, TestSendR1ShareKeys) {
321 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
322 CreateSecAggSendToServer(std::optional<std::string>());
323 secagg::ClientToServerWrapperMessage server_message;
324 server_message.mutable_share_keys_response()
325 ->mutable_encrypted_key_shares()
326 ->Add("encrypted_key");
327 // Create expected request
328 ShareKeysRequest expected_request;
329 *expected_request.mutable_share_keys_client_response() =
330 server_message.share_keys_response();
331
332 EXPECT_CALL(
333 http_client_,
334 PerformSingleRequest(SimpleHttpRequestMatcher(
335 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
336 "clients/client_token:sharekeys?%24alt=proto",
337 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
338 .WillOnce(Return(FakeHttpResponse(
339 200, HeaderList(),
340 CreatePendingOperation("operations/foo#bar").SerializeAsString())));
341 // Create expected response
342 secagg::MaskedInputCollectionRequest masked_input_collection_request;
343 masked_input_collection_request.add_encrypted_key_shares(
344 "encryoted_key_share");
345 ShareKeysResponse share_keys_response;
346 *share_keys_response.mutable_masked_input_collection_server_request() =
347 masked_input_collection_request;
348 Operation complete_operation =
349 CreateDoneOperation(kOperationName, share_keys_response);
350 EXPECT_CALL(
351 http_client_,
352 PerformSingleRequest(SimpleHttpRequestMatcher(
353 "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
354 HttpRequest::Method::kGet, _, "")))
355 .WillOnce(Return(FakeHttpResponse(
356 200, HeaderList(), complete_operation.SerializeAsString())));
357 send_to_server->Send(&server_message);
358 ASSERT_OK(server_response_holder_);
359
360 secagg::ServerToClientWrapperMessage expected_message;
361 *expected_message.mutable_masked_input_request() =
362 masked_input_collection_request;
363 EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
364 }
365
TEST_F(HttpSecAggSendToServerImplTest,TestSendR1ShareKeysFailedImmediatedly)366 TEST_F(HttpSecAggSendToServerImplTest, TestSendR1ShareKeysFailedImmediatedly) {
367 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
368 CreateSecAggSendToServer(std::optional<std::string>());
369 secagg::ClientToServerWrapperMessage server_message;
370 server_message.mutable_share_keys_response()
371 ->mutable_encrypted_key_shares()
372 ->Add("encrypted_key");
373 // Create expected request
374 ShareKeysRequest expected_request;
375 *expected_request.mutable_share_keys_client_response() =
376 server_message.share_keys_response();
377
378 EXPECT_CALL(
379 http_client_,
380 PerformSingleRequest(SimpleHttpRequestMatcher(
381 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
382 "clients/client_token:sharekeys?%24alt=proto",
383 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
384 .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
385 send_to_server->Send(&server_message);
386 EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
387 }
388
TEST_F(HttpSecAggSendToServerImplTest,TestSendR1ShareKeysFailed)389 TEST_F(HttpSecAggSendToServerImplTest, TestSendR1ShareKeysFailed) {
390 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
391 CreateSecAggSendToServer(std::optional<std::string>());
392 secagg::ClientToServerWrapperMessage server_message;
393 server_message.mutable_share_keys_response()
394 ->mutable_encrypted_key_shares()
395 ->Add("encrypted_key");
396 // Create expected request
397 ShareKeysRequest expected_request;
398 *expected_request.mutable_share_keys_client_response() =
399 server_message.share_keys_response();
400
401 EXPECT_CALL(
402 http_client_,
403 PerformSingleRequest(SimpleHttpRequestMatcher(
404 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
405 "clients/client_token:sharekeys?%24alt=proto",
406 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
407 .WillOnce(Return(FakeHttpResponse(
408 200, HeaderList(),
409 CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
410 "Something's wrong")
411 .SerializeAsString())));
412 send_to_server->Send(&server_message);
413 EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kInternal));
414 }
415
TEST_F(HttpSecAggSendToServerImplTest,TestSendR2SubmitResultNoCheckpoint)416 TEST_F(HttpSecAggSendToServerImplTest, TestSendR2SubmitResultNoCheckpoint) {
417 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
418 CreateSecAggSendToServer(std::optional<std::string>());
419 secagg::ClientToServerWrapperMessage server_message;
420 secagg::MaskedInputVector masked_vector;
421 *masked_vector.mutable_encoded_vector() = "encoded_vector";
422 auto vector_map =
423 server_message.mutable_masked_input_response()->mutable_vectors();
424 (*vector_map)["vector_1"] = masked_vector;
425 // Create expected request
426 SubmitSecureAggregationResultRequest expected_request;
427 *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
428
429 // Create expected responses
430 EXPECT_CALL(http_client_,
431 PerformSingleRequest(SimpleHttpRequestMatcher(
432 "https://bytestream.uri/upload/v1/media/"
433 "masked_resource?upload_protocol=raw",
434 HttpRequest::Method::kPost, _,
435 server_message.masked_input_response().SerializeAsString())))
436 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
437
438 EXPECT_CALL(
439 http_client_,
440 PerformSingleRequest(SimpleHttpRequestMatcher(
441 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
442 "clients/client_token:submit?%24alt=proto",
443 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
444 .WillOnce(Return(FakeHttpResponse(
445 200, HeaderList(),
446 CreatePendingOperation("operations/foo#bar").SerializeAsString())));
447 secagg::UnmaskingRequest unmasking_request;
448 unmasking_request.add_dead_3_client_ids(12345);
449 SubmitSecureAggregationResultResponse submit_secagg_result_response;
450 *submit_secagg_result_response.mutable_unmasking_server_request() =
451 unmasking_request;
452 Operation complete_operation =
453 CreateDoneOperation(kOperationName, submit_secagg_result_response);
454 EXPECT_CALL(
455 http_client_,
456 PerformSingleRequest(SimpleHttpRequestMatcher(
457 "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
458 HttpRequest::Method::kGet, _, "")))
459 .WillOnce(Return(FakeHttpResponse(
460 200, HeaderList(), complete_operation.SerializeAsString())));
461 send_to_server->Send(&server_message);
462 ASSERT_OK(server_response_holder_);
463
464 secagg::ServerToClientWrapperMessage expected_message;
465 *expected_message.mutable_unmasking_request() = unmasking_request;
466 EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
467 }
468
TEST_F(HttpSecAggSendToServerImplTest,TestSendR2SubmitResultWithCheckpoint)469 TEST_F(HttpSecAggSendToServerImplTest, TestSendR2SubmitResultWithCheckpoint) {
470 std::string tf_checkpoint = "trained.ckpt";
471 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
472 CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
473 secagg::ClientToServerWrapperMessage server_message;
474 secagg::MaskedInputVector masked_vector;
475 *masked_vector.mutable_encoded_vector() = "encoded_vector";
476 auto vector_map =
477 server_message.mutable_masked_input_response()->mutable_vectors();
478 (*vector_map)["vector_1"] = masked_vector;
479 // Create expected request
480 SubmitSecureAggregationResultRequest expected_request;
481 *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
482 *expected_request.mutable_nonmasked_result_resource_name() =
483 kNonmaskedResourceName;
484
485 // Create expected responses
486 EXPECT_CALL(http_client_,
487 PerformSingleRequest(SimpleHttpRequestMatcher(
488 "https://bytestream.uri/upload/v1/media/"
489 "masked_resource?upload_protocol=raw",
490 HttpRequest::Method::kPost, _,
491 server_message.masked_input_response().SerializeAsString())))
492 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
493 EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
494 "https://bytestream.uri/upload/v1/media/"
495 "nonmasked_resource?upload_protocol=raw",
496 HttpRequest::Method::kPost, _, tf_checkpoint)))
497 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
498
499 secagg::UnmaskingRequest unmasking_request;
500 unmasking_request.add_dead_3_client_ids(12345);
501 SubmitSecureAggregationResultResponse submit_secagg_result_response;
502 *submit_secagg_result_response.mutable_unmasking_server_request() =
503 unmasking_request;
504 Operation complete_operation =
505 CreateDoneOperation(kOperationName, submit_secagg_result_response);
506 EXPECT_CALL(
507 http_client_,
508 PerformSingleRequest(SimpleHttpRequestMatcher(
509 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
510 "clients/client_token:submit?%24alt=proto",
511 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
512 .WillOnce(Return(FakeHttpResponse(
513 200, HeaderList(), complete_operation.SerializeAsString())));
514 send_to_server->Send(&server_message);
515 ASSERT_OK(server_response_holder_);
516
517 secagg::ServerToClientWrapperMessage expected_message;
518 *expected_message.mutable_unmasking_request() = unmasking_request;
519 EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
520 }
521
TEST_F(HttpSecAggSendToServerImplTest,TestSendR2SubmitResultWithCheckpointUploadFailed)522 TEST_F(HttpSecAggSendToServerImplTest,
523 TestSendR2SubmitResultWithCheckpointUploadFailed) {
524 std::string tf_checkpoint = "trained.ckpt";
525 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
526 CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
527 secagg::ClientToServerWrapperMessage server_message;
528 secagg::MaskedInputVector masked_vector;
529 *masked_vector.mutable_encoded_vector() = "encoded_vector";
530 auto vector_map =
531 server_message.mutable_masked_input_response()->mutable_vectors();
532 (*vector_map)["vector_1"] = masked_vector;
533 // Create expected request
534 SubmitSecureAggregationResultRequest expected_request;
535 *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
536 *expected_request.mutable_nonmasked_result_resource_name() =
537 kNonmaskedResourceName;
538
539 // Fail one of the upload
540 EXPECT_CALL(http_client_,
541 PerformSingleRequest(SimpleHttpRequestMatcher(
542 "https://bytestream.uri/upload/v1/media/"
543 "masked_resource?upload_protocol=raw",
544 HttpRequest::Method::kPost, _,
545 server_message.masked_input_response().SerializeAsString())))
546 .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
547 EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
548 "https://bytestream.uri/upload/v1/media/"
549 "nonmasked_resource?upload_protocol=raw",
550 HttpRequest::Method::kPost, _, tf_checkpoint)))
551 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
552
553 send_to_server->Send(&server_message);
554 EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
555 }
556
TEST_F(HttpSecAggSendToServerImplTest,TestSendR2SubmitResultWithCheckpointSubmitResultFailedImmediately)557 TEST_F(HttpSecAggSendToServerImplTest,
558 TestSendR2SubmitResultWithCheckpointSubmitResultFailedImmediately) {
559 std::string tf_checkpoint = "trained.ckpt";
560 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
561 CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
562 secagg::ClientToServerWrapperMessage server_message;
563 secagg::MaskedInputVector masked_vector;
564 masked_vector.set_encoded_vector("encoded_vector");
565 auto vector_map =
566 server_message.mutable_masked_input_response()->mutable_vectors();
567 (*vector_map)["vector_1"] = masked_vector;
568 // Create expected request
569 SubmitSecureAggregationResultRequest expected_request;
570 *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
571 *expected_request.mutable_nonmasked_result_resource_name() =
572 kNonmaskedResourceName;
573
574 // Create expected responses
575 EXPECT_CALL(http_client_,
576 PerformSingleRequest(SimpleHttpRequestMatcher(
577 "https://bytestream.uri/upload/v1/media/"
578 "masked_resource?upload_protocol=raw",
579 HttpRequest::Method::kPost, _,
580 server_message.masked_input_response().SerializeAsString())))
581 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
582 EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
583 "https://bytestream.uri/upload/v1/media/"
584 "nonmasked_resource?upload_protocol=raw",
585 HttpRequest::Method::kPost, _, tf_checkpoint)))
586 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
587
588 secagg::UnmaskingRequest unmasking_request;
589 unmasking_request.add_dead_3_client_ids(12345);
590 SubmitSecureAggregationResultResponse submit_secagg_result_response;
591 *submit_secagg_result_response.mutable_unmasking_server_request() =
592 unmasking_request;
593 Operation complete_operation =
594 CreateDoneOperation(kOperationName, submit_secagg_result_response);
595 EXPECT_CALL(
596 http_client_,
597 PerformSingleRequest(SimpleHttpRequestMatcher(
598 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
599 "clients/client_token:submit?%24alt=proto",
600 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
601 .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
602 send_to_server->Send(&server_message);
603 EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
604 }
605
TEST_F(HttpSecAggSendToServerImplTest,TestSendR2SubmitResultWithCheckpointSubmitResultFailed)606 TEST_F(HttpSecAggSendToServerImplTest,
607 TestSendR2SubmitResultWithCheckpointSubmitResultFailed) {
608 std::string tf_checkpoint = "trained.ckpt";
609 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
610 CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
611 secagg::ClientToServerWrapperMessage server_message;
612 secagg::MaskedInputVector masked_vector;
613 masked_vector.set_encoded_vector("encoded_vector");
614 auto vector_map =
615 server_message.mutable_masked_input_response()->mutable_vectors();
616 (*vector_map)["vector_1"] = masked_vector;
617 // Create expected request
618 SubmitSecureAggregationResultRequest expected_request;
619 *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
620 *expected_request.mutable_nonmasked_result_resource_name() =
621 kNonmaskedResourceName;
622
623 // Create expected responses
624 EXPECT_CALL(http_client_,
625 PerformSingleRequest(SimpleHttpRequestMatcher(
626 "https://bytestream.uri/upload/v1/media/"
627 "masked_resource?upload_protocol=raw",
628 HttpRequest::Method::kPost, _,
629 server_message.masked_input_response().SerializeAsString())))
630 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
631 EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
632 "https://bytestream.uri/upload/v1/media/"
633 "nonmasked_resource?upload_protocol=raw",
634 HttpRequest::Method::kPost, _, tf_checkpoint)))
635 .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
636
637 EXPECT_CALL(
638 http_client_,
639 PerformSingleRequest(SimpleHttpRequestMatcher(
640 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
641 "clients/client_token:submit?%24alt=proto",
642 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
643 .WillOnce(Return(FakeHttpResponse(
644 200, HeaderList(),
645 CreatePendingOperation("operations/foo#bar").SerializeAsString())));
646
647 secagg::UnmaskingRequest unmasking_request;
648 unmasking_request.add_dead_3_client_ids(12345);
649 SubmitSecureAggregationResultResponse submit_secagg_result_response;
650 *submit_secagg_result_response.mutable_unmasking_server_request() =
651 unmasking_request;
652 Operation complete_operation =
653 CreateDoneOperation(kOperationName, submit_secagg_result_response);
654 EXPECT_CALL(
655 http_client_,
656 PerformSingleRequest(SimpleHttpRequestMatcher(
657 "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
658 HttpRequest::Method::kGet, _, "")))
659 .WillOnce(Return(FakeHttpResponse(
660 200, HeaderList(),
661 CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
662 "Something's wroing.")
663 .SerializeAsString())));
664 send_to_server->Send(&server_message);
665 EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kInternal));
666 }
667
TEST_F(HttpSecAggSendToServerImplTest,TestSendR3Unmask)668 TEST_F(HttpSecAggSendToServerImplTest, TestSendR3Unmask) {
669 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
670 CreateSecAggSendToServer(std::optional<std::string>());
671 secagg::ClientToServerWrapperMessage server_message;
672 server_message.mutable_unmasking_response()
673 ->add_noise_or_prf_key_shares()
674 ->set_noise_sk_share("noise_sk_share");
675 // Create expected request
676 UnmaskRequest expected_request;
677 *expected_request.mutable_unmasking_client_response() =
678 server_message.unmasking_response();
679
680 // Create expected response
681 EXPECT_CALL(
682 http_client_,
683 PerformSingleRequest(SimpleHttpRequestMatcher(
684 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
685 "clients/client_token:unmask?%24alt=proto",
686 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
687 .WillOnce(Return(FakeHttpResponse(200, HeaderList(),
688 UnmaskResponse().SerializeAsString())));
689 send_to_server->Send(&server_message);
690 auto response = server_response_holder_;
691 ASSERT_OK(response);
692 EXPECT_THAT(*response, EqualsProto(secagg::ServerToClientWrapperMessage()));
693 }
694
TEST_F(HttpSecAggSendToServerImplTest,TestSendAbortWithoutInterruption)695 TEST_F(HttpSecAggSendToServerImplTest, TestSendAbortWithoutInterruption) {
696 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
697 CreateSecAggSendToServer(std::optional<std::string>());
698 std::string diagnostic_info = "Some computation failed.";
699 secagg::ClientToServerWrapperMessage server_message;
700 server_message.mutable_abort()->set_diagnostic_info(diagnostic_info);
701 // Create expected request
702 AbortSecureAggregationRequest expected_request;
703 google::rpc::Status status;
704 status.set_message(diagnostic_info);
705 status.set_code(google::rpc::INTERNAL);
706 *expected_request.mutable_status() = status;
707
708 // We expect the abort request to actually be issued, because interrupted_ is
709 // set to false, and hence the "InterruptibleRunner" we provided at the top of
710 // the test should let the request go through.
711 EXPECT_CALL(
712 http_client_,
713 PerformSingleRequest(SimpleHttpRequestMatcher(
714 "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
715 "clients/client_token:abort?%24alt=proto",
716 HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
717 .WillOnce(Return(FakeHttpResponse(
718 200, HeaderList(),
719 AbortSecureAggregationResponse().SerializeAsString())));
720
721 // Send the request, and verify that sending it succeeded.
722 send_to_server->Send(&server_message);
723 ASSERT_OK(server_response_holder_);
724 secagg::ServerToClientWrapperMessage expected_response;
725 expected_response.mutable_abort();
726 EXPECT_THAT(*server_response_holder_, EqualsProto(expected_response));
727 }
728
TEST_F(HttpSecAggSendToServerImplTest,TestSendAbortShouldBeCancelledIfAlreadyInterruptedForTooLong)729 TEST_F(HttpSecAggSendToServerImplTest,
730 TestSendAbortShouldBeCancelledIfAlreadyInterruptedForTooLong) {
731 std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
732 CreateSecAggSendToServer(std::optional<std::string>());
733 std::string diagnostic_info = "Some computation failed.";
734 secagg::ClientToServerWrapperMessage server_message;
735 server_message.mutable_abort()->set_diagnostic_info(diagnostic_info);
736 // Create expected request
737 AbortSecureAggregationRequest expected_request;
738 google::rpc::Status status;
739 status.set_message(diagnostic_info);
740 status.set_code(google::rpc::INTERNAL);
741 *expected_request.mutable_status() = status;
742
743 // We do *not* expect any HTTP request to actually be issued, since the
744 // interrupted_ flag is true, and therefore the request should be cancelled
745 // before it is even issued.
746 interrupted_ = true;
747
748 // Send the request, and verify that sending it failed.
749 send_to_server->Send(&server_message);
750 EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kCancelled));
751 }
752
753 } // anonymous namespace
754 } // namespace http
755 } // namespace client
756 } // namespace fcp
757