1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/data/root_dataset.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/function_handle_cache.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/kernels/initializable_lookup_table.h"
32 #include "tensorflow/core/kernels/lookup_util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/io/inputbuffer.h"
36 #include "tensorflow/core/lib/strings/numbers.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/platform/macros.h"
39 #include "tensorflow/core/platform/refcount.h"
40
41 namespace tensorflow {
42 namespace data {
43 namespace experimental {
44 namespace {
45
46 using InitializerSerializer =
47 ::tensorflow::lookup::InitializableLookupTable::InitializerSerializer;
48
49 class DatasetIterator
50 : public lookup::InitializableLookupTable::InitTableIterator {
51 public:
DatasetIterator(data::DatasetBase * dataset)52 explicit DatasetIterator(data::DatasetBase* dataset) : dataset_(dataset) {}
53
~DatasetIterator()54 ~DatasetIterator() override {}
55
Init(OpKernelContext * ctx)56 Status Init(OpKernelContext* ctx) {
57 data::IteratorContext::Params params(ctx);
58 function_handle_cache_ = absl::make_unique<FunctionHandleCache>(params.flr);
59 params.function_handle_cache = function_handle_cache_.get();
60 params.resource_mgr = &resource_mgr_;
61 cancellation_manager_ =
62 absl::make_unique<CancellationManager>(ctx->cancellation_manager());
63 params.cancellation_manager = cancellation_manager_.get();
64 iterator_ctx_ = absl::make_unique<data::IteratorContext>(std::move(params));
65
66 DatasetBase* finalized_dataset;
67 TF_RETURN_IF_ERROR(
68 data::FinalizeDataset(ctx, dataset_, &finalized_dataset));
69 TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
70 iterator_ctx_.get(), nullptr, "LookupTable", &iterator_));
71 core::ScopedUnref unref(finalized_dataset);
72 Next();
73 return Status::OK();
74 }
75
Next()76 void Next() override {
77 bool end_of_input;
78 tensors_.clear();
79 status_ = iterator_->GetNext(iterator_ctx_.get(), &tensors_, &end_of_input);
80 if (status_.ok() && end_of_input) {
81 status_ = errors::OutOfRange("end of iterator");
82 }
83 }
84
Valid() const85 bool Valid() const override { return status_.ok(); }
86
keys() const87 const Tensor& keys() const override { return tensors_[0]; }
88
values() const89 const Tensor& values() const override { return tensors_[1]; }
90
status() const91 Status status() const override { return status_; }
92
total_size() const93 int64 total_size() const override {
94 int64_t size = dataset_->Cardinality();
95 if (size < 0) {
96 return 0;
97 }
98 return size;
99 }
100
101 private:
102 data::DatasetBase* dataset_; // owned.
103 std::unique_ptr<data::IteratorContext> iterator_ctx_;
104 std::unique_ptr<FunctionHandleCache> function_handle_cache_;
105 ResourceMgr resource_mgr_;
106 std::unique_ptr<CancellationManager> cancellation_manager_;
107 std::unique_ptr<data::IteratorBase> iterator_;
108 std::vector<Tensor> tensors_;
109 Status status_;
110 };
111
MakeDatasetInitializerSerializer(OpKernelContext * ctx,data::DatasetBase * dataset)112 std::unique_ptr<InitializerSerializer> MakeDatasetInitializerSerializer(
113 OpKernelContext* ctx, data::DatasetBase* dataset) {
114 dataset->Ref();
115 auto unref_dataset = [dataset] { dataset->Unref(); };
116 return absl::make_unique<InitializerSerializer>(
117 [dataset, resource_manager = ctx->resource_manager(),
118 device_name = ctx->device()->attributes().name()](
119 GraphDefBuilder* builder, Node* table, Node** out) {
120 data::DatasetBase::DatasetGraphDefBuilder db(builder);
121 data::SerializationContext::Params params;
122 params.resource_mgr = resource_manager;
123 params.device_name = device_name;
124 params.serialize_data_tensors = true;
125 data::SerializationContext serialization_ctx(params);
126 Node* dataset_node;
127 TF_RETURN_IF_ERROR(
128 db.AddInputDataset(&serialization_ctx, dataset, &dataset_node));
129 *out = ops::BinaryOp("InitializeTableFromDataset", table, dataset_node,
130 builder->opts());
131 if (*out == nullptr) {
132 return errors::Internal(
133 "Failed to create InitializeTableFromDataset op: ",
134 builder->opts().StatusToString());
135 }
136 return Status::OK();
137 },
138 /*cleanup=*/std::move(unref_dataset));
139 }
140
InitializeTableFromDataset(OpKernelContext * ctx,data::DatasetBase * dataset,lookup::InitializableLookupTable * table,AsyncOpKernel::DoneCallback done)141 void InitializeTableFromDataset(OpKernelContext* ctx,
142 data::DatasetBase* dataset,
143 lookup::InitializableLookupTable* table,
144 AsyncOpKernel::DoneCallback done) {
145 // Construct the cleanup before `iter` below so that `iter` is destroyed
146 // before calling `done`.
147 auto cleanup = gtl::MakeCleanup([done = std::move(done)]() { done(); });
148 // Assert that the dataset types match up to that expected in the table.
149 const auto& dataset_types = dataset->output_dtypes();
150 OP_REQUIRES(
151 ctx, dataset_types.size() == 2,
152 errors::InvalidArgument("Dataset should have two output types only"));
153 OP_REQUIRES(ctx, dataset_types[0] == table->key_dtype(),
154 errors::InvalidArgument(
155 "Key dtype expected: ", table->key_dtype(),
156 " but obtained: ", dataset_types[0], " from the dataset"));
157 OP_REQUIRES(ctx, dataset_types[1] == table->value_dtype(),
158 errors::InvalidArgument(
159 "Value dtype expected: ", table->value_dtype(),
160 " but obtained: ", dataset_types[1], " from the dataset"));
161 // Assert that the dataset output shapes are scalars.
162 const auto& dataset_shapes = dataset->output_shapes();
163 OP_REQUIRES(
164 ctx, dataset_shapes.size() == 2,
165 errors::InvalidArgument("Dataset should have two output shapes only"));
166 OP_REQUIRES(ctx, dataset_shapes[0].IsCompatibleWith(PartialTensorShape({})),
167 errors::InvalidArgument("Expected scalar for key. Obtained: ",
168 dataset_shapes[0].DebugString()));
169 OP_REQUIRES(ctx, dataset_shapes[1].IsCompatibleWith(PartialTensorShape({})),
170 errors::InvalidArgument("Expected scalar for key. Obtained: ",
171 dataset_shapes[1].DebugString()));
172 DatasetIterator iter(dataset);
173 OP_REQUIRES_OK(ctx, iter.Init(ctx));
174 Status s =
175 table->Initialize(iter, MakeDatasetInitializerSerializer(ctx, dataset));
176 if (errors::IsFailedPrecondition(s) && table->is_initialized()) {
177 LOG(INFO) << "Table already initialized from dataset.";
178 return;
179 }
180 ctx->SetStatus(s);
181 }
182
183 class InitializeTableFromDatasetOp : public AsyncOpKernel {
184 public:
InitializeTableFromDatasetOp(OpKernelConstruction * ctx)185 explicit InitializeTableFromDatasetOp(OpKernelConstruction* ctx)
186 : AsyncOpKernel(ctx),
187 background_worker_(ctx->env(), "initialize_table_from_dataset") {}
188
ComputeAsync(OpKernelContext * ctx,DoneCallback done)189 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
190 lookup::InitializableLookupTable* table;
191 OP_REQUIRES_OK_ASYNC(
192 ctx, GetInitializableLookupTable("table_handle", ctx, &table), done);
193 core::ScopedUnref unref_me(table);
194 data::DatasetBase* dataset;
195 OP_REQUIRES_OK_ASYNC(
196 ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
197 background_worker_.Schedule([ctx, dataset, table, done]() {
198 InitializeTableFromDataset(ctx, dataset, table, done);
199 });
200 }
201
202 private:
203 TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromDatasetOp);
204
205 data::BackgroundWorker background_worker_;
206 };
207
208 REGISTER_KERNEL_BUILDER(Name("InitializeTableFromDataset").Device(DEVICE_CPU),
209 InitializeTableFromDatasetOp);
210
211 } // namespace
212 } // namespace experimental
213 } // namespace data
214 } // namespace tensorflow
215