1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#============================================================================== 15"""Lookup operations.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23import uuid 24 25import six 26 27from tensorflow.python.eager import context 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.framework import tensor_util 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import gen_lookup_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import string_ops 39# go/tf-wildcard-import 40# pylint: disable=wildcard-import 41from tensorflow.python.ops.gen_lookup_ops import * 42from tensorflow.python.ops.ragged import ragged_tensor 43from tensorflow.python.training.saver import BaseSaverBuilder 44# pylint: enable=wildcard-import 45from tensorflow.python.training.tracking import base as trackable_base 46from tensorflow.python.training.tracking import tracking as trackable 47from tensorflow.python.util import compat as compat_util 48from tensorflow.python.util.deprecation import deprecated 49from tensorflow.python.util.tf_export import tf_export 50 51 52@tf_export(v1=["initialize_all_tables"]) 53@deprecated(None, "Use `tf.tables_initializer` instead.") 54def initialize_all_tables(name="init_all_tables"): 55 """Returns an Op that initializes all tables of the default graph. 56 57 Args: 58 name: Optional name for the initialization op. 59 60 Returns: 61 An Op that initializes all tables. Note that if there are 62 not tables the returned Op is a NoOp. 63 """ 64 return tables_initializer(name) 65 66 67@tf_export(v1=["initializers.tables_initializer", "tables_initializer"]) 68def tables_initializer(name="init_all_tables"): 69 """Returns an Op that initializes all tables of the default graph. 70 71 Args: 72 name: Optional name for the initialization op. 73 74 Returns: 75 An Op that initializes all tables. Note that if there are 76 not tables the returned Op is a NoOp. 77 78 @compatibility(TF2) 79 `tf.compat.v1.tables_initializer` is no longer needed with eager execution and 80 `tf.function`. In TF2, when creating an initializable table like a 81 `tf.lookup.StaticHashTable`, the table will automatically be initialized on 82 creation. 83 84 #### Before & After Usage Example 85 86 Before: 87 88 >>> with tf.compat.v1.Session(): 89 ... init = tf.compat.v1.lookup.KeyValueTensorInitializer(['a', 'b'], [1, 2]) 90 ... table = tf.compat.v1.lookup.StaticHashTable(init, default_value=-1) 91 ... tf.compat.v1.tables_initializer().run() 92 ... result = table.lookup(tf.constant(['a', 'c'])).eval() 93 >>> result 94 array([ 1, -1], dtype=int32) 95 96 After: 97 98 >>> init = tf.lookup.KeyValueTensorInitializer(['a', 'b'], [1, 2]) 99 >>> table = tf.lookup.StaticHashTable(init, default_value=-1) 100 >>> table.lookup(tf.constant(['a', 'c'])).numpy() 101 array([ 1, -1], dtype=int32) 102 103 @end_compatibility 104 """ 105 initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) 106 if initializers: 107 return control_flow_ops.group(*initializers, name=name) 108 return control_flow_ops.no_op(name=name) 109 110 111def check_table_dtypes(table, key_dtype, value_dtype): 112 """Check that the given key_dtype and value_dtype matches the table dtypes. 113 114 Args: 115 table: The table to check types against to. 116 key_dtype: The key data type to check. 117 value_dtype: The value data type to check. 118 119 Raises: 120 TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data 121 types. 122 """ 123 if key_dtype.base_dtype != table.key_dtype: 124 raise TypeError(f"Invalid key dtype for table, expected {table.key_dtype} " 125 f"but got {key_dtype}.") 126 if value_dtype.base_dtype != table.value_dtype: 127 raise TypeError("Invalid value dtype for table, expected " 128 f"{table.value_dtype} but got {value_dtype}.") 129 130 131class LookupInterface(trackable.TrackableResource): 132 """Represent a lookup table that persists across different steps.""" 133 134 def __init__(self, key_dtype, value_dtype): 135 """Construct a lookup table interface. 136 137 Args: 138 key_dtype: The table key type. 139 value_dtype: The table value type. 140 """ 141 self._key_dtype = dtypes.as_dtype(key_dtype) 142 self._value_dtype = dtypes.as_dtype(value_dtype) 143 super(LookupInterface, self).__init__() 144 145 def _create_resource(self): 146 raise NotImplementedError 147 148 @property 149 def key_dtype(self): 150 """The table key dtype.""" 151 return self._key_dtype 152 153 @property 154 def value_dtype(self): 155 """The table value dtype.""" 156 return self._value_dtype 157 158 @property 159 def name(self): 160 """The name of the table.""" 161 return NotImplementedError 162 163 def size(self, name=None): 164 """Compute the number of elements in this table.""" 165 raise NotImplementedError 166 167 def lookup(self, keys, name=None): 168 """Looks up `keys` in a table, outputs the corresponding values.""" 169 raise NotImplementedError 170 171 def __getitem__(self, keys): 172 """Looks up `keys` in a table, outputs the corresponding values.""" 173 return self.lookup(keys) 174 175 176class InitializableLookupTableBase(LookupInterface): 177 """Initializable lookup table interface. 178 179 An initializable lookup tables persist across different steps. 180 """ 181 182 def __init__(self, default_value, initializer): 183 """Construct a table object from a table reference. 184 185 If requires a table initializer object (subclass of `TableInitializerBase`). 186 It provides the table key and value types, as well as the op to initialize 187 the table. The caller is responsible to execute the initialization op. 188 189 Args: 190 default_value: The value to use if a key is missing in the table. 191 initializer: The table initializer to use. 192 """ 193 super(InitializableLookupTableBase, self).__init__(initializer.key_dtype, 194 initializer.value_dtype) 195 self._default_value = ops.convert_to_tensor( 196 default_value, dtype=self._value_dtype) 197 self._default_value.get_shape().merge_with(tensor_shape.TensorShape([])) 198 if isinstance(initializer, trackable_base.Trackable): 199 self._initializer = self._track_trackable(initializer, "_initializer") 200 with ops.init_scope(): 201 self._resource_handle = self._create_resource() 202 if (not context.executing_eagerly() and 203 ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access 204 with ops.init_scope(): 205 self._init_op = self._initialize() 206 else: 207 self._init_op = self._initialize() 208 209 def _initialize(self): 210 return self._initializer.initialize(self) 211 212 @property 213 def default_value(self): 214 """The default value of the table.""" 215 return self._default_value 216 217 def size(self, name=None): 218 """Compute the number of elements in this table. 219 220 Args: 221 name: A name for the operation (optional). 222 223 Returns: 224 A scalar tensor containing the number of elements in this table. 225 """ 226 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 227 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 228 229 def lookup(self, keys, name=None): 230 """Looks up `keys` in a table, outputs the corresponding values. 231 232 The `default_value` is used for keys not present in the table. 233 234 Args: 235 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 236 name: A name for the operation (optional). 237 238 Returns: 239 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged, 240 otherwise a dense `Tensor`. 241 242 Raises: 243 TypeError: when `keys` or `default_value` doesn't match the table data 244 types. 245 """ 246 key_tensor = keys 247 if isinstance(keys, 248 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)): 249 key_tensor = keys.values 250 251 if keys.dtype.base_dtype != self._key_dtype: 252 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, " 253 f"received: {keys.dtype}") 254 255 with ops.name_scope( 256 name, "%s_Lookup" % self.name, 257 (self.resource_handle, key_tensor, self._default_value)): 258 values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, 259 key_tensor, 260 self._default_value) 261 262 values.set_shape(key_tensor.get_shape()) 263 if isinstance(keys, sparse_tensor.SparseTensor): 264 return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape) 265 elif isinstance(keys, ragged_tensor.RaggedTensor): 266 return keys.with_values(values) 267 else: 268 return values 269 270 271class InitializableLookupTableBaseV1(InitializableLookupTableBase): 272 273 @property 274 def initializer(self): 275 return self._init_op 276 277 278@tf_export("lookup.StaticHashTable", v1=[]) 279class StaticHashTable(InitializableLookupTableBase): 280 """A generic hash table that is immutable once initialized. 281 282 Example usage: 283 284 >>> keys_tensor = tf.constant(['a', 'b', 'c']) 285 >>> vals_tensor = tf.constant([7, 8, 9]) 286 >>> input_tensor = tf.constant(['a', 'f']) 287 >>> table = tf.lookup.StaticHashTable( 288 ... tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), 289 ... default_value=-1) 290 >>> table.lookup(input_tensor).numpy() 291 array([ 7, -1], dtype=int32) 292 293 Or for more pythonic code: 294 295 >>> table[input_tensor].numpy() 296 array([ 7, -1], dtype=int32) 297 298 The result of a lookup operation has the same shape as the argument: 299 300 >>> input_tensor = tf.constant([['a', 'b'], ['c', 'd']]) 301 >>> table[input_tensor].numpy() 302 array([[ 7, 8], 303 [ 9, -1]], dtype=int32) 304 305 306 """ 307 308 def __init__(self, initializer, default_value, name=None): 309 """Creates a non-initialized `HashTable` object. 310 311 Creates a table, the type of its keys and values are specified by the 312 initializer. 313 Before using the table you will have to initialize it. After initialization 314 the table will be immutable. 315 316 Args: 317 initializer: The table initializer to use. See `HashTable` kernel for 318 supported key and value types. 319 default_value: The value to use if a key is missing in the table. 320 name: A name for the operation (optional). 321 322 Returns: 323 A `HashTable` object. 324 """ 325 self._initializer = initializer 326 self._default_value = default_value 327 self._shared_name = self._initializer._shared_name # pylint: disable=protected-access 328 if not self._shared_name: 329 # Force using a shared name so that StaticHashTable resources can be 330 # shared across different kernels. If no "shared_name" is set and 331 # "use_node_name_sharing" is False, then each kernel gets its own local 332 # resource. 333 self._shared_name = "hash_table_%s" % (str(uuid.uuid4()),) 334 self._name = name or "hash_table" 335 self._table_name = None 336 super(StaticHashTable, self).__init__(default_value, initializer) 337 self._value_shape = self._default_value.get_shape() 338 339 def _create_resource(self): 340 table_ref = gen_lookup_ops.hash_table_v2( 341 shared_name=self._shared_name, 342 key_dtype=self._initializer.key_dtype, 343 value_dtype=self._initializer.value_dtype, 344 name=self._name) 345 if context.executing_eagerly(): 346 self._table_name = None 347 else: 348 self._table_name = table_ref.op.name.split("/")[-1] 349 return table_ref 350 351 @property 352 def name(self): 353 return self._table_name 354 355 def export(self, name=None): 356 """Returns tensors of all keys and values in the table. 357 358 Args: 359 name: A name for the operation (optional). 360 361 Returns: 362 A pair of tensors with the first tensor containing all keys and the 363 second tensors containing all values in the table. 364 """ 365 with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]): 366 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 367 self.resource_handle, self._key_dtype, self._value_dtype) 368 369 exported_values.set_shape(exported_keys.get_shape().concatenate( 370 self._value_shape)) 371 return exported_keys, exported_values 372 373 374@tf_export(v1=["lookup.StaticHashTable"]) 375class StaticHashTableV1(StaticHashTable): 376 """A generic hash table that is immutable once initialized. 377 378 When running in graph mode, you must evaluate the tensor returned by 379 `tf.tables_initializer()` before evaluating the tensor returned by 380 this class's `lookup()` method. Example usage in graph mode: 381 382 ```python 383 keys_tensor = tf.constant([1, 2]) 384 vals_tensor = tf.constant([3, 4]) 385 input_tensor = tf.constant([1, 5]) 386 table = tf.lookup.StaticHashTable( 387 tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1) 388 out = table.lookup(input_tensor) 389 with tf.Session() as sess: 390 sess.run(tf.tables_initializer()) 391 print(sess.run(out)) 392 ``` 393 394 In eager mode, no special code is needed to initialize the table. 395 Example usage in eager mode: 396 397 ```python 398 tf.enable_eager_execution() 399 keys_tensor = tf.constant([1, 2]) 400 vals_tensor = tf.constant([3, 4]) 401 input_tensor = tf.constant([1, 5]) 402 table = tf.lookup.StaticHashTable( 403 tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1) 404 print(table.lookup(input_tensor)) 405 ``` 406 """ 407 408 @property 409 def initializer(self): 410 return self._init_op 411 412 413# For backwards compatibility. This will be removed in TF 2.0. 414class HashTable(StaticHashTableV1): 415 416 @property 417 def init(self): 418 return self.initializer 419 420 421class TableInitializerBase(trackable_base.Trackable): 422 """Base class for lookup table initializers.""" 423 424 def __init__(self, key_dtype, value_dtype): 425 """Construct a table initializer object. 426 427 Args: 428 key_dtype: Type of the table keys. 429 value_dtype: Type of the table values. 430 """ 431 self._key_dtype = dtypes.as_dtype(key_dtype) 432 self._value_dtype = dtypes.as_dtype(value_dtype) 433 434 @property 435 def key_dtype(self): 436 """The expected table key dtype.""" 437 return self._key_dtype 438 439 @property 440 def value_dtype(self): 441 """The expected table value dtype.""" 442 return self._value_dtype 443 444 def initialize(self, table): 445 """Returns the table initialization op.""" 446 raise NotImplementedError 447 448 @property 449 def _shared_name(self): 450 """Returns a shared name to be used by the table.""" 451 shared_name = "" 452 if context.executing_eagerly(): 453 # Ensure a unique name when eager execution is enabled to avoid spurious 454 # sharing issues. 455 # TODO(rohanj): Use context.shared_name() instead. 456 shared_name += str(ops.uid()) 457 return shared_name 458 459 460@tf_export("lookup.KeyValueTensorInitializer") 461class KeyValueTensorInitializer(TableInitializerBase): 462 """Table initializers given `keys` and `values` tensors. 463 464 >>> keys_tensor = tf.constant(['a', 'b', 'c']) 465 >>> vals_tensor = tf.constant([7, 8, 9]) 466 >>> input_tensor = tf.constant(['a', 'f']) 467 >>> init = tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor) 468 >>> table = tf.lookup.StaticHashTable( 469 ... init, 470 ... default_value=-1) 471 >>> table.lookup(input_tensor).numpy() 472 array([ 7, -1], dtype=int32) 473 474 """ 475 476 def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None): 477 """Constructs a table initializer object based on keys and values tensors. 478 479 Args: 480 keys: The tensor for the keys. 481 values: The tensor for the values. 482 key_dtype: The `keys` data type. Used when `keys` is a python array. 483 value_dtype: The `values` data type. Used when `values` is a python array. 484 name: A name for the operation (optional). 485 """ 486 if (not context.executing_eagerly() and 487 ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access 488 with ops.init_scope(): 489 self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") 490 self._values = ops.convert_to_tensor( 491 values, dtype=value_dtype, name="values") 492 else: 493 self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") 494 self._values = ops.convert_to_tensor( 495 values, dtype=value_dtype, name="values") 496 self._name = name if name is not None else "key_value_init" 497 if context.executing_eagerly(): 498 # Ensure a unique name when eager execution is enabled to avoid spurious 499 # sharing issues. 500 # TODO(rohanj): Use context.shared_name() instead. 501 self._name += str(ops.uid()) 502 503 super(KeyValueTensorInitializer, self).__init__(self._keys.dtype, 504 self._values.dtype) 505 506 def initialize(self, table): 507 """Initializes the given `table` with `keys` and `values` tensors. 508 509 Args: 510 table: The table to initialize. 511 512 Returns: 513 The operation that initializes the table. 514 515 Raises: 516 TypeError: when the keys and values data types do not match the table 517 key and value data types. 518 """ 519 check_table_dtypes(table, self._keys.dtype, self._values.dtype) 520 with ops.name_scope( 521 self._name, values=(table.resource_handle, self._keys, self._values)): 522 init_op = gen_lookup_ops.lookup_table_import_v2(table.resource_handle, 523 self._keys, self._values) 524 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 525 return init_op 526 527 528@tf_export("lookup.TextFileIndex") 529class TextFileIndex(object): 530 """The key and value content to get from each line. 531 532 This class defines the key and value used for `tf.lookup.TextFileInitializer`. 533 534 The key and value content to get from each line is specified either 535 by the following, or a value `>=0`. 536 * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero, 537 expects data type int64. 538 * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data 539 type string. 540 541 A value `>=0` means use the index (starting at zero) of the split line based 542 on `delimiter`. 543 """ 544 WHOLE_LINE = -2 545 LINE_NUMBER = -1 546 547 548@tf_export("lookup.TextFileInitializer") 549class TextFileInitializer(TableInitializerBase): 550 r"""Table initializers from a text file. 551 552 This initializer assigns one entry in the table for each line in the file. 553 554 The key and value type of the table to initialize is given by `key_dtype` and 555 `value_dtype`. 556 557 The key and value content to get from each line is specified by 558 the `key_index` and `value_index`. 559 560 * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero, 561 expects data type int64. 562 * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data 563 type string. 564 * A value `>=0` means use the index (starting at zero) of the split line based 565 on `delimiter`. 566 567 For example if we have a file with the following content: 568 569 >>> import tempfile 570 >>> f = tempfile.NamedTemporaryFile(delete=False) 571 >>> content='\n'.join(["emerson 10", "lake 20", "palmer 30",]) 572 >>> f.file.write(content.encode('utf-8')) 573 >>> f.file.close() 574 575 The following snippet initializes a table with the first column as keys and 576 second column as values: 577 578 * `emerson -> 10` 579 * `lake -> 20` 580 * `palmer -> 30` 581 582 >>> init= tf.lookup.TextFileInitializer( 583 ... filename=f.name, 584 ... key_dtype=tf.string, key_index=0, 585 ... value_dtype=tf.int64, value_index=1, 586 ... delimiter=" ") 587 >>> table = tf.lookup.StaticHashTable(init, default_value=-1) 588 >>> table.lookup(tf.constant(['palmer','lake','tarkus'])).numpy() 589 590 Similarly to initialize the whole line as keys and the line number as values. 591 592 * `emerson 10 -> 0` 593 * `lake 20 -> 1` 594 * `palmer 30 -> 2` 595 596 >>> init = tf.lookup.TextFileInitializer( 597 ... filename=f.name, 598 ... key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, 599 ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER) 600 >>> table = tf.lookup.StaticHashTable(init, -1) 601 >>> table.lookup(tf.constant('palmer 30')).numpy() 602 2 603 """ 604 605 def __init__(self, 606 filename, 607 key_dtype, 608 key_index, 609 value_dtype, 610 value_index, 611 vocab_size=None, 612 delimiter="\t", 613 name=None, 614 value_index_offset=0): 615 """Constructs a table initializer object to populate from a text file. 616 617 It generates one key-value pair per line. The type of table key and 618 value are specified by `key_dtype` and `value_dtype`, respectively. 619 Similarly the content of the key and value are specified by the key_index 620 and value_index. 621 622 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 623 expects data type int64. 624 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 625 type string or int64. 626 - A value >=0 means use the index (starting at zero) of the split line based 627 on `delimiter`. 628 629 Args: 630 filename: The filename of the text file to be used for initialization. The 631 path must be accessible from wherever the graph is initialized (eg. 632 trainer or eval workers). The filename may be a scalar `Tensor`. 633 key_dtype: The `key` data type. 634 key_index: the index that represents information of a line to get the 635 table 'key' values from. 636 value_dtype: The `value` data type. 637 value_index: the index that represents information of a line to get the 638 table 'value' values from.' 639 vocab_size: The number of elements in the file, if known. 640 delimiter: The delimiter to separate fields in a line. 641 name: A name for the operation (optional). 642 value_index_offset: A number to add to all indices extracted from the file 643 This is useful for cases where a user would like to reserve one or more 644 low index values for control characters. For instance, if you would 645 like to ensure that no vocabulary item is mapped to index 0 (so you can 646 reserve 0 for a masking value), you can set value_index_offset to 1; 647 this will mean that the first vocabulary element is mapped to 1 648 instead of 0. 649 650 Raises: 651 ValueError: when the filename is empty, or when the table key and value 652 data types do not match the expected data types. 653 """ 654 if not isinstance(filename, ops.Tensor) and not filename: 655 raise ValueError("`filename` argument required for tf.lookup.TextFileInitializer") 656 657 self._filename_arg = filename 658 key_dtype = dtypes.as_dtype(key_dtype) 659 value_dtype = dtypes.as_dtype(value_dtype) 660 661 if key_index < -2: 662 raise ValueError("`key_index` should be >= -2, received: {key_index}.") 663 664 if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64: 665 raise ValueError("`key_dtype` must be int64 if `key_index` is " 666 f"{TextFileIndex.LINE_NUMBER}, received: {key_dtype}") 667 if ((key_index == TextFileIndex.WHOLE_LINE) and 668 (not key_dtype.is_integer) and (key_dtype != dtypes.string)): 669 raise ValueError( 670 "`key_dtype` should be either integer or string for `key_index` " 671 f"{TextFileIndex.WHOLE_LINE}, received: {key_dtype}") 672 if value_index < -2: 673 raise ValueError("`value_index` should be >= -2, received: " 674 f"{value_index}") 675 676 if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64: 677 raise ValueError("`value_dtype` must be int64 for `value_index` " 678 f"{TextFileIndex.LINE_NUMBER}, received: {key_dtype}") 679 if ((value_index == TextFileIndex.WHOLE_LINE) and 680 (not value_dtype.is_integer) and (value_dtype != dtypes.string)): 681 raise ValueError( 682 "`value_dtype` should be either integer or string for `value_index` " 683 f"{TextFileIndex.WHOLE_LINE}, received: {value_dtype}") 684 685 if (vocab_size is not None) and (vocab_size <= 0): 686 raise ValueError(f"`vocab_size` should be > 0, received: {vocab_size}") 687 688 self._key_index = key_index 689 self._value_index = value_index 690 self._vocab_size = vocab_size 691 self._delimiter = delimiter 692 self._name = name 693 self._filename = self._track_trackable( 694 trackable.Asset(filename), "_filename") 695 self._offset = value_index_offset 696 697 super(TextFileInitializer, self).__init__(key_dtype, value_dtype) 698 699 def initialize(self, table): 700 """Initializes the table from a text file. 701 702 Args: 703 table: The table to be initialized. 704 705 Returns: 706 The operation that initializes the table. 707 708 Raises: 709 TypeError: when the keys and values data types do not match the table 710 key and value data types. 711 """ 712 check_table_dtypes(table, self.key_dtype, self.value_dtype) 713 with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)): 714 filename = ops.convert_to_tensor( 715 self._filename, dtypes.string, name="asset_filepath") 716 init_op = gen_lookup_ops.initialize_table_from_text_file_v2( 717 table.resource_handle, filename, self._key_index, self._value_index, 718 -1 if self._vocab_size is None else self._vocab_size, self._delimiter, 719 self._offset) 720 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 721 # If the filename tensor is anything other than a string constant (e.g., 722 # if it is a placeholder) then it does not make sense to track it as an 723 # asset. 724 if not context.executing_eagerly() and constant_op.is_constant(filename): 725 ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename) 726 return init_op 727 728 @property 729 def _shared_name(self): 730 if self._vocab_size: 731 # Keep the shared_name: 732 # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>_<offset> 733 if self._offset: 734 shared_name = "hash_table_%s_%d_%s_%s_%s" % ( 735 self._filename_arg, self._vocab_size, self._key_index, 736 self._value_index, self._offset) 737 else: 738 shared_name = "hash_table_%s_%d_%s_%s" % ( 739 self._filename_arg, self._vocab_size, self._key_index, 740 self._value_index) 741 else: 742 # Keep the shared_name 743 # <table_type>_<filename>_<key_index>_<value_index>_<offset> 744 if self._offset: 745 shared_name = "hash_table_%s_%s_%s_%s" % ( 746 self._filename_arg, self._key_index, self._value_index, 747 self._offset) 748 else: 749 shared_name = "hash_table_%s_%s_%s" % ( 750 self._filename_arg, self._key_index, self._value_index) 751 752 return shared_name 753 754 755class TextFileStringTableInitializer(TextFileInitializer): 756 """Table initializer for `int64` IDs to string tables from a text file.""" 757 758 def __init__(self, 759 filename, 760 key_column_index=TextFileIndex.LINE_NUMBER, 761 value_column_index=TextFileIndex.WHOLE_LINE, 762 vocab_size=None, 763 delimiter="\t", 764 name="text_file_string_table_init"): 765 """Constructs an initializer for an id-to-string table from a text file. 766 767 It populates a table that its key and value types are int64 and string, 768 respectively. It generates one key-value pair per line. 769 The content of the key and value are specified by `key_column_index` 770 and `value_column_index`. 771 772 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 773 expects data type int64. 774 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 775 type string or int64. 776 - A value >=0 means use the index (starting at zero) of the split line based 777 on `delimiter`. 778 779 Args: 780 filename: The filename of the text file to be used for initialization. The 781 path must be accessible from wherever the graph is initialized (eg. 782 trainer or eval workers). The filename may be a scalar `Tensor`. 783 key_column_index: The column index from the text file to get the keys 784 from. The default is to use the line number, starting from zero. 785 value_column_index: The column index from the text file to get the values 786 from. The default is to use the whole line content. 787 vocab_size: The number of elements in the file, if known. 788 delimiter: The delimiter to separate fields in a line. 789 name: Optional name for the op. 790 791 Raises: 792 TypeError: when the filename is empty, or when the table key and value 793 data types do not match the expected data types. 794 """ 795 super(TextFileStringTableInitializer, self).__init__( 796 filename, 797 dtypes.int64, 798 key_column_index, 799 dtypes.string, 800 value_column_index, 801 vocab_size=vocab_size, 802 delimiter=delimiter, 803 name=name) 804 805 806class TextFileIdTableInitializer(TextFileInitializer): 807 """Table initializer for string to `int64` IDs tables from a text file.""" 808 809 def __init__(self, 810 filename, 811 key_column_index=TextFileIndex.WHOLE_LINE, 812 value_column_index=TextFileIndex.LINE_NUMBER, 813 vocab_size=None, 814 delimiter="\t", 815 name="text_file_id_table_init", 816 key_dtype=dtypes.string): 817 """Constructs an initializer for an string-to-id table from a text file. 818 819 It populates a table that its key and value types are string and int64, 820 respectively. It generates one key-value pair per line. 821 The content of the key and value are specified by the key_index 822 and value_index. 823 824 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 825 expects data type int64. 826 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 827 type string. 828 - A value >=0 means use the index (starting at zero) of the split line based 829 on `delimiter`. 830 831 Args: 832 filename: The filename of the text file to be used for initialization. The 833 path must be accessible from wherever the graph is initialized (eg. 834 trainer or eval workers). The filename may be a scalar `Tensor`. 835 key_column_index: The column index from the text file to get the `key` 836 values from. The default is to use the whole line content. 837 value_column_index: The column index from the text file to get the `value` 838 values from. The default is to use the line number, starting from zero. 839 vocab_size: The number of elements in the file, if known. 840 delimiter: The delimiter to separate fields in a line. 841 name: Optional name for the op. 842 key_dtype: The `key` data type. 843 844 Raises: 845 TypeError: when the filename is empty, or when the table key and value 846 data types do not match the expected data types. 847 """ 848 super(TextFileIdTableInitializer, self).__init__( 849 filename, 850 key_dtype, 851 key_column_index, 852 dtypes.int64, 853 value_column_index, 854 vocab_size=vocab_size, 855 delimiter=delimiter, 856 name=name) 857 858 859class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])): 860 """A structure for the spec of the hashing function to use for hash buckets. 861 862 `hasher` is the name of the hashing function to use (eg. "fasthash", 863 "stronghash"). 864 `key` is optional and specify the key to use for the hash function if 865 supported, currently only used by a strong hash. 866 867 Fields: 868 hasher: The hasher name to use. 869 key: The key to be used by the hashing function, if required. 870 """ 871 __slots__ = () 872 873 874FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name 875 876 877class StrongHashSpec(HasherSpec): 878 """A structure to specify a key of the strong keyed hash spec. 879 880 The strong hash requires a `key`, which is a list of 2 unsigned integer 881 numbers. These should be non-zero; random numbers generated from random.org 882 would be a fine choice. 883 884 Fields: 885 key: The key to be used by the keyed hashing function. 886 """ 887 __slots__ = () 888 889 def __new__(cls, key): 890 if len(key) != 2: 891 raise ValueError(f"`key` must have size 2, received {len(key)}") 892 893 if not isinstance(key[0], compat_util.integral_types) or not isinstance( 894 key[1], compat_util.integral_types): 895 raise TypeError("Invalid key %s. Must be unsigned integer values." % key) 896 897 return super(cls, StrongHashSpec).__new__(cls, "stronghash", key) 898 899 900def _as_string(tensor): 901 if dtypes.string == tensor.dtype.base_dtype: 902 return tensor 903 return string_ops.as_string(tensor) 904 905 906class IdTableWithHashBuckets(LookupInterface): 907 r"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets. 908 909 For example, if an instance of `IdTableWithHashBuckets` is initialized with a 910 string-to-id table that maps: 911 912 * `emerson -> 0` 913 * `lake -> 1` 914 * `palmer -> 2` 915 916 The `IdTableWithHashBuckets` object will performs the following mapping: 917 918 * `emerson -> 0` 919 * `lake -> 1` 920 * `palmer -> 2` 921 * `<other term> -> bucket_id`, where bucket_id will be between `3` and 922 `3 + num_oov_buckets - 1`, calculated by: 923 `hash(<term>) % num_oov_buckets + vocab_size` 924 925 If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`, 926 the lookup result is `[0, 1, 2, 4, 7]`. 927 928 If `table` is None, only out-of-vocabulary buckets are used. 929 930 Example usage: 931 932 ```python 933 num_oov_buckets = 3 934 input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"]) 935 table = tf.IdTableWithHashBuckets( 936 tf.StaticHashTable( 937 tf.lookup.TextFileInitializer( 938 filename, 939 key_dtype=tf.string, 940 key_index=tf.lookup.TextFileIndex.WHOLE_LINE, 941 value_dtype=tf.int64, 942 value_index=tf.lookup.TextFileIndex.LINE_NUMBER, 943 delimiter="\t"), 944 default_value), 945 num_oov_buckets) 946 out = table.lookup(input_tensor). 947 table.init.run() 948 print(out.eval()) 949 ``` 950 951 The hash function used for generating out-of-vocabulary buckets ID is handled 952 by `hasher_spec`. 953 """ 954 955 def __init__(self, 956 table, 957 num_oov_buckets, 958 hasher_spec=FastHashSpec, 959 name=None, 960 key_dtype=None): 961 """Construct a `IdTableWithHashBuckets` object. 962 963 Args: 964 table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids. 965 num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. 966 hasher_spec: A `HasherSpec` to specify the hash function to use for 967 assignation of out-of-vocabulary buckets (optional). 968 name: A name for the operation (optional). 969 key_dtype: Data type of keys passed to `lookup`. Defaults to 970 `table.key_dtype` if `table` is specified, otherwise `tf.string`. Must 971 be string or integer, and must be castable to `table.key_dtype`. 972 973 Raises: 974 ValueError: when `table` in None and `num_oov_buckets` is not positive. 975 TypeError: when `hasher_spec` is invalid. 976 """ 977 # If a name ends with a '/' it is a "name scope", remove all trailing '/' 978 # characters to use as table name. 979 if name: 980 name = name.rstrip("/") 981 if table: 982 if key_dtype is None: 983 key_dtype = table.key_dtype 984 supported_table_key_dtypes = (dtypes.int64, dtypes.string) 985 if table.key_dtype not in supported_table_key_dtypes: 986 raise TypeError("Invalid `key_dtype`, expected one of " 987 f"{supported_table_key_dtypes}, received {key_dtype}.") 988 if table.key_dtype.is_integer != key_dtype.is_integer: 989 raise TypeError("Invalid `key dtype`, expected %s but got %s." % 990 ("integer" if key_dtype.is_integer else "non-integer", 991 table.key_dtype)) 992 if table.value_dtype != dtypes.int64: 993 raise TypeError("Invalid `value_dtype`: expected int64 but got %s." % 994 (table.value_dtype)) 995 self._table = table 996 name = name or self._table.name 997 else: 998 if num_oov_buckets <= 0: 999 raise ValueError("`oov_buckets` must be > 0 if no `table` is supplied.") 1000 key_dtype = dtypes.string if key_dtype is None else key_dtype 1001 self._table = None 1002 name = name or "hash_bucket" 1003 if (not key_dtype.is_integer) and (dtypes.string != key_dtype): 1004 raise TypeError("Invalid `key_dtype`, expected integer or string, got " 1005 f"{key_dtype}.") 1006 self._num_oov_buckets = num_oov_buckets 1007 1008 if not isinstance(hasher_spec, HasherSpec): 1009 raise TypeError("`hasher_spec` must be of type HasherSpec, got " 1010 f"{type(hasher_spec)}.") 1011 self._hasher_spec = hasher_spec 1012 if name: 1013 self._table_name = name.split("/")[-1] 1014 else: 1015 self._table_name = None 1016 super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64) 1017 1018 def _create_resource(self): 1019 if self._table is not None: 1020 return self._table._create_resource() # pylint: disable=protected-access 1021 return None 1022 1023 def _initialize(self): 1024 if self._table is not None: 1025 return self._table._initialize() # pylint: disable=protected-access 1026 with ops.name_scope(None, "init"): 1027 return control_flow_ops.no_op() 1028 1029 @property 1030 def initializer(self): 1031 if self._table is not None: 1032 return self._table._init_op # pylint: disable=protected-access 1033 with ops.name_scope(None, "init"): 1034 return control_flow_ops.no_op() 1035 1036 @property 1037 @deprecated("2018-12-15", "Use `initializer` instead.") 1038 def init(self): 1039 return self.initializer 1040 1041 @property 1042 def resource_handle(self): 1043 if self._table is not None: 1044 return self._table.resource_handle 1045 return None 1046 1047 @property 1048 def name(self): 1049 return self._table_name 1050 1051 def size(self, name=None): 1052 """Compute the number of elements in this table.""" 1053 with ops.name_scope(name, "%s_Size" % self.name): 1054 if self._table: 1055 tsize = self._table.size() 1056 else: 1057 tsize = ops.convert_to_tensor(0, dtype=dtypes.int64) 1058 return tsize + self._num_oov_buckets 1059 1060 def _get_string_to_hash_bucket_fn(self, hasher_spec): 1061 """Returns the string_to_hash_bucket op to use based on `hasher_spec`.""" 1062 if not isinstance(hasher_spec, HasherSpec): 1063 raise TypeError("`hasher_spec` must be of type HasherSpec, got " 1064 f"{type(hasher_spec)}.") 1065 if hasher_spec.hasher == "fasthash": 1066 return string_ops.string_to_hash_bucket_fast 1067 if hasher_spec.hasher == "legacy": 1068 return string_ops.string_to_hash_bucket 1069 if hasher_spec.hasher == "stronghash": 1070 return functools.partial( 1071 string_ops.string_to_hash_bucket_strong, key=hasher_spec.key) 1072 raise ValueError( 1073 f"Found unknown hasher {hasher_spec.hasher} in `hasher_spec`") 1074 1075 def lookup(self, keys, name=None): 1076 """Looks up `keys` in the table, outputs the corresponding values. 1077 1078 It assigns out-of-vocabulary keys to buckets based in their hashes. 1079 1080 Args: 1081 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 1082 name: Optional name for the op. 1083 1084 Returns: 1085 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged, 1086 otherwise a dense `Tensor`. 1087 1088 Raises: 1089 TypeError: when `keys` doesn't match the table key data type. 1090 """ 1091 if keys.dtype.base_dtype != self._key_dtype: 1092 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, " 1093 f"received: {keys.dtype}") 1094 values = keys 1095 if isinstance(keys, 1096 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)): 1097 values = keys.values 1098 if self._table and (self._table.key_dtype.base_dtype == dtypes.int64): 1099 values = math_ops.cast(values, dtypes.int64) 1100 1101 if self._num_oov_buckets == 0: 1102 ids = self._table.lookup(values, name=name) 1103 else: 1104 # TODO(yleon): Consider moving this functionality to its own kernel. 1105 with ops.name_scope(name, "%s_Lookup" % self.name): 1106 str_to_hash_bucket = self._get_string_to_hash_bucket_fn( 1107 self._hasher_spec) 1108 buckets = str_to_hash_bucket( 1109 _as_string(values), 1110 num_buckets=self._num_oov_buckets, 1111 name="hash_bucket") 1112 if self._table: 1113 ids = self._table.lookup(values) 1114 buckets = math_ops.add(buckets, self._table.size()) 1115 is_id_non_default = math_ops.not_equal(ids, self._table.default_value) 1116 ids = array_ops.where_v2(is_id_non_default, ids, buckets) 1117 else: 1118 ids = buckets 1119 if isinstance(keys, sparse_tensor.SparseTensor): 1120 return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape) 1121 elif isinstance(keys, ragged_tensor.RaggedTensor): 1122 return keys.with_values(ids) 1123 return ids 1124 1125 1126@tf_export("lookup.StaticVocabularyTable", v1=[]) 1127class StaticVocabularyTable(LookupInterface): 1128 r"""String to Id table that assigns out-of-vocabulary keys to hash buckets. 1129 1130 For example, if an instance of `StaticVocabularyTable` is initialized with a 1131 string-to-id initializer that maps: 1132 1133 >>> init = tf.lookup.KeyValueTensorInitializer( 1134 ... keys=tf.constant(['emerson', 'lake', 'palmer']), 1135 ... values=tf.constant([0, 1, 2], dtype=tf.int64)) 1136 >>> table = tf.lookup.StaticVocabularyTable( 1137 ... init, 1138 ... num_oov_buckets=5) 1139 1140 The `Vocabulary` object will performs the following mapping: 1141 1142 * `emerson -> 0` 1143 * `lake -> 1` 1144 * `palmer -> 2` 1145 * `<other term> -> bucket_id`, where `bucket_id` will be between `3` and 1146 `3 + num_oov_buckets - 1 = 7`, calculated by: 1147 `hash(<term>) % num_oov_buckets + vocab_size` 1148 1149 If input_tensor is: 1150 1151 >>> input_tensor = tf.constant(["emerson", "lake", "palmer", 1152 ... "king", "crimson"]) 1153 >>> table[input_tensor].numpy() 1154 array([0, 1, 2, 6, 7]) 1155 1156 If `initializer` is None, only out-of-vocabulary buckets are used. 1157 1158 Example usage: 1159 1160 >>> num_oov_buckets = 3 1161 >>> vocab = ["emerson", "lake", "palmer", "crimnson"] 1162 >>> import tempfile 1163 >>> f = tempfile.NamedTemporaryFile(delete=False) 1164 >>> f.write('\n'.join(vocab).encode('utf-8')) 1165 >>> f.close() 1166 1167 >>> init = tf.lookup.TextFileInitializer( 1168 ... f.name, 1169 ... key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, 1170 ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER) 1171 >>> table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets) 1172 >>> table.lookup(tf.constant(["palmer", "crimnson" , "king", 1173 ... "tarkus", "black", "moon"])).numpy() 1174 array([2, 3, 5, 6, 6, 4]) 1175 1176 The hash function used for generating out-of-vocabulary buckets ID is 1177 Fingerprint64. 1178 1179 Note that the out-of-vocabulary bucket IDs always range from the table `size` 1180 up to `size + num_oov_buckets - 1` regardless of the table values, which could 1181 cause unexpected collisions: 1182 1183 >>> init = tf.lookup.KeyValueTensorInitializer( 1184 ... keys=tf.constant(["emerson", "lake", "palmer"]), 1185 ... values=tf.constant([1, 2, 3], dtype=tf.int64)) 1186 >>> table = tf.lookup.StaticVocabularyTable( 1187 ... init, 1188 ... num_oov_buckets=1) 1189 >>> input_tensor = tf.constant(["emerson", "lake", "palmer", "king"]) 1190 >>> table[input_tensor].numpy() 1191 array([1, 2, 3, 3]) 1192 """ 1193 1194 def __init__(self, 1195 initializer, 1196 num_oov_buckets, 1197 lookup_key_dtype=None, 1198 name=None): 1199 """Construct a `StaticVocabularyTable` object. 1200 1201 Args: 1202 initializer: A `TableInitializerBase` object that contains the data used 1203 to initialize the table. If None, then we only use out-of-vocab buckets. 1204 num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. Must 1205 be greater than zero. 1206 lookup_key_dtype: Data type of keys passed to `lookup`. Defaults to 1207 `initializer.key_dtype` if `initializer` is specified, otherwise 1208 `tf.string`. Must be string or integer, and must be castable to 1209 `initializer.key_dtype`. 1210 name: A name for the operation (optional). 1211 1212 Raises: 1213 ValueError: when `num_oov_buckets` is not positive. 1214 TypeError: when lookup_key_dtype or initializer.key_dtype are not 1215 integer or string. Also when initializer.value_dtype != int64. 1216 """ 1217 if num_oov_buckets <= 0: 1218 raise ValueError("`num_oov_buckets` must be > 0.") 1219 # If a name ends with a '/' it is a "name scope", remove all trailing '/' 1220 # characters to use as table name. 1221 if name: 1222 name = name.rstrip("/") 1223 if initializer: 1224 if lookup_key_dtype is None: 1225 lookup_key_dtype = initializer.key_dtype 1226 supported_table_key_dtypes = (dtypes.int64, dtypes.string) 1227 if initializer.key_dtype not in supported_table_key_dtypes: 1228 raise TypeError("Invalid `key_dtype`, expected one of %s, but got %s." % 1229 (supported_table_key_dtypes, initializer.key_dtype)) 1230 if initializer.key_dtype.is_integer != lookup_key_dtype.is_integer: 1231 raise TypeError( 1232 "Invalid `key_dtype`, expected %s but got %s." % 1233 ("integer" if lookup_key_dtype.is_integer else "non-integer", 1234 initializer.key_dtype)) 1235 if initializer.value_dtype != dtypes.int64: 1236 raise TypeError("Invalid `value_dtype`, expected %s but got %s." % 1237 (dtypes.int64, initializer.value_dtype)) 1238 if isinstance(initializer, trackable_base.Trackable): 1239 self._initializer = self._track_trackable(initializer, "_initializer") 1240 self._table = HashTable(initializer, default_value=-1) 1241 name = name or self._table.name 1242 else: 1243 lookup_key_dtype = dtypes.string 1244 self._table = None 1245 name = name or "hash_bucket" 1246 if (not lookup_key_dtype.is_integer) and (dtypes.string != 1247 lookup_key_dtype): 1248 raise TypeError("Invalid `key_dtype`, expected integer or string, got " 1249 f"{lookup_key_dtype}") 1250 self._num_oov_buckets = num_oov_buckets 1251 1252 self._table_name = None 1253 if name is not None: 1254 self._table_name = name.split("/")[-1] 1255 super(StaticVocabularyTable, self).__init__(lookup_key_dtype, dtypes.int64) 1256 1257 def _create_resource(self): 1258 if self._table is not None: 1259 return self._table._create_resource() # pylint: disable=protected-access 1260 return None 1261 1262 def _initialize(self): 1263 if self._table is not None: 1264 return self._table._initialize() # pylint: disable=protected-access 1265 with ops.name_scope(None, "init"): 1266 return control_flow_ops.no_op() 1267 1268 @property 1269 def resource_handle(self): 1270 if self._table is not None: 1271 return self._table.resource_handle 1272 return None 1273 1274 @property 1275 def name(self): 1276 return self._table_name 1277 1278 def size(self, name=None): 1279 """Compute the number of elements in this table.""" 1280 with ops.name_scope(name, "%s_Size" % self.name): 1281 if self._table: 1282 tsize = self._table.size() 1283 else: 1284 tsize = ops.convert_to_tensor(0, dtype=dtypes.int64) 1285 return tsize + self._num_oov_buckets 1286 1287 def lookup(self, keys, name=None): 1288 """Looks up `keys` in the table, outputs the corresponding values. 1289 1290 It assigns out-of-vocabulary keys to buckets based in their hashes. 1291 1292 Args: 1293 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 1294 name: Optional name for the op. 1295 1296 Returns: 1297 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged, 1298 otherwise a dense `Tensor`. 1299 1300 Raises: 1301 TypeError: when `keys` doesn't match the table key data type. 1302 """ 1303 if keys.dtype.base_dtype != self._key_dtype: 1304 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, " 1305 f"received: {keys.dtype}") 1306 values = keys 1307 if isinstance(keys, 1308 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)): 1309 values = keys.values 1310 if self._table and (self._table.key_dtype.base_dtype == dtypes.int64): 1311 values = math_ops.cast(values, dtypes.int64) 1312 1313 # TODO(yleon): Consider moving this functionality to its own kernel. 1314 with ops.name_scope(name, "%s_Lookup" % self.name): 1315 buckets = string_ops.string_to_hash_bucket_fast( 1316 _as_string(values), 1317 num_buckets=self._num_oov_buckets, 1318 name="hash_bucket") 1319 if self._table: 1320 ids = self._table.lookup(values) 1321 buckets = math_ops.add(buckets, self._table.size()) 1322 is_id_non_default = math_ops.not_equal(ids, self._table.default_value) 1323 ids = array_ops.where_v2(is_id_non_default, ids, buckets) 1324 else: 1325 ids = buckets 1326 if isinstance(keys, sparse_tensor.SparseTensor): 1327 return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape) 1328 elif isinstance(keys, ragged_tensor.RaggedTensor): 1329 return keys.with_values(ids) 1330 return ids 1331 1332 1333@tf_export(v1=["lookup.StaticVocabularyTable"]) 1334class StaticVocabularyTableV1(StaticVocabularyTable): 1335 1336 @property 1337 def initializer(self): 1338 if self._table is not None: 1339 return self._table._init_op # pylint: disable=protected-access 1340 with ops.name_scope(None, "init"): 1341 return control_flow_ops.no_op() 1342 1343 1344def index_table_from_file(vocabulary_file=None, 1345 num_oov_buckets=0, 1346 vocab_size=None, 1347 default_value=-1, 1348 hasher_spec=FastHashSpec, 1349 key_dtype=dtypes.string, 1350 name=None, 1351 key_column_index=TextFileIndex.WHOLE_LINE, 1352 value_column_index=TextFileIndex.LINE_NUMBER, 1353 delimiter="\t"): 1354 """Returns a lookup table that converts a string tensor into int64 IDs. 1355 1356 This operation constructs a lookup table to convert tensor of strings into 1357 int64 IDs. The mapping can be initialized from a vocabulary file specified in 1358 `vocabulary_file`, where the whole line is the key and the zero-based line 1359 number is the ID. 1360 1361 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 1362 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 1363 `default_value`. 1364 The bucket ID range is 1365 `[vocabulary size, vocabulary size + num_oov_buckets - 1]`. 1366 1367 The underlying table must be initialized by calling 1368 `session.run(tf.compat.v1.tables_initializer())` or 1369 `session.run(table.init())` once. 1370 1371 To specify multi-column vocabulary files, use key_column_index and 1372 value_column_index and delimiter. 1373 1374 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 1375 expects data type int64. 1376 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 1377 type string. 1378 - A value >=0 means use the index (starting at zero) of the split line based 1379 on `delimiter`. 1380 1381 Sample Usages: 1382 1383 If we have a vocabulary file "test.txt" with the following content: 1384 1385 ``` 1386 emerson 1387 lake 1388 palmer 1389 ``` 1390 1391 ```python 1392 features = tf.constant(["emerson", "lake", "and", "palmer"]) 1393 table = tf.lookup.index_table_from_file( 1394 vocabulary_file="test.txt", num_oov_buckets=1) 1395 ids = table.lookup(features) 1396 ... 1397 tf.compat.v1.tables_initializer().run() 1398 1399 ids.eval() ==> [0, 1, 3, 2] # where 3 is the out-of-vocabulary bucket 1400 ``` 1401 1402 Args: 1403 vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. 1404 num_oov_buckets: The number of out-of-vocabulary buckets. 1405 vocab_size: Number of the elements in the vocabulary, if known. 1406 default_value: The value to use for out-of-vocabulary feature values. 1407 Defaults to -1. 1408 hasher_spec: A `HasherSpec` to specify the hash function to use for 1409 assignation of out-of-vocabulary buckets. 1410 key_dtype: The `key` data type. 1411 name: A name for this op (optional). 1412 key_column_index: The column index from the text file to get the `key` 1413 values from. The default is to use the whole line content. 1414 value_column_index: The column index from the text file to get the `value` 1415 values from. The default is to use the line number, starting from zero. 1416 delimiter: The delimiter to separate fields in a line. 1417 1418 Returns: 1419 The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`. 1420 1421 Raises: 1422 ValueError: If `vocabulary_file` is not set. 1423 ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater 1424 than zero. 1425 """ 1426 if vocabulary_file is None or (isinstance(vocabulary_file, six.string_types) 1427 and not vocabulary_file): 1428 raise ValueError( 1429 "`vocabulary_file` must be specified and must not be empty.") 1430 if num_oov_buckets < 0: 1431 raise ValueError( 1432 "num_oov_buckets must be greater or equal than 0, got %d." % 1433 num_oov_buckets) 1434 if vocab_size is not None and vocab_size < 1: 1435 vocab_file_value = vocabulary_file 1436 if isinstance(vocabulary_file, ops.Tensor): 1437 vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?" 1438 raise ValueError("`vocab_size` must be greater than 0, got %d for " 1439 "vocabulary_file: %s." % (vocab_size, vocab_file_value)) 1440 if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): 1441 raise TypeError("Dtype for `keys` should be either integer or string.") 1442 1443 with ops.name_scope(name, "string_to_index"): 1444 table = None 1445 with ops.name_scope(None, "hash_table"): 1446 init = TextFileIdTableInitializer( 1447 vocabulary_file, 1448 vocab_size=vocab_size, 1449 key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype, 1450 name="table_init", 1451 key_column_index=key_column_index, 1452 value_column_index=value_column_index, 1453 delimiter=delimiter) 1454 1455 table = StaticHashTableV1(init, default_value) 1456 if num_oov_buckets: 1457 table = IdTableWithHashBuckets( 1458 table, 1459 num_oov_buckets=num_oov_buckets, 1460 hasher_spec=hasher_spec, 1461 key_dtype=key_dtype) 1462 1463 return table 1464 1465 1466def index_table_from_tensor(vocabulary_list, 1467 num_oov_buckets=0, 1468 default_value=-1, 1469 hasher_spec=FastHashSpec, 1470 dtype=dtypes.string, 1471 name=None): 1472 """Returns a lookup table that converts a string tensor into int64 IDs. 1473 1474 This operation constructs a lookup table to convert tensor of strings into 1475 int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D 1476 tensor where each element is a key and corresponding index within the tensor 1477 is the value. 1478 1479 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 1480 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 1481 `default_value`. The bucket ID range is 1482 `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`. 1483 1484 The underlying table must be initialized by calling 1485 `session.run(tf.compat.v1.tables_initializer())` or 1486 `session.run(table.init())` once. 1487 1488 Elements in `vocabulary_list` cannot have duplicates, otherwise when executing 1489 the table initializer op, it will throw a `FailedPreconditionError`. 1490 1491 Sample Usages: 1492 1493 ```python 1494 vocabulary_list = tf.constant(["emerson", "lake", "palmer"]) 1495 table = tf.lookup.index_table_from_tensor( 1496 vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1) 1497 features = tf.constant(["emerson", "lake", "and", "palmer"]) 1498 ids = table.lookup(features) 1499 ... 1500 tf.compat.v1.tables_initializer().run() 1501 1502 ids.eval() ==> [0, 1, 4, 2] 1503 ``` 1504 1505 Args: 1506 vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to 1507 indices. The type of this object must be castable to `dtype`. 1508 num_oov_buckets: The number of out-of-vocabulary buckets. 1509 default_value: The value to use for out-of-vocabulary feature values. 1510 Defaults to -1. 1511 hasher_spec: A `HasherSpec` to specify the hash function to use for 1512 assignment of out-of-vocabulary buckets. 1513 dtype: The type of values passed to `lookup`. Only string and integers are 1514 supported. 1515 name: A name for this op (optional). 1516 1517 Returns: 1518 The lookup table to map an input `Tensor` to index `int64` `Tensor`. 1519 1520 Raises: 1521 ValueError: If `vocabulary_list` is invalid. 1522 ValueError: If `num_oov_buckets` is negative. 1523 """ 1524 if vocabulary_list is None: 1525 raise ValueError("`vocabulary_list` must be specified.") 1526 1527 if num_oov_buckets < 0: 1528 raise ValueError( 1529 "`num_oov_buckets` must be greater or equal than 0, got %d." % 1530 num_oov_buckets) 1531 1532 if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype): 1533 raise TypeError("`dtype` must either be integer or string.") 1534 1535 with ops.name_scope(name, "string_to_index"): 1536 keys = ops.convert_to_tensor(vocabulary_list) 1537 if keys.dtype.is_integer != dtype.is_integer: 1538 raise ValueError( 1539 "Invalid `dtype`: Expected %s, got %s." % 1540 ("integer" if dtype.is_integer else "non-integer", keys.dtype)) 1541 if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype): 1542 raise ValueError("Invalid `dtype`: Expected %s, got %s." % 1543 (dtype, keys.dtype)) 1544 num_elements = array_ops.size(keys) 1545 values = math_ops.cast(math_ops.range(num_elements), dtypes.int64) 1546 1547 with ops.name_scope(None, "hash_table"): 1548 table_keys = math_ops.cast( 1549 keys, dtypes.int64) if keys.dtype.is_integer else keys 1550 init = KeyValueTensorInitializer( 1551 table_keys, 1552 values, 1553 table_keys.dtype.base_dtype, 1554 dtypes.int64, 1555 name="table_init") 1556 table = StaticHashTableV1(init, default_value) 1557 if num_oov_buckets: 1558 table = IdTableWithHashBuckets( 1559 table, 1560 num_oov_buckets=num_oov_buckets, 1561 hasher_spec=hasher_spec, 1562 key_dtype=dtype) 1563 return table 1564 1565 1566def index_to_string_table_from_file(vocabulary_file, 1567 vocab_size=None, 1568 default_value="UNK", 1569 name=None, 1570 key_column_index=TextFileIndex.LINE_NUMBER, 1571 value_column_index=TextFileIndex.WHOLE_LINE, 1572 delimiter="\t"): 1573 """Returns a lookup table that maps a `Tensor` of indices into strings. 1574 1575 This operation constructs a lookup table to map int64 indices into string 1576 values. The table is initialized from a vocabulary file specified in 1577 `vocabulary_file`, where the whole line is the value and the 1578 zero-based line number is the index. 1579 1580 Any input which does not have a corresponding index in the vocabulary file 1581 (an out-of-vocabulary entry) is assigned the `default_value` 1582 1583 The underlying table must be initialized by calling 1584 `session.run(tf.compat.v1.tables_initializer())` or 1585 `session.run(table.init())` once. 1586 1587 To specify multi-column vocabulary files, use key_column_index and 1588 value_column_index and delimiter. 1589 1590 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 1591 expects data type int64. 1592 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 1593 type string. 1594 - A value >=0 means use the index (starting at zero) of the split line based 1595 on `delimiter`. 1596 1597 Sample Usages: 1598 1599 If we have a vocabulary file "test.txt" with the following content: 1600 1601 ``` 1602 emerson 1603 lake 1604 palmer 1605 ``` 1606 1607 ```python 1608 indices = tf.constant([1, 5], tf.int64) 1609 table = tf.lookup.index_to_string_table_from_file( 1610 vocabulary_file="test.txt", default_value="UNKNOWN") 1611 values = table.lookup(indices) 1612 ... 1613 tf.compat.v1.tables_initializer().run() 1614 1615 values.eval() ==> ["lake", "UNKNOWN"] 1616 ``` 1617 1618 Args: 1619 vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. 1620 vocab_size: Number of the elements in the vocabulary, if known. 1621 default_value: The value to use for out-of-vocabulary indices. 1622 name: A name for this op (optional). 1623 key_column_index: The column index from the text file to get the `key` 1624 values from. The default is to use the line number, starting from zero. 1625 value_column_index: The column index from the text file to get the `value` 1626 values from. The default is to use the whole line content. 1627 delimiter: The delimiter to separate fields in a line. 1628 1629 Returns: 1630 The lookup table to map a string values associated to a given index `int64` 1631 `Tensors`. 1632 1633 Raises: 1634 ValueError: when `vocabulary_file` is empty. 1635 ValueError: when `vocab_size` is invalid. 1636 """ 1637 if vocabulary_file is None or (isinstance(vocabulary_file, six.string_types) 1638 and not vocabulary_file): 1639 raise ValueError( 1640 "`vocabulary_file` must be specified and must not be empty.") 1641 1642 if vocab_size is not None and vocab_size < 1: 1643 raise ValueError(f"`vocab_size` must be greater than 0, got {vocab_size}.") 1644 1645 with ops.name_scope(name, "index_to_string"): 1646 init = TextFileStringTableInitializer( 1647 vocabulary_file, 1648 vocab_size=vocab_size, 1649 name="table_init", 1650 key_column_index=key_column_index, 1651 value_column_index=value_column_index, 1652 delimiter=delimiter) 1653 1654 # TODO(yleon): Use a more efficient structure. 1655 return StaticHashTableV1(init, default_value) 1656 1657 1658def index_to_string_table_from_tensor(vocabulary_list, 1659 default_value="UNK", 1660 name=None): 1661 """Returns a lookup table that maps a `Tensor` of indices into strings. 1662 1663 This operation constructs a lookup table to map int64 indices into string 1664 values. The mapping is initialized from a string `vocabulary_list` 1-D 1665 `Tensor` where each element is a value and the corresponding index within the 1666 tensor is the key. 1667 1668 Any input which does not have a corresponding index in 'vocabulary_list' 1669 (an out-of-vocabulary entry) is assigned the `default_value` 1670 1671 The underlying table must be initialized by calling 1672 `session.run(tf.compat.v1.tables_initializer())` or 1673 `session.run(table.init())` once. 1674 1675 Elements in `vocabulary_list` cannot have duplicates, otherwise when executing 1676 the table initializer op, it will throw a `FailedPreconditionError`. 1677 1678 Sample Usages: 1679 1680 ```python 1681 vocabulary_list = tf.constant(["emerson", "lake", "palmer"]) 1682 indices = tf.constant([1, 5], tf.int64) 1683 table = tf.lookup.index_to_string_table_from_tensor( 1684 vocabulary_list, default_value="UNKNOWN") 1685 values = table.lookup(indices) 1686 ... 1687 tf.compat.v1.tables_initializer().run() 1688 1689 values.eval() ==> ["lake", "UNKNOWN"] 1690 ``` 1691 1692 Args: 1693 vocabulary_list: A 1-D string `Tensor` that specifies the strings to map 1694 from indices. 1695 default_value: The value to use for out-of-vocabulary indices. 1696 name: A name for this op (optional). 1697 1698 Returns: 1699 The lookup table to map a string values associated to a given index `int64` 1700 `Tensors`. 1701 1702 Raises: 1703 ValueError: when `vocabulary_list` is not set. 1704 """ 1705 1706 if vocabulary_list is None: 1707 raise ValueError("`vocabulary_list` argument must be specified.") 1708 1709 with ops.name_scope(name, "index_to_string"): 1710 vocabulary_list = ops.convert_to_tensor(vocabulary_list, dtypes.string) 1711 num_elements = array_ops.size(vocabulary_list) 1712 keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64) 1713 1714 init = KeyValueTensorInitializer( 1715 keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init") 1716 # TODO(yleon): Use a more efficient structure. 1717 return StaticHashTableV1(init, default_value) 1718 1719 1720@tf_export("lookup.experimental.MutableHashTable") 1721class MutableHashTable(LookupInterface): 1722 """A generic mutable hash table implementation. 1723 1724 Data can be inserted by calling the `insert` method and removed by calling the 1725 `remove` method. It does not support initialization via the init method. 1726 1727 `MutableHashTable` requires additional memory during checkpointing and restore 1728 operations to create temporary key and value tensors. 1729 1730 Example usage: 1731 1732 >>> table = tf.lookup.experimental.MutableHashTable(key_dtype=tf.string, 1733 ... value_dtype=tf.int64, 1734 ... default_value=-1) 1735 >>> keys_tensor = tf.constant(['a', 'b', 'c']) 1736 >>> vals_tensor = tf.constant([7, 8, 9], dtype=tf.int64) 1737 >>> input_tensor = tf.constant(['a', 'f']) 1738 >>> table.insert(keys_tensor, vals_tensor) 1739 >>> table.lookup(input_tensor).numpy() 1740 array([ 7, -1]) 1741 >>> table.remove(tf.constant(['c'])) 1742 >>> table.lookup(keys_tensor).numpy() 1743 array([ 7, 8, -1]) 1744 >>> sorted(table.export()[0].numpy()) 1745 [b'a', b'b'] 1746 >>> sorted(table.export()[1].numpy()) 1747 [7, 8] 1748 """ 1749 1750 def __init__(self, 1751 key_dtype, 1752 value_dtype, 1753 default_value, 1754 name="MutableHashTable", 1755 checkpoint=True): 1756 """Creates an empty `MutableHashTable` object. 1757 1758 Creates a table, the type of its keys and values are specified by key_dtype 1759 and value_dtype, respectively. 1760 1761 Args: 1762 key_dtype: the type of the key tensors. 1763 value_dtype: the type of the value tensors. 1764 default_value: The value to use if a key is missing in the table. 1765 name: A name for the operation (optional). 1766 checkpoint: if True, the contents of the table are saved to and restored 1767 from checkpoints. If `shared_name` is empty for a checkpointed table, it 1768 is shared using the table node name. 1769 1770 Returns: 1771 A `MutableHashTable` object. 1772 1773 Raises: 1774 ValueError: If checkpoint is True and no name was specified. 1775 """ 1776 self._default_value = ops.convert_to_tensor( 1777 default_value, dtype=value_dtype) 1778 self._value_shape = self._default_value.get_shape() 1779 self._checkpoint = checkpoint 1780 self._key_dtype = key_dtype 1781 self._value_dtype = value_dtype 1782 self._name = name 1783 1784 self._shared_name = None 1785 if context.executing_eagerly(): 1786 # TODO(allenl): This will leak memory due to kernel caching by the 1787 # shared_name attribute value (but is better than the alternative of 1788 # sharing everything by default when executing eagerly; hopefully creating 1789 # tables in a loop is uncommon). 1790 # TODO(rohanj): Use context.shared_name() instead. 1791 self._shared_name = "table_%d" % (ops.uid(),) 1792 super(MutableHashTable, self).__init__(key_dtype, value_dtype) 1793 1794 self._resource_handle = self._create_resource() 1795 if checkpoint: 1796 saveable = MutableHashTable._Saveable(self, name) 1797 if not context.executing_eagerly(): 1798 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 1799 1800 def _create_resource(self): 1801 # The table must be shared if checkpointing is requested for multi-worker 1802 # training to work correctly. Use the node name if no shared_name has been 1803 # explicitly specified. 1804 use_node_name_sharing = self._checkpoint and self._shared_name is None 1805 if self._default_value.get_shape().ndims == 0: 1806 table_ref = gen_lookup_ops.mutable_hash_table_v2( 1807 shared_name=self._shared_name, 1808 use_node_name_sharing=use_node_name_sharing, 1809 key_dtype=self._key_dtype, 1810 value_dtype=self._value_dtype, 1811 name=self._name) 1812 else: 1813 table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( 1814 shared_name=self._shared_name, 1815 use_node_name_sharing=use_node_name_sharing, 1816 key_dtype=self._key_dtype, 1817 value_dtype=self._value_dtype, 1818 value_shape=self._default_value.get_shape(), 1819 name=self._name) 1820 1821 if context.executing_eagerly(): 1822 self._table_name = None 1823 else: 1824 self._table_name = table_ref.op.name.split("/")[-1] 1825 return table_ref 1826 1827 @property 1828 def name(self): 1829 return self._table_name 1830 1831 def size(self, name=None): 1832 """Compute the number of elements in this table. 1833 1834 Args: 1835 name: A name for the operation (optional). 1836 1837 Returns: 1838 A scalar tensor containing the number of elements in this table. 1839 """ 1840 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 1841 with ops.colocate_with(self.resource_handle): 1842 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 1843 1844 def remove(self, keys, name=None): 1845 """Removes `keys` and its associated values from the table. 1846 1847 If a key is not present in the table, it is silently ignored. 1848 1849 Args: 1850 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 1851 key type. 1852 name: A name for the operation (optional). 1853 1854 Returns: 1855 The created Operation. 1856 1857 Raises: 1858 TypeError: when `keys` do not match the table data types. 1859 """ 1860 if keys.dtype != self._key_dtype: 1861 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, " 1862 f"received: {keys.dtype}") 1863 1864 with ops.name_scope(name, "%s_lookup_table_remove" % self.name, 1865 (self.resource_handle, keys, self._default_value)): 1866 op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys) 1867 1868 return op 1869 1870 def lookup(self, keys, dynamic_default_values=None, name=None): 1871 """Looks up `keys` in a table, outputs the corresponding values. 1872 1873 The `default_value` is used for keys not present in the table. 1874 1875 Args: 1876 keys: Keys to look up. Can be a tensor of any shape. Must match the 1877 table's key_dtype. 1878 dynamic_default_values: The values to use if a key is missing in the 1879 table. If None (by default), the `table.default_value` will be used. 1880 Shape of `dynamic_default_values` must be same with 1881 `table.default_value` or the lookup result tensor. 1882 In the latter case, each key will have a different default value. 1883 1884 For example: 1885 1886 ```python 1887 keys = [0, 1, 3] 1888 dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]] 1889 1890 # The key '0' will use [1, 3, 4] as default value. 1891 # The key '1' will use [2, 3, 9] as default value. 1892 # The key '3' will use [8, 3, 0] as default value. 1893 ``` 1894 1895 name: A name for the operation (optional). 1896 1897 Returns: 1898 A tensor containing the values in the same shape as `keys` using the 1899 table's value type. 1900 1901 Raises: 1902 TypeError: when `keys` do not match the table data types. 1903 """ 1904 with ops.name_scope(name, "%s_lookup_table_find" % self.name, 1905 (self.resource_handle, keys, self._default_value)): 1906 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 1907 with ops.colocate_with(self.resource_handle): 1908 values = gen_lookup_ops.lookup_table_find_v2( 1909 self.resource_handle, keys, dynamic_default_values 1910 if dynamic_default_values is not None else self._default_value) 1911 return values 1912 1913 def insert(self, keys, values, name=None): 1914 """Associates `keys` with `values`. 1915 1916 Args: 1917 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 1918 key type. 1919 values: Values to be associated with keys. Must be a tensor of the same 1920 shape as `keys` and match the table's value type. 1921 name: A name for the operation (optional). 1922 1923 Returns: 1924 The created Operation. 1925 1926 Raises: 1927 TypeError: when `keys` or `values` doesn't match the table data 1928 types. 1929 """ 1930 with ops.name_scope(name, "%s_lookup_table_insert" % self.name, 1931 [self.resource_handle, keys, values]): 1932 keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys") 1933 values = ops.convert_to_tensor(values, self._value_dtype, name="values") 1934 with ops.colocate_with(self.resource_handle): 1935 # pylint: disable=protected-access 1936 op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys, 1937 values) 1938 return op 1939 1940 def export(self, name=None): 1941 """Returns tensors of all keys and values in the table. 1942 1943 Args: 1944 name: A name for the operation (optional). 1945 1946 Returns: 1947 A pair of tensors with the first tensor containing all keys and the 1948 second tensors containing all values in the table. 1949 """ 1950 with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, 1951 [self.resource_handle]): 1952 with ops.colocate_with(self.resource_handle): 1953 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 1954 self.resource_handle, self._key_dtype, self._value_dtype) 1955 return exported_keys, exported_values 1956 1957 def _gather_saveables_for_checkpoint(self): 1958 """For object-based checkpointing.""" 1959 return { 1960 "table": 1961 functools.partial( 1962 MutableHashTable._Saveable, table=self, name=self._name, 1963 table_name=self._name) 1964 } 1965 1966 class _Saveable(BaseSaverBuilder.SaveableObject): 1967 """SaveableObject implementation for DenseHashTable.""" 1968 1969 def __init__(self, table, name, table_name=None): 1970 tensors = table.export() 1971 specs = [ 1972 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), 1973 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") 1974 ] 1975 self.table_name = table_name or name 1976 # pylint: disable=protected-access 1977 super(MutableHashTable._Saveable, self).__init__(table, specs, name) 1978 1979 def restore(self, restored_tensors, restored_shapes): 1980 del restored_shapes # unused 1981 # pylint: disable=protected-access 1982 with ops.name_scope("%s_table_restore" % self.table_name): 1983 with ops.colocate_with(self.op.resource_handle): 1984 return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle, 1985 restored_tensors[0], 1986 restored_tensors[1]) 1987 1988 1989@tf_export("lookup.experimental.DenseHashTable") 1990class DenseHashTable(LookupInterface): 1991 """A mutable hash table with faster lookups and higher memory usage. 1992 1993 Data can be inserted by calling the `insert` method and removed by calling the 1994 `remove` method. It does not support initialization via the init method. 1995 1996 Compared to `MutableHashTable`, `DenseHashTable` offers generally faster 1997 `insert`, `remove` and `lookup` operations, in exchange for a higher overall 1998 memory footprint. 1999 2000 It uses "open addressing" with quadratic reprobing to resolve collisions. This 2001 requires specifying two keys in the key space, `empty_key` and `deleted_key`, 2002 that can never inserted into the table. 2003 2004 Unlike `MutableHashTable`, `DenseHashTable` does not require additional memory 2005 for temporary tensors created during checkpointing and restore operations. 2006 2007 Example usage: 2008 2009 >>> table = tf.lookup.experimental.DenseHashTable( 2010 ... key_dtype=tf.string, 2011 ... value_dtype=tf.int64, 2012 ... default_value=-1, 2013 ... empty_key='', 2014 ... deleted_key='$') 2015 >>> keys = tf.constant(['a', 'b', 'c']) 2016 >>> values = tf.constant([0, 1, 2], dtype=tf.int64) 2017 >>> table.insert(keys, values) 2018 >>> table.remove(tf.constant(['c'])) 2019 >>> table.lookup(tf.constant(['a', 'b', 'c','d'])).numpy() 2020 array([ 0, 1, -1, -1]) 2021 """ 2022 2023 # TODO(andreasst): consider extracting common code with MutableHashTable into 2024 # a common superclass. 2025 def __init__(self, 2026 key_dtype, 2027 value_dtype, 2028 default_value, 2029 empty_key, 2030 deleted_key, 2031 initial_num_buckets=None, 2032 name="MutableDenseHashTable", 2033 checkpoint=True): 2034 """Creates an empty `DenseHashTable` object. 2035 2036 Creates a table, the type of its keys and values are specified by key_dtype 2037 and value_dtype, respectively. 2038 2039 Args: 2040 key_dtype: the type of the key tensors. 2041 value_dtype: the type of the value tensors. 2042 default_value: The value to use if a key is missing in the table. 2043 empty_key: the key to use to represent empty buckets internally. Must not 2044 be used in insert, remove or lookup operations. 2045 deleted_key: the key to use to represent deleted buckets internally. Must 2046 not be used in insert, remove or lookup operations and be different from 2047 the empty_key. 2048 initial_num_buckets: the initial number of buckets. 2049 name: A name for the operation (optional). 2050 checkpoint: if True, the contents of the table are saved to and restored 2051 from checkpoints. If `shared_name` is empty for a checkpointed table, it 2052 is shared using the table node name. 2053 2054 Returns: 2055 A `DenseHashTable` object. 2056 2057 Raises: 2058 ValueError: If checkpoint is True and no name was specified. 2059 """ 2060 self._default_value = ops.convert_to_tensor( 2061 default_value, dtype=value_dtype, name="default_value") 2062 self._key_dtype = key_dtype 2063 self._value_dtype = value_dtype 2064 self._initial_num_buckets = initial_num_buckets 2065 self._value_shape = self._default_value.get_shape() 2066 self._checkpoint = checkpoint 2067 self._name = name 2068 2069 self._empty_key = empty_key 2070 self._deleted_key = deleted_key 2071 self._shared_name = None 2072 if context.executing_eagerly(): 2073 # TODO(allenl): This will leak memory due to kernel caching by the 2074 # shared_name attribute value (but is better than the alternative of 2075 # sharing everything by default when executing eagerly; hopefully creating 2076 # tables in a loop is uncommon). 2077 # TODO(rohanj): Use context.shared_name() instead. 2078 self._shared_name = "table_%d" % (ops.uid(),) 2079 super(DenseHashTable, self).__init__(key_dtype, value_dtype) 2080 2081 self._resource_handle = self._create_resource() 2082 if checkpoint: 2083 saveable = DenseHashTable._Saveable(self, name) 2084 if not context.executing_eagerly(): 2085 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 2086 2087 def _create_resource(self): 2088 # The table must be shared if checkpointing is requested for multi-worker 2089 # training to work correctly. Use the node name if no shared_name has been 2090 # explicitly specified. 2091 use_node_name_sharing = self._checkpoint and self._shared_name is None 2092 empty_key = ops.convert_to_tensor( 2093 self._empty_key, dtype=self._key_dtype, name="empty_key") 2094 deleted_key = ops.convert_to_tensor( 2095 self._deleted_key, dtype=self._key_dtype, name="deleted_key") 2096 table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( 2097 empty_key=empty_key, 2098 deleted_key=deleted_key, 2099 shared_name=self._shared_name, 2100 use_node_name_sharing=use_node_name_sharing, 2101 value_dtype=self._value_dtype, 2102 value_shape=self._value_shape, 2103 initial_num_buckets=self._initial_num_buckets, 2104 name=self._name) 2105 if context.executing_eagerly(): 2106 self._table_name = None 2107 else: 2108 self._table_name = table_ref.op.name.split("/")[-1] 2109 return table_ref 2110 2111 @property 2112 def name(self): 2113 return self._table_name 2114 2115 def size(self, name=None): 2116 """Compute the number of elements in this table. 2117 2118 Args: 2119 name: A name for the operation (optional). 2120 2121 Returns: 2122 A scalar tensor containing the number of elements in this table. 2123 """ 2124 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 2125 with ops.colocate_with(self.resource_handle): 2126 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 2127 2128 def lookup(self, keys, name=None): 2129 """Looks up `keys` in a table, outputs the corresponding values. 2130 2131 The `default_value` is used for keys not present in the table. 2132 2133 Args: 2134 keys: Keys to look up. Can be a tensor of any shape. Must match the 2135 table's key_dtype. 2136 name: A name for the operation (optional). 2137 2138 Returns: 2139 A tensor containing the values in the same shape as `keys` using the 2140 table's value type. 2141 2142 Raises: 2143 TypeError: when `keys` do not match the table data types. 2144 """ 2145 with ops.name_scope(name, "%s_lookup_table_find" % self.name, 2146 [self.resource_handle, keys]): 2147 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 2148 with ops.colocate_with(self.resource_handle): 2149 values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys, 2150 self._default_value) 2151 2152 return values 2153 2154 def insert_or_assign(self, keys, values, name=None): 2155 """Associates `keys` with `values`. 2156 2157 Args: 2158 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 2159 key type. 2160 values: Values to be associated with keys. Must be a tensor of the same 2161 shape as `keys` and match the table's value type. 2162 name: A name for the operation (optional). 2163 2164 Returns: 2165 The created Operation. 2166 2167 Raises: 2168 TypeError: when `keys` or `values` doesn't match the table data 2169 types. 2170 """ 2171 with ops.name_scope(name, "%s_lookup_table_insert" % self.name, 2172 [self.resource_handle, keys, values]): 2173 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 2174 values = ops.convert_to_tensor( 2175 values, dtype=self._value_dtype, name="values") 2176 with ops.colocate_with(self.resource_handle): 2177 op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys, 2178 values) 2179 return op 2180 2181 def insert(self, keys, values, name=None): 2182 """Associates `keys` with `values`. 2183 2184 Args: 2185 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 2186 key type. 2187 values: Values to be associated with keys. Must be a tensor of the same 2188 shape as `keys` and match the table's value type. 2189 name: A name for the operation (optional). 2190 2191 Returns: 2192 The created Operation. 2193 2194 Raises: 2195 TypeError: when `keys` or `values` doesn't match the table data 2196 types. 2197 """ 2198 return self.insert_or_assign(keys, values, name) 2199 2200 def erase(self, keys, name=None): 2201 """Removes `keys` and its associated values from the table. 2202 2203 If a key is not present in the table, it is silently ignored. 2204 2205 Args: 2206 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 2207 key type. 2208 name: A name for the operation (optional). 2209 2210 Returns: 2211 The created Operation. 2212 2213 Raises: 2214 TypeError: when `keys` do not match the table data types. 2215 """ 2216 if keys.dtype != self._key_dtype: 2217 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 2218 (self._key_dtype, keys.dtype)) 2219 2220 with ops.name_scope(name, "%s_lookup_table_remove" % self.name, 2221 (self.resource_handle, keys, self._default_value)): 2222 # pylint: disable=protected-access 2223 op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys) 2224 2225 return op 2226 2227 def remove(self, keys, name=None): 2228 """Removes `keys` and its associated values from the table. 2229 2230 If a key is not present in the table, it is silently ignored. 2231 2232 Args: 2233 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 2234 key type. 2235 name: A name for the operation (optional). 2236 2237 Returns: 2238 The created Operation. 2239 2240 Raises: 2241 TypeError: when `keys` do not match the table data types. 2242 """ 2243 return self.erase(keys, name) 2244 2245 def export(self, name=None): 2246 """Returns tensors of all keys and values in the table. 2247 2248 Args: 2249 name: A name for the operation (optional). 2250 2251 Returns: 2252 A pair of tensors with the first tensor containing all keys and the 2253 second tensors containing all values in the table. 2254 """ 2255 with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, 2256 [self.resource_handle]): 2257 with ops.colocate_with(self.resource_handle): 2258 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 2259 self.resource_handle, self._key_dtype, self._value_dtype) 2260 2261 return exported_keys, exported_values 2262 2263 def _gather_saveables_for_checkpoint(self): 2264 """For object-based checkpointing.""" 2265 return { 2266 "table": 2267 functools.partial( 2268 DenseHashTable._Saveable, table=self, name=self._name, 2269 table_name=self._name) 2270 } 2271 2272 class _Saveable(BaseSaverBuilder.SaveableObject): 2273 """SaveableObject implementation for DenseHashTable.""" 2274 2275 def __init__(self, table, name, table_name=None): 2276 tensors = table.export() 2277 specs = [ 2278 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), 2279 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") 2280 ] 2281 self.table_name = table_name or name 2282 # pylint: disable=protected-access 2283 super(DenseHashTable._Saveable, self).__init__(table, specs, name) 2284 2285 def restore(self, restored_tensors, restored_shapes): 2286 del restored_shapes # unused 2287 # pylint: disable=protected-access 2288 with ops.name_scope("%s_table_restore" % self.table_name): 2289 with ops.colocate_with(self.op.resource_handle): 2290 return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle, 2291 restored_tensors[0], 2292 restored_tensors[1]) 2293 2294 2295ops.NotDifferentiable("LookupTableFind") 2296ops.NotDifferentiable("LookupTableFindV2") 2297ops.NotDifferentiable("LookupTableInsert") 2298ops.NotDifferentiable("LookupTableInsertV2") 2299ops.NotDifferentiable("LookupTableSize") 2300ops.NotDifferentiable("LookupTableSizeV2") 2301ops.NotDifferentiable("HashTable") 2302ops.NotDifferentiable("HashTableV2") 2303ops.NotDifferentiable("InitializeTable") 2304ops.NotDifferentiable("InitializeTableV2") 2305ops.NotDifferentiable("InitializeTableFromTextFile") 2306ops.NotDifferentiable("InitializeTableFromTextFileV2") 2307ops.NotDifferentiable("MutableDenseHashTable") 2308ops.NotDifferentiable("MutableDenseHashTableV2") 2309ops.NotDifferentiable("MutableHashTable") 2310ops.NotDifferentiable("MutableHashTableV2") 2311ops.NotDifferentiable("MutableHashTableOfTensors") 2312ops.NotDifferentiable("MutableHashTableOfTensorsV2") 2313