• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/dataset.h"
17 #include "tensorflow/core/framework/partial_tensor_shape.h"
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/kernels/data/window_dataset.h"
20 
21 namespace tensorflow {
22 namespace data {
23 namespace {
24 
25 // See documentation in ../../ops/dataset_ops.cc for a high-level
26 // description of the following op.
27 
28 class WindowDatasetOp : public UnaryDatasetOpKernel {
29  public:
WindowDatasetOp(OpKernelConstruction * ctx)30   explicit WindowDatasetOp(OpKernelConstruction* ctx)
31       : UnaryDatasetOpKernel(ctx) {}
32 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)33   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
34                    DatasetBase** output) override {
35     int64 window_size = 0;
36     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size));
37     OP_REQUIRES(
38         ctx, window_size > 0,
39         errors::InvalidArgument("Window size must be greater than zero."));
40 
41     int64 window_shift = 0;
42     OP_REQUIRES_OK(ctx,
43                    ParseScalarArgument<int64>(ctx, "shift", &window_shift));
44     OP_REQUIRES(
45         ctx, window_shift > 0,
46         errors::InvalidArgument("Window shift must be greater than zero."));
47 
48     int64 window_stride = 0;
49     OP_REQUIRES_OK(ctx,
50                    ParseScalarArgument<int64>(ctx, "stride", &window_stride));
51     OP_REQUIRES(
52         ctx, window_stride > 0,
53         errors::InvalidArgument("Window stride must be greater than zero."));
54 
55     bool drop_remainder;
56     OP_REQUIRES_OK(
57         ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder));
58 
59     *output = new Dataset(ctx, input, window_size, window_shift, window_stride,
60                           drop_remainder);
61   }
62 
63  private:
64   class Dataset : public DatasetBase {
65    public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 window_size,int64 window_shift,int64 window_stride,bool drop_remainder)66     Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size,
67             int64 window_shift, int64 window_stride, bool drop_remainder)
68         : DatasetBase(DatasetContext(ctx)),
69           input_(input),
70           window_size_(window_size),
71           window_shift_(window_shift),
72           window_stride_(window_stride),
73           drop_remainder_(drop_remainder) {
74       input_->Ref();
75     }
76 
~Dataset()77     ~Dataset() override { input_->Unref(); }
78 
MakeIteratorInternal(const string & prefix) const79     std::unique_ptr<IteratorBase> MakeIteratorInternal(
80         const string& prefix) const override {
81       return absl::make_unique<Iterator>(
82           Iterator::Params{this, strings::StrCat(prefix, "::Window")});
83     }
84 
output_dtypes() const85     const DataTypeVector& output_dtypes() const override {
86       static DataTypeVector* output_dtypes = new DataTypeVector({DT_VARIANT});
87       return *output_dtypes;
88     }
89 
output_shapes() const90     const std::vector<PartialTensorShape>& output_shapes() const override {
91       static std::vector<PartialTensorShape>* output_shapes =
92           new std::vector<PartialTensorShape>({TensorShape({})});
93       return *output_shapes;
94     }
95 
DebugString() const96     string DebugString() const override {
97       return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_,
98                              window_stride_, drop_remainder_, ")::Dataset");
99     }
100 
Cardinality() const101     int64 Cardinality() const override {
102       int64 n = input_->Cardinality();
103       if (n == kInfiniteCardinality || n == kUnknownCardinality) {
104         return n;
105       }
106       return n / window_shift_ +
107              (n % window_shift_ == 0 || drop_remainder_ ? 0 : 1);
108     }
109 
110    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const111     Status AsGraphDefInternal(SerializationContext* ctx,
112                               DatasetGraphDefBuilder* b,
113                               Node** output) const override {
114       Node* input_graph_node = nullptr;
115       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
116       Node* window_size_node = nullptr;
117       TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node));
118       Node* window_shift_node = nullptr;
119       TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node));
120       Node* window_stride_node = nullptr;
121       TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node));
122       Node* drop_remainder_node = nullptr;
123       TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
124       TF_RETURN_IF_ERROR(
125           b->AddDataset(this,
126                         {input_graph_node, window_size_node, window_shift_node,
127                          window_stride_node, drop_remainder_node},
128                         output));
129       return Status::OK();
130     }
131 
132    private:
133     class Iterator : public DatasetIterator<Dataset> {
134      public:
Iterator(const Params & params)135       explicit Iterator(const Params& params)
136           : DatasetIterator<Dataset>(params) {}
137 
Initialize(IteratorContext * ctx)138       Status Initialize(IteratorContext* ctx) override {
139         return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
140       }
141 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)142       Status GetNextInternal(IteratorContext* ctx,
143                              std::vector<Tensor>* out_tensors,
144                              bool* end_of_sequence) override {
145         const int64 window_size = dataset()->window_size_;
146         const int64 window_shift = dataset()->window_shift_;
147         const int64 window_stride = dataset()->window_stride_;
148         std::vector<std::vector<Tensor>> window_elements;
149         Status status = Status::OK();
150         {
151           mutex_lock l(mu_);
152           if (!input_impl_ && buffer_.empty()) {
153             *end_of_sequence = true;
154             return Status::OK();
155           }
156 
157           // Add elements to the buffer.
158           size_t target_size = TargetBufferSize(window_size, window_stride);
159           if (input_impl_) {
160             *end_of_sequence = false;
161             for (size_t i = buffer_.size();
162                  i < target_size && !*end_of_sequence; ++i) {
163               std::vector<Tensor> element;
164               Status status =
165                   input_impl_->GetNext(ctx, &element, end_of_sequence);
166               if (!*end_of_sequence) {
167                 RecordBufferEnqueue(ctx, element);
168                 buffer_.emplace_back(std::move(element), status);
169               } else {
170                 input_impl_.reset();
171               }
172             }
173           }
174 
175           // If there are not enough elements and `drop_remainder` is set, we do
176           // not wish to return a smaller window.
177           if (buffer_.empty() ||
178               (dataset()->drop_remainder_ && buffer_.size() < target_size)) {
179             DCHECK(*end_of_sequence);
180             return Status::OK();
181           }
182 
183           int num_elements = 1 + (buffer_.size() - 1) / window_stride;
184           window_elements.reserve(num_elements);
185           for (size_t i = 0; i < num_elements; ++i) {
186             status.Update(buffer_[window_stride * i].status);
187             if (!status.ok()) {
188               break;
189             }
190             window_elements.emplace_back(buffer_[window_stride * i].result);
191           }
192 
193           // Shift the window, discarding elements if necessary.
194           int buffer_size = buffer_.size();
195           if (window_shift >= buffer_size) {
196             for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) {
197               bool end_of_input;
198               std::vector<Tensor> element;
199               // Ignore non-error status of discarded elements.
200               input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError();
201               if (end_of_input) {
202                 input_impl_.reset();
203               }
204             }
205             for (size_t i = 0; i < buffer_.size(); ++i) {
206               RecordBufferDequeue(ctx, buffer_.at(i).result);
207             }
208             buffer_.clear();
209           } else {
210             for (size_t i = 0; i < window_shift; ++i) {
211               RecordBufferDequeue(ctx, buffer_.at(i).result);
212             }
213             buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift);
214           }
215         }
216 
217         if (!status.ok()) {
218           return status;
219         }
220 
221         // Construct output tensors.
222         const size_t num_tuple_components = window_elements[0].size();
223         const int64 num_window_elements = window_elements.size();
224         *end_of_sequence = false;
225         for (size_t idx = 0; idx < num_tuple_components; ++idx) {
226           DatasetBase* window_dataset;
227           std::vector<std::vector<Tensor>> window_component_elements;
228           window_component_elements.reserve(num_window_elements);
229           // Build the output tuple component by copying one slice
230           // from each input element in the window.
231           for (size_t i = 0; i < num_window_elements; ++i) {
232             std::vector<Tensor> component_element;
233             component_element.push_back(std::move(window_elements[i][idx]));
234             window_component_elements.push_back(component_element);
235           }
236           DataTypeVector output_types(
237               {dataset()->input_->output_dtypes()[idx]});
238           std::vector<PartialTensorShape> output_shapes(
239               {dataset()->input_->output_shapes()[idx]});
240           TF_RETURN_IF_ERROR(NewWindowDataset(window_component_elements,
241                                               output_types, output_shapes,
242                                               &window_dataset));
243           out_tensors->emplace_back(DT_VARIANT, TensorShape({}));
244           TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset,
245                                                          &out_tensors->back()));
246         }
247         return Status::OK();
248       }
249 
250      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const251       std::shared_ptr<model::Node> CreateNode(
252           IteratorContext* ctx, model::Node::Args args) const override {
253         return model::MakeKnownRatioNode(std::move(args),
254                                          dataset()->window_shift_);
255       }
256 
SaveInternal(IteratorStateWriter * writer)257       Status SaveInternal(IteratorStateWriter* writer) override {
258         mutex_lock l(mu_);
259         if (!input_impl_) {
260           TF_RETURN_IF_ERROR(
261               writer->WriteScalar(full_name("input_impl_empty"), ""));
262         } else {
263           TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
264         }
265         // Save buffer.
266         TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"),
267                                                buffer_.size()));
268         for (int64 i = 0; i < buffer_.size(); i++) {
269           TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status));
270           TF_RETURN_IF_ERROR(
271               writer->WriteScalar(strings::StrCat("buffer[", i, "].size"),
272                                   buffer_[i].result.size()));
273           for (int64 j = 0; j < buffer_[i].result.size(); j++) {
274             TF_RETURN_IF_ERROR(
275                 writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"),
276                                     buffer_[i].result[j]));
277           }
278         }
279         return Status::OK();
280       }
281 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)282       Status RestoreInternal(IteratorContext* ctx,
283                              IteratorStateReader* reader) override {
284         mutex_lock l(mu_);
285         if (!reader->Contains(full_name("input_impl_empty"))) {
286           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
287         } else {
288           input_impl_.reset();
289         }
290         // Restore buffer.
291         int64 buffer_size;
292         TF_RETURN_IF_ERROR(
293             reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size));
294         buffer_.resize(buffer_size);
295         for (int64 i = 0; i < buffer_size; i++) {
296           int64 vector_size;
297           TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status));
298           TF_RETURN_IF_ERROR(reader->ReadScalar(
299               strings::StrCat("buffer[", i, "].size"), &vector_size));
300           buffer_[i].result.resize(vector_size);
301           for (int64 j = 0; j < vector_size; j++) {
302             TF_RETURN_IF_ERROR(
303                 reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"),
304                                    &buffer_[i].result[j]));
305           }
306         }
307         return Status::OK();
308       }
309 
310      private:
311       struct InvocationResult {
312         InvocationResult() = default;
InvocationResulttensorflow::data::__anon256b48110111::WindowDatasetOp::Dataset::Iterator::InvocationResult313         InvocationResult(std::vector<Tensor>&& result, const Status& status)
314             : result(result), status(status) {}
315 
316         std::vector<Tensor> result;
317         Status status;
318       };
319 
WriteStatusLocked(IteratorStateWriter * writer,size_t index,const Status & status)320       Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
321                                const Status& status)
322           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
323         TF_RETURN_IF_ERROR(writer->WriteScalar(
324             CodeKey(index), static_cast<int64>(status.code())));
325         if (!status.ok()) {
326           TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
327                                                  status.error_message()));
328         }
329         return Status::OK();
330       }
331 
ReadStatusLocked(IteratorStateReader * reader,size_t index,Status * status)332       Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
333                               Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
334         int64 code_int;
335         TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
336         error::Code code = static_cast<error::Code>(code_int);
337 
338         if (code != error::Code::OK) {
339           string error_message;
340           TF_RETURN_IF_ERROR(
341               reader->ReadScalar(ErrorMessageKey(index), &error_message));
342           *status = Status(code, error_message);
343         } else {
344           *status = Status::OK();
345         }
346         return Status::OK();
347       }
348 
CodeKey(size_t index)349       string CodeKey(size_t index) {
350         return full_name(strings::StrCat("buffer[", index, "].code"));
351       }
352 
ErrorMessageKey(size_t index)353       string ErrorMessageKey(size_t index) {
354         return full_name(strings::StrCat("buffer[", index, "].error_message"));
355       }
356 
TargetBufferSize(int64 window_size,int64 window_stride)357       size_t TargetBufferSize(int64 window_size, int64 window_stride) {
358         return (window_size - 1) * window_stride + 1;
359       }
360 
361       mutex mu_;
362       std::deque<InvocationResult> buffer_ GUARDED_BY(mu_);
363       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
364     };
365 
366     const DatasetBase* const input_;
367     const int64 window_size_;
368     const int64 window_shift_;
369     const int64 window_stride_;
370     const bool drop_remainder_;
371   };
372 };
373 
374 REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
375                         WindowDatasetOp);
376 }  // namespace
377 }  // namespace data
378 }  // namespace tensorflow
379