1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/lookup_util.h"
17
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/framework/tensor_shape.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/io/inputbuffer.h"
22
23 namespace tensorflow {
24 namespace lookup {
25 namespace {
26
27 static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
28 static const int kLineNumber = -1;
29 static const int kWholeLine = -2;
30
GetNumLinesInTextFile(Env * env,const string & vocab_file,int64 * num_lines)31 Status GetNumLinesInTextFile(Env* env, const string& vocab_file,
32 int64* num_lines) {
33 std::unique_ptr<RandomAccessFile> file;
34 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
35
36 io::InputBuffer input_buffer(file.get(), kInputBufferSize);
37 string line;
38 Status s = input_buffer.ReadLine(&line);
39 int64 next_id = 0;
40 while (s.ok()) {
41 next_id++;
42 s = input_buffer.ReadLine(&line);
43 }
44 if (!errors::IsOutOfRange(s)) {
45 return s;
46 }
47 *num_lines = next_id;
48 return Status::OK();
49 }
50
51 // Iterator that reads a text file. Each iteration process one line, it parses
52 // the line and populates the keys and values tensors used for initialization
53 // with a single key and corresponding value.
54 //
55 // What information of the line to populate the key or values is specified by
56 // providing key_index and value_index.
57 class TextFileLineIterator
58 : public InitializableLookupTable::InitTableIterator {
59 public:
TextFileLineIterator()60 TextFileLineIterator()
61 : valid_(false),
62 vocab_size_(-1),
63 status_(errors::FailedPrecondition("Not initialized")) {}
64
65 // Initialize iterator.
66 //
67 // Prepares the file 'filename' and sets the data types to return the keys and
68 // values tensors. It requires the indices of the tokens in the line given a
69 // delimiter to specify where to pick the data from.
70 //
71 // - Index -2 means the entire line as string.
72 // - Index -1 means the line number stored in int64.
73 // - Index >= 0 represent index (starting at zero) of the split line based on
74 // delimiter.
Init(const string & filename,int64 vocab_size,char delimiter,DataType key_dtype,int64 key_index,DataType value_dtype,int64 value_index,Env * env)75 Status Init(const string& filename, int64 vocab_size, char delimiter,
76 DataType key_dtype, int64 key_index, DataType value_dtype,
77 int64 value_index, Env* env) {
78 filename_ = filename;
79 vocab_size_ = vocab_size;
80 delimiter_ = delimiter;
81 key_ = Tensor(key_dtype, TensorShape({}));
82 value_ = Tensor(value_dtype, TensorShape({}));
83 key_index_ = key_index;
84 value_index_ = value_index;
85 env_ = env;
86
87 status_ = env->NewRandomAccessFile(filename_, &file_);
88 if (!status_.ok()) return status_;
89
90 input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize));
91 valid_ = true;
92 next_id_ = 0;
93 ignore_split_ = std::max(key_index_, value_index_) < 0;
94 Next();
95 return status_;
96 }
97
Next()98 void Next() override {
99 if (!valid_) return;
100
101 string line;
102 status_ = input_buffer_->ReadLine(&line);
103 if (!status_.ok()) {
104 if (errors::IsOutOfRange(status_) && vocab_size_ != -1 &&
105 next_id_ != vocab_size_) {
106 status_ = errors::InvalidArgument("Invalid vocab_size in ", filename_,
107 ": expected ", vocab_size_,
108 " but got ", next_id_);
109 }
110 valid_ = false;
111 return;
112 }
113 if (vocab_size_ != -1 && next_id_ >= vocab_size_) {
114 LOG(WARNING) << "Truncated " << filename_ << " before its end at "
115 << vocab_size_ << " records.";
116 LOG(WARNING) << "next_id_ : " << next_id_;
117 status_ = errors::OutOfRange("Finished reading ", vocab_size_,
118 " of lines from ", filename_);
119 valid_ = false;
120 return;
121 }
122 if (line.empty()) {
123 status_ = errors::InvalidArgument("Invalid content in ", filename_,
124 ": empty line found at position ",
125 input_buffer_->Tell(), ".");
126 valid_ = false;
127 return;
128 }
129
130 std::vector<string> tokens;
131 if (!ignore_split_) {
132 tokens = str_util::Split(line, delimiter_);
133 if (std::max(key_index_, value_index_) >= tokens.size()) {
134 status_ = errors::InvalidArgument(
135 "Invalid number of columns in ", filename_, " line ", next_id_,
136 " (", line, ") : expected ", std::max(key_index_, value_index_),
137 " got ", tokens.size());
138 valid_ = false;
139 return;
140 }
141 }
142 status_ = SetValue(line, tokens, key_index_, &key_);
143 if (!status_.ok()) {
144 valid_ = false;
145 return;
146 }
147 status_ = SetValue(line, tokens, value_index_, &value_);
148 if (!status_.ok()) {
149 valid_ = false;
150 return;
151 }
152
153 next_id_++;
154 }
155
Valid() const156 bool Valid() const override { return valid_; }
157
keys() const158 const Tensor& keys() const override { return key_; }
159
values() const160 const Tensor& values() const override { return value_; }
161
status() const162 Status status() const override { return status_; }
163
total_size() const164 int64 total_size() const override {
165 if (vocab_size_ == -1) {
166 int64 new_size = -1;
167 Status status = GetNumLinesInTextFile(env_, filename_, &new_size);
168 if (!status.ok()) {
169 LOG(WARNING) << "Unable to get line count: " << status;
170 new_size = -1;
171 }
172 *const_cast<int64*>(&vocab_size_) = new_size;
173 }
174 return vocab_size_;
175 }
176
177 private:
178 Tensor key_;
179 Tensor value_;
180 bool valid_; // true if the iterator points to an existing range.
181 int64 key_index_;
182 int64 value_index_;
183 Env* env_;
184 int64 next_id_;
185 int64 vocab_size_;
186 string filename_;
187 char delimiter_;
188 Status status_;
189 bool ignore_split_;
190 std::unique_ptr<RandomAccessFile> file_; // must outlive input_buffer_
191 std::unique_ptr<io::InputBuffer> input_buffer_;
192
193 // Set the corresponding value from line or tokens based on 'index' into the
194 // tensor 't'. The value is transformed to the given data type 'dtype'.
SetValue(const string & line,const std::vector<string> & tokens,int64 index,Tensor * tensor)195 Status SetValue(const string& line, const std::vector<string>& tokens,
196 int64 index, Tensor* tensor) {
197 if (index == kLineNumber) {
198 tensor->flat<int64>()(0) = next_id_;
199 return Status::OK();
200 }
201 const string& token = (index == kWholeLine) ? line : tokens[index];
202 const DataType& dtype = tensor->dtype();
203 switch (dtype) {
204 case DT_INT32: {
205 int32 value;
206 if (!strings::safe_strto32(token.c_str(), &value)) {
207 valid_ = false;
208 return errors::InvalidArgument("Field ", token, " in line ", next_id_,
209 " is not a valid int32.");
210 }
211 tensor->flat<int32>()(0) = value;
212 } break;
213 case DT_INT64: {
214 int64 value;
215 if (!strings::safe_strto64(token.c_str(), &value)) {
216 valid_ = false;
217 return errors::InvalidArgument("Field ", token, " in line ", next_id_,
218 " is not a valid int64.");
219 }
220 tensor->flat<int64>()(0) = value;
221 } break;
222 case DT_FLOAT: {
223 float value;
224 if (!strings::safe_strtof(token.c_str(), &value)) {
225 valid_ = false;
226 return errors::InvalidArgument("Field ", token, " in line ", next_id_,
227 " is not a valid float.");
228 }
229 tensor->flat<float>()(0) = value;
230 } break;
231 case DT_DOUBLE: {
232 double value;
233 if (!strings::safe_strtod(token.c_str(), &value)) {
234 valid_ = false;
235 return errors::InvalidArgument("Field ", token, " in line ", next_id_,
236 " is not a valid double.");
237 }
238 tensor->flat<double>()(0) = value;
239 } break;
240 case DT_STRING:
241 tensor->flat<string>()(0) = token;
242 break;
243 default:
244 valid_ = false;
245 return errors::InvalidArgument("Data type ", DataTypeString(dtype),
246 " not supported.");
247 }
248 return Status::OK();
249 }
250
251 TF_DISALLOW_COPY_AND_ASSIGN(TextFileLineIterator);
252 };
253
GetTableHandle(const string & input_name,OpKernelContext * ctx,string * container,string * table_handle)254 Status GetTableHandle(const string& input_name, OpKernelContext* ctx,
255 string* container, string* table_handle) {
256 {
257 mutex* mu;
258 TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
259 mutex_lock l(*mu);
260 Tensor tensor;
261 TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
262 if (tensor.NumElements() != 2) {
263 return errors::InvalidArgument(
264 "Lookup table handle must be scalar, but had shape: ",
265 tensor.shape().DebugString());
266 }
267 auto h = tensor.flat<string>();
268 *container = h(0);
269 *table_handle = h(1);
270 }
271 return Status::OK();
272 }
273
274 } // namespace
275
GetLookupTable(const string & input_name,OpKernelContext * ctx,LookupInterface ** table)276 Status GetLookupTable(const string& input_name, OpKernelContext* ctx,
277 LookupInterface** table) {
278 string container;
279 string table_handle;
280 DataType handle_dtype;
281 TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
282 if (handle_dtype == DT_RESOURCE) {
283 ResourceHandle handle;
284 TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle));
285 return LookupResource(ctx, handle, table);
286 } else {
287 TF_RETURN_IF_ERROR(
288 GetTableHandle(input_name, ctx, &container, &table_handle));
289 return ctx->resource_manager()->Lookup(container, table_handle, table);
290 }
291 }
292
GetInitializableLookupTable(const string & input_name,OpKernelContext * ctx,InitializableLookupTable ** table)293 Status GetInitializableLookupTable(const string& input_name,
294 OpKernelContext* ctx,
295 InitializableLookupTable** table) {
296 LookupInterface* lookup_table;
297 DataType handle_dtype;
298 TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
299 if (handle_dtype == DT_RESOURCE) {
300 ResourceHandle handle;
301 TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle));
302 TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &lookup_table));
303 *table = lookup_table->GetInitializableLookupTable();
304 if (*table == nullptr) {
305 lookup_table->Unref();
306 return errors::InvalidArgument("Table ", handle.container(), " ",
307 handle.name(), " is not initializable");
308 }
309 } else {
310 string container;
311 string table_handle;
312 TF_RETURN_IF_ERROR(
313 GetTableHandle(input_name, ctx, &container, &table_handle));
314 TF_RETURN_IF_ERROR(ctx->resource_manager()->Lookup(container, table_handle,
315 &lookup_table));
316 *table = lookup_table->GetInitializableLookupTable();
317 if (*table == nullptr) {
318 lookup_table->Unref();
319 return errors::InvalidArgument("Table ", container, " ", table_handle,
320 " is not initializable");
321 }
322 }
323 return Status::OK();
324 }
325
CheckTableDataTypes(const LookupInterface & table,DataType key_dtype,DataType value_dtype,const string & table_name)326 Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
327 DataType value_dtype, const string& table_name) {
328 if (table.key_dtype() != key_dtype || table.value_dtype() != value_dtype) {
329 return errors::InvalidArgument(
330 "Conflicting key/value dtypes ", DataTypeString(key_dtype), "->",
331 DataTypeString(value_dtype), " with ",
332 DataTypeString(table.key_dtype()), "-",
333 DataTypeString(table.value_dtype()), " for table ", table_name);
334 }
335 return Status::OK();
336 }
337
338 // Helper function to initialize an InitializableLookupTable from a text file.
InitializeTableFromTextFile(const string & filename,int64 vocab_size,char delimiter,int32 key_index,int32 value_index,Env * env,InitializableLookupTable * table)339 Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
340 char delimiter, int32 key_index,
341 int32 value_index, Env* env,
342 InitializableLookupTable* table) {
343 if (key_index == kLineNumber && table->key_dtype() != DT_INT64) {
344 return errors::InvalidArgument(
345 "Key index for line number requires table key dtype of int64, got ",
346 DataTypeString(table->key_dtype()));
347 }
348 const DataType& key_dtype = table->key_dtype();
349 const DataType& value_dtype = table->value_dtype();
350 if (key_index == kWholeLine && !DataTypeIsInteger(key_dtype) &&
351 key_dtype != DT_STRING) {
352 return errors::InvalidArgument(
353 "Key index for whole line requires string or integer table key, got ",
354 DataTypeString(table->key_dtype()));
355 }
356 if (value_index == kLineNumber && value_dtype != DT_INT64) {
357 return errors::InvalidArgument(
358 "Value index for line number requires table value dtype of int64, got ",
359 DataTypeString(table->value_dtype()));
360 }
361 if (value_index == kWholeLine && value_dtype != DT_STRING) {
362 return errors::InvalidArgument(
363 "Value index for whole line requires table value dtype of string, got ",
364 DataTypeString(table->value_dtype()));
365 }
366
367 TextFileLineIterator iter;
368 TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype,
369 key_index, value_dtype, value_index, env));
370 // For initialization from files, ignore if the table is already
371 // initialized. The table shared name should contain the filename to
372 // avoid trying to initialize the same table from the same file at the same
373 // time.
374 Status s = table->Initialize(iter);
375 if (errors::IsFailedPrecondition(s) && table->is_initialized()) {
376 LOG(INFO) << "Table trying to initialize from file " << filename
377 << " is already initialized.";
378 return Status::OK();
379 }
380 return s;
381 }
382
383 } // namespace lookup
384 } // namespace tensorflow
385