• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "tensorflow/core/framework/bounds_check.h"
21 #include "tensorflow/core/framework/lookup_interface.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/graph/graph_def_builder.h"
27 #include "tensorflow/core/kernels/lookup_util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/gtl/map_util.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 
35 namespace tensorflow {
36 
37 // Lookup table op that supports different table implementations specified by
38 // the 'Container' template. Container must be derived from LookupInterface. The
39 // key and value are of the templated type "key_dtype" and "value_dtype"
40 // respectively.
41 template <class Container, class key_dtype, class value_dtype>
42 class LookupTableOp : public OpKernel {
43  public:
44   // ctx is not owned by this class.
LookupTableOp(OpKernelConstruction * ctx)45   explicit LookupTableOp(OpKernelConstruction* ctx)
46       : OpKernel(ctx), table_set_(false) {
47     if (ctx->output_type(0) == DT_RESOURCE) {
48       OP_REQUIRES_OK(ctx,
49                      ctx->allocate_temp(tensorflow::DT_RESOURCE,
50                                         tensorflow::TensorShape({}), &table_));
51     } else {
52       OP_REQUIRES_OK(ctx,
53                      ctx->allocate_temp(tensorflow::DT_STRING,
54                                         tensorflow::TensorShape({2}), &table_));
55     }
56     OP_REQUIRES_OK(
57         ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_));
58   }
59 
60   // ctx is not owned by this function.
Compute(OpKernelContext * ctx)61   void Compute(OpKernelContext* ctx) override {
62     mutex_lock l(mu_);
63 
64     if (!table_set_) {
65       OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
66                                       use_node_name_sharing_));
67     }
68 
69     auto creator =
70         [ctx, this](lookup::LookupInterface** ret)
71             TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
72               lookup::LookupInterface* container = new Container(ctx, this);
73               if (!ctx->status().ok()) {
74                 container->Unref();
75                 return ctx->status();
76               }
77               if (ctx->track_allocations()) {
78                 ctx->record_persistent_memory_allocation(
79                     container->MemoryUsed() + table_.AllocatedBytes());
80               }
81               *ret = container;
82               return Status::OK();
83             };
84 
85     lookup::LookupInterface* table = nullptr;
86     OP_REQUIRES_OK(ctx,
87                    cinfo_.resource_manager()
88                        ->template LookupOrCreate<lookup::LookupInterface>(
89                            cinfo_.container(), cinfo_.name(), &table, creator));
90     core::ScopedUnref unref_me(table);
91 
92     OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
93                             *table, DataTypeToEnum<key_dtype>::v(),
94                             DataTypeToEnum<value_dtype>::v(), cinfo_.name()));
95 
96     if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
97       if (!table_set_) {
98         auto h = table_.template scalar<ResourceHandle>();
99         h() = MakeResourceHandle<lookup::LookupInterface>(
100             ctx, cinfo_.container(), cinfo_.name());
101       }
102       ctx->set_output(0, table_);
103     } else {
104       if (!table_set_) {
105         auto h = table_.template flat<tstring>();
106         h(0) = cinfo_.container();
107         h(1) = cinfo_.name();
108       }
109       ctx->set_output_ref(0, &mu_, &table_);
110     }
111     table_set_ = true;
112   }
113 
~LookupTableOp()114   ~LookupTableOp() override {
115     // If the table object was not shared, delete it.
116     if (table_set_ && cinfo_.resource_is_private_to_kernel()) {
117       if (!cinfo_.resource_manager()
118                ->template Delete<lookup::LookupInterface>(cinfo_.container(),
119                                                           cinfo_.name())
120                .ok()) {
121         // Do nothing; the resource can have been deleted by session resets.
122       }
123     }
124   }
125 
126  private:
127   mutex mu_;
128   Tensor table_ TF_GUARDED_BY(mu_);
129   bool table_set_ TF_GUARDED_BY(mu_);
130   ContainerInfo cinfo_;
131   bool use_node_name_sharing_;
132 
133   TF_DISALLOW_COPY_AND_ASSIGN(LookupTableOp);
134 };
135 
136 namespace lookup {
137 
138 // Ensure that the compiler cannot elide a copy into a local, for
139 // bounds checking on source tensors that might be updated asynchronously for
140 // integral types. However non-integer variables are not allowed and therefore
141 // the local copy is unnecessary.
142 template <typename T>
SubtleMustCopyIfIntegral(const T & value)143 T SubtleMustCopyIfIntegral(const T& value) {
144   return internal::SubtleMustCopy(value);
145 }
146 
SubtleMustCopyIfIntegral(const tstring & value)147 inline const tstring& SubtleMustCopyIfIntegral(const tstring& value) {
148   return value;
149 }
150 
SubtleMustCopyIfIntegral(const float value)151 inline const float SubtleMustCopyIfIntegral(const float value) { return value; }
152 
SubtleMustCopyIfIntegral(const double value)153 inline const double SubtleMustCopyIfIntegral(const double value) {
154   return value;
155 }
156 
SubtleMustCopyIfIntegral(const Variant & value)157 inline const Variant& SubtleMustCopyIfIntegral(const Variant& value) {
158   return value;
159 }
160 
SubtleMustCopyIfIntegral(const ResourceHandle & value)161 inline const ResourceHandle& SubtleMustCopyIfIntegral(
162     const ResourceHandle& value) {
163   return value;
164 }
165 
166 // Returns a unique node name starting with "base".
167 std::string UniqueNodeName(const std::string& base);
168 
169 // Lookup table that wraps an flat_hash_map, where the key and value data type
170 // is specified.
171 //
172 // This table is recommended for any variations to key values.
173 //
174 // For look up, the table is required to be initialized (allocated
175 // and populated). Once the table is marked as initialized it becomes read-only.
176 //
177 // Sample use case:
178 //
179 // HashTable<int64, int64> table;  // int64 -> int64.
180 // table.Initialize(...);
181 // table.Find(in_t, &out_t, default_t)
182 //
183 template <class K, class V>
184 class HashTable : public InitializableLookupTable {
185  public:
HashTable(OpKernelContext * ctx,OpKernel * kernel)186   HashTable(OpKernelContext* ctx, OpKernel* kernel) {}
187 
AsGraphDef(GraphDefBuilder * builder,Node ** out)188   Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override {
189     // We set use_node_name_sharing with a unique node name so that the resource
190     // can outlive the HashTableV2 kernel. This means that the lifetime of the
191     // HashTable resource will be tied to the lifetime of the resource manager
192     // it is created in.
193     // TODO(b/181695913): Provide a mechanism for deleting this resource
194     // earlier when appropriate.
195     Node* hash_table_node = ops::SourceOp(
196         "HashTableV2", builder->opts()
197                            .WithName(UniqueNodeName("HashTableFromGraphDef"))
198                            .WithAttr("key_dtype", key_dtype())
199                            .WithAttr("value_dtype", value_dtype())
200                            .WithAttr("use_node_name_sharing", true));
201     if (table_.empty()) {
202       *out = hash_table_node;
203       return Status::OK();
204     }
205 
206     if (initializer_serializer_ == nullptr) {
207       std::string message =
208           "Failed to serialize lookup table: no initialization function was "
209           "specified. Falling back to serializing a handle to the table.";
210       LOG(WARNING) << message;
211       return errors::Unimplemented(message);
212     }
213     Node* initializer;
214     TF_RETURN_IF_ERROR(initializer_serializer_->AsGraphDef(
215         builder, hash_table_node, &initializer));
216     *out = ops::UnaryOp("Identity", hash_table_node,
217                         builder->opts().WithControlInput(initializer));
218     return Status::OK();
219   }
220 
size()221   size_t size() const override {
222     if (!is_initialized())
223       return 0;
224     else
225       return table_.size();
226   }
227 
ExportValues(OpKernelContext * context)228   Status ExportValues(OpKernelContext* context) override {
229     if (!is_initialized()) {
230       return errors::Aborted("HashTable is not initialized.");
231     }
232 
233     const int64_t size = table_.size();
234 
235     Tensor* keys;
236     Tensor* values;
237     TF_RETURN_IF_ERROR(
238         context->allocate_output("keys", TensorShape({size}), &keys));
239     TF_RETURN_IF_ERROR(
240         context->allocate_output("values", TensorShape({size}), &values));
241 
242     auto keys_data = keys->flat<K>();
243     auto values_data = values->flat<V>();
244     int64_t i = 0;
245     for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
246       keys_data(i) = it->first;
247       values_data(i) = it->second;
248     }
249     return Status::OK();
250   }
251 
key_dtype()252   DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
253 
value_dtype()254   DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
255 
256  protected:
DoPrepare(size_t size)257   Status DoPrepare(size_t size) override {
258     if (is_initialized()) {
259       return errors::Aborted("HashTable already initialized.");
260     }
261     if (size > 0) {
262       table_.reserve(size);
263     }
264     return Status::OK();
265   };
266 
DoLazyPrepare(std::function<int64 (void)> size_fn)267   Status DoLazyPrepare(std::function<int64(void)> size_fn) override {
268     return DoPrepare(size_fn());
269   }
270 
DoInsert(const Tensor & keys,const Tensor & values)271   Status DoInsert(const Tensor& keys, const Tensor& values) override {
272     const auto key_values = keys.flat<K>();
273     const auto value_values = values.flat<V>();
274     for (int64_t i = 0; i < key_values.size(); ++i) {
275       auto&& key = SubtleMustCopyIfIntegral(key_values(i));
276       auto&& value = SubtleMustCopyIfIntegral(value_values(i));
277       auto result = table_.try_emplace(key, value);
278       if (!result.second && result.first->second != value) {
279         return errors::FailedPrecondition(
280             "HashTable has different value for same key. Key ", key, " has ",
281             result.first->second, " and trying to add value ", value);
282       }
283     }
284     return Status::OK();
285   }
286 
DoFind(const Tensor & key,Tensor * value,const Tensor & default_value)287   Status DoFind(const Tensor& key, Tensor* value,
288                 const Tensor& default_value) override {
289     const V default_val = default_value.flat<V>()(0);
290     const auto key_values = key.flat<K>();
291     auto value_values = value->flat<V>();
292 
293     for (int64_t i = 0; i < key_values.size(); ++i) {
294       value_values(i) = gtl::FindWithDefault(
295           table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
296     }
297     return Status::OK();
298   }
299 
MemoryUsed()300   int64 MemoryUsed() const override {
301     if (!is_initialized()) {
302       return 0;
303     }
304     const int64_t num_elements = table_.size();
305     return num_elements * (sizeof(K) + sizeof(V));
306   }
307 
308  private:
309   absl::flat_hash_map<K, V> table_;
310 };
311 
312 }  // namespace lookup
313 
314 }  // namespace tensorflow
315 
316 #endif  // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
317