• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include "tensorflow/core/kernels/lookup_table_init_op.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/kernels/lookup_util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/io/inputbuffer.h"
33 #include "tensorflow/core/lib/strings/numbers.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/platform/macros.h"
36 
37 namespace tensorflow {
38 
39 // Kernel to initialize a look table given a key and value tensors.
40 // After this operation, the table becomes read-only.
41 class InitializeTableOp : public OpKernel {
42  public:
InitializeTableOp(OpKernelConstruction * context)43   explicit InitializeTableOp(OpKernelConstruction* context)
44       : OpKernel(context) {}
45 
Compute(OpKernelContext * ctx)46   void Compute(OpKernelContext* ctx) override {
47     mutex_lock l(mu_);
48     lookup::InitializableLookupTable* table;
49     OP_REQUIRES_OK(ctx,
50                    GetInitializableLookupTable("table_handle", ctx, &table));
51     core::ScopedUnref unref_me(table);
52 
53     DataType expected_input_0 =
54         (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
55     DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
56                                       table->value_dtype()};
57     DataTypeVector expected_outputs = {};
58     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
59 
60     const Tensor& keys = ctx->input(1);
61     OP_REQUIRES(
62         ctx, TensorShapeUtils::IsVector(keys.shape()),
63         errors::InvalidArgument("Keys must be a vector, but received shape",
64                                 keys.shape().DebugString()));
65 
66     const Tensor& values = ctx->input(2);
67     OP_REQUIRES(
68         ctx, TensorShapeUtils::IsVector(values.shape()),
69         errors::InvalidArgument("Values must be a vector, but received shape",
70                                 values.shape().DebugString()));
71 
72     OP_REQUIRES(ctx, keys.NumElements() == values.NumElements(),
73                 errors::InvalidArgument(
74                     "Keys and values must have the same size ",
75                     keys.NumElements(), " vs ", values.NumElements()));
76 
77     int memory_used_before = 0;
78     if (ctx->track_allocations()) {
79       memory_used_before = table->MemoryUsed();
80     }
81     OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values));
82     if (ctx->track_allocations()) {
83       ctx->record_persistent_memory_allocation(table->MemoryUsed() -
84                                                memory_used_before);
85     }
86   }
87 
88  private:
89   mutex mu_;
90 };
91 
92 REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU),
93                         InitializeTableOp);
94 REGISTER_KERNEL_BUILDER(Name("InitializeTableV2").Device(DEVICE_CPU),
95                         InitializeTableOp);
96 
97 // Kernel to initialize a lookup table from a text file.
98 //
99 // After this operation, the table becomes read-only.
100 class InitializeTableFromTextFileOp : public OpKernel {
101  public:
InitializeTableFromTextFileOp(OpKernelConstruction * ctx)102   explicit InitializeTableFromTextFileOp(OpKernelConstruction* ctx)
103       : OpKernel(ctx) {
104     OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_));
105     OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_));
106     OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_));
107     string delimiter;
108     OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter));
109     OP_REQUIRES(ctx, delimiter.size() == 1,
110                 errors::InvalidArgument("delimiter should be only 1 char"));
111     delimiter_ = delimiter[0];
112   }
113 
Compute(OpKernelContext * ctx)114   void Compute(OpKernelContext* ctx) override {
115     mutex_lock l(mu_);
116     lookup::InitializableLookupTable* table;
117     OP_REQUIRES_OK(ctx,
118                    GetInitializableLookupTable("table_handle", ctx, &table));
119     core::ScopedUnref unref_me(table);
120 
121     DataType expected_input_0 =
122         (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
123     DataTypeVector expected_inputs = {expected_input_0, DT_STRING};
124     DataTypeVector expected_outputs = {};
125     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
126 
127     const Tensor& vocab_filename_tensor = ctx->input(1);
128     OP_REQUIRES(
129         ctx, TensorShapeUtils::IsScalar(vocab_filename_tensor.shape()),
130         errors::InvalidArgument("filename should be a single string, but got ",
131                                 vocab_filename_tensor.shape().DebugString()));
132 
133     string vocab_filename = vocab_filename_tensor.scalar<string>()();
134     OP_REQUIRES(ctx, !vocab_filename.empty(),
135                 errors::InvalidArgument("filename cannot be empty."));
136 
137     int64 memory_used_before = 0;
138     if (ctx->track_allocations()) {
139       memory_used_before = table->MemoryUsed();
140     }
141     OP_REQUIRES_OK(ctx, lookup::InitializeTableFromTextFile(
142                             vocab_filename, vocab_size_, delimiter_, key_index_,
143                             value_index_, ctx->env(), table));
144     if (ctx->track_allocations()) {
145       ctx->record_persistent_memory_allocation(table->MemoryUsed() -
146                                                memory_used_before);
147     }
148   }
149 
150  private:
151   mutex mu_;
152   int64 vocab_size_;
153   char delimiter_;
154   int64 key_index_;
155   int64 value_index_;
156 
157   TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp);
158 };
159 
160 REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile").Device(DEVICE_CPU),
161                         InitializeTableFromTextFileOp);
162 REGISTER_KERNEL_BUILDER(
163     Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU),
164     InitializeTableFromTextFileOp);
165 
166 }  // namespace tensorflow
167