• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_TENSORFLOW_EXTERNAL_DATASET_H_
18 #define FCP_TENSORFLOW_EXTERNAL_DATASET_H_
19 
20 #include <memory>
21 #include <string>
22 
23 #include "absl/status/status.h"
24 #include "absl/status/statusor.h"
25 #include "absl/strings/string_view.h"
26 #include "fcp/base/bounds.h"
27 #include "fcp/tensorflow/host_object.h"
28 
29 namespace fcp {
30 
31 /**
32  * Interface for an iterator, created from a particular dataset. A single
33  * dataset may be used to create multiple iterators.
34  */
35 class ExternalDatasetIterator {
36  public:
37   virtual ~ExternalDatasetIterator() = default;
38 
39   /**
40    * Returns the next element, if possible. Indicates end-of-stream with
41    * OUT_OF_RANGE, even when repeatedly called. Corresponds to
42    * tensorflow::data::IteratorBase::GetNext.
43    *
44    * Implementations must be thread-safe.
45    */
46   virtual absl::StatusOr<std::string> GetNext() = 0;
47 };
48 
49 namespace external_dataset_internal {
50 
51 template <typename FuncType>
52 class DatasetFromFunction;
53 
54 }  // namespace external_dataset_internal
55 
56 /**
57  * Interface for a particular dataset - created from an ExternalDatasetProvider
58  * (during dataset op execution), for a particular selector. A dataset may be
59  * used zero or more times to create an ExternalDatasetIterator.
60  *
61  * Dataset implementations are often trivial, just needing to capture some
62  * values (like the selector) for the iterator constructor. Consider using
63  * ExternalDataset::FromFunction.
64  */
65 class ExternalDataset {
66  public:
67   virtual ~ExternalDataset() = default;
68 
69   /**
70    * Creates a new iterator. Corresponds to
71    * tensorflow::data::DatasetBase::MakeIterator.
72    */
73   virtual std::unique_ptr<ExternalDatasetIterator> MakeIterator() = 0;
74 
75   /**
76    * Creates an ExternalDataset that wraps a callable object 'f', implementing
77    * MakeIterator(). The lifetime of 'f' is that of the dataset (so,
78    * by-reference lambda captures are almost always unsafe here).
79    */
80   template <typename F>
FromFunction(F f)81   static std::unique_ptr<ExternalDataset> FromFunction(F f) {
82     return std::make_unique<external_dataset_internal::DatasetFromFunction<F>>(
83         std::move(f));
84   }
85 };
86 
87 /**
88  * Interface for an ExternalDataset op's host object.
89  *
90  * An ExternalDatasetProvider is a function from Selector -> ExternalDataset.
91  * Here, 'Selector' is a string provided to the dataset op (typically, an
92  * encoded proto). The returned ExternalDataset may be used (perhaps multiple
93  * times) to create an iterator.
94  *
95  * When implementing a dataset provider and the selector is a proto message,
96  * consider inheritng from ExternalDatasetProvider::UsingProtoSelector<T> (for
97  * some message type T).
98  */
99 class ExternalDatasetProvider {
100  public:
101   virtual ~ExternalDatasetProvider() = default;
102 
103   /**
104    * Creates a dataset for a given selector.
105    *
106    * This function can usually be implemented succinctly, using
107    * ExternalDataset::FromFunction.
108    *
109    * Corresponds to tensorflow::data::DatasetOpKernel::MakeDataset.
110    */
111   virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
112       absl::string_view selector) = 0;
113 
114   /**
115    * Base class for dataset providers that expect a selector of a particular
116    * proto message type. If inheriting from UsingProtoSelector<T>, then one
117    * implements MakeDataset(T) instead of MakeDataset(absl::string_view).
118    */
119   template <typename T>
120   class UsingProtoSelector;
121 };
122 
123 /**
124  * HostObjectRegistry for the ExternalDataset interface.
125  */
126 using ExternalDatasetProviderRegistry =
127     HostObjectRegistry<ExternalDatasetProvider>;
128 
129 namespace external_dataset_internal {
130 
131 template <typename T>
TryParseProtoSelector(absl::string_view selector)132 absl::StatusOr<T> TryParseProtoSelector(absl::string_view selector) {
133   T msg;
134   if (!msg.ParseFromArray(selector.data(),
135                           CastIntegerChecked<int>(selector.size()))) {
136     return absl::InvalidArgumentError(absl::StrCat(
137         "Failed to parse selector proto of type ", msg.GetTypeName()));
138   }
139 
140   return msg;
141 }
142 
143 template <typename FuncType>
144 class DatasetFromFunction : public ExternalDataset {
145  public:
DatasetFromFunction(FuncType func)146   explicit DatasetFromFunction(FuncType func) : func_(std::move(func)) {}
147 
MakeIterator()148   std::unique_ptr<ExternalDatasetIterator> MakeIterator() final {
149     return func_();
150   }
151 
152  private:
153   FuncType func_;
154 };
155 
156 }  // namespace external_dataset_internal
157 
158 template <typename T>
159 class ExternalDatasetProvider::UsingProtoSelector
160     : public ExternalDatasetProvider {
161  public:
MakeDataset(absl::string_view selector)162   absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
163       absl::string_view selector) final {
164     auto maybe_msg =
165         external_dataset_internal::TryParseProtoSelector<T>(selector);
166     if (!maybe_msg.ok()) {
167       return maybe_msg.status();
168     }
169 
170     return MakeDataset(std::move(maybe_msg).value());
171   }
172 
173   virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
174       T selector) = 0;
175 };
176 
177 }  // namespace fcp
178 
179 #endif  // FCP_TENSORFLOW_EXTERNAL_DATASET_H_
180