• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/repeat_dataset_op.h"
16 
17 #include <utility>
18 
19 #include "tensorflow/core/data/name_utils.h"
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22 
23 namespace tensorflow {
24 namespace data {
25 
26 // See documentation in ../../ops/dataset_ops.cc for a high-level
27 // description of the following op.
28 
29 /* static */ constexpr const char* const RepeatDatasetOp::kDatasetType;
30 /* static */ constexpr const char* const RepeatDatasetOp::kInputDataset;
31 /* static */ constexpr const char* const RepeatDatasetOp::kCount;
32 /* static */ constexpr const char* const RepeatDatasetOp::kOutputTypes;
33 /* static */ constexpr const char* const RepeatDatasetOp::kOutputShapes;
34 
35 constexpr char kForeverRepeat[] = "ForeverRepeat";
36 constexpr char kEmptyRepeat[] = "EmptyRepeat";
37 constexpr char kFiniteRepeat[] = "FiniteRepeat";
38 constexpr char kCurIteration[] = "i";
39 constexpr char kInputImplEmpty[] = "input_impl_empty";
40 constexpr char kUninitialized[] = "uninitialized";
41 constexpr int64_t kKnownRatio = 1;
42 
43 class RepeatDatasetOp::Dataset : public DatasetBase {
44  public:
Dataset(OpKernelContext * ctx,int64_t count,const DatasetBase * input)45   Dataset(OpKernelContext* ctx, int64_t count, const DatasetBase* input)
46       : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
47     input_->Ref();
48   }
49 
~Dataset()50   ~Dataset() override { input_->Unref(); }
51 
MakeIteratorInternal(const string & prefix) const52   std::unique_ptr<IteratorBase> MakeIteratorInternal(
53       const string& prefix) const override {
54     if (count_ < 0) {
55       return absl::make_unique<ForeverIterator>(ForeverIterator::Params{
56           this, name_utils::IteratorPrefix(kForeverRepeat, prefix)});
57     } else if (count_ == 0) {
58       return absl::make_unique<EmptyIterator>(EmptyIterator::Params{
59           this, name_utils::IteratorPrefix(kEmptyRepeat, prefix)});
60     } else {
61       return absl::make_unique<FiniteIterator>(FiniteIterator::Params{
62           this, name_utils::IteratorPrefix(kFiniteRepeat, prefix)});
63     }
64   }
65 
output_dtypes() const66   const DataTypeVector& output_dtypes() const override {
67     return input_->output_dtypes();
68   }
output_shapes() const69   const std::vector<PartialTensorShape>& output_shapes() const override {
70     return input_->output_shapes();
71   }
72 
DebugString() const73   string DebugString() const override {
74     return name_utils::DatasetDebugString(RepeatDatasetOp::kDatasetType);
75   }
76 
Cardinality() const77   int64 Cardinality() const override {
78     int64_t n = input_->Cardinality();
79     if (count_ < 0) {
80       if (n == 0) {
81         return 0;
82       }
83       return kInfiniteCardinality;
84     }
85     if (count_ == 0) {
86       return 0;
87     }
88     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
89       return n;
90     }
91     return count_ * n;
92   }
93 
InputDatasets(std::vector<const DatasetBase * > * inputs) const94   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
95     inputs->push_back(input_);
96     return Status::OK();
97   }
98 
CheckExternalState() const99   Status CheckExternalState() const override {
100     return input_->CheckExternalState();
101   }
102 
103  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const104   Status AsGraphDefInternal(SerializationContext* ctx,
105                             DatasetGraphDefBuilder* b,
106                             Node** output) const override {
107     Node* input_graph_node = nullptr;
108     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
109     Node* count = nullptr;
110     TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
111     TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output));
112     return Status::OK();
113   }
114 
115  private:
116   class EmptyIterator : public DatasetIterator<Dataset> {
117    public:
EmptyIterator(const Params & params)118     explicit EmptyIterator(const Params& params)
119         : DatasetIterator<Dataset>(params) {}
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)120     Status GetNextInternal(IteratorContext* ctx,
121                            std::vector<Tensor>* out_tensors,
122                            bool* end_of_sequence) override {
123       *end_of_sequence = true;
124       return Status::OK();
125     }
126 
127    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const128     std::shared_ptr<model::Node> CreateNode(
129         IteratorContext* ctx, model::Node::Args args) const override {
130       return model::MakeKnownRatioNode(std::move(args),
131                                        /*ratio=*/kKnownRatio);
132     }
133 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)134     Status SaveInternal(SerializationContext* ctx,
135                         IteratorStateWriter* writer) override {
136       return Status::OK();
137     }
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)138     Status RestoreInternal(IteratorContext* ctx,
139                            IteratorStateReader* reader) override {
140       return Status::OK();
141     }
142   };
143 
144   class FiniteIterator : public DatasetIterator<Dataset> {
145    public:
FiniteIterator(const Params & params)146     explicit FiniteIterator(const Params& params)
147         : DatasetIterator<Dataset>(params), i_(0) {}
148 
Initialize(IteratorContext * ctx)149     Status Initialize(IteratorContext* ctx) override {
150       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
151     }
152 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)153     Status GetNextInternal(IteratorContext* ctx,
154                            std::vector<Tensor>* out_tensors,
155                            bool* end_of_sequence) override {
156       mutex_lock l(mu_);  // TODO(mrry): Make locking less conservative.
157       if (!input_impl_) {
158         *end_of_sequence = true;
159         return Status::OK();
160       }
161       while (i_ < dataset()->count_) {
162         TF_RETURN_IF_ERROR(
163             input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
164         if (!*end_of_sequence) {
165           return Status::OK();
166         }
167         ++i_;
168         for (const auto& provider : ctx->split_providers()) {
169           TF_RETURN_IF_ERROR(provider->Reset());
170         }
171         TF_RETURN_IF_ERROR(
172             dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
173       }
174       *end_of_sequence = true;
175       input_impl_.reset();
176       return Status::OK();
177     }
178 
179    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const180     std::shared_ptr<model::Node> CreateNode(
181         IteratorContext* ctx, model::Node::Args args) const override {
182       return model::MakeKnownRatioNode(std::move(args),
183                                        /*ratio=*/kKnownRatio);
184     }
185 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)186     Status SaveInternal(SerializationContext* ctx,
187                         IteratorStateWriter* writer) override {
188       mutex_lock l(mu_);
189       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIteration), i_));
190       if (!input_impl_) {
191         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
192       } else {
193         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
194       }
195       return Status::OK();
196     }
197 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)198     Status RestoreInternal(IteratorContext* ctx,
199                            IteratorStateReader* reader) override {
200       mutex_lock l(mu_);
201       TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIteration), &i_));
202       if (!reader->Contains(full_name(kInputImplEmpty))) {
203         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
204       } else {
205         input_impl_.reset();
206       }
207       return Status::OK();
208     }
209 
210    private:
211     mutex mu_;
212     int64 i_ TF_GUARDED_BY(mu_);
213     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
214   };
215 
216   class ForeverIterator : public DatasetIterator<Dataset> {
217    public:
ForeverIterator(const Params & params)218     explicit ForeverIterator(const Params& params)
219         : DatasetIterator<Dataset>(params),
220           input_impl_(nullptr),
221           first_call_(true) {}
222 
Initialize(IteratorContext * ctx)223     Status Initialize(IteratorContext* ctx) override {
224       mutex_lock l(mu_);
225       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
226     }
227 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)228     Status GetNextInternal(IteratorContext* ctx,
229                            std::vector<Tensor>* out_tensors,
230                            bool* end_of_sequence) override {
231       mutex_lock l(mu_);  // TODO(mrry): Make locking less conservative.
232       do {
233         if (!input_impl_) {
234           TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
235               ctx, this, prefix(), &input_impl_));
236         }
237         TF_RETURN_IF_ERROR(
238             input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
239         DCHECK(!*end_of_sequence || out_tensors->empty());
240         if (first_call_ && *end_of_sequence && ctx->split_providers().empty()) {
241           // If the first call to GetNext() fails because the end of sequence
242           // has been reached, we terminate the iteration immediately.
243           // Otherwise, this iterator would loop infinitely and never produce a
244           // value.
245           input_impl_.reset();
246           return Status::OK();
247         }
248         first_call_ = false;
249         if (!*end_of_sequence) {
250           return Status::OK();
251         }
252         for (const auto& provider : ctx->split_providers()) {
253           TF_RETURN_IF_ERROR(provider->Reset());
254         }
255         input_impl_.reset();
256         first_call_ = true;
257       } while (true);
258     }
259 
260    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const261     std::shared_ptr<model::Node> CreateNode(
262         IteratorContext* ctx, model::Node::Args args) const override {
263       return model::MakeKnownRatioNode(std::move(args),
264                                        /*ratio=*/kKnownRatio);
265     }
266 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)267     Status SaveInternal(SerializationContext* ctx,
268                         IteratorStateWriter* writer) override {
269       mutex_lock l(mu_);
270       if (!first_call_)
271         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
272       else
273         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kUninitialized), ""));
274       return Status::OK();
275     }
276 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)277     Status RestoreInternal(IteratorContext* ctx,
278                            IteratorStateReader* reader) override {
279       mutex_lock l(mu_);
280       if (reader->Contains(full_name(kUninitialized))) {
281         input_impl_.reset();
282         first_call_ = true;
283       } else {
284         TF_RETURN_IF_ERROR(
285             dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
286         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
287         first_call_ = false;
288       }
289       return Status::OK();
290     }
291 
292    private:
293     mutex mu_;
294     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
295     bool first_call_ TF_GUARDED_BY(mu_);
296   };
297 
298   const int64 count_;
299   const DatasetBase* const input_;
300 };
301 
RepeatDatasetOp(OpKernelConstruction * ctx)302 RepeatDatasetOp::RepeatDatasetOp(OpKernelConstruction* ctx)
303     : UnaryDatasetOpKernel(ctx) {}
304 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)305 void RepeatDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
306                                   DatasetBase** output) {
307   // Create a new RepeatDatasetOp::Dataset, insert it in the step-local
308   // container, and return it as the output.
309   int64_t count;
310   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
311   *output = new Dataset(ctx, count, input);
312 }
313 
314 namespace {
315 REGISTER_KERNEL_BUILDER(Name("RepeatDataset").Device(DEVICE_CPU),
316                         RepeatDatasetOp);
317 }  // namespace
318 }  // namespace data
319 }  // namespace tensorflow
320