• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/http/http_secagg_send_to_server_impl.h"
17 
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 
23 #include "google/longrunning/operations.pb.h"
24 #include "google/rpc/code.pb.h"
25 #include "absl/time/time.h"
26 #include "fcp/base/simulated_clock.h"
27 #include "fcp/client/http/testing/test_helpers.h"
28 #include "fcp/client/test_helpers.h"
29 #include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
30 #include "fcp/testing/testing.h"
31 
32 namespace fcp {
33 namespace client {
34 namespace http {
35 namespace {
36 
37 using ::google::internal::federatedcompute::v1::AbortSecureAggregationRequest;
38 using ::google::internal::federatedcompute::v1::AbortSecureAggregationResponse;
39 using ::google::internal::federatedcompute::v1::AdvertiseKeysRequest;
40 using ::google::internal::federatedcompute::v1::AdvertiseKeysResponse;
41 using ::google::internal::federatedcompute::v1::ByteStreamResource;
42 using ::google::internal::federatedcompute::v1::ForwardingInfo;
43 using ::google::internal::federatedcompute::v1::SecureAggregandExecutionInfo;
44 using ::google::internal::federatedcompute::v1::ShareKeysRequest;
45 using ::google::internal::federatedcompute::v1::ShareKeysResponse;
46 using ::google::internal::federatedcompute::v1::
47     SubmitSecureAggregationResultRequest;
48 using ::google::internal::federatedcompute::v1::
49     SubmitSecureAggregationResultResponse;
50 using ::google::internal::federatedcompute::v1::UnmaskRequest;
51 using ::google::internal::federatedcompute::v1::UnmaskResponse;
52 using ::google::longrunning::Operation;
53 using ::testing::_;
54 using ::testing::NiceMock;
55 using ::testing::Return;
56 using ::testing::StrictMock;
57 
58 constexpr absl::string_view kAggregationId = "aggregation_id";
59 constexpr absl::string_view kClientToken = "client_token";
60 constexpr absl::string_view kSecureAggregationTargetUri =
61     "https://secureaggregation.uri/";
62 constexpr absl::string_view kByteStreamTargetUri = "https://bytestream.uri/";
63 constexpr absl::string_view kMaskedResourceName = "masked_resource";
64 constexpr absl::string_view kNonmaskedResourceName = "nonmasked_resource";
65 constexpr absl::string_view kOperationName = "my_operation";
66 constexpr absl::string_view kApiKey = "API_KEY";
67 constexpr absl::Duration kDelayedInterruptibleRunnerDeadline =
68     absl::Seconds(10);
69 
TEST(HttpSecAggProtocolDelegateTest,GetModulus)70 TEST(HttpSecAggProtocolDelegateTest, GetModulus) {
71   absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
72   std::string tensor_key = "tensor_1";
73   google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
74   SecureAggregandExecutionInfo secure_aggregand_execution_info;
75   secure_aggregand_execution_info.set_modulus(12345);
76   secure_aggregands[tensor_key] = secure_aggregand_execution_info;
77   HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
78   auto modulus = delegate.GetModulus(tensor_key);
79   ASSERT_OK(modulus);
80   ASSERT_EQ(*modulus, 12345);
81 }
82 
TEST(HttpSecAggProtocolDelegateTest,GetModulusKeyNotFound)83 TEST(HttpSecAggProtocolDelegateTest, GetModulusKeyNotFound) {
84   absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
85   google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
86   SecureAggregandExecutionInfo secure_aggregand_execution_info;
87   secure_aggregand_execution_info.set_modulus(12345);
88   secure_aggregands["tensor_1"] = secure_aggregand_execution_info;
89   HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
90   ASSERT_THAT(delegate.GetModulus("do_not_exist"),
91               IsCode(absl::StatusCode::kInternal));
92 }
93 
TEST(HttpSecAggProtocolDelegateTest,ReceiveMessageOkResponse)94 TEST(HttpSecAggProtocolDelegateTest, ReceiveMessageOkResponse) {
95   absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
96   google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
97   HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
98   secagg::ServerToClientWrapperMessage server_response;
99   server_response.mutable_masked_input_request()->add_encrypted_key_shares(
100       "encrypted_key");
101   holder = server_response;
102 
103   auto server_message = delegate.ReceiveServerMessage();
104   ASSERT_OK(server_message);
105   ASSERT_THAT(*server_message, EqualsProto(server_response));
106   ASSERT_EQ(delegate.last_received_message_size(),
107             server_response.ByteSizeLong());
108 }
109 
TEST(HttpSecAggProtocolDelegateTest,ReceiveMessageErrorResponse)110 TEST(HttpSecAggProtocolDelegateTest, ReceiveMessageErrorResponse) {
111   absl::StatusOr<secagg::ServerToClientWrapperMessage> holder;
112   google::protobuf::Map<std::string, SecureAggregandExecutionInfo> secure_aggregands;
113   HttpSecAggProtocolDelegate delegate(secure_aggregands, &holder);
114   holder = absl::InternalError("Something is broken.");
115 
116   ASSERT_THAT(delegate.ReceiveServerMessage(),
117               IsCode(absl::StatusCode::kInternal));
118   ASSERT_EQ(delegate.last_received_message_size(), 0);
119 }
120 
121 class HttpSecAggSendToServerImplTest : public ::testing::Test {
122  protected:
SetUp()123   void SetUp() override {
124     request_helper_ = std::make_unique<ProtocolRequestHelper>(
125         &http_client_, &bytes_downloaded_, &bytes_uploaded_,
126         network_stopwatch_.get(), Clock::RealClock());
127     runner_ = std::make_unique<InterruptibleRunner>(
128         &log_manager_, []() { return false; },
129         InterruptibleRunner::TimingConfig{
130             .polling_period = absl::ZeroDuration(),
131             .graceful_shutdown_period = absl::InfiniteDuration(),
132             .extended_shutdown_period = absl::InfiniteDuration()},
133         InterruptibleRunner::DiagnosticsConfig{
134             .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
135             .interrupt_timeout =
136                 ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
137             .interrupted_extended = ProdDiagCode::
138                 BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
139             .interrupt_timeout_extended = ProdDiagCode::
140                 BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT});
141     *secagg_upload_forwarding_info_.mutable_target_uri_prefix() =
142         kSecureAggregationTargetUri;
143     *masked_result_resource_.mutable_resource_name() = kMaskedResourceName;
144     ForwardingInfo masked_resource_forwarding_info;
145     *masked_resource_forwarding_info.mutable_target_uri_prefix() =
146         kByteStreamTargetUri;
147     *masked_result_resource_.mutable_data_upload_forwarding_info() =
148         masked_resource_forwarding_info;
149     *nonmasked_result_resource_.mutable_resource_name() =
150         kNonmaskedResourceName;
151     ForwardingInfo nonmasked_resource_forwarding_info;
152     *nonmasked_resource_forwarding_info.mutable_target_uri_prefix() =
153         kByteStreamTargetUri;
154     *nonmasked_result_resource_.mutable_data_upload_forwarding_info() =
155         nonmasked_resource_forwarding_info;
156   }
157 
CreateInterruptibleRunner()158   std::unique_ptr<InterruptibleRunner> CreateInterruptibleRunner() {
159     return std::make_unique<InterruptibleRunner>(
160         &log_manager_, [this]() { return interrupted_; },
161         InterruptibleRunner::TimingConfig{
162             .polling_period = absl::ZeroDuration(),
163             .graceful_shutdown_period = absl::InfiniteDuration(),
164             .extended_shutdown_period = absl::InfiniteDuration()},
165         InterruptibleRunner::DiagnosticsConfig{
166             .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP,
167             .interrupt_timeout =
168                 ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_HTTP_TIMED_OUT,
169             .interrupted_extended = ProdDiagCode::
170                 BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_COMPLETED,
171             .interrupt_timeout_extended = ProdDiagCode::
172                 BACKGROUND_TRAINING_INTERRUPT_HTTP_EXTENDED_TIMED_OUT});
173   }
174 
CreateSecAggSendToServer(std::optional<std::string> tf_checkpoint)175   std::unique_ptr<HttpSecAggSendToServerImpl> CreateSecAggSendToServer(
176       std::optional<std::string> tf_checkpoint) {
177     auto send_to_server = HttpSecAggSendToServerImpl::Create(
178         kApiKey, &clock_, request_helper_.get(), runner_.get(),
179         /* delayed_interruptible_runner_creator=*/
180         [this](absl::Time deadline) {
181           // Ensure that the HttpSecAggSendToServerImpl implementation correctly
182           // passes a deadline that matches the 'waiting period' value we
183           // provide below (with a 1s grace period to account for the delay of
184           // executing the actual test; unfortunately the underlying HTTP code
185           // currently still uses absl::Now() directly, so we're forced to deal
186           // with 'real' time...).
187           //
188           // We don't actually use the deadline value though, since that
189           // would only make testing more complicated.
190           EXPECT_GE(deadline, absl::Now() +
191                                   kDelayedInterruptibleRunnerDeadline -
192                                   absl::Seconds(1));
193           return CreateInterruptibleRunner();
194         },
195         &server_response_holder_, kAggregationId, kClientToken,
196         secagg_upload_forwarding_info_, masked_result_resource_,
197         nonmasked_result_resource_, tf_checkpoint,
198         /* disable_request_body_compression=*/true,
199         /* waiting_period_for_cancellation=*/
200         kDelayedInterruptibleRunnerDeadline);
201     FCP_CHECK(send_to_server.ok());
202     return std::move(*send_to_server);
203   }
204   bool interrupted_ = false;
205   // We set the simulated clock to "now", since a bunch of the HTTP-related FCP
206   // code currently still uses absl::Now() directly, rather than using a more
207   // testable "Clock" object. This ensures various timestamps we may encounter
208   // are more understandable.
209   SimulatedClock clock_ = SimulatedClock(absl::Now());
210   StrictMock<MockHttpClient> http_client_;
211   NiceMock<MockLogManager> log_manager_;
212   int64_t bytes_downloaded_ = 0;
213   int64_t bytes_uploaded_ = 0;
214   std::unique_ptr<WallClockStopwatch> network_stopwatch_ =
215       WallClockStopwatch::Create();
216   std::unique_ptr<ProtocolRequestHelper> request_helper_;
217   std::unique_ptr<InterruptibleRunner> runner_;
218   absl::StatusOr<secagg::ServerToClientWrapperMessage> server_response_holder_;
219   ForwardingInfo secagg_upload_forwarding_info_;
220   ByteStreamResource masked_result_resource_;
221   ByteStreamResource nonmasked_result_resource_;
222 };
223 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR0AdvertiseKeys)224 TEST_F(HttpSecAggSendToServerImplTest, TestSendR0AdvertiseKeys) {
225   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
226       CreateSecAggSendToServer(std::optional<std::string>());
227   secagg::ClientToServerWrapperMessage server_message;
228   auto pair_of_keys =
229       server_message.mutable_advertise_keys()->mutable_pair_of_public_keys();
230   pair_of_keys->set_enc_pk("enc_pk");
231   pair_of_keys->set_noise_pk("noise_pk");
232   // Create expected request.
233   AdvertiseKeysRequest expected_request;
234   *expected_request.mutable_advertise_keys() = server_message.advertise_keys();
235 
236   EXPECT_CALL(
237       http_client_,
238       PerformSingleRequest(SimpleHttpRequestMatcher(
239           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
240           "clients/client_token:advertisekeys?%24alt=proto",
241           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
242       .WillOnce(Return(FakeHttpResponse(
243           200, HeaderList(),
244           CreatePendingOperation("operations/foo#bar").SerializeAsString())));
245   // Create expected response.
246   secagg::ShareKeysRequest share_keys_request;
247   *share_keys_request.add_pairs_of_public_keys()->mutable_noise_pk() =
248       "noise_pk";
249   AdvertiseKeysResponse advertise_keys_response;
250   *advertise_keys_response.mutable_share_keys_server_request() =
251       share_keys_request;
252   Operation complete_operation =
253       CreateDoneOperation(kOperationName, advertise_keys_response);
254   EXPECT_CALL(
255       http_client_,
256       PerformSingleRequest(SimpleHttpRequestMatcher(
257           "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
258           HttpRequest::Method::kGet, _, "")))
259       .WillOnce(Return(FakeHttpResponse(
260           200, HeaderList(), complete_operation.SerializeAsString())));
261   send_to_server->Send(&server_message);
262   ASSERT_OK(server_response_holder_);
263 
264   secagg::ServerToClientWrapperMessage expected_message;
265   *expected_message.mutable_share_keys_request() = share_keys_request;
266   EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
267 }
268 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR0AdvertiseKeysFailedImmediately)269 TEST_F(HttpSecAggSendToServerImplTest,
270        TestSendR0AdvertiseKeysFailedImmediately) {
271   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
272       CreateSecAggSendToServer(std::optional<std::string>());
273   secagg::ClientToServerWrapperMessage server_message;
274   auto pair_of_keys =
275       server_message.mutable_advertise_keys()->mutable_pair_of_public_keys();
276   pair_of_keys->set_enc_pk("enc_pk");
277   pair_of_keys->set_noise_pk("noise_pk");
278   // Create expected request.
279   AdvertiseKeysRequest expected_request;
280   *expected_request.mutable_advertise_keys() = server_message.advertise_keys();
281 
282   EXPECT_CALL(
283       http_client_,
284       PerformSingleRequest(SimpleHttpRequestMatcher(
285           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
286           "clients/client_token:advertisekeys?%24alt=proto",
287           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
288       .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
289   send_to_server->Send(&server_message);
290   EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
291 }
292 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR0AdvertiseKeysFailed)293 TEST_F(HttpSecAggSendToServerImplTest, TestSendR0AdvertiseKeysFailed) {
294   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
295       CreateSecAggSendToServer(std::optional<std::string>());
296   secagg::ClientToServerWrapperMessage server_message;
297   auto pair_of_keys =
298       server_message.mutable_advertise_keys()->mutable_pair_of_public_keys();
299   pair_of_keys->set_enc_pk("enc_pk");
300   pair_of_keys->set_noise_pk("noise_pk");
301   // Create expected request.
302   AdvertiseKeysRequest expected_request;
303   *expected_request.mutable_advertise_keys() = server_message.advertise_keys();
304 
305   EXPECT_CALL(
306       http_client_,
307       PerformSingleRequest(SimpleHttpRequestMatcher(
308           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
309           "clients/client_token:advertisekeys?%24alt=proto",
310           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
311       .WillOnce(Return(FakeHttpResponse(
312           200, HeaderList(),
313           CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
314                                "Something's wrong")
315               .SerializeAsString())));
316   send_to_server->Send(&server_message);
317   EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kInternal));
318 }
319 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR1ShareKeys)320 TEST_F(HttpSecAggSendToServerImplTest, TestSendR1ShareKeys) {
321   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
322       CreateSecAggSendToServer(std::optional<std::string>());
323   secagg::ClientToServerWrapperMessage server_message;
324   server_message.mutable_share_keys_response()
325       ->mutable_encrypted_key_shares()
326       ->Add("encrypted_key");
327   // Create expected request
328   ShareKeysRequest expected_request;
329   *expected_request.mutable_share_keys_client_response() =
330       server_message.share_keys_response();
331 
332   EXPECT_CALL(
333       http_client_,
334       PerformSingleRequest(SimpleHttpRequestMatcher(
335           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
336           "clients/client_token:sharekeys?%24alt=proto",
337           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
338       .WillOnce(Return(FakeHttpResponse(
339           200, HeaderList(),
340           CreatePendingOperation("operations/foo#bar").SerializeAsString())));
341   // Create expected response
342   secagg::MaskedInputCollectionRequest masked_input_collection_request;
343   masked_input_collection_request.add_encrypted_key_shares(
344       "encryoted_key_share");
345   ShareKeysResponse share_keys_response;
346   *share_keys_response.mutable_masked_input_collection_server_request() =
347       masked_input_collection_request;
348   Operation complete_operation =
349       CreateDoneOperation(kOperationName, share_keys_response);
350   EXPECT_CALL(
351       http_client_,
352       PerformSingleRequest(SimpleHttpRequestMatcher(
353           "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
354           HttpRequest::Method::kGet, _, "")))
355       .WillOnce(Return(FakeHttpResponse(
356           200, HeaderList(), complete_operation.SerializeAsString())));
357   send_to_server->Send(&server_message);
358   ASSERT_OK(server_response_holder_);
359 
360   secagg::ServerToClientWrapperMessage expected_message;
361   *expected_message.mutable_masked_input_request() =
362       masked_input_collection_request;
363   EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
364 }
365 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR1ShareKeysFailedImmediatedly)366 TEST_F(HttpSecAggSendToServerImplTest, TestSendR1ShareKeysFailedImmediatedly) {
367   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
368       CreateSecAggSendToServer(std::optional<std::string>());
369   secagg::ClientToServerWrapperMessage server_message;
370   server_message.mutable_share_keys_response()
371       ->mutable_encrypted_key_shares()
372       ->Add("encrypted_key");
373   // Create expected request
374   ShareKeysRequest expected_request;
375   *expected_request.mutable_share_keys_client_response() =
376       server_message.share_keys_response();
377 
378   EXPECT_CALL(
379       http_client_,
380       PerformSingleRequest(SimpleHttpRequestMatcher(
381           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
382           "clients/client_token:sharekeys?%24alt=proto",
383           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
384       .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
385   send_to_server->Send(&server_message);
386   EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
387 }
388 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR1ShareKeysFailed)389 TEST_F(HttpSecAggSendToServerImplTest, TestSendR1ShareKeysFailed) {
390   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
391       CreateSecAggSendToServer(std::optional<std::string>());
392   secagg::ClientToServerWrapperMessage server_message;
393   server_message.mutable_share_keys_response()
394       ->mutable_encrypted_key_shares()
395       ->Add("encrypted_key");
396   // Create expected request
397   ShareKeysRequest expected_request;
398   *expected_request.mutable_share_keys_client_response() =
399       server_message.share_keys_response();
400 
401   EXPECT_CALL(
402       http_client_,
403       PerformSingleRequest(SimpleHttpRequestMatcher(
404           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
405           "clients/client_token:sharekeys?%24alt=proto",
406           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
407       .WillOnce(Return(FakeHttpResponse(
408           200, HeaderList(),
409           CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
410                                "Something's wrong")
411               .SerializeAsString())));
412   send_to_server->Send(&server_message);
413   EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kInternal));
414 }
415 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR2SubmitResultNoCheckpoint)416 TEST_F(HttpSecAggSendToServerImplTest, TestSendR2SubmitResultNoCheckpoint) {
417   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
418       CreateSecAggSendToServer(std::optional<std::string>());
419   secagg::ClientToServerWrapperMessage server_message;
420   secagg::MaskedInputVector masked_vector;
421   *masked_vector.mutable_encoded_vector() = "encoded_vector";
422   auto vector_map =
423       server_message.mutable_masked_input_response()->mutable_vectors();
424   (*vector_map)["vector_1"] = masked_vector;
425   // Create expected request
426   SubmitSecureAggregationResultRequest expected_request;
427   *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
428 
429   // Create expected responses
430   EXPECT_CALL(http_client_,
431               PerformSingleRequest(SimpleHttpRequestMatcher(
432                   "https://bytestream.uri/upload/v1/media/"
433                   "masked_resource?upload_protocol=raw",
434                   HttpRequest::Method::kPost, _,
435                   server_message.masked_input_response().SerializeAsString())))
436       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
437 
438   EXPECT_CALL(
439       http_client_,
440       PerformSingleRequest(SimpleHttpRequestMatcher(
441           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
442           "clients/client_token:submit?%24alt=proto",
443           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
444       .WillOnce(Return(FakeHttpResponse(
445           200, HeaderList(),
446           CreatePendingOperation("operations/foo#bar").SerializeAsString())));
447   secagg::UnmaskingRequest unmasking_request;
448   unmasking_request.add_dead_3_client_ids(12345);
449   SubmitSecureAggregationResultResponse submit_secagg_result_response;
450   *submit_secagg_result_response.mutable_unmasking_server_request() =
451       unmasking_request;
452   Operation complete_operation =
453       CreateDoneOperation(kOperationName, submit_secagg_result_response);
454   EXPECT_CALL(
455       http_client_,
456       PerformSingleRequest(SimpleHttpRequestMatcher(
457           "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
458           HttpRequest::Method::kGet, _, "")))
459       .WillOnce(Return(FakeHttpResponse(
460           200, HeaderList(), complete_operation.SerializeAsString())));
461   send_to_server->Send(&server_message);
462   ASSERT_OK(server_response_holder_);
463 
464   secagg::ServerToClientWrapperMessage expected_message;
465   *expected_message.mutable_unmasking_request() = unmasking_request;
466   EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
467 }
468 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR2SubmitResultWithCheckpoint)469 TEST_F(HttpSecAggSendToServerImplTest, TestSendR2SubmitResultWithCheckpoint) {
470   std::string tf_checkpoint = "trained.ckpt";
471   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
472       CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
473   secagg::ClientToServerWrapperMessage server_message;
474   secagg::MaskedInputVector masked_vector;
475   *masked_vector.mutable_encoded_vector() = "encoded_vector";
476   auto vector_map =
477       server_message.mutable_masked_input_response()->mutable_vectors();
478   (*vector_map)["vector_1"] = masked_vector;
479   // Create expected request
480   SubmitSecureAggregationResultRequest expected_request;
481   *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
482   *expected_request.mutable_nonmasked_result_resource_name() =
483       kNonmaskedResourceName;
484 
485   // Create expected responses
486   EXPECT_CALL(http_client_,
487               PerformSingleRequest(SimpleHttpRequestMatcher(
488                   "https://bytestream.uri/upload/v1/media/"
489                   "masked_resource?upload_protocol=raw",
490                   HttpRequest::Method::kPost, _,
491                   server_message.masked_input_response().SerializeAsString())))
492       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
493   EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
494                                 "https://bytestream.uri/upload/v1/media/"
495                                 "nonmasked_resource?upload_protocol=raw",
496                                 HttpRequest::Method::kPost, _, tf_checkpoint)))
497       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
498 
499   secagg::UnmaskingRequest unmasking_request;
500   unmasking_request.add_dead_3_client_ids(12345);
501   SubmitSecureAggregationResultResponse submit_secagg_result_response;
502   *submit_secagg_result_response.mutable_unmasking_server_request() =
503       unmasking_request;
504   Operation complete_operation =
505       CreateDoneOperation(kOperationName, submit_secagg_result_response);
506   EXPECT_CALL(
507       http_client_,
508       PerformSingleRequest(SimpleHttpRequestMatcher(
509           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
510           "clients/client_token:submit?%24alt=proto",
511           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
512       .WillOnce(Return(FakeHttpResponse(
513           200, HeaderList(), complete_operation.SerializeAsString())));
514   send_to_server->Send(&server_message);
515   ASSERT_OK(server_response_holder_);
516 
517   secagg::ServerToClientWrapperMessage expected_message;
518   *expected_message.mutable_unmasking_request() = unmasking_request;
519   EXPECT_THAT(*server_response_holder_, EqualsProto(expected_message));
520 }
521 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR2SubmitResultWithCheckpointUploadFailed)522 TEST_F(HttpSecAggSendToServerImplTest,
523        TestSendR2SubmitResultWithCheckpointUploadFailed) {
524   std::string tf_checkpoint = "trained.ckpt";
525   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
526       CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
527   secagg::ClientToServerWrapperMessage server_message;
528   secagg::MaskedInputVector masked_vector;
529   *masked_vector.mutable_encoded_vector() = "encoded_vector";
530   auto vector_map =
531       server_message.mutable_masked_input_response()->mutable_vectors();
532   (*vector_map)["vector_1"] = masked_vector;
533   // Create expected request
534   SubmitSecureAggregationResultRequest expected_request;
535   *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
536   *expected_request.mutable_nonmasked_result_resource_name() =
537       kNonmaskedResourceName;
538 
539   // Fail one of the upload
540   EXPECT_CALL(http_client_,
541               PerformSingleRequest(SimpleHttpRequestMatcher(
542                   "https://bytestream.uri/upload/v1/media/"
543                   "masked_resource?upload_protocol=raw",
544                   HttpRequest::Method::kPost, _,
545                   server_message.masked_input_response().SerializeAsString())))
546       .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
547   EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
548                                 "https://bytestream.uri/upload/v1/media/"
549                                 "nonmasked_resource?upload_protocol=raw",
550                                 HttpRequest::Method::kPost, _, tf_checkpoint)))
551       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
552 
553   send_to_server->Send(&server_message);
554   EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
555 }
556 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR2SubmitResultWithCheckpointSubmitResultFailedImmediately)557 TEST_F(HttpSecAggSendToServerImplTest,
558        TestSendR2SubmitResultWithCheckpointSubmitResultFailedImmediately) {
559   std::string tf_checkpoint = "trained.ckpt";
560   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
561       CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
562   secagg::ClientToServerWrapperMessage server_message;
563   secagg::MaskedInputVector masked_vector;
564   masked_vector.set_encoded_vector("encoded_vector");
565   auto vector_map =
566       server_message.mutable_masked_input_response()->mutable_vectors();
567   (*vector_map)["vector_1"] = masked_vector;
568   // Create expected request
569   SubmitSecureAggregationResultRequest expected_request;
570   *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
571   *expected_request.mutable_nonmasked_result_resource_name() =
572       kNonmaskedResourceName;
573 
574   // Create expected responses
575   EXPECT_CALL(http_client_,
576               PerformSingleRequest(SimpleHttpRequestMatcher(
577                   "https://bytestream.uri/upload/v1/media/"
578                   "masked_resource?upload_protocol=raw",
579                   HttpRequest::Method::kPost, _,
580                   server_message.masked_input_response().SerializeAsString())))
581       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
582   EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
583                                 "https://bytestream.uri/upload/v1/media/"
584                                 "nonmasked_resource?upload_protocol=raw",
585                                 HttpRequest::Method::kPost, _, tf_checkpoint)))
586       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
587 
588   secagg::UnmaskingRequest unmasking_request;
589   unmasking_request.add_dead_3_client_ids(12345);
590   SubmitSecureAggregationResultResponse submit_secagg_result_response;
591   *submit_secagg_result_response.mutable_unmasking_server_request() =
592       unmasking_request;
593   Operation complete_operation =
594       CreateDoneOperation(kOperationName, submit_secagg_result_response);
595   EXPECT_CALL(
596       http_client_,
597       PerformSingleRequest(SimpleHttpRequestMatcher(
598           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
599           "clients/client_token:submit?%24alt=proto",
600           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
601       .WillOnce(Return(FakeHttpResponse(503, HeaderList(), "")));
602   send_to_server->Send(&server_message);
603   EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kUnavailable));
604 }
605 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR2SubmitResultWithCheckpointSubmitResultFailed)606 TEST_F(HttpSecAggSendToServerImplTest,
607        TestSendR2SubmitResultWithCheckpointSubmitResultFailed) {
608   std::string tf_checkpoint = "trained.ckpt";
609   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
610       CreateSecAggSendToServer(std::optional<std::string>(tf_checkpoint));
611   secagg::ClientToServerWrapperMessage server_message;
612   secagg::MaskedInputVector masked_vector;
613   masked_vector.set_encoded_vector("encoded_vector");
614   auto vector_map =
615       server_message.mutable_masked_input_response()->mutable_vectors();
616   (*vector_map)["vector_1"] = masked_vector;
617   // Create expected request
618   SubmitSecureAggregationResultRequest expected_request;
619   *expected_request.mutable_masked_result_resource_name() = kMaskedResourceName;
620   *expected_request.mutable_nonmasked_result_resource_name() =
621       kNonmaskedResourceName;
622 
623   // Create expected responses
624   EXPECT_CALL(http_client_,
625               PerformSingleRequest(SimpleHttpRequestMatcher(
626                   "https://bytestream.uri/upload/v1/media/"
627                   "masked_resource?upload_protocol=raw",
628                   HttpRequest::Method::kPost, _,
629                   server_message.masked_input_response().SerializeAsString())))
630       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
631   EXPECT_CALL(http_client_, PerformSingleRequest(SimpleHttpRequestMatcher(
632                                 "https://bytestream.uri/upload/v1/media/"
633                                 "nonmasked_resource?upload_protocol=raw",
634                                 HttpRequest::Method::kPost, _, tf_checkpoint)))
635       .WillOnce(Return(FakeHttpResponse(200, HeaderList(), "")));
636 
637   EXPECT_CALL(
638       http_client_,
639       PerformSingleRequest(SimpleHttpRequestMatcher(
640           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
641           "clients/client_token:submit?%24alt=proto",
642           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
643       .WillOnce(Return(FakeHttpResponse(
644           200, HeaderList(),
645           CreatePendingOperation("operations/foo#bar").SerializeAsString())));
646 
647   secagg::UnmaskingRequest unmasking_request;
648   unmasking_request.add_dead_3_client_ids(12345);
649   SubmitSecureAggregationResultResponse submit_secagg_result_response;
650   *submit_secagg_result_response.mutable_unmasking_server_request() =
651       unmasking_request;
652   Operation complete_operation =
653       CreateDoneOperation(kOperationName, submit_secagg_result_response);
654   EXPECT_CALL(
655       http_client_,
656       PerformSingleRequest(SimpleHttpRequestMatcher(
657           "https://secureaggregation.uri/v1/operations/foo%23bar?%24alt=proto",
658           HttpRequest::Method::kGet, _, "")))
659       .WillOnce(Return(FakeHttpResponse(
660           200, HeaderList(),
661           CreateErrorOperation(kOperationName, absl::StatusCode::kInternal,
662                                "Something's wroing.")
663               .SerializeAsString())));
664   send_to_server->Send(&server_message);
665   EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kInternal));
666 }
667 
TEST_F(HttpSecAggSendToServerImplTest,TestSendR3Unmask)668 TEST_F(HttpSecAggSendToServerImplTest, TestSendR3Unmask) {
669   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
670       CreateSecAggSendToServer(std::optional<std::string>());
671   secagg::ClientToServerWrapperMessage server_message;
672   server_message.mutable_unmasking_response()
673       ->add_noise_or_prf_key_shares()
674       ->set_noise_sk_share("noise_sk_share");
675   // Create expected request
676   UnmaskRequest expected_request;
677   *expected_request.mutable_unmasking_client_response() =
678       server_message.unmasking_response();
679 
680   // Create expected response
681   EXPECT_CALL(
682       http_client_,
683       PerformSingleRequest(SimpleHttpRequestMatcher(
684           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
685           "clients/client_token:unmask?%24alt=proto",
686           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
687       .WillOnce(Return(FakeHttpResponse(200, HeaderList(),
688                                         UnmaskResponse().SerializeAsString())));
689   send_to_server->Send(&server_message);
690   auto response = server_response_holder_;
691   ASSERT_OK(response);
692   EXPECT_THAT(*response, EqualsProto(secagg::ServerToClientWrapperMessage()));
693 }
694 
TEST_F(HttpSecAggSendToServerImplTest,TestSendAbortWithoutInterruption)695 TEST_F(HttpSecAggSendToServerImplTest, TestSendAbortWithoutInterruption) {
696   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
697       CreateSecAggSendToServer(std::optional<std::string>());
698   std::string diagnostic_info = "Some computation failed.";
699   secagg::ClientToServerWrapperMessage server_message;
700   server_message.mutable_abort()->set_diagnostic_info(diagnostic_info);
701   // Create expected request
702   AbortSecureAggregationRequest expected_request;
703   google::rpc::Status status;
704   status.set_message(diagnostic_info);
705   status.set_code(google::rpc::INTERNAL);
706   *expected_request.mutable_status() = status;
707 
708   // We expect the abort request to actually be issued, because interrupted_ is
709   // set to false, and hence the "InterruptibleRunner" we provided at the top of
710   // the test should let the request go through.
711   EXPECT_CALL(
712       http_client_,
713       PerformSingleRequest(SimpleHttpRequestMatcher(
714           "https://secureaggregation.uri/v1/secureaggregations/aggregation_id/"
715           "clients/client_token:abort?%24alt=proto",
716           HttpRequest::Method::kPost, _, expected_request.SerializeAsString())))
717       .WillOnce(Return(FakeHttpResponse(
718           200, HeaderList(),
719           AbortSecureAggregationResponse().SerializeAsString())));
720 
721   // Send the request, and verify that sending it succeeded.
722   send_to_server->Send(&server_message);
723   ASSERT_OK(server_response_holder_);
724   secagg::ServerToClientWrapperMessage expected_response;
725   expected_response.mutable_abort();
726   EXPECT_THAT(*server_response_holder_, EqualsProto(expected_response));
727 }
728 
TEST_F(HttpSecAggSendToServerImplTest,TestSendAbortShouldBeCancelledIfAlreadyInterruptedForTooLong)729 TEST_F(HttpSecAggSendToServerImplTest,
730        TestSendAbortShouldBeCancelledIfAlreadyInterruptedForTooLong) {
731   std::unique_ptr<HttpSecAggSendToServerImpl> send_to_server =
732       CreateSecAggSendToServer(std::optional<std::string>());
733   std::string diagnostic_info = "Some computation failed.";
734   secagg::ClientToServerWrapperMessage server_message;
735   server_message.mutable_abort()->set_diagnostic_info(diagnostic_info);
736   // Create expected request
737   AbortSecureAggregationRequest expected_request;
738   google::rpc::Status status;
739   status.set_message(diagnostic_info);
740   status.set_code(google::rpc::INTERNAL);
741   *expected_request.mutable_status() = status;
742 
743   // We do *not* expect any HTTP request to actually be issued, since the
744   // interrupted_ flag is true, and therefore the request should be cancelled
745   // before it is even issued.
746   interrupted_ = true;
747 
748   // Send the request, and verify that sending it failed.
749   send_to_server->Send(&server_message);
750   EXPECT_THAT(server_response_holder_, IsCode(absl::StatusCode::kCancelled));
751 }
752 
753 }  // anonymous namespace
754 }  // namespace http
755 }  // namespace client
756 }  // namespace fcp
757