1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/iterator_ops.h"
16
17 #include <memory>
18
19 #include "absl/memory/memory.h"
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/common_runtime/graph_runner.h"
22 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
23 #include "tensorflow/core/common_runtime/metrics.h"
24 #include "tensorflow/core/common_runtime/renamed_device.h"
25 #include "tensorflow/core/common_runtime/threadpool_device.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/metrics.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/partial_tensor_shape.h"
31 #include "tensorflow/core/framework/resource_op_kernel.h"
32 #include "tensorflow/core/framework/stats_aggregator.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/framework/variant_op_registry.h"
36 #include "tensorflow/core/framework/variant_tensor_data.h"
37 #include "tensorflow/core/kernels/data/captured_function.h"
38 #include "tensorflow/core/kernels/data/dataset_utils.h"
39 #include "tensorflow/core/kernels/data/optional_ops.h"
40 #include "tensorflow/core/kernels/ops_util.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/refcount.h"
43 #include "tensorflow/core/lib/core/threadpool.h"
44 #include "tensorflow/core/lib/gtl/cleanup.h"
45 #include "tensorflow/core/lib/random/random.h"
46 #include "tensorflow/core/lib/strings/strcat.h"
47 #include "tensorflow/core/lib/strings/stringprintf.h"
48 #include "tensorflow/core/platform/casts.h"
49 #include "tensorflow/core/platform/env.h"
50 #include "tensorflow/core/platform/mem.h"
51 #include "tensorflow/core/platform/mutex.h"
52 #include "tensorflow/core/platform/refcount.h"
53 #include "tensorflow/core/platform/resource.h"
54 #include "tensorflow/core/profiler/lib/traceme.h"
55 #include "tensorflow/core/profiler/lib/traceme_encode.h"
56 #include "tensorflow/core/public/session_options.h"
57
58 namespace tensorflow {
59 namespace data {
60 namespace {
61
62 // See documentation in ../../ops/dataset_ops.cc for a high-level
63 // description of the following ops.
64
65 const char kAnonymousIterator[] = "AnonymousIterator";
66 const char kAnonymousIteratorV2[] = "AnonymousIteratorV2";
67 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
68 const char kOutputShapes[] = "output_shapes";
69 const char kOutputTypes[] = "output_types";
70
71 // Safely subtracts x from y avoiding underflow.
safe_sub(uint64 x,uint64 y)72 inline uint64 safe_sub(uint64 x, uint64 y) { return x >= y ? x - y : 0; }
73
74 } // namespace
75
76 /* static */ constexpr const char* const
77 SerializeIteratorOp::kExternalStatePolicy;
78
IteratorResource(Env * env,const DataTypeVector & output_dtypes,const std::vector<PartialTensorShape> & output_shapes,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr)79 IteratorResource::IteratorResource(
80 Env* env, const DataTypeVector& output_dtypes,
81 const std::vector<PartialTensorShape>& output_shapes,
82 std::unique_ptr<DeviceMgr> device_mgr,
83 std::unique_ptr<FunctionLibraryDefinition> flib_def,
84 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
85 FunctionLibraryRuntime* flr)
86 : unbounded_thread_pool_(env, "tf_data_iterator_resource"),
87 device_mgr_(std::move(device_mgr)),
88 iterator_state_(std::make_shared<State>(std::move(flib_def),
89 std::move(pflr), flr,
90 /*iterator=*/nullptr)),
91 output_dtypes_(output_dtypes),
92 output_shapes_(output_shapes),
93 // We do not collect iterator resource metrics for non-CPU devices. This
94 // is a heuristic to avoid collecting metrics for device-side iterators
95 // created by the multi-device iterator mechanism.
96 collect_metrics_(flr->device()->device_type() == DEVICE_CPU) {
97 VLOG(2) << "creating iterator resource";
98 }
99
~IteratorResource()100 IteratorResource::~IteratorResource() {
101 VLOG(2) << "destroying iterator resource";
102 }
103
GetNext(OpKernelContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)104 Status IteratorResource::GetNext(OpKernelContext* ctx,
105 std::vector<Tensor>* out_tensors,
106 bool* end_of_sequence) {
107 std::shared_ptr<State> captured_state;
108 {
109 tf_shared_lock l(mu_);
110 captured_state = iterator_state_;
111 }
112 if (!captured_state->iterator()) {
113 return errors::FailedPrecondition(
114 "GetNext() failed because the iterator has not been initialized. "
115 "Ensure that you have run the initializer operation for this iterator "
116 "before getting the next element.");
117 }
118 IteratorContext::Params params(ctx);
119 params.flr = captured_state->flr();
120 params.function_handle_cache = captured_state->function_handle_cache();
121 params.resource_mgr = captured_state->resource_mgr();
122 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
123 params.thread_pool = &unbounded_thread_pool_;
124 params.cancellation_manager = captured_state->cancellation_manager();
125 std::function<void()> deregister_fn;
126 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
127 ctx->cancellation_manager(),
128 [cm = params.cancellation_manager]() { cm->StartCancel(); },
129 &deregister_fn));
130 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
131 const uint64 start_time_us = ctx->env()->NowMicros();
132 if (collect_metrics_) {
133 mutex_lock l(mu_);
134 if (get_next_end_time_us_ == 0) {
135 // We initialize `get_next_end_time_us_` to the start time of the first
136 // request to make it possible to use the delta between
137 // `get_next_end_time_us_` and subsequent `GetNext()` end time to
138 // incrementally collect the duration of the iterator's lifetime.
139 get_next_end_time_us_ = start_time_us;
140 }
141 if (num_get_next_calls_ == 0) {
142 get_next_start_time_us_ = start_time_us;
143 }
144 num_get_next_calls_++;
145 }
146 auto iterator_ = captured_state->iterator();
147 auto status = iterator_->GetNext(IteratorContext(std::move(params)),
148 out_tensors, end_of_sequence);
149 if (collect_metrics_) {
150 const uint64 end_time_us = ctx->env()->NowMicros();
151 metrics::RecordTFDataGetNextDuration(safe_sub(end_time_us, start_time_us));
152 metrics::RecordTFDataBytesFetched(GetTotalBytes(*out_tensors));
153 mutex_lock l(mu_);
154 metrics::RecordTFDataIteratorLifetime(
155 safe_sub(end_time_us, get_next_end_time_us_));
156 get_next_end_time_us_ = std::max(get_next_end_time_us_, end_time_us);
157 num_get_next_calls_--;
158 if (num_get_next_calls_ == 0) {
159 metrics::RecordTFDataIteratorBusy(
160 safe_sub(get_next_end_time_us_, get_next_start_time_us_));
161 }
162 }
163 return status;
164 }
165
Save(SerializationContext * ctx,IteratorStateWriter * writer)166 Status IteratorResource::Save(SerializationContext* ctx,
167 IteratorStateWriter* writer) {
168 std::shared_ptr<State> captured_state;
169 {
170 tf_shared_lock l(mu_);
171 captured_state = iterator_state_;
172 }
173 auto iterator_ = captured_state->iterator();
174 if (iterator_) {
175 return iterator_->Save(ctx, writer);
176 }
177 return errors::FailedPrecondition(
178 "Save() failed because the iterator has not been initialized. Ensure "
179 "that you have run the initializer operation for this iterator before "
180 "saving it.");
181 }
182
Restore(OpKernelContext * ctx,IteratorStateReader * reader)183 Status IteratorResource::Restore(OpKernelContext* ctx,
184 IteratorStateReader* reader) {
185 const DatasetBase* dataset;
186 std::shared_ptr<State> new_state;
187 {
188 tf_shared_lock l(mu_);
189 if (!iterator_state_->iterator()) {
190 return errors::FailedPrecondition(
191 "Restore() failed because the iterator has not been initialized. "
192 "Ensure that you have run the initializer operation for this "
193 "iterator before restoring it.");
194 }
195 auto iterator_ = iterator_state_->iterator();
196 dataset = iterator_->dataset();
197 // Hang onto a reference until we've created the new iterator, which will
198 // then hold its own reference to keep the dataset alive.
199 dataset->Ref();
200 new_state =
201 std::make_shared<State>(iterator_state_->flib_def(),
202 iterator_state_->pflr(), iterator_state_->flr(),
203 /*iterator=*/nullptr);
204 }
205 core::ScopedUnref scoped_unref(dataset);
206 IteratorContext::Params params(ctx);
207 params.flr = new_state->flr();
208 params.function_handle_cache = new_state->function_handle_cache();
209 params.resource_mgr = new_state->resource_mgr();
210 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
211 params.thread_pool = &unbounded_thread_pool_;
212 params.cancellation_manager = new_state->cancellation_manager();
213 std::function<void()> deregister_fn;
214 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
215 ctx->cancellation_manager(),
216 [cm = params.cancellation_manager]() { cm->StartCancel(); },
217 &deregister_fn));
218 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
219 std::unique_ptr<IteratorBase> iterator_base;
220 TF_RETURN_IF_ERROR(dataset->MakeIteratorFromCheckpoint(
221 IteratorContext(std::move(params)), "Iterator", reader, &iterator_base));
222 new_state->DowncastAndSetIterator(std::move(iterator_base));
223
224 mutex_lock l(mu_);
225 std::swap(iterator_state_, new_state);
226 return Status::OK();
227 }
228
SetIteratorFromDataset(OpKernelContext * ctx,DatasetBase * dataset)229 Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
230 DatasetBase* dataset) {
231 std::shared_ptr<State> new_state;
232 {
233 tf_shared_lock l(mu_);
234 new_state =
235 std::make_shared<State>(iterator_state_->flib_def(),
236 iterator_state_->pflr(), iterator_state_->flr(),
237 /*iterator=*/nullptr);
238 }
239 // Create new iterator.
240 std::unique_ptr<IteratorBase> iterator;
241 IteratorContext::Params params(ctx);
242 params.flr = new_state->flr();
243 params.function_handle_cache = new_state->function_handle_cache();
244 params.resource_mgr = new_state->resource_mgr();
245 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
246 params.thread_pool = &unbounded_thread_pool_;
247 params.cancellation_manager = new_state->cancellation_manager();
248 std::function<void()> deregister_fn;
249 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
250 ctx->cancellation_manager(),
251 [cm = params.cancellation_manager]() { cm->StartCancel(); },
252 &deregister_fn));
253 {
254 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
255 TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
256 /*parent=*/nullptr, "Iterator",
257 &iterator));
258 TF_RETURN_IF_ERROR(
259 VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
260 TF_RETURN_IF_ERROR(
261 VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
262
263 new_state->DowncastAndSetIterator(std::move(iterator));
264 }
265
266 mutex_lock l(mu_);
267 std::swap(iterator_state_, new_state);
268 return Status::OK();
269 }
270
271 namespace {
272
273 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
274 // The get() method returns an VariantTensorData object which contains all the
275 // state needed to restore a single iterator.
276 //
277 // Usage example:
278 //
279 // Encoding:
280 //
281 // Tensor t(DT_VARIANT, TensorShape({}));
282 // t->scalar<Variant>()() = IteratorStateVariant();
283 //
284 // Encode() sets the type_name of the VariantTensorData object to
285 // IteratorStateVariant::TypeName().
286 //
287 // Decoding:
288 //
289 // Variant v = <VariantTensorDataProto object>;
290 // DecodeUnaryVariant(&v);
291 // IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
292 // IteratorStateReader reader({wrapper->GetData()});
293 // iterator_resource->Restore(ctx, &reader);
294 //
295 // The type_name of the VariantTensorData object to be decoded must
296 // match IteratorStateVariant::TypeName().
297 class IteratorStateVariant {
298 public:
IteratorStateVariant()299 IteratorStateVariant() : data_(nullptr) {}
IteratorStateVariant(const IteratorStateVariant & other)300 IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
301 if (other.data_) {
302 Decode(*other.data_);
303 }
304 }
305 IteratorStateVariant& operator=(IteratorStateVariant&& other) = default;
306 IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete;
307
308 // Initializes `this` from a VariantTensorData object.
InitializeFromVariantData(std::unique_ptr<VariantTensorData> d)309 Status InitializeFromVariantData(std::unique_ptr<VariantTensorData> d) {
310 data_ = std::move(d);
311 return Status::OK();
312 }
313
TypeName() const314 string TypeName() const { return kIteratorVariantTypeName; }
Encode(VariantTensorData * data) const315 void Encode(VariantTensorData* data) const { *data = *data_; }
Decode(VariantTensorData data)316 bool Decode(VariantTensorData data) {
317 if (data.type_name() != TypeName()) {
318 return false;
319 }
320 auto tensor_data = absl::make_unique<VariantTensorData>();
321 std::swap(*tensor_data, data);
322 data_ = std::move(tensor_data);
323 return true;
324 }
325
326 // Returns a borrowed pointer to the underlying VariantTensorData.
GetData() const327 const VariantTensorData* GetData() const { return data_.get(); }
328
DebugString() const329 string DebugString() const {
330 if (data_) {
331 return strings::StrCat("IteratorStateVariant<", data_->DebugString(),
332 ">");
333 } else {
334 return strings::StrCat("IteratorStateVariant<empty>");
335 }
336 }
337
338 private:
339 std::unique_ptr<VariantTensorData> data_;
340 };
341
342 // Register the reader class in the global variant decode_fn registry
343 // so that a Variant containing a serialized representation of iterator state
344 // can be decoded using DecodeUnaryVariant. If we don't do this we will need
345 // to manually decode the returned Variant using MaybeDecodeAndCopy in
346 // DeserializeIteratorOp which is not recommended.
347 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
348 kIteratorVariantTypeName);
349
350 // A helper class that uses a list of IteratorStateVariant objects to represent
351 // the state for an iterator resource. It exposes methods that help with
352 // saving and restoring of this state. Sample usage
353 // Saving:
354 // IteratorVariantSerializer serializer;
355 // serializer.InitializeFromIterator(iterator_resource);
356 // Tensor serialized_t;
357 // serializer.Serialize(&serialized_t);
358 //
359 // Restoring:
360 // IteratorVariantSerializer serializer;
361 // serializer.InitFromTensor(ctx->input(0));
362 // IteratorStateReader* reader = serializer.GetReader();
363 // iterator_resource->Restore(ctx, reader);
364 class IteratorVariantSerializer {
365 public:
IteratorVariantSerializer()366 IteratorVariantSerializer() {}
367
368 // Calls `Save` on the iterator_resource to build up the list of
369 // IteratorStateVariant objects.
InitializeFromIterator(SerializationContext * serialization_ctx,IteratorResource * iterator_resource)370 Status InitializeFromIterator(SerializationContext* serialization_ctx,
371 IteratorResource* iterator_resource) {
372 VariantTensorDataWriter writer;
373 TF_RETURN_IF_ERROR(iterator_resource->Save(serialization_ctx, &writer));
374 std::vector<std::unique_ptr<VariantTensorData>> data;
375 writer.ReleaseData(&data);
376 variants_.clear();
377 variants_.reserve(data.size());
378 for (auto& it : data) {
379 IteratorStateVariant v;
380 TF_RETURN_IF_ERROR(v.InitializeFromVariantData(std::move(it)));
381 variants_.push_back(v);
382 }
383 num_tensors_ = variants_.size();
384 can_serialize_ = true;
385 return Status::OK();
386 }
387
388 // Initializes `this` from `serialized_t` while restoring the iterator state.
InitFromTensor(const Tensor * serialized_t)389 Status InitFromTensor(const Tensor* serialized_t) {
390 int64 num_tensors = serialized_t->dim_size(0);
391 auto serialized_vec = serialized_t->vec<Variant>();
392 std::vector<const VariantTensorData*> data;
393 data.reserve(num_tensors);
394 for (int i = 0; i < num_tensors; ++i) {
395 auto* w = serialized_vec(i).get<IteratorStateVariant>();
396 if (!w) {
397 return errors::Internal(
398 "Cannot initialize an iterator from tensor ",
399 serialized_vec(i).DebugString(),
400 ". Expected a variant tensor of type IteratorStateVariant");
401 }
402 data.push_back(w->GetData());
403 }
404 reader_ = absl::make_unique<VariantTensorDataReader>(data);
405 num_tensors_ = data.size();
406 return Status::OK();
407 }
408
NumTensors()409 int64 NumTensors() { return num_tensors_; }
410
411 // Stores the IteratorStateVariant list into a pre-allocated tensor. Expects
412 // that InitializeFromIterator was called before.
Serialize(Tensor * serialized)413 Status Serialize(Tensor* serialized) {
414 if (!can_serialize_) {
415 return errors::InvalidArgument(
416 "Please call InitializeFromIterator before calling Serialize.");
417 }
418 int64 size = variants_.size();
419 for (int64 i = 0; i < size; ++i) {
420 if (variants_[i].GetData() == nullptr) {
421 return errors::Internal(
422 "Cannot serialize an empty IteratorStateVariant");
423 }
424 serialized->vec<Variant>()(i) = variants_[i];
425 }
426 return Status::OK();
427 }
428
429 // Returns an IteratorStateReader to restore iterator state. Expects that
430 // InitFromTensor was called before.
GetReader()431 IteratorStateReader* GetReader() { return reader_.get(); }
432
433 private:
434 bool can_serialize_ = false;
435 int64 num_tensors_;
436 std::vector<IteratorStateVariant> variants_;
437 std::unique_ptr<IteratorStateReader> reader_;
438 };
439
440 } // namespace
441
442 // Note that IteratorHandleOp holds a reference to the resource it creates. If
443 // cleaning up resources with DestroyResourceOp is important, consider creating
444 // resource containers with AnonymousIteratorHandleOp instead.
IteratorHandleOp(OpKernelConstruction * ctx)445 IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
446 : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
447 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
448 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
449 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
450 }
451
452 // The resource is deleted from the resource manager only when it is private
453 // to kernel. Ideally the resource should be deleted when it is no longer held
454 // by anyone, but it would break backward compatibility.
~IteratorHandleOp()455 IteratorHandleOp::~IteratorHandleOp() {
456 if (resource_ != nullptr) {
457 resource_->Unref();
458 if (cinfo_.resource_is_private_to_kernel()) {
459 if (!cinfo_.resource_manager()
460 ->template Delete<IteratorResource>(cinfo_.container(),
461 cinfo_.name())
462 .ok()) {
463 // Do nothing; the resource can have been deleted by session resets.
464 }
465 }
466 }
467 }
468
Compute(OpKernelContext * context)469 void IteratorHandleOp::Compute(OpKernelContext* context)
470 TF_LOCKS_EXCLUDED(mu_) {
471 {
472 mutex_lock l(mu_);
473 if (resource_ == nullptr) {
474 FunctionLibraryRuntime* flr;
475 std::unique_ptr<DeviceMgr> device_mgr(nullptr);
476 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
477 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
478 // If the iterator is shared then we construct a new FLR, and pass that
479 // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
480 // functions from the iterator. We may add this functionality if there
481 // is sufficient demand, but it will require a significant refactoring.
482 if (!name_.empty()) {
483 flr = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
484 } else {
485 OP_REQUIRES_OK(context, context->function_library()->Clone(
486 &flib_def, &pflr, &flr, true));
487 }
488
489 ResourceMgr* mgr = context->resource_manager();
490 OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
491
492 IteratorResource* resource;
493 OP_REQUIRES_OK(
494 context,
495 mgr->LookupOrCreate<IteratorResource>(
496 cinfo_.container(), cinfo_.name(), &resource,
497 [context, flr, &device_mgr, &flib_def, &pflr,
498 this](IteratorResource** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
499 *ret = new IteratorResource(
500 context->env(), output_dtypes_, output_shapes_,
501 std::move(device_mgr), std::move(flib_def), std::move(pflr),
502 flr);
503 return Status::OK();
504 }));
505
506 Status s = VerifyResource(resource);
507 if (TF_PREDICT_FALSE(!s.ok())) {
508 resource->Unref();
509 context->SetStatus(s);
510 return;
511 }
512
513 resource_ = resource;
514 }
515 }
516 OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
517 context, 0, cinfo_.container(), cinfo_.name(),
518 TypeIndex::Make<IteratorResource>()));
519 }
520
VerifyResource(IteratorResource * resource)521 Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
522 TF_RETURN_IF_ERROR(
523 VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
524 TF_RETURN_IF_ERROR(
525 VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
526 return Status::OK();
527 }
528
CreatePrivateFLR(OpKernelContext * ctx,std::unique_ptr<DeviceMgr> * device_mgr,std::unique_ptr<FunctionLibraryDefinition> * flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * pflr)529 FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
530 OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
531 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
532 std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
533 // Wrap the existing device in order to see any captured resources
534 // in its resource manager. The existing device will outlive the
535 // IteratorResource, because we are storing the IteratorResource
536 // in that device's resource manager.
537
538 *device_mgr =
539 absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
540 ctx->device()->name(), down_cast<Device*>(ctx->device()),
541 false /* owns_underlying */, false /* isolate_session_state */));
542 *flib_def = absl::make_unique<FunctionLibraryDefinition>(
543 *ctx->function_library()->GetFunctionLibraryDefinition());
544 const auto* config = ctx->function_library()->config_proto();
545 *pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
546 device_mgr->get(), ctx->env(),
547 /*config=*/config, graph_def_version_, flib_def->get(),
548 config->graph_options().optimizer_options());
549
550 return (*pflr)->GetFLR(ctx->device()->name());
551 }
552
553 // Like IteratorHandleOp, but creates handles which are never shared, and does
554 // not hold a reference to these handles. The latter is important for eager
555 // execution, since OpKernel instances generally live as long as the program
556 // running them.
AnonymousIteratorHandleOp(OpKernelConstruction * context)557 AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
558 OpKernelConstruction* context)
559 : AnonymousResourceOp<IteratorResource>(context),
560 graph_def_version_(context->graph_def_version()) {
561 OP_REQUIRES_OK(context, context->GetAttr(kOutputTypes, &output_dtypes_));
562 OP_REQUIRES_OK(context, context->GetAttr(kOutputShapes, &output_shapes_));
563 create_deleter_ = context->def().op() == kAnonymousIteratorV2;
564 }
565
name()566 string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; }
567
CreateResource(OpKernelContext * ctx,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib,IteratorResource ** resource)568 Status AnonymousIteratorHandleOp::CreateResource(
569 OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
570 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
571 FunctionLibraryRuntime* lib, IteratorResource** resource) {
572 std::unique_ptr<DeviceMgr> device_mgr(nullptr);
573 *resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_,
574 std::move(device_mgr), std::move(flib_def),
575 std::move(pflr), lib);
576 return Status::OK();
577 }
578
HybridAsyncOpKernel(OpKernelConstruction * ctx,const char * background_worker_name)579 HybridAsyncOpKernel::HybridAsyncOpKernel(OpKernelConstruction* ctx,
580 const char* background_worker_name)
581 : AsyncOpKernel(ctx),
582 background_worker_(ctx->env(), background_worker_name) {}
583
ComputeAsync(OpKernelContext * ctx,DoneCallback done)584 void HybridAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
585 DoneCallback done) {
586 background_worker_.Schedule([this, ctx, done = std::move(done)]() {
587 ctx->SetStatus(DoCompute(ctx));
588 done();
589 });
590 }
591
Compute(OpKernelContext * ctx)592 void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) {
593 ctx->SetStatus(DoCompute(ctx));
594 }
595
DoCompute(OpKernelContext * ctx)596 Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
597 tensorflow::ResourceTagger tag(kTFDataResourceTag,
598 ctx->op_kernel().type_string());
599 DatasetBase* dataset;
600 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
601 IteratorResource* iterator_resource;
602 TF_RETURN_IF_ERROR(
603 LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
604 core::ScopedUnref unref_iterator(iterator_resource);
605 return iterator_resource->SetIteratorFromDataset(ctx, dataset);
606 }
607
DoCompute(OpKernelContext * ctx)608 Status DeleteIteratorOp::DoCompute(OpKernelContext* ctx) {
609 tensorflow::ResourceTagger tag(kTFDataResourceTag,
610 ctx->op_kernel().type_string());
611 const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
612 // The iterator resource is guaranteed to exist because the variant tensor
613 // wrapping the deleter is provided as an unused input to this op, which
614 // guarantees that it has not run yet.
615 return ctx->resource_manager()->Delete(handle);
616 }
617
618 namespace {
619
620 class ToSingleElementOp : public HybridAsyncOpKernel {
621 public:
ToSingleElementOp(OpKernelConstruction * ctx)622 explicit ToSingleElementOp(OpKernelConstruction* ctx)
623 : HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {
624 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
625 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
626 }
627
628 protected:
DoCompute(OpKernelContext * ctx)629 Status DoCompute(OpKernelContext* ctx) override {
630 tensorflow::ResourceTagger tag(kTFDataResourceTag,
631 ctx->op_kernel().type_string());
632 DatasetBase* dataset;
633 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
634
635 IteratorContext::Params params(ctx);
636 FunctionHandleCache function_handle_cache(params.flr);
637 params.function_handle_cache = &function_handle_cache;
638 ResourceMgr resource_mgr;
639 params.resource_mgr = &resource_mgr;
640 CancellationManager cancellation_manager(ctx->cancellation_manager());
641 params.cancellation_manager = &cancellation_manager;
642
643 IteratorContext iter_ctx(std::move(params));
644 std::unique_ptr<IteratorBase> iterator;
645 TF_RETURN_IF_ERROR(dataset->MakeIterator(
646 &iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator));
647
648 std::vector<Tensor> components;
649 components.reserve(dataset->output_dtypes().size());
650 bool end_of_sequence = false;
651
652 TF_RETURN_IF_ERROR(
653 iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
654
655 if (end_of_sequence) {
656 return errors::InvalidArgument("Dataset was empty.");
657 }
658 TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
659 TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
660 for (int i = 0; i < components.size(); ++i) {
661 ctx->set_output(i, components[i]);
662 }
663
664 components.clear();
665 TF_RETURN_IF_ERROR(
666 iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
667 if (!end_of_sequence) {
668 return errors::InvalidArgument("Dataset had more than one element.");
669 }
670 return Status::OK();
671 }
672
673 private:
674 DataTypeVector output_types_;
675 std::vector<PartialTensorShape> output_shapes_;
676 };
677
678 class ReduceDatasetOp : public HybridAsyncOpKernel {
679 public:
ReduceDatasetOp(OpKernelConstruction * ctx)680 explicit ReduceDatasetOp(OpKernelConstruction* ctx)
681 : HybridAsyncOpKernel(ctx, "tf_data_reduce_dataset") {
682 FunctionMetadata::Params params;
683 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
684 ¶ms.use_inter_op_parallelism));
685 params.use_default_device = false;
686 OP_REQUIRES_OK(ctx,
687 FunctionMetadata::Create(ctx, "f", params, &func_metadata_));
688 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
689 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
690 }
691
692 protected:
DoCompute(OpKernelContext * ctx)693 Status DoCompute(OpKernelContext* ctx) override {
694 tensorflow::ResourceTagger tag(kTFDataResourceTag,
695 ctx->op_kernel().type_string());
696 DatasetBase* dataset;
697 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
698 OpInputList inputs;
699 TF_RETURN_IF_ERROR(ctx->input_list("initial_state", &inputs));
700 std::vector<Tensor> state(inputs.begin(), inputs.end());
701
702 std::unique_ptr<CapturedFunction> captured_func;
703 TF_RETURN_IF_ERROR(CapturedFunction::Create(
704 ctx, func_metadata_, "other_arguments", &captured_func));
705
706 IteratorContext::Params params(ctx);
707 auto function_handle_cache =
708 absl::make_unique<FunctionHandleCache>(params.flr);
709 params.function_handle_cache = function_handle_cache.get();
710 ResourceMgr resource_mgr;
711 params.resource_mgr = &resource_mgr;
712 CancellationManager cancellation_manager(ctx->cancellation_manager());
713 params.cancellation_manager = &cancellation_manager;
714
715 IteratorContext iter_ctx(std::move(params));
716 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func;
717 TF_RETURN_IF_ERROR(
718 captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
719
720 std::unique_ptr<IteratorBase> iterator;
721 TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, /*parent=*/nullptr,
722 "ReduceIterator", &iterator));
723
724 // Iterate through the input dataset.
725 while (true) {
726 if (ctx->cancellation_manager()->IsCancelled()) {
727 return errors::Cancelled("Operation was cancelled");
728 }
729 std::vector<Tensor> next_input_element;
730 bool end_of_input;
731 TF_RETURN_IF_ERROR(
732 iterator->GetNext(&iter_ctx, &next_input_element, &end_of_input));
733 if (end_of_input) {
734 break;
735 }
736
737 // Run the reduce function to update the current state.
738 std::vector<Tensor> args;
739 args.reserve(state.size() + next_input_element.size());
740 std::copy(state.begin(), state.end(), std::back_inserter(args));
741 std::copy(next_input_element.begin(), next_input_element.end(),
742 std::back_inserter(args));
743
744 std::vector<Tensor> reduce_func_output;
745 TF_RETURN_IF_ERROR(instantiated_captured_func->Run(
746 &iter_ctx, std::move(args), &reduce_func_output, /*node=*/nullptr));
747 if (reduce_func_output.size() != state.size()) {
748 return errors::InvalidArgument(
749 "The number of components of the initial state and the "
750 "reduce "
751 "function output does not match. (initial_state=",
752 state.size(), ", output=", reduce_func_output.size(), ").");
753 }
754 std::swap(reduce_func_output, state);
755 }
756
757 TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, state));
758 TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, state));
759 for (size_t i = 0; i < state.size(); ++i) {
760 ctx->set_output(i, state[i]);
761 }
762 return Status::OK();
763 }
764
765 std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
766 DataTypeVector output_types_;
767 std::vector<PartialTensorShape> output_shapes_;
768 };
769
770 class OneShotIteratorOp : public AsyncOpKernel {
771 public:
OneShotIteratorOp(OpKernelConstruction * ctx)772 explicit OneShotIteratorOp(OpKernelConstruction* ctx)
773 : AsyncOpKernel(ctx),
774 background_worker_(ctx->env(), "tf_data_one_shot_iterator"),
775 graph_def_version_(ctx->graph_def_version())
776
777 {
778 string shared_name;
779 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name));
780 OP_REQUIRES(ctx, shared_name.empty(),
781 errors::InvalidArgument("OneShotIteratorOp does not currently "
782 "support the 'shared_name' attr."));
783 OP_REQUIRES_OK(ctx,
784 ctx->GetAttr("dataset_factory", &dataset_factory_func_));
785 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
786 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
787 }
788
~OneShotIteratorOp()789 ~OneShotIteratorOp() override {
790 if (iterator_resource_ != nullptr) {
791 iterator_resource_->Unref();
792 if (!cinfo_.resource_manager()
793 ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name())
794 .ok()) {
795 // Do nothing; the resource can have been deleted by session resets.
796 }
797 }
798 }
799
800 // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`,
801 // but due to the fact that `ResourceOpKernel<T>::CreateResource()`
802 // does not provide access to the `OpKernelContext*` and we need
803 // this to invoke the factory function, it's not possible to
804 // implement this kernel by implementing `CreateResource()`.
805 // Furthermore, due to the fact that this kernel might block when
806 // running the initialization function, we must implement this
807 // kernel as an async kernel.
ComputeAsync(OpKernelContext * ctx,DoneCallback done)808 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
809 tensorflow::ResourceTagger tag(kTFDataResourceTag,
810 ctx->op_kernel().type_string());
811 {
812 mutex_lock l(mu_);
813 if (iterator_resource_ == nullptr && initialization_status_.ok()) {
814 // The initialization thread will call `done`.
815 if (!initialization_started_) {
816 // TODO(mrry): Convert the initialization code to use
817 // callbacks instead of wasting a thread.
818 background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); });
819 initialization_started_ = true;
820 } else {
821 done_callbacks_.emplace_back(ctx, std::move(done));
822 }
823 return;
824 }
825 }
826 ProduceOutput(ctx, done);
827 }
828
829 private:
Init(OpKernelContext * ctx,const DoneCallback & done)830 void Init(OpKernelContext* ctx, const DoneCallback& done) {
831 IteratorResource* iterator = nullptr;
832 ContainerInfo cinfo;
833 Status s = TryInit(ctx, &iterator, &cinfo);
834
835 std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run;
836 {
837 mutex_lock l(mu_);
838 if (s.ok()) {
839 iterator_resource_ = iterator;
840 cinfo_ = cinfo;
841 }
842 initialization_status_ = s;
843 std::swap(done_callbacks_, callbacks_to_run);
844 }
845
846 for (auto&& ctx_done : callbacks_to_run) {
847 ProduceOutput(ctx_done.first, ctx_done.second);
848 }
849 ProduceOutput(ctx, done);
850 }
851
TryInit(OpKernelContext * ctx,IteratorResource ** iterator,ContainerInfo * cinfo)852 Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
853 ContainerInfo* cinfo) {
854 TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def()));
855
856 FunctionLibraryRuntime* flr;
857 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
858 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
859 TF_RETURN_IF_ERROR(
860 ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
861
862 // Create an IteratorResource that will hold the iterator for this op.
863 TF_RETURN_IF_ERROR(
864 ctx->resource_manager()->LookupOrCreate<IteratorResource>(
865 cinfo->container(), cinfo->name(), iterator,
866 [ctx, flr, this, &flib_def, &pflr](IteratorResource** ret)
867 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
868 *ret = new IteratorResource(
869 ctx->env(), output_dtypes_, output_shapes_,
870 /*device_mgr=*/nullptr, std::move(flib_def),
871 std::move(pflr), flr);
872 return Status::OK();
873 }));
874
875 core::ScopedUnref unref_iterator(*iterator);
876
877 TF_RETURN_IF_ERROR(
878 VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes()));
879 TF_RETURN_IF_ERROR(
880 VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes()));
881
882 // Call the dataset_factory_func_ to create a new dataset,
883 // over which this op will iterate.
884 FunctionLibraryRuntime::Handle f_handle;
885 TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
886 dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()),
887 &f_handle));
888 FunctionLibraryRuntime::Options opts;
889 opts.cancellation_manager = ctx->cancellation_manager();
890 ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
891 ctx->resource_manager()->Cleanup(name).IgnoreError();
892 });
893 opts.step_container = &step_container;
894 opts.runner = ctx->runner();
895 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
896 std::vector<Tensor> return_values;
897 TF_RETURN_IF_ERROR(ctx->function_library()->RunSync(
898 std::move(opts), f_handle, {}, &return_values));
899 if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
900 !TensorShapeUtils::IsScalar(return_values[0].shape())) {
901 return errors::InvalidArgument(
902 "The `dataset_factory` function must return "
903 "a single scalar of dtype DT_VARIANT.");
904 }
905
906 // Create an iterator for the dataset that was created in the
907 // factory function.
908 DatasetBase* dataset;
909 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
910 TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
911 (*iterator)->Ref();
912 return Status::OK();
913 }
914
ProduceOutput(OpKernelContext * ctx,const DoneCallback & done)915 void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) {
916 Tensor* handle;
917 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle),
918 done);
919 Status s;
920 {
921 mutex_lock l(mu_);
922 s = initialization_status_;
923 if (s.ok()) {
924 handle->scalar<ResourceHandle>()() =
925 MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(),
926 cinfo_.name());
927 }
928 }
929 OP_REQUIRES_OK_ASYNC(ctx, s, done);
930 done();
931 }
932
933 NameAttrList dataset_factory_func_;
934 DataTypeVector output_dtypes_;
935 std::vector<PartialTensorShape> output_shapes_;
936
937 BackgroundWorker background_worker_;
938
939 mutex mu_;
940 ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
941 IteratorResource* iterator_resource_ TF_GUARDED_BY(mu_) = nullptr;
942
943 bool initialization_started_ TF_GUARDED_BY(mu_) = false;
944 Status initialization_status_ TF_GUARDED_BY(mu_);
945 std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_
946 TF_GUARDED_BY(mu_);
947 const int graph_def_version_;
948 };
949
950 } // namespace
951
AsAsync()952 AsyncOpKernel* IteratorGetNextOp::AsAsync() {
953 return type_string() == "IteratorGetNextSync" ? nullptr : this;
954 }
955
RecordElementSize(const std::vector<Tensor> element,profiler::TraceMe * traceme)956 void RecordElementSize(const std::vector<Tensor> element,
957 profiler::TraceMe* traceme) {
958 traceme->AppendMetadata([&]() {
959 int64 element_size = 0;
960 for (const auto& component : element) {
961 element_size += component.TotalBytes();
962 }
963 return profiler::TraceMeEncode({{"element_size", element_size}});
964 });
965 }
966
DoCompute(OpKernelContext * ctx)967 Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
968 profiler::TraceMe traceme(
969 [&] {
970 int64 mem_bw = port::GetMemoryBandwidthInfo().bw_used;
971
972 if (mem_bw != INT64_MAX) {
973 return profiler::TraceMeEncode(
974 "IteratorGetNextOp::DoCompute",
975 {{"id", ctx->step_id()},
976 {"iter_num", ctx->frame_iter().iter_id},
977 {"mem_bw_used_megabytes_per_sec", mem_bw}});
978 }
979 return profiler::TraceMeEncode(
980 "IteratorGetNextOp::DoCompute",
981 {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
982 },
983 profiler::kInfo);
984 tensorflow::ResourceTagger tag(kTFDataResourceTag,
985 ctx->op_kernel().type_string());
986 IteratorResource* iterator;
987 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
988 core::ScopedUnref unref_iterator(iterator);
989 std::vector<Tensor> components;
990 bool end_of_sequence = false;
991
992 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
993 if (end_of_sequence) {
994 return errors::OutOfRange("End of sequence");
995 }
996 TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
997 TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
998 RecordElementSize(components, &traceme);
999 for (int i = 0; i < components.size(); ++i) {
1000 ctx->set_output(i, components[i]);
1001 }
1002 return Status::OK();
1003 }
1004
DoCompute(OpKernelContext * ctx)1005 Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
1006 profiler::TraceMe traceme(
1007 [&] {
1008 return strings::StrCat(
1009 "IteratorGetNextAsOptionalOp::DoCompute#id=", ctx->step_id(),
1010 ",iter_num=", ctx->frame_iter().iter_id, "#");
1011 },
1012 profiler::kInfo);
1013 tensorflow::ResourceTagger tag(kTFDataResourceTag,
1014 ctx->op_kernel().type_string());
1015 IteratorResource* iterator;
1016 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
1017 core::ScopedUnref unref_iterator(iterator);
1018 std::vector<Tensor> components;
1019 bool end_of_sequence = false;
1020
1021 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
1022
1023 if (end_of_sequence) {
1024 return WriteOptionalNoneToOutput(ctx, 0);
1025 } else {
1026 RecordElementSize(components, &traceme);
1027 for (int i = 0; i < components.size(); ++i) {
1028 if (components[i].dtype() != output_types_[i]) {
1029 return errors::InvalidArgument(
1030 "The given optional does not match the expected type for "
1031 "component ",
1032 i, ". Expected: ", DataTypeString(output_types_[i]),
1033 ". Actual: ", DataTypeString(components[i].dtype()), ".");
1034 }
1035 if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
1036 return errors::InvalidArgument(
1037 "The given optional does not match the expected shape "
1038 "for component ",
1039 i, ". Expected: ", output_shapes_[i].DebugString(),
1040 ". Actual: ", components[i].shape().DebugString(), ".");
1041 }
1042 }
1043 return WriteOptionalWithValueToOutput(ctx, 0, std::move(components));
1044 }
1045 }
1046
Compute(OpKernelContext * ctx)1047 void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
1048 const Tensor& resource_handle_t = ctx->input(0);
1049 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1050 errors::InvalidArgument("resource_handle must be a scalar"));
1051
1052 // Validate that the handle corresponds to a real resource, and
1053 // that it is an IteratorResource.
1054 IteratorResource* iterator_resource;
1055 OP_REQUIRES_OK(
1056 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1057 iterator_resource->Unref();
1058
1059 Tensor* string_handle_t;
1060 OP_REQUIRES_OK(ctx,
1061 ctx->allocate_output(0, TensorShape({}), &string_handle_t));
1062 string_handle_t->scalar<tstring>()() =
1063 resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
1064 }
1065
IteratorFromStringHandleOp(OpKernelConstruction * ctx)1066 IteratorFromStringHandleOp::IteratorFromStringHandleOp(
1067 OpKernelConstruction* ctx)
1068 : OpKernel(ctx) {
1069 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
1070 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1071 OP_REQUIRES(
1072 ctx,
1073 output_dtypes_.empty() || output_shapes_.empty() ||
1074 output_dtypes_.size() == output_shapes_.size(),
1075 errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
1076 "are set, they must have the same length."));
1077 }
1078
Compute(OpKernelContext * ctx)1079 void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
1080 const Tensor& string_handle_t = ctx->input(0);
1081 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
1082 errors::InvalidArgument("string_handle must be a scalar"));
1083
1084 ResourceHandle resource_handle;
1085 OP_REQUIRES(
1086 ctx, resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
1087 errors::InvalidArgument(
1088 "Could not parse string_handle as a valid ResourceHandle"));
1089
1090 OP_REQUIRES(
1091 ctx, resource_handle.device() == ctx->device()->attributes().name(),
1092 errors::InvalidArgument("Attempted create an iterator on device \"",
1093 ctx->device()->attributes().name(),
1094 "\" from handle defined on device \"",
1095 resource_handle.device(), "\""));
1096
1097 // Validate that the handle corresponds to a real resource, and
1098 // that it is an IteratorResource.
1099 IteratorResource* iterator_resource;
1100 OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource));
1101 core::ScopedUnref unref_iterator(iterator_resource);
1102 if (!output_dtypes_.empty()) {
1103 OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
1104 iterator_resource->output_dtypes()));
1105 }
1106 if (!output_shapes_.empty()) {
1107 OP_REQUIRES_OK(ctx,
1108 VerifyShapesCompatible(output_shapes_,
1109 iterator_resource->output_shapes()));
1110 }
1111
1112 Tensor* resource_handle_t;
1113 OP_REQUIRES_OK(ctx,
1114 ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
1115 resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
1116 }
1117
SerializeIteratorOp(OpKernelConstruction * ctx)1118 SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
1119 : OpKernel(ctx) {
1120 if (ctx->HasAttr(kExternalStatePolicy)) {
1121 int64 state_change_option;
1122 OP_REQUIRES_OK(ctx,
1123 ctx->GetAttr(kExternalStatePolicy, &state_change_option));
1124 external_state_policy_ =
1125 SerializationContext::ExternalStatePolicy(state_change_option);
1126 }
1127 }
1128
Compute(OpKernelContext * ctx)1129 void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
1130 tensorflow::ResourceTagger tag(kTFDataResourceTag,
1131 ctx->op_kernel().type_string());
1132 const Tensor& resource_handle_t = ctx->input(0);
1133 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1134 errors::InvalidArgument("resource_handle must be a scalar"));
1135 // Validate that the handle corresponds to a real resource, and
1136 // that it is an IteratorResource.
1137 IteratorResource* iterator_resource;
1138 OP_REQUIRES_OK(
1139 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1140 core::ScopedUnref unref_iterator(iterator_resource);
1141 IteratorVariantSerializer serializer;
1142 SerializationContext::Params params;
1143 params.external_state_policy = external_state_policy_;
1144 SerializationContext serialization_ctx(params);
1145 OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
1146 iterator_resource));
1147 Tensor* serialized_t;
1148 OP_REQUIRES_OK(ctx,
1149 ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
1150 &serialized_t));
1151 OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
1152 }
1153
Compute(OpKernelContext * ctx)1154 void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
1155 tensorflow::ResourceTagger tag(kTFDataResourceTag,
1156 ctx->op_kernel().type_string());
1157 // Validate that the handle corresponds to a real resource, and
1158 // that it is an IteratorResource.
1159 IteratorResource* iterator_resource;
1160 OP_REQUIRES_OK(
1161 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1162 core::ScopedUnref unref_iterator(iterator_resource);
1163 const Tensor* serialized_t;
1164 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
1165 IteratorVariantSerializer serializer;
1166 OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
1167 OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, serializer.GetReader()));
1168 }
1169
1170 namespace {
1171
1172 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
1173 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
1174 IteratorHandleOp);
1175 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU).Priority(1),
1176 IteratorHandleOp);
1177 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2),
1178 MakeIteratorOp);
1179 REGISTER_KERNEL_BUILDER(
1180 Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"),
1181 MakeIteratorOp);
1182 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_CPU).Priority(2),
1183 DeleteIteratorOp);
1184 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_GPU).Priority(1),
1185 DeleteIteratorOp);
1186 REGISTER_KERNEL_BUILDER(
1187 Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2),
1188 AnonymousIteratorHandleOp);
1189 REGISTER_KERNEL_BUILDER(
1190 Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1),
1191 AnonymousIteratorHandleOp);
1192 REGISTER_KERNEL_BUILDER(
1193 Name("AnonymousIteratorV2").Device(DEVICE_CPU).Priority(2),
1194 AnonymousIteratorHandleOp);
1195 REGISTER_KERNEL_BUILDER(
1196 Name("AnonymousIteratorV2").Device(DEVICE_GPU).Priority(1),
1197 AnonymousIteratorHandleOp);
1198 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
1199 ToSingleElementOp);
1200 REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU),
1201 ReduceDatasetOp);
1202 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
1203 OneShotIteratorOp);
1204 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU).Priority(2),
1205 IteratorGetNextOp);
1206 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU).Priority(1),
1207 IteratorGetNextOp);
1208 REGISTER_KERNEL_BUILDER(
1209 Name("IteratorGetNextSync").Device(DEVICE_CPU).Priority(2),
1210 IteratorGetNextOp);
1211 REGISTER_KERNEL_BUILDER(
1212 Name("IteratorGetNextSync").Device(DEVICE_GPU).Priority(1),
1213 IteratorGetNextOp);
1214 REGISTER_KERNEL_BUILDER(
1215 Name("IteratorGetNextAsOptional").Device(DEVICE_CPU).Priority(2),
1216 IteratorGetNextAsOptionalOp);
1217 REGISTER_KERNEL_BUILDER(
1218 Name("IteratorGetNextAsOptional").Device(DEVICE_GPU).Priority(1),
1219 IteratorGetNextAsOptionalOp);
1220 REGISTER_KERNEL_BUILDER(
1221 Name("IteratorToStringHandle").Device(DEVICE_CPU).Priority(2),
1222 IteratorToStringHandleOp);
1223 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
1224 .Device(DEVICE_GPU)
1225 .HostMemory("string_handle")
1226 .Priority(1),
1227 IteratorToStringHandleOp);
1228 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
1229 IteratorFromStringHandleOp);
1230 REGISTER_KERNEL_BUILDER(
1231 Name("IteratorFromStringHandleV2").Device(DEVICE_CPU).Priority(2),
1232 IteratorFromStringHandleOp);
1233 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
1234 .Device(DEVICE_GPU)
1235 .HostMemory("string_handle")
1236 .Priority(1),
1237 IteratorFromStringHandleOp);
1238 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
1239 SerializeIteratorOp);
1240 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
1241 DeserializeIteratorOp);
1242
1243 REGISTER_INPUT_COLOCATION_EXEMPTION("ReduceDataset");
1244
1245 } // namespace
1246
1247 } // namespace data
1248 } // namespace tensorflow
1249