• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/iterator_ops.h"
16 
17 #include <memory>
18 
19 #include "absl/memory/memory.h"
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/common_runtime/graph_runner.h"
22 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
23 #include "tensorflow/core/common_runtime/metrics.h"
24 #include "tensorflow/core/common_runtime/renamed_device.h"
25 #include "tensorflow/core/common_runtime/threadpool_device.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/metrics.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/partial_tensor_shape.h"
31 #include "tensorflow/core/framework/resource_op_kernel.h"
32 #include "tensorflow/core/framework/stats_aggregator.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/framework/variant_op_registry.h"
36 #include "tensorflow/core/framework/variant_tensor_data.h"
37 #include "tensorflow/core/kernels/data/captured_function.h"
38 #include "tensorflow/core/kernels/data/dataset_utils.h"
39 #include "tensorflow/core/kernels/data/optional_ops.h"
40 #include "tensorflow/core/kernels/ops_util.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/refcount.h"
43 #include "tensorflow/core/lib/core/threadpool.h"
44 #include "tensorflow/core/lib/gtl/cleanup.h"
45 #include "tensorflow/core/lib/random/random.h"
46 #include "tensorflow/core/lib/strings/strcat.h"
47 #include "tensorflow/core/lib/strings/stringprintf.h"
48 #include "tensorflow/core/platform/casts.h"
49 #include "tensorflow/core/platform/env.h"
50 #include "tensorflow/core/platform/mem.h"
51 #include "tensorflow/core/platform/mutex.h"
52 #include "tensorflow/core/platform/refcount.h"
53 #include "tensorflow/core/platform/resource.h"
54 #include "tensorflow/core/profiler/lib/traceme.h"
55 #include "tensorflow/core/profiler/lib/traceme_encode.h"
56 #include "tensorflow/core/public/session_options.h"
57 
58 namespace tensorflow {
59 namespace data {
60 namespace {
61 
62 // See documentation in ../../ops/dataset_ops.cc for a high-level
63 // description of the following ops.
64 
65 const char kAnonymousIterator[] = "AnonymousIterator";
66 const char kAnonymousIteratorV2[] = "AnonymousIteratorV2";
67 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
68 const char kOutputShapes[] = "output_shapes";
69 const char kOutputTypes[] = "output_types";
70 
71 // Safely subtracts x from y avoiding underflow.
safe_sub(uint64 x,uint64 y)72 inline uint64 safe_sub(uint64 x, uint64 y) { return x >= y ? x - y : 0; }
73 
74 }  // namespace
75 
76 /* static */ constexpr const char* const
77     SerializeIteratorOp::kExternalStatePolicy;
78 
IteratorResource(Env * env,const DataTypeVector & output_dtypes,const std::vector<PartialTensorShape> & output_shapes,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr)79 IteratorResource::IteratorResource(
80     Env* env, const DataTypeVector& output_dtypes,
81     const std::vector<PartialTensorShape>& output_shapes,
82     std::unique_ptr<DeviceMgr> device_mgr,
83     std::unique_ptr<FunctionLibraryDefinition> flib_def,
84     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
85     FunctionLibraryRuntime* flr)
86     : unbounded_thread_pool_(env, "tf_data_iterator_resource"),
87       device_mgr_(std::move(device_mgr)),
88       iterator_state_(std::make_shared<State>(std::move(flib_def),
89                                               std::move(pflr), flr,
90                                               /*iterator=*/nullptr)),
91       output_dtypes_(output_dtypes),
92       output_shapes_(output_shapes),
93       // We do not collect iterator resource metrics for non-CPU devices. This
94       // is a heuristic to avoid collecting metrics for device-side iterators
95       // created by the multi-device iterator mechanism.
96       collect_metrics_(flr->device()->device_type() == DEVICE_CPU) {
97   VLOG(2) << "creating iterator resource";
98 }
99 
~IteratorResource()100 IteratorResource::~IteratorResource() {
101   VLOG(2) << "destroying iterator resource";
102 }
103 
GetNext(OpKernelContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)104 Status IteratorResource::GetNext(OpKernelContext* ctx,
105                                  std::vector<Tensor>* out_tensors,
106                                  bool* end_of_sequence) {
107   std::shared_ptr<State> captured_state;
108   {
109     tf_shared_lock l(mu_);
110     captured_state = iterator_state_;
111   }
112   if (!captured_state->iterator()) {
113     return errors::FailedPrecondition(
114         "GetNext() failed because the iterator has not been initialized. "
115         "Ensure that you have run the initializer operation for this iterator "
116         "before getting the next element.");
117   }
118   IteratorContext::Params params(ctx);
119   params.flr = captured_state->flr();
120   params.function_handle_cache = captured_state->function_handle_cache();
121   params.resource_mgr = captured_state->resource_mgr();
122   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
123   params.thread_pool = &unbounded_thread_pool_;
124   params.cancellation_manager = captured_state->cancellation_manager();
125   std::function<void()> deregister_fn;
126   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
127       ctx->cancellation_manager(),
128       [cm = params.cancellation_manager]() { cm->StartCancel(); },
129       &deregister_fn));
130   auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
131   const uint64 start_time_us = ctx->env()->NowMicros();
132   if (collect_metrics_) {
133     mutex_lock l(mu_);
134     if (get_next_end_time_us_ == 0) {
135       // We initialize `get_next_end_time_us_` to the start time of the first
136       // request to make it possible to use the delta between
137       // `get_next_end_time_us_` and subsequent `GetNext()` end time to
138       // incrementally collect the duration of the iterator's lifetime.
139       get_next_end_time_us_ = start_time_us;
140     }
141     if (num_get_next_calls_ == 0) {
142       get_next_start_time_us_ = start_time_us;
143     }
144     num_get_next_calls_++;
145   }
146   auto iterator_ = captured_state->iterator();
147   auto status = iterator_->GetNext(IteratorContext(std::move(params)),
148                                    out_tensors, end_of_sequence);
149   if (collect_metrics_) {
150     const uint64 end_time_us = ctx->env()->NowMicros();
151     metrics::RecordTFDataGetNextDuration(safe_sub(end_time_us, start_time_us));
152     metrics::RecordTFDataBytesFetched(GetTotalBytes(*out_tensors));
153     mutex_lock l(mu_);
154     metrics::RecordTFDataIteratorLifetime(
155         safe_sub(end_time_us, get_next_end_time_us_));
156     get_next_end_time_us_ = std::max(get_next_end_time_us_, end_time_us);
157     num_get_next_calls_--;
158     if (num_get_next_calls_ == 0) {
159       metrics::RecordTFDataIteratorBusy(
160           safe_sub(get_next_end_time_us_, get_next_start_time_us_));
161     }
162   }
163   return status;
164 }
165 
Save(SerializationContext * ctx,IteratorStateWriter * writer)166 Status IteratorResource::Save(SerializationContext* ctx,
167                               IteratorStateWriter* writer) {
168   std::shared_ptr<State> captured_state;
169   {
170     tf_shared_lock l(mu_);
171     captured_state = iterator_state_;
172   }
173   auto iterator_ = captured_state->iterator();
174   if (iterator_) {
175     return iterator_->Save(ctx, writer);
176   }
177   return errors::FailedPrecondition(
178       "Save() failed because the iterator has not been initialized. Ensure "
179       "that you have run the initializer operation for this iterator before "
180       "saving it.");
181 }
182 
Restore(OpKernelContext * ctx,IteratorStateReader * reader)183 Status IteratorResource::Restore(OpKernelContext* ctx,
184                                  IteratorStateReader* reader) {
185   const DatasetBase* dataset;
186   std::shared_ptr<State> new_state;
187   {
188     tf_shared_lock l(mu_);
189     if (!iterator_state_->iterator()) {
190       return errors::FailedPrecondition(
191           "Restore() failed because the iterator has not been initialized. "
192           "Ensure that you have run the initializer operation for this "
193           "iterator before restoring it.");
194     }
195     auto iterator_ = iterator_state_->iterator();
196     dataset = iterator_->dataset();
197     // Hang onto a reference until we've created the new iterator, which will
198     // then hold its own reference to keep the dataset alive.
199     dataset->Ref();
200     new_state =
201         std::make_shared<State>(iterator_state_->flib_def(),
202                                 iterator_state_->pflr(), iterator_state_->flr(),
203                                 /*iterator=*/nullptr);
204   }
205   core::ScopedUnref scoped_unref(dataset);
206   IteratorContext::Params params(ctx);
207   params.flr = new_state->flr();
208   params.function_handle_cache = new_state->function_handle_cache();
209   params.resource_mgr = new_state->resource_mgr();
210   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
211   params.thread_pool = &unbounded_thread_pool_;
212   params.cancellation_manager = new_state->cancellation_manager();
213   std::function<void()> deregister_fn;
214   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
215       ctx->cancellation_manager(),
216       [cm = params.cancellation_manager]() { cm->StartCancel(); },
217       &deregister_fn));
218   auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
219   std::unique_ptr<IteratorBase> iterator_base;
220   TF_RETURN_IF_ERROR(dataset->MakeIteratorFromCheckpoint(
221       IteratorContext(std::move(params)), "Iterator", reader, &iterator_base));
222   new_state->DowncastAndSetIterator(std::move(iterator_base));
223 
224   mutex_lock l(mu_);
225   std::swap(iterator_state_, new_state);
226   return Status::OK();
227 }
228 
SetIteratorFromDataset(OpKernelContext * ctx,DatasetBase * dataset)229 Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
230                                                 DatasetBase* dataset) {
231   std::shared_ptr<State> new_state;
232   {
233     tf_shared_lock l(mu_);
234     new_state =
235         std::make_shared<State>(iterator_state_->flib_def(),
236                                 iterator_state_->pflr(), iterator_state_->flr(),
237                                 /*iterator=*/nullptr);
238   }
239   // Create new iterator.
240   std::unique_ptr<IteratorBase> iterator;
241   IteratorContext::Params params(ctx);
242   params.flr = new_state->flr();
243   params.function_handle_cache = new_state->function_handle_cache();
244   params.resource_mgr = new_state->resource_mgr();
245   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
246   params.thread_pool = &unbounded_thread_pool_;
247   params.cancellation_manager = new_state->cancellation_manager();
248   std::function<void()> deregister_fn;
249   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
250       ctx->cancellation_manager(),
251       [cm = params.cancellation_manager]() { cm->StartCancel(); },
252       &deregister_fn));
253   {
254     auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
255     TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
256                                              /*parent=*/nullptr, "Iterator",
257                                              &iterator));
258     TF_RETURN_IF_ERROR(
259         VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
260     TF_RETURN_IF_ERROR(
261         VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
262 
263     new_state->DowncastAndSetIterator(std::move(iterator));
264   }
265 
266   mutex_lock l(mu_);
267   std::swap(iterator_state_, new_state);
268   return Status::OK();
269 }
270 
271 namespace {
272 
273 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
274 // The get() method returns an VariantTensorData object which contains all the
275 // state needed to restore a single iterator.
276 //
277 // Usage example:
278 //
279 // Encoding:
280 //
281 //   Tensor t(DT_VARIANT, TensorShape({}));
282 //   t->scalar<Variant>()() = IteratorStateVariant();
283 //
284 // Encode() sets the type_name of the VariantTensorData object to
285 // IteratorStateVariant::TypeName().
286 //
287 // Decoding:
288 //
289 //   Variant v = <VariantTensorDataProto object>;
290 //   DecodeUnaryVariant(&v);
291 //   IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
292 //   IteratorStateReader reader({wrapper->GetData()});
293 //   iterator_resource->Restore(ctx, &reader);
294 //
295 // The type_name of the VariantTensorData object to be decoded must
296 // match IteratorStateVariant::TypeName().
297 class IteratorStateVariant {
298  public:
IteratorStateVariant()299   IteratorStateVariant() : data_(nullptr) {}
IteratorStateVariant(const IteratorStateVariant & other)300   IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
301     if (other.data_) {
302       Decode(*other.data_);
303     }
304   }
305   IteratorStateVariant& operator=(IteratorStateVariant&& other) = default;
306   IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete;
307 
308   // Initializes `this` from a VariantTensorData object.
InitializeFromVariantData(std::unique_ptr<VariantTensorData> d)309   Status InitializeFromVariantData(std::unique_ptr<VariantTensorData> d) {
310     data_ = std::move(d);
311     return Status::OK();
312   }
313 
TypeName() const314   string TypeName() const { return kIteratorVariantTypeName; }
Encode(VariantTensorData * data) const315   void Encode(VariantTensorData* data) const { *data = *data_; }
Decode(VariantTensorData data)316   bool Decode(VariantTensorData data) {
317     if (data.type_name() != TypeName()) {
318       return false;
319     }
320     auto tensor_data = absl::make_unique<VariantTensorData>();
321     std::swap(*tensor_data, data);
322     data_ = std::move(tensor_data);
323     return true;
324   }
325 
326   // Returns a borrowed pointer to the underlying VariantTensorData.
GetData() const327   const VariantTensorData* GetData() const { return data_.get(); }
328 
DebugString() const329   string DebugString() const {
330     if (data_) {
331       return strings::StrCat("IteratorStateVariant<", data_->DebugString(),
332                              ">");
333     } else {
334       return strings::StrCat("IteratorStateVariant<empty>");
335     }
336   }
337 
338  private:
339   std::unique_ptr<VariantTensorData> data_;
340 };
341 
342 // Register the reader class in the global variant decode_fn registry
343 // so that a Variant containing a serialized representation of iterator state
344 // can be decoded using DecodeUnaryVariant. If we don't do this we will need
345 // to manually decode the returned Variant using MaybeDecodeAndCopy in
346 // DeserializeIteratorOp which is not recommended.
347 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
348                                        kIteratorVariantTypeName);
349 
350 // A helper class that uses a list of IteratorStateVariant objects to represent
351 // the state for an iterator resource. It exposes methods that help with
352 // saving and restoring of this state. Sample usage
353 // Saving:
354 //   IteratorVariantSerializer serializer;
355 //   serializer.InitializeFromIterator(iterator_resource);
356 //   Tensor serialized_t;
357 //   serializer.Serialize(&serialized_t);
358 //
359 // Restoring:
360 //   IteratorVariantSerializer serializer;
361 //   serializer.InitFromTensor(ctx->input(0));
362 //   IteratorStateReader* reader = serializer.GetReader();
363 //   iterator_resource->Restore(ctx, reader);
364 class IteratorVariantSerializer {
365  public:
IteratorVariantSerializer()366   IteratorVariantSerializer() {}
367 
368   // Calls `Save` on the iterator_resource to build up the list of
369   // IteratorStateVariant objects.
InitializeFromIterator(SerializationContext * serialization_ctx,IteratorResource * iterator_resource)370   Status InitializeFromIterator(SerializationContext* serialization_ctx,
371                                 IteratorResource* iterator_resource) {
372     VariantTensorDataWriter writer;
373     TF_RETURN_IF_ERROR(iterator_resource->Save(serialization_ctx, &writer));
374     std::vector<std::unique_ptr<VariantTensorData>> data;
375     writer.ReleaseData(&data);
376     variants_.clear();
377     variants_.reserve(data.size());
378     for (auto& it : data) {
379       IteratorStateVariant v;
380       TF_RETURN_IF_ERROR(v.InitializeFromVariantData(std::move(it)));
381       variants_.push_back(v);
382     }
383     num_tensors_ = variants_.size();
384     can_serialize_ = true;
385     return Status::OK();
386   }
387 
388   // Initializes `this` from `serialized_t` while restoring the iterator state.
InitFromTensor(const Tensor * serialized_t)389   Status InitFromTensor(const Tensor* serialized_t) {
390     int64 num_tensors = serialized_t->dim_size(0);
391     auto serialized_vec = serialized_t->vec<Variant>();
392     std::vector<const VariantTensorData*> data;
393     data.reserve(num_tensors);
394     for (int i = 0; i < num_tensors; ++i) {
395       auto* w = serialized_vec(i).get<IteratorStateVariant>();
396       if (!w) {
397         return errors::Internal(
398             "Cannot initialize an iterator from tensor ",
399             serialized_vec(i).DebugString(),
400             ". Expected a variant tensor of type IteratorStateVariant");
401       }
402       data.push_back(w->GetData());
403     }
404     reader_ = absl::make_unique<VariantTensorDataReader>(data);
405     num_tensors_ = data.size();
406     return Status::OK();
407   }
408 
NumTensors()409   int64 NumTensors() { return num_tensors_; }
410 
411   // Stores the IteratorStateVariant list into a pre-allocated tensor. Expects
412   // that InitializeFromIterator was called before.
Serialize(Tensor * serialized)413   Status Serialize(Tensor* serialized) {
414     if (!can_serialize_) {
415       return errors::InvalidArgument(
416           "Please call InitializeFromIterator before calling Serialize.");
417     }
418     int64 size = variants_.size();
419     for (int64 i = 0; i < size; ++i) {
420       if (variants_[i].GetData() == nullptr) {
421         return errors::Internal(
422             "Cannot serialize an empty IteratorStateVariant");
423       }
424       serialized->vec<Variant>()(i) = variants_[i];
425     }
426     return Status::OK();
427   }
428 
429   // Returns an IteratorStateReader to restore iterator state. Expects that
430   // InitFromTensor was called before.
GetReader()431   IteratorStateReader* GetReader() { return reader_.get(); }
432 
433  private:
434   bool can_serialize_ = false;
435   int64 num_tensors_;
436   std::vector<IteratorStateVariant> variants_;
437   std::unique_ptr<IteratorStateReader> reader_;
438 };
439 
440 }  // namespace
441 
442 // Note that IteratorHandleOp holds a reference to the resource it creates. If
443 // cleaning up resources with DestroyResourceOp is important, consider creating
444 // resource containers with AnonymousIteratorHandleOp instead.
IteratorHandleOp(OpKernelConstruction * ctx)445 IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
446     : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
447   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
448   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
449   OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
450 }
451 
452 // The resource is deleted from the resource manager only when it is private
453 // to kernel. Ideally the resource should be deleted when it is no longer held
454 // by anyone, but it would break backward compatibility.
~IteratorHandleOp()455 IteratorHandleOp::~IteratorHandleOp() {
456   if (resource_ != nullptr) {
457     resource_->Unref();
458     if (cinfo_.resource_is_private_to_kernel()) {
459       if (!cinfo_.resource_manager()
460                ->template Delete<IteratorResource>(cinfo_.container(),
461                                                    cinfo_.name())
462                .ok()) {
463         // Do nothing; the resource can have been deleted by session resets.
464       }
465     }
466   }
467 }
468 
Compute(OpKernelContext * context)469 void IteratorHandleOp::Compute(OpKernelContext* context)
470     TF_LOCKS_EXCLUDED(mu_) {
471   {
472     mutex_lock l(mu_);
473     if (resource_ == nullptr) {
474       FunctionLibraryRuntime* flr;
475       std::unique_ptr<DeviceMgr> device_mgr(nullptr);
476       std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
477       std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
478       // If the iterator is shared then we construct a new FLR, and pass that
479       // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
480       // functions from the iterator. We may add this functionality if there
481       // is sufficient demand, but it will require a significant refactoring.
482       if (!name_.empty()) {
483         flr = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
484       } else {
485         OP_REQUIRES_OK(context, context->function_library()->Clone(
486                                     &flib_def, &pflr, &flr, true));
487       }
488 
489       ResourceMgr* mgr = context->resource_manager();
490       OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
491 
492       IteratorResource* resource;
493       OP_REQUIRES_OK(
494           context,
495           mgr->LookupOrCreate<IteratorResource>(
496               cinfo_.container(), cinfo_.name(), &resource,
497               [context, flr, &device_mgr, &flib_def, &pflr,
498                this](IteratorResource** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
499                 *ret = new IteratorResource(
500                     context->env(), output_dtypes_, output_shapes_,
501                     std::move(device_mgr), std::move(flib_def), std::move(pflr),
502                     flr);
503                 return Status::OK();
504               }));
505 
506       Status s = VerifyResource(resource);
507       if (TF_PREDICT_FALSE(!s.ok())) {
508         resource->Unref();
509         context->SetStatus(s);
510         return;
511       }
512 
513       resource_ = resource;
514     }
515   }
516   OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
517                               context, 0, cinfo_.container(), cinfo_.name(),
518                               TypeIndex::Make<IteratorResource>()));
519 }
520 
VerifyResource(IteratorResource * resource)521 Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
522   TF_RETURN_IF_ERROR(
523       VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
524   TF_RETURN_IF_ERROR(
525       VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
526   return Status::OK();
527 }
528 
CreatePrivateFLR(OpKernelContext * ctx,std::unique_ptr<DeviceMgr> * device_mgr,std::unique_ptr<FunctionLibraryDefinition> * flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * pflr)529 FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
530     OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
531     std::unique_ptr<FunctionLibraryDefinition>* flib_def,
532     std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
533   // Wrap the existing device in order to see any captured resources
534   // in its resource manager. The existing device will outlive the
535   // IteratorResource, because we are storing the IteratorResource
536   // in that device's resource manager.
537 
538   *device_mgr =
539       absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
540           ctx->device()->name(), down_cast<Device*>(ctx->device()),
541           false /* owns_underlying */, false /* isolate_session_state */));
542   *flib_def = absl::make_unique<FunctionLibraryDefinition>(
543       *ctx->function_library()->GetFunctionLibraryDefinition());
544   const auto* config = ctx->function_library()->config_proto();
545   *pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
546       device_mgr->get(), ctx->env(),
547       /*config=*/config, graph_def_version_, flib_def->get(),
548       config->graph_options().optimizer_options());
549 
550   return (*pflr)->GetFLR(ctx->device()->name());
551 }
552 
553 // Like IteratorHandleOp, but creates handles which are never shared, and does
554 // not hold a reference to these handles. The latter is important for eager
555 // execution, since OpKernel instances generally live as long as the program
556 // running them.
AnonymousIteratorHandleOp(OpKernelConstruction * context)557 AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
558     OpKernelConstruction* context)
559     : AnonymousResourceOp<IteratorResource>(context),
560       graph_def_version_(context->graph_def_version()) {
561   OP_REQUIRES_OK(context, context->GetAttr(kOutputTypes, &output_dtypes_));
562   OP_REQUIRES_OK(context, context->GetAttr(kOutputShapes, &output_shapes_));
563   create_deleter_ = context->def().op() == kAnonymousIteratorV2;
564 }
565 
name()566 string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; }
567 
CreateResource(OpKernelContext * ctx,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib,IteratorResource ** resource)568 Status AnonymousIteratorHandleOp::CreateResource(
569     OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
570     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
571     FunctionLibraryRuntime* lib, IteratorResource** resource) {
572   std::unique_ptr<DeviceMgr> device_mgr(nullptr);
573   *resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_,
574                                    std::move(device_mgr), std::move(flib_def),
575                                    std::move(pflr), lib);
576   return Status::OK();
577 }
578 
HybridAsyncOpKernel(OpKernelConstruction * ctx,const char * background_worker_name)579 HybridAsyncOpKernel::HybridAsyncOpKernel(OpKernelConstruction* ctx,
580                                          const char* background_worker_name)
581     : AsyncOpKernel(ctx),
582       background_worker_(ctx->env(), background_worker_name) {}
583 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)584 void HybridAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
585                                        DoneCallback done) {
586   background_worker_.Schedule([this, ctx, done = std::move(done)]() {
587     ctx->SetStatus(DoCompute(ctx));
588     done();
589   });
590 }
591 
Compute(OpKernelContext * ctx)592 void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) {
593   ctx->SetStatus(DoCompute(ctx));
594 }
595 
DoCompute(OpKernelContext * ctx)596 Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
597   tensorflow::ResourceTagger tag(kTFDataResourceTag,
598                                  ctx->op_kernel().type_string());
599   DatasetBase* dataset;
600   TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
601   IteratorResource* iterator_resource;
602   TF_RETURN_IF_ERROR(
603       LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
604   core::ScopedUnref unref_iterator(iterator_resource);
605   return iterator_resource->SetIteratorFromDataset(ctx, dataset);
606 }
607 
DoCompute(OpKernelContext * ctx)608 Status DeleteIteratorOp::DoCompute(OpKernelContext* ctx) {
609   tensorflow::ResourceTagger tag(kTFDataResourceTag,
610                                  ctx->op_kernel().type_string());
611   const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
612   // The iterator resource is guaranteed to exist because the variant tensor
613   // wrapping the deleter is provided as an unused input to this op, which
614   // guarantees that it has not run yet.
615   return ctx->resource_manager()->Delete(handle);
616 }
617 
618 namespace {
619 
620 class ToSingleElementOp : public HybridAsyncOpKernel {
621  public:
ToSingleElementOp(OpKernelConstruction * ctx)622   explicit ToSingleElementOp(OpKernelConstruction* ctx)
623       : HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {
624     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
625     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
626   }
627 
628  protected:
DoCompute(OpKernelContext * ctx)629   Status DoCompute(OpKernelContext* ctx) override {
630     tensorflow::ResourceTagger tag(kTFDataResourceTag,
631                                    ctx->op_kernel().type_string());
632     DatasetBase* dataset;
633     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
634 
635     IteratorContext::Params params(ctx);
636     FunctionHandleCache function_handle_cache(params.flr);
637     params.function_handle_cache = &function_handle_cache;
638     ResourceMgr resource_mgr;
639     params.resource_mgr = &resource_mgr;
640     CancellationManager cancellation_manager(ctx->cancellation_manager());
641     params.cancellation_manager = &cancellation_manager;
642 
643     IteratorContext iter_ctx(std::move(params));
644     std::unique_ptr<IteratorBase> iterator;
645     TF_RETURN_IF_ERROR(dataset->MakeIterator(
646         &iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator));
647 
648     std::vector<Tensor> components;
649     components.reserve(dataset->output_dtypes().size());
650     bool end_of_sequence = false;
651 
652     TF_RETURN_IF_ERROR(
653         iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
654 
655     if (end_of_sequence) {
656       return errors::InvalidArgument("Dataset was empty.");
657     }
658     TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
659     TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
660     for (int i = 0; i < components.size(); ++i) {
661       ctx->set_output(i, components[i]);
662     }
663 
664     components.clear();
665     TF_RETURN_IF_ERROR(
666         iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
667     if (!end_of_sequence) {
668       return errors::InvalidArgument("Dataset had more than one element.");
669     }
670     return Status::OK();
671   }
672 
673  private:
674   DataTypeVector output_types_;
675   std::vector<PartialTensorShape> output_shapes_;
676 };
677 
678 class ReduceDatasetOp : public HybridAsyncOpKernel {
679  public:
ReduceDatasetOp(OpKernelConstruction * ctx)680   explicit ReduceDatasetOp(OpKernelConstruction* ctx)
681       : HybridAsyncOpKernel(ctx, "tf_data_reduce_dataset") {
682     FunctionMetadata::Params params;
683     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
684                                      &params.use_inter_op_parallelism));
685     params.use_default_device = false;
686     OP_REQUIRES_OK(ctx,
687                    FunctionMetadata::Create(ctx, "f", params, &func_metadata_));
688     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
689     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
690   }
691 
692  protected:
DoCompute(OpKernelContext * ctx)693   Status DoCompute(OpKernelContext* ctx) override {
694     tensorflow::ResourceTagger tag(kTFDataResourceTag,
695                                    ctx->op_kernel().type_string());
696     DatasetBase* dataset;
697     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
698     OpInputList inputs;
699     TF_RETURN_IF_ERROR(ctx->input_list("initial_state", &inputs));
700     std::vector<Tensor> state(inputs.begin(), inputs.end());
701 
702     std::unique_ptr<CapturedFunction> captured_func;
703     TF_RETURN_IF_ERROR(CapturedFunction::Create(
704         ctx, func_metadata_, "other_arguments", &captured_func));
705 
706     IteratorContext::Params params(ctx);
707     auto function_handle_cache =
708         absl::make_unique<FunctionHandleCache>(params.flr);
709     params.function_handle_cache = function_handle_cache.get();
710     ResourceMgr resource_mgr;
711     params.resource_mgr = &resource_mgr;
712     CancellationManager cancellation_manager(ctx->cancellation_manager());
713     params.cancellation_manager = &cancellation_manager;
714 
715     IteratorContext iter_ctx(std::move(params));
716     std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func;
717     TF_RETURN_IF_ERROR(
718         captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
719 
720     std::unique_ptr<IteratorBase> iterator;
721     TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, /*parent=*/nullptr,
722                                              "ReduceIterator", &iterator));
723 
724     // Iterate through the input dataset.
725     while (true) {
726       if (ctx->cancellation_manager()->IsCancelled()) {
727         return errors::Cancelled("Operation was cancelled");
728       }
729       std::vector<Tensor> next_input_element;
730       bool end_of_input;
731       TF_RETURN_IF_ERROR(
732           iterator->GetNext(&iter_ctx, &next_input_element, &end_of_input));
733       if (end_of_input) {
734         break;
735       }
736 
737       // Run the reduce function to update the current state.
738       std::vector<Tensor> args;
739       args.reserve(state.size() + next_input_element.size());
740       std::copy(state.begin(), state.end(), std::back_inserter(args));
741       std::copy(next_input_element.begin(), next_input_element.end(),
742                 std::back_inserter(args));
743 
744       std::vector<Tensor> reduce_func_output;
745       TF_RETURN_IF_ERROR(instantiated_captured_func->Run(
746           &iter_ctx, std::move(args), &reduce_func_output, /*node=*/nullptr));
747       if (reduce_func_output.size() != state.size()) {
748         return errors::InvalidArgument(
749             "The number of components of the initial state and the "
750             "reduce "
751             "function output does not match. (initial_state=",
752             state.size(), ", output=", reduce_func_output.size(), ").");
753       }
754       std::swap(reduce_func_output, state);
755     }
756 
757     TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, state));
758     TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, state));
759     for (size_t i = 0; i < state.size(); ++i) {
760       ctx->set_output(i, state[i]);
761     }
762     return Status::OK();
763   }
764 
765   std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
766   DataTypeVector output_types_;
767   std::vector<PartialTensorShape> output_shapes_;
768 };
769 
770 class OneShotIteratorOp : public AsyncOpKernel {
771  public:
OneShotIteratorOp(OpKernelConstruction * ctx)772   explicit OneShotIteratorOp(OpKernelConstruction* ctx)
773       : AsyncOpKernel(ctx),
774         background_worker_(ctx->env(), "tf_data_one_shot_iterator"),
775         graph_def_version_(ctx->graph_def_version())
776 
777   {
778     string shared_name;
779     OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name));
780     OP_REQUIRES(ctx, shared_name.empty(),
781                 errors::InvalidArgument("OneShotIteratorOp does not currently "
782                                         "support the 'shared_name' attr."));
783     OP_REQUIRES_OK(ctx,
784                    ctx->GetAttr("dataset_factory", &dataset_factory_func_));
785     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
786     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
787   }
788 
~OneShotIteratorOp()789   ~OneShotIteratorOp() override {
790     if (iterator_resource_ != nullptr) {
791       iterator_resource_->Unref();
792       if (!cinfo_.resource_manager()
793                ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name())
794                .ok()) {
795         // Do nothing; the resource can have been deleted by session resets.
796       }
797     }
798   }
799 
800   // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`,
801   // but due to the fact that `ResourceOpKernel<T>::CreateResource()`
802   // does not provide access to the `OpKernelContext*` and we need
803   // this to invoke the factory function, it's not possible to
804   // implement this kernel by implementing `CreateResource()`.
805   // Furthermore, due to the fact that this kernel might block when
806   // running the initialization function, we must implement this
807   // kernel as an async kernel.
ComputeAsync(OpKernelContext * ctx,DoneCallback done)808   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
809     tensorflow::ResourceTagger tag(kTFDataResourceTag,
810                                    ctx->op_kernel().type_string());
811     {
812       mutex_lock l(mu_);
813       if (iterator_resource_ == nullptr && initialization_status_.ok()) {
814         // The initialization thread will call `done`.
815         if (!initialization_started_) {
816           // TODO(mrry): Convert the initialization code to use
817           // callbacks instead of wasting a thread.
818           background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); });
819           initialization_started_ = true;
820         } else {
821           done_callbacks_.emplace_back(ctx, std::move(done));
822         }
823         return;
824       }
825     }
826     ProduceOutput(ctx, done);
827   }
828 
829  private:
Init(OpKernelContext * ctx,const DoneCallback & done)830   void Init(OpKernelContext* ctx, const DoneCallback& done) {
831     IteratorResource* iterator = nullptr;
832     ContainerInfo cinfo;
833     Status s = TryInit(ctx, &iterator, &cinfo);
834 
835     std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run;
836     {
837       mutex_lock l(mu_);
838       if (s.ok()) {
839         iterator_resource_ = iterator;
840         cinfo_ = cinfo;
841       }
842       initialization_status_ = s;
843       std::swap(done_callbacks_, callbacks_to_run);
844     }
845 
846     for (auto&& ctx_done : callbacks_to_run) {
847       ProduceOutput(ctx_done.first, ctx_done.second);
848     }
849     ProduceOutput(ctx, done);
850   }
851 
TryInit(OpKernelContext * ctx,IteratorResource ** iterator,ContainerInfo * cinfo)852   Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
853                  ContainerInfo* cinfo) {
854     TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def()));
855 
856     FunctionLibraryRuntime* flr;
857     std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
858     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
859     TF_RETURN_IF_ERROR(
860         ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
861 
862     // Create an IteratorResource that will hold the iterator for this op.
863     TF_RETURN_IF_ERROR(
864         ctx->resource_manager()->LookupOrCreate<IteratorResource>(
865             cinfo->container(), cinfo->name(), iterator,
866             [ctx, flr, this, &flib_def, &pflr](IteratorResource** ret)
867                 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
868                   *ret = new IteratorResource(
869                       ctx->env(), output_dtypes_, output_shapes_,
870                       /*device_mgr=*/nullptr, std::move(flib_def),
871                       std::move(pflr), flr);
872                   return Status::OK();
873                 }));
874 
875     core::ScopedUnref unref_iterator(*iterator);
876 
877     TF_RETURN_IF_ERROR(
878         VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes()));
879     TF_RETURN_IF_ERROR(
880         VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes()));
881 
882     // Call the dataset_factory_func_ to create a new dataset,
883     // over which this op will iterate.
884     FunctionLibraryRuntime::Handle f_handle;
885     TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
886         dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()),
887         &f_handle));
888     FunctionLibraryRuntime::Options opts;
889     opts.cancellation_manager = ctx->cancellation_manager();
890     ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
891       ctx->resource_manager()->Cleanup(name).IgnoreError();
892     });
893     opts.step_container = &step_container;
894     opts.runner = ctx->runner();
895     opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
896     std::vector<Tensor> return_values;
897     TF_RETURN_IF_ERROR(ctx->function_library()->RunSync(
898         std::move(opts), f_handle, {}, &return_values));
899     if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
900         !TensorShapeUtils::IsScalar(return_values[0].shape())) {
901       return errors::InvalidArgument(
902           "The `dataset_factory` function must return "
903           "a single scalar of dtype DT_VARIANT.");
904     }
905 
906     // Create an iterator for the dataset that was created in the
907     // factory function.
908     DatasetBase* dataset;
909     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
910     TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
911     (*iterator)->Ref();
912     return Status::OK();
913   }
914 
ProduceOutput(OpKernelContext * ctx,const DoneCallback & done)915   void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) {
916     Tensor* handle;
917     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle),
918                          done);
919     Status s;
920     {
921       mutex_lock l(mu_);
922       s = initialization_status_;
923       if (s.ok()) {
924         handle->scalar<ResourceHandle>()() =
925             MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(),
926                                                  cinfo_.name());
927       }
928     }
929     OP_REQUIRES_OK_ASYNC(ctx, s, done);
930     done();
931   }
932 
933   NameAttrList dataset_factory_func_;
934   DataTypeVector output_dtypes_;
935   std::vector<PartialTensorShape> output_shapes_;
936 
937   BackgroundWorker background_worker_;
938 
939   mutex mu_;
940   ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
941   IteratorResource* iterator_resource_ TF_GUARDED_BY(mu_) = nullptr;
942 
943   bool initialization_started_ TF_GUARDED_BY(mu_) = false;
944   Status initialization_status_ TF_GUARDED_BY(mu_);
945   std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_
946       TF_GUARDED_BY(mu_);
947   const int graph_def_version_;
948 };
949 
950 }  // namespace
951 
AsAsync()952 AsyncOpKernel* IteratorGetNextOp::AsAsync() {
953   return type_string() == "IteratorGetNextSync" ? nullptr : this;
954 }
955 
RecordElementSize(const std::vector<Tensor> element,profiler::TraceMe * traceme)956 void RecordElementSize(const std::vector<Tensor> element,
957                        profiler::TraceMe* traceme) {
958   traceme->AppendMetadata([&]() {
959     int64 element_size = 0;
960     for (const auto& component : element) {
961       element_size += component.TotalBytes();
962     }
963     return profiler::TraceMeEncode({{"element_size", element_size}});
964   });
965 }
966 
DoCompute(OpKernelContext * ctx)967 Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
968   profiler::TraceMe traceme(
969       [&] {
970         int64 mem_bw = port::GetMemoryBandwidthInfo().bw_used;
971 
972         if (mem_bw != INT64_MAX) {
973           return profiler::TraceMeEncode(
974               "IteratorGetNextOp::DoCompute",
975               {{"id", ctx->step_id()},
976                {"iter_num", ctx->frame_iter().iter_id},
977                {"mem_bw_used_megabytes_per_sec", mem_bw}});
978         }
979         return profiler::TraceMeEncode(
980             "IteratorGetNextOp::DoCompute",
981             {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
982       },
983       profiler::kInfo);
984   tensorflow::ResourceTagger tag(kTFDataResourceTag,
985                                  ctx->op_kernel().type_string());
986   IteratorResource* iterator;
987   TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
988   core::ScopedUnref unref_iterator(iterator);
989   std::vector<Tensor> components;
990   bool end_of_sequence = false;
991 
992   TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
993   if (end_of_sequence) {
994     return errors::OutOfRange("End of sequence");
995   }
996   TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
997   TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
998   RecordElementSize(components, &traceme);
999   for (int i = 0; i < components.size(); ++i) {
1000     ctx->set_output(i, components[i]);
1001   }
1002   return Status::OK();
1003 }
1004 
DoCompute(OpKernelContext * ctx)1005 Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
1006   profiler::TraceMe traceme(
1007       [&] {
1008         return strings::StrCat(
1009             "IteratorGetNextAsOptionalOp::DoCompute#id=", ctx->step_id(),
1010             ",iter_num=", ctx->frame_iter().iter_id, "#");
1011       },
1012       profiler::kInfo);
1013   tensorflow::ResourceTagger tag(kTFDataResourceTag,
1014                                  ctx->op_kernel().type_string());
1015   IteratorResource* iterator;
1016   TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
1017   core::ScopedUnref unref_iterator(iterator);
1018   std::vector<Tensor> components;
1019   bool end_of_sequence = false;
1020 
1021   TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
1022 
1023   if (end_of_sequence) {
1024     return WriteOptionalNoneToOutput(ctx, 0);
1025   } else {
1026     RecordElementSize(components, &traceme);
1027     for (int i = 0; i < components.size(); ++i) {
1028       if (components[i].dtype() != output_types_[i]) {
1029         return errors::InvalidArgument(
1030             "The given optional does not match the expected type for "
1031             "component ",
1032             i, ". Expected: ", DataTypeString(output_types_[i]),
1033             ". Actual: ", DataTypeString(components[i].dtype()), ".");
1034       }
1035       if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
1036         return errors::InvalidArgument(
1037             "The given optional does not match the expected shape "
1038             "for component ",
1039             i, ". Expected: ", output_shapes_[i].DebugString(),
1040             ". Actual: ", components[i].shape().DebugString(), ".");
1041       }
1042     }
1043     return WriteOptionalWithValueToOutput(ctx, 0, std::move(components));
1044   }
1045 }
1046 
Compute(OpKernelContext * ctx)1047 void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
1048   const Tensor& resource_handle_t = ctx->input(0);
1049   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1050               errors::InvalidArgument("resource_handle must be a scalar"));
1051 
1052   // Validate that the handle corresponds to a real resource, and
1053   // that it is an IteratorResource.
1054   IteratorResource* iterator_resource;
1055   OP_REQUIRES_OK(
1056       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1057   iterator_resource->Unref();
1058 
1059   Tensor* string_handle_t;
1060   OP_REQUIRES_OK(ctx,
1061                  ctx->allocate_output(0, TensorShape({}), &string_handle_t));
1062   string_handle_t->scalar<tstring>()() =
1063       resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
1064 }
1065 
IteratorFromStringHandleOp(OpKernelConstruction * ctx)1066 IteratorFromStringHandleOp::IteratorFromStringHandleOp(
1067     OpKernelConstruction* ctx)
1068     : OpKernel(ctx) {
1069   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
1070   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1071   OP_REQUIRES(
1072       ctx,
1073       output_dtypes_.empty() || output_shapes_.empty() ||
1074           output_dtypes_.size() == output_shapes_.size(),
1075       errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
1076                               "are set, they must have the same length."));
1077 }
1078 
Compute(OpKernelContext * ctx)1079 void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
1080   const Tensor& string_handle_t = ctx->input(0);
1081   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
1082               errors::InvalidArgument("string_handle must be a scalar"));
1083 
1084   ResourceHandle resource_handle;
1085   OP_REQUIRES(
1086       ctx, resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
1087       errors::InvalidArgument(
1088           "Could not parse string_handle as a valid ResourceHandle"));
1089 
1090   OP_REQUIRES(
1091       ctx, resource_handle.device() == ctx->device()->attributes().name(),
1092       errors::InvalidArgument("Attempted create an iterator on device \"",
1093                               ctx->device()->attributes().name(),
1094                               "\" from handle defined on device \"",
1095                               resource_handle.device(), "\""));
1096 
1097   // Validate that the handle corresponds to a real resource, and
1098   // that it is an IteratorResource.
1099   IteratorResource* iterator_resource;
1100   OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource));
1101   core::ScopedUnref unref_iterator(iterator_resource);
1102   if (!output_dtypes_.empty()) {
1103     OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
1104                                          iterator_resource->output_dtypes()));
1105   }
1106   if (!output_shapes_.empty()) {
1107     OP_REQUIRES_OK(ctx,
1108                    VerifyShapesCompatible(output_shapes_,
1109                                           iterator_resource->output_shapes()));
1110   }
1111 
1112   Tensor* resource_handle_t;
1113   OP_REQUIRES_OK(ctx,
1114                  ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
1115   resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
1116 }
1117 
SerializeIteratorOp(OpKernelConstruction * ctx)1118 SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
1119     : OpKernel(ctx) {
1120   if (ctx->HasAttr(kExternalStatePolicy)) {
1121     int64 state_change_option;
1122     OP_REQUIRES_OK(ctx,
1123                    ctx->GetAttr(kExternalStatePolicy, &state_change_option));
1124     external_state_policy_ =
1125         SerializationContext::ExternalStatePolicy(state_change_option);
1126   }
1127 }
1128 
Compute(OpKernelContext * ctx)1129 void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
1130   tensorflow::ResourceTagger tag(kTFDataResourceTag,
1131                                  ctx->op_kernel().type_string());
1132   const Tensor& resource_handle_t = ctx->input(0);
1133   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1134               errors::InvalidArgument("resource_handle must be a scalar"));
1135   // Validate that the handle corresponds to a real resource, and
1136   // that it is an IteratorResource.
1137   IteratorResource* iterator_resource;
1138   OP_REQUIRES_OK(
1139       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1140   core::ScopedUnref unref_iterator(iterator_resource);
1141   IteratorVariantSerializer serializer;
1142   SerializationContext::Params params;
1143   params.external_state_policy = external_state_policy_;
1144   SerializationContext serialization_ctx(params);
1145   OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
1146                                                         iterator_resource));
1147   Tensor* serialized_t;
1148   OP_REQUIRES_OK(ctx,
1149                  ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
1150                                       &serialized_t));
1151   OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
1152 }
1153 
Compute(OpKernelContext * ctx)1154 void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
1155   tensorflow::ResourceTagger tag(kTFDataResourceTag,
1156                                  ctx->op_kernel().type_string());
1157   // Validate that the handle corresponds to a real resource, and
1158   // that it is an IteratorResource.
1159   IteratorResource* iterator_resource;
1160   OP_REQUIRES_OK(
1161       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1162   core::ScopedUnref unref_iterator(iterator_resource);
1163   const Tensor* serialized_t;
1164   OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
1165   IteratorVariantSerializer serializer;
1166   OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
1167   OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, serializer.GetReader()));
1168 }
1169 
1170 namespace {
1171 
1172 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
1173 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
1174                         IteratorHandleOp);
1175 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU).Priority(1),
1176                         IteratorHandleOp);
1177 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2),
1178                         MakeIteratorOp);
1179 REGISTER_KERNEL_BUILDER(
1180     Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"),
1181     MakeIteratorOp);
1182 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_CPU).Priority(2),
1183                         DeleteIteratorOp);
1184 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_GPU).Priority(1),
1185                         DeleteIteratorOp);
1186 REGISTER_KERNEL_BUILDER(
1187     Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2),
1188     AnonymousIteratorHandleOp);
1189 REGISTER_KERNEL_BUILDER(
1190     Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1),
1191     AnonymousIteratorHandleOp);
1192 REGISTER_KERNEL_BUILDER(
1193     Name("AnonymousIteratorV2").Device(DEVICE_CPU).Priority(2),
1194     AnonymousIteratorHandleOp);
1195 REGISTER_KERNEL_BUILDER(
1196     Name("AnonymousIteratorV2").Device(DEVICE_GPU).Priority(1),
1197     AnonymousIteratorHandleOp);
1198 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
1199                         ToSingleElementOp);
1200 REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU),
1201                         ReduceDatasetOp);
1202 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
1203                         OneShotIteratorOp);
1204 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU).Priority(2),
1205                         IteratorGetNextOp);
1206 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU).Priority(1),
1207                         IteratorGetNextOp);
1208 REGISTER_KERNEL_BUILDER(
1209     Name("IteratorGetNextSync").Device(DEVICE_CPU).Priority(2),
1210     IteratorGetNextOp);
1211 REGISTER_KERNEL_BUILDER(
1212     Name("IteratorGetNextSync").Device(DEVICE_GPU).Priority(1),
1213     IteratorGetNextOp);
1214 REGISTER_KERNEL_BUILDER(
1215     Name("IteratorGetNextAsOptional").Device(DEVICE_CPU).Priority(2),
1216     IteratorGetNextAsOptionalOp);
1217 REGISTER_KERNEL_BUILDER(
1218     Name("IteratorGetNextAsOptional").Device(DEVICE_GPU).Priority(1),
1219     IteratorGetNextAsOptionalOp);
1220 REGISTER_KERNEL_BUILDER(
1221     Name("IteratorToStringHandle").Device(DEVICE_CPU).Priority(2),
1222     IteratorToStringHandleOp);
1223 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
1224                             .Device(DEVICE_GPU)
1225                             .HostMemory("string_handle")
1226                             .Priority(1),
1227                         IteratorToStringHandleOp);
1228 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
1229                         IteratorFromStringHandleOp);
1230 REGISTER_KERNEL_BUILDER(
1231     Name("IteratorFromStringHandleV2").Device(DEVICE_CPU).Priority(2),
1232     IteratorFromStringHandleOp);
1233 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
1234                             .Device(DEVICE_GPU)
1235                             .HostMemory("string_handle")
1236                             .Priority(1),
1237                         IteratorFromStringHandleOp);
1238 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
1239                         SerializeIteratorOp);
1240 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
1241                         DeserializeIteratorOp);
1242 
1243 REGISTER_INPUT_COLOCATION_EXEMPTION("ReduceDataset");
1244 
1245 }  // namespace
1246 
1247 }  // namespace data
1248 }  // namespace tensorflow
1249