1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/flat_map_dataset_op.h"
16
17 #include <string>
18 #include <utility>
19
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/common_runtime/graph_runner.h"
23 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
24 #include "tensorflow/core/data/dataset_utils.h"
25 #include "tensorflow/core/data/name_utils.h"
26 #include "tensorflow/core/data/serialization_utils.h"
27 #include "tensorflow/core/framework/dataset.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/partial_tensor_shape.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/lib/random/random.h"
32
33 namespace tensorflow {
34 namespace data {
35
36 // See documentation in ../../ops/dataset_ops.cc for a high-level
37 // description of the following op.
38
39 /* static */ constexpr const char* const FlatMapDatasetOp::kDatasetType;
40 /* static */ constexpr const char* const FlatMapDatasetOp::kInputDataset;
41 /* static */ constexpr const char* const FlatMapDatasetOp::kOtherArguments;
42 /* static */ constexpr const char* const FlatMapDatasetOp::kFunc;
43 /* static */ constexpr const char* const FlatMapDatasetOp::kTarguments;
44 /* static */ constexpr const char* const FlatMapDatasetOp::kOutputTypes;
45 /* static */ constexpr const char* const FlatMapDatasetOp::kOutputShapes;
46
47 constexpr char kElementIndex[] = "element_index";
48 constexpr char kInputsSize[] = "inputs_size";
49 constexpr char kInputs[] = "inputs";
50 constexpr char kCurrentElementIteratorUninitialized[] =
51 "current_element_iterator_uninitialized";
52 constexpr char kExhausted[] = "exhausted";
53
54 class FlatMapDatasetOp::Dataset : public DatasetBase {
55 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)56 Dataset(OpKernelContext* ctx, const DatasetBase* input,
57 std::unique_ptr<CapturedFunction> captured_func,
58 const DataTypeVector& output_types,
59 const std::vector<PartialTensorShape>& output_shapes)
60 : DatasetBase(DatasetContext(ctx)),
61 input_(input),
62 captured_func_(std::move(captured_func)),
63 output_types_(output_types),
64 output_shapes_(output_shapes) {
65 input_->Ref();
66 }
67
~Dataset()68 ~Dataset() override { input_->Unref(); }
69
MakeIteratorInternal(const string & prefix) const70 std::unique_ptr<IteratorBase> MakeIteratorInternal(
71 const string& prefix) const override {
72 return absl::make_unique<Iterator>(Iterator::Params{
73 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
74 }
75
output_dtypes() const76 const DataTypeVector& output_dtypes() const override { return output_types_; }
77
output_shapes() const78 const std::vector<PartialTensorShape>& output_shapes() const override {
79 return output_shapes_;
80 }
81
DebugString() const82 string DebugString() const override {
83 return name_utils::DatasetDebugString(kDatasetType);
84 }
85
InputDatasets(std::vector<const DatasetBase * > * inputs) const86 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
87 inputs->push_back(input_);
88 return Status::OK();
89 }
90
CheckExternalState() const91 Status CheckExternalState() const override {
92 TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
93 return input_->CheckExternalState();
94 }
95
96 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const97 Status AsGraphDefInternal(SerializationContext* ctx,
98 DatasetGraphDefBuilder* b,
99 Node** output) const override {
100 Node* input_graph_node = nullptr;
101 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
102 std::vector<Node*> other_arguments;
103 DataTypeVector other_arguments_types;
104 TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
105 &other_arguments_types));
106 AttrValue f;
107 b->BuildAttrValue(captured_func_->func(), &f);
108 AttrValue other_arguments_types_attr;
109 b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
110
111 TF_RETURN_IF_ERROR(b->AddDataset(
112 this, {std::make_pair(0, input_graph_node)}, // Single tensor inputs.
113 {std::make_pair(1, other_arguments)}, // Tensor list inputs.
114 {std::make_pair(kFunc, f),
115 std::make_pair(kTarguments, other_arguments_types_attr)}, // Attrs
116 output));
117 return Status::OK();
118 }
119
120 private:
121 class Iterator : public DatasetIterator<Dataset> {
122 public:
Iterator(const Params & params)123 explicit Iterator(const Params& params)
124 : DatasetIterator<Dataset>(params) {}
125
Initialize(IteratorContext * ctx)126 Status Initialize(IteratorContext* ctx) override {
127 TF_RETURN_IF_ERROR(
128 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
129 return dataset()->captured_func_->Instantiate(
130 ctx, &instantiated_captured_func_);
131 }
132
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)133 Status GetNextInternal(IteratorContext* ctx,
134 std::vector<Tensor>* out_tensors,
135 bool* end_of_sequence) override {
136 mutex_lock l(mu_);
137 do {
138 if (!input_impl_) {
139 *end_of_sequence = true;
140 return Status::OK();
141 }
142 if (current_element_iterator_) {
143 // We are currently processing a mapped element, so try to get the
144 // next subelement.
145 bool end_of_element;
146 TF_RETURN_IF_ERROR(current_element_iterator_->GetNext(
147 ctx, out_tensors, &end_of_element));
148 if (!end_of_element) {
149 // Produce the subelement as output.
150 *end_of_sequence = false;
151 return Status::OK();
152 }
153
154 // We have reached the end of the current element, so maybe move on
155 // to the next element.
156 current_element_iterator_.reset();
157 }
158
159 // Get the next element from the input dataset.
160 inputs_.clear();
161 TF_RETURN_IF_ERROR(
162 input_impl_->GetNext(ctx, &inputs_, end_of_sequence));
163 if (*end_of_sequence) {
164 input_impl_.reset();
165 return Status::OK();
166 }
167
168 TF_RETURN_IF_ERROR(
169 BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/true));
170 } while (true);
171 }
172
SkipInternal(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)173 Status SkipInternal(IteratorContext* ctx, int num_to_skip,
174 bool* end_of_sequence, int* num_skipped) override {
175 mutex_lock l(mu_);
176 *num_skipped = 0;
177 while (*num_skipped < num_to_skip) {
178 if (!input_impl_) {
179 *end_of_sequence = true;
180 return Status::OK();
181 }
182 if (!current_element_iterator_) {
183 // Get the next element from the input dataset.
184 inputs_.clear();
185 TF_RETURN_IF_ERROR(
186 input_impl_->GetNext(ctx, &inputs_, end_of_sequence));
187 if (*end_of_sequence) {
188 input_impl_.reset();
189 *end_of_sequence = true;
190 return Status::OK();
191 }
192 TF_RETURN_IF_ERROR(
193 BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/false));
194 }
195 bool end_of_element;
196 int last_num_skipped;
197 TF_RETURN_IF_ERROR(current_element_iterator_->Skip(
198 ctx, num_to_skip - *num_skipped, &end_of_element,
199 &last_num_skipped));
200 *num_skipped += last_num_skipped;
201 if (end_of_element) {
202 // We have reached the end of the current element, so maybe move on
203 // to the next element.
204 current_element_iterator_.reset();
205 }
206 }
207 *end_of_sequence = false;
208 return Status::OK();
209 }
210
211 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const212 std::shared_ptr<model::Node> CreateNode(
213 IteratorContext* ctx, model::Node::Args args) const override {
214 return model::MakeInterleaveManyNode(std::move(args));
215 }
216
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)217 Status SaveInternal(SerializationContext* ctx,
218 IteratorStateWriter* writer) override {
219 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
220 dataset()->captured_func_->CheckExternalState()));
221 mutex_lock l(mu_);
222 if (input_impl_) {
223 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
224 TF_RETURN_IF_ERROR(
225 writer->WriteScalar(full_name(kElementIndex), element_index_));
226 if (current_element_iterator_) {
227 TF_RETURN_IF_ERROR(
228 writer->WriteScalar(full_name(kInputsSize), inputs_.size()));
229 for (int i = 0; i < inputs_.size(); i++) {
230 TF_RETURN_IF_ERROR(writer->WriteTensor(
231 full_name(strings::StrCat(kInputs, "[", i, "]")), inputs_[i]));
232 }
233 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_element_iterator_));
234 } else {
235 TF_RETURN_IF_ERROR(writer->WriteScalar(
236 full_name(kCurrentElementIteratorUninitialized), ""));
237 }
238 } else {
239 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), ""));
240 }
241 return Status::OK();
242 }
243
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)244 Status RestoreInternal(IteratorContext* ctx,
245 IteratorStateReader* reader) override {
246 mutex_lock l(mu_);
247 input_impl_.reset();
248 element_index_ = 0;
249 current_element_iterator_.reset();
250 inputs_.clear();
251 if (!reader->Contains(full_name(kExhausted))) {
252 TF_RETURN_IF_ERROR(
253 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
254 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
255 {
256 int64_t temp;
257 TF_RETURN_IF_ERROR(
258 reader->ReadScalar(full_name(kElementIndex), &temp));
259 element_index_ = temp;
260 }
261 if (!reader->Contains(
262 full_name(kCurrentElementIteratorUninitialized))) {
263 size_t inputs_size;
264 {
265 int64_t temp;
266 TF_RETURN_IF_ERROR(
267 reader->ReadScalar(full_name(kInputsSize), &temp));
268 inputs_size = static_cast<size_t>(temp);
269 }
270 inputs_.reserve(inputs_size);
271 for (int i = 0; i < inputs_size; i++) {
272 inputs_.emplace_back();
273 TF_RETURN_IF_ERROR(reader->ReadTensor(
274 ctx->flr(), full_name(strings::StrCat(kInputs, "[", i, "]")),
275 &inputs_.back()));
276 }
277
278 element_index_--;
279 TF_RETURN_IF_ERROR(
280 BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/false));
281 TF_RETURN_IF_ERROR(
282 RestoreInput(ctx, reader, current_element_iterator_));
283 }
284 }
285 return Status::OK();
286 }
287
288 private:
BuildCurrentElementIteratorLocked(IteratorContext * ctx,bool is_get_next)289 Status BuildCurrentElementIteratorLocked(IteratorContext* ctx,
290 bool is_get_next)
291 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
292 if (is_get_next) {
293 return MakeIteratorFromInputElement(
294 ctx, this, inputs_, element_index_++, *instantiated_captured_func_,
295 prefix(), ¤t_element_iterator_, model_node());
296 } else {
297 // NOTE: We intentionally ignore resource modeling outside GetNext().
298 return MakeIteratorFromInputElement(
299 ctx, this, inputs_, element_index_++, *instantiated_captured_func_,
300 prefix(), ¤t_element_iterator_,
301 /*node=*/nullptr);
302 }
303 }
304
305 mutex mu_;
306 size_t element_index_ TF_GUARDED_BY(mu_) = 0;
307 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
308 std::unique_ptr<IteratorBase> current_element_iterator_ TF_GUARDED_BY(mu_);
309 std::vector<Tensor> inputs_ TF_GUARDED_BY(mu_);
310 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
311 };
312
313 const DatasetBase* const input_;
314 const std::unique_ptr<CapturedFunction> captured_func_;
315 const DataTypeVector output_types_;
316 const std::vector<PartialTensorShape> output_shapes_;
317 };
318
FlatMapDatasetOp(OpKernelConstruction * ctx)319 FlatMapDatasetOp::FlatMapDatasetOp(OpKernelConstruction* ctx)
320 : UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
321 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
322 &func_metadata_));
323 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
324 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
325 }
326
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)327 void FlatMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
328 DatasetBase** output) {
329 std::unique_ptr<CapturedFunction> captured_func;
330 OP_REQUIRES_OK(ctx,
331 CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
332 &captured_func));
333 *output = new Dataset(ctx, input, std::move(captured_func), output_types_,
334 output_shapes_);
335 }
336
337 namespace {
338
339 REGISTER_KERNEL_BUILDER(Name("FlatMapDataset").Device(DEVICE_CPU),
340 FlatMapDatasetOp);
341 REGISTER_INPUT_COLOCATION_EXEMPTION("FlatMapDataset");
342
343 } // namespace
344 } // namespace data
345 } // namespace tensorflow
346