1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef FCP_TENSORFLOW_EXTERNAL_DATASET_H_
18 #define FCP_TENSORFLOW_EXTERNAL_DATASET_H_
19
20 #include <memory>
21 #include <string>
22
23 #include "absl/status/status.h"
24 #include "absl/status/statusor.h"
25 #include "absl/strings/string_view.h"
26 #include "fcp/base/bounds.h"
27 #include "fcp/tensorflow/host_object.h"
28
29 namespace fcp {
30
31 /**
32 * Interface for an iterator, created from a particular dataset. A single
33 * dataset may be used to create multiple iterators.
34 */
35 class ExternalDatasetIterator {
36 public:
37 virtual ~ExternalDatasetIterator() = default;
38
39 /**
40 * Returns the next element, if possible. Indicates end-of-stream with
41 * OUT_OF_RANGE, even when repeatedly called. Corresponds to
42 * tensorflow::data::IteratorBase::GetNext.
43 *
44 * Implementations must be thread-safe.
45 */
46 virtual absl::StatusOr<std::string> GetNext() = 0;
47 };
48
49 namespace external_dataset_internal {
50
51 template <typename FuncType>
52 class DatasetFromFunction;
53
54 } // namespace external_dataset_internal
55
56 /**
57 * Interface for a particular dataset - created from an ExternalDatasetProvider
58 * (during dataset op execution), for a particular selector. A dataset may be
59 * used zero or more times to create an ExternalDatasetIterator.
60 *
61 * Dataset implementations are often trivial, just needing to capture some
62 * values (like the selector) for the iterator constructor. Consider using
63 * ExternalDataset::FromFunction.
64 */
65 class ExternalDataset {
66 public:
67 virtual ~ExternalDataset() = default;
68
69 /**
70 * Creates a new iterator. Corresponds to
71 * tensorflow::data::DatasetBase::MakeIterator.
72 */
73 virtual std::unique_ptr<ExternalDatasetIterator> MakeIterator() = 0;
74
75 /**
76 * Creates an ExternalDataset that wraps a callable object 'f', implementing
77 * MakeIterator(). The lifetime of 'f' is that of the dataset (so,
78 * by-reference lambda captures are almost always unsafe here).
79 */
80 template <typename F>
FromFunction(F f)81 static std::unique_ptr<ExternalDataset> FromFunction(F f) {
82 return std::make_unique<external_dataset_internal::DatasetFromFunction<F>>(
83 std::move(f));
84 }
85 };
86
87 /**
88 * Interface for an ExternalDataset op's host object.
89 *
90 * An ExternalDatasetProvider is a function from Selector -> ExternalDataset.
91 * Here, 'Selector' is a string provided to the dataset op (typically, an
92 * encoded proto). The returned ExternalDataset may be used (perhaps multiple
93 * times) to create an iterator.
94 *
95 * When implementing a dataset provider and the selector is a proto message,
96 * consider inheritng from ExternalDatasetProvider::UsingProtoSelector<T> (for
97 * some message type T).
98 */
99 class ExternalDatasetProvider {
100 public:
101 virtual ~ExternalDatasetProvider() = default;
102
103 /**
104 * Creates a dataset for a given selector.
105 *
106 * This function can usually be implemented succinctly, using
107 * ExternalDataset::FromFunction.
108 *
109 * Corresponds to tensorflow::data::DatasetOpKernel::MakeDataset.
110 */
111 virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
112 absl::string_view selector) = 0;
113
114 /**
115 * Base class for dataset providers that expect a selector of a particular
116 * proto message type. If inheriting from UsingProtoSelector<T>, then one
117 * implements MakeDataset(T) instead of MakeDataset(absl::string_view).
118 */
119 template <typename T>
120 class UsingProtoSelector;
121 };
122
123 /**
124 * HostObjectRegistry for the ExternalDataset interface.
125 */
126 using ExternalDatasetProviderRegistry =
127 HostObjectRegistry<ExternalDatasetProvider>;
128
129 namespace external_dataset_internal {
130
131 template <typename T>
TryParseProtoSelector(absl::string_view selector)132 absl::StatusOr<T> TryParseProtoSelector(absl::string_view selector) {
133 T msg;
134 if (!msg.ParseFromArray(selector.data(),
135 CastIntegerChecked<int>(selector.size()))) {
136 return absl::InvalidArgumentError(absl::StrCat(
137 "Failed to parse selector proto of type ", msg.GetTypeName()));
138 }
139
140 return msg;
141 }
142
143 template <typename FuncType>
144 class DatasetFromFunction : public ExternalDataset {
145 public:
DatasetFromFunction(FuncType func)146 explicit DatasetFromFunction(FuncType func) : func_(std::move(func)) {}
147
MakeIterator()148 std::unique_ptr<ExternalDatasetIterator> MakeIterator() final {
149 return func_();
150 }
151
152 private:
153 FuncType func_;
154 };
155
156 } // namespace external_dataset_internal
157
158 template <typename T>
159 class ExternalDatasetProvider::UsingProtoSelector
160 : public ExternalDatasetProvider {
161 public:
MakeDataset(absl::string_view selector)162 absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
163 absl::string_view selector) final {
164 auto maybe_msg =
165 external_dataset_internal::TryParseProtoSelector<T>(selector);
166 if (!maybe_msg.ok()) {
167 return maybe_msg.status();
168 }
169
170 return MakeDataset(std::move(maybe_msg).value());
171 }
172
173 virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
174 T selector) = 0;
175 };
176
177 } // namespace fcp
178
179 #endif // FCP_TENSORFLOW_EXTERNAL_DATASET_H_
180