• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/iterator_ops.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/common_runtime/graph_constructor.h"
24 #include "tensorflow/core/common_runtime/graph_runner.h"
25 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
26 #include "tensorflow/core/common_runtime/metrics.h"
27 #include "tensorflow/core/common_runtime/renamed_device.h"
28 #include "tensorflow/core/common_runtime/threadpool_device.h"
29 #include "tensorflow/core/data/captured_function.h"
30 #include "tensorflow/core/data/dataset_utils.h"
31 #include "tensorflow/core/data/root_dataset.h"
32 #include "tensorflow/core/data/serialization_utils.h"
33 #include "tensorflow/core/framework/cancellation.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/metrics.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/partial_tensor_shape.h"
38 #include "tensorflow/core/framework/resource_op_kernel.h"
39 #include "tensorflow/core/framework/stats_aggregator.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/framework/variant_op_registry.h"
43 #include "tensorflow/core/framework/variant_tensor_data.h"
44 #include "tensorflow/core/kernels/data/optional_ops.h"
45 #include "tensorflow/core/kernels/ops_util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/refcount.h"
48 #include "tensorflow/core/lib/core/threadpool.h"
49 #include "tensorflow/core/lib/gtl/cleanup.h"
50 #include "tensorflow/core/lib/random/random.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/lib/strings/stringprintf.h"
53 #include "tensorflow/core/platform/casts.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/mem.h"
56 #include "tensorflow/core/platform/mutex.h"
57 #include "tensorflow/core/platform/refcount.h"
58 #include "tensorflow/core/platform/resource.h"
59 #include "tensorflow/core/profiler/lib/traceme.h"
60 #include "tensorflow/core/profiler/lib/traceme_encode.h"
61 #include "tensorflow/core/public/session_options.h"
62 
63 namespace tensorflow {
64 namespace data {
65 namespace {
66 
67 // See documentation in ../../ops/dataset_ops.cc for a high-level
68 // description of the following ops.
69 
70 const char kAnonymousIterator[] = "AnonymousIterator";
71 const char kAnonymousIteratorV2[] = "AnonymousIteratorV2";
72 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
73 const char kOutputShapes[] = "output_shapes";
74 const char kOutputTypes[] = "output_types";
75 
76 // Safely subtracts x from y avoiding underflow.
safe_sub(uint64 x,uint64 y)77 inline uint64 safe_sub(uint64 x, uint64 y) { return x >= y ? x - y : 0; }
78 
79 }  // namespace
80 
81 /* static */ constexpr const char* const
82     SerializeIteratorOp::kExternalStatePolicy;
83 
IteratorResource(Env * env,const DataTypeVector & output_dtypes,const std::vector<PartialTensorShape> & output_shapes,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr)84 IteratorResource::IteratorResource(
85     Env* env, const DataTypeVector& output_dtypes,
86     const std::vector<PartialTensorShape>& output_shapes,
87     std::unique_ptr<DeviceMgr> device_mgr,
88     std::unique_ptr<FunctionLibraryDefinition> flib_def,
89     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
90     FunctionLibraryRuntime* flr)
91     : unbounded_thread_pool_(env, "tf_data_iterator_resource"),
92       device_mgr_(std::move(device_mgr)),
93       iterator_state_(std::make_shared<State>(std::move(flib_def),
94                                               std::move(pflr), flr,
95                                               /*iterator=*/nullptr)),
96       output_dtypes_(output_dtypes),
97       output_shapes_(output_shapes),
98       // We do not collect iterator resource metrics for non-CPU devices. This
99       // is a heuristic to avoid collecting metrics for device-side iterators
100       // created by the multi-device iterator mechanism.
101       collect_metrics_(flr->device()->device_type() == DEVICE_CPU) {
102   VLOG(2) << "creating iterator resource";
103 }
104 
~IteratorResource()105 IteratorResource::~IteratorResource() {
106   VLOG(2) << "destroying iterator resource";
107 }
108 
GetNext(OpKernelContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)109 Status IteratorResource::GetNext(OpKernelContext* ctx,
110                                  std::vector<Tensor>* out_tensors,
111                                  bool* end_of_sequence) {
112   std::shared_ptr<State> captured_state;
113   {
114     tf_shared_lock l(mu_);
115     captured_state = iterator_state_;
116   }
117   if (!captured_state->iterator()) {
118     return errors::FailedPrecondition(
119         "GetNext() failed because the iterator has not been initialized. "
120         "Ensure that you have run the initializer operation for this iterator "
121         "before getting the next element.");
122   }
123   IteratorContext::Params params(ctx);
124   params.flr = captured_state->flr();
125   params.function_handle_cache = captured_state->function_handle_cache();
126   params.resource_mgr = captured_state->resource_mgr();
127   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
128   params.thread_pool = &unbounded_thread_pool_;
129   params.cancellation_manager = captured_state->cancellation_manager();
130   std::function<void()> deregister_fn;
131   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
132       ctx->cancellation_manager(),
133       [cm = params.cancellation_manager]() { cm->StartCancel(); },
134       &deregister_fn));
135   auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
136   const uint64 start_time_us = ctx->env()->NowMicros();
137   if (collect_metrics_) {
138     mutex_lock l(mu_);
139     if (get_next_end_time_us_ == 0) {
140       // We initialize `get_next_end_time_us_` to the start time of the first
141       // request to make it possible to use the delta between
142       // `get_next_end_time_us_` and subsequent `GetNext()` end time to
143       // incrementally collect the duration of the iterator's lifetime.
144       get_next_end_time_us_ = start_time_us;
145     }
146     if (num_get_next_calls_ == 0) {
147       get_next_start_time_us_ = start_time_us;
148     }
149     num_get_next_calls_++;
150   }
151   auto iterator_ = captured_state->iterator();
152   auto status = iterator_->GetNext(IteratorContext(std::move(params)),
153                                    out_tensors, end_of_sequence);
154   if (collect_metrics_) {
155     const uint64 end_time_us = ctx->env()->NowMicros();
156     metrics::RecordTFDataGetNextDuration(safe_sub(end_time_us, start_time_us));
157     metrics::RecordTFDataBytesFetched(GetTotalBytes(*out_tensors));
158     mutex_lock l(mu_);
159     metrics::RecordTFDataIteratorLifetime(
160         safe_sub(end_time_us, get_next_end_time_us_));
161     get_next_end_time_us_ = std::max(get_next_end_time_us_, end_time_us);
162     num_get_next_calls_--;
163     if (num_get_next_calls_ == 0) {
164       metrics::RecordTFDataIteratorBusy(
165           safe_sub(get_next_end_time_us_, get_next_start_time_us_));
166     }
167   }
168   return status;
169 }
170 
Save(SerializationContext * ctx,IteratorStateWriter * writer)171 Status IteratorResource::Save(SerializationContext* ctx,
172                               IteratorStateWriter* writer) {
173   std::shared_ptr<State> captured_state;
174   {
175     tf_shared_lock l(mu_);
176     captured_state = iterator_state_;
177   }
178   auto iterator_ = captured_state->iterator();
179   if (iterator_) {
180     return iterator_->Save(ctx, writer);
181   }
182   return errors::FailedPrecondition(
183       "Save() failed because the iterator has not been initialized. Ensure "
184       "that you have run the initializer operation for this iterator before "
185       "saving it.");
186 }
187 
Restore(OpKernelContext * ctx,IteratorStateReader * reader)188 Status IteratorResource::Restore(OpKernelContext* ctx,
189                                  IteratorStateReader* reader) {
190   const DatasetBase* dataset;
191   std::shared_ptr<State> new_state;
192   {
193     tf_shared_lock l(mu_);
194     if (!iterator_state_->iterator()) {
195       return errors::FailedPrecondition(
196           "Restore() failed because the iterator has not been initialized. "
197           "Ensure that you have run the initializer operation for this "
198           "iterator before restoring it.");
199     }
200     auto iterator_ = iterator_state_->iterator();
201     dataset = iterator_->dataset();
202     // Hang onto a reference until we've created the new iterator, which will
203     // then hold its own reference to keep the dataset alive.
204     dataset->Ref();
205     new_state =
206         std::make_shared<State>(iterator_state_->flib_def(),
207                                 iterator_state_->pflr(), iterator_state_->flr(),
208                                 /*iterator=*/nullptr);
209   }
210   core::ScopedUnref scoped_unref(dataset);
211   IteratorContext::Params params(ctx);
212   params.flr = new_state->flr();
213   params.function_handle_cache = new_state->function_handle_cache();
214   params.resource_mgr = new_state->resource_mgr();
215   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
216   params.thread_pool = &unbounded_thread_pool_;
217   params.cancellation_manager = new_state->cancellation_manager();
218   std::function<void()> deregister_fn;
219   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
220       ctx->cancellation_manager(),
221       [cm = params.cancellation_manager]() { cm->StartCancel(); },
222       &deregister_fn));
223   auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
224   std::unique_ptr<IteratorBase> iterator_base;
225   TF_RETURN_IF_ERROR(dataset->MakeIteratorFromCheckpoint(
226       IteratorContext(std::move(params)), "Iterator", reader, &iterator_base));
227   new_state->DowncastAndSetIterator(std::move(iterator_base));
228 
229   mutex_lock l(mu_);
230   std::swap(iterator_state_, new_state);
231   return Status::OK();
232 }
233 
SetIteratorFromDataset(OpKernelContext * ctx,DatasetBase * dataset)234 Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
235                                                 DatasetBase* dataset) {
236   std::shared_ptr<State> new_state;
237   {
238     tf_shared_lock l(mu_);
239     new_state =
240         std::make_shared<State>(iterator_state_->flib_def(),
241                                 iterator_state_->pflr(), iterator_state_->flr(),
242                                 /*iterator=*/nullptr);
243   }
244 
245   // Create new iterator.
246   IteratorContext::Params params(ctx);
247   params.flr = new_state->flr();
248   params.function_handle_cache = new_state->function_handle_cache();
249   params.resource_mgr = new_state->resource_mgr();
250   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
251   params.thread_pool = &unbounded_thread_pool_;
252   params.cancellation_manager = new_state->cancellation_manager();
253   std::function<void()> deregister_fn;
254   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
255       ctx->cancellation_manager(),
256       [cm = params.cancellation_manager]() { cm->StartCancel(); },
257       &deregister_fn));
258   auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
259 
260   std::unique_ptr<IteratorBase> iterator;
261   if (ctx->function_library()->device()->device_type() == DEVICE_CPU) {
262     DatasetBase* finalized_dataset;
263     TF_RETURN_IF_ERROR(FinalizeDataset(ctx, dataset, &finalized_dataset));
264     core::ScopedUnref unref(finalized_dataset);
265     TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
266         IteratorContext(std::move(params)),
267         /*parent=*/nullptr, "Iterator", &iterator));
268   } else {
269     TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
270                                              /*parent=*/nullptr, "Iterator",
271                                              &iterator));
272   }
273   TF_RETURN_IF_ERROR(
274       VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
275   TF_RETURN_IF_ERROR(
276       VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
277 
278   new_state->DowncastAndSetIterator(std::move(iterator));
279 
280   mutex_lock l(mu_);
281   std::swap(iterator_state_, new_state);
282   return Status::OK();
283 }
284 
285 namespace {
286 
287 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
288 // The get() method returns an VariantTensorData object which contains all the
289 // state needed to restore a single iterator.
290 //
291 // Usage example:
292 //
293 // Encoding:
294 //
295 //   Tensor t(DT_VARIANT, TensorShape({}));
296 //   t->scalar<Variant>()() = IteratorStateVariant();
297 //
298 // Encode() sets the type_name of the VariantTensorData object to
299 // IteratorStateVariant::TypeName().
300 //
301 // Decoding:
302 //
303 //   Variant v = <VariantTensorDataProto object>;
304 //   DecodeUnaryVariant(&v);
305 //   IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
306 //   IteratorStateReader reader({wrapper->GetData()});
307 //   iterator_resource->Restore(ctx, &reader);
308 //
309 // The type_name of the VariantTensorData object to be decoded must
310 // match IteratorStateVariant::TypeName().
311 class IteratorStateVariant {
312  public:
IteratorStateVariant()313   IteratorStateVariant() : data_(nullptr) {}
IteratorStateVariant(const IteratorStateVariant & other)314   IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
315     if (other.data_) {
316       Decode(*other.data_);
317     }
318   }
319   IteratorStateVariant& operator=(IteratorStateVariant&& other) = default;
320   IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete;
321 
322   // Initializes `this` from a VariantTensorData object.
InitializeFromVariantData(std::unique_ptr<VariantTensorData> d)323   Status InitializeFromVariantData(std::unique_ptr<VariantTensorData> d) {
324     data_ = std::move(d);
325     return Status::OK();
326   }
327 
TypeName() const328   string TypeName() const { return kIteratorVariantTypeName; }
Encode(VariantTensorData * data) const329   void Encode(VariantTensorData* data) const { *data = *data_; }
Decode(VariantTensorData data)330   bool Decode(VariantTensorData data) {
331     if (data.type_name() != TypeName()) {
332       return false;
333     }
334     auto tensor_data = absl::make_unique<VariantTensorData>();
335     std::swap(*tensor_data, data);
336     data_ = std::move(tensor_data);
337     return true;
338   }
339 
340   // Returns a borrowed pointer to the underlying VariantTensorData.
GetData() const341   const VariantTensorData* GetData() const { return data_.get(); }
342 
DebugString() const343   string DebugString() const {
344     if (data_) {
345       return strings::StrCat("IteratorStateVariant<", data_->DebugString(),
346                              ">");
347     } else {
348       return strings::StrCat("IteratorStateVariant<empty>");
349     }
350   }
351 
352  private:
353   std::unique_ptr<VariantTensorData> data_;
354 };
355 
356 // Register the reader class in the global variant decode_fn registry
357 // so that a Variant containing a serialized representation of iterator state
358 // can be decoded using DecodeUnaryVariant. If we don't do this we will need
359 // to manually decode the returned Variant using MaybeDecodeAndCopy in
360 // DeserializeIteratorOp which is not recommended.
361 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
362                                        kIteratorVariantTypeName);
363 
364 // A helper class that uses a list of IteratorStateVariant objects to represent
365 // the state for an iterator resource. It exposes methods that help with
366 // saving and restoring of this state. Sample usage
367 // Saving:
368 //   IteratorVariantSerializer serializer;
369 //   serializer.InitializeFromIterator(iterator_resource);
370 //   Tensor serialized_t;
371 //   serializer.Serialize(&serialized_t);
372 //
373 // Restoring:
374 //   IteratorVariantSerializer serializer;
375 //   serializer.InitFromTensor(ctx->input(0));
376 //   IteratorStateReader* reader = serializer.GetReader();
377 //   iterator_resource->Restore(ctx, reader);
378 class IteratorVariantSerializer {
379  public:
IteratorVariantSerializer()380   IteratorVariantSerializer() {}
381 
382   // Calls `Save` on the iterator_resource to build up the list of
383   // IteratorStateVariant objects.
InitializeFromIterator(SerializationContext * serialization_ctx,IteratorResource * iterator_resource)384   Status InitializeFromIterator(SerializationContext* serialization_ctx,
385                                 IteratorResource* iterator_resource) {
386     VariantTensorDataWriter writer;
387     TF_RETURN_IF_ERROR(iterator_resource->Save(serialization_ctx, &writer));
388     std::vector<std::unique_ptr<VariantTensorData>> data;
389     writer.ReleaseData(&data);
390     variants_.clear();
391     variants_.reserve(data.size());
392     for (auto& it : data) {
393       IteratorStateVariant v;
394       TF_RETURN_IF_ERROR(v.InitializeFromVariantData(std::move(it)));
395       variants_.push_back(v);
396     }
397     num_tensors_ = variants_.size();
398     can_serialize_ = true;
399     return Status::OK();
400   }
401 
402   // Initializes `this` from `serialized_t` while restoring the iterator state.
InitFromTensor(const Tensor * serialized_t)403   Status InitFromTensor(const Tensor* serialized_t) {
404     int64_t num_tensors = serialized_t->dim_size(0);
405     auto serialized_vec = serialized_t->vec<Variant>();
406     std::vector<const VariantTensorData*> data;
407     data.reserve(num_tensors);
408     for (int i = 0; i < num_tensors; ++i) {
409       auto* w = serialized_vec(i).get<IteratorStateVariant>();
410       if (!w) {
411         return errors::Internal(
412             "Cannot initialize an iterator from tensor ",
413             serialized_vec(i).DebugString(),
414             ". Expected a variant tensor of type IteratorStateVariant");
415       }
416       data.push_back(w->GetData());
417     }
418     reader_ = absl::make_unique<VariantTensorDataReader>(data);
419     num_tensors_ = data.size();
420     return Status::OK();
421   }
422 
NumTensors()423   int64 NumTensors() { return num_tensors_; }
424 
425   // Stores the IteratorStateVariant list into a pre-allocated tensor. Expects
426   // that InitializeFromIterator was called before.
Serialize(Tensor * serialized)427   Status Serialize(Tensor* serialized) {
428     if (!can_serialize_) {
429       return errors::InvalidArgument(
430           "Please call InitializeFromIterator before calling Serialize.");
431     }
432     int64_t size = variants_.size();
433     for (int64_t i = 0; i < size; ++i) {
434       if (variants_[i].GetData() == nullptr) {
435         return errors::Internal(
436             "Cannot serialize an empty IteratorStateVariant");
437       }
438       serialized->vec<Variant>()(i) = variants_[i];
439     }
440     return Status::OK();
441   }
442 
443   // Returns an IteratorStateReader to restore iterator state. Expects that
444   // InitFromTensor was called before.
GetReader()445   IteratorStateReader* GetReader() { return reader_.get(); }
446 
447  private:
448   bool can_serialize_ = false;
449   int64 num_tensors_;
450   std::vector<IteratorStateVariant> variants_;
451   std::unique_ptr<IteratorStateReader> reader_;
452 };
453 
454 }  // namespace
455 
456 // Note that IteratorHandleOp holds a reference to the resource it creates. If
457 // cleaning up resources with DestroyResourceOp is important, consider creating
458 // resource containers with AnonymousIteratorHandleOp instead.
IteratorHandleOp(OpKernelConstruction * ctx)459 IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
460     : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
461   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
462   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
463   OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
464 }
465 
466 // The resource is deleted from the resource manager only when it is private
467 // to kernel. Ideally the resource should be deleted when it is no longer held
468 // by anyone, but it would break backward compatibility.
~IteratorHandleOp()469 IteratorHandleOp::~IteratorHandleOp() {
470   if (resource_ != nullptr) {
471     resource_->Unref();
472     if (cinfo_.resource_is_private_to_kernel()) {
473       if (!cinfo_.resource_manager()
474                ->template Delete<IteratorResource>(cinfo_.container(),
475                                                    cinfo_.name())
476                .ok()) {
477         // Do nothing; the resource can have been deleted by session resets.
478       }
479     }
480   }
481 }
482 
Compute(OpKernelContext * context)483 void IteratorHandleOp::Compute(OpKernelContext* context)
484     TF_LOCKS_EXCLUDED(mu_) {
485   {
486     mutex_lock l(mu_);
487     if (resource_ == nullptr) {
488       FunctionLibraryRuntime* flr;
489       std::unique_ptr<DeviceMgr> device_mgr(nullptr);
490       std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
491       std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
492       // If the iterator is shared then we construct a new FLR, and pass that
493       // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
494       // functions from the iterator. We may add this functionality if there
495       // is sufficient demand, but it will require a significant refactoring.
496       if (!name_.empty()) {
497         flr = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
498       } else {
499         OP_REQUIRES_OK(context, context->function_library()->Clone(
500                                     &flib_def, &pflr, &flr, true));
501       }
502 
503       ResourceMgr* mgr = context->resource_manager();
504       OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
505 
506       IteratorResource* resource;
507       OP_REQUIRES_OK(
508           context,
509           mgr->LookupOrCreate<IteratorResource>(
510               cinfo_.container(), cinfo_.name(), &resource,
511               [context, flr, &device_mgr, &flib_def, &pflr,
512                this](IteratorResource** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
513                 *ret = new IteratorResource(
514                     context->env(), output_dtypes_, output_shapes_,
515                     std::move(device_mgr), std::move(flib_def), std::move(pflr),
516                     flr);
517                 return Status::OK();
518               }));
519 
520       Status s = VerifyResource(resource);
521       if (TF_PREDICT_FALSE(!s.ok())) {
522         resource->Unref();
523         context->SetStatus(s);
524         return;
525       }
526 
527       resource_ = resource;
528     }
529   }
530   OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
531                               context, 0, cinfo_.container(), cinfo_.name(),
532                               TypeIndex::Make<IteratorResource>()));
533 }
534 
VerifyResource(IteratorResource * resource)535 Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
536   TF_RETURN_IF_ERROR(
537       VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
538   TF_RETURN_IF_ERROR(
539       VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
540   return Status::OK();
541 }
542 
CreatePrivateFLR(OpKernelContext * ctx,std::unique_ptr<DeviceMgr> * device_mgr,std::unique_ptr<FunctionLibraryDefinition> * flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * pflr)543 FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
544     OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
545     std::unique_ptr<FunctionLibraryDefinition>* flib_def,
546     std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
547   // Wrap the existing device in order to see any captured resources
548   // in its resource manager. The existing device will outlive the
549   // IteratorResource, because we are storing the IteratorResource
550   // in that device's resource manager.
551 
552   *device_mgr =
553       absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
554           ctx->device()->name(), down_cast<Device*>(ctx->device()),
555           false /* owns_underlying */, false /* isolate_session_state */));
556   *flib_def = absl::make_unique<FunctionLibraryDefinition>(
557       *ctx->function_library()->GetFunctionLibraryDefinition());
558   const auto* config = ctx->function_library()->config_proto();
559   *pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
560       device_mgr->get(), ctx->env(),
561       /*config=*/config, graph_def_version_, flib_def->get(),
562       config->graph_options().optimizer_options());
563 
564   return (*pflr)->GetFLR(ctx->device()->name());
565 }
566 
567 // Like IteratorHandleOp, but creates handles which are never shared, and does
568 // not hold a reference to these handles. The latter is important for eager
569 // execution, since OpKernel instances generally live as long as the program
570 // running them.
AnonymousIteratorHandleOp(OpKernelConstruction * context)571 AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
572     OpKernelConstruction* context)
573     : AnonymousResourceOp<IteratorResource>(context),
574       graph_def_version_(context->graph_def_version()) {
575   OP_REQUIRES_OK(context, context->GetAttr(kOutputTypes, &output_dtypes_));
576   OP_REQUIRES_OK(context, context->GetAttr(kOutputShapes, &output_shapes_));
577   create_deleter_ = context->def().op() == kAnonymousIteratorV2;
578 }
579 
name()580 string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; }
581 
CreateResource(OpKernelContext * ctx,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib,IteratorResource ** resource)582 Status AnonymousIteratorHandleOp::CreateResource(
583     OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
584     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
585     FunctionLibraryRuntime* lib, IteratorResource** resource) {
586   std::unique_ptr<DeviceMgr> device_mgr(nullptr);
587   *resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_,
588                                    std::move(device_mgr), std::move(flib_def),
589                                    std::move(pflr), lib);
590   return Status::OK();
591 }
592 
HybridAsyncOpKernel(OpKernelConstruction * ctx,const char * background_worker_name)593 HybridAsyncOpKernel::HybridAsyncOpKernel(OpKernelConstruction* ctx,
594                                          const char* background_worker_name)
595     : AsyncOpKernel(ctx),
596       background_worker_(ctx->env(), background_worker_name) {}
597 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)598 void HybridAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
599                                        DoneCallback done) {
600   background_worker_.Schedule([this, ctx, done = std::move(done)]() {
601     ctx->SetStatus(DoCompute(ctx));
602     done();
603   });
604 }
605 
Compute(OpKernelContext * ctx)606 void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) {
607   ctx->SetStatus(DoCompute(ctx));
608 }
609 
DoCompute(OpKernelContext * ctx)610 Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
611   tensorflow::ResourceTagger tag(kTFDataResourceTag,
612                                  ctx->op_kernel().type_string());
613   DatasetBase* dataset;
614   TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
615   IteratorResource* iterator_resource;
616   TF_RETURN_IF_ERROR(
617       LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
618   core::ScopedUnref unref_iterator(iterator_resource);
619   return iterator_resource->SetIteratorFromDataset(ctx, dataset);
620 }
621 
DoCompute(OpKernelContext * ctx)622 Status DeleteIteratorOp::DoCompute(OpKernelContext* ctx) {
623   tensorflow::ResourceTagger tag(kTFDataResourceTag,
624                                  ctx->op_kernel().type_string());
625   const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
626   // The iterator resource is guaranteed to exist because the variant tensor
627   // wrapping the deleter is provided as an unused input to this op, which
628   // guarantees that it has not run yet.
629   return ctx->resource_manager()->Delete(handle);
630 }
631 
632 namespace {
633 
634 class ToSingleElementOp : public AsyncOpKernel {
635  public:
ToSingleElementOp(OpKernelConstruction * ctx)636   explicit ToSingleElementOp(OpKernelConstruction* ctx)
637       : AsyncOpKernel(ctx),
638         unbounded_threadpool_(ctx->env(), "tf_data_to_single_element") {
639     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
640     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
641   }
642 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)643   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
644     unbounded_threadpool_.Schedule([this, ctx, done = std::move(done)]() {
645       ctx->SetStatus(DoCompute(ctx));
646       done();
647     });
648   }
649 
Compute(OpKernelContext * ctx)650   void Compute(OpKernelContext* ctx) override {
651     ctx->SetStatus(DoCompute(ctx));
652   }
653 
654  private:
DoCompute(OpKernelContext * ctx)655   Status DoCompute(OpKernelContext* ctx) {
656     profiler::TraceMe traceme(
657         [&] {
658           return profiler::TraceMeEncode("ToSingleElementOp::DoCompute",
659                                          {{"id", ctx->step_id()}});
660         },
661         profiler::kInfo);
662     tensorflow::ResourceTagger tag(kTFDataResourceTag,
663                                    ctx->op_kernel().type_string());
664     DatasetBase* dataset;
665     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
666 
667     IteratorContext::Params params(ctx);
668     ResourceMgr resource_mgr;
669     params.resource_mgr = &resource_mgr;
670     CancellationManager cancellation_manager(ctx->cancellation_manager());
671     params.cancellation_manager = &cancellation_manager;
672 
673     IteratorContext iter_ctx(std::move(params));
674     std::unique_ptr<IteratorBase> iterator;
675     TF_RETURN_IF_ERROR(dataset->MakeIterator(
676         &iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator));
677 
678     std::vector<Tensor> components;
679     components.reserve(dataset->output_dtypes().size());
680     bool end_of_sequence = false;
681 
682     TF_RETURN_IF_ERROR(
683         iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
684 
685     if (end_of_sequence) {
686       return errors::InvalidArgument("Dataset was empty.");
687     }
688     TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
689     TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
690     for (int i = 0; i < components.size(); ++i) {
691       ctx->set_output(i, components[i]);
692     }
693 
694     components.clear();
695     TF_RETURN_IF_ERROR(
696         iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
697     if (!end_of_sequence) {
698       return errors::InvalidArgument("Dataset had more than one element.");
699     }
700     return Status::OK();
701   }
702 
703   UnboundedThreadPool unbounded_threadpool_;
704   DataTypeVector output_types_;
705   std::vector<PartialTensorShape> output_shapes_;
706 };
707 
708 class OneShotIteratorOp : public AsyncOpKernel {
709  public:
OneShotIteratorOp(OpKernelConstruction * ctx)710   explicit OneShotIteratorOp(OpKernelConstruction* ctx)
711       : AsyncOpKernel(ctx),
712         background_worker_(ctx->env(), "tf_data_one_shot_iterator"),
713         graph_def_version_(ctx->graph_def_version())
714 
715   {
716     string shared_name;
717     OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name));
718     OP_REQUIRES(ctx, shared_name.empty(),
719                 errors::InvalidArgument("OneShotIteratorOp does not currently "
720                                         "support the 'shared_name' attr."));
721     OP_REQUIRES_OK(ctx,
722                    ctx->GetAttr("dataset_factory", &dataset_factory_func_));
723     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
724     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
725   }
726 
~OneShotIteratorOp()727   ~OneShotIteratorOp() override {
728     if (iterator_resource_ != nullptr) {
729       iterator_resource_->Unref();
730       if (!cinfo_.resource_manager()
731                ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name())
732                .ok()) {
733         // Do nothing; the resource can have been deleted by session resets.
734       }
735     }
736   }
737 
738   // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`,
739   // but due to the fact that `ResourceOpKernel<T>::CreateResource()`
740   // does not provide access to the `OpKernelContext*` and we need
741   // this to invoke the factory function, it's not possible to
742   // implement this kernel by implementing `CreateResource()`.
743   // Furthermore, due to the fact that this kernel might block when
744   // running the initialization function, we must implement this
745   // kernel as an async kernel.
ComputeAsync(OpKernelContext * ctx,DoneCallback done)746   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
747     tensorflow::ResourceTagger tag(kTFDataResourceTag,
748                                    ctx->op_kernel().type_string());
749     {
750       mutex_lock l(mu_);
751       if (iterator_resource_ == nullptr && initialization_status_.ok()) {
752         // The initialization thread will call `done`.
753         if (!initialization_started_) {
754           // TODO(mrry): Convert the initialization code to use
755           // callbacks instead of wasting a thread.
756           background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); });
757           initialization_started_ = true;
758         } else {
759           done_callbacks_.emplace_back(ctx, std::move(done));
760         }
761         return;
762       }
763     }
764     ProduceOutput(ctx, done);
765   }
766 
767  private:
Init(OpKernelContext * ctx,const DoneCallback & done)768   void Init(OpKernelContext* ctx, const DoneCallback& done) {
769     IteratorResource* iterator = nullptr;
770     ContainerInfo cinfo;
771     Status s = TryInit(ctx, &iterator, &cinfo);
772 
773     std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run;
774     {
775       mutex_lock l(mu_);
776       if (s.ok()) {
777         iterator_resource_ = iterator;
778         cinfo_ = cinfo;
779       }
780       initialization_status_ = s;
781       std::swap(done_callbacks_, callbacks_to_run);
782     }
783 
784     for (auto&& ctx_done : callbacks_to_run) {
785       ProduceOutput(ctx_done.first, ctx_done.second);
786     }
787     ProduceOutput(ctx, done);
788   }
789 
TryInit(OpKernelContext * ctx,IteratorResource ** iterator,ContainerInfo * cinfo)790   Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
791                  ContainerInfo* cinfo) {
792     TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def()));
793 
794     FunctionLibraryRuntime* flr;
795     std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
796     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
797     TF_RETURN_IF_ERROR(
798         ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
799 
800     // Create an IteratorResource that will hold the iterator for this op.
801     TF_RETURN_IF_ERROR(
802         ctx->resource_manager()->LookupOrCreate<IteratorResource>(
803             cinfo->container(), cinfo->name(), iterator,
804             [ctx, flr, this, &flib_def, &pflr](IteratorResource** ret)
805                 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
806                   *ret = new IteratorResource(
807                       ctx->env(), output_dtypes_, output_shapes_,
808                       /*device_mgr=*/nullptr, std::move(flib_def),
809                       std::move(pflr), flr);
810                   return Status::OK();
811                 }));
812 
813     core::ScopedUnref unref_iterator(*iterator);
814 
815     TF_RETURN_IF_ERROR(
816         VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes()));
817     TF_RETURN_IF_ERROR(
818         VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes()));
819 
820     // Call the dataset_factory_func_ to create a new dataset,
821     // over which this op will iterate.
822     FunctionLibraryRuntime::Handle f_handle;
823     TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
824         dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()),
825         &f_handle));
826     FunctionLibraryRuntime::Options opts;
827     opts.cancellation_manager = ctx->cancellation_manager();
828     ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
829       ctx->resource_manager()->Cleanup(name).IgnoreError();
830     });
831     opts.step_container = &step_container;
832     opts.runner = ctx->runner();
833     opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
834     std::vector<Tensor> return_values;
835     TF_RETURN_IF_ERROR(ctx->function_library()->RunSync(
836         std::move(opts), f_handle, {}, &return_values));
837     if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
838         !TensorShapeUtils::IsScalar(return_values[0].shape())) {
839       return errors::InvalidArgument(
840           "The `dataset_factory` function must return "
841           "a single scalar of dtype DT_VARIANT.");
842     }
843 
844     // Create an iterator for the dataset that was created in the
845     // factory function.
846     DatasetBase* dataset;
847     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
848     TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
849     (*iterator)->Ref();
850     return Status::OK();
851   }
852 
ProduceOutput(OpKernelContext * ctx,const DoneCallback & done)853   void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) {
854     Tensor* handle;
855     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle),
856                          done);
857     Status s;
858     {
859       mutex_lock l(mu_);
860       s = initialization_status_;
861       if (s.ok()) {
862         handle->scalar<ResourceHandle>()() =
863             MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(),
864                                                  cinfo_.name());
865       }
866     }
867     OP_REQUIRES_OK_ASYNC(ctx, s, done);
868     done();
869   }
870 
871   NameAttrList dataset_factory_func_;
872   DataTypeVector output_dtypes_;
873   std::vector<PartialTensorShape> output_shapes_;
874 
875   BackgroundWorker background_worker_;
876 
877   mutex mu_;
878   ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
879   IteratorResource* iterator_resource_ TF_GUARDED_BY(mu_) = nullptr;
880 
881   bool initialization_started_ TF_GUARDED_BY(mu_) = false;
882   Status initialization_status_ TF_GUARDED_BY(mu_);
883   std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_
884       TF_GUARDED_BY(mu_);
885   const int graph_def_version_;
886 };
887 
888 }  // namespace
889 
AsAsync()890 AsyncOpKernel* IteratorGetNextOp::AsAsync() {
891   return type_string() == "IteratorGetNextSync" ? nullptr : this;
892 }
893 
RecordElementSize(const std::vector<Tensor> element,profiler::TraceMe * traceme)894 void RecordElementSize(const std::vector<Tensor> element,
895                        profiler::TraceMe* traceme) {
896   traceme->AppendMetadata([&]() {
897     int64_t element_size = 0;
898     for (const auto& component : element) {
899       element_size += component.TotalBytes();
900     }
901     return profiler::TraceMeEncode({{"element_size", element_size}});
902   });
903 }
904 
DoCompute(OpKernelContext * ctx)905 Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
906   profiler::TraceMe traceme(
907       [&] {
908         int64_t mem_bw = port::GetMemoryBandwidthInfo().bw_used;
909 
910         if (mem_bw != INT64_MAX) {
911           return profiler::TraceMeEncode(
912               "IteratorGetNextOp::DoCompute",
913               {{"id", ctx->step_id()},
914                {"iter_num", ctx->frame_iter().iter_id},
915                {"mem_bw_used_megabytes_per_sec", mem_bw}});
916         }
917         return profiler::TraceMeEncode(
918             "IteratorGetNextOp::DoCompute",
919             {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
920       },
921       profiler::kInfo);
922   tensorflow::ResourceTagger tag(kTFDataResourceTag,
923                                  ctx->op_kernel().type_string());
924   IteratorResource* iterator;
925   TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
926   core::ScopedUnref unref_iterator(iterator);
927   std::vector<Tensor> components;
928   bool end_of_sequence = false;
929 
930   TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
931   if (end_of_sequence) {
932     return errors::OutOfRange("End of sequence");
933   }
934   TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
935   TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
936   RecordElementSize(components, &traceme);
937   for (int i = 0; i < components.size(); ++i) {
938     ctx->set_output(i, components[i]);
939   }
940   return Status::OK();
941 }
942 
DoCompute(OpKernelContext * ctx)943 Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
944   profiler::TraceMe traceme(
945       [&] {
946         return profiler::TraceMeEncode(
947             "IteratorGetNextAsOptionalOp::DoCompute",
948             {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
949       },
950       profiler::kInfo);
951   tensorflow::ResourceTagger tag(kTFDataResourceTag,
952                                  ctx->op_kernel().type_string());
953   IteratorResource* iterator;
954   TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
955   core::ScopedUnref unref_iterator(iterator);
956   std::vector<Tensor> components;
957   bool end_of_sequence = false;
958 
959   TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
960 
961   if (end_of_sequence) {
962     return WriteOptionalNoneToOutput(ctx, 0);
963   } else {
964     RecordElementSize(components, &traceme);
965     for (int i = 0; i < components.size(); ++i) {
966       if (components[i].dtype() != output_types_[i]) {
967         return errors::InvalidArgument(
968             "The given optional does not match the expected type for "
969             "component ",
970             i, ". Expected: ", DataTypeString(output_types_[i]),
971             ". Actual: ", DataTypeString(components[i].dtype()), ".");
972       }
973       if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
974         return errors::InvalidArgument(
975             "The given optional does not match the expected shape "
976             "for component ",
977             i, ". Expected: ", output_shapes_[i].DebugString(),
978             ". Actual: ", components[i].shape().DebugString(), ".");
979       }
980     }
981     return WriteOptionalWithValueToOutput(ctx, 0, std::move(components));
982   }
983 }
984 
Compute(OpKernelContext * ctx)985 void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
986   const Tensor& resource_handle_t = ctx->input(0);
987   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
988               errors::InvalidArgument("resource_handle must be a scalar"));
989 
990   // Validate that the handle corresponds to a real resource, and
991   // that it is an IteratorResource.
992   IteratorResource* iterator_resource;
993   OP_REQUIRES_OK(
994       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
995   iterator_resource->Unref();
996 
997   Tensor* string_handle_t;
998   OP_REQUIRES_OK(ctx,
999                  ctx->allocate_output(0, TensorShape({}), &string_handle_t));
1000   string_handle_t->scalar<tstring>()() =
1001       resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
1002 }
1003 
IteratorFromStringHandleOp(OpKernelConstruction * ctx)1004 IteratorFromStringHandleOp::IteratorFromStringHandleOp(
1005     OpKernelConstruction* ctx)
1006     : OpKernel(ctx) {
1007   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
1008   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1009   OP_REQUIRES(
1010       ctx,
1011       output_dtypes_.empty() || output_shapes_.empty() ||
1012           output_dtypes_.size() == output_shapes_.size(),
1013       errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
1014                               "are set, they must have the same length."));
1015 }
1016 
Compute(OpKernelContext * ctx)1017 void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
1018   const Tensor& string_handle_t = ctx->input(0);
1019   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
1020               errors::InvalidArgument("string_handle must be a scalar"));
1021 
1022   ResourceHandle resource_handle;
1023   OP_REQUIRES(
1024       ctx, resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
1025       errors::InvalidArgument(
1026           "Could not parse string_handle as a valid ResourceHandle"));
1027 
1028   OP_REQUIRES(
1029       ctx, resource_handle.device() == ctx->device()->attributes().name(),
1030       errors::InvalidArgument("Attempted create an iterator on device \"",
1031                               ctx->device()->attributes().name(),
1032                               "\" from handle defined on device \"",
1033                               resource_handle.device(), "\""));
1034 
1035   // Validate that the handle corresponds to a real resource, and
1036   // that it is an IteratorResource.
1037   IteratorResource* iterator_resource;
1038   OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource));
1039   core::ScopedUnref unref_iterator(iterator_resource);
1040   if (!output_dtypes_.empty()) {
1041     OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
1042                                          iterator_resource->output_dtypes()));
1043   }
1044   if (!output_shapes_.empty()) {
1045     OP_REQUIRES_OK(ctx,
1046                    VerifyShapesCompatible(output_shapes_,
1047                                           iterator_resource->output_shapes()));
1048   }
1049 
1050   Tensor* resource_handle_t;
1051   OP_REQUIRES_OK(ctx,
1052                  ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
1053   resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
1054 }
1055 
SerializeIteratorOp(OpKernelConstruction * ctx)1056 SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
1057     : OpKernel(ctx) {
1058   if (ctx->HasAttr(kExternalStatePolicy)) {
1059     int64_t state_change_option;
1060     OP_REQUIRES_OK(ctx,
1061                    ctx->GetAttr(kExternalStatePolicy, &state_change_option));
1062     external_state_policy_ =
1063         SerializationContext::ExternalStatePolicy(state_change_option);
1064   }
1065 }
1066 
Compute(OpKernelContext * ctx)1067 void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
1068   tensorflow::ResourceTagger tag(kTFDataResourceTag,
1069                                  ctx->op_kernel().type_string());
1070   const Tensor& resource_handle_t = ctx->input(0);
1071   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1072               errors::InvalidArgument("resource_handle must be a scalar"));
1073   // Validate that the handle corresponds to a real resource, and
1074   // that it is an IteratorResource.
1075   IteratorResource* iterator_resource;
1076   OP_REQUIRES_OK(
1077       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1078   core::ScopedUnref unref_iterator(iterator_resource);
1079   IteratorVariantSerializer serializer;
1080   SerializationContext::Params params(ctx);
1081   params.external_state_policy = external_state_policy_;
1082   SerializationContext serialization_ctx(params);
1083   OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
1084                                                         iterator_resource));
1085   Tensor* serialized_t;
1086   OP_REQUIRES_OK(ctx,
1087                  ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
1088                                       &serialized_t));
1089   OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
1090 }
1091 
Compute(OpKernelContext * ctx)1092 void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
1093   tensorflow::ResourceTagger tag(kTFDataResourceTag,
1094                                  ctx->op_kernel().type_string());
1095   // Validate that the handle corresponds to a real resource, and
1096   // that it is an IteratorResource.
1097   IteratorResource* iterator_resource;
1098   OP_REQUIRES_OK(
1099       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1100   core::ScopedUnref unref_iterator(iterator_resource);
1101   const Tensor* serialized_t;
1102   OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
1103   IteratorVariantSerializer serializer;
1104   OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
1105   Status s = iterator_resource->Restore(ctx, serializer.GetReader());
1106   if (!s.ok()) {
1107     OP_REQUIRES_OK(
1108         ctx,
1109         Status(s.code(),
1110                absl::StrCat(
1111                    "Failed to restore dataset iterator from checkpoint: ",
1112                    s.error_message(),
1113                    ". Make sure the dataset definition has not changed between "
1114                    "the process that saved the checkpoint and the process that "
1115                    "is restoring it.")));
1116   }
1117 }
1118 
1119 namespace {
1120 
1121 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
1122 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
1123                         IteratorHandleOp);
1124 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU).Priority(1),
1125                         IteratorHandleOp);
1126 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2),
1127                         MakeIteratorOp);
1128 REGISTER_KERNEL_BUILDER(
1129     Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"),
1130     MakeIteratorOp);
1131 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_CPU).Priority(2),
1132                         DeleteIteratorOp);
1133 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_GPU).Priority(1),
1134                         DeleteIteratorOp);
1135 REGISTER_KERNEL_BUILDER(
1136     Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2),
1137     AnonymousIteratorHandleOp);
1138 REGISTER_KERNEL_BUILDER(
1139     Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1),
1140     AnonymousIteratorHandleOp);
1141 REGISTER_KERNEL_BUILDER(
1142     Name("AnonymousIteratorV2").Device(DEVICE_CPU).Priority(2),
1143     AnonymousIteratorHandleOp);
1144 REGISTER_KERNEL_BUILDER(
1145     Name("AnonymousIteratorV2").Device(DEVICE_GPU).Priority(1),
1146     AnonymousIteratorHandleOp);
1147 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
1148                         ToSingleElementOp);
1149 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
1150                         OneShotIteratorOp);
1151 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU).Priority(2),
1152                         IteratorGetNextOp);
1153 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU).Priority(1),
1154                         IteratorGetNextOp);
1155 REGISTER_KERNEL_BUILDER(
1156     Name("IteratorGetNextSync").Device(DEVICE_CPU).Priority(2),
1157     IteratorGetNextOp);
1158 REGISTER_KERNEL_BUILDER(
1159     Name("IteratorGetNextSync").Device(DEVICE_GPU).Priority(1),
1160     IteratorGetNextOp);
1161 REGISTER_KERNEL_BUILDER(
1162     Name("IteratorGetNextAsOptional").Device(DEVICE_CPU).Priority(2),
1163     IteratorGetNextAsOptionalOp);
1164 REGISTER_KERNEL_BUILDER(
1165     Name("IteratorGetNextAsOptional").Device(DEVICE_GPU).Priority(1),
1166     IteratorGetNextAsOptionalOp);
1167 REGISTER_KERNEL_BUILDER(
1168     Name("IteratorToStringHandle").Device(DEVICE_CPU).Priority(2),
1169     IteratorToStringHandleOp);
1170 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
1171                             .Device(DEVICE_GPU)
1172                             .HostMemory("string_handle")
1173                             .Priority(1),
1174                         IteratorToStringHandleOp);
1175 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
1176                         IteratorFromStringHandleOp);
1177 REGISTER_KERNEL_BUILDER(
1178     Name("IteratorFromStringHandleV2").Device(DEVICE_CPU).Priority(2),
1179     IteratorFromStringHandleOp);
1180 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
1181                             .Device(DEVICE_GPU)
1182                             .HostMemory("string_handle")
1183                             .Priority(1),
1184                         IteratorFromStringHandleOp);
1185 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
1186                         SerializeIteratorOp);
1187 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
1188                         DeserializeIteratorOp);
1189 
1190 }  // namespace
1191 
1192 }  // namespace data
1193 }  // namespace tensorflow
1194