1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/generator_dataset_op.h"
16
17 #include <iterator>
18 #include <vector>
19
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/kernels/data/captured_function.h"
23 #include "tensorflow/core/lib/random/random.h"
24
25 namespace tensorflow {
26 namespace data {
27
28 // See documentation in ../../ops/dataset_ops.cc for a high-level
29 // description of the following op.
30
31 class GeneratorDatasetOp::Dataset : public DatasetBase {
32 public:
Dataset(OpKernelContext * ctx,std::unique_ptr<CapturedFunction> init_func,std::unique_ptr<CapturedFunction> next_func,std::unique_ptr<CapturedFunction> finalize_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)33 Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
34 std::unique_ptr<CapturedFunction> next_func,
35 std::unique_ptr<CapturedFunction> finalize_func,
36 const DataTypeVector& output_types,
37 const std::vector<PartialTensorShape>& output_shapes)
38 : DatasetBase(DatasetContext(ctx)),
39 init_func_(std::move(init_func)),
40 next_func_(std::move(next_func)),
41 finalize_func_(std::move(finalize_func)),
42 output_types_(output_types),
43 output_shapes_(output_shapes) {}
44
MakeIteratorInternal(const string & prefix) const45 std::unique_ptr<IteratorBase> MakeIteratorInternal(
46 const string& prefix) const override {
47 return absl::make_unique<Iterator>(
48 Iterator::Params{this, strings::StrCat(prefix, "::Generator")});
49 }
50
output_dtypes() const51 const DataTypeVector& output_dtypes() const override { return output_types_; }
52
output_shapes() const53 const std::vector<PartialTensorShape>& output_shapes() const override {
54 return output_shapes_;
55 }
56
DebugString() const57 string DebugString() const override { return "GeneratorDatasetOp::Dataset"; }
58
59 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const60 Status AsGraphDefInternal(SerializationContext* ctx,
61 DatasetGraphDefBuilder* b,
62 Node** output) const override {
63 return errors::Unimplemented("%s does not support serialization",
64 DebugString());
65 }
66
67 private:
68 class Iterator : public DatasetIterator<Dataset> {
69 public:
Iterator(const Params & params)70 explicit Iterator(const Params& params)
71 : DatasetIterator<Dataset>(params) {}
72
~Iterator()73 ~Iterator() override {
74 if (!finalized_ && initialized_) {
75 std::vector<Tensor> ignored;
76 Status s =
77 instantiated_finalize_func_->RunInstantiated(state_, &ignored);
78 if (!s.ok()) {
79 LOG(WARNING)
80 << "Error occurred when finalizing GeneratorDataset iterator: "
81 << s;
82 }
83 }
84 }
85
Initialize(IteratorContext * ctx)86 Status Initialize(IteratorContext* ctx) override {
87 TF_RETURN_IF_ERROR(
88 dataset()->init_func_->Instantiate(ctx, &instantiated_init_func_));
89 TF_RETURN_IF_ERROR(
90 dataset()->next_func_->Instantiate(ctx, &instantiated_next_func_));
91 TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(
92 ctx, &instantiated_finalize_func_));
93 return Status::OK();
94 }
95
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)96 Status GetNextInternal(IteratorContext* ctx,
97 std::vector<Tensor>* out_tensors,
98 bool* end_of_sequence) override {
99 mutex_lock l(mu_);
100
101 if (!initialized_) {
102 TF_RETURN_IF_ERROR(
103 instantiated_init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
104 initialized_ = true;
105 }
106
107 if (finalized_) {
108 *end_of_sequence = true;
109 return Status::OK();
110 }
111
112 Status s = instantiated_next_func_->RunWithBorrowedArgs(ctx, state_,
113 out_tensors);
114 if (s.ok()) {
115 *end_of_sequence = false;
116 } else if (errors::IsOutOfRange(s)) {
117 // `next_func` may deliberately raise `errors::OutOfRange`
118 // to indicate that we should terminate the iteration.
119 s = Status::OK();
120 *end_of_sequence = true;
121
122 // NOTE(mrry): We ignore any tensors returned by the
123 // finalize function.
124 std::vector<Tensor> ignored;
125 TF_RETURN_IF_ERROR(
126 instantiated_finalize_func_->RunInstantiated(state_, &ignored));
127 finalized_ = true;
128 }
129 return s;
130 }
131
132 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const133 std::shared_ptr<model::Node> CreateNode(
134 IteratorContext* ctx, model::Node::Args args) const override {
135 return model::MakeSourceNode(std::move(args));
136 }
137
138 private:
139 mutex mu_;
140 bool initialized_ GUARDED_BY(mu_) = false;
141 bool finalized_ GUARDED_BY(mu_) = false;
142 std::vector<Tensor> state_ GUARDED_BY(mu_);
143 std::unique_ptr<InstantiatedCapturedFunction> instantiated_init_func_;
144 std::unique_ptr<InstantiatedCapturedFunction> instantiated_next_func_;
145 std::unique_ptr<InstantiatedCapturedFunction> instantiated_finalize_func_;
146 };
147
148 const std::unique_ptr<CapturedFunction> init_func_;
149 const std::unique_ptr<CapturedFunction> next_func_;
150 const std::unique_ptr<CapturedFunction> finalize_func_;
151 const DataTypeVector output_types_;
152 const std::vector<PartialTensorShape> output_shapes_;
153 };
154
GeneratorDatasetOp(OpKernelConstruction * ctx)155 GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
156 : DatasetOpKernel(ctx) {
157 OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
158 OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_));
159 OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_));
160 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
161 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
162 }
163
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)164 void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
165 DatasetBase** output) {
166 std::unique_ptr<CapturedFunction> init_func;
167 OP_REQUIRES_OK(ctx, CapturedFunction::Create(
168 init_func_, ctx, "init_func_other_args", &init_func));
169
170 std::unique_ptr<CapturedFunction> next_func;
171 OP_REQUIRES_OK(ctx, CapturedFunction::Create(
172 next_func_, ctx, "next_func_other_args", &next_func));
173
174 std::unique_ptr<CapturedFunction> finalize_func;
175 OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx,
176 "finalize_func_other_args",
177 &finalize_func));
178
179 *output =
180 new Dataset(ctx, std::move(init_func), std::move(next_func),
181 std::move(finalize_func), output_types_, output_shapes_);
182 }
183
184 namespace {
185 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU).Priority(2),
186 GeneratorDatasetOp);
187 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset")
188 .Device(DEVICE_GPU)
189 .HostMemory("handle")
190 .Priority(1),
191 GeneratorDatasetOp);
192 } // namespace
193
194 } // namespace data
195 } // namespace tensorflow
196