/* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "fcp/protos/plan.pb.h" #include "google/protobuf/any.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" namespace fcp { using ::google::internal::federated::plan::ExampleSelector; using ::tensorflow::DEVICE_CPU; using ::tensorflow::OpKernel; using ::tensorflow::OpKernelConstruction; using ::tensorflow::OpKernelContext; using ::tensorflow::Tensor; using ::tensorflow::data::ParseScalarArgument; /** * ExampleSelectorFuserOp op-kernel. * * ExampleSelectorFuser fills the resumption token field for an existing * ExampleSelector protobuf message. The resumption token field is an Any proto * which can be any user defined protobuf message. The user needs to provide the * type url and content for the resumption token. * * Inputs: * example_selector: A string scalar encodes an ExampleSelector protobuf * message. * resumption_token_type_url: String scalar. The type_url for the resumption * token. * resumption_token_content: String scalar. The bytes for the resumption * token message. * * Output: * A string tensor contains the fused ExampleSelector message serialized to * string. */ class ExampleSelectorFuserOp : public OpKernel { public: explicit ExampleSelectorFuserOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { tensorflow::tstring example_selector_str; OP_REQUIRES_OK(ctx, ParseScalarArgument( ctx, "example_selector", &example_selector_str)); tensorflow::tstring resumption_token_type_url_str; OP_REQUIRES_OK(ctx, ParseScalarArgument( ctx, "resumption_token_type_url", &resumption_token_type_url_str)); tensorflow::tstring resumption_token_content_str; OP_REQUIRES_OK(ctx, ParseScalarArgument( ctx, "resumption_token_content", &resumption_token_content_str)); ExampleSelector example_selector; if (!example_selector.ParseFromString( std::string(example_selector_str.data()))) { ctx->SetStatus(tensorflow::Status( // Remove the cast after TF 2.12 is released and used in FCP. static_cast( absl::StatusCode::kInvalidArgument), tensorflow::StringPiece("Cannot parse ExampleSelector"))); return; } example_selector.mutable_resumption_token()->set_type_url( std::string(resumption_token_type_url_str.data())); example_selector.mutable_resumption_token()->set_value( std::string(resumption_token_content_str.data())); Tensor* output_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output_tensor)); output_tensor->flat()(0) = example_selector.SerializeAsString(); } }; REGISTER_OP("ExampleSelectorFuser") .Input("example_selector: string") .Input("resumption_token_type_url: string") .Input("resumption_token_content: string") .Output("fused_example_selector: string") .SetShapeFn(tensorflow::shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("ExampleSelectorFuser").Device(DEVICE_CPU), ExampleSelectorFuserOp); } // namespace fcp