1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/iterator_ops.h"
16 #include <memory>
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/core/common_runtime/graph_runner.h"
20 #include "tensorflow/core/common_runtime/renamed_device.h"
21 #include "tensorflow/core/common_runtime/threadpool_device.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/function_handle_cache.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/resource_op_kernel.h"
26 #include "tensorflow/core/framework/stats_aggregator.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/variant_op_registry.h"
29 #include "tensorflow/core/graph/graph_constructor.h"
30 #include "tensorflow/core/kernels/data/dataset_utils.h"
31 #include "tensorflow/core/kernels/data/optional_ops.h"
32 #include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
33 #include "tensorflow/core/kernels/ops_util.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/lib/random/random.h"
37 #include "tensorflow/core/lib/strings/strcat.h"
38 #include "tensorflow/core/lib/strings/stringprintf.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/mutex.h"
41 #include "tensorflow/core/public/session_options.h"
42
43 namespace tensorflow {
44 namespace data {
45 namespace {
46
47 // See documentation in ../../ops/dataset_ops.cc for a high-level
48 // description of the following ops.
49
50 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
51
52 } // namespace
53
54 class IteratorResource : public ResourceBase {
55 public:
IteratorResource(Env * env,const DataTypeVector & output_dtypes,const std::vector<PartialTensorShape> & output_shapes,const int,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib)56 IteratorResource(Env* env, const DataTypeVector& output_dtypes,
57 const std::vector<PartialTensorShape>& output_shapes,
58 const int /*unused: graph_def_version*/,
59 std::unique_ptr<DeviceMgr> device_mgr,
60 std::unique_ptr<FunctionLibraryDefinition> flib_def,
61 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
62 FunctionLibraryRuntime* lib)
63 : unbounded_thread_pool_(env, "tf_data_iterator_resource"),
64 device_mgr_(std::move(device_mgr)),
65 iterator_state_(std::make_shared<State>(
66 std::move(flib_def), std::move(pflr), lib, nullptr /* iterator */)),
67 output_dtypes_(output_dtypes),
68 output_shapes_(output_shapes) {}
69
GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)70 Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
71 bool* end_of_sequence) {
72 std::shared_ptr<State> captured_state;
73 {
74 tf_shared_lock l(mu_);
75 captured_state = iterator_state_;
76 }
77 if (captured_state->iterator) {
78 IteratorContext::Params params(ctx);
79 params.lib = captured_state->lib;
80 params.function_handle_cache =
81 captured_state->function_handle_cache.get();
82 params.resource_mgr = &captured_state->resource_mgr;
83 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
84 return captured_state->iterator->GetNext(
85 IteratorContext(std::move(params)), out_tensors, end_of_sequence);
86 } else {
87 return errors::FailedPrecondition(
88 "GetNext() failed because the iterator has not been initialized. "
89 "Ensure that you have run the initializer operation for this "
90 "iterator before getting the next element.");
91 }
92 }
93
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)94 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
95 bool* end_of_sequence) {
96 return GetNext(&ctx, out_tensors, end_of_sequence);
97 }
98
Save(SerializationContext * ctx,IteratorStateWriter * writer)99 Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
100 std::shared_ptr<State> captured_state;
101 {
102 tf_shared_lock l(mu_);
103 captured_state = iterator_state_;
104 }
105 if (captured_state) {
106 SerializationContext::Params params;
107 // The iterator state may contain functions that are not present
108 // in ctx's function library. Namely, an iterator may be restored from
109 // a serialized iterator with a modified function library (for example, as
110 // a result of OptimizeDataset). These modified functions are needed
111 // to serialize the iterator again.
112 params.flib_def = captured_state->flib_def.get();
113 params.input_list = ctx->input_list();
114 params.optimization_only = ctx->optimization_only();
115 SerializationContext ctx_with_functions(params);
116 return captured_state->iterator->Save(&ctx_with_functions, writer);
117 } else {
118 return errors::FailedPrecondition(
119 "Save() failed because the iterator has not been initialized. "
120 "Ensure that you have run the initializer operation for this "
121 "iterator before saving it.");
122 }
123 }
124
Restore(OpKernelContext * ctx,IteratorStateReader * reader)125 Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
126 string serialized_graph_def;
127 TF_RETURN_IF_ERROR(reader->ReadScalar(DatasetBase::kDatasetGraphKey,
128 &serialized_graph_def));
129 GraphDef graph_def;
130 if (!graph_def.ParseFromString(serialized_graph_def)) {
131 return errors::Internal("Error parsing dataset GraphDef.");
132 }
133 string output_node;
134 TF_RETURN_IF_ERROR(reader->ReadScalar(
135 DatasetBase::kDatasetGraphOutputNodeKey, &output_node));
136 DatasetBase* dataset = nullptr;
137 Graph graph(OpRegistry::Global());
138 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
139 std::vector<Tensor> outputs;
140 GraphRunner graph_runner(ctx->env());
141
142 // Build a new FLR that knows about the functions in the graph, and use
143 // it for all operations on the restored iterator.
144 // NOTE(mrry): We clone the existing FLR and use it in the GraphRunner
145 // because some of the OpKernels in the graph might call functions that are
146 // only defined in the loaded GraphDef.
147 FunctionLibraryRuntime* lib;
148 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
149 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
150 TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def, &pflr, &lib));
151
152 // Some function names may be duplicated (for example, if the serialized
153 // graph has an optimized function that retains its original name). We
154 // override functions in flib_def in the event of conflict. It is
155 // safe to assume that any node in the serialized graph is referring to the
156 // serialized function when there is a conflict.
157 TF_RETURN_IF_ERROR(
158 AddToFunctionLibrary(flib_def.get(), graph_def.library()));
159 std::unique_ptr<State> new_state = absl::make_unique<State>(
160 std::move(flib_def), std::move(pflr), lib, nullptr /* iterator */);
161
162 TF_RETURN_IF_ERROR(
163 graph_runner.Run(&graph, new_state->lib, {}, {output_node}, &outputs));
164 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
165
166 IteratorContext::Params params(ctx);
167 params.lib = new_state->lib;
168 params.function_handle_cache = new_state->function_handle_cache.get();
169 params.resource_mgr = &new_state->resource_mgr;
170 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
171
172 TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
173 "Iterator", &new_state->iterator));
174 TF_RETURN_IF_ERROR(
175 VerifyTypesMatch(output_dtypes_, new_state->iterator->output_dtypes()));
176 TF_RETURN_IF_ERROR(VerifyShapesCompatible(
177 output_shapes_, new_state->iterator->output_shapes()));
178
179 {
180 IteratorContext::Params params(ctx);
181 params.lib = new_state->lib;
182 params.function_handle_cache = new_state->function_handle_cache.get();
183 params.resource_mgr = &new_state->resource_mgr;
184 DeviceBase* device = new_state->lib->device();
185 params.allocator_getter = [device](AllocatorAttributes attrs) {
186 return device->GetAllocator(attrs);
187 };
188 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
189 IteratorContext iter_ctx(std::move(params));
190 TF_RETURN_IF_ERROR(new_state->iterator->Restore(&iter_ctx, reader));
191 }
192
193 mutex_lock l(mu_);
194 iterator_state_ = std::move(new_state);
195 return Status::OK();
196 }
197
AddLibrary(const FunctionLibraryDefinition & flib_def)198 Status AddLibrary(const FunctionLibraryDefinition& flib_def) {
199 mutex_lock l(mu_);
200 return iterator_state_->flib_def->AddLibrary(flib_def);
201 }
202
SetIteratorFromDataset(OpKernelContext * ctx,DatasetBase * dataset)203 Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset) {
204 std::shared_ptr<State> new_state;
205 {
206 tf_shared_lock l(mu_);
207 new_state = std::make_shared<State>(
208 iterator_state_->flib_def, iterator_state_->pflr,
209 iterator_state_->lib, nullptr /* function_handle_cache */,
210 nullptr /* iterator */);
211 }
212
213 // Ensure that the iterator has access to all functions in the current
214 // subgraph, because some functions may have been defined after the resource
215 // was initially created.
216 Status s = new_state->flib_def->AddLibrary(
217 *ctx->function_library()->GetFunctionLibraryDefinition());
218
219 if (!s.ok()) {
220 // Adding functions to `flib_def_` may fail, if there are clashes between
221 // the function names in (e.g.) a restored graph and the currently
222 // executing graph. In that case, we create a new function runtime for
223 // this iterator, based on the current `OpKernelContext`, which will have
224 // the functions we need.
225 FunctionLibraryRuntime* lib;
226 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
227 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
228 TF_RETURN_IF_ERROR(
229 ctx->function_library()->Clone(&flib_def, &pflr, &lib));
230 new_state->flib_def = std::move(flib_def);
231 new_state->pflr = std::move(pflr);
232 new_state->lib = lib;
233 }
234
235 new_state->function_handle_cache =
236 absl::make_unique<FunctionHandleCache>(new_state->lib);
237 // Create new iterator.
238 std::unique_ptr<IteratorBase> iterator;
239 IteratorContext::Params params(ctx);
240 params.lib = new_state->lib;
241 params.function_handle_cache = new_state->function_handle_cache.get();
242 params.resource_mgr = &new_state->resource_mgr;
243 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
244 TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
245 "Iterator", &iterator));
246 TF_RETURN_IF_ERROR(
247 VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
248 TF_RETURN_IF_ERROR(
249 VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
250 std::swap(new_state->iterator, iterator);
251
252 mutex_lock l(mu_);
253 std::swap(iterator_state_, new_state);
254 return Status::OK();
255 }
256
DebugString() const257 string DebugString() const override { return "Iterator resource"; }
258
output_dtypes() const259 const DataTypeVector& output_dtypes() const { return output_dtypes_; }
260
output_shapes() const261 const std::vector<PartialTensorShape>& output_shapes() const {
262 return output_shapes_;
263 }
264
265 private:
266 struct State {
Statetensorflow::data::IteratorResource::State267 State(std::shared_ptr<FunctionLibraryDefinition> flib_def,
268 std::shared_ptr<ProcessFunctionLibraryRuntime> pflr,
269 FunctionLibraryRuntime* lib, std::unique_ptr<IteratorBase> iterator)
270 : flib_def(flib_def),
271 pflr(pflr),
272 lib(lib),
273 function_handle_cache(absl::make_unique<FunctionHandleCache>(lib)),
274 iterator(std::move(iterator)) {}
275
Statetensorflow::data::IteratorResource::State276 State(std::shared_ptr<FunctionLibraryDefinition> flib_def,
277 std::shared_ptr<ProcessFunctionLibraryRuntime> pflr,
278 FunctionLibraryRuntime* lib,
279 std::unique_ptr<FunctionHandleCache> function_handle_cache,
280 std::unique_ptr<IteratorBase> iterator)
281 : flib_def(flib_def),
282 pflr(pflr),
283 lib(lib),
284 function_handle_cache(std::move(function_handle_cache)),
285 iterator(std::move(iterator)) {}
286
287 std::shared_ptr<FunctionLibraryDefinition> flib_def;
288 std::shared_ptr<ProcessFunctionLibraryRuntime> pflr;
289 FunctionLibraryRuntime* lib = nullptr; // not owned.
290 std::unique_ptr<FunctionHandleCache> function_handle_cache;
291 ResourceMgr resource_mgr;
292 std::unique_ptr<IteratorBase> iterator;
293 };
294
295 UnboundedThreadPool unbounded_thread_pool_;
296 mutex mu_;
297 const std::unique_ptr<DeviceMgr> device_mgr_ GUARDED_BY(mu_);
298 std::shared_ptr<State> iterator_state_ GUARDED_BY(mu_);
299 const DataTypeVector output_dtypes_;
300 const std::vector<PartialTensorShape> output_shapes_;
301 };
302
303 namespace {
304
305 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
306 // The get() method returns an IteratorStateReader which can be used
307 // to restore iterator state.
308 //
309 // Usage example:
310 //
311 // Encoding:
312 //
313 // Tensor t(DT_VARIANT, TensorShape({}));
314 // t->scalar<Variant>()() = IteratorStateVariant(iterator_resource);
315 //
316 // Encode() sets the type_name of the VariantTensorData object to
317 // IteratorStateVariant::TypeName().
318 //
319 // Decoding:
320 //
321 // Variant v = <VariantTensorDataProto object>;
322 // DecodeUnaryVariant(&v);
323 // IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
324 // iterator_resource->Restore(ctx, wrapper->get())
325 //
326 // The type_name of the VariantTensorData object to be decoded must
327 // match IteratorStateVariant::TypeName().
328 class IteratorStateVariant {
329 public:
IteratorStateVariant()330 IteratorStateVariant() : data_(nullptr) {}
IteratorStateVariant(const IteratorStateVariant & other)331 IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
332 if (other.data_) {
333 Decode(*other.data_);
334 }
335 }
336 // Initializes this object with the current state of the iterator so
337 // that it can be written on the next call to Encode().
InitializeFromIterator(OpKernelContext * ctx,IteratorResource * iterator_resource)338 Status InitializeFromIterator(OpKernelContext* ctx,
339 IteratorResource* iterator_resource) {
340 SerializationContext::Params params;
341 params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
342 SerializationContext serialization_ctx(params);
343 data_ = absl::make_unique<VariantTensorData>();
344 data_->set_type_name(TypeName());
345 VariantTensorDataWriter writer(data_.get());
346 TF_RETURN_IF_ERROR(iterator_resource->Save(&serialization_ctx, &writer));
347 TF_RETURN_IF_ERROR(writer.Flush());
348 return Status::OK();
349 }
TypeName() const350 string TypeName() const { return kIteratorVariantTypeName; }
Encode(VariantTensorData * data) const351 void Encode(VariantTensorData* data) const { *data = *data_; }
Decode(VariantTensorData data)352 bool Decode(VariantTensorData data) {
353 if (data.type_name() != TypeName()) {
354 return false;
355 }
356 std::unique_ptr<VariantTensorData> tensor_data =
357 absl::make_unique<VariantTensorData>();
358 std::swap(*tensor_data, data);
359 std::unique_ptr<VariantTensorDataReader> reader =
360 absl::make_unique<VariantTensorDataReader>(tensor_data.get());
361 data_ = std::move(tensor_data);
362 reader_ = std::move(reader);
363 return true;
364 }
get()365 IteratorStateReader* get() { return reader_.get(); }
DebugString() const366 string DebugString() const {
367 if (data_) {
368 return strings::StrCat("IteratorStateVariant<", data_->DebugString(),
369 ">");
370 } else {
371 return strings::StrCat("IteratorStateVariant<empty>");
372 }
373 }
374
375 private:
376 std::unique_ptr<IteratorStateReader> reader_;
377 std::unique_ptr<VariantTensorData> data_;
378 };
379
380 // Register the reader class in the global variant decode_fn registry
381 // so that a Variant containing a serialized representation of iterator state
382 // can be decoded using DecodeUnaryVariant. If we don't do this we will need
383 // to manually decode the returned Variant using MaybeDecodeAndCopy in
384 // DeserializeIteratorOp which is not recommended.
385 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
386 kIteratorVariantTypeName);
387
388 } // namespace
389
390 // Note that IteratorHandleOp holds a reference to the resource it creates. If
391 // cleaning up resources with DestroyResourceOp is important, consider creating
392 // resource containers with AnonymousIteratorHandleOp instead.
IteratorHandleOp(OpKernelConstruction * ctx)393 IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
394 : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
395 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
396 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
397 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
398 }
399
400 // The resource is deleted from the resource manager only when it is private
401 // to kernel. Ideally the resource should be deleted when it is no longer held
402 // by anyone, but it would break backward compatibility.
~IteratorHandleOp()403 IteratorHandleOp::~IteratorHandleOp() {
404 if (resource_ != nullptr) {
405 resource_->Unref();
406 if (cinfo_.resource_is_private_to_kernel()) {
407 if (!cinfo_.resource_manager()
408 ->template Delete<IteratorResource>(cinfo_.container(),
409 cinfo_.name())
410 .ok()) {
411 // Do nothing; the resource can have been deleted by session resets.
412 }
413 }
414 }
415 }
416
Compute(OpKernelContext * context)417 void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) {
418 {
419 mutex_lock l(mu_);
420 if (resource_ == nullptr) {
421 FunctionLibraryRuntime* lib;
422 std::unique_ptr<DeviceMgr> device_mgr(nullptr);
423 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
424 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
425 // If the iterator is shared then we construct a new FLR, and pass that
426 // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
427 // functions from the iterator. We may add this functionality if there
428 // is sufficient demand, but it will require a significant refactoring.
429 if (!name_.empty()) {
430 lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
431 } else {
432 OP_REQUIRES_OK(context, context->function_library()->Clone(
433 &flib_def, &pflr, &lib));
434 }
435
436 ResourceMgr* mgr = context->resource_manager();
437 OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
438
439 IteratorResource* resource;
440 OP_REQUIRES_OK(
441 context,
442 mgr->LookupOrCreate<IteratorResource>(
443 cinfo_.container(), cinfo_.name(), &resource,
444 [context, lib, &device_mgr, &flib_def, &pflr,
445 this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
446 *ret = new IteratorResource(
447 context->env(), output_dtypes_, output_shapes_,
448 graph_def_version_, std::move(device_mgr),
449 std::move(flib_def), std::move(pflr), lib);
450 return Status::OK();
451 }));
452
453 Status s = VerifyResource(resource);
454 if (TF_PREDICT_FALSE(!s.ok())) {
455 resource->Unref();
456 context->SetStatus(s);
457 return;
458 }
459
460 resource_ = resource;
461 }
462 }
463 OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
464 context, 0, cinfo_.container(), cinfo_.name(),
465 MakeTypeIndex<IteratorResource>()));
466 }
467
VerifyResource(IteratorResource * resource)468 Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
469 TF_RETURN_IF_ERROR(
470 VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
471 TF_RETURN_IF_ERROR(
472 VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
473 return Status::OK();
474 }
475
CreatePrivateFLR(OpKernelContext * ctx,std::unique_ptr<DeviceMgr> * device_mgr,std::unique_ptr<FunctionLibraryDefinition> * flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * pflr)476 FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
477 OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
478 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
479 std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
480 // Wrap the existing device in order to see any captured resources
481 // in its resource manager. The existing device will outlive the
482 // IteratorResource, because we are storing the IteratorResource
483 // in that device's resource manager.
484 *device_mgr = absl::make_unique<DeviceMgr>(RenamedDevice::NewRenamedDevice(
485 ctx->device()->name(), down_cast<Device*>(ctx->device()),
486 false /* owns_underlying */, false /* isolate_session_state */));
487 *flib_def = absl::make_unique<FunctionLibraryDefinition>(
488 *ctx->function_library()->GetFunctionLibraryDefinition());
489 *pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
490 device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(),
491 OptimizerOptions{} /* TODO(mrry): OptimizerOptions? */,
492 nullptr /* TODO(mrry): ClusterFLR */);
493
494 return (*pflr)->GetFLR(ctx->device()->name());
495 }
496
497 // Like IteratorHandleOp, but creates handles which are never shared, and does
498 // not hold a reference to these handles. The latter is important for eager
499 // execution, since OpKernel instances generally live as long as the program
500 // running them.
AnonymousIteratorHandleOp(OpKernelConstruction * context)501 AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
502 OpKernelConstruction* context)
503 : OpKernel(context), graph_def_version_(context->graph_def_version()) {
504 OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_));
505 OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_));
506 }
507
Compute(OpKernelContext * context)508 void AnonymousIteratorHandleOp::Compute(OpKernelContext* context) {
509 FunctionLibraryRuntime* lib;
510 std::unique_ptr<DeviceMgr> device_mgr(nullptr);
511 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
512 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
513 OP_REQUIRES_OK(context,
514 context->function_library()->Clone(&flib_def, &pflr, &lib));
515
516 ResourceMgr* mgr = context->resource_manager();
517
518 const string container_name = "AnonymousIterator";
519 string unique_name;
520 {
521 mutex_lock l(static_resource_lookup_mutex_);
522 while (true) { // Find an unused name
523 IteratorResource* existing_resource = nullptr;
524 unique_name = strings::StrCat("AnonymousIterator", current_id_++);
525 Status status = mgr->Lookup<IteratorResource>(container_name, unique_name,
526 &existing_resource);
527 if (status.code() == error::NOT_FOUND) {
528 break;
529 }
530 OP_REQUIRES_OK(context, status);
531 existing_resource->Unref();
532 }
533 IteratorResource* new_resource = new IteratorResource(
534 context->env(), output_dtypes_, output_shapes_, graph_def_version_,
535 std::move(device_mgr), std::move(flib_def), std::move(pflr), lib);
536 // Create the resource with our chosen name under the resource lookup
537 // mutex to avoid another kernel racily creating a resource with this
538 // name.
539 OP_REQUIRES_OK(context, mgr->Create<IteratorResource>(
540 container_name, unique_name, new_resource));
541 }
542 OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
543 context, 0, container_name, unique_name,
544 MakeTypeIndex<IteratorResource>()));
545 }
546
547 // Static initializers for AnonymousIteratorHandleOp id counting.
548 mutex AnonymousIteratorHandleOp::static_resource_lookup_mutex_{
549 LINKER_INITIALIZED};
550 int64 AnonymousIteratorHandleOp::current_id_(0);
551
Compute(OpKernelContext * ctx)552 void MakeIteratorOp::Compute(OpKernelContext* ctx) {
553 DatasetBase* dataset;
554 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
555 IteratorResource* iterator_resource;
556 OP_REQUIRES_OK(
557 ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
558 core::ScopedUnref unref(iterator_resource);
559 OP_REQUIRES_OK(ctx, iterator_resource->SetIteratorFromDataset(ctx, dataset));
560 }
561
562 namespace {
563
564 class ToSingleElementOp : public AsyncOpKernel {
565 public:
ToSingleElementOp(OpKernelConstruction * ctx)566 explicit ToSingleElementOp(OpKernelConstruction* ctx)
567 : AsyncOpKernel(ctx),
568 background_worker_(ctx->env(), "tf_data_to_single_element") {}
569
ComputeAsync(OpKernelContext * ctx,DoneCallback done)570 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
571 // The call to `iterator->GetNext()` may block and depend on an
572 // inter-op thread pool thread, so we issue the call from the
573 // owned thread pool.
574 background_worker_.Schedule([ctx, done]() {
575 DatasetBase* dataset;
576 OP_REQUIRES_OK_ASYNC(
577 ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
578 std::unique_ptr<IteratorBase> iterator;
579 IteratorContext::Params params(ctx);
580 std::unique_ptr<FunctionHandleCache> function_handle_cache =
581 absl::make_unique<FunctionHandleCache>(params.lib);
582 params.function_handle_cache = function_handle_cache.get();
583 std::unique_ptr<ResourceMgr> resource_mgr =
584 absl::make_unique<ResourceMgr>();
585 params.resource_mgr = resource_mgr.get();
586 IteratorContext iter_ctx(std::move(params));
587
588 OP_REQUIRES_OK_ASYNC(
589 ctx,
590 dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator),
591 done);
592
593 // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
594 // avoid destruction races.
595 IteratorBase* raw_iterator = iterator.release();
596 auto cleanup = gtl::MakeCleanup([raw_iterator, done] {
597 delete raw_iterator;
598 done();
599 });
600 std::vector<Tensor> components;
601 components.reserve(dataset->output_dtypes().size());
602 bool end_of_sequence = false;
603
604 Status s =
605 raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
606 if (!s.ok()) {
607 ctx->SetStatus(s);
608 return;
609 }
610 if (end_of_sequence) {
611 ctx->SetStatus(errors::InvalidArgument("Dataset was empty."));
612 return;
613 }
614 for (int i = 0; i < components.size(); ++i) {
615 // TODO(mrry): Check that the shapes match the shape attrs.
616 ctx->set_output(i, components[i]);
617 }
618
619 components.clear();
620 Status s2 =
621 raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
622 if (!s2.ok()) {
623 ctx->SetStatus(s2);
624 return;
625 }
626 if (!end_of_sequence) {
627 ctx->SetStatus(
628 errors::InvalidArgument("Dataset had more than one element."));
629 return;
630 }
631 });
632 }
633
634 private:
635 BackgroundWorker background_worker_;
636 };
637
638 class ReduceDatasetOp : public AsyncOpKernel {
639 public:
ReduceDatasetOp(OpKernelConstruction * ctx)640 explicit ReduceDatasetOp(OpKernelConstruction* ctx)
641 : AsyncOpKernel(ctx),
642 background_worker_(ctx->env(), "tf_data_reduce_dataset") {
643 OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &reduce_func_));
644 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
645 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
646 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
647 &use_inter_op_parallelism_));
648 }
649
ComputeAsync(OpKernelContext * ctx,DoneCallback done)650 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
651 // The call to `iterator->GetNext()` may block and depend on an
652 // inter-op thread pool thread, so we issue the call from the
653 // owned thread pool.
654 background_worker_.Schedule([this, ctx, done]() {
655 DatasetBase* dataset;
656 OP_REQUIRES_OK_ASYNC(
657 ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
658 OpInputList inputs;
659 OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs),
660 done);
661 std::vector<Tensor> state(inputs.begin(), inputs.end());
662
663 std::unique_ptr<CapturedFunction> captured_func;
664 OP_REQUIRES_OK_ASYNC(
665 ctx,
666 CapturedFunction::Create(reduce_func_, ctx, "other_arguments",
667 use_inter_op_parallelism_, &captured_func),
668 done);
669
670 IteratorContext::Params params(ctx);
671 std::unique_ptr<FunctionHandleCache> function_handle_cache =
672 absl::make_unique<FunctionHandleCache>(params.lib);
673 params.function_handle_cache = function_handle_cache.get();
674 std::unique_ptr<ResourceMgr> resource_mgr =
675 absl::make_unique<ResourceMgr>();
676 params.resource_mgr = resource_mgr.get();
677 IteratorContext iter_ctx(std::move(params));
678 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func;
679 OP_REQUIRES_OK_ASYNC(
680 ctx,
681 captured_func->Instantiate(&iter_ctx, &instantiated_captured_func),
682 done);
683
684 std::unique_ptr<IteratorBase> iterator;
685 OP_REQUIRES_OK_ASYNC(
686 ctx, dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator),
687 done);
688
689 // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
690 // avoid destruction races.
691 IteratorBase* raw_iterator = iterator.release();
692 auto cleanup = gtl::MakeCleanup([raw_iterator, done] {
693 delete raw_iterator;
694 done();
695 });
696
697 // Iterate through the input dataset.
698 Status status;
699 while (true) {
700 std::vector<Tensor> next_input_element;
701 bool end_of_input;
702 status = raw_iterator->GetNext(&iter_ctx, &next_input_element,
703 &end_of_input);
704 if (!status.ok() || end_of_input) {
705 break;
706 }
707
708 // Run the reduce function to update the current state.
709 std::vector<Tensor> args;
710 args.reserve(state.size() + next_input_element.size());
711 std::copy(state.begin(), state.end(), std::back_inserter(args));
712 std::copy(next_input_element.begin(), next_input_element.end(),
713 std::back_inserter(args));
714
715 std::vector<Tensor> reduce_func_output;
716 status = instantiated_captured_func->Run(&iter_ctx, std::move(args),
717 &reduce_func_output);
718 if (!status.ok()) {
719 break;
720 }
721 std::swap(reduce_func_output, state);
722 }
723
724 if (!status.ok()) {
725 ctx->SetStatus(status);
726 return;
727 }
728 for (int i = 0; i < state.size(); ++i) {
729 OP_REQUIRES_ASYNC(
730 ctx, state[i].dtype() == output_types_[i],
731 errors::InvalidArgument(
732 "The result does not match the expected type for component ", i,
733 ". Expected: ", DataTypeString(output_types_[i]),
734 ". Actual: ", DataTypeString(state[i].dtype()), "."),
735 done);
736 OP_REQUIRES_ASYNC(
737 ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()),
738 errors::InvalidArgument(
739 "The result does not match the expected shape for component ",
740 i, ". Expected: ", output_shapes_[i].DebugString(),
741 ". Actual: ", state[i].shape().DebugString(), "."),
742 done);
743 ctx->set_output(i, state[i]);
744 }
745 });
746 }
747
748 private:
749 NameAttrList reduce_func_;
750 DataTypeVector output_types_;
751 std::vector<PartialTensorShape> output_shapes_;
752 bool use_inter_op_parallelism_;
753 BackgroundWorker background_worker_;
754 };
755
756 class OneShotIteratorOp : public AsyncOpKernel {
757 public:
OneShotIteratorOp(OpKernelConstruction * ctx)758 explicit OneShotIteratorOp(OpKernelConstruction* ctx)
759 : AsyncOpKernel(ctx),
760 background_worker_(ctx->env(), "tf_data_one_shot_iterator"),
761 graph_def_version_(ctx->graph_def_version())
762
763 {
764 string shared_name;
765 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name));
766 OP_REQUIRES(ctx, shared_name.empty(),
767 errors::InvalidArgument("OneShotIteratorOp does not currently "
768 "support the 'shared_name' attr."));
769 OP_REQUIRES_OK(ctx,
770 ctx->GetAttr("dataset_factory", &dataset_factory_func_));
771 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
772 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
773 }
774
~OneShotIteratorOp()775 ~OneShotIteratorOp() override {
776 if (iterator_resource_ != nullptr) {
777 iterator_resource_->Unref();
778 if (!cinfo_.resource_manager()
779 ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name())
780 .ok()) {
781 // Do nothing; the resource can have been deleted by session resets.
782 }
783 }
784 }
785
786 // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`,
787 // but due to the fact that `ResourceOpKernel<T>::CreateResource()`
788 // does not provide access to the `OpKernelContext*` and we need
789 // this to invoke the factory function, it's not possible to
790 // implement this kernel by implementing `CreateResource()`.
791 // Furthermore, due to the fact that this kernel might block when
792 // running the initialization function, we must implement this
793 // kernel as an async kernel.
ComputeAsync(OpKernelContext * ctx,DoneCallback done)794 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
795 {
796 mutex_lock l(mu_);
797 if (iterator_resource_ == nullptr && initialization_status_.ok()) {
798 // The initialization thread will call `done`.
799 if (!initialization_started_) {
800 // TODO(mrry): Convert the initialization code to use
801 // callbacks instead of wasting a thread.
802 background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); });
803 initialization_started_ = true;
804 } else {
805 done_callbacks_.emplace_back(ctx, std::move(done));
806 }
807 return;
808 }
809 }
810 ProduceOutput(ctx, done);
811 }
812
813 private:
Init(OpKernelContext * ctx,const DoneCallback & done)814 void Init(OpKernelContext* ctx, const DoneCallback& done) {
815 IteratorResource* iterator = nullptr;
816 ContainerInfo cinfo;
817 Status s = TryInit(ctx, &iterator, &cinfo);
818
819 std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run;
820 {
821 mutex_lock l(mu_);
822 if (s.ok()) {
823 iterator_resource_ = iterator;
824 cinfo_ = cinfo;
825 }
826 initialization_status_ = s;
827 std::swap(done_callbacks_, callbacks_to_run);
828 }
829
830 for (auto&& ctx_done : callbacks_to_run) {
831 ProduceOutput(ctx_done.first, ctx_done.second);
832 }
833 ProduceOutput(ctx, done);
834 }
835
TryInit(OpKernelContext * ctx,IteratorResource ** iterator,ContainerInfo * cinfo)836 Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
837 ContainerInfo* cinfo) {
838 TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def()));
839
840 FunctionLibraryRuntime* lib;
841 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
842 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
843 TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def, &pflr, &lib));
844
845 // Create an IteratorResource that will hold the iterator for this op.
846 TF_RETURN_IF_ERROR(
847 ctx->resource_manager()->LookupOrCreate<IteratorResource>(
848 cinfo->container(), cinfo->name(), iterator,
849 [ctx, lib, this, &flib_def, &pflr](IteratorResource** ret)
850 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
851 *ret = new IteratorResource(
852 ctx->env(), output_dtypes_, output_shapes_,
853 graph_def_version_, nullptr, std::move(flib_def),
854 std::move(pflr), lib);
855 return Status::OK();
856 }));
857
858 core::ScopedUnref unref_iterator(*iterator);
859
860 TF_RETURN_IF_ERROR(
861 VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes()));
862 TF_RETURN_IF_ERROR(
863 VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes()));
864
865 // Call the dataset_factory_func_ to create a new dataset,
866 // over which this op will iterate.
867 FunctionLibraryRuntime::Handle f_handle;
868 TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
869 dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()),
870 &f_handle));
871 FunctionLibraryRuntime::Options opts;
872 opts.cancellation_manager = ctx->cancellation_manager();
873 // Choose a step ID that is guaranteed not to clash with any
874 // Session-generated step ID. DirectSession only generates
875 // non-negative step IDs (contiguous, starting from 0), and
876 // MasterSession generates 56-bit random step IDs whose MSB is
877 // always 0, so a negative random step ID should suffice.
878 opts.step_id = -std::abs(static_cast<int64>(random::New64()));
879 ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
880 ctx->resource_manager()->Cleanup(name).IgnoreError();
881 });
882 opts.step_container = &step_container;
883 opts.runner = ctx->runner();
884 Notification n;
885 Status factory_status;
886 std::vector<Tensor> return_values;
887 ctx->function_library()->Run(opts, f_handle, {}, &return_values,
888 [&n, &factory_status](Status s) {
889 factory_status.Update(s);
890 n.Notify();
891 });
892 n.WaitForNotification();
893 TF_RETURN_IF_ERROR(factory_status);
894 if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
895 !TensorShapeUtils::IsScalar(return_values[0].shape())) {
896 return errors::InvalidArgument(
897 "The `dataset_factory` function must return "
898 "a single scalar of dtype DT_VARIANT.");
899 }
900
901 // Create an iterator for the dataset that was created in the
902 // factory function.
903 DatasetBase* dataset;
904 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
905 TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
906 (*iterator)->Ref();
907 return Status::OK();
908 }
909
ProduceOutput(OpKernelContext * ctx,const DoneCallback & done)910 void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) {
911 Tensor* handle;
912 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle),
913 done);
914 Status s;
915 {
916 mutex_lock l(mu_);
917 s = initialization_status_;
918 if (s.ok()) {
919 handle->scalar<ResourceHandle>()() =
920 MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(),
921 cinfo_.name());
922 }
923 }
924 OP_REQUIRES_OK_ASYNC(ctx, s, done);
925 done();
926 }
927
928 NameAttrList dataset_factory_func_;
929 DataTypeVector output_dtypes_;
930 std::vector<PartialTensorShape> output_shapes_;
931
932 BackgroundWorker background_worker_;
933
934 mutex mu_;
935 ContainerInfo cinfo_ GUARDED_BY(mu_);
936 IteratorResource* iterator_resource_ GUARDED_BY(mu_) = nullptr;
937
938 bool initialization_started_ GUARDED_BY(mu_) = false;
939 Status initialization_status_ GUARDED_BY(mu_);
940 std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_
941 GUARDED_BY(mu_);
942 const int graph_def_version_;
943 };
944
945 } // namespace
946
ComputeAsync(OpKernelContext * ctx,DoneCallback done)947 void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
948 IteratorResource* iterator;
949 OP_REQUIRES_OK_ASYNC(
950 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
951 // The call to `iterator->GetNext()` may block and depend on an
952 // inter-op thread pool thread, so we issue the call from the
953 // owned thread pool.
954 background_worker_.Schedule(std::bind(
955 [ctx, iterator](DoneCallback done) {
956 std::vector<Tensor> components;
957 bool end_of_sequence = false;
958
959 Status s = iterator->GetNext(IteratorContext(ctx), &components,
960 &end_of_sequence);
961 // NOTE(mrry): We must unref the iterator before calling `done()`, to
962 // avoid destruction races.
963 iterator->Unref();
964
965 if (!s.ok()) {
966 ctx->SetStatus(s);
967 } else if (end_of_sequence) {
968 ctx->SetStatus(errors::OutOfRange("End of sequence"));
969 } else {
970 for (int i = 0; i < components.size(); ++i) {
971 // TODO(mrry): Check that the shapes match the shape attrs.
972 ctx->set_output(i, components[i]);
973 }
974 }
975 done();
976 },
977 std::move(done)));
978 }
979
Compute(OpKernelContext * ctx)980 void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) {
981 IteratorResource* iterator;
982 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
983 core::ScopedUnref unref_iterator(iterator);
984 std::vector<Tensor> components;
985 bool end_of_sequence = false;
986
987 OP_REQUIRES_OK(ctx, iterator->GetNext(IteratorContext(ctx), &components,
988 &end_of_sequence));
989 OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
990
991 for (int i = 0; i < components.size(); ++i) {
992 // TODO(mrry): Check that the shapes match the shape attrs.
993 ctx->set_output(i, components[i]);
994 }
995 }
996
ComputeAsync(OpKernelContext * ctx,DoneCallback done)997 void IteratorGetNextAsOptionalOp::ComputeAsync(OpKernelContext* ctx,
998 DoneCallback done) {
999 IteratorResource* iterator;
1000 OP_REQUIRES_OK_ASYNC(
1001 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
1002 // The call to `iterator->GetNext()` may block and depend on an
1003 // inter-op thread pool thread, so we issue the call from the
1004 // owned thread pool.
1005 background_worker_.Schedule(std::bind(
1006 [this, ctx, iterator](DoneCallback done) {
1007 std::vector<Tensor> components;
1008 bool end_of_sequence = false;
1009
1010 Status s = iterator->GetNext(IteratorContext(ctx), &components,
1011 &end_of_sequence);
1012 // NOTE(mrry): We must unref the iterator before calling `done()`, to
1013 // avoid destruction races.
1014 iterator->Unref();
1015
1016 if (!s.ok()) {
1017 ctx->SetStatus(s);
1018 } else if (end_of_sequence) {
1019 OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done);
1020 } else {
1021 for (int i = 0; i < components.size(); ++i) {
1022 OP_REQUIRES_ASYNC(
1023 ctx, components[i].dtype() == output_types_[i],
1024 errors::InvalidArgument(
1025 "The given optional does not match the expected type for "
1026 "component ",
1027 i, ". Expected: ", DataTypeString(output_types_[i]),
1028 ". Actual: ", DataTypeString(components[i].dtype()), "."),
1029 done);
1030 OP_REQUIRES_ASYNC(
1031 ctx, output_shapes_[i].IsCompatibleWith(components[i].shape()),
1032 errors::InvalidArgument(
1033 "The given optional does not match the expected shape "
1034 "for component ",
1035 i, ". Expected: ", output_shapes_[i].DebugString(),
1036 ". Actual: ", components[i].shape().DebugString(), "."),
1037 done);
1038 }
1039
1040 OP_REQUIRES_OK_ASYNC(
1041 ctx,
1042 WriteOptionalWithValueToOutput(ctx, 0, std::move(components)),
1043 done);
1044 }
1045 done();
1046 },
1047 std::move(done)));
1048 }
1049
Compute(OpKernelContext * ctx)1050 void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
1051 const Tensor& resource_handle_t = ctx->input(0);
1052 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1053 errors::InvalidArgument("resource_handle must be a scalar"));
1054
1055 // Validate that the handle corresponds to a real resource, and
1056 // that it is an IteratorResource.
1057 IteratorResource* iterator_resource;
1058 OP_REQUIRES_OK(
1059 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1060 iterator_resource->Unref();
1061
1062 Tensor* string_handle_t;
1063 OP_REQUIRES_OK(ctx,
1064 ctx->allocate_output(0, TensorShape({}), &string_handle_t));
1065 string_handle_t->scalar<string>()() =
1066 resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
1067 }
1068
IteratorFromStringHandleOp(OpKernelConstruction * ctx)1069 IteratorFromStringHandleOp::IteratorFromStringHandleOp(
1070 OpKernelConstruction* ctx)
1071 : OpKernel(ctx) {
1072 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
1073 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
1074 OP_REQUIRES(
1075 ctx,
1076 output_dtypes_.empty() || output_shapes_.empty() ||
1077 output_dtypes_.size() == output_shapes_.size(),
1078 errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
1079 "are set, they must have the same length."));
1080 }
1081
Compute(OpKernelContext * ctx)1082 void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
1083 const Tensor& string_handle_t = ctx->input(0);
1084 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
1085 errors::InvalidArgument("string_handle must be a scalar"));
1086
1087 ResourceHandle resource_handle;
1088 OP_REQUIRES(
1089 ctx, resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
1090 errors::InvalidArgument(
1091 "Could not parse string_handle as a valid ResourceHandle"));
1092
1093 OP_REQUIRES(
1094 ctx, resource_handle.device() == ctx->device()->attributes().name(),
1095 errors::InvalidArgument("Attempted create an iterator on device \"",
1096 ctx->device()->attributes().name(),
1097 "\" from handle defined on device \"",
1098 resource_handle.device(), "\""));
1099
1100 // Validate that the handle corresponds to a real resource, and
1101 // that it is an IteratorResource.
1102 IteratorResource* iterator_resource;
1103 OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource));
1104 core::ScopedUnref unref_iterator(iterator_resource);
1105 if (!output_dtypes_.empty()) {
1106 OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
1107 iterator_resource->output_dtypes()));
1108 }
1109 if (!output_shapes_.empty()) {
1110 OP_REQUIRES_OK(ctx,
1111 VerifyShapesCompatible(output_shapes_,
1112 iterator_resource->output_shapes()));
1113 }
1114
1115 Tensor* resource_handle_t;
1116 OP_REQUIRES_OK(ctx,
1117 ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
1118 resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
1119 }
1120
1121 namespace {
1122
1123 class SerializeIteratorOp : public OpKernel {
1124 public:
SerializeIteratorOp(OpKernelConstruction * ctx)1125 explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
1126
Compute(OpKernelContext * ctx)1127 void Compute(OpKernelContext* ctx) override {
1128 const Tensor& resource_handle_t = ctx->input(0);
1129 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1130 errors::InvalidArgument("resource_handle must be a scalar"));
1131
1132 // Validate that the handle corresponds to a real resource, and
1133 // that it is an IteratorResource.
1134 IteratorResource* iterator_resource;
1135 OP_REQUIRES_OK(
1136 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1137 core::ScopedUnref unref_iterator(iterator_resource);
1138 Tensor* variant_t;
1139 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t));
1140 IteratorStateVariant v;
1141 OP_REQUIRES_OK(ctx, v.InitializeFromIterator(ctx, iterator_resource));
1142 variant_t->scalar<Variant>()() = v;
1143 }
1144 };
1145
1146 class DeserializeIteratorOp : public OpKernel {
1147 public:
DeserializeIteratorOp(OpKernelConstruction * ctx)1148 explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
1149
Compute(OpKernelContext * ctx)1150 void Compute(OpKernelContext* ctx) override {
1151 // Validate that the handle corresponds to a real resource, and
1152 // that it is an IteratorResource.
1153 IteratorResource* iterator_resource;
1154 OP_REQUIRES_OK(
1155 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1156 core::ScopedUnref unref_iterator(iterator_resource);
1157 Variant variant = ctx->input(1).scalar<Variant>()();
1158 auto* wrapper = variant.get<IteratorStateVariant>();
1159 OP_REQUIRES(ctx, wrapper != nullptr,
1160 errors::InvalidArgument(
1161 "DeserializeIteratorOp: Unable to parse variant tensor."));
1162 OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, wrapper->get()));
1163 }
1164 };
1165
1166 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
1167 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
1168 IteratorHandleOp);
1169 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU).Priority(1),
1170 IteratorHandleOp);
1171 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2),
1172 MakeIteratorOp);
1173 REGISTER_KERNEL_BUILDER(
1174 Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"),
1175 MakeIteratorOp);
1176 REGISTER_KERNEL_BUILDER(
1177 Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2),
1178 AnonymousIteratorHandleOp);
1179 REGISTER_KERNEL_BUILDER(
1180 Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1),
1181 AnonymousIteratorHandleOp);
1182 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
1183 ToSingleElementOp);
1184 REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU),
1185 ReduceDatasetOp);
1186 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
1187 OneShotIteratorOp);
1188 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU).Priority(2),
1189 IteratorGetNextOp);
1190 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU).Priority(1),
1191 IteratorGetNextOp);
1192 REGISTER_KERNEL_BUILDER(
1193 Name("IteratorGetNextSync").Device(DEVICE_CPU).Priority(2),
1194 IteratorGetNextSyncOp);
1195 REGISTER_KERNEL_BUILDER(
1196 Name("IteratorGetNextSync").Device(DEVICE_GPU).Priority(1),
1197 IteratorGetNextSyncOp);
1198 REGISTER_KERNEL_BUILDER(
1199 Name("IteratorGetNextAsOptional").Device(DEVICE_CPU).Priority(2),
1200 IteratorGetNextAsOptionalOp);
1201 REGISTER_KERNEL_BUILDER(
1202 Name("IteratorGetNextAsOptional").Device(DEVICE_GPU).Priority(1),
1203 IteratorGetNextAsOptionalOp);
1204 REGISTER_KERNEL_BUILDER(
1205 Name("IteratorToStringHandle").Device(DEVICE_CPU).Priority(2),
1206 IteratorToStringHandleOp);
1207 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
1208 .Device(DEVICE_GPU)
1209 .HostMemory("string_handle")
1210 .Priority(1),
1211 IteratorToStringHandleOp);
1212 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
1213 IteratorFromStringHandleOp);
1214 REGISTER_KERNEL_BUILDER(
1215 Name("IteratorFromStringHandleV2").Device(DEVICE_CPU).Priority(2),
1216 IteratorFromStringHandleOp);
1217 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
1218 .Device(DEVICE_GPU)
1219 .HostMemory("string_handle")
1220 .Priority(1),
1221 IteratorFromStringHandleOp);
1222 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
1223 SerializeIteratorOp);
1224 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
1225 DeserializeIteratorOp);
1226
1227 } // namespace
1228
1229 } // namespace data
1230 } // namespace tensorflow
1231