• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
17 #define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
18 
19 #include <atomic>
20 
21 #include "tensorflow/core/framework/lookup_interface.h"
22 #include "tensorflow/core/platform/macros.h"
23 
24 namespace tensorflow {
25 namespace lookup {
26 
27 // Base class for lookup tables that require initialization.
28 class InitializableLookupTable : public LookupInterface {
29  public:
30   class InitTableIterator;
31   class InitializerSerializer;
32 
33   // Performs batch lookups, for every element in the key tensor, Find returns
34   // the corresponding value into the values tensor.
35   // If an element is not present in the table, the given default value is used.
36   //
37   // For tables that require initialization, `Find` is available once the table
38   // is marked as initialized.
39   //
40   // Returns the following statuses:
41   // - OK: when the find finishes successfully.
42   // - FailedPrecondition: if the table is not initialized.
43   // - InvalidArgument: if any of the preconditions on the lookup key or value
44   //   fails.
45   // - In addition, other implementations may provide another non-OK status
46   //   specific to their failure modes.
47   Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values,
48               const Tensor& default_value) final;
49 
50   // Returns errors::Unimplemented.
Insert(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)51   Status Insert(OpKernelContext* ctx, const Tensor& keys,
52                 const Tensor& values) final {
53     return errors::Unimplemented(
54         "Insert not supported by InitializableLookupTable implementations");
55   }
56 
57   // Returns errors::Unimplemented.
Remove(OpKernelContext * ctx,const Tensor & keys)58   Status Remove(OpKernelContext* ctx, const Tensor& keys) final {
59     return errors::Unimplemented(
60         "Remove not supported by InitializableLookupTable implementations");
61   }
62 
ExportValues(OpKernelContext * context)63   Status ExportValues(OpKernelContext* context) override {
64     return errors::Unimplemented(
65         "ExportValues not supported by InitializableLookupTable "
66         "implementations");
67   }
68 
69   Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
70                       const Tensor& values) final;
71 
key_shape()72   TensorShape key_shape() const final { return TensorShape(); }
73 
value_shape()74   TensorShape value_shape() const final { return TensorShape(); }
75 
76   // Returns whether the table was initialized and is ready to serve lookups.
is_initialized()77   bool is_initialized() const {
78     return is_initialized_.load(std::memory_order_acquire);
79   }
80 
81   // Initializes the table from the given init table iterator.
82   //
83   // Atomically, this operation prepares the table, populates it with the given
84   // iterator, and marks the table as initialized.
85   //
86   // Returns the following statuses:
87   // - OK: when the initialization was successful.
88   // - InvalidArgument: if any of the preconditions on the lookup key or value
89   //   fails.
90   // - FailedPrecondition: if the table is already initialized and
91   //   fail_if_initialized is set to true.
92   // - In addition, other implementations may provide another non-OK status
93   //   specific to their failure modes.
94   Status Initialize(InitTableIterator& iter);
95 
96   // Initializes the table from the given init table iterator. `serializer` may
97   // specify how to serialize the table initializer, so that the table can be
98   // serialized using its metadata (as opposed to serializing a handle to the
99   // table).
100   Status Initialize(InitTableIterator& iter,
101                     std::unique_ptr<InitializerSerializer> serializer);
102 
103   // Basic iterator to initialize lookup tables.
104   // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that
105   // the consumer may insert key-value pairs in batches.
106   //
107   // Then the iterator is exhausted, valid returns false and status returns
108   // Status::OutOfRange.
109   //
110   // This class is Thread-unsafe.
111   class InitTableIterator {
112    public:
InitTableIterator()113     InitTableIterator() {}
114 
~InitTableIterator()115     virtual ~InitTableIterator() {}
116 
117     // Prepares the next batch of key and value tensors.
118     virtual void Next() = 0;
119 
120     // Returns true if keys and values point to valid tensors.
121     virtual bool Valid() const = 0;
122 
123     // Returns a tensor that contains the current batch of 'key' values.
124     virtual const Tensor& keys() const = 0;
125 
126     // Returns a tensor that contains the current batch of 'value' values.
127     virtual const Tensor& values() const = 0;
128 
129     // Returns an error if one has occurred, otherwise returns Status::OK.
130     virtual Status status() const = 0;
131 
132     // Returns the total number of elements that the iterator will produce.
133     // It might return -1 in case of error.
134     virtual int64 total_size() const = 0;
135 
136    private:
137     TF_DISALLOW_COPY_AND_ASSIGN(InitTableIterator);
138   };
139 
GetInitializableLookupTable()140   InitializableLookupTable* GetInitializableLookupTable() override {
141     return this;
142   }
143 
144   // Logic specifying how to represent an initializer as a GraphDef, so that a
145   // lookup table can be serialized using its metadata (as opposed to
146   // serializing the content of the table, or a handle to the table).
147   class InitializerSerializer {
148    public:
149     // A function which builds a graph so that executing `*out` will initialize
150     // `table`.
151     using SerializeFn = std::function<Status(GraphDefBuilder* builder,
152                                              Node* table, Node** out)>;
153     // A function which performs any necessary cleanup for the serializer.
154     using CleanupFn = std::function<void()>;
155 
156     // Wraps serialization logic that requires no cleanup.
InitializerSerializer(SerializeFn serialize)157     explicit InitializerSerializer(SerializeFn serialize)
158         : serialize_(std::move(serialize)), cleanup_([] {}) {}
159 
160     // Wraps serialization logic along with a cleanup function. `cleanup` will
161     // be run when the serializer is destroyed.
InitializerSerializer(SerializeFn serialize,CleanupFn cleanup)162     explicit InitializerSerializer(SerializeFn serialize, CleanupFn cleanup)
163         : serialize_(std::move(serialize)), cleanup_(std::move(cleanup)) {}
164 
~InitializerSerializer()165     ~InitializerSerializer() { cleanup_(); }
166 
167     // Builds a graph so that executing `*out` will initialize `table`.
AsGraphDef(GraphDefBuilder * builder,Node * table,Node ** out)168     Status AsGraphDef(GraphDefBuilder* builder, Node* table, Node** out) {
169       return serialize_(builder, table, out);
170     }
171 
172    private:
173     SerializeFn serialize_;
174     CleanupFn cleanup_;
175   };
176 
177  protected:
178   // Prepares and allocates the underlying data structure to store the given
179   // number of expected elements.
180   virtual Status DoPrepare(size_t expected_num_elements) = 0;
181 
182   // Same as DoPrepare() but derived implementations might choose to skip
183   // calling get_expected_num_elements if size is not needed for DoPrepare.
DoLazyPrepare(std::function<int64 (void)> get_expected_num_elements)184   virtual Status DoLazyPrepare(
185       std::function<int64(void)> get_expected_num_elements) {
186     int64_t expected_num_elements = get_expected_num_elements();
187     if (expected_num_elements < 0) {
188       return errors::FailedPrecondition("Got negative expected_num_elements.");
189     }
190     return DoPrepare(expected_num_elements);
191   }
192 
193   // Populates the table in batches given keys and values as tensors into the
194   // underlying data structure.
195   virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0;
196 
197   // Performs the batch find operation on the underlying data structure.
198   virtual Status DoFind(const Tensor& keys, Tensor* values,
199                         const Tensor& default_value) = 0;
200 
201   virtual Status AreEntriesSame(const InitTableIterator& iter, bool* result);
202 
203   mutex mu_;
204 
205  protected:
206   // When set, provides a mechanism for serializing the table initializer as
207   // GraphDef.
208   std::unique_ptr<InitializerSerializer> initializer_serializer_;
209 
210  private:
211   std::atomic<bool> is_initialized_{false};
212 };
213 
214 // Iterator to initialize tables given 'keys' and 'values' tensors.
215 //
216 // The two tensors are returned in the first iteration. It doesn't loop
217 // over each element of the tensor since insertions in the lookup table can
218 // process batches.
219 class KeyValueTensorIterator
220     : public InitializableLookupTable::InitTableIterator {
221  public:
222   // keys and values are not owned by the iterator.
KeyValueTensorIterator(const Tensor * keys,const Tensor * values)223   explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values)
224       : keys_(keys), values_(values), valid_(true), status_(Status::OK()) {
225     TensorShape key_shape = keys_->shape();
226     if (!key_shape.IsSameSize(values_->shape())) {
227       valid_ = false;
228       status_ = errors::InvalidArgument(
229           "keys and values should have the same dimension.",
230           key_shape.DebugString(), " vs ", values_->shape().DebugString());
231     }
232     if (key_shape.num_elements() == 0) {
233       valid_ = false;
234       status_ =
235           errors::InvalidArgument("keys and values cannot be empty tensors.");
236     }
237   }
238 
Valid()239   bool Valid() const override { return valid_; }
240 
Next()241   void Next() override {
242     valid_ = false;
243     status_ = errors::OutOfRange("No more data.");
244   }
245 
keys()246   const Tensor& keys() const override { return *keys_; }
247 
values()248   const Tensor& values() const override { return *values_; }
249 
status()250   Status status() const override { return status_; }
251 
total_size()252   int64 total_size() const override {
253     return keys_ == nullptr ? -1 : keys_->NumElements();
254   }
255 
256  private:
257   TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator);
258 
259   const Tensor* keys_;    // Doesn't own it.
260   const Tensor* values_;  // Doesn't own it.
261   bool valid_;            // true if the iterator points to an existing range.
262   Status status_;
263 };
264 
265 }  // namespace lookup
266 }  // namespace tensorflow
267 
268 #endif  // TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
269