1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#============================================================================== 15"""Lookup operations.""" 16# pylint: disable=g-bad-name 17import collections 18import functools 19import uuid 20 21from tensorflow.python.checkpoint import saveable_compat 22from tensorflow.python.eager import context 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.framework import tensor_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import gen_lookup_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import string_ops 34# go/tf-wildcard-import 35# pylint: disable=wildcard-import 36from tensorflow.python.ops.gen_lookup_ops import * 37from tensorflow.python.ops.ragged import ragged_tensor 38from tensorflow.python.saved_model import registration 39from tensorflow.python.trackable import asset 40# pylint: enable=wildcard-import 41from tensorflow.python.trackable import base as trackable_base 42from tensorflow.python.trackable import resource 43from tensorflow.python.training.saver import BaseSaverBuilder 44from tensorflow.python.util import compat as compat_util 45from tensorflow.python.util.deprecation import deprecated 46from tensorflow.python.util.tf_export import tf_export 47 48 49@tf_export(v1=["initialize_all_tables"]) 50@deprecated(None, "Use `tf.tables_initializer` instead.") 51def initialize_all_tables(name="init_all_tables"): 52 """Returns an Op that initializes all tables of the default graph. 53 54 Args: 55 name: Optional name for the initialization op. 56 57 Returns: 58 An Op that initializes all tables. Note that if there are 59 not tables the returned Op is a NoOp. 60 """ 61 return tables_initializer(name) 62 63 64@tf_export(v1=["initializers.tables_initializer", "tables_initializer"]) 65def tables_initializer(name="init_all_tables"): 66 """Returns an Op that initializes all tables of the default graph. 67 68 Args: 69 name: Optional name for the initialization op. 70 71 Returns: 72 An Op that initializes all tables. Note that if there are 73 not tables the returned Op is a NoOp. 74 75 @compatibility(TF2) 76 `tf.compat.v1.tables_initializer` is no longer needed with eager execution and 77 `tf.function`. In TF2, when creating an initializable table like a 78 `tf.lookup.StaticHashTable`, the table will automatically be initialized on 79 creation. 80 81 #### Before & After Usage Example 82 83 Before: 84 85 >>> with tf.compat.v1.Session(): 86 ... init = tf.compat.v1.lookup.KeyValueTensorInitializer(['a', 'b'], [1, 2]) 87 ... table = tf.compat.v1.lookup.StaticHashTable(init, default_value=-1) 88 ... tf.compat.v1.tables_initializer().run() 89 ... result = table.lookup(tf.constant(['a', 'c'])).eval() 90 >>> result 91 array([ 1, -1], dtype=int32) 92 93 After: 94 95 >>> init = tf.lookup.KeyValueTensorInitializer(['a', 'b'], [1, 2]) 96 >>> table = tf.lookup.StaticHashTable(init, default_value=-1) 97 >>> table.lookup(tf.constant(['a', 'c'])).numpy() 98 array([ 1, -1], dtype=int32) 99 100 @end_compatibility 101 """ 102 initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) 103 if initializers: 104 return control_flow_ops.group(*initializers, name=name) 105 return control_flow_ops.no_op(name=name) 106 107 108def check_table_dtypes(table, key_dtype, value_dtype): 109 """Check that the given key_dtype and value_dtype matches the table dtypes. 110 111 Args: 112 table: The table to check types against to. 113 key_dtype: The key data type to check. 114 value_dtype: The value data type to check. 115 116 Raises: 117 TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data 118 types. 119 """ 120 if key_dtype.base_dtype != table.key_dtype: 121 raise TypeError(f"Invalid key dtype for table, expected {table.key_dtype} " 122 f"but got {key_dtype}.") 123 if value_dtype.base_dtype != table.value_dtype: 124 raise TypeError("Invalid value dtype for table, expected " 125 f"{table.value_dtype} but got {value_dtype}.") 126 127 128class LookupInterface(resource.TrackableResource): 129 """Represent a lookup table that persists across different steps.""" 130 131 def __init__(self, key_dtype, value_dtype): 132 """Construct a lookup table interface. 133 134 Args: 135 key_dtype: The table key type. 136 value_dtype: The table value type. 137 """ 138 self._key_dtype = dtypes.as_dtype(key_dtype) 139 self._value_dtype = dtypes.as_dtype(value_dtype) 140 super(LookupInterface, self).__init__() 141 142 def _create_resource(self): 143 raise NotImplementedError 144 145 @property 146 def key_dtype(self): 147 """The table key dtype.""" 148 return self._key_dtype 149 150 @property 151 def value_dtype(self): 152 """The table value dtype.""" 153 return self._value_dtype 154 155 @property 156 def name(self): 157 """The name of the table.""" 158 return NotImplementedError 159 160 def size(self, name=None): 161 """Compute the number of elements in this table.""" 162 raise NotImplementedError 163 164 def lookup(self, keys, name=None): 165 """Looks up `keys` in a table, outputs the corresponding values.""" 166 raise NotImplementedError 167 168 def __getitem__(self, keys): 169 """Looks up `keys` in a table, outputs the corresponding values.""" 170 return self.lookup(keys) 171 172 173class InitializableLookupTableBase(LookupInterface): 174 """Initializable lookup table interface. 175 176 An initializable lookup tables persist across different steps. 177 """ 178 179 def __init__(self, default_value, initializer): 180 """Construct a table object from a table reference. 181 182 If requires a table initializer object (subclass of `TableInitializerBase`). 183 It provides the table key and value types, as well as the op to initialize 184 the table. The caller is responsible to execute the initialization op. 185 186 Args: 187 default_value: The value to use if a key is missing in the table. 188 initializer: The table initializer to use. 189 """ 190 super(InitializableLookupTableBase, self).__init__(initializer.key_dtype, 191 initializer.value_dtype) 192 self._default_value = ops.convert_to_tensor( 193 default_value, dtype=self._value_dtype) 194 self._default_value.get_shape().merge_with(tensor_shape.TensorShape([])) 195 if isinstance(initializer, trackable_base.Trackable): 196 self._initializer = self._track_trackable(initializer, "_initializer") 197 with ops.init_scope(): 198 self._resource_handle = self._create_resource() 199 if (not context.executing_eagerly() and 200 ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access 201 with ops.init_scope(): 202 self._init_op = self._initialize() 203 else: 204 self._init_op = self._initialize() 205 206 def _initialize(self): 207 return self._initializer.initialize(self) 208 209 @property 210 def default_value(self): 211 """The default value of the table.""" 212 return self._default_value 213 214 def size(self, name=None): 215 """Compute the number of elements in this table. 216 217 Args: 218 name: A name for the operation (optional). 219 220 Returns: 221 A scalar tensor containing the number of elements in this table. 222 """ 223 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 224 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 225 226 def lookup(self, keys, name=None): 227 """Looks up `keys` in a table, outputs the corresponding values. 228 229 The `default_value` is used for keys not present in the table. 230 231 Args: 232 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 233 name: A name for the operation (optional). 234 235 Returns: 236 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged, 237 otherwise a dense `Tensor`. 238 239 Raises: 240 TypeError: when `keys` or `default_value` doesn't match the table data 241 types. 242 """ 243 key_tensor = keys 244 if isinstance(keys, 245 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)): 246 key_tensor = keys.values 247 248 if keys.dtype.base_dtype != self._key_dtype: 249 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, " 250 f"received: {keys.dtype}") 251 252 with ops.name_scope( 253 name, "%s_Lookup" % self.name, 254 (self.resource_handle, key_tensor, self._default_value)): 255 values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, 256 key_tensor, 257 self._default_value) 258 259 values.set_shape(key_tensor.get_shape()) 260 if isinstance(keys, sparse_tensor.SparseTensor): 261 return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape) 262 elif isinstance(keys, ragged_tensor.RaggedTensor): 263 return keys.with_values(values) 264 else: 265 return values 266 267 268class InitializableLookupTableBaseV1(InitializableLookupTableBase): 269 270 @property 271 def initializer(self): 272 return self._init_op 273 274 275@registration.register_tf_serializable( 276 predicate=lambda obj: isinstance(obj, StaticHashTable)) 277@tf_export("lookup.StaticHashTable", v1=[]) 278class StaticHashTable(InitializableLookupTableBase): 279 """A generic hash table that is immutable once initialized. 280 281 Example usage: 282 283 >>> keys_tensor = tf.constant(['a', 'b', 'c']) 284 >>> vals_tensor = tf.constant([7, 8, 9]) 285 >>> input_tensor = tf.constant(['a', 'f']) 286 >>> table = tf.lookup.StaticHashTable( 287 ... tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), 288 ... default_value=-1) 289 >>> table.lookup(input_tensor).numpy() 290 array([ 7, -1], dtype=int32) 291 292 Or for more pythonic code: 293 294 >>> table[input_tensor].numpy() 295 array([ 7, -1], dtype=int32) 296 297 The result of a lookup operation has the same shape as the argument: 298 299 >>> input_tensor = tf.constant([['a', 'b'], ['c', 'd']]) 300 >>> table[input_tensor].numpy() 301 array([[ 7, 8], 302 [ 9, -1]], dtype=int32) 303 304 305 """ 306 307 def __init__(self, 308 initializer, 309 default_value, 310 name=None, 311 experimental_is_anonymous=False): 312 """Creates a non-initialized `HashTable` object. 313 314 Creates a table, the type of its keys and values are specified by the 315 initializer. 316 Before using the table you will have to initialize it. After initialization 317 the table will be immutable. 318 319 Args: 320 initializer: The table initializer to use. See `HashTable` kernel for 321 supported key and value types. 322 default_value: The value to use if a key is missing in the table. 323 name: A name for the operation (optional). 324 experimental_is_anonymous: Whether to use anonymous mode for the 325 table (default is False). In anonymous mode, the table 326 resource can only be accessed via a resource handle. It can't 327 be looked up by a name. When all resource handles pointing to 328 that resource are gone, the resource will be deleted 329 automatically. 330 331 Returns: 332 A `HashTable` object. 333 """ 334 self._initializer = initializer 335 self._default_value = default_value 336 self._is_anonymous = experimental_is_anonymous 337 if not self._is_anonymous: 338 self._shared_name = self._initializer._shared_name # pylint: disable=protected-access 339 if not self._shared_name: 340 # Force using a shared name so that StaticHashTable resources can be 341 # shared across different kernels. If no "shared_name" is set and 342 # "use_node_name_sharing" is False, then each kernel gets its own local 343 # resource. 344 self._shared_name = "hash_table_%s" % (str(uuid.uuid4()),) 345 self._name = name or "hash_table" 346 self._table_name = None 347 super(StaticHashTable, self).__init__(default_value, initializer) 348 self._value_shape = self._default_value.get_shape() 349 350 def _create_resource(self): 351 if self._is_anonymous: 352 table_ref = gen_lookup_ops.anonymous_hash_table( 353 key_dtype=self._initializer.key_dtype, 354 value_dtype=self._initializer.value_dtype, 355 name=self._name) 356 else: 357 table_ref = gen_lookup_ops.hash_table_v2( 358 shared_name=self._shared_name, 359 key_dtype=self._initializer.key_dtype, 360 value_dtype=self._initializer.value_dtype, 361 name=self._name) 362 if context.executing_eagerly(): 363 self._table_name = None 364 else: 365 self._table_name = table_ref.op.name.split("/")[-1] 366 return table_ref 367 368 @property 369 def name(self): 370 return self._table_name 371 372 def export(self, name=None): 373 """Returns tensors of all keys and values in the table. 374 375 Args: 376 name: A name for the operation (optional). 377 378 Returns: 379 A pair of tensors with the first tensor containing all keys and the 380 second tensors containing all values in the table. 381 """ 382 with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]): 383 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 384 self.resource_handle, self._key_dtype, self._value_dtype) 385 386 exported_values.set_shape(exported_keys.get_shape().concatenate( 387 self._value_shape)) 388 return exported_keys, exported_values 389 390 def _serialize_to_proto(self, **unused_kwargs): 391 return None 392 393 def _add_trackable_child(self, name, value): 394 setattr(self, name, value) 395 if isinstance(value, trackable_base.Trackable): 396 self._track_trackable(value, name) # pylint:disable=protected-access 397 398 @classmethod 399 def _deserialize_from_proto(cls, **kwargs): 400 401 class _RestoredStaticHashTable(resource.RestoredResource): # pylint: disable=protected-access 402 403 @classmethod 404 def _resource_type(cls): 405 return "RestoredStaticHashTable" 406 407 return _RestoredStaticHashTable._deserialize_from_proto(**kwargs) # pylint: disable=protected-access 408 409 410@tf_export(v1=["lookup.StaticHashTable"]) 411class StaticHashTableV1(StaticHashTable): 412 """A generic hash table that is immutable once initialized. 413 414 When running in graph mode, you must evaluate the tensor returned by 415 `tf.tables_initializer()` before evaluating the tensor returned by 416 this class's `lookup()` method. Example usage in graph mode: 417 418 ```python 419 keys_tensor = tf.constant([1, 2]) 420 vals_tensor = tf.constant([3, 4]) 421 input_tensor = tf.constant([1, 5]) 422 table = tf.lookup.StaticHashTable( 423 tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1) 424 out = table.lookup(input_tensor) 425 with tf.Session() as sess: 426 sess.run(tf.tables_initializer()) 427 print(sess.run(out)) 428 ``` 429 430 Note that in graph mode if you set `experimental_is_anonymous` to 431 `True`, you should only call `Session.run` once, otherwise each 432 `Session.run` will create (and destroy) a new table unrelated to 433 each other, leading to errors such as "Table not initialized". 434 You can do so like this: 435 436 ```python 437 keys_tensor = tf.constant([1, 2]) 438 vals_tensor = tf.constant([3, 4]) 439 input_tensor = tf.constant([1, 5]) 440 table = tf.lookup.StaticHashTable( 441 tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1, 442 experimental_is_anonymous=True) 443 with tf.control_dependencies([tf.tables_initializer()]): 444 out = table.lookup(input_tensor) 445 with tf.Session() as sess: 446 print(sess.run(out)) 447 ``` 448 449 In eager mode, no special code is needed to initialize the table. 450 Example usage in eager mode: 451 452 ```python 453 tf.enable_eager_execution() 454 keys_tensor = tf.constant([1, 2]) 455 vals_tensor = tf.constant([3, 4]) 456 input_tensor = tf.constant([1, 5]) 457 table = tf.lookup.StaticHashTable( 458 tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1) 459 print(table.lookup(input_tensor)) 460 ``` 461 """ 462 463 @property 464 def initializer(self): 465 return self._init_op 466 467 468# For backwards compatibility. This will be removed in TF 2.0. 469class HashTable(StaticHashTableV1): 470 471 @property 472 def init(self): 473 return self.initializer 474 475 476class TableInitializerBase(trackable_base.Trackable): 477 """Base class for lookup table initializers.""" 478 479 def __init__(self, key_dtype, value_dtype): 480 """Construct a table initializer object. 481 482 Args: 483 key_dtype: Type of the table keys. 484 value_dtype: Type of the table values. 485 """ 486 self._key_dtype = dtypes.as_dtype(key_dtype) 487 self._value_dtype = dtypes.as_dtype(value_dtype) 488 489 @property 490 def key_dtype(self): 491 """The expected table key dtype.""" 492 return self._key_dtype 493 494 @property 495 def value_dtype(self): 496 """The expected table value dtype.""" 497 return self._value_dtype 498 499 def initialize(self, table): 500 """Returns the table initialization op.""" 501 raise NotImplementedError 502 503 @property 504 def _shared_name(self): 505 """Returns a shared name to be used by the table.""" 506 shared_name = "" 507 if context.executing_eagerly(): 508 # Ensure a unique name when eager execution is enabled to avoid spurious 509 # sharing issues. 510 # TODO(rohanj): Use context.anonymous_name() instead. 511 shared_name += str(ops.uid()) 512 return shared_name 513 514 515@tf_export("lookup.KeyValueTensorInitializer") 516class KeyValueTensorInitializer(TableInitializerBase): 517 """Table initializers given `keys` and `values` tensors. 518 519 >>> keys_tensor = tf.constant(['a', 'b', 'c']) 520 >>> vals_tensor = tf.constant([7, 8, 9]) 521 >>> input_tensor = tf.constant(['a', 'f']) 522 >>> init = tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor) 523 >>> table = tf.lookup.StaticHashTable( 524 ... init, 525 ... default_value=-1) 526 >>> table.lookup(input_tensor).numpy() 527 array([ 7, -1], dtype=int32) 528 529 """ 530 531 def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None): 532 """Constructs a table initializer object based on keys and values tensors. 533 534 Args: 535 keys: The tensor for the keys. 536 values: The tensor for the values. 537 key_dtype: The `keys` data type. Used when `keys` is a python array. 538 value_dtype: The `values` data type. Used when `values` is a python array. 539 name: A name for the operation (optional). 540 """ 541 if (not context.executing_eagerly() and 542 ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access 543 with ops.init_scope(): 544 self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") 545 self._values = ops.convert_to_tensor( 546 values, dtype=value_dtype, name="values") 547 else: 548 self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") 549 self._values = ops.convert_to_tensor( 550 values, dtype=value_dtype, name="values") 551 self._name = name if name is not None else "key_value_init" 552 if context.executing_eagerly(): 553 # Ensure a unique name when eager execution is enabled to avoid spurious 554 # sharing issues. 555 # TODO(rohanj): Use context.anonymous_name() instead. 556 self._name += str(ops.uid()) 557 558 super(KeyValueTensorInitializer, self).__init__(self._keys.dtype, 559 self._values.dtype) 560 561 def initialize(self, table): 562 """Initializes the given `table` with `keys` and `values` tensors. 563 564 Args: 565 table: The table to initialize. 566 567 Returns: 568 The operation that initializes the table. 569 570 Raises: 571 TypeError: when the keys and values data types do not match the table 572 key and value data types. 573 """ 574 check_table_dtypes(table, self._keys.dtype, self._values.dtype) 575 with ops.name_scope( 576 self._name, values=(table.resource_handle, self._keys, self._values)): 577 init_op = gen_lookup_ops.lookup_table_import_v2(table.resource_handle, 578 self._keys, self._values) 579 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 580 return init_op 581 582 583@tf_export("lookup.TextFileIndex") 584class TextFileIndex: 585 """The key and value content to get from each line. 586 587 This class defines the key and value used for `tf.lookup.TextFileInitializer`. 588 589 The key and value content to get from each line is specified either 590 by the following, or a value `>=0`. 591 * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero, 592 expects data type int64. 593 * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data 594 type string. 595 596 A value `>=0` means use the index (starting at zero) of the split line based 597 on `delimiter`. 598 """ 599 WHOLE_LINE = -2 600 LINE_NUMBER = -1 601 602 603@tf_export("lookup.TextFileInitializer") 604class TextFileInitializer(TableInitializerBase): 605 r"""Table initializers from a text file. 606 607 This initializer assigns one entry in the table for each line in the file. 608 609 The key and value type of the table to initialize is given by `key_dtype` and 610 `value_dtype`. 611 612 The key and value content to get from each line is specified by 613 the `key_index` and `value_index`. 614 615 * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero, 616 expects data type int64. 617 * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data 618 type string. 619 * A value `>=0` means use the index (starting at zero) of the split line based 620 on `delimiter`. 621 622 For example if we have a file with the following content: 623 624 >>> import tempfile 625 >>> f = tempfile.NamedTemporaryFile(delete=False) 626 >>> content='\n'.join(["emerson 10", "lake 20", "palmer 30",]) 627 >>> f.file.write(content.encode('utf-8')) 628 >>> f.file.close() 629 630 The following snippet initializes a table with the first column as keys and 631 second column as values: 632 633 * `emerson -> 10` 634 * `lake -> 20` 635 * `palmer -> 30` 636 637 >>> init= tf.lookup.TextFileInitializer( 638 ... filename=f.name, 639 ... key_dtype=tf.string, key_index=0, 640 ... value_dtype=tf.int64, value_index=1, 641 ... delimiter=" ") 642 >>> table = tf.lookup.StaticHashTable(init, default_value=-1) 643 >>> table.lookup(tf.constant(['palmer','lake','tarkus'])).numpy() 644 645 Similarly to initialize the whole line as keys and the line number as values. 646 647 * `emerson 10 -> 0` 648 * `lake 20 -> 1` 649 * `palmer 30 -> 2` 650 651 >>> init = tf.lookup.TextFileInitializer( 652 ... filename=f.name, 653 ... key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, 654 ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER) 655 >>> table = tf.lookup.StaticHashTable(init, -1) 656 >>> table.lookup(tf.constant('palmer 30')).numpy() 657 2 658 """ 659 660 def __init__(self, 661 filename, 662 key_dtype, 663 key_index, 664 value_dtype, 665 value_index, 666 vocab_size=None, 667 delimiter="\t", 668 name=None, 669 value_index_offset=0): 670 """Constructs a table initializer object to populate from a text file. 671 672 It generates one key-value pair per line. The type of table key and 673 value are specified by `key_dtype` and `value_dtype`, respectively. 674 Similarly the content of the key and value are specified by the key_index 675 and value_index. 676 677 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 678 expects data type int64. 679 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 680 type string or int64. 681 - A value >=0 means use the index (starting at zero) of the split line based 682 on `delimiter`. 683 684 Args: 685 filename: The filename of the text file to be used for initialization. The 686 path must be accessible from wherever the graph is initialized (eg. 687 trainer or eval workers). The filename may be a scalar `Tensor`. 688 key_dtype: The `key` data type. 689 key_index: the index that represents information of a line to get the 690 table 'key' values from. 691 value_dtype: The `value` data type. 692 value_index: the index that represents information of a line to get the 693 table 'value' values from.' 694 vocab_size: The number of elements in the file, if known. 695 delimiter: The delimiter to separate fields in a line. 696 name: A name for the operation (optional). 697 value_index_offset: A number to add to all indices extracted from the file 698 This is useful for cases where a user would like to reserve one or more 699 low index values for control characters. For instance, if you would 700 like to ensure that no vocabulary item is mapped to index 0 (so you can 701 reserve 0 for a masking value), you can set value_index_offset to 1; 702 this will mean that the first vocabulary element is mapped to 1 703 instead of 0. 704 705 Raises: 706 ValueError: when the filename is empty, or when the table key and value 707 data types do not match the expected data types. 708 """ 709 if not isinstance(filename, ops.Tensor) and not filename: 710 raise ValueError("`filename` argument required for tf.lookup.TextFileInitializer") 711 712 self._filename_arg = filename 713 key_dtype = dtypes.as_dtype(key_dtype) 714 value_dtype = dtypes.as_dtype(value_dtype) 715 716 if key_index < -2: 717 raise ValueError(f"`key_index` should be >= -2, received: {key_index}.") 718 719 if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64: 720 raise ValueError("`key_dtype` must be int64 if `key_index` is " 721 f"{TextFileIndex.LINE_NUMBER}, received: {key_dtype}") 722 if ((key_index == TextFileIndex.WHOLE_LINE) and 723 (not key_dtype.is_integer) and (key_dtype != dtypes.string)): 724 raise ValueError( 725 "`key_dtype` should be either integer or string for `key_index` " 726 f"{TextFileIndex.WHOLE_LINE}, received: {key_dtype}") 727 if value_index < -2: 728 raise ValueError("`value_index` should be >= -2, received: " 729 f"{value_index}") 730 731 if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64: 732 raise ValueError("`value_dtype` must be int64 for `value_index` " 733 f"{TextFileIndex.LINE_NUMBER}, received: {value_dtype}") 734 if ((value_index == TextFileIndex.WHOLE_LINE) and 735 (not value_dtype.is_integer) and (value_dtype != dtypes.string)): 736 raise ValueError( 737 "`value_dtype` should be either integer or string for `value_index` " 738 f"{TextFileIndex.WHOLE_LINE}, received: {value_dtype}") 739 740 if (vocab_size is not None) and (vocab_size <= 0): 741 raise ValueError(f"`vocab_size` should be > 0, received: {vocab_size}") 742 743 self._key_index = key_index 744 self._value_index = value_index 745 self._vocab_size = vocab_size 746 self._delimiter = delimiter 747 self._name = name 748 self._filename = self._track_trackable( 749 asset.Asset(filename), "_filename") 750 self._offset = value_index_offset 751 752 super(TextFileInitializer, self).__init__(key_dtype, value_dtype) 753 754 def initialize(self, table): 755 """Initializes the table from a text file. 756 757 Args: 758 table: The table to be initialized. 759 760 Returns: 761 The operation that initializes the table. 762 763 Raises: 764 TypeError: when the keys and values data types do not match the table 765 key and value data types. 766 """ 767 check_table_dtypes(table, self.key_dtype, self.value_dtype) 768 with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)): 769 filename = ops.convert_to_tensor( 770 self._filename, dtypes.string, name="asset_filepath") 771 init_op = gen_lookup_ops.initialize_table_from_text_file_v2( 772 table.resource_handle, filename, self._key_index, self._value_index, 773 -1 if self._vocab_size is None else self._vocab_size, self._delimiter, 774 self._offset) 775 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 776 # If the filename tensor is anything other than a string constant (e.g., 777 # if it is a placeholder) then it does not make sense to track it as an 778 # asset. 779 if not context.executing_eagerly() and constant_op.is_constant(filename): 780 ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename) 781 return init_op 782 783 @property 784 def _shared_name(self): 785 if self._vocab_size: 786 # Keep the shared_name: 787 # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>_<offset> 788 if self._offset: 789 shared_name = "hash_table_%s_%d_%s_%s_%s" % ( 790 self._filename_arg, self._vocab_size, self._key_index, 791 self._value_index, self._offset) 792 else: 793 shared_name = "hash_table_%s_%d_%s_%s" % ( 794 self._filename_arg, self._vocab_size, self._key_index, 795 self._value_index) 796 else: 797 # Keep the shared_name 798 # <table_type>_<filename>_<key_index>_<value_index>_<offset> 799 if self._offset: 800 shared_name = "hash_table_%s_%s_%s_%s" % ( 801 self._filename_arg, self._key_index, self._value_index, 802 self._offset) 803 else: 804 shared_name = "hash_table_%s_%s_%s" % ( 805 self._filename_arg, self._key_index, self._value_index) 806 807 return shared_name 808 809 810class TextFileStringTableInitializer(TextFileInitializer): 811 """Table initializer for `int64` IDs to string tables from a text file.""" 812 813 def __init__(self, 814 filename, 815 key_column_index=TextFileIndex.LINE_NUMBER, 816 value_column_index=TextFileIndex.WHOLE_LINE, 817 vocab_size=None, 818 delimiter="\t", 819 name="text_file_string_table_init"): 820 """Constructs an initializer for an id-to-string table from a text file. 821 822 It populates a table that its key and value types are int64 and string, 823 respectively. It generates one key-value pair per line. 824 The content of the key and value are specified by `key_column_index` 825 and `value_column_index`. 826 827 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 828 expects data type int64. 829 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 830 type string or int64. 831 - A value >=0 means use the index (starting at zero) of the split line based 832 on `delimiter`. 833 834 Args: 835 filename: The filename of the text file to be used for initialization. The 836 path must be accessible from wherever the graph is initialized (eg. 837 trainer or eval workers). The filename may be a scalar `Tensor`. 838 key_column_index: The column index from the text file to get the keys 839 from. The default is to use the line number, starting from zero. 840 value_column_index: The column index from the text file to get the values 841 from. The default is to use the whole line content. 842 vocab_size: The number of elements in the file, if known. 843 delimiter: The delimiter to separate fields in a line. 844 name: Optional name for the op. 845 846 Raises: 847 TypeError: when the filename is empty, or when the table key and value 848 data types do not match the expected data types. 849 """ 850 super(TextFileStringTableInitializer, self).__init__( 851 filename, 852 dtypes.int64, 853 key_column_index, 854 dtypes.string, 855 value_column_index, 856 vocab_size=vocab_size, 857 delimiter=delimiter, 858 name=name) 859 860 861class TextFileIdTableInitializer(TextFileInitializer): 862 """Table initializer for string to `int64` IDs tables from a text file.""" 863 864 def __init__(self, 865 filename, 866 key_column_index=TextFileIndex.WHOLE_LINE, 867 value_column_index=TextFileIndex.LINE_NUMBER, 868 vocab_size=None, 869 delimiter="\t", 870 name="text_file_id_table_init", 871 key_dtype=dtypes.string): 872 """Constructs an initializer for an string-to-id table from a text file. 873 874 It populates a table that its key and value types are string and int64, 875 respectively. It generates one key-value pair per line. 876 The content of the key and value are specified by the key_index 877 and value_index. 878 879 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 880 expects data type int64. 881 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 882 type string. 883 - A value >=0 means use the index (starting at zero) of the split line based 884 on `delimiter`. 885 886 Args: 887 filename: The filename of the text file to be used for initialization. The 888 path must be accessible from wherever the graph is initialized (eg. 889 trainer or eval workers). The filename may be a scalar `Tensor`. 890 key_column_index: The column index from the text file to get the `key` 891 values from. The default is to use the whole line content. 892 value_column_index: The column index from the text file to get the `value` 893 values from. The default is to use the line number, starting from zero. 894 vocab_size: The number of elements in the file, if known. 895 delimiter: The delimiter to separate fields in a line. 896 name: Optional name for the op. 897 key_dtype: The `key` data type. 898 899 Raises: 900 TypeError: when the filename is empty, or when the table key and value 901 data types do not match the expected data types. 902 """ 903 super(TextFileIdTableInitializer, self).__init__( 904 filename, 905 key_dtype, 906 key_column_index, 907 dtypes.int64, 908 value_column_index, 909 vocab_size=vocab_size, 910 delimiter=delimiter, 911 name=name) 912 913 914class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])): 915 """A structure for the spec of the hashing function to use for hash buckets. 916 917 `hasher` is the name of the hashing function to use (eg. "fasthash", 918 "stronghash"). 919 `key` is optional and specify the key to use for the hash function if 920 supported, currently only used by a strong hash. 921 922 Fields: 923 hasher: The hasher name to use. 924 key: The key to be used by the hashing function, if required. 925 """ 926 __slots__ = () 927 928 929FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name 930 931 932class StrongHashSpec(HasherSpec): 933 """A structure to specify a key of the strong keyed hash spec. 934 935 The strong hash requires a `key`, which is a list of 2 unsigned integer 936 numbers. These should be non-zero; random numbers generated from random.org 937 would be a fine choice. 938 939 Fields: 940 key: The key to be used by the keyed hashing function. 941 """ 942 __slots__ = () 943 944 def __new__(cls, key): 945 if len(key) != 2: 946 raise ValueError(f"`key` must have size 2, received {len(key)}") 947 948 if not isinstance(key[0], compat_util.integral_types) or not isinstance( 949 key[1], compat_util.integral_types): 950 raise TypeError("Invalid key %s. Must be unsigned integer values." % key) 951 952 return super(cls, StrongHashSpec).__new__(cls, "stronghash", key) 953 954 955def _as_string(tensor): 956 if dtypes.string == tensor.dtype.base_dtype: 957 return tensor 958 return string_ops.as_string(tensor) 959 960 961class IdTableWithHashBuckets(LookupInterface): 962 r"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets. 963 964 For example, if an instance of `IdTableWithHashBuckets` is initialized with a 965 string-to-id table that maps: 966 967 * `emerson -> 0` 968 * `lake -> 1` 969 * `palmer -> 2` 970 971 The `IdTableWithHashBuckets` object will performs the following mapping: 972 973 * `emerson -> 0` 974 * `lake -> 1` 975 * `palmer -> 2` 976 * `<other term> -> bucket_id`, where bucket_id will be between `3` and 977 `3 + num_oov_buckets - 1`, calculated by: 978 `hash(<term>) % num_oov_buckets + vocab_size` 979 980 If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`, 981 the lookup result is `[0, 1, 2, 4, 7]`. 982 983 If `table` is None, only out-of-vocabulary buckets are used. 984 985 Example usage: 986 987 ```python 988 num_oov_buckets = 3 989 input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"]) 990 table = tf.IdTableWithHashBuckets( 991 tf.StaticHashTable( 992 tf.lookup.TextFileInitializer( 993 filename, 994 key_dtype=tf.string, 995 key_index=tf.lookup.TextFileIndex.WHOLE_LINE, 996 value_dtype=tf.int64, 997 value_index=tf.lookup.TextFileIndex.LINE_NUMBER, 998 delimiter="\t"), 999 default_value), 1000 num_oov_buckets) 1001 out = table.lookup(input_tensor). 1002 table.init.run() 1003 print(out.eval()) 1004 ``` 1005 1006 The hash function used for generating out-of-vocabulary buckets ID is handled 1007 by `hasher_spec`. 1008 """ 1009 1010 def __init__(self, 1011 table, 1012 num_oov_buckets, 1013 hasher_spec=FastHashSpec, 1014 name=None, 1015 key_dtype=None): 1016 """Construct a `IdTableWithHashBuckets` object. 1017 1018 Args: 1019 table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids. 1020 num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. 1021 hasher_spec: A `HasherSpec` to specify the hash function to use for 1022 assignation of out-of-vocabulary buckets (optional). 1023 name: A name for the operation (optional). 1024 key_dtype: Data type of keys passed to `lookup`. Defaults to 1025 `table.key_dtype` if `table` is specified, otherwise `tf.string`. Must 1026 be string or integer, and must be castable to `table.key_dtype`. 1027 1028 Raises: 1029 ValueError: when `table` in None and `num_oov_buckets` is not positive. 1030 TypeError: when `hasher_spec` is invalid. 1031 """ 1032 # If a name ends with a '/' it is a "name scope", remove all trailing '/' 1033 # characters to use as table name. 1034 if name: 1035 name = name.rstrip("/") 1036 if table: 1037 if key_dtype is None: 1038 key_dtype = table.key_dtype 1039 supported_table_key_dtypes = (dtypes.int64, dtypes.string) 1040 if table.key_dtype not in supported_table_key_dtypes: 1041 raise TypeError("Invalid `key_dtype`, expected one of " 1042 f"{supported_table_key_dtypes}, received {key_dtype}.") 1043 if table.key_dtype.is_integer != key_dtype.is_integer: 1044 raise TypeError("Invalid `key dtype`, expected %s but got %s." % 1045 ("integer" if key_dtype.is_integer else "non-integer", 1046 table.key_dtype)) 1047 if table.value_dtype != dtypes.int64: 1048 raise TypeError("Invalid `value_dtype`: expected int64 but got %s." % 1049 (table.value_dtype)) 1050 self._table = table 1051 name = name or self._table.name 1052 else: 1053 if num_oov_buckets <= 0: 1054 raise ValueError("`oov_buckets` must be > 0 if no `table` is supplied.") 1055 key_dtype = dtypes.string if key_dtype is None else key_dtype 1056 self._table = None 1057 name = name or "hash_bucket" 1058 if (not key_dtype.is_integer) and (dtypes.string != key_dtype): 1059 raise TypeError("Invalid `key_dtype`, expected integer or string, got " 1060 f"{key_dtype}.") 1061 self._num_oov_buckets = num_oov_buckets 1062 1063 if not isinstance(hasher_spec, HasherSpec): 1064 raise TypeError("`hasher_spec` must be of type HasherSpec, got " 1065 f"{type(hasher_spec)}.") 1066 self._hasher_spec = hasher_spec 1067 if name: 1068 self._table_name = name.split("/")[-1] 1069 else: 1070 self._table_name = None 1071 super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64) 1072 1073 def _create_resource(self): 1074 if self._table is not None: 1075 return self._table._create_resource() # pylint: disable=protected-access 1076 return None 1077 1078 def _initialize(self): 1079 if self._table is not None: 1080 return self._table._initialize() # pylint: disable=protected-access 1081 with ops.name_scope(None, "init"): 1082 return control_flow_ops.no_op() 1083 1084 @property 1085 def initializer(self): 1086 if self._table is not None: 1087 return self._table._init_op # pylint: disable=protected-access 1088 with ops.name_scope(None, "init"): 1089 return control_flow_ops.no_op() 1090 1091 @property 1092 @deprecated("2018-12-15", "Use `initializer` instead.") 1093 def init(self): 1094 return self.initializer 1095 1096 @property 1097 def resource_handle(self): 1098 if self._table is not None: 1099 return self._table.resource_handle 1100 return None 1101 1102 @property 1103 def name(self): 1104 return self._table_name 1105 1106 def size(self, name=None): 1107 """Compute the number of elements in this table.""" 1108 with ops.name_scope(name, "%s_Size" % self.name): 1109 if self._table: 1110 tsize = self._table.size() 1111 else: 1112 tsize = ops.convert_to_tensor(0, dtype=dtypes.int64) 1113 return tsize + self._num_oov_buckets 1114 1115 def _get_string_to_hash_bucket_fn(self, hasher_spec): 1116 """Returns the string_to_hash_bucket op to use based on `hasher_spec`.""" 1117 if not isinstance(hasher_spec, HasherSpec): 1118 raise TypeError("`hasher_spec` must be of type HasherSpec, got " 1119 f"{type(hasher_spec)}.") 1120 if hasher_spec.hasher == "fasthash": 1121 return string_ops.string_to_hash_bucket_fast 1122 if hasher_spec.hasher == "legacy": 1123 return string_ops.string_to_hash_bucket 1124 if hasher_spec.hasher == "stronghash": 1125 return functools.partial( 1126 string_ops.string_to_hash_bucket_strong, key=hasher_spec.key) 1127 raise ValueError( 1128 f"Found unknown hasher {hasher_spec.hasher} in `hasher_spec`") 1129 1130 def lookup(self, keys, name=None): 1131 """Looks up `keys` in the table, outputs the corresponding values. 1132 1133 It assigns out-of-vocabulary keys to buckets based in their hashes. 1134 1135 Args: 1136 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 1137 name: Optional name for the op. 1138 1139 Returns: 1140 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged, 1141 otherwise a dense `Tensor`. 1142 1143 Raises: 1144 TypeError: when `keys` doesn't match the table key data type. 1145 """ 1146 if keys.dtype.base_dtype != self._key_dtype: 1147 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, " 1148 f"received: {keys.dtype}") 1149 values = keys 1150 if isinstance(keys, 1151 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)): 1152 values = keys.values 1153 if self._table and (self._table.key_dtype.base_dtype == dtypes.int64): 1154 values = math_ops.cast(values, dtypes.int64) 1155 1156 if self._num_oov_buckets == 0: 1157 ids = self._table.lookup(values, name=name) 1158 else: 1159 # TODO(yleon): Consider moving this functionality to its own kernel. 1160 with ops.name_scope(name, "%s_Lookup" % self.name): 1161 str_to_hash_bucket = self._get_string_to_hash_bucket_fn( 1162 self._hasher_spec) 1163 buckets = str_to_hash_bucket( 1164 _as_string(values), 1165 num_buckets=self._num_oov_buckets, 1166 name="hash_bucket") 1167 if self._table: 1168 ids = self._table.lookup(values) 1169 buckets = math_ops.add(buckets, self._table.size()) 1170 is_id_non_default = math_ops.not_equal(ids, self._table.default_value) 1171 ids = array_ops.where_v2(is_id_non_default, ids, buckets) 1172 else: 1173 ids = buckets 1174 if isinstance(keys, sparse_tensor.SparseTensor): 1175 return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape) 1176 elif isinstance(keys, ragged_tensor.RaggedTensor): 1177 return keys.with_values(ids) 1178 return ids 1179 1180 1181@tf_export("lookup.StaticVocabularyTable", v1=[]) 1182class StaticVocabularyTable(LookupInterface): 1183 r"""String to Id table that assigns out-of-vocabulary keys to hash buckets. 1184 1185 For example, if an instance of `StaticVocabularyTable` is initialized with a 1186 string-to-id initializer that maps: 1187 1188 >>> init = tf.lookup.KeyValueTensorInitializer( 1189 ... keys=tf.constant(['emerson', 'lake', 'palmer']), 1190 ... values=tf.constant([0, 1, 2], dtype=tf.int64)) 1191 >>> table = tf.lookup.StaticVocabularyTable( 1192 ... init, 1193 ... num_oov_buckets=5) 1194 1195 The `Vocabulary` object will performs the following mapping: 1196 1197 * `emerson -> 0` 1198 * `lake -> 1` 1199 * `palmer -> 2` 1200 * `<other term> -> bucket_id`, where `bucket_id` will be between `3` and 1201 `3 + num_oov_buckets - 1 = 7`, calculated by: 1202 `hash(<term>) % num_oov_buckets + vocab_size` 1203 1204 If input_tensor is: 1205 1206 >>> input_tensor = tf.constant(["emerson", "lake", "palmer", 1207 ... "king", "crimson"]) 1208 >>> table[input_tensor].numpy() 1209 array([0, 1, 2, 6, 7]) 1210 1211 If `initializer` is None, only out-of-vocabulary buckets are used. 1212 1213 Example usage: 1214 1215 >>> num_oov_buckets = 3 1216 >>> vocab = ["emerson", "lake", "palmer", "crimnson"] 1217 >>> import tempfile 1218 >>> f = tempfile.NamedTemporaryFile(delete=False) 1219 >>> f.write('\n'.join(vocab).encode('utf-8')) 1220 >>> f.close() 1221 1222 >>> init = tf.lookup.TextFileInitializer( 1223 ... f.name, 1224 ... key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, 1225 ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER) 1226 >>> table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets) 1227 >>> table.lookup(tf.constant(["palmer", "crimnson" , "king", 1228 ... "tarkus", "black", "moon"])).numpy() 1229 array([2, 3, 5, 6, 6, 4]) 1230 1231 The hash function used for generating out-of-vocabulary buckets ID is 1232 Fingerprint64. 1233 1234 Note that the out-of-vocabulary bucket IDs always range from the table `size` 1235 up to `size + num_oov_buckets - 1` regardless of the table values, which could 1236 cause unexpected collisions: 1237 1238 >>> init = tf.lookup.KeyValueTensorInitializer( 1239 ... keys=tf.constant(["emerson", "lake", "palmer"]), 1240 ... values=tf.constant([1, 2, 3], dtype=tf.int64)) 1241 >>> table = tf.lookup.StaticVocabularyTable( 1242 ... init, 1243 ... num_oov_buckets=1) 1244 >>> input_tensor = tf.constant(["emerson", "lake", "palmer", "king"]) 1245 >>> table[input_tensor].numpy() 1246 array([1, 2, 3, 3]) 1247 """ 1248 1249 def __init__(self, 1250 initializer, 1251 num_oov_buckets, 1252 lookup_key_dtype=None, 1253 name=None, 1254 experimental_is_anonymous=False): 1255 """Construct a `StaticVocabularyTable` object. 1256 1257 Args: 1258 initializer: A `TableInitializerBase` object that contains the data used 1259 to initialize the table. If None, then we only use out-of-vocab buckets. 1260 num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. Must 1261 be greater than zero. If out-of-vocab buckets are not required, use 1262 `StaticHashTable` instead. 1263 lookup_key_dtype: Data type of keys passed to `lookup`. Defaults to 1264 `initializer.key_dtype` if `initializer` is specified, otherwise 1265 `tf.string`. Must be string or integer, and must be castable to 1266 `initializer.key_dtype`. 1267 name: A name for the operation (optional). 1268 experimental_is_anonymous: Whether to use anonymous mode for the 1269 table (default is False). In anonymous mode, the table 1270 resource can only be accessed via a resource handle. It can't 1271 be looked up by a name. When all resource handles pointing to 1272 that resource are gone, the resource will be deleted 1273 automatically. 1274 1275 Raises: 1276 ValueError: when `num_oov_buckets` is not positive. 1277 TypeError: when lookup_key_dtype or initializer.key_dtype are not 1278 integer or string. Also when initializer.value_dtype != int64. 1279 """ 1280 if num_oov_buckets <= 0: 1281 raise ValueError("`num_oov_buckets` must be > 0; use StaticHashTable.") 1282 # If a name ends with a '/' it is a "name scope", remove all trailing '/' 1283 # characters to use as table name. 1284 if name: 1285 name = name.rstrip("/") 1286 if initializer: 1287 if lookup_key_dtype is None: 1288 lookup_key_dtype = initializer.key_dtype 1289 supported_table_key_dtypes = (dtypes.int64, dtypes.string) 1290 if initializer.key_dtype not in supported_table_key_dtypes: 1291 raise TypeError("Invalid `key_dtype`, expected one of %s, but got %s." % 1292 (supported_table_key_dtypes, initializer.key_dtype)) 1293 if initializer.key_dtype.is_integer != lookup_key_dtype.is_integer: 1294 raise TypeError( 1295 "Invalid `key_dtype`, expected %s but got %s." % 1296 ("integer" if lookup_key_dtype.is_integer else "non-integer", 1297 initializer.key_dtype)) 1298 if initializer.value_dtype != dtypes.int64: 1299 raise TypeError("Invalid `value_dtype`, expected %s but got %s." % 1300 (dtypes.int64, initializer.value_dtype)) 1301 if isinstance(initializer, trackable_base.Trackable): 1302 self._initializer = self._track_trackable(initializer, "_initializer") 1303 self._table = HashTable( 1304 initializer, 1305 default_value=-1, 1306 experimental_is_anonymous=experimental_is_anonymous) 1307 name = name or self._table.name 1308 else: 1309 lookup_key_dtype = dtypes.string 1310 self._table = None 1311 name = name or "hash_bucket" 1312 if (not lookup_key_dtype.is_integer) and (dtypes.string != 1313 lookup_key_dtype): 1314 raise TypeError("Invalid `key_dtype`, expected integer or string, got " 1315 f"{lookup_key_dtype}") 1316 self._num_oov_buckets = num_oov_buckets 1317 1318 self._table_name = None 1319 if name is not None: 1320 self._table_name = name.split("/")[-1] 1321 super(StaticVocabularyTable, self).__init__(lookup_key_dtype, dtypes.int64) 1322 1323 def _create_resource(self): 1324 if self._table is not None: 1325 return self._table._create_resource() # pylint: disable=protected-access 1326 return None 1327 1328 def _initialize(self): 1329 if self._table is not None: 1330 return self._table._initialize() # pylint: disable=protected-access 1331 with ops.name_scope(None, "init"): 1332 return control_flow_ops.no_op() 1333 1334 @property 1335 def resource_handle(self): 1336 if self._table is not None: 1337 return self._table.resource_handle 1338 return None 1339 1340 @property 1341 def name(self): 1342 return self._table_name 1343 1344 def size(self, name=None): 1345 """Compute the number of elements in this table.""" 1346 with ops.name_scope(name, "%s_Size" % self.name): 1347 if self._table: 1348 tsize = self._table.size() 1349 else: 1350 tsize = ops.convert_to_tensor(0, dtype=dtypes.int64) 1351 return tsize + self._num_oov_buckets 1352 1353 def lookup(self, keys, name=None): 1354 """Looks up `keys` in the table, outputs the corresponding values. 1355 1356 It assigns out-of-vocabulary keys to buckets based in their hashes. 1357 1358 Args: 1359 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 1360 name: Optional name for the op. 1361 1362 Returns: 1363 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged, 1364 otherwise a dense `Tensor`. 1365 1366 Raises: 1367 TypeError: when `keys` doesn't match the table key data type. 1368 """ 1369 if keys.dtype.base_dtype != self._key_dtype: 1370 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, " 1371 f"received: {keys.dtype}") 1372 values = keys 1373 if isinstance(keys, 1374 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)): 1375 values = keys.values 1376 if self._table and (self._table.key_dtype.base_dtype == dtypes.int64): 1377 values = math_ops.cast(values, dtypes.int64) 1378 1379 # TODO(yleon): Consider moving this functionality to its own kernel. 1380 with ops.name_scope(name, "%s_Lookup" % self.name): 1381 buckets = string_ops.string_to_hash_bucket_fast( 1382 _as_string(values), 1383 num_buckets=self._num_oov_buckets, 1384 name="hash_bucket") 1385 if self._table: 1386 ids = self._table.lookup(values) 1387 buckets = math_ops.add(buckets, self._table.size()) 1388 is_id_non_default = math_ops.not_equal(ids, self._table.default_value) 1389 ids = array_ops.where_v2(is_id_non_default, ids, buckets) 1390 else: 1391 ids = buckets 1392 if isinstance(keys, sparse_tensor.SparseTensor): 1393 return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape) 1394 elif isinstance(keys, ragged_tensor.RaggedTensor): 1395 return keys.with_values(ids) 1396 return ids 1397 1398 1399@tf_export(v1=["lookup.StaticVocabularyTable"]) 1400class StaticVocabularyTableV1(StaticVocabularyTable): 1401 1402 @property 1403 def initializer(self): 1404 if self._table is not None: 1405 return self._table._init_op # pylint: disable=protected-access 1406 with ops.name_scope(None, "init"): 1407 return control_flow_ops.no_op() 1408 1409 1410def index_table_from_file(vocabulary_file=None, 1411 num_oov_buckets=0, 1412 vocab_size=None, 1413 default_value=-1, 1414 hasher_spec=FastHashSpec, 1415 key_dtype=dtypes.string, 1416 name=None, 1417 key_column_index=TextFileIndex.WHOLE_LINE, 1418 value_column_index=TextFileIndex.LINE_NUMBER, 1419 delimiter="\t"): 1420 """Returns a lookup table that converts a string tensor into int64 IDs. 1421 1422 This operation constructs a lookup table to convert tensor of strings into 1423 int64 IDs. The mapping can be initialized from a vocabulary file specified in 1424 `vocabulary_file`, where the whole line is the key and the zero-based line 1425 number is the ID. 1426 1427 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 1428 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 1429 `default_value`. 1430 The bucket ID range is 1431 `[vocabulary size, vocabulary size + num_oov_buckets - 1]`. 1432 1433 The underlying table must be initialized by calling 1434 `session.run(tf.compat.v1.tables_initializer())` or 1435 `session.run(table.init())` once. 1436 1437 To specify multi-column vocabulary files, use key_column_index and 1438 value_column_index and delimiter. 1439 1440 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 1441 expects data type int64. 1442 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 1443 type string. 1444 - A value >=0 means use the index (starting at zero) of the split line based 1445 on `delimiter`. 1446 1447 Sample Usages: 1448 1449 If we have a vocabulary file "test.txt" with the following content: 1450 1451 ``` 1452 emerson 1453 lake 1454 palmer 1455 ``` 1456 1457 ```python 1458 features = tf.constant(["emerson", "lake", "and", "palmer"]) 1459 table = tf.lookup.index_table_from_file( 1460 vocabulary_file="test.txt", num_oov_buckets=1) 1461 ids = table.lookup(features) 1462 ... 1463 tf.compat.v1.tables_initializer().run() 1464 1465 ids.eval() ==> [0, 1, 3, 2] # where 3 is the out-of-vocabulary bucket 1466 ``` 1467 1468 Args: 1469 vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. 1470 num_oov_buckets: The number of out-of-vocabulary buckets. 1471 vocab_size: Number of the elements in the vocabulary, if known. 1472 default_value: The value to use for out-of-vocabulary feature values. 1473 Defaults to -1. 1474 hasher_spec: A `HasherSpec` to specify the hash function to use for 1475 assignation of out-of-vocabulary buckets. 1476 key_dtype: The `key` data type. 1477 name: A name for this op (optional). 1478 key_column_index: The column index from the text file to get the `key` 1479 values from. The default is to use the whole line content. 1480 value_column_index: The column index from the text file to get the `value` 1481 values from. The default is to use the line number, starting from zero. 1482 delimiter: The delimiter to separate fields in a line. 1483 1484 Returns: 1485 The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`. 1486 1487 Raises: 1488 ValueError: If `vocabulary_file` is not set. 1489 ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater 1490 than zero. 1491 """ 1492 if vocabulary_file is None or (isinstance(vocabulary_file, str) and 1493 not vocabulary_file): 1494 raise ValueError( 1495 "`vocabulary_file` must be specified and must not be empty.") 1496 if num_oov_buckets < 0: 1497 raise ValueError( 1498 "num_oov_buckets must be greater or equal than 0, got %d." % 1499 num_oov_buckets) 1500 if vocab_size is not None and vocab_size < 1: 1501 vocab_file_value = vocabulary_file 1502 if isinstance(vocabulary_file, ops.Tensor): 1503 vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?" 1504 raise ValueError("`vocab_size` must be greater than 0, got %d for " 1505 "vocabulary_file: %s." % (vocab_size, vocab_file_value)) 1506 if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): 1507 raise TypeError("Dtype for `keys` should be either integer or string.") 1508 1509 with ops.name_scope(name, "string_to_index"): 1510 table = None 1511 with ops.name_scope(None, "hash_table"): 1512 init = TextFileIdTableInitializer( 1513 vocabulary_file, 1514 vocab_size=vocab_size, 1515 key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype, 1516 name="table_init", 1517 key_column_index=key_column_index, 1518 value_column_index=value_column_index, 1519 delimiter=delimiter) 1520 1521 table = StaticHashTableV1(init, default_value) 1522 if num_oov_buckets: 1523 table = IdTableWithHashBuckets( 1524 table, 1525 num_oov_buckets=num_oov_buckets, 1526 hasher_spec=hasher_spec, 1527 key_dtype=key_dtype) 1528 1529 return table 1530 1531 1532def index_table_from_tensor(vocabulary_list, 1533 num_oov_buckets=0, 1534 default_value=-1, 1535 hasher_spec=FastHashSpec, 1536 dtype=dtypes.string, 1537 name=None): 1538 """Returns a lookup table that converts a string tensor into int64 IDs. 1539 1540 This operation constructs a lookup table to convert tensor of strings into 1541 int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D 1542 tensor where each element is a key and corresponding index within the tensor 1543 is the value. 1544 1545 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 1546 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 1547 `default_value`. The bucket ID range is 1548 `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`. 1549 1550 The underlying table must be initialized by calling 1551 `session.run(tf.compat.v1.tables_initializer())` or 1552 `session.run(table.init())` once. 1553 1554 Elements in `vocabulary_list` cannot have duplicates, otherwise when executing 1555 the table initializer op, it will throw a `FailedPreconditionError`. 1556 1557 Sample Usages: 1558 1559 ```python 1560 vocabulary_list = tf.constant(["emerson", "lake", "palmer"]) 1561 table = tf.lookup.index_table_from_tensor( 1562 vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1) 1563 features = tf.constant(["emerson", "lake", "and", "palmer"]) 1564 ids = table.lookup(features) 1565 ... 1566 tf.compat.v1.tables_initializer().run() 1567 1568 ids.eval() ==> [0, 1, 4, 2] 1569 ``` 1570 1571 Args: 1572 vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to 1573 indices. The type of this object must be castable to `dtype`. 1574 num_oov_buckets: The number of out-of-vocabulary buckets. 1575 default_value: The value to use for out-of-vocabulary feature values. 1576 Defaults to -1. 1577 hasher_spec: A `HasherSpec` to specify the hash function to use for 1578 assignment of out-of-vocabulary buckets. 1579 dtype: The type of values passed to `lookup`. Only string and integers are 1580 supported. 1581 name: A name for this op (optional). 1582 1583 Returns: 1584 The lookup table to map an input `Tensor` to index `int64` `Tensor`. 1585 1586 Raises: 1587 ValueError: If `vocabulary_list` is invalid. 1588 ValueError: If `num_oov_buckets` is negative. 1589 """ 1590 if vocabulary_list is None: 1591 raise ValueError("`vocabulary_list` must be specified.") 1592 1593 if num_oov_buckets < 0: 1594 raise ValueError( 1595 "`num_oov_buckets` must be greater or equal than 0, got %d." % 1596 num_oov_buckets) 1597 1598 if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype): 1599 raise TypeError("`dtype` must either be integer or string.") 1600 1601 with ops.name_scope(name, "string_to_index"): 1602 keys = ops.convert_to_tensor(vocabulary_list) 1603 if keys.dtype.is_integer != dtype.is_integer: 1604 raise ValueError( 1605 "Invalid `dtype`: Expected %s, got %s." % 1606 ("integer" if dtype.is_integer else "non-integer", keys.dtype)) 1607 if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype): 1608 raise ValueError("Invalid `dtype`: Expected %s, got %s." % 1609 (dtype, keys.dtype)) 1610 num_elements = array_ops.size(keys) 1611 values = math_ops.cast(math_ops.range(num_elements), dtypes.int64) 1612 1613 with ops.name_scope(None, "hash_table"): 1614 table_keys = math_ops.cast( 1615 keys, dtypes.int64) if keys.dtype.is_integer else keys 1616 init = KeyValueTensorInitializer( 1617 table_keys, 1618 values, 1619 table_keys.dtype.base_dtype, 1620 dtypes.int64, 1621 name="table_init") 1622 table = StaticHashTableV1(init, default_value) 1623 if num_oov_buckets: 1624 table = IdTableWithHashBuckets( 1625 table, 1626 num_oov_buckets=num_oov_buckets, 1627 hasher_spec=hasher_spec, 1628 key_dtype=dtype) 1629 return table 1630 1631 1632def index_to_string_table_from_file(vocabulary_file, 1633 vocab_size=None, 1634 default_value="UNK", 1635 name=None, 1636 key_column_index=TextFileIndex.LINE_NUMBER, 1637 value_column_index=TextFileIndex.WHOLE_LINE, 1638 delimiter="\t"): 1639 """Returns a lookup table that maps a `Tensor` of indices into strings. 1640 1641 This operation constructs a lookup table to map int64 indices into string 1642 values. The table is initialized from a vocabulary file specified in 1643 `vocabulary_file`, where the whole line is the value and the 1644 zero-based line number is the index. 1645 1646 Any input which does not have a corresponding index in the vocabulary file 1647 (an out-of-vocabulary entry) is assigned the `default_value` 1648 1649 The underlying table must be initialized by calling 1650 `session.run(tf.compat.v1.tables_initializer())` or 1651 `session.run(table.init())` once. 1652 1653 To specify multi-column vocabulary files, use key_column_index and 1654 value_column_index and delimiter. 1655 1656 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 1657 expects data type int64. 1658 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 1659 type string. 1660 - A value >=0 means use the index (starting at zero) of the split line based 1661 on `delimiter`. 1662 1663 Sample Usages: 1664 1665 If we have a vocabulary file "test.txt" with the following content: 1666 1667 ``` 1668 emerson 1669 lake 1670 palmer 1671 ``` 1672 1673 ```python 1674 indices = tf.constant([1, 5], tf.int64) 1675 table = tf.lookup.index_to_string_table_from_file( 1676 vocabulary_file="test.txt", default_value="UNKNOWN") 1677 values = table.lookup(indices) 1678 ... 1679 tf.compat.v1.tables_initializer().run() 1680 1681 values.eval() ==> ["lake", "UNKNOWN"] 1682 ``` 1683 1684 Args: 1685 vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. 1686 vocab_size: Number of the elements in the vocabulary, if known. 1687 default_value: The value to use for out-of-vocabulary indices. 1688 name: A name for this op (optional). 1689 key_column_index: The column index from the text file to get the `key` 1690 values from. The default is to use the line number, starting from zero. 1691 value_column_index: The column index from the text file to get the `value` 1692 values from. The default is to use the whole line content. 1693 delimiter: The delimiter to separate fields in a line. 1694 1695 Returns: 1696 The lookup table to map a string values associated to a given index `int64` 1697 `Tensors`. 1698 1699 Raises: 1700 ValueError: when `vocabulary_file` is empty. 1701 ValueError: when `vocab_size` is invalid. 1702 """ 1703 if vocabulary_file is None or (isinstance(vocabulary_file, str) and 1704 not vocabulary_file): 1705 raise ValueError( 1706 "`vocabulary_file` must be specified and must not be empty.") 1707 1708 if vocab_size is not None and vocab_size < 1: 1709 raise ValueError(f"`vocab_size` must be greater than 0, got {vocab_size}.") 1710 1711 with ops.name_scope(name, "index_to_string"): 1712 init = TextFileStringTableInitializer( 1713 vocabulary_file, 1714 vocab_size=vocab_size, 1715 name="table_init", 1716 key_column_index=key_column_index, 1717 value_column_index=value_column_index, 1718 delimiter=delimiter) 1719 1720 # TODO(yleon): Use a more efficient structure. 1721 return StaticHashTableV1(init, default_value) 1722 1723 1724def index_to_string_table_from_tensor(vocabulary_list, 1725 default_value="UNK", 1726 name=None): 1727 """Returns a lookup table that maps a `Tensor` of indices into strings. 1728 1729 This operation constructs a lookup table to map int64 indices into string 1730 values. The mapping is initialized from a string `vocabulary_list` 1-D 1731 `Tensor` where each element is a value and the corresponding index within the 1732 tensor is the key. 1733 1734 Any input which does not have a corresponding index in 'vocabulary_list' 1735 (an out-of-vocabulary entry) is assigned the `default_value` 1736 1737 The underlying table must be initialized by calling 1738 `session.run(tf.compat.v1.tables_initializer())` or 1739 `session.run(table.init())` once. 1740 1741 Elements in `vocabulary_list` cannot have duplicates, otherwise when executing 1742 the table initializer op, it will throw a `FailedPreconditionError`. 1743 1744 Sample Usages: 1745 1746 ```python 1747 vocabulary_list = tf.constant(["emerson", "lake", "palmer"]) 1748 indices = tf.constant([1, 5], tf.int64) 1749 table = tf.lookup.index_to_string_table_from_tensor( 1750 vocabulary_list, default_value="UNKNOWN") 1751 values = table.lookup(indices) 1752 ... 1753 tf.compat.v1.tables_initializer().run() 1754 1755 values.eval() ==> ["lake", "UNKNOWN"] 1756 ``` 1757 1758 Args: 1759 vocabulary_list: A 1-D string `Tensor` that specifies the strings to map 1760 from indices. 1761 default_value: The value to use for out-of-vocabulary indices. 1762 name: A name for this op (optional). 1763 1764 Returns: 1765 The lookup table to map a string values associated to a given index `int64` 1766 `Tensors`. 1767 1768 Raises: 1769 ValueError: when `vocabulary_list` is not set. 1770 """ 1771 1772 if vocabulary_list is None: 1773 raise ValueError("`vocabulary_list` argument must be specified.") 1774 1775 with ops.name_scope(name, "index_to_string"): 1776 vocabulary_list = ops.convert_to_tensor(vocabulary_list, dtypes.string) 1777 num_elements = array_ops.size(vocabulary_list) 1778 keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64) 1779 1780 init = KeyValueTensorInitializer( 1781 keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init") 1782 # TODO(yleon): Use a more efficient structure. 1783 return StaticHashTableV1(init, default_value) 1784 1785 1786@tf_export("lookup.experimental.MutableHashTable") 1787@saveable_compat.legacy_saveable_name("table") 1788class MutableHashTable(LookupInterface): 1789 """A generic mutable hash table implementation. 1790 1791 Data can be inserted by calling the `insert` method and removed by calling the 1792 `remove` method. It does not support initialization via the init method. 1793 1794 `MutableHashTable` requires additional memory during checkpointing and restore 1795 operations to create temporary key and value tensors. 1796 1797 Example usage: 1798 1799 >>> table = tf.lookup.experimental.MutableHashTable(key_dtype=tf.string, 1800 ... value_dtype=tf.int64, 1801 ... default_value=-1) 1802 >>> keys_tensor = tf.constant(['a', 'b', 'c']) 1803 >>> vals_tensor = tf.constant([7, 8, 9], dtype=tf.int64) 1804 >>> input_tensor = tf.constant(['a', 'f']) 1805 >>> table.insert(keys_tensor, vals_tensor) 1806 >>> table.lookup(input_tensor).numpy() 1807 array([ 7, -1]) 1808 >>> table.remove(tf.constant(['c'])) 1809 >>> table.lookup(keys_tensor).numpy() 1810 array([ 7, 8, -1]) 1811 >>> sorted(table.export()[0].numpy()) 1812 [b'a', b'b'] 1813 >>> sorted(table.export()[1].numpy()) 1814 [7, 8] 1815 """ 1816 1817 def __init__(self, 1818 key_dtype, 1819 value_dtype, 1820 default_value, 1821 name="MutableHashTable", 1822 checkpoint=True, 1823 experimental_is_anonymous=False): 1824 """Creates an empty `MutableHashTable` object. 1825 1826 Creates a table, the type of its keys and values are specified by key_dtype 1827 and value_dtype, respectively. 1828 1829 Args: 1830 key_dtype: the type of the key tensors. 1831 value_dtype: the type of the value tensors. 1832 default_value: The value to use if a key is missing in the table. 1833 name: A name for the operation (optional). 1834 checkpoint: if True, the contents of the table are saved to and restored 1835 from checkpoints. If `shared_name` is empty for a checkpointed table, it 1836 is shared using the table node name. 1837 experimental_is_anonymous: Whether to use anonymous mode for the 1838 table (default is False). In anonymous mode, the table 1839 resource can only be accessed via a resource handle. It can't 1840 be looked up by a name. When all resource handles pointing to 1841 that resource are gone, the resource will be deleted 1842 automatically. 1843 1844 Returns: 1845 A `MutableHashTable` object. 1846 1847 Raises: 1848 ValueError: If checkpoint is True and no name was specified. 1849 """ 1850 self._default_value = ops.convert_to_tensor( 1851 default_value, dtype=value_dtype) 1852 self._value_shape = self._default_value.get_shape() 1853 self._checkpoint = checkpoint 1854 self._key_dtype = key_dtype 1855 self._value_dtype = value_dtype 1856 self._name = name 1857 self._is_anonymous = experimental_is_anonymous 1858 if not self._is_anonymous: 1859 self._shared_name = None 1860 if context.executing_eagerly(): 1861 # TODO(allenl): This will leak memory due to kernel caching by 1862 # the shared_name attribute value (but is better than the 1863 # alternative of sharing everything by default when executing 1864 # eagerly; hopefully creating tables in a loop is uncommon). 1865 self._shared_name = "table_%d" % (ops.uid(),) 1866 super(MutableHashTable, self).__init__(key_dtype, value_dtype) 1867 self._resource_handle = self._create_resource() 1868 if checkpoint: 1869 saveable = MutableHashTable._Saveable(self, name) 1870 if not context.executing_eagerly(): 1871 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 1872 1873 def _create_resource(self): 1874 if self._is_anonymous: 1875 if self._default_value.get_shape().ndims == 0: 1876 table_ref = gen_lookup_ops.anonymous_mutable_hash_table( 1877 key_dtype=self._key_dtype, 1878 value_dtype=self._value_dtype, 1879 name=self._name) 1880 else: 1881 table_ref = gen_lookup_ops.anonymous_mutable_hash_table_of_tensors( 1882 key_dtype=self._key_dtype, 1883 value_dtype=self._value_dtype, 1884 value_shape=self._default_value.get_shape(), 1885 name=self._name) 1886 else: 1887 # The table must be shared if checkpointing is requested for multi-worker 1888 # training to work correctly. Use the node name if no shared_name has been 1889 # explicitly specified. 1890 use_node_name_sharing = self._checkpoint and self._shared_name is None 1891 if self._default_value.get_shape().ndims == 0: 1892 table_ref = gen_lookup_ops.mutable_hash_table_v2( 1893 shared_name=self._shared_name, 1894 use_node_name_sharing=use_node_name_sharing, 1895 key_dtype=self._key_dtype, 1896 value_dtype=self._value_dtype, 1897 name=self._name) 1898 else: 1899 table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( 1900 shared_name=self._shared_name, 1901 use_node_name_sharing=use_node_name_sharing, 1902 key_dtype=self._key_dtype, 1903 value_dtype=self._value_dtype, 1904 value_shape=self._default_value.get_shape(), 1905 name=self._name) 1906 1907 if context.executing_eagerly(): 1908 self._table_name = None 1909 else: 1910 self._table_name = table_ref.op.name.split("/")[-1] 1911 return table_ref 1912 1913 @property 1914 def name(self): 1915 return self._table_name 1916 1917 def size(self, name=None): 1918 """Compute the number of elements in this table. 1919 1920 Args: 1921 name: A name for the operation (optional). 1922 1923 Returns: 1924 A scalar tensor containing the number of elements in this table. 1925 """ 1926 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 1927 with ops.colocate_with(self.resource_handle): 1928 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 1929 1930 def remove(self, keys, name=None): 1931 """Removes `keys` and its associated values from the table. 1932 1933 If a key is not present in the table, it is silently ignored. 1934 1935 Args: 1936 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 1937 key type. 1938 name: A name for the operation (optional). 1939 1940 Returns: 1941 The created Operation. 1942 1943 Raises: 1944 TypeError: when `keys` do not match the table data types. 1945 """ 1946 if keys.dtype != self._key_dtype: 1947 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, " 1948 f"received: {keys.dtype}") 1949 1950 with ops.name_scope(name, "%s_lookup_table_remove" % self.name, 1951 (self.resource_handle, keys, self._default_value)): 1952 op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys) 1953 1954 return op 1955 1956 def lookup(self, keys, dynamic_default_values=None, name=None): 1957 """Looks up `keys` in a table, outputs the corresponding values. 1958 1959 The `default_value` is used for keys not present in the table. 1960 1961 Args: 1962 keys: Keys to look up. Can be a tensor of any shape. Must match the 1963 table's key_dtype. 1964 dynamic_default_values: The values to use if a key is missing in the 1965 table. If None (by default), the `table.default_value` will be used. 1966 Shape of `dynamic_default_values` must be same with 1967 `table.default_value` or the lookup result tensor. 1968 In the latter case, each key will have a different default value. 1969 1970 For example: 1971 1972 ```python 1973 keys = [0, 1, 3] 1974 dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]] 1975 1976 # The key '0' will use [1, 3, 4] as default value. 1977 # The key '1' will use [2, 3, 9] as default value. 1978 # The key '3' will use [8, 3, 0] as default value. 1979 ``` 1980 1981 name: A name for the operation (optional). 1982 1983 Returns: 1984 A tensor containing the values in the same shape as `keys` using the 1985 table's value type. 1986 1987 Raises: 1988 TypeError: when `keys` do not match the table data types. 1989 """ 1990 with ops.name_scope(name, "%s_lookup_table_find" % self.name, 1991 (self.resource_handle, keys, self._default_value)): 1992 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 1993 with ops.colocate_with(self.resource_handle): 1994 values = gen_lookup_ops.lookup_table_find_v2( 1995 self.resource_handle, keys, dynamic_default_values 1996 if dynamic_default_values is not None else self._default_value) 1997 return values 1998 1999 def insert(self, keys, values, name=None): 2000 """Associates `keys` with `values`. 2001 2002 Args: 2003 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 2004 key type. 2005 values: Values to be associated with keys. Must be a tensor of the same 2006 shape as `keys` and match the table's value type. 2007 name: A name for the operation (optional). 2008 2009 Returns: 2010 The created Operation. 2011 2012 Raises: 2013 TypeError: when `keys` or `values` doesn't match the table data 2014 types. 2015 """ 2016 with ops.name_scope(name, "%s_lookup_table_insert" % self.name, 2017 [self.resource_handle, keys, values]): 2018 keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys") 2019 values = ops.convert_to_tensor(values, self._value_dtype, name="values") 2020 with ops.colocate_with(self.resource_handle): 2021 # pylint: disable=protected-access 2022 op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys, 2023 values) 2024 return op 2025 2026 def export(self, name=None): 2027 """Returns tensors of all keys and values in the table. 2028 2029 Args: 2030 name: A name for the operation (optional). 2031 2032 Returns: 2033 A pair of tensors with the first tensor containing all keys and the 2034 second tensors containing all values in the table. 2035 """ 2036 with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, 2037 [self.resource_handle]): 2038 with ops.colocate_with(self.resource_handle): 2039 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 2040 self.resource_handle, self._key_dtype, self._value_dtype) 2041 return exported_keys, exported_values 2042 2043 def _serialize_to_tensors(self): 2044 """Implements checkpointing protocols for `Trackable`.""" 2045 tensors = self.export() 2046 return {"-keys": tensors[0], "-values": tensors[1]} 2047 2048 def _restore_from_tensors(self, restored_tensors): 2049 """Implements checkpointing protocols for `Trackable`.""" 2050 with ops.name_scope("%s_table_restore" % self._name): 2051 with ops.colocate_with(self.resource_handle): 2052 return gen_lookup_ops.lookup_table_import_v2( 2053 self.resource_handle, 2054 restored_tensors["-keys"], 2055 restored_tensors["-values"]) 2056 2057 # This class is needed for `MutableHashTable(checkpoint=True)`. 2058 class _Saveable(BaseSaverBuilder.SaveableObject): 2059 """SaveableObject implementation for DenseHashTable.""" 2060 2061 def __init__(self, table, name, table_name=None): 2062 tensors = table.export() 2063 specs = [ 2064 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), 2065 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") 2066 ] 2067 self.table_name = table_name or name 2068 # pylint: disable=protected-access 2069 super(MutableHashTable._Saveable, self).__init__(table, specs, name) 2070 2071 def restore(self, restored_tensors, restored_shapes): 2072 del restored_shapes # unused 2073 # pylint: disable=protected-access 2074 with ops.name_scope("%s_table_restore" % self.table_name): 2075 with ops.colocate_with(self.op.resource_handle): 2076 return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle, 2077 restored_tensors[0], 2078 restored_tensors[1]) 2079 2080 2081@tf_export("lookup.experimental.DenseHashTable") 2082@saveable_compat.legacy_saveable_name("table") 2083class DenseHashTable(LookupInterface): 2084 """A mutable hash table with faster lookups and higher memory usage. 2085 2086 Data can be inserted by calling the `insert` method and removed by calling the 2087 `remove` method. It does not support initialization via the init method. 2088 2089 Compared to `MutableHashTable`, `DenseHashTable` offers generally faster 2090 `insert`, `remove` and `lookup` operations, in exchange for a higher overall 2091 memory footprint. 2092 2093 It uses "open addressing" with quadratic reprobing to resolve collisions. This 2094 requires specifying two keys in the key space, `empty_key` and `deleted_key`, 2095 that can never inserted into the table. 2096 2097 Unlike `MutableHashTable`, `DenseHashTable` does not require additional memory 2098 for temporary tensors created during checkpointing and restore operations. 2099 2100 Example usage: 2101 2102 >>> table = tf.lookup.experimental.DenseHashTable( 2103 ... key_dtype=tf.string, 2104 ... value_dtype=tf.int64, 2105 ... default_value=-1, 2106 ... empty_key='', 2107 ... deleted_key='$') 2108 >>> keys = tf.constant(['a', 'b', 'c']) 2109 >>> values = tf.constant([0, 1, 2], dtype=tf.int64) 2110 >>> table.insert(keys, values) 2111 >>> table.remove(tf.constant(['c'])) 2112 >>> table.lookup(tf.constant(['a', 'b', 'c','d'])).numpy() 2113 array([ 0, 1, -1, -1]) 2114 """ 2115 2116 # TODO(andreasst): consider extracting common code with MutableHashTable into 2117 # a common superclass. 2118 def __init__(self, 2119 key_dtype, 2120 value_dtype, 2121 default_value, 2122 empty_key, 2123 deleted_key, 2124 initial_num_buckets=None, 2125 name="MutableDenseHashTable", 2126 checkpoint=True, 2127 experimental_is_anonymous=False): 2128 """Creates an empty `DenseHashTable` object. 2129 2130 Creates a table, the type of its keys and values are specified by key_dtype 2131 and value_dtype, respectively. 2132 2133 Args: 2134 key_dtype: the type of the key tensors. 2135 value_dtype: the type of the value tensors. 2136 default_value: The value to use if a key is missing in the table. 2137 empty_key: the key to use to represent empty buckets internally. Must not 2138 be used in insert, remove or lookup operations. 2139 deleted_key: the key to use to represent deleted buckets internally. Must 2140 not be used in insert, remove or lookup operations and be different from 2141 the empty_key. 2142 initial_num_buckets: the initial number of buckets (optional, 2143 default to 2^17=131072). Note that the default value is 2144 relatively large (~1MB), so if you are going to create many 2145 tables (likely the case when `experimental_is_anonymous` is 2146 `True`), you should set `initial_num_buckets` to a smaller 2147 value to reduce memory usage. 2148 name: A name for the operation (optional). 2149 checkpoint: if True, the contents of the table are saved to and restored 2150 from checkpoints. If `shared_name` is empty for a checkpointed table, it 2151 is shared using the table node name. 2152 experimental_is_anonymous: Whether to use anonymous mode for the 2153 table (default is False). In anonymous mode, the table 2154 resource can only be accessed via a resource handle. It can't 2155 be looked up by a name. When all resource handles pointing to 2156 that resource are gone, the resource will be deleted 2157 automatically. 2158 2159 Returns: 2160 A `DenseHashTable` object. 2161 2162 Raises: 2163 ValueError: If checkpoint is True and no name was specified. 2164 """ 2165 self._default_value = ops.convert_to_tensor( 2166 default_value, dtype=value_dtype, name="default_value") 2167 self._key_dtype = key_dtype 2168 self._value_dtype = value_dtype 2169 # TODO(b/201578996): Pick a good default for initial_num_buckets 2170 # other than 2^17. 2171 self._initial_num_buckets = initial_num_buckets 2172 self._value_shape = self._default_value.get_shape() 2173 self._checkpoint = checkpoint 2174 self._name = name 2175 self._empty_key = empty_key 2176 self._deleted_key = deleted_key 2177 self._is_anonymous = experimental_is_anonymous 2178 if not self._is_anonymous: 2179 self._shared_name = None 2180 if context.executing_eagerly(): 2181 # TODO(allenl): This will leak memory due to kernel caching by 2182 # the shared_name attribute value (but is better than the 2183 # alternative of sharing everything by default when executing 2184 # eagerly; hopefully creating tables in a loop is uncommon). 2185 self._shared_name = "table_%d" % (ops.uid(),) 2186 super(DenseHashTable, self).__init__(key_dtype, value_dtype) 2187 self._resource_handle = self._create_resource() 2188 if checkpoint: 2189 saveable = DenseHashTable._Saveable(self, name) 2190 if not context.executing_eagerly(): 2191 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 2192 2193 def _create_resource(self): 2194 empty_key = ops.convert_to_tensor( 2195 self._empty_key, dtype=self._key_dtype, name="empty_key") 2196 deleted_key = ops.convert_to_tensor( 2197 self._deleted_key, dtype=self._key_dtype, name="deleted_key") 2198 if self._is_anonymous: 2199 table_ref = gen_lookup_ops.anonymous_mutable_dense_hash_table( 2200 empty_key=empty_key, 2201 deleted_key=deleted_key, 2202 value_dtype=self._value_dtype, 2203 value_shape=self._value_shape, 2204 initial_num_buckets=self._initial_num_buckets, 2205 name=self._name) 2206 else: 2207 # The table must be shared if checkpointing is requested for multi-worker 2208 # training to work correctly. Use the node name if no shared_name has been 2209 # explicitly specified. 2210 use_node_name_sharing = self._checkpoint and self._shared_name is None 2211 table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( 2212 empty_key=empty_key, 2213 deleted_key=deleted_key, 2214 shared_name=self._shared_name, 2215 use_node_name_sharing=use_node_name_sharing, 2216 value_dtype=self._value_dtype, 2217 value_shape=self._value_shape, 2218 initial_num_buckets=self._initial_num_buckets, 2219 name=self._name) 2220 if context.executing_eagerly(): 2221 self._table_name = None 2222 else: 2223 self._table_name = table_ref.op.name.split("/")[-1] 2224 return table_ref 2225 2226 @property 2227 def name(self): 2228 return self._table_name 2229 2230 def size(self, name=None): 2231 """Compute the number of elements in this table. 2232 2233 Args: 2234 name: A name for the operation (optional). 2235 2236 Returns: 2237 A scalar tensor containing the number of elements in this table. 2238 """ 2239 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 2240 with ops.colocate_with(self.resource_handle): 2241 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 2242 2243 def lookup(self, keys, name=None): 2244 """Looks up `keys` in a table, outputs the corresponding values. 2245 2246 The `default_value` is used for keys not present in the table. 2247 2248 Args: 2249 keys: Keys to look up. Can be a tensor of any shape. Must match the 2250 table's key_dtype. 2251 name: A name for the operation (optional). 2252 2253 Returns: 2254 A tensor containing the values in the same shape as `keys` using the 2255 table's value type. 2256 2257 Raises: 2258 TypeError: when `keys` do not match the table data types. 2259 """ 2260 with ops.name_scope(name, "%s_lookup_table_find" % self.name, 2261 [self.resource_handle, keys]): 2262 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 2263 with ops.colocate_with(self.resource_handle): 2264 values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys, 2265 self._default_value) 2266 2267 return values 2268 2269 def insert_or_assign(self, keys, values, name=None): 2270 """Associates `keys` with `values`. 2271 2272 Args: 2273 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 2274 key type. 2275 values: Values to be associated with keys. Must be a tensor of the same 2276 shape as `keys` and match the table's value type. 2277 name: A name for the operation (optional). 2278 2279 Returns: 2280 The created Operation. 2281 2282 Raises: 2283 TypeError: when `keys` or `values` doesn't match the table data 2284 types. 2285 """ 2286 with ops.name_scope(name, "%s_lookup_table_insert" % self.name, 2287 [self.resource_handle, keys, values]): 2288 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 2289 values = ops.convert_to_tensor( 2290 values, dtype=self._value_dtype, name="values") 2291 with ops.colocate_with(self.resource_handle): 2292 op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys, 2293 values) 2294 return op 2295 2296 def insert(self, keys, values, name=None): 2297 """Associates `keys` with `values`. 2298 2299 Args: 2300 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 2301 key type. 2302 values: Values to be associated with keys. Must be a tensor of the same 2303 shape as `keys` and match the table's value type. 2304 name: A name for the operation (optional). 2305 2306 Returns: 2307 The created Operation. 2308 2309 Raises: 2310 TypeError: when `keys` or `values` doesn't match the table data 2311 types. 2312 """ 2313 return self.insert_or_assign(keys, values, name) 2314 2315 def erase(self, keys, name=None): 2316 """Removes `keys` and its associated values from the table. 2317 2318 If a key is not present in the table, it is silently ignored. 2319 2320 Args: 2321 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 2322 key type. 2323 name: A name for the operation (optional). 2324 2325 Returns: 2326 The created Operation. 2327 2328 Raises: 2329 TypeError: when `keys` do not match the table data types. 2330 """ 2331 if keys.dtype != self._key_dtype: 2332 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 2333 (self._key_dtype, keys.dtype)) 2334 2335 with ops.name_scope(name, "%s_lookup_table_remove" % self.name, 2336 (self.resource_handle, keys, self._default_value)): 2337 # pylint: disable=protected-access 2338 op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys) 2339 2340 return op 2341 2342 def remove(self, keys, name=None): 2343 """Removes `keys` and its associated values from the table. 2344 2345 If a key is not present in the table, it is silently ignored. 2346 2347 Args: 2348 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 2349 key type. 2350 name: A name for the operation (optional). 2351 2352 Returns: 2353 The created Operation. 2354 2355 Raises: 2356 TypeError: when `keys` do not match the table data types. 2357 """ 2358 return self.erase(keys, name) 2359 2360 def export(self, name=None): 2361 """Returns tensors of all keys and values in the table. 2362 2363 Args: 2364 name: A name for the operation (optional). 2365 2366 Returns: 2367 A pair of tensors with the first tensor containing all keys and the 2368 second tensors containing all values in the table. 2369 """ 2370 with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, 2371 [self.resource_handle]): 2372 with ops.colocate_with(self.resource_handle): 2373 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 2374 self.resource_handle, self._key_dtype, self._value_dtype) 2375 2376 return exported_keys, exported_values 2377 2378 def _serialize_to_tensors(self): 2379 """Implements checkpointing interface in `Trackable`.""" 2380 tensors = self.export() 2381 return {"-keys": tensors[0], "-values": tensors[1]} 2382 2383 def _restore_from_tensors(self, restored_tensors): 2384 """Implements checkpointing interface in `Trackable`.""" 2385 with ops.name_scope("%s_table_restore" % self._name): 2386 with ops.colocate_with(self.resource_handle): 2387 return gen_lookup_ops.lookup_table_import_v2( 2388 self.resource_handle, 2389 restored_tensors["-keys"], 2390 restored_tensors["-values"]) 2391 2392 # This class is needed for `DenseHashTable(checkpoint=True)`. 2393 class _Saveable(BaseSaverBuilder.SaveableObject): 2394 """SaveableObject implementation for DenseHashTable.""" 2395 2396 def __init__(self, table, name, table_name=None): 2397 tensors = table.export() 2398 specs = [ 2399 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), 2400 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") 2401 ] 2402 self.table_name = table_name or name 2403 # pylint: disable=protected-access 2404 super(DenseHashTable._Saveable, self).__init__(table, specs, name) 2405 2406 def restore(self, restored_tensors, restored_shapes): 2407 del restored_shapes # unused 2408 # pylint: disable=protected-access 2409 with ops.name_scope("%s_table_restore" % self.table_name): 2410 with ops.colocate_with(self.op.resource_handle): 2411 return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle, 2412 restored_tensors[0], 2413 restored_tensors[1]) 2414 2415 2416ops.NotDifferentiable("LookupTableFind") 2417ops.NotDifferentiable("LookupTableFindV2") 2418ops.NotDifferentiable("LookupTableInsert") 2419ops.NotDifferentiable("LookupTableInsertV2") 2420ops.NotDifferentiable("LookupTableSize") 2421ops.NotDifferentiable("LookupTableSizeV2") 2422ops.NotDifferentiable("HashTable") 2423ops.NotDifferentiable("HashTableV2") 2424ops.NotDifferentiable("InitializeTable") 2425ops.NotDifferentiable("InitializeTableV2") 2426ops.NotDifferentiable("InitializeTableFromTextFile") 2427ops.NotDifferentiable("InitializeTableFromTextFileV2") 2428ops.NotDifferentiable("MutableDenseHashTable") 2429ops.NotDifferentiable("MutableDenseHashTableV2") 2430ops.NotDifferentiable("MutableHashTable") 2431ops.NotDifferentiable("MutableHashTableV2") 2432ops.NotDifferentiable("MutableHashTableOfTensors") 2433ops.NotDifferentiable("MutableHashTableOfTensorsV2") 2434