1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/interleave_dataset_op.h"
16
17 #include "tensorflow/core/common_runtime/function.h"
18 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
19 #include "tensorflow/core/framework/model.h"
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/kernels/data/dataset_utils.h"
23 #include "tensorflow/core/kernels/data/name_utils.h"
24 #include "tensorflow/core/lib/random/random.h"
25 #include "tensorflow/core/platform/cpu_info.h"
26 #include "tensorflow/core/platform/stringprintf.h"
27
28 namespace tensorflow {
29 namespace data {
30
31 // See documentation in ../../ops/dataset_ops.cc for a high-level
32 // description of the following op.
33
34 /* static */ constexpr const char* const InterleaveDatasetOp::kDatasetType;
35 /* static */ constexpr const char* const InterleaveDatasetOp::kInputDataset;
36 /* static */ constexpr const char* const InterleaveDatasetOp::kOtherArguments;
37 /* static */ constexpr const char* const InterleaveDatasetOp::kCycleLength;
38 /* static */ constexpr const char* const InterleaveDatasetOp::kBlockLength;
39 /* static */ constexpr const char* const InterleaveDatasetOp::kFunc;
40 /* static */ constexpr const char* const InterleaveDatasetOp::kTarguments;
41 /* static */ constexpr const char* const InterleaveDatasetOp::kOutputTypes;
42 /* static */ constexpr const char* const InterleaveDatasetOp::kOutputShapes;
43
44 constexpr char kCycleIndex[] = "cycle_index";
45 constexpr char kBlockIndex[] = "block_index";
46 constexpr char kEndOfInput[] = "end_of_input";
47 constexpr char kNumOpen[] = "num_open";
48 constexpr char kArgsSize[] = "args_size";
49 constexpr char kArgsList[] = "args_list_";
50
51 class InterleaveDatasetOp::Dataset : public DatasetBase {
52 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,int64 cycle_length,int64 block_length,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)53 Dataset(OpKernelContext* ctx, const DatasetBase* input,
54 std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
55 int64 block_length, const DataTypeVector& output_types,
56 const std::vector<PartialTensorShape>& output_shapes)
57 : DatasetBase(DatasetContext(ctx)),
58 input_(input),
59 captured_func_(std::move(captured_func)),
60 cycle_length_(cycle_length),
61 block_length_(block_length),
62 output_types_(output_types),
63 output_shapes_(output_shapes),
64 traceme_metadata_(
65 {{"block_length",
66 strings::Printf("%lld", static_cast<long long>(block_length))},
67 {"cycle_length",
68 strings::Printf("%lld", static_cast<long long>(cycle_length))}}) {
69 input_->Ref();
70 }
71
~Dataset()72 ~Dataset() override { input_->Unref(); }
73
MakeIteratorInternal(const string & prefix) const74 std::unique_ptr<IteratorBase> MakeIteratorInternal(
75 const string& prefix) const override {
76 return absl::make_unique<Iterator>(Iterator::Params{
77 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
78 }
79
output_dtypes() const80 const DataTypeVector& output_dtypes() const override { return output_types_; }
81
output_shapes() const82 const std::vector<PartialTensorShape>& output_shapes() const override {
83 return output_shapes_;
84 }
85
DebugString() const86 string DebugString() const override {
87 return name_utils::DatasetDebugString(kDatasetType);
88 }
89
InputDatasets(std::vector<const DatasetBase * > * inputs) const90 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
91 inputs->push_back(input_);
92 return Status::OK();
93 }
94
CheckExternalState() const95 Status CheckExternalState() const override {
96 TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
97 return input_->CheckExternalState();
98 }
99
100 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const101 Status AsGraphDefInternal(SerializationContext* ctx,
102 DatasetGraphDefBuilder* b,
103 Node** output) const override {
104 Node* input_node;
105 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
106 Node* cycle_length_node;
107 TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
108 Node* block_length_node;
109 TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
110 std::vector<Node*> other_arguments;
111 DataTypeVector other_arguments_types;
112 TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
113 &other_arguments_types));
114 AttrValue f;
115 b->BuildAttrValue(captured_func_->func(), &f);
116 AttrValue other_arguments_types_attr;
117 b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
118
119 TF_RETURN_IF_ERROR(b->AddDataset(
120 this, {{0, input_node}, {2, cycle_length_node}, {3, block_length_node}},
121 {{1, other_arguments}},
122 {{kFunc, f}, {kTarguments, other_arguments_types_attr}}, output));
123 return Status::OK();
124 }
125
126 private:
127 class Iterator : public DatasetIterator<Dataset> {
128 public:
Iterator(const Params & params)129 explicit Iterator(const Params& params)
130 : DatasetIterator<Dataset>(params),
131 current_elements_(params.dataset->cycle_length_),
132 args_list_(params.dataset->cycle_length_) {}
133
Initialize(IteratorContext * ctx)134 Status Initialize(IteratorContext* ctx) override {
135 TF_RETURN_IF_ERROR(
136 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
137 return dataset()->captured_func_->Instantiate(
138 ctx, &instantiated_captured_func_);
139 }
140
AdvanceToNextInCycle()141 void AdvanceToNextInCycle() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
142 block_index_ = 0;
143 cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
144 }
145
AdvancePosition()146 void AdvancePosition() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
147 ++block_index_;
148 if (block_index_ == dataset()->block_length_) {
149 AdvanceToNextInCycle();
150 }
151 }
152
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)153 Status GetNextInternal(IteratorContext* ctx,
154 std::vector<Tensor>* out_tensors,
155 bool* end_of_sequence) override {
156 mutex_lock l(mu_);
157 while (!end_of_input_ || num_open_ > 0) {
158 if (current_elements_[cycle_index_]) {
159 // We are currently processing a mapped element, so try to get the
160 // next subelement.
161 bool end_of_element;
162 TF_RETURN_IF_ERROR(current_elements_[cycle_index_]->GetNext(
163 ctx, out_tensors, &end_of_element));
164 if (!end_of_element) {
165 // Produce the subelement as output.
166 AdvancePosition();
167 *end_of_sequence = false;
168 return Status::OK();
169 }
170 // We have reached the end of the current element, so move
171 // on to the next element in the cycle.
172 current_elements_[cycle_index_].reset();
173 args_list_[cycle_index_].clear();
174 --num_open_;
175 AdvanceToNextInCycle();
176 } else if (!end_of_input_) {
177 // Get the next element from the input dataset, and create
178 // an iterator from it.
179 TF_RETURN_IF_ERROR(input_impl_->GetNext(
180 ctx, &args_list_[cycle_index_], &end_of_input_));
181 if (!end_of_input_) {
182 TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
183 ctx, this, args_list_[cycle_index_], cycle_index_,
184 *instantiated_captured_func_, prefix(),
185 ¤t_elements_[cycle_index_], model_node()));
186 ++num_open_;
187 }
188 } else {
189 AdvanceToNextInCycle();
190 }
191 }
192
193 *end_of_sequence = true;
194 return Status::OK();
195 }
196
197 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const198 std::shared_ptr<model::Node> CreateNode(
199 IteratorContext* ctx, model::Node::Args args) const override {
200 return model::MakeInterleaveManyNode(std::move(args));
201 }
202
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)203 Status SaveInternal(SerializationContext* ctx,
204 IteratorStateWriter* writer) override {
205 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
206 dataset()->captured_func_->CheckExternalState()));
207 mutex_lock l(mu_);
208 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
209 TF_RETURN_IF_ERROR(
210 writer->WriteScalar(full_name(kCycleIndex), cycle_index_));
211 TF_RETURN_IF_ERROR(
212 writer->WriteScalar(full_name(kBlockIndex), block_index_));
213 if (end_of_input_) {
214 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kEndOfInput), ""));
215 }
216 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumOpen), num_open_));
217 TF_RETURN_IF_ERROR(SaveCurrentElements(ctx, writer));
218 return Status::OK();
219 }
220
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)221 Status RestoreInternal(IteratorContext* ctx,
222 IteratorStateReader* reader) override {
223 mutex_lock l(mu_);
224 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
225 int64 cycle_index;
226 TF_RETURN_IF_ERROR(
227 reader->ReadScalar(full_name(kCycleIndex), &cycle_index));
228 cycle_index_ = size_t(cycle_index);
229 TF_RETURN_IF_ERROR(
230 reader->ReadScalar(full_name(kBlockIndex), &block_index_));
231 if (reader->Contains(full_name(kEndOfInput))) end_of_input_ = true;
232 int64 num_open;
233 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumOpen), &num_open));
234 num_open_ = size_t(num_open);
235 TF_RETURN_IF_ERROR(RestoreCurrentElements(ctx, reader));
236 return Status::OK();
237 }
238
GetTraceMeMetadata() const239 TraceMeMetadata GetTraceMeMetadata() const override {
240 return dataset()->traceme_metadata_;
241 }
242
243 private:
SaveCurrentElements(SerializationContext * ctx,IteratorStateWriter * writer)244 Status SaveCurrentElements(SerializationContext* ctx,
245 IteratorStateWriter* writer)
246 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
247 for (int idx = 0; idx < current_elements_.size(); idx++) {
248 if (current_elements_[idx]) {
249 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_elements_[idx]));
250 TF_RETURN_IF_ERROR(writer->WriteScalar(
251 full_name(strings::StrCat(kArgsSize, "[", idx, "]")),
252 args_list_[idx].size()));
253 for (int i = 0; i < args_list_[idx].size(); i++) {
254 TF_RETURN_IF_ERROR(writer->WriteTensor(
255 full_name(strings::StrCat(kArgsList, "[", idx, "][", i, "]")),
256 args_list_[idx][i]));
257 }
258 }
259 }
260 return Status::OK();
261 }
262
RestoreCurrentElements(IteratorContext * ctx,IteratorStateReader * reader)263 Status RestoreCurrentElements(IteratorContext* ctx,
264 IteratorStateReader* reader)
265 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
266 for (int idx = 0; idx < current_elements_.size(); idx++) {
267 if (reader->Contains(
268 full_name(strings::StrCat(kArgsSize, "[", idx, "]")))) {
269 int64 args_size;
270 TF_RETURN_IF_ERROR(reader->ReadScalar(
271 full_name(strings::StrCat(kArgsSize, "[", idx, "]")),
272 &args_size));
273 args_list_[idx].resize(args_size);
274 for (int i = 0; i < args_size; i++) {
275 TF_RETURN_IF_ERROR(reader->ReadTensor(
276 full_name(strings::StrCat(kArgsList, "[", idx, "][", i, "]")),
277 &args_list_[idx][i]));
278 }
279 // NOTE: We intentionally ignore resource modeling outside GetNext().
280 TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
281 ctx, this, args_list_[idx], idx, *instantiated_captured_func_,
282 prefix(), ¤t_elements_[idx], /*node=*/nullptr));
283 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_elements_[idx]));
284 } else {
285 current_elements_[idx].reset();
286 }
287 }
288 return Status::OK();
289 }
290
291 mutex mu_;
292 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
293 std::vector<std::unique_ptr<IteratorBase>> current_elements_
294 TF_GUARDED_BY(mu_);
295 std::vector<std::vector<Tensor>> args_list_ TF_GUARDED_BY(mu_);
296 size_t cycle_index_ TF_GUARDED_BY(mu_) = 0;
297 int64 block_index_ TF_GUARDED_BY(mu_) = 0;
298 bool end_of_input_ TF_GUARDED_BY(mu_) = false;
299 size_t num_open_ TF_GUARDED_BY(mu_) = 0;
300 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
301 };
302
303 const DatasetBase* const input_;
304 const std::unique_ptr<CapturedFunction> captured_func_;
305 const int64 cycle_length_;
306 const int64 block_length_;
307 const DataTypeVector output_types_;
308 const std::vector<PartialTensorShape> output_shapes_;
309 const TraceMeMetadata traceme_metadata_;
310 };
311
InterleaveDatasetOp(OpKernelConstruction * ctx)312 InterleaveDatasetOp::InterleaveDatasetOp(OpKernelConstruction* ctx)
313 : UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
314 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
315 &func_metadata_));
316 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
317 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
318 }
319
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)320 void InterleaveDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
321 DatasetBase** output) {
322 int64 cycle_length = 0;
323 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
324 if (cycle_length == model::kAutotune) {
325 cycle_length = port::MaxParallelism();
326 }
327 OP_REQUIRES(
328 ctx, cycle_length > 0,
329 errors::InvalidArgument("cycle_length must be greater than zero."));
330
331 int64 block_length = 0;
332 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
333 OP_REQUIRES(
334 ctx, block_length > 0,
335 errors::InvalidArgument("block_length must be greater than zero."));
336
337 std::unique_ptr<CapturedFunction> captured_func;
338 OP_REQUIRES_OK(ctx,
339 CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
340 &captured_func));
341
342 *output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
343 block_length, output_types_, output_shapes_);
344 }
345
346 namespace {
347 REGISTER_KERNEL_BUILDER(Name("InterleaveDataset").Device(DEVICE_CPU),
348 InterleaveDatasetOp);
349 REGISTER_INPUT_COLOCATION_EXEMPTION("InterleaveDataset");
350 } // namespace
351 } // namespace data
352 } // namespace tensorflow
353