• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/interleave_dataset_op.h"
16 
17 #include "tensorflow/core/common_runtime/function.h"
18 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
19 #include "tensorflow/core/framework/model.h"
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/kernels/data/dataset_utils.h"
23 #include "tensorflow/core/kernels/data/name_utils.h"
24 #include "tensorflow/core/lib/random/random.h"
25 #include "tensorflow/core/platform/cpu_info.h"
26 #include "tensorflow/core/platform/stringprintf.h"
27 
28 namespace tensorflow {
29 namespace data {
30 
31 // See documentation in ../../ops/dataset_ops.cc for a high-level
32 // description of the following op.
33 
34 /* static */ constexpr const char* const InterleaveDatasetOp::kDatasetType;
35 /* static */ constexpr const char* const InterleaveDatasetOp::kInputDataset;
36 /* static */ constexpr const char* const InterleaveDatasetOp::kOtherArguments;
37 /* static */ constexpr const char* const InterleaveDatasetOp::kCycleLength;
38 /* static */ constexpr const char* const InterleaveDatasetOp::kBlockLength;
39 /* static */ constexpr const char* const InterleaveDatasetOp::kFunc;
40 /* static */ constexpr const char* const InterleaveDatasetOp::kTarguments;
41 /* static */ constexpr const char* const InterleaveDatasetOp::kOutputTypes;
42 /* static */ constexpr const char* const InterleaveDatasetOp::kOutputShapes;
43 
44 constexpr char kCycleIndex[] = "cycle_index";
45 constexpr char kBlockIndex[] = "block_index";
46 constexpr char kEndOfInput[] = "end_of_input";
47 constexpr char kNumOpen[] = "num_open";
48 constexpr char kArgsSize[] = "args_size";
49 constexpr char kArgsList[] = "args_list_";
50 
51 class InterleaveDatasetOp::Dataset : public DatasetBase {
52  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,int64 cycle_length,int64 block_length,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)53   Dataset(OpKernelContext* ctx, const DatasetBase* input,
54           std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
55           int64 block_length, const DataTypeVector& output_types,
56           const std::vector<PartialTensorShape>& output_shapes)
57       : DatasetBase(DatasetContext(ctx)),
58         input_(input),
59         captured_func_(std::move(captured_func)),
60         cycle_length_(cycle_length),
61         block_length_(block_length),
62         output_types_(output_types),
63         output_shapes_(output_shapes),
64         traceme_metadata_(
65             {{"block_length",
66               strings::Printf("%lld", static_cast<long long>(block_length))},
67              {"cycle_length",
68               strings::Printf("%lld", static_cast<long long>(cycle_length))}}) {
69     input_->Ref();
70   }
71 
~Dataset()72   ~Dataset() override { input_->Unref(); }
73 
MakeIteratorInternal(const string & prefix) const74   std::unique_ptr<IteratorBase> MakeIteratorInternal(
75       const string& prefix) const override {
76     return absl::make_unique<Iterator>(Iterator::Params{
77         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
78   }
79 
output_dtypes() const80   const DataTypeVector& output_dtypes() const override { return output_types_; }
81 
output_shapes() const82   const std::vector<PartialTensorShape>& output_shapes() const override {
83     return output_shapes_;
84   }
85 
DebugString() const86   string DebugString() const override {
87     return name_utils::DatasetDebugString(kDatasetType);
88   }
89 
InputDatasets(std::vector<const DatasetBase * > * inputs) const90   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
91     inputs->push_back(input_);
92     return Status::OK();
93   }
94 
CheckExternalState() const95   Status CheckExternalState() const override {
96     TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
97     return input_->CheckExternalState();
98   }
99 
100  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const101   Status AsGraphDefInternal(SerializationContext* ctx,
102                             DatasetGraphDefBuilder* b,
103                             Node** output) const override {
104     Node* input_node;
105     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
106     Node* cycle_length_node;
107     TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
108     Node* block_length_node;
109     TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
110     std::vector<Node*> other_arguments;
111     DataTypeVector other_arguments_types;
112     TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
113                                                   &other_arguments_types));
114     AttrValue f;
115     b->BuildAttrValue(captured_func_->func(), &f);
116     AttrValue other_arguments_types_attr;
117     b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
118 
119     TF_RETURN_IF_ERROR(b->AddDataset(
120         this, {{0, input_node}, {2, cycle_length_node}, {3, block_length_node}},
121         {{1, other_arguments}},
122         {{kFunc, f}, {kTarguments, other_arguments_types_attr}}, output));
123     return Status::OK();
124   }
125 
126  private:
127   class Iterator : public DatasetIterator<Dataset> {
128    public:
Iterator(const Params & params)129     explicit Iterator(const Params& params)
130         : DatasetIterator<Dataset>(params),
131           current_elements_(params.dataset->cycle_length_),
132           args_list_(params.dataset->cycle_length_) {}
133 
Initialize(IteratorContext * ctx)134     Status Initialize(IteratorContext* ctx) override {
135       TF_RETURN_IF_ERROR(
136           dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
137       return dataset()->captured_func_->Instantiate(
138           ctx, &instantiated_captured_func_);
139     }
140 
AdvanceToNextInCycle()141     void AdvanceToNextInCycle() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
142       block_index_ = 0;
143       cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
144     }
145 
AdvancePosition()146     void AdvancePosition() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
147       ++block_index_;
148       if (block_index_ == dataset()->block_length_) {
149         AdvanceToNextInCycle();
150       }
151     }
152 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)153     Status GetNextInternal(IteratorContext* ctx,
154                            std::vector<Tensor>* out_tensors,
155                            bool* end_of_sequence) override {
156       mutex_lock l(mu_);
157       while (!end_of_input_ || num_open_ > 0) {
158         if (current_elements_[cycle_index_]) {
159           // We are currently processing a mapped element, so try to get the
160           // next subelement.
161           bool end_of_element;
162           TF_RETURN_IF_ERROR(current_elements_[cycle_index_]->GetNext(
163               ctx, out_tensors, &end_of_element));
164           if (!end_of_element) {
165             // Produce the subelement as output.
166             AdvancePosition();
167             *end_of_sequence = false;
168             return Status::OK();
169           }
170           // We have reached the end of the current element, so move
171           // on to the next element in the cycle.
172           current_elements_[cycle_index_].reset();
173           args_list_[cycle_index_].clear();
174           --num_open_;
175           AdvanceToNextInCycle();
176         } else if (!end_of_input_) {
177           // Get the next element from the input dataset, and create
178           // an iterator from it.
179           TF_RETURN_IF_ERROR(input_impl_->GetNext(
180               ctx, &args_list_[cycle_index_], &end_of_input_));
181           if (!end_of_input_) {
182             TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
183                 ctx, this, args_list_[cycle_index_], cycle_index_,
184                 *instantiated_captured_func_, prefix(),
185                 &current_elements_[cycle_index_], model_node()));
186             ++num_open_;
187           }
188         } else {
189           AdvanceToNextInCycle();
190         }
191       }
192 
193       *end_of_sequence = true;
194       return Status::OK();
195     }
196 
197    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const198     std::shared_ptr<model::Node> CreateNode(
199         IteratorContext* ctx, model::Node::Args args) const override {
200       return model::MakeInterleaveManyNode(std::move(args));
201     }
202 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)203     Status SaveInternal(SerializationContext* ctx,
204                         IteratorStateWriter* writer) override {
205       TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
206           dataset()->captured_func_->CheckExternalState()));
207       mutex_lock l(mu_);
208       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
209       TF_RETURN_IF_ERROR(
210           writer->WriteScalar(full_name(kCycleIndex), cycle_index_));
211       TF_RETURN_IF_ERROR(
212           writer->WriteScalar(full_name(kBlockIndex), block_index_));
213       if (end_of_input_) {
214         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kEndOfInput), ""));
215       }
216       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumOpen), num_open_));
217       TF_RETURN_IF_ERROR(SaveCurrentElements(ctx, writer));
218       return Status::OK();
219     }
220 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)221     Status RestoreInternal(IteratorContext* ctx,
222                            IteratorStateReader* reader) override {
223       mutex_lock l(mu_);
224       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
225       int64 cycle_index;
226       TF_RETURN_IF_ERROR(
227           reader->ReadScalar(full_name(kCycleIndex), &cycle_index));
228       cycle_index_ = size_t(cycle_index);
229       TF_RETURN_IF_ERROR(
230           reader->ReadScalar(full_name(kBlockIndex), &block_index_));
231       if (reader->Contains(full_name(kEndOfInput))) end_of_input_ = true;
232       int64 num_open;
233       TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumOpen), &num_open));
234       num_open_ = size_t(num_open);
235       TF_RETURN_IF_ERROR(RestoreCurrentElements(ctx, reader));
236       return Status::OK();
237     }
238 
GetTraceMeMetadata() const239     TraceMeMetadata GetTraceMeMetadata() const override {
240       return dataset()->traceme_metadata_;
241     }
242 
243    private:
SaveCurrentElements(SerializationContext * ctx,IteratorStateWriter * writer)244     Status SaveCurrentElements(SerializationContext* ctx,
245                                IteratorStateWriter* writer)
246         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
247       for (int idx = 0; idx < current_elements_.size(); idx++) {
248         if (current_elements_[idx]) {
249           TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_elements_[idx]));
250           TF_RETURN_IF_ERROR(writer->WriteScalar(
251               full_name(strings::StrCat(kArgsSize, "[", idx, "]")),
252               args_list_[idx].size()));
253           for (int i = 0; i < args_list_[idx].size(); i++) {
254             TF_RETURN_IF_ERROR(writer->WriteTensor(
255                 full_name(strings::StrCat(kArgsList, "[", idx, "][", i, "]")),
256                 args_list_[idx][i]));
257           }
258         }
259       }
260       return Status::OK();
261     }
262 
RestoreCurrentElements(IteratorContext * ctx,IteratorStateReader * reader)263     Status RestoreCurrentElements(IteratorContext* ctx,
264                                   IteratorStateReader* reader)
265         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
266       for (int idx = 0; idx < current_elements_.size(); idx++) {
267         if (reader->Contains(
268                 full_name(strings::StrCat(kArgsSize, "[", idx, "]")))) {
269           int64 args_size;
270           TF_RETURN_IF_ERROR(reader->ReadScalar(
271               full_name(strings::StrCat(kArgsSize, "[", idx, "]")),
272               &args_size));
273           args_list_[idx].resize(args_size);
274           for (int i = 0; i < args_size; i++) {
275             TF_RETURN_IF_ERROR(reader->ReadTensor(
276                 full_name(strings::StrCat(kArgsList, "[", idx, "][", i, "]")),
277                 &args_list_[idx][i]));
278           }
279           // NOTE: We intentionally ignore resource modeling outside GetNext().
280           TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
281               ctx, this, args_list_[idx], idx, *instantiated_captured_func_,
282               prefix(), &current_elements_[idx], /*node=*/nullptr));
283           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_elements_[idx]));
284         } else {
285           current_elements_[idx].reset();
286         }
287       }
288       return Status::OK();
289     }
290 
291     mutex mu_;
292     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
293     std::vector<std::unique_ptr<IteratorBase>> current_elements_
294         TF_GUARDED_BY(mu_);
295     std::vector<std::vector<Tensor>> args_list_ TF_GUARDED_BY(mu_);
296     size_t cycle_index_ TF_GUARDED_BY(mu_) = 0;
297     int64 block_index_ TF_GUARDED_BY(mu_) = 0;
298     bool end_of_input_ TF_GUARDED_BY(mu_) = false;
299     size_t num_open_ TF_GUARDED_BY(mu_) = 0;
300     std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
301   };
302 
303   const DatasetBase* const input_;
304   const std::unique_ptr<CapturedFunction> captured_func_;
305   const int64 cycle_length_;
306   const int64 block_length_;
307   const DataTypeVector output_types_;
308   const std::vector<PartialTensorShape> output_shapes_;
309   const TraceMeMetadata traceme_metadata_;
310 };
311 
InterleaveDatasetOp(OpKernelConstruction * ctx)312 InterleaveDatasetOp::InterleaveDatasetOp(OpKernelConstruction* ctx)
313     : UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
314   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
315                                                &func_metadata_));
316   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
317   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
318 }
319 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)320 void InterleaveDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
321                                       DatasetBase** output) {
322   int64 cycle_length = 0;
323   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
324   if (cycle_length == model::kAutotune) {
325     cycle_length = port::MaxParallelism();
326   }
327   OP_REQUIRES(
328       ctx, cycle_length > 0,
329       errors::InvalidArgument("cycle_length must be greater than zero."));
330 
331   int64 block_length = 0;
332   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
333   OP_REQUIRES(
334       ctx, block_length > 0,
335       errors::InvalidArgument("block_length must be greater than zero."));
336 
337   std::unique_ptr<CapturedFunction> captured_func;
338   OP_REQUIRES_OK(ctx,
339                  CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
340                                           &captured_func));
341 
342   *output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
343                         block_length, output_types_, output_shapes_);
344 }
345 
346 namespace {
347 REGISTER_KERNEL_BUILDER(Name("InterleaveDataset").Device(DEVICE_CPU),
348                         InterleaveDatasetOp);
349 REGISTER_INPUT_COLOCATION_EXEMPTION("InterleaveDataset");
350 }  // namespace
351 }  // namespace data
352 }  // namespace tensorflow
353