1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#============================================================================== 15"""Lookup operations.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23import six 24 25from tensorflow.python.compat import compat as fwd_compat 26from tensorflow.python.eager import context 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import gen_lookup_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import string_ops 38# go/tf-wildcard-import 39# pylint: disable=wildcard-import 40from tensorflow.python.ops.gen_lookup_ops import * 41from tensorflow.python.training.saver import BaseSaverBuilder 42# pylint: enable=wildcard-import 43from tensorflow.python.training.tracking import base as trackable_base 44from tensorflow.python.training.tracking import tracking as trackable 45from tensorflow.python.util import compat 46from tensorflow.python.util.deprecation import deprecated 47from tensorflow.python.util.tf_export import tf_export 48 49 50@tf_export(v1=["initialize_all_tables"]) 51@deprecated(None, "Use `tf.tables_initializer` instead.") 52def initialize_all_tables(name="init_all_tables"): 53 """Returns an Op that initializes all tables of the default graph. 54 55 Args: 56 name: Optional name for the initialization op. 57 58 Returns: 59 An Op that initializes all tables. Note that if there are 60 not tables the returned Op is a NoOp. 61 """ 62 return tables_initializer(name) 63 64 65@tf_export(v1=["initializers.tables_initializer", "tables_initializer"]) 66def tables_initializer(name="init_all_tables"): 67 """Returns an Op that initializes all tables of the default graph. 68 69 See the [Low Level Intro](https://www.tensorflow.org/guide/low_level_intro#feature_columns) 70 guide, for an example of usage. 71 72 Args: 73 name: Optional name for the initialization op. 74 75 Returns: 76 An Op that initializes all tables. Note that if there are 77 not tables the returned Op is a NoOp. 78 """ 79 initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) 80 if initializers: 81 return control_flow_ops.group(*initializers, name=name) 82 return control_flow_ops.no_op(name=name) 83 84 85def _check_table_dtypes(table, key_dtype, value_dtype): 86 """Check that the given key_dtype and value_dtype matches the table dtypes. 87 88 Args: 89 table: The table to check types against to. 90 key_dtype: The key data type to check. 91 value_dtype: The value data type to check. 92 93 Raises: 94 TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data 95 types. 96 """ 97 if key_dtype.base_dtype != table.key_dtype: 98 raise TypeError("Invalid key dtype, expected %s but got %s." % 99 (table.key_dtype, key_dtype)) 100 if value_dtype.base_dtype != table.value_dtype: 101 raise TypeError("Invalid value dtype, expected %s but got %s." % 102 (table.value_dtype, value_dtype)) 103 104 105class LookupInterface(trackable.TrackableResource): 106 """Represent a lookup table that persists across different steps.""" 107 108 def __init__(self, key_dtype, value_dtype): 109 """Construct a lookup table interface. 110 111 Args: 112 key_dtype: The table key type. 113 value_dtype: The table value type. 114 """ 115 self._key_dtype = dtypes.as_dtype(key_dtype) 116 self._value_dtype = dtypes.as_dtype(value_dtype) 117 super(LookupInterface, self).__init__() 118 119 def _create_resource(self): 120 raise NotImplementedError 121 122 @property 123 def key_dtype(self): 124 """The table key dtype.""" 125 return self._key_dtype 126 127 @property 128 def value_dtype(self): 129 """The table value dtype.""" 130 return self._value_dtype 131 132 @property 133 def name(self): 134 """The name of the table.""" 135 return NotImplementedError 136 137 def size(self, name=None): 138 """Compute the number of elements in this table.""" 139 raise NotImplementedError 140 141 def lookup(self, keys, name=None): 142 """Looks up `keys` in a table, outputs the corresponding values.""" 143 raise NotImplementedError 144 145 146class InitializableLookupTableBase(LookupInterface): 147 """Initializable lookup table interface. 148 149 An initializable lookup tables persist across different steps. 150 """ 151 152 def __init__(self, default_value, initializer): 153 """Construct a table object from a table reference. 154 155 If requires a table initializer object (subclass of `TableInitializerBase`). 156 It provides the table key and value types, as well as the op to initialize 157 the table. The caller is responsible to execute the initialization op. 158 159 Args: 160 default_value: The value to use if a key is missing in the table. 161 initializer: The table initializer to use. 162 """ 163 super(InitializableLookupTableBase, self).__init__(initializer.key_dtype, 164 initializer.value_dtype) 165 self._default_value = ops.convert_to_tensor( 166 default_value, dtype=self._value_dtype) 167 self._default_value.get_shape().merge_with(tensor_shape.scalar()) 168 if isinstance(initializer, trackable_base.Trackable): 169 self._initializer = self._track_trackable( 170 initializer, "_initializer") 171 with ops.init_scope(): 172 self._resource_handle = self._create_resource() 173 self._init_op = self._initialize() 174 175 def _initialize(self): 176 return self._initializer.initialize(self) 177 178 @property 179 def default_value(self): 180 """The default value of the table.""" 181 return self._default_value 182 183 def size(self, name=None): 184 """Compute the number of elements in this table. 185 186 Args: 187 name: A name for the operation (optional). 188 189 Returns: 190 A scalar tensor containing the number of elements in this table. 191 """ 192 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 193 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 194 195 def lookup(self, keys, name=None): 196 """Looks up `keys` in a table, outputs the corresponding values. 197 198 The `default_value` is used for keys not present in the table. 199 200 Args: 201 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 202 name: A name for the operation (optional). 203 204 Returns: 205 A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`. 206 207 Raises: 208 TypeError: when `keys` or `default_value` doesn't match the table data 209 types. 210 """ 211 key_tensor = keys 212 if isinstance(keys, sparse_tensor.SparseTensor): 213 key_tensor = keys.values 214 215 if keys.dtype.base_dtype != self._key_dtype: 216 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 217 (self._key_dtype, keys.dtype)) 218 219 with ops.name_scope( 220 name, "%s_Lookup" % self.name, 221 (self.resource_handle, key_tensor, self._default_value)): 222 values = gen_lookup_ops.lookup_table_find_v2( 223 self.resource_handle, key_tensor, self._default_value) 224 225 values.set_shape(key_tensor.get_shape()) 226 if isinstance(keys, sparse_tensor.SparseTensor): 227 return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape) 228 else: 229 return values 230 231 232class InitializableLookupTableBaseV1(InitializableLookupTableBase): 233 234 @property 235 def initializer(self): 236 return self._init_op 237 238 239@tf_export("lookup.StaticHashTable", v1=[]) 240class StaticHashTable(InitializableLookupTableBase): 241 """A generic hash table implementation. 242 243 Example usage: 244 245 ```python 246 table = tf.lookup.StaticHashTable( 247 tf.KeyValueTensorInitializer(keys, values), -1) 248 out = table.lookup(input_tensor) 249 table.init.run() 250 print(out.eval()) 251 ``` 252 """ 253 254 def __init__(self, initializer, default_value, name=None): 255 """Creates a non-initialized `HashTable` object. 256 257 Creates a table, the type of its keys and values are specified by the 258 initializer. 259 Before using the table you will have to initialize it. After initialization 260 the table will be immutable. 261 262 Args: 263 initializer: The table initializer to use. See `HashTable` kernel for 264 supported key and value types. 265 default_value: The value to use if a key is missing in the table. 266 name: A name for the operation (optional). 267 268 Returns: 269 A `HashTable` object. 270 """ 271 self._initializer = initializer 272 self._default_value = default_value 273 self._shared_name = self._initializer._shared_name # pylint: disable=protected-access 274 self._name = name or "hash_table" 275 self._table_name = None 276 super(StaticHashTable, self).__init__(default_value, initializer) 277 self._value_shape = self._default_value.get_shape() 278 279 def _create_resource(self): 280 table_ref = gen_lookup_ops.hash_table_v2( 281 shared_name=self._shared_name, 282 key_dtype=self._initializer.key_dtype, 283 value_dtype=self._initializer.value_dtype, 284 name=self._name) 285 if context.executing_eagerly(): 286 self._table_name = None 287 else: 288 self._table_name = table_ref.op.name.split("/")[-1] 289 return table_ref 290 291 @property 292 def name(self): 293 return self._table_name 294 295 def export(self, name=None): 296 """Returns tensors of all keys and values in the table. 297 298 Args: 299 name: A name for the operation (optional). 300 301 Returns: 302 A pair of tensors with the first tensor containing all keys and the 303 second tensors containing all values in the table. 304 """ 305 with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]): 306 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 307 self.resource_handle, self._key_dtype, self._value_dtype) 308 309 exported_values.set_shape(exported_keys.get_shape().concatenate( 310 self._value_shape)) 311 return exported_keys, exported_values 312 313 314@tf_export(v1=["lookup.StaticHashTable"]) 315class StaticHashTableV1(StaticHashTable): 316 317 @property 318 def initializer(self): 319 return self._init_op 320 321 322# For backwards compatibility. This will be removed in TF 2.0. 323class HashTable(StaticHashTableV1): 324 325 @property 326 def init(self): 327 return self.initializer 328 329 330class TableInitializerBase(trackable_base.Trackable): 331 """Base class for lookup table initializers.""" 332 333 def __init__(self, key_dtype, value_dtype): 334 """Construct a table initializer object. 335 336 Args: 337 key_dtype: Type of the table keys. 338 value_dtype: Type of the table values. 339 """ 340 self._key_dtype = dtypes.as_dtype(key_dtype) 341 self._value_dtype = dtypes.as_dtype(value_dtype) 342 343 @property 344 def key_dtype(self): 345 """The expected table key dtype.""" 346 return self._key_dtype 347 348 @property 349 def value_dtype(self): 350 """The expected table value dtype.""" 351 return self._value_dtype 352 353 def initialize(self, table): 354 """Returns the table initialization op.""" 355 raise NotImplementedError 356 357 @property 358 def _shared_name(self): 359 """Returns a shared name to be used by the table.""" 360 shared_name = "" 361 if context.executing_eagerly(): 362 # Ensure a unique name when eager execution is enabled to avoid spurious 363 # sharing issues. 364 # TODO(rohanj): Use context.shared_name() instead. 365 shared_name += str(ops.uid()) 366 return shared_name 367 368 369@tf_export("lookup.KeyValueTensorInitializer") 370class KeyValueTensorInitializer(TableInitializerBase): 371 """Table initializers given `keys` and `values` tensors.""" 372 373 def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None): 374 """Constructs a table initializer object based on keys and values tensors. 375 376 Args: 377 keys: The tensor for the keys. 378 values: The tensor for the values. 379 key_dtype: The `keys` data type. Used when `keys` is a python array. 380 value_dtype: The `values` data type. Used when `values` is a python array. 381 name: A name for the operation (optional). 382 """ 383 with ops.init_scope(): 384 self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") 385 self._values = ops.convert_to_tensor( 386 values, dtype=value_dtype, name="values") 387 self._name = name if name is not None else "key_value_init" 388 if context.executing_eagerly(): 389 # Ensure a unique name when eager execution is enabled to avoid spurious 390 # sharing issues. 391 # TODO(rohanj): Use context.shared_name() instead. 392 self._name += str(ops.uid()) 393 394 super(KeyValueTensorInitializer, self).__init__(self._keys.dtype, 395 self._values.dtype) 396 397 def initialize(self, table): 398 """Initializes the given `table` with `keys` and `values` tensors. 399 400 Args: 401 table: The table to initialize. 402 403 Returns: 404 The operation that initializes the table. 405 406 Raises: 407 TypeError: when the keys and values data types do not match the table 408 key and value data types. 409 """ 410 _check_table_dtypes(table, self._keys.dtype, self._values.dtype) 411 with ops.name_scope( 412 self._name, values=(table.resource_handle, self._keys, self._values)): 413 if fwd_compat.forward_compatible(2018, 9, 19): 414 init_op = gen_lookup_ops.lookup_table_import_v2( 415 table.resource_handle, self._keys, self._values) 416 else: 417 # To maintain forward compatibiltiy, use the old implementation. 418 init_op = gen_lookup_ops.initialize_table_v2(table.resource_handle, 419 self._keys, self._values) 420 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 421 return init_op 422 423 424class TextFileIndex(object): 425 WHOLE_LINE = -2 426 LINE_NUMBER = -1 427 428 429@tf_export("lookup.TextFileInitializer") 430class TextFileInitializer(TableInitializerBase): 431 """Table initializers from a text file. 432 433 This initializer assigns one entry in the table for each line in the file. 434 435 The key and value type of the table to initialize is given by `key_dtype` and 436 `value_dtype`. 437 438 The key and value content to get from each line is specified by 439 the `key_index` and `value_index`. 440 441 * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero, 442 expects data type int64. 443 * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data 444 type string. 445 * A value `>=0` means use the index (starting at zero) of the split line based 446 on `delimiter`. 447 448 For example if we have a file with the following content: 449 450 ``` 451 emerson 10 452 lake 20 453 palmer 30 454 ``` 455 456 The following snippet initializes a table with the first column as keys and 457 second column as values: 458 459 * `emerson -> 10` 460 * `lake -> 20` 461 * `palmer -> 30` 462 463 ```python 464 table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer( 465 "test.txt", tf.string, 0, tf.int64, 1, delimiter=" "), -1) 466 ... 467 table.init.run() 468 ``` 469 470 Similarly to initialize the whole line as keys and the line number as values. 471 472 * `emerson 10 -> 0` 473 * `lake 20 -> 1` 474 * `palmer 30 -> 2` 475 476 ```python 477 table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer( 478 "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE, 479 tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1) 480 ... 481 table.init.run() 482 ``` 483 """ 484 485 def __init__(self, 486 filename, 487 key_dtype, 488 key_index, 489 value_dtype, 490 value_index, 491 vocab_size=None, 492 delimiter="\t", 493 name=None): 494 """Constructs a table initializer object to populate from a text file. 495 496 It generates one key-value pair per line. The type of table key and 497 value are specified by `key_dtype` and `value_dtype`, respectively. 498 Similarly the content of the key and value are specified by the key_index 499 and value_index. 500 501 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 502 expects data type int64. 503 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 504 type string. 505 - A value >=0 means use the index (starting at zero) of the split line based 506 on `delimiter`. 507 508 Args: 509 filename: The filename of the text file to be used for initialization. 510 The path must be accessible from wherever the graph is initialized 511 (eg. trainer or eval workers). The filename may be a scalar `Tensor`. 512 key_dtype: The `key` data type. 513 key_index: the index that represents information of a line to get the 514 table 'key' values from. 515 value_dtype: The `value` data type. 516 value_index: the index that represents information of a line to get the 517 table 'value' values from.' 518 vocab_size: The number of elements in the file, if known. 519 delimiter: The delimiter to separate fields in a line. 520 name: A name for the operation (optional). 521 522 Raises: 523 ValueError: when the filename is empty, or when the table key and value 524 data types do not match the expected data types. 525 """ 526 if not isinstance(filename, ops.Tensor) and not filename: 527 raise ValueError("Filename required for %s." % name) 528 529 self._filename_arg = filename 530 key_dtype = dtypes.as_dtype(key_dtype) 531 value_dtype = dtypes.as_dtype(value_dtype) 532 533 if key_index < -2: 534 raise ValueError("Invalid key index %s." % (key_index)) 535 536 if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64: 537 raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." % 538 (dtypes.int64, key_dtype)) 539 if ((key_index == TextFileIndex.WHOLE_LINE) and 540 (not key_dtype.is_integer) and (key_dtype != dtypes.string)): 541 raise ValueError( 542 "Signature mismatch. Keys must be integer or string, got %s." % 543 key_dtype) 544 if value_index < -2: 545 raise ValueError("Invalid value index %s." % (value_index)) 546 547 if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64: 548 raise ValueError("Signature mismatch. Values must be dtype %s, got %s." % 549 (dtypes.int64, value_dtype)) 550 if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string: 551 raise ValueError("Signature mismatch. Values must be dtype %s, got %s." % 552 (dtypes.string, value_dtype)) 553 554 if (vocab_size is not None) and (vocab_size <= 0): 555 raise ValueError("Invalid vocab_size %s." % vocab_size) 556 557 self._key_index = key_index 558 self._value_index = value_index 559 self._vocab_size = vocab_size 560 self._delimiter = delimiter 561 self._name = name 562 self._filename = self._track_trackable( 563 trackable.TrackableAsset(filename), 564 "_filename") 565 566 super(TextFileInitializer, self).__init__(key_dtype, value_dtype) 567 568 def initialize(self, table): 569 """Initializes the table from a text file. 570 571 Args: 572 table: The table to be initialized. 573 574 Returns: 575 The operation that initializes the table. 576 577 Raises: 578 TypeError: when the keys and values data types do not match the table 579 key and value data types. 580 """ 581 _check_table_dtypes(table, self.key_dtype, self.value_dtype) 582 with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)): 583 filename = ops.convert_to_tensor( 584 self._filename, dtypes.string, name="asset_filepath") 585 init_op = gen_lookup_ops.initialize_table_from_text_file_v2( 586 table.resource_handle, filename, self._key_index, self._value_index, 587 -1 if self._vocab_size is None else self._vocab_size, self._delimiter) 588 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 589 # If the filename tensor is anything other than a string constant (e.g., 590 # if it is a placeholder) then it does not make sense to track it as an 591 # asset. 592 if not context.executing_eagerly() and constant_op.is_constant(filename): 593 ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename) 594 return init_op 595 596 @property 597 def _shared_name(self): 598 if self._vocab_size: 599 # Keep the shared_name: 600 # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index> 601 shared_name = "hash_table_%s_%d_%s_%s" % ( 602 self._filename_arg, self._vocab_size, self._key_index, 603 self._value_index) 604 else: 605 # Keep the shared_name 606 # <table_type>_<filename>_<key_index>_<value_index> 607 shared_name = "hash_table_%s_%s_%s" % (self._filename_arg, 608 self._key_index, self._value_index) 609 return shared_name 610 611 612class TextFileStringTableInitializer(TextFileInitializer): 613 """Table initializer for `int64` IDs to string tables from a text file.""" 614 615 def __init__(self, 616 filename, 617 key_column_index=TextFileIndex.LINE_NUMBER, 618 value_column_index=TextFileIndex.WHOLE_LINE, 619 vocab_size=None, 620 delimiter="\t", 621 name="text_file_string_table_init"): 622 """Constructs an initializer for an id-to-string table from a text file. 623 624 It populates a table that its key and value types are int64 and string, 625 respectively. It generates one key-value pair per line. 626 The content of the key and value are specified by `key_column_index` 627 and `value_column_index`. 628 629 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 630 expects data type int64. 631 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 632 type string. 633 - A value >=0 means use the index (starting at zero) of the split line based 634 on `delimiter`. 635 636 Args: 637 filename: The filename of the text file to be used for initialization. 638 The path must be accessible from wherever the graph is initialized 639 (eg. trainer or eval workers). The filename may be a scalar `Tensor`. 640 key_column_index: The column index from the text file to get the keys 641 from. The default is to use the line number, starting from zero. 642 value_column_index: The column index from the text file to get the 643 values from. The default is to use the whole line content. 644 vocab_size: The number of elements in the file, if known. 645 delimiter: The delimiter to separate fields in a line. 646 name: Optional name for the op. 647 648 Raises: 649 TypeError: when the filename is empty, or when the table key and value 650 data types do not match the expected data types. 651 """ 652 super(TextFileStringTableInitializer, self).__init__( 653 filename, 654 dtypes.int64, 655 key_column_index, 656 dtypes.string, 657 value_column_index, 658 vocab_size=vocab_size, 659 delimiter=delimiter, 660 name=name) 661 662 663class TextFileIdTableInitializer(TextFileInitializer): 664 """Table initializer for string to `int64` IDs tables from a text file.""" 665 666 def __init__(self, 667 filename, 668 key_column_index=TextFileIndex.WHOLE_LINE, 669 value_column_index=TextFileIndex.LINE_NUMBER, 670 vocab_size=None, 671 delimiter="\t", 672 name="text_file_id_table_init", 673 key_dtype=dtypes.string): 674 """Constructs an initializer for an string-to-id table from a text file. 675 676 It populates a table that its key and value types are string and int64, 677 respectively. It generates one key-value pair per line. 678 The content of the key and value are specified by the key_index 679 and value_index. 680 681 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 682 expects data type int64. 683 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 684 type string. 685 - A value >=0 means use the index (starting at zero) of the split line based 686 on `delimiter`. 687 688 Args: 689 filename: The filename of the text file to be used for initialization. 690 The path must be accessible from wherever the graph is initialized 691 (eg. trainer or eval workers). The filename may be a scalar `Tensor`. 692 key_column_index: The column index from the text file to get the `key` 693 values from. The default is to use the whole line content. 694 value_column_index: The column index from the text file to get the `value` 695 values from. The default is to use the line number, starting from zero. 696 vocab_size: The number of elements in the file, if known. 697 delimiter: The delimiter to separate fields in a line. 698 name: Optional name for the op. 699 key_dtype: The `key` data type. 700 701 Raises: 702 TypeError: when the filename is empty, or when the table key and value 703 data types do not match the expected data types. 704 """ 705 super(TextFileIdTableInitializer, self).__init__( 706 filename, 707 key_dtype, 708 key_column_index, 709 dtypes.int64, 710 value_column_index, 711 vocab_size=vocab_size, 712 delimiter=delimiter, 713 name=name) 714 715 716class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])): 717 """A structure for the spec of the hashing function to use for hash buckets. 718 719 `hasher` is the name of the hashing function to use (eg. "fasthash", 720 "stronghash"). 721 `key` is optional and specify the key to use for the hash function if 722 supported, currently only used by a strong hash. 723 724 Fields: 725 hasher: The hasher name to use. 726 key: The key to be used by the hashing function, if required. 727 """ 728 __slots__ = () 729 730 731FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name 732 733 734class StrongHashSpec(HasherSpec): 735 """A structure to specify a key of the strong keyed hash spec. 736 737 The strong hash requires a `key`, which is a list of 2 unsigned integer 738 numbers. These should be non-zero; random numbers generated from random.org 739 would be a fine choice. 740 741 Fields: 742 key: The key to be used by the keyed hashing function. 743 """ 744 __slots__ = () 745 746 def __new__(cls, key): 747 if len(key) != 2: 748 raise ValueError("key must have size 2, got %s." % len(key)) 749 750 if not isinstance(key[0], compat.integral_types) or not isinstance( 751 key[1], compat.integral_types): 752 raise TypeError("Invalid key %s. Must be unsigned integer values." % key) 753 754 return super(cls, StrongHashSpec).__new__(cls, "stronghash", key) 755 756 757def _as_string(tensor): 758 if dtypes.string == tensor.dtype.base_dtype: 759 return tensor 760 return string_ops.as_string(tensor) 761 762 763class IdTableWithHashBuckets(LookupInterface): 764 """String to Id table wrapper that assigns out-of-vocabulary keys to buckets. 765 766 For example, if an instance of `IdTableWithHashBuckets` is initialized with a 767 string-to-id table that maps: 768 769 * `emerson -> 0` 770 * `lake -> 1` 771 * `palmer -> 2` 772 773 The `IdTableWithHashBuckets` object will performs the following mapping: 774 775 * `emerson -> 0` 776 * `lake -> 1` 777 * `palmer -> 2` 778 * `<other term> -> bucket_id`, where bucket_id will be between `3` and 779 `3 + num_oov_buckets - 1`, calculated by: 780 `hash(<term>) % num_oov_buckets + vocab_size` 781 782 If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`, 783 the lookup result is `[0, 1, 2, 4, 7]`. 784 785 If `table` is None, only out-of-vocabulary buckets are used. 786 787 Example usage: 788 789 ```python 790 num_oov_buckets = 3 791 input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"]) 792 table = tf.IdTableWithHashBuckets( 793 tf.StaticHashTable(tf.TextFileIdTableInitializer(filename), 794 default_value), 795 num_oov_buckets) 796 out = table.lookup(input_tensor). 797 table.init.run() 798 print(out.eval()) 799 ``` 800 801 The hash function used for generating out-of-vocabulary buckets ID is handled 802 by `hasher_spec`. 803 """ 804 805 def __init__(self, 806 table, 807 num_oov_buckets, 808 hasher_spec=FastHashSpec, 809 name=None, 810 key_dtype=None): 811 """Construct a `IdTableWithHashBuckets` object. 812 813 Args: 814 table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids. 815 num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. 816 hasher_spec: A `HasherSpec` to specify the hash function to use for 817 assignation of out-of-vocabulary buckets (optional). 818 name: A name for the operation (optional). 819 key_dtype: Data type of keys passed to `lookup`. Defaults to 820 `table.key_dtype` if `table` is specified, otherwise `tf.string`. 821 Must be string or integer, and must be castable to `table.key_dtype`. 822 823 Raises: 824 ValueError: when `table` in None and `num_oov_buckets` is not positive. 825 TypeError: when `hasher_spec` is invalid. 826 """ 827 # If a name ends with a '/' it is a "name scope", remove all trailing '/' 828 # characters to use as table name. 829 if name: 830 name = name.rstrip("/") 831 if table: 832 if key_dtype is None: 833 key_dtype = table.key_dtype 834 supported_table_key_dtypes = (dtypes.int64, dtypes.string) 835 if table.key_dtype not in supported_table_key_dtypes: 836 raise TypeError("Invalid key dtype, expected one of %s, but got %s." % 837 (supported_table_key_dtypes, key_dtype)) 838 if table.key_dtype.is_integer != key_dtype.is_integer: 839 raise TypeError("Invalid key dtype, expected %s but got %s." % 840 ("integer" if key_dtype.is_integer else "non-integer", 841 table.key_dtype)) 842 if table.value_dtype != dtypes.int64: 843 raise TypeError("Invalid value dtype, expected %s but got %s." % 844 (dtypes.int64, table.value_dtype)) 845 self._table = table 846 name = name or self._table.name 847 else: 848 if num_oov_buckets <= 0: 849 raise ValueError("oov_buckets must be > 0 if no table is supplied.") 850 key_dtype = dtypes.string if key_dtype is None else key_dtype 851 self._table = None 852 name = name or "hash_bucket" 853 if (not key_dtype.is_integer) and (dtypes.string != key_dtype): 854 raise TypeError( 855 "Invalid key_dtype, expected integer or string, got %s." % key_dtype) 856 self._num_oov_buckets = num_oov_buckets 857 858 if not isinstance(hasher_spec, HasherSpec): 859 raise TypeError( 860 "hasher_spec must be of type HasherSpec, got %s" % hasher_spec) 861 self._hasher_spec = hasher_spec 862 if name: 863 self._table_name = name.split("/")[-1] 864 else: 865 self._table_name = None 866 super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64) 867 868 def _create_resource(self): 869 if self._table is not None: 870 return self._table._create_resource() # pylint: disable=protected-access 871 return None 872 873 def _initialize(self): 874 if self._table is not None: 875 return self._table._initialize() # pylint: disable=protected-access 876 with ops.name_scope(None, "init"): 877 return control_flow_ops.no_op() 878 879 @property 880 def initializer(self): 881 if self._table is not None: 882 return self._table._init_op # pylint: disable=protected-access 883 with ops.name_scope(None, "init"): 884 return control_flow_ops.no_op() 885 886 @property 887 @deprecated("2018-12-15", "Use `initializer` instead.") 888 def init(self): 889 return self.initializer 890 891 @property 892 def resource_handle(self): 893 if self._table is not None: 894 return self._table.resource_handle 895 return None 896 897 @property 898 def name(self): 899 return self._table_name 900 901 def size(self, name=None): 902 """Compute the number of elements in this table.""" 903 with ops.name_scope(name, "%s_Size" % self.name): 904 if self._table: 905 tsize = self._table.size() 906 else: 907 tsize = ops.convert_to_tensor(0, dtype=dtypes.int64) 908 return tsize + self._num_oov_buckets 909 910 def _get_string_to_hash_bucket_fn(self, hasher_spec): 911 """Returns the string_to_hash_bucket op to use based on `hasher_spec`.""" 912 if not isinstance(hasher_spec, HasherSpec): 913 raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec) 914 if hasher_spec.hasher == "fasthash": 915 return string_ops.string_to_hash_bucket_fast 916 if hasher_spec.hasher == "legacy": 917 return string_ops.string_to_hash_bucket 918 if hasher_spec.hasher == "stronghash": 919 return functools.partial( 920 string_ops.string_to_hash_bucket_strong, key=hasher_spec.key) 921 raise ValueError("Unknown hasher %s" % hasher_spec.hasher) 922 923 def lookup(self, keys, name=None): 924 """Looks up `keys` in the table, outputs the corresponding values. 925 926 It assigns out-of-vocabulary keys to buckets based in their hashes. 927 928 Args: 929 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 930 name: Optional name for the op. 931 932 Returns: 933 A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`. 934 935 Raises: 936 TypeError: when `keys` doesn't match the table key data type. 937 """ 938 if keys.dtype.base_dtype != self._key_dtype: 939 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 940 (self._key_dtype, keys.dtype)) 941 values = keys 942 if isinstance(keys, sparse_tensor.SparseTensor): 943 values = keys.values 944 if self._table and (self._table.key_dtype.base_dtype == dtypes.int64): 945 values = math_ops.cast(values, dtypes.int64) 946 947 if self._num_oov_buckets == 0: 948 ids = self._table.lookup(values, name=name) 949 else: 950 # TODO(yleon): Consider moving this functionality to its own kernel. 951 with ops.name_scope(name, "%s_Lookup" % self.name): 952 str_to_hash_bucket = self._get_string_to_hash_bucket_fn( 953 self._hasher_spec) 954 buckets = str_to_hash_bucket( 955 _as_string(values), 956 num_buckets=self._num_oov_buckets, 957 name="hash_bucket") 958 if self._table: 959 ids = self._table.lookup(values) 960 buckets = math_ops.add(buckets, self._table.size()) 961 is_id_non_default = math_ops.not_equal(ids, self._table.default_value) 962 ids = array_ops.where(is_id_non_default, ids, buckets) 963 else: 964 ids = buckets 965 if isinstance(keys, sparse_tensor.SparseTensor): 966 return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape) 967 return ids 968 969 970@tf_export("lookup.StaticVocabularyTable", v1=[]) 971class StaticVocabularyTable(LookupInterface): 972 """String to Id table wrapper that assigns out-of-vocabulary keys to buckets. 973 974 For example, if an instance of `StaticVocabularyTable` is initialized with a 975 string-to-id initializer that maps: 976 977 * `emerson -> 0` 978 * `lake -> 1` 979 * `palmer -> 2` 980 981 The `Vocabulary` object will performs the following mapping: 982 983 * `emerson -> 0` 984 * `lake -> 1` 985 * `palmer -> 2` 986 * `<other term> -> bucket_id`, where bucket_id will be between `3` and 987 `3 + num_oov_buckets - 1`, calculated by: 988 `hash(<term>) % num_oov_buckets + vocab_size` 989 990 If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`, 991 the lookup result is `[0, 1, 2, 4, 7]`. 992 993 If `initializer` is None, only out-of-vocabulary buckets are used. 994 995 Example usage: 996 997 ```python 998 num_oov_buckets = 3 999 input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"]) 1000 table = tf.lookup.StaticVocabularyTable( 1001 tf.TextFileIdTableInitializer(filename), num_oov_buckets) 1002 out = table.lookup(input_tensor). 1003 table.init.run() 1004 print(out.eval()) 1005 ``` 1006 1007 The hash function used for generating out-of-vocabulary buckets ID is 1008 Fingerprint64. 1009 """ 1010 1011 def __init__(self, 1012 initializer, 1013 num_oov_buckets, 1014 lookup_key_dtype=None, 1015 name=None): 1016 """Construct a `StaticVocabularyTable` object. 1017 1018 Args: 1019 initializer: A TableInitializerBase object that contains the data used to 1020 initialize the table. If None, then we only use out-of-vocab buckets. 1021 num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. Must 1022 be greater than zero. 1023 lookup_key_dtype: Data type of keys passed to `lookup`. Defaults to 1024 `initializer.key_dtype` if `initializer` is specified, otherwise 1025 `tf.string`. Must be string or integer, and must be castable to 1026 `initializer.key_dtype`. 1027 name: A name for the operation (optional). 1028 1029 Raises: 1030 ValueError: when `num_oov_buckets` is not positive. 1031 TypeError: when lookup_key_dtype or initializer.key_dtype are not 1032 integer or string. Also when initializer.value_dtype != int64. 1033 """ 1034 if num_oov_buckets <= 0: 1035 raise ValueError("oov_buckets must be > 0.") 1036 # If a name ends with a '/' it is a "name scope", remove all trailing '/' 1037 # characters to use as table name. 1038 if name: 1039 name = name.rstrip("/") 1040 if initializer: 1041 if lookup_key_dtype is None: 1042 lookup_key_dtype = initializer.key_dtype 1043 supported_table_key_dtypes = (dtypes.int64, dtypes.string) 1044 if initializer.key_dtype not in supported_table_key_dtypes: 1045 raise TypeError("Invalid key dtype, expected one of %s, but got %s." % 1046 (supported_table_key_dtypes, initializer.key_dtype)) 1047 if initializer.key_dtype.is_integer != lookup_key_dtype.is_integer: 1048 raise TypeError( 1049 "Invalid key dtype, expected %s but got %s." % 1050 ("integer" if lookup_key_dtype.is_integer else "non-integer", 1051 initializer.key_dtype)) 1052 if initializer.value_dtype != dtypes.int64: 1053 raise TypeError("Invalid value dtype, expected %s but got %s." % 1054 (dtypes.int64, initializer.value_dtype)) 1055 self._table = HashTable(initializer, default_value=-1) 1056 name = name or self._table.name 1057 else: 1058 lookup_key_dtype = dtypes.string 1059 self._table = None 1060 name = name or "hash_bucket" 1061 if (not lookup_key_dtype.is_integer) and (dtypes.string != 1062 lookup_key_dtype): 1063 raise TypeError("Invalid key_dtype, expected integer or string, got %s." % 1064 lookup_key_dtype) 1065 self._num_oov_buckets = num_oov_buckets 1066 1067 self._table_name = None 1068 if name is not None: 1069 self._table_name = name.split("/")[-1] 1070 super(StaticVocabularyTable, self).__init__(lookup_key_dtype, dtypes.int64) 1071 1072 def _create_resource(self): 1073 if self._table is not None: 1074 return self._table._create_resource() # pylint: disable=protected-access 1075 return None 1076 1077 def _initialize(self): 1078 if self._table is not None: 1079 return self._table._initialize() # pylint: disable=protected-access 1080 with ops.name_scope(None, "init"): 1081 return control_flow_ops.no_op() 1082 1083 @property 1084 def resource_handle(self): 1085 if self._table is not None: 1086 return self._table.resource_handle 1087 return None 1088 1089 @property 1090 def name(self): 1091 return self._table_name 1092 1093 def size(self, name=None): 1094 """Compute the number of elements in this table.""" 1095 with ops.name_scope(name, "%s_Size" % self.name): 1096 if self._table: 1097 tsize = self._table.size() 1098 else: 1099 tsize = ops.convert_to_tensor(0, dtype=dtypes.int64) 1100 return tsize + self._num_oov_buckets 1101 1102 def lookup(self, keys, name=None): 1103 """Looks up `keys` in the table, outputs the corresponding values. 1104 1105 It assigns out-of-vocabulary keys to buckets based in their hashes. 1106 1107 Args: 1108 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. 1109 name: Optional name for the op. 1110 1111 Returns: 1112 A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`. 1113 1114 Raises: 1115 TypeError: when `keys` doesn't match the table key data type. 1116 """ 1117 if keys.dtype.base_dtype != self._key_dtype: 1118 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 1119 (self._key_dtype, keys.dtype)) 1120 values = keys 1121 if isinstance(keys, sparse_tensor.SparseTensor): 1122 values = keys.values 1123 if self._table and (self._table.key_dtype.base_dtype == dtypes.int64): 1124 values = math_ops.cast(values, dtypes.int64) 1125 1126 # TODO(yleon): Consider moving this functionality to its own kernel. 1127 with ops.name_scope(name, "%s_Lookup" % self.name): 1128 buckets = string_ops.string_to_hash_bucket_fast( 1129 _as_string(values), 1130 num_buckets=self._num_oov_buckets, 1131 name="hash_bucket") 1132 if self._table: 1133 ids = self._table.lookup(values) 1134 buckets = math_ops.add(buckets, self._table.size()) 1135 is_id_non_default = math_ops.not_equal(ids, self._table.default_value) 1136 ids = array_ops.where(is_id_non_default, ids, buckets) 1137 else: 1138 ids = buckets 1139 if isinstance(keys, sparse_tensor.SparseTensor): 1140 return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape) 1141 return ids 1142 1143 1144@tf_export(v1=["lookup.StaticVocabularyTable"]) 1145class StaticVocabularyTableV1(StaticVocabularyTable): 1146 1147 @property 1148 def initializer(self): 1149 if self._table is not None: 1150 return self._table._init_op # pylint: disable=protected-access 1151 with ops.name_scope(None, "init"): 1152 return control_flow_ops.no_op() 1153 1154 1155def index_table_from_file(vocabulary_file=None, 1156 num_oov_buckets=0, 1157 vocab_size=None, 1158 default_value=-1, 1159 hasher_spec=FastHashSpec, 1160 key_dtype=dtypes.string, 1161 name=None, 1162 key_column_index=TextFileIndex.WHOLE_LINE, 1163 value_column_index=TextFileIndex.LINE_NUMBER, 1164 delimiter="\t"): 1165 """Returns a lookup table that converts a string tensor into int64 IDs. 1166 1167 This operation constructs a lookup table to convert tensor of strings into 1168 int64 IDs. The mapping can be initialized from a vocabulary file specified in 1169 `vocabulary_file`, where the whole line is the key and the zero-based line 1170 number is the ID. 1171 1172 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 1173 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 1174 `default_value`. 1175 The bucket ID range is 1176 `[vocabulary size, vocabulary size + num_oov_buckets - 1]`. 1177 1178 The underlying table must be initialized by calling 1179 `session.run(tf.tables_initializer)` or `session.run(table.init)` once. 1180 1181 To specify multi-column vocabulary files, use key_column_index and 1182 value_column_index and delimiter. 1183 1184 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 1185 expects data type int64. 1186 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 1187 type string. 1188 - A value >=0 means use the index (starting at zero) of the split line based 1189 on `delimiter`. 1190 1191 Sample Usages: 1192 1193 If we have a vocabulary file "test.txt" with the following content: 1194 1195 ``` 1196 emerson 1197 lake 1198 palmer 1199 ``` 1200 1201 ```python 1202 features = tf.constant(["emerson", "lake", "and", "palmer"]) 1203 table = tf.lookup.index_table_from_file( 1204 vocabulary_file="test.txt", num_oov_buckets=1) 1205 ids = table.lookup(features) 1206 ... 1207 tf.tables_initializer().run() 1208 1209 ids.eval() ==> [0, 1, 3, 2] # where 3 is the out-of-vocabulary bucket 1210 ``` 1211 1212 Args: 1213 vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. 1214 num_oov_buckets: The number of out-of-vocabulary buckets. 1215 vocab_size: Number of the elements in the vocabulary, if known. 1216 default_value: The value to use for out-of-vocabulary feature values. 1217 Defaults to -1. 1218 hasher_spec: A `HasherSpec` to specify the hash function to use for 1219 assignation of out-of-vocabulary buckets. 1220 key_dtype: The `key` data type. 1221 name: A name for this op (optional). 1222 key_column_index: The column index from the text file to get the `key` 1223 values from. The default is to use the whole line content. 1224 value_column_index: The column index from the text file to get the `value` 1225 values from. The default is to use the line number, starting from zero. 1226 delimiter: The delimiter to separate fields in a line. 1227 1228 Returns: 1229 The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`. 1230 1231 Raises: 1232 ValueError: If `vocabulary_file` is not set. 1233 ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater 1234 than zero. 1235 """ 1236 if vocabulary_file is None or ( 1237 isinstance(vocabulary_file, six.string_types) and not vocabulary_file): 1238 raise ValueError("vocabulary_file must be specified and must not be empty.") 1239 if num_oov_buckets < 0: 1240 raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." 1241 % num_oov_buckets) 1242 if vocab_size is not None and vocab_size < 1: 1243 vocab_file_value = vocabulary_file 1244 if isinstance(vocabulary_file, ops.Tensor): 1245 vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?" 1246 raise ValueError("vocab_size must be greater than 0, got %d. " 1247 "vocabulary_file: %s" % (vocab_size, vocab_file_value)) 1248 if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): 1249 raise TypeError("Only integer and string keys are supported.") 1250 1251 with ops.name_scope(name, "string_to_index"): 1252 table = None 1253 with ops.name_scope(None, "hash_table"): 1254 init = TextFileIdTableInitializer( 1255 vocabulary_file, 1256 vocab_size=vocab_size, 1257 key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype, 1258 name="table_init", 1259 key_column_index=key_column_index, 1260 value_column_index=value_column_index, 1261 delimiter=delimiter) 1262 1263 table = StaticHashTableV1(init, default_value) 1264 if num_oov_buckets: 1265 table = IdTableWithHashBuckets( 1266 table, 1267 num_oov_buckets=num_oov_buckets, 1268 hasher_spec=hasher_spec, 1269 key_dtype=key_dtype) 1270 1271 return table 1272 1273 1274def index_table_from_tensor(vocabulary_list, 1275 num_oov_buckets=0, 1276 default_value=-1, 1277 hasher_spec=FastHashSpec, 1278 dtype=dtypes.string, 1279 name=None): 1280 """Returns a lookup table that converts a string tensor into int64 IDs. 1281 1282 This operation constructs a lookup table to convert tensor of strings into 1283 int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D 1284 tensor where each element is a key and corresponding index within the tensor 1285 is the value. 1286 1287 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 1288 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 1289 `default_value`. The bucket ID range is 1290 `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`. 1291 1292 The underlying table must be initialized by calling 1293 `session.run(tf.tables_initializer)` or `session.run(table.init)` once. 1294 1295 Elements in `vocabulary_list` cannot have duplicates, otherwise when executing 1296 the table initializer op, it will throw a `FailedPreconditionError`. 1297 1298 Sample Usages: 1299 1300 ```python 1301 vocabulary_list = tf.constant(["emerson", "lake", "palmer"]) 1302 table = tf.lookup.index_table_from_tensor( 1303 vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1) 1304 features = tf.constant(["emerson", "lake", "and", "palmer"]) 1305 ids = table.lookup(features) 1306 ... 1307 tf.tables_initializer().run() 1308 1309 ids.eval() ==> [0, 1, 4, 2] 1310 ``` 1311 1312 Args: 1313 vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to 1314 indices. The type of this object must be castable to `dtype`. 1315 num_oov_buckets: The number of out-of-vocabulary buckets. 1316 default_value: The value to use for out-of-vocabulary feature values. 1317 Defaults to -1. 1318 hasher_spec: A `HasherSpec` to specify the hash function to use for 1319 assignment of out-of-vocabulary buckets. 1320 dtype: The type of values passed to `lookup`. Only string and integers are 1321 supported. 1322 name: A name for this op (optional). 1323 1324 Returns: 1325 The lookup table to map an input `Tensor` to index `int64` `Tensor`. 1326 1327 Raises: 1328 ValueError: If `vocabulary_list` is invalid. 1329 ValueError: If `num_oov_buckets` is negative. 1330 """ 1331 if vocabulary_list is None: 1332 raise ValueError("vocabulary_list must be specified.") 1333 1334 if num_oov_buckets < 0: 1335 raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." 1336 % num_oov_buckets) 1337 1338 if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype): 1339 raise TypeError("Only integer and string keys are supported.") 1340 1341 with ops.name_scope(name, "string_to_index"): 1342 keys = ops.convert_to_tensor(vocabulary_list) 1343 if keys.dtype.is_integer != dtype.is_integer: 1344 raise ValueError("Expected %s, got %s." % 1345 ("integer" 1346 if dtype.is_integer else "non-integer", keys.dtype)) 1347 if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype): 1348 raise ValueError("Expected %s, got %s." % (dtype, keys.dtype)) 1349 num_elements = array_ops.size(keys) 1350 values = math_ops.cast(math_ops.range(num_elements), dtypes.int64) 1351 1352 with ops.name_scope(None, "hash_table"): 1353 table_keys = math_ops.cast( 1354 keys, dtypes.int64) if keys.dtype.is_integer else keys 1355 init = KeyValueTensorInitializer( 1356 table_keys, 1357 values, 1358 table_keys.dtype.base_dtype, 1359 dtypes.int64, 1360 name="table_init") 1361 table = StaticHashTableV1(init, default_value) 1362 if num_oov_buckets: 1363 table = IdTableWithHashBuckets( 1364 table, 1365 num_oov_buckets=num_oov_buckets, 1366 hasher_spec=hasher_spec, 1367 key_dtype=dtype) 1368 return table 1369 1370 1371def index_to_string_table_from_file(vocabulary_file, 1372 vocab_size=None, 1373 default_value="UNK", 1374 name=None, 1375 key_column_index=TextFileIndex.LINE_NUMBER, 1376 value_column_index=TextFileIndex.WHOLE_LINE, 1377 delimiter="\t"): 1378 """Returns a lookup table that maps a `Tensor` of indices into strings. 1379 1380 This operation constructs a lookup table to map int64 indices into string 1381 values. The table is initialized from a vocabulary file specified in 1382 `vocabulary_file`, where the whole line is the value and the 1383 zero-based line number is the index. 1384 1385 Any input which does not have a corresponding index in the vocabulary file 1386 (an out-of-vocabulary entry) is assigned the `default_value` 1387 1388 The underlying table must be initialized by calling 1389 `session.run(tf.tables_initializer)` or `session.run(table.init)` once. 1390 1391 To specify multi-column vocabulary files, use key_column_index and 1392 value_column_index and delimiter. 1393 1394 - TextFileIndex.LINE_NUMBER means use the line number starting from zero, 1395 expects data type int64. 1396 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data 1397 type string. 1398 - A value >=0 means use the index (starting at zero) of the split line based 1399 on `delimiter`. 1400 1401 Sample Usages: 1402 1403 If we have a vocabulary file "test.txt" with the following content: 1404 1405 ``` 1406 emerson 1407 lake 1408 palmer 1409 ``` 1410 1411 ```python 1412 indices = tf.constant([1, 5], tf.int64) 1413 table = tf.lookup.index_to_string_table_from_file( 1414 vocabulary_file="test.txt", default_value="UNKNOWN") 1415 values = table.lookup(indices) 1416 ... 1417 tf.tables_initializer().run() 1418 1419 values.eval() ==> ["lake", "UNKNOWN"] 1420 ``` 1421 1422 Args: 1423 vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. 1424 vocab_size: Number of the elements in the vocabulary, if known. 1425 default_value: The value to use for out-of-vocabulary indices. 1426 name: A name for this op (optional). 1427 key_column_index: The column index from the text file to get the `key` 1428 values from. The default is to use the line number, starting from zero. 1429 value_column_index: The column index from the text file to get the `value` 1430 values from. The default is to use the whole line content. 1431 delimiter: The delimiter to separate fields in a line. 1432 1433 Returns: 1434 The lookup table to map a string values associated to a given index `int64` 1435 `Tensors`. 1436 1437 Raises: 1438 ValueError: when `vocabulary_file` is empty. 1439 ValueError: when `vocab_size` is invalid. 1440 """ 1441 if vocabulary_file is None or ( 1442 isinstance(vocabulary_file, six.string_types) and not vocabulary_file): 1443 raise ValueError("vocabulary_file must be specified and must not be empty.") 1444 1445 if vocab_size is not None and vocab_size < 1: 1446 raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size) 1447 1448 with ops.name_scope(name, "index_to_string"): 1449 init = TextFileStringTableInitializer( 1450 vocabulary_file, 1451 vocab_size=vocab_size, 1452 name="table_init", 1453 key_column_index=key_column_index, 1454 value_column_index=value_column_index, 1455 delimiter=delimiter) 1456 1457 # TODO(yleon): Use a more effienct structure. 1458 return StaticHashTableV1(init, default_value) 1459 1460 1461def index_to_string_table_from_tensor(vocabulary_list, 1462 default_value="UNK", 1463 name=None): 1464 """Returns a lookup table that maps a `Tensor` of indices into strings. 1465 1466 This operation constructs a lookup table to map int64 indices into string 1467 values. The mapping is initialized from a string `vocabulary_list` 1-D 1468 `Tensor` where each element is a value and the corresponding index within the 1469 tensor is the key. 1470 1471 Any input which does not have a corresponding index in 'vocabulary_list' 1472 (an out-of-vocabulary entry) is assigned the `default_value` 1473 1474 The underlying table must be initialized by calling 1475 `session.run(tf.tables_initializer)` or `session.run(table.init)` once. 1476 1477 Elements in `vocabulary_list` cannot have duplicates, otherwise when executing 1478 the table initializer op, it will throw a `FailedPreconditionError`. 1479 1480 Sample Usages: 1481 1482 ```python 1483 vocabulary_list = tf.constant(["emerson", "lake", "palmer"]) 1484 indices = tf.constant([1, 5], tf.int64) 1485 table = tf.lookup.index_to_string_table_from_tensor( 1486 vocabulary_list, default_value="UNKNOWN") 1487 values = table.lookup(indices) 1488 ... 1489 tf.tables_initializer().run() 1490 1491 values.eval() ==> ["lake", "UNKNOWN"] 1492 ``` 1493 1494 Args: 1495 vocabulary_list: A 1-D string `Tensor` that specifies the strings to map 1496 from indices. 1497 default_value: The value to use for out-of-vocabulary indices. 1498 name: A name for this op (optional). 1499 1500 Returns: 1501 The lookup table to map a string values associated to a given index `int64` 1502 `Tensors`. 1503 1504 Raises: 1505 ValueError: when `vocabulary_list` is not set. 1506 """ 1507 1508 if vocabulary_list is None: 1509 raise ValueError("vocabulary_list must be specified.") 1510 1511 with ops.name_scope(name, "index_to_string"): 1512 vocabulary_list = ops.convert_to_tensor(vocabulary_list, dtypes.string) 1513 num_elements = array_ops.size(vocabulary_list) 1514 keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64) 1515 1516 init = KeyValueTensorInitializer( 1517 keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init") 1518 # TODO(yleon): Use a more effienct structure. 1519 return StaticHashTableV1(init, default_value) 1520 1521 1522class MutableHashTable(LookupInterface): 1523 """A generic mutable hash table implementation. 1524 1525 Data can be inserted by calling the insert method and removed by calling the 1526 remove method. It does not support initialization via the init method. 1527 1528 Example usage: 1529 1530 ```python 1531 table = tf.lookup.MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, 1532 default_value=-1) 1533 sess.run(table.insert(keys, values)) 1534 out = table.lookup(query_keys) 1535 print(out.eval()) 1536 ``` 1537 """ 1538 1539 def __init__(self, 1540 key_dtype, 1541 value_dtype, 1542 default_value, 1543 name="MutableHashTable", 1544 checkpoint=True): 1545 """Creates an empty `MutableHashTable` object. 1546 1547 Creates a table, the type of its keys and values are specified by key_dtype 1548 and value_dtype, respectively. 1549 1550 Args: 1551 key_dtype: the type of the key tensors. 1552 value_dtype: the type of the value tensors. 1553 default_value: The value to use if a key is missing in the table. 1554 name: A name for the operation (optional). 1555 checkpoint: if True, the contents of the table are saved to and restored 1556 from checkpoints. If `shared_name` is empty for a checkpointed table, it 1557 is shared using the table node name. 1558 1559 Returns: 1560 A `MutableHashTable` object. 1561 1562 Raises: 1563 ValueError: If checkpoint is True and no name was specified. 1564 """ 1565 self._default_value = ops.convert_to_tensor( 1566 default_value, dtype=value_dtype) 1567 self._value_shape = self._default_value.get_shape() 1568 self._checkpoint = checkpoint 1569 self._key_dtype = key_dtype 1570 self._value_dtype = value_dtype 1571 self._name = name 1572 1573 self._shared_name = None 1574 if context.executing_eagerly(): 1575 # TODO(allenl): This will leak memory due to kernel caching by the 1576 # shared_name attribute value (but is better than the alternative of 1577 # sharing everything by default when executing eagerly; hopefully creating 1578 # tables in a loop is uncommon). 1579 # TODO(rohanj): Use context.shared_name() instead. 1580 self._shared_name = "table_%d" % (ops.uid(),) 1581 super(MutableHashTable, self).__init__(key_dtype, value_dtype) 1582 1583 self._resource_handle = self._create_resource() 1584 if checkpoint: 1585 saveable = MutableHashTable._Saveable(self, name) 1586 if not context.executing_eagerly(): 1587 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 1588 1589 def _create_resource(self): 1590 # The table must be shared if checkpointing is requested for multi-worker 1591 # training to work correctly. Use the node name if no shared_name has been 1592 # explicitly specified. 1593 use_node_name_sharing = self._checkpoint and self._shared_name is None 1594 if self._default_value.get_shape().ndims == 0: 1595 table_ref = gen_lookup_ops.mutable_hash_table_v2( 1596 shared_name=self._shared_name, 1597 use_node_name_sharing=use_node_name_sharing, 1598 key_dtype=self._key_dtype, 1599 value_dtype=self._value_dtype, 1600 name=self._name) 1601 else: 1602 table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( 1603 shared_name=self._shared_name, 1604 use_node_name_sharing=use_node_name_sharing, 1605 key_dtype=self._key_dtype, 1606 value_dtype=self._value_dtype, 1607 value_shape=self._default_value.get_shape(), 1608 name=self._name) 1609 1610 if context.executing_eagerly(): 1611 self._table_name = None 1612 else: 1613 self._table_name = table_ref.op.name.split("/")[-1] 1614 return table_ref 1615 1616 @property 1617 def name(self): 1618 return self._table_name 1619 1620 def size(self, name=None): 1621 """Compute the number of elements in this table. 1622 1623 Args: 1624 name: A name for the operation (optional). 1625 1626 Returns: 1627 A scalar tensor containing the number of elements in this table. 1628 """ 1629 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 1630 with ops.colocate_with(self.resource_handle): 1631 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 1632 1633 def remove(self, keys, name=None): 1634 """Removes `keys` and its associated values from the table. 1635 1636 If a key is not present in the table, it is silently ignored. 1637 1638 Args: 1639 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 1640 key type. 1641 name: A name for the operation (optional). 1642 1643 Returns: 1644 The created Operation. 1645 1646 Raises: 1647 TypeError: when `keys` do not match the table data types. 1648 """ 1649 if keys.dtype != self._key_dtype: 1650 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 1651 (self._key_dtype, keys.dtype)) 1652 1653 with ops.name_scope(name, "%s_lookup_table_remove" % self.name, 1654 (self.resource_handle, keys, self._default_value)): 1655 op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys) 1656 1657 return op 1658 1659 def lookup(self, keys, name=None): 1660 """Looks up `keys` in a table, outputs the corresponding values. 1661 1662 The `default_value` is used for keys not present in the table. 1663 1664 Args: 1665 keys: Keys to look up. Can be a tensor of any shape. Must match the 1666 table's key_dtype. 1667 name: A name for the operation (optional). 1668 1669 Returns: 1670 A tensor containing the values in the same shape as `keys` using the 1671 table's value type. 1672 1673 Raises: 1674 TypeError: when `keys` do not match the table data types. 1675 """ 1676 with ops.name_scope(name, "%s_lookup_table_find" % self.name, 1677 (self.resource_handle, keys, self._default_value)): 1678 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 1679 with ops.colocate_with(self.resource_handle): 1680 values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys, 1681 self._default_value) 1682 return values 1683 1684 def insert(self, keys, values, name=None): 1685 """Associates `keys` with `values`. 1686 1687 Args: 1688 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 1689 key type. 1690 values: Values to be associated with keys. Must be a tensor of the same 1691 shape as `keys` and match the table's value type. 1692 name: A name for the operation (optional). 1693 1694 Returns: 1695 The created Operation. 1696 1697 Raises: 1698 TypeError: when `keys` or `values` doesn't match the table data 1699 types. 1700 """ 1701 with ops.name_scope(name, "%s_lookup_table_insert" % self.name, 1702 [self.resource_handle, keys, values]): 1703 keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys") 1704 values = ops.convert_to_tensor(values, self._value_dtype, name="values") 1705 with ops.colocate_with(self.resource_handle): 1706 # pylint: disable=protected-access 1707 op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys, 1708 values) 1709 return op 1710 1711 def export(self, name=None): 1712 """Returns tensors of all keys and values in the table. 1713 1714 Args: 1715 name: A name for the operation (optional). 1716 1717 Returns: 1718 A pair of tensors with the first tensor containing all keys and the 1719 second tensors containing all values in the table. 1720 """ 1721 with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, 1722 [self.resource_handle]): 1723 with ops.colocate_with(self.resource_handle): 1724 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 1725 self.resource_handle, self._key_dtype, self._value_dtype) 1726 return exported_keys, exported_values 1727 1728 def _gather_saveables_for_checkpoint(self): 1729 """For object-based checkpointing.""" 1730 return {"table": functools.partial(MutableHashTable._Saveable, table=self)} 1731 1732 class _Saveable(BaseSaverBuilder.SaveableObject): 1733 """SaveableObject implementation for MutableHashTable.""" 1734 1735 def __init__(self, table, name): 1736 tensors = table.export() 1737 specs = [ 1738 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), 1739 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") 1740 ] 1741 # pylint: disable=protected-access 1742 super(MutableHashTable._Saveable, self).__init__(table, specs, name) 1743 1744 def restore(self, restored_tensors, restored_shapes, name=None): 1745 del restored_shapes # unused 1746 # pylint: disable=protected-access 1747 with ops.name_scope(name, "%s_table_restore" % self.name): 1748 with ops.colocate_with(self.op.resource_handle): 1749 return gen_lookup_ops.lookup_table_import_v2( 1750 self.op.resource_handle, restored_tensors[0], restored_tensors[1]) 1751 1752 1753@tf_export("lookup.experimental.DenseHashTable") 1754class DenseHashTable(LookupInterface): 1755 """A generic mutable hash table implementation using tensors as backing store. 1756 1757 Data can be inserted by calling the insert method and removed by calling the 1758 remove method. It does not support initialization via the init method. 1759 1760 It uses "open addressing" with quadratic reprobing to resolve collisions. 1761 Compared to `MutableHashTable` the insert, remove and lookup operations in a 1762 `DenseHashTable` are typically faster, but memory usage can be higher. 1763 However, `DenseHashTable` does not require additional memory for 1764 temporary tensors created during checkpointing and restore operations. 1765 1766 Example usage: 1767 1768 ```python 1769 table = tf.lookup.DenseHashTable(key_dtype=tf.int64, 1770 value_dtype=tf.int64, 1771 default_value=-1, 1772 empty_key=0, 1773 deleted_key=-1) 1774 1775 sess.run(table.insert(keys, values)) 1776 out = table.lookup(query_keys) 1777 print(out.eval()) 1778 ``` 1779 """ 1780 1781 # TODO(andreasst): consider extracting common code with MutableHashTable into 1782 # a common superclass. 1783 def __init__(self, 1784 key_dtype, 1785 value_dtype, 1786 default_value, 1787 empty_key, 1788 deleted_key, 1789 initial_num_buckets=None, 1790 name="MutableDenseHashTable", 1791 checkpoint=True): 1792 """Creates an empty `DenseHashTable` object. 1793 1794 Creates a table, the type of its keys and values are specified by key_dtype 1795 and value_dtype, respectively. 1796 1797 Args: 1798 key_dtype: the type of the key tensors. 1799 value_dtype: the type of the value tensors. 1800 default_value: The value to use if a key is missing in the table. 1801 empty_key: the key to use to represent empty buckets internally. Must not 1802 be used in insert, remove or lookup operations. 1803 deleted_key: the key to use to represent deleted buckets internally. Must 1804 not be used in insert, remove or lookup operations and be different from 1805 the empty_key. 1806 initial_num_buckets: the initial number of buckets. 1807 name: A name for the operation (optional). 1808 checkpoint: if True, the contents of the table are saved to and restored 1809 from checkpoints. If `shared_name` is empty for a checkpointed table, it 1810 is shared using the table node name. 1811 1812 Returns: 1813 A `DenseHashTable` object. 1814 1815 Raises: 1816 ValueError: If checkpoint is True and no name was specified. 1817 """ 1818 self._default_value = ops.convert_to_tensor( 1819 default_value, dtype=value_dtype, name="default_value") 1820 self._key_dtype = key_dtype 1821 self._value_dtype = value_dtype 1822 self._initial_num_buckets = initial_num_buckets 1823 self._value_shape = self._default_value.get_shape() 1824 self._checkpoint = checkpoint 1825 self._name = name 1826 1827 self._empty_key = ops.convert_to_tensor( 1828 empty_key, dtype=key_dtype, name="empty_key") 1829 self._deleted_key = ops.convert_to_tensor( 1830 deleted_key, dtype=key_dtype, name="deleted_key") 1831 self._shared_name = None 1832 if context.executing_eagerly(): 1833 # TODO(allenl): This will leak memory due to kernel caching by the 1834 # shared_name attribute value (but is better than the alternative of 1835 # sharing everything by default when executing eagerly; hopefully creating 1836 # tables in a loop is uncommon). 1837 # TODO(rohanj): Use context.shared_name() instead. 1838 self._shared_name = "table_%d" % (ops.uid(),) 1839 super(DenseHashTable, self).__init__(key_dtype, value_dtype) 1840 1841 self._resource_handle = self._create_resource() 1842 if checkpoint: 1843 saveable = DenseHashTable._Saveable(self, name) 1844 if not context.executing_eagerly(): 1845 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 1846 1847 def _create_resource(self): 1848 # The table must be shared if checkpointing is requested for multi-worker 1849 # training to work correctly. Use the node name if no shared_name has been 1850 # explicitly specified. 1851 use_node_name_sharing = self._checkpoint and self._shared_name is None 1852 table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( 1853 empty_key=self._empty_key, 1854 deleted_key=self._deleted_key, 1855 shared_name=self._shared_name, 1856 use_node_name_sharing=use_node_name_sharing, 1857 value_dtype=self._value_dtype, 1858 value_shape=self._value_shape, 1859 initial_num_buckets=self._initial_num_buckets, 1860 name=self._name) 1861 if context.executing_eagerly(): 1862 self._table_name = None 1863 else: 1864 self._table_name = table_ref.op.name.split("/")[-1] 1865 return table_ref 1866 1867 @property 1868 def name(self): 1869 return self._table_name 1870 1871 def size(self, name=None): 1872 """Compute the number of elements in this table. 1873 1874 Args: 1875 name: A name for the operation (optional). 1876 1877 Returns: 1878 A scalar tensor containing the number of elements in this table. 1879 """ 1880 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]): 1881 with ops.colocate_with(self.resource_handle): 1882 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle) 1883 1884 def lookup(self, keys, name=None): 1885 """Looks up `keys` in a table, outputs the corresponding values. 1886 1887 The `default_value` is used for keys not present in the table. 1888 1889 Args: 1890 keys: Keys to look up. Can be a tensor of any shape. Must match the 1891 table's key_dtype. 1892 name: A name for the operation (optional). 1893 1894 Returns: 1895 A tensor containing the values in the same shape as `keys` using the 1896 table's value type. 1897 1898 Raises: 1899 TypeError: when `keys` do not match the table data types. 1900 """ 1901 with ops.name_scope(name, "%s_lookup_table_find" % self.name, 1902 [self.resource_handle, keys]): 1903 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 1904 with ops.colocate_with(self.resource_handle): 1905 values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys, 1906 self._default_value) 1907 1908 return values 1909 1910 def insert_or_assign(self, keys, values, name=None): 1911 """Associates `keys` with `values`. 1912 1913 Args: 1914 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 1915 key type. 1916 values: Values to be associated with keys. Must be a tensor of the same 1917 shape as `keys` and match the table's value type. 1918 name: A name for the operation (optional). 1919 1920 Returns: 1921 The created Operation. 1922 1923 Raises: 1924 TypeError: when `keys` or `values` doesn't match the table data 1925 types. 1926 """ 1927 with ops.name_scope(name, "%s_lookup_table_insert" % self.name, 1928 [self.resource_handle, keys, values]): 1929 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") 1930 values = ops.convert_to_tensor( 1931 values, dtype=self._value_dtype, name="values") 1932 with ops.colocate_with(self.resource_handle): 1933 op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys, 1934 values) 1935 return op 1936 1937 def insert(self, keys, values, name=None): 1938 """Associates `keys` with `values`. 1939 1940 Args: 1941 keys: Keys to insert. Can be a tensor of any shape. Must match the table's 1942 key type. 1943 values: Values to be associated with keys. Must be a tensor of the same 1944 shape as `keys` and match the table's value type. 1945 name: A name for the operation (optional). 1946 1947 Returns: 1948 The created Operation. 1949 1950 Raises: 1951 TypeError: when `keys` or `values` doesn't match the table data 1952 types. 1953 """ 1954 return self.insert_or_assign(keys, values, name) 1955 1956 def erase(self, keys, name=None): 1957 """Removes `keys` and its associated values from the table. 1958 1959 If a key is not present in the table, it is silently ignored. 1960 1961 Args: 1962 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 1963 key type. 1964 name: A name for the operation (optional). 1965 1966 Returns: 1967 The created Operation. 1968 1969 Raises: 1970 TypeError: when `keys` do not match the table data types. 1971 """ 1972 if keys.dtype != self._key_dtype: 1973 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % 1974 (self._key_dtype, keys.dtype)) 1975 1976 with ops.name_scope(name, "%s_lookup_table_remove" % self.name, 1977 (self.resource_handle, keys, self._default_value)): 1978 # pylint: disable=protected-access 1979 op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys) 1980 1981 return op 1982 1983 def remove(self, keys, name=None): 1984 """Removes `keys` and its associated values from the table. 1985 1986 If a key is not present in the table, it is silently ignored. 1987 1988 Args: 1989 keys: Keys to remove. Can be a tensor of any shape. Must match the table's 1990 key type. 1991 name: A name for the operation (optional). 1992 1993 Returns: 1994 The created Operation. 1995 1996 Raises: 1997 TypeError: when `keys` do not match the table data types. 1998 """ 1999 return self.erase(keys, name) 2000 2001 def export(self, name=None): 2002 """Returns tensors of all keys and values in the table. 2003 2004 Args: 2005 name: A name for the operation (optional). 2006 2007 Returns: 2008 A pair of tensors with the first tensor containing all keys and the 2009 second tensors containing all values in the table. 2010 """ 2011 with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, 2012 [self.resource_handle]): 2013 with ops.colocate_with(self.resource_handle): 2014 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( 2015 self.resource_handle, self._key_dtype, self._value_dtype) 2016 2017 return exported_keys, exported_values 2018 2019 def _gather_saveables_for_checkpoint(self): 2020 """For object-based checkpointing.""" 2021 return {"table": functools.partial(DenseHashTable._Saveable, table=self)} 2022 2023 class _Saveable(BaseSaverBuilder.SaveableObject): 2024 """SaveableObject implementation for DenseHashTable.""" 2025 2026 def __init__(self, table, name): 2027 tensors = table.export() 2028 specs = [ 2029 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), 2030 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") 2031 ] 2032 # pylint: disable=protected-access 2033 super(DenseHashTable._Saveable, self).__init__(table, specs, name) 2034 2035 def restore(self, restored_tensors, restored_shapes, name=None): 2036 del restored_shapes # unused 2037 # pylint: disable=protected-access 2038 with ops.name_scope(name, "%s_table_restore" % self.name): 2039 with ops.colocate_with(self.op.resource_handle): 2040 return gen_lookup_ops.lookup_table_import_v2( 2041 self.op.resource_handle, restored_tensors[0], restored_tensors[1]) 2042 2043 2044ops.NotDifferentiable("LookupTableFind") 2045ops.NotDifferentiable("LookupTableFindV2") 2046ops.NotDifferentiable("LookupTableInsert") 2047ops.NotDifferentiable("LookupTableInsertV2") 2048ops.NotDifferentiable("LookupTableSize") 2049ops.NotDifferentiable("LookupTableSizeV2") 2050ops.NotDifferentiable("HashTable") 2051ops.NotDifferentiable("HashTableV2") 2052ops.NotDifferentiable("InitializeTable") 2053ops.NotDifferentiable("InitializeTableV2") 2054ops.NotDifferentiable("InitializeTableFromTextFile") 2055ops.NotDifferentiable("InitializeTableFromTextFileV2") 2056ops.NotDifferentiable("MutableDenseHashTable") 2057ops.NotDifferentiable("MutableDenseHashTableV2") 2058ops.NotDifferentiable("MutableHashTable") 2059ops.NotDifferentiable("MutableHashTableV2") 2060ops.NotDifferentiable("MutableHashTableOfTensors") 2061ops.NotDifferentiable("MutableHashTableOfTensorsV2") 2062