• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/generator_dataset_op.h"
16 
17 #include <iterator>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/kernels/data/captured_function.h"
23 #include "tensorflow/core/lib/random/random.h"
24 
25 namespace tensorflow {
26 namespace data {
27 
28 // See documentation in ../../ops/dataset_ops.cc for a high-level
29 // description of the following op.
30 
31 class GeneratorDatasetOp::Dataset : public DatasetBase {
32  public:
Dataset(OpKernelContext * ctx,std::unique_ptr<CapturedFunction> init_func,std::unique_ptr<CapturedFunction> next_func,std::unique_ptr<CapturedFunction> finalize_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)33   Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
34           std::unique_ptr<CapturedFunction> next_func,
35           std::unique_ptr<CapturedFunction> finalize_func,
36           const DataTypeVector& output_types,
37           const std::vector<PartialTensorShape>& output_shapes)
38       : DatasetBase(DatasetContext(ctx)),
39         init_func_(std::move(init_func)),
40         next_func_(std::move(next_func)),
41         finalize_func_(std::move(finalize_func)),
42         output_types_(output_types),
43         output_shapes_(output_shapes) {}
44 
MakeIteratorInternal(const string & prefix) const45   std::unique_ptr<IteratorBase> MakeIteratorInternal(
46       const string& prefix) const override {
47     return absl::make_unique<Iterator>(
48         Iterator::Params{this, strings::StrCat(prefix, "::Generator")});
49   }
50 
output_dtypes() const51   const DataTypeVector& output_dtypes() const override { return output_types_; }
52 
output_shapes() const53   const std::vector<PartialTensorShape>& output_shapes() const override {
54     return output_shapes_;
55   }
56 
DebugString() const57   string DebugString() const override { return "GeneratorDatasetOp::Dataset"; }
58 
59  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const60   Status AsGraphDefInternal(SerializationContext* ctx,
61                             DatasetGraphDefBuilder* b,
62                             Node** output) const override {
63     return errors::Unimplemented("%s does not support serialization",
64                                  DebugString());
65   }
66 
67  private:
68   class Iterator : public DatasetIterator<Dataset> {
69    public:
Iterator(const Params & params)70     explicit Iterator(const Params& params)
71         : DatasetIterator<Dataset>(params) {}
72 
~Iterator()73     ~Iterator() override {
74       if (!finalized_ && initialized_) {
75         std::vector<Tensor> ignored;
76         Status s =
77             instantiated_finalize_func_->RunInstantiated(state_, &ignored);
78         if (!s.ok()) {
79           LOG(WARNING)
80               << "Error occurred when finalizing GeneratorDataset iterator: "
81               << s;
82         }
83       }
84     }
85 
Initialize(IteratorContext * ctx)86     Status Initialize(IteratorContext* ctx) override {
87       TF_RETURN_IF_ERROR(
88           dataset()->init_func_->Instantiate(ctx, &instantiated_init_func_));
89       TF_RETURN_IF_ERROR(
90           dataset()->next_func_->Instantiate(ctx, &instantiated_next_func_));
91       TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(
92           ctx, &instantiated_finalize_func_));
93       return Status::OK();
94     }
95 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)96     Status GetNextInternal(IteratorContext* ctx,
97                            std::vector<Tensor>* out_tensors,
98                            bool* end_of_sequence) override {
99       mutex_lock l(mu_);
100 
101       if (!initialized_) {
102         TF_RETURN_IF_ERROR(
103             instantiated_init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
104         initialized_ = true;
105       }
106 
107       if (finalized_) {
108         *end_of_sequence = true;
109         return Status::OK();
110       }
111 
112       Status s = instantiated_next_func_->RunWithBorrowedArgs(ctx, state_,
113                                                               out_tensors);
114       if (s.ok()) {
115         *end_of_sequence = false;
116       } else if (errors::IsOutOfRange(s)) {
117         // `next_func` may deliberately raise `errors::OutOfRange`
118         // to indicate that we should terminate the iteration.
119         s = Status::OK();
120         *end_of_sequence = true;
121 
122         // NOTE(mrry): We ignore any tensors returned by the
123         // finalize function.
124         std::vector<Tensor> ignored;
125         TF_RETURN_IF_ERROR(
126             instantiated_finalize_func_->RunInstantiated(state_, &ignored));
127         finalized_ = true;
128       }
129       return s;
130     }
131 
132    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const133     std::shared_ptr<model::Node> CreateNode(
134         IteratorContext* ctx, model::Node::Args args) const override {
135       return model::MakeSourceNode(std::move(args));
136     }
137 
138    private:
139     mutex mu_;
140     bool initialized_ GUARDED_BY(mu_) = false;
141     bool finalized_ GUARDED_BY(mu_) = false;
142     std::vector<Tensor> state_ GUARDED_BY(mu_);
143     std::unique_ptr<InstantiatedCapturedFunction> instantiated_init_func_;
144     std::unique_ptr<InstantiatedCapturedFunction> instantiated_next_func_;
145     std::unique_ptr<InstantiatedCapturedFunction> instantiated_finalize_func_;
146   };
147 
148   const std::unique_ptr<CapturedFunction> init_func_;
149   const std::unique_ptr<CapturedFunction> next_func_;
150   const std::unique_ptr<CapturedFunction> finalize_func_;
151   const DataTypeVector output_types_;
152   const std::vector<PartialTensorShape> output_shapes_;
153 };
154 
GeneratorDatasetOp(OpKernelConstruction * ctx)155 GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
156     : DatasetOpKernel(ctx) {
157   OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
158   OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_));
159   OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_));
160   OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
161   OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
162 }
163 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)164 void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
165                                      DatasetBase** output) {
166   std::unique_ptr<CapturedFunction> init_func;
167   OP_REQUIRES_OK(ctx, CapturedFunction::Create(
168                           init_func_, ctx, "init_func_other_args", &init_func));
169 
170   std::unique_ptr<CapturedFunction> next_func;
171   OP_REQUIRES_OK(ctx, CapturedFunction::Create(
172                           next_func_, ctx, "next_func_other_args", &next_func));
173 
174   std::unique_ptr<CapturedFunction> finalize_func;
175   OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx,
176                                                "finalize_func_other_args",
177                                                &finalize_func));
178 
179   *output =
180       new Dataset(ctx, std::move(init_func), std::move(next_func),
181                   std::move(finalize_func), output_types_, output_shapes_);
182 }
183 
184 namespace {
185 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU).Priority(2),
186                         GeneratorDatasetOp);
187 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset")
188                             .Device(DEVICE_GPU)
189                             .HostMemory("handle")
190                             .Priority(1),
191                         GeneratorDatasetOp);
192 }  // namespace
193 
194 }  // namespace data
195 }  // namespace tensorflow
196