1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#============================================================================== 15"""Lookup operations.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23import uuid 24 25import six 26 27from tensorflow.python.eager import context 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.framework import tensor_spec 34from tensorflow.python.framework import tensor_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import gen_lookup_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import string_ops 40# go/tf-wildcard-import 41# pylint: disable=wildcard-import 42from tensorflow.python.ops.gen_lookup_ops import * 43from tensorflow.python.ops.ragged import ragged_tensor 44from tensorflow.python.training.saver import BaseSaverBuilder 45# pylint: enable=wildcard-import 46from tensorflow.python.training.tracking import base as trackable_base 47from tensorflow.python.training.tracking import tracking as trackable 48from tensorflow.python.util import compat 49from tensorflow.python.util.deprecation import deprecated 50from tensorflow.python.util.tf_export import tf_export 51 52 53@tf_export(v1=["initialize_all_tables"]) 54@deprecated(None, "Use `tf.tables_initializer` instead.") 55def initialize_all_tables(name="init_all_tables"): 56 """Returns an Op that initializes all tables of the default graph. 57 58 Args: 59 name: Optional name for the initialization op. 60 61 Returns: 62 An Op that initializes all tables. Note that if there are 63 not tables the returned Op is a NoOp. 64 """ 65 return tables_initializer(name) 66 67 68@tf_export(v1=["initializers.tables_initializer", "tables_initializer"]) 69def tables_initializer(name="init_all_tables"): 70 """Returns an Op that initializes all tables of the default graph. 71 72 See the [Low Level 73 Intro](https://www.tensorflow.org/guide/low_level_intro#feature_columns) 74 guide, for an example of usage. 75 76 Args: 77 name: Optional name for the initialization op. 78 79 Returns: 80 An Op that initializes all tables. Note that if there are 81 not tables the returned Op is a NoOp. 82 """ 83 initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) 84 if initializers: 85 return control_flow_ops.group(*initializers, name=name) 86 return control_flow_ops.no_op(name=name) 87 88 89def _check_table_dtypes(table, key_dtype, value_dtype): 90 """Check that the given key_dtype and value_dtype matches the table dtypes. 91 92 Args: 93 table: The table to check types against to. 94 key_dtype: The key data type to check. 95 value_dtype: The value data type to check. 96 97 Raises: 98 TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data 99 types. 100 """ 101 if key_dtype.base_dtype != table.key_dtype: 102 raise TypeError("Invalid key dtype, expected %s but got %s." % 103 (table.key_dtype, key_dtype)) 104 if value_dtype.base_dtype != table.value_dtype: 105 raise TypeError("Invalid value dtype, expected %s but got %s." % 106 (table.value_dtype, value_dtype)) 107 108 109class LookupInterface(trackable.TrackableResource): 110 """Represent a lookup table that persists across different steps.""" 111 112 def __init__(self, key_dtype, value_dtype): 113 """Construct a lookup table interface. 114 115 Args: 116 key_dtype: The table key type. 117 value_dtype: The table value type. 118 """ 119 self._key_dtype = dtypes.as_dtype(key_dtype) 120 self._value_dtype = dtypes.as_dtype(value_dtype) 121 super(LookupInterface, self).__init__() 122 123 def _create_resource(self): 124 raise NotImplementedError 125 126 @property 127 def key_dtype(self): 128 """The table key dtype.""" 129 return self._key_dtype 130 131 @property 132 def value_dtype(self): 133 """The table value dtype.""" 134 return self._value_dtype 135 136 @property 137 def name(self): 138 """The name of the table.""" 139 return NotImplementedError 140 141 def size(self, name=None): 142 """Compute the number of elements in this table.""" 143 raise NotImplementedError 144 145 def lookup(self, keys, name=None): 146 """Looks up `keys` in a table, outputs the corresponding values.""" 147 raise NotImplementedError 148 149 def __getitem__(self, keys): 150 """Looks up `keys` in a table, outputs the corresponding values.""" 151 return self.lookup(keys) 152 153 154class InitializableLookupTableBase(LookupInterface): 155 """Initializable lookup table interface. 156 157 An initializable lookup tables persist across different steps. 158 """ 159 160 def __init__(self, default_value, initializer): 161 """Construct a table object from a table reference. 162 163 If requires a table initializer object (subclass of `TableInitializerBase`). 164 It provides the table key and value types, as well as the op to initialize 165 the table. The caller is responsible to execute the initialization op. 166 167 Args: 168 default_value: The value to use if a key is missing in the table. 169 initializer: The table initializer to use. 170 """ 171 super(InitializableLookupTableBase, self).__init__(initializer.key_dtype, 172 initializer.value_dtype) 173 self._default_value = ops.convert_to_tensor( 174 default_value, dtype=self._value_dtype) 175 self._default_value.get_shape().merge_with(tensor_shape.TensorShape([])) 176 if isinstance(initializer, trackable_base.Trackable): 177 self._initializer = self._track_trackable(initializer, "_initializer") 178 with ops.init_scope(): 179 self._resource_handle = self._create_resource() 180 if (not context.executing_eagerly() and 181 ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access 182 with ops.init_scope(): 183 self._init_op = self._initialize() 184 else: 185 self._init_op = self._initialize() 186 187 def _initialize(self): 188 return self._initializer.initialize(self) 189 190 @property 191 def default_value(self): 192 """The default value of the table.""" 193 return self._default_value 194 195 def size(self, name=None): 196 """Compute the number of elements in this table. 197 198 Args: 199 name: A name for the operation (optional). 200 201 Returns: 202 A scalar tensor containing the number of elements in this table. 203 """ 204 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 205 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 206 207 def lookup(self, keys, name=None): 208 """Looks up `keys` in a table, outputs the corresponding values. 209 210 The `default_value` is used for keys not present in the table. 211 212 Args: 213 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 214 name: A name for the operation (optional). 215 216 Returns: 217 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged, 218 otherwise a dense `Tensor`. 219 220 Raises: 221 TypeError: when `keys` or `default_value` doesn't match the table data 222 types. 223 """ 224 key_tensor = keys 225 if isinstance(keys, 226 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)): 227 key_tensor = keys.values 228 229 if keys.dtype.base_dtype != self._key_dtype: 230 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 231 (self._key_dtype, keys.dtype)) 232 233 with ops.name_scope( 234 name, "%s_Lookup" % self.name, 235 (self.resource_handle, key_tensor, self._default_value)): 236 values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, 237 key_tensor, 238 self._default_value) 239 240 values.set_shape(key_tensor.get_shape()) 241 if isinstance(keys, sparse_tensor.SparseTensor): 242 return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape) 243 elif isinstance(keys, ragged_tensor.RaggedTensor): 244 return keys.with_values(values) 245 else: 246 return values 247 248 249class InitializableLookupTableBaseV1(InitializableLookupTableBase): 250 251 @property 252 def initializer(self): 253 return self._init_op 254 255 256@tf_export("lookup.StaticHashTable", v1=[]) 257class StaticHashTable(InitializableLookupTableBase): 258 """A generic hash table that is immutable once initialized. 259 260 Example usage: 261 262 >>> keys_tensor = tf.constant(['a', 'b', 'c']) 263 >>> vals_tensor = tf.constant([7, 8, 9]) 264 >>> input_tensor = tf.constant(['a', 'f']) 265 >>> table = tf.lookup.StaticHashTable( 266 ... tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), 267 ... default_value=-1) 268 >>> table.lookup(input_tensor).numpy() 269 array([ 7, -1], dtype=int32) 270 271 Or for more pythonic code: 272 273 >>> table[input_tensor].numpy() 274 array([ 7, -1], dtype=int32) 275 276 The result of a lookup operation has the same shape as the argument: 277 278 >>> input_tensor = tf.constant([['a', 'b'], ['c', 'd']]) 279 >>> table[input_tensor].numpy() 280 array([[ 7, 8], 281 [ 9, -1]], dtype=int32) 282 283 284 """ 285 286 def __init__(self, initializer, default_value, name=None): 287 """Creates a non-initialized `HashTable` object. 288 289 Creates a table, the type of its keys and values are specified by the 290 initializer. 291 Before using the table you will have to initialize it. After initialization 292 the table will be immutable. 293 294 Args: 295 initializer: The table initializer to use. See `HashTable` kernel for 296 supported key and value types. 297 default_value: The value to use if a key is missing in the table. 298 name: A name for the operation (optional). 299 300 Returns: 301 A `HashTable` object. 302 """ 303 self._initializer = initializer 304 self._default_value = default_value 305 self._shared_name = self._initializer._shared_name # pylint: disable=protected-access 306 if not self._shared_name: 307 # Force using a shared name so that StaticHashTable resources can be 308 # shared across different kernels. If no "shared_name" is set and 309 # "use_node_name_sharing" is False, then each kernel gets its own local 310 # resource. 311 self._shared_name = "hash_table_%s" % (str(uuid.uuid4()),) 312 self._name = name or "hash_table" 313 self._table_name = None 314 super(StaticHashTable, self).__init__(default_value, initializer) 315 self._value_shape = self._default_value.get_shape() 316 317 def _create_resource(self): 318 table_ref = gen_lookup_ops.hash_table_v2( 319 shared_name=self._shared_name, 320 key_dtype=self._initializer.key_dtype, 321 value_dtype=self._initializer.value_dtype, 322 name=self._name) 323 if context.executing_eagerly(): 324 self._table_name = None 325 else: 326 self._table_name = table_ref.op.name.split("/")[-1] 327 return table_ref 328 329 @property 330 def name(self): 331 return self._table_name 332 333 def export(self, name=None): 334 """Returns tensors of all keys and values in the table. 335 336 Args: 337 name: A name for the operation (optional). 338 339 Returns: 340 A pair of tensors with the first tensor containing all keys and the 341 second tensors containing all values in the table. 342 """ 343 with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]): 344 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 345 self.resource_handle, self._key_dtype, self._value_dtype) 346 347 exported_values.set_shape(exported_keys.get_shape().concatenate( 348 self._value_shape)) 349 return exported_keys, exported_values 350 351 352@tf_export(v1=["lookup.StaticHashTable"]) 353class StaticHashTableV1(StaticHashTable): 354 """A generic hash table that is immutable once initialized. 355 356 When running in graph mode, you must evaluate the tensor returned by 357 `tf.tables_initializer()` before evaluating the tensor returned by 358 this class's `lookup()` method. Example usage in graph mode: 359 360 ```python 361 keys_tensor = tf.constant([1, 2]) 362 vals_tensor = tf.constant([3, 4]) 363 input_tensor = tf.constant([1, 5]) 364 table = tf.lookup.StaticHashTable( 365 tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1) 366 out = table.lookup(input_tensor) 367 with tf.Session() as sess: 368 sess.run(tf.tables_initializer()) 369 print(sess.run(out)) 370 ``` 371 372 In eager mode, no special code is needed to initialize the table. 373 Example usage in eager mode: 374 375 ```python 376 tf.enable_eager_execution() 377 keys_tensor = tf.constant([1, 2]) 378 vals_tensor = tf.constant([3, 4]) 379 input_tensor = tf.constant([1, 5]) 380 table = tf.lookup.StaticHashTable( 381 tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1) 382 print(table.lookup(input_tensor)) 383 ``` 384 """ 385 386 @property 387 def initializer(self): 388 return self._init_op 389 390 391# For backwards compatibility. This will be removed in TF 2.0. 392class HashTable(StaticHashTableV1): 393 394 @property 395 def init(self): 396 return self.initializer 397 398 399class TableInitializerBase(trackable_base.Trackable): 400 """Base class for lookup table initializers.""" 401 402 def __init__(self, key_dtype, value_dtype): 403 """Construct a table initializer object. 404 405 Args: 406 key_dtype: Type of the table keys. 407 value_dtype: Type of the table values. 408 """ 409 self._key_dtype = dtypes.as_dtype(key_dtype) 410 self._value_dtype = dtypes.as_dtype(value_dtype) 411 412 @property 413 def key_dtype(self): 414 """The expected table key dtype.""" 415 return self._key_dtype 416 417 @property 418 def value_dtype(self): 419 """The expected table value dtype.""" 420 return self._value_dtype 421 422 def initialize(self, table): 423 """Returns the table initialization op.""" 424 raise NotImplementedError 425 426 @property 427 def _shared_name(self): 428 """Returns a shared name to be used by the table.""" 429 shared_name = "" 430 if context.executing_eagerly(): 431 # Ensure a unique name when eager execution is enabled to avoid spurious 432 # sharing issues. 433 # TODO(rohanj): Use context.shared_name() instead. 434 shared_name += str(ops.uid()) 435 return shared_name 436 437 438@tf_export("lookup.experimental.DatasetInitializer") 439class DatasetInitializer(TableInitializerBase): 440 """Creates a table initializer from a `tf.data.Dataset`. 441 442 Sample usage: 443 444 >>> keys = tf.data.Dataset.range(100) 445 >>> values = tf.data.Dataset.range(100).map( 446 ... lambda x: string_ops.as_string(x * 2)) 447 >>> ds = tf.data.Dataset.zip((keys, values)) 448 >>> init = tf.lookup.experimental.DatasetInitializer(ds) 449 >>> table = tf.lookup.StaticHashTable(init, "") 450 >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy() 451 array([b'0', b'2', b'4'], dtype=object) 452 453 Attributes: 454 dataset: A `tf.data.Dataset` object that produces tuples of scalars. The 455 first scalar is treated as a key and the second as value. 456 457 Raises: ValueError if `dataset` doesn't conform to specifications. 458 """ 459 460 def __init__(self, dataset): 461 """Creates a table initializer from a `tf.data.Dataset`. 462 463 Args: 464 dataset: A `tf.data.Dataset` object that produces tuples of scalars. The 465 first scalar is treated as a key and the second as value. 466 467 Raises: ValueError if `dataset` doesn't conform to specifications. 468 Returns: A `DatasetInitializer` object 469 """ 470 # Assert that the dataset element spec is a tuple of TensorSpecs where 471 # each tensor is a scalar. 472 self.dataset = dataset 473 elem_spec = self.dataset.element_spec 474 if len(elem_spec) != 2: 475 raise ValueError("element spec size should be 2") 476 if not isinstance(elem_spec[0], tensor_spec.TensorSpec): 477 raise ValueError("elem_spec[0] should be of type TensorSpec") 478 if not isinstance(elem_spec[1], tensor_spec.TensorSpec): 479 raise ValueError("elem_spec[1] should be of type TensorSpec") 480 if elem_spec[0].shape.rank not in (None, 0): 481 raise ValueError("key tensor should be a scalar") 482 if elem_spec[1].shape.rank not in (None, 0): 483 raise ValueError("value tensor should be a scalar") 484 485 key_type = elem_spec[0].dtype 486 value_type = elem_spec[1].dtype 487 super(DatasetInitializer, self).__init__(key_type, value_type) 488 489 def initialize(self, table): 490 _check_table_dtypes(table, self._key_dtype, self._value_dtype) 491 init_op = gen_lookup_ops.initialize_table_from_dataset( 492 table.resource_handle, self.dataset._variant_tensor) # pylint: disable=protected-access 493 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 494 return init_op 495 496 497@tf_export("lookup.KeyValueTensorInitializer") 498class KeyValueTensorInitializer(TableInitializerBase): 499 """Table initializers given `keys` and `values` tensors. 500 501 >>> keys_tensor = tf.constant(['a', 'b', 'c']) 502 >>> vals_tensor = tf.constant([7, 8, 9]) 503 >>> input_tensor = tf.constant(['a', 'f']) 504 >>> init = tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor) 505 >>> table = tf.lookup.StaticHashTable( 506 ... init, 507 ... default_value=-1) 508 >>> table.lookup(input_tensor).numpy() 509 array([ 7, -1], dtype=int32) 510 511 """ 512 513 def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None): 514 """Constructs a table initializer object based on keys and values tensors. 515 516 Args: 517 keys: The tensor for the keys. 518 values: The tensor for the values. 519 key_dtype: The `keys` data type. Used when `keys` is a python array. 520 value_dtype: The `values` data type. Used when `values` is a python array. 521 name: A name for the operation (optional). 522 """ 523 if (not context.executing_eagerly() and 524 ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access 525 with ops.init_scope(): 526 self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") 527 self._values = ops.convert_to_tensor( 528 values, dtype=value_dtype, name="values") 529 else: 530 self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") 531 self._values = ops.convert_to_tensor( 532 values, dtype=value_dtype, name="values") 533 self._name = name if name is not None else "key_value_init" 534 if context.executing_eagerly(): 535 # Ensure a unique name when eager execution is enabled to avoid spurious 536 # sharing issues. 537 # TODO(rohanj): Use context.shared_name() instead. 538 self._name += str(ops.uid()) 539 540 super(KeyValueTensorInitializer, self).__init__(self._keys.dtype, 541 self._values.dtype) 542 543 def initialize(self, table): 544 """Initializes the given `table` with `keys` and `values` tensors. 545 546 Args: 547 table: The table to initialize. 548 549 Returns: 550 The operation that initializes the table. 551 552 Raises: 553 TypeError: when the keys and values data types do not match the table 554 key and value data types. 555 """ 556 _check_table_dtypes(table, self._keys.dtype, self._values.dtype) 557 with ops.name_scope( 558 self._name, values=(table.resource_handle, self._keys, self._values)): 559 init_op = gen_lookup_ops.lookup_table_import_v2(table.resource_handle, 560 self._keys, self._values) 561 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 562 return init_op 563 564 565@tf_export("lookup.TextFileIndex") 566class TextFileIndex(object): 567 """The key and value content to get from each line. 568 569 This class defines the key and value used for `tf.lookup.TextFileInitializer`. 570 571 The key and value content to get from each line is specified either 572 by the following, or a value `>=0`. 573 * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero, 574 expects data type int64. 575 * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data 576 type string. 577 578 A value `>=0` means use the index (starting at zero) of the split line based 579 on `delimiter`. 580 """ 581 WHOLE_LINE = -2 582 LINE_NUMBER = -1 583 584 585@tf_export("lookup.TextFileInitializer") 586class TextFileInitializer(TableInitializerBase): 587 r"""Table initializers from a text file. 588 589 This initializer assigns one entry in the table for each line in the file. 590 591 The key and value type of the table to initialize is given by `key_dtype` and 592 `value_dtype`. 593 594 The key and value content to get from each line is specified by 595 the `key_index` and `value_index`. 596 597 * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero, 598 expects data type int64. 599 * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data 600 type string. 601 * A value `>=0` means use the index (starting at zero) of the split line based 602 on `delimiter`. 603 604 For example if we have a file with the following content: 605 606 >>> import tempfile 607 >>> f = tempfile.NamedTemporaryFile(delete=False) 608 >>> content='\n'.join(["emerson 10", "lake 20", "palmer 30",]) 609 >>> f.file.write(content.encode('utf-8')) 610 >>> f.file.close() 611 612 The following snippet initializes a table with the first column as keys and 613 second column as values: 614 615 * `emerson -> 10` 616 * `lake -> 20` 617 * `palmer -> 30` 618 619 >>> init= tf.lookup.TextFileInitializer( 620 ... filename=f.name, 621 ... key_dtype=tf.string, key_index=0, 622 ... value_dtype=tf.int64, value_index=1, 623 ... delimiter=" ") 624 >>> table = tf.lookup.StaticHashTable(init, default_value=-1) 625 >>> table.lookup(tf.constant(['palmer','lake','tarkus'])).numpy() 626 627 Similarly to initialize the whole line as keys and the line number as values. 628 629 * `emerson 10 -> 0` 630 * `lake 20 -> 1` 631 * `palmer 30 -> 2` 632 633 >>> init = tf.lookup.TextFileInitializer( 634 ... filename=f.name, 635 ... key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, 636 ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER) 637 >>> table = tf.lookup.StaticHashTable(init, -1) 638 >>> table.lookup(tf.constant('palmer 30')).numpy() 639 2 640 """ 641 642 def __init__(self, 643 filename, 644 key_dtype, 645 key_index, 646 value_dtype, 647 value_index, 648 vocab_size=None, 649 delimiter="\t", 650 name=None, 651 value_index_offset=0): 652 """Constructs a table initializer object to populate from a text file. 653 654 It generates one key-value pair per line. The type of table key and 655 value are specified by `key_dtype` and `value_dtype`, respectively. 656 Similarly the content of the key and value are specified by the key_index 657 and value_index. 658 659 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 660 expects data type int64. 661 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 662 type string. 663 - A value >=0 means use the index (starting at zero) of the split line based 664 on `delimiter`. 665 666 Args: 667 filename: The filename of the text file to be used for initialization. The 668 path must be accessible from wherever the graph is initialized (eg. 669 trainer or eval workers). The filename may be a scalar `Tensor`. 670 key_dtype: The `key` data type. 671 key_index: the index that represents information of a line to get the 672 table 'key' values from. 673 value_dtype: The `value` data type. 674 value_index: the index that represents information of a line to get the 675 table 'value' values from.' 676 vocab_size: The number of elements in the file, if known. 677 delimiter: The delimiter to separate fields in a line. 678 name: A name for the operation (optional). 679 value_index_offset: A number to add to all indices extracted from the file 680 This is useful for cases where a user would like to reserve one or more 681 low index values for control characters. For instance, if you would 682 like to ensure that no vocabulary item is mapped to index 0 (so you can 683 reserve 0 for a masking value), you can set value_index_offset to 1; 684 this will mean that the first vocabulary element is mapped to 1 685 instead of 0. 686 687 Raises: 688 ValueError: when the filename is empty, or when the table key and value 689 data types do not match the expected data types. 690 """ 691 if not isinstance(filename, ops.Tensor) and not filename: 692 raise ValueError("Filename required for %s." % name) 693 694 self._filename_arg = filename 695 key_dtype = dtypes.as_dtype(key_dtype) 696 value_dtype = dtypes.as_dtype(value_dtype) 697 698 if key_index < -2: 699 raise ValueError("Invalid key index %s." % (key_index)) 700 701 if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64: 702 raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." % 703 (dtypes.int64, key_dtype)) 704 if ((key_index == TextFileIndex.WHOLE_LINE) and 705 (not key_dtype.is_integer) and (key_dtype != dtypes.string)): 706 raise ValueError( 707 "Signature mismatch. Keys must be integer or string, got %s." % 708 key_dtype) 709 if value_index < -2: 710 raise ValueError("Invalid value index %s." % (value_index)) 711 712 if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64: 713 raise ValueError("Signature mismatch. Values must be dtype %s, got %s." % 714 (dtypes.int64, value_dtype)) 715 if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string: 716 raise ValueError("Signature mismatch. Values must be dtype %s, got %s." % 717 (dtypes.string, value_dtype)) 718 719 if (vocab_size is not None) and (vocab_size <= 0): 720 raise ValueError("Invalid vocab_size %s." % vocab_size) 721 722 self._key_index = key_index 723 self._value_index = value_index 724 self._vocab_size = vocab_size 725 self._delimiter = delimiter 726 self._name = name 727 self._filename = self._track_trackable( 728 trackable.Asset(filename), "_filename") 729 self._offset = value_index_offset 730 731 super(TextFileInitializer, self).__init__(key_dtype, value_dtype) 732 733 def initialize(self, table): 734 """Initializes the table from a text file. 735 736 Args: 737 table: The table to be initialized. 738 739 Returns: 740 The operation that initializes the table. 741 742 Raises: 743 TypeError: when the keys and values data types do not match the table 744 key and value data types. 745 """ 746 _check_table_dtypes(table, self.key_dtype, self.value_dtype) 747 with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)): 748 filename = ops.convert_to_tensor( 749 self._filename, dtypes.string, name="asset_filepath") 750 init_op = gen_lookup_ops.initialize_table_from_text_file_v2( 751 table.resource_handle, filename, self._key_index, self._value_index, 752 -1 if self._vocab_size is None else self._vocab_size, self._delimiter, 753 self._offset) 754 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 755 # If the filename tensor is anything other than a string constant (e.g., 756 # if it is a placeholder) then it does not make sense to track it as an 757 # asset. 758 if not context.executing_eagerly() and constant_op.is_constant(filename): 759 ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename) 760 return init_op 761 762 @property 763 def _shared_name(self): 764 if self._vocab_size: 765 # Keep the shared_name: 766 # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index> 767 shared_name = "hash_table_%s_%d_%s_%s" % ( 768 self._filename_arg, self._vocab_size, self._key_index, 769 self._value_index) 770 else: 771 # Keep the shared_name 772 # <table_type>_<filename>_<key_index>_<value_index> 773 shared_name = "hash_table_%s_%s_%s" % (self._filename_arg, 774 self._key_index, self._value_index) 775 return shared_name 776 777 778class TextFileStringTableInitializer(TextFileInitializer): 779 """Table initializer for `int64` IDs to string tables from a text file.""" 780 781 def __init__(self, 782 filename, 783 key_column_index=TextFileIndex.LINE_NUMBER, 784 value_column_index=TextFileIndex.WHOLE_LINE, 785 vocab_size=None, 786 delimiter="\t", 787 name="text_file_string_table_init"): 788 """Constructs an initializer for an id-to-string table from a text file. 789 790 It populates a table that its key and value types are int64 and string, 791 respectively. It generates one key-value pair per line. 792 The content of the key and value are specified by `key_column_index` 793 and `value_column_index`. 794 795 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 796 expects data type int64. 797 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 798 type string. 799 - A value >=0 means use the index (starting at zero) of the split line based 800 on `delimiter`. 801 802 Args: 803 filename: The filename of the text file to be used for initialization. The 804 path must be accessible from wherever the graph is initialized (eg. 805 trainer or eval workers). The filename may be a scalar `Tensor`. 806 key_column_index: The column index from the text file to get the keys 807 from. The default is to use the line number, starting from zero. 808 value_column_index: The column index from the text file to get the values 809 from. The default is to use the whole line content. 810 vocab_size: The number of elements in the file, if known. 811 delimiter: The delimiter to separate fields in a line. 812 name: Optional name for the op. 813 814 Raises: 815 TypeError: when the filename is empty, or when the table key and value 816 data types do not match the expected data types. 817 """ 818 super(TextFileStringTableInitializer, self).__init__( 819 filename, 820 dtypes.int64, 821 key_column_index, 822 dtypes.string, 823 value_column_index, 824 vocab_size=vocab_size, 825 delimiter=delimiter, 826 name=name) 827 828 829class TextFileIdTableInitializer(TextFileInitializer): 830 """Table initializer for string to `int64` IDs tables from a text file.""" 831 832 def __init__(self, 833 filename, 834 key_column_index=TextFileIndex.WHOLE_LINE, 835 value_column_index=TextFileIndex.LINE_NUMBER, 836 vocab_size=None, 837 delimiter="\t", 838 name="text_file_id_table_init", 839 key_dtype=dtypes.string): 840 """Constructs an initializer for an string-to-id table from a text file. 841 842 It populates a table that its key and value types are string and int64, 843 respectively. It generates one key-value pair per line. 844 The content of the key and value are specified by the key_index 845 and value_index. 846 847 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 848 expects data type int64. 849 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 850 type string. 851 - A value >=0 means use the index (starting at zero) of the split line based 852 on `delimiter`. 853 854 Args: 855 filename: The filename of the text file to be used for initialization. The 856 path must be accessible from wherever the graph is initialized (eg. 857 trainer or eval workers). The filename may be a scalar `Tensor`. 858 key_column_index: The column index from the text file to get the `key` 859 values from. The default is to use the whole line content. 860 value_column_index: The column index from the text file to get the `value` 861 values from. The default is to use the line number, starting from zero. 862 vocab_size: The number of elements in the file, if known. 863 delimiter: The delimiter to separate fields in a line. 864 name: Optional name for the op. 865 key_dtype: The `key` data type. 866 867 Raises: 868 TypeError: when the filename is empty, or when the table key and value 869 data types do not match the expected data types. 870 """ 871 super(TextFileIdTableInitializer, self).__init__( 872 filename, 873 key_dtype, 874 key_column_index, 875 dtypes.int64, 876 value_column_index, 877 vocab_size=vocab_size, 878 delimiter=delimiter, 879 name=name) 880 881 882class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])): 883 """A structure for the spec of the hashing function to use for hash buckets. 884 885 `hasher` is the name of the hashing function to use (eg. "fasthash", 886 "stronghash"). 887 `key` is optional and specify the key to use for the hash function if 888 supported, currently only used by a strong hash. 889 890 Fields: 891 hasher: The hasher name to use. 892 key: The key to be used by the hashing function, if required. 893 """ 894 __slots__ = () 895 896 897FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name 898 899 900class StrongHashSpec(HasherSpec): 901 """A structure to specify a key of the strong keyed hash spec. 902 903 The strong hash requires a `key`, which is a list of 2 unsigned integer 904 numbers. These should be non-zero; random numbers generated from random.org 905 would be a fine choice. 906 907 Fields: 908 key: The key to be used by the keyed hashing function. 909 """ 910 __slots__ = () 911 912 def __new__(cls, key): 913 if len(key) != 2: 914 raise ValueError("key must have size 2, got %s." % len(key)) 915 916 if not isinstance(key[0], compat.integral_types) or not isinstance( 917 key[1], compat.integral_types): 918 raise TypeError("Invalid key %s. Must be unsigned integer values." % key) 919 920 return super(cls, StrongHashSpec).__new__(cls, "stronghash", key) 921 922 923def _as_string(tensor): 924 if dtypes.string == tensor.dtype.base_dtype: 925 return tensor 926 return string_ops.as_string(tensor) 927 928 929class IdTableWithHashBuckets(LookupInterface): 930 r"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets. 931 932 For example, if an instance of `IdTableWithHashBuckets` is initialized with a 933 string-to-id table that maps: 934 935 * `emerson -> 0` 936 * `lake -> 1` 937 * `palmer -> 2` 938 939 The `IdTableWithHashBuckets` object will performs the following mapping: 940 941 * `emerson -> 0` 942 * `lake -> 1` 943 * `palmer -> 2` 944 * `<other term> -> bucket_id`, where bucket_id will be between `3` and 945 `3 + num_oov_buckets - 1`, calculated by: 946 `hash(<term>) % num_oov_buckets + vocab_size` 947 948 If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`, 949 the lookup result is `[0, 1, 2, 4, 7]`. 950 951 If `table` is None, only out-of-vocabulary buckets are used. 952 953 Example usage: 954 955 ```python 956 num_oov_buckets = 3 957 input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"]) 958 table = tf.IdTableWithHashBuckets( 959 tf.StaticHashTable( 960 tf.lookup.TextFileInitializer( 961 filename, 962 key_dtype=tf.string, 963 key_index=tf.lookup.TextFileIndex.WHOLE_LINE, 964 value_dtype=tf.int64, 965 value_index=tf.lookup.TextFileIndex.LINE_NUMBER, 966 delimiter="\t"), 967 default_value), 968 num_oov_buckets) 969 out = table.lookup(input_tensor). 970 table.init.run() 971 print(out.eval()) 972 ``` 973 974 The hash function used for generating out-of-vocabulary buckets ID is handled 975 by `hasher_spec`. 976 """ 977 978 def __init__(self, 979 table, 980 num_oov_buckets, 981 hasher_spec=FastHashSpec, 982 name=None, 983 key_dtype=None): 984 """Construct a `IdTableWithHashBuckets` object. 985 986 Args: 987 table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids. 988 num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. 989 hasher_spec: A `HasherSpec` to specify the hash function to use for 990 assignation of out-of-vocabulary buckets (optional). 991 name: A name for the operation (optional). 992 key_dtype: Data type of keys passed to `lookup`. Defaults to 993 `table.key_dtype` if `table` is specified, otherwise `tf.string`. Must 994 be string or integer, and must be castable to `table.key_dtype`. 995 996 Raises: 997 ValueError: when `table` in None and `num_oov_buckets` is not positive. 998 TypeError: when `hasher_spec` is invalid. 999 """ 1000 # If a name ends with a '/' it is a "name scope", remove all trailing '/' 1001 # characters to use as table name. 1002 if name: 1003 name = name.rstrip("/") 1004 if table: 1005 if key_dtype is None: 1006 key_dtype = table.key_dtype 1007 supported_table_key_dtypes = (dtypes.int64, dtypes.string) 1008 if table.key_dtype not in supported_table_key_dtypes: 1009 raise TypeError("Invalid key dtype, expected one of %s, but got %s." % 1010 (supported_table_key_dtypes, key_dtype)) 1011 if table.key_dtype.is_integer != key_dtype.is_integer: 1012 raise TypeError("Invalid key dtype, expected %s but got %s." % 1013 ("integer" if key_dtype.is_integer else "non-integer", 1014 table.key_dtype)) 1015 if table.value_dtype != dtypes.int64: 1016 raise TypeError("Invalid value dtype, expected %s but got %s." % 1017 (dtypes.int64, table.value_dtype)) 1018 self._table = table 1019 name = name or self._table.name 1020 else: 1021 if num_oov_buckets <= 0: 1022 raise ValueError("oov_buckets must be > 0 if no table is supplied.") 1023 key_dtype = dtypes.string if key_dtype is None else key_dtype 1024 self._table = None 1025 name = name or "hash_bucket" 1026 if (not key_dtype.is_integer) and (dtypes.string != key_dtype): 1027 raise TypeError("Invalid key_dtype, expected integer or string, got %s." % 1028 key_dtype) 1029 self._num_oov_buckets = num_oov_buckets 1030 1031 if not isinstance(hasher_spec, HasherSpec): 1032 raise TypeError("hasher_spec must be of type HasherSpec, got %s" % 1033 hasher_spec) 1034 self._hasher_spec = hasher_spec 1035 if name: 1036 self._table_name = name.split("/")[-1] 1037 else: 1038 self._table_name = None 1039 super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64) 1040 1041 def _create_resource(self): 1042 if self._table is not None: 1043 return self._table._create_resource() # pylint: disable=protected-access 1044 return None 1045 1046 def _initialize(self): 1047 if self._table is not None: 1048 return self._table._initialize() # pylint: disable=protected-access 1049 with ops.name_scope(None, "init"): 1050 return control_flow_ops.no_op() 1051 1052 @property 1053 def initializer(self): 1054 if self._table is not None: 1055 return self._table._init_op # pylint: disable=protected-access 1056 with ops.name_scope(None, "init"): 1057 return control_flow_ops.no_op() 1058 1059 @property 1060 @deprecated("2018-12-15", "Use `initializer` instead.") 1061 def init(self): 1062 return self.initializer 1063 1064 @property 1065 def resource_handle(self): 1066 if self._table is not None: 1067 return self._table.resource_handle 1068 return None 1069 1070 @property 1071 def name(self): 1072 return self._table_name 1073 1074 def size(self, name=None): 1075 """Compute the number of elements in this table.""" 1076 with ops.name_scope(name, "%s_Size" % self.name): 1077 if self._table: 1078 tsize = self._table.size() 1079 else: 1080 tsize = ops.convert_to_tensor(0, dtype=dtypes.int64) 1081 return tsize + self._num_oov_buckets 1082 1083 def _get_string_to_hash_bucket_fn(self, hasher_spec): 1084 """Returns the string_to_hash_bucket op to use based on `hasher_spec`.""" 1085 if not isinstance(hasher_spec, HasherSpec): 1086 raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec) 1087 if hasher_spec.hasher == "fasthash": 1088 return string_ops.string_to_hash_bucket_fast 1089 if hasher_spec.hasher == "legacy": 1090 return string_ops.string_to_hash_bucket 1091 if hasher_spec.hasher == "stronghash": 1092 return functools.partial( 1093 string_ops.string_to_hash_bucket_strong, key=hasher_spec.key) 1094 raise ValueError("Unknown hasher %s" % hasher_spec.hasher) 1095 1096 def lookup(self, keys, name=None): 1097 """Looks up `keys` in the table, outputs the corresponding values. 1098 1099 It assigns out-of-vocabulary keys to buckets based in their hashes. 1100 1101 Args: 1102 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 1103 name: Optional name for the op. 1104 1105 Returns: 1106 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged, 1107 otherwise a dense `Tensor`. 1108 1109 Raises: 1110 TypeError: when `keys` doesn't match the table key data type. 1111 """ 1112 if keys.dtype.base_dtype != self._key_dtype: 1113 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 1114 (self._key_dtype, keys.dtype)) 1115 values = keys 1116 if isinstance(keys, 1117 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)): 1118 values = keys.values 1119 if self._table and (self._table.key_dtype.base_dtype == dtypes.int64): 1120 values = math_ops.cast(values, dtypes.int64) 1121 1122 if self._num_oov_buckets == 0: 1123 ids = self._table.lookup(values, name=name) 1124 else: 1125 # TODO(yleon): Consider moving this functionality to its own kernel. 1126 with ops.name_scope(name, "%s_Lookup" % self.name): 1127 str_to_hash_bucket = self._get_string_to_hash_bucket_fn( 1128 self._hasher_spec) 1129 buckets = str_to_hash_bucket( 1130 _as_string(values), 1131 num_buckets=self._num_oov_buckets, 1132 name="hash_bucket") 1133 if self._table: 1134 ids = self._table.lookup(values) 1135 buckets = math_ops.add(buckets, self._table.size()) 1136 is_id_non_default = math_ops.not_equal(ids, self._table.default_value) 1137 ids = array_ops.where_v2(is_id_non_default, ids, buckets) 1138 else: 1139 ids = buckets 1140 if isinstance(keys, sparse_tensor.SparseTensor): 1141 return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape) 1142 elif isinstance(keys, ragged_tensor.RaggedTensor): 1143 return keys.with_values(ids) 1144 return ids 1145 1146 1147@tf_export("lookup.StaticVocabularyTable", v1=[]) 1148class StaticVocabularyTable(LookupInterface): 1149 r"""String to Id table that assigns out-of-vocabulary keys to hash buckets. 1150 1151 For example, if an instance of `StaticVocabularyTable` is initialized with a 1152 string-to-id initializer that maps: 1153 1154 >>> init = tf.lookup.KeyValueTensorInitializer( 1155 ... keys=tf.constant(['emerson', 'lake', 'palmer']), 1156 ... values=tf.constant([0, 1, 2], dtype=tf.int64)) 1157 >>> table = tf.lookup.StaticVocabularyTable( 1158 ... init, 1159 ... num_oov_buckets=5) 1160 1161 The `Vocabulary` object will performs the following mapping: 1162 1163 * `emerson -> 0` 1164 * `lake -> 1` 1165 * `palmer -> 2` 1166 * `<other term> -> bucket_id`, where `bucket_id` will be between `3` and 1167 `3 + num_oov_buckets - 1 = 7`, calculated by: 1168 `hash(<term>) % num_oov_buckets + vocab_size` 1169 1170 If input_tensor is: 1171 1172 >>> input_tensor = tf.constant(["emerson", "lake", "palmer", 1173 ... "king", "crimson"]) 1174 >>> table[input_tensor].numpy() 1175 array([0, 1, 2, 6, 7]) 1176 1177 If `initializer` is None, only out-of-vocabulary buckets are used. 1178 1179 Example usage: 1180 1181 >>> num_oov_buckets = 3 1182 >>> vocab = ["emerson", "lake", "palmer", "crimnson"] 1183 >>> import tempfile 1184 >>> f = tempfile.NamedTemporaryFile(delete=False) 1185 >>> f.write('\n'.join(vocab).encode('utf-8')) 1186 >>> f.close() 1187 1188 >>> init = tf.lookup.TextFileInitializer( 1189 ... f.name, 1190 ... key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, 1191 ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER) 1192 >>> table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets) 1193 >>> table.lookup(tf.constant(["palmer", "crimnson" , "king", 1194 ... "tarkus", "black", "moon"])).numpy() 1195 array([2, 3, 5, 6, 6, 4]) 1196 1197 The hash function used for generating out-of-vocabulary buckets ID is 1198 Fingerprint64. 1199 """ 1200 1201 def __init__(self, 1202 initializer, 1203 num_oov_buckets, 1204 lookup_key_dtype=None, 1205 name=None): 1206 """Construct a `StaticVocabularyTable` object. 1207 1208 Args: 1209 initializer: A `TableInitializerBase` object that contains the data used 1210 to initialize the table. If None, then we only use out-of-vocab buckets. 1211 num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. Must 1212 be greater than zero. 1213 lookup_key_dtype: Data type of keys passed to `lookup`. Defaults to 1214 `initializer.key_dtype` if `initializer` is specified, otherwise 1215 `tf.string`. Must be string or integer, and must be castable to 1216 `initializer.key_dtype`. 1217 name: A name for the operation (optional). 1218 1219 Raises: 1220 ValueError: when `num_oov_buckets` is not positive. 1221 TypeError: when lookup_key_dtype or initializer.key_dtype are not 1222 integer or string. Also when initializer.value_dtype != int64. 1223 """ 1224 if num_oov_buckets <= 0: 1225 raise ValueError("oov_buckets must be > 0.") 1226 # If a name ends with a '/' it is a "name scope", remove all trailing '/' 1227 # characters to use as table name. 1228 if name: 1229 name = name.rstrip("/") 1230 if initializer: 1231 if lookup_key_dtype is None: 1232 lookup_key_dtype = initializer.key_dtype 1233 supported_table_key_dtypes = (dtypes.int64, dtypes.string) 1234 if initializer.key_dtype not in supported_table_key_dtypes: 1235 raise TypeError("Invalid key dtype, expected one of %s, but got %s." % 1236 (supported_table_key_dtypes, initializer.key_dtype)) 1237 if initializer.key_dtype.is_integer != lookup_key_dtype.is_integer: 1238 raise TypeError( 1239 "Invalid key dtype, expected %s but got %s." % 1240 ("integer" if lookup_key_dtype.is_integer else "non-integer", 1241 initializer.key_dtype)) 1242 if initializer.value_dtype != dtypes.int64: 1243 raise TypeError("Invalid value dtype, expected %s but got %s." % 1244 (dtypes.int64, initializer.value_dtype)) 1245 if isinstance(initializer, trackable_base.Trackable): 1246 self._initializer = self._track_trackable(initializer, "_initializer") 1247 self._table = HashTable(initializer, default_value=-1) 1248 name = name or self._table.name 1249 else: 1250 lookup_key_dtype = dtypes.string 1251 self._table = None 1252 name = name or "hash_bucket" 1253 if (not lookup_key_dtype.is_integer) and (dtypes.string != 1254 lookup_key_dtype): 1255 raise TypeError("Invalid key_dtype, expected integer or string, got %s." % 1256 lookup_key_dtype) 1257 self._num_oov_buckets = num_oov_buckets 1258 1259 self._table_name = None 1260 if name is not None: 1261 self._table_name = name.split("/")[-1] 1262 super(StaticVocabularyTable, self).__init__(lookup_key_dtype, dtypes.int64) 1263 1264 def _create_resource(self): 1265 if self._table is not None: 1266 return self._table._create_resource() # pylint: disable=protected-access 1267 return None 1268 1269 def _initialize(self): 1270 if self._table is not None: 1271 return self._table._initialize() # pylint: disable=protected-access 1272 with ops.name_scope(None, "init"): 1273 return control_flow_ops.no_op() 1274 1275 @property 1276 def resource_handle(self): 1277 if self._table is not None: 1278 return self._table.resource_handle 1279 return None 1280 1281 @property 1282 def name(self): 1283 return self._table_name 1284 1285 def size(self, name=None): 1286 """Compute the number of elements in this table.""" 1287 with ops.name_scope(name, "%s_Size" % self.name): 1288 if self._table: 1289 tsize = self._table.size() 1290 else: 1291 tsize = ops.convert_to_tensor(0, dtype=dtypes.int64) 1292 return tsize + self._num_oov_buckets 1293 1294 def lookup(self, keys, name=None): 1295 """Looks up `keys` in the table, outputs the corresponding values. 1296 1297 It assigns out-of-vocabulary keys to buckets based in their hashes. 1298 1299 Args: 1300 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 1301 name: Optional name for the op. 1302 1303 Returns: 1304 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged, 1305 otherwise a dense `Tensor`. 1306 1307 Raises: 1308 TypeError: when `keys` doesn't match the table key data type. 1309 """ 1310 if keys.dtype.base_dtype != self._key_dtype: 1311 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 1312 (self._key_dtype, keys.dtype)) 1313 values = keys 1314 if isinstance(keys, 1315 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)): 1316 values = keys.values 1317 if self._table and (self._table.key_dtype.base_dtype == dtypes.int64): 1318 values = math_ops.cast(values, dtypes.int64) 1319 1320 # TODO(yleon): Consider moving this functionality to its own kernel. 1321 with ops.name_scope(name, "%s_Lookup" % self.name): 1322 buckets = string_ops.string_to_hash_bucket_fast( 1323 _as_string(values), 1324 num_buckets=self._num_oov_buckets, 1325 name="hash_bucket") 1326 if self._table: 1327 ids = self._table.lookup(values) 1328 buckets = math_ops.add(buckets, self._table.size()) 1329 is_id_non_default = math_ops.not_equal(ids, self._table.default_value) 1330 ids = array_ops.where_v2(is_id_non_default, ids, buckets) 1331 else: 1332 ids = buckets 1333 if isinstance(keys, sparse_tensor.SparseTensor): 1334 return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape) 1335 elif isinstance(keys, ragged_tensor.RaggedTensor): 1336 return keys.with_values(ids) 1337 return ids 1338 1339 1340@tf_export(v1=["lookup.StaticVocabularyTable"]) 1341class StaticVocabularyTableV1(StaticVocabularyTable): 1342 1343 @property 1344 def initializer(self): 1345 if self._table is not None: 1346 return self._table._init_op # pylint: disable=protected-access 1347 with ops.name_scope(None, "init"): 1348 return control_flow_ops.no_op() 1349 1350 1351def index_table_from_file(vocabulary_file=None, 1352 num_oov_buckets=0, 1353 vocab_size=None, 1354 default_value=-1, 1355 hasher_spec=FastHashSpec, 1356 key_dtype=dtypes.string, 1357 name=None, 1358 key_column_index=TextFileIndex.WHOLE_LINE, 1359 value_column_index=TextFileIndex.LINE_NUMBER, 1360 delimiter="\t"): 1361 """Returns a lookup table that converts a string tensor into int64 IDs. 1362 1363 This operation constructs a lookup table to convert tensor of strings into 1364 int64 IDs. The mapping can be initialized from a vocabulary file specified in 1365 `vocabulary_file`, where the whole line is the key and the zero-based line 1366 number is the ID. 1367 1368 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 1369 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 1370 `default_value`. 1371 The bucket ID range is 1372 `[vocabulary size, vocabulary size + num_oov_buckets - 1]`. 1373 1374 The underlying table must be initialized by calling 1375 `session.run(tf.compat.v1.tables_initializer())` or 1376 `session.run(table.init())` once. 1377 1378 To specify multi-column vocabulary files, use key_column_index and 1379 value_column_index and delimiter. 1380 1381 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 1382 expects data type int64. 1383 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 1384 type string. 1385 - A value >=0 means use the index (starting at zero) of the split line based 1386 on `delimiter`. 1387 1388 Sample Usages: 1389 1390 If we have a vocabulary file "test.txt" with the following content: 1391 1392 ``` 1393 emerson 1394 lake 1395 palmer 1396 ``` 1397 1398 ```python 1399 features = tf.constant(["emerson", "lake", "and", "palmer"]) 1400 table = tf.lookup.index_table_from_file( 1401 vocabulary_file="test.txt", num_oov_buckets=1) 1402 ids = table.lookup(features) 1403 ... 1404 tf.compat.v1.tables_initializer().run() 1405 1406 ids.eval() ==> [0, 1, 3, 2] # where 3 is the out-of-vocabulary bucket 1407 ``` 1408 1409 Args: 1410 vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. 1411 num_oov_buckets: The number of out-of-vocabulary buckets. 1412 vocab_size: Number of the elements in the vocabulary, if known. 1413 default_value: The value to use for out-of-vocabulary feature values. 1414 Defaults to -1. 1415 hasher_spec: A `HasherSpec` to specify the hash function to use for 1416 assignation of out-of-vocabulary buckets. 1417 key_dtype: The `key` data type. 1418 name: A name for this op (optional). 1419 key_column_index: The column index from the text file to get the `key` 1420 values from. The default is to use the whole line content. 1421 value_column_index: The column index from the text file to get the `value` 1422 values from. The default is to use the line number, starting from zero. 1423 delimiter: The delimiter to separate fields in a line. 1424 1425 Returns: 1426 The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`. 1427 1428 Raises: 1429 ValueError: If `vocabulary_file` is not set. 1430 ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater 1431 than zero. 1432 """ 1433 if vocabulary_file is None or (isinstance(vocabulary_file, six.string_types) 1434 and not vocabulary_file): 1435 raise ValueError("vocabulary_file must be specified and must not be empty.") 1436 if num_oov_buckets < 0: 1437 raise ValueError( 1438 "num_oov_buckets must be greater or equal than 0, got %d." % 1439 num_oov_buckets) 1440 if vocab_size is not None and vocab_size < 1: 1441 vocab_file_value = vocabulary_file 1442 if isinstance(vocabulary_file, ops.Tensor): 1443 vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?" 1444 raise ValueError("vocab_size must be greater than 0, got %d. " 1445 "vocabulary_file: %s" % (vocab_size, vocab_file_value)) 1446 if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): 1447 raise TypeError("Only integer and string keys are supported.") 1448 1449 with ops.name_scope(name, "string_to_index"): 1450 table = None 1451 with ops.name_scope(None, "hash_table"): 1452 init = TextFileIdTableInitializer( 1453 vocabulary_file, 1454 vocab_size=vocab_size, 1455 key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype, 1456 name="table_init", 1457 key_column_index=key_column_index, 1458 value_column_index=value_column_index, 1459 delimiter=delimiter) 1460 1461 table = StaticHashTableV1(init, default_value) 1462 if num_oov_buckets: 1463 table = IdTableWithHashBuckets( 1464 table, 1465 num_oov_buckets=num_oov_buckets, 1466 hasher_spec=hasher_spec, 1467 key_dtype=key_dtype) 1468 1469 return table 1470 1471 1472def index_table_from_tensor(vocabulary_list, 1473 num_oov_buckets=0, 1474 default_value=-1, 1475 hasher_spec=FastHashSpec, 1476 dtype=dtypes.string, 1477 name=None): 1478 """Returns a lookup table that converts a string tensor into int64 IDs. 1479 1480 This operation constructs a lookup table to convert tensor of strings into 1481 int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D 1482 tensor where each element is a key and corresponding index within the tensor 1483 is the value. 1484 1485 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 1486 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 1487 `default_value`. The bucket ID range is 1488 `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`. 1489 1490 The underlying table must be initialized by calling 1491 `session.run(tf.compat.v1.tables_initializer())` or 1492 `session.run(table.init())` once. 1493 1494 Elements in `vocabulary_list` cannot have duplicates, otherwise when executing 1495 the table initializer op, it will throw a `FailedPreconditionError`. 1496 1497 Sample Usages: 1498 1499 ```python 1500 vocabulary_list = tf.constant(["emerson", "lake", "palmer"]) 1501 table = tf.lookup.index_table_from_tensor( 1502 vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1) 1503 features = tf.constant(["emerson", "lake", "and", "palmer"]) 1504 ids = table.lookup(features) 1505 ... 1506 tf.compat.v1.tables_initializer().run() 1507 1508 ids.eval() ==> [0, 1, 4, 2] 1509 ``` 1510 1511 Args: 1512 vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to 1513 indices. The type of this object must be castable to `dtype`. 1514 num_oov_buckets: The number of out-of-vocabulary buckets. 1515 default_value: The value to use for out-of-vocabulary feature values. 1516 Defaults to -1. 1517 hasher_spec: A `HasherSpec` to specify the hash function to use for 1518 assignment of out-of-vocabulary buckets. 1519 dtype: The type of values passed to `lookup`. Only string and integers are 1520 supported. 1521 name: A name for this op (optional). 1522 1523 Returns: 1524 The lookup table to map an input `Tensor` to index `int64` `Tensor`. 1525 1526 Raises: 1527 ValueError: If `vocabulary_list` is invalid. 1528 ValueError: If `num_oov_buckets` is negative. 1529 """ 1530 if vocabulary_list is None: 1531 raise ValueError("vocabulary_list must be specified.") 1532 1533 if num_oov_buckets < 0: 1534 raise ValueError( 1535 "num_oov_buckets must be greater or equal than 0, got %d." % 1536 num_oov_buckets) 1537 1538 if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype): 1539 raise TypeError("Only integer and string keys are supported.") 1540 1541 with ops.name_scope(name, "string_to_index"): 1542 keys = ops.convert_to_tensor(vocabulary_list) 1543 if keys.dtype.is_integer != dtype.is_integer: 1544 raise ValueError( 1545 "Expected %s, got %s." % 1546 ("integer" if dtype.is_integer else "non-integer", keys.dtype)) 1547 if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype): 1548 raise ValueError("Expected %s, got %s." % (dtype, keys.dtype)) 1549 num_elements = array_ops.size(keys) 1550 values = math_ops.cast(math_ops.range(num_elements), dtypes.int64) 1551 1552 with ops.name_scope(None, "hash_table"): 1553 table_keys = math_ops.cast( 1554 keys, dtypes.int64) if keys.dtype.is_integer else keys 1555 init = KeyValueTensorInitializer( 1556 table_keys, 1557 values, 1558 table_keys.dtype.base_dtype, 1559 dtypes.int64, 1560 name="table_init") 1561 table = StaticHashTableV1(init, default_value) 1562 if num_oov_buckets: 1563 table = IdTableWithHashBuckets( 1564 table, 1565 num_oov_buckets=num_oov_buckets, 1566 hasher_spec=hasher_spec, 1567 key_dtype=dtype) 1568 return table 1569 1570 1571def index_to_string_table_from_file(vocabulary_file, 1572 vocab_size=None, 1573 default_value="UNK", 1574 name=None, 1575 key_column_index=TextFileIndex.LINE_NUMBER, 1576 value_column_index=TextFileIndex.WHOLE_LINE, 1577 delimiter="\t"): 1578 """Returns a lookup table that maps a `Tensor` of indices into strings. 1579 1580 This operation constructs a lookup table to map int64 indices into string 1581 values. The table is initialized from a vocabulary file specified in 1582 `vocabulary_file`, where the whole line is the value and the 1583 zero-based line number is the index. 1584 1585 Any input which does not have a corresponding index in the vocabulary file 1586 (an out-of-vocabulary entry) is assigned the `default_value` 1587 1588 The underlying table must be initialized by calling 1589 `session.run(tf.compat.v1.tables_initializer())` or 1590 `session.run(table.init())` once. 1591 1592 To specify multi-column vocabulary files, use key_column_index and 1593 value_column_index and delimiter. 1594 1595 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 1596 expects data type int64. 1597 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 1598 type string. 1599 - A value >=0 means use the index (starting at zero) of the split line based 1600 on `delimiter`. 1601 1602 Sample Usages: 1603 1604 If we have a vocabulary file "test.txt" with the following content: 1605 1606 ``` 1607 emerson 1608 lake 1609 palmer 1610 ``` 1611 1612 ```python 1613 indices = tf.constant([1, 5], tf.int64) 1614 table = tf.lookup.index_to_string_table_from_file( 1615 vocabulary_file="test.txt", default_value="UNKNOWN") 1616 values = table.lookup(indices) 1617 ... 1618 tf.compat.v1.tables_initializer().run() 1619 1620 values.eval() ==> ["lake", "UNKNOWN"] 1621 ``` 1622 1623 Args: 1624 vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. 1625 vocab_size: Number of the elements in the vocabulary, if known. 1626 default_value: The value to use for out-of-vocabulary indices. 1627 name: A name for this op (optional). 1628 key_column_index: The column index from the text file to get the `key` 1629 values from. The default is to use the line number, starting from zero. 1630 value_column_index: The column index from the text file to get the `value` 1631 values from. The default is to use the whole line content. 1632 delimiter: The delimiter to separate fields in a line. 1633 1634 Returns: 1635 The lookup table to map a string values associated to a given index `int64` 1636 `Tensors`. 1637 1638 Raises: 1639 ValueError: when `vocabulary_file` is empty. 1640 ValueError: when `vocab_size` is invalid. 1641 """ 1642 if vocabulary_file is None or (isinstance(vocabulary_file, six.string_types) 1643 and not vocabulary_file): 1644 raise ValueError("vocabulary_file must be specified and must not be empty.") 1645 1646 if vocab_size is not None and vocab_size < 1: 1647 raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size) 1648 1649 with ops.name_scope(name, "index_to_string"): 1650 init = TextFileStringTableInitializer( 1651 vocabulary_file, 1652 vocab_size=vocab_size, 1653 name="table_init", 1654 key_column_index=key_column_index, 1655 value_column_index=value_column_index, 1656 delimiter=delimiter) 1657 1658 # TODO(yleon): Use a more efficient structure. 1659 return StaticHashTableV1(init, default_value) 1660 1661 1662def index_to_string_table_from_tensor(vocabulary_list, 1663 default_value="UNK", 1664 name=None): 1665 """Returns a lookup table that maps a `Tensor` of indices into strings. 1666 1667 This operation constructs a lookup table to map int64 indices into string 1668 values. The mapping is initialized from a string `vocabulary_list` 1-D 1669 `Tensor` where each element is a value and the corresponding index within the 1670 tensor is the key. 1671 1672 Any input which does not have a corresponding index in 'vocabulary_list' 1673 (an out-of-vocabulary entry) is assigned the `default_value` 1674 1675 The underlying table must be initialized by calling 1676 `session.run(tf.compat.v1.tables_initializer())` or 1677 `session.run(table.init())` once. 1678 1679 Elements in `vocabulary_list` cannot have duplicates, otherwise when executing 1680 the table initializer op, it will throw a `FailedPreconditionError`. 1681 1682 Sample Usages: 1683 1684 ```python 1685 vocabulary_list = tf.constant(["emerson", "lake", "palmer"]) 1686 indices = tf.constant([1, 5], tf.int64) 1687 table = tf.lookup.index_to_string_table_from_tensor( 1688 vocabulary_list, default_value="UNKNOWN") 1689 values = table.lookup(indices) 1690 ... 1691 tf.compat.v1.tables_initializer().run() 1692 1693 values.eval() ==> ["lake", "UNKNOWN"] 1694 ``` 1695 1696 Args: 1697 vocabulary_list: A 1-D string `Tensor` that specifies the strings to map 1698 from indices. 1699 default_value: The value to use for out-of-vocabulary indices. 1700 name: A name for this op (optional). 1701 1702 Returns: 1703 The lookup table to map a string values associated to a given index `int64` 1704 `Tensors`. 1705 1706 Raises: 1707 ValueError: when `vocabulary_list` is not set. 1708 """ 1709 1710 if vocabulary_list is None: 1711 raise ValueError("vocabulary_list must be specified.") 1712 1713 with ops.name_scope(name, "index_to_string"): 1714 vocabulary_list = ops.convert_to_tensor(vocabulary_list, dtypes.string) 1715 num_elements = array_ops.size(vocabulary_list) 1716 keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64) 1717 1718 init = KeyValueTensorInitializer( 1719 keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init") 1720 # TODO(yleon): Use a more efficient structure. 1721 return StaticHashTableV1(init, default_value) 1722 1723 1724class MutableHashTable(LookupInterface): 1725 """A generic mutable hash table implementation. 1726 1727 Data can be inserted by calling the insert method and removed by calling the 1728 remove method. It does not support initialization via the init method. 1729 1730 Example usage: 1731 1732 ```python 1733 table = tf.lookup.MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, 1734 default_value=-1) 1735 sess.run(table.insert(keys, values)) 1736 out = table.lookup(query_keys) 1737 print(out.eval()) 1738 ``` 1739 """ 1740 1741 def __init__(self, 1742 key_dtype, 1743 value_dtype, 1744 default_value, 1745 name="MutableHashTable", 1746 checkpoint=True): 1747 """Creates an empty `MutableHashTable` object. 1748 1749 Creates a table, the type of its keys and values are specified by key_dtype 1750 and value_dtype, respectively. 1751 1752 Args: 1753 key_dtype: the type of the key tensors. 1754 value_dtype: the type of the value tensors. 1755 default_value: The value to use if a key is missing in the table. 1756 name: A name for the operation (optional). 1757 checkpoint: if True, the contents of the table are saved to and restored 1758 from checkpoints. If `shared_name` is empty for a checkpointed table, it 1759 is shared using the table node name. 1760 1761 Returns: 1762 A `MutableHashTable` object. 1763 1764 Raises: 1765 ValueError: If checkpoint is True and no name was specified. 1766 """ 1767 self._default_value = ops.convert_to_tensor( 1768 default_value, dtype=value_dtype) 1769 self._value_shape = self._default_value.get_shape() 1770 self._checkpoint = checkpoint 1771 self._key_dtype = key_dtype 1772 self._value_dtype = value_dtype 1773 self._name = name 1774 1775 self._shared_name = None 1776 if context.executing_eagerly(): 1777 # TODO(allenl): This will leak memory due to kernel caching by the 1778 # shared_name attribute value (but is better than the alternative of 1779 # sharing everything by default when executing eagerly; hopefully creating 1780 # tables in a loop is uncommon). 1781 # TODO(rohanj): Use context.shared_name() instead. 1782 self._shared_name = "table_%d" % (ops.uid(),) 1783 super(MutableHashTable, self).__init__(key_dtype, value_dtype) 1784 1785 self._resource_handle = self._create_resource() 1786 if checkpoint: 1787 saveable = MutableHashTable._Saveable(self, name) 1788 if not context.executing_eagerly(): 1789 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 1790 1791 def _create_resource(self): 1792 # The table must be shared if checkpointing is requested for multi-worker 1793 # training to work correctly. Use the node name if no shared_name has been 1794 # explicitly specified. 1795 use_node_name_sharing = self._checkpoint and self._shared_name is None 1796 if self._default_value.get_shape().ndims == 0: 1797 table_ref = gen_lookup_ops.mutable_hash_table_v2( 1798 shared_name=self._shared_name, 1799 use_node_name_sharing=use_node_name_sharing, 1800 key_dtype=self._key_dtype, 1801 value_dtype=self._value_dtype, 1802 name=self._name) 1803 else: 1804 table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( 1805 shared_name=self._shared_name, 1806 use_node_name_sharing=use_node_name_sharing, 1807 key_dtype=self._key_dtype, 1808 value_dtype=self._value_dtype, 1809 value_shape=self._default_value.get_shape(), 1810 name=self._name) 1811 1812 if context.executing_eagerly(): 1813 self._table_name = None 1814 else: 1815 self._table_name = table_ref.op.name.split("/")[-1] 1816 return table_ref 1817 1818 @property 1819 def name(self): 1820 return self._table_name 1821 1822 def size(self, name=None): 1823 """Compute the number of elements in this table. 1824 1825 Args: 1826 name: A name for the operation (optional). 1827 1828 Returns: 1829 A scalar tensor containing the number of elements in this table. 1830 """ 1831 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 1832 with ops.colocate_with(self.resource_handle): 1833 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 1834 1835 def remove(self, keys, name=None): 1836 """Removes `keys` and its associated values from the table. 1837 1838 If a key is not present in the table, it is silently ignored. 1839 1840 Args: 1841 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 1842 key type. 1843 name: A name for the operation (optional). 1844 1845 Returns: 1846 The created Operation. 1847 1848 Raises: 1849 TypeError: when `keys` do not match the table data types. 1850 """ 1851 if keys.dtype != self._key_dtype: 1852 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 1853 (self._key_dtype, keys.dtype)) 1854 1855 with ops.name_scope(name, "%s_lookup_table_remove" % self.name, 1856 (self.resource_handle, keys, self._default_value)): 1857 op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys) 1858 1859 return op 1860 1861 def lookup(self, keys, dynamic_default_values=None, name=None): 1862 """Looks up `keys` in a table, outputs the corresponding values. 1863 1864 The `default_value` is used for keys not present in the table. 1865 1866 Args: 1867 keys: Keys to look up. Can be a tensor of any shape. Must match the 1868 table's key_dtype. 1869 dynamic_default_values: The values to use if a key is missing in the 1870 table. If None (by default), the `table.default_value` will be used. 1871 Shape of `dynamic_default_values` must be same with 1872 `table.default_value` or the lookup result tensor. 1873 In the latter case, each key will have a different default value. 1874 1875 For example: 1876 1877 ```python 1878 keys = [0, 1, 3] 1879 dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]] 1880 1881 # The key '0' will use [1, 3, 4] as default value. 1882 # The key '1' will use [2, 3, 9] as default value. 1883 # The key '3' will use [8, 3, 0] as default value. 1884 ``` 1885 1886 name: A name for the operation (optional). 1887 1888 Returns: 1889 A tensor containing the values in the same shape as `keys` using the 1890 table's value type. 1891 1892 Raises: 1893 TypeError: when `keys` do not match the table data types. 1894 """ 1895 with ops.name_scope(name, "%s_lookup_table_find" % self.name, 1896 (self.resource_handle, keys, self._default_value)): 1897 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 1898 with ops.colocate_with(self.resource_handle): 1899 values = gen_lookup_ops.lookup_table_find_v2( 1900 self.resource_handle, keys, dynamic_default_values 1901 if dynamic_default_values is not None else self._default_value) 1902 return values 1903 1904 def insert(self, keys, values, name=None): 1905 """Associates `keys` with `values`. 1906 1907 Args: 1908 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 1909 key type. 1910 values: Values to be associated with keys. Must be a tensor of the same 1911 shape as `keys` and match the table's value type. 1912 name: A name for the operation (optional). 1913 1914 Returns: 1915 The created Operation. 1916 1917 Raises: 1918 TypeError: when `keys` or `values` doesn't match the table data 1919 types. 1920 """ 1921 with ops.name_scope(name, "%s_lookup_table_insert" % self.name, 1922 [self.resource_handle, keys, values]): 1923 keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys") 1924 values = ops.convert_to_tensor(values, self._value_dtype, name="values") 1925 with ops.colocate_with(self.resource_handle): 1926 # pylint: disable=protected-access 1927 op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys, 1928 values) 1929 return op 1930 1931 def export(self, name=None): 1932 """Returns tensors of all keys and values in the table. 1933 1934 Args: 1935 name: A name for the operation (optional). 1936 1937 Returns: 1938 A pair of tensors with the first tensor containing all keys and the 1939 second tensors containing all values in the table. 1940 """ 1941 with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, 1942 [self.resource_handle]): 1943 with ops.colocate_with(self.resource_handle): 1944 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 1945 self.resource_handle, self._key_dtype, self._value_dtype) 1946 return exported_keys, exported_values 1947 1948 def _gather_saveables_for_checkpoint(self): 1949 """For object-based checkpointing.""" 1950 return { 1951 "table": 1952 functools.partial( 1953 MutableHashTable._Saveable, table=self, name=self._name, 1954 table_name=self._name) 1955 } 1956 1957 class _Saveable(BaseSaverBuilder.SaveableObject): 1958 """SaveableObject implementation for DenseHashTable.""" 1959 1960 def __init__(self, table, name, table_name=None): 1961 tensors = table.export() 1962 specs = [ 1963 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), 1964 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") 1965 ] 1966 self.table_name = table_name or name 1967 # pylint: disable=protected-access 1968 super(MutableHashTable._Saveable, self).__init__(table, specs, name) 1969 1970 def restore(self, restored_tensors, restored_shapes): 1971 del restored_shapes # unused 1972 # pylint: disable=protected-access 1973 with ops.name_scope("%s_table_restore" % self.table_name): 1974 with ops.colocate_with(self.op.resource_handle): 1975 return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle, 1976 restored_tensors[0], 1977 restored_tensors[1]) 1978 1979 1980@tf_export("lookup.experimental.DenseHashTable") 1981class DenseHashTable(LookupInterface): 1982 """A generic mutable hash table implementation using tensors as backing store. 1983 1984 Data can be inserted by calling the insert method and removed by calling the 1985 remove method. It does not support initialization via the init method. 1986 1987 It uses "open addressing" with quadratic reprobing to resolve collisions. 1988 Compared to `MutableHashTable` the insert, remove and lookup operations in a 1989 `DenseHashTable` are typically faster, but memory usage can be higher. 1990 However, `DenseHashTable` does not require additional memory for 1991 temporary tensors created during checkpointing and restore operations. 1992 1993 Example usage: 1994 1995 >>> table = tf.lookup.experimental.DenseHashTable( 1996 ... key_dtype=tf.string, 1997 ... value_dtype=tf.int64, 1998 ... default_value=-1, 1999 ... empty_key='', 2000 ... deleted_key='$') 2001 >>> keys = tf.constant(['a', 'b', 'c']) 2002 >>> values = tf.constant([0, 1, 2], dtype=tf.int64) 2003 >>> table.insert(keys, values) 2004 >>> table.remove(tf.constant(['c'])) 2005 >>> table.lookup(tf.constant(['a', 'b', 'c','d'])).numpy() 2006 array([ 0, 1, -1, -1]) 2007 """ 2008 2009 # TODO(andreasst): consider extracting common code with MutableHashTable into 2010 # a common superclass. 2011 def __init__(self, 2012 key_dtype, 2013 value_dtype, 2014 default_value, 2015 empty_key, 2016 deleted_key, 2017 initial_num_buckets=None, 2018 name="MutableDenseHashTable", 2019 checkpoint=True): 2020 """Creates an empty `DenseHashTable` object. 2021 2022 Creates a table, the type of its keys and values are specified by key_dtype 2023 and value_dtype, respectively. 2024 2025 Args: 2026 key_dtype: the type of the key tensors. 2027 value_dtype: the type of the value tensors. 2028 default_value: The value to use if a key is missing in the table. 2029 empty_key: the key to use to represent empty buckets internally. Must not 2030 be used in insert, remove or lookup operations. 2031 deleted_key: the key to use to represent deleted buckets internally. Must 2032 not be used in insert, remove or lookup operations and be different from 2033 the empty_key. 2034 initial_num_buckets: the initial number of buckets. 2035 name: A name for the operation (optional). 2036 checkpoint: if True, the contents of the table are saved to and restored 2037 from checkpoints. If `shared_name` is empty for a checkpointed table, it 2038 is shared using the table node name. 2039 2040 Returns: 2041 A `DenseHashTable` object. 2042 2043 Raises: 2044 ValueError: If checkpoint is True and no name was specified. 2045 """ 2046 self._default_value = ops.convert_to_tensor( 2047 default_value, dtype=value_dtype, name="default_value") 2048 self._key_dtype = key_dtype 2049 self._value_dtype = value_dtype 2050 self._initial_num_buckets = initial_num_buckets 2051 self._value_shape = self._default_value.get_shape() 2052 self._checkpoint = checkpoint 2053 self._name = name 2054 2055 self._empty_key = empty_key 2056 self._deleted_key = deleted_key 2057 self._shared_name = None 2058 if context.executing_eagerly(): 2059 # TODO(allenl): This will leak memory due to kernel caching by the 2060 # shared_name attribute value (but is better than the alternative of 2061 # sharing everything by default when executing eagerly; hopefully creating 2062 # tables in a loop is uncommon). 2063 # TODO(rohanj): Use context.shared_name() instead. 2064 self._shared_name = "table_%d" % (ops.uid(),) 2065 super(DenseHashTable, self).__init__(key_dtype, value_dtype) 2066 2067 self._resource_handle = self._create_resource() 2068 if checkpoint: 2069 saveable = DenseHashTable._Saveable(self, name) 2070 if not context.executing_eagerly(): 2071 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 2072 2073 def _create_resource(self): 2074 # The table must be shared if checkpointing is requested for multi-worker 2075 # training to work correctly. Use the node name if no shared_name has been 2076 # explicitly specified. 2077 use_node_name_sharing = self._checkpoint and self._shared_name is None 2078 empty_key = ops.convert_to_tensor( 2079 self._empty_key, dtype=self._key_dtype, name="empty_key") 2080 deleted_key = ops.convert_to_tensor( 2081 self._deleted_key, dtype=self._key_dtype, name="deleted_key") 2082 table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( 2083 empty_key=empty_key, 2084 deleted_key=deleted_key, 2085 shared_name=self._shared_name, 2086 use_node_name_sharing=use_node_name_sharing, 2087 value_dtype=self._value_dtype, 2088 value_shape=self._value_shape, 2089 initial_num_buckets=self._initial_num_buckets, 2090 name=self._name) 2091 if context.executing_eagerly(): 2092 self._table_name = None 2093 else: 2094 self._table_name = table_ref.op.name.split("/")[-1] 2095 return table_ref 2096 2097 @property 2098 def name(self): 2099 return self._table_name 2100 2101 def size(self, name=None): 2102 """Compute the number of elements in this table. 2103 2104 Args: 2105 name: A name for the operation (optional). 2106 2107 Returns: 2108 A scalar tensor containing the number of elements in this table. 2109 """ 2110 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 2111 with ops.colocate_with(self.resource_handle): 2112 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 2113 2114 def lookup(self, keys, name=None): 2115 """Looks up `keys` in a table, outputs the corresponding values. 2116 2117 The `default_value` is used for keys not present in the table. 2118 2119 Args: 2120 keys: Keys to look up. Can be a tensor of any shape. Must match the 2121 table's key_dtype. 2122 name: A name for the operation (optional). 2123 2124 Returns: 2125 A tensor containing the values in the same shape as `keys` using the 2126 table's value type. 2127 2128 Raises: 2129 TypeError: when `keys` do not match the table data types. 2130 """ 2131 with ops.name_scope(name, "%s_lookup_table_find" % self.name, 2132 [self.resource_handle, keys]): 2133 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 2134 with ops.colocate_with(self.resource_handle): 2135 values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys, 2136 self._default_value) 2137 2138 return values 2139 2140 def insert_or_assign(self, keys, values, name=None): 2141 """Associates `keys` with `values`. 2142 2143 Args: 2144 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 2145 key type. 2146 values: Values to be associated with keys. Must be a tensor of the same 2147 shape as `keys` and match the table's value type. 2148 name: A name for the operation (optional). 2149 2150 Returns: 2151 The created Operation. 2152 2153 Raises: 2154 TypeError: when `keys` or `values` doesn't match the table data 2155 types. 2156 """ 2157 with ops.name_scope(name, "%s_lookup_table_insert" % self.name, 2158 [self.resource_handle, keys, values]): 2159 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 2160 values = ops.convert_to_tensor( 2161 values, dtype=self._value_dtype, name="values") 2162 with ops.colocate_with(self.resource_handle): 2163 op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys, 2164 values) 2165 return op 2166 2167 def insert(self, keys, values, name=None): 2168 """Associates `keys` with `values`. 2169 2170 Args: 2171 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 2172 key type. 2173 values: Values to be associated with keys. Must be a tensor of the same 2174 shape as `keys` and match the table's value type. 2175 name: A name for the operation (optional). 2176 2177 Returns: 2178 The created Operation. 2179 2180 Raises: 2181 TypeError: when `keys` or `values` doesn't match the table data 2182 types. 2183 """ 2184 return self.insert_or_assign(keys, values, name) 2185 2186 def erase(self, keys, name=None): 2187 """Removes `keys` and its associated values from the table. 2188 2189 If a key is not present in the table, it is silently ignored. 2190 2191 Args: 2192 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 2193 key type. 2194 name: A name for the operation (optional). 2195 2196 Returns: 2197 The created Operation. 2198 2199 Raises: 2200 TypeError: when `keys` do not match the table data types. 2201 """ 2202 if keys.dtype != self._key_dtype: 2203 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 2204 (self._key_dtype, keys.dtype)) 2205 2206 with ops.name_scope(name, "%s_lookup_table_remove" % self.name, 2207 (self.resource_handle, keys, self._default_value)): 2208 # pylint: disable=protected-access 2209 op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys) 2210 2211 return op 2212 2213 def remove(self, keys, name=None): 2214 """Removes `keys` and its associated values from the table. 2215 2216 If a key is not present in the table, it is silently ignored. 2217 2218 Args: 2219 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 2220 key type. 2221 name: A name for the operation (optional). 2222 2223 Returns: 2224 The created Operation. 2225 2226 Raises: 2227 TypeError: when `keys` do not match the table data types. 2228 """ 2229 return self.erase(keys, name) 2230 2231 def export(self, name=None): 2232 """Returns tensors of all keys and values in the table. 2233 2234 Args: 2235 name: A name for the operation (optional). 2236 2237 Returns: 2238 A pair of tensors with the first tensor containing all keys and the 2239 second tensors containing all values in the table. 2240 """ 2241 with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, 2242 [self.resource_handle]): 2243 with ops.colocate_with(self.resource_handle): 2244 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 2245 self.resource_handle, self._key_dtype, self._value_dtype) 2246 2247 return exported_keys, exported_values 2248 2249 def _gather_saveables_for_checkpoint(self): 2250 """For object-based checkpointing.""" 2251 return { 2252 "table": 2253 functools.partial( 2254 DenseHashTable._Saveable, table=self, name=self._name, 2255 table_name=self._name) 2256 } 2257 2258 class _Saveable(BaseSaverBuilder.SaveableObject): 2259 """SaveableObject implementation for DenseHashTable.""" 2260 2261 def __init__(self, table, name, table_name=None): 2262 tensors = table.export() 2263 specs = [ 2264 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), 2265 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") 2266 ] 2267 self.table_name = table_name or name 2268 # pylint: disable=protected-access 2269 super(DenseHashTable._Saveable, self).__init__(table, specs, name) 2270 2271 def restore(self, restored_tensors, restored_shapes): 2272 del restored_shapes # unused 2273 # pylint: disable=protected-access 2274 with ops.name_scope("%s_table_restore" % self.table_name): 2275 with ops.colocate_with(self.op.resource_handle): 2276 return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle, 2277 restored_tensors[0], 2278 restored_tensors[1]) 2279 2280 2281ops.NotDifferentiable("LookupTableFind") 2282ops.NotDifferentiable("LookupTableFindV2") 2283ops.NotDifferentiable("LookupTableInsert") 2284ops.NotDifferentiable("LookupTableInsertV2") 2285ops.NotDifferentiable("LookupTableSize") 2286ops.NotDifferentiable("LookupTableSizeV2") 2287ops.NotDifferentiable("HashTable") 2288ops.NotDifferentiable("HashTableV2") 2289ops.NotDifferentiable("InitializeTable") 2290ops.NotDifferentiable("InitializeTableV2") 2291ops.NotDifferentiable("InitializeTableFromTextFile") 2292ops.NotDifferentiable("InitializeTableFromTextFileV2") 2293ops.NotDifferentiable("MutableDenseHashTable") 2294ops.NotDifferentiable("MutableDenseHashTableV2") 2295ops.NotDifferentiable("MutableHashTable") 2296ops.NotDifferentiable("MutableHashTableV2") 2297ops.NotDifferentiable("MutableHashTableOfTensors") 2298ops.NotDifferentiable("MutableHashTableOfTensorsV2") 2299