1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/generator_dataset_op.h"
16
17 #include <iterator>
18 #include <vector>
19
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/kernels/data/captured_function.h"
23 #include "tensorflow/core/kernels/data/dataset_utils.h"
24 #include "tensorflow/core/kernels/data/name_utils.h"
25 #include "tensorflow/core/lib/random/random.h"
26
27 namespace tensorflow {
28 namespace data {
29
30 // See documentation in ../../ops/dataset_ops.cc for a high-level
31 // description of the following op.
32
33 /* static */ constexpr const char* const GeneratorDatasetOp::kDatasetType;
34 /* static */ constexpr const char* const GeneratorDatasetOp::kInitFuncOtherArgs;
35 /* static */ constexpr const char* const GeneratorDatasetOp::kNextFuncOtherArgs;
36 /* static */ constexpr const char* const
37 GeneratorDatasetOp::kFinalizeFuncOtherArgs;
38 /* static */ constexpr const char* const GeneratorDatasetOp::kInitFunc;
39 /* static */ constexpr const char* const GeneratorDatasetOp::kNextFunc;
40 /* static */ constexpr const char* const GeneratorDatasetOp::kFinalizeFunc;
41 /* static */ constexpr const char* const GeneratorDatasetOp::kTinitFuncArgs;
42 /* static */ constexpr const char* const GeneratorDatasetOp::kTnextFuncArgs;
43 /* static */ constexpr const char* const GeneratorDatasetOp::kTfinalizeFuncArgs;
44 /* static */ constexpr const char* const GeneratorDatasetOp::kOutputTypes;
45 /* static */ constexpr const char* const GeneratorDatasetOp::kOutputShapes;
46
47 class GeneratorDatasetOp::Dataset : public DatasetBase {
48 public:
Dataset(OpKernelContext * ctx,std::unique_ptr<CapturedFunction> init_func,std::unique_ptr<CapturedFunction> next_func,std::unique_ptr<CapturedFunction> finalize_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)49 Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
50 std::unique_ptr<CapturedFunction> next_func,
51 std::unique_ptr<CapturedFunction> finalize_func,
52 const DataTypeVector& output_types,
53 const std::vector<PartialTensorShape>& output_shapes)
54 : DatasetBase(DatasetContext(ctx)),
55 init_func_(std::move(init_func)),
56 next_func_(std::move(next_func)),
57 finalize_func_(std::move(finalize_func)),
58 output_types_(output_types),
59 output_shapes_(output_shapes) {}
60
MakeIteratorInternal(const string & prefix) const61 std::unique_ptr<IteratorBase> MakeIteratorInternal(
62 const string& prefix) const override {
63 return absl::make_unique<Iterator>(Iterator::Params{
64 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
65 }
66
output_dtypes() const67 const DataTypeVector& output_dtypes() const override { return output_types_; }
68
output_shapes() const69 const std::vector<PartialTensorShape>& output_shapes() const override {
70 return output_shapes_;
71 }
72
DebugString() const73 string DebugString() const override {
74 return name_utils::DatasetDebugString(kDatasetType);
75 }
76
InputDatasets(std::vector<const DatasetBase * > * inputs) const77 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
78 return Status::OK();
79 }
80
CheckExternalState() const81 Status CheckExternalState() const override {
82 TF_RETURN_IF_ERROR(init_func_->CheckExternalState());
83 TF_RETURN_IF_ERROR(next_func_->CheckExternalState());
84 return finalize_func_->CheckExternalState();
85 }
86
87 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const88 Status AsGraphDefInternal(SerializationContext* ctx,
89 DatasetGraphDefBuilder* b,
90 Node** output) const override {
91 return errors::Unimplemented(DebugString(),
92 " does not support serialization");
93 }
94
95 private:
96 class Iterator : public DatasetIterator<Dataset> {
97 public:
Iterator(const Params & params)98 explicit Iterator(const Params& params)
99 : DatasetIterator<Dataset>(params) {}
100
~Iterator()101 ~Iterator() override {
102 if (!finalized_ && initialized_) {
103 std::vector<Tensor> ignored;
104 Status s =
105 instantiated_finalize_func_->RunInstantiated(state_, &ignored);
106 if (!s.ok()) {
107 LOG(WARNING)
108 << "Error occurred when finalizing GeneratorDataset iterator: "
109 << s;
110 }
111 }
112 }
113
Initialize(IteratorContext * ctx)114 Status Initialize(IteratorContext* ctx) override {
115 TF_RETURN_IF_ERROR(
116 dataset()->init_func_->Instantiate(ctx, &instantiated_init_func_));
117 TF_RETURN_IF_ERROR(
118 dataset()->next_func_->Instantiate(ctx, &instantiated_next_func_));
119 TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(
120 ctx, &instantiated_finalize_func_));
121 return Status::OK();
122 }
123
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)124 Status GetNextInternal(IteratorContext* ctx,
125 std::vector<Tensor>* out_tensors,
126 bool* end_of_sequence) override {
127 mutex_lock l(mu_);
128
129 if (!initialized_) {
130 TF_RETURN_IF_ERROR(instantiated_init_func_->RunWithBorrowedArgs(
131 ctx, {}, &state_, model_node()));
132 initialized_ = true;
133 }
134
135 if (finalized_) {
136 *end_of_sequence = true;
137 return Status::OK();
138 }
139
140 Status s = instantiated_next_func_->RunWithBorrowedArgs(
141 ctx, state_, out_tensors, model_node());
142 if (s.ok()) {
143 *end_of_sequence = false;
144 } else if (errors::IsOutOfRange(s)) {
145 // `next_func` may deliberately raise `errors::OutOfRange`
146 // to indicate that we should terminate the iteration.
147 s = Status::OK();
148 *end_of_sequence = true;
149
150 // NOTE(mrry): We ignore any tensors returned by the finalize function.
151 std::vector<Tensor> ignored;
152 TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
153 ctx, state_, &ignored, model_node()));
154 finalized_ = true;
155 }
156 return s;
157 }
158
159 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const160 std::shared_ptr<model::Node> CreateNode(
161 IteratorContext* ctx, model::Node::Args args) const override {
162 return model::MakeSourceNode(std::move(args));
163 }
164
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)165 Status SaveInternal(SerializationContext* ctx,
166 IteratorStateWriter* writer) override {
167 return errors::Unimplemented(
168 "GeneratorDataset does not support checkpointing.");
169 }
170
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)171 Status RestoreInternal(IteratorContext* ctx,
172 IteratorStateReader* reader) override {
173 return errors::Unimplemented(
174 "GeneratorDataset does not support checkpointing.");
175 }
176
177 private:
178 mutex mu_;
179 bool initialized_ TF_GUARDED_BY(mu_) = false;
180 bool finalized_ TF_GUARDED_BY(mu_) = false;
181 std::vector<Tensor> state_ TF_GUARDED_BY(mu_);
182 std::unique_ptr<InstantiatedCapturedFunction> instantiated_init_func_;
183 std::unique_ptr<InstantiatedCapturedFunction> instantiated_next_func_;
184 std::unique_ptr<InstantiatedCapturedFunction> instantiated_finalize_func_;
185 };
186
187 const std::unique_ptr<CapturedFunction> init_func_;
188 const std::unique_ptr<CapturedFunction> next_func_;
189 const std::unique_ptr<CapturedFunction> finalize_func_;
190 const DataTypeVector output_types_;
191 const std::vector<PartialTensorShape> output_shapes_;
192 };
193
GeneratorDatasetOp(OpKernelConstruction * ctx)194 GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
195 : DatasetOpKernel(ctx) {
196 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kInitFunc, /*params=*/{},
197 &init_func_metadata_));
198 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kNextFunc, /*params=*/{},
199 &next_func_metadata_));
200 OP_REQUIRES_OK(ctx,
201 FunctionMetadata::Create(ctx, kFinalizeFunc, /*params=*/{},
202 &finalize_func_metadata_));
203 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
204 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
205 }
206
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)207 void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
208 DatasetBase** output) {
209 std::unique_ptr<CapturedFunction> init_func;
210 OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, init_func_metadata_,
211 kInitFuncOtherArgs, &init_func));
212
213 std::unique_ptr<CapturedFunction> next_func;
214 OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, next_func_metadata_,
215 kNextFuncOtherArgs, &next_func));
216
217 std::unique_ptr<CapturedFunction> finalize_func;
218 OP_REQUIRES_OK(
219 ctx, CapturedFunction::Create(ctx, finalize_func_metadata_,
220 kFinalizeFuncOtherArgs, &finalize_func));
221
222 *output =
223 new Dataset(ctx, std::move(init_func), std::move(next_func),
224 std::move(finalize_func), output_types_, output_shapes_);
225 }
226
227 namespace {
228 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU).Priority(2),
229 GeneratorDatasetOp);
230 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset")
231 .Device(DEVICE_GPU)
232 .HostMemory("handle")
233 .Priority(1),
234 GeneratorDatasetOp);
235 } // namespace
236
237 } // namespace data
238 } // namespace tensorflow
239