1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/generator_dataset_op.h"
16
17 #include <iterator>
18 #include <vector>
19
20 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
21 #include "tensorflow/core/data/captured_function.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/name_utils.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/random/random.h"
27
28 namespace tensorflow {
29 namespace data {
30
31 // See documentation in ../../ops/dataset_ops.cc for a high-level
32 // description of the following op.
33
34 /* static */ constexpr const char* const GeneratorDatasetOp::kDatasetType;
35 /* static */ constexpr const char* const GeneratorDatasetOp::kInitFuncOtherArgs;
36 /* static */ constexpr const char* const GeneratorDatasetOp::kNextFuncOtherArgs;
37 /* static */ constexpr const char* const
38 GeneratorDatasetOp::kFinalizeFuncOtherArgs;
39 /* static */ constexpr const char* const GeneratorDatasetOp::kInitFunc;
40 /* static */ constexpr const char* const GeneratorDatasetOp::kNextFunc;
41 /* static */ constexpr const char* const GeneratorDatasetOp::kFinalizeFunc;
42 /* static */ constexpr const char* const GeneratorDatasetOp::kTinitFuncArgs;
43 /* static */ constexpr const char* const GeneratorDatasetOp::kTnextFuncArgs;
44 /* static */ constexpr const char* const GeneratorDatasetOp::kTfinalizeFuncArgs;
45 /* static */ constexpr const char* const GeneratorDatasetOp::kOutputTypes;
46 /* static */ constexpr const char* const GeneratorDatasetOp::kOutputShapes;
47
48 class GeneratorDatasetOp::Dataset : public DatasetBase {
49 public:
Dataset(OpKernelContext * ctx,std::unique_ptr<CapturedFunction> init_func,std::unique_ptr<CapturedFunction> next_func,std::unique_ptr<CapturedFunction> finalize_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)50 Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
51 std::unique_ptr<CapturedFunction> next_func,
52 std::unique_ptr<CapturedFunction> finalize_func,
53 const DataTypeVector& output_types,
54 const std::vector<PartialTensorShape>& output_shapes)
55 : DatasetBase(DatasetContext(ctx)),
56 init_func_(std::move(init_func)),
57 next_func_(std::move(next_func)),
58 finalize_func_(std::move(finalize_func)),
59 output_types_(output_types),
60 output_shapes_(output_shapes) {}
61
MakeIteratorInternal(const string & prefix) const62 std::unique_ptr<IteratorBase> MakeIteratorInternal(
63 const string& prefix) const override {
64 return std::make_unique<Iterator>(Iterator::Params{
65 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
66 }
67
output_dtypes() const68 const DataTypeVector& output_dtypes() const override { return output_types_; }
69
output_shapes() const70 const std::vector<PartialTensorShape>& output_shapes() const override {
71 return output_shapes_;
72 }
73
DebugString() const74 string DebugString() const override {
75 return name_utils::DatasetDebugString(kDatasetType);
76 }
77
InputDatasets(std::vector<const DatasetBase * > * inputs) const78 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
79 return OkStatus();
80 }
81
CheckExternalState() const82 Status CheckExternalState() const override {
83 TF_RETURN_IF_ERROR(init_func_->CheckExternalState());
84 TF_RETURN_IF_ERROR(next_func_->CheckExternalState());
85 return finalize_func_->CheckExternalState();
86 }
87
88 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const89 Status AsGraphDefInternal(SerializationContext* ctx,
90 DatasetGraphDefBuilder* b,
91 Node** output) const override {
92 return errors::Unimplemented(DebugString(),
93 " does not support serialization");
94 }
95
96 private:
97 class Iterator : public DatasetIterator<Dataset> {
98 public:
Iterator(const Params & params)99 explicit Iterator(const Params& params)
100 : DatasetIterator<Dataset>(params) {}
101
~Iterator()102 ~Iterator() override {
103 if (!finalized_ && initialized_) {
104 std::vector<Tensor> ignored;
105 Status s =
106 instantiated_finalize_func_->RunInstantiated(state_, &ignored);
107 if (!s.ok()) {
108 LOG(WARNING)
109 << "Error occurred when finalizing GeneratorDataset iterator: "
110 << s;
111 }
112 }
113 }
114
Initialize(IteratorContext * ctx)115 Status Initialize(IteratorContext* ctx) override {
116 TF_RETURN_IF_ERROR(
117 dataset()->init_func_->Instantiate(ctx, &instantiated_init_func_));
118 TF_RETURN_IF_ERROR(
119 dataset()->next_func_->Instantiate(ctx, &instantiated_next_func_));
120 TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(
121 ctx, &instantiated_finalize_func_));
122 return OkStatus();
123 }
124
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)125 Status GetNextInternal(IteratorContext* ctx,
126 std::vector<Tensor>* out_tensors,
127 bool* end_of_sequence) override {
128 mutex_lock l(mu_);
129
130 if (!initialized_) {
131 TF_RETURN_IF_ERROR(instantiated_init_func_->RunWithBorrowedArgs(
132 ctx, {}, &state_, model_node()));
133 initialized_ = true;
134 }
135
136 if (finalized_) {
137 *end_of_sequence = true;
138 return OkStatus();
139 }
140
141 Status s = instantiated_next_func_->RunWithBorrowedArgs(
142 ctx, state_, out_tensors, model_node());
143 if (s.ok()) {
144 *end_of_sequence = false;
145 } else if (errors::IsOutOfRange(s)) {
146 // `next_func` may deliberately raise `errors::OutOfRange`
147 // to indicate that we should terminate the iteration.
148 s = OkStatus();
149 *end_of_sequence = true;
150
151 // NOTE(mrry): We ignore any tensors returned by the finalize function.
152 std::vector<Tensor> ignored;
153 TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
154 ctx, state_, &ignored, model_node()));
155 finalized_ = true;
156 }
157 return s;
158 }
159
160 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const161 std::shared_ptr<model::Node> CreateNode(
162 IteratorContext* ctx, model::Node::Args args) const override {
163 return model::MakeSourceNode(std::move(args));
164 }
165
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)166 Status SaveInternal(SerializationContext* ctx,
167 IteratorStateWriter* writer) override {
168 return errors::Unimplemented(
169 "GeneratorDataset does not support checkpointing.");
170 }
171
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)172 Status RestoreInternal(IteratorContext* ctx,
173 IteratorStateReader* reader) override {
174 return errors::Unimplemented(
175 "GeneratorDataset does not support checkpointing.");
176 }
177
178 private:
179 mutex mu_;
180 bool initialized_ TF_GUARDED_BY(mu_) = false;
181 bool finalized_ TF_GUARDED_BY(mu_) = false;
182 std::vector<Tensor> state_ TF_GUARDED_BY(mu_);
183 std::unique_ptr<InstantiatedCapturedFunction> instantiated_init_func_;
184 std::unique_ptr<InstantiatedCapturedFunction> instantiated_next_func_;
185 std::unique_ptr<InstantiatedCapturedFunction> instantiated_finalize_func_;
186 };
187
188 const std::unique_ptr<CapturedFunction> init_func_;
189 const std::unique_ptr<CapturedFunction> next_func_;
190 const std::unique_ptr<CapturedFunction> finalize_func_;
191 const DataTypeVector output_types_;
192 const std::vector<PartialTensorShape> output_shapes_;
193 };
194
GeneratorDatasetOp(OpKernelConstruction * ctx)195 GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
196 : DatasetOpKernel(ctx) {
197 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kInitFunc, /*params=*/{},
198 &init_func_metadata_));
199 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kNextFunc, /*params=*/{},
200 &next_func_metadata_));
201 OP_REQUIRES_OK(ctx,
202 FunctionMetadata::Create(ctx, kFinalizeFunc, /*params=*/{},
203 &finalize_func_metadata_));
204 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
205 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
206 }
207
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)208 void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
209 DatasetBase** output) {
210 std::unique_ptr<CapturedFunction> init_func;
211 OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, init_func_metadata_,
212 kInitFuncOtherArgs, &init_func));
213
214 std::unique_ptr<CapturedFunction> next_func;
215 OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, next_func_metadata_,
216 kNextFuncOtherArgs, &next_func));
217
218 std::unique_ptr<CapturedFunction> finalize_func;
219 OP_REQUIRES_OK(
220 ctx, CapturedFunction::Create(ctx, finalize_func_metadata_,
221 kFinalizeFuncOtherArgs, &finalize_func));
222
223 *output =
224 new Dataset(ctx, std::move(init_func), std::move(next_func),
225 std::move(finalize_func), output_types_, output_shapes_);
226 }
227
228 namespace {
229 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU).Priority(2),
230 GeneratorDatasetOp);
231 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset")
232 .Device(DEVICE_GPU)
233 .HostMemory("handle")
234 .Priority(1),
235 GeneratorDatasetOp);
236 REGISTER_INPUT_COLOCATION_EXEMPTION("GeneratorDataset");
237 } // namespace
238
239 } // namespace data
240 } // namespace tensorflow
241