• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/generator_dataset_op.h"
16 
17 #include <iterator>
18 #include <vector>
19 
20 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
21 #include "tensorflow/core/data/captured_function.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/name_utils.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/random/random.h"
27 
28 namespace tensorflow {
29 namespace data {
30 
31 // See documentation in ../../ops/dataset_ops.cc for a high-level
32 // description of the following op.
33 
34 /* static */ constexpr const char* const GeneratorDatasetOp::kDatasetType;
35 /* static */ constexpr const char* const GeneratorDatasetOp::kInitFuncOtherArgs;
36 /* static */ constexpr const char* const GeneratorDatasetOp::kNextFuncOtherArgs;
37 /* static */ constexpr const char* const
38     GeneratorDatasetOp::kFinalizeFuncOtherArgs;
39 /* static */ constexpr const char* const GeneratorDatasetOp::kInitFunc;
40 /* static */ constexpr const char* const GeneratorDatasetOp::kNextFunc;
41 /* static */ constexpr const char* const GeneratorDatasetOp::kFinalizeFunc;
42 /* static */ constexpr const char* const GeneratorDatasetOp::kTinitFuncArgs;
43 /* static */ constexpr const char* const GeneratorDatasetOp::kTnextFuncArgs;
44 /* static */ constexpr const char* const GeneratorDatasetOp::kTfinalizeFuncArgs;
45 /* static */ constexpr const char* const GeneratorDatasetOp::kOutputTypes;
46 /* static */ constexpr const char* const GeneratorDatasetOp::kOutputShapes;
47 
48 class GeneratorDatasetOp::Dataset : public DatasetBase {
49  public:
Dataset(OpKernelContext * ctx,std::unique_ptr<CapturedFunction> init_func,std::unique_ptr<CapturedFunction> next_func,std::unique_ptr<CapturedFunction> finalize_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)50   Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
51           std::unique_ptr<CapturedFunction> next_func,
52           std::unique_ptr<CapturedFunction> finalize_func,
53           const DataTypeVector& output_types,
54           const std::vector<PartialTensorShape>& output_shapes)
55       : DatasetBase(DatasetContext(ctx)),
56         init_func_(std::move(init_func)),
57         next_func_(std::move(next_func)),
58         finalize_func_(std::move(finalize_func)),
59         output_types_(output_types),
60         output_shapes_(output_shapes) {}
61 
MakeIteratorInternal(const string & prefix) const62   std::unique_ptr<IteratorBase> MakeIteratorInternal(
63       const string& prefix) const override {
64     return std::make_unique<Iterator>(Iterator::Params{
65         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
66   }
67 
output_dtypes() const68   const DataTypeVector& output_dtypes() const override { return output_types_; }
69 
output_shapes() const70   const std::vector<PartialTensorShape>& output_shapes() const override {
71     return output_shapes_;
72   }
73 
DebugString() const74   string DebugString() const override {
75     return name_utils::DatasetDebugString(kDatasetType);
76   }
77 
InputDatasets(std::vector<const DatasetBase * > * inputs) const78   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
79     return OkStatus();
80   }
81 
CheckExternalState() const82   Status CheckExternalState() const override {
83     TF_RETURN_IF_ERROR(init_func_->CheckExternalState());
84     TF_RETURN_IF_ERROR(next_func_->CheckExternalState());
85     return finalize_func_->CheckExternalState();
86   }
87 
88  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const89   Status AsGraphDefInternal(SerializationContext* ctx,
90                             DatasetGraphDefBuilder* b,
91                             Node** output) const override {
92     return errors::Unimplemented(DebugString(),
93                                  " does not support serialization");
94   }
95 
96  private:
97   class Iterator : public DatasetIterator<Dataset> {
98    public:
Iterator(const Params & params)99     explicit Iterator(const Params& params)
100         : DatasetIterator<Dataset>(params) {}
101 
~Iterator()102     ~Iterator() override {
103       if (!finalized_ && initialized_) {
104         std::vector<Tensor> ignored;
105         Status s =
106             instantiated_finalize_func_->RunInstantiated(state_, &ignored);
107         if (!s.ok()) {
108           LOG(WARNING)
109               << "Error occurred when finalizing GeneratorDataset iterator: "
110               << s;
111         }
112       }
113     }
114 
Initialize(IteratorContext * ctx)115     Status Initialize(IteratorContext* ctx) override {
116       TF_RETURN_IF_ERROR(
117           dataset()->init_func_->Instantiate(ctx, &instantiated_init_func_));
118       TF_RETURN_IF_ERROR(
119           dataset()->next_func_->Instantiate(ctx, &instantiated_next_func_));
120       TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(
121           ctx, &instantiated_finalize_func_));
122       return OkStatus();
123     }
124 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)125     Status GetNextInternal(IteratorContext* ctx,
126                            std::vector<Tensor>* out_tensors,
127                            bool* end_of_sequence) override {
128       mutex_lock l(mu_);
129 
130       if (!initialized_) {
131         TF_RETURN_IF_ERROR(instantiated_init_func_->RunWithBorrowedArgs(
132             ctx, {}, &state_, model_node()));
133         initialized_ = true;
134       }
135 
136       if (finalized_) {
137         *end_of_sequence = true;
138         return OkStatus();
139       }
140 
141       Status s = instantiated_next_func_->RunWithBorrowedArgs(
142           ctx, state_, out_tensors, model_node());
143       if (s.ok()) {
144         *end_of_sequence = false;
145       } else if (errors::IsOutOfRange(s)) {
146         // `next_func` may deliberately raise `errors::OutOfRange`
147         // to indicate that we should terminate the iteration.
148         s = OkStatus();
149         *end_of_sequence = true;
150 
151         // NOTE(mrry): We ignore any tensors returned by the finalize function.
152         std::vector<Tensor> ignored;
153         TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
154             ctx, state_, &ignored, model_node()));
155         finalized_ = true;
156       }
157       return s;
158     }
159 
160    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const161     std::shared_ptr<model::Node> CreateNode(
162         IteratorContext* ctx, model::Node::Args args) const override {
163       return model::MakeSourceNode(std::move(args));
164     }
165 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)166     Status SaveInternal(SerializationContext* ctx,
167                         IteratorStateWriter* writer) override {
168       return errors::Unimplemented(
169           "GeneratorDataset does not support checkpointing.");
170     }
171 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)172     Status RestoreInternal(IteratorContext* ctx,
173                            IteratorStateReader* reader) override {
174       return errors::Unimplemented(
175           "GeneratorDataset does not support checkpointing.");
176     }
177 
178    private:
179     mutex mu_;
180     bool initialized_ TF_GUARDED_BY(mu_) = false;
181     bool finalized_ TF_GUARDED_BY(mu_) = false;
182     std::vector<Tensor> state_ TF_GUARDED_BY(mu_);
183     std::unique_ptr<InstantiatedCapturedFunction> instantiated_init_func_;
184     std::unique_ptr<InstantiatedCapturedFunction> instantiated_next_func_;
185     std::unique_ptr<InstantiatedCapturedFunction> instantiated_finalize_func_;
186   };
187 
188   const std::unique_ptr<CapturedFunction> init_func_;
189   const std::unique_ptr<CapturedFunction> next_func_;
190   const std::unique_ptr<CapturedFunction> finalize_func_;
191   const DataTypeVector output_types_;
192   const std::vector<PartialTensorShape> output_shapes_;
193 };
194 
GeneratorDatasetOp(OpKernelConstruction * ctx)195 GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
196     : DatasetOpKernel(ctx) {
197   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kInitFunc, /*params=*/{},
198                                                &init_func_metadata_));
199   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kNextFunc, /*params=*/{},
200                                                &next_func_metadata_));
201   OP_REQUIRES_OK(ctx,
202                  FunctionMetadata::Create(ctx, kFinalizeFunc, /*params=*/{},
203                                           &finalize_func_metadata_));
204   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
205   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
206 }
207 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)208 void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
209                                      DatasetBase** output) {
210   std::unique_ptr<CapturedFunction> init_func;
211   OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, init_func_metadata_,
212                                                kInitFuncOtherArgs, &init_func));
213 
214   std::unique_ptr<CapturedFunction> next_func;
215   OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, next_func_metadata_,
216                                                kNextFuncOtherArgs, &next_func));
217 
218   std::unique_ptr<CapturedFunction> finalize_func;
219   OP_REQUIRES_OK(
220       ctx, CapturedFunction::Create(ctx, finalize_func_metadata_,
221                                     kFinalizeFuncOtherArgs, &finalize_func));
222 
223   *output =
224       new Dataset(ctx, std::move(init_func), std::move(next_func),
225                   std::move(finalize_func), output_types_, output_shapes_);
226 }
227 
228 namespace {
229 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU).Priority(2),
230                         GeneratorDatasetOp);
231 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset")
232                             .Device(DEVICE_GPU)
233                             .HostMemory("handle")
234                             .Priority(1),
235                         GeneratorDatasetOp);
236 REGISTER_INPUT_COLOCATION_EXEMPTION("GeneratorDataset");
237 }  // namespace
238 
239 }  // namespace data
240 }  // namespace tensorflow
241