• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/lookup_util.h"
17 
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/framework/tensor_shape.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/io/inputbuffer.h"
22 
23 namespace tensorflow {
24 namespace lookup {
25 namespace {
26 
27 static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
28 static const int kLineNumber = -1;
29 static const int kWholeLine = -2;
30 
GetNumLinesInTextFile(Env * env,const string & vocab_file,int64 * num_lines)31 Status GetNumLinesInTextFile(Env* env, const string& vocab_file,
32                              int64* num_lines) {
33   std::unique_ptr<RandomAccessFile> file;
34   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
35 
36   io::InputBuffer input_buffer(file.get(), kInputBufferSize);
37   string line;
38   Status s = input_buffer.ReadLine(&line);
39   int64 next_id = 0;
40   while (s.ok()) {
41     next_id++;
42     s = input_buffer.ReadLine(&line);
43   }
44   if (!errors::IsOutOfRange(s)) {
45     return s;
46   }
47   *num_lines = next_id;
48   return Status::OK();
49 }
50 
51 // Iterator that reads a text file. Each iteration process one line, it parses
52 // the line and populates the keys and values tensors used for initialization
53 // with a single key and corresponding value.
54 //
55 // What information of the line to populate the key or values is specified by
56 // providing key_index and value_index.
57 class TextFileLineIterator
58     : public InitializableLookupTable::InitTableIterator {
59  public:
TextFileLineIterator()60   TextFileLineIterator()
61       : valid_(false),
62         vocab_size_(-1),
63         status_(errors::FailedPrecondition("Not initialized")) {}
64 
65   // Initialize iterator.
66   //
67   // Prepares the file 'filename' and sets the data types to return the keys and
68   // values tensors. It requires the indices of the tokens in the line given a
69   // delimiter to specify where to pick the data from.
70   //
71   // - Index -2 means the entire line as string.
72   // - Index -1 means the line number stored in int64.
73   // - Index >= 0 represent index (starting at zero) of the split line based on
74   //   delimiter.
Init(const string & filename,int64 vocab_size,char delimiter,DataType key_dtype,int64 key_index,DataType value_dtype,int64 value_index,Env * env)75   Status Init(const string& filename, int64 vocab_size, char delimiter,
76               DataType key_dtype, int64 key_index, DataType value_dtype,
77               int64 value_index, Env* env) {
78     filename_ = filename;
79     vocab_size_ = vocab_size;
80     delimiter_ = delimiter;
81     key_ = Tensor(key_dtype, TensorShape({}));
82     value_ = Tensor(value_dtype, TensorShape({}));
83     key_index_ = key_index;
84     value_index_ = value_index;
85     env_ = env;
86 
87     status_ = env->NewRandomAccessFile(filename_, &file_);
88     if (!status_.ok()) return status_;
89 
90     input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize));
91     valid_ = true;
92     next_id_ = 0;
93     ignore_split_ = std::max(key_index_, value_index_) < 0;
94     Next();
95     return status_;
96   }
97 
Next()98   void Next() override {
99     if (!valid_) return;
100 
101     string line;
102     status_ = input_buffer_->ReadLine(&line);
103     if (!status_.ok()) {
104       if (errors::IsOutOfRange(status_) && vocab_size_ != -1 &&
105           next_id_ != vocab_size_) {
106         status_ = errors::InvalidArgument("Invalid vocab_size in ", filename_,
107                                           ": expected ", vocab_size_,
108                                           " but got ", next_id_);
109       }
110       valid_ = false;
111       return;
112     }
113     if (vocab_size_ != -1 && next_id_ >= vocab_size_) {
114       LOG(WARNING) << "Truncated " << filename_ << " before its end at "
115                    << vocab_size_ << " records.";
116       LOG(WARNING) << "next_id_  : " << next_id_;
117       status_ = errors::OutOfRange("Finished reading ", vocab_size_,
118                                    " of lines from ", filename_);
119       valid_ = false;
120       return;
121     }
122     if (line.empty()) {
123       status_ = errors::InvalidArgument("Invalid content in ", filename_,
124                                         ": empty line found at position ",
125                                         input_buffer_->Tell(), ".");
126       valid_ = false;
127       return;
128     }
129 
130     std::vector<string> tokens;
131     if (!ignore_split_) {
132       tokens = str_util::Split(line, delimiter_);
133       if (std::max(key_index_, value_index_) >= tokens.size()) {
134         status_ = errors::InvalidArgument(
135             "Invalid number of columns in ", filename_, " line ", next_id_,
136             " (", line, ") : expected ", std::max(key_index_, value_index_),
137             " got ", tokens.size());
138         valid_ = false;
139         return;
140       }
141     }
142     status_ = SetValue(line, tokens, key_index_, &key_);
143     if (!status_.ok()) {
144       valid_ = false;
145       return;
146     }
147     status_ = SetValue(line, tokens, value_index_, &value_);
148     if (!status_.ok()) {
149       valid_ = false;
150       return;
151     }
152 
153     next_id_++;
154   }
155 
Valid() const156   bool Valid() const override { return valid_; }
157 
keys() const158   const Tensor& keys() const override { return key_; }
159 
values() const160   const Tensor& values() const override { return value_; }
161 
status() const162   Status status() const override { return status_; }
163 
total_size() const164   int64 total_size() const override {
165     if (vocab_size_ == -1) {
166       int64 new_size = -1;
167       Status status = GetNumLinesInTextFile(env_, filename_, &new_size);
168       if (!status.ok()) {
169         LOG(WARNING) << "Unable to get line count: " << status;
170         new_size = -1;
171       }
172       *const_cast<int64*>(&vocab_size_) = new_size;
173     }
174     return vocab_size_;
175   }
176 
177  private:
178   Tensor key_;
179   Tensor value_;
180   bool valid_;  // true if the iterator points to an existing range.
181   int64 key_index_;
182   int64 value_index_;
183   Env* env_;
184   int64 next_id_;
185   int64 vocab_size_;
186   string filename_;
187   char delimiter_;
188   Status status_;
189   bool ignore_split_;
190   std::unique_ptr<RandomAccessFile> file_;  // must outlive input_buffer_
191   std::unique_ptr<io::InputBuffer> input_buffer_;
192 
193   // Set the corresponding value from line or tokens based on 'index' into the
194   // tensor 't'. The value is transformed to the given data type 'dtype'.
SetValue(const string & line,const std::vector<string> & tokens,int64 index,Tensor * tensor)195   Status SetValue(const string& line, const std::vector<string>& tokens,
196                   int64 index, Tensor* tensor) {
197     if (index == kLineNumber) {
198       tensor->flat<int64>()(0) = next_id_;
199       return Status::OK();
200     }
201     const string& token = (index == kWholeLine) ? line : tokens[index];
202     const DataType& dtype = tensor->dtype();
203     switch (dtype) {
204       case DT_INT32: {
205         int32 value;
206         if (!strings::safe_strto32(token.c_str(), &value)) {
207           valid_ = false;
208           return errors::InvalidArgument("Field ", token, " in line ", next_id_,
209                                          " is not a valid int32.");
210         }
211         tensor->flat<int32>()(0) = value;
212       } break;
213       case DT_INT64: {
214         int64 value;
215         if (!strings::safe_strto64(token.c_str(), &value)) {
216           valid_ = false;
217           return errors::InvalidArgument("Field ", token, " in line ", next_id_,
218                                          " is not a valid int64.");
219         }
220         tensor->flat<int64>()(0) = value;
221       } break;
222       case DT_FLOAT: {
223         float value;
224         if (!strings::safe_strtof(token.c_str(), &value)) {
225           valid_ = false;
226           return errors::InvalidArgument("Field ", token, " in line ", next_id_,
227                                          " is not a valid float.");
228         }
229         tensor->flat<float>()(0) = value;
230       } break;
231       case DT_DOUBLE: {
232         double value;
233         if (!strings::safe_strtod(token.c_str(), &value)) {
234           valid_ = false;
235           return errors::InvalidArgument("Field ", token, " in line ", next_id_,
236                                          " is not a valid double.");
237         }
238         tensor->flat<double>()(0) = value;
239       } break;
240       case DT_STRING:
241         tensor->flat<string>()(0) = token;
242         break;
243       default:
244         valid_ = false;
245         return errors::InvalidArgument("Data type ", DataTypeString(dtype),
246                                        " not supported.");
247     }
248     return Status::OK();
249   }
250 
251   TF_DISALLOW_COPY_AND_ASSIGN(TextFileLineIterator);
252 };
253 
GetTableHandle(const string & input_name,OpKernelContext * ctx,string * container,string * table_handle)254 Status GetTableHandle(const string& input_name, OpKernelContext* ctx,
255                       string* container, string* table_handle) {
256   {
257     mutex* mu;
258     TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
259     mutex_lock l(*mu);
260     Tensor tensor;
261     TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
262     if (tensor.NumElements() != 2) {
263       return errors::InvalidArgument(
264           "Lookup table handle must be scalar, but had shape: ",
265           tensor.shape().DebugString());
266     }
267     auto h = tensor.flat<string>();
268     *container = h(0);
269     *table_handle = h(1);
270   }
271   return Status::OK();
272 }
273 
274 }  // namespace
275 
GetLookupTable(const string & input_name,OpKernelContext * ctx,LookupInterface ** table)276 Status GetLookupTable(const string& input_name, OpKernelContext* ctx,
277                       LookupInterface** table) {
278   string container;
279   string table_handle;
280   DataType handle_dtype;
281   TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
282   if (handle_dtype == DT_RESOURCE) {
283     ResourceHandle handle;
284     TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle));
285     return LookupResource(ctx, handle, table);
286   } else {
287     TF_RETURN_IF_ERROR(
288         GetTableHandle(input_name, ctx, &container, &table_handle));
289     return ctx->resource_manager()->Lookup(container, table_handle, table);
290   }
291 }
292 
GetInitializableLookupTable(const string & input_name,OpKernelContext * ctx,InitializableLookupTable ** table)293 Status GetInitializableLookupTable(const string& input_name,
294                                    OpKernelContext* ctx,
295                                    InitializableLookupTable** table) {
296   LookupInterface* lookup_table;
297   DataType handle_dtype;
298   TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
299   if (handle_dtype == DT_RESOURCE) {
300     ResourceHandle handle;
301     TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle));
302     TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &lookup_table));
303     *table = lookup_table->GetInitializableLookupTable();
304     if (*table == nullptr) {
305       lookup_table->Unref();
306       return errors::InvalidArgument("Table ", handle.container(), " ",
307                                      handle.name(), " is not initializable");
308     }
309   } else {
310     string container;
311     string table_handle;
312     TF_RETURN_IF_ERROR(
313         GetTableHandle(input_name, ctx, &container, &table_handle));
314     TF_RETURN_IF_ERROR(ctx->resource_manager()->Lookup(container, table_handle,
315                                                        &lookup_table));
316     *table = lookup_table->GetInitializableLookupTable();
317     if (*table == nullptr) {
318       lookup_table->Unref();
319       return errors::InvalidArgument("Table ", container, " ", table_handle,
320                                      " is not initializable");
321     }
322   }
323   return Status::OK();
324 }
325 
CheckTableDataTypes(const LookupInterface & table,DataType key_dtype,DataType value_dtype,const string & table_name)326 Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
327                            DataType value_dtype, const string& table_name) {
328   if (table.key_dtype() != key_dtype || table.value_dtype() != value_dtype) {
329     return errors::InvalidArgument(
330         "Conflicting key/value dtypes ", DataTypeString(key_dtype), "->",
331         DataTypeString(value_dtype), " with ",
332         DataTypeString(table.key_dtype()), "-",
333         DataTypeString(table.value_dtype()), " for table ", table_name);
334   }
335   return Status::OK();
336 }
337 
338 // Helper function to initialize an InitializableLookupTable from a text file.
InitializeTableFromTextFile(const string & filename,int64 vocab_size,char delimiter,int32 key_index,int32 value_index,Env * env,InitializableLookupTable * table)339 Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
340                                    char delimiter, int32 key_index,
341                                    int32 value_index, Env* env,
342                                    InitializableLookupTable* table) {
343   if (key_index == kLineNumber && table->key_dtype() != DT_INT64) {
344     return errors::InvalidArgument(
345         "Key index for line number requires table key dtype of int64, got ",
346         DataTypeString(table->key_dtype()));
347   }
348   const DataType& key_dtype = table->key_dtype();
349   const DataType& value_dtype = table->value_dtype();
350   if (key_index == kWholeLine && !DataTypeIsInteger(key_dtype) &&
351       key_dtype != DT_STRING) {
352     return errors::InvalidArgument(
353         "Key index for whole line requires string or integer table key, got ",
354         DataTypeString(table->key_dtype()));
355   }
356   if (value_index == kLineNumber && value_dtype != DT_INT64) {
357     return errors::InvalidArgument(
358         "Value index for line number requires table value dtype of int64, got ",
359         DataTypeString(table->value_dtype()));
360   }
361   if (value_index == kWholeLine && value_dtype != DT_STRING) {
362     return errors::InvalidArgument(
363         "Value index for whole line requires table value dtype of string, got ",
364         DataTypeString(table->value_dtype()));
365   }
366 
367   TextFileLineIterator iter;
368   TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype,
369                                key_index, value_dtype, value_index, env));
370   // For initialization from files, ignore if the table is already
371   // initialized. The table shared name should contain the filename to
372   // avoid trying to initialize the same table from the same file at the same
373   // time.
374   Status s = table->Initialize(iter);
375   if (errors::IsFailedPrecondition(s) && table->is_initialized()) {
376     LOG(INFO) << "Table trying to initialize from file " << filename
377               << " is already initialized.";
378     return Status::OK();
379   }
380   return s;
381 }
382 
383 }  // namespace lookup
384 }  // namespace tensorflow
385