1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This API defines FeatureColumn abstraction. 16 17FeatureColumns provide a high level abstraction for ingesting and representing 18features. FeatureColumns are also the primary way of encoding features for 19canned `tf.estimator.Estimator`s. 20 21When using FeatureColumns with `Estimators`, the type of feature column you 22should choose depends on (1) the feature type and (2) the model type. 23 241. Feature type: 25 26 * Continuous features can be represented by `numeric_column`. 27 * Categorical features can be represented by any `categorical_column_with_*` 28 column: 29 - `categorical_column_with_vocabulary_list` 30 - `categorical_column_with_vocabulary_file` 31 - `categorical_column_with_hash_bucket` 32 - `categorical_column_with_identity` 33 - `weighted_categorical_column` 34 352. Model type: 36 37 * Deep neural network models (`DNNClassifier`, `DNNRegressor`). 38 39 Continuous features can be directly fed into deep neural network models. 40 41 age_column = numeric_column("age") 42 43 To feed sparse features into DNN models, wrap the column with 44 `embedding_column` or `indicator_column`. `indicator_column` is recommended 45 for features with only a few possible values. For features with many 46 possible values, to reduce the size of your model, `embedding_column` is 47 recommended. 48 49 embedded_dept_column = embedding_column( 50 categorical_column_with_vocabulary_list( 51 "department", ["math", "philosophy", ...]), dimension=10) 52 53 * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`). 54 55 Sparse features can be fed directly into linear models. They behave like an 56 indicator column but with an efficient implementation. 57 58 dept_column = categorical_column_with_vocabulary_list("department", 59 ["math", "philosophy", "english"]) 60 61 It is recommended that continuous features be bucketized before being 62 fed into linear models. 63 64 bucketized_age_column = bucketized_column( 65 source_column=age_column, 66 boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) 67 68 Sparse features can be crossed (also known as conjuncted or combined) in 69 order to form non-linearities, and then fed into linear models. 70 71 cross_dept_age_column = crossed_column( 72 columns=["department", bucketized_age_column], 73 hash_bucket_size=1000) 74 75Example of building canned `Estimator`s using FeatureColumns: 76 77 ```python 78 # Define features and transformations 79 deep_feature_columns = [age_column, embedded_dept_column] 80 wide_feature_columns = [dept_column, bucketized_age_column, 81 cross_dept_age_column] 82 83 # Build deep model 84 estimator = DNNClassifier( 85 feature_columns=deep_feature_columns, 86 hidden_units=[500, 250, 50]) 87 estimator.train(...) 88 89 # Or build a wide model 90 estimator = LinearClassifier( 91 feature_columns=wide_feature_columns) 92 estimator.train(...) 93 94 # Or build a wide and deep model! 95 estimator = DNNLinearCombinedClassifier( 96 linear_feature_columns=wide_feature_columns, 97 dnn_feature_columns=deep_feature_columns, 98 dnn_hidden_units=[500, 250, 50]) 99 estimator.train(...) 100 ``` 101 102 103FeatureColumns can also be transformed into a generic input layer for 104custom models using `input_layer`. 105 106Example of building model using FeatureColumns, this can be used in a 107`model_fn` which is given to the {tf.estimator.Estimator}: 108 109 ```python 110 # Building model via layers 111 112 deep_feature_columns = [age_column, embedded_dept_column] 113 columns_to_tensor = parse_feature_columns_from_examples( 114 serialized=my_data, 115 feature_columns=deep_feature_columns) 116 first_layer = input_layer( 117 features=columns_to_tensor, 118 feature_columns=deep_feature_columns) 119 second_layer = fully_connected(first_layer, ...) 120 ``` 121 122NOTE: Functions prefixed with "_" indicate experimental or private parts of 123the API subject to change, and should not be relied upon! 124""" 125 126from __future__ import absolute_import 127from __future__ import division 128from __future__ import print_function 129 130import abc 131import collections 132import math 133import re 134 135import numpy as np 136import six 137 138from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops 139from tensorflow.python.data.ops import readers 140from tensorflow.python.eager import context 141from tensorflow.python.feature_column import feature_column as fc_old 142from tensorflow.python.feature_column import utils as fc_utils 143from tensorflow.python.framework import dtypes 144from tensorflow.python.framework import ops 145from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 146from tensorflow.python.framework import tensor_shape 147from tensorflow.python.ops import array_ops 148from tensorflow.python.ops import check_ops 149from tensorflow.python.ops import control_flow_ops 150from tensorflow.python.ops import embedding_ops 151from tensorflow.python.ops import init_ops 152from tensorflow.python.ops import lookup_ops 153from tensorflow.python.ops import math_ops 154from tensorflow.python.ops import parsing_ops 155from tensorflow.python.ops import sparse_ops 156from tensorflow.python.ops import string_ops 157from tensorflow.python.ops import variable_scope 158from tensorflow.python.ops import variables 159from tensorflow.python.platform import gfile 160from tensorflow.python.platform import tf_logging as logging 161from tensorflow.python.training import checkpoint_utils 162from tensorflow.python.training.tracking import base as trackable 163from tensorflow.python.training.tracking import data_structures 164from tensorflow.python.training.tracking import tracking 165from tensorflow.python.util import deprecation 166from tensorflow.python.util import nest 167from tensorflow.python.util import tf_inspect 168from tensorflow.python.util.compat import collections_abc 169from tensorflow.python.util.tf_export import tf_export 170 171 172_FEATURE_COLUMN_DEPRECATION_DATE = None 173_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being ' 174 'deprecated. Please use the new FeatureColumn ' 175 'APIs instead.') 176 177 178class StateManager(object): 179 """Manages the state associated with FeatureColumns. 180 181 Some `FeatureColumn`s create variables or resources to assist their 182 computation. The `StateManager` is responsible for creating and storing these 183 objects since `FeatureColumn`s are supposed to be stateless configuration 184 only. 185 """ 186 187 def create_variable(self, 188 feature_column, 189 name, 190 shape, 191 dtype=None, 192 trainable=True, 193 use_resource=True, 194 initializer=None): 195 """Creates a new variable. 196 197 Args: 198 feature_column: A `FeatureColumn` object this variable corresponds to. 199 name: variable name. 200 shape: variable shape. 201 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 202 trainable: Whether this variable is trainable or not. 203 use_resource: If true, we use resource variables. Otherwise we use 204 RefVariable. 205 initializer: initializer instance (callable). 206 207 Returns: 208 The created variable. 209 """ 210 del feature_column, name, shape, dtype, trainable, use_resource, initializer 211 raise NotImplementedError('StateManager.create_variable') 212 213 def add_variable(self, feature_column, var): 214 """Adds an existing variable to the state. 215 216 Args: 217 feature_column: A `FeatureColumn` object to associate this variable with. 218 var: The variable. 219 """ 220 del feature_column, var 221 raise NotImplementedError('StateManager.add_variable') 222 223 def get_variable(self, feature_column, name): 224 """Returns an existing variable. 225 226 Args: 227 feature_column: A `FeatureColumn` object this variable corresponds to. 228 name: variable name. 229 """ 230 del feature_column, name 231 raise NotImplementedError('StateManager.get_var') 232 233 def add_resource(self, feature_column, name, resource): 234 """Creates a new resource. 235 236 Resources can be things such as tables, variables, trackables, etc. 237 238 Args: 239 feature_column: A `FeatureColumn` object this resource corresponds to. 240 name: Name of the resource. 241 resource: The resource. 242 243 Returns: 244 The created resource. 245 """ 246 del feature_column, name, resource 247 raise NotImplementedError('StateManager.add_resource') 248 249 def has_resource(self, feature_column, name): 250 """Returns true iff a resource with same name exists. 251 252 Resources can be things such as tables, variables, trackables, etc. 253 254 Args: 255 feature_column: A `FeatureColumn` object this variable corresponds to. 256 name: Name of the resource. 257 """ 258 del feature_column, name 259 raise NotImplementedError('StateManager.has_resource') 260 261 def get_resource(self, feature_column, name): 262 """Returns an already created resource. 263 264 Resources can be things such as tables, variables, trackables, etc. 265 266 Args: 267 feature_column: A `FeatureColumn` object this variable corresponds to. 268 name: Name of the resource. 269 """ 270 del feature_column, name 271 raise NotImplementedError('StateManager.get_resource') 272 273 274@tf_export('__internal__.feature_column.StateManager', v1=[]) 275class _StateManagerImpl(StateManager): 276 """Manages the state of DenseFeatures and LinearLayer. 277 278 Some `FeatureColumn`s create variables or resources to assist their 279 computation. The `StateManager` is responsible for creating and storing these 280 objects since `FeatureColumn`s are supposed to be stateless configuration 281 only. 282 """ 283 284 def __init__(self, layer, trainable): 285 """Creates an _StateManagerImpl object. 286 287 Args: 288 layer: The input layer this state manager is associated with. 289 trainable: Whether by default, variables created are trainable or not. 290 """ 291 self._trainable = trainable 292 self._layer = layer 293 if self._layer is not None and not hasattr(self._layer, '_resources'): 294 self._layer._resources = data_structures.Mapping() # pylint: disable=protected-access 295 self._cols_to_vars_map = collections.defaultdict(lambda: {}) 296 self._cols_to_resources_map = collections.defaultdict(lambda: {}) 297 298 def create_variable(self, 299 feature_column, 300 name, 301 shape, 302 dtype=None, 303 trainable=True, 304 use_resource=True, 305 initializer=None): 306 """Creates a new variable. 307 308 Args: 309 feature_column: A `FeatureColumn` object this variable corresponds to. 310 name: variable name. 311 shape: variable shape. 312 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 313 trainable: Whether this variable is trainable or not. 314 use_resource: If true, we use resource variables. Otherwise we use 315 RefVariable. 316 initializer: initializer instance (callable). 317 318 Returns: 319 The created variable. 320 """ 321 if name in self._cols_to_vars_map[feature_column]: 322 raise ValueError('Variable already exists.') 323 324 # We explicitly track these variables since `name` is not guaranteed to be 325 # unique and disable manual tracking that the add_weight call does. 326 with trackable.no_manual_dependency_tracking_scope(self._layer): 327 var = self._layer.add_weight( 328 name=name, 329 shape=shape, 330 dtype=dtype, 331 initializer=initializer, 332 trainable=self._trainable and trainable, 333 use_resource=use_resource, 334 # TODO(rohanj): Get rid of this hack once we have a mechanism for 335 # specifying a default partitioner for an entire layer. In that case, 336 # the default getter for Layers should work. 337 getter=variable_scope.get_variable) 338 if isinstance(var, variables.PartitionedVariable): 339 for v in var: 340 part_name = name + '/' + str(v._get_save_slice_info().var_offset[0]) # pylint: disable=protected-access 341 self._layer._track_trackable(v, feature_column.name + '/' + part_name) # pylint: disable=protected-access 342 else: 343 if isinstance(var, trackable.Trackable): 344 self._layer._track_trackable(var, feature_column.name + '/' + name) # pylint: disable=protected-access 345 346 self._cols_to_vars_map[feature_column][name] = var 347 return var 348 349 def get_variable(self, feature_column, name): 350 """Returns an existing variable. 351 352 Args: 353 feature_column: A `FeatureColumn` object this variable corresponds to. 354 name: variable name. 355 """ 356 if name in self._cols_to_vars_map[feature_column]: 357 return self._cols_to_vars_map[feature_column][name] 358 raise ValueError('Variable does not exist.') 359 360 def add_resource(self, feature_column, resource_name, resource): 361 """Creates a new resource. 362 363 Resources can be things such as tables, variables, trackables, etc. 364 365 Args: 366 feature_column: A `FeatureColumn` object this resource corresponds to. 367 resource_name: Name of the resource. 368 resource: The resource. 369 370 Returns: 371 The created resource. 372 """ 373 self._cols_to_resources_map[feature_column][resource_name] = resource 374 # pylint: disable=protected-access 375 if self._layer is not None and isinstance(resource, trackable.Trackable): 376 # Add trackable resources to the layer for serialization. 377 if feature_column.name not in self._layer._resources: 378 self._layer._resources[feature_column.name] = data_structures.Mapping() 379 if resource_name not in self._layer._resources[feature_column.name]: 380 self._layer._resources[feature_column.name][resource_name] = resource 381 # pylint: enable=protected-access 382 383 def has_resource(self, feature_column, resource_name): 384 """Returns true iff a resource with same name exists. 385 386 Resources can be things such as tables, variables, trackables, etc. 387 388 Args: 389 feature_column: A `FeatureColumn` object this variable corresponds to. 390 resource_name: Name of the resource. 391 """ 392 return resource_name in self._cols_to_resources_map[feature_column] 393 394 def get_resource(self, feature_column, resource_name): 395 """Returns an already created resource. 396 397 Resources can be things such as tables, variables, trackables, etc. 398 399 Args: 400 feature_column: A `FeatureColumn` object this variable corresponds to. 401 resource_name: Name of the resource. 402 """ 403 if (feature_column not in self._cols_to_resources_map or 404 resource_name not in self._cols_to_resources_map[feature_column]): 405 raise ValueError('Resource does not exist.') 406 return self._cols_to_resources_map[feature_column][resource_name] 407 408 409def _transform_features_v2(features, feature_columns, state_manager): 410 """Returns transformed features based on features columns passed in. 411 412 Please note that most probably you would not need to use this function. Please 413 check `input_layer` and `linear_model` to see whether they will 414 satisfy your use case or not. 415 416 Example: 417 418 ```python 419 # Define features and transformations 420 crosses_a_x_b = crossed_column( 421 columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000) 422 price_buckets = bucketized_column( 423 source_column=numeric_column("price"), boundaries=[...]) 424 425 columns = [crosses_a_x_b, price_buckets] 426 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 427 transformed = transform_features(features=features, feature_columns=columns) 428 429 assertCountEqual(columns, transformed.keys()) 430 ``` 431 432 Args: 433 features: A mapping from key to tensors. `FeatureColumn`s look up via these 434 keys. For example `numeric_column('price')` will look at 'price' key in 435 this dict. Values can be a `SparseTensor` or a `Tensor` depends on 436 corresponding `FeatureColumn`. 437 feature_columns: An iterable containing all the `FeatureColumn`s. 438 state_manager: A StateManager object that holds the FeatureColumn state. 439 440 Returns: 441 A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values. 442 """ 443 feature_columns = _normalize_feature_columns(feature_columns) 444 outputs = {} 445 with ops.name_scope( 446 None, default_name='transform_features', values=features.values()): 447 transformation_cache = FeatureTransformationCache(features) 448 for column in feature_columns: 449 with ops.name_scope( 450 None, 451 default_name=_sanitize_column_name_for_variable_scope(column.name)): 452 outputs[column] = transformation_cache.get(column, state_manager) 453 return outputs 454 455 456@tf_export('feature_column.make_parse_example_spec', v1=[]) 457def make_parse_example_spec_v2(feature_columns): 458 """Creates parsing spec dictionary from input feature_columns. 459 460 The returned dictionary can be used as arg 'features' in 461 `tf.io.parse_example`. 462 463 Typical usage example: 464 465 ```python 466 # Define features and transformations 467 feature_a = tf.feature_column.categorical_column_with_vocabulary_file(...) 468 feature_b = tf.feature_column.numeric_column(...) 469 feature_c_bucketized = tf.feature_column.bucketized_column( 470 tf.feature_column.numeric_column("feature_c"), ...) 471 feature_a_x_feature_c = tf.feature_column.crossed_column( 472 columns=["feature_a", feature_c_bucketized], ...) 473 474 feature_columns = set( 475 [feature_b, feature_c_bucketized, feature_a_x_feature_c]) 476 features = tf.io.parse_example( 477 serialized=serialized_examples, 478 features=tf.feature_column.make_parse_example_spec(feature_columns)) 479 ``` 480 481 For the above example, make_parse_example_spec would return the dict: 482 483 ```python 484 { 485 "feature_a": parsing_ops.VarLenFeature(tf.string), 486 "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32), 487 "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32) 488 } 489 ``` 490 491 Args: 492 feature_columns: An iterable containing all feature columns. All items 493 should be instances of classes derived from `FeatureColumn`. 494 495 Returns: 496 A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature` 497 value. 498 499 Raises: 500 ValueError: If any of the given `feature_columns` is not a `FeatureColumn` 501 instance. 502 """ 503 result = {} 504 for column in feature_columns: 505 if not isinstance(column, FeatureColumn): 506 raise ValueError('All feature_columns must be FeatureColumn instances. ' 507 'Given: {}'.format(column)) 508 config = column.parse_example_spec 509 for key, value in six.iteritems(config): 510 if key in result and value != result[key]: 511 raise ValueError( 512 'feature_columns contain different parse_spec for key ' 513 '{}. Given {} and {}'.format(key, value, result[key])) 514 result.update(config) 515 return result 516 517 518@tf_export('feature_column.embedding_column') 519def embedding_column(categorical_column, 520 dimension, 521 combiner='mean', 522 initializer=None, 523 ckpt_to_load_from=None, 524 tensor_name_in_ckpt=None, 525 max_norm=None, 526 trainable=True, 527 use_safe_embedding_lookup=True): 528 """`DenseColumn` that converts from sparse, categorical input. 529 530 Use this when your inputs are sparse, but you want to convert them to a dense 531 representation (e.g., to feed to a DNN). 532 533 Inputs must be a `CategoricalColumn` created by any of the 534 `categorical_column_*` function. Here is an example of using 535 `embedding_column` with `DNNClassifier`: 536 537 ```python 538 video_id = categorical_column_with_identity( 539 key='video_id', num_buckets=1000000, default_value=0) 540 columns = [embedding_column(video_id, 9),...] 541 542 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 543 544 label_column = ... 545 def input_fn(): 546 features = tf.io.parse_example( 547 ..., features=make_parse_example_spec(columns + [label_column])) 548 labels = features.pop(label_column.name) 549 return features, labels 550 551 estimator.train(input_fn=input_fn, steps=100) 552 ``` 553 554 Here is an example using `embedding_column` with model_fn: 555 556 ```python 557 def model_fn(features, ...): 558 video_id = categorical_column_with_identity( 559 key='video_id', num_buckets=1000000, default_value=0) 560 columns = [embedding_column(video_id, 9),...] 561 dense_tensor = input_layer(features, columns) 562 # Form DNN layers, calculate loss, and return EstimatorSpec. 563 ... 564 ``` 565 566 Args: 567 categorical_column: A `CategoricalColumn` created by a 568 `categorical_column_with_*` function. This column produces the sparse IDs 569 that are inputs to the embedding lookup. 570 dimension: An integer specifying dimension of the embedding, must be > 0. 571 combiner: A string specifying how to reduce if there are multiple entries in 572 a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 573 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 574 with bag-of-words columns. Each of this can be thought as example level 575 normalizations on the column. For more information, see 576 `tf.embedding_lookup_sparse`. 577 initializer: A variable initializer function to be used in embedding 578 variable initialization. If not specified, defaults to 579 `truncated_normal_initializer` with mean `0.0` and 580 standard deviation `1/sqrt(dimension)`. 581 ckpt_to_load_from: String representing checkpoint name/pattern from which to 582 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 583 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which 584 to restore the column weights. Required if `ckpt_to_load_from` is not 585 `None`. 586 max_norm: If not `None`, embedding values are l2-normalized to this value. 587 trainable: Whether or not the embedding is trainable. Default is True. 588 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 589 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 590 there are no empty rows and all weights and ids are positive at the 591 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 592 input tensors. Defaults to true, consider turning off if the above checks 593 are not needed. Note that having empty rows will not trigger any error 594 though the output result might be 0 or omitted. 595 596 Returns: 597 `DenseColumn` that converts from sparse input. 598 599 Raises: 600 ValueError: if `dimension` not > 0. 601 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 602 is specified. 603 ValueError: if `initializer` is specified and is not callable. 604 RuntimeError: If eager execution is enabled. 605 """ 606 if (dimension is None) or (dimension < 1): 607 raise ValueError('Invalid dimension {}.'.format(dimension)) 608 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 609 raise ValueError('Must specify both `ckpt_to_load_from` and ' 610 '`tensor_name_in_ckpt` or none of them.') 611 612 if (initializer is not None) and (not callable(initializer)): 613 raise ValueError('initializer must be callable if specified. ' 614 'Embedding of column_name: {}'.format( 615 categorical_column.name)) 616 if initializer is None: 617 initializer = init_ops.truncated_normal_initializer( 618 mean=0.0, stddev=1 / math.sqrt(dimension)) 619 620 return EmbeddingColumn( 621 categorical_column=categorical_column, 622 dimension=dimension, 623 combiner=combiner, 624 initializer=initializer, 625 ckpt_to_load_from=ckpt_to_load_from, 626 tensor_name_in_ckpt=tensor_name_in_ckpt, 627 max_norm=max_norm, 628 trainable=trainable, 629 use_safe_embedding_lookup=use_safe_embedding_lookup) 630 631 632@tf_export(v1=['feature_column.shared_embedding_columns']) 633def shared_embedding_columns(categorical_columns, 634 dimension, 635 combiner='mean', 636 initializer=None, 637 shared_embedding_collection_name=None, 638 ckpt_to_load_from=None, 639 tensor_name_in_ckpt=None, 640 max_norm=None, 641 trainable=True, 642 use_safe_embedding_lookup=True): 643 """List of dense columns that convert from sparse, categorical input. 644 645 This is similar to `embedding_column`, except that it produces a list of 646 embedding columns that share the same embedding weights. 647 648 Use this when your inputs are sparse and of the same type (e.g. watched and 649 impression video IDs that share the same vocabulary), and you want to convert 650 them to a dense representation (e.g., to feed to a DNN). 651 652 Inputs must be a list of categorical columns created by any of the 653 `categorical_column_*` function. They must all be of the same type and have 654 the same arguments except `key`. E.g. they can be 655 categorical_column_with_vocabulary_file with the same vocabulary_file. Some or 656 all columns could also be weighted_categorical_column. 657 658 Here is an example embedding of two features for a DNNClassifier model: 659 660 ```python 661 watched_video_id = categorical_column_with_vocabulary_file( 662 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 663 impression_video_id = categorical_column_with_vocabulary_file( 664 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 665 columns = shared_embedding_columns( 666 [watched_video_id, impression_video_id], dimension=10) 667 668 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 669 670 label_column = ... 671 def input_fn(): 672 features = tf.io.parse_example( 673 ..., features=make_parse_example_spec(columns + [label_column])) 674 labels = features.pop(label_column.name) 675 return features, labels 676 677 estimator.train(input_fn=input_fn, steps=100) 678 ``` 679 680 Here is an example using `shared_embedding_columns` with model_fn: 681 682 ```python 683 def model_fn(features, ...): 684 watched_video_id = categorical_column_with_vocabulary_file( 685 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 686 impression_video_id = categorical_column_with_vocabulary_file( 687 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 688 columns = shared_embedding_columns( 689 [watched_video_id, impression_video_id], dimension=10) 690 dense_tensor = input_layer(features, columns) 691 # Form DNN layers, calculate loss, and return EstimatorSpec. 692 ... 693 ``` 694 695 Args: 696 categorical_columns: List of categorical columns created by a 697 `categorical_column_with_*` function. These columns produce the sparse IDs 698 that are inputs to the embedding lookup. All columns must be of the same 699 type and have the same arguments except `key`. E.g. they can be 700 categorical_column_with_vocabulary_file with the same vocabulary_file. 701 Some or all columns could also be weighted_categorical_column. 702 dimension: An integer specifying dimension of the embedding, must be > 0. 703 combiner: A string specifying how to reduce if there are multiple entries in 704 a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 705 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 706 with bag-of-words columns. Each of this can be thought as example level 707 normalizations on the column. For more information, see 708 `tf.embedding_lookup_sparse`. 709 initializer: A variable initializer function to be used in embedding 710 variable initialization. If not specified, defaults to 711 `truncated_normal_initializer` with mean `0.0` and 712 standard deviation `1/sqrt(dimension)`. 713 shared_embedding_collection_name: Optional name of the collection where 714 shared embedding weights are added. If not given, a reasonable name will 715 be chosen based on the names of `categorical_columns`. This is also used 716 in `variable_scope` when creating shared embedding weights. 717 ckpt_to_load_from: String representing checkpoint name/pattern from which to 718 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 719 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which 720 to restore the column weights. Required if `ckpt_to_load_from` is not 721 `None`. 722 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 723 than this value, before combining. 724 trainable: Whether or not the embedding is trainable. Default is True. 725 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 726 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 727 there are no empty rows and all weights and ids are positive at the 728 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 729 input tensors. Defaults to true, consider turning off if the above checks 730 are not needed. Note that having empty rows will not trigger any error 731 though the output result might be 0 or omitted. 732 733 Returns: 734 A list of dense columns that converts from sparse input. The order of 735 results follows the ordering of `categorical_columns`. 736 737 Raises: 738 ValueError: if `dimension` not > 0. 739 ValueError: if any of the given `categorical_columns` is of different type 740 or has different arguments than the others. 741 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 742 is specified. 743 ValueError: if `initializer` is specified and is not callable. 744 RuntimeError: if eager execution is enabled. 745 """ 746 if context.executing_eagerly(): 747 raise RuntimeError('shared_embedding_columns are not supported when eager ' 748 'execution is enabled.') 749 750 if (dimension is None) or (dimension < 1): 751 raise ValueError('Invalid dimension {}.'.format(dimension)) 752 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 753 raise ValueError('Must specify both `ckpt_to_load_from` and ' 754 '`tensor_name_in_ckpt` or none of them.') 755 756 if (initializer is not None) and (not callable(initializer)): 757 raise ValueError('initializer must be callable if specified.') 758 if initializer is None: 759 initializer = init_ops.truncated_normal_initializer( 760 mean=0.0, stddev=1. / math.sqrt(dimension)) 761 762 # Sort the columns so the default collection name is deterministic even if the 763 # user passes columns from an unsorted collection, such as dict.values(). 764 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 765 766 c0 = sorted_columns[0] 767 num_buckets = c0._num_buckets # pylint: disable=protected-access 768 if not isinstance(c0, fc_old._CategoricalColumn): # pylint: disable=protected-access 769 raise ValueError( 770 'All categorical_columns must be subclasses of _CategoricalColumn. ' 771 'Given: {}, of type: {}'.format(c0, type(c0))) 772 while isinstance( 773 c0, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn, # pylint: disable=protected-access 774 fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)): # pylint: disable=protected-access 775 c0 = c0.categorical_column 776 for c in sorted_columns[1:]: 777 while isinstance( 778 c, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn, # pylint: disable=protected-access 779 fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)): # pylint: disable=protected-access 780 c = c.categorical_column 781 if not isinstance(c, type(c0)): 782 raise ValueError( 783 'To use shared_embedding_column, all categorical_columns must have ' 784 'the same type, or be weighted_categorical_column or sequence column ' 785 'of the same type. Given column: {} of type: {} does not match given ' 786 'column: {} of type: {}'.format(c0, type(c0), c, type(c))) 787 if num_buckets != c._num_buckets: # pylint: disable=protected-access 788 raise ValueError( 789 'To use shared_embedding_column, all categorical_columns must have ' 790 'the same number of buckets. ven column: {} with buckets: {} does ' 791 'not match column: {} with buckets: {}'.format( 792 c0, num_buckets, c, c._num_buckets)) # pylint: disable=protected-access 793 794 if not shared_embedding_collection_name: 795 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 796 shared_embedding_collection_name += '_shared_embedding' 797 798 result = [] 799 for column in categorical_columns: 800 result.append( 801 fc_old._SharedEmbeddingColumn( # pylint: disable=protected-access 802 categorical_column=column, 803 initializer=initializer, 804 dimension=dimension, 805 combiner=combiner, 806 shared_embedding_collection_name=shared_embedding_collection_name, 807 ckpt_to_load_from=ckpt_to_load_from, 808 tensor_name_in_ckpt=tensor_name_in_ckpt, 809 max_norm=max_norm, 810 trainable=trainable, 811 use_safe_embedding_lookup=use_safe_embedding_lookup)) 812 813 return result 814 815 816@tf_export('feature_column.shared_embeddings', v1=[]) 817def shared_embedding_columns_v2(categorical_columns, 818 dimension, 819 combiner='mean', 820 initializer=None, 821 shared_embedding_collection_name=None, 822 ckpt_to_load_from=None, 823 tensor_name_in_ckpt=None, 824 max_norm=None, 825 trainable=True, 826 use_safe_embedding_lookup=True): 827 """List of dense columns that convert from sparse, categorical input. 828 829 This is similar to `embedding_column`, except that it produces a list of 830 embedding columns that share the same embedding weights. 831 832 Use this when your inputs are sparse and of the same type (e.g. watched and 833 impression video IDs that share the same vocabulary), and you want to convert 834 them to a dense representation (e.g., to feed to a DNN). 835 836 Inputs must be a list of categorical columns created by any of the 837 `categorical_column_*` function. They must all be of the same type and have 838 the same arguments except `key`. E.g. they can be 839 categorical_column_with_vocabulary_file with the same vocabulary_file. Some or 840 all columns could also be weighted_categorical_column. 841 842 Here is an example embedding of two features for a DNNClassifier model: 843 844 ```python 845 watched_video_id = categorical_column_with_vocabulary_file( 846 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 847 impression_video_id = categorical_column_with_vocabulary_file( 848 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 849 columns = shared_embedding_columns( 850 [watched_video_id, impression_video_id], dimension=10) 851 852 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 853 854 label_column = ... 855 def input_fn(): 856 features = tf.io.parse_example( 857 ..., features=make_parse_example_spec(columns + [label_column])) 858 labels = features.pop(label_column.name) 859 return features, labels 860 861 estimator.train(input_fn=input_fn, steps=100) 862 ``` 863 864 Here is an example using `shared_embedding_columns` with model_fn: 865 866 ```python 867 def model_fn(features, ...): 868 watched_video_id = categorical_column_with_vocabulary_file( 869 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 870 impression_video_id = categorical_column_with_vocabulary_file( 871 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 872 columns = shared_embedding_columns( 873 [watched_video_id, impression_video_id], dimension=10) 874 dense_tensor = input_layer(features, columns) 875 # Form DNN layers, calculate loss, and return EstimatorSpec. 876 ... 877 ``` 878 879 Args: 880 categorical_columns: List of categorical columns created by a 881 `categorical_column_with_*` function. These columns produce the sparse IDs 882 that are inputs to the embedding lookup. All columns must be of the same 883 type and have the same arguments except `key`. E.g. they can be 884 categorical_column_with_vocabulary_file with the same vocabulary_file. 885 Some or all columns could also be weighted_categorical_column. 886 dimension: An integer specifying dimension of the embedding, must be > 0. 887 combiner: A string specifying how to reduce if there are multiple entries 888 in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 889 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 890 with bag-of-words columns. Each of this can be thought as example level 891 normalizations on the column. For more information, see 892 `tf.embedding_lookup_sparse`. 893 initializer: A variable initializer function to be used in embedding 894 variable initialization. If not specified, defaults to 895 `truncated_normal_initializer` with mean `0.0` and standard 896 deviation `1/sqrt(dimension)`. 897 shared_embedding_collection_name: Optional collective name of these columns. 898 If not given, a reasonable name will be chosen based on the names of 899 `categorical_columns`. 900 ckpt_to_load_from: String representing checkpoint name/pattern from which to 901 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 902 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from 903 which to restore the column weights. Required if `ckpt_to_load_from` is 904 not `None`. 905 max_norm: If not `None`, each embedding is clipped if its l2-norm is 906 larger than this value, before combining. 907 trainable: Whether or not the embedding is trainable. Default is True. 908 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 909 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 910 there are no empty rows and all weights and ids are positive at the 911 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 912 input tensors. Defaults to true, consider turning off if the above checks 913 are not needed. Note that having empty rows will not trigger any error 914 though the output result might be 0 or omitted. 915 916 Returns: 917 A list of dense columns that converts from sparse input. The order of 918 results follows the ordering of `categorical_columns`. 919 920 Raises: 921 ValueError: if `dimension` not > 0. 922 ValueError: if any of the given `categorical_columns` is of different type 923 or has different arguments than the others. 924 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 925 is specified. 926 ValueError: if `initializer` is specified and is not callable. 927 RuntimeError: if eager execution is enabled. 928 """ 929 if context.executing_eagerly(): 930 raise RuntimeError('shared_embedding_columns are not supported when eager ' 931 'execution is enabled.') 932 933 if (dimension is None) or (dimension < 1): 934 raise ValueError('Invalid dimension {}.'.format(dimension)) 935 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 936 raise ValueError('Must specify both `ckpt_to_load_from` and ' 937 '`tensor_name_in_ckpt` or none of them.') 938 939 if (initializer is not None) and (not callable(initializer)): 940 raise ValueError('initializer must be callable if specified.') 941 if initializer is None: 942 initializer = init_ops.truncated_normal_initializer( 943 mean=0.0, stddev=1. / math.sqrt(dimension)) 944 945 # Sort the columns so the default collection name is deterministic even if the 946 # user passes columns from an unsorted collection, such as dict.values(). 947 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 948 949 c0 = sorted_columns[0] 950 num_buckets = c0.num_buckets 951 if not isinstance(c0, CategoricalColumn): 952 raise ValueError( 953 'All categorical_columns must be subclasses of CategoricalColumn. ' 954 'Given: {}, of type: {}'.format(c0, type(c0))) 955 while isinstance(c0, (WeightedCategoricalColumn, SequenceCategoricalColumn)): 956 c0 = c0.categorical_column 957 for c in sorted_columns[1:]: 958 while isinstance(c, (WeightedCategoricalColumn, SequenceCategoricalColumn)): 959 c = c.categorical_column 960 if not isinstance(c, type(c0)): 961 raise ValueError( 962 'To use shared_embedding_column, all categorical_columns must have ' 963 'the same type, or be weighted_categorical_column or sequence column ' 964 'of the same type. Given column: {} of type: {} does not match given ' 965 'column: {} of type: {}'.format(c0, type(c0), c, type(c))) 966 if num_buckets != c.num_buckets: 967 raise ValueError( 968 'To use shared_embedding_column, all categorical_columns must have ' 969 'the same number of buckets. Given column: {} with buckets: {} does ' 970 'not match column: {} with buckets: {}'.format( 971 c0, num_buckets, c, c.num_buckets)) 972 973 if not shared_embedding_collection_name: 974 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 975 shared_embedding_collection_name += '_shared_embedding' 976 977 column_creator = SharedEmbeddingColumnCreator( 978 dimension, initializer, ckpt_to_load_from, tensor_name_in_ckpt, 979 num_buckets, trainable, shared_embedding_collection_name, 980 use_safe_embedding_lookup) 981 982 result = [] 983 for column in categorical_columns: 984 result.append( 985 column_creator( 986 categorical_column=column, combiner=combiner, max_norm=max_norm)) 987 988 return result 989 990 991@tf_export('feature_column.numeric_column') 992def numeric_column(key, 993 shape=(1,), 994 default_value=None, 995 dtype=dtypes.float32, 996 normalizer_fn=None): 997 """Represents real valued or numerical features. 998 999 Example: 1000 1001 Assume we have data with two features `a` and `b`. 1002 1003 >>> data = {'a': [15, 9, 17, 19, 21, 18, 25, 30], 1004 ... 'b': [5.0, 6.4, 10.5, 13.6, 15.7, 19.9, 20.3 , 0.0]} 1005 1006 Let us represent the features `a` and `b` as numerical features. 1007 1008 >>> a = tf.feature_column.numeric_column('a') 1009 >>> b = tf.feature_column.numeric_column('b') 1010 1011 Feature column describe a set of transformations to the inputs. 1012 1013 For example, to "bucketize" feature `a`, wrap the `a` column in a 1014 `feature_column.bucketized_column`. 1015 Providing `5` bucket boundaries, the bucketized_column api 1016 will bucket this feature in total of `6` buckets. 1017 1018 >>> a_buckets = tf.feature_column.bucketized_column(a, 1019 ... boundaries=[10, 15, 20, 25, 30]) 1020 1021 Create a `DenseFeatures` layer which will apply the transformations 1022 described by the set of `tf.feature_column` objects: 1023 1024 >>> feature_layer = tf.keras.layers.DenseFeatures([a_buckets, b]) 1025 >>> print(feature_layer(data)) 1026 tf.Tensor( 1027 [[ 0. 0. 1. 0. 0. 0. 5. ] 1028 [ 1. 0. 0. 0. 0. 0. 6.4] 1029 [ 0. 0. 1. 0. 0. 0. 10.5] 1030 [ 0. 0. 1. 0. 0. 0. 13.6] 1031 [ 0. 0. 0. 1. 0. 0. 15.7] 1032 [ 0. 0. 1. 0. 0. 0. 19.9] 1033 [ 0. 0. 0. 0. 1. 0. 20.3] 1034 [ 0. 0. 0. 0. 0. 1. 0. ]], shape=(8, 7), dtype=float32) 1035 1036 Args: 1037 key: A unique string identifying the input feature. It is used as the 1038 column name and the dictionary key for feature parsing configs, feature 1039 `Tensor` objects, and feature columns. 1040 shape: An iterable of integers specifies the shape of the `Tensor`. An 1041 integer can be given which means a single dimension `Tensor` with given 1042 width. The `Tensor` representing the column will have the shape of 1043 [batch_size] + `shape`. 1044 default_value: A single value compatible with `dtype` or an iterable of 1045 values compatible with `dtype` which the column takes on during 1046 `tf.Example` parsing if data is missing. A default value of `None` will 1047 cause `tf.io.parse_example` to fail if an example does not contain this 1048 column. If a single value is provided, the same value will be applied as 1049 the default value for every item. If an iterable of values is provided, 1050 the shape of the `default_value` should be equal to the given `shape`. 1051 dtype: defines the type of values. Default value is `tf.float32`. Must be a 1052 non-quantized, real integer or floating point type. 1053 normalizer_fn: If not `None`, a function that can be used to normalize the 1054 value of the tensor after `default_value` is applied for parsing. 1055 Normalizer function takes the input `Tensor` as its argument, and returns 1056 the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that 1057 even though the most common use case of this function is normalization, it 1058 can be used for any kind of Tensorflow transformations. 1059 1060 Returns: 1061 A `NumericColumn`. 1062 1063 Raises: 1064 TypeError: if any dimension in shape is not an int 1065 ValueError: if any dimension in shape is not a positive integer 1066 TypeError: if `default_value` is an iterable but not compatible with `shape` 1067 TypeError: if `default_value` is not compatible with `dtype`. 1068 ValueError: if `dtype` is not convertible to `tf.float32`. 1069 """ 1070 shape = _check_shape(shape, key) 1071 if not (dtype.is_integer or dtype.is_floating): 1072 raise ValueError('dtype must be convertible to float. ' 1073 'dtype: {}, key: {}'.format(dtype, key)) 1074 default_value = fc_utils.check_default_value( 1075 shape, default_value, dtype, key) 1076 1077 if normalizer_fn is not None and not callable(normalizer_fn): 1078 raise TypeError( 1079 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) 1080 1081 fc_utils.assert_key_is_string(key) 1082 return NumericColumn( 1083 key, 1084 shape=shape, 1085 default_value=default_value, 1086 dtype=dtype, 1087 normalizer_fn=normalizer_fn) 1088 1089 1090@tf_export('feature_column.bucketized_column') 1091def bucketized_column(source_column, boundaries): 1092 """Represents discretized dense input bucketed by `boundaries`. 1093 1094 Buckets include the left boundary, and exclude the right boundary. Namely, 1095 `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`, 1096 `[1., 2.)`, and `[2., +inf)`. 1097 1098 For example, if the inputs are 1099 1100 ```python 1101 boundaries = [0, 10, 100] 1102 input tensor = [[-5, 10000] 1103 [150, 10] 1104 [5, 100]] 1105 ``` 1106 1107 then the output will be 1108 1109 ```python 1110 output = [[0, 3] 1111 [3, 2] 1112 [1, 3]] 1113 ``` 1114 1115 Example: 1116 1117 ```python 1118 price = tf.feature_column.numeric_column('price') 1119 bucketized_price = tf.feature_column.bucketized_column( 1120 price, boundaries=[...]) 1121 columns = [bucketized_price, ...] 1122 features = tf.io.parse_example( 1123 ..., features=tf.feature_column.make_parse_example_spec(columns)) 1124 dense_tensor = tf.keras.layers.DenseFeatures(columns)(features) 1125 ``` 1126 1127 A `bucketized_column` can also be crossed with another categorical column 1128 using `crossed_column`: 1129 1130 ```python 1131 price = tf.feature_column.numeric_column('price') 1132 # bucketized_column converts numerical feature to a categorical one. 1133 bucketized_price = tf.feature_column.bucketized_column( 1134 price, boundaries=[...]) 1135 # 'keywords' is a string feature. 1136 price_x_keywords = tf.feature_column.crossed_column( 1137 [bucketized_price, 'keywords'], 50K) 1138 columns = [price_x_keywords, ...] 1139 features = tf.io.parse_example( 1140 ..., features=tf.feature_column.make_parse_example_spec(columns)) 1141 dense_tensor = tf.keras.layers.DenseFeatures(columns)(features) 1142 linear_model = tf.keras.experimental.LinearModel(units=...)(dense_tensor) 1143 ``` 1144 1145 Args: 1146 source_column: A one-dimensional dense column which is generated with 1147 `numeric_column`. 1148 boundaries: A sorted list or tuple of floats specifying the boundaries. 1149 1150 Returns: 1151 A `BucketizedColumn`. 1152 1153 Raises: 1154 ValueError: If `source_column` is not a numeric column, or if it is not 1155 one-dimensional. 1156 ValueError: If `boundaries` is not a sorted list or tuple. 1157 """ 1158 if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)): # pylint: disable=protected-access 1159 raise ValueError( 1160 'source_column must be a column generated with numeric_column(). ' 1161 'Given: {}'.format(source_column)) 1162 if len(source_column.shape) > 1: 1163 raise ValueError( 1164 'source_column must be one-dimensional column. ' 1165 'Given: {}'.format(source_column)) 1166 if not boundaries: 1167 raise ValueError('boundaries must not be empty.') 1168 if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)): 1169 raise ValueError('boundaries must be a sorted list.') 1170 for i in range(len(boundaries) - 1): 1171 if boundaries[i] >= boundaries[i + 1]: 1172 raise ValueError('boundaries must be a sorted list.') 1173 return BucketizedColumn(source_column, tuple(boundaries)) 1174 1175 1176@tf_export('feature_column.categorical_column_with_hash_bucket') 1177def categorical_column_with_hash_bucket(key, 1178 hash_bucket_size, 1179 dtype=dtypes.string): 1180 """Represents sparse feature where ids are set by hashing. 1181 1182 Use this when your sparse features are in string or integer format, and you 1183 want to distribute your inputs into a finite number of buckets by hashing. 1184 output_id = Hash(input_feature_string) % bucket_size for string type input. 1185 For int type input, the value is converted to its string representation first 1186 and then hashed by the same formula. 1187 1188 For input dictionary `features`, `features[key]` is either `Tensor` or 1189 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1190 and `''` for string, which will be dropped by this feature column. 1191 1192 Example: 1193 1194 ```python 1195 import tensorflow as tf 1196 keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords", 1197 10000) 1198 columns = [keywords] 1199 features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM', 1200 'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow', 1201 'LSTM', 'Keras', 'RNN']])} 1202 linear_prediction, _, _ = tf.compat.v1.feature_column.linear_model(features, 1203 columns) 1204 1205 # or 1206 import tensorflow as tf 1207 keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords", 1208 10000) 1209 keywords_embedded = tf.feature_column.embedding_column(keywords, 16) 1210 columns = [keywords_embedded] 1211 features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM', 1212 'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow', 1213 'LSTM', 'Keras', 'RNN']])} 1214 input_layer = tf.keras.layers.DenseFeatures(columns) 1215 dense_tensor = input_layer(features) 1216 ``` 1217 1218 Args: 1219 key: A unique string identifying the input feature. It is used as the 1220 column name and the dictionary key for feature parsing configs, feature 1221 `Tensor` objects, and feature columns. 1222 hash_bucket_size: An int > 1. The number of buckets. 1223 dtype: The type of features. Only string and integer types are supported. 1224 1225 Returns: 1226 A `HashedCategoricalColumn`. 1227 1228 Raises: 1229 ValueError: `hash_bucket_size` is not greater than 1. 1230 ValueError: `dtype` is neither string nor integer. 1231 """ 1232 if hash_bucket_size is None: 1233 raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key)) 1234 1235 if hash_bucket_size < 1: 1236 raise ValueError('hash_bucket_size must be at least 1. ' 1237 'hash_bucket_size: {}, key: {}'.format( 1238 hash_bucket_size, key)) 1239 1240 fc_utils.assert_key_is_string(key) 1241 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1242 1243 return HashedCategoricalColumn(key, hash_bucket_size, dtype) 1244 1245 1246@tf_export(v1=['feature_column.categorical_column_with_vocabulary_file']) 1247def categorical_column_with_vocabulary_file(key, 1248 vocabulary_file, 1249 vocabulary_size=None, 1250 num_oov_buckets=0, 1251 default_value=None, 1252 dtype=dtypes.string): 1253 """A `CategoricalColumn` with a vocabulary file. 1254 1255 Use this when your inputs are in string or integer format, and you have a 1256 vocabulary file that maps each value to an integer ID. By default, 1257 out-of-vocabulary values are ignored. Use either (but not both) of 1258 `num_oov_buckets` and `default_value` to specify how to include 1259 out-of-vocabulary values. 1260 1261 For input dictionary `features`, `features[key]` is either `Tensor` or 1262 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1263 and `''` for string, which will be dropped by this feature column. 1264 1265 Example with `num_oov_buckets`: 1266 File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state 1267 abbreviation. All inputs with values in that file are assigned an ID 0-49, 1268 corresponding to its line number. All other values are hashed and assigned an 1269 ID 50-54. 1270 1271 ```python 1272 import tensorflow as tf 1273 states = tf.feature_column.categorical_column_with_vocabulary_file( 1274 key='states', vocabulary_file='states.txt', vocabulary_size=5, 1275 num_oov_buckets=1) 1276 columns = [states] 1277 features = {'states':tf.constant([['california', 'georgia', 'michigan', 1278 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan', 1279 'texas']])} 1280 linear_prediction = tf.compat.v1.feature_column.linear_model(features, 1281 columns) 1282 ``` 1283 1284 Example with `default_value`: 1285 File '/us/states.txt' contains 51 lines - the first line is 'XX', and the 1286 other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX' 1287 in input, and other values missing from the file, will be assigned ID 0. All 1288 others are assigned the corresponding line number 1-50. 1289 1290 ```python 1291 import tensorflow as tf 1292 states = tf.feature_column.categorical_column_with_vocabulary_file( 1293 key='states', vocabulary_file='states.txt', vocabulary_size=6, 1294 default_value=0) 1295 columns = [states] 1296 features = {'states':tf.constant([['california', 'georgia', 'michigan', 1297 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan', 1298 'texas']])} 1299 linear_prediction = tf.compat.v1.feature_column.linear_model(features, 1300 columns) 1301 ``` 1302 1303 And to make an embedding with either: 1304 1305 ```python 1306 import tensorflow as tf 1307 states = tf.feature_column.categorical_column_with_vocabulary_file( 1308 key='states', vocabulary_file='states.txt', vocabulary_size=5, 1309 num_oov_buckets=1) 1310 columns = [tf.feature_column.embedding_column(states, 3)] 1311 features = {'states':tf.constant([['california', 'georgia', 'michigan', 1312 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan', 1313 'texas']])} 1314 input_layer = tf.keras.layers.DenseFeatures(columns) 1315 dense_tensor = input_layer(features) 1316 ``` 1317 1318 Args: 1319 key: A unique string identifying the input feature. It is used as the 1320 column name and the dictionary key for feature parsing configs, feature 1321 `Tensor` objects, and feature columns. 1322 vocabulary_file: The vocabulary file name. 1323 vocabulary_size: Number of the elements in the vocabulary. This must be no 1324 greater than length of `vocabulary_file`, if less than length, later 1325 values are ignored. If None, it is set to the length of `vocabulary_file`. 1326 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1327 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1328 `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of 1329 the input value. A positive `num_oov_buckets` can not be specified with 1330 `default_value`. 1331 default_value: The integer ID value to return for out-of-vocabulary feature 1332 values, defaults to `-1`. This can not be specified with a positive 1333 `num_oov_buckets`. 1334 dtype: The type of features. Only string and integer types are supported. 1335 1336 Returns: 1337 A `CategoricalColumn` with a vocabulary file. 1338 1339 Raises: 1340 ValueError: `vocabulary_file` is missing or cannot be opened. 1341 ValueError: `vocabulary_size` is missing or < 1. 1342 ValueError: `num_oov_buckets` is a negative integer. 1343 ValueError: `num_oov_buckets` and `default_value` are both specified. 1344 ValueError: `dtype` is neither string nor integer. 1345 """ 1346 return categorical_column_with_vocabulary_file_v2( 1347 key, vocabulary_file, vocabulary_size, 1348 dtype, default_value, 1349 num_oov_buckets) 1350 1351 1352@tf_export('feature_column.categorical_column_with_vocabulary_file', v1=[]) 1353def categorical_column_with_vocabulary_file_v2(key, 1354 vocabulary_file, 1355 vocabulary_size=None, 1356 dtype=dtypes.string, 1357 default_value=None, 1358 num_oov_buckets=0, 1359 file_format=None): 1360 """A `CategoricalColumn` with a vocabulary file. 1361 1362 Use this when your inputs are in string or integer format, and you have a 1363 vocabulary file that maps each value to an integer ID. By default, 1364 out-of-vocabulary values are ignored. Use either (but not both) of 1365 `num_oov_buckets` and `default_value` to specify how to include 1366 out-of-vocabulary values. 1367 1368 For input dictionary `features`, `features[key]` is either `Tensor` or 1369 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1370 and `''` for string, which will be dropped by this feature column. 1371 1372 Example with `num_oov_buckets`: 1373 File `'/us/states.txt'` contains 50 lines, each with a 2-character U.S. state 1374 abbreviation. All inputs with values in that file are assigned an ID 0-49, 1375 corresponding to its line number. All other values are hashed and assigned an 1376 ID 50-54. 1377 1378 ```python 1379 states = categorical_column_with_vocabulary_file( 1380 key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, 1381 num_oov_buckets=5) 1382 columns = [states, ...] 1383 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1384 linear_prediction = linear_model(features, columns) 1385 ``` 1386 1387 Example with `default_value`: 1388 File `'/us/states.txt'` contains 51 lines - the first line is `'XX'`, and the 1389 other 50 each have a 2-character U.S. state abbreviation. Both a literal 1390 `'XX'` in input, and other values missing from the file, will be assigned 1391 ID 0. All others are assigned the corresponding line number 1-50. 1392 1393 ```python 1394 states = categorical_column_with_vocabulary_file( 1395 key='states', vocabulary_file='/us/states.txt', vocabulary_size=51, 1396 default_value=0) 1397 columns = [states, ...] 1398 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1399 linear_prediction, _, _ = linear_model(features, columns) 1400 ``` 1401 1402 And to make an embedding with either: 1403 1404 ```python 1405 columns = [embedding_column(states, 3),...] 1406 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1407 dense_tensor = input_layer(features, columns) 1408 ``` 1409 1410 Args: 1411 key: A unique string identifying the input feature. It is used as the 1412 column name and the dictionary key for feature parsing configs, feature 1413 `Tensor` objects, and feature columns. 1414 vocabulary_file: The vocabulary file name. 1415 vocabulary_size: Number of the elements in the vocabulary. This must be no 1416 greater than length of `vocabulary_file`, if less than length, later 1417 values are ignored. If None, it is set to the length of `vocabulary_file`. 1418 dtype: The type of features. Only string and integer types are supported. 1419 default_value: The integer ID value to return for out-of-vocabulary feature 1420 values, defaults to `-1`. This can not be specified with a positive 1421 `num_oov_buckets`. 1422 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1423 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1424 `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of 1425 the input value. A positive `num_oov_buckets` can not be specified with 1426 `default_value`. 1427 file_format: The format of the vocabulary file. The format is 'text' by 1428 default unless `vocabulary_file` is a string which ends in 'tfrecord.gz'. 1429 Accepted alternative value for `file_format` is 'tfrecord_gzip'. 1430 1431 Returns: 1432 A `CategoricalColumn` with a vocabulary file. 1433 1434 Raises: 1435 ValueError: `vocabulary_file` is missing or cannot be opened. 1436 ValueError: `vocabulary_size` is missing or < 1. 1437 ValueError: `num_oov_buckets` is a negative integer. 1438 ValueError: `num_oov_buckets` and `default_value` are both specified. 1439 ValueError: `dtype` is neither string nor integer. 1440 """ 1441 if not vocabulary_file: 1442 raise ValueError('Missing vocabulary_file in {}.'.format(key)) 1443 1444 if file_format is None and vocabulary_file.endswith('tfrecord.gz'): 1445 file_format = 'tfrecord_gzip' 1446 1447 if vocabulary_size is None: 1448 if not gfile.Exists(vocabulary_file): 1449 raise ValueError('vocabulary_file in {} does not exist.'.format(key)) 1450 1451 if file_format == 'tfrecord_gzip': 1452 ds = readers.TFRecordDataset(vocabulary_file, 'GZIP') 1453 vocabulary_size = ds.reduce(0, lambda x, _: x + 1) 1454 if context.executing_eagerly(): 1455 vocabulary_size = vocabulary_size.numpy() 1456 else: 1457 with gfile.GFile(vocabulary_file, mode='rb') as f: 1458 vocabulary_size = sum(1 for _ in f) 1459 logging.info( 1460 'vocabulary_size = %d in %s is inferred from the number of elements ' 1461 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file) 1462 1463 # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`. 1464 if not isinstance(vocabulary_size, ops.Tensor) and vocabulary_size < 1: 1465 raise ValueError('Invalid vocabulary_size in {}.'.format(key)) 1466 if num_oov_buckets: 1467 if default_value is not None: 1468 raise ValueError( 1469 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( 1470 key)) 1471 if num_oov_buckets < 0: 1472 raise ValueError('Invalid num_oov_buckets {} in {}.'.format( 1473 num_oov_buckets, key)) 1474 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1475 fc_utils.assert_key_is_string(key) 1476 return VocabularyFileCategoricalColumn( 1477 key=key, 1478 vocabulary_file=vocabulary_file, 1479 vocabulary_size=vocabulary_size, 1480 num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets, 1481 default_value=-1 if default_value is None else default_value, 1482 dtype=dtype, 1483 file_format=file_format) 1484 1485 1486@tf_export('feature_column.categorical_column_with_vocabulary_list') 1487def categorical_column_with_vocabulary_list(key, 1488 vocabulary_list, 1489 dtype=None, 1490 default_value=-1, 1491 num_oov_buckets=0): 1492 """A `CategoricalColumn` with in-memory vocabulary. 1493 1494 Use this when your inputs are in string or integer format, and you have an 1495 in-memory vocabulary mapping each value to an integer ID. By default, 1496 out-of-vocabulary values are ignored. Use either (but not both) of 1497 `num_oov_buckets` and `default_value` to specify how to include 1498 out-of-vocabulary values. 1499 1500 For input dictionary `features`, `features[key]` is either `Tensor` or 1501 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1502 and `''` for string, which will be dropped by this feature column. 1503 1504 Example with `num_oov_buckets`: 1505 In the following example, each input in `vocabulary_list` is assigned an ID 1506 0-3 corresponding to its index (e.g., input 'B' produces output 2). All other 1507 inputs are hashed and assigned an ID 4-5. 1508 1509 ```python 1510 colors = categorical_column_with_vocabulary_list( 1511 key='colors', vocabulary_list=('R', 'G', 'B', 'Y'), 1512 num_oov_buckets=2) 1513 columns = [colors, ...] 1514 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1515 linear_prediction, _, _ = linear_model(features, columns) 1516 ``` 1517 1518 Example with `default_value`: 1519 In the following example, each input in `vocabulary_list` is assigned an ID 1520 0-4 corresponding to its index (e.g., input 'B' produces output 3). All other 1521 inputs are assigned `default_value` 0. 1522 1523 1524 ```python 1525 colors = categorical_column_with_vocabulary_list( 1526 key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0) 1527 columns = [colors, ...] 1528 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1529 linear_prediction, _, _ = linear_model(features, columns) 1530 ``` 1531 1532 And to make an embedding with either: 1533 1534 ```python 1535 columns = [embedding_column(colors, 3),...] 1536 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1537 dense_tensor = input_layer(features, columns) 1538 ``` 1539 1540 Args: 1541 key: A unique string identifying the input feature. It is used as the column 1542 name and the dictionary key for feature parsing configs, feature `Tensor` 1543 objects, and feature columns. 1544 vocabulary_list: An ordered iterable defining the vocabulary. Each feature 1545 is mapped to the index of its value (if present) in `vocabulary_list`. 1546 Must be castable to `dtype`. 1547 dtype: The type of features. Only string and integer types are supported. If 1548 `None`, it will be inferred from `vocabulary_list`. 1549 default_value: The integer ID value to return for out-of-vocabulary feature 1550 values, defaults to `-1`. This can not be specified with a positive 1551 `num_oov_buckets`. 1552 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1553 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1554 `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a 1555 hash of the input value. A positive `num_oov_buckets` can not be specified 1556 with `default_value`. 1557 1558 Returns: 1559 A `CategoricalColumn` with in-memory vocabulary. 1560 1561 Raises: 1562 ValueError: if `vocabulary_list` is empty, or contains duplicate keys. 1563 ValueError: `num_oov_buckets` is a negative integer. 1564 ValueError: `num_oov_buckets` and `default_value` are both specified. 1565 ValueError: if `dtype` is not integer or string. 1566 """ 1567 if (vocabulary_list is None) or (len(vocabulary_list) < 1): 1568 raise ValueError( 1569 'vocabulary_list {} must be non-empty, column_name: {}'.format( 1570 vocabulary_list, key)) 1571 if len(set(vocabulary_list)) != len(vocabulary_list): 1572 raise ValueError( 1573 'Duplicate keys in vocabulary_list {}, column_name: {}'.format( 1574 vocabulary_list, key)) 1575 vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype) 1576 if num_oov_buckets: 1577 if default_value != -1: 1578 raise ValueError( 1579 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( 1580 key)) 1581 if num_oov_buckets < 0: 1582 raise ValueError('Invalid num_oov_buckets {} in {}.'.format( 1583 num_oov_buckets, key)) 1584 fc_utils.assert_string_or_int( 1585 vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key)) 1586 if dtype is None: 1587 dtype = vocabulary_dtype 1588 elif dtype.is_integer != vocabulary_dtype.is_integer: 1589 raise ValueError( 1590 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format( 1591 dtype, vocabulary_dtype, key)) 1592 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1593 fc_utils.assert_key_is_string(key) 1594 1595 return VocabularyListCategoricalColumn( 1596 key=key, 1597 vocabulary_list=tuple(vocabulary_list), 1598 dtype=dtype, 1599 default_value=default_value, 1600 num_oov_buckets=num_oov_buckets) 1601 1602 1603@tf_export('feature_column.categorical_column_with_identity') 1604def categorical_column_with_identity(key, num_buckets, default_value=None): 1605 """A `CategoricalColumn` that returns identity values. 1606 1607 Use this when your inputs are integers in the range `[0, num_buckets)`, and 1608 you want to use the input value itself as the categorical ID. Values outside 1609 this range will result in `default_value` if specified, otherwise it will 1610 fail. 1611 1612 Typically, this is used for contiguous ranges of integer indexes, but 1613 it doesn't have to be. This might be inefficient, however, if many of IDs 1614 are unused. Consider `categorical_column_with_hash_bucket` in that case. 1615 1616 For input dictionary `features`, `features[key]` is either `Tensor` or 1617 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1618 and `''` for string, which will be dropped by this feature column. 1619 1620 In the following examples, each input in the range `[0, 1000000)` is assigned 1621 the same value. All other inputs are assigned `default_value` 0. Note that a 1622 literal 0 in inputs will result in the same default ID. 1623 1624 Linear model: 1625 1626 ```python 1627 import tensorflow as tf 1628 video_id = tf.feature_column.categorical_column_with_identity( 1629 key='video_id', num_buckets=1000000, default_value=0) 1630 columns = [video_id] 1631 features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0], 1632 [33,78, 2, 73, 1]])} 1633 linear_prediction = tf.compat.v1.feature_column.linear_model(features, 1634 columns) 1635 ``` 1636 1637 Embedding for a DNN model: 1638 1639 ```python 1640 import tensorflow as tf 1641 video_id = tf.feature_column.categorical_column_with_identity( 1642 key='video_id', num_buckets=1000000, default_value=0) 1643 columns = [tf.feature_column.embedding_column(video_id, 9)] 1644 features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0], 1645 [33,78, 2, 73, 1]])} 1646 input_layer = tf.keras.layers.DenseFeatures(columns) 1647 dense_tensor = input_layer(features) 1648 ``` 1649 1650 Args: 1651 key: A unique string identifying the input feature. It is used as the 1652 column name and the dictionary key for feature parsing configs, feature 1653 `Tensor` objects, and feature columns. 1654 num_buckets: Range of inputs and outputs is `[0, num_buckets)`. 1655 default_value: If set, values outside of range `[0, num_buckets)` will 1656 be replaced with this value. If not set, values >= num_buckets will 1657 cause a failure while values < 0 will be dropped. 1658 1659 Returns: 1660 A `CategoricalColumn` that returns identity values. 1661 1662 Raises: 1663 ValueError: if `num_buckets` is less than one. 1664 ValueError: if `default_value` is not in range `[0, num_buckets)`. 1665 """ 1666 if num_buckets < 1: 1667 raise ValueError( 1668 'num_buckets {} < 1, column_name {}'.format(num_buckets, key)) 1669 if (default_value is not None) and ( 1670 (default_value < 0) or (default_value >= num_buckets)): 1671 raise ValueError( 1672 'default_value {} not in range [0, {}), column_name {}'.format( 1673 default_value, num_buckets, key)) 1674 fc_utils.assert_key_is_string(key) 1675 return IdentityCategoricalColumn( 1676 key=key, number_buckets=num_buckets, default_value=default_value) 1677 1678 1679@tf_export('feature_column.indicator_column') 1680def indicator_column(categorical_column): 1681 """Represents multi-hot representation of given categorical column. 1682 1683 - For DNN model, `indicator_column` can be used to wrap any 1684 `categorical_column_*` (e.g., to feed to DNN). Consider to Use 1685 `embedding_column` if the number of buckets/unique(values) are large. 1686 1687 - For Wide (aka linear) model, `indicator_column` is the internal 1688 representation for categorical column when passing categorical column 1689 directly (as any element in feature_columns) to `linear_model`. See 1690 `linear_model` for details. 1691 1692 ```python 1693 name = indicator_column(categorical_column_with_vocabulary_list( 1694 'name', ['bob', 'george', 'wanda'])) 1695 columns = [name, ...] 1696 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1697 dense_tensor = input_layer(features, columns) 1698 1699 dense_tensor == [[1, 0, 0]] # If "name" bytes_list is ["bob"] 1700 dense_tensor == [[1, 0, 1]] # If "name" bytes_list is ["bob", "wanda"] 1701 dense_tensor == [[2, 0, 0]] # If "name" bytes_list is ["bob", "bob"] 1702 ``` 1703 1704 Args: 1705 categorical_column: A `CategoricalColumn` which is created by 1706 `categorical_column_with_*` or `crossed_column` functions. 1707 1708 Returns: 1709 An `IndicatorColumn`. 1710 1711 Raises: 1712 ValueError: If `categorical_column` is not CategoricalColumn type. 1713 """ 1714 if not isinstance(categorical_column, 1715 (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access 1716 raise ValueError( 1717 'Unsupported input type. Input must be a CategoricalColumn. ' 1718 'Given: {}'.format(categorical_column)) 1719 return IndicatorColumn(categorical_column) 1720 1721 1722@tf_export('feature_column.weighted_categorical_column') 1723def weighted_categorical_column(categorical_column, 1724 weight_feature_key, 1725 dtype=dtypes.float32): 1726 """Applies weight values to a `CategoricalColumn`. 1727 1728 Use this when each of your sparse inputs has both an ID and a value. For 1729 example, if you're representing text documents as a collection of word 1730 frequencies, you can provide 2 parallel sparse input features ('terms' and 1731 'frequencies' below). 1732 1733 Example: 1734 1735 Input `tf.Example` objects: 1736 1737 ```proto 1738 [ 1739 features { 1740 feature { 1741 key: "terms" 1742 value {bytes_list {value: "very" value: "model"}} 1743 } 1744 feature { 1745 key: "frequencies" 1746 value {float_list {value: 0.3 value: 0.1}} 1747 } 1748 }, 1749 features { 1750 feature { 1751 key: "terms" 1752 value {bytes_list {value: "when" value: "course" value: "human"}} 1753 } 1754 feature { 1755 key: "frequencies" 1756 value {float_list {value: 0.4 value: 0.1 value: 0.2}} 1757 } 1758 } 1759 ] 1760 ``` 1761 1762 ```python 1763 categorical_column = categorical_column_with_hash_bucket( 1764 column_name='terms', hash_bucket_size=1000) 1765 weighted_column = weighted_categorical_column( 1766 categorical_column=categorical_column, weight_feature_key='frequencies') 1767 columns = [weighted_column, ...] 1768 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1769 linear_prediction, _, _ = linear_model(features, columns) 1770 ``` 1771 1772 This assumes the input dictionary contains a `SparseTensor` for key 1773 'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have 1774 the same indices and dense shape. 1775 1776 Args: 1777 categorical_column: A `CategoricalColumn` created by 1778 `categorical_column_with_*` functions. 1779 weight_feature_key: String key for weight values. 1780 dtype: Type of weights, such as `tf.float32`. Only float and integer weights 1781 are supported. 1782 1783 Returns: 1784 A `CategoricalColumn` composed of two sparse features: one represents id, 1785 the other represents weight (value) of the id feature in that example. 1786 1787 Raises: 1788 ValueError: if `dtype` is not convertible to float. 1789 """ 1790 if (dtype is None) or not (dtype.is_integer or dtype.is_floating): 1791 raise ValueError('dtype {} is not convertible to float.'.format(dtype)) 1792 return WeightedCategoricalColumn( 1793 categorical_column=categorical_column, 1794 weight_feature_key=weight_feature_key, 1795 dtype=dtype) 1796 1797 1798@tf_export('feature_column.crossed_column') 1799def crossed_column(keys, hash_bucket_size, hash_key=None): 1800 """Returns a column for performing crosses of categorical features. 1801 1802 Crossed features will be hashed according to `hash_bucket_size`. Conceptually, 1803 the transformation can be thought of as: 1804 Hash(cartesian product of features) % `hash_bucket_size` 1805 1806 For example, if the input features are: 1807 1808 * SparseTensor referred by first key: 1809 1810 ```python 1811 shape = [2, 2] 1812 { 1813 [0, 0]: "a" 1814 [1, 0]: "b" 1815 [1, 1]: "c" 1816 } 1817 ``` 1818 1819 * SparseTensor referred by second key: 1820 1821 ```python 1822 shape = [2, 1] 1823 { 1824 [0, 0]: "d" 1825 [1, 0]: "e" 1826 } 1827 ``` 1828 1829 then crossed feature will look like: 1830 1831 ```python 1832 shape = [2, 2] 1833 { 1834 [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size 1835 [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size 1836 [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size 1837 } 1838 ``` 1839 1840 Here is an example to create a linear model with crosses of string features: 1841 1842 ```python 1843 keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K) 1844 columns = [keywords_x_doc_terms, ...] 1845 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1846 linear_prediction = linear_model(features, columns) 1847 ``` 1848 1849 You could also use vocabulary lookup before crossing: 1850 1851 ```python 1852 keywords = categorical_column_with_vocabulary_file( 1853 'keywords', '/path/to/vocabulary/file', vocabulary_size=1K) 1854 keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K) 1855 columns = [keywords_x_doc_terms, ...] 1856 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1857 linear_prediction = linear_model(features, columns) 1858 ``` 1859 1860 If an input feature is of numeric type, you can use 1861 `categorical_column_with_identity`, or `bucketized_column`, as in the example: 1862 1863 ```python 1864 # vertical_id is an integer categorical feature. 1865 vertical_id = categorical_column_with_identity('vertical_id', 10K) 1866 price = numeric_column('price') 1867 # bucketized_column converts numerical feature to a categorical one. 1868 bucketized_price = bucketized_column(price, boundaries=[...]) 1869 vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K) 1870 columns = [vertical_id_x_price, ...] 1871 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1872 linear_prediction = linear_model(features, columns) 1873 ``` 1874 1875 To use crossed column in DNN model, you need to add it in an embedding column 1876 as in this example: 1877 1878 ```python 1879 vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K) 1880 vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10) 1881 dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...]) 1882 ``` 1883 1884 Args: 1885 keys: An iterable identifying the features to be crossed. Each element can 1886 be either: 1887 * string: Will use the corresponding feature which must be of string type. 1888 * `CategoricalColumn`: Will use the transformed tensor produced by this 1889 column. Does not support hashed categorical column. 1890 hash_bucket_size: An int > 1. The number of buckets. 1891 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 1892 function to combine the crosses fingerprints on SparseCrossOp (optional). 1893 1894 Returns: 1895 A `CrossedColumn`. 1896 1897 Raises: 1898 ValueError: If `len(keys) < 2`. 1899 ValueError: If any of the keys is neither a string nor `CategoricalColumn`. 1900 ValueError: If any of the keys is `HashedCategoricalColumn`. 1901 ValueError: If `hash_bucket_size < 1`. 1902 """ 1903 if not hash_bucket_size or hash_bucket_size < 1: 1904 raise ValueError('hash_bucket_size must be > 1. ' 1905 'hash_bucket_size: {}'.format(hash_bucket_size)) 1906 if not keys or len(keys) < 2: 1907 raise ValueError( 1908 'keys must be a list with length > 1. Given: {}'.format(keys)) 1909 for key in keys: 1910 if (not isinstance(key, six.string_types) and 1911 not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))): # pylint: disable=protected-access 1912 raise ValueError( 1913 'Unsupported key type. All keys must be either string, or ' 1914 'categorical column except HashedCategoricalColumn. ' 1915 'Given: {}'.format(key)) 1916 if isinstance(key, 1917 (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)): # pylint: disable=protected-access 1918 raise ValueError( 1919 'categorical_column_with_hash_bucket is not supported for crossing. ' 1920 'Hashing before crossing will increase probability of collision. ' 1921 'Instead, use the feature name as a string. Given: {}'.format(key)) 1922 return CrossedColumn( 1923 keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key) 1924 1925 1926# TODO(b/181853833): Add a tf.type for instance type checking. 1927@tf_export('__internal__.feature_column.FeatureColumn', v1=[]) 1928@six.add_metaclass(abc.ABCMeta) 1929class FeatureColumn(object): 1930 """Represents a feature column abstraction. 1931 1932 WARNING: Do not subclass this layer unless you know what you are doing: 1933 the API is subject to future changes. 1934 1935 To distinguish between the concept of a feature family and a specific binary 1936 feature within a family, we refer to a feature family like "country" as a 1937 feature column. For example, we can have a feature in a `tf.Example` format: 1938 {key: "country", value: [ "US" ]} 1939 In this example the value of feature is "US" and "country" refers to the 1940 column of the feature. 1941 1942 This class is an abstract class. Users should not create instances of this. 1943 """ 1944 1945 @abc.abstractproperty 1946 def name(self): 1947 """Returns string. Used for naming.""" 1948 pass 1949 1950 def __lt__(self, other): 1951 """Allows feature columns to be sorted in Python 3 as they are in Python 2. 1952 1953 Feature columns need to occasionally be sortable, for example when used as 1954 keys in a features dictionary passed to a layer. 1955 1956 In CPython, `__lt__` must be defined for all objects in the 1957 sequence being sorted. 1958 1959 If any objects in the sequence being sorted do not have an `__lt__` method 1960 compatible with feature column objects (such as strings), then CPython will 1961 fall back to using the `__gt__` method below. 1962 https://docs.python.org/3/library/stdtypes.html#list.sort 1963 1964 Args: 1965 other: The other object to compare to. 1966 1967 Returns: 1968 True if the string representation of this object is lexicographically less 1969 than the string representation of `other`. For FeatureColumn objects, 1970 this looks like "<__main__.FeatureColumn object at 0xa>". 1971 """ 1972 return str(self) < str(other) 1973 1974 def __gt__(self, other): 1975 """Allows feature columns to be sorted in Python 3 as they are in Python 2. 1976 1977 Feature columns need to occasionally be sortable, for example when used as 1978 keys in a features dictionary passed to a layer. 1979 1980 `__gt__` is called when the "other" object being compared during the sort 1981 does not have `__lt__` defined. 1982 Example: 1983 ``` 1984 # __lt__ only class 1985 class A(): 1986 def __lt__(self, other): return str(self) < str(other) 1987 1988 a = A() 1989 a < "b" # True 1990 "0" < a # Error 1991 1992 # __lt__ and __gt__ class 1993 class B(): 1994 def __lt__(self, other): return str(self) < str(other) 1995 def __gt__(self, other): return str(self) > str(other) 1996 1997 b = B() 1998 b < "c" # True 1999 "0" < b # True 2000 ``` 2001 2002 Args: 2003 other: The other object to compare to. 2004 2005 Returns: 2006 True if the string representation of this object is lexicographically 2007 greater than the string representation of `other`. For FeatureColumn 2008 objects, this looks like "<__main__.FeatureColumn object at 0xa>". 2009 """ 2010 return str(self) > str(other) 2011 2012 @abc.abstractmethod 2013 def transform_feature(self, transformation_cache, state_manager): 2014 """Returns intermediate representation (usually a `Tensor`). 2015 2016 Uses `transformation_cache` to create an intermediate representation 2017 (usually a `Tensor`) that other feature columns can use. 2018 2019 Example usage of `transformation_cache`: 2020 Let's say a Feature column depends on raw feature ('raw') and another 2021 `FeatureColumn` (input_fc). To access corresponding `Tensor`s, 2022 transformation_cache will be used as follows: 2023 2024 ```python 2025 raw_tensor = transformation_cache.get('raw', state_manager) 2026 fc_tensor = transformation_cache.get(input_fc, state_manager) 2027 ``` 2028 2029 Args: 2030 transformation_cache: A `FeatureTransformationCache` object to access 2031 features. 2032 state_manager: A `StateManager` to create / access resources such as 2033 lookup tables. 2034 2035 Returns: 2036 Transformed feature `Tensor`. 2037 """ 2038 pass 2039 2040 @abc.abstractproperty 2041 def parse_example_spec(self): 2042 """Returns a `tf.Example` parsing spec as dict. 2043 2044 It is used for get_parsing_spec for `tf.io.parse_example`. Returned spec is 2045 a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other 2046 supported objects. Please check documentation of `tf.io.parse_example` for 2047 all supported spec objects. 2048 2049 Let's say a Feature column depends on raw feature ('raw') and another 2050 `FeatureColumn` (input_fc). One possible implementation of 2051 parse_example_spec is as follows: 2052 2053 ```python 2054 spec = {'raw': tf.io.FixedLenFeature(...)} 2055 spec.update(input_fc.parse_example_spec) 2056 return spec 2057 ``` 2058 """ 2059 pass 2060 2061 def create_state(self, state_manager): 2062 """Uses the `state_manager` to create state for the FeatureColumn. 2063 2064 Args: 2065 state_manager: A `StateManager` to create / access resources such as 2066 lookup tables and variables. 2067 """ 2068 pass 2069 2070 @abc.abstractproperty 2071 def _is_v2_column(self): 2072 """Returns whether this FeatureColumn is fully conformant to the new API. 2073 2074 This is needed for composition type cases where an EmbeddingColumn etc. 2075 might take in old categorical columns as input and then we want to use the 2076 old API. 2077 """ 2078 pass 2079 2080 @abc.abstractproperty 2081 def parents(self): 2082 """Returns a list of immediate raw feature and FeatureColumn dependencies. 2083 2084 For example: 2085 # For the following feature columns 2086 a = numeric_column('f1') 2087 c = crossed_column(a, 'f2') 2088 # The expected parents are: 2089 a.parents = ['f1'] 2090 c.parents = [a, 'f2'] 2091 """ 2092 pass 2093 2094 def get_config(self): 2095 """Returns the config of the feature column. 2096 2097 A FeatureColumn config is a Python dictionary (serializable) containing the 2098 configuration of a FeatureColumn. The same FeatureColumn can be 2099 reinstantiated later from this configuration. 2100 2101 The config of a feature column does not include information about feature 2102 columns depending on it nor the FeatureColumn class name. 2103 2104 Example with (de)serialization practices followed in this file: 2105 ```python 2106 class SerializationExampleFeatureColumn( 2107 FeatureColumn, collections.namedtuple( 2108 'SerializationExampleFeatureColumn', 2109 ('dimension', 'parent', 'dtype', 'normalizer_fn'))): 2110 2111 def get_config(self): 2112 # Create a dict from the namedtuple. 2113 # Python attribute literals can be directly copied from / to the config. 2114 # For example 'dimension', assuming it is an integer literal. 2115 config = dict(zip(self._fields, self)) 2116 2117 # (De)serialization of parent FeatureColumns should use the provided 2118 # (de)serialize_feature_column() methods that take care of de-duping. 2119 config['parent'] = serialize_feature_column(self.parent) 2120 2121 # Many objects provide custom (de)serialization e.g: for tf.DType 2122 # tf.DType.name, tf.as_dtype() can be used. 2123 config['dtype'] = self.dtype.name 2124 2125 # Non-trivial dependencies should be Keras-(de)serializable. 2126 config['normalizer_fn'] = generic_utils.serialize_keras_object( 2127 self.normalizer_fn) 2128 2129 return config 2130 2131 @classmethod 2132 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2133 # This should do the inverse transform from `get_config` and construct 2134 # the namedtuple. 2135 kwargs = config.copy() 2136 kwargs['parent'] = deserialize_feature_column( 2137 config['parent'], custom_objects, columns_by_name) 2138 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 2139 kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object( 2140 config['normalizer_fn'], custom_objects=custom_objects) 2141 return cls(**kwargs) 2142 2143 ``` 2144 Returns: 2145 A serializable Dict that can be used to deserialize the object with 2146 from_config. 2147 """ 2148 return self._get_config() 2149 2150 def _get_config(self): 2151 raise NotImplementedError('Must be implemented in subclasses.') 2152 2153 @classmethod 2154 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2155 """Creates a FeatureColumn from its config. 2156 2157 This method should be the reverse of `get_config`, capable of instantiating 2158 the same FeatureColumn from the config dictionary. See `get_config` for an 2159 example of common (de)serialization practices followed in this file. 2160 2161 TODO(b/118939620): This is a private method until consensus is reached on 2162 supporting object deserialization deduping within Keras. 2163 2164 Args: 2165 config: A Dict config acquired with `get_config`. 2166 custom_objects: Optional dictionary mapping names (strings) to custom 2167 classes or functions to be considered during deserialization. 2168 columns_by_name: A Dict[String, FeatureColumn] of existing columns in 2169 order to avoid duplication. Should be passed to any calls to 2170 deserialize_feature_column(). 2171 2172 Returns: 2173 A FeatureColumn for the input config. 2174 """ 2175 return cls._from_config(config, custom_objects, columns_by_name) 2176 2177 @classmethod 2178 def _from_config(cls, config, custom_objects=None, columns_by_name=None): 2179 raise NotImplementedError('Must be implemented in subclasses.') 2180 2181 2182# TODO(b/181853833): Add a tf.type for instance type checking. 2183@tf_export('__internal__.feature_column.DenseColumn', v1=[]) 2184class DenseColumn(FeatureColumn): 2185 """Represents a column which can be represented as `Tensor`. 2186 2187 Some examples of this type are: numeric_column, embedding_column, 2188 indicator_column. 2189 """ 2190 2191 @abc.abstractproperty 2192 def variable_shape(self): 2193 """`TensorShape` of `get_dense_tensor`, without batch dimension.""" 2194 pass 2195 2196 @abc.abstractmethod 2197 def get_dense_tensor(self, transformation_cache, state_manager): 2198 """Returns a `Tensor`. 2199 2200 The output of this function will be used by model-builder-functions. For 2201 example the pseudo code of `input_layer` will be like: 2202 2203 ```python 2204 def input_layer(features, feature_columns, ...): 2205 outputs = [fc.get_dense_tensor(...) for fc in feature_columns] 2206 return tf.concat(outputs) 2207 ``` 2208 2209 Args: 2210 transformation_cache: A `FeatureTransformationCache` object to access 2211 features. 2212 state_manager: A `StateManager` to create / access resources such as 2213 lookup tables. 2214 2215 Returns: 2216 `Tensor` of shape [batch_size] + `variable_shape`. 2217 """ 2218 pass 2219 2220 2221def is_feature_column_v2(feature_columns): 2222 """Returns True if all feature columns are V2.""" 2223 for feature_column in feature_columns: 2224 if not isinstance(feature_column, FeatureColumn): 2225 return False 2226 if not feature_column._is_v2_column: # pylint: disable=protected-access 2227 return False 2228 return True 2229 2230 2231def _create_weighted_sum(column, transformation_cache, state_manager, 2232 sparse_combiner, weight_var): 2233 """Creates a weighted sum for a dense/categorical column for linear_model.""" 2234 if isinstance(column, CategoricalColumn): 2235 return _create_categorical_column_weighted_sum( 2236 column=column, 2237 transformation_cache=transformation_cache, 2238 state_manager=state_manager, 2239 sparse_combiner=sparse_combiner, 2240 weight_var=weight_var) 2241 else: 2242 return _create_dense_column_weighted_sum( 2243 column=column, 2244 transformation_cache=transformation_cache, 2245 state_manager=state_manager, 2246 weight_var=weight_var) 2247 2248 2249def _create_dense_column_weighted_sum(column, transformation_cache, 2250 state_manager, weight_var): 2251 """Create a weighted sum of a dense column for linear_model.""" 2252 tensor = column.get_dense_tensor(transformation_cache, state_manager) 2253 num_elements = column.variable_shape.num_elements() 2254 batch_size = array_ops.shape(tensor)[0] 2255 tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements)) 2256 return math_ops.matmul(tensor, weight_var, name='weighted_sum') 2257 2258 2259class CategoricalColumn(FeatureColumn): 2260 """Represents a categorical feature. 2261 2262 A categorical feature typically handled with a `tf.sparse.SparseTensor` of 2263 IDs. 2264 """ 2265 2266 IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name 2267 'IdWeightPair', ('id_tensor', 'weight_tensor')) 2268 2269 @abc.abstractproperty 2270 def num_buckets(self): 2271 """Returns number of buckets in this sparse feature.""" 2272 pass 2273 2274 @abc.abstractmethod 2275 def get_sparse_tensors(self, transformation_cache, state_manager): 2276 """Returns an IdWeightPair. 2277 2278 `IdWeightPair` is a pair of `SparseTensor`s which represents ids and 2279 weights. 2280 2281 `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets` 2282 `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a 2283 `SparseTensor` of `float` or `None` to indicate all weights should be 2284 taken to be 1. If specified, `weight_tensor` must have exactly the same 2285 shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing 2286 output of a `VarLenFeature` which is a ragged matrix. 2287 2288 Args: 2289 transformation_cache: A `FeatureTransformationCache` object to access 2290 features. 2291 state_manager: A `StateManager` to create / access resources such as 2292 lookup tables. 2293 """ 2294 pass 2295 2296 2297def _create_categorical_column_weighted_sum( 2298 column, transformation_cache, state_manager, sparse_combiner, weight_var): 2299 # pylint: disable=g-doc-return-or-yield,g-doc-args 2300 """Create a weighted sum of a categorical column for linear_model. 2301 2302 Note to maintainer: As implementation details, the weighted sum is 2303 implemented via embedding_lookup_sparse toward efficiency. Mathematically, 2304 they are the same. 2305 2306 To be specific, conceptually, categorical column can be treated as multi-hot 2307 vector. Say: 2308 2309 ```python 2310 x = [0 0 1] # categorical column input 2311 w = [a b c] # weights 2312 ``` 2313 The weighted sum is `c` in this case, which is same as `w[2]`. 2314 2315 Another example is 2316 2317 ```python 2318 x = [0 1 1] # categorical column input 2319 w = [a b c] # weights 2320 ``` 2321 The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`. 2322 2323 For both cases, we can implement weighted sum via embedding_lookup with 2324 sparse_combiner = "sum". 2325 """ 2326 2327 sparse_tensors = column.get_sparse_tensors(transformation_cache, 2328 state_manager) 2329 id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [ 2330 array_ops.shape(sparse_tensors.id_tensor)[0], -1 2331 ]) 2332 weight_tensor = sparse_tensors.weight_tensor 2333 if weight_tensor is not None: 2334 weight_tensor = sparse_ops.sparse_reshape( 2335 weight_tensor, [array_ops.shape(weight_tensor)[0], -1]) 2336 2337 return embedding_ops.safe_embedding_lookup_sparse( 2338 weight_var, 2339 id_tensor, 2340 sparse_weights=weight_tensor, 2341 combiner=sparse_combiner, 2342 name='weighted_sum') 2343 2344 2345# TODO(b/181853833): Add a tf.type for instance type checking. 2346@tf_export('__internal__.feature_column.SequenceDenseColumn', v1=[]) 2347class SequenceDenseColumn(FeatureColumn): 2348 """Represents dense sequence data.""" 2349 2350 TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name 2351 'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length')) 2352 2353 @abc.abstractmethod 2354 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 2355 """Returns a `TensorSequenceLengthPair`. 2356 2357 Args: 2358 transformation_cache: A `FeatureTransformationCache` object to access 2359 features. 2360 state_manager: A `StateManager` to create / access resources such as 2361 lookup tables. 2362 """ 2363 pass 2364 2365 2366@tf_export('__internal__.feature_column.FeatureTransformationCache', v1=[]) 2367class FeatureTransformationCache(object): 2368 """Handles caching of transformations while building the model. 2369 2370 `FeatureColumn` specifies how to digest an input column to the network. Some 2371 feature columns require data transformations. This class caches those 2372 transformations. 2373 2374 Some features may be used in more than one place. For example, one can use a 2375 bucketized feature by itself and a cross with it. In that case we 2376 should create only one bucketization op instead of creating ops for each 2377 feature column separately. To handle re-use of transformed columns, 2378 `FeatureTransformationCache` caches all previously transformed columns. 2379 2380 Example: 2381 We're trying to use the following `FeatureColumn`s: 2382 2383 ```python 2384 bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...) 2385 keywords = fc.categorical_column_with_hash_buckets("keywords", ...) 2386 age_X_keywords = fc.crossed_column([bucketized_age, "keywords"]) 2387 ... = linear_model(features, 2388 [bucketized_age, keywords, age_X_keywords] 2389 ``` 2390 2391 If we transform each column independently, then we'll get duplication of 2392 bucketization (one for cross, one for bucketization itself). 2393 The `FeatureTransformationCache` eliminates this duplication. 2394 """ 2395 2396 def __init__(self, features): 2397 """Creates a `FeatureTransformationCache`. 2398 2399 Args: 2400 features: A mapping from feature column to objects that are `Tensor` or 2401 `SparseTensor`, or can be converted to same via 2402 `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key 2403 signifies a base feature (not-transformed). A `FeatureColumn` key 2404 means that this `Tensor` is the output of an existing `FeatureColumn` 2405 which can be reused. 2406 """ 2407 self._features = features.copy() 2408 self._feature_tensors = {} 2409 2410 def get(self, key, state_manager, training=None): 2411 """Returns a `Tensor` for the given key. 2412 2413 A `str` key is used to access a base feature (not-transformed). When a 2414 `FeatureColumn` is passed, the transformed feature is returned if it 2415 already exists, otherwise the given `FeatureColumn` is asked to provide its 2416 transformed output, which is then cached. 2417 2418 Args: 2419 key: a `str` or a `FeatureColumn`. 2420 state_manager: A StateManager object that holds the FeatureColumn state. 2421 training: Boolean indicating whether to the column is being used in 2422 training mode. This argument is passed to the transform_feature method 2423 of any `FeatureColumn` that takes a `training` argument. For example, if 2424 a `FeatureColumn` performed dropout, it could expose a `training` 2425 argument to control whether the dropout should be applied. 2426 2427 Returns: 2428 The transformed `Tensor` corresponding to the `key`. 2429 2430 Raises: 2431 ValueError: if key is not found or a transformed `Tensor` cannot be 2432 computed. 2433 """ 2434 if key in self._feature_tensors: 2435 # FeatureColumn is already transformed or converted. 2436 return self._feature_tensors[key] 2437 2438 if key in self._features: 2439 feature_tensor = self._get_raw_feature_as_tensor(key) 2440 self._feature_tensors[key] = feature_tensor 2441 return feature_tensor 2442 2443 if isinstance(key, six.string_types): 2444 raise ValueError('Feature {} is not in features dictionary.'.format(key)) 2445 2446 if not isinstance(key, FeatureColumn): 2447 raise TypeError('"key" must be either a "str" or "FeatureColumn". ' 2448 'Provided: {}'.format(key)) 2449 2450 column = key 2451 logging.debug('Transforming feature_column %s.', column) 2452 2453 # Some columns may need information about whether the transformation is 2454 # happening in training or prediction mode, but not all columns expose this 2455 # argument. 2456 try: 2457 transformed = column.transform_feature( 2458 self, state_manager, training=training) 2459 except TypeError: 2460 transformed = column.transform_feature(self, state_manager) 2461 if transformed is None: 2462 raise ValueError('Column {} is not supported.'.format(column.name)) 2463 self._feature_tensors[column] = transformed 2464 return transformed 2465 2466 def _get_raw_feature_as_tensor(self, key): 2467 """Gets the raw_feature (keyed by `key`) as `tensor`. 2468 2469 The raw feature is converted to (sparse) tensor and maybe expand dim. 2470 2471 For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if 2472 the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will 2473 error out as it is not supported. 2474 2475 Args: 2476 key: A `str` key to access the raw feature. 2477 2478 Returns: 2479 A `Tensor` or `SparseTensor`. 2480 2481 Raises: 2482 ValueError: if the raw feature has rank 0. 2483 """ 2484 raw_feature = self._features[key] 2485 feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 2486 raw_feature) 2487 2488 def expand_dims(input_tensor): 2489 # Input_tensor must have rank 1. 2490 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2491 return sparse_ops.sparse_reshape( 2492 input_tensor, [array_ops.shape(input_tensor)[0], 1]) 2493 else: 2494 return array_ops.expand_dims(input_tensor, -1) 2495 2496 rank = feature_tensor.get_shape().ndims 2497 if rank is not None: 2498 if rank == 0: 2499 raise ValueError( 2500 'Feature (key: {}) cannot have rank 0. Given: {}'.format( 2501 key, feature_tensor)) 2502 return feature_tensor if rank != 1 else expand_dims(feature_tensor) 2503 2504 # Handle dynamic rank. 2505 with ops.control_dependencies([ 2506 check_ops.assert_positive( 2507 array_ops.rank(feature_tensor), 2508 message='Feature (key: {}) cannot have rank 0. Given: {}'.format( 2509 key, feature_tensor))]): 2510 return control_flow_ops.cond( 2511 math_ops.equal(1, array_ops.rank(feature_tensor)), 2512 lambda: expand_dims(feature_tensor), 2513 lambda: feature_tensor) 2514 2515 2516# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py 2517def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None): 2518 """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. 2519 2520 If `input_tensor` is already a `SparseTensor`, just return it. 2521 2522 Args: 2523 input_tensor: A string or integer `Tensor`. 2524 ignore_value: Entries in `dense_tensor` equal to this value will be 2525 absent from the resulting `SparseTensor`. If `None`, default value of 2526 `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`). 2527 2528 Returns: 2529 A `SparseTensor` with the same shape as `input_tensor`. 2530 2531 Raises: 2532 ValueError: when `input_tensor`'s rank is `None`. 2533 """ 2534 input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 2535 input_tensor) 2536 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2537 return input_tensor 2538 with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)): 2539 if ignore_value is None: 2540 if input_tensor.dtype == dtypes.string: 2541 # Exception due to TF strings are converted to numpy objects by default. 2542 ignore_value = '' 2543 elif input_tensor.dtype.is_integer: 2544 ignore_value = -1 # -1 has a special meaning of missing feature 2545 else: 2546 # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is 2547 # constructing a new numpy object of the given type, which yields the 2548 # default value for that type. 2549 ignore_value = input_tensor.dtype.as_numpy_dtype() 2550 ignore_value = math_ops.cast( 2551 ignore_value, input_tensor.dtype, name='ignore_value') 2552 indices = array_ops.where_v2( 2553 math_ops.not_equal(input_tensor, ignore_value), name='indices') 2554 return sparse_tensor_lib.SparseTensor( 2555 indices=indices, 2556 values=array_ops.gather_nd(input_tensor, indices, name='values'), 2557 dense_shape=array_ops.shape( 2558 input_tensor, out_type=dtypes.int64, name='dense_shape')) 2559 2560 2561def _normalize_feature_columns(feature_columns): 2562 """Normalizes the `feature_columns` input. 2563 2564 This method converts the `feature_columns` to list type as best as it can. In 2565 addition, verifies the type and other parts of feature_columns, required by 2566 downstream library. 2567 2568 Args: 2569 feature_columns: The raw feature columns, usually passed by users. 2570 2571 Returns: 2572 The normalized feature column list. 2573 2574 Raises: 2575 ValueError: for any invalid inputs, such as empty, duplicated names, etc. 2576 """ 2577 if isinstance(feature_columns, FeatureColumn): 2578 feature_columns = [feature_columns] 2579 2580 if isinstance(feature_columns, collections_abc.Iterator): 2581 feature_columns = list(feature_columns) 2582 2583 if isinstance(feature_columns, dict): 2584 raise ValueError('Expected feature_columns to be iterable, found dict.') 2585 2586 for column in feature_columns: 2587 if not isinstance(column, FeatureColumn): 2588 raise ValueError('Items of feature_columns must be a FeatureColumn. ' 2589 'Given (type {}): {}.'.format(type(column), column)) 2590 if not feature_columns: 2591 raise ValueError('feature_columns must not be empty.') 2592 name_to_column = {} 2593 for column in feature_columns: 2594 if column.name in name_to_column: 2595 raise ValueError('Duplicate feature column name found for columns: {} ' 2596 'and {}. This usually means that these columns refer to ' 2597 'same base feature. Either one must be discarded or a ' 2598 'duplicated but renamed item must be inserted in ' 2599 'features dict.'.format(column, 2600 name_to_column[column.name])) 2601 name_to_column[column.name] = column 2602 2603 return sorted(feature_columns, key=lambda x: x.name) 2604 2605 2606class NumericColumn( 2607 DenseColumn, 2608 fc_old._DenseColumn, # pylint: disable=protected-access 2609 collections.namedtuple( 2610 'NumericColumn', 2611 ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))): 2612 """see `numeric_column`.""" 2613 2614 @property 2615 def _is_v2_column(self): 2616 return True 2617 2618 @property 2619 def name(self): 2620 """See `FeatureColumn` base class.""" 2621 return self.key 2622 2623 @property 2624 def parse_example_spec(self): 2625 """See `FeatureColumn` base class.""" 2626 return { 2627 self.key: 2628 parsing_ops.FixedLenFeature(self.shape, self.dtype, 2629 self.default_value) 2630 } 2631 2632 @property 2633 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2634 _FEATURE_COLUMN_DEPRECATION) 2635 def _parse_example_spec(self): 2636 return self.parse_example_spec 2637 2638 def _transform_input_tensor(self, input_tensor): 2639 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2640 raise ValueError( 2641 'The corresponding Tensor of numerical column must be a Tensor. ' 2642 'SparseTensor is not supported. key: {}'.format(self.key)) 2643 if self.normalizer_fn is not None: 2644 input_tensor = self.normalizer_fn(input_tensor) 2645 return math_ops.cast(input_tensor, dtypes.float32) 2646 2647 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2648 _FEATURE_COLUMN_DEPRECATION) 2649 def _transform_feature(self, inputs): 2650 input_tensor = inputs.get(self.key) 2651 return self._transform_input_tensor(input_tensor) 2652 2653 def transform_feature(self, transformation_cache, state_manager): 2654 """See `FeatureColumn` base class. 2655 2656 In this case, we apply the `normalizer_fn` to the input tensor. 2657 2658 Args: 2659 transformation_cache: A `FeatureTransformationCache` object to access 2660 features. 2661 state_manager: A `StateManager` to create / access resources such as 2662 lookup tables. 2663 2664 Returns: 2665 Normalized input tensor. 2666 Raises: 2667 ValueError: If a SparseTensor is passed in. 2668 """ 2669 input_tensor = transformation_cache.get(self.key, state_manager) 2670 return self._transform_input_tensor(input_tensor) 2671 2672 @property 2673 def variable_shape(self): 2674 """See `DenseColumn` base class.""" 2675 return tensor_shape.TensorShape(self.shape) 2676 2677 @property 2678 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2679 _FEATURE_COLUMN_DEPRECATION) 2680 def _variable_shape(self): 2681 return self.variable_shape 2682 2683 def get_dense_tensor(self, transformation_cache, state_manager): 2684 """Returns dense `Tensor` representing numeric feature. 2685 2686 Args: 2687 transformation_cache: A `FeatureTransformationCache` object to access 2688 features. 2689 state_manager: A `StateManager` to create / access resources such as 2690 lookup tables. 2691 2692 Returns: 2693 Dense `Tensor` created within `transform_feature`. 2694 """ 2695 # Feature has been already transformed. Return the intermediate 2696 # representation created by _transform_feature. 2697 return transformation_cache.get(self, state_manager) 2698 2699 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2700 _FEATURE_COLUMN_DEPRECATION) 2701 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 2702 del weight_collections 2703 del trainable 2704 return inputs.get(self) 2705 2706 @property 2707 def parents(self): 2708 """See 'FeatureColumn` base class.""" 2709 return [self.key] 2710 2711 def get_config(self): 2712 """See 'FeatureColumn` base class.""" 2713 config = dict(zip(self._fields, self)) 2714 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 2715 config['normalizer_fn'] = serialization._serialize_keras_object( # pylint: disable=protected-access 2716 self.normalizer_fn) 2717 config['dtype'] = self.dtype.name 2718 return config 2719 2720 @classmethod 2721 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2722 """See 'FeatureColumn` base class.""" 2723 _check_config_keys(config, cls._fields) 2724 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 2725 kwargs = _standardize_and_copy_config(config) 2726 kwargs['normalizer_fn'] = serialization._deserialize_keras_object( # pylint: disable=protected-access 2727 config['normalizer_fn'], custom_objects=custom_objects) 2728 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 2729 2730 return cls(**kwargs) 2731 2732 2733class BucketizedColumn( 2734 DenseColumn, 2735 CategoricalColumn, 2736 fc_old._DenseColumn, # pylint: disable=protected-access 2737 fc_old._CategoricalColumn, # pylint: disable=protected-access 2738 collections.namedtuple('BucketizedColumn', 2739 ('source_column', 'boundaries'))): 2740 """See `bucketized_column`.""" 2741 2742 @property 2743 def _is_v2_column(self): 2744 return (isinstance(self.source_column, FeatureColumn) and 2745 self.source_column._is_v2_column) # pylint: disable=protected-access 2746 2747 @property 2748 def name(self): 2749 """See `FeatureColumn` base class.""" 2750 return '{}_bucketized'.format(self.source_column.name) 2751 2752 @property 2753 def parse_example_spec(self): 2754 """See `FeatureColumn` base class.""" 2755 return self.source_column.parse_example_spec 2756 2757 @property 2758 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2759 _FEATURE_COLUMN_DEPRECATION) 2760 def _parse_example_spec(self): 2761 return self.source_column._parse_example_spec # pylint: disable=protected-access 2762 2763 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2764 _FEATURE_COLUMN_DEPRECATION) 2765 def _transform_feature(self, inputs): 2766 """Returns bucketized categorical `source_column` tensor.""" 2767 source_tensor = inputs.get(self.source_column) 2768 return math_ops._bucketize( # pylint: disable=protected-access 2769 source_tensor, 2770 boundaries=self.boundaries) 2771 2772 def transform_feature(self, transformation_cache, state_manager): 2773 """Returns bucketized categorical `source_column` tensor.""" 2774 source_tensor = transformation_cache.get(self.source_column, state_manager) 2775 return math_ops._bucketize( # pylint: disable=protected-access 2776 source_tensor, 2777 boundaries=self.boundaries) 2778 2779 @property 2780 def variable_shape(self): 2781 """See `DenseColumn` base class.""" 2782 return tensor_shape.TensorShape( 2783 tuple(self.source_column.shape) + (len(self.boundaries) + 1,)) 2784 2785 @property 2786 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2787 _FEATURE_COLUMN_DEPRECATION) 2788 def _variable_shape(self): 2789 return self.variable_shape 2790 2791 def _get_dense_tensor_for_input_tensor(self, input_tensor): 2792 return array_ops.one_hot( 2793 indices=math_ops.cast(input_tensor, dtypes.int64), 2794 depth=len(self.boundaries) + 1, 2795 on_value=1., 2796 off_value=0.) 2797 2798 def get_dense_tensor(self, transformation_cache, state_manager): 2799 """Returns one hot encoded dense `Tensor`.""" 2800 input_tensor = transformation_cache.get(self, state_manager) 2801 return self._get_dense_tensor_for_input_tensor(input_tensor) 2802 2803 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2804 _FEATURE_COLUMN_DEPRECATION) 2805 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 2806 del weight_collections 2807 del trainable 2808 input_tensor = inputs.get(self) 2809 return self._get_dense_tensor_for_input_tensor(input_tensor) 2810 2811 @property 2812 def num_buckets(self): 2813 """See `CategoricalColumn` base class.""" 2814 # By construction, source_column is always one-dimensional. 2815 return (len(self.boundaries) + 1) * self.source_column.shape[0] 2816 2817 @property 2818 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2819 _FEATURE_COLUMN_DEPRECATION) 2820 def _num_buckets(self): 2821 return self.num_buckets 2822 2823 def _get_sparse_tensors_for_input_tensor(self, input_tensor): 2824 batch_size = array_ops.shape(input_tensor)[0] 2825 # By construction, source_column is always one-dimensional. 2826 source_dimension = self.source_column.shape[0] 2827 2828 i1 = array_ops.reshape( 2829 array_ops.tile( 2830 array_ops.expand_dims(math_ops.range(0, batch_size), 1), 2831 [1, source_dimension]), 2832 (-1,)) 2833 i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size]) 2834 # Flatten the bucket indices and unique them across dimensions 2835 # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets 2836 bucket_indices = ( 2837 array_ops.reshape(input_tensor, (-1,)) + 2838 (len(self.boundaries) + 1) * i2) 2839 2840 indices = math_ops.cast( 2841 array_ops.transpose(array_ops.stack((i1, i2))), dtypes.int64) 2842 dense_shape = math_ops.cast( 2843 array_ops.stack([batch_size, source_dimension]), dtypes.int64) 2844 sparse_tensor = sparse_tensor_lib.SparseTensor( 2845 indices=indices, 2846 values=bucket_indices, 2847 dense_shape=dense_shape) 2848 return CategoricalColumn.IdWeightPair(sparse_tensor, None) 2849 2850 def get_sparse_tensors(self, transformation_cache, state_manager): 2851 """Converts dense inputs to SparseTensor so downstream code can use it.""" 2852 input_tensor = transformation_cache.get(self, state_manager) 2853 return self._get_sparse_tensors_for_input_tensor(input_tensor) 2854 2855 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2856 _FEATURE_COLUMN_DEPRECATION) 2857 def _get_sparse_tensors(self, inputs, weight_collections=None, 2858 trainable=None): 2859 """Converts dense inputs to SparseTensor so downstream code can use it.""" 2860 del weight_collections 2861 del trainable 2862 input_tensor = inputs.get(self) 2863 return self._get_sparse_tensors_for_input_tensor(input_tensor) 2864 2865 @property 2866 def parents(self): 2867 """See 'FeatureColumn` base class.""" 2868 return [self.source_column] 2869 2870 def get_config(self): 2871 """See 'FeatureColumn` base class.""" 2872 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 2873 config = dict(zip(self._fields, self)) 2874 config['source_column'] = serialize_feature_column(self.source_column) 2875 return config 2876 2877 @classmethod 2878 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2879 """See 'FeatureColumn` base class.""" 2880 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 2881 _check_config_keys(config, cls._fields) 2882 kwargs = _standardize_and_copy_config(config) 2883 kwargs['source_column'] = deserialize_feature_column( 2884 config['source_column'], custom_objects, columns_by_name) 2885 return cls(**kwargs) 2886 2887 2888class EmbeddingColumn( 2889 DenseColumn, 2890 SequenceDenseColumn, 2891 fc_old._DenseColumn, # pylint: disable=protected-access 2892 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 2893 collections.namedtuple( 2894 'EmbeddingColumn', 2895 ('categorical_column', 'dimension', 'combiner', 'initializer', 2896 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable', 2897 'use_safe_embedding_lookup'))): 2898 """See `embedding_column`.""" 2899 2900 def __new__(cls, 2901 categorical_column, 2902 dimension, 2903 combiner, 2904 initializer, 2905 ckpt_to_load_from, 2906 tensor_name_in_ckpt, 2907 max_norm, 2908 trainable, 2909 use_safe_embedding_lookup=True): 2910 return super(EmbeddingColumn, cls).__new__( 2911 cls, 2912 categorical_column=categorical_column, 2913 dimension=dimension, 2914 combiner=combiner, 2915 initializer=initializer, 2916 ckpt_to_load_from=ckpt_to_load_from, 2917 tensor_name_in_ckpt=tensor_name_in_ckpt, 2918 max_norm=max_norm, 2919 trainable=trainable, 2920 use_safe_embedding_lookup=use_safe_embedding_lookup) 2921 2922 @property 2923 def _is_v2_column(self): 2924 return (isinstance(self.categorical_column, FeatureColumn) and 2925 self.categorical_column._is_v2_column) # pylint: disable=protected-access 2926 2927 @property 2928 def name(self): 2929 """See `FeatureColumn` base class.""" 2930 return '{}_embedding'.format(self.categorical_column.name) 2931 2932 @property 2933 def parse_example_spec(self): 2934 """See `FeatureColumn` base class.""" 2935 return self.categorical_column.parse_example_spec 2936 2937 @property 2938 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2939 _FEATURE_COLUMN_DEPRECATION) 2940 def _parse_example_spec(self): 2941 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 2942 2943 def transform_feature(self, transformation_cache, state_manager): 2944 """Transforms underlying `categorical_column`.""" 2945 return transformation_cache.get(self.categorical_column, state_manager) 2946 2947 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2948 _FEATURE_COLUMN_DEPRECATION) 2949 def _transform_feature(self, inputs): 2950 return inputs.get(self.categorical_column) 2951 2952 @property 2953 def variable_shape(self): 2954 """See `DenseColumn` base class.""" 2955 return tensor_shape.TensorShape([self.dimension]) 2956 2957 @property 2958 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2959 _FEATURE_COLUMN_DEPRECATION) 2960 def _variable_shape(self): 2961 return self.variable_shape 2962 2963 def create_state(self, state_manager): 2964 """Creates the embedding lookup variable.""" 2965 default_num_buckets = (self.categorical_column.num_buckets 2966 if self._is_v2_column 2967 else self.categorical_column._num_buckets) # pylint: disable=protected-access 2968 num_buckets = getattr(self.categorical_column, 'num_buckets', 2969 default_num_buckets) 2970 embedding_shape = (num_buckets, self.dimension) 2971 state_manager.create_variable( 2972 self, 2973 name='embedding_weights', 2974 shape=embedding_shape, 2975 dtype=dtypes.float32, 2976 trainable=self.trainable, 2977 use_resource=True, 2978 initializer=self.initializer) 2979 2980 def _get_dense_tensor_internal_helper(self, sparse_tensors, 2981 embedding_weights): 2982 sparse_ids = sparse_tensors.id_tensor 2983 sparse_weights = sparse_tensors.weight_tensor 2984 2985 if self.ckpt_to_load_from is not None: 2986 to_restore = embedding_weights 2987 if isinstance(to_restore, variables.PartitionedVariable): 2988 to_restore = to_restore._get_variable_list() # pylint: disable=protected-access 2989 checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, { 2990 self.tensor_name_in_ckpt: to_restore 2991 }) 2992 2993 sparse_id_rank = tensor_shape.dimension_value( 2994 sparse_ids.dense_shape.get_shape()[0]) 2995 embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse 2996 if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and 2997 sparse_id_rank <= 2): 2998 embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2 2999 # Return embedding lookup result. 3000 return embedding_lookup_sparse( 3001 embedding_weights, 3002 sparse_ids, 3003 sparse_weights, 3004 combiner=self.combiner, 3005 name='%s_weights' % self.name, 3006 max_norm=self.max_norm) 3007 3008 def _get_dense_tensor_internal(self, sparse_tensors, state_manager): 3009 """Private method that follows the signature of get_dense_tensor.""" 3010 embedding_weights = state_manager.get_variable( 3011 self, name='embedding_weights') 3012 return self._get_dense_tensor_internal_helper(sparse_tensors, 3013 embedding_weights) 3014 3015 def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections, 3016 trainable): 3017 """Private method that follows the signature of _get_dense_tensor.""" 3018 embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access 3019 if (weight_collections and 3020 ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections): 3021 weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) 3022 embedding_weights = variable_scope.get_variable( 3023 name='embedding_weights', 3024 shape=embedding_shape, 3025 dtype=dtypes.float32, 3026 initializer=self.initializer, 3027 trainable=self.trainable and trainable, 3028 collections=weight_collections) 3029 return self._get_dense_tensor_internal_helper(sparse_tensors, 3030 embedding_weights) 3031 3032 def get_dense_tensor(self, transformation_cache, state_manager): 3033 """Returns tensor after doing the embedding lookup. 3034 3035 Args: 3036 transformation_cache: A `FeatureTransformationCache` object to access 3037 features. 3038 state_manager: A `StateManager` to create / access resources such as 3039 lookup tables. 3040 3041 Returns: 3042 Embedding lookup tensor. 3043 3044 Raises: 3045 ValueError: `categorical_column` is SequenceCategoricalColumn. 3046 """ 3047 if isinstance(self.categorical_column, SequenceCategoricalColumn): 3048 raise ValueError( 3049 'In embedding_column: {}. ' 3050 'categorical_column must not be of type SequenceCategoricalColumn. ' 3051 'Suggested fix A: If you wish to use DenseFeatures, use a ' 3052 'non-sequence categorical_column_with_*. ' 3053 'Suggested fix B: If you wish to create sequence input, use ' 3054 'SequenceFeatures instead of DenseFeatures. ' 3055 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3056 self.categorical_column)) 3057 # Get sparse IDs and weights. 3058 sparse_tensors = self.categorical_column.get_sparse_tensors( 3059 transformation_cache, state_manager) 3060 return self._get_dense_tensor_internal(sparse_tensors, state_manager) 3061 3062 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3063 _FEATURE_COLUMN_DEPRECATION) 3064 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 3065 if isinstance( 3066 self.categorical_column, 3067 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 3068 raise ValueError( 3069 'In embedding_column: {}. ' 3070 'categorical_column must not be of type _SequenceCategoricalColumn. ' 3071 'Suggested fix A: If you wish to use DenseFeatures, use a ' 3072 'non-sequence categorical_column_with_*. ' 3073 'Suggested fix B: If you wish to create sequence input, use ' 3074 'SequenceFeatures instead of DenseFeatures. ' 3075 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3076 self.categorical_column)) 3077 sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access 3078 inputs, weight_collections, trainable) 3079 return self._old_get_dense_tensor_internal(sparse_tensors, 3080 weight_collections, trainable) 3081 3082 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 3083 """See `SequenceDenseColumn` base class.""" 3084 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 3085 raise ValueError( 3086 'In embedding_column: {}. ' 3087 'categorical_column must be of type SequenceCategoricalColumn ' 3088 'to use SequenceFeatures. ' 3089 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3090 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3091 self.categorical_column)) 3092 sparse_tensors = self.categorical_column.get_sparse_tensors( 3093 transformation_cache, state_manager) 3094 dense_tensor = self._get_dense_tensor_internal(sparse_tensors, 3095 state_manager) 3096 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3097 sparse_tensors.id_tensor) 3098 return SequenceDenseColumn.TensorSequenceLengthPair( 3099 dense_tensor=dense_tensor, sequence_length=sequence_length) 3100 3101 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3102 _FEATURE_COLUMN_DEPRECATION) 3103 def _get_sequence_dense_tensor(self, 3104 inputs, 3105 weight_collections=None, 3106 trainable=None): 3107 if not isinstance( 3108 self.categorical_column, 3109 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 3110 raise ValueError( 3111 'In embedding_column: {}. ' 3112 'categorical_column must be of type SequenceCategoricalColumn ' 3113 'to use SequenceFeatures. ' 3114 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3115 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3116 self.categorical_column)) 3117 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 3118 dense_tensor = self._old_get_dense_tensor_internal( 3119 sparse_tensors, 3120 weight_collections=weight_collections, 3121 trainable=trainable) 3122 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3123 sparse_tensors.id_tensor) 3124 return SequenceDenseColumn.TensorSequenceLengthPair( 3125 dense_tensor=dense_tensor, sequence_length=sequence_length) 3126 3127 @property 3128 def parents(self): 3129 """See 'FeatureColumn` base class.""" 3130 return [self.categorical_column] 3131 3132 def get_config(self): 3133 """See 'FeatureColumn` base class.""" 3134 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 3135 config = dict(zip(self._fields, self)) 3136 config['categorical_column'] = serialization.serialize_feature_column( 3137 self.categorical_column) 3138 config['initializer'] = serialization._serialize_keras_object( # pylint: disable=protected-access 3139 self.initializer) 3140 return config 3141 3142 @classmethod 3143 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3144 """See 'FeatureColumn` base class.""" 3145 if 'use_safe_embedding_lookup' not in config: 3146 config['use_safe_embedding_lookup'] = True 3147 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 3148 _check_config_keys(config, cls._fields) 3149 kwargs = _standardize_and_copy_config(config) 3150 kwargs['categorical_column'] = serialization.deserialize_feature_column( 3151 config['categorical_column'], custom_objects, columns_by_name) 3152 all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass)) 3153 kwargs['initializer'] = serialization._deserialize_keras_object( # pylint: disable=protected-access 3154 config['initializer'], 3155 module_objects=all_initializers, 3156 custom_objects=custom_objects) 3157 return cls(**kwargs) 3158 3159 3160def _raise_shared_embedding_column_error(): 3161 raise ValueError('SharedEmbeddingColumns are not supported in ' 3162 '`linear_model` or `input_layer`. Please use ' 3163 '`DenseFeatures` or `LinearModel` instead.') 3164 3165 3166class SharedEmbeddingColumnCreator(tracking.AutoTrackable): 3167 3168 def __init__(self, 3169 dimension, 3170 initializer, 3171 ckpt_to_load_from, 3172 tensor_name_in_ckpt, 3173 num_buckets, 3174 trainable, 3175 name='shared_embedding_column_creator', 3176 use_safe_embedding_lookup=True): 3177 self._dimension = dimension 3178 self._initializer = initializer 3179 self._ckpt_to_load_from = ckpt_to_load_from 3180 self._tensor_name_in_ckpt = tensor_name_in_ckpt 3181 self._num_buckets = num_buckets 3182 self._trainable = trainable 3183 self._name = name 3184 self._use_safe_embedding_lookup = use_safe_embedding_lookup 3185 # Map from graph keys to embedding_weight variables. 3186 self._embedding_weights = {} 3187 3188 def __call__(self, categorical_column, combiner, max_norm): 3189 return SharedEmbeddingColumn(categorical_column, self, combiner, max_norm, 3190 self._use_safe_embedding_lookup) 3191 3192 @property 3193 def embedding_weights(self): 3194 key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 3195 if key not in self._embedding_weights: 3196 embedding_shape = (self._num_buckets, self._dimension) 3197 var = variable_scope.get_variable( 3198 name=self._name, 3199 shape=embedding_shape, 3200 dtype=dtypes.float32, 3201 initializer=self._initializer, 3202 trainable=self._trainable) 3203 3204 if self._ckpt_to_load_from is not None: 3205 to_restore = var 3206 if isinstance(to_restore, variables.PartitionedVariable): 3207 to_restore = to_restore._get_variable_list() # pylint: disable=protected-access 3208 checkpoint_utils.init_from_checkpoint( 3209 self._ckpt_to_load_from, {self._tensor_name_in_ckpt: to_restore}) 3210 self._embedding_weights[key] = var 3211 return self._embedding_weights[key] 3212 3213 @property 3214 def dimension(self): 3215 return self._dimension 3216 3217 3218class SharedEmbeddingColumn( 3219 DenseColumn, 3220 SequenceDenseColumn, 3221 fc_old._DenseColumn, # pylint: disable=protected-access 3222 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 3223 collections.namedtuple( 3224 'SharedEmbeddingColumn', 3225 ('categorical_column', 'shared_embedding_column_creator', 'combiner', 3226 'max_norm', 'use_safe_embedding_lookup'))): 3227 """See `embedding_column`.""" 3228 3229 def __new__(cls, 3230 categorical_column, 3231 shared_embedding_column_creator, 3232 combiner, 3233 max_norm, 3234 use_safe_embedding_lookup=True): 3235 return super(SharedEmbeddingColumn, cls).__new__( 3236 cls, 3237 categorical_column=categorical_column, 3238 shared_embedding_column_creator=shared_embedding_column_creator, 3239 combiner=combiner, 3240 max_norm=max_norm, 3241 use_safe_embedding_lookup=use_safe_embedding_lookup) 3242 3243 @property 3244 def _is_v2_column(self): 3245 return True 3246 3247 @property 3248 def name(self): 3249 """See `FeatureColumn` base class.""" 3250 return '{}_shared_embedding'.format(self.categorical_column.name) 3251 3252 @property 3253 def parse_example_spec(self): 3254 """See `FeatureColumn` base class.""" 3255 return self.categorical_column.parse_example_spec 3256 3257 @property 3258 def _parse_example_spec(self): 3259 return _raise_shared_embedding_column_error() 3260 3261 def transform_feature(self, transformation_cache, state_manager): 3262 """See `FeatureColumn` base class.""" 3263 return transformation_cache.get(self.categorical_column, state_manager) 3264 3265 def _transform_feature(self, inputs): 3266 return _raise_shared_embedding_column_error() 3267 3268 @property 3269 def variable_shape(self): 3270 """See `DenseColumn` base class.""" 3271 return tensor_shape.TensorShape( 3272 [self.shared_embedding_column_creator.dimension]) 3273 3274 @property 3275 def _variable_shape(self): 3276 return _raise_shared_embedding_column_error() 3277 3278 def _get_dense_tensor_internal(self, transformation_cache, state_manager): 3279 """Private method that follows the signature of _get_dense_tensor.""" 3280 # This method is called from a variable_scope with name _var_scope_name, 3281 # which is shared among all shared embeddings. Open a name_scope here, so 3282 # that the ops for different columns have distinct names. 3283 with ops.name_scope(None, default_name=self.name): 3284 # Get sparse IDs and weights. 3285 sparse_tensors = self.categorical_column.get_sparse_tensors( 3286 transformation_cache, state_manager) 3287 sparse_ids = sparse_tensors.id_tensor 3288 sparse_weights = sparse_tensors.weight_tensor 3289 3290 embedding_weights = self.shared_embedding_column_creator.embedding_weights 3291 3292 sparse_id_rank = tensor_shape.dimension_value( 3293 sparse_ids.dense_shape.get_shape()[0]) 3294 embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse 3295 if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and 3296 sparse_id_rank <= 2): 3297 embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2 3298 # Return embedding lookup result. 3299 return embedding_lookup_sparse( 3300 embedding_weights, 3301 sparse_ids, 3302 sparse_weights, 3303 combiner=self.combiner, 3304 name='%s_weights' % self.name, 3305 max_norm=self.max_norm) 3306 3307 def get_dense_tensor(self, transformation_cache, state_manager): 3308 """Returns the embedding lookup result.""" 3309 if isinstance(self.categorical_column, SequenceCategoricalColumn): 3310 raise ValueError( 3311 'In embedding_column: {}. ' 3312 'categorical_column must not be of type SequenceCategoricalColumn. ' 3313 'Suggested fix A: If you wish to use DenseFeatures, use a ' 3314 'non-sequence categorical_column_with_*. ' 3315 'Suggested fix B: If you wish to create sequence input, use ' 3316 'SequenceFeatures instead of DenseFeatures. ' 3317 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3318 self.categorical_column)) 3319 return self._get_dense_tensor_internal(transformation_cache, state_manager) 3320 3321 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 3322 return _raise_shared_embedding_column_error() 3323 3324 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 3325 """See `SequenceDenseColumn` base class.""" 3326 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 3327 raise ValueError( 3328 'In embedding_column: {}. ' 3329 'categorical_column must be of type SequenceCategoricalColumn ' 3330 'to use SequenceFeatures. ' 3331 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3332 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3333 self.categorical_column)) 3334 dense_tensor = self._get_dense_tensor_internal(transformation_cache, 3335 state_manager) 3336 sparse_tensors = self.categorical_column.get_sparse_tensors( 3337 transformation_cache, state_manager) 3338 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3339 sparse_tensors.id_tensor) 3340 return SequenceDenseColumn.TensorSequenceLengthPair( 3341 dense_tensor=dense_tensor, sequence_length=sequence_length) 3342 3343 def _get_sequence_dense_tensor(self, 3344 inputs, 3345 weight_collections=None, 3346 trainable=None): 3347 return _raise_shared_embedding_column_error() 3348 3349 @property 3350 def parents(self): 3351 """See 'FeatureColumn` base class.""" 3352 return [self.categorical_column] 3353 3354 3355def _check_shape(shape, key): 3356 """Returns shape if it's valid, raises error otherwise.""" 3357 assert shape is not None 3358 if not nest.is_sequence(shape): 3359 shape = [shape] 3360 shape = tuple(shape) 3361 for dimension in shape: 3362 if not isinstance(dimension, int): 3363 raise TypeError('shape dimensions must be integer. ' 3364 'shape: {}, key: {}'.format(shape, key)) 3365 if dimension < 1: 3366 raise ValueError('shape dimensions must be greater than 0. ' 3367 'shape: {}, key: {}'.format(shape, key)) 3368 return shape 3369 3370 3371class HashedCategoricalColumn( 3372 CategoricalColumn, 3373 fc_old._CategoricalColumn, # pylint: disable=protected-access 3374 collections.namedtuple('HashedCategoricalColumn', 3375 ('key', 'hash_bucket_size', 'dtype'))): 3376 """see `categorical_column_with_hash_bucket`.""" 3377 3378 @property 3379 def _is_v2_column(self): 3380 return True 3381 3382 @property 3383 def name(self): 3384 """See `FeatureColumn` base class.""" 3385 return self.key 3386 3387 @property 3388 def parse_example_spec(self): 3389 """See `FeatureColumn` base class.""" 3390 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3391 3392 @property 3393 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3394 _FEATURE_COLUMN_DEPRECATION) 3395 def _parse_example_spec(self): 3396 return self.parse_example_spec 3397 3398 def _transform_input_tensor(self, input_tensor): 3399 """Hashes the values in the feature_column.""" 3400 if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 3401 raise ValueError('SparseColumn input must be a SparseTensor.') 3402 3403 fc_utils.assert_string_or_int( 3404 input_tensor.dtype, 3405 prefix='column_name: {} input_tensor'.format(self.key)) 3406 3407 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3408 raise ValueError( 3409 'Column dtype and SparseTensors dtype must be compatible. ' 3410 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3411 self.key, self.dtype, input_tensor.dtype)) 3412 3413 if self.dtype == dtypes.string: 3414 sparse_values = input_tensor.values 3415 else: 3416 sparse_values = string_ops.as_string(input_tensor.values) 3417 3418 sparse_id_values = string_ops.string_to_hash_bucket_fast( 3419 sparse_values, self.hash_bucket_size, name='lookup') 3420 return sparse_tensor_lib.SparseTensor( 3421 input_tensor.indices, sparse_id_values, input_tensor.dense_shape) 3422 3423 def transform_feature(self, transformation_cache, state_manager): 3424 """Hashes the values in the feature_column.""" 3425 input_tensor = _to_sparse_input_and_drop_ignore_values( 3426 transformation_cache.get(self.key, state_manager)) 3427 return self._transform_input_tensor(input_tensor) 3428 3429 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3430 _FEATURE_COLUMN_DEPRECATION) 3431 def _transform_feature(self, inputs): 3432 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3433 return self._transform_input_tensor(input_tensor) 3434 3435 @property 3436 def num_buckets(self): 3437 """Returns number of buckets in this sparse feature.""" 3438 return self.hash_bucket_size 3439 3440 @property 3441 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3442 _FEATURE_COLUMN_DEPRECATION) 3443 def _num_buckets(self): 3444 return self.num_buckets 3445 3446 def get_sparse_tensors(self, transformation_cache, state_manager): 3447 """See `CategoricalColumn` base class.""" 3448 return CategoricalColumn.IdWeightPair( 3449 transformation_cache.get(self, state_manager), None) 3450 3451 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3452 _FEATURE_COLUMN_DEPRECATION) 3453 def _get_sparse_tensors(self, inputs, weight_collections=None, 3454 trainable=None): 3455 del weight_collections 3456 del trainable 3457 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3458 3459 @property 3460 def parents(self): 3461 """See 'FeatureColumn` base class.""" 3462 return [self.key] 3463 3464 def get_config(self): 3465 """See 'FeatureColumn` base class.""" 3466 config = dict(zip(self._fields, self)) 3467 config['dtype'] = self.dtype.name 3468 return config 3469 3470 @classmethod 3471 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3472 """See 'FeatureColumn` base class.""" 3473 _check_config_keys(config, cls._fields) 3474 kwargs = _standardize_and_copy_config(config) 3475 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3476 return cls(**kwargs) 3477 3478 3479class VocabularyFileCategoricalColumn( 3480 CategoricalColumn, 3481 fc_old._CategoricalColumn, # pylint: disable=protected-access 3482 collections.namedtuple( 3483 'VocabularyFileCategoricalColumn', 3484 ('key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 3485 'dtype', 'default_value', 'file_format'))): 3486 """See `categorical_column_with_vocabulary_file`.""" 3487 3488 def __new__(cls, 3489 key, 3490 vocabulary_file, 3491 vocabulary_size, 3492 num_oov_buckets, 3493 dtype, 3494 default_value, 3495 file_format=None): 3496 return super(VocabularyFileCategoricalColumn, cls).__new__( 3497 cls, 3498 key=key, 3499 vocabulary_file=vocabulary_file, 3500 vocabulary_size=vocabulary_size, 3501 num_oov_buckets=num_oov_buckets, 3502 dtype=dtype, 3503 default_value=default_value, 3504 file_format=file_format) 3505 3506 @property 3507 def _is_v2_column(self): 3508 return True 3509 3510 @property 3511 def name(self): 3512 """See `FeatureColumn` base class.""" 3513 return self.key 3514 3515 @property 3516 def parse_example_spec(self): 3517 """See `FeatureColumn` base class.""" 3518 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3519 3520 @property 3521 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3522 _FEATURE_COLUMN_DEPRECATION) 3523 def _parse_example_spec(self): 3524 return self.parse_example_spec 3525 3526 def _make_table_from_tfrecord_gzip_file(self, key_dtype, name): 3527 dataset = readers.TFRecordDataset( 3528 self.vocabulary_file, compression_type='GZIP') 3529 3530 def key_dtype_fn(key): 3531 return key if key_dtype is dtypes.string else string_ops.string_to_number( 3532 key, out_type=key_dtype) 3533 3534 return data_lookup_ops.index_table_from_dataset( 3535 dataset.map(key_dtype_fn), 3536 num_oov_buckets=self.num_oov_buckets, 3537 vocab_size=self.vocabulary_size, 3538 default_value=self.default_value, 3539 key_dtype=key_dtype, 3540 name=name) 3541 3542 def _make_table(self, key_dtype, state_manager): 3543 name = '{}_lookup'.format(self.key) 3544 if state_manager is None or not state_manager.has_resource(self, name): 3545 with ops.init_scope(): 3546 if self.file_format == 'tfrecord_gzip': 3547 table = self._make_table_from_tfrecord_gzip_file(key_dtype, name) 3548 else: 3549 table = lookup_ops.index_table_from_file( 3550 vocabulary_file=self.vocabulary_file, 3551 num_oov_buckets=self.num_oov_buckets, 3552 vocab_size=self.vocabulary_size, 3553 default_value=self.default_value, 3554 key_dtype=key_dtype, 3555 name=name) 3556 if state_manager is not None: 3557 state_manager.add_resource(self, name, table) 3558 else: 3559 # Reuse the table from the previous run. 3560 table = state_manager.get_resource(self, name) 3561 return table 3562 3563 def _transform_input_tensor(self, input_tensor, state_manager=None): 3564 """Creates a lookup table for the vocabulary.""" 3565 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3566 raise ValueError( 3567 'Column dtype and SparseTensors dtype must be compatible. ' 3568 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3569 self.key, self.dtype, input_tensor.dtype)) 3570 3571 fc_utils.assert_string_or_int( 3572 input_tensor.dtype, 3573 prefix='column_name: {} input_tensor'.format(self.key)) 3574 3575 key_dtype = self.dtype 3576 if input_tensor.dtype.is_integer: 3577 # `index_table_from_file` requires 64-bit integer keys. 3578 key_dtype = dtypes.int64 3579 input_tensor = math_ops.cast(input_tensor, dtypes.int64) 3580 return self._make_table(key_dtype, state_manager).lookup(input_tensor) 3581 3582 def transform_feature(self, transformation_cache, state_manager): 3583 """Creates a lookup table for the vocabulary.""" 3584 input_tensor = _to_sparse_input_and_drop_ignore_values( 3585 transformation_cache.get(self.key, state_manager)) 3586 return self._transform_input_tensor(input_tensor, state_manager) 3587 3588 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3589 _FEATURE_COLUMN_DEPRECATION) 3590 def _transform_feature(self, inputs): 3591 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3592 return self._transform_input_tensor(input_tensor) 3593 3594 @property 3595 def num_buckets(self): 3596 """Returns number of buckets in this sparse feature.""" 3597 return self.vocabulary_size + self.num_oov_buckets 3598 3599 @property 3600 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3601 _FEATURE_COLUMN_DEPRECATION) 3602 def _num_buckets(self): 3603 return self.num_buckets 3604 3605 def get_sparse_tensors(self, transformation_cache, state_manager): 3606 """See `CategoricalColumn` base class.""" 3607 return CategoricalColumn.IdWeightPair( 3608 transformation_cache.get(self, state_manager), None) 3609 3610 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3611 _FEATURE_COLUMN_DEPRECATION) 3612 def _get_sparse_tensors(self, inputs, weight_collections=None, 3613 trainable=None): 3614 del weight_collections 3615 del trainable 3616 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3617 3618 @property 3619 def parents(self): 3620 """See 'FeatureColumn` base class.""" 3621 return [self.key] 3622 3623 def get_config(self): 3624 """See 'FeatureColumn` base class.""" 3625 config = dict(zip(self._fields, self)) 3626 config['dtype'] = self.dtype.name 3627 return config 3628 3629 @classmethod 3630 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3631 """See 'FeatureColumn` base class.""" 3632 _check_config_keys(config, cls._fields) 3633 kwargs = _standardize_and_copy_config(config) 3634 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3635 return cls(**kwargs) 3636 3637 3638class VocabularyListCategoricalColumn( 3639 CategoricalColumn, 3640 fc_old._CategoricalColumn, # pylint: disable=protected-access 3641 collections.namedtuple( 3642 'VocabularyListCategoricalColumn', 3643 ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets')) 3644): 3645 """See `categorical_column_with_vocabulary_list`.""" 3646 3647 @property 3648 def _is_v2_column(self): 3649 return True 3650 3651 @property 3652 def name(self): 3653 """See `FeatureColumn` base class.""" 3654 return self.key 3655 3656 @property 3657 def parse_example_spec(self): 3658 """See `FeatureColumn` base class.""" 3659 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3660 3661 @property 3662 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3663 _FEATURE_COLUMN_DEPRECATION) 3664 def _parse_example_spec(self): 3665 return self.parse_example_spec 3666 3667 def _transform_input_tensor(self, input_tensor, state_manager=None): 3668 """Creates a lookup table for the vocabulary list.""" 3669 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3670 raise ValueError( 3671 'Column dtype and SparseTensors dtype must be compatible. ' 3672 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3673 self.key, self.dtype, input_tensor.dtype)) 3674 3675 fc_utils.assert_string_or_int( 3676 input_tensor.dtype, 3677 prefix='column_name: {} input_tensor'.format(self.key)) 3678 3679 key_dtype = self.dtype 3680 if input_tensor.dtype.is_integer: 3681 # `index_table_from_tensor` requires 64-bit integer keys. 3682 key_dtype = dtypes.int64 3683 input_tensor = math_ops.cast(input_tensor, dtypes.int64) 3684 3685 name = '{}_lookup'.format(self.key) 3686 if state_manager is None or not state_manager.has_resource(self, name): 3687 with ops.init_scope(): 3688 table = lookup_ops.index_table_from_tensor( 3689 vocabulary_list=tuple(self.vocabulary_list), 3690 default_value=self.default_value, 3691 num_oov_buckets=self.num_oov_buckets, 3692 dtype=key_dtype, 3693 name=name) 3694 if state_manager is not None: 3695 state_manager.add_resource(self, name, table) 3696 else: 3697 # Reuse the table from the previous run. 3698 table = state_manager.get_resource(self, name) 3699 return table.lookup(input_tensor) 3700 3701 def transform_feature(self, transformation_cache, state_manager): 3702 """Creates a lookup table for the vocabulary list.""" 3703 input_tensor = _to_sparse_input_and_drop_ignore_values( 3704 transformation_cache.get(self.key, state_manager)) 3705 return self._transform_input_tensor(input_tensor, state_manager) 3706 3707 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3708 _FEATURE_COLUMN_DEPRECATION) 3709 def _transform_feature(self, inputs): 3710 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3711 return self._transform_input_tensor(input_tensor) 3712 3713 @property 3714 def num_buckets(self): 3715 """Returns number of buckets in this sparse feature.""" 3716 return len(self.vocabulary_list) + self.num_oov_buckets 3717 3718 @property 3719 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3720 _FEATURE_COLUMN_DEPRECATION) 3721 def _num_buckets(self): 3722 return self.num_buckets 3723 3724 def get_sparse_tensors(self, transformation_cache, state_manager): 3725 """See `CategoricalColumn` base class.""" 3726 return CategoricalColumn.IdWeightPair( 3727 transformation_cache.get(self, state_manager), None) 3728 3729 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3730 _FEATURE_COLUMN_DEPRECATION) 3731 def _get_sparse_tensors(self, inputs, weight_collections=None, 3732 trainable=None): 3733 del weight_collections 3734 del trainable 3735 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3736 3737 @property 3738 def parents(self): 3739 """See 'FeatureColumn` base class.""" 3740 return [self.key] 3741 3742 def get_config(self): 3743 """See 'FeatureColumn` base class.""" 3744 config = dict(zip(self._fields, self)) 3745 config['dtype'] = self.dtype.name 3746 return config 3747 3748 @classmethod 3749 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3750 """See 'FeatureColumn` base class.""" 3751 _check_config_keys(config, cls._fields) 3752 kwargs = _standardize_and_copy_config(config) 3753 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3754 return cls(**kwargs) 3755 3756 3757class IdentityCategoricalColumn( 3758 CategoricalColumn, 3759 fc_old._CategoricalColumn, # pylint: disable=protected-access 3760 collections.namedtuple('IdentityCategoricalColumn', 3761 ('key', 'number_buckets', 'default_value'))): 3762 3763 """See `categorical_column_with_identity`.""" 3764 3765 @property 3766 def _is_v2_column(self): 3767 return True 3768 3769 @property 3770 def name(self): 3771 """See `FeatureColumn` base class.""" 3772 return self.key 3773 3774 @property 3775 def parse_example_spec(self): 3776 """See `FeatureColumn` base class.""" 3777 return {self.key: parsing_ops.VarLenFeature(dtypes.int64)} 3778 3779 @property 3780 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3781 _FEATURE_COLUMN_DEPRECATION) 3782 def _parse_example_spec(self): 3783 return self.parse_example_spec 3784 3785 def _transform_input_tensor(self, input_tensor): 3786 """Returns a SparseTensor with identity values.""" 3787 if not input_tensor.dtype.is_integer: 3788 raise ValueError( 3789 'Invalid input, not integer. key: {} dtype: {}'.format( 3790 self.key, input_tensor.dtype)) 3791 values = input_tensor.values 3792 if input_tensor.values.dtype != dtypes.int64: 3793 values = math_ops.cast(values, dtypes.int64, name='values') 3794 if self.default_value is not None: 3795 values = math_ops.cast(input_tensor.values, dtypes.int64, name='values') 3796 num_buckets = math_ops.cast( 3797 self.num_buckets, dtypes.int64, name='num_buckets') 3798 zero = math_ops.cast(0, dtypes.int64, name='zero') 3799 # Assign default for out-of-range values. 3800 values = array_ops.where_v2( 3801 math_ops.logical_or( 3802 values < zero, values >= num_buckets, name='out_of_range'), 3803 array_ops.fill( 3804 dims=array_ops.shape(values), 3805 value=math_ops.cast(self.default_value, dtypes.int64), 3806 name='default_values'), values) 3807 3808 return sparse_tensor_lib.SparseTensor( 3809 indices=input_tensor.indices, 3810 values=values, 3811 dense_shape=input_tensor.dense_shape) 3812 3813 def transform_feature(self, transformation_cache, state_manager): 3814 """Returns a SparseTensor with identity values.""" 3815 input_tensor = _to_sparse_input_and_drop_ignore_values( 3816 transformation_cache.get(self.key, state_manager)) 3817 return self._transform_input_tensor(input_tensor) 3818 3819 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3820 _FEATURE_COLUMN_DEPRECATION) 3821 def _transform_feature(self, inputs): 3822 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3823 return self._transform_input_tensor(input_tensor) 3824 3825 @property 3826 def num_buckets(self): 3827 """Returns number of buckets in this sparse feature.""" 3828 return self.number_buckets 3829 3830 @property 3831 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3832 _FEATURE_COLUMN_DEPRECATION) 3833 def _num_buckets(self): 3834 return self.num_buckets 3835 3836 def get_sparse_tensors(self, transformation_cache, state_manager): 3837 """See `CategoricalColumn` base class.""" 3838 return CategoricalColumn.IdWeightPair( 3839 transformation_cache.get(self, state_manager), None) 3840 3841 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3842 _FEATURE_COLUMN_DEPRECATION) 3843 def _get_sparse_tensors(self, inputs, weight_collections=None, 3844 trainable=None): 3845 del weight_collections 3846 del trainable 3847 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3848 3849 @property 3850 def parents(self): 3851 """See 'FeatureColumn` base class.""" 3852 return [self.key] 3853 3854 def get_config(self): 3855 """See 'FeatureColumn` base class.""" 3856 return dict(zip(self._fields, self)) 3857 3858 @classmethod 3859 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3860 """See 'FeatureColumn` base class.""" 3861 _check_config_keys(config, cls._fields) 3862 kwargs = _standardize_and_copy_config(config) 3863 return cls(**kwargs) 3864 3865 3866class WeightedCategoricalColumn( 3867 CategoricalColumn, 3868 fc_old._CategoricalColumn, # pylint: disable=protected-access 3869 collections.namedtuple( 3870 'WeightedCategoricalColumn', 3871 ('categorical_column', 'weight_feature_key', 'dtype'))): 3872 """See `weighted_categorical_column`.""" 3873 3874 @property 3875 def _is_v2_column(self): 3876 return (isinstance(self.categorical_column, FeatureColumn) and 3877 self.categorical_column._is_v2_column) # pylint: disable=protected-access 3878 3879 @property 3880 def name(self): 3881 """See `FeatureColumn` base class.""" 3882 return '{}_weighted_by_{}'.format( 3883 self.categorical_column.name, self.weight_feature_key) 3884 3885 @property 3886 def parse_example_spec(self): 3887 """See `FeatureColumn` base class.""" 3888 config = self.categorical_column.parse_example_spec 3889 if self.weight_feature_key in config: 3890 raise ValueError('Parse config {} already exists for {}.'.format( 3891 config[self.weight_feature_key], self.weight_feature_key)) 3892 config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype) 3893 return config 3894 3895 @property 3896 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3897 _FEATURE_COLUMN_DEPRECATION) 3898 def _parse_example_spec(self): 3899 config = self.categorical_column._parse_example_spec # pylint: disable=protected-access 3900 if self.weight_feature_key in config: 3901 raise ValueError('Parse config {} already exists for {}.'.format( 3902 config[self.weight_feature_key], self.weight_feature_key)) 3903 config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype) 3904 return config 3905 3906 @property 3907 def num_buckets(self): 3908 """See `DenseColumn` base class.""" 3909 return self.categorical_column.num_buckets 3910 3911 @property 3912 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3913 _FEATURE_COLUMN_DEPRECATION) 3914 def _num_buckets(self): 3915 return self.categorical_column._num_buckets # pylint: disable=protected-access 3916 3917 def _transform_weight_tensor(self, weight_tensor): 3918 if weight_tensor is None: 3919 raise ValueError('Missing weights {}.'.format(self.weight_feature_key)) 3920 weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 3921 weight_tensor) 3922 if self.dtype != weight_tensor.dtype.base_dtype: 3923 raise ValueError('Bad dtype, expected {}, but got {}.'.format( 3924 self.dtype, weight_tensor.dtype)) 3925 if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor): 3926 # The weight tensor can be a regular Tensor. In this case, sparsify it. 3927 weight_tensor = _to_sparse_input_and_drop_ignore_values( 3928 weight_tensor, ignore_value=0.0) 3929 if not weight_tensor.dtype.is_floating: 3930 weight_tensor = math_ops.cast(weight_tensor, dtypes.float32) 3931 return weight_tensor 3932 3933 def transform_feature(self, transformation_cache, state_manager): 3934 """Applies weights to tensor generated from `categorical_column`'.""" 3935 weight_tensor = transformation_cache.get(self.weight_feature_key, 3936 state_manager) 3937 sparse_weight_tensor = self._transform_weight_tensor(weight_tensor) 3938 sparse_categorical_tensor = _to_sparse_input_and_drop_ignore_values( 3939 transformation_cache.get(self.categorical_column, state_manager)) 3940 return (sparse_categorical_tensor, sparse_weight_tensor) 3941 3942 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3943 _FEATURE_COLUMN_DEPRECATION) 3944 def _transform_feature(self, inputs): 3945 """Applies weights to tensor generated from `categorical_column`'.""" 3946 weight_tensor = inputs.get(self.weight_feature_key) 3947 weight_tensor = self._transform_weight_tensor(weight_tensor) 3948 return (inputs.get(self.categorical_column), weight_tensor) 3949 3950 def get_sparse_tensors(self, transformation_cache, state_manager): 3951 """See `CategoricalColumn` base class.""" 3952 tensors = transformation_cache.get(self, state_manager) 3953 return CategoricalColumn.IdWeightPair(tensors[0], tensors[1]) 3954 3955 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3956 _FEATURE_COLUMN_DEPRECATION) 3957 def _get_sparse_tensors(self, inputs, weight_collections=None, 3958 trainable=None): 3959 del weight_collections 3960 del trainable 3961 tensors = inputs.get(self) 3962 return CategoricalColumn.IdWeightPair(tensors[0], tensors[1]) 3963 3964 @property 3965 def parents(self): 3966 """See 'FeatureColumn` base class.""" 3967 return [self.categorical_column, self.weight_feature_key] 3968 3969 def get_config(self): 3970 """See 'FeatureColumn` base class.""" 3971 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 3972 config = dict(zip(self._fields, self)) 3973 config['categorical_column'] = serialize_feature_column( 3974 self.categorical_column) 3975 config['dtype'] = self.dtype.name 3976 return config 3977 3978 @classmethod 3979 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3980 """See 'FeatureColumn` base class.""" 3981 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 3982 _check_config_keys(config, cls._fields) 3983 kwargs = _standardize_and_copy_config(config) 3984 kwargs['categorical_column'] = deserialize_feature_column( 3985 config['categorical_column'], custom_objects, columns_by_name) 3986 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3987 return cls(**kwargs) 3988 3989 3990class CrossedColumn( 3991 CategoricalColumn, 3992 fc_old._CategoricalColumn, # pylint: disable=protected-access 3993 collections.namedtuple('CrossedColumn', 3994 ('keys', 'hash_bucket_size', 'hash_key'))): 3995 """See `crossed_column`.""" 3996 3997 @property 3998 def _is_v2_column(self): 3999 for key in _collect_leaf_level_keys(self): 4000 if isinstance(key, six.string_types): 4001 continue 4002 if not isinstance(key, FeatureColumn): 4003 return False 4004 if not key._is_v2_column: # pylint: disable=protected-access 4005 return False 4006 return True 4007 4008 @property 4009 def name(self): 4010 """See `FeatureColumn` base class.""" 4011 feature_names = [] 4012 for key in _collect_leaf_level_keys(self): 4013 if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)): # pylint: disable=protected-access 4014 feature_names.append(key.name) 4015 else: # key must be a string 4016 feature_names.append(key) 4017 return '_X_'.join(sorted(feature_names)) 4018 4019 @property 4020 def parse_example_spec(self): 4021 """See `FeatureColumn` base class.""" 4022 config = {} 4023 for key in self.keys: 4024 if isinstance(key, FeatureColumn): 4025 config.update(key.parse_example_spec) 4026 elif isinstance(key, fc_old._FeatureColumn): # pylint: disable=protected-access 4027 config.update(key._parse_example_spec) # pylint: disable=protected-access 4028 else: # key must be a string 4029 config.update({key: parsing_ops.VarLenFeature(dtypes.string)}) 4030 return config 4031 4032 @property 4033 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4034 _FEATURE_COLUMN_DEPRECATION) 4035 def _parse_example_spec(self): 4036 return self.parse_example_spec 4037 4038 def transform_feature(self, transformation_cache, state_manager): 4039 """Generates a hashed sparse cross from the input tensors.""" 4040 feature_tensors = [] 4041 for key in _collect_leaf_level_keys(self): 4042 if isinstance(key, six.string_types): 4043 feature_tensors.append(transformation_cache.get(key, state_manager)) 4044 elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)): # pylint: disable=protected-access 4045 ids_and_weights = key.get_sparse_tensors(transformation_cache, 4046 state_manager) 4047 if ids_and_weights.weight_tensor is not None: 4048 raise ValueError( 4049 'crossed_column does not support weight_tensor, but the given ' 4050 'column populates weight_tensor. ' 4051 'Given column: {}'.format(key.name)) 4052 feature_tensors.append(ids_and_weights.id_tensor) 4053 else: 4054 raise ValueError('Unsupported column type. Given: {}'.format(key)) 4055 return sparse_ops.sparse_cross_hashed( 4056 inputs=feature_tensors, 4057 num_buckets=self.hash_bucket_size, 4058 hash_key=self.hash_key) 4059 4060 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4061 _FEATURE_COLUMN_DEPRECATION) 4062 def _transform_feature(self, inputs): 4063 """Generates a hashed sparse cross from the input tensors.""" 4064 feature_tensors = [] 4065 for key in _collect_leaf_level_keys(self): 4066 if isinstance(key, six.string_types): 4067 feature_tensors.append(inputs.get(key)) 4068 elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access 4069 ids_and_weights = key._get_sparse_tensors(inputs) # pylint: disable=protected-access 4070 if ids_and_weights.weight_tensor is not None: 4071 raise ValueError( 4072 'crossed_column does not support weight_tensor, but the given ' 4073 'column populates weight_tensor. ' 4074 'Given column: {}'.format(key.name)) 4075 feature_tensors.append(ids_and_weights.id_tensor) 4076 else: 4077 raise ValueError('Unsupported column type. Given: {}'.format(key)) 4078 return sparse_ops.sparse_cross_hashed( 4079 inputs=feature_tensors, 4080 num_buckets=self.hash_bucket_size, 4081 hash_key=self.hash_key) 4082 4083 @property 4084 def num_buckets(self): 4085 """Returns number of buckets in this sparse feature.""" 4086 return self.hash_bucket_size 4087 4088 @property 4089 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4090 _FEATURE_COLUMN_DEPRECATION) 4091 def _num_buckets(self): 4092 return self.num_buckets 4093 4094 def get_sparse_tensors(self, transformation_cache, state_manager): 4095 """See `CategoricalColumn` base class.""" 4096 return CategoricalColumn.IdWeightPair( 4097 transformation_cache.get(self, state_manager), None) 4098 4099 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4100 _FEATURE_COLUMN_DEPRECATION) 4101 def _get_sparse_tensors(self, inputs, weight_collections=None, 4102 trainable=None): 4103 """See `CategoricalColumn` base class.""" 4104 del weight_collections 4105 del trainable 4106 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 4107 4108 @property 4109 def parents(self): 4110 """See 'FeatureColumn` base class.""" 4111 return list(self.keys) 4112 4113 def get_config(self): 4114 """See 'FeatureColumn` base class.""" 4115 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4116 config = dict(zip(self._fields, self)) 4117 config['keys'] = tuple([serialize_feature_column(fc) for fc in self.keys]) 4118 return config 4119 4120 @classmethod 4121 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4122 """See 'FeatureColumn` base class.""" 4123 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4124 _check_config_keys(config, cls._fields) 4125 kwargs = _standardize_and_copy_config(config) 4126 kwargs['keys'] = tuple([ 4127 deserialize_feature_column(c, custom_objects, columns_by_name) 4128 for c in config['keys'] 4129 ]) 4130 return cls(**kwargs) 4131 4132 4133def _collect_leaf_level_keys(cross): 4134 """Collects base keys by expanding all nested crosses. 4135 4136 Args: 4137 cross: A `CrossedColumn`. 4138 4139 Returns: 4140 A list of strings or `CategoricalColumn` instances. 4141 """ 4142 leaf_level_keys = [] 4143 for k in cross.keys: 4144 if isinstance(k, CrossedColumn): 4145 leaf_level_keys.extend(_collect_leaf_level_keys(k)) 4146 else: 4147 leaf_level_keys.append(k) 4148 return leaf_level_keys 4149 4150 4151def _prune_invalid_ids(sparse_ids, sparse_weights): 4152 """Prune invalid IDs (< 0) from the input ids and weights.""" 4153 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) 4154 if sparse_weights is not None: 4155 is_id_valid = math_ops.logical_and( 4156 is_id_valid, 4157 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) 4158 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) 4159 if sparse_weights is not None: 4160 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) 4161 return sparse_ids, sparse_weights 4162 4163 4164def _prune_invalid_weights(sparse_ids, sparse_weights): 4165 """Prune invalid weights (< 0) from the input ids and weights.""" 4166 if sparse_weights is not None: 4167 is_weights_valid = math_ops.greater(sparse_weights.values, 0) 4168 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) 4169 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) 4170 return sparse_ids, sparse_weights 4171 4172 4173class IndicatorColumn( 4174 DenseColumn, 4175 SequenceDenseColumn, 4176 fc_old._DenseColumn, # pylint: disable=protected-access 4177 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 4178 collections.namedtuple('IndicatorColumn', ('categorical_column'))): 4179 """Represents a one-hot column for use in deep networks. 4180 4181 Args: 4182 categorical_column: A `CategoricalColumn` which is created by 4183 `categorical_column_with_*` function. 4184 """ 4185 4186 @property 4187 def _is_v2_column(self): 4188 return (isinstance(self.categorical_column, FeatureColumn) and 4189 self.categorical_column._is_v2_column) # pylint: disable=protected-access 4190 4191 @property 4192 def name(self): 4193 """See `FeatureColumn` base class.""" 4194 return '{}_indicator'.format(self.categorical_column.name) 4195 4196 def _transform_id_weight_pair(self, id_weight_pair, size): 4197 id_tensor = id_weight_pair.id_tensor 4198 weight_tensor = id_weight_pair.weight_tensor 4199 4200 # If the underlying column is weighted, return the input as a dense tensor. 4201 if weight_tensor is not None: 4202 weighted_column = sparse_ops.sparse_merge( 4203 sp_ids=id_tensor, sp_values=weight_tensor, vocab_size=int(size)) 4204 # Remove (?, -1) index. 4205 weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0], 4206 weighted_column.dense_shape) 4207 # Use scatter_nd to merge duplicated indices if existed, 4208 # instead of sparse_tensor_to_dense. 4209 return array_ops.scatter_nd(weighted_column.indices, 4210 weighted_column.values, 4211 weighted_column.dense_shape) 4212 4213 dense_id_tensor = sparse_ops.sparse_tensor_to_dense( 4214 id_tensor, default_value=-1) 4215 4216 # One hot must be float for tf.concat reasons since all other inputs to 4217 # input_layer are float32. 4218 one_hot_id_tensor = array_ops.one_hot( 4219 dense_id_tensor, depth=size, on_value=1.0, off_value=0.0) 4220 4221 # Reduce to get a multi-hot per example. 4222 return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2]) 4223 4224 def transform_feature(self, transformation_cache, state_manager): 4225 """Returns dense `Tensor` representing feature. 4226 4227 Args: 4228 transformation_cache: A `FeatureTransformationCache` object to access 4229 features. 4230 state_manager: A `StateManager` to create / access resources such as 4231 lookup tables. 4232 4233 Returns: 4234 Transformed feature `Tensor`. 4235 4236 Raises: 4237 ValueError: if input rank is not known at graph building time. 4238 """ 4239 id_weight_pair = self.categorical_column.get_sparse_tensors( 4240 transformation_cache, state_manager) 4241 return self._transform_id_weight_pair(id_weight_pair, 4242 self.variable_shape[-1]) 4243 4244 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4245 _FEATURE_COLUMN_DEPRECATION) 4246 def _transform_feature(self, inputs): 4247 id_weight_pair = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4248 return self._transform_id_weight_pair(id_weight_pair, 4249 self._variable_shape[-1]) 4250 4251 @property 4252 def parse_example_spec(self): 4253 """See `FeatureColumn` base class.""" 4254 return self.categorical_column.parse_example_spec 4255 4256 @property 4257 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4258 _FEATURE_COLUMN_DEPRECATION) 4259 def _parse_example_spec(self): 4260 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 4261 4262 @property 4263 def variable_shape(self): 4264 """Returns a `TensorShape` representing the shape of the dense `Tensor`.""" 4265 if isinstance(self.categorical_column, FeatureColumn): 4266 return tensor_shape.TensorShape([1, self.categorical_column.num_buckets]) 4267 else: 4268 return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access 4269 4270 @property 4271 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4272 _FEATURE_COLUMN_DEPRECATION) 4273 def _variable_shape(self): 4274 return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access 4275 4276 def get_dense_tensor(self, transformation_cache, state_manager): 4277 """Returns dense `Tensor` representing feature. 4278 4279 Args: 4280 transformation_cache: A `FeatureTransformationCache` object to access 4281 features. 4282 state_manager: A `StateManager` to create / access resources such as 4283 lookup tables. 4284 4285 Returns: 4286 Dense `Tensor` created within `transform_feature`. 4287 4288 Raises: 4289 ValueError: If `categorical_column` is a `SequenceCategoricalColumn`. 4290 """ 4291 if isinstance(self.categorical_column, SequenceCategoricalColumn): 4292 raise ValueError( 4293 'In indicator_column: {}. ' 4294 'categorical_column must not be of type SequenceCategoricalColumn. ' 4295 'Suggested fix A: If you wish to use DenseFeatures, use a ' 4296 'non-sequence categorical_column_with_*. ' 4297 'Suggested fix B: If you wish to create sequence input, use ' 4298 'SequenceFeatures instead of DenseFeatures. ' 4299 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4300 self.categorical_column)) 4301 # Feature has been already transformed. Return the intermediate 4302 # representation created by transform_feature. 4303 return transformation_cache.get(self, state_manager) 4304 4305 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4306 _FEATURE_COLUMN_DEPRECATION) 4307 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 4308 del weight_collections 4309 del trainable 4310 if isinstance( 4311 self.categorical_column, 4312 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 4313 raise ValueError( 4314 'In indicator_column: {}. ' 4315 'categorical_column must not be of type _SequenceCategoricalColumn. ' 4316 'Suggested fix A: If you wish to use DenseFeatures, use a ' 4317 'non-sequence categorical_column_with_*. ' 4318 'Suggested fix B: If you wish to create sequence input, use ' 4319 'SequenceFeatures instead of DenseFeatures. ' 4320 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4321 self.categorical_column)) 4322 # Feature has been already transformed. Return the intermediate 4323 # representation created by transform_feature. 4324 return inputs.get(self) 4325 4326 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 4327 """See `SequenceDenseColumn` base class.""" 4328 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 4329 raise ValueError( 4330 'In indicator_column: {}. ' 4331 'categorical_column must be of type SequenceCategoricalColumn ' 4332 'to use SequenceFeatures. ' 4333 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 4334 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4335 self.categorical_column)) 4336 # Feature has been already transformed. Return the intermediate 4337 # representation created by transform_feature. 4338 dense_tensor = transformation_cache.get(self, state_manager) 4339 sparse_tensors = self.categorical_column.get_sparse_tensors( 4340 transformation_cache, state_manager) 4341 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 4342 sparse_tensors.id_tensor) 4343 return SequenceDenseColumn.TensorSequenceLengthPair( 4344 dense_tensor=dense_tensor, sequence_length=sequence_length) 4345 4346 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4347 _FEATURE_COLUMN_DEPRECATION) 4348 def _get_sequence_dense_tensor(self, 4349 inputs, 4350 weight_collections=None, 4351 trainable=None): 4352 # Do nothing with weight_collections and trainable since no variables are 4353 # created in this function. 4354 del weight_collections 4355 del trainable 4356 if not isinstance( 4357 self.categorical_column, 4358 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 4359 raise ValueError( 4360 'In indicator_column: {}. ' 4361 'categorical_column must be of type _SequenceCategoricalColumn ' 4362 'to use SequenceFeatures. ' 4363 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 4364 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4365 self.categorical_column)) 4366 # Feature has been already transformed. Return the intermediate 4367 # representation created by _transform_feature. 4368 dense_tensor = inputs.get(self) 4369 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4370 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 4371 sparse_tensors.id_tensor) 4372 return SequenceDenseColumn.TensorSequenceLengthPair( 4373 dense_tensor=dense_tensor, sequence_length=sequence_length) 4374 4375 @property 4376 def parents(self): 4377 """See 'FeatureColumn` base class.""" 4378 return [self.categorical_column] 4379 4380 def get_config(self): 4381 """See 'FeatureColumn` base class.""" 4382 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4383 config = dict(zip(self._fields, self)) 4384 config['categorical_column'] = serialize_feature_column( 4385 self.categorical_column) 4386 return config 4387 4388 @classmethod 4389 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4390 """See 'FeatureColumn` base class.""" 4391 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4392 _check_config_keys(config, cls._fields) 4393 kwargs = _standardize_and_copy_config(config) 4394 kwargs['categorical_column'] = deserialize_feature_column( 4395 config['categorical_column'], custom_objects, columns_by_name) 4396 return cls(**kwargs) 4397 4398 4399def _verify_static_batch_size_equality(tensors, columns): 4400 """Verify equality between static batch sizes. 4401 4402 Args: 4403 tensors: iterable of input tensors. 4404 columns: Corresponding feature columns. 4405 4406 Raises: 4407 ValueError: in case of mismatched batch sizes. 4408 """ 4409 # bath_size is a Dimension object. 4410 expected_batch_size = None 4411 for i in range(0, len(tensors)): 4412 batch_size = tensor_shape.Dimension(tensor_shape.dimension_value( 4413 tensors[i].shape[0])) 4414 if batch_size.value is not None: 4415 if expected_batch_size is None: 4416 bath_size_column_index = i 4417 expected_batch_size = batch_size 4418 elif not expected_batch_size.is_compatible_with(batch_size): 4419 raise ValueError( 4420 'Batch size (first dimension) of each feature must be same. ' 4421 'Batch size of columns ({}, {}): ({}, {})'.format( 4422 columns[bath_size_column_index].name, columns[i].name, 4423 expected_batch_size, batch_size)) 4424 4425 4426class SequenceCategoricalColumn( 4427 CategoricalColumn, 4428 fc_old._SequenceCategoricalColumn, # pylint: disable=protected-access 4429 collections.namedtuple('SequenceCategoricalColumn', 4430 ('categorical_column'))): 4431 """Represents sequences of categorical data.""" 4432 4433 @property 4434 def _is_v2_column(self): 4435 return (isinstance(self.categorical_column, FeatureColumn) and 4436 self.categorical_column._is_v2_column) # pylint: disable=protected-access 4437 4438 @property 4439 def name(self): 4440 """See `FeatureColumn` base class.""" 4441 return self.categorical_column.name 4442 4443 @property 4444 def parse_example_spec(self): 4445 """See `FeatureColumn` base class.""" 4446 return self.categorical_column.parse_example_spec 4447 4448 @property 4449 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4450 _FEATURE_COLUMN_DEPRECATION) 4451 def _parse_example_spec(self): 4452 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 4453 4454 def transform_feature(self, transformation_cache, state_manager): 4455 """See `FeatureColumn` base class.""" 4456 return self.categorical_column.transform_feature(transformation_cache, 4457 state_manager) 4458 4459 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4460 _FEATURE_COLUMN_DEPRECATION) 4461 def _transform_feature(self, inputs): 4462 return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access 4463 4464 @property 4465 def num_buckets(self): 4466 """Returns number of buckets in this sparse feature.""" 4467 return self.categorical_column.num_buckets 4468 4469 @property 4470 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4471 _FEATURE_COLUMN_DEPRECATION) 4472 def _num_buckets(self): 4473 return self.categorical_column._num_buckets # pylint: disable=protected-access 4474 4475 def _get_sparse_tensors_helper(self, sparse_tensors): 4476 id_tensor = sparse_tensors.id_tensor 4477 weight_tensor = sparse_tensors.weight_tensor 4478 # Expands third dimension, if necessary so that embeddings are not 4479 # combined during embedding lookup. If the tensor is already 3D, leave 4480 # as-is. 4481 shape = array_ops.shape(id_tensor) 4482 # Compute the third dimension explicitly instead of setting it to -1, as 4483 # that doesn't work for dynamically shaped tensors with 0-length at runtime. 4484 # This happens for empty sequences. 4485 target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])] 4486 id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape) 4487 if weight_tensor is not None: 4488 weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape) 4489 return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) 4490 4491 def get_sparse_tensors(self, transformation_cache, state_manager): 4492 """Returns an IdWeightPair. 4493 4494 `IdWeightPair` is a pair of `SparseTensor`s which represents ids and 4495 weights. 4496 4497 `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets` 4498 `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a 4499 `SparseTensor` of `float` or `None` to indicate all weights should be 4500 taken to be 1. If specified, `weight_tensor` must have exactly the same 4501 shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing 4502 output of a `VarLenFeature` which is a ragged matrix. 4503 4504 Args: 4505 transformation_cache: A `FeatureTransformationCache` object to access 4506 features. 4507 state_manager: A `StateManager` to create / access resources such as 4508 lookup tables. 4509 """ 4510 sparse_tensors = self.categorical_column.get_sparse_tensors( 4511 transformation_cache, state_manager) 4512 return self._get_sparse_tensors_helper(sparse_tensors) 4513 4514 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4515 _FEATURE_COLUMN_DEPRECATION) 4516 def _get_sparse_tensors(self, inputs, weight_collections=None, 4517 trainable=None): 4518 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4519 return self._get_sparse_tensors_helper(sparse_tensors) 4520 4521 @property 4522 def parents(self): 4523 """See 'FeatureColumn` base class.""" 4524 return [self.categorical_column] 4525 4526 def get_config(self): 4527 """See 'FeatureColumn` base class.""" 4528 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4529 config = dict(zip(self._fields, self)) 4530 config['categorical_column'] = serialize_feature_column( 4531 self.categorical_column) 4532 return config 4533 4534 @classmethod 4535 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4536 """See 'FeatureColumn` base class.""" 4537 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4538 _check_config_keys(config, cls._fields) 4539 kwargs = _standardize_and_copy_config(config) 4540 kwargs['categorical_column'] = deserialize_feature_column( 4541 config['categorical_column'], custom_objects, columns_by_name) 4542 return cls(**kwargs) 4543 4544 4545def _check_config_keys(config, expected_keys): 4546 """Checks that a config has all expected_keys.""" 4547 if set(config.keys()) != set(expected_keys): 4548 raise ValueError('Invalid config: {}, expected keys: {}'.format( 4549 config, expected_keys)) 4550 4551 4552def _standardize_and_copy_config(config): 4553 """Returns a shallow copy of config with lists turned to tuples. 4554 4555 Keras serialization uses nest to listify everything. 4556 This causes problems with the NumericColumn shape, which becomes 4557 unhashable. We could try to solve this on the Keras side, but that 4558 would require lots of tracking to avoid changing existing behavior. 4559 Instead, we ensure here that we revive correctly. 4560 4561 Args: 4562 config: dict that will be used to revive a Feature Column 4563 4564 Returns: 4565 Shallow copy of config with lists turned to tuples. 4566 """ 4567 kwargs = config.copy() 4568 for k, v in kwargs.items(): 4569 if isinstance(v, list): 4570 kwargs[k] = tuple(v) 4571 4572 return kwargs 4573 4574 4575def _sanitize_column_name_for_variable_scope(name): 4576 """Sanitizes user-provided feature names for use as variable scopes.""" 4577 invalid_char = re.compile('[^A-Za-z0-9_.\\-]') 4578 return invalid_char.sub('_', name) 4579