• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/data/root_dataset.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/function_handle_cache.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/kernels/initializable_lookup_table.h"
32 #include "tensorflow/core/kernels/lookup_util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/io/inputbuffer.h"
36 #include "tensorflow/core/lib/strings/numbers.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/platform/macros.h"
39 #include "tensorflow/core/platform/refcount.h"
40 
41 namespace tensorflow {
42 namespace data {
43 namespace experimental {
44 namespace {
45 
46 using InitializerSerializer =
47     ::tensorflow::lookup::InitializableLookupTable::InitializerSerializer;
48 
49 class DatasetIterator
50     : public lookup::InitializableLookupTable::InitTableIterator {
51  public:
DatasetIterator(data::DatasetBase * dataset)52   explicit DatasetIterator(data::DatasetBase* dataset) : dataset_(dataset) {}
53 
~DatasetIterator()54   ~DatasetIterator() override {}
55 
Init(OpKernelContext * ctx)56   Status Init(OpKernelContext* ctx) {
57     data::IteratorContext::Params params(ctx);
58     function_handle_cache_ = absl::make_unique<FunctionHandleCache>(params.flr);
59     params.function_handle_cache = function_handle_cache_.get();
60     params.resource_mgr = &resource_mgr_;
61     cancellation_manager_ =
62         absl::make_unique<CancellationManager>(ctx->cancellation_manager());
63     params.cancellation_manager = cancellation_manager_.get();
64     iterator_ctx_ = absl::make_unique<data::IteratorContext>(std::move(params));
65 
66     DatasetBase* finalized_dataset;
67     TF_RETURN_IF_ERROR(
68         data::FinalizeDataset(ctx, dataset_, &finalized_dataset));
69     TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
70         iterator_ctx_.get(), nullptr, "LookupTable", &iterator_));
71     core::ScopedUnref unref(finalized_dataset);
72     Next();
73     return Status::OK();
74   }
75 
Next()76   void Next() override {
77     bool end_of_input;
78     tensors_.clear();
79     status_ = iterator_->GetNext(iterator_ctx_.get(), &tensors_, &end_of_input);
80     if (status_.ok() && end_of_input) {
81       status_ = errors::OutOfRange("end of iterator");
82     }
83   }
84 
Valid() const85   bool Valid() const override { return status_.ok(); }
86 
keys() const87   const Tensor& keys() const override { return tensors_[0]; }
88 
values() const89   const Tensor& values() const override { return tensors_[1]; }
90 
status() const91   Status status() const override { return status_; }
92 
total_size() const93   int64 total_size() const override {
94     int64_t size = dataset_->Cardinality();
95     if (size < 0) {
96       return 0;
97     }
98     return size;
99   }
100 
101  private:
102   data::DatasetBase* dataset_;  // owned.
103   std::unique_ptr<data::IteratorContext> iterator_ctx_;
104   std::unique_ptr<FunctionHandleCache> function_handle_cache_;
105   ResourceMgr resource_mgr_;
106   std::unique_ptr<CancellationManager> cancellation_manager_;
107   std::unique_ptr<data::IteratorBase> iterator_;
108   std::vector<Tensor> tensors_;
109   Status status_;
110 };
111 
MakeDatasetInitializerSerializer(OpKernelContext * ctx,data::DatasetBase * dataset)112 std::unique_ptr<InitializerSerializer> MakeDatasetInitializerSerializer(
113     OpKernelContext* ctx, data::DatasetBase* dataset) {
114   dataset->Ref();
115   auto unref_dataset = [dataset] { dataset->Unref(); };
116   return absl::make_unique<InitializerSerializer>(
117       [dataset, resource_manager = ctx->resource_manager(),
118        device_name = ctx->device()->attributes().name()](
119           GraphDefBuilder* builder, Node* table, Node** out) {
120         data::DatasetBase::DatasetGraphDefBuilder db(builder);
121         data::SerializationContext::Params params;
122         params.resource_mgr = resource_manager;
123         params.device_name = device_name;
124         params.serialize_data_tensors = true;
125         data::SerializationContext serialization_ctx(params);
126         Node* dataset_node;
127         TF_RETURN_IF_ERROR(
128             db.AddInputDataset(&serialization_ctx, dataset, &dataset_node));
129         *out = ops::BinaryOp("InitializeTableFromDataset", table, dataset_node,
130                              builder->opts());
131         if (*out == nullptr) {
132           return errors::Internal(
133               "Failed to create InitializeTableFromDataset op: ",
134               builder->opts().StatusToString());
135         }
136         return Status::OK();
137       },
138       /*cleanup=*/std::move(unref_dataset));
139 }
140 
InitializeTableFromDataset(OpKernelContext * ctx,data::DatasetBase * dataset,lookup::InitializableLookupTable * table,AsyncOpKernel::DoneCallback done)141 void InitializeTableFromDataset(OpKernelContext* ctx,
142                                 data::DatasetBase* dataset,
143                                 lookup::InitializableLookupTable* table,
144                                 AsyncOpKernel::DoneCallback done) {
145   // Construct the cleanup before `iter` below so that `iter` is destroyed
146   // before calling `done`.
147   auto cleanup = gtl::MakeCleanup([done = std::move(done)]() { done(); });
148   // Assert that the dataset types match up to that expected in the table.
149   const auto& dataset_types = dataset->output_dtypes();
150   OP_REQUIRES(
151       ctx, dataset_types.size() == 2,
152       errors::InvalidArgument("Dataset should have two output types only"));
153   OP_REQUIRES(ctx, dataset_types[0] == table->key_dtype(),
154               errors::InvalidArgument(
155                   "Key dtype expected: ", table->key_dtype(),
156                   " but obtained: ", dataset_types[0], " from the dataset"));
157   OP_REQUIRES(ctx, dataset_types[1] == table->value_dtype(),
158               errors::InvalidArgument(
159                   "Value dtype expected: ", table->value_dtype(),
160                   " but obtained: ", dataset_types[1], " from the dataset"));
161   // Assert that the dataset output shapes are scalars.
162   const auto& dataset_shapes = dataset->output_shapes();
163   OP_REQUIRES(
164       ctx, dataset_shapes.size() == 2,
165       errors::InvalidArgument("Dataset should have two output shapes only"));
166   OP_REQUIRES(ctx, dataset_shapes[0].IsCompatibleWith(PartialTensorShape({})),
167               errors::InvalidArgument("Expected scalar for key. Obtained: ",
168                                       dataset_shapes[0].DebugString()));
169   OP_REQUIRES(ctx, dataset_shapes[1].IsCompatibleWith(PartialTensorShape({})),
170               errors::InvalidArgument("Expected scalar for key. Obtained: ",
171                                       dataset_shapes[1].DebugString()));
172   DatasetIterator iter(dataset);
173   OP_REQUIRES_OK(ctx, iter.Init(ctx));
174   Status s =
175       table->Initialize(iter, MakeDatasetInitializerSerializer(ctx, dataset));
176   if (errors::IsFailedPrecondition(s) && table->is_initialized()) {
177     LOG(INFO) << "Table already initialized from dataset.";
178     return;
179   }
180   ctx->SetStatus(s);
181 }
182 
183 class InitializeTableFromDatasetOp : public AsyncOpKernel {
184  public:
InitializeTableFromDatasetOp(OpKernelConstruction * ctx)185   explicit InitializeTableFromDatasetOp(OpKernelConstruction* ctx)
186       : AsyncOpKernel(ctx),
187         background_worker_(ctx->env(), "initialize_table_from_dataset") {}
188 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)189   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
190     lookup::InitializableLookupTable* table;
191     OP_REQUIRES_OK_ASYNC(
192         ctx, GetInitializableLookupTable("table_handle", ctx, &table), done);
193     core::ScopedUnref unref_me(table);
194     data::DatasetBase* dataset;
195     OP_REQUIRES_OK_ASYNC(
196         ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
197     background_worker_.Schedule([ctx, dataset, table, done]() {
198       InitializeTableFromDataset(ctx, dataset, table, done);
199     });
200   }
201 
202  private:
203   TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromDatasetOp);
204 
205   data::BackgroundWorker background_worker_;
206 };
207 
208 REGISTER_KERNEL_BUILDER(Name("InitializeTableFromDataset").Device(DEVICE_CPU),
209                         InitializeTableFromDatasetOp);
210 
211 }  // namespace
212 }  // namespace experimental
213 }  // namespace data
214 }  // namespace tensorflow
215