• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <utility>
19 
20 #include "absl/strings/str_format.h"
21 #include "fcp/base/random_token.h"
22 #include "fcp/tensorflow/external_dataset.h"
23 #include "fcp/tensorflow/status.h"
24 #include "tensorflow/core/framework/common_shape_fns.h"
25 #include "tensorflow/core/framework/dataset.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/shape_inference.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/public/version.h"
30 
31 namespace fcp {
32 
33 /**
34  * ExternalDataset op-kernel. Delegates to an ExternalDatasetProvider, found
35  * from the ExternalDatasetProviderRegistry (a HostObjectRegistry).
36  *
37  * Inputs:
38  *   selector: An opaque string scalar. Forwarded to the stub.
39  *   token: String scalar. It should encode a token obtained from
40  *          ExternalDatasetProviderRegistry::Register.
41  *
42  * See TensorFlow's guide to making custom dataset ops:
43  * https://www.tensorflow.org/guide/extend/formats
44  */
45 class ExternalDatasetOp : public tensorflow::data::DatasetOpKernel {
46  public:
47   using tensorflow::data::DatasetOpKernel::DatasetOpKernel;
48 
MakeDataset(tensorflow::OpKernelContext * ctx,tensorflow::data::DatasetBase ** output)49   void MakeDataset(tensorflow::OpKernelContext* ctx,
50                    tensorflow::data::DatasetBase** output) override {
51     tensorflow::tstring token_str;
52     OP_REQUIRES_OK(ctx,
53                    tensorflow::data::ParseScalarArgument<tensorflow::tstring>(
54                        ctx, "token", &token_str));
55     absl::Span<char const> token_bytes = token_str;
56     OP_REQUIRES(ctx, token_bytes.size() == kRandomTokenSizeInBytes,
57                 tensorflow::errors::InvalidArgument(absl::StrFormat(
58                     "Tokens have a fixed size. Expected: %d; Actual %d",
59                     kRandomTokenSizeInBytes, token_bytes.size())));
60     RandomToken token = RandomToken::FromBytes(token_bytes);
61 
62     tensorflow::tstring selector_str;
63     OP_REQUIRES_OK(ctx,
64                    tensorflow::data::ParseScalarArgument<tensorflow::tstring>(
65                        ctx, "selector", &selector_str));
66 
67     std::optional<std::shared_ptr<ExternalDatasetProvider>> maybe_provider =
68         ExternalDatasetProviderRegistry::TryLookup(token);
69     OP_REQUIRES(ctx, maybe_provider.has_value(),
70                 tensorflow::errors::InvalidArgument(
71                     "A dataset provider is not currently registered for the "
72                     "provided token: ",
73                     token.ToPrintableString()));
74 
75     std::shared_ptr<ExternalDatasetProvider> provider =
76         *std::move(maybe_provider);
77     StatusOr<std::unique_ptr<ExternalDataset>> maybe_dataset =
78         provider->MakeDataset(selector_str);
79     // The provider might not like the given selector.
80     if (!maybe_dataset.ok()) {
81       ctx->SetStatus(ConvertToTensorFlowStatus(maybe_dataset.status()));
82       return;
83     }
84 
85     *output = new Dataset(ctx, std::move(maybe_dataset).value());
86   }
87 
88  private:
89   class Dataset : public tensorflow::data::DatasetBase {
90    public:
Dataset(tensorflow::OpKernelContext * ctx,std::unique_ptr<ExternalDataset> stub)91     Dataset(tensorflow::OpKernelContext* ctx,
92             std::unique_ptr<ExternalDataset> stub)
93         : DatasetBase(tensorflow::data::DatasetContext(ctx)),
94           stub_(std::move(stub)) {}
95 
MakeIteratorInternal(const std::string & prefix) const96     std::unique_ptr<tensorflow::data::IteratorBase> MakeIteratorInternal(
97         const std::string& prefix) const override {
98       std::unique_ptr<ExternalDatasetIterator> iter = stub_->MakeIterator();
99       Iterator::Params params{
100           this, tensorflow::strings::StrCat(prefix, "::ExternalDataset")};
101       return std::unique_ptr<tensorflow::data::IteratorBase>(
102           new Iterator(params, std::move(iter)));
103     }
104 
105     // Each iterator element is just a scalar string.
106 
output_dtypes() const107     const tensorflow::DataTypeVector& output_dtypes() const override {
108       static auto* const dtypes =
109           new tensorflow::DataTypeVector({tensorflow::DT_STRING});
110       return *dtypes;
111     }
112 
output_shapes() const113     const std::vector<tensorflow::PartialTensorShape>& output_shapes()
114         const override {
115       static std::vector<tensorflow::PartialTensorShape>* shapes =
116           new std::vector<tensorflow::PartialTensorShape>({{}});
117       return *shapes;
118     }
119 
DebugString() const120     std::string DebugString() const override {
121       return "ExternalDatasetOp::Dataset";
122     }
123 
InputDatasets(std::vector<const DatasetBase * > * inputs) const124     tensorflow::Status InputDatasets(
125         std::vector<const DatasetBase*>* inputs) const override {
126       // ExternalDatast has no input datasets, so just return OK.
127       return tensorflow::OkStatus();
128     }
129 
130 // The `DatasetBase::CheckExternalState()` method was introduced on 8/7/2019. We
131 // use the `TF_GRAPH_DEF_VERSION` value (which is updated daily) to determine if
132 // we should add its override.
133 #if TF_GRAPH_DEF_VERSION > 125
CheckExternalState() const134     tensorflow::Status CheckExternalState() const override {
135       return tensorflow::OkStatus();
136     }
137 #endif
138 
139    protected:
AsGraphDefInternal(tensorflow::data::SerializationContext * ctx,DatasetGraphDefBuilder * b,tensorflow::Node ** output) const140     tensorflow::Status AsGraphDefInternal(
141         tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b,
142         tensorflow::Node** output) const override {
143       return ::tensorflow::errors::Unimplemented(
144           DebugString(), " does not support serialization.");
145     }
146 
147    private:
148     class Iterator : public tensorflow::data::DatasetIterator<Dataset> {
149      public:
Iterator(const Params & params,std::unique_ptr<ExternalDatasetIterator> stub)150       explicit Iterator(const Params& params,
151                         std::unique_ptr<ExternalDatasetIterator> stub)
152           : DatasetIterator<Dataset>(params), stub_(std::move(stub)) {}
153 
GetNextInternal(tensorflow::data::IteratorContext * ctx,std::vector<tensorflow::Tensor> * out_tensors,bool * end_of_sequence)154       tensorflow::Status GetNextInternal(
155           tensorflow::data::IteratorContext* ctx,
156           std::vector<tensorflow::Tensor>* out_tensors,
157           bool* end_of_sequence) override {
158         StatusOr<std::string> maybe_element;
159         {
160           absl::MutexLock _(&mu_);
161           maybe_element = stub_->GetNext();
162         }
163 
164         if (maybe_element.ok()) {
165           std::string element = std::move(maybe_element).value();
166 
167           // The {} at the end specifies a scalar tensor.
168           tensorflow::Tensor element_tensor(ctx->allocator({}),
169                                             tensorflow::DT_STRING, {});
170           element_tensor.scalar<tensorflow::tstring>()() = element;
171 
172           *end_of_sequence = false;
173           out_tensors->push_back(std::move(element_tensor));
174           return tensorflow::OkStatus();
175         } else {
176           *end_of_sequence = true;
177           if (maybe_element.status().code() == StatusCode::kOutOfRange) {
178             return tensorflow::OkStatus();
179           } else {
180             return ConvertToTensorFlowStatus(maybe_element.status());
181           }
182         }
183       }
184 
185      protected:
SaveInternal(tensorflow::data::SerializationContext * ctx,tensorflow::data::IteratorStateWriter * writer)186       tensorflow::Status SaveInternal(
187 // `::tensorflow::data::SerializationContext` argument was added on
188 // 2020-03-17 when `TF_GRAPH_DEF_VERSION` was defined to 343.
189 #if TF_GRAPH_DEF_VERSION > 343
190           tensorflow::data::SerializationContext* ctx,
191 #endif
192           tensorflow::data::IteratorStateWriter* writer) override {
193         return ::tensorflow::errors::Unimplemented(
194             "Save / Restore of an ExternalDataset iterator is not supported");
195       }
RestoreInternal(tensorflow::data::IteratorContext * ctx,tensorflow::data::IteratorStateReader * reader)196       tensorflow::Status RestoreInternal(
197           tensorflow::data::IteratorContext* ctx,
198           tensorflow::data::IteratorStateReader* reader) override {
199         return ::tensorflow::errors::Unimplemented(
200             "Save / Restore of an ExternalDataset iterator is not supported");
201       }
202 
203      private:
204       std::unique_ptr<ExternalDatasetIterator> stub_;
205       absl::Mutex mu_;
206     };
207 
208     // Private members of Dataset
209 
210     std::unique_ptr<ExternalDataset> stub_;
211   };
212 };
213 
214 REGISTER_OP("ExternalDataset")
215     .Input("token: string")
216     .Input("selector: string")
217     .Output("handle: variant")
218     .SetIsStateful()
219     .SetShapeFn(tensorflow::shape_inference::ScalarShape);
220 
221 REGISTER_KERNEL_BUILDER(Name("ExternalDataset").Device(tensorflow::DEVICE_CPU),
222                         ExternalDatasetOp);
223 
224 }  // namespace fcp
225