1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/repeat_dataset_op.h"
16
17 #include <utility>
18
19 #include "tensorflow/core/data/name_utils.h"
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22
23 namespace tensorflow {
24 namespace data {
25
26 // See documentation in ../../ops/dataset_ops.cc for a high-level
27 // description of the following op.
28
29 /* static */ constexpr const char* const RepeatDatasetOp::kDatasetType;
30 /* static */ constexpr const char* const RepeatDatasetOp::kInputDataset;
31 /* static */ constexpr const char* const RepeatDatasetOp::kCount;
32 /* static */ constexpr const char* const RepeatDatasetOp::kOutputTypes;
33 /* static */ constexpr const char* const RepeatDatasetOp::kOutputShapes;
34
35 constexpr char kForeverRepeat[] = "ForeverRepeat";
36 constexpr char kEmptyRepeat[] = "EmptyRepeat";
37 constexpr char kFiniteRepeat[] = "FiniteRepeat";
38 constexpr char kCurIteration[] = "i";
39 constexpr char kInputImplEmpty[] = "input_impl_empty";
40 constexpr char kUninitialized[] = "uninitialized";
41 constexpr int64_t kKnownRatio = 1;
42
43 class RepeatDatasetOp::Dataset : public DatasetBase {
44 public:
Dataset(OpKernelContext * ctx,int64_t count,const DatasetBase * input)45 Dataset(OpKernelContext* ctx, int64_t count, const DatasetBase* input)
46 : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
47 input_->Ref();
48 }
49
~Dataset()50 ~Dataset() override { input_->Unref(); }
51
MakeIteratorInternal(const string & prefix) const52 std::unique_ptr<IteratorBase> MakeIteratorInternal(
53 const string& prefix) const override {
54 if (count_ < 0) {
55 return absl::make_unique<ForeverIterator>(ForeverIterator::Params{
56 this, name_utils::IteratorPrefix(kForeverRepeat, prefix)});
57 } else if (count_ == 0) {
58 return absl::make_unique<EmptyIterator>(EmptyIterator::Params{
59 this, name_utils::IteratorPrefix(kEmptyRepeat, prefix)});
60 } else {
61 return absl::make_unique<FiniteIterator>(FiniteIterator::Params{
62 this, name_utils::IteratorPrefix(kFiniteRepeat, prefix)});
63 }
64 }
65
output_dtypes() const66 const DataTypeVector& output_dtypes() const override {
67 return input_->output_dtypes();
68 }
output_shapes() const69 const std::vector<PartialTensorShape>& output_shapes() const override {
70 return input_->output_shapes();
71 }
72
DebugString() const73 string DebugString() const override {
74 return name_utils::DatasetDebugString(RepeatDatasetOp::kDatasetType);
75 }
76
Cardinality() const77 int64 Cardinality() const override {
78 int64_t n = input_->Cardinality();
79 if (count_ < 0) {
80 if (n == 0) {
81 return 0;
82 }
83 return kInfiniteCardinality;
84 }
85 if (count_ == 0) {
86 return 0;
87 }
88 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
89 return n;
90 }
91 return count_ * n;
92 }
93
InputDatasets(std::vector<const DatasetBase * > * inputs) const94 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
95 inputs->push_back(input_);
96 return Status::OK();
97 }
98
CheckExternalState() const99 Status CheckExternalState() const override {
100 return input_->CheckExternalState();
101 }
102
103 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const104 Status AsGraphDefInternal(SerializationContext* ctx,
105 DatasetGraphDefBuilder* b,
106 Node** output) const override {
107 Node* input_graph_node = nullptr;
108 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
109 Node* count = nullptr;
110 TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
111 TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output));
112 return Status::OK();
113 }
114
115 private:
116 class EmptyIterator : public DatasetIterator<Dataset> {
117 public:
EmptyIterator(const Params & params)118 explicit EmptyIterator(const Params& params)
119 : DatasetIterator<Dataset>(params) {}
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)120 Status GetNextInternal(IteratorContext* ctx,
121 std::vector<Tensor>* out_tensors,
122 bool* end_of_sequence) override {
123 *end_of_sequence = true;
124 return Status::OK();
125 }
126
127 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const128 std::shared_ptr<model::Node> CreateNode(
129 IteratorContext* ctx, model::Node::Args args) const override {
130 return model::MakeKnownRatioNode(std::move(args),
131 /*ratio=*/kKnownRatio);
132 }
133
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)134 Status SaveInternal(SerializationContext* ctx,
135 IteratorStateWriter* writer) override {
136 return Status::OK();
137 }
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)138 Status RestoreInternal(IteratorContext* ctx,
139 IteratorStateReader* reader) override {
140 return Status::OK();
141 }
142 };
143
144 class FiniteIterator : public DatasetIterator<Dataset> {
145 public:
FiniteIterator(const Params & params)146 explicit FiniteIterator(const Params& params)
147 : DatasetIterator<Dataset>(params), i_(0) {}
148
Initialize(IteratorContext * ctx)149 Status Initialize(IteratorContext* ctx) override {
150 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
151 }
152
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)153 Status GetNextInternal(IteratorContext* ctx,
154 std::vector<Tensor>* out_tensors,
155 bool* end_of_sequence) override {
156 mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
157 if (!input_impl_) {
158 *end_of_sequence = true;
159 return Status::OK();
160 }
161 while (i_ < dataset()->count_) {
162 TF_RETURN_IF_ERROR(
163 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
164 if (!*end_of_sequence) {
165 return Status::OK();
166 }
167 ++i_;
168 for (const auto& provider : ctx->split_providers()) {
169 TF_RETURN_IF_ERROR(provider->Reset());
170 }
171 TF_RETURN_IF_ERROR(
172 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
173 }
174 *end_of_sequence = true;
175 input_impl_.reset();
176 return Status::OK();
177 }
178
179 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const180 std::shared_ptr<model::Node> CreateNode(
181 IteratorContext* ctx, model::Node::Args args) const override {
182 return model::MakeKnownRatioNode(std::move(args),
183 /*ratio=*/kKnownRatio);
184 }
185
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)186 Status SaveInternal(SerializationContext* ctx,
187 IteratorStateWriter* writer) override {
188 mutex_lock l(mu_);
189 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIteration), i_));
190 if (!input_impl_) {
191 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
192 } else {
193 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
194 }
195 return Status::OK();
196 }
197
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)198 Status RestoreInternal(IteratorContext* ctx,
199 IteratorStateReader* reader) override {
200 mutex_lock l(mu_);
201 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIteration), &i_));
202 if (!reader->Contains(full_name(kInputImplEmpty))) {
203 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
204 } else {
205 input_impl_.reset();
206 }
207 return Status::OK();
208 }
209
210 private:
211 mutex mu_;
212 int64 i_ TF_GUARDED_BY(mu_);
213 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
214 };
215
216 class ForeverIterator : public DatasetIterator<Dataset> {
217 public:
ForeverIterator(const Params & params)218 explicit ForeverIterator(const Params& params)
219 : DatasetIterator<Dataset>(params),
220 input_impl_(nullptr),
221 first_call_(true) {}
222
Initialize(IteratorContext * ctx)223 Status Initialize(IteratorContext* ctx) override {
224 mutex_lock l(mu_);
225 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
226 }
227
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)228 Status GetNextInternal(IteratorContext* ctx,
229 std::vector<Tensor>* out_tensors,
230 bool* end_of_sequence) override {
231 mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
232 do {
233 if (!input_impl_) {
234 TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
235 ctx, this, prefix(), &input_impl_));
236 }
237 TF_RETURN_IF_ERROR(
238 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
239 DCHECK(!*end_of_sequence || out_tensors->empty());
240 if (first_call_ && *end_of_sequence && ctx->split_providers().empty()) {
241 // If the first call to GetNext() fails because the end of sequence
242 // has been reached, we terminate the iteration immediately.
243 // Otherwise, this iterator would loop infinitely and never produce a
244 // value.
245 input_impl_.reset();
246 return Status::OK();
247 }
248 first_call_ = false;
249 if (!*end_of_sequence) {
250 return Status::OK();
251 }
252 for (const auto& provider : ctx->split_providers()) {
253 TF_RETURN_IF_ERROR(provider->Reset());
254 }
255 input_impl_.reset();
256 first_call_ = true;
257 } while (true);
258 }
259
260 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const261 std::shared_ptr<model::Node> CreateNode(
262 IteratorContext* ctx, model::Node::Args args) const override {
263 return model::MakeKnownRatioNode(std::move(args),
264 /*ratio=*/kKnownRatio);
265 }
266
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)267 Status SaveInternal(SerializationContext* ctx,
268 IteratorStateWriter* writer) override {
269 mutex_lock l(mu_);
270 if (!first_call_)
271 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
272 else
273 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kUninitialized), ""));
274 return Status::OK();
275 }
276
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)277 Status RestoreInternal(IteratorContext* ctx,
278 IteratorStateReader* reader) override {
279 mutex_lock l(mu_);
280 if (reader->Contains(full_name(kUninitialized))) {
281 input_impl_.reset();
282 first_call_ = true;
283 } else {
284 TF_RETURN_IF_ERROR(
285 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
286 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
287 first_call_ = false;
288 }
289 return Status::OK();
290 }
291
292 private:
293 mutex mu_;
294 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
295 bool first_call_ TF_GUARDED_BY(mu_);
296 };
297
298 const int64 count_;
299 const DatasetBase* const input_;
300 };
301
RepeatDatasetOp(OpKernelConstruction * ctx)302 RepeatDatasetOp::RepeatDatasetOp(OpKernelConstruction* ctx)
303 : UnaryDatasetOpKernel(ctx) {}
304
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)305 void RepeatDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
306 DatasetBase** output) {
307 // Create a new RepeatDatasetOp::Dataset, insert it in the step-local
308 // container, and return it as the output.
309 int64_t count;
310 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
311 *output = new Dataset(ctx, count, input);
312 }
313
314 namespace {
315 REGISTER_KERNEL_BUILDER(Name("RepeatDataset").Device(DEVICE_CPU),
316 RepeatDatasetOp);
317 } // namespace
318 } // namespace data
319 } // namespace tensorflow
320