1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ 17 #define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ 18 19 #include "tensorflow/core/framework/lookup_interface.h" 20 #include "tensorflow/core/platform/macros.h" 21 22 namespace tensorflow { 23 namespace lookup { 24 25 // Base class for lookup tables that require initialization. 26 class InitializableLookupTable : public LookupInterface { 27 public: 28 class InitTableIterator; 29 30 // Performs batch lookups, for every element in the key tensor, Find returns 31 // the corresponding value into the values tensor. 32 // If an element is not present in the table, the given default value is used. 33 // 34 // For tables that require initialization, `Find` is available once the table 35 // is marked as initialized. 36 // 37 // Returns the following statuses: 38 // - OK: when the find finishes successfully. 39 // - FailedPrecondition: if the table is not initialized. 40 // - InvalidArgument: if any of the preconditions on the lookup key or value 41 // fails. 42 // - In addition, other implementations may provide another non-OK status 43 // specific to their failure modes. 44 Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values, 45 const Tensor& default_value) final; 46 47 // Returns errors::Unimplemented. Insert(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)48 Status Insert(OpKernelContext* ctx, const Tensor& keys, 49 const Tensor& values) final { 50 return errors::Unimplemented( 51 "Insert not supported by InitializableLookupTable implementations"); 52 } 53 54 // Returns errors::Unimplemented. Remove(OpKernelContext * ctx,const Tensor & keys)55 Status Remove(OpKernelContext* ctx, const Tensor& keys) final { 56 return errors::Unimplemented( 57 "Remove not supported by InitializableLookupTable implementations"); 58 } 59 ExportValues(OpKernelContext * context)60 Status ExportValues(OpKernelContext* context) override { 61 return errors::Unimplemented( 62 "ExportValues not supported by InitializableLookupTable " 63 "implementations"); 64 } 65 66 Status ImportValues(OpKernelContext* ctx, const Tensor& keys, 67 const Tensor& values) final; 68 key_shape()69 TensorShape key_shape() const final { return TensorShape(); } 70 value_shape()71 TensorShape value_shape() const final { return TensorShape(); } 72 73 // Returns whether the table was initialized and is ready to serve lookups. is_initialized()74 bool is_initialized() const { return is_initialized_; } 75 76 // Initializes the table from the given init table iterator. 77 // 78 // Atomically, this operation prepares the table, populates it with the given 79 // iterator, and mark the table as initialized. 80 // 81 // Returns the following statuses: 82 // - OK: when the initialization was successful. 83 // - InvalidArgument: if any of the preconditions on the lookup key or value 84 // fails. 85 // - FailedPrecondition: if the table is already initialized and 86 // fail_if_initialized is set to true. 87 // - In addition, other implementations may provide another non-OK status 88 // specific to their failure modes. 89 Status Initialize(InitTableIterator& iter); 90 91 // Basic iterator to initialize lookup tables. 92 // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that 93 // the consumer may insert key-value pairs in batches. 94 // 95 // Then the iterator is exhausted, valid returns false and status returns 96 // Status::OutOfRange. 97 // 98 // This class is Thread-unsafe. 99 class InitTableIterator { 100 public: InitTableIterator()101 InitTableIterator() {} 102 ~InitTableIterator()103 virtual ~InitTableIterator() {} 104 105 // Prepares the next batch of key and value tensors. 106 virtual void Next() = 0; 107 108 // Returns true if keys and values point to valid tensors. 109 virtual bool Valid() const = 0; 110 111 // Returns a tensor that contains the current batch of 'key' values. 112 virtual const Tensor& keys() const = 0; 113 114 // Returns a tensor that contains the current batch of 'value' values. 115 virtual const Tensor& values() const = 0; 116 117 // Returns an error if one has occurred, otherwise returns Status::OK. 118 virtual Status status() const = 0; 119 120 // Returns the total number of elements that the iterator will produce. 121 // It might return -1 in case of error. 122 virtual int64 total_size() const = 0; 123 124 private: 125 TF_DISALLOW_COPY_AND_ASSIGN(InitTableIterator); 126 }; 127 GetInitializableLookupTable()128 InitializableLookupTable* GetInitializableLookupTable() override { 129 return this; 130 } 131 132 protected: 133 // Prepares and allocates the underlying data structure to store the given 134 // number of expected elements. 135 virtual Status DoPrepare(size_t expected_num_elements) = 0; 136 137 // Same as DoPrepare() but derived implementations might choose to skip 138 // calling get_expected_num_elements if size is not needed for DoPrepare. DoLazyPrepare(std::function<int64 (void)> get_expected_num_elements)139 virtual Status DoLazyPrepare( 140 std::function<int64(void)> get_expected_num_elements) { 141 int64 expected_num_elements = get_expected_num_elements(); 142 if (expected_num_elements < 0) { 143 return errors::FailedPrecondition("Got negative expected_num_elements."); 144 } 145 return DoPrepare(expected_num_elements); 146 } 147 148 // Populates the table in batches given keys and values as tensors into the 149 // underlying data structure. 150 virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0; 151 152 // Performs the batch find operation on the underlying data structure. 153 virtual Status DoFind(const Tensor& keys, Tensor* values, 154 const Tensor& default_value) = 0; 155 156 mutex mu_; 157 bool is_initialized_ = false; 158 }; 159 160 // Iterator to initialize tables given 'keys' and 'values' tensors. 161 // 162 // The two tensors are returned in the first iteration. It doesn't loop 163 // over each element of the tensor since insertions in the lookup table can 164 // process batches. 165 class KeyValueTensorIterator 166 : public InitializableLookupTable::InitTableIterator { 167 public: 168 // keys and values are not owned by the iterator. KeyValueTensorIterator(const Tensor * keys,const Tensor * values)169 explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values) 170 : keys_(keys), values_(values), valid_(true), status_(Status::OK()) { 171 TensorShape key_shape = keys_->shape(); 172 if (!key_shape.IsSameSize(values_->shape())) { 173 valid_ = false; 174 status_ = errors::InvalidArgument( 175 "keys and values should have the same dimension.", 176 key_shape.DebugString(), " vs ", values_->shape().DebugString()); 177 } 178 if (key_shape.num_elements() == 0) { 179 valid_ = false; 180 status_ = 181 errors::InvalidArgument("keys and values cannot be empty tensors."); 182 } 183 } 184 Valid()185 bool Valid() const override { return valid_; } 186 Next()187 void Next() override { 188 valid_ = false; 189 status_ = errors::OutOfRange("No more data."); 190 } 191 keys()192 const Tensor& keys() const override { return *keys_; } 193 values()194 const Tensor& values() const override { return *values_; } 195 status()196 Status status() const override { return status_; } 197 total_size()198 int64 total_size() const override { 199 return keys_ == nullptr ? -1 : keys_->NumElements(); 200 } 201 202 private: 203 TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator); 204 205 const Tensor* keys_; // Doesn't own it. 206 const Tensor* values_; // Doesn't own it. 207 bool valid_; // true if the iterator points to an existing range. 208 Status status_; 209 }; 210 211 } // namespace lookup 212 } // namespace tensorflow 213 214 #endif // TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ 215