• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
18 
19 #include "tensorflow/core/framework/bounds_check.h"
20 #include "tensorflow/core/framework/lookup_interface.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/kernels/lookup_util.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/thread_annotations.h"
31 
32 namespace tensorflow {
33 
34 // Lookup table op that supports different table implementations specified by
35 // the 'Container' template. Container must be derived from LookupInterface. The
36 // key and value are of the templated type "key_dtype" and "value_dtype"
37 // respectively.
38 template <class Container, class key_dtype, class value_dtype>
39 class LookupTableOp : public OpKernel {
40  public:
41   // ctx is not owned by this class.
LookupTableOp(OpKernelConstruction * ctx)42   explicit LookupTableOp(OpKernelConstruction* ctx)
43       : OpKernel(ctx), table_handle_set_(false) {
44     OP_REQUIRES_OK(ctx, ctx->allocate_persistent(tensorflow::DT_STRING,
45                                                  tensorflow::TensorShape({2}),
46                                                  &table_handle_, nullptr));
47     OP_REQUIRES_OK(
48         ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_));
49   }
50 
51   // ctx is not owned by this function.
Compute(OpKernelContext * ctx)52   void Compute(OpKernelContext* ctx) override {
53     mutex_lock l(mu_);
54 
55     if (!table_handle_set_) {
56       OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
57                                       use_node_name_sharing_));
58     }
59 
60     auto creator = [ctx, this](lookup::LookupInterface** ret) {
61       lookup::LookupInterface* container = new Container(ctx, this);
62       if (!ctx->status().ok()) {
63         container->Unref();
64         return ctx->status();
65       }
66       if (ctx->track_allocations()) {
67         ctx->record_persistent_memory_allocation(
68             container->MemoryUsed() + table_handle_.AllocatedBytes());
69       }
70       *ret = container;
71       return Status::OK();
72     };
73 
74     lookup::LookupInterface* table = nullptr;
75     OP_REQUIRES_OK(ctx,
76                    cinfo_.resource_manager()
77                        ->template LookupOrCreate<lookup::LookupInterface>(
78                            cinfo_.container(), cinfo_.name(), &table, creator));
79     core::ScopedUnref unref_me(table);
80 
81     OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
82                             *table, DataTypeToEnum<key_dtype>::v(),
83                             DataTypeToEnum<value_dtype>::v(), cinfo_.name()));
84 
85     if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
86       Tensor* handle;
87       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
88       handle->scalar<ResourceHandle>()() =
89           MakeResourceHandle<lookup::LookupInterface>(ctx, cinfo_.container(),
90                                                       cinfo_.name());
91     } else {
92       if (!table_handle_set_) {
93         auto h = table_handle_.AccessTensor(ctx)->template flat<string>();
94         h(0) = cinfo_.container();
95         h(1) = cinfo_.name();
96       }
97       ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx));
98     }
99     table_handle_set_ = true;
100   }
101 
~LookupTableOp()102   ~LookupTableOp() override {
103     // If the table object was not shared, delete it.
104     if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
105       if (!cinfo_.resource_manager()
106                ->template Delete<lookup::LookupInterface>(cinfo_.container(),
107                                                           cinfo_.name())
108                .ok()) {
109         // Do nothing; the resource can have been deleted by session resets.
110       }
111     }
112   }
113 
114  private:
115   mutex mu_;
116   PersistentTensor table_handle_ GUARDED_BY(mu_);
117   bool table_handle_set_ GUARDED_BY(mu_);
118   ContainerInfo cinfo_;
119   bool use_node_name_sharing_;
120 
121   TF_DISALLOW_COPY_AND_ASSIGN(LookupTableOp);
122 };
123 
124 namespace lookup {
125 
126 // Ensure that the compiler cannot elide a copy into a local, for
127 // bounds checking on source tensors that might be updated asynchronously for
128 // integral types. However non-integer variables are not allowed and therefore
129 // the local copy is unnecessary.
130 template <typename T>
SubtleMustCopyIfIntegral(const T & value)131 T SubtleMustCopyIfIntegral(const T& value) {
132   return internal::SubtleMustCopy(value);
133 }
134 
SubtleMustCopyIfIntegral(const string & value)135 inline const string& SubtleMustCopyIfIntegral(const string& value) {
136   return value;
137 }
138 
SubtleMustCopyIfIntegral(const float value)139 inline const float SubtleMustCopyIfIntegral(const float value) { return value; }
140 
SubtleMustCopyIfIntegral(const double value)141 inline const double SubtleMustCopyIfIntegral(const double value) {
142   return value;
143 }
144 
SubtleMustCopyIfIntegral(const Variant & value)145 inline const Variant& SubtleMustCopyIfIntegral(const Variant& value) {
146   return value;
147 }
148 
149 // Lookup table that wraps an unordered_map, where the key and value data type
150 // is specified.
151 //
152 // This table is recommended for any variations to key values.
153 //
154 // For look up, the table is required to be initialized (allocated
155 // and populated). Once the table is marked as initialized it becomes read-only.
156 //
157 // Sample use case:
158 //
159 // HashTable<int64, int64> table;  // int64 -> int64.
160 // table.Prepare(10); // Prepare the underlying data structure, the number of
161 //                    // elements is required by interface, but not used.
162 // // Populate the table, elements could be added in one or multiple calls.
163 // table.Insert(key_tensor, value_tensor); // Populate the table.
164 // ...
165 // table.set_is_initialized();
166 //
167 // table.Find(in_t, &out_t, default_t)
168 //
169 template <class K, class V>
170 class HashTable : public InitializableLookupTable {
171  public:
HashTable(OpKernelContext * ctx,OpKernel * kernel)172   HashTable(OpKernelContext* ctx, OpKernel* kernel) {}
173 
size()174   size_t size() const override {
175     // return the size of the table only if it's initialized, otherwise 0.
176     if (!is_initialized_) {
177       return 0;
178     }
179     std::atomic_thread_fence(std::memory_order_acquire);
180     return table_ ? table_->size() : 0;
181   }
182 
ExportValues(OpKernelContext * context)183   Status ExportValues(OpKernelContext* context) override {
184     if (!is_initialized_) {
185       return errors::Aborted("HashTable is not initialized.");
186     }
187 
188     const int64 size = table_->size();
189 
190     Tensor* keys;
191     Tensor* values;
192     TF_RETURN_IF_ERROR(
193         context->allocate_output("keys", TensorShape({size}), &keys));
194     TF_RETURN_IF_ERROR(
195         context->allocate_output("values", TensorShape({size}), &values));
196 
197     auto keys_data = keys->flat<K>();
198     auto values_data = values->flat<V>();
199     int64 i = 0;
200     for (auto it = table_->begin(); it != table_->end(); ++it, ++i) {
201       keys_data(i) = it->first;
202       values_data(i) = it->second;
203     }
204     return Status::OK();
205   }
206 
key_dtype()207   DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
208 
value_dtype()209   DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
210 
211  protected:
DoPrepare(size_t unused)212   Status DoPrepare(size_t unused) override {
213     if (is_initialized_) {
214       return errors::Aborted("HashTable already initialized.");
215     }
216     if (!table_) {
217       table_ = std::unique_ptr<std::unordered_map<K, V>>(
218           new std::unordered_map<K, V>());
219     }
220     return Status::OK();
221   };
222 
DoLazyPrepare(std::function<int64 (void)> unused)223   Status DoLazyPrepare(std::function<int64(void)> unused) override {
224     constexpr size_t kUnusedSize = 0;
225     return DoPrepare(kUnusedSize);
226   }
227 
DoInsert(const Tensor & keys,const Tensor & values)228   Status DoInsert(const Tensor& keys, const Tensor& values) override {
229     if (!table_) {
230       return errors::FailedPrecondition("HashTable is not prepared.");
231     }
232 
233     const auto key_values = keys.flat<K>();
234     const auto value_values = values.flat<V>();
235     for (int64 i = 0; i < key_values.size(); ++i) {
236       const K key = SubtleMustCopyIfIntegral(key_values(i));
237       const V value = SubtleMustCopyIfIntegral(value_values(i));
238       const V& previous_value = gtl::LookupOrInsert(table_.get(), key, value);
239       if (previous_value != value) {
240         return errors::FailedPrecondition(
241             "HashTable has different value for same key. Key ", key, " has ",
242             previous_value, " and trying to add value ", value);
243       }
244     }
245     return Status::OK();
246   }
247 
DoFind(const Tensor & key,Tensor * value,const Tensor & default_value)248   Status DoFind(const Tensor& key, Tensor* value,
249                 const Tensor& default_value) override {
250     const V default_val = default_value.flat<V>()(0);
251     const auto key_values = key.flat<K>();
252     auto value_values = value->flat<V>();
253 
254     for (int64 i = 0; i < key_values.size(); ++i) {
255       value_values(i) = gtl::FindWithDefault(
256           *table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
257     }
258     return Status::OK();
259   }
260 
MemoryUsed()261   int64 MemoryUsed() const override {
262     if (table_) {
263       const int64 num_elements = table_->size();
264       return num_elements * (sizeof(K) + sizeof(V));
265     } else {
266       return 0;
267     }
268   }
269 
270  private:
271   std::unique_ptr<std::unordered_map<K, V>> table_;
272 };
273 
274 }  // namespace lookup
275 
276 }  // namespace tensorflow
277 
278 #endif  // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
279