1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/iterator_ops.h"
16
17 #include <cstdint>
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "absl/memory/memory.h"
24 #include "absl/time/time.h"
25 #include "tensorflow/core/activity_watcher/activity.h"
26 #include "tensorflow/core/activity_watcher/activity_utils.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/common_runtime/graph_runner.h"
29 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
30 #include "tensorflow/core/common_runtime/renamed_device.h"
31 #include "tensorflow/core/common_runtime/threadpool_device.h"
32 #include "tensorflow/core/data/captured_function.h"
33 #include "tensorflow/core/data/dataset_utils.h"
34 #include "tensorflow/core/data/finalization_utils.h"
35 #include "tensorflow/core/data/metric_utils.h"
36 #include "tensorflow/core/data/root_dataset.h"
37 #include "tensorflow/core/data/serialization_utils.h"
38 #include "tensorflow/core/data/utils.h"
39 #include "tensorflow/core/framework/cancellation.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/metrics.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/framework/partial_tensor_shape.h"
44 #include "tensorflow/core/framework/resource_op_kernel.h"
45 #include "tensorflow/core/framework/stats_aggregator.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/types.h"
48 #include "tensorflow/core/framework/variant_op_registry.h"
49 #include "tensorflow/core/framework/variant_tensor_data.h"
50 #include "tensorflow/core/kernels/data/optional_ops.h"
51 #include "tensorflow/core/kernels/ops_util.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/lib/core/refcount.h"
54 #include "tensorflow/core/lib/core/threadpool.h"
55 #include "tensorflow/core/lib/gtl/cleanup.h"
56 #include "tensorflow/core/lib/random/random.h"
57 #include "tensorflow/core/lib/strings/strcat.h"
58 #include "tensorflow/core/lib/strings/stringprintf.h"
59 #include "tensorflow/core/platform/casts.h"
60 #include "tensorflow/core/platform/env.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/mem.h"
63 #include "tensorflow/core/platform/mutex.h"
64 #include "tensorflow/core/platform/refcount.h"
65 #include "tensorflow/core/platform/resource.h"
66 #include "tensorflow/core/profiler/lib/traceme.h"
67 #include "tensorflow/core/profiler/lib/traceme_encode.h"
68 #include "tensorflow/core/public/session_options.h"
69
70 namespace tensorflow {
71 namespace data {
72 namespace {
73
74 // See documentation in ../../ops/dataset_ops.cc for a high-level
75 // description of the following ops.
76
77 const char kAnonymousIterator[] = "AnonymousIterator";
78 const char kAnonymousIteratorV2[] = "AnonymousIteratorV2";
79 const char kAnonymousIteratorV3[] = "AnonymousIteratorV3";
80 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
81 const char kOutputShapes[] = "output_shapes";
82 const char kOutputTypes[] = "output_types";
83
84 } // namespace
85
86 /* static */ constexpr const char* const
87 SerializeIteratorOp::kExternalStatePolicy;
88
IteratorResource(Env * env,const DataTypeVector & output_dtypes,const std::vector<PartialTensorShape> & output_shapes,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr)89 IteratorResource::IteratorResource(
90 Env* env, const DataTypeVector& output_dtypes,
91 const std::vector<PartialTensorShape>& output_shapes,
92 std::unique_ptr<DeviceMgr> device_mgr,
93 std::unique_ptr<FunctionLibraryDefinition> flib_def,
94 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
95 FunctionLibraryRuntime* flr)
96 : metrics_collector_(flr->device()->device_type(), *env),
97 unbounded_thread_pool_(env, "tf_data_iterator_resource"),
98 device_mgr_(std::move(device_mgr)),
99 iterator_state_(std::make_shared<State>(std::move(flib_def),
100 std::move(pflr), flr,
101 /*iterator=*/nullptr)),
102 output_dtypes_(output_dtypes),
103 output_shapes_(output_shapes) {
104 VLOG(2) << "creating iterator resource";
105 }
106
~IteratorResource()107 IteratorResource::~IteratorResource() {
108 VLOG(2) << "destroying iterator resource";
109 }
110
GetNext(OpKernelContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)111 Status IteratorResource::GetNext(OpKernelContext* ctx,
112 std::vector<Tensor>* out_tensors,
113 bool* end_of_sequence) {
114 std::shared_ptr<State> captured_state;
115 {
116 tf_shared_lock l(mu_);
117 captured_state = iterator_state_;
118 }
119 if (!captured_state->iterator()) {
120 return errors::FailedPrecondition(
121 "GetNext() failed because the iterator has not been initialized. "
122 "Ensure that you have run the initializer operation for this iterator "
123 "before getting the next element.");
124 }
125 IteratorContext::Params params(ctx);
126 params.flr = captured_state->flr();
127 params.function_handle_cache = captured_state->function_handle_cache();
128 params.resource_mgr = captured_state->resource_mgr();
129 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
130 params.thread_pool = &unbounded_thread_pool_;
131 params.cancellation_manager = captured_state->cancellation_manager();
132 std::function<void()> deregister_fn;
133 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
134 ctx->cancellation_manager(),
135 [cm = params.cancellation_manager]() { cm->StartCancel(); },
136 &deregister_fn));
137 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
138
139 const absl::Time start_time = metrics_collector_.RecordStart();
140 auto iterator_ = captured_state->iterator();
141 auto status = iterator_->GetNext(IteratorContext(std::move(params)),
142 out_tensors, end_of_sequence);
143 metrics_collector_.RecordStop(start_time, *out_tensors);
144 return status;
145 }
146
Save(SerializationContext * ctx,IteratorStateWriter * writer)147 Status IteratorResource::Save(SerializationContext* ctx,
148 IteratorStateWriter* writer) {
149 std::shared_ptr<State> captured_state;
150 {
151 tf_shared_lock l(mu_);
152 captured_state = iterator_state_;
153 }
154 auto iterator_ = captured_state->iterator();
155 if (iterator_) {
156 return iterator_->Save(ctx, writer);
157 }
158 return errors::FailedPrecondition(
159 "Save() failed because the iterator has not been initialized. Ensure "
160 "that you have run the initializer operation for this iterator before "
161 "saving it.");
162 }
163
Restore(OpKernelContext * ctx,IteratorStateReader * reader)164 Status IteratorResource::Restore(OpKernelContext* ctx,
165 IteratorStateReader* reader) {
166 const DatasetBase* dataset;
167 std::shared_ptr<State> new_state;
168 const DatasetBase* input_dataset;
169 {
170 tf_shared_lock l(mu_);
171 if (!iterator_state_->iterator()) {
172 return errors::FailedPrecondition(
173 "Restore() failed because the iterator has not been initialized. "
174 "Ensure that you have run the initializer operation for this "
175 "iterator before restoring it.");
176 }
177 auto iterator_ = iterator_state_->iterator();
178 dataset = iterator_->dataset();
179 // Hang onto a reference until we've created the new iterator, which will
180 // then hold its own reference to keep the dataset alive.
181 dataset->Ref();
182 new_state =
183 std::make_shared<State>(iterator_state_->flib_def(),
184 iterator_state_->pflr(), iterator_state_->flr(),
185 /*iterator=*/nullptr);
186 input_dataset = iterator_state_->dataset();
187 }
188 core::ScopedUnref scoped_unref(dataset);
189 IteratorContext::Params params(ctx);
190 params.flr = new_state->flr();
191 params.function_handle_cache = new_state->function_handle_cache();
192 params.resource_mgr = new_state->resource_mgr();
193 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
194 params.thread_pool = &unbounded_thread_pool_;
195 params.cancellation_manager = new_state->cancellation_manager();
196 std::function<void()> deregister_fn;
197 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
198 ctx->cancellation_manager(),
199 [cm = params.cancellation_manager]() { cm->StartCancel(); },
200 &deregister_fn));
201 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
202 std::unique_ptr<IteratorBase> iterator_base;
203 TF_RETURN_IF_ERROR(dataset->MakeIteratorFromCheckpoint(
204 IteratorContext(std::move(params)), "Iterator", reader, &iterator_base));
205 new_state->DowncastAndSetIteratorAndDataset(std::move(iterator_base),
206 input_dataset);
207
208 mutex_lock l(mu_);
209 std::swap(iterator_state_, new_state);
210 return OkStatus();
211 }
212
SetIteratorFromDataset(OpKernelContext * ctx,const DatasetBase * dataset)213 Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
214 const DatasetBase* dataset) {
215 std::shared_ptr<State> new_state;
216 {
217 tf_shared_lock l(mu_);
218 new_state =
219 std::make_shared<State>(iterator_state_->flib_def(),
220 iterator_state_->pflr(), iterator_state_->flr(),
221 /*iterator=*/nullptr);
222 }
223
224 // Create new iterator.
225 IteratorContext::Params params(ctx);
226 params.flr = new_state->flr();
227 params.function_handle_cache = new_state->function_handle_cache();
228 params.resource_mgr = new_state->resource_mgr();
229 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
230 params.thread_pool = &unbounded_thread_pool_;
231 params.cancellation_manager = new_state->cancellation_manager();
232 std::function<void()> deregister_fn;
233 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
234 ctx->cancellation_manager(),
235 [cm = params.cancellation_manager]() { cm->StartCancel(); },
236 &deregister_fn));
237 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
238
239 std::unique_ptr<IteratorBase> iterator;
240 if (ctx->function_library()->device()->device_type() == DEVICE_CPU) {
241 DatasetBase* finalized_dataset;
242 TF_ASSIGN_OR_RETURN(finalized_dataset, GetFinalizedDataset(ctx, dataset));
243 TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
244 IteratorContext(std::move(params)),
245 /*parent=*/nullptr, "Iterator", &iterator));
246 } else {
247 TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
248 /*parent=*/nullptr, "Iterator",
249 &iterator));
250 }
251 TF_RETURN_IF_ERROR(
252 VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
253 TF_RETURN_IF_ERROR(
254 VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
255
256 new_state->DowncastAndSetIteratorAndDataset(std::move(iterator), dataset);
257
258 mutex_lock l(mu_);
259 std::swap(iterator_state_, new_state);
260 return OkStatus();
261 }
262
263 namespace {
264
265 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
266 // The get() method returns an VariantTensorData object which contains all the
267 // state needed to restore a single iterator.
268 //
269 // Usage example:
270 //
271 // Encoding:
272 //
273 // Tensor t(DT_VARIANT, TensorShape({}));
274 // t->scalar<Variant>()() = IteratorStateVariant();
275 //
276 // Encode() sets the type_name of the VariantTensorData object to
277 // IteratorStateVariant::TypeName().
278 //
279 // Decoding:
280 //
281 // Variant v = <VariantTensorDataProto object>;
282 // DecodeUnaryVariant(&v);
283 // IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
284 // IteratorStateReader reader({wrapper->GetData()});
285 // iterator_resource->Restore(ctx, &reader);
286 //
287 // The type_name of the VariantTensorData object to be decoded must
288 // match IteratorStateVariant::TypeName().
289 class IteratorStateVariant {
290 public:
IteratorStateVariant()291 IteratorStateVariant() : data_(nullptr) {}
IteratorStateVariant(const IteratorStateVariant & other)292 IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
293 if (other.data_) {
294 Decode(*other.data_);
295 }
296 }
297 IteratorStateVariant& operator=(IteratorStateVariant&& other) = default;
298 IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete;
299
300 // Initializes `this` from a VariantTensorData object.
InitializeFromVariantData(std::unique_ptr<VariantTensorData> d)301 Status InitializeFromVariantData(std::unique_ptr<VariantTensorData> d) {
302 data_ = std::move(d);
303 return OkStatus();
304 }
305
TypeName() const306 string TypeName() const { return kIteratorVariantTypeName; }
Encode(VariantTensorData * data) const307 void Encode(VariantTensorData* data) const { *data = *data_; }
Decode(VariantTensorData data)308 bool Decode(VariantTensorData data) {
309 if (data.type_name() != TypeName()) {
310 return false;
311 }
312 auto tensor_data = std::make_unique<VariantTensorData>();
313 std::swap(*tensor_data, data);
314 data_ = std::move(tensor_data);
315 return true;
316 }
317
318 // Returns a borrowed pointer to the underlying VariantTensorData.
GetData() const319 const VariantTensorData* GetData() const { return data_.get(); }
320
DebugString() const321 string DebugString() const {
322 if (data_) {
323 return strings::StrCat("IteratorStateVariant<", data_->DebugString(),
324 ">");
325 } else {
326 return strings::StrCat("IteratorStateVariant<empty>");
327 }
328 }
329
330 private:
331 std::unique_ptr<VariantTensorData> data_;
332 };
333
334 // Register the reader class in the global variant decode_fn registry
335 // so that a Variant containing a serialized representation of iterator state
336 // can be decoded using DecodeUnaryVariant. If we don't do this we will need
337 // to manually decode the returned Variant using MaybeDecodeAndCopy in
338 // DeserializeIteratorOp which is not recommended.
339 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
340 kIteratorVariantTypeName);
341
342 // A helper class that uses a list of IteratorStateVariant objects to represent
343 // the state for an iterator resource. It exposes methods that help with
344 // saving and restoring of this state. Sample usage
345 // Saving:
346 // IteratorVariantSerializer serializer;
347 // serializer.InitializeFromIterator(iterator_resource);
348 // Tensor serialized_t;
349 // serializer.Serialize(&serialized_t);
350 //
351 // Restoring:
352 // IteratorVariantSerializer serializer;
353 // serializer.InitFromTensor(ctx->input(0));
354 // IteratorStateReader* reader = serializer.GetReader();
355 // iterator_resource->Restore(ctx, reader);
356 class IteratorVariantSerializer {
357 public:
IteratorVariantSerializer()358 IteratorVariantSerializer() {}
359
360 // Calls `Save` on the iterator_resource to build up the list of
361 // IteratorStateVariant objects.
InitializeFromIterator(SerializationContext * serialization_ctx,IteratorResource * iterator_resource)362 Status InitializeFromIterator(SerializationContext* serialization_ctx,
363 IteratorResource* iterator_resource) {
364 VariantTensorDataWriter writer;
365 TF_RETURN_IF_ERROR(iterator_resource->Save(serialization_ctx, &writer));
366 std::vector<std::unique_ptr<VariantTensorData>> data;
367 writer.ReleaseData(&data);
368 variants_.clear();
369 variants_.reserve(data.size());
370 for (auto& it : data) {
371 IteratorStateVariant v;
372 TF_RETURN_IF_ERROR(v.InitializeFromVariantData(std::move(it)));
373 variants_.push_back(v);
374 }
375 num_tensors_ = variants_.size();
376 can_serialize_ = true;
377 return OkStatus();
378 }
379
380 // Initializes `this` from `serialized_t` while restoring the iterator state.
InitFromTensor(const Tensor * serialized_t)381 Status InitFromTensor(const Tensor* serialized_t) {
382 int64_t num_tensors = serialized_t->dim_size(0);
383 auto serialized_vec = serialized_t->vec<Variant>();
384 std::vector<const VariantTensorData*> data;
385 data.reserve(num_tensors);
386 for (int i = 0; i < num_tensors; ++i) {
387 auto* w = serialized_vec(i).get<IteratorStateVariant>();
388 if (!w) {
389 return errors::Internal(
390 "Cannot initialize an iterator from tensor ",
391 serialized_vec(i).DebugString(),
392 ". Expected a variant tensor of type IteratorStateVariant");
393 }
394 data.push_back(w->GetData());
395 }
396 reader_ = std::make_unique<VariantTensorDataReader>(data);
397 num_tensors_ = data.size();
398 return OkStatus();
399 }
400
NumTensors()401 int64_t NumTensors() { return num_tensors_; }
402
403 // Stores the IteratorStateVariant list into a pre-allocated tensor. Expects
404 // that InitializeFromIterator was called before.
Serialize(Tensor * serialized)405 Status Serialize(Tensor* serialized) {
406 if (!can_serialize_) {
407 return errors::InvalidArgument(
408 "Please call InitializeFromIterator before calling Serialize.");
409 }
410 int64_t size = variants_.size();
411 for (int64_t i = 0; i < size; ++i) {
412 if (variants_[i].GetData() == nullptr) {
413 return errors::Internal(
414 "Cannot serialize an empty IteratorStateVariant");
415 }
416 serialized->vec<Variant>()(i) = variants_[i];
417 }
418 return OkStatus();
419 }
420
421 // Returns an IteratorStateReader to restore iterator state. Expects that
422 // InitFromTensor was called before.
GetReader()423 IteratorStateReader* GetReader() { return reader_.get(); }
424
425 private:
426 bool can_serialize_ = false;
427 int64_t num_tensors_;
428 std::vector<IteratorStateVariant> variants_;
429 std::unique_ptr<IteratorStateReader> reader_;
430 };
431
432 } // namespace
433
434 // Note that IteratorHandleOp holds a reference to the resource it creates. If
435 // cleaning up resources with DestroyResourceOp is important, consider creating
436 // resource containers with AnonymousIteratorHandleOp instead.
IteratorHandleOp(OpKernelConstruction * ctx)437 IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
438 : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
439 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
440 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
441 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
442 }
443
444 // The resource is deleted from the resource manager only when it is private
445 // to kernel. Ideally the resource should be deleted when it is no longer held
446 // by anyone, but it would break backward compatibility.
~IteratorHandleOp()447 IteratorHandleOp::~IteratorHandleOp() {
448 if (resource_ != nullptr) {
449 resource_->Unref();
450 if (cinfo_.resource_is_private_to_kernel()) {
451 if (!cinfo_.resource_manager()
452 ->template Delete<IteratorResource>(cinfo_.container(),
453 cinfo_.name())
454 .ok()) {
455 // Do nothing; the resource can have been deleted by session resets.
456 }
457 }
458 }
459 }
460
Compute(OpKernelContext * context)461 void IteratorHandleOp::Compute(OpKernelContext* context)
462 TF_LOCKS_EXCLUDED(mu_) {
463 {
464 mutex_lock l(mu_);
465 if (resource_ == nullptr) {
466 FunctionLibraryRuntime* flr;
467 std::unique_ptr<DeviceMgr> device_mgr(nullptr);
468 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
469 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
470 // If the iterator is shared then we construct a new FLR, and pass that
471 // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
472 // functions from the iterator. We may add this functionality if there
473 // is sufficient demand, but it will require a significant refactoring.
474 if (!name_.empty()) {
475 flr = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
476 } else {
477 OP_REQUIRES_OK(context, context->function_library()->Clone(
478 &flib_def, &pflr, &flr, true));
479 }
480
481 ResourceMgr* mgr = context->resource_manager();
482 OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
483
484 IteratorResource* resource;
485 OP_REQUIRES_OK(
486 context,
487 mgr->LookupOrCreate<IteratorResource>(
488 cinfo_.container(), cinfo_.name(), &resource,
489 [context, flr, &device_mgr, &flib_def, &pflr,
490 this](IteratorResource** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
491 *ret = new IteratorResource(
492 context->env(), output_dtypes_, output_shapes_,
493 std::move(device_mgr), std::move(flib_def), std::move(pflr),
494 flr);
495 return OkStatus();
496 }));
497
498 Status s = VerifyResource(resource);
499 if (TF_PREDICT_FALSE(!s.ok())) {
500 resource->Unref();
501 context->SetStatus(s);
502 return;
503 }
504
505 resource_ = resource;
506 }
507 }
508 OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
509 context, 0, cinfo_.container(), cinfo_.name(),
510 TypeIndex::Make<IteratorResource>()));
511 }
512
VerifyResource(IteratorResource * resource)513 Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
514 TF_RETURN_IF_ERROR(
515 VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
516 TF_RETURN_IF_ERROR(
517 VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
518 return OkStatus();
519 }
520
CreatePrivateFLR(OpKernelContext * ctx,std::unique_ptr<DeviceMgr> * device_mgr,std::unique_ptr<FunctionLibraryDefinition> * flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * pflr)521 FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
522 OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
523 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
524 std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
525 // Wrap the existing device in order to see any captured resources
526 // in its resource manager. The existing device will outlive the
527 // IteratorResource, because we are storing the IteratorResource
528 // in that device's resource manager.
529
530 *device_mgr =
531 std::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
532 ctx->device()->name(), down_cast<Device*>(ctx->device()),
533 false /* owns_underlying */, false /* isolate_session_state */));
534 *flib_def = std::make_unique<FunctionLibraryDefinition>(
535 *ctx->function_library()->GetFunctionLibraryDefinition());
536 const auto* config = ctx->function_library()->config_proto();
537 *pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
538 device_mgr->get(), ctx->env(),
539 /*config=*/config, graph_def_version_, flib_def->get(),
540 config->graph_options().optimizer_options());
541
542 return (*pflr)->GetFLR(ctx->device()->name());
543 }
544
545 // Like IteratorHandleOp, but creates handles which are never shared, and does
546 // not hold a reference to these handles. The latter is important for eager
547 // execution, since OpKernel instances generally live as long as the program
548 // running them.
AnonymousIteratorHandleOp(OpKernelConstruction * context)549 AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
550 OpKernelConstruction* context)
551 : AnonymousResourceOp<IteratorResource>(
552 context,
553 /* ref_counting */
554 // Only enable this for V2 (via Python's iter protocol),
555 // AnonymousIteratorV1 requires IteratorToStringHandle, which is
556 // undefined on Refcounting ResourceHandle.
557 context->def().op() == kAnonymousIteratorV2 ||
558 context->def().op() == kAnonymousIteratorV3,
559 // V1 does not return a deleter.
560 /* return_deleter */
561 context->def().op() == kAnonymousIteratorV2),
562 graph_def_version_(context->graph_def_version()) {
563 OP_REQUIRES_OK(context, context->GetAttr(kOutputTypes, &output_dtypes_));
564 OP_REQUIRES_OK(context, context->GetAttr(kOutputShapes, &output_shapes_));
565 }
566
name()567 string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; }
568
CreateResource(OpKernelContext * ctx,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib,IteratorResource ** resource)569 Status AnonymousIteratorHandleOp::CreateResource(
570 OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
571 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
572 FunctionLibraryRuntime* lib, IteratorResource** resource) {
573 std::unique_ptr<DeviceMgr> device_mgr(nullptr);
574 *resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_,
575 std::move(device_mgr), std::move(flib_def),
576 std::move(pflr), lib);
577 return OkStatus();
578 }
579
HybridAsyncOpKernel(OpKernelConstruction * ctx,const char * background_worker_name)580 HybridAsyncOpKernel::HybridAsyncOpKernel(OpKernelConstruction* ctx,
581 const char* background_worker_name)
582 : AsyncOpKernel(ctx),
583 background_worker_(ctx->env(), background_worker_name) {}
584
ComputeAsync(OpKernelContext * ctx,DoneCallback done)585 void HybridAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
586 DoneCallback done) {
587 background_worker_.Schedule([this, ctx, done = std::move(done)]() {
588 ctx->SetStatus(DoCompute(ctx));
589 done();
590 });
591 }
592
Compute(OpKernelContext * ctx)593 void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) {
594 ctx->SetStatus(DoCompute(ctx));
595 }
596
DoCompute(OpKernelContext * ctx)597 Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
598 tensorflow::ResourceTagger tag(kTFDataResourceTag,
599 ctx->op_kernel().type_string());
600 DatasetBase* dataset;
601 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
602 IteratorResource* iterator_resource;
603 TF_RETURN_IF_ERROR(
604 LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
605 core::ScopedUnref unref_iterator(iterator_resource);
606 return iterator_resource->SetIteratorFromDataset(ctx, dataset);
607 }
608
DoCompute(OpKernelContext * ctx)609 Status DeleteIteratorOp::DoCompute(OpKernelContext* ctx) {
610 tensorflow::ResourceTagger tag(kTFDataResourceTag,
611 ctx->op_kernel().type_string());
612 const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
613 // The iterator resource is guaranteed to exist because the variant tensor
614 // wrapping the deleter is provided as an unused input to this op, which
615 // guarantees that it has not run yet.
616 return DeleteResource(ctx, handle);
617 }
618
619 namespace {
620
621 class ToSingleElementOp : public AsyncOpKernel {
622 public:
ToSingleElementOp(OpKernelConstruction * ctx)623 explicit ToSingleElementOp(OpKernelConstruction* ctx)
624 : AsyncOpKernel(ctx),
625 metrics_collector_(ctx->device()->attributes().device_type(),
626 *ctx->env()),
627 unbounded_threadpool_(ctx->env(), "tf_data_to_single_element") {
628 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
629 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
630 }
631
ComputeAsync(OpKernelContext * ctx,DoneCallback done)632 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
633 unbounded_threadpool_.Schedule([this, ctx, done = std::move(done)]() {
634 ctx->SetStatus(DoCompute(ctx));
635 done();
636 });
637 }
638
Compute(OpKernelContext * ctx)639 void Compute(OpKernelContext* ctx) override {
640 ctx->SetStatus(DoCompute(ctx));
641 }
642
643 private:
DoCompute(OpKernelContext * ctx)644 Status DoCompute(OpKernelContext* ctx) {
645 profiler::TraceMe traceme(
646 [&] {
647 return profiler::TraceMeEncode("ToSingleElementOp::DoCompute",
648 {{"id", ctx->step_id()}});
649 },
650 profiler::kInfo);
651 tensorflow::ResourceTagger tag(kTFDataResourceTag,
652 ctx->op_kernel().type_string());
653 DatasetBase* dataset;
654 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
655
656 IteratorContext::Params params(ctx);
657 ResourceMgr resource_mgr;
658 params.resource_mgr = &resource_mgr;
659 CancellationManager cancellation_manager(ctx->cancellation_manager());
660 params.cancellation_manager = &cancellation_manager;
661
662 IteratorContext iter_ctx(std::move(params));
663 std::unique_ptr<IteratorBase> iterator;
664 TF_RETURN_IF_ERROR(dataset->MakeIterator(
665 &iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator));
666
667 std::vector<Tensor> components;
668 components.reserve(dataset->output_dtypes().size());
669 bool end_of_sequence = false;
670
671 const absl::Time start_time = metrics_collector_.RecordStart();
672 TF_RETURN_IF_ERROR(
673 iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
674 metrics_collector_.RecordStop(start_time, components);
675
676 if (end_of_sequence) {
677 return errors::InvalidArgument("Dataset was empty.");
678 }
679 TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
680 TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
681 for (int i = 0; i < components.size(); ++i) {
682 ctx->set_output(i, components[i]);
683 }
684
685 components.clear();
686 TF_RETURN_IF_ERROR(
687 iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
688 if (!end_of_sequence) {
689 return errors::InvalidArgument("Dataset had more than one element.");
690 }
691 return OkStatus();
692 }
693
694 IteratorMetricsCollector metrics_collector_;
695 UnboundedThreadPool unbounded_threadpool_;
696 DataTypeVector output_types_;
697 std::vector<PartialTensorShape> output_shapes_;
698 };
699
700 class OneShotIteratorOp : public AsyncOpKernel {
701 public:
OneShotIteratorOp(OpKernelConstruction * ctx)702 explicit OneShotIteratorOp(OpKernelConstruction* ctx)
703 : AsyncOpKernel(ctx),
704 background_worker_(ctx->env(), "tf_data_one_shot_iterator"),
705 graph_def_version_(ctx->graph_def_version())
706
707 {
708 string shared_name;
709 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name));
710 OP_REQUIRES(ctx, shared_name.empty(),
711 errors::InvalidArgument("OneShotIteratorOp does not currently "
712 "support the 'shared_name' attr."));
713 OP_REQUIRES_OK(ctx,
714 ctx->GetAttr("dataset_factory", &dataset_factory_func_));
715 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
716 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
717 }
718
~OneShotIteratorOp()719 ~OneShotIteratorOp() override {
720 if (iterator_resource_ != nullptr) {
721 iterator_resource_->Unref();
722 if (!cinfo_.resource_manager()
723 ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name())
724 .ok()) {
725 // Do nothing; the resource can have been deleted by session resets.
726 }
727 }
728 }
729
730 // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`,
731 // but due to the fact that `ResourceOpKernel<T>::CreateResource()`
732 // does not provide access to the `OpKernelContext*` and we need
733 // this to invoke the factory function, it's not possible to
734 // implement this kernel by implementing `CreateResource()`.
735 // Furthermore, due to the fact that this kernel might block when
736 // running the initialization function, we must implement this
737 // kernel as an async kernel.
ComputeAsync(OpKernelContext * ctx,DoneCallback done)738 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
739 tensorflow::ResourceTagger tag(kTFDataResourceTag,
740 ctx->op_kernel().type_string());
741 {
742 mutex_lock l(mu_);
743 if (iterator_resource_ == nullptr && initialization_status_.ok()) {
744 // The initialization thread will call `done`.
745 if (!initialization_started_) {
746 // TODO(mrry): Convert the initialization code to use
747 // callbacks instead of wasting a thread.
748 background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); });
749 initialization_started_ = true;
750 } else {
751 done_callbacks_.emplace_back(ctx, std::move(done));
752 }
753 return;
754 }
755 }
756 ProduceOutput(ctx, done);
757 }
758
759 private:
Init(OpKernelContext * ctx,const DoneCallback & done)760 void Init(OpKernelContext* ctx, const DoneCallback& done) {
761 IteratorResource* iterator = nullptr;
762 ContainerInfo cinfo;
763 Status s = TryInit(ctx, &iterator, &cinfo);
764
765 std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run;
766 {
767 mutex_lock l(mu_);
768 if (s.ok()) {
769 iterator_resource_ = iterator;
770 cinfo_ = cinfo;
771 }
772 initialization_status_ = s;
773 std::swap(done_callbacks_, callbacks_to_run);
774 }
775
776 for (auto&& ctx_done : callbacks_to_run) {
777 ProduceOutput(ctx_done.first, ctx_done.second);
778 }
779 ProduceOutput(ctx, done);
780 }
781
TryInit(OpKernelContext * ctx,IteratorResource ** iterator,ContainerInfo * cinfo)782 Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
783 ContainerInfo* cinfo) {
784 TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def()));
785
786 FunctionLibraryRuntime* flr;
787 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
788 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
789 TF_RETURN_IF_ERROR(
790 ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
791
792 // Create an IteratorResource that will hold the iterator for this op.
793 TF_RETURN_IF_ERROR(
794 ctx->resource_manager()->LookupOrCreate<IteratorResource>(
795 cinfo->container(), cinfo->name(), iterator,
796 [ctx, flr, this, &flib_def, &pflr](IteratorResource** ret)
797 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
798 *ret = new IteratorResource(
799 ctx->env(), output_dtypes_, output_shapes_,
800 /*device_mgr=*/nullptr, std::move(flib_def),
801 std::move(pflr), flr);
802 return OkStatus();
803 }));
804
805 core::ScopedUnref unref_iterator(*iterator);
806
807 TF_RETURN_IF_ERROR(
808 VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes()));
809 TF_RETURN_IF_ERROR(
810 VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes()));
811
812 // Call the dataset_factory_func_ to create a new dataset,
813 // over which this op will iterate.
814 FunctionLibraryRuntime::Handle f_handle;
815 TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
816 dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()),
817 &f_handle));
818 FunctionLibraryRuntime::Options opts;
819 opts.cancellation_manager = ctx->cancellation_manager();
820 ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
821 ctx->resource_manager()->Cleanup(name).IgnoreError();
822 });
823 opts.step_container = &step_container;
824 opts.runner = ctx->runner();
825 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
826 std::vector<Tensor> return_values;
827 TF_RETURN_IF_ERROR(ctx->function_library()->RunSync(
828 std::move(opts), f_handle, {}, &return_values));
829 if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
830 !TensorShapeUtils::IsScalar(return_values[0].shape())) {
831 return errors::InvalidArgument(
832 "The `dataset_factory` function must return "
833 "a single scalar of dtype DT_VARIANT.");
834 }
835
836 // Create an iterator for the dataset that was created in the
837 // factory function.
838 DatasetBase* dataset;
839 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
840 TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
841 (*iterator)->Ref();
842 return OkStatus();
843 }
844
ProduceOutput(OpKernelContext * ctx,const DoneCallback & done)845 void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) {
846 Tensor* handle;
847 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle),
848 done);
849 Status s;
850 {
851 mutex_lock l(mu_);
852 s = initialization_status_;
853 if (s.ok()) {
854 handle->scalar<ResourceHandle>()() =
855 MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(),
856 cinfo_.name());
857 }
858 }
859 OP_REQUIRES_OK_ASYNC(ctx, s, done);
860 done();
861 }
862
863 NameAttrList dataset_factory_func_;
864 DataTypeVector output_dtypes_;
865 std::vector<PartialTensorShape> output_shapes_;
866
867 BackgroundWorker background_worker_;
868
869 mutex mu_;
870 ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
871 IteratorResource* iterator_resource_ TF_GUARDED_BY(mu_) = nullptr;
872
873 bool initialization_started_ TF_GUARDED_BY(mu_) = false;
874 Status initialization_status_ TF_GUARDED_BY(mu_);
875 std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_
876 TF_GUARDED_BY(mu_);
877 const int graph_def_version_;
878 };
879
880 } // namespace
881
AsAsync()882 AsyncOpKernel* IteratorGetNextOp::AsAsync() {
883 return type_string() == "IteratorGetNextSync" ? nullptr : this;
884 }
885
RecordElementSize(const std::vector<Tensor> element,profiler::TraceMe * traceme)886 void RecordElementSize(const std::vector<Tensor> element,
887 profiler::TraceMe* traceme) {
888 traceme->AppendMetadata([&]() {
889 int64_t element_size = 0;
890 for (const auto& component : element) {
891 element_size += component.TotalBytes();
892 }
893 return profiler::TraceMeEncode({{"element_size", element_size}});
894 });
895 }
896
DoCompute(OpKernelContext * ctx)897 Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
898 VLOG(3) << "IteratorGetNextOp enter. iter_id=" << ctx->frame_iter().iter_id;
899 auto cleanup = gtl::MakeCleanup([ctx] {
900 VLOG(3) << "IteratorGetNextOp exit. iter_id=" << ctx->frame_iter().iter_id;
901 });
902 activity_watcher::ActivityScope activity_scope([ctx = ctx]() {
903 return activity_watcher::ActivityFromContext(
904 ctx, "IteratorGetNextOp::DoCompute",
905 activity_watcher::ActivityCategory::kDatasetOp);
906 });
907 profiler::TraceMe traceme(
908 [&] {
909 return profiler::TraceMeEncode(
910 "IteratorGetNextOp::DoCompute",
911 {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
912 },
913 profiler::kInfo);
914 tensorflow::ResourceTagger tag(kTFDataResourceTag,
915 ctx->op_kernel().type_string());
916 IteratorResource* iterator;
917 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
918 core::ScopedUnref unref_iterator(iterator);
919 std::vector<Tensor> components;
920 bool end_of_sequence = false;
921
922 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
923 if (end_of_sequence) {
924 return errors::OutOfRange("End of sequence");
925 }
926 TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
927 TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
928 RecordElementSize(components, &traceme);
929 for (int i = 0; i < components.size(); ++i) {
930 ctx->set_output(i, components[i]);
931 }
932 return OkStatus();
933 }
934
DoCompute(OpKernelContext * ctx)935 Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
936 VLOG(3) << "IteratorGetNextAsOptionalOp enter. iter_id="
937 << ctx->frame_iter().iter_id;
938 auto cleanup = gtl::MakeCleanup([ctx] {
939 VLOG(3) << "IteratorGetNextAsOptionalOp exit. iter_id="
940 << ctx->frame_iter().iter_id;
941 });
942 activity_watcher::ActivityScope activity_scope([ctx = ctx]() {
943 return activity_watcher::ActivityFromContext(
944 ctx, "IteratorGetNextAsOptionalOp::DoCompute",
945 activity_watcher::ActivityCategory::kDatasetOp);
946 });
947 profiler::TraceMe traceme(
948 [&] {
949 return profiler::TraceMeEncode(
950 "IteratorGetNextAsOptionalOp::DoCompute",
951 {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
952 },
953 profiler::kInfo);
954 tensorflow::ResourceTagger tag(kTFDataResourceTag,
955 ctx->op_kernel().type_string());
956 IteratorResource* iterator;
957 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
958 core::ScopedUnref unref_iterator(iterator);
959 std::vector<Tensor> components;
960 bool end_of_sequence = false;
961
962 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
963
964 if (end_of_sequence) {
965 return WriteOptionalNoneToOutput(ctx, 0);
966 } else {
967 RecordElementSize(components, &traceme);
968 for (int i = 0; i < components.size(); ++i) {
969 if (components[i].dtype() != output_types_[i]) {
970 return errors::InvalidArgument(
971 "The given optional does not match the expected type for "
972 "component ",
973 i, ". Expected: ", DataTypeString(output_types_[i]),
974 ". Actual: ", DataTypeString(components[i].dtype()), ".");
975 }
976 if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
977 return errors::InvalidArgument(
978 "The given optional does not match the expected shape "
979 "for component ",
980 i, ". Expected: ", output_shapes_[i].DebugString(),
981 ". Actual: ", components[i].shape().DebugString(), ".");
982 }
983 }
984 return WriteOptionalWithValueToOutput(ctx, 0, std::move(components));
985 }
986 }
987
Compute(OpKernelContext * ctx)988 void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
989 const Tensor& resource_handle_t = ctx->input(0);
990 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
991 errors::InvalidArgument("resource_handle must be a scalar"));
992
993 // Validate that the handle corresponds to a real resource, and
994 // that it is an IteratorResource.
995 IteratorResource* iterator_resource;
996 OP_REQUIRES_OK(
997 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
998 iterator_resource->Unref();
999
1000 Tensor* string_handle_t;
1001 OP_REQUIRES_OK(ctx,
1002 ctx->allocate_output(0, TensorShape({}), &string_handle_t));
1003 string_handle_t->scalar<tstring>()() =
1004 resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
1005 }
1006
IteratorFromStringHandleOp(OpKernelConstruction * ctx)1007 IteratorFromStringHandleOp::IteratorFromStringHandleOp(
1008 OpKernelConstruction* ctx)
1009 : OpKernel(ctx) {
1010 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
1011 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1012 OP_REQUIRES(
1013 ctx,
1014 output_dtypes_.empty() || output_shapes_.empty() ||
1015 output_dtypes_.size() == output_shapes_.size(),
1016 errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
1017 "are set, they must have the same length."));
1018 }
1019
Compute(OpKernelContext * ctx)1020 void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
1021 const Tensor& string_handle_t = ctx->input(0);
1022 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
1023 errors::InvalidArgument("string_handle must be a scalar"));
1024
1025 ResourceHandle resource_handle;
1026 OP_REQUIRES(
1027 ctx, resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
1028 errors::InvalidArgument(
1029 "Could not parse string_handle as a valid ResourceHandle"));
1030
1031 OP_REQUIRES(
1032 ctx, resource_handle.device() == ctx->device()->attributes().name(),
1033 errors::InvalidArgument("Attempted create an iterator on device \"",
1034 ctx->device()->attributes().name(),
1035 "\" from handle defined on device \"",
1036 resource_handle.device(), "\""));
1037
1038 // Validate that the handle corresponds to a real resource, and
1039 // that it is an IteratorResource.
1040 IteratorResource* iterator_resource;
1041 OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource));
1042 core::ScopedUnref unref_iterator(iterator_resource);
1043 if (!output_dtypes_.empty()) {
1044 OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
1045 iterator_resource->output_dtypes()));
1046 }
1047 if (!output_shapes_.empty()) {
1048 OP_REQUIRES_OK(ctx,
1049 VerifyShapesCompatible(output_shapes_,
1050 iterator_resource->output_shapes()));
1051 }
1052
1053 Tensor* resource_handle_t;
1054 OP_REQUIRES_OK(ctx,
1055 ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
1056 resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
1057 }
1058
SerializeIteratorOp(OpKernelConstruction * ctx)1059 SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
1060 : OpKernel(ctx) {
1061 if (ctx->HasAttr(kExternalStatePolicy)) {
1062 int64_t state_change_option;
1063 OP_REQUIRES_OK(ctx,
1064 ctx->GetAttr(kExternalStatePolicy, &state_change_option));
1065 external_state_policy_ =
1066 SerializationContext::ExternalStatePolicy(state_change_option);
1067 }
1068 }
1069
Compute(OpKernelContext * ctx)1070 void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
1071 tensorflow::ResourceTagger tag(kTFDataResourceTag,
1072 ctx->op_kernel().type_string());
1073 const Tensor& resource_handle_t = ctx->input(0);
1074 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1075 errors::InvalidArgument("resource_handle must be a scalar"));
1076 // Validate that the handle corresponds to a real resource, and
1077 // that it is an IteratorResource.
1078 IteratorResource* iterator_resource;
1079 OP_REQUIRES_OK(
1080 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1081 core::ScopedUnref unref_iterator(iterator_resource);
1082 IteratorVariantSerializer serializer;
1083 SerializationContext::Params params(ctx);
1084 params.external_state_policy = external_state_policy_;
1085 SerializationContext serialization_ctx(params);
1086 OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
1087 iterator_resource));
1088 Tensor* serialized_t;
1089 OP_REQUIRES_OK(ctx,
1090 ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
1091 &serialized_t));
1092 OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
1093 }
1094
Compute(OpKernelContext * ctx)1095 void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
1096 tensorflow::ResourceTagger tag(kTFDataResourceTag,
1097 ctx->op_kernel().type_string());
1098 // Validate that the handle corresponds to a real resource, and
1099 // that it is an IteratorResource.
1100 IteratorResource* iterator_resource;
1101 OP_REQUIRES_OK(
1102 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1103 core::ScopedUnref unref_iterator(iterator_resource);
1104 const Tensor* serialized_t;
1105 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
1106 IteratorVariantSerializer serializer;
1107 OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
1108 Status s = iterator_resource->Restore(ctx, serializer.GetReader());
1109 if (!s.ok()) {
1110 OP_REQUIRES_OK(
1111 ctx,
1112 errors::CreateWithUpdatedMessage(
1113 s, absl::StrCat(
1114 "Failed to restore dataset iterator from checkpoint: ",
1115 s.error_message(),
1116 ". Make sure the dataset definition has not changed between "
1117 "the process that saved the checkpoint and the process that "
1118 "is restoring it.")));
1119 }
1120 }
1121
1122 namespace {
1123
1124 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
1125 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
1126 IteratorHandleOp);
1127 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU).Priority(1),
1128 IteratorHandleOp);
1129 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2),
1130 MakeIteratorOp);
1131 REGISTER_KERNEL_BUILDER(
1132 Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"),
1133 MakeIteratorOp);
1134 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_CPU).Priority(2),
1135 DeleteIteratorOp);
1136 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_GPU).Priority(1),
1137 DeleteIteratorOp);
1138 REGISTER_KERNEL_BUILDER(
1139 Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2),
1140 AnonymousIteratorHandleOp);
1141 REGISTER_KERNEL_BUILDER(
1142 Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1),
1143 AnonymousIteratorHandleOp);
1144 REGISTER_KERNEL_BUILDER(
1145 Name("AnonymousIteratorV2").Device(DEVICE_CPU).Priority(2),
1146 AnonymousIteratorHandleOp);
1147 REGISTER_KERNEL_BUILDER(
1148 Name("AnonymousIteratorV2").Device(DEVICE_GPU).Priority(1),
1149 AnonymousIteratorHandleOp);
1150 REGISTER_KERNEL_BUILDER(
1151 Name("AnonymousIteratorV3").Device(DEVICE_CPU).Priority(2),
1152 AnonymousIteratorHandleOp);
1153 REGISTER_KERNEL_BUILDER(
1154 Name("AnonymousIteratorV3").Device(DEVICE_GPU).Priority(1),
1155 AnonymousIteratorHandleOp);
1156 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
1157 ToSingleElementOp);
1158 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
1159 OneShotIteratorOp);
1160 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU).Priority(2),
1161 IteratorGetNextOp);
1162 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU).Priority(1),
1163 IteratorGetNextOp);
1164 REGISTER_KERNEL_BUILDER(
1165 Name("IteratorGetNextSync").Device(DEVICE_CPU).Priority(2),
1166 IteratorGetNextOp);
1167 REGISTER_KERNEL_BUILDER(
1168 Name("IteratorGetNextSync").Device(DEVICE_GPU).Priority(1),
1169 IteratorGetNextOp);
1170 REGISTER_KERNEL_BUILDER(
1171 Name("IteratorGetNextAsOptional").Device(DEVICE_CPU).Priority(2),
1172 IteratorGetNextAsOptionalOp);
1173 REGISTER_KERNEL_BUILDER(
1174 Name("IteratorGetNextAsOptional").Device(DEVICE_GPU).Priority(1),
1175 IteratorGetNextAsOptionalOp);
1176 REGISTER_KERNEL_BUILDER(
1177 Name("IteratorToStringHandle").Device(DEVICE_CPU).Priority(2),
1178 IteratorToStringHandleOp);
1179 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
1180 .Device(DEVICE_GPU)
1181 .HostMemory("string_handle")
1182 .Priority(1),
1183 IteratorToStringHandleOp);
1184 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
1185 IteratorFromStringHandleOp);
1186 REGISTER_KERNEL_BUILDER(
1187 Name("IteratorFromStringHandleV2").Device(DEVICE_CPU).Priority(2),
1188 IteratorFromStringHandleOp);
1189 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
1190 .Device(DEVICE_GPU)
1191 .HostMemory("string_handle")
1192 .Priority(1),
1193 IteratorFromStringHandleOp);
1194 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
1195 SerializeIteratorOp);
1196 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
1197 DeserializeIteratorOp);
1198
1199 } // namespace
1200
1201 } // namespace data
1202 } // namespace tensorflow
1203