• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/generator_dataset_op.h"
16 
17 #include <iterator>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/kernels/data/captured_function.h"
23 #include "tensorflow/core/kernels/data/dataset_utils.h"
24 #include "tensorflow/core/kernels/data/name_utils.h"
25 #include "tensorflow/core/lib/random/random.h"
26 
27 namespace tensorflow {
28 namespace data {
29 
30 // See documentation in ../../ops/dataset_ops.cc for a high-level
31 // description of the following op.
32 
33 /* static */ constexpr const char* const GeneratorDatasetOp::kDatasetType;
34 /* static */ constexpr const char* const GeneratorDatasetOp::kInitFuncOtherArgs;
35 /* static */ constexpr const char* const GeneratorDatasetOp::kNextFuncOtherArgs;
36 /* static */ constexpr const char* const
37     GeneratorDatasetOp::kFinalizeFuncOtherArgs;
38 /* static */ constexpr const char* const GeneratorDatasetOp::kInitFunc;
39 /* static */ constexpr const char* const GeneratorDatasetOp::kNextFunc;
40 /* static */ constexpr const char* const GeneratorDatasetOp::kFinalizeFunc;
41 /* static */ constexpr const char* const GeneratorDatasetOp::kTinitFuncArgs;
42 /* static */ constexpr const char* const GeneratorDatasetOp::kTnextFuncArgs;
43 /* static */ constexpr const char* const GeneratorDatasetOp::kTfinalizeFuncArgs;
44 /* static */ constexpr const char* const GeneratorDatasetOp::kOutputTypes;
45 /* static */ constexpr const char* const GeneratorDatasetOp::kOutputShapes;
46 
47 class GeneratorDatasetOp::Dataset : public DatasetBase {
48  public:
Dataset(OpKernelContext * ctx,std::unique_ptr<CapturedFunction> init_func,std::unique_ptr<CapturedFunction> next_func,std::unique_ptr<CapturedFunction> finalize_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)49   Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
50           std::unique_ptr<CapturedFunction> next_func,
51           std::unique_ptr<CapturedFunction> finalize_func,
52           const DataTypeVector& output_types,
53           const std::vector<PartialTensorShape>& output_shapes)
54       : DatasetBase(DatasetContext(ctx)),
55         init_func_(std::move(init_func)),
56         next_func_(std::move(next_func)),
57         finalize_func_(std::move(finalize_func)),
58         output_types_(output_types),
59         output_shapes_(output_shapes) {}
60 
MakeIteratorInternal(const string & prefix) const61   std::unique_ptr<IteratorBase> MakeIteratorInternal(
62       const string& prefix) const override {
63     return absl::make_unique<Iterator>(Iterator::Params{
64         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
65   }
66 
output_dtypes() const67   const DataTypeVector& output_dtypes() const override { return output_types_; }
68 
output_shapes() const69   const std::vector<PartialTensorShape>& output_shapes() const override {
70     return output_shapes_;
71   }
72 
DebugString() const73   string DebugString() const override {
74     return name_utils::DatasetDebugString(kDatasetType);
75   }
76 
InputDatasets(std::vector<const DatasetBase * > * inputs) const77   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
78     return Status::OK();
79   }
80 
CheckExternalState() const81   Status CheckExternalState() const override {
82     TF_RETURN_IF_ERROR(init_func_->CheckExternalState());
83     TF_RETURN_IF_ERROR(next_func_->CheckExternalState());
84     return finalize_func_->CheckExternalState();
85   }
86 
87  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const88   Status AsGraphDefInternal(SerializationContext* ctx,
89                             DatasetGraphDefBuilder* b,
90                             Node** output) const override {
91     return errors::Unimplemented(DebugString(),
92                                  " does not support serialization");
93   }
94 
95  private:
96   class Iterator : public DatasetIterator<Dataset> {
97    public:
Iterator(const Params & params)98     explicit Iterator(const Params& params)
99         : DatasetIterator<Dataset>(params) {}
100 
~Iterator()101     ~Iterator() override {
102       if (!finalized_ && initialized_) {
103         std::vector<Tensor> ignored;
104         Status s =
105             instantiated_finalize_func_->RunInstantiated(state_, &ignored);
106         if (!s.ok()) {
107           LOG(WARNING)
108               << "Error occurred when finalizing GeneratorDataset iterator: "
109               << s;
110         }
111       }
112     }
113 
Initialize(IteratorContext * ctx)114     Status Initialize(IteratorContext* ctx) override {
115       TF_RETURN_IF_ERROR(
116           dataset()->init_func_->Instantiate(ctx, &instantiated_init_func_));
117       TF_RETURN_IF_ERROR(
118           dataset()->next_func_->Instantiate(ctx, &instantiated_next_func_));
119       TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(
120           ctx, &instantiated_finalize_func_));
121       return Status::OK();
122     }
123 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)124     Status GetNextInternal(IteratorContext* ctx,
125                            std::vector<Tensor>* out_tensors,
126                            bool* end_of_sequence) override {
127       mutex_lock l(mu_);
128 
129       if (!initialized_) {
130         TF_RETURN_IF_ERROR(instantiated_init_func_->RunWithBorrowedArgs(
131             ctx, {}, &state_, model_node()));
132         initialized_ = true;
133       }
134 
135       if (finalized_) {
136         *end_of_sequence = true;
137         return Status::OK();
138       }
139 
140       Status s = instantiated_next_func_->RunWithBorrowedArgs(
141           ctx, state_, out_tensors, model_node());
142       if (s.ok()) {
143         *end_of_sequence = false;
144       } else if (errors::IsOutOfRange(s)) {
145         // `next_func` may deliberately raise `errors::OutOfRange`
146         // to indicate that we should terminate the iteration.
147         s = Status::OK();
148         *end_of_sequence = true;
149 
150         // NOTE(mrry): We ignore any tensors returned by the finalize function.
151         std::vector<Tensor> ignored;
152         TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
153             ctx, state_, &ignored, model_node()));
154         finalized_ = true;
155       }
156       return s;
157     }
158 
159    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const160     std::shared_ptr<model::Node> CreateNode(
161         IteratorContext* ctx, model::Node::Args args) const override {
162       return model::MakeSourceNode(std::move(args));
163     }
164 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)165     Status SaveInternal(SerializationContext* ctx,
166                         IteratorStateWriter* writer) override {
167       return errors::Unimplemented(
168           "GeneratorDataset does not support checkpointing.");
169     }
170 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)171     Status RestoreInternal(IteratorContext* ctx,
172                            IteratorStateReader* reader) override {
173       return errors::Unimplemented(
174           "GeneratorDataset does not support checkpointing.");
175     }
176 
177    private:
178     mutex mu_;
179     bool initialized_ TF_GUARDED_BY(mu_) = false;
180     bool finalized_ TF_GUARDED_BY(mu_) = false;
181     std::vector<Tensor> state_ TF_GUARDED_BY(mu_);
182     std::unique_ptr<InstantiatedCapturedFunction> instantiated_init_func_;
183     std::unique_ptr<InstantiatedCapturedFunction> instantiated_next_func_;
184     std::unique_ptr<InstantiatedCapturedFunction> instantiated_finalize_func_;
185   };
186 
187   const std::unique_ptr<CapturedFunction> init_func_;
188   const std::unique_ptr<CapturedFunction> next_func_;
189   const std::unique_ptr<CapturedFunction> finalize_func_;
190   const DataTypeVector output_types_;
191   const std::vector<PartialTensorShape> output_shapes_;
192 };
193 
GeneratorDatasetOp(OpKernelConstruction * ctx)194 GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
195     : DatasetOpKernel(ctx) {
196   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kInitFunc, /*params=*/{},
197                                                &init_func_metadata_));
198   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kNextFunc, /*params=*/{},
199                                                &next_func_metadata_));
200   OP_REQUIRES_OK(ctx,
201                  FunctionMetadata::Create(ctx, kFinalizeFunc, /*params=*/{},
202                                           &finalize_func_metadata_));
203   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
204   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
205 }
206 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)207 void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
208                                      DatasetBase** output) {
209   std::unique_ptr<CapturedFunction> init_func;
210   OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, init_func_metadata_,
211                                                kInitFuncOtherArgs, &init_func));
212 
213   std::unique_ptr<CapturedFunction> next_func;
214   OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, next_func_metadata_,
215                                                kNextFuncOtherArgs, &next_func));
216 
217   std::unique_ptr<CapturedFunction> finalize_func;
218   OP_REQUIRES_OK(
219       ctx, CapturedFunction::Create(ctx, finalize_func_metadata_,
220                                     kFinalizeFuncOtherArgs, &finalize_func));
221 
222   *output =
223       new Dataset(ctx, std::move(init_func), std::move(next_func),
224                   std::move(finalize_func), output_types_, output_shapes_);
225 }
226 
227 namespace {
228 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU).Priority(2),
229                         GeneratorDatasetOp);
230 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset")
231                             .Device(DEVICE_GPU)
232                             .HostMemory("handle")
233                             .Priority(1),
234                         GeneratorDatasetOp);
235 }  // namespace
236 
237 }  // namespace data
238 }  // namespace tensorflow
239