/* * Copyright 2020 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_ #define FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_ #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "fcp/secagg/server/experiments_interface.h" #include "fcp/secagg/server/secagg_scheduler.h" #include "fcp/secagg/server/secagg_server_enums.pb.h" #include "fcp/secagg/server/secagg_server_metrics_listener.h" #include "fcp/secagg/server/secret_sharing_graph.h" #include "fcp/secagg/server/send_to_clients_interface.h" #include "fcp/secagg/shared/aes_prng_factory.h" #include "fcp/secagg/shared/compute_session_id.h" #include "fcp/secagg/shared/ecdh_key_agreement.h" #include "fcp/secagg/shared/ecdh_keys.h" #include "fcp/secagg/shared/input_vector_specification.h" #include "fcp/secagg/shared/secagg_messages.pb.h" #include "fcp/secagg/shared/secagg_vector.h" #include "fcp/secagg/shared/shamir_secret_sharing.h" namespace fcp { namespace secagg { // Interface that describes internal implementation of SecAgg protocol. // // The general design is the following // // +--------------+ +-------------------+ +--------------------------+ // | SecAggServer |--->| SecAggServerState |--->| SecAggServerProtocolImpl | // +--------------+ +-------------------+ +--------------------------+ // ^ ^ // /-\ /-\ // | | // +-------------------+ +------------------------+ // | Specific State | | Specific protocol impl | // +-------------------+ +------------------------+ // // Specific states implement logic specific to each logic SecAgg state, such as // R0AdvertiseKeys or PrngRunning, while specific protocol implementation is // shared between all states and is responsible for encapsulating the data // for the protocol and providing methods for manipulating the data. // class SecAggServerProtocolImpl { public: explicit SecAggServerProtocolImpl( std::unique_ptr graph, int minimum_number_of_clients_to_proceed, std::unique_ptr metrics, std::unique_ptr prng_factory, SendToClientsInterface* sender, std::unique_ptr scheduler, std::vector client_statuses, std::unique_ptr experiments = nullptr); virtual ~SecAggServerProtocolImpl() = default; SecAggServerProtocolImpl(const SecAggServerProtocolImpl& other) = delete; SecAggServerProtocolImpl& operator=(const SecAggServerProtocolImpl& other) = delete; // Returns server variant for this protocol implementation. virtual ServerVariant server_variant() const = 0; // Returns the graph that represents the cohort of clients. inline const SecretSharingGraph* secret_sharing_graph() const { return secret_sharing_graph_.get(); } // Returns the minimum threshold number of clients that need to send valid // responses in order for the protocol to proceed from one round to the next. inline int minimum_number_of_clients_to_proceed() const { return minimum_number_of_clients_to_proceed_; } // Returns the callback interface for recording metrics. inline SecAggServerMetricsListener* metrics() const { return metrics_.get(); } // Returns a reference to an instance of a subclass of AesPrngFactory. inline AesPrngFactory* prng_factory() const { return prng_factory_.get(); } // Returns the callback interface for sending protocol buffer messages to the // client. inline SendToClientsInterface* sender() const { return sender_; } // Returns the scheduler for scheduling parallel computation tasks and // callbacks. inline SecAggScheduler* scheduler() const { return scheduler_.get(); } // Returns the experiments inline ExperimentsInterface* experiments() const { return experiments_.get(); } // Getting or setting the protocol result. // // TODO(team): SetResult should not be needed (except for testing) once // PRNG computation is moved into the protocol implementation. void SetResult(std::unique_ptr result); std::unique_ptr TakeResult(); // Gets the client status. inline const ClientStatus& client_status(uint32_t client_id) const { return client_statuses_.at(client_id); } // Sets the client status. inline void set_client_status(uint32_t client_id, ClientStatus status) { client_statuses_[client_id] = status; } // Gets the number of clients that the protocol starts with. inline size_t total_number_of_clients() const { return total_number_of_clients_; } // Returns the number of neighbors of each client. inline const int number_of_neighbors() const { return secret_sharing_graph()->GetDegree(); } // Returns the minimum number of neighbors of a client that must not drop-out // for that client's contribution to be included in the sum. This corresponds // to the threshold in the shamir secret sharing of self and pairwise masks. inline const int minimum_surviving_neighbors_for_reconstruction() const { return secret_sharing_graph()->GetThreshold(); } // Returns client_id's ith neighbor. // This function assumes that 0 <= i < number_of_neighbors() and will throw a // runtime error if that's not the case inline const int GetNeighbor(int client_id, int i) const { return secret_sharing_graph()->GetNeighbor(client_id, i); } // Returns the index of client_id_2 in the list of neighbors of client_id_1, // if present inline const std::optional GetNeighborIndex(int client_id_1, int client_id_2) const { return secret_sharing_graph()->GetNeighborIndex(client_id_1, client_id_2); } // Returns the index of client_id_2 in the list of neighbors of client_id_1 // This function assumes that client_id_1 and client_id_2 are neighbors, and // will throw a runtime error if that's not the case inline const int GetNeighborIndexOrDie(int client_id_1, int client_id_2) const { auto index = secret_sharing_graph()->GetNeighborIndex(client_id_1, client_id_2); FCP_CHECK(index.has_value()); return index.value(); } // Returns true if clients client_id_1 and client_id_1 are neighbors, else // false. inline const bool AreNeighbors(int client_id_1, int client_id_2) const { return secret_sharing_graph()->AreNeighbors(client_id_1, client_id_2); } // Returns true if client_id_1 is an outgoing neighbor of client_id_2, else // false. inline const bool IsOutgoingNeighbor(int client_id_1, int client_id_2) const { return secret_sharing_graph()->IsOutgoingNeighbor(client_id_1, client_id_2); } inline void SetPairwisePublicKeys(uint32_t client_id, const EcdhPublicKey& pairwise_key) { pairwise_public_keys_[client_id] = pairwise_key; } inline const EcdhPublicKey& pairwise_public_keys(uint32_t client_id) const { return pairwise_public_keys_[client_id]; } inline const SessionId& session_id() const { FCP_CHECK(session_id_ != nullptr); return *session_id_; } void set_session_id(std::unique_ptr session_id) { FCP_CHECK(session_id != nullptr); session_id_ = std::move(session_id); } // TODO(team): Review whether getters and setters below are needed. // Most of these fields are needed only for testing. void set_pairwise_shamir_share_table( std::unique_ptr>> pairwise_shamir_share_table) { pairwise_shamir_share_table_ = std::move(pairwise_shamir_share_table); } void set_self_shamir_share_table( std::unique_ptr>> self_shamir_share_table) { self_shamir_share_table_ = std::move(self_shamir_share_table); } // --------------------------------------------------------------------------- // Round 0 methods // --------------------------------------------------------------------------- // Sets the public key pairs for a client. Status HandleAdvertiseKeys(uint32_t client_id, const AdvertiseKeys& advertise_keys); // Erases public key pairs for a client. void ErasePublicKeysForClient(uint32_t client_id); // Compute session ID based on public key pairs advertised by clients. void ComputeSessionId(); // This method allows a protocol implementation to populate fields that are // common to the ShareKeysRequest sent to all clients. virtual Status InitializeShareKeysRequest( ShareKeysRequest* request) const = 0; // Prepares ShareKeysRequest message to send to the client. // This method will update fields in the request as needed, but will not clear // any fields that are not specific to the share keys request for the specific // client. The caller can therefore set up a single ShareKeysRequest object, // populate fields that will be common to all clients, and repeatedly call // this method to set the client-specific fields before serializing the // message and sending it. void PrepareShareKeysRequestForClient(uint32_t client_id, ShareKeysRequest* request) const; // Clears all pairs of public keys. void ClearPairsOfPublicKeys(); // --------------------------------------------------------------------------- // Round 1 methods // --------------------------------------------------------------------------- // Sets the encrypted shares received from a client. Status HandleShareKeysResponse(uint32_t client_id, const ShareKeysResponse& share_keys_response); // Erases the encrypted shares for a client. void EraseShareKeysForClient(uint32_t client_id); // Prepares MaskedInputCollectionRequest message to send to the client. // This method will update fields in the request as needed, but will not clear // any fields that are not specific to the share keys request for the specific // client. The caller can therefore set up a single ShareKeysRequest object, // populate fields that will be common to all clients, and repeatedly call // this method to set the client-specific fields before serializing the // message and sending it. void PrepareMaskedInputCollectionRequestForClient( uint32_t client_id, MaskedInputCollectionRequest* request) const; // Clears all encrypted shares. void ClearShareKeys(); // --------------------------------------------------------------------------- // Round 2 methods // --------------------------------------------------------------------------- // Sets up the sum of encrypted vectors received by the clients in R1. This // must be called before any other R2 methods are called. virtual std::shared_ptr> SetupMaskedInputCollection() = 0; // Finalizes the async aggregation of R2 messages before moving to R3. virtual void FinalizeMaskedInputCollection() = 0; // Check that an encrypted vector received by the user is valid, and add it to // the sum of encrypted vectors. virtual Status HandleMaskedInputCollectionResponse( std::unique_ptr masked_input_response) = 0; // --------------------------------------------------------------------------- // Round 3 methods // --------------------------------------------------------------------------- // This must be called in the beginning of round 3 to setup Shamir shares // tables based on client states at the beginning of the round. void SetUpShamirSharesTables(); // Populates Shamir shares tables with the data from UnmaskingResponse. // Returning an error status means that the unmasking response was invalid. Status HandleUnmaskingResponse(uint32_t client_id, const UnmaskingResponse& unmasking_response); // --------------------------------------------------------------------------- // PRNG computation methods // --------------------------------------------------------------------------- // Result of performing Shamir secret sharing keys reconstruction. struct ShamirReconstructionResult { absl::flat_hash_map aborted_client_key_agreements; absl::node_hash_map self_keys; }; // Performs reconstruction secret sharing keys reconstruction step of // the PRNG stage of the protocol. StatusOr HandleShamirReconstruction(); struct PrngWorkItems { std::vector prng_keys_to_add; std::vector prng_keys_to_subtract; }; // Initializes PRNG work items. StatusOr InitializePrng( const ShamirReconstructionResult& shamir_reconstruction_result) const; // Tells the PRNG stage of the protocol to start running asynchronously by // executing PRNG work items. // The returned cancellation token can be used to abort the asynchronous // execution. virtual CancellationToken StartPrng( const PrngWorkItems& work_items, std::function done_callback) = 0; private: std::unique_ptr secret_sharing_graph_; int minimum_number_of_clients_to_proceed_; std::vector input_vector_specs_; std::unique_ptr metrics_; std::unique_ptr prng_factory_; SendToClientsInterface* sender_; std::unique_ptr scheduler_; std::unique_ptr result_; size_t total_number_of_clients_; std::vector client_statuses_; std::unique_ptr experiments_; // This vector collects the public keys sent by the clients that will be used // for running the PRNG later on. std::vector pairwise_public_keys_; // This vector collects all pairs of public keys sent by the clients, so they // can be forwarded at the end of Advertise Keys round. std::vector pairs_of_public_keys_; std::unique_ptr session_id_; // Track the encrypted shares received from clients in preparation for sending // them. encrypted_shares_table_[i][j] is an encryption of the pair of shares // to be sent to client i, received from client j. std::vector> encrypted_shares_; // Shamir shares tables. // These store shares that have been collected from clients, and will be built // up over the course of round 3. For both tables, the map key represents // the client whose key these are shares of; the index in the vector // represents the client who provided that key share. std::unique_ptr>> pairwise_shamir_share_table_; std::unique_ptr>> self_shamir_share_table_; }; } // namespace secagg } // namespace fcp #endif // FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_