• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/iterator_ops.h"
16 #include <memory>
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/core/common_runtime/graph_runner.h"
20 #include "tensorflow/core/common_runtime/renamed_device.h"
21 #include "tensorflow/core/common_runtime/threadpool_device.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/function_handle_cache.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/resource_op_kernel.h"
26 #include "tensorflow/core/framework/stats_aggregator.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/variant_op_registry.h"
29 #include "tensorflow/core/graph/graph_constructor.h"
30 #include "tensorflow/core/kernels/data/dataset_utils.h"
31 #include "tensorflow/core/kernels/data/optional_ops.h"
32 #include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
33 #include "tensorflow/core/kernels/ops_util.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/lib/random/random.h"
37 #include "tensorflow/core/lib/strings/strcat.h"
38 #include "tensorflow/core/lib/strings/stringprintf.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/mutex.h"
41 #include "tensorflow/core/public/session_options.h"
42 
43 namespace tensorflow {
44 namespace data {
45 namespace {
46 
47 // See documentation in ../../ops/dataset_ops.cc for a high-level
48 // description of the following ops.
49 
50 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
51 
52 }  // namespace
53 
54 class IteratorResource : public ResourceBase {
55  public:
IteratorResource(Env * env,const DataTypeVector & output_dtypes,const std::vector<PartialTensorShape> & output_shapes,const int,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib)56   IteratorResource(Env* env, const DataTypeVector& output_dtypes,
57                    const std::vector<PartialTensorShape>& output_shapes,
58                    const int /*unused: graph_def_version*/,
59                    std::unique_ptr<DeviceMgr> device_mgr,
60                    std::unique_ptr<FunctionLibraryDefinition> flib_def,
61                    std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
62                    FunctionLibraryRuntime* lib)
63       : unbounded_thread_pool_(env, "tf_data_iterator_resource"),
64         device_mgr_(std::move(device_mgr)),
65         iterator_state_(std::make_shared<State>(
66             std::move(flib_def), std::move(pflr), lib, nullptr /* iterator */)),
67         output_dtypes_(output_dtypes),
68         output_shapes_(output_shapes) {}
69 
GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)70   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
71                  bool* end_of_sequence) {
72     std::shared_ptr<State> captured_state;
73     {
74       tf_shared_lock l(mu_);
75       captured_state = iterator_state_;
76     }
77     if (captured_state->iterator) {
78       IteratorContext::Params params(ctx);
79       params.lib = captured_state->lib;
80       params.function_handle_cache =
81           captured_state->function_handle_cache.get();
82       params.resource_mgr = &captured_state->resource_mgr;
83       params.thread_factory = unbounded_thread_pool_.get_thread_factory();
84       return captured_state->iterator->GetNext(
85           IteratorContext(std::move(params)), out_tensors, end_of_sequence);
86     } else {
87       return errors::FailedPrecondition(
88           "GetNext() failed because the iterator has not been initialized. "
89           "Ensure that you have run the initializer operation for this "
90           "iterator before getting the next element.");
91     }
92   }
93 
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)94   Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
95                  bool* end_of_sequence) {
96     return GetNext(&ctx, out_tensors, end_of_sequence);
97   }
98 
Save(SerializationContext * ctx,IteratorStateWriter * writer)99   Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
100     std::shared_ptr<State> captured_state;
101     {
102       tf_shared_lock l(mu_);
103       captured_state = iterator_state_;
104     }
105     if (captured_state) {
106       SerializationContext::Params params;
107       // The iterator state may contain functions that are not present
108       // in ctx's function library. Namely, an iterator may be restored from
109       // a serialized iterator with a modified function library (for example, as
110       // a result of OptimizeDataset). These modified functions are needed
111       // to serialize the iterator again.
112       params.flib_def = captured_state->flib_def.get();
113       params.input_list = ctx->input_list();
114       params.optimization_only = ctx->optimization_only();
115       SerializationContext ctx_with_functions(params);
116       return captured_state->iterator->Save(&ctx_with_functions, writer);
117     } else {
118       return errors::FailedPrecondition(
119           "Save() failed because the iterator has not been initialized. "
120           "Ensure that you have run the initializer operation for this "
121           "iterator before saving it.");
122     }
123   }
124 
Restore(OpKernelContext * ctx,IteratorStateReader * reader)125   Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
126     string serialized_graph_def;
127     TF_RETURN_IF_ERROR(reader->ReadScalar(DatasetBase::kDatasetGraphKey,
128                                           &serialized_graph_def));
129     GraphDef graph_def;
130     if (!graph_def.ParseFromString(serialized_graph_def)) {
131       return errors::Internal("Error parsing dataset GraphDef.");
132     }
133     string output_node;
134     TF_RETURN_IF_ERROR(reader->ReadScalar(
135         DatasetBase::kDatasetGraphOutputNodeKey, &output_node));
136     DatasetBase* dataset = nullptr;
137     Graph graph(OpRegistry::Global());
138     TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
139     std::vector<Tensor> outputs;
140     GraphRunner graph_runner(ctx->env());
141 
142     // Build a new FLR that knows about the functions in the graph, and use
143     // it for all operations on the restored iterator.
144     // NOTE(mrry): We clone the existing FLR and use it in the GraphRunner
145     // because some of the OpKernels in the graph might call functions that are
146     // only defined in the loaded GraphDef.
147     FunctionLibraryRuntime* lib;
148     std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
149     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
150     TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def, &pflr, &lib));
151 
152     // Some function names may be duplicated (for example, if the serialized
153     // graph has an optimized function that retains its original name). We
154     // override functions in flib_def in the event of conflict. It is
155     // safe to assume that any node in the serialized graph is referring to the
156     // serialized function when there is a conflict.
157     TF_RETURN_IF_ERROR(
158         AddToFunctionLibrary(flib_def.get(), graph_def.library()));
159     std::unique_ptr<State> new_state = absl::make_unique<State>(
160         std::move(flib_def), std::move(pflr), lib, nullptr /* iterator */);
161 
162     TF_RETURN_IF_ERROR(
163         graph_runner.Run(&graph, new_state->lib, {}, {output_node}, &outputs));
164     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
165 
166     IteratorContext::Params params(ctx);
167     params.lib = new_state->lib;
168     params.function_handle_cache = new_state->function_handle_cache.get();
169     params.resource_mgr = &new_state->resource_mgr;
170     params.thread_factory = unbounded_thread_pool_.get_thread_factory();
171 
172     TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
173                                              "Iterator", &new_state->iterator));
174     TF_RETURN_IF_ERROR(
175         VerifyTypesMatch(output_dtypes_, new_state->iterator->output_dtypes()));
176     TF_RETURN_IF_ERROR(VerifyShapesCompatible(
177         output_shapes_, new_state->iterator->output_shapes()));
178 
179     {
180       IteratorContext::Params params(ctx);
181       params.lib = new_state->lib;
182       params.function_handle_cache = new_state->function_handle_cache.get();
183       params.resource_mgr = &new_state->resource_mgr;
184       DeviceBase* device = new_state->lib->device();
185       params.allocator_getter = [device](AllocatorAttributes attrs) {
186         return device->GetAllocator(attrs);
187       };
188       params.thread_factory = unbounded_thread_pool_.get_thread_factory();
189       IteratorContext iter_ctx(std::move(params));
190       TF_RETURN_IF_ERROR(new_state->iterator->Restore(&iter_ctx, reader));
191     }
192 
193     mutex_lock l(mu_);
194     iterator_state_ = std::move(new_state);
195     return Status::OK();
196   }
197 
AddLibrary(const FunctionLibraryDefinition & flib_def)198   Status AddLibrary(const FunctionLibraryDefinition& flib_def) {
199     mutex_lock l(mu_);
200     return iterator_state_->flib_def->AddLibrary(flib_def);
201   }
202 
SetIteratorFromDataset(OpKernelContext * ctx,DatasetBase * dataset)203   Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset) {
204     std::shared_ptr<State> new_state;
205     {
206       tf_shared_lock l(mu_);
207       new_state = std::make_shared<State>(
208           iterator_state_->flib_def, iterator_state_->pflr,
209           iterator_state_->lib, nullptr /* function_handle_cache */,
210           nullptr /* iterator */);
211     }
212 
213     // Ensure that the iterator has access to all functions in the current
214     // subgraph, because some functions may have been defined after the resource
215     // was initially created.
216     Status s = new_state->flib_def->AddLibrary(
217         *ctx->function_library()->GetFunctionLibraryDefinition());
218 
219     if (!s.ok()) {
220       // Adding functions to `flib_def_` may fail, if there are clashes between
221       // the function names in (e.g.) a restored graph and the currently
222       // executing graph. In that case, we create a new function runtime for
223       // this iterator, based on the current `OpKernelContext`, which will have
224       // the functions we need.
225       FunctionLibraryRuntime* lib;
226       std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
227       std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
228       TF_RETURN_IF_ERROR(
229           ctx->function_library()->Clone(&flib_def, &pflr, &lib));
230       new_state->flib_def = std::move(flib_def);
231       new_state->pflr = std::move(pflr);
232       new_state->lib = lib;
233     }
234 
235     new_state->function_handle_cache =
236         absl::make_unique<FunctionHandleCache>(new_state->lib);
237     // Create new iterator.
238     std::unique_ptr<IteratorBase> iterator;
239     IteratorContext::Params params(ctx);
240     params.lib = new_state->lib;
241     params.function_handle_cache = new_state->function_handle_cache.get();
242     params.resource_mgr = &new_state->resource_mgr;
243     params.thread_factory = unbounded_thread_pool_.get_thread_factory();
244     TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
245                                              "Iterator", &iterator));
246     TF_RETURN_IF_ERROR(
247         VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
248     TF_RETURN_IF_ERROR(
249         VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
250     std::swap(new_state->iterator, iterator);
251 
252     mutex_lock l(mu_);
253     std::swap(iterator_state_, new_state);
254     return Status::OK();
255   }
256 
DebugString() const257   string DebugString() const override { return "Iterator resource"; }
258 
output_dtypes() const259   const DataTypeVector& output_dtypes() const { return output_dtypes_; }
260 
output_shapes() const261   const std::vector<PartialTensorShape>& output_shapes() const {
262     return output_shapes_;
263   }
264 
265  private:
266   struct State {
Statetensorflow::data::IteratorResource::State267     State(std::shared_ptr<FunctionLibraryDefinition> flib_def,
268           std::shared_ptr<ProcessFunctionLibraryRuntime> pflr,
269           FunctionLibraryRuntime* lib, std::unique_ptr<IteratorBase> iterator)
270         : flib_def(flib_def),
271           pflr(pflr),
272           lib(lib),
273           function_handle_cache(absl::make_unique<FunctionHandleCache>(lib)),
274           iterator(std::move(iterator)) {}
275 
Statetensorflow::data::IteratorResource::State276     State(std::shared_ptr<FunctionLibraryDefinition> flib_def,
277           std::shared_ptr<ProcessFunctionLibraryRuntime> pflr,
278           FunctionLibraryRuntime* lib,
279           std::unique_ptr<FunctionHandleCache> function_handle_cache,
280           std::unique_ptr<IteratorBase> iterator)
281         : flib_def(flib_def),
282           pflr(pflr),
283           lib(lib),
284           function_handle_cache(std::move(function_handle_cache)),
285           iterator(std::move(iterator)) {}
286 
287     std::shared_ptr<FunctionLibraryDefinition> flib_def;
288     std::shared_ptr<ProcessFunctionLibraryRuntime> pflr;
289     FunctionLibraryRuntime* lib = nullptr;  // not owned.
290     std::unique_ptr<FunctionHandleCache> function_handle_cache;
291     ResourceMgr resource_mgr;
292     std::unique_ptr<IteratorBase> iterator;
293   };
294 
295   UnboundedThreadPool unbounded_thread_pool_;
296   mutex mu_;
297   const std::unique_ptr<DeviceMgr> device_mgr_ GUARDED_BY(mu_);
298   std::shared_ptr<State> iterator_state_ GUARDED_BY(mu_);
299   const DataTypeVector output_dtypes_;
300   const std::vector<PartialTensorShape> output_shapes_;
301 };
302 
303 namespace {
304 
305 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
306 // The get() method returns an IteratorStateReader which can be used
307 // to restore iterator state.
308 //
309 // Usage example:
310 //
311 // Encoding:
312 //
313 //   Tensor t(DT_VARIANT, TensorShape({}));
314 //   t->scalar<Variant>()() = IteratorStateVariant(iterator_resource);
315 //
316 // Encode() sets the type_name of the VariantTensorData object to
317 // IteratorStateVariant::TypeName().
318 //
319 // Decoding:
320 //
321 //   Variant v = <VariantTensorDataProto object>;
322 //   DecodeUnaryVariant(&v);
323 //   IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
324 //   iterator_resource->Restore(ctx, wrapper->get())
325 //
326 // The type_name of the VariantTensorData object to be decoded must
327 // match IteratorStateVariant::TypeName().
328 class IteratorStateVariant {
329  public:
IteratorStateVariant()330   IteratorStateVariant() : data_(nullptr) {}
IteratorStateVariant(const IteratorStateVariant & other)331   IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
332     if (other.data_) {
333       Decode(*other.data_);
334     }
335   }
336   // Initializes this object with the current state of the iterator so
337   // that it can be written on the next call to Encode().
InitializeFromIterator(OpKernelContext * ctx,IteratorResource * iterator_resource)338   Status InitializeFromIterator(OpKernelContext* ctx,
339                                 IteratorResource* iterator_resource) {
340     SerializationContext::Params params;
341     params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
342     SerializationContext serialization_ctx(params);
343     data_ = absl::make_unique<VariantTensorData>();
344     data_->set_type_name(TypeName());
345     VariantTensorDataWriter writer(data_.get());
346     TF_RETURN_IF_ERROR(iterator_resource->Save(&serialization_ctx, &writer));
347     TF_RETURN_IF_ERROR(writer.Flush());
348     return Status::OK();
349   }
TypeName() const350   string TypeName() const { return kIteratorVariantTypeName; }
Encode(VariantTensorData * data) const351   void Encode(VariantTensorData* data) const { *data = *data_; }
Decode(VariantTensorData data)352   bool Decode(VariantTensorData data) {
353     if (data.type_name() != TypeName()) {
354       return false;
355     }
356     std::unique_ptr<VariantTensorData> tensor_data =
357         absl::make_unique<VariantTensorData>();
358     std::swap(*tensor_data, data);
359     std::unique_ptr<VariantTensorDataReader> reader =
360         absl::make_unique<VariantTensorDataReader>(tensor_data.get());
361     data_ = std::move(tensor_data);
362     reader_ = std::move(reader);
363     return true;
364   }
get()365   IteratorStateReader* get() { return reader_.get(); }
DebugString() const366   string DebugString() const {
367     if (data_) {
368       return strings::StrCat("IteratorStateVariant<", data_->DebugString(),
369                              ">");
370     } else {
371       return strings::StrCat("IteratorStateVariant<empty>");
372     }
373   }
374 
375  private:
376   std::unique_ptr<IteratorStateReader> reader_;
377   std::unique_ptr<VariantTensorData> data_;
378 };
379 
380 // Register the reader class in the global variant decode_fn registry
381 // so that a Variant containing a serialized representation of iterator state
382 // can be decoded using DecodeUnaryVariant. If we don't do this we will need
383 // to manually decode the returned Variant using MaybeDecodeAndCopy in
384 // DeserializeIteratorOp which is not recommended.
385 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
386                                        kIteratorVariantTypeName);
387 
388 }  // namespace
389 
390 // Note that IteratorHandleOp holds a reference to the resource it creates. If
391 // cleaning up resources with DestroyResourceOp is important, consider creating
392 // resource containers with AnonymousIteratorHandleOp instead.
IteratorHandleOp(OpKernelConstruction * ctx)393 IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
394     : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
395   OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
396   OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
397   OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
398 }
399 
400 // The resource is deleted from the resource manager only when it is private
401 // to kernel. Ideally the resource should be deleted when it is no longer held
402 // by anyone, but it would break backward compatibility.
~IteratorHandleOp()403 IteratorHandleOp::~IteratorHandleOp() {
404   if (resource_ != nullptr) {
405     resource_->Unref();
406     if (cinfo_.resource_is_private_to_kernel()) {
407       if (!cinfo_.resource_manager()
408                ->template Delete<IteratorResource>(cinfo_.container(),
409                                                    cinfo_.name())
410                .ok()) {
411         // Do nothing; the resource can have been deleted by session resets.
412       }
413     }
414   }
415 }
416 
Compute(OpKernelContext * context)417 void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) {
418   {
419     mutex_lock l(mu_);
420     if (resource_ == nullptr) {
421       FunctionLibraryRuntime* lib;
422       std::unique_ptr<DeviceMgr> device_mgr(nullptr);
423       std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
424       std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
425       // If the iterator is shared then we construct a new FLR, and pass that
426       // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
427       // functions from the iterator. We may add this functionality if there
428       // is sufficient demand, but it will require a significant refactoring.
429       if (!name_.empty()) {
430         lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
431       } else {
432         OP_REQUIRES_OK(context, context->function_library()->Clone(
433                                     &flib_def, &pflr, &lib));
434       }
435 
436       ResourceMgr* mgr = context->resource_manager();
437       OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
438 
439       IteratorResource* resource;
440       OP_REQUIRES_OK(
441           context,
442           mgr->LookupOrCreate<IteratorResource>(
443               cinfo_.container(), cinfo_.name(), &resource,
444               [context, lib, &device_mgr, &flib_def, &pflr,
445                this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
446                 *ret = new IteratorResource(
447                     context->env(), output_dtypes_, output_shapes_,
448                     graph_def_version_, std::move(device_mgr),
449                     std::move(flib_def), std::move(pflr), lib);
450                 return Status::OK();
451               }));
452 
453       Status s = VerifyResource(resource);
454       if (TF_PREDICT_FALSE(!s.ok())) {
455         resource->Unref();
456         context->SetStatus(s);
457         return;
458       }
459 
460       resource_ = resource;
461     }
462   }
463   OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
464                               context, 0, cinfo_.container(), cinfo_.name(),
465                               MakeTypeIndex<IteratorResource>()));
466 }
467 
VerifyResource(IteratorResource * resource)468 Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
469   TF_RETURN_IF_ERROR(
470       VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
471   TF_RETURN_IF_ERROR(
472       VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
473   return Status::OK();
474 }
475 
CreatePrivateFLR(OpKernelContext * ctx,std::unique_ptr<DeviceMgr> * device_mgr,std::unique_ptr<FunctionLibraryDefinition> * flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * pflr)476 FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
477     OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
478     std::unique_ptr<FunctionLibraryDefinition>* flib_def,
479     std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
480   // Wrap the existing device in order to see any captured resources
481   // in its resource manager. The existing device will outlive the
482   // IteratorResource, because we are storing the IteratorResource
483   // in that device's resource manager.
484   *device_mgr = absl::make_unique<DeviceMgr>(RenamedDevice::NewRenamedDevice(
485       ctx->device()->name(), down_cast<Device*>(ctx->device()),
486       false /* owns_underlying */, false /* isolate_session_state */));
487   *flib_def = absl::make_unique<FunctionLibraryDefinition>(
488       *ctx->function_library()->GetFunctionLibraryDefinition());
489   *pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
490       device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(),
491       OptimizerOptions{} /* TODO(mrry): OptimizerOptions? */,
492       nullptr /* TODO(mrry): ClusterFLR */);
493 
494   return (*pflr)->GetFLR(ctx->device()->name());
495 }
496 
497 // Like IteratorHandleOp, but creates handles which are never shared, and does
498 // not hold a reference to these handles. The latter is important for eager
499 // execution, since OpKernel instances generally live as long as the program
500 // running them.
AnonymousIteratorHandleOp(OpKernelConstruction * context)501 AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
502     OpKernelConstruction* context)
503     : OpKernel(context), graph_def_version_(context->graph_def_version()) {
504   OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_));
505   OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_));
506 }
507 
Compute(OpKernelContext * context)508 void AnonymousIteratorHandleOp::Compute(OpKernelContext* context) {
509   FunctionLibraryRuntime* lib;
510   std::unique_ptr<DeviceMgr> device_mgr(nullptr);
511   std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
512   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
513   OP_REQUIRES_OK(context,
514                  context->function_library()->Clone(&flib_def, &pflr, &lib));
515 
516   ResourceMgr* mgr = context->resource_manager();
517 
518   const string container_name = "AnonymousIterator";
519   string unique_name;
520   {
521     mutex_lock l(static_resource_lookup_mutex_);
522     while (true) {  // Find an unused name
523       IteratorResource* existing_resource = nullptr;
524       unique_name = strings::StrCat("AnonymousIterator", current_id_++);
525       Status status = mgr->Lookup<IteratorResource>(container_name, unique_name,
526                                                     &existing_resource);
527       if (status.code() == error::NOT_FOUND) {
528         break;
529       }
530       OP_REQUIRES_OK(context, status);
531       existing_resource->Unref();
532     }
533     IteratorResource* new_resource = new IteratorResource(
534         context->env(), output_dtypes_, output_shapes_, graph_def_version_,
535         std::move(device_mgr), std::move(flib_def), std::move(pflr), lib);
536     // Create the resource with our chosen name under the resource lookup
537     // mutex to avoid another kernel racily creating a resource with this
538     // name.
539     OP_REQUIRES_OK(context, mgr->Create<IteratorResource>(
540                                 container_name, unique_name, new_resource));
541   }
542   OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
543                               context, 0, container_name, unique_name,
544                               MakeTypeIndex<IteratorResource>()));
545 }
546 
547 // Static initializers for AnonymousIteratorHandleOp id counting.
548 mutex AnonymousIteratorHandleOp::static_resource_lookup_mutex_{
549     LINKER_INITIALIZED};
550 int64 AnonymousIteratorHandleOp::current_id_(0);
551 
Compute(OpKernelContext * ctx)552 void MakeIteratorOp::Compute(OpKernelContext* ctx) {
553   DatasetBase* dataset;
554   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
555   IteratorResource* iterator_resource;
556   OP_REQUIRES_OK(
557       ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
558   core::ScopedUnref unref(iterator_resource);
559   OP_REQUIRES_OK(ctx, iterator_resource->SetIteratorFromDataset(ctx, dataset));
560 }
561 
562 namespace {
563 
564 class ToSingleElementOp : public AsyncOpKernel {
565  public:
ToSingleElementOp(OpKernelConstruction * ctx)566   explicit ToSingleElementOp(OpKernelConstruction* ctx)
567       : AsyncOpKernel(ctx),
568         background_worker_(ctx->env(), "tf_data_to_single_element") {}
569 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)570   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
571     // The call to `iterator->GetNext()` may block and depend on an
572     // inter-op thread pool thread, so we issue the call from the
573     // owned thread pool.
574     background_worker_.Schedule([ctx, done]() {
575       DatasetBase* dataset;
576       OP_REQUIRES_OK_ASYNC(
577           ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
578       std::unique_ptr<IteratorBase> iterator;
579       IteratorContext::Params params(ctx);
580       std::unique_ptr<FunctionHandleCache> function_handle_cache =
581           absl::make_unique<FunctionHandleCache>(params.lib);
582       params.function_handle_cache = function_handle_cache.get();
583       std::unique_ptr<ResourceMgr> resource_mgr =
584           absl::make_unique<ResourceMgr>();
585       params.resource_mgr = resource_mgr.get();
586       IteratorContext iter_ctx(std::move(params));
587 
588       OP_REQUIRES_OK_ASYNC(
589           ctx,
590           dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator),
591           done);
592 
593       // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
594       // avoid destruction races.
595       IteratorBase* raw_iterator = iterator.release();
596       auto cleanup = gtl::MakeCleanup([raw_iterator, done] {
597         delete raw_iterator;
598         done();
599       });
600       std::vector<Tensor> components;
601       components.reserve(dataset->output_dtypes().size());
602       bool end_of_sequence = false;
603 
604       Status s =
605           raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
606       if (!s.ok()) {
607         ctx->SetStatus(s);
608         return;
609       }
610       if (end_of_sequence) {
611         ctx->SetStatus(errors::InvalidArgument("Dataset was empty."));
612         return;
613       }
614       for (int i = 0; i < components.size(); ++i) {
615         // TODO(mrry): Check that the shapes match the shape attrs.
616         ctx->set_output(i, components[i]);
617       }
618 
619       components.clear();
620       Status s2 =
621           raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
622       if (!s2.ok()) {
623         ctx->SetStatus(s2);
624         return;
625       }
626       if (!end_of_sequence) {
627         ctx->SetStatus(
628             errors::InvalidArgument("Dataset had more than one element."));
629         return;
630       }
631     });
632   }
633 
634  private:
635   BackgroundWorker background_worker_;
636 };
637 
638 class ReduceDatasetOp : public AsyncOpKernel {
639  public:
ReduceDatasetOp(OpKernelConstruction * ctx)640   explicit ReduceDatasetOp(OpKernelConstruction* ctx)
641       : AsyncOpKernel(ctx),
642         background_worker_(ctx->env(), "tf_data_reduce_dataset") {
643     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &reduce_func_));
644     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
645     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
646     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
647                                      &use_inter_op_parallelism_));
648   }
649 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)650   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
651     // The call to `iterator->GetNext()` may block and depend on an
652     // inter-op thread pool thread, so we issue the call from the
653     // owned thread pool.
654     background_worker_.Schedule([this, ctx, done]() {
655       DatasetBase* dataset;
656       OP_REQUIRES_OK_ASYNC(
657           ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
658       OpInputList inputs;
659       OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs),
660                            done);
661       std::vector<Tensor> state(inputs.begin(), inputs.end());
662 
663       std::unique_ptr<CapturedFunction> captured_func;
664       OP_REQUIRES_OK_ASYNC(
665           ctx,
666           CapturedFunction::Create(reduce_func_, ctx, "other_arguments",
667                                    use_inter_op_parallelism_, &captured_func),
668           done);
669 
670       IteratorContext::Params params(ctx);
671       std::unique_ptr<FunctionHandleCache> function_handle_cache =
672           absl::make_unique<FunctionHandleCache>(params.lib);
673       params.function_handle_cache = function_handle_cache.get();
674       std::unique_ptr<ResourceMgr> resource_mgr =
675           absl::make_unique<ResourceMgr>();
676       params.resource_mgr = resource_mgr.get();
677       IteratorContext iter_ctx(std::move(params));
678       std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func;
679       OP_REQUIRES_OK_ASYNC(
680           ctx,
681           captured_func->Instantiate(&iter_ctx, &instantiated_captured_func),
682           done);
683 
684       std::unique_ptr<IteratorBase> iterator;
685       OP_REQUIRES_OK_ASYNC(
686           ctx, dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator),
687           done);
688 
689       // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
690       // avoid destruction races.
691       IteratorBase* raw_iterator = iterator.release();
692       auto cleanup = gtl::MakeCleanup([raw_iterator, done] {
693         delete raw_iterator;
694         done();
695       });
696 
697       // Iterate through the input dataset.
698       Status status;
699       while (true) {
700         std::vector<Tensor> next_input_element;
701         bool end_of_input;
702         status = raw_iterator->GetNext(&iter_ctx, &next_input_element,
703                                        &end_of_input);
704         if (!status.ok() || end_of_input) {
705           break;
706         }
707 
708         // Run the reduce function to update the current state.
709         std::vector<Tensor> args;
710         args.reserve(state.size() + next_input_element.size());
711         std::copy(state.begin(), state.end(), std::back_inserter(args));
712         std::copy(next_input_element.begin(), next_input_element.end(),
713                   std::back_inserter(args));
714 
715         std::vector<Tensor> reduce_func_output;
716         status = instantiated_captured_func->Run(&iter_ctx, std::move(args),
717                                                  &reduce_func_output);
718         if (!status.ok()) {
719           break;
720         }
721         std::swap(reduce_func_output, state);
722       }
723 
724       if (!status.ok()) {
725         ctx->SetStatus(status);
726         return;
727       }
728       for (int i = 0; i < state.size(); ++i) {
729         OP_REQUIRES_ASYNC(
730             ctx, state[i].dtype() == output_types_[i],
731             errors::InvalidArgument(
732                 "The result does not match the expected type for component ", i,
733                 ". Expected: ", DataTypeString(output_types_[i]),
734                 ". Actual: ", DataTypeString(state[i].dtype()), "."),
735             done);
736         OP_REQUIRES_ASYNC(
737             ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()),
738             errors::InvalidArgument(
739                 "The result does not match the expected shape for component ",
740                 i, ". Expected: ", output_shapes_[i].DebugString(),
741                 ". Actual: ", state[i].shape().DebugString(), "."),
742             done);
743         ctx->set_output(i, state[i]);
744       }
745     });
746   }
747 
748  private:
749   NameAttrList reduce_func_;
750   DataTypeVector output_types_;
751   std::vector<PartialTensorShape> output_shapes_;
752   bool use_inter_op_parallelism_;
753   BackgroundWorker background_worker_;
754 };
755 
756 class OneShotIteratorOp : public AsyncOpKernel {
757  public:
OneShotIteratorOp(OpKernelConstruction * ctx)758   explicit OneShotIteratorOp(OpKernelConstruction* ctx)
759       : AsyncOpKernel(ctx),
760         background_worker_(ctx->env(), "tf_data_one_shot_iterator"),
761         graph_def_version_(ctx->graph_def_version())
762 
763   {
764     string shared_name;
765     OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name));
766     OP_REQUIRES(ctx, shared_name.empty(),
767                 errors::InvalidArgument("OneShotIteratorOp does not currently "
768                                         "support the 'shared_name' attr."));
769     OP_REQUIRES_OK(ctx,
770                    ctx->GetAttr("dataset_factory", &dataset_factory_func_));
771     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
772     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
773   }
774 
~OneShotIteratorOp()775   ~OneShotIteratorOp() override {
776     if (iterator_resource_ != nullptr) {
777       iterator_resource_->Unref();
778       if (!cinfo_.resource_manager()
779                ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name())
780                .ok()) {
781         // Do nothing; the resource can have been deleted by session resets.
782       }
783     }
784   }
785 
786   // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`,
787   // but due to the fact that `ResourceOpKernel<T>::CreateResource()`
788   // does not provide access to the `OpKernelContext*` and we need
789   // this to invoke the factory function, it's not possible to
790   // implement this kernel by implementing `CreateResource()`.
791   // Furthermore, due to the fact that this kernel might block when
792   // running the initialization function, we must implement this
793   // kernel as an async kernel.
ComputeAsync(OpKernelContext * ctx,DoneCallback done)794   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
795     {
796       mutex_lock l(mu_);
797       if (iterator_resource_ == nullptr && initialization_status_.ok()) {
798         // The initialization thread will call `done`.
799         if (!initialization_started_) {
800           // TODO(mrry): Convert the initialization code to use
801           // callbacks instead of wasting a thread.
802           background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); });
803           initialization_started_ = true;
804         } else {
805           done_callbacks_.emplace_back(ctx, std::move(done));
806         }
807         return;
808       }
809     }
810     ProduceOutput(ctx, done);
811   }
812 
813  private:
Init(OpKernelContext * ctx,const DoneCallback & done)814   void Init(OpKernelContext* ctx, const DoneCallback& done) {
815     IteratorResource* iterator = nullptr;
816     ContainerInfo cinfo;
817     Status s = TryInit(ctx, &iterator, &cinfo);
818 
819     std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run;
820     {
821       mutex_lock l(mu_);
822       if (s.ok()) {
823         iterator_resource_ = iterator;
824         cinfo_ = cinfo;
825       }
826       initialization_status_ = s;
827       std::swap(done_callbacks_, callbacks_to_run);
828     }
829 
830     for (auto&& ctx_done : callbacks_to_run) {
831       ProduceOutput(ctx_done.first, ctx_done.second);
832     }
833     ProduceOutput(ctx, done);
834   }
835 
TryInit(OpKernelContext * ctx,IteratorResource ** iterator,ContainerInfo * cinfo)836   Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
837                  ContainerInfo* cinfo) {
838     TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def()));
839 
840     FunctionLibraryRuntime* lib;
841     std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
842     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
843     TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def, &pflr, &lib));
844 
845     // Create an IteratorResource that will hold the iterator for this op.
846     TF_RETURN_IF_ERROR(
847         ctx->resource_manager()->LookupOrCreate<IteratorResource>(
848             cinfo->container(), cinfo->name(), iterator,
849             [ctx, lib, this, &flib_def, &pflr](IteratorResource** ret)
850                 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
851                   *ret = new IteratorResource(
852                       ctx->env(), output_dtypes_, output_shapes_,
853                       graph_def_version_, nullptr, std::move(flib_def),
854                       std::move(pflr), lib);
855                   return Status::OK();
856                 }));
857 
858     core::ScopedUnref unref_iterator(*iterator);
859 
860     TF_RETURN_IF_ERROR(
861         VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes()));
862     TF_RETURN_IF_ERROR(
863         VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes()));
864 
865     // Call the dataset_factory_func_ to create a new dataset,
866     // over which this op will iterate.
867     FunctionLibraryRuntime::Handle f_handle;
868     TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
869         dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()),
870         &f_handle));
871     FunctionLibraryRuntime::Options opts;
872     opts.cancellation_manager = ctx->cancellation_manager();
873     // Choose a step ID that is guaranteed not to clash with any
874     // Session-generated step ID. DirectSession only generates
875     // non-negative step IDs (contiguous, starting from 0), and
876     // MasterSession generates 56-bit random step IDs whose MSB is
877     // always 0, so a negative random step ID should suffice.
878     opts.step_id = -std::abs(static_cast<int64>(random::New64()));
879     ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
880       ctx->resource_manager()->Cleanup(name).IgnoreError();
881     });
882     opts.step_container = &step_container;
883     opts.runner = ctx->runner();
884     Notification n;
885     Status factory_status;
886     std::vector<Tensor> return_values;
887     ctx->function_library()->Run(opts, f_handle, {}, &return_values,
888                                  [&n, &factory_status](Status s) {
889                                    factory_status.Update(s);
890                                    n.Notify();
891                                  });
892     n.WaitForNotification();
893     TF_RETURN_IF_ERROR(factory_status);
894     if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
895         !TensorShapeUtils::IsScalar(return_values[0].shape())) {
896       return errors::InvalidArgument(
897           "The `dataset_factory` function must return "
898           "a single scalar of dtype DT_VARIANT.");
899     }
900 
901     // Create an iterator for the dataset that was created in the
902     // factory function.
903     DatasetBase* dataset;
904     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
905     TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
906     (*iterator)->Ref();
907     return Status::OK();
908   }
909 
ProduceOutput(OpKernelContext * ctx,const DoneCallback & done)910   void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) {
911     Tensor* handle;
912     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle),
913                          done);
914     Status s;
915     {
916       mutex_lock l(mu_);
917       s = initialization_status_;
918       if (s.ok()) {
919         handle->scalar<ResourceHandle>()() =
920             MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(),
921                                                  cinfo_.name());
922       }
923     }
924     OP_REQUIRES_OK_ASYNC(ctx, s, done);
925     done();
926   }
927 
928   NameAttrList dataset_factory_func_;
929   DataTypeVector output_dtypes_;
930   std::vector<PartialTensorShape> output_shapes_;
931 
932   BackgroundWorker background_worker_;
933 
934   mutex mu_;
935   ContainerInfo cinfo_ GUARDED_BY(mu_);
936   IteratorResource* iterator_resource_ GUARDED_BY(mu_) = nullptr;
937 
938   bool initialization_started_ GUARDED_BY(mu_) = false;
939   Status initialization_status_ GUARDED_BY(mu_);
940   std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_
941       GUARDED_BY(mu_);
942   const int graph_def_version_;
943 };
944 
945 }  // namespace
946 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)947 void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
948   IteratorResource* iterator;
949   OP_REQUIRES_OK_ASYNC(
950       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
951   // The call to `iterator->GetNext()` may block and depend on an
952   // inter-op thread pool thread, so we issue the call from the
953   // owned thread pool.
954   background_worker_.Schedule(std::bind(
955       [ctx, iterator](DoneCallback done) {
956         std::vector<Tensor> components;
957         bool end_of_sequence = false;
958 
959         Status s = iterator->GetNext(IteratorContext(ctx), &components,
960                                      &end_of_sequence);
961         // NOTE(mrry): We must unref the iterator before calling `done()`, to
962         // avoid destruction races.
963         iterator->Unref();
964 
965         if (!s.ok()) {
966           ctx->SetStatus(s);
967         } else if (end_of_sequence) {
968           ctx->SetStatus(errors::OutOfRange("End of sequence"));
969         } else {
970           for (int i = 0; i < components.size(); ++i) {
971             // TODO(mrry): Check that the shapes match the shape attrs.
972             ctx->set_output(i, components[i]);
973           }
974         }
975         done();
976       },
977       std::move(done)));
978 }
979 
Compute(OpKernelContext * ctx)980 void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) {
981   IteratorResource* iterator;
982   OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
983   core::ScopedUnref unref_iterator(iterator);
984   std::vector<Tensor> components;
985   bool end_of_sequence = false;
986 
987   OP_REQUIRES_OK(ctx, iterator->GetNext(IteratorContext(ctx), &components,
988                                         &end_of_sequence));
989   OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
990 
991   for (int i = 0; i < components.size(); ++i) {
992     // TODO(mrry): Check that the shapes match the shape attrs.
993     ctx->set_output(i, components[i]);
994   }
995 }
996 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)997 void IteratorGetNextAsOptionalOp::ComputeAsync(OpKernelContext* ctx,
998                                                DoneCallback done) {
999   IteratorResource* iterator;
1000   OP_REQUIRES_OK_ASYNC(
1001       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
1002   // The call to `iterator->GetNext()` may block and depend on an
1003   // inter-op thread pool thread, so we issue the call from the
1004   // owned thread pool.
1005   background_worker_.Schedule(std::bind(
1006       [this, ctx, iterator](DoneCallback done) {
1007         std::vector<Tensor> components;
1008         bool end_of_sequence = false;
1009 
1010         Status s = iterator->GetNext(IteratorContext(ctx), &components,
1011                                      &end_of_sequence);
1012         // NOTE(mrry): We must unref the iterator before calling `done()`, to
1013         // avoid destruction races.
1014         iterator->Unref();
1015 
1016         if (!s.ok()) {
1017           ctx->SetStatus(s);
1018         } else if (end_of_sequence) {
1019           OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done);
1020         } else {
1021           for (int i = 0; i < components.size(); ++i) {
1022             OP_REQUIRES_ASYNC(
1023                 ctx, components[i].dtype() == output_types_[i],
1024                 errors::InvalidArgument(
1025                     "The given optional does not match the expected type for "
1026                     "component ",
1027                     i, ". Expected: ", DataTypeString(output_types_[i]),
1028                     ". Actual: ", DataTypeString(components[i].dtype()), "."),
1029                 done);
1030             OP_REQUIRES_ASYNC(
1031                 ctx, output_shapes_[i].IsCompatibleWith(components[i].shape()),
1032                 errors::InvalidArgument(
1033                     "The given optional does not match the expected shape "
1034                     "for component ",
1035                     i, ". Expected: ", output_shapes_[i].DebugString(),
1036                     ". Actual: ", components[i].shape().DebugString(), "."),
1037                 done);
1038           }
1039 
1040           OP_REQUIRES_OK_ASYNC(
1041               ctx,
1042               WriteOptionalWithValueToOutput(ctx, 0, std::move(components)),
1043               done);
1044         }
1045         done();
1046       },
1047       std::move(done)));
1048 }
1049 
Compute(OpKernelContext * ctx)1050 void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
1051   const Tensor& resource_handle_t = ctx->input(0);
1052   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1053               errors::InvalidArgument("resource_handle must be a scalar"));
1054 
1055   // Validate that the handle corresponds to a real resource, and
1056   // that it is an IteratorResource.
1057   IteratorResource* iterator_resource;
1058   OP_REQUIRES_OK(
1059       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1060   iterator_resource->Unref();
1061 
1062   Tensor* string_handle_t;
1063   OP_REQUIRES_OK(ctx,
1064                  ctx->allocate_output(0, TensorShape({}), &string_handle_t));
1065   string_handle_t->scalar<string>()() =
1066       resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
1067 }
1068 
IteratorFromStringHandleOp(OpKernelConstruction * ctx)1069 IteratorFromStringHandleOp::IteratorFromStringHandleOp(
1070     OpKernelConstruction* ctx)
1071     : OpKernel(ctx) {
1072   OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
1073   OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
1074   OP_REQUIRES(
1075       ctx,
1076       output_dtypes_.empty() || output_shapes_.empty() ||
1077           output_dtypes_.size() == output_shapes_.size(),
1078       errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
1079                               "are set, they must have the same length."));
1080 }
1081 
Compute(OpKernelContext * ctx)1082 void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
1083   const Tensor& string_handle_t = ctx->input(0);
1084   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
1085               errors::InvalidArgument("string_handle must be a scalar"));
1086 
1087   ResourceHandle resource_handle;
1088   OP_REQUIRES(
1089       ctx, resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
1090       errors::InvalidArgument(
1091           "Could not parse string_handle as a valid ResourceHandle"));
1092 
1093   OP_REQUIRES(
1094       ctx, resource_handle.device() == ctx->device()->attributes().name(),
1095       errors::InvalidArgument("Attempted create an iterator on device \"",
1096                               ctx->device()->attributes().name(),
1097                               "\" from handle defined on device \"",
1098                               resource_handle.device(), "\""));
1099 
1100   // Validate that the handle corresponds to a real resource, and
1101   // that it is an IteratorResource.
1102   IteratorResource* iterator_resource;
1103   OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource));
1104   core::ScopedUnref unref_iterator(iterator_resource);
1105   if (!output_dtypes_.empty()) {
1106     OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
1107                                          iterator_resource->output_dtypes()));
1108   }
1109   if (!output_shapes_.empty()) {
1110     OP_REQUIRES_OK(ctx,
1111                    VerifyShapesCompatible(output_shapes_,
1112                                           iterator_resource->output_shapes()));
1113   }
1114 
1115   Tensor* resource_handle_t;
1116   OP_REQUIRES_OK(ctx,
1117                  ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
1118   resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
1119 }
1120 
1121 namespace {
1122 
1123 class SerializeIteratorOp : public OpKernel {
1124  public:
SerializeIteratorOp(OpKernelConstruction * ctx)1125   explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
1126 
Compute(OpKernelContext * ctx)1127   void Compute(OpKernelContext* ctx) override {
1128     const Tensor& resource_handle_t = ctx->input(0);
1129     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1130                 errors::InvalidArgument("resource_handle must be a scalar"));
1131 
1132     // Validate that the handle corresponds to a real resource, and
1133     // that it is an IteratorResource.
1134     IteratorResource* iterator_resource;
1135     OP_REQUIRES_OK(
1136         ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1137     core::ScopedUnref unref_iterator(iterator_resource);
1138     Tensor* variant_t;
1139     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t));
1140     IteratorStateVariant v;
1141     OP_REQUIRES_OK(ctx, v.InitializeFromIterator(ctx, iterator_resource));
1142     variant_t->scalar<Variant>()() = v;
1143   }
1144 };
1145 
1146 class DeserializeIteratorOp : public OpKernel {
1147  public:
DeserializeIteratorOp(OpKernelConstruction * ctx)1148   explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
1149 
Compute(OpKernelContext * ctx)1150   void Compute(OpKernelContext* ctx) override {
1151     // Validate that the handle corresponds to a real resource, and
1152     // that it is an IteratorResource.
1153     IteratorResource* iterator_resource;
1154     OP_REQUIRES_OK(
1155         ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1156     core::ScopedUnref unref_iterator(iterator_resource);
1157     Variant variant = ctx->input(1).scalar<Variant>()();
1158     auto* wrapper = variant.get<IteratorStateVariant>();
1159     OP_REQUIRES(ctx, wrapper != nullptr,
1160                 errors::InvalidArgument(
1161                     "DeserializeIteratorOp: Unable to parse variant tensor."));
1162     OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, wrapper->get()));
1163   }
1164 };
1165 
1166 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
1167 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
1168                         IteratorHandleOp);
1169 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU).Priority(1),
1170                         IteratorHandleOp);
1171 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2),
1172                         MakeIteratorOp);
1173 REGISTER_KERNEL_BUILDER(
1174     Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"),
1175     MakeIteratorOp);
1176 REGISTER_KERNEL_BUILDER(
1177     Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2),
1178     AnonymousIteratorHandleOp);
1179 REGISTER_KERNEL_BUILDER(
1180     Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1),
1181     AnonymousIteratorHandleOp);
1182 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
1183                         ToSingleElementOp);
1184 REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU),
1185                         ReduceDatasetOp);
1186 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
1187                         OneShotIteratorOp);
1188 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU).Priority(2),
1189                         IteratorGetNextOp);
1190 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU).Priority(1),
1191                         IteratorGetNextOp);
1192 REGISTER_KERNEL_BUILDER(
1193     Name("IteratorGetNextSync").Device(DEVICE_CPU).Priority(2),
1194     IteratorGetNextSyncOp);
1195 REGISTER_KERNEL_BUILDER(
1196     Name("IteratorGetNextSync").Device(DEVICE_GPU).Priority(1),
1197     IteratorGetNextSyncOp);
1198 REGISTER_KERNEL_BUILDER(
1199     Name("IteratorGetNextAsOptional").Device(DEVICE_CPU).Priority(2),
1200     IteratorGetNextAsOptionalOp);
1201 REGISTER_KERNEL_BUILDER(
1202     Name("IteratorGetNextAsOptional").Device(DEVICE_GPU).Priority(1),
1203     IteratorGetNextAsOptionalOp);
1204 REGISTER_KERNEL_BUILDER(
1205     Name("IteratorToStringHandle").Device(DEVICE_CPU).Priority(2),
1206     IteratorToStringHandleOp);
1207 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
1208                             .Device(DEVICE_GPU)
1209                             .HostMemory("string_handle")
1210                             .Priority(1),
1211                         IteratorToStringHandleOp);
1212 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
1213                         IteratorFromStringHandleOp);
1214 REGISTER_KERNEL_BUILDER(
1215     Name("IteratorFromStringHandleV2").Device(DEVICE_CPU).Priority(2),
1216     IteratorFromStringHandleOp);
1217 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
1218                             .Device(DEVICE_GPU)
1219                             .HostMemory("string_handle")
1220                             .Priority(1),
1221                         IteratorFromStringHandleOp);
1222 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
1223                         SerializeIteratorOp);
1224 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
1225                         DeserializeIteratorOp);
1226 
1227 }  // namespace
1228 
1229 }  // namespace data
1230 }  // namespace tensorflow
1231