1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
18
19 #include "tensorflow/core/framework/bounds_check.h"
20 #include "tensorflow/core/framework/lookup_interface.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/kernels/lookup_util.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/thread_annotations.h"
31
32 namespace tensorflow {
33
34 // Lookup table op that supports different table implementations specified by
35 // the 'Container' template. Container must be derived from LookupInterface. The
36 // key and value are of the templated type "key_dtype" and "value_dtype"
37 // respectively.
38 template <class Container, class key_dtype, class value_dtype>
39 class LookupTableOp : public OpKernel {
40 public:
41 // ctx is not owned by this class.
LookupTableOp(OpKernelConstruction * ctx)42 explicit LookupTableOp(OpKernelConstruction* ctx)
43 : OpKernel(ctx), table_handle_set_(false) {
44 OP_REQUIRES_OK(ctx, ctx->allocate_persistent(tensorflow::DT_STRING,
45 tensorflow::TensorShape({2}),
46 &table_handle_, nullptr));
47 OP_REQUIRES_OK(
48 ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_));
49 }
50
51 // ctx is not owned by this function.
Compute(OpKernelContext * ctx)52 void Compute(OpKernelContext* ctx) override {
53 mutex_lock l(mu_);
54
55 if (!table_handle_set_) {
56 OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
57 use_node_name_sharing_));
58 }
59
60 auto creator = [ctx, this](lookup::LookupInterface** ret) {
61 lookup::LookupInterface* container = new Container(ctx, this);
62 if (!ctx->status().ok()) {
63 container->Unref();
64 return ctx->status();
65 }
66 if (ctx->track_allocations()) {
67 ctx->record_persistent_memory_allocation(
68 container->MemoryUsed() + table_handle_.AllocatedBytes());
69 }
70 *ret = container;
71 return Status::OK();
72 };
73
74 lookup::LookupInterface* table = nullptr;
75 OP_REQUIRES_OK(ctx,
76 cinfo_.resource_manager()
77 ->template LookupOrCreate<lookup::LookupInterface>(
78 cinfo_.container(), cinfo_.name(), &table, creator));
79 core::ScopedUnref unref_me(table);
80
81 OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
82 *table, DataTypeToEnum<key_dtype>::v(),
83 DataTypeToEnum<value_dtype>::v(), cinfo_.name()));
84
85 if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
86 Tensor* handle;
87 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
88 handle->scalar<ResourceHandle>()() =
89 MakeResourceHandle<lookup::LookupInterface>(ctx, cinfo_.container(),
90 cinfo_.name());
91 } else {
92 if (!table_handle_set_) {
93 auto h = table_handle_.AccessTensor(ctx)->template flat<string>();
94 h(0) = cinfo_.container();
95 h(1) = cinfo_.name();
96 }
97 ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx));
98 }
99 table_handle_set_ = true;
100 }
101
~LookupTableOp()102 ~LookupTableOp() override {
103 // If the table object was not shared, delete it.
104 if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
105 if (!cinfo_.resource_manager()
106 ->template Delete<lookup::LookupInterface>(cinfo_.container(),
107 cinfo_.name())
108 .ok()) {
109 // Do nothing; the resource can have been deleted by session resets.
110 }
111 }
112 }
113
114 private:
115 mutex mu_;
116 PersistentTensor table_handle_ GUARDED_BY(mu_);
117 bool table_handle_set_ GUARDED_BY(mu_);
118 ContainerInfo cinfo_;
119 bool use_node_name_sharing_;
120
121 TF_DISALLOW_COPY_AND_ASSIGN(LookupTableOp);
122 };
123
124 namespace lookup {
125
126 // Ensure that the compiler cannot elide a copy into a local, for
127 // bounds checking on source tensors that might be updated asynchronously for
128 // integral types. However non-integer variables are not allowed and therefore
129 // the local copy is unnecessary.
130 template <typename T>
SubtleMustCopyIfIntegral(const T & value)131 T SubtleMustCopyIfIntegral(const T& value) {
132 return internal::SubtleMustCopy(value);
133 }
134
SubtleMustCopyIfIntegral(const string & value)135 inline const string& SubtleMustCopyIfIntegral(const string& value) {
136 return value;
137 }
138
SubtleMustCopyIfIntegral(const float value)139 inline const float SubtleMustCopyIfIntegral(const float value) { return value; }
140
SubtleMustCopyIfIntegral(const double value)141 inline const double SubtleMustCopyIfIntegral(const double value) {
142 return value;
143 }
144
SubtleMustCopyIfIntegral(const Variant & value)145 inline const Variant& SubtleMustCopyIfIntegral(const Variant& value) {
146 return value;
147 }
148
149 // Lookup table that wraps an unordered_map, where the key and value data type
150 // is specified.
151 //
152 // This table is recommended for any variations to key values.
153 //
154 // For look up, the table is required to be initialized (allocated
155 // and populated). Once the table is marked as initialized it becomes read-only.
156 //
157 // Sample use case:
158 //
159 // HashTable<int64, int64> table; // int64 -> int64.
160 // table.Prepare(10); // Prepare the underlying data structure, the number of
161 // // elements is required by interface, but not used.
162 // // Populate the table, elements could be added in one or multiple calls.
163 // table.Insert(key_tensor, value_tensor); // Populate the table.
164 // ...
165 // table.set_is_initialized();
166 //
167 // table.Find(in_t, &out_t, default_t)
168 //
169 template <class K, class V>
170 class HashTable : public InitializableLookupTable {
171 public:
HashTable(OpKernelContext * ctx,OpKernel * kernel)172 HashTable(OpKernelContext* ctx, OpKernel* kernel) {}
173
size()174 size_t size() const override {
175 // return the size of the table only if it's initialized, otherwise 0.
176 if (!is_initialized_) {
177 return 0;
178 }
179 std::atomic_thread_fence(std::memory_order_acquire);
180 return table_ ? table_->size() : 0;
181 }
182
ExportValues(OpKernelContext * context)183 Status ExportValues(OpKernelContext* context) override {
184 if (!is_initialized_) {
185 return errors::Aborted("HashTable is not initialized.");
186 }
187
188 const int64 size = table_->size();
189
190 Tensor* keys;
191 Tensor* values;
192 TF_RETURN_IF_ERROR(
193 context->allocate_output("keys", TensorShape({size}), &keys));
194 TF_RETURN_IF_ERROR(
195 context->allocate_output("values", TensorShape({size}), &values));
196
197 auto keys_data = keys->flat<K>();
198 auto values_data = values->flat<V>();
199 int64 i = 0;
200 for (auto it = table_->begin(); it != table_->end(); ++it, ++i) {
201 keys_data(i) = it->first;
202 values_data(i) = it->second;
203 }
204 return Status::OK();
205 }
206
key_dtype()207 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
208
value_dtype()209 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
210
211 protected:
DoPrepare(size_t unused)212 Status DoPrepare(size_t unused) override {
213 if (is_initialized_) {
214 return errors::Aborted("HashTable already initialized.");
215 }
216 if (!table_) {
217 table_ = std::unique_ptr<std::unordered_map<K, V>>(
218 new std::unordered_map<K, V>());
219 }
220 return Status::OK();
221 };
222
DoLazyPrepare(std::function<int64 (void)> unused)223 Status DoLazyPrepare(std::function<int64(void)> unused) override {
224 constexpr size_t kUnusedSize = 0;
225 return DoPrepare(kUnusedSize);
226 }
227
DoInsert(const Tensor & keys,const Tensor & values)228 Status DoInsert(const Tensor& keys, const Tensor& values) override {
229 if (!table_) {
230 return errors::FailedPrecondition("HashTable is not prepared.");
231 }
232
233 const auto key_values = keys.flat<K>();
234 const auto value_values = values.flat<V>();
235 for (int64 i = 0; i < key_values.size(); ++i) {
236 const K key = SubtleMustCopyIfIntegral(key_values(i));
237 const V value = SubtleMustCopyIfIntegral(value_values(i));
238 const V& previous_value = gtl::LookupOrInsert(table_.get(), key, value);
239 if (previous_value != value) {
240 return errors::FailedPrecondition(
241 "HashTable has different value for same key. Key ", key, " has ",
242 previous_value, " and trying to add value ", value);
243 }
244 }
245 return Status::OK();
246 }
247
DoFind(const Tensor & key,Tensor * value,const Tensor & default_value)248 Status DoFind(const Tensor& key, Tensor* value,
249 const Tensor& default_value) override {
250 const V default_val = default_value.flat<V>()(0);
251 const auto key_values = key.flat<K>();
252 auto value_values = value->flat<V>();
253
254 for (int64 i = 0; i < key_values.size(); ++i) {
255 value_values(i) = gtl::FindWithDefault(
256 *table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
257 }
258 return Status::OK();
259 }
260
MemoryUsed()261 int64 MemoryUsed() const override {
262 if (table_) {
263 const int64 num_elements = table_->size();
264 return num_elements * (sizeof(K) + sizeof(V));
265 } else {
266 return 0;
267 }
268 }
269
270 private:
271 std::unique_ptr<std::unordered_map<K, V>> table_;
272 };
273
274 } // namespace lookup
275
276 } // namespace tensorflow
277
278 #endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
279