1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/iterator_ops.h"
16
17 #include <functional>
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/common_runtime/graph_constructor.h"
24 #include "tensorflow/core/common_runtime/graph_runner.h"
25 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
26 #include "tensorflow/core/common_runtime/metrics.h"
27 #include "tensorflow/core/common_runtime/renamed_device.h"
28 #include "tensorflow/core/common_runtime/threadpool_device.h"
29 #include "tensorflow/core/data/captured_function.h"
30 #include "tensorflow/core/data/dataset_utils.h"
31 #include "tensorflow/core/data/root_dataset.h"
32 #include "tensorflow/core/data/serialization_utils.h"
33 #include "tensorflow/core/framework/cancellation.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/metrics.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/partial_tensor_shape.h"
38 #include "tensorflow/core/framework/resource_op_kernel.h"
39 #include "tensorflow/core/framework/stats_aggregator.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/framework/variant_op_registry.h"
43 #include "tensorflow/core/framework/variant_tensor_data.h"
44 #include "tensorflow/core/kernels/data/optional_ops.h"
45 #include "tensorflow/core/kernels/ops_util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/refcount.h"
48 #include "tensorflow/core/lib/core/threadpool.h"
49 #include "tensorflow/core/lib/gtl/cleanup.h"
50 #include "tensorflow/core/lib/random/random.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/lib/strings/stringprintf.h"
53 #include "tensorflow/core/platform/casts.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/mem.h"
56 #include "tensorflow/core/platform/mutex.h"
57 #include "tensorflow/core/platform/refcount.h"
58 #include "tensorflow/core/platform/resource.h"
59 #include "tensorflow/core/profiler/lib/traceme.h"
60 #include "tensorflow/core/profiler/lib/traceme_encode.h"
61 #include "tensorflow/core/public/session_options.h"
62
63 namespace tensorflow {
64 namespace data {
65 namespace {
66
67 // See documentation in ../../ops/dataset_ops.cc for a high-level
68 // description of the following ops.
69
70 const char kAnonymousIterator[] = "AnonymousIterator";
71 const char kAnonymousIteratorV2[] = "AnonymousIteratorV2";
72 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
73 const char kOutputShapes[] = "output_shapes";
74 const char kOutputTypes[] = "output_types";
75
76 // Safely subtracts x from y avoiding underflow.
safe_sub(uint64 x,uint64 y)77 inline uint64 safe_sub(uint64 x, uint64 y) { return x >= y ? x - y : 0; }
78
79 } // namespace
80
81 /* static */ constexpr const char* const
82 SerializeIteratorOp::kExternalStatePolicy;
83
IteratorResource(Env * env,const DataTypeVector & output_dtypes,const std::vector<PartialTensorShape> & output_shapes,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr)84 IteratorResource::IteratorResource(
85 Env* env, const DataTypeVector& output_dtypes,
86 const std::vector<PartialTensorShape>& output_shapes,
87 std::unique_ptr<DeviceMgr> device_mgr,
88 std::unique_ptr<FunctionLibraryDefinition> flib_def,
89 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
90 FunctionLibraryRuntime* flr)
91 : unbounded_thread_pool_(env, "tf_data_iterator_resource"),
92 device_mgr_(std::move(device_mgr)),
93 iterator_state_(std::make_shared<State>(std::move(flib_def),
94 std::move(pflr), flr,
95 /*iterator=*/nullptr)),
96 output_dtypes_(output_dtypes),
97 output_shapes_(output_shapes),
98 // We do not collect iterator resource metrics for non-CPU devices. This
99 // is a heuristic to avoid collecting metrics for device-side iterators
100 // created by the multi-device iterator mechanism.
101 collect_metrics_(flr->device()->device_type() == DEVICE_CPU) {
102 VLOG(2) << "creating iterator resource";
103 }
104
~IteratorResource()105 IteratorResource::~IteratorResource() {
106 VLOG(2) << "destroying iterator resource";
107 }
108
GetNext(OpKernelContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)109 Status IteratorResource::GetNext(OpKernelContext* ctx,
110 std::vector<Tensor>* out_tensors,
111 bool* end_of_sequence) {
112 std::shared_ptr<State> captured_state;
113 {
114 tf_shared_lock l(mu_);
115 captured_state = iterator_state_;
116 }
117 if (!captured_state->iterator()) {
118 return errors::FailedPrecondition(
119 "GetNext() failed because the iterator has not been initialized. "
120 "Ensure that you have run the initializer operation for this iterator "
121 "before getting the next element.");
122 }
123 IteratorContext::Params params(ctx);
124 params.flr = captured_state->flr();
125 params.function_handle_cache = captured_state->function_handle_cache();
126 params.resource_mgr = captured_state->resource_mgr();
127 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
128 params.thread_pool = &unbounded_thread_pool_;
129 params.cancellation_manager = captured_state->cancellation_manager();
130 std::function<void()> deregister_fn;
131 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
132 ctx->cancellation_manager(),
133 [cm = params.cancellation_manager]() { cm->StartCancel(); },
134 &deregister_fn));
135 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
136 const uint64 start_time_us = ctx->env()->NowMicros();
137 if (collect_metrics_) {
138 mutex_lock l(mu_);
139 if (get_next_end_time_us_ == 0) {
140 // We initialize `get_next_end_time_us_` to the start time of the first
141 // request to make it possible to use the delta between
142 // `get_next_end_time_us_` and subsequent `GetNext()` end time to
143 // incrementally collect the duration of the iterator's lifetime.
144 get_next_end_time_us_ = start_time_us;
145 }
146 if (num_get_next_calls_ == 0) {
147 get_next_start_time_us_ = start_time_us;
148 }
149 num_get_next_calls_++;
150 }
151 auto iterator_ = captured_state->iterator();
152 auto status = iterator_->GetNext(IteratorContext(std::move(params)),
153 out_tensors, end_of_sequence);
154 if (collect_metrics_) {
155 const uint64 end_time_us = ctx->env()->NowMicros();
156 metrics::RecordTFDataGetNextDuration(safe_sub(end_time_us, start_time_us));
157 metrics::RecordTFDataBytesFetched(GetTotalBytes(*out_tensors));
158 mutex_lock l(mu_);
159 metrics::RecordTFDataIteratorLifetime(
160 safe_sub(end_time_us, get_next_end_time_us_));
161 get_next_end_time_us_ = std::max(get_next_end_time_us_, end_time_us);
162 num_get_next_calls_--;
163 if (num_get_next_calls_ == 0) {
164 metrics::RecordTFDataIteratorBusy(
165 safe_sub(get_next_end_time_us_, get_next_start_time_us_));
166 }
167 }
168 return status;
169 }
170
Save(SerializationContext * ctx,IteratorStateWriter * writer)171 Status IteratorResource::Save(SerializationContext* ctx,
172 IteratorStateWriter* writer) {
173 std::shared_ptr<State> captured_state;
174 {
175 tf_shared_lock l(mu_);
176 captured_state = iterator_state_;
177 }
178 auto iterator_ = captured_state->iterator();
179 if (iterator_) {
180 return iterator_->Save(ctx, writer);
181 }
182 return errors::FailedPrecondition(
183 "Save() failed because the iterator has not been initialized. Ensure "
184 "that you have run the initializer operation for this iterator before "
185 "saving it.");
186 }
187
Restore(OpKernelContext * ctx,IteratorStateReader * reader)188 Status IteratorResource::Restore(OpKernelContext* ctx,
189 IteratorStateReader* reader) {
190 const DatasetBase* dataset;
191 std::shared_ptr<State> new_state;
192 {
193 tf_shared_lock l(mu_);
194 if (!iterator_state_->iterator()) {
195 return errors::FailedPrecondition(
196 "Restore() failed because the iterator has not been initialized. "
197 "Ensure that you have run the initializer operation for this "
198 "iterator before restoring it.");
199 }
200 auto iterator_ = iterator_state_->iterator();
201 dataset = iterator_->dataset();
202 // Hang onto a reference until we've created the new iterator, which will
203 // then hold its own reference to keep the dataset alive.
204 dataset->Ref();
205 new_state =
206 std::make_shared<State>(iterator_state_->flib_def(),
207 iterator_state_->pflr(), iterator_state_->flr(),
208 /*iterator=*/nullptr);
209 }
210 core::ScopedUnref scoped_unref(dataset);
211 IteratorContext::Params params(ctx);
212 params.flr = new_state->flr();
213 params.function_handle_cache = new_state->function_handle_cache();
214 params.resource_mgr = new_state->resource_mgr();
215 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
216 params.thread_pool = &unbounded_thread_pool_;
217 params.cancellation_manager = new_state->cancellation_manager();
218 std::function<void()> deregister_fn;
219 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
220 ctx->cancellation_manager(),
221 [cm = params.cancellation_manager]() { cm->StartCancel(); },
222 &deregister_fn));
223 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
224 std::unique_ptr<IteratorBase> iterator_base;
225 TF_RETURN_IF_ERROR(dataset->MakeIteratorFromCheckpoint(
226 IteratorContext(std::move(params)), "Iterator", reader, &iterator_base));
227 new_state->DowncastAndSetIterator(std::move(iterator_base));
228
229 mutex_lock l(mu_);
230 std::swap(iterator_state_, new_state);
231 return Status::OK();
232 }
233
SetIteratorFromDataset(OpKernelContext * ctx,DatasetBase * dataset)234 Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
235 DatasetBase* dataset) {
236 std::shared_ptr<State> new_state;
237 {
238 tf_shared_lock l(mu_);
239 new_state =
240 std::make_shared<State>(iterator_state_->flib_def(),
241 iterator_state_->pflr(), iterator_state_->flr(),
242 /*iterator=*/nullptr);
243 }
244
245 // Create new iterator.
246 IteratorContext::Params params(ctx);
247 params.flr = new_state->flr();
248 params.function_handle_cache = new_state->function_handle_cache();
249 params.resource_mgr = new_state->resource_mgr();
250 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
251 params.thread_pool = &unbounded_thread_pool_;
252 params.cancellation_manager = new_state->cancellation_manager();
253 std::function<void()> deregister_fn;
254 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
255 ctx->cancellation_manager(),
256 [cm = params.cancellation_manager]() { cm->StartCancel(); },
257 &deregister_fn));
258 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
259
260 std::unique_ptr<IteratorBase> iterator;
261 if (ctx->function_library()->device()->device_type() == DEVICE_CPU) {
262 DatasetBase* finalized_dataset;
263 TF_RETURN_IF_ERROR(FinalizeDataset(ctx, dataset, &finalized_dataset));
264 core::ScopedUnref unref(finalized_dataset);
265 TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
266 IteratorContext(std::move(params)),
267 /*parent=*/nullptr, "Iterator", &iterator));
268 } else {
269 TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
270 /*parent=*/nullptr, "Iterator",
271 &iterator));
272 }
273 TF_RETURN_IF_ERROR(
274 VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
275 TF_RETURN_IF_ERROR(
276 VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
277
278 new_state->DowncastAndSetIterator(std::move(iterator));
279
280 mutex_lock l(mu_);
281 std::swap(iterator_state_, new_state);
282 return Status::OK();
283 }
284
285 namespace {
286
287 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
288 // The get() method returns an VariantTensorData object which contains all the
289 // state needed to restore a single iterator.
290 //
291 // Usage example:
292 //
293 // Encoding:
294 //
295 // Tensor t(DT_VARIANT, TensorShape({}));
296 // t->scalar<Variant>()() = IteratorStateVariant();
297 //
298 // Encode() sets the type_name of the VariantTensorData object to
299 // IteratorStateVariant::TypeName().
300 //
301 // Decoding:
302 //
303 // Variant v = <VariantTensorDataProto object>;
304 // DecodeUnaryVariant(&v);
305 // IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
306 // IteratorStateReader reader({wrapper->GetData()});
307 // iterator_resource->Restore(ctx, &reader);
308 //
309 // The type_name of the VariantTensorData object to be decoded must
310 // match IteratorStateVariant::TypeName().
311 class IteratorStateVariant {
312 public:
IteratorStateVariant()313 IteratorStateVariant() : data_(nullptr) {}
IteratorStateVariant(const IteratorStateVariant & other)314 IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
315 if (other.data_) {
316 Decode(*other.data_);
317 }
318 }
319 IteratorStateVariant& operator=(IteratorStateVariant&& other) = default;
320 IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete;
321
322 // Initializes `this` from a VariantTensorData object.
InitializeFromVariantData(std::unique_ptr<VariantTensorData> d)323 Status InitializeFromVariantData(std::unique_ptr<VariantTensorData> d) {
324 data_ = std::move(d);
325 return Status::OK();
326 }
327
TypeName() const328 string TypeName() const { return kIteratorVariantTypeName; }
Encode(VariantTensorData * data) const329 void Encode(VariantTensorData* data) const { *data = *data_; }
Decode(VariantTensorData data)330 bool Decode(VariantTensorData data) {
331 if (data.type_name() != TypeName()) {
332 return false;
333 }
334 auto tensor_data = absl::make_unique<VariantTensorData>();
335 std::swap(*tensor_data, data);
336 data_ = std::move(tensor_data);
337 return true;
338 }
339
340 // Returns a borrowed pointer to the underlying VariantTensorData.
GetData() const341 const VariantTensorData* GetData() const { return data_.get(); }
342
DebugString() const343 string DebugString() const {
344 if (data_) {
345 return strings::StrCat("IteratorStateVariant<", data_->DebugString(),
346 ">");
347 } else {
348 return strings::StrCat("IteratorStateVariant<empty>");
349 }
350 }
351
352 private:
353 std::unique_ptr<VariantTensorData> data_;
354 };
355
356 // Register the reader class in the global variant decode_fn registry
357 // so that a Variant containing a serialized representation of iterator state
358 // can be decoded using DecodeUnaryVariant. If we don't do this we will need
359 // to manually decode the returned Variant using MaybeDecodeAndCopy in
360 // DeserializeIteratorOp which is not recommended.
361 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
362 kIteratorVariantTypeName);
363
364 // A helper class that uses a list of IteratorStateVariant objects to represent
365 // the state for an iterator resource. It exposes methods that help with
366 // saving and restoring of this state. Sample usage
367 // Saving:
368 // IteratorVariantSerializer serializer;
369 // serializer.InitializeFromIterator(iterator_resource);
370 // Tensor serialized_t;
371 // serializer.Serialize(&serialized_t);
372 //
373 // Restoring:
374 // IteratorVariantSerializer serializer;
375 // serializer.InitFromTensor(ctx->input(0));
376 // IteratorStateReader* reader = serializer.GetReader();
377 // iterator_resource->Restore(ctx, reader);
378 class IteratorVariantSerializer {
379 public:
IteratorVariantSerializer()380 IteratorVariantSerializer() {}
381
382 // Calls `Save` on the iterator_resource to build up the list of
383 // IteratorStateVariant objects.
InitializeFromIterator(SerializationContext * serialization_ctx,IteratorResource * iterator_resource)384 Status InitializeFromIterator(SerializationContext* serialization_ctx,
385 IteratorResource* iterator_resource) {
386 VariantTensorDataWriter writer;
387 TF_RETURN_IF_ERROR(iterator_resource->Save(serialization_ctx, &writer));
388 std::vector<std::unique_ptr<VariantTensorData>> data;
389 writer.ReleaseData(&data);
390 variants_.clear();
391 variants_.reserve(data.size());
392 for (auto& it : data) {
393 IteratorStateVariant v;
394 TF_RETURN_IF_ERROR(v.InitializeFromVariantData(std::move(it)));
395 variants_.push_back(v);
396 }
397 num_tensors_ = variants_.size();
398 can_serialize_ = true;
399 return Status::OK();
400 }
401
402 // Initializes `this` from `serialized_t` while restoring the iterator state.
InitFromTensor(const Tensor * serialized_t)403 Status InitFromTensor(const Tensor* serialized_t) {
404 int64_t num_tensors = serialized_t->dim_size(0);
405 auto serialized_vec = serialized_t->vec<Variant>();
406 std::vector<const VariantTensorData*> data;
407 data.reserve(num_tensors);
408 for (int i = 0; i < num_tensors; ++i) {
409 auto* w = serialized_vec(i).get<IteratorStateVariant>();
410 if (!w) {
411 return errors::Internal(
412 "Cannot initialize an iterator from tensor ",
413 serialized_vec(i).DebugString(),
414 ". Expected a variant tensor of type IteratorStateVariant");
415 }
416 data.push_back(w->GetData());
417 }
418 reader_ = absl::make_unique<VariantTensorDataReader>(data);
419 num_tensors_ = data.size();
420 return Status::OK();
421 }
422
NumTensors()423 int64 NumTensors() { return num_tensors_; }
424
425 // Stores the IteratorStateVariant list into a pre-allocated tensor. Expects
426 // that InitializeFromIterator was called before.
Serialize(Tensor * serialized)427 Status Serialize(Tensor* serialized) {
428 if (!can_serialize_) {
429 return errors::InvalidArgument(
430 "Please call InitializeFromIterator before calling Serialize.");
431 }
432 int64_t size = variants_.size();
433 for (int64_t i = 0; i < size; ++i) {
434 if (variants_[i].GetData() == nullptr) {
435 return errors::Internal(
436 "Cannot serialize an empty IteratorStateVariant");
437 }
438 serialized->vec<Variant>()(i) = variants_[i];
439 }
440 return Status::OK();
441 }
442
443 // Returns an IteratorStateReader to restore iterator state. Expects that
444 // InitFromTensor was called before.
GetReader()445 IteratorStateReader* GetReader() { return reader_.get(); }
446
447 private:
448 bool can_serialize_ = false;
449 int64 num_tensors_;
450 std::vector<IteratorStateVariant> variants_;
451 std::unique_ptr<IteratorStateReader> reader_;
452 };
453
454 } // namespace
455
456 // Note that IteratorHandleOp holds a reference to the resource it creates. If
457 // cleaning up resources with DestroyResourceOp is important, consider creating
458 // resource containers with AnonymousIteratorHandleOp instead.
IteratorHandleOp(OpKernelConstruction * ctx)459 IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
460 : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
461 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
462 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
463 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
464 }
465
466 // The resource is deleted from the resource manager only when it is private
467 // to kernel. Ideally the resource should be deleted when it is no longer held
468 // by anyone, but it would break backward compatibility.
~IteratorHandleOp()469 IteratorHandleOp::~IteratorHandleOp() {
470 if (resource_ != nullptr) {
471 resource_->Unref();
472 if (cinfo_.resource_is_private_to_kernel()) {
473 if (!cinfo_.resource_manager()
474 ->template Delete<IteratorResource>(cinfo_.container(),
475 cinfo_.name())
476 .ok()) {
477 // Do nothing; the resource can have been deleted by session resets.
478 }
479 }
480 }
481 }
482
Compute(OpKernelContext * context)483 void IteratorHandleOp::Compute(OpKernelContext* context)
484 TF_LOCKS_EXCLUDED(mu_) {
485 {
486 mutex_lock l(mu_);
487 if (resource_ == nullptr) {
488 FunctionLibraryRuntime* flr;
489 std::unique_ptr<DeviceMgr> device_mgr(nullptr);
490 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
491 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
492 // If the iterator is shared then we construct a new FLR, and pass that
493 // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
494 // functions from the iterator. We may add this functionality if there
495 // is sufficient demand, but it will require a significant refactoring.
496 if (!name_.empty()) {
497 flr = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
498 } else {
499 OP_REQUIRES_OK(context, context->function_library()->Clone(
500 &flib_def, &pflr, &flr, true));
501 }
502
503 ResourceMgr* mgr = context->resource_manager();
504 OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
505
506 IteratorResource* resource;
507 OP_REQUIRES_OK(
508 context,
509 mgr->LookupOrCreate<IteratorResource>(
510 cinfo_.container(), cinfo_.name(), &resource,
511 [context, flr, &device_mgr, &flib_def, &pflr,
512 this](IteratorResource** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
513 *ret = new IteratorResource(
514 context->env(), output_dtypes_, output_shapes_,
515 std::move(device_mgr), std::move(flib_def), std::move(pflr),
516 flr);
517 return Status::OK();
518 }));
519
520 Status s = VerifyResource(resource);
521 if (TF_PREDICT_FALSE(!s.ok())) {
522 resource->Unref();
523 context->SetStatus(s);
524 return;
525 }
526
527 resource_ = resource;
528 }
529 }
530 OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
531 context, 0, cinfo_.container(), cinfo_.name(),
532 TypeIndex::Make<IteratorResource>()));
533 }
534
VerifyResource(IteratorResource * resource)535 Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
536 TF_RETURN_IF_ERROR(
537 VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
538 TF_RETURN_IF_ERROR(
539 VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
540 return Status::OK();
541 }
542
CreatePrivateFLR(OpKernelContext * ctx,std::unique_ptr<DeviceMgr> * device_mgr,std::unique_ptr<FunctionLibraryDefinition> * flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * pflr)543 FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
544 OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
545 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
546 std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
547 // Wrap the existing device in order to see any captured resources
548 // in its resource manager. The existing device will outlive the
549 // IteratorResource, because we are storing the IteratorResource
550 // in that device's resource manager.
551
552 *device_mgr =
553 absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
554 ctx->device()->name(), down_cast<Device*>(ctx->device()),
555 false /* owns_underlying */, false /* isolate_session_state */));
556 *flib_def = absl::make_unique<FunctionLibraryDefinition>(
557 *ctx->function_library()->GetFunctionLibraryDefinition());
558 const auto* config = ctx->function_library()->config_proto();
559 *pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
560 device_mgr->get(), ctx->env(),
561 /*config=*/config, graph_def_version_, flib_def->get(),
562 config->graph_options().optimizer_options());
563
564 return (*pflr)->GetFLR(ctx->device()->name());
565 }
566
567 // Like IteratorHandleOp, but creates handles which are never shared, and does
568 // not hold a reference to these handles. The latter is important for eager
569 // execution, since OpKernel instances generally live as long as the program
570 // running them.
AnonymousIteratorHandleOp(OpKernelConstruction * context)571 AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
572 OpKernelConstruction* context)
573 : AnonymousResourceOp<IteratorResource>(context),
574 graph_def_version_(context->graph_def_version()) {
575 OP_REQUIRES_OK(context, context->GetAttr(kOutputTypes, &output_dtypes_));
576 OP_REQUIRES_OK(context, context->GetAttr(kOutputShapes, &output_shapes_));
577 create_deleter_ = context->def().op() == kAnonymousIteratorV2;
578 }
579
name()580 string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; }
581
CreateResource(OpKernelContext * ctx,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib,IteratorResource ** resource)582 Status AnonymousIteratorHandleOp::CreateResource(
583 OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
584 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
585 FunctionLibraryRuntime* lib, IteratorResource** resource) {
586 std::unique_ptr<DeviceMgr> device_mgr(nullptr);
587 *resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_,
588 std::move(device_mgr), std::move(flib_def),
589 std::move(pflr), lib);
590 return Status::OK();
591 }
592
HybridAsyncOpKernel(OpKernelConstruction * ctx,const char * background_worker_name)593 HybridAsyncOpKernel::HybridAsyncOpKernel(OpKernelConstruction* ctx,
594 const char* background_worker_name)
595 : AsyncOpKernel(ctx),
596 background_worker_(ctx->env(), background_worker_name) {}
597
ComputeAsync(OpKernelContext * ctx,DoneCallback done)598 void HybridAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
599 DoneCallback done) {
600 background_worker_.Schedule([this, ctx, done = std::move(done)]() {
601 ctx->SetStatus(DoCompute(ctx));
602 done();
603 });
604 }
605
Compute(OpKernelContext * ctx)606 void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) {
607 ctx->SetStatus(DoCompute(ctx));
608 }
609
DoCompute(OpKernelContext * ctx)610 Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
611 tensorflow::ResourceTagger tag(kTFDataResourceTag,
612 ctx->op_kernel().type_string());
613 DatasetBase* dataset;
614 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
615 IteratorResource* iterator_resource;
616 TF_RETURN_IF_ERROR(
617 LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
618 core::ScopedUnref unref_iterator(iterator_resource);
619 return iterator_resource->SetIteratorFromDataset(ctx, dataset);
620 }
621
DoCompute(OpKernelContext * ctx)622 Status DeleteIteratorOp::DoCompute(OpKernelContext* ctx) {
623 tensorflow::ResourceTagger tag(kTFDataResourceTag,
624 ctx->op_kernel().type_string());
625 const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
626 // The iterator resource is guaranteed to exist because the variant tensor
627 // wrapping the deleter is provided as an unused input to this op, which
628 // guarantees that it has not run yet.
629 return ctx->resource_manager()->Delete(handle);
630 }
631
632 namespace {
633
634 class ToSingleElementOp : public AsyncOpKernel {
635 public:
ToSingleElementOp(OpKernelConstruction * ctx)636 explicit ToSingleElementOp(OpKernelConstruction* ctx)
637 : AsyncOpKernel(ctx),
638 unbounded_threadpool_(ctx->env(), "tf_data_to_single_element") {
639 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
640 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
641 }
642
ComputeAsync(OpKernelContext * ctx,DoneCallback done)643 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
644 unbounded_threadpool_.Schedule([this, ctx, done = std::move(done)]() {
645 ctx->SetStatus(DoCompute(ctx));
646 done();
647 });
648 }
649
Compute(OpKernelContext * ctx)650 void Compute(OpKernelContext* ctx) override {
651 ctx->SetStatus(DoCompute(ctx));
652 }
653
654 private:
DoCompute(OpKernelContext * ctx)655 Status DoCompute(OpKernelContext* ctx) {
656 profiler::TraceMe traceme(
657 [&] {
658 return profiler::TraceMeEncode("ToSingleElementOp::DoCompute",
659 {{"id", ctx->step_id()}});
660 },
661 profiler::kInfo);
662 tensorflow::ResourceTagger tag(kTFDataResourceTag,
663 ctx->op_kernel().type_string());
664 DatasetBase* dataset;
665 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
666
667 IteratorContext::Params params(ctx);
668 ResourceMgr resource_mgr;
669 params.resource_mgr = &resource_mgr;
670 CancellationManager cancellation_manager(ctx->cancellation_manager());
671 params.cancellation_manager = &cancellation_manager;
672
673 IteratorContext iter_ctx(std::move(params));
674 std::unique_ptr<IteratorBase> iterator;
675 TF_RETURN_IF_ERROR(dataset->MakeIterator(
676 &iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator));
677
678 std::vector<Tensor> components;
679 components.reserve(dataset->output_dtypes().size());
680 bool end_of_sequence = false;
681
682 TF_RETURN_IF_ERROR(
683 iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
684
685 if (end_of_sequence) {
686 return errors::InvalidArgument("Dataset was empty.");
687 }
688 TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
689 TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
690 for (int i = 0; i < components.size(); ++i) {
691 ctx->set_output(i, components[i]);
692 }
693
694 components.clear();
695 TF_RETURN_IF_ERROR(
696 iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
697 if (!end_of_sequence) {
698 return errors::InvalidArgument("Dataset had more than one element.");
699 }
700 return Status::OK();
701 }
702
703 UnboundedThreadPool unbounded_threadpool_;
704 DataTypeVector output_types_;
705 std::vector<PartialTensorShape> output_shapes_;
706 };
707
708 class OneShotIteratorOp : public AsyncOpKernel {
709 public:
OneShotIteratorOp(OpKernelConstruction * ctx)710 explicit OneShotIteratorOp(OpKernelConstruction* ctx)
711 : AsyncOpKernel(ctx),
712 background_worker_(ctx->env(), "tf_data_one_shot_iterator"),
713 graph_def_version_(ctx->graph_def_version())
714
715 {
716 string shared_name;
717 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name));
718 OP_REQUIRES(ctx, shared_name.empty(),
719 errors::InvalidArgument("OneShotIteratorOp does not currently "
720 "support the 'shared_name' attr."));
721 OP_REQUIRES_OK(ctx,
722 ctx->GetAttr("dataset_factory", &dataset_factory_func_));
723 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
724 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
725 }
726
~OneShotIteratorOp()727 ~OneShotIteratorOp() override {
728 if (iterator_resource_ != nullptr) {
729 iterator_resource_->Unref();
730 if (!cinfo_.resource_manager()
731 ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name())
732 .ok()) {
733 // Do nothing; the resource can have been deleted by session resets.
734 }
735 }
736 }
737
738 // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`,
739 // but due to the fact that `ResourceOpKernel<T>::CreateResource()`
740 // does not provide access to the `OpKernelContext*` and we need
741 // this to invoke the factory function, it's not possible to
742 // implement this kernel by implementing `CreateResource()`.
743 // Furthermore, due to the fact that this kernel might block when
744 // running the initialization function, we must implement this
745 // kernel as an async kernel.
ComputeAsync(OpKernelContext * ctx,DoneCallback done)746 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
747 tensorflow::ResourceTagger tag(kTFDataResourceTag,
748 ctx->op_kernel().type_string());
749 {
750 mutex_lock l(mu_);
751 if (iterator_resource_ == nullptr && initialization_status_.ok()) {
752 // The initialization thread will call `done`.
753 if (!initialization_started_) {
754 // TODO(mrry): Convert the initialization code to use
755 // callbacks instead of wasting a thread.
756 background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); });
757 initialization_started_ = true;
758 } else {
759 done_callbacks_.emplace_back(ctx, std::move(done));
760 }
761 return;
762 }
763 }
764 ProduceOutput(ctx, done);
765 }
766
767 private:
Init(OpKernelContext * ctx,const DoneCallback & done)768 void Init(OpKernelContext* ctx, const DoneCallback& done) {
769 IteratorResource* iterator = nullptr;
770 ContainerInfo cinfo;
771 Status s = TryInit(ctx, &iterator, &cinfo);
772
773 std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run;
774 {
775 mutex_lock l(mu_);
776 if (s.ok()) {
777 iterator_resource_ = iterator;
778 cinfo_ = cinfo;
779 }
780 initialization_status_ = s;
781 std::swap(done_callbacks_, callbacks_to_run);
782 }
783
784 for (auto&& ctx_done : callbacks_to_run) {
785 ProduceOutput(ctx_done.first, ctx_done.second);
786 }
787 ProduceOutput(ctx, done);
788 }
789
TryInit(OpKernelContext * ctx,IteratorResource ** iterator,ContainerInfo * cinfo)790 Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
791 ContainerInfo* cinfo) {
792 TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def()));
793
794 FunctionLibraryRuntime* flr;
795 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
796 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
797 TF_RETURN_IF_ERROR(
798 ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
799
800 // Create an IteratorResource that will hold the iterator for this op.
801 TF_RETURN_IF_ERROR(
802 ctx->resource_manager()->LookupOrCreate<IteratorResource>(
803 cinfo->container(), cinfo->name(), iterator,
804 [ctx, flr, this, &flib_def, &pflr](IteratorResource** ret)
805 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
806 *ret = new IteratorResource(
807 ctx->env(), output_dtypes_, output_shapes_,
808 /*device_mgr=*/nullptr, std::move(flib_def),
809 std::move(pflr), flr);
810 return Status::OK();
811 }));
812
813 core::ScopedUnref unref_iterator(*iterator);
814
815 TF_RETURN_IF_ERROR(
816 VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes()));
817 TF_RETURN_IF_ERROR(
818 VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes()));
819
820 // Call the dataset_factory_func_ to create a new dataset,
821 // over which this op will iterate.
822 FunctionLibraryRuntime::Handle f_handle;
823 TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
824 dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()),
825 &f_handle));
826 FunctionLibraryRuntime::Options opts;
827 opts.cancellation_manager = ctx->cancellation_manager();
828 ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
829 ctx->resource_manager()->Cleanup(name).IgnoreError();
830 });
831 opts.step_container = &step_container;
832 opts.runner = ctx->runner();
833 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
834 std::vector<Tensor> return_values;
835 TF_RETURN_IF_ERROR(ctx->function_library()->RunSync(
836 std::move(opts), f_handle, {}, &return_values));
837 if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
838 !TensorShapeUtils::IsScalar(return_values[0].shape())) {
839 return errors::InvalidArgument(
840 "The `dataset_factory` function must return "
841 "a single scalar of dtype DT_VARIANT.");
842 }
843
844 // Create an iterator for the dataset that was created in the
845 // factory function.
846 DatasetBase* dataset;
847 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
848 TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
849 (*iterator)->Ref();
850 return Status::OK();
851 }
852
ProduceOutput(OpKernelContext * ctx,const DoneCallback & done)853 void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) {
854 Tensor* handle;
855 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle),
856 done);
857 Status s;
858 {
859 mutex_lock l(mu_);
860 s = initialization_status_;
861 if (s.ok()) {
862 handle->scalar<ResourceHandle>()() =
863 MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(),
864 cinfo_.name());
865 }
866 }
867 OP_REQUIRES_OK_ASYNC(ctx, s, done);
868 done();
869 }
870
871 NameAttrList dataset_factory_func_;
872 DataTypeVector output_dtypes_;
873 std::vector<PartialTensorShape> output_shapes_;
874
875 BackgroundWorker background_worker_;
876
877 mutex mu_;
878 ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
879 IteratorResource* iterator_resource_ TF_GUARDED_BY(mu_) = nullptr;
880
881 bool initialization_started_ TF_GUARDED_BY(mu_) = false;
882 Status initialization_status_ TF_GUARDED_BY(mu_);
883 std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_
884 TF_GUARDED_BY(mu_);
885 const int graph_def_version_;
886 };
887
888 } // namespace
889
AsAsync()890 AsyncOpKernel* IteratorGetNextOp::AsAsync() {
891 return type_string() == "IteratorGetNextSync" ? nullptr : this;
892 }
893
RecordElementSize(const std::vector<Tensor> element,profiler::TraceMe * traceme)894 void RecordElementSize(const std::vector<Tensor> element,
895 profiler::TraceMe* traceme) {
896 traceme->AppendMetadata([&]() {
897 int64_t element_size = 0;
898 for (const auto& component : element) {
899 element_size += component.TotalBytes();
900 }
901 return profiler::TraceMeEncode({{"element_size", element_size}});
902 });
903 }
904
DoCompute(OpKernelContext * ctx)905 Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
906 profiler::TraceMe traceme(
907 [&] {
908 int64_t mem_bw = port::GetMemoryBandwidthInfo().bw_used;
909
910 if (mem_bw != INT64_MAX) {
911 return profiler::TraceMeEncode(
912 "IteratorGetNextOp::DoCompute",
913 {{"id", ctx->step_id()},
914 {"iter_num", ctx->frame_iter().iter_id},
915 {"mem_bw_used_megabytes_per_sec", mem_bw}});
916 }
917 return profiler::TraceMeEncode(
918 "IteratorGetNextOp::DoCompute",
919 {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
920 },
921 profiler::kInfo);
922 tensorflow::ResourceTagger tag(kTFDataResourceTag,
923 ctx->op_kernel().type_string());
924 IteratorResource* iterator;
925 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
926 core::ScopedUnref unref_iterator(iterator);
927 std::vector<Tensor> components;
928 bool end_of_sequence = false;
929
930 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
931 if (end_of_sequence) {
932 return errors::OutOfRange("End of sequence");
933 }
934 TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
935 TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
936 RecordElementSize(components, &traceme);
937 for (int i = 0; i < components.size(); ++i) {
938 ctx->set_output(i, components[i]);
939 }
940 return Status::OK();
941 }
942
DoCompute(OpKernelContext * ctx)943 Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
944 profiler::TraceMe traceme(
945 [&] {
946 return profiler::TraceMeEncode(
947 "IteratorGetNextAsOptionalOp::DoCompute",
948 {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
949 },
950 profiler::kInfo);
951 tensorflow::ResourceTagger tag(kTFDataResourceTag,
952 ctx->op_kernel().type_string());
953 IteratorResource* iterator;
954 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
955 core::ScopedUnref unref_iterator(iterator);
956 std::vector<Tensor> components;
957 bool end_of_sequence = false;
958
959 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
960
961 if (end_of_sequence) {
962 return WriteOptionalNoneToOutput(ctx, 0);
963 } else {
964 RecordElementSize(components, &traceme);
965 for (int i = 0; i < components.size(); ++i) {
966 if (components[i].dtype() != output_types_[i]) {
967 return errors::InvalidArgument(
968 "The given optional does not match the expected type for "
969 "component ",
970 i, ". Expected: ", DataTypeString(output_types_[i]),
971 ". Actual: ", DataTypeString(components[i].dtype()), ".");
972 }
973 if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
974 return errors::InvalidArgument(
975 "The given optional does not match the expected shape "
976 "for component ",
977 i, ". Expected: ", output_shapes_[i].DebugString(),
978 ". Actual: ", components[i].shape().DebugString(), ".");
979 }
980 }
981 return WriteOptionalWithValueToOutput(ctx, 0, std::move(components));
982 }
983 }
984
Compute(OpKernelContext * ctx)985 void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
986 const Tensor& resource_handle_t = ctx->input(0);
987 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
988 errors::InvalidArgument("resource_handle must be a scalar"));
989
990 // Validate that the handle corresponds to a real resource, and
991 // that it is an IteratorResource.
992 IteratorResource* iterator_resource;
993 OP_REQUIRES_OK(
994 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
995 iterator_resource->Unref();
996
997 Tensor* string_handle_t;
998 OP_REQUIRES_OK(ctx,
999 ctx->allocate_output(0, TensorShape({}), &string_handle_t));
1000 string_handle_t->scalar<tstring>()() =
1001 resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
1002 }
1003
IteratorFromStringHandleOp(OpKernelConstruction * ctx)1004 IteratorFromStringHandleOp::IteratorFromStringHandleOp(
1005 OpKernelConstruction* ctx)
1006 : OpKernel(ctx) {
1007 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
1008 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1009 OP_REQUIRES(
1010 ctx,
1011 output_dtypes_.empty() || output_shapes_.empty() ||
1012 output_dtypes_.size() == output_shapes_.size(),
1013 errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
1014 "are set, they must have the same length."));
1015 }
1016
Compute(OpKernelContext * ctx)1017 void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
1018 const Tensor& string_handle_t = ctx->input(0);
1019 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
1020 errors::InvalidArgument("string_handle must be a scalar"));
1021
1022 ResourceHandle resource_handle;
1023 OP_REQUIRES(
1024 ctx, resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
1025 errors::InvalidArgument(
1026 "Could not parse string_handle as a valid ResourceHandle"));
1027
1028 OP_REQUIRES(
1029 ctx, resource_handle.device() == ctx->device()->attributes().name(),
1030 errors::InvalidArgument("Attempted create an iterator on device \"",
1031 ctx->device()->attributes().name(),
1032 "\" from handle defined on device \"",
1033 resource_handle.device(), "\""));
1034
1035 // Validate that the handle corresponds to a real resource, and
1036 // that it is an IteratorResource.
1037 IteratorResource* iterator_resource;
1038 OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource));
1039 core::ScopedUnref unref_iterator(iterator_resource);
1040 if (!output_dtypes_.empty()) {
1041 OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
1042 iterator_resource->output_dtypes()));
1043 }
1044 if (!output_shapes_.empty()) {
1045 OP_REQUIRES_OK(ctx,
1046 VerifyShapesCompatible(output_shapes_,
1047 iterator_resource->output_shapes()));
1048 }
1049
1050 Tensor* resource_handle_t;
1051 OP_REQUIRES_OK(ctx,
1052 ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
1053 resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
1054 }
1055
SerializeIteratorOp(OpKernelConstruction * ctx)1056 SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
1057 : OpKernel(ctx) {
1058 if (ctx->HasAttr(kExternalStatePolicy)) {
1059 int64_t state_change_option;
1060 OP_REQUIRES_OK(ctx,
1061 ctx->GetAttr(kExternalStatePolicy, &state_change_option));
1062 external_state_policy_ =
1063 SerializationContext::ExternalStatePolicy(state_change_option);
1064 }
1065 }
1066
Compute(OpKernelContext * ctx)1067 void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
1068 tensorflow::ResourceTagger tag(kTFDataResourceTag,
1069 ctx->op_kernel().type_string());
1070 const Tensor& resource_handle_t = ctx->input(0);
1071 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1072 errors::InvalidArgument("resource_handle must be a scalar"));
1073 // Validate that the handle corresponds to a real resource, and
1074 // that it is an IteratorResource.
1075 IteratorResource* iterator_resource;
1076 OP_REQUIRES_OK(
1077 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1078 core::ScopedUnref unref_iterator(iterator_resource);
1079 IteratorVariantSerializer serializer;
1080 SerializationContext::Params params(ctx);
1081 params.external_state_policy = external_state_policy_;
1082 SerializationContext serialization_ctx(params);
1083 OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
1084 iterator_resource));
1085 Tensor* serialized_t;
1086 OP_REQUIRES_OK(ctx,
1087 ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
1088 &serialized_t));
1089 OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
1090 }
1091
Compute(OpKernelContext * ctx)1092 void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
1093 tensorflow::ResourceTagger tag(kTFDataResourceTag,
1094 ctx->op_kernel().type_string());
1095 // Validate that the handle corresponds to a real resource, and
1096 // that it is an IteratorResource.
1097 IteratorResource* iterator_resource;
1098 OP_REQUIRES_OK(
1099 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1100 core::ScopedUnref unref_iterator(iterator_resource);
1101 const Tensor* serialized_t;
1102 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
1103 IteratorVariantSerializer serializer;
1104 OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
1105 Status s = iterator_resource->Restore(ctx, serializer.GetReader());
1106 if (!s.ok()) {
1107 OP_REQUIRES_OK(
1108 ctx,
1109 Status(s.code(),
1110 absl::StrCat(
1111 "Failed to restore dataset iterator from checkpoint: ",
1112 s.error_message(),
1113 ". Make sure the dataset definition has not changed between "
1114 "the process that saved the checkpoint and the process that "
1115 "is restoring it.")));
1116 }
1117 }
1118
1119 namespace {
1120
1121 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
1122 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
1123 IteratorHandleOp);
1124 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU).Priority(1),
1125 IteratorHandleOp);
1126 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2),
1127 MakeIteratorOp);
1128 REGISTER_KERNEL_BUILDER(
1129 Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"),
1130 MakeIteratorOp);
1131 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_CPU).Priority(2),
1132 DeleteIteratorOp);
1133 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_GPU).Priority(1),
1134 DeleteIteratorOp);
1135 REGISTER_KERNEL_BUILDER(
1136 Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2),
1137 AnonymousIteratorHandleOp);
1138 REGISTER_KERNEL_BUILDER(
1139 Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1),
1140 AnonymousIteratorHandleOp);
1141 REGISTER_KERNEL_BUILDER(
1142 Name("AnonymousIteratorV2").Device(DEVICE_CPU).Priority(2),
1143 AnonymousIteratorHandleOp);
1144 REGISTER_KERNEL_BUILDER(
1145 Name("AnonymousIteratorV2").Device(DEVICE_GPU).Priority(1),
1146 AnonymousIteratorHandleOp);
1147 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
1148 ToSingleElementOp);
1149 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
1150 OneShotIteratorOp);
1151 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU).Priority(2),
1152 IteratorGetNextOp);
1153 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU).Priority(1),
1154 IteratorGetNextOp);
1155 REGISTER_KERNEL_BUILDER(
1156 Name("IteratorGetNextSync").Device(DEVICE_CPU).Priority(2),
1157 IteratorGetNextOp);
1158 REGISTER_KERNEL_BUILDER(
1159 Name("IteratorGetNextSync").Device(DEVICE_GPU).Priority(1),
1160 IteratorGetNextOp);
1161 REGISTER_KERNEL_BUILDER(
1162 Name("IteratorGetNextAsOptional").Device(DEVICE_CPU).Priority(2),
1163 IteratorGetNextAsOptionalOp);
1164 REGISTER_KERNEL_BUILDER(
1165 Name("IteratorGetNextAsOptional").Device(DEVICE_GPU).Priority(1),
1166 IteratorGetNextAsOptionalOp);
1167 REGISTER_KERNEL_BUILDER(
1168 Name("IteratorToStringHandle").Device(DEVICE_CPU).Priority(2),
1169 IteratorToStringHandleOp);
1170 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
1171 .Device(DEVICE_GPU)
1172 .HostMemory("string_handle")
1173 .Priority(1),
1174 IteratorToStringHandleOp);
1175 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
1176 IteratorFromStringHandleOp);
1177 REGISTER_KERNEL_BUILDER(
1178 Name("IteratorFromStringHandleV2").Device(DEVICE_CPU).Priority(2),
1179 IteratorFromStringHandleOp);
1180 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
1181 .Device(DEVICE_GPU)
1182 .HostMemory("string_handle")
1183 .Priority(1),
1184 IteratorFromStringHandleOp);
1185 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
1186 SerializeIteratorOp);
1187 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
1188 DeserializeIteratorOp);
1189
1190 } // namespace
1191
1192 } // namespace data
1193 } // namespace tensorflow
1194