• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/iterator_ops.h"
16 
17 #include <cstdint>
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/time/time.h"
25 #include "tensorflow/core/activity_watcher/activity.h"
26 #include "tensorflow/core/activity_watcher/activity_utils.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/common_runtime/graph_runner.h"
29 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
30 #include "tensorflow/core/common_runtime/renamed_device.h"
31 #include "tensorflow/core/common_runtime/threadpool_device.h"
32 #include "tensorflow/core/data/captured_function.h"
33 #include "tensorflow/core/data/dataset_utils.h"
34 #include "tensorflow/core/data/finalization_utils.h"
35 #include "tensorflow/core/data/metric_utils.h"
36 #include "tensorflow/core/data/root_dataset.h"
37 #include "tensorflow/core/data/serialization_utils.h"
38 #include "tensorflow/core/data/utils.h"
39 #include "tensorflow/core/framework/cancellation.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/metrics.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/framework/partial_tensor_shape.h"
44 #include "tensorflow/core/framework/resource_op_kernel.h"
45 #include "tensorflow/core/framework/stats_aggregator.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/types.h"
48 #include "tensorflow/core/framework/variant_op_registry.h"
49 #include "tensorflow/core/framework/variant_tensor_data.h"
50 #include "tensorflow/core/kernels/data/optional_ops.h"
51 #include "tensorflow/core/kernels/ops_util.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/lib/core/refcount.h"
54 #include "tensorflow/core/lib/core/threadpool.h"
55 #include "tensorflow/core/lib/gtl/cleanup.h"
56 #include "tensorflow/core/lib/random/random.h"
57 #include "tensorflow/core/lib/strings/strcat.h"
58 #include "tensorflow/core/lib/strings/stringprintf.h"
59 #include "tensorflow/core/platform/casts.h"
60 #include "tensorflow/core/platform/env.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/mem.h"
63 #include "tensorflow/core/platform/mutex.h"
64 #include "tensorflow/core/platform/refcount.h"
65 #include "tensorflow/core/platform/resource.h"
66 #include "tensorflow/core/profiler/lib/traceme.h"
67 #include "tensorflow/core/profiler/lib/traceme_encode.h"
68 #include "tensorflow/core/public/session_options.h"
69 
70 namespace tensorflow {
71 namespace data {
72 namespace {
73 
74 // See documentation in ../../ops/dataset_ops.cc for a high-level
75 // description of the following ops.
76 
77 const char kAnonymousIterator[] = "AnonymousIterator";
78 const char kAnonymousIteratorV2[] = "AnonymousIteratorV2";
79 const char kAnonymousIteratorV3[] = "AnonymousIteratorV3";
80 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
81 const char kOutputShapes[] = "output_shapes";
82 const char kOutputTypes[] = "output_types";
83 
84 }  // namespace
85 
86 /* static */ constexpr const char* const
87     SerializeIteratorOp::kExternalStatePolicy;
88 
IteratorResource(Env * env,const DataTypeVector & output_dtypes,const std::vector<PartialTensorShape> & output_shapes,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr)89 IteratorResource::IteratorResource(
90     Env* env, const DataTypeVector& output_dtypes,
91     const std::vector<PartialTensorShape>& output_shapes,
92     std::unique_ptr<DeviceMgr> device_mgr,
93     std::unique_ptr<FunctionLibraryDefinition> flib_def,
94     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
95     FunctionLibraryRuntime* flr)
96     : metrics_collector_(flr->device()->device_type(), *env),
97       unbounded_thread_pool_(env, "tf_data_iterator_resource"),
98       device_mgr_(std::move(device_mgr)),
99       iterator_state_(std::make_shared<State>(std::move(flib_def),
100                                               std::move(pflr), flr,
101                                               /*iterator=*/nullptr)),
102       output_dtypes_(output_dtypes),
103       output_shapes_(output_shapes) {
104   VLOG(2) << "creating iterator resource";
105 }
106 
~IteratorResource()107 IteratorResource::~IteratorResource() {
108   VLOG(2) << "destroying iterator resource";
109 }
110 
GetNext(OpKernelContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)111 Status IteratorResource::GetNext(OpKernelContext* ctx,
112                                  std::vector<Tensor>* out_tensors,
113                                  bool* end_of_sequence) {
114   std::shared_ptr<State> captured_state;
115   {
116     tf_shared_lock l(mu_);
117     captured_state = iterator_state_;
118   }
119   if (!captured_state->iterator()) {
120     return errors::FailedPrecondition(
121         "GetNext() failed because the iterator has not been initialized. "
122         "Ensure that you have run the initializer operation for this iterator "
123         "before getting the next element.");
124   }
125   IteratorContext::Params params(ctx);
126   params.flr = captured_state->flr();
127   params.function_handle_cache = captured_state->function_handle_cache();
128   params.resource_mgr = captured_state->resource_mgr();
129   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
130   params.thread_pool = &unbounded_thread_pool_;
131   params.cancellation_manager = captured_state->cancellation_manager();
132   std::function<void()> deregister_fn;
133   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
134       ctx->cancellation_manager(),
135       [cm = params.cancellation_manager]() { cm->StartCancel(); },
136       &deregister_fn));
137   auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
138 
139   const absl::Time start_time = metrics_collector_.RecordStart();
140   auto iterator_ = captured_state->iterator();
141   auto status = iterator_->GetNext(IteratorContext(std::move(params)),
142                                    out_tensors, end_of_sequence);
143   metrics_collector_.RecordStop(start_time, *out_tensors);
144   return status;
145 }
146 
Save(SerializationContext * ctx,IteratorStateWriter * writer)147 Status IteratorResource::Save(SerializationContext* ctx,
148                               IteratorStateWriter* writer) {
149   std::shared_ptr<State> captured_state;
150   {
151     tf_shared_lock l(mu_);
152     captured_state = iterator_state_;
153   }
154   auto iterator_ = captured_state->iterator();
155   if (iterator_) {
156     return iterator_->Save(ctx, writer);
157   }
158   return errors::FailedPrecondition(
159       "Save() failed because the iterator has not been initialized. Ensure "
160       "that you have run the initializer operation for this iterator before "
161       "saving it.");
162 }
163 
Restore(OpKernelContext * ctx,IteratorStateReader * reader)164 Status IteratorResource::Restore(OpKernelContext* ctx,
165                                  IteratorStateReader* reader) {
166   const DatasetBase* dataset;
167   std::shared_ptr<State> new_state;
168   const DatasetBase* input_dataset;
169   {
170     tf_shared_lock l(mu_);
171     if (!iterator_state_->iterator()) {
172       return errors::FailedPrecondition(
173           "Restore() failed because the iterator has not been initialized. "
174           "Ensure that you have run the initializer operation for this "
175           "iterator before restoring it.");
176     }
177     auto iterator_ = iterator_state_->iterator();
178     dataset = iterator_->dataset();
179     // Hang onto a reference until we've created the new iterator, which will
180     // then hold its own reference to keep the dataset alive.
181     dataset->Ref();
182     new_state =
183         std::make_shared<State>(iterator_state_->flib_def(),
184                                 iterator_state_->pflr(), iterator_state_->flr(),
185                                 /*iterator=*/nullptr);
186     input_dataset = iterator_state_->dataset();
187   }
188   core::ScopedUnref scoped_unref(dataset);
189   IteratorContext::Params params(ctx);
190   params.flr = new_state->flr();
191   params.function_handle_cache = new_state->function_handle_cache();
192   params.resource_mgr = new_state->resource_mgr();
193   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
194   params.thread_pool = &unbounded_thread_pool_;
195   params.cancellation_manager = new_state->cancellation_manager();
196   std::function<void()> deregister_fn;
197   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
198       ctx->cancellation_manager(),
199       [cm = params.cancellation_manager]() { cm->StartCancel(); },
200       &deregister_fn));
201   auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
202   std::unique_ptr<IteratorBase> iterator_base;
203   TF_RETURN_IF_ERROR(dataset->MakeIteratorFromCheckpoint(
204       IteratorContext(std::move(params)), "Iterator", reader, &iterator_base));
205   new_state->DowncastAndSetIteratorAndDataset(std::move(iterator_base),
206                                               input_dataset);
207 
208   mutex_lock l(mu_);
209   std::swap(iterator_state_, new_state);
210   return OkStatus();
211 }
212 
SetIteratorFromDataset(OpKernelContext * ctx,const DatasetBase * dataset)213 Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
214                                                 const DatasetBase* dataset) {
215   std::shared_ptr<State> new_state;
216   {
217     tf_shared_lock l(mu_);
218     new_state =
219         std::make_shared<State>(iterator_state_->flib_def(),
220                                 iterator_state_->pflr(), iterator_state_->flr(),
221                                 /*iterator=*/nullptr);
222   }
223 
224   // Create new iterator.
225   IteratorContext::Params params(ctx);
226   params.flr = new_state->flr();
227   params.function_handle_cache = new_state->function_handle_cache();
228   params.resource_mgr = new_state->resource_mgr();
229   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
230   params.thread_pool = &unbounded_thread_pool_;
231   params.cancellation_manager = new_state->cancellation_manager();
232   std::function<void()> deregister_fn;
233   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
234       ctx->cancellation_manager(),
235       [cm = params.cancellation_manager]() { cm->StartCancel(); },
236       &deregister_fn));
237   auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
238 
239   std::unique_ptr<IteratorBase> iterator;
240   if (ctx->function_library()->device()->device_type() == DEVICE_CPU) {
241     DatasetBase* finalized_dataset;
242     TF_ASSIGN_OR_RETURN(finalized_dataset, GetFinalizedDataset(ctx, dataset));
243     TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
244         IteratorContext(std::move(params)),
245         /*parent=*/nullptr, "Iterator", &iterator));
246   } else {
247     TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
248                                              /*parent=*/nullptr, "Iterator",
249                                              &iterator));
250   }
251   TF_RETURN_IF_ERROR(
252       VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
253   TF_RETURN_IF_ERROR(
254       VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
255 
256   new_state->DowncastAndSetIteratorAndDataset(std::move(iterator), dataset);
257 
258   mutex_lock l(mu_);
259   std::swap(iterator_state_, new_state);
260   return OkStatus();
261 }
262 
263 namespace {
264 
265 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
266 // The get() method returns an VariantTensorData object which contains all the
267 // state needed to restore a single iterator.
268 //
269 // Usage example:
270 //
271 // Encoding:
272 //
273 //   Tensor t(DT_VARIANT, TensorShape({}));
274 //   t->scalar<Variant>()() = IteratorStateVariant();
275 //
276 // Encode() sets the type_name of the VariantTensorData object to
277 // IteratorStateVariant::TypeName().
278 //
279 // Decoding:
280 //
281 //   Variant v = <VariantTensorDataProto object>;
282 //   DecodeUnaryVariant(&v);
283 //   IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
284 //   IteratorStateReader reader({wrapper->GetData()});
285 //   iterator_resource->Restore(ctx, &reader);
286 //
287 // The type_name of the VariantTensorData object to be decoded must
288 // match IteratorStateVariant::TypeName().
289 class IteratorStateVariant {
290  public:
IteratorStateVariant()291   IteratorStateVariant() : data_(nullptr) {}
IteratorStateVariant(const IteratorStateVariant & other)292   IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
293     if (other.data_) {
294       Decode(*other.data_);
295     }
296   }
297   IteratorStateVariant& operator=(IteratorStateVariant&& other) = default;
298   IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete;
299 
300   // Initializes `this` from a VariantTensorData object.
InitializeFromVariantData(std::unique_ptr<VariantTensorData> d)301   Status InitializeFromVariantData(std::unique_ptr<VariantTensorData> d) {
302     data_ = std::move(d);
303     return OkStatus();
304   }
305 
TypeName() const306   string TypeName() const { return kIteratorVariantTypeName; }
Encode(VariantTensorData * data) const307   void Encode(VariantTensorData* data) const { *data = *data_; }
Decode(VariantTensorData data)308   bool Decode(VariantTensorData data) {
309     if (data.type_name() != TypeName()) {
310       return false;
311     }
312     auto tensor_data = std::make_unique<VariantTensorData>();
313     std::swap(*tensor_data, data);
314     data_ = std::move(tensor_data);
315     return true;
316   }
317 
318   // Returns a borrowed pointer to the underlying VariantTensorData.
GetData() const319   const VariantTensorData* GetData() const { return data_.get(); }
320 
DebugString() const321   string DebugString() const {
322     if (data_) {
323       return strings::StrCat("IteratorStateVariant<", data_->DebugString(),
324                              ">");
325     } else {
326       return strings::StrCat("IteratorStateVariant<empty>");
327     }
328   }
329 
330  private:
331   std::unique_ptr<VariantTensorData> data_;
332 };
333 
334 // Register the reader class in the global variant decode_fn registry
335 // so that a Variant containing a serialized representation of iterator state
336 // can be decoded using DecodeUnaryVariant. If we don't do this we will need
337 // to manually decode the returned Variant using MaybeDecodeAndCopy in
338 // DeserializeIteratorOp which is not recommended.
339 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
340                                        kIteratorVariantTypeName);
341 
342 // A helper class that uses a list of IteratorStateVariant objects to represent
343 // the state for an iterator resource. It exposes methods that help with
344 // saving and restoring of this state. Sample usage
345 // Saving:
346 //   IteratorVariantSerializer serializer;
347 //   serializer.InitializeFromIterator(iterator_resource);
348 //   Tensor serialized_t;
349 //   serializer.Serialize(&serialized_t);
350 //
351 // Restoring:
352 //   IteratorVariantSerializer serializer;
353 //   serializer.InitFromTensor(ctx->input(0));
354 //   IteratorStateReader* reader = serializer.GetReader();
355 //   iterator_resource->Restore(ctx, reader);
356 class IteratorVariantSerializer {
357  public:
IteratorVariantSerializer()358   IteratorVariantSerializer() {}
359 
360   // Calls `Save` on the iterator_resource to build up the list of
361   // IteratorStateVariant objects.
InitializeFromIterator(SerializationContext * serialization_ctx,IteratorResource * iterator_resource)362   Status InitializeFromIterator(SerializationContext* serialization_ctx,
363                                 IteratorResource* iterator_resource) {
364     VariantTensorDataWriter writer;
365     TF_RETURN_IF_ERROR(iterator_resource->Save(serialization_ctx, &writer));
366     std::vector<std::unique_ptr<VariantTensorData>> data;
367     writer.ReleaseData(&data);
368     variants_.clear();
369     variants_.reserve(data.size());
370     for (auto& it : data) {
371       IteratorStateVariant v;
372       TF_RETURN_IF_ERROR(v.InitializeFromVariantData(std::move(it)));
373       variants_.push_back(v);
374     }
375     num_tensors_ = variants_.size();
376     can_serialize_ = true;
377     return OkStatus();
378   }
379 
380   // Initializes `this` from `serialized_t` while restoring the iterator state.
InitFromTensor(const Tensor * serialized_t)381   Status InitFromTensor(const Tensor* serialized_t) {
382     int64_t num_tensors = serialized_t->dim_size(0);
383     auto serialized_vec = serialized_t->vec<Variant>();
384     std::vector<const VariantTensorData*> data;
385     data.reserve(num_tensors);
386     for (int i = 0; i < num_tensors; ++i) {
387       auto* w = serialized_vec(i).get<IteratorStateVariant>();
388       if (!w) {
389         return errors::Internal(
390             "Cannot initialize an iterator from tensor ",
391             serialized_vec(i).DebugString(),
392             ". Expected a variant tensor of type IteratorStateVariant");
393       }
394       data.push_back(w->GetData());
395     }
396     reader_ = std::make_unique<VariantTensorDataReader>(data);
397     num_tensors_ = data.size();
398     return OkStatus();
399   }
400 
NumTensors()401   int64_t NumTensors() { return num_tensors_; }
402 
403   // Stores the IteratorStateVariant list into a pre-allocated tensor. Expects
404   // that InitializeFromIterator was called before.
Serialize(Tensor * serialized)405   Status Serialize(Tensor* serialized) {
406     if (!can_serialize_) {
407       return errors::InvalidArgument(
408           "Please call InitializeFromIterator before calling Serialize.");
409     }
410     int64_t size = variants_.size();
411     for (int64_t i = 0; i < size; ++i) {
412       if (variants_[i].GetData() == nullptr) {
413         return errors::Internal(
414             "Cannot serialize an empty IteratorStateVariant");
415       }
416       serialized->vec<Variant>()(i) = variants_[i];
417     }
418     return OkStatus();
419   }
420 
421   // Returns an IteratorStateReader to restore iterator state. Expects that
422   // InitFromTensor was called before.
GetReader()423   IteratorStateReader* GetReader() { return reader_.get(); }
424 
425  private:
426   bool can_serialize_ = false;
427   int64_t num_tensors_;
428   std::vector<IteratorStateVariant> variants_;
429   std::unique_ptr<IteratorStateReader> reader_;
430 };
431 
432 }  // namespace
433 
434 // Note that IteratorHandleOp holds a reference to the resource it creates. If
435 // cleaning up resources with DestroyResourceOp is important, consider creating
436 // resource containers with AnonymousIteratorHandleOp instead.
IteratorHandleOp(OpKernelConstruction * ctx)437 IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
438     : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
439   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
440   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
441   OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
442 }
443 
444 // The resource is deleted from the resource manager only when it is private
445 // to kernel. Ideally the resource should be deleted when it is no longer held
446 // by anyone, but it would break backward compatibility.
~IteratorHandleOp()447 IteratorHandleOp::~IteratorHandleOp() {
448   if (resource_ != nullptr) {
449     resource_->Unref();
450     if (cinfo_.resource_is_private_to_kernel()) {
451       if (!cinfo_.resource_manager()
452                ->template Delete<IteratorResource>(cinfo_.container(),
453                                                    cinfo_.name())
454                .ok()) {
455         // Do nothing; the resource can have been deleted by session resets.
456       }
457     }
458   }
459 }
460 
Compute(OpKernelContext * context)461 void IteratorHandleOp::Compute(OpKernelContext* context)
462     TF_LOCKS_EXCLUDED(mu_) {
463   {
464     mutex_lock l(mu_);
465     if (resource_ == nullptr) {
466       FunctionLibraryRuntime* flr;
467       std::unique_ptr<DeviceMgr> device_mgr(nullptr);
468       std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
469       std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
470       // If the iterator is shared then we construct a new FLR, and pass that
471       // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
472       // functions from the iterator. We may add this functionality if there
473       // is sufficient demand, but it will require a significant refactoring.
474       if (!name_.empty()) {
475         flr = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
476       } else {
477         OP_REQUIRES_OK(context, context->function_library()->Clone(
478                                     &flib_def, &pflr, &flr, true));
479       }
480 
481       ResourceMgr* mgr = context->resource_manager();
482       OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
483 
484       IteratorResource* resource;
485       OP_REQUIRES_OK(
486           context,
487           mgr->LookupOrCreate<IteratorResource>(
488               cinfo_.container(), cinfo_.name(), &resource,
489               [context, flr, &device_mgr, &flib_def, &pflr,
490                this](IteratorResource** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
491                 *ret = new IteratorResource(
492                     context->env(), output_dtypes_, output_shapes_,
493                     std::move(device_mgr), std::move(flib_def), std::move(pflr),
494                     flr);
495                 return OkStatus();
496               }));
497 
498       Status s = VerifyResource(resource);
499       if (TF_PREDICT_FALSE(!s.ok())) {
500         resource->Unref();
501         context->SetStatus(s);
502         return;
503       }
504 
505       resource_ = resource;
506     }
507   }
508   OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
509                               context, 0, cinfo_.container(), cinfo_.name(),
510                               TypeIndex::Make<IteratorResource>()));
511 }
512 
VerifyResource(IteratorResource * resource)513 Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
514   TF_RETURN_IF_ERROR(
515       VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
516   TF_RETURN_IF_ERROR(
517       VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
518   return OkStatus();
519 }
520 
CreatePrivateFLR(OpKernelContext * ctx,std::unique_ptr<DeviceMgr> * device_mgr,std::unique_ptr<FunctionLibraryDefinition> * flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * pflr)521 FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
522     OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
523     std::unique_ptr<FunctionLibraryDefinition>* flib_def,
524     std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
525   // Wrap the existing device in order to see any captured resources
526   // in its resource manager. The existing device will outlive the
527   // IteratorResource, because we are storing the IteratorResource
528   // in that device's resource manager.
529 
530   *device_mgr =
531       std::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
532           ctx->device()->name(), down_cast<Device*>(ctx->device()),
533           false /* owns_underlying */, false /* isolate_session_state */));
534   *flib_def = std::make_unique<FunctionLibraryDefinition>(
535       *ctx->function_library()->GetFunctionLibraryDefinition());
536   const auto* config = ctx->function_library()->config_proto();
537   *pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
538       device_mgr->get(), ctx->env(),
539       /*config=*/config, graph_def_version_, flib_def->get(),
540       config->graph_options().optimizer_options());
541 
542   return (*pflr)->GetFLR(ctx->device()->name());
543 }
544 
545 // Like IteratorHandleOp, but creates handles which are never shared, and does
546 // not hold a reference to these handles. The latter is important for eager
547 // execution, since OpKernel instances generally live as long as the program
548 // running them.
AnonymousIteratorHandleOp(OpKernelConstruction * context)549 AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
550     OpKernelConstruction* context)
551     : AnonymousResourceOp<IteratorResource>(
552           context,
553           /* ref_counting */
554           // Only enable this for V2 (via Python's iter protocol),
555           // AnonymousIteratorV1 requires IteratorToStringHandle, which is
556           // undefined on Refcounting ResourceHandle.
557           context->def().op() == kAnonymousIteratorV2 ||
558               context->def().op() == kAnonymousIteratorV3,
559           // V1 does not return a deleter.
560           /* return_deleter */
561           context->def().op() == kAnonymousIteratorV2),
562       graph_def_version_(context->graph_def_version()) {
563   OP_REQUIRES_OK(context, context->GetAttr(kOutputTypes, &output_dtypes_));
564   OP_REQUIRES_OK(context, context->GetAttr(kOutputShapes, &output_shapes_));
565 }
566 
name()567 string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; }
568 
CreateResource(OpKernelContext * ctx,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib,IteratorResource ** resource)569 Status AnonymousIteratorHandleOp::CreateResource(
570     OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
571     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
572     FunctionLibraryRuntime* lib, IteratorResource** resource) {
573   std::unique_ptr<DeviceMgr> device_mgr(nullptr);
574   *resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_,
575                                    std::move(device_mgr), std::move(flib_def),
576                                    std::move(pflr), lib);
577   return OkStatus();
578 }
579 
HybridAsyncOpKernel(OpKernelConstruction * ctx,const char * background_worker_name)580 HybridAsyncOpKernel::HybridAsyncOpKernel(OpKernelConstruction* ctx,
581                                          const char* background_worker_name)
582     : AsyncOpKernel(ctx),
583       background_worker_(ctx->env(), background_worker_name) {}
584 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)585 void HybridAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
586                                        DoneCallback done) {
587   background_worker_.Schedule([this, ctx, done = std::move(done)]() {
588     ctx->SetStatus(DoCompute(ctx));
589     done();
590   });
591 }
592 
Compute(OpKernelContext * ctx)593 void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) {
594   ctx->SetStatus(DoCompute(ctx));
595 }
596 
DoCompute(OpKernelContext * ctx)597 Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
598   tensorflow::ResourceTagger tag(kTFDataResourceTag,
599                                  ctx->op_kernel().type_string());
600   DatasetBase* dataset;
601   TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
602   IteratorResource* iterator_resource;
603   TF_RETURN_IF_ERROR(
604       LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
605   core::ScopedUnref unref_iterator(iterator_resource);
606   return iterator_resource->SetIteratorFromDataset(ctx, dataset);
607 }
608 
DoCompute(OpKernelContext * ctx)609 Status DeleteIteratorOp::DoCompute(OpKernelContext* ctx) {
610   tensorflow::ResourceTagger tag(kTFDataResourceTag,
611                                  ctx->op_kernel().type_string());
612   const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
613   // The iterator resource is guaranteed to exist because the variant tensor
614   // wrapping the deleter is provided as an unused input to this op, which
615   // guarantees that it has not run yet.
616   return DeleteResource(ctx, handle);
617 }
618 
619 namespace {
620 
621 class ToSingleElementOp : public AsyncOpKernel {
622  public:
ToSingleElementOp(OpKernelConstruction * ctx)623   explicit ToSingleElementOp(OpKernelConstruction* ctx)
624       : AsyncOpKernel(ctx),
625         metrics_collector_(ctx->device()->attributes().device_type(),
626                            *ctx->env()),
627         unbounded_threadpool_(ctx->env(), "tf_data_to_single_element") {
628     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
629     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
630   }
631 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)632   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
633     unbounded_threadpool_.Schedule([this, ctx, done = std::move(done)]() {
634       ctx->SetStatus(DoCompute(ctx));
635       done();
636     });
637   }
638 
Compute(OpKernelContext * ctx)639   void Compute(OpKernelContext* ctx) override {
640     ctx->SetStatus(DoCompute(ctx));
641   }
642 
643  private:
DoCompute(OpKernelContext * ctx)644   Status DoCompute(OpKernelContext* ctx) {
645     profiler::TraceMe traceme(
646         [&] {
647           return profiler::TraceMeEncode("ToSingleElementOp::DoCompute",
648                                          {{"id", ctx->step_id()}});
649         },
650         profiler::kInfo);
651     tensorflow::ResourceTagger tag(kTFDataResourceTag,
652                                    ctx->op_kernel().type_string());
653     DatasetBase* dataset;
654     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
655 
656     IteratorContext::Params params(ctx);
657     ResourceMgr resource_mgr;
658     params.resource_mgr = &resource_mgr;
659     CancellationManager cancellation_manager(ctx->cancellation_manager());
660     params.cancellation_manager = &cancellation_manager;
661 
662     IteratorContext iter_ctx(std::move(params));
663     std::unique_ptr<IteratorBase> iterator;
664     TF_RETURN_IF_ERROR(dataset->MakeIterator(
665         &iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator));
666 
667     std::vector<Tensor> components;
668     components.reserve(dataset->output_dtypes().size());
669     bool end_of_sequence = false;
670 
671     const absl::Time start_time = metrics_collector_.RecordStart();
672     TF_RETURN_IF_ERROR(
673         iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
674     metrics_collector_.RecordStop(start_time, components);
675 
676     if (end_of_sequence) {
677       return errors::InvalidArgument("Dataset was empty.");
678     }
679     TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
680     TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
681     for (int i = 0; i < components.size(); ++i) {
682       ctx->set_output(i, components[i]);
683     }
684 
685     components.clear();
686     TF_RETURN_IF_ERROR(
687         iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
688     if (!end_of_sequence) {
689       return errors::InvalidArgument("Dataset had more than one element.");
690     }
691     return OkStatus();
692   }
693 
694   IteratorMetricsCollector metrics_collector_;
695   UnboundedThreadPool unbounded_threadpool_;
696   DataTypeVector output_types_;
697   std::vector<PartialTensorShape> output_shapes_;
698 };
699 
700 class OneShotIteratorOp : public AsyncOpKernel {
701  public:
OneShotIteratorOp(OpKernelConstruction * ctx)702   explicit OneShotIteratorOp(OpKernelConstruction* ctx)
703       : AsyncOpKernel(ctx),
704         background_worker_(ctx->env(), "tf_data_one_shot_iterator"),
705         graph_def_version_(ctx->graph_def_version())
706 
707   {
708     string shared_name;
709     OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name));
710     OP_REQUIRES(ctx, shared_name.empty(),
711                 errors::InvalidArgument("OneShotIteratorOp does not currently "
712                                         "support the 'shared_name' attr."));
713     OP_REQUIRES_OK(ctx,
714                    ctx->GetAttr("dataset_factory", &dataset_factory_func_));
715     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
716     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
717   }
718 
~OneShotIteratorOp()719   ~OneShotIteratorOp() override {
720     if (iterator_resource_ != nullptr) {
721       iterator_resource_->Unref();
722       if (!cinfo_.resource_manager()
723                ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name())
724                .ok()) {
725         // Do nothing; the resource can have been deleted by session resets.
726       }
727     }
728   }
729 
730   // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`,
731   // but due to the fact that `ResourceOpKernel<T>::CreateResource()`
732   // does not provide access to the `OpKernelContext*` and we need
733   // this to invoke the factory function, it's not possible to
734   // implement this kernel by implementing `CreateResource()`.
735   // Furthermore, due to the fact that this kernel might block when
736   // running the initialization function, we must implement this
737   // kernel as an async kernel.
ComputeAsync(OpKernelContext * ctx,DoneCallback done)738   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
739     tensorflow::ResourceTagger tag(kTFDataResourceTag,
740                                    ctx->op_kernel().type_string());
741     {
742       mutex_lock l(mu_);
743       if (iterator_resource_ == nullptr && initialization_status_.ok()) {
744         // The initialization thread will call `done`.
745         if (!initialization_started_) {
746           // TODO(mrry): Convert the initialization code to use
747           // callbacks instead of wasting a thread.
748           background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); });
749           initialization_started_ = true;
750         } else {
751           done_callbacks_.emplace_back(ctx, std::move(done));
752         }
753         return;
754       }
755     }
756     ProduceOutput(ctx, done);
757   }
758 
759  private:
Init(OpKernelContext * ctx,const DoneCallback & done)760   void Init(OpKernelContext* ctx, const DoneCallback& done) {
761     IteratorResource* iterator = nullptr;
762     ContainerInfo cinfo;
763     Status s = TryInit(ctx, &iterator, &cinfo);
764 
765     std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run;
766     {
767       mutex_lock l(mu_);
768       if (s.ok()) {
769         iterator_resource_ = iterator;
770         cinfo_ = cinfo;
771       }
772       initialization_status_ = s;
773       std::swap(done_callbacks_, callbacks_to_run);
774     }
775 
776     for (auto&& ctx_done : callbacks_to_run) {
777       ProduceOutput(ctx_done.first, ctx_done.second);
778     }
779     ProduceOutput(ctx, done);
780   }
781 
TryInit(OpKernelContext * ctx,IteratorResource ** iterator,ContainerInfo * cinfo)782   Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
783                  ContainerInfo* cinfo) {
784     TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def()));
785 
786     FunctionLibraryRuntime* flr;
787     std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
788     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
789     TF_RETURN_IF_ERROR(
790         ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
791 
792     // Create an IteratorResource that will hold the iterator for this op.
793     TF_RETURN_IF_ERROR(
794         ctx->resource_manager()->LookupOrCreate<IteratorResource>(
795             cinfo->container(), cinfo->name(), iterator,
796             [ctx, flr, this, &flib_def, &pflr](IteratorResource** ret)
797                 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
798                   *ret = new IteratorResource(
799                       ctx->env(), output_dtypes_, output_shapes_,
800                       /*device_mgr=*/nullptr, std::move(flib_def),
801                       std::move(pflr), flr);
802                   return OkStatus();
803                 }));
804 
805     core::ScopedUnref unref_iterator(*iterator);
806 
807     TF_RETURN_IF_ERROR(
808         VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes()));
809     TF_RETURN_IF_ERROR(
810         VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes()));
811 
812     // Call the dataset_factory_func_ to create a new dataset,
813     // over which this op will iterate.
814     FunctionLibraryRuntime::Handle f_handle;
815     TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
816         dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()),
817         &f_handle));
818     FunctionLibraryRuntime::Options opts;
819     opts.cancellation_manager = ctx->cancellation_manager();
820     ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
821       ctx->resource_manager()->Cleanup(name).IgnoreError();
822     });
823     opts.step_container = &step_container;
824     opts.runner = ctx->runner();
825     opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
826     std::vector<Tensor> return_values;
827     TF_RETURN_IF_ERROR(ctx->function_library()->RunSync(
828         std::move(opts), f_handle, {}, &return_values));
829     if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
830         !TensorShapeUtils::IsScalar(return_values[0].shape())) {
831       return errors::InvalidArgument(
832           "The `dataset_factory` function must return "
833           "a single scalar of dtype DT_VARIANT.");
834     }
835 
836     // Create an iterator for the dataset that was created in the
837     // factory function.
838     DatasetBase* dataset;
839     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
840     TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
841     (*iterator)->Ref();
842     return OkStatus();
843   }
844 
ProduceOutput(OpKernelContext * ctx,const DoneCallback & done)845   void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) {
846     Tensor* handle;
847     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle),
848                          done);
849     Status s;
850     {
851       mutex_lock l(mu_);
852       s = initialization_status_;
853       if (s.ok()) {
854         handle->scalar<ResourceHandle>()() =
855             MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(),
856                                                  cinfo_.name());
857       }
858     }
859     OP_REQUIRES_OK_ASYNC(ctx, s, done);
860     done();
861   }
862 
863   NameAttrList dataset_factory_func_;
864   DataTypeVector output_dtypes_;
865   std::vector<PartialTensorShape> output_shapes_;
866 
867   BackgroundWorker background_worker_;
868 
869   mutex mu_;
870   ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
871   IteratorResource* iterator_resource_ TF_GUARDED_BY(mu_) = nullptr;
872 
873   bool initialization_started_ TF_GUARDED_BY(mu_) = false;
874   Status initialization_status_ TF_GUARDED_BY(mu_);
875   std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_
876       TF_GUARDED_BY(mu_);
877   const int graph_def_version_;
878 };
879 
880 }  // namespace
881 
AsAsync()882 AsyncOpKernel* IteratorGetNextOp::AsAsync() {
883   return type_string() == "IteratorGetNextSync" ? nullptr : this;
884 }
885 
RecordElementSize(const std::vector<Tensor> element,profiler::TraceMe * traceme)886 void RecordElementSize(const std::vector<Tensor> element,
887                        profiler::TraceMe* traceme) {
888   traceme->AppendMetadata([&]() {
889     int64_t element_size = 0;
890     for (const auto& component : element) {
891       element_size += component.TotalBytes();
892     }
893     return profiler::TraceMeEncode({{"element_size", element_size}});
894   });
895 }
896 
DoCompute(OpKernelContext * ctx)897 Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
898   VLOG(3) << "IteratorGetNextOp enter. iter_id=" << ctx->frame_iter().iter_id;
899   auto cleanup = gtl::MakeCleanup([ctx] {
900     VLOG(3) << "IteratorGetNextOp exit. iter_id=" << ctx->frame_iter().iter_id;
901   });
902   activity_watcher::ActivityScope activity_scope([ctx = ctx]() {
903     return activity_watcher::ActivityFromContext(
904         ctx, "IteratorGetNextOp::DoCompute",
905         activity_watcher::ActivityCategory::kDatasetOp);
906   });
907   profiler::TraceMe traceme(
908       [&] {
909         return profiler::TraceMeEncode(
910             "IteratorGetNextOp::DoCompute",
911             {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
912       },
913       profiler::kInfo);
914   tensorflow::ResourceTagger tag(kTFDataResourceTag,
915                                  ctx->op_kernel().type_string());
916   IteratorResource* iterator;
917   TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
918   core::ScopedUnref unref_iterator(iterator);
919   std::vector<Tensor> components;
920   bool end_of_sequence = false;
921 
922   TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
923   if (end_of_sequence) {
924     return errors::OutOfRange("End of sequence");
925   }
926   TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
927   TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
928   RecordElementSize(components, &traceme);
929   for (int i = 0; i < components.size(); ++i) {
930     ctx->set_output(i, components[i]);
931   }
932   return OkStatus();
933 }
934 
DoCompute(OpKernelContext * ctx)935 Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
936   VLOG(3) << "IteratorGetNextAsOptionalOp enter. iter_id="
937           << ctx->frame_iter().iter_id;
938   auto cleanup = gtl::MakeCleanup([ctx] {
939     VLOG(3) << "IteratorGetNextAsOptionalOp exit. iter_id="
940             << ctx->frame_iter().iter_id;
941   });
942   activity_watcher::ActivityScope activity_scope([ctx = ctx]() {
943     return activity_watcher::ActivityFromContext(
944         ctx, "IteratorGetNextAsOptionalOp::DoCompute",
945         activity_watcher::ActivityCategory::kDatasetOp);
946   });
947   profiler::TraceMe traceme(
948       [&] {
949         return profiler::TraceMeEncode(
950             "IteratorGetNextAsOptionalOp::DoCompute",
951             {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
952       },
953       profiler::kInfo);
954   tensorflow::ResourceTagger tag(kTFDataResourceTag,
955                                  ctx->op_kernel().type_string());
956   IteratorResource* iterator;
957   TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
958   core::ScopedUnref unref_iterator(iterator);
959   std::vector<Tensor> components;
960   bool end_of_sequence = false;
961 
962   TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
963 
964   if (end_of_sequence) {
965     return WriteOptionalNoneToOutput(ctx, 0);
966   } else {
967     RecordElementSize(components, &traceme);
968     for (int i = 0; i < components.size(); ++i) {
969       if (components[i].dtype() != output_types_[i]) {
970         return errors::InvalidArgument(
971             "The given optional does not match the expected type for "
972             "component ",
973             i, ". Expected: ", DataTypeString(output_types_[i]),
974             ". Actual: ", DataTypeString(components[i].dtype()), ".");
975       }
976       if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
977         return errors::InvalidArgument(
978             "The given optional does not match the expected shape "
979             "for component ",
980             i, ". Expected: ", output_shapes_[i].DebugString(),
981             ". Actual: ", components[i].shape().DebugString(), ".");
982       }
983     }
984     return WriteOptionalWithValueToOutput(ctx, 0, std::move(components));
985   }
986 }
987 
Compute(OpKernelContext * ctx)988 void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
989   const Tensor& resource_handle_t = ctx->input(0);
990   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
991               errors::InvalidArgument("resource_handle must be a scalar"));
992 
993   // Validate that the handle corresponds to a real resource, and
994   // that it is an IteratorResource.
995   IteratorResource* iterator_resource;
996   OP_REQUIRES_OK(
997       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
998   iterator_resource->Unref();
999 
1000   Tensor* string_handle_t;
1001   OP_REQUIRES_OK(ctx,
1002                  ctx->allocate_output(0, TensorShape({}), &string_handle_t));
1003   string_handle_t->scalar<tstring>()() =
1004       resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
1005 }
1006 
IteratorFromStringHandleOp(OpKernelConstruction * ctx)1007 IteratorFromStringHandleOp::IteratorFromStringHandleOp(
1008     OpKernelConstruction* ctx)
1009     : OpKernel(ctx) {
1010   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
1011   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1012   OP_REQUIRES(
1013       ctx,
1014       output_dtypes_.empty() || output_shapes_.empty() ||
1015           output_dtypes_.size() == output_shapes_.size(),
1016       errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
1017                               "are set, they must have the same length."));
1018 }
1019 
Compute(OpKernelContext * ctx)1020 void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
1021   const Tensor& string_handle_t = ctx->input(0);
1022   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
1023               errors::InvalidArgument("string_handle must be a scalar"));
1024 
1025   ResourceHandle resource_handle;
1026   OP_REQUIRES(
1027       ctx, resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
1028       errors::InvalidArgument(
1029           "Could not parse string_handle as a valid ResourceHandle"));
1030 
1031   OP_REQUIRES(
1032       ctx, resource_handle.device() == ctx->device()->attributes().name(),
1033       errors::InvalidArgument("Attempted create an iterator on device \"",
1034                               ctx->device()->attributes().name(),
1035                               "\" from handle defined on device \"",
1036                               resource_handle.device(), "\""));
1037 
1038   // Validate that the handle corresponds to a real resource, and
1039   // that it is an IteratorResource.
1040   IteratorResource* iterator_resource;
1041   OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource));
1042   core::ScopedUnref unref_iterator(iterator_resource);
1043   if (!output_dtypes_.empty()) {
1044     OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
1045                                          iterator_resource->output_dtypes()));
1046   }
1047   if (!output_shapes_.empty()) {
1048     OP_REQUIRES_OK(ctx,
1049                    VerifyShapesCompatible(output_shapes_,
1050                                           iterator_resource->output_shapes()));
1051   }
1052 
1053   Tensor* resource_handle_t;
1054   OP_REQUIRES_OK(ctx,
1055                  ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
1056   resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
1057 }
1058 
SerializeIteratorOp(OpKernelConstruction * ctx)1059 SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
1060     : OpKernel(ctx) {
1061   if (ctx->HasAttr(kExternalStatePolicy)) {
1062     int64_t state_change_option;
1063     OP_REQUIRES_OK(ctx,
1064                    ctx->GetAttr(kExternalStatePolicy, &state_change_option));
1065     external_state_policy_ =
1066         SerializationContext::ExternalStatePolicy(state_change_option);
1067   }
1068 }
1069 
Compute(OpKernelContext * ctx)1070 void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
1071   tensorflow::ResourceTagger tag(kTFDataResourceTag,
1072                                  ctx->op_kernel().type_string());
1073   const Tensor& resource_handle_t = ctx->input(0);
1074   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1075               errors::InvalidArgument("resource_handle must be a scalar"));
1076   // Validate that the handle corresponds to a real resource, and
1077   // that it is an IteratorResource.
1078   IteratorResource* iterator_resource;
1079   OP_REQUIRES_OK(
1080       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1081   core::ScopedUnref unref_iterator(iterator_resource);
1082   IteratorVariantSerializer serializer;
1083   SerializationContext::Params params(ctx);
1084   params.external_state_policy = external_state_policy_;
1085   SerializationContext serialization_ctx(params);
1086   OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
1087                                                         iterator_resource));
1088   Tensor* serialized_t;
1089   OP_REQUIRES_OK(ctx,
1090                  ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
1091                                       &serialized_t));
1092   OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
1093 }
1094 
Compute(OpKernelContext * ctx)1095 void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
1096   tensorflow::ResourceTagger tag(kTFDataResourceTag,
1097                                  ctx->op_kernel().type_string());
1098   // Validate that the handle corresponds to a real resource, and
1099   // that it is an IteratorResource.
1100   IteratorResource* iterator_resource;
1101   OP_REQUIRES_OK(
1102       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1103   core::ScopedUnref unref_iterator(iterator_resource);
1104   const Tensor* serialized_t;
1105   OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
1106   IteratorVariantSerializer serializer;
1107   OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
1108   Status s = iterator_resource->Restore(ctx, serializer.GetReader());
1109   if (!s.ok()) {
1110     OP_REQUIRES_OK(
1111         ctx,
1112         errors::CreateWithUpdatedMessage(
1113             s, absl::StrCat(
1114                    "Failed to restore dataset iterator from checkpoint: ",
1115                    s.error_message(),
1116                    ". Make sure the dataset definition has not changed between "
1117                    "the process that saved the checkpoint and the process that "
1118                    "is restoring it.")));
1119   }
1120 }
1121 
1122 namespace {
1123 
1124 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
1125 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
1126                         IteratorHandleOp);
1127 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU).Priority(1),
1128                         IteratorHandleOp);
1129 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2),
1130                         MakeIteratorOp);
1131 REGISTER_KERNEL_BUILDER(
1132     Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"),
1133     MakeIteratorOp);
1134 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_CPU).Priority(2),
1135                         DeleteIteratorOp);
1136 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_GPU).Priority(1),
1137                         DeleteIteratorOp);
1138 REGISTER_KERNEL_BUILDER(
1139     Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2),
1140     AnonymousIteratorHandleOp);
1141 REGISTER_KERNEL_BUILDER(
1142     Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1),
1143     AnonymousIteratorHandleOp);
1144 REGISTER_KERNEL_BUILDER(
1145     Name("AnonymousIteratorV2").Device(DEVICE_CPU).Priority(2),
1146     AnonymousIteratorHandleOp);
1147 REGISTER_KERNEL_BUILDER(
1148     Name("AnonymousIteratorV2").Device(DEVICE_GPU).Priority(1),
1149     AnonymousIteratorHandleOp);
1150 REGISTER_KERNEL_BUILDER(
1151     Name("AnonymousIteratorV3").Device(DEVICE_CPU).Priority(2),
1152     AnonymousIteratorHandleOp);
1153 REGISTER_KERNEL_BUILDER(
1154     Name("AnonymousIteratorV3").Device(DEVICE_GPU).Priority(1),
1155     AnonymousIteratorHandleOp);
1156 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
1157                         ToSingleElementOp);
1158 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
1159                         OneShotIteratorOp);
1160 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU).Priority(2),
1161                         IteratorGetNextOp);
1162 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU).Priority(1),
1163                         IteratorGetNextOp);
1164 REGISTER_KERNEL_BUILDER(
1165     Name("IteratorGetNextSync").Device(DEVICE_CPU).Priority(2),
1166     IteratorGetNextOp);
1167 REGISTER_KERNEL_BUILDER(
1168     Name("IteratorGetNextSync").Device(DEVICE_GPU).Priority(1),
1169     IteratorGetNextOp);
1170 REGISTER_KERNEL_BUILDER(
1171     Name("IteratorGetNextAsOptional").Device(DEVICE_CPU).Priority(2),
1172     IteratorGetNextAsOptionalOp);
1173 REGISTER_KERNEL_BUILDER(
1174     Name("IteratorGetNextAsOptional").Device(DEVICE_GPU).Priority(1),
1175     IteratorGetNextAsOptionalOp);
1176 REGISTER_KERNEL_BUILDER(
1177     Name("IteratorToStringHandle").Device(DEVICE_CPU).Priority(2),
1178     IteratorToStringHandleOp);
1179 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
1180                             .Device(DEVICE_GPU)
1181                             .HostMemory("string_handle")
1182                             .Priority(1),
1183                         IteratorToStringHandleOp);
1184 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
1185                         IteratorFromStringHandleOp);
1186 REGISTER_KERNEL_BUILDER(
1187     Name("IteratorFromStringHandleV2").Device(DEVICE_CPU).Priority(2),
1188     IteratorFromStringHandleOp);
1189 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
1190                             .Device(DEVICE_GPU)
1191                             .HostMemory("string_handle")
1192                             .Priority(1),
1193                         IteratorFromStringHandleOp);
1194 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
1195                         SerializeIteratorOp);
1196 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
1197                         DeserializeIteratorOp);
1198 
1199 }  // namespace
1200 
1201 }  // namespace data
1202 }  // namespace tensorflow
1203