1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ 17 #define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ 18 19 #include <atomic> 20 21 #include "tensorflow/core/framework/lookup_interface.h" 22 #include "tensorflow/core/platform/macros.h" 23 24 namespace tensorflow { 25 namespace lookup { 26 27 // Base class for lookup tables that require initialization. 28 class InitializableLookupTable : public LookupInterface { 29 public: 30 class InitTableIterator; 31 class InitializerSerializer; 32 33 // Performs batch lookups, for every element in the key tensor, Find returns 34 // the corresponding value into the values tensor. 35 // If an element is not present in the table, the given default value is used. 36 // 37 // For tables that require initialization, `Find` is available once the table 38 // is marked as initialized. 39 // 40 // Returns the following statuses: 41 // - OK: when the find finishes successfully. 42 // - FailedPrecondition: if the table is not initialized. 43 // - InvalidArgument: if any of the preconditions on the lookup key or value 44 // fails. 45 // - In addition, other implementations may provide another non-OK status 46 // specific to their failure modes. 47 Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values, 48 const Tensor& default_value) final; 49 50 // Returns errors::Unimplemented. Insert(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)51 Status Insert(OpKernelContext* ctx, const Tensor& keys, 52 const Tensor& values) final { 53 return errors::Unimplemented( 54 "Insert not supported by InitializableLookupTable implementations"); 55 } 56 57 // Returns errors::Unimplemented. Remove(OpKernelContext * ctx,const Tensor & keys)58 Status Remove(OpKernelContext* ctx, const Tensor& keys) final { 59 return errors::Unimplemented( 60 "Remove not supported by InitializableLookupTable implementations"); 61 } 62 ExportValues(OpKernelContext * context)63 Status ExportValues(OpKernelContext* context) override { 64 return errors::Unimplemented( 65 "ExportValues not supported by InitializableLookupTable " 66 "implementations"); 67 } 68 69 Status ImportValues(OpKernelContext* ctx, const Tensor& keys, 70 const Tensor& values) final; 71 key_shape()72 TensorShape key_shape() const final { return TensorShape(); } 73 value_shape()74 TensorShape value_shape() const final { return TensorShape(); } 75 76 // Returns whether the table was initialized and is ready to serve lookups. is_initialized()77 bool is_initialized() const { 78 return is_initialized_.load(std::memory_order_acquire); 79 } 80 81 // Initializes the table from the given init table iterator. 82 // 83 // Atomically, this operation prepares the table, populates it with the given 84 // iterator, and marks the table as initialized. 85 // 86 // Returns the following statuses: 87 // - OK: when the initialization was successful. 88 // - InvalidArgument: if any of the preconditions on the lookup key or value 89 // fails. 90 // - FailedPrecondition: if the table is already initialized and 91 // fail_if_initialized is set to true. 92 // - In addition, other implementations may provide another non-OK status 93 // specific to their failure modes. 94 Status Initialize(InitTableIterator& iter); 95 96 // Initializes the table from the given init table iterator. `serializer` may 97 // specify how to serialize the table initializer, so that the table can be 98 // serialized using its metadata (as opposed to serializing a handle to the 99 // table). 100 Status Initialize(InitTableIterator& iter, 101 std::unique_ptr<InitializerSerializer> serializer); 102 103 // Basic iterator to initialize lookup tables. 104 // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that 105 // the consumer may insert key-value pairs in batches. 106 // 107 // Then the iterator is exhausted, valid returns false and status returns 108 // Status::OutOfRange. 109 // 110 // This class is Thread-unsafe. 111 class InitTableIterator { 112 public: InitTableIterator()113 InitTableIterator() {} 114 ~InitTableIterator()115 virtual ~InitTableIterator() {} 116 117 // Prepares the next batch of key and value tensors. 118 virtual void Next() = 0; 119 120 // Returns true if keys and values point to valid tensors. 121 virtual bool Valid() const = 0; 122 123 // Returns a tensor that contains the current batch of 'key' values. 124 virtual const Tensor& keys() const = 0; 125 126 // Returns a tensor that contains the current batch of 'value' values. 127 virtual const Tensor& values() const = 0; 128 129 // Returns an error if one has occurred, otherwise returns Status::OK. 130 virtual Status status() const = 0; 131 132 // Returns the total number of elements that the iterator will produce. 133 // It might return -1 in case of error. 134 virtual int64 total_size() const = 0; 135 136 private: 137 TF_DISALLOW_COPY_AND_ASSIGN(InitTableIterator); 138 }; 139 GetInitializableLookupTable()140 InitializableLookupTable* GetInitializableLookupTable() override { 141 return this; 142 } 143 144 // Logic specifying how to represent an initializer as a GraphDef, so that a 145 // lookup table can be serialized using its metadata (as opposed to 146 // serializing the content of the table, or a handle to the table). 147 class InitializerSerializer { 148 public: 149 // A function which builds a graph so that executing `*out` will initialize 150 // `table`. 151 using SerializeFn = std::function<Status(GraphDefBuilder* builder, 152 Node* table, Node** out)>; 153 // A function which performs any necessary cleanup for the serializer. 154 using CleanupFn = std::function<void()>; 155 156 // Wraps serialization logic that requires no cleanup. InitializerSerializer(SerializeFn serialize)157 explicit InitializerSerializer(SerializeFn serialize) 158 : serialize_(std::move(serialize)), cleanup_([] {}) {} 159 160 // Wraps serialization logic along with a cleanup function. `cleanup` will 161 // be run when the serializer is destroyed. InitializerSerializer(SerializeFn serialize,CleanupFn cleanup)162 explicit InitializerSerializer(SerializeFn serialize, CleanupFn cleanup) 163 : serialize_(std::move(serialize)), cleanup_(std::move(cleanup)) {} 164 ~InitializerSerializer()165 ~InitializerSerializer() { cleanup_(); } 166 167 // Builds a graph so that executing `*out` will initialize `table`. AsGraphDef(GraphDefBuilder * builder,Node * table,Node ** out)168 Status AsGraphDef(GraphDefBuilder* builder, Node* table, Node** out) { 169 return serialize_(builder, table, out); 170 } 171 172 private: 173 SerializeFn serialize_; 174 CleanupFn cleanup_; 175 }; 176 177 protected: 178 // Prepares and allocates the underlying data structure to store the given 179 // number of expected elements. 180 virtual Status DoPrepare(size_t expected_num_elements) = 0; 181 182 // Same as DoPrepare() but derived implementations might choose to skip 183 // calling get_expected_num_elements if size is not needed for DoPrepare. DoLazyPrepare(std::function<int64 (void)> get_expected_num_elements)184 virtual Status DoLazyPrepare( 185 std::function<int64(void)> get_expected_num_elements) { 186 int64_t expected_num_elements = get_expected_num_elements(); 187 if (expected_num_elements < 0) { 188 return errors::FailedPrecondition("Got negative expected_num_elements."); 189 } 190 return DoPrepare(expected_num_elements); 191 } 192 193 // Populates the table in batches given keys and values as tensors into the 194 // underlying data structure. 195 virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0; 196 197 // Performs the batch find operation on the underlying data structure. 198 virtual Status DoFind(const Tensor& keys, Tensor* values, 199 const Tensor& default_value) = 0; 200 201 virtual Status AreEntriesSame(const InitTableIterator& iter, bool* result); 202 203 mutex mu_; 204 205 protected: 206 // When set, provides a mechanism for serializing the table initializer as 207 // GraphDef. 208 std::unique_ptr<InitializerSerializer> initializer_serializer_; 209 210 private: 211 std::atomic<bool> is_initialized_{false}; 212 }; 213 214 // Iterator to initialize tables given 'keys' and 'values' tensors. 215 // 216 // The two tensors are returned in the first iteration. It doesn't loop 217 // over each element of the tensor since insertions in the lookup table can 218 // process batches. 219 class KeyValueTensorIterator 220 : public InitializableLookupTable::InitTableIterator { 221 public: 222 // keys and values are not owned by the iterator. KeyValueTensorIterator(const Tensor * keys,const Tensor * values)223 explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values) 224 : keys_(keys), values_(values), valid_(true), status_(Status::OK()) { 225 TensorShape key_shape = keys_->shape(); 226 if (!key_shape.IsSameSize(values_->shape())) { 227 valid_ = false; 228 status_ = errors::InvalidArgument( 229 "keys and values should have the same dimension.", 230 key_shape.DebugString(), " vs ", values_->shape().DebugString()); 231 } 232 if (key_shape.num_elements() == 0) { 233 valid_ = false; 234 status_ = 235 errors::InvalidArgument("keys and values cannot be empty tensors."); 236 } 237 } 238 Valid()239 bool Valid() const override { return valid_; } 240 Next()241 void Next() override { 242 valid_ = false; 243 status_ = errors::OutOfRange("No more data."); 244 } 245 keys()246 const Tensor& keys() const override { return *keys_; } 247 values()248 const Tensor& values() const override { return *values_; } 249 status()250 Status status() const override { return status_; } 251 total_size()252 int64 total_size() const override { 253 return keys_ == nullptr ? -1 : keys_->NumElements(); 254 } 255 256 private: 257 TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator); 258 259 const Tensor* keys_; // Doesn't own it. 260 const Tensor* values_; // Doesn't own it. 261 bool valid_; // true if the iterator points to an existing range. 262 Status status_; 263 }; 264 265 } // namespace lookup 266 } // namespace tensorflow 267 268 #endif // TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ 269