/* * Copyright 2023 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "fcp/client/engine/example_query_plan_engine.h" #include #include #include #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "fcp/client/client_runner.h" #include "fcp/client/engine/common.h" #include "fcp/client/example_query_result.pb.h" #include "fcp/client/test_helpers.h" #include "fcp/protos/plan.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "tensorflow/c/checkpoint_reader.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" namespace fcp { namespace client { namespace engine { namespace { namespace tf = ::tensorflow; using ::fcp::client::ExampleQueryResult; using ::google::internal::federated::plan::AggregationConfig; using ::google::internal::federated::plan::ClientOnlyPlan; using ::google::internal::federated::plan::Dataset; using ::google::internal::federated::plan::ExampleQuerySpec; using ::google::internal::federated::plan::ExampleSelector; using ::testing::StrictMock; const char* const kCollectionUri = "app:/test_collection"; const char* const kOutputStringVectorName = "vector1"; const char* const kOutputIntVectorName = "vector2"; const char* const kOutputStringTensorName = "tensor1"; const char* const kOutputIntTensorName = "tensor2"; class InvalidExampleIteratorFactory : public ExampleIteratorFactory { public: InvalidExampleIteratorFactory() = default; bool CanHandle(const google::internal::federated::plan::ExampleSelector& example_selector) override { return false; } absl::StatusOr> CreateExampleIterator( const ExampleSelector& example_selector) override { absl::Status error(absl::StatusCode::kInternal, ""); return error; } bool ShouldCollectStats() override { return false; } }; class NoIteratorExampleIteratorFactory : public ExampleIteratorFactory { public: NoIteratorExampleIteratorFactory() = default; bool CanHandle(const google::internal::federated::plan::ExampleSelector& example_selector) override { return true; } absl::StatusOr> CreateExampleIterator( const ExampleSelector& example_selector) override { absl::Status error(absl::StatusCode::kInternal, ""); return error; } bool ShouldCollectStats() override { return false; } }; class TwoExampleIteratorsFactory : public ExampleIteratorFactory { public: explicit TwoExampleIteratorsFactory( std::function>( const google::internal::federated::plan::ExampleSelector& )> create_first_iterator_func, std::function>( const google::internal::federated::plan::ExampleSelector& )> create_second_iterator_func, const std::string& first_collection_uri, const std::string& second_collection_uri) : create_first_iterator_func_(create_first_iterator_func), create_second_iterator_func_(create_second_iterator_func), first_collection_uri_(first_collection_uri), second_collection_uri_(second_collection_uri) {} bool CanHandle(const google::internal::federated::plan::ExampleSelector& example_selector) override { return true; } absl::StatusOr> CreateExampleIterator( const google::internal::federated::plan::ExampleSelector& example_selector) override { if (example_selector.collection_uri() == first_collection_uri_) { return create_first_iterator_func_(example_selector); } else if (example_selector.collection_uri() == second_collection_uri_) { return create_second_iterator_func_(example_selector); } return absl::InvalidArgumentError("Unknown collection URI"); } bool ShouldCollectStats() override { return false; } private: std::function>( const google::internal::federated::plan::ExampleSelector&)> create_first_iterator_func_; std::function>( const google::internal::federated::plan::ExampleSelector&)> create_second_iterator_func_; std::string first_collection_uri_; std::string second_collection_uri_; }; absl::StatusOr> ReadTensors( std::string checkpoint_path) { absl::flat_hash_map tensors; tf::TF_StatusPtr tf_status(TF_NewStatus()); tf::checkpoint::CheckpointReader tf_checkpoint_reader(checkpoint_path, tf_status.get()); if (TF_GetCode(tf_status.get()) != TF_OK) { return absl::NotFoundError("Couldn't read an input checkpoint"); } for (const auto& [name, tf_dtype] : tf_checkpoint_reader.GetVariableToDataTypeMap()) { std::unique_ptr tensor; tf_checkpoint_reader.GetTensor(name, &tensor, tf_status.get()); if (TF_GetCode(tf_status.get()) != TF_OK) { return absl::NotFoundError( absl::StrFormat("Checkpoint doesn't have tensor %s", name)); } tensors[name] = *tensor; } return tensors; } class ExampleQueryPlanEngineTest : public testing::Test { protected: void Initialize() { std::filesystem::path root_dir(testing::TempDir()); std::filesystem::path output_path = root_dir / std::string("output.ckpt"); output_checkpoint_filename_ = output_path.string(); ExampleQuerySpec::OutputVectorSpec string_vector_spec; string_vector_spec.set_vector_name(kOutputStringVectorName); string_vector_spec.set_data_type( ExampleQuerySpec::OutputVectorSpec::STRING); ExampleQuerySpec::OutputVectorSpec int_vector_spec; int_vector_spec.set_vector_name(kOutputIntVectorName); int_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::INT64); ExampleQuerySpec::ExampleQuery example_query; example_query.mutable_example_selector()->set_collection_uri( kCollectionUri); (*example_query.mutable_output_vector_specs())[kOutputStringTensorName] = string_vector_spec; (*example_query.mutable_output_vector_specs())[kOutputIntTensorName] = int_vector_spec; client_only_plan_.mutable_phase() ->mutable_example_query_spec() ->mutable_example_queries() ->Add(std::move(example_query)); AggregationConfig aggregation_config; aggregation_config.mutable_tf_v1_checkpoint_aggregation(); (*client_only_plan_.mutable_phase() ->mutable_federated_example_query() ->mutable_aggregations())[kOutputStringTensorName] = aggregation_config; (*client_only_plan_.mutable_phase() ->mutable_federated_example_query() ->mutable_aggregations())[kOutputIntTensorName] = aggregation_config; ExampleQueryResult::VectorData::Values int_values; int_values.mutable_int64_values()->add_value(42); int_values.mutable_int64_values()->add_value(24); (*example_query_result_.mutable_vector_data() ->mutable_vectors())[kOutputIntVectorName] = int_values; ExampleQueryResult::VectorData::Values string_values; string_values.mutable_string_values()->add_value("value1"); string_values.mutable_string_values()->add_value("value2"); (*example_query_result_.mutable_vector_data() ->mutable_vectors())[kOutputStringVectorName] = string_values; std::string example = example_query_result_.SerializeAsString(); Dataset::ClientDataset client_dataset; client_dataset.set_client_id("client_id"); client_dataset.add_example(example); dataset_.mutable_client_data()->Add(std::move(client_dataset)); num_examples_ = 1; example_bytes_ = example.size(); example_iterator_factory_ = std::make_unique( [&dataset = dataset_]( const google::internal::federated::plan::ExampleSelector& selector) { return std::make_unique(dataset); }); } fcp::client::FilesImpl files_impl_; StrictMock mock_opstats_logger_; std::unique_ptr example_iterator_factory_; ExampleQueryResult example_query_result_; ClientOnlyPlan client_only_plan_; Dataset dataset_; std::string output_checkpoint_filename_; int num_examples_ = 0; int64_t example_bytes_ = 0; }; TEST_F(ExampleQueryPlanEngineTest, PlanSucceeds) { Initialize(); EXPECT_CALL( mock_opstats_logger_, UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_)); ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()}, &mock_opstats_logger_); engine::PlanResult result = plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), output_checkpoint_filename_); EXPECT_THAT(result.outcome, PlanOutcome::kSuccess); auto tensors = ReadTensors(output_checkpoint_filename_); ASSERT_TRUE(tensors.ok()); tf::Tensor int_tensor = tensors.value()[kOutputIntTensorName]; ASSERT_EQ(int_tensor.shape(), tf::TensorShape({2})); ASSERT_EQ(int_tensor.dtype(), tf::DT_INT64); auto int_data = static_cast(int_tensor.data()); std::vector expected_int_data({42, 24}); for (int i = 0; i < 2; ++i) { ASSERT_EQ(int_data[i], expected_int_data[i]); } tf::Tensor string_tensor = tensors.value()[kOutputStringTensorName]; ASSERT_EQ(string_tensor.shape(), tf::TensorShape({2})); ASSERT_EQ(string_tensor.dtype(), tf::DT_STRING); auto string_data = static_cast(string_tensor.data()); std::vector expected_string_data({"value1", "value2"}); for (int i = 0; i < 2; ++i) { ASSERT_EQ(static_cast(string_data[i]), expected_string_data[i]); } } TEST_F(ExampleQueryPlanEngineTest, MultipleQueries) { Initialize(); ExampleQuerySpec::OutputVectorSpec float_vector_spec; float_vector_spec.set_vector_name("float_vector"); float_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::FLOAT); ExampleQuerySpec::OutputVectorSpec string_vector_spec; // Same vector name as in the other ExampleQuery, but with a different output // one to make sure these vectors are distinguished in // example_query_plan_engine. string_vector_spec.set_vector_name(kOutputStringVectorName); string_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::STRING); ExampleQuerySpec::ExampleQuery second_example_query; second_example_query.mutable_example_selector()->set_collection_uri( "app:/second_collection"); (*second_example_query.mutable_output_vector_specs())["float_tensor"] = float_vector_spec; (*second_example_query .mutable_output_vector_specs())["another_string_tensor"] = string_vector_spec; client_only_plan_.mutable_phase() ->mutable_example_query_spec() ->mutable_example_queries() ->Add(std::move(second_example_query)); AggregationConfig aggregation_config; aggregation_config.mutable_tf_v1_checkpoint_aggregation(); (*client_only_plan_.mutable_phase() ->mutable_federated_example_query() ->mutable_aggregations())["float_tensor"] = aggregation_config; ExampleQueryResult second_example_query_result; ExampleQueryResult::VectorData::Values float_values; float_values.mutable_float_values()->add_value(0.24f); float_values.mutable_float_values()->add_value(0.42f); float_values.mutable_float_values()->add_value(0.33f); ExampleQueryResult::VectorData::Values string_values; string_values.mutable_string_values()->add_value("another_string_value"); (*second_example_query_result.mutable_vector_data() ->mutable_vectors())["float_vector"] = float_values; (*second_example_query_result.mutable_vector_data() ->mutable_vectors())[kOutputStringVectorName] = string_values; std::string example = second_example_query_result.SerializeAsString(); Dataset::ClientDataset dataset; dataset.set_client_id("second_client_id"); dataset.add_example(example); Dataset second_dataset; second_dataset.mutable_client_data()->Add(std::move(dataset)); example_iterator_factory_ = std::make_unique( [&dataset = dataset_]( const google::internal::federated::plan::ExampleSelector& selector) { return std::make_unique(dataset); }, [&dataset = second_dataset]( const google::internal::federated::plan::ExampleSelector& selector) { return std::make_unique(dataset); }, kCollectionUri, "app:/second_collection"); ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()}, &mock_opstats_logger_); engine::PlanResult result = plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), output_checkpoint_filename_); EXPECT_THAT(result.outcome, PlanOutcome::kSuccess); auto tensors = ReadTensors(output_checkpoint_filename_); // ASSERT_OK is not supported in AOSP. ASSERT_TRUE(tensors.ok()); tf::Tensor int_tensor = tensors.value()[kOutputIntTensorName]; ASSERT_EQ(int_tensor.shape(), tf::TensorShape({2})); ASSERT_EQ(int_tensor.dtype(), tf::DT_INT64); auto int_data = static_cast(int_tensor.data()); std::vector expected_int_data({42, 24}); for (int i = 0; i < 2; ++i) { ASSERT_EQ(int_data[i], expected_int_data[i]); } tf::Tensor string_tensor = tensors.value()[kOutputStringTensorName]; ASSERT_EQ(string_tensor.shape(), tf::TensorShape({2})); ASSERT_EQ(string_tensor.dtype(), tf::DT_STRING); auto string_data = static_cast(string_tensor.data()); std::vector expected_string_data({"value1", "value2"}); for (int i = 0; i < 2; ++i) { ASSERT_EQ(static_cast(string_data[i]), expected_string_data[i]); } tf::Tensor float_tensor = tensors.value()["float_tensor"]; ASSERT_EQ(float_tensor.shape(), tf::TensorShape({3})); ASSERT_EQ(float_tensor.dtype(), tf::DT_FLOAT); auto float_data = static_cast(float_tensor.data()); std::vector expected_float_data({0.24f, 0.42f, 0.33f}); for (int i = 0; i < 3; ++i) { ASSERT_EQ(float_data[i], expected_float_data[i]); } tf::Tensor second_query_string_tensor = tensors.value()["another_string_tensor"]; ASSERT_EQ(second_query_string_tensor.shape(), tf::TensorShape({1})); ASSERT_EQ(second_query_string_tensor.dtype(), tf::DT_STRING); auto second_query_string_data = static_cast(second_query_string_tensor.data()); ASSERT_EQ(static_cast(*second_query_string_data), "another_string_value"); } TEST_F(ExampleQueryPlanEngineTest, OutputVectorSpecMissingInResult) { Initialize(); ExampleQuerySpec::OutputVectorSpec new_vector_spec; new_vector_spec.set_vector_name("new_vector"); new_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::DOUBLE); ExampleQuerySpec::ExampleQuery example_query = client_only_plan_.phase().example_query_spec().example_queries().at(0); (*example_query.mutable_output_vector_specs())["new_tensor"] = new_vector_spec; client_only_plan_.mutable_phase() ->mutable_example_query_spec() ->clear_example_queries(); client_only_plan_.mutable_phase() ->mutable_example_query_spec() ->mutable_example_queries() ->Add(std::move(example_query)); ExampleQueryResult example_query_result; ExampleQueryResult::VectorData::Values bool_values; bool_values.mutable_bool_values()->add_value(true); (*example_query_result_.mutable_vector_data() ->mutable_vectors())["new_vector"] = bool_values; std::string example = example_query_result_.SerializeAsString(); Dataset::ClientDataset client_dataset; client_dataset.set_client_id("client_id"); client_dataset.add_example(example); dataset_.clear_client_data(); dataset_.mutable_client_data()->Add(std::move(client_dataset)); num_examples_ = 1; example_bytes_ = example.size(); example_iterator_factory_ = std::make_unique( [&dataset = dataset_]( const google::internal::federated::plan::ExampleSelector& selector) { return std::make_unique(dataset); }); EXPECT_CALL( mock_opstats_logger_, UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_)); ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()}, &mock_opstats_logger_); engine::PlanResult result = plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), output_checkpoint_filename_); EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError); } TEST_F(ExampleQueryPlanEngineTest, OutputVectorSpecTypeMismatch) { Initialize(); ExampleQuerySpec::OutputVectorSpec new_vector_spec; new_vector_spec.set_vector_name("new_vector"); new_vector_spec.set_data_type(ExampleQuerySpec::OutputVectorSpec::DOUBLE); ExampleQuerySpec::ExampleQuery example_query = client_only_plan_.phase().example_query_spec().example_queries().at(0); (*example_query.mutable_output_vector_specs())["new_tensor"] = new_vector_spec; client_only_plan_.mutable_phase() ->mutable_example_query_spec() ->clear_example_queries(); client_only_plan_.mutable_phase() ->mutable_example_query_spec() ->mutable_example_queries() ->Add(std::move(example_query)); EXPECT_CALL( mock_opstats_logger_, UpdateDatasetStats(kCollectionUri, num_examples_, example_bytes_)); ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()}, &mock_opstats_logger_); engine::PlanResult result = plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), output_checkpoint_filename_); EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError); } TEST_F(ExampleQueryPlanEngineTest, FactoryNotFound) { Initialize(); auto invalid_example_factory = std::make_unique(); ExampleQueryPlanEngine plan_engine({invalid_example_factory.get()}, &mock_opstats_logger_); engine::PlanResult result = plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), output_checkpoint_filename_); EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError); } TEST_F(ExampleQueryPlanEngineTest, NoIteratorCreated) { Initialize(); auto invalid_example_factory = std::make_unique(); ExampleQueryPlanEngine plan_engine({invalid_example_factory.get()}, &mock_opstats_logger_); engine::PlanResult result = plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), output_checkpoint_filename_); EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError); } TEST_F(ExampleQueryPlanEngineTest, InvalidExampleQueryResultFormat) { Initialize(); std::string invalid_example = "invalid_example"; Dataset::ClientDataset client_dataset; client_dataset.add_example(invalid_example); dataset_.clear_client_data(); dataset_.mutable_client_data()->Add(std::move(client_dataset)); example_iterator_factory_ = std::make_unique( [&dataset = dataset_]( const google::internal::federated::plan::ExampleSelector& selector) { return std::make_unique(dataset); }); EXPECT_CALL(mock_opstats_logger_, UpdateDatasetStats(kCollectionUri, 1, invalid_example.size())); ExampleQueryPlanEngine plan_engine({example_iterator_factory_.get()}, &mock_opstats_logger_); engine::PlanResult result = plan_engine.RunPlan(client_only_plan_.phase().example_query_spec(), output_checkpoint_filename_); EXPECT_THAT(result.outcome, PlanOutcome::kExampleIteratorError); } } // anonymous namespace } // namespace engine } // namespace client } // namespace fcp