1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This API defines FeatureColumn abstraction. 16 17FeatureColumns provide a high level abstraction for ingesting and representing 18features. FeatureColumns are also the primary way of encoding features for 19canned `tf.estimator.Estimator`s. 20 21When using FeatureColumns with `Estimators`, the type of feature column you 22should choose depends on (1) the feature type and (2) the model type. 23 241. Feature type: 25 26 * Continuous features can be represented by `numeric_column`. 27 * Categorical features can be represented by any `categorical_column_with_*` 28 column: 29 - `categorical_column_with_vocabulary_list` 30 - `categorical_column_with_vocabulary_file` 31 - `categorical_column_with_hash_bucket` 32 - `categorical_column_with_identity` 33 - `weighted_categorical_column` 34 352. Model type: 36 37 * Deep neural network models (`DNNClassifier`, `DNNRegressor`). 38 39 Continuous features can be directly fed into deep neural network models. 40 41 age_column = numeric_column("age") 42 43 To feed sparse features into DNN models, wrap the column with 44 `embedding_column` or `indicator_column`. `indicator_column` is recommended 45 for features with only a few possible values. For features with many 46 possible values, to reduce the size of your model, `embedding_column` is 47 recommended. 48 49 embedded_dept_column = embedding_column( 50 categorical_column_with_vocabulary_list( 51 "department", ["math", "philosophy", ...]), dimension=10) 52 53 * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`). 54 55 Sparse features can be fed directly into linear models. They behave like an 56 indicator column but with an efficient implementation. 57 58 dept_column = categorical_column_with_vocabulary_list("department", 59 ["math", "philosophy", "english"]) 60 61 It is recommended that continuous features be bucketized before being 62 fed into linear models. 63 64 bucketized_age_column = bucketized_column( 65 source_column=age_column, 66 boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) 67 68 Sparse features can be crossed (also known as conjuncted or combined) in 69 order to form non-linearities, and then fed into linear models. 70 71 cross_dept_age_column = crossed_column( 72 columns=["department", bucketized_age_column], 73 hash_bucket_size=1000) 74 75Example of building canned `Estimator`s using FeatureColumns: 76 77 ```python 78 # Define features and transformations 79 deep_feature_columns = [age_column, embedded_dept_column] 80 wide_feature_columns = [dept_column, bucketized_age_column, 81 cross_dept_age_column] 82 83 # Build deep model 84 estimator = DNNClassifier( 85 feature_columns=deep_feature_columns, 86 hidden_units=[500, 250, 50]) 87 estimator.train(...) 88 89 # Or build a wide model 90 estimator = LinearClassifier( 91 feature_columns=wide_feature_columns) 92 estimator.train(...) 93 94 # Or build a wide and deep model! 95 estimator = DNNLinearCombinedClassifier( 96 linear_feature_columns=wide_feature_columns, 97 dnn_feature_columns=deep_feature_columns, 98 dnn_hidden_units=[500, 250, 50]) 99 estimator.train(...) 100 ``` 101 102 103FeatureColumns can also be transformed into a generic input layer for 104custom models using `input_layer`. 105 106Example of building model using FeatureColumns, this can be used in a 107`model_fn` which is given to the {tf.estimator.Estimator}: 108 109 ```python 110 # Building model via layers 111 112 deep_feature_columns = [age_column, embedded_dept_column] 113 columns_to_tensor = parse_feature_columns_from_examples( 114 serialized=my_data, 115 feature_columns=deep_feature_columns) 116 first_layer = input_layer( 117 features=columns_to_tensor, 118 feature_columns=deep_feature_columns) 119 second_layer = fully_connected(first_layer, ...) 120 ``` 121 122NOTE: Functions prefixed with "_" indicate experimental or private parts of 123the API subject to change, and should not be relied upon! 124""" 125 126from __future__ import absolute_import 127from __future__ import division 128from __future__ import print_function 129 130import abc 131import collections 132import math 133 134import re 135import numpy as np 136import six 137 138from tensorflow.python.eager import context 139from tensorflow.python.feature_column import feature_column as fc_old 140from tensorflow.python.feature_column import utils as fc_utils 141from tensorflow.python.framework import dtypes 142from tensorflow.python.framework import ops 143from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 144from tensorflow.python.framework import tensor_shape 145# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out 146# of the main repo. 147from tensorflow.python.keras import initializers 148from tensorflow.python.keras.utils import generic_utils 149from tensorflow.python.keras.engine import training 150from tensorflow.python.keras.engine.base_layer import Layer 151from tensorflow.python.ops import array_ops 152from tensorflow.python.ops import check_ops 153from tensorflow.python.ops import control_flow_ops 154from tensorflow.python.ops import embedding_ops 155from tensorflow.python.ops import init_ops 156from tensorflow.python.ops import lookup_ops 157from tensorflow.python.ops import math_ops 158from tensorflow.python.ops import nn_ops 159from tensorflow.python.ops import parsing_ops 160from tensorflow.python.ops import sparse_ops 161from tensorflow.python.ops import string_ops 162from tensorflow.python.ops import variable_scope 163from tensorflow.python.ops import variables 164from tensorflow.python.platform import gfile 165from tensorflow.python.platform import tf_logging as logging 166from tensorflow.python.training import checkpoint_utils 167from tensorflow.python.training.tracking import tracking 168from tensorflow.python.training.tracking import base as trackable 169from tensorflow.python.util import deprecation 170from tensorflow.python.util import nest 171from tensorflow.python.util.tf_export import tf_export 172from tensorflow.python.util.compat import collections_abc 173 174 175_FEATURE_COLUMN_DEPRECATION_DATE = None 176_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being ' 177 'deprecated. Please use the new FeatureColumn ' 178 'APIs instead.') 179 180 181class StateManager(object): 182 """Manages the state associated with FeatureColumns. 183 184 Some `FeatureColumn`s create variables or resources to assist their 185 computation. The `StateManager` is responsible for creating and storing these 186 objects since `FeatureColumn`s are supposed to be stateless configuration 187 only. 188 """ 189 190 def create_variable(self, 191 feature_column, 192 name, 193 shape, 194 dtype=None, 195 trainable=True, 196 use_resource=True, 197 initializer=None): 198 """Creates a new variable. 199 200 Args: 201 feature_column: A `FeatureColumn` object this variable corresponds to. 202 name: variable name. 203 shape: variable shape. 204 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 205 trainable: Whether this variable is trainable or not. 206 use_resource: If true, we use resource variables. Otherwise we use 207 RefVariable. 208 initializer: initializer instance (callable). 209 210 Returns: 211 The created variable. 212 """ 213 del feature_column, name, shape, dtype, trainable, use_resource, initializer 214 raise NotImplementedError('StateManager.create_variable') 215 216 def add_variable(self, feature_column, var): 217 """Adds an existing variable to the state. 218 219 Args: 220 feature_column: A `FeatureColumn` object to associate this variable with. 221 var: The variable. 222 """ 223 del feature_column, var 224 raise NotImplementedError('StateManager.add_variable') 225 226 def get_variable(self, feature_column, name): 227 """Returns an existing variable. 228 229 Args: 230 feature_column: A `FeatureColumn` object this variable corresponds to. 231 name: variable name. 232 """ 233 del feature_column, name 234 raise NotImplementedError('StateManager.get_var') 235 236 def add_resource(self, feature_column, name, resource): 237 """Creates a new resource. 238 239 Resources can be things such as tables etc. 240 241 Args: 242 feature_column: A `FeatureColumn` object this resource corresponds to. 243 name: Name of the resource. 244 resource: The resource. 245 246 Returns: 247 The created resource. 248 """ 249 del feature_column, name, resource 250 raise NotImplementedError('StateManager.add_resource') 251 252 def get_resource(self, feature_column, name): 253 """Returns an already created resource. 254 255 Resources can be things such as tables etc. 256 257 Args: 258 feature_column: A `FeatureColumn` object this variable corresponds to. 259 name: Name of the resource. 260 """ 261 del feature_column, name 262 raise NotImplementedError('StateManager.get_resource') 263 264 265class _StateManagerImpl(StateManager): 266 """Manages the state of DenseFeatures and LinearLayer.""" 267 268 def __init__(self, layer, trainable): 269 """Creates an _StateManagerImpl object. 270 271 Args: 272 layer: The input layer this state manager is associated with. 273 trainable: Whether by default, variables created are trainable or not. 274 """ 275 self._trainable = trainable 276 self._layer = layer 277 if self._layer is not None and not hasattr(self._layer, '_resources'): 278 self._layer._resources = [] # pylint: disable=protected-access 279 self._cols_to_vars_map = collections.defaultdict(lambda: {}) 280 # TODO(vbardiovsky): Make sure the resources are tracked by moving them to 281 # the layer (inheriting from AutoTrackable), e.g.: 282 # self._layer._resources_map = data_structures.Mapping() 283 self._cols_to_resources_map = collections.defaultdict(lambda: {}) 284 285 def create_variable(self, 286 feature_column, 287 name, 288 shape, 289 dtype=None, 290 trainable=True, 291 use_resource=True, 292 initializer=None): 293 if name in self._cols_to_vars_map[feature_column]: 294 raise ValueError('Variable already exists.') 295 296 # We explicitly track these variables since `name` is not guaranteed to be 297 # unique and disable manual tracking that the add_weight call does. 298 with trackable.no_manual_dependency_tracking_scope(self._layer): 299 var = self._layer.add_weight( 300 name=name, 301 shape=shape, 302 dtype=dtype, 303 initializer=initializer, 304 trainable=self._trainable and trainable, 305 use_resource=use_resource, 306 # TODO(rohanj): Get rid of this hack once we have a mechanism for 307 # specifying a default partitioner for an entire layer. In that case, 308 # the default getter for Layers should work. 309 getter=variable_scope.get_variable) 310 if isinstance(var, variables.PartitionedVariable): 311 for v in var: 312 part_name = name + '/' + str(v._get_save_slice_info().var_offset[0]) # pylint: disable=protected-access 313 self._layer._track_trackable(v, feature_column.name + '/' + part_name) # pylint: disable=protected-access 314 else: 315 if isinstance(var, trackable.Trackable): 316 self._layer._track_trackable(var, feature_column.name + '/' + name) # pylint: disable=protected-access 317 318 self._cols_to_vars_map[feature_column][name] = var 319 return var 320 321 def get_variable(self, feature_column, name): 322 if name in self._cols_to_vars_map[feature_column]: 323 return self._cols_to_vars_map[feature_column][name] 324 raise ValueError('Variable does not exist.') 325 326 def add_resource(self, feature_column, name, resource): 327 self._cols_to_resources_map[feature_column][name] = resource 328 if self._layer is not None: 329 self._layer._resources.append(resource) # pylint: disable=protected-access 330 331 def get_resource(self, feature_column, name): 332 if name in self._cols_to_resources_map[feature_column]: 333 return self._cols_to_resources_map[feature_column][name] 334 raise ValueError('Resource does not exist.') 335 336 337class _StateManagerImplV2(_StateManagerImpl): 338 """Manages the state of DenseFeatures.""" 339 340 def create_variable(self, 341 feature_column, 342 name, 343 shape, 344 dtype=None, 345 trainable=True, 346 use_resource=True, 347 initializer=None): 348 if name in self._cols_to_vars_map[feature_column]: 349 raise ValueError('Variable already exists.') 350 351 # We explicitly track these variables since `name` is not guaranteed to be 352 # unique and disable manual tracking that the add_weight call does. 353 with trackable.no_manual_dependency_tracking_scope(self._layer): 354 var = self._layer.add_weight( 355 name=name, 356 shape=shape, 357 dtype=dtype, 358 initializer=initializer, 359 trainable=self._trainable and trainable, 360 use_resource=use_resource) 361 if isinstance(var, trackable.Trackable): 362 self._layer._track_trackable(var, feature_column.name + '/' + name) # pylint: disable=protected-access 363 self._cols_to_vars_map[feature_column][name] = var 364 return var 365 366 367class _BaseFeaturesLayer(Layer): 368 """Base class for DenseFeatures and SequenceFeatures. 369 370 Defines common methods and helpers. 371 372 Args: 373 feature_columns: An iterable containing the FeatureColumns to use as 374 inputs to your model. 375 expected_column_type: Expected class for provided feature columns. 376 trainable: Boolean, whether the layer's variables will be updated via 377 gradient descent during training. 378 name: Name to give to the DenseFeatures. 379 **kwargs: Keyword arguments to construct a layer. 380 381 Raises: 382 ValueError: if an item in `feature_columns` doesn't match 383 `expected_column_type`. 384 """ 385 386 def __init__(self, 387 feature_columns, 388 expected_column_type, 389 trainable, 390 name, 391 partitioner=None, 392 **kwargs): 393 super(_BaseFeaturesLayer, self).__init__( 394 name=name, trainable=trainable, **kwargs) 395 self._feature_columns = _normalize_feature_columns(feature_columns) 396 self._state_manager = _StateManagerImpl(self, self.trainable) 397 self._partitioner = partitioner 398 for column in self._feature_columns: 399 if not isinstance(column, expected_column_type): 400 raise ValueError( 401 'Items of feature_columns must be a {}. ' 402 'You can wrap a categorical column with an ' 403 'embedding_column or indicator_column. Given: {}'.format( 404 expected_column_type, column)) 405 406 def build(self, _): 407 for column in self._feature_columns: 408 with variable_scope._pure_variable_scope( # pylint: disable=protected-access 409 self.name, 410 partitioner=self._partitioner): 411 with variable_scope._pure_variable_scope( # pylint: disable=protected-access 412 _sanitize_column_name_for_variable_scope(column.name)): 413 column.create_state(self._state_manager) 414 super(_BaseFeaturesLayer, self).build(None) 415 416 def _output_shape(self, input_shape, num_elements): 417 """Computes expected output shape of the layer or a column's dense tensor. 418 419 Args: 420 input_shape: Tensor or array with batch shape. 421 num_elements: Size of the last dimension of the output. 422 423 Returns: 424 Tuple with output shape. 425 """ 426 raise NotImplementedError('Calling an abstract method.') 427 428 def compute_output_shape(self, input_shape): 429 total_elements = 0 430 for column in self._feature_columns: 431 total_elements += column.variable_shape.num_elements() 432 return self._target_shape(input_shape, total_elements) 433 434 def _process_dense_tensor(self, column, tensor): 435 """Reshapes the dense tensor output of a column based on expected shape. 436 437 Args: 438 column: A DenseColumn or SequenceDenseColumn object. 439 tensor: A dense tensor obtained from the same column. 440 441 Returns: 442 Reshaped dense tensor.""" 443 num_elements = column.variable_shape.num_elements() 444 target_shape = self._target_shape(array_ops.shape(tensor), num_elements) 445 return array_ops.reshape(tensor, shape=target_shape) 446 447 def _verify_and_concat_tensors(self, output_tensors): 448 """Verifies and concatenates the dense output of several columns.""" 449 _verify_static_batch_size_equality(output_tensors, self._feature_columns) 450 return array_ops.concat(output_tensors, -1) 451 452 def get_config(self): 453 # Import here to avoid circular imports. 454 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 455 column_configs = serialization.serialize_feature_columns( 456 self._feature_columns) 457 config = {'feature_columns': column_configs} 458 config['partitioner'] = generic_utils.serialize_keras_object( 459 self._partitioner) 460 461 base_config = super( # pylint: disable=bad-super-call 462 _BaseFeaturesLayer, self).get_config() 463 return dict(list(base_config.items()) + list(config.items())) 464 465 @classmethod 466 def from_config(cls, config, custom_objects=None): 467 # Import here to avoid circular imports. 468 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 469 config_cp = config.copy() 470 config_cp['feature_columns'] = serialization.deserialize_feature_columns( 471 config['feature_columns'], custom_objects=custom_objects) 472 config_cp['partitioner'] = generic_utils.deserialize_keras_object( 473 config['partitioner'], custom_objects) 474 475 return cls(**config_cp) 476 477 478class _LinearModelLayer(Layer): 479 """Layer that contains logic for `LinearModel`.""" 480 481 def __init__(self, 482 feature_columns, 483 units=1, 484 sparse_combiner='sum', 485 trainable=True, 486 name=None, 487 **kwargs): 488 super(_LinearModelLayer, self).__init__( 489 name=name, trainable=trainable, **kwargs) 490 491 self._feature_columns = _normalize_feature_columns(feature_columns) 492 for column in self._feature_columns: 493 if not isinstance(column, (DenseColumn, CategoricalColumn)): 494 raise ValueError( 495 'Items of feature_columns must be either a ' 496 'DenseColumn or CategoricalColumn. Given: {}'.format(column)) 497 498 self._units = units 499 self._sparse_combiner = sparse_combiner 500 501 self._state_manager = _StateManagerImpl(self, self.trainable) 502 self.bias = None 503 504 def build(self, _): 505 # We need variable scopes for now because we want the variable partitioning 506 # information to percolate down. We also use _pure_variable_scope's here 507 # since we want to open up a name_scope in the `call` method while creating 508 # the ops. 509 with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access 510 for column in self._feature_columns: 511 with variable_scope._pure_variable_scope( # pylint: disable=protected-access 512 _sanitize_column_name_for_variable_scope(column.name)): 513 # Create the state for each feature column 514 column.create_state(self._state_manager) 515 516 # Create a weight variable for each column. 517 if isinstance(column, CategoricalColumn): 518 first_dim = column.num_buckets 519 else: 520 first_dim = column.variable_shape.num_elements() 521 self._state_manager.create_variable( 522 column, 523 name='weights', 524 dtype=dtypes.float32, 525 shape=(first_dim, self._units), 526 initializer=init_ops.zeros_initializer(), 527 trainable=self.trainable) 528 529 # Create a bias variable. 530 self.bias = self.add_variable( 531 name='bias_weights', 532 dtype=dtypes.float32, 533 shape=[self._units], 534 initializer=init_ops.zeros_initializer(), 535 trainable=self.trainable, 536 use_resource=True, 537 # TODO(rohanj): Get rid of this hack once we have a mechanism for 538 # specifying a default partitioner for an entire layer. In that case, 539 # the default getter for Layers should work. 540 getter=variable_scope.get_variable) 541 542 super(_LinearModelLayer, self).build(None) 543 544 def call(self, features): 545 if not isinstance(features, dict): 546 raise ValueError('We expected a dictionary here. Instead we got: {}' 547 .format(features)) 548 with ops.name_scope(self.name): 549 transformation_cache = FeatureTransformationCache(features) 550 weighted_sums = [] 551 for column in self._feature_columns: 552 with ops.name_scope( 553 _sanitize_column_name_for_variable_scope(column.name)): 554 # All the weights used in the linear model are owned by the state 555 # manager associated with this Linear Model. 556 weight_var = self._state_manager.get_variable(column, 'weights') 557 558 weighted_sum = _create_weighted_sum( 559 column=column, 560 transformation_cache=transformation_cache, 561 state_manager=self._state_manager, 562 sparse_combiner=self._sparse_combiner, 563 weight_var=weight_var) 564 weighted_sums.append(weighted_sum) 565 566 _verify_static_batch_size_equality(weighted_sums, self._feature_columns) 567 predictions_no_bias = math_ops.add_n( 568 weighted_sums, name='weighted_sum_no_bias') 569 predictions = nn_ops.bias_add( 570 predictions_no_bias, self.bias, name='weighted_sum') 571 return predictions 572 573 def get_config(self): 574 # Import here to avoid circular imports. 575 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 576 column_configs = serialization.serialize_feature_columns( 577 self._feature_columns) 578 config = { 579 'feature_columns': column_configs, 580 'units': self._units, 581 'sparse_combiner': self._sparse_combiner 582 } 583 584 base_config = super( # pylint: disable=bad-super-call 585 _LinearModelLayer, self).get_config() 586 return dict(list(base_config.items()) + list(config.items())) 587 588 @classmethod 589 def from_config(cls, config, custom_objects=None): 590 # Import here to avoid circular imports. 591 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 592 config_cp = config.copy() 593 columns = serialization.deserialize_feature_columns( 594 config_cp['feature_columns'], custom_objects=custom_objects) 595 596 del config_cp['feature_columns'] 597 return cls(feature_columns=columns, **config_cp) 598 599 600# TODO(tanzheny): Cleanup it with respect to Premade model b/132690565. 601class LinearModel(training.Model): 602 """Produces a linear prediction `Tensor` based on given `feature_columns`. 603 604 This layer generates a weighted sum based on output dimension `units`. 605 Weighted sum refers to logits in classification problems. It refers to the 606 prediction itself for linear regression problems. 607 608 Note on supported columns: `LinearLayer` treats categorical columns as 609 `indicator_column`s. To be specific, assume the input as `SparseTensor` looks 610 like: 611 612 ```python 613 shape = [2, 2] 614 { 615 [0, 0]: "a" 616 [1, 0]: "b" 617 [1, 1]: "c" 618 } 619 ``` 620 `linear_model` assigns weights for the presence of "a", "b", "c' implicitly, 621 just like `indicator_column`, while `input_layer` explicitly requires wrapping 622 each of categorical columns with an `embedding_column` or an 623 `indicator_column`. 624 625 Example of usage: 626 627 ```python 628 price = numeric_column('price') 629 price_buckets = bucketized_column(price, boundaries=[0., 10., 100., 1000.]) 630 keywords = categorical_column_with_hash_bucket("keywords", 10K) 631 keywords_price = crossed_column('keywords', price_buckets, ...) 632 columns = [price_buckets, keywords, keywords_price ...] 633 linear_model = LinearLayer(columns) 634 635 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 636 prediction = linear_model(features) 637 ``` 638 """ 639 640 def __init__(self, 641 feature_columns, 642 units=1, 643 sparse_combiner='sum', 644 trainable=True, 645 name=None, 646 **kwargs): 647 """Constructs a LinearLayer. 648 649 Args: 650 feature_columns: An iterable containing the FeatureColumns to use as 651 inputs to your model. All items should be instances of classes derived 652 from `_FeatureColumn`s. 653 units: An integer, dimensionality of the output space. Default value is 1. 654 sparse_combiner: A string specifying how to reduce if a categorical column 655 is multivalent. Except `numeric_column`, almost all columns passed to 656 `linear_model` are considered as categorical columns. It combines each 657 categorical column independently. Currently "mean", "sqrtn" and "sum" 658 are supported, with "sum" the default for linear model. "sqrtn" often 659 achieves good accuracy, in particular with bag-of-words columns. 660 * "sum": do not normalize features in the column 661 * "mean": do l1 normalization on features in the column 662 * "sqrtn": do l2 normalization on features in the column 663 For example, for two features represented as the categorical columns: 664 665 ```python 666 # Feature 1 667 668 shape = [2, 2] 669 { 670 [0, 0]: "a" 671 [0, 1]: "b" 672 [1, 0]: "c" 673 } 674 675 # Feature 2 676 677 shape = [2, 3] 678 { 679 [0, 0]: "d" 680 [1, 0]: "e" 681 [1, 1]: "f" 682 [1, 2]: "g" 683 } 684 ``` 685 686 with `sparse_combiner` as "mean", the linear model outputs conceptly are 687 ``` 688 y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0 689 y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1 690 ``` 691 where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight 692 assigned to the presence of `x` in the input features. 693 trainable: If `True` also add the variable to the graph collection 694 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 695 name: Name to give to the Linear Model. All variables and ops created will 696 be scoped by this name. 697 **kwargs: Keyword arguments to construct a layer. 698 699 Raises: 700 ValueError: if an item in `feature_columns` is neither a `DenseColumn` 701 nor `CategoricalColumn`. 702 """ 703 704 super(LinearModel, self).__init__(name=name, **kwargs) 705 self.layer = _LinearModelLayer( 706 feature_columns, 707 units, 708 sparse_combiner, 709 trainable, 710 name=self.name, 711 **kwargs) 712 713 def call(self, features): 714 """Returns a `Tensor` the represents the predictions of a linear model. 715 716 Args: 717 features: A mapping from key to tensors. `_FeatureColumn`s look up via 718 these keys. For example `numeric_column('price')` will look at 'price' 719 key in this dict. Values are `Tensor` or `SparseTensor` depending on 720 corresponding `_FeatureColumn`. 721 722 Returns: 723 A `Tensor` which represents predictions/logits of a linear model. Its 724 shape is (batch_size, units) and its dtype is `float32`. 725 726 Raises: 727 ValueError: If features are not a dictionary. 728 """ 729 return self.layer(features) 730 731 @property 732 def bias(self): 733 return self.layer.bias 734 735 736def _transform_features_v2(features, feature_columns, state_manager): 737 """Returns transformed features based on features columns passed in. 738 739 Please note that most probably you would not need to use this function. Please 740 check `input_layer` and `linear_model` to see whether they will 741 satisfy your use case or not. 742 743 Example: 744 745 ```python 746 # Define features and transformations 747 crosses_a_x_b = crossed_column( 748 columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000) 749 price_buckets = bucketized_column( 750 source_column=numeric_column("price"), boundaries=[...]) 751 752 columns = [crosses_a_x_b, price_buckets] 753 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 754 transformed = transform_features(features=features, feature_columns=columns) 755 756 assertCountEqual(columns, transformed.keys()) 757 ``` 758 759 Args: 760 features: A mapping from key to tensors. `FeatureColumn`s look up via these 761 keys. For example `numeric_column('price')` will look at 'price' key in 762 this dict. Values can be a `SparseTensor` or a `Tensor` depends on 763 corresponding `FeatureColumn`. 764 feature_columns: An iterable containing all the `FeatureColumn`s. 765 state_manager: A StateManager object that holds the FeatureColumn state. 766 767 Returns: 768 A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values. 769 """ 770 feature_columns = _normalize_feature_columns(feature_columns) 771 outputs = {} 772 with ops.name_scope( 773 None, default_name='transform_features', values=features.values()): 774 transformation_cache = FeatureTransformationCache(features) 775 for column in feature_columns: 776 with ops.name_scope( 777 None, 778 default_name=_sanitize_column_name_for_variable_scope(column.name)): 779 outputs[column] = transformation_cache.get(column, state_manager) 780 return outputs 781 782 783@tf_export('feature_column.make_parse_example_spec', v1=[]) 784def make_parse_example_spec_v2(feature_columns): 785 """Creates parsing spec dictionary from input feature_columns. 786 787 The returned dictionary can be used as arg 'features' in 788 `tf.io.parse_example`. 789 790 Typical usage example: 791 792 ```python 793 # Define features and transformations 794 feature_a = tf.feature_column.categorical_column_with_vocabulary_file(...) 795 feature_b = tf.feature_column.numeric_column(...) 796 feature_c_bucketized = tf.feature_column.bucketized_column( 797 tf.feature_column.numeric_column("feature_c"), ...) 798 feature_a_x_feature_c = tf.feature_column.crossed_column( 799 columns=["feature_a", feature_c_bucketized], ...) 800 801 feature_columns = set( 802 [feature_b, feature_c_bucketized, feature_a_x_feature_c]) 803 features = tf.io.parse_example( 804 serialized=serialized_examples, 805 features=tf.feature_column.make_parse_example_spec(feature_columns)) 806 ``` 807 808 For the above example, make_parse_example_spec would return the dict: 809 810 ```python 811 { 812 "feature_a": parsing_ops.VarLenFeature(tf.string), 813 "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32), 814 "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32) 815 } 816 ``` 817 818 Args: 819 feature_columns: An iterable containing all feature columns. All items 820 should be instances of classes derived from `FeatureColumn`. 821 822 Returns: 823 A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature` 824 value. 825 826 Raises: 827 ValueError: If any of the given `feature_columns` is not a `FeatureColumn` 828 instance. 829 """ 830 result = {} 831 for column in feature_columns: 832 if not isinstance(column, FeatureColumn): 833 raise ValueError('All feature_columns must be FeatureColumn instances. ' 834 'Given: {}'.format(column)) 835 config = column.parse_example_spec 836 for key, value in six.iteritems(config): 837 if key in result and value != result[key]: 838 raise ValueError( 839 'feature_columns contain different parse_spec for key ' 840 '{}. Given {} and {}'.format(key, value, result[key])) 841 result.update(config) 842 return result 843 844 845@tf_export('feature_column.embedding_column') 846def embedding_column(categorical_column, 847 dimension, 848 combiner='mean', 849 initializer=None, 850 ckpt_to_load_from=None, 851 tensor_name_in_ckpt=None, 852 max_norm=None, 853 trainable=True, 854 use_safe_embedding_lookup=True): 855 """`DenseColumn` that converts from sparse, categorical input. 856 857 Use this when your inputs are sparse, but you want to convert them to a dense 858 representation (e.g., to feed to a DNN). 859 860 Inputs must be a `CategoricalColumn` created by any of the 861 `categorical_column_*` function. Here is an example of using 862 `embedding_column` with `DNNClassifier`: 863 864 ```python 865 video_id = categorical_column_with_identity( 866 key='video_id', num_buckets=1000000, default_value=0) 867 columns = [embedding_column(video_id, 9),...] 868 869 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 870 871 label_column = ... 872 def input_fn(): 873 features = tf.io.parse_example( 874 ..., features=make_parse_example_spec(columns + [label_column])) 875 labels = features.pop(label_column.name) 876 return features, labels 877 878 estimator.train(input_fn=input_fn, steps=100) 879 ``` 880 881 Here is an example using `embedding_column` with model_fn: 882 883 ```python 884 def model_fn(features, ...): 885 video_id = categorical_column_with_identity( 886 key='video_id', num_buckets=1000000, default_value=0) 887 columns = [embedding_column(video_id, 9),...] 888 dense_tensor = input_layer(features, columns) 889 # Form DNN layers, calculate loss, and return EstimatorSpec. 890 ... 891 ``` 892 893 Args: 894 categorical_column: A `CategoricalColumn` created by a 895 `categorical_column_with_*` function. This column produces the sparse IDs 896 that are inputs to the embedding lookup. 897 dimension: An integer specifying dimension of the embedding, must be > 0. 898 combiner: A string specifying how to reduce if there are multiple entries in 899 a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 900 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 901 with bag-of-words columns. Each of this can be thought as example level 902 normalizations on the column. For more information, see 903 `tf.embedding_lookup_sparse`. 904 initializer: A variable initializer function to be used in embedding 905 variable initialization. If not specified, defaults to 906 `truncated_normal_initializer` with mean `0.0` and 907 standard deviation `1/sqrt(dimension)`. 908 ckpt_to_load_from: String representing checkpoint name/pattern from which to 909 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 910 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which 911 to restore the column weights. Required if `ckpt_to_load_from` is not 912 `None`. 913 max_norm: If not `None`, embedding values are l2-normalized to this value. 914 trainable: Whether or not the embedding is trainable. Default is True. 915 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 916 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 917 there are no empty rows and all weights and ids are positive at the 918 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 919 input tensors. Defaults to true, consider turning off if the above checks 920 are not needed. Note that having empty rows will not trigger any error 921 though the output result might be 0 or omitted. 922 923 Returns: 924 `DenseColumn` that converts from sparse input. 925 926 Raises: 927 ValueError: if `dimension` not > 0. 928 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 929 is specified. 930 ValueError: if `initializer` is specified and is not callable. 931 RuntimeError: If eager execution is enabled. 932 """ 933 if (dimension is None) or (dimension < 1): 934 raise ValueError('Invalid dimension {}.'.format(dimension)) 935 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 936 raise ValueError('Must specify both `ckpt_to_load_from` and ' 937 '`tensor_name_in_ckpt` or none of them.') 938 939 if (initializer is not None) and (not callable(initializer)): 940 raise ValueError('initializer must be callable if specified. ' 941 'Embedding of column_name: {}'.format( 942 categorical_column.name)) 943 if initializer is None: 944 initializer = init_ops.truncated_normal_initializer( 945 mean=0.0, stddev=1 / math.sqrt(dimension)) 946 947 return EmbeddingColumn( 948 categorical_column=categorical_column, 949 dimension=dimension, 950 combiner=combiner, 951 initializer=initializer, 952 ckpt_to_load_from=ckpt_to_load_from, 953 tensor_name_in_ckpt=tensor_name_in_ckpt, 954 max_norm=max_norm, 955 trainable=trainable, 956 use_safe_embedding_lookup=use_safe_embedding_lookup) 957 958 959@tf_export(v1=['feature_column.shared_embedding_columns']) 960def shared_embedding_columns(categorical_columns, 961 dimension, 962 combiner='mean', 963 initializer=None, 964 shared_embedding_collection_name=None, 965 ckpt_to_load_from=None, 966 tensor_name_in_ckpt=None, 967 max_norm=None, 968 trainable=True, 969 use_safe_embedding_lookup=True): 970 """List of dense columns that convert from sparse, categorical input. 971 972 This is similar to `embedding_column`, except that it produces a list of 973 embedding columns that share the same embedding weights. 974 975 Use this when your inputs are sparse and of the same type (e.g. watched and 976 impression video IDs that share the same vocabulary), and you want to convert 977 them to a dense representation (e.g., to feed to a DNN). 978 979 Inputs must be a list of categorical columns created by any of the 980 `categorical_column_*` function. They must all be of the same type and have 981 the same arguments except `key`. E.g. they can be 982 categorical_column_with_vocabulary_file with the same vocabulary_file. Some or 983 all columns could also be weighted_categorical_column. 984 985 Here is an example embedding of two features for a DNNClassifier model: 986 987 ```python 988 watched_video_id = categorical_column_with_vocabulary_file( 989 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 990 impression_video_id = categorical_column_with_vocabulary_file( 991 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 992 columns = shared_embedding_columns( 993 [watched_video_id, impression_video_id], dimension=10) 994 995 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 996 997 label_column = ... 998 def input_fn(): 999 features = tf.io.parse_example( 1000 ..., features=make_parse_example_spec(columns + [label_column])) 1001 labels = features.pop(label_column.name) 1002 return features, labels 1003 1004 estimator.train(input_fn=input_fn, steps=100) 1005 ``` 1006 1007 Here is an example using `shared_embedding_columns` with model_fn: 1008 1009 ```python 1010 def model_fn(features, ...): 1011 watched_video_id = categorical_column_with_vocabulary_file( 1012 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 1013 impression_video_id = categorical_column_with_vocabulary_file( 1014 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 1015 columns = shared_embedding_columns( 1016 [watched_video_id, impression_video_id], dimension=10) 1017 dense_tensor = input_layer(features, columns) 1018 # Form DNN layers, calculate loss, and return EstimatorSpec. 1019 ... 1020 ``` 1021 1022 Args: 1023 categorical_columns: List of categorical columns created by a 1024 `categorical_column_with_*` function. These columns produce the sparse IDs 1025 that are inputs to the embedding lookup. All columns must be of the same 1026 type and have the same arguments except `key`. E.g. they can be 1027 categorical_column_with_vocabulary_file with the same vocabulary_file. 1028 Some or all columns could also be weighted_categorical_column. 1029 dimension: An integer specifying dimension of the embedding, must be > 0. 1030 combiner: A string specifying how to reduce if there are multiple entries in 1031 a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 1032 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 1033 with bag-of-words columns. Each of this can be thought as example level 1034 normalizations on the column. For more information, see 1035 `tf.embedding_lookup_sparse`. 1036 initializer: A variable initializer function to be used in embedding 1037 variable initialization. If not specified, defaults to 1038 `truncated_normal_initializer` with mean `0.0` and 1039 standard deviation `1/sqrt(dimension)`. 1040 shared_embedding_collection_name: Optional name of the collection where 1041 shared embedding weights are added. If not given, a reasonable name will 1042 be chosen based on the names of `categorical_columns`. This is also used 1043 in `variable_scope` when creating shared embedding weights. 1044 ckpt_to_load_from: String representing checkpoint name/pattern from which to 1045 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 1046 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which 1047 to restore the column weights. Required if `ckpt_to_load_from` is not 1048 `None`. 1049 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 1050 than this value, before combining. 1051 trainable: Whether or not the embedding is trainable. Default is True. 1052 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 1053 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 1054 there are no empty rows and all weights and ids are positive at the 1055 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 1056 input tensors. Defaults to true, consider turning off if the above checks 1057 are not needed. Note that having empty rows will not trigger any error 1058 though the output result might be 0 or omitted. 1059 1060 Returns: 1061 A list of dense columns that converts from sparse input. The order of 1062 results follows the ordering of `categorical_columns`. 1063 1064 Raises: 1065 ValueError: if `dimension` not > 0. 1066 ValueError: if any of the given `categorical_columns` is of different type 1067 or has different arguments than the others. 1068 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 1069 is specified. 1070 ValueError: if `initializer` is specified and is not callable. 1071 RuntimeError: if eager execution is enabled. 1072 """ 1073 if context.executing_eagerly(): 1074 raise RuntimeError('shared_embedding_columns are not supported when eager ' 1075 'execution is enabled.') 1076 1077 if (dimension is None) or (dimension < 1): 1078 raise ValueError('Invalid dimension {}.'.format(dimension)) 1079 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 1080 raise ValueError('Must specify both `ckpt_to_load_from` and ' 1081 '`tensor_name_in_ckpt` or none of them.') 1082 1083 if (initializer is not None) and (not callable(initializer)): 1084 raise ValueError('initializer must be callable if specified.') 1085 if initializer is None: 1086 initializer = init_ops.truncated_normal_initializer( 1087 mean=0.0, stddev=1. / math.sqrt(dimension)) 1088 1089 # Sort the columns so the default collection name is deterministic even if the 1090 # user passes columns from an unsorted collection, such as dict.values(). 1091 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 1092 1093 c0 = sorted_columns[0] 1094 num_buckets = c0._num_buckets # pylint: disable=protected-access 1095 if not isinstance(c0, fc_old._CategoricalColumn): # pylint: disable=protected-access 1096 raise ValueError( 1097 'All categorical_columns must be subclasses of _CategoricalColumn. ' 1098 'Given: {}, of type: {}'.format(c0, type(c0))) 1099 while isinstance( 1100 c0, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn, # pylint: disable=protected-access 1101 fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)): # pylint: disable=protected-access 1102 c0 = c0.categorical_column 1103 for c in sorted_columns[1:]: 1104 while isinstance( 1105 c, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn, # pylint: disable=protected-access 1106 fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)): # pylint: disable=protected-access 1107 c = c.categorical_column 1108 if not isinstance(c, type(c0)): 1109 raise ValueError( 1110 'To use shared_embedding_column, all categorical_columns must have ' 1111 'the same type, or be weighted_categorical_column or sequence column ' 1112 'of the same type. Given column: {} of type: {} does not match given ' 1113 'column: {} of type: {}'.format(c0, type(c0), c, type(c))) 1114 if num_buckets != c._num_buckets: # pylint: disable=protected-access 1115 raise ValueError( 1116 'To use shared_embedding_column, all categorical_columns must have ' 1117 'the same number of buckets. Given column: {} with buckets: {} does ' 1118 'not match column: {} with buckets: {}'.format( 1119 c0, num_buckets, c, c._num_buckets)) # pylint: disable=protected-access 1120 1121 if not shared_embedding_collection_name: 1122 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 1123 shared_embedding_collection_name += '_shared_embedding' 1124 1125 result = [] 1126 for column in categorical_columns: 1127 result.append( 1128 fc_old._SharedEmbeddingColumn( # pylint: disable=protected-access 1129 categorical_column=column, 1130 initializer=initializer, 1131 dimension=dimension, 1132 combiner=combiner, 1133 shared_embedding_collection_name=shared_embedding_collection_name, 1134 ckpt_to_load_from=ckpt_to_load_from, 1135 tensor_name_in_ckpt=tensor_name_in_ckpt, 1136 max_norm=max_norm, 1137 trainable=trainable, 1138 use_safe_embedding_lookup=use_safe_embedding_lookup)) 1139 1140 return result 1141 1142 1143@tf_export('feature_column.shared_embeddings', v1=[]) 1144def shared_embedding_columns_v2(categorical_columns, 1145 dimension, 1146 combiner='mean', 1147 initializer=None, 1148 shared_embedding_collection_name=None, 1149 ckpt_to_load_from=None, 1150 tensor_name_in_ckpt=None, 1151 max_norm=None, 1152 trainable=True, 1153 use_safe_embedding_lookup=True): 1154 """List of dense columns that convert from sparse, categorical input. 1155 1156 This is similar to `embedding_column`, except that it produces a list of 1157 embedding columns that share the same embedding weights. 1158 1159 Use this when your inputs are sparse and of the same type (e.g. watched and 1160 impression video IDs that share the same vocabulary), and you want to convert 1161 them to a dense representation (e.g., to feed to a DNN). 1162 1163 Inputs must be a list of categorical columns created by any of the 1164 `categorical_column_*` function. They must all be of the same type and have 1165 the same arguments except `key`. E.g. they can be 1166 categorical_column_with_vocabulary_file with the same vocabulary_file. Some or 1167 all columns could also be weighted_categorical_column. 1168 1169 Here is an example embedding of two features for a DNNClassifier model: 1170 1171 ```python 1172 watched_video_id = categorical_column_with_vocabulary_file( 1173 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 1174 impression_video_id = categorical_column_with_vocabulary_file( 1175 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 1176 columns = shared_embedding_columns( 1177 [watched_video_id, impression_video_id], dimension=10) 1178 1179 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 1180 1181 label_column = ... 1182 def input_fn(): 1183 features = tf.io.parse_example( 1184 ..., features=make_parse_example_spec(columns + [label_column])) 1185 labels = features.pop(label_column.name) 1186 return features, labels 1187 1188 estimator.train(input_fn=input_fn, steps=100) 1189 ``` 1190 1191 Here is an example using `shared_embedding_columns` with model_fn: 1192 1193 ```python 1194 def model_fn(features, ...): 1195 watched_video_id = categorical_column_with_vocabulary_file( 1196 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 1197 impression_video_id = categorical_column_with_vocabulary_file( 1198 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 1199 columns = shared_embedding_columns( 1200 [watched_video_id, impression_video_id], dimension=10) 1201 dense_tensor = input_layer(features, columns) 1202 # Form DNN layers, calculate loss, and return EstimatorSpec. 1203 ... 1204 ``` 1205 1206 Args: 1207 categorical_columns: List of categorical columns created by a 1208 `categorical_column_with_*` function. These columns produce the sparse IDs 1209 that are inputs to the embedding lookup. All columns must be of the same 1210 type and have the same arguments except `key`. E.g. they can be 1211 categorical_column_with_vocabulary_file with the same vocabulary_file. 1212 Some or all columns could also be weighted_categorical_column. 1213 dimension: An integer specifying dimension of the embedding, must be > 0. 1214 combiner: A string specifying how to reduce if there are multiple entries 1215 in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 1216 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 1217 with bag-of-words columns. Each of this can be thought as example level 1218 normalizations on the column. For more information, see 1219 `tf.embedding_lookup_sparse`. 1220 initializer: A variable initializer function to be used in embedding 1221 variable initialization. If not specified, defaults to 1222 `truncated_normal_initializer` with mean `0.0` and standard 1223 deviation `1/sqrt(dimension)`. 1224 shared_embedding_collection_name: Optional collective name of these columns. 1225 If not given, a reasonable name will be chosen based on the names of 1226 `categorical_columns`. 1227 ckpt_to_load_from: String representing checkpoint name/pattern from which to 1228 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 1229 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from 1230 which to restore the column weights. Required if `ckpt_to_load_from` is 1231 not `None`. 1232 max_norm: If not `None`, each embedding is clipped if its l2-norm is 1233 larger than this value, before combining. 1234 trainable: Whether or not the embedding is trainable. Default is True. 1235 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 1236 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 1237 there are no empty rows and all weights and ids are positive at the 1238 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 1239 input tensors. Defaults to true, consider turning off if the above checks 1240 are not needed. Note that having empty rows will not trigger any error 1241 though the output result might be 0 or omitted. 1242 1243 Returns: 1244 A list of dense columns that converts from sparse input. The order of 1245 results follows the ordering of `categorical_columns`. 1246 1247 Raises: 1248 ValueError: if `dimension` not > 0. 1249 ValueError: if any of the given `categorical_columns` is of different type 1250 or has different arguments than the others. 1251 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 1252 is specified. 1253 ValueError: if `initializer` is specified and is not callable. 1254 RuntimeError: if eager execution is enabled. 1255 """ 1256 if context.executing_eagerly(): 1257 raise RuntimeError('shared_embedding_columns are not supported when eager ' 1258 'execution is enabled.') 1259 1260 if (dimension is None) or (dimension < 1): 1261 raise ValueError('Invalid dimension {}.'.format(dimension)) 1262 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 1263 raise ValueError('Must specify both `ckpt_to_load_from` and ' 1264 '`tensor_name_in_ckpt` or none of them.') 1265 1266 if (initializer is not None) and (not callable(initializer)): 1267 raise ValueError('initializer must be callable if specified.') 1268 if initializer is None: 1269 initializer = init_ops.truncated_normal_initializer( 1270 mean=0.0, stddev=1. / math.sqrt(dimension)) 1271 1272 # Sort the columns so the default collection name is deterministic even if the 1273 # user passes columns from an unsorted collection, such as dict.values(). 1274 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 1275 1276 c0 = sorted_columns[0] 1277 num_buckets = c0.num_buckets 1278 if not isinstance(c0, CategoricalColumn): 1279 raise ValueError( 1280 'All categorical_columns must be subclasses of CategoricalColumn. ' 1281 'Given: {}, of type: {}'.format(c0, type(c0))) 1282 while isinstance(c0, (WeightedCategoricalColumn, SequenceCategoricalColumn)): 1283 c0 = c0.categorical_column 1284 for c in sorted_columns[1:]: 1285 while isinstance(c, (WeightedCategoricalColumn, SequenceCategoricalColumn)): 1286 c = c.categorical_column 1287 if not isinstance(c, type(c0)): 1288 raise ValueError( 1289 'To use shared_embedding_column, all categorical_columns must have ' 1290 'the same type, or be weighted_categorical_column or sequence column ' 1291 'of the same type. Given column: {} of type: {} does not match given ' 1292 'column: {} of type: {}'.format(c0, type(c0), c, type(c))) 1293 if num_buckets != c.num_buckets: 1294 raise ValueError( 1295 'To use shared_embedding_column, all categorical_columns must have ' 1296 'the same number of buckets. Given column: {} with buckets: {} does ' 1297 'not match column: {} with buckets: {}'.format( 1298 c0, num_buckets, c, c.num_buckets)) 1299 1300 if not shared_embedding_collection_name: 1301 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 1302 shared_embedding_collection_name += '_shared_embedding' 1303 1304 column_creator = SharedEmbeddingColumnCreator( 1305 dimension, initializer, ckpt_to_load_from, tensor_name_in_ckpt, 1306 num_buckets, trainable, shared_embedding_collection_name, 1307 use_safe_embedding_lookup) 1308 1309 result = [] 1310 for column in categorical_columns: 1311 result.append( 1312 column_creator( 1313 categorical_column=column, combiner=combiner, max_norm=max_norm)) 1314 1315 return result 1316 1317 1318@tf_export('feature_column.numeric_column') 1319def numeric_column(key, 1320 shape=(1,), 1321 default_value=None, 1322 dtype=dtypes.float32, 1323 normalizer_fn=None): 1324 """Represents real valued or numerical features. 1325 1326 Example: 1327 1328 ```python 1329 price = numeric_column('price') 1330 columns = [price, ...] 1331 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1332 dense_tensor = input_layer(features, columns) 1333 1334 # or 1335 bucketized_price = bucketized_column(price, boundaries=[...]) 1336 columns = [bucketized_price, ...] 1337 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1338 linear_prediction = linear_model(features, columns) 1339 ``` 1340 1341 Args: 1342 key: A unique string identifying the input feature. It is used as the 1343 column name and the dictionary key for feature parsing configs, feature 1344 `Tensor` objects, and feature columns. 1345 shape: An iterable of integers specifies the shape of the `Tensor`. An 1346 integer can be given which means a single dimension `Tensor` with given 1347 width. The `Tensor` representing the column will have the shape of 1348 [batch_size] + `shape`. 1349 default_value: A single value compatible with `dtype` or an iterable of 1350 values compatible with `dtype` which the column takes on during 1351 `tf.Example` parsing if data is missing. A default value of `None` will 1352 cause `tf.io.parse_example` to fail if an example does not contain this 1353 column. If a single value is provided, the same value will be applied as 1354 the default value for every item. If an iterable of values is provided, 1355 the shape of the `default_value` should be equal to the given `shape`. 1356 dtype: defines the type of values. Default value is `tf.float32`. Must be a 1357 non-quantized, real integer or floating point type. 1358 normalizer_fn: If not `None`, a function that can be used to normalize the 1359 value of the tensor after `default_value` is applied for parsing. 1360 Normalizer function takes the input `Tensor` as its argument, and returns 1361 the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that 1362 even though the most common use case of this function is normalization, it 1363 can be used for any kind of Tensorflow transformations. 1364 1365 Returns: 1366 A `NumericColumn`. 1367 1368 Raises: 1369 TypeError: if any dimension in shape is not an int 1370 ValueError: if any dimension in shape is not a positive integer 1371 TypeError: if `default_value` is an iterable but not compatible with `shape` 1372 TypeError: if `default_value` is not compatible with `dtype`. 1373 ValueError: if `dtype` is not convertible to `tf.float32`. 1374 """ 1375 shape = _check_shape(shape, key) 1376 if not (dtype.is_integer or dtype.is_floating): 1377 raise ValueError('dtype must be convertible to float. ' 1378 'dtype: {}, key: {}'.format(dtype, key)) 1379 default_value = fc_utils.check_default_value( 1380 shape, default_value, dtype, key) 1381 1382 if normalizer_fn is not None and not callable(normalizer_fn): 1383 raise TypeError( 1384 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) 1385 1386 fc_utils.assert_key_is_string(key) 1387 return NumericColumn( 1388 key, 1389 shape=shape, 1390 default_value=default_value, 1391 dtype=dtype, 1392 normalizer_fn=normalizer_fn) 1393 1394 1395@tf_export('feature_column.bucketized_column') 1396def bucketized_column(source_column, boundaries): 1397 """Represents discretized dense input bucketed by `boundaries`. 1398 1399 Buckets include the left boundary, and exclude the right boundary. Namely, 1400 `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`, 1401 `[1., 2.)`, and `[2., +inf)`. 1402 1403 For example, if the inputs are 1404 1405 ```python 1406 boundaries = [0, 10, 100] 1407 input tensor = [[-5, 10000] 1408 [150, 10] 1409 [5, 100]] 1410 ``` 1411 1412 then the output will be 1413 1414 ```python 1415 output = [[0, 3] 1416 [3, 2] 1417 [1, 3]] 1418 ``` 1419 1420 Example: 1421 1422 ```python 1423 price = tf.feature_column.numeric_column('price') 1424 bucketized_price = tf.feature_column.bucketized_column( 1425 price, boundaries=[...]) 1426 columns = [bucketized_price, ...] 1427 features = tf.io.parse_example( 1428 ..., features=tf.feature_column.make_parse_example_spec(columns)) 1429 dense_tensor = tf.keras.layers.DenseFeatures(columns)(features) 1430 ``` 1431 1432 `bucketized_column` can also be crossed with another categorical column using 1433 `crossed_column`: 1434 1435 ```python 1436 price = tf.feature_column.numeric_column('price') 1437 # bucketized_column converts numerical feature to a categorical one. 1438 bucketized_price = tf.feature_column.bucketized_column( 1439 price, boundaries=[...]) 1440 # 'keywords' is a string feature. 1441 price_x_keywords = tf.feature_column.crossed_column( 1442 [bucketized_price, 'keywords'], 50K) 1443 columns = [price_x_keywords, ...] 1444 features = tf.io.parse_example( 1445 ..., features=tf.feature_column.make_parse_example_spec(columns)) 1446 dense_tensor = tf.keras.layers.DenseFeatures(columns)(features) 1447 linear_model = tf.keras.experimental.LinearModel(units=...)(dense_tensor) 1448 ``` 1449 1450 Args: 1451 source_column: A one-dimensional dense column which is generated with 1452 `numeric_column`. 1453 boundaries: A sorted list or tuple of floats specifying the boundaries. 1454 1455 Returns: 1456 A `BucketizedColumn`. 1457 1458 Raises: 1459 ValueError: If `source_column` is not a numeric column, or if it is not 1460 one-dimensional. 1461 ValueError: If `boundaries` is not a sorted list or tuple. 1462 """ 1463 if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)): # pylint: disable=protected-access 1464 raise ValueError( 1465 'source_column must be a column generated with numeric_column(). ' 1466 'Given: {}'.format(source_column)) 1467 if len(source_column.shape) > 1: 1468 raise ValueError( 1469 'source_column must be one-dimensional column. ' 1470 'Given: {}'.format(source_column)) 1471 if not boundaries: 1472 raise ValueError('boundaries must not be empty.') 1473 if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)): 1474 raise ValueError('boundaries must be a sorted list.') 1475 for i in range(len(boundaries) - 1): 1476 if boundaries[i] >= boundaries[i + 1]: 1477 raise ValueError('boundaries must be a sorted list.') 1478 return BucketizedColumn(source_column, tuple(boundaries)) 1479 1480 1481@tf_export('feature_column.categorical_column_with_hash_bucket') 1482def categorical_column_with_hash_bucket(key, 1483 hash_bucket_size, 1484 dtype=dtypes.string): 1485 """Represents sparse feature where ids are set by hashing. 1486 1487 Use this when your sparse features are in string or integer format, and you 1488 want to distribute your inputs into a finite number of buckets by hashing. 1489 output_id = Hash(input_feature_string) % bucket_size for string type input. 1490 For int type input, the value is converted to its string representation first 1491 and then hashed by the same formula. 1492 1493 For input dictionary `features`, `features[key]` is either `Tensor` or 1494 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1495 and `''` for string, which will be dropped by this feature column. 1496 1497 Example: 1498 1499 ```python 1500 keywords = categorical_column_with_hash_bucket("keywords", 10K) 1501 columns = [keywords, ...] 1502 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1503 linear_prediction = linear_model(features, columns) 1504 1505 # or 1506 keywords_embedded = embedding_column(keywords, 16) 1507 columns = [keywords_embedded, ...] 1508 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1509 dense_tensor = input_layer(features, columns) 1510 ``` 1511 1512 Args: 1513 key: A unique string identifying the input feature. It is used as the 1514 column name and the dictionary key for feature parsing configs, feature 1515 `Tensor` objects, and feature columns. 1516 hash_bucket_size: An int > 1. The number of buckets. 1517 dtype: The type of features. Only string and integer types are supported. 1518 1519 Returns: 1520 A `HashedCategoricalColumn`. 1521 1522 Raises: 1523 ValueError: `hash_bucket_size` is not greater than 1. 1524 ValueError: `dtype` is neither string nor integer. 1525 """ 1526 if hash_bucket_size is None: 1527 raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key)) 1528 1529 if hash_bucket_size < 1: 1530 raise ValueError('hash_bucket_size must be at least 1. ' 1531 'hash_bucket_size: {}, key: {}'.format( 1532 hash_bucket_size, key)) 1533 1534 fc_utils.assert_key_is_string(key) 1535 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1536 1537 return HashedCategoricalColumn(key, hash_bucket_size, dtype) 1538 1539 1540@tf_export(v1=['feature_column.categorical_column_with_vocabulary_file']) 1541def categorical_column_with_vocabulary_file(key, 1542 vocabulary_file, 1543 vocabulary_size=None, 1544 num_oov_buckets=0, 1545 default_value=None, 1546 dtype=dtypes.string): 1547 """A `CategoricalColumn` with a vocabulary file. 1548 1549 Use this when your inputs are in string or integer format, and you have a 1550 vocabulary file that maps each value to an integer ID. By default, 1551 out-of-vocabulary values are ignored. Use either (but not both) of 1552 `num_oov_buckets` and `default_value` to specify how to include 1553 out-of-vocabulary values. 1554 1555 For input dictionary `features`, `features[key]` is either `Tensor` or 1556 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1557 and `''` for string, which will be dropped by this feature column. 1558 1559 Example with `num_oov_buckets`: 1560 File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state 1561 abbreviation. All inputs with values in that file are assigned an ID 0-49, 1562 corresponding to its line number. All other values are hashed and assigned an 1563 ID 50-54. 1564 1565 ```python 1566 states = categorical_column_with_vocabulary_file( 1567 key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, 1568 num_oov_buckets=5) 1569 columns = [states, ...] 1570 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1571 linear_prediction = linear_model(features, columns) 1572 ``` 1573 1574 Example with `default_value`: 1575 File '/us/states.txt' contains 51 lines - the first line is 'XX', and the 1576 other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX' 1577 in input, and other values missing from the file, will be assigned ID 0. All 1578 others are assigned the corresponding line number 1-50. 1579 1580 ```python 1581 states = categorical_column_with_vocabulary_file( 1582 key='states', vocabulary_file='/us/states.txt', vocabulary_size=51, 1583 default_value=0) 1584 columns = [states, ...] 1585 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1586 linear_prediction, _, _ = linear_model(features, columns) 1587 ``` 1588 1589 And to make an embedding with either: 1590 1591 ```python 1592 columns = [embedding_column(states, 3),...] 1593 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1594 dense_tensor = input_layer(features, columns) 1595 ``` 1596 1597 Args: 1598 key: A unique string identifying the input feature. It is used as the 1599 column name and the dictionary key for feature parsing configs, feature 1600 `Tensor` objects, and feature columns. 1601 vocabulary_file: The vocabulary file name. 1602 vocabulary_size: Number of the elements in the vocabulary. This must be no 1603 greater than length of `vocabulary_file`, if less than length, later 1604 values are ignored. If None, it is set to the length of `vocabulary_file`. 1605 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1606 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1607 `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of 1608 the input value. A positive `num_oov_buckets` can not be specified with 1609 `default_value`. 1610 default_value: The integer ID value to return for out-of-vocabulary feature 1611 values, defaults to `-1`. This can not be specified with a positive 1612 `num_oov_buckets`. 1613 dtype: The type of features. Only string and integer types are supported. 1614 1615 Returns: 1616 A `CategoricalColumn` with a vocabulary file. 1617 1618 Raises: 1619 ValueError: `vocabulary_file` is missing or cannot be opened. 1620 ValueError: `vocabulary_size` is missing or < 1. 1621 ValueError: `num_oov_buckets` is a negative integer. 1622 ValueError: `num_oov_buckets` and `default_value` are both specified. 1623 ValueError: `dtype` is neither string nor integer. 1624 """ 1625 return categorical_column_with_vocabulary_file_v2( 1626 key, vocabulary_file, vocabulary_size, 1627 dtype, default_value, 1628 num_oov_buckets) 1629 1630 1631@tf_export('feature_column.categorical_column_with_vocabulary_file', v1=[]) 1632def categorical_column_with_vocabulary_file_v2(key, 1633 vocabulary_file, 1634 vocabulary_size=None, 1635 dtype=dtypes.string, 1636 default_value=None, 1637 num_oov_buckets=0): 1638 """A `CategoricalColumn` with a vocabulary file. 1639 1640 Use this when your inputs are in string or integer format, and you have a 1641 vocabulary file that maps each value to an integer ID. By default, 1642 out-of-vocabulary values are ignored. Use either (but not both) of 1643 `num_oov_buckets` and `default_value` to specify how to include 1644 out-of-vocabulary values. 1645 1646 For input dictionary `features`, `features[key]` is either `Tensor` or 1647 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1648 and `''` for string, which will be dropped by this feature column. 1649 1650 Example with `num_oov_buckets`: 1651 File `'/us/states.txt'` contains 50 lines, each with a 2-character U.S. state 1652 abbreviation. All inputs with values in that file are assigned an ID 0-49, 1653 corresponding to its line number. All other values are hashed and assigned an 1654 ID 50-54. 1655 1656 ```python 1657 states = categorical_column_with_vocabulary_file( 1658 key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, 1659 num_oov_buckets=5) 1660 columns = [states, ...] 1661 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1662 linear_prediction = linear_model(features, columns) 1663 ``` 1664 1665 Example with `default_value`: 1666 File `'/us/states.txt'` contains 51 lines - the first line is `'XX'`, and the 1667 other 50 each have a 2-character U.S. state abbreviation. Both a literal 1668 `'XX'` in input, and other values missing from the file, will be assigned 1669 ID 0. All others are assigned the corresponding line number 1-50. 1670 1671 ```python 1672 states = categorical_column_with_vocabulary_file( 1673 key='states', vocabulary_file='/us/states.txt', vocabulary_size=51, 1674 default_value=0) 1675 columns = [states, ...] 1676 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1677 linear_prediction, _, _ = linear_model(features, columns) 1678 ``` 1679 1680 And to make an embedding with either: 1681 1682 ```python 1683 columns = [embedding_column(states, 3),...] 1684 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1685 dense_tensor = input_layer(features, columns) 1686 ``` 1687 1688 Args: 1689 key: A unique string identifying the input feature. It is used as the 1690 column name and the dictionary key for feature parsing configs, feature 1691 `Tensor` objects, and feature columns. 1692 vocabulary_file: The vocabulary file name. 1693 vocabulary_size: Number of the elements in the vocabulary. This must be no 1694 greater than length of `vocabulary_file`, if less than length, later 1695 values are ignored. If None, it is set to the length of `vocabulary_file`. 1696 dtype: The type of features. Only string and integer types are supported. 1697 default_value: The integer ID value to return for out-of-vocabulary feature 1698 values, defaults to `-1`. This can not be specified with a positive 1699 `num_oov_buckets`. 1700 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1701 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1702 `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of 1703 the input value. A positive `num_oov_buckets` can not be specified with 1704 `default_value`. 1705 1706 Returns: 1707 A `CategoricalColumn` with a vocabulary file. 1708 1709 Raises: 1710 ValueError: `vocabulary_file` is missing or cannot be opened. 1711 ValueError: `vocabulary_size` is missing or < 1. 1712 ValueError: `num_oov_buckets` is a negative integer. 1713 ValueError: `num_oov_buckets` and `default_value` are both specified. 1714 ValueError: `dtype` is neither string nor integer. 1715 """ 1716 if not vocabulary_file: 1717 raise ValueError('Missing vocabulary_file in {}.'.format(key)) 1718 1719 if vocabulary_size is None: 1720 if not gfile.Exists(vocabulary_file): 1721 raise ValueError('vocabulary_file in {} does not exist.'.format(key)) 1722 1723 with gfile.GFile(vocabulary_file, mode='rb') as f: 1724 vocabulary_size = sum(1 for _ in f) 1725 logging.info( 1726 'vocabulary_size = %d in %s is inferred from the number of elements ' 1727 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file) 1728 1729 # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`. 1730 if vocabulary_size < 1: 1731 raise ValueError('Invalid vocabulary_size in {}.'.format(key)) 1732 if num_oov_buckets: 1733 if default_value is not None: 1734 raise ValueError( 1735 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( 1736 key)) 1737 if num_oov_buckets < 0: 1738 raise ValueError('Invalid num_oov_buckets {} in {}.'.format( 1739 num_oov_buckets, key)) 1740 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1741 fc_utils.assert_key_is_string(key) 1742 return VocabularyFileCategoricalColumn( 1743 key=key, 1744 vocabulary_file=vocabulary_file, 1745 vocabulary_size=vocabulary_size, 1746 num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets, 1747 default_value=-1 if default_value is None else default_value, 1748 dtype=dtype) 1749 1750 1751@tf_export('feature_column.categorical_column_with_vocabulary_list') 1752def categorical_column_with_vocabulary_list(key, 1753 vocabulary_list, 1754 dtype=None, 1755 default_value=-1, 1756 num_oov_buckets=0): 1757 """A `CategoricalColumn` with in-memory vocabulary. 1758 1759 Use this when your inputs are in string or integer format, and you have an 1760 in-memory vocabulary mapping each value to an integer ID. By default, 1761 out-of-vocabulary values are ignored. Use either (but not both) of 1762 `num_oov_buckets` and `default_value` to specify how to include 1763 out-of-vocabulary values. 1764 1765 For input dictionary `features`, `features[key]` is either `Tensor` or 1766 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1767 and `''` for string, which will be dropped by this feature column. 1768 1769 Example with `num_oov_buckets`: 1770 In the following example, each input in `vocabulary_list` is assigned an ID 1771 0-3 corresponding to its index (e.g., input 'B' produces output 2). All other 1772 inputs are hashed and assigned an ID 4-5. 1773 1774 ```python 1775 colors = categorical_column_with_vocabulary_list( 1776 key='colors', vocabulary_list=('R', 'G', 'B', 'Y'), 1777 num_oov_buckets=2) 1778 columns = [colors, ...] 1779 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1780 linear_prediction, _, _ = linear_model(features, columns) 1781 ``` 1782 1783 Example with `default_value`: 1784 In the following example, each input in `vocabulary_list` is assigned an ID 1785 0-4 corresponding to its index (e.g., input 'B' produces output 3). All other 1786 inputs are assigned `default_value` 0. 1787 1788 1789 ```python 1790 colors = categorical_column_with_vocabulary_list( 1791 key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0) 1792 columns = [colors, ...] 1793 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1794 linear_prediction, _, _ = linear_model(features, columns) 1795 ``` 1796 1797 And to make an embedding with either: 1798 1799 ```python 1800 columns = [embedding_column(colors, 3),...] 1801 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1802 dense_tensor = input_layer(features, columns) 1803 ``` 1804 1805 Args: 1806 key: A unique string identifying the input feature. It is used as the column 1807 name and the dictionary key for feature parsing configs, feature `Tensor` 1808 objects, and feature columns. 1809 vocabulary_list: An ordered iterable defining the vocabulary. Each feature 1810 is mapped to the index of its value (if present) in `vocabulary_list`. 1811 Must be castable to `dtype`. 1812 dtype: The type of features. Only string and integer types are supported. If 1813 `None`, it will be inferred from `vocabulary_list`. 1814 default_value: The integer ID value to return for out-of-vocabulary feature 1815 values, defaults to `-1`. This can not be specified with a positive 1816 `num_oov_buckets`. 1817 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1818 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1819 `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a 1820 hash of the input value. A positive `num_oov_buckets` can not be specified 1821 with `default_value`. 1822 1823 Returns: 1824 A `CategoricalColumn` with in-memory vocabulary. 1825 1826 Raises: 1827 ValueError: if `vocabulary_list` is empty, or contains duplicate keys. 1828 ValueError: `num_oov_buckets` is a negative integer. 1829 ValueError: `num_oov_buckets` and `default_value` are both specified. 1830 ValueError: if `dtype` is not integer or string. 1831 """ 1832 if (vocabulary_list is None) or (len(vocabulary_list) < 1): 1833 raise ValueError( 1834 'vocabulary_list {} must be non-empty, column_name: {}'.format( 1835 vocabulary_list, key)) 1836 if len(set(vocabulary_list)) != len(vocabulary_list): 1837 raise ValueError( 1838 'Duplicate keys in vocabulary_list {}, column_name: {}'.format( 1839 vocabulary_list, key)) 1840 vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype) 1841 if num_oov_buckets: 1842 if default_value != -1: 1843 raise ValueError( 1844 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( 1845 key)) 1846 if num_oov_buckets < 0: 1847 raise ValueError('Invalid num_oov_buckets {} in {}.'.format( 1848 num_oov_buckets, key)) 1849 fc_utils.assert_string_or_int( 1850 vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key)) 1851 if dtype is None: 1852 dtype = vocabulary_dtype 1853 elif dtype.is_integer != vocabulary_dtype.is_integer: 1854 raise ValueError( 1855 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format( 1856 dtype, vocabulary_dtype, key)) 1857 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1858 fc_utils.assert_key_is_string(key) 1859 1860 return VocabularyListCategoricalColumn( 1861 key=key, 1862 vocabulary_list=tuple(vocabulary_list), 1863 dtype=dtype, 1864 default_value=default_value, 1865 num_oov_buckets=num_oov_buckets) 1866 1867 1868@tf_export('feature_column.categorical_column_with_identity') 1869def categorical_column_with_identity(key, num_buckets, default_value=None): 1870 """A `CategoricalColumn` that returns identity values. 1871 1872 Use this when your inputs are integers in the range `[0, num_buckets)`, and 1873 you want to use the input value itself as the categorical ID. Values outside 1874 this range will result in `default_value` if specified, otherwise it will 1875 fail. 1876 1877 Typically, this is used for contiguous ranges of integer indexes, but 1878 it doesn't have to be. This might be inefficient, however, if many of IDs 1879 are unused. Consider `categorical_column_with_hash_bucket` in that case. 1880 1881 For input dictionary `features`, `features[key]` is either `Tensor` or 1882 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1883 and `''` for string, which will be dropped by this feature column. 1884 1885 In the following examples, each input in the range `[0, 1000000)` is assigned 1886 the same value. All other inputs are assigned `default_value` 0. Note that a 1887 literal 0 in inputs will result in the same default ID. 1888 1889 Linear model: 1890 1891 ```python 1892 video_id = categorical_column_with_identity( 1893 key='video_id', num_buckets=1000000, default_value=0) 1894 columns = [video_id, ...] 1895 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1896 linear_prediction, _, _ = linear_model(features, columns) 1897 ``` 1898 1899 Embedding for a DNN model: 1900 1901 ```python 1902 columns = [embedding_column(video_id, 9),...] 1903 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1904 dense_tensor = input_layer(features, columns) 1905 ``` 1906 1907 Args: 1908 key: A unique string identifying the input feature. It is used as the 1909 column name and the dictionary key for feature parsing configs, feature 1910 `Tensor` objects, and feature columns. 1911 num_buckets: Range of inputs and outputs is `[0, num_buckets)`. 1912 default_value: If set, values outside of range `[0, num_buckets)` will 1913 be replaced with this value. If not set, values >= num_buckets will 1914 cause a failure while values < 0 will be dropped. 1915 1916 Returns: 1917 A `CategoricalColumn` that returns identity values. 1918 1919 Raises: 1920 ValueError: if `num_buckets` is less than one. 1921 ValueError: if `default_value` is not in range `[0, num_buckets)`. 1922 """ 1923 if num_buckets < 1: 1924 raise ValueError( 1925 'num_buckets {} < 1, column_name {}'.format(num_buckets, key)) 1926 if (default_value is not None) and ( 1927 (default_value < 0) or (default_value >= num_buckets)): 1928 raise ValueError( 1929 'default_value {} not in range [0, {}), column_name {}'.format( 1930 default_value, num_buckets, key)) 1931 fc_utils.assert_key_is_string(key) 1932 return IdentityCategoricalColumn( 1933 key=key, number_buckets=num_buckets, default_value=default_value) 1934 1935 1936@tf_export('feature_column.indicator_column') 1937def indicator_column(categorical_column): 1938 """Represents multi-hot representation of given categorical column. 1939 1940 - For DNN model, `indicator_column` can be used to wrap any 1941 `categorical_column_*` (e.g., to feed to DNN). Consider to Use 1942 `embedding_column` if the number of buckets/unique(values) are large. 1943 1944 - For Wide (aka linear) model, `indicator_column` is the internal 1945 representation for categorical column when passing categorical column 1946 directly (as any element in feature_columns) to `linear_model`. See 1947 `linear_model` for details. 1948 1949 ```python 1950 name = indicator_column(categorical_column_with_vocabulary_list( 1951 'name', ['bob', 'george', 'wanda']) 1952 columns = [name, ...] 1953 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1954 dense_tensor = input_layer(features, columns) 1955 1956 dense_tensor == [[1, 0, 0]] # If "name" bytes_list is ["bob"] 1957 dense_tensor == [[1, 0, 1]] # If "name" bytes_list is ["bob", "wanda"] 1958 dense_tensor == [[2, 0, 0]] # If "name" bytes_list is ["bob", "bob"] 1959 ``` 1960 1961 Args: 1962 categorical_column: A `CategoricalColumn` which is created by 1963 `categorical_column_with_*` or `crossed_column` functions. 1964 1965 Returns: 1966 An `IndicatorColumn`. 1967 1968 Raises: 1969 ValueError: If `categorical_column` is not CategoricalColumn type. 1970 """ 1971 if not isinstance(categorical_column, 1972 (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access 1973 raise ValueError( 1974 'Unsupported input type. Input must be a CategoricalColumn. ' 1975 'Given: {}'.format(categorical_column)) 1976 return IndicatorColumn(categorical_column) 1977 1978 1979@tf_export('feature_column.weighted_categorical_column') 1980def weighted_categorical_column(categorical_column, 1981 weight_feature_key, 1982 dtype=dtypes.float32): 1983 """Applies weight values to a `CategoricalColumn`. 1984 1985 Use this when each of your sparse inputs has both an ID and a value. For 1986 example, if you're representing text documents as a collection of word 1987 frequencies, you can provide 2 parallel sparse input features ('terms' and 1988 'frequencies' below). 1989 1990 Example: 1991 1992 Input `tf.Example` objects: 1993 1994 ```proto 1995 [ 1996 features { 1997 feature { 1998 key: "terms" 1999 value {bytes_list {value: "very" value: "model"}} 2000 } 2001 feature { 2002 key: "frequencies" 2003 value {float_list {value: 0.3 value: 0.1}} 2004 } 2005 }, 2006 features { 2007 feature { 2008 key: "terms" 2009 value {bytes_list {value: "when" value: "course" value: "human"}} 2010 } 2011 feature { 2012 key: "frequencies" 2013 value {float_list {value: 0.4 value: 0.1 value: 0.2}} 2014 } 2015 } 2016 ] 2017 ``` 2018 2019 ```python 2020 categorical_column = categorical_column_with_hash_bucket( 2021 column_name='terms', hash_bucket_size=1000) 2022 weighted_column = weighted_categorical_column( 2023 categorical_column=categorical_column, weight_feature_key='frequencies') 2024 columns = [weighted_column, ...] 2025 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 2026 linear_prediction, _, _ = linear_model(features, columns) 2027 ``` 2028 2029 This assumes the input dictionary contains a `SparseTensor` for key 2030 'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have 2031 the same indices and dense shape. 2032 2033 Args: 2034 categorical_column: A `CategoricalColumn` created by 2035 `categorical_column_with_*` functions. 2036 weight_feature_key: String key for weight values. 2037 dtype: Type of weights, such as `tf.float32`. Only float and integer weights 2038 are supported. 2039 2040 Returns: 2041 A `CategoricalColumn` composed of two sparse features: one represents id, 2042 the other represents weight (value) of the id feature in that example. 2043 2044 Raises: 2045 ValueError: if `dtype` is not convertible to float. 2046 """ 2047 if (dtype is None) or not (dtype.is_integer or dtype.is_floating): 2048 raise ValueError('dtype {} is not convertible to float.'.format(dtype)) 2049 return WeightedCategoricalColumn( 2050 categorical_column=categorical_column, 2051 weight_feature_key=weight_feature_key, 2052 dtype=dtype) 2053 2054 2055@tf_export('feature_column.crossed_column') 2056def crossed_column(keys, hash_bucket_size, hash_key=None): 2057 """Returns a column for performing crosses of categorical features. 2058 2059 Crossed features will be hashed according to `hash_bucket_size`. Conceptually, 2060 the transformation can be thought of as: 2061 Hash(cartesian product of features) % `hash_bucket_size` 2062 2063 For example, if the input features are: 2064 2065 * SparseTensor referred by first key: 2066 2067 ```python 2068 shape = [2, 2] 2069 { 2070 [0, 0]: "a" 2071 [1, 0]: "b" 2072 [1, 1]: "c" 2073 } 2074 ``` 2075 2076 * SparseTensor referred by second key: 2077 2078 ```python 2079 shape = [2, 1] 2080 { 2081 [0, 0]: "d" 2082 [1, 0]: "e" 2083 } 2084 ``` 2085 2086 then crossed feature will look like: 2087 2088 ```python 2089 shape = [2, 2] 2090 { 2091 [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size 2092 [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size 2093 [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size 2094 } 2095 ``` 2096 2097 Here is an example to create a linear model with crosses of string features: 2098 2099 ```python 2100 keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K) 2101 columns = [keywords_x_doc_terms, ...] 2102 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 2103 linear_prediction = linear_model(features, columns) 2104 ``` 2105 2106 You could also use vocabulary lookup before crossing: 2107 2108 ```python 2109 keywords = categorical_column_with_vocabulary_file( 2110 'keywords', '/path/to/vocabulary/file', vocabulary_size=1K) 2111 keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K) 2112 columns = [keywords_x_doc_terms, ...] 2113 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 2114 linear_prediction = linear_model(features, columns) 2115 ``` 2116 2117 If an input feature is of numeric type, you can use 2118 `categorical_column_with_identity`, or `bucketized_column`, as in the example: 2119 2120 ```python 2121 # vertical_id is an integer categorical feature. 2122 vertical_id = categorical_column_with_identity('vertical_id', 10K) 2123 price = numeric_column('price') 2124 # bucketized_column converts numerical feature to a categorical one. 2125 bucketized_price = bucketized_column(price, boundaries=[...]) 2126 vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K) 2127 columns = [vertical_id_x_price, ...] 2128 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 2129 linear_prediction = linear_model(features, columns) 2130 ``` 2131 2132 To use crossed column in DNN model, you need to add it in an embedding column 2133 as in this example: 2134 2135 ```python 2136 vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K) 2137 vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10) 2138 dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...]) 2139 ``` 2140 2141 Args: 2142 keys: An iterable identifying the features to be crossed. Each element can 2143 be either: 2144 * string: Will use the corresponding feature which must be of string type. 2145 * `CategoricalColumn`: Will use the transformed tensor produced by this 2146 column. Does not support hashed categorical column. 2147 hash_bucket_size: An int > 1. The number of buckets. 2148 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 2149 function to combine the crosses fingerprints on SparseCrossOp (optional). 2150 2151 Returns: 2152 A `CrossedColumn`. 2153 2154 Raises: 2155 ValueError: If `len(keys) < 2`. 2156 ValueError: If any of the keys is neither a string nor `CategoricalColumn`. 2157 ValueError: If any of the keys is `HashedCategoricalColumn`. 2158 ValueError: If `hash_bucket_size < 1`. 2159 """ 2160 if not hash_bucket_size or hash_bucket_size < 1: 2161 raise ValueError('hash_bucket_size must be > 1. ' 2162 'hash_bucket_size: {}'.format(hash_bucket_size)) 2163 if not keys or len(keys) < 2: 2164 raise ValueError( 2165 'keys must be a list with length > 1. Given: {}'.format(keys)) 2166 for key in keys: 2167 if (not isinstance(key, six.string_types) and 2168 not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))): # pylint: disable=protected-access 2169 raise ValueError( 2170 'Unsupported key type. All keys must be either string, or ' 2171 'categorical column except HashedCategoricalColumn. ' 2172 'Given: {}'.format(key)) 2173 if isinstance(key, 2174 (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)): # pylint: disable=protected-access 2175 raise ValueError( 2176 'categorical_column_with_hash_bucket is not supported for crossing. ' 2177 'Hashing before crossing will increase probability of collision. ' 2178 'Instead, use the feature name as a string. Given: {}'.format(key)) 2179 return CrossedColumn( 2180 keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key) 2181 2182 2183@six.add_metaclass(abc.ABCMeta) 2184class FeatureColumn(object): 2185 """Represents a feature column abstraction. 2186 2187 WARNING: Do not subclass this layer unless you know what you are doing: 2188 the API is subject to future changes. 2189 2190 To distinguish between the concept of a feature family and a specific binary 2191 feature within a family, we refer to a feature family like "country" as a 2192 feature column. For example, we can have a feature in a `tf.Example` format: 2193 {key: "country", value: [ "US" ]} 2194 In this example the value of feature is "US" and "country" refers to the 2195 column of the feature. 2196 2197 This class is an abstract class. Users should not create instances of this. 2198 """ 2199 2200 @abc.abstractproperty 2201 def name(self): 2202 """Returns string. Used for naming.""" 2203 pass 2204 2205 def __lt__(self, other): 2206 """Allows feature columns to be sorted in Python 3 as they are in Python 2. 2207 2208 Feature columns need to occasionally be sortable, for example when used as 2209 keys in a features dictionary passed to a layer. 2210 2211 In CPython, `__lt__` must be defined for all objects in the 2212 sequence being sorted. 2213 2214 If any objects in teh sequence being sorted do not have an `__lt__` method 2215 compatible with feature column objects (such as strings), then CPython will 2216 fall back to using the `__gt__` method below. 2217 https://docs.python.org/3/library/stdtypes.html#list.sort 2218 2219 Args: 2220 other: The other object to compare to. 2221 2222 Returns: 2223 True if the string representation of this object is lexicographically less 2224 than the string representation of `other`. For FeatureColumn objects, 2225 this looks like "<__main__.FeatureColumn object at 0xa>". 2226 """ 2227 return str(self) < str(other) 2228 2229 def __gt__(self, other): 2230 """Allows feature columns to be sorted in Python 3 as they are in Python 2. 2231 2232 Feature columns need to occasionally be sortable, for example when used as 2233 keys in a features dictionary passed to a layer. 2234 2235 `__gt__` is called when the "other" object being compared during the sort 2236 does not have `__lt__` defined. 2237 Example: http://gpaste/4803354716798976 2238 2239 Args: 2240 other: The other object to compare to. 2241 2242 Returns: 2243 True if the string representation of this object is lexicographically 2244 greater than the string representation of `other`. For FeatureColumn 2245 objects, this looks like "<__main__.FeatureColumn object at 0xa>". 2246 """ 2247 return str(self) > str(other) 2248 2249 @abc.abstractmethod 2250 def transform_feature(self, transformation_cache, state_manager): 2251 """Returns intermediate representation (usually a `Tensor`). 2252 2253 Uses `transformation_cache` to create an intermediate representation 2254 (usually a `Tensor`) that other feature columns can use. 2255 2256 Example usage of `transformation_cache`: 2257 Let's say a Feature column depends on raw feature ('raw') and another 2258 `FeatureColumn` (input_fc). To access corresponding `Tensor`s, 2259 transformation_cache will be used as follows: 2260 2261 ```python 2262 raw_tensor = transformation_cache.get('raw', state_manager) 2263 fc_tensor = transformation_cache.get(input_fc, state_manager) 2264 ``` 2265 2266 Args: 2267 transformation_cache: A `FeatureTransformationCache` object to access 2268 features. 2269 state_manager: A `StateManager` to create / access resources such as 2270 lookup tables. 2271 2272 Returns: 2273 Transformed feature `Tensor`. 2274 """ 2275 pass 2276 2277 @abc.abstractproperty 2278 def parse_example_spec(self): 2279 """Returns a `tf.Example` parsing spec as dict. 2280 2281 It is used for get_parsing_spec for `tf.io.parse_example`. Returned spec is 2282 a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other 2283 supported objects. Please check documentation of `tf.io.parse_example` for 2284 all supported spec objects. 2285 2286 Let's say a Feature column depends on raw feature ('raw') and another 2287 `FeatureColumn` (input_fc). One possible implementation of 2288 parse_example_spec is as follows: 2289 2290 ```python 2291 spec = {'raw': tf.io.FixedLenFeature(...)} 2292 spec.update(input_fc.parse_example_spec) 2293 return spec 2294 ``` 2295 """ 2296 pass 2297 2298 def create_state(self, state_manager): 2299 """Uses the `state_manager` to create state for the FeatureColumn. 2300 2301 Args: 2302 state_manager: A `StateManager` to create / access resources such as 2303 lookup tables and variables. 2304 """ 2305 pass 2306 2307 @abc.abstractproperty 2308 def _is_v2_column(self): 2309 """Returns whether this FeatureColumn is fully conformant to the new API. 2310 2311 This is needed for composition type cases where an EmbeddingColumn etc. 2312 might take in old categorical columns as input and then we want to use the 2313 old API. 2314 """ 2315 pass 2316 2317 @abc.abstractproperty 2318 def parents(self): 2319 """Returns a list of immediate raw feature and FeatureColumn dependencies. 2320 2321 For example: 2322 # For the following feature columns 2323 a = numeric_column('f1') 2324 c = crossed_column(a, 'f2') 2325 # The expected parents are: 2326 a.parents = ['f1'] 2327 c.parents = [a, 'f2'] 2328 """ 2329 pass 2330 2331 def get_config(self): 2332 """Returns the config of the feature column. 2333 2334 A FeatureColumn config is a Python dictionary (serializable) containing the 2335 configuration of a FeatureColumn. The same FeatureColumn can be 2336 reinstantiated later from this configuration. 2337 2338 The config of a feature column does not include information about feature 2339 columns depending on it nor the FeatureColumn class name. 2340 2341 Example with (de)serialization practices followed in this file: 2342 ```python 2343 class SerializationExampleFeatureColumn( 2344 FeatureColumn, collections.namedtuple( 2345 'SerializationExampleFeatureColumn', 2346 ('dimension', 'parent', 'dtype', 'normalizer_fn'))): 2347 2348 def get_config(self): 2349 # Create a dict from the namedtuple. 2350 # Python attribute literals can be directly copied from / to the config. 2351 # For example 'dimension', assuming it is an integer literal. 2352 config = dict(zip(self._fields, self)) 2353 2354 # (De)serialization of parent FeatureColumns should use the provided 2355 # (de)serialize_feature_column() methods that take care of de-duping. 2356 config['parent'] = serialize_feature_column(self.parent) 2357 2358 # Many objects provide custom (de)serialization e.g: for tf.DType 2359 # tf.DType.name, tf.as_dtype() can be used. 2360 config['dtype'] = self.dtype.name 2361 2362 # Non-trivial dependencies should be Keras-(de)serializable. 2363 config['normalizer_fn'] = generic_utils.serialize_keras_object( 2364 self.normalizer_fn) 2365 2366 return config 2367 2368 @classmethod 2369 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2370 # This should do the inverse transform from `get_config` and construct 2371 # the namedtuple. 2372 kwargs = config.copy() 2373 kwargs['parent'] = deserialize_feature_column( 2374 config['parent'], custom_objects, columns_by_name) 2375 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 2376 kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object( 2377 config['normalizer_fn'], custom_objects=custom_objects) 2378 return cls(**kwargs) 2379 2380 ``` 2381 Returns: 2382 A serializable Dict that can be used to deserialize the object with 2383 from_config. 2384 """ 2385 return self._get_config() 2386 2387 def _get_config(self): 2388 raise NotImplementedError('Must be implemented in subclasses.') 2389 2390 @classmethod 2391 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2392 """Creates a FeatureColumn from its config. 2393 2394 This method should be the reverse of `get_config`, capable of instantiating 2395 the same FeatureColumn from the config dictionary. See `get_config` for an 2396 example of common (de)serialization practices followed in this file. 2397 2398 TODO(b/118939620): This is a private method until consensus is reached on 2399 supporting object deserialization deduping within Keras. 2400 2401 Args: 2402 config: A Dict config acquired with `get_config`. 2403 custom_objects: Optional dictionary mapping names (strings) to custom 2404 classes or functions to be considered during deserialization. 2405 columns_by_name: A Dict[String, FeatureColumn] of existing columns in 2406 order to avoid duplication. Should be passed to any calls to 2407 deserialize_feature_column(). 2408 2409 Returns: 2410 A FeatureColumn for the input config. 2411 """ 2412 return cls._from_config(config, custom_objects, columns_by_name) 2413 2414 @classmethod 2415 def _from_config(cls, config, custom_objects=None, columns_by_name=None): 2416 raise NotImplementedError('Must be implemented in subclasses.') 2417 2418 2419class DenseColumn(FeatureColumn): 2420 """Represents a column which can be represented as `Tensor`. 2421 2422 Some examples of this type are: numeric_column, embedding_column, 2423 indicator_column. 2424 """ 2425 2426 @abc.abstractproperty 2427 def variable_shape(self): 2428 """`TensorShape` of `get_dense_tensor`, without batch dimension.""" 2429 pass 2430 2431 @abc.abstractmethod 2432 def get_dense_tensor(self, transformation_cache, state_manager): 2433 """Returns a `Tensor`. 2434 2435 The output of this function will be used by model-builder-functions. For 2436 example the pseudo code of `input_layer` will be like: 2437 2438 ```python 2439 def input_layer(features, feature_columns, ...): 2440 outputs = [fc.get_dense_tensor(...) for fc in feature_columns] 2441 return tf.concat(outputs) 2442 ``` 2443 2444 Args: 2445 transformation_cache: A `FeatureTransformationCache` object to access 2446 features. 2447 state_manager: A `StateManager` to create / access resources such as 2448 lookup tables. 2449 2450 Returns: 2451 `Tensor` of shape [batch_size] + `variable_shape`. 2452 """ 2453 pass 2454 2455 2456def is_feature_column_v2(feature_columns): 2457 """Returns True if all feature columns are V2.""" 2458 for feature_column in feature_columns: 2459 if not isinstance(feature_column, FeatureColumn): 2460 return False 2461 if not feature_column._is_v2_column: # pylint: disable=protected-access 2462 return False 2463 return True 2464 2465 2466def _create_weighted_sum(column, transformation_cache, state_manager, 2467 sparse_combiner, weight_var): 2468 """Creates a weighted sum for a dense/categorical column for linear_model.""" 2469 if isinstance(column, CategoricalColumn): 2470 return _create_categorical_column_weighted_sum( 2471 column=column, 2472 transformation_cache=transformation_cache, 2473 state_manager=state_manager, 2474 sparse_combiner=sparse_combiner, 2475 weight_var=weight_var) 2476 else: 2477 return _create_dense_column_weighted_sum( 2478 column=column, 2479 transformation_cache=transformation_cache, 2480 state_manager=state_manager, 2481 weight_var=weight_var) 2482 2483 2484def _create_dense_column_weighted_sum(column, transformation_cache, 2485 state_manager, weight_var): 2486 """Create a weighted sum of a dense column for linear_model.""" 2487 tensor = column.get_dense_tensor(transformation_cache, state_manager) 2488 num_elements = column.variable_shape.num_elements() 2489 batch_size = array_ops.shape(tensor)[0] 2490 tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements)) 2491 return math_ops.matmul(tensor, weight_var, name='weighted_sum') 2492 2493 2494class CategoricalColumn(FeatureColumn): 2495 """Represents a categorical feature. 2496 2497 A categorical feature typically handled with a `tf.SparseTensor` of IDs. 2498 """ 2499 2500 IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name 2501 'IdWeightPair', ('id_tensor', 'weight_tensor')) 2502 2503 @abc.abstractproperty 2504 def num_buckets(self): 2505 """Returns number of buckets in this sparse feature.""" 2506 pass 2507 2508 @abc.abstractmethod 2509 def get_sparse_tensors(self, transformation_cache, state_manager): 2510 """Returns an IdWeightPair. 2511 2512 `IdWeightPair` is a pair of `SparseTensor`s which represents ids and 2513 weights. 2514 2515 `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets` 2516 `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a 2517 `SparseTensor` of `float` or `None` to indicate all weights should be 2518 taken to be 1. If specified, `weight_tensor` must have exactly the same 2519 shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing 2520 output of a `VarLenFeature` which is a ragged matrix. 2521 2522 Args: 2523 transformation_cache: A `FeatureTransformationCache` object to access 2524 features. 2525 state_manager: A `StateManager` to create / access resources such as 2526 lookup tables. 2527 """ 2528 pass 2529 2530 2531def _create_categorical_column_weighted_sum( 2532 column, transformation_cache, state_manager, sparse_combiner, weight_var): 2533 # pylint: disable=g-doc-return-or-yield,g-doc-args 2534 """Create a weighted sum of a categorical column for linear_model. 2535 2536 Note to maintainer: As implementation details, the weighted sum is 2537 implemented via embedding_lookup_sparse toward efficiency. Mathematically, 2538 they are the same. 2539 2540 To be specific, conceptually, categorical column can be treated as multi-hot 2541 vector. Say: 2542 2543 ```python 2544 x = [0 0 1] # categorical column input 2545 w = [a b c] # weights 2546 ``` 2547 The weighted sum is `c` in this case, which is same as `w[2]`. 2548 2549 Another example is 2550 2551 ```python 2552 x = [0 1 1] # categorical column input 2553 w = [a b c] # weights 2554 ``` 2555 The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`. 2556 2557 For both cases, we can implement weighted sum via embedding_lookup with 2558 sparse_combiner = "sum". 2559 """ 2560 2561 sparse_tensors = column.get_sparse_tensors(transformation_cache, 2562 state_manager) 2563 id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [ 2564 array_ops.shape(sparse_tensors.id_tensor)[0], -1 2565 ]) 2566 weight_tensor = sparse_tensors.weight_tensor 2567 if weight_tensor is not None: 2568 weight_tensor = sparse_ops.sparse_reshape( 2569 weight_tensor, [array_ops.shape(weight_tensor)[0], -1]) 2570 2571 return embedding_ops.safe_embedding_lookup_sparse( 2572 weight_var, 2573 id_tensor, 2574 sparse_weights=weight_tensor, 2575 combiner=sparse_combiner, 2576 name='weighted_sum') 2577 2578 2579class SequenceDenseColumn(FeatureColumn): 2580 """Represents dense sequence data.""" 2581 2582 TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name 2583 'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length')) 2584 2585 @abc.abstractmethod 2586 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 2587 """Returns a `TensorSequenceLengthPair`. 2588 2589 Args: 2590 transformation_cache: A `FeatureTransformationCache` object to access 2591 features. 2592 state_manager: A `StateManager` to create / access resources such as 2593 lookup tables. 2594 """ 2595 pass 2596 2597 2598class FeatureTransformationCache(object): 2599 """Handles caching of transformations while building the model. 2600 2601 `FeatureColumn` specifies how to digest an input column to the network. Some 2602 feature columns require data transformations. This class caches those 2603 transformations. 2604 2605 Some features may be used in more than one place. For example, one can use a 2606 bucketized feature by itself and a cross with it. In that case we 2607 should create only one bucketization op instead of creating ops for each 2608 feature column separately. To handle re-use of transformed columns, 2609 `FeatureTransformationCache` caches all previously transformed columns. 2610 2611 Example: 2612 We're trying to use the following `FeatureColumn`s: 2613 2614 ```python 2615 bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...) 2616 keywords = fc.categorical_column_with_hash_buckets("keywords", ...) 2617 age_X_keywords = fc.crossed_column([bucketized_age, "keywords"]) 2618 ... = linear_model(features, 2619 [bucketized_age, keywords, age_X_keywords] 2620 ``` 2621 2622 If we transform each column independently, then we'll get duplication of 2623 bucketization (one for cross, one for bucketization itself). 2624 The `FeatureTransformationCache` eliminates this duplication. 2625 """ 2626 2627 def __init__(self, features): 2628 """Creates a `FeatureTransformationCache`. 2629 2630 Args: 2631 features: A mapping from feature column to objects that are `Tensor` or 2632 `SparseTensor`, or can be converted to same via 2633 `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key 2634 signifies a base feature (not-transformed). A `FeatureColumn` key 2635 means that this `Tensor` is the output of an existing `FeatureColumn` 2636 which can be reused. 2637 """ 2638 self._features = features.copy() 2639 self._feature_tensors = {} 2640 2641 def get(self, key, state_manager): 2642 """Returns a `Tensor` for the given key. 2643 2644 A `str` key is used to access a base feature (not-transformed). When a 2645 `FeatureColumn` is passed, the transformed feature is returned if it 2646 already exists, otherwise the given `FeatureColumn` is asked to provide its 2647 transformed output, which is then cached. 2648 2649 Args: 2650 key: a `str` or a `FeatureColumn`. 2651 state_manager: A StateManager object that holds the FeatureColumn state. 2652 2653 Returns: 2654 The transformed `Tensor` corresponding to the `key`. 2655 2656 Raises: 2657 ValueError: if key is not found or a transformed `Tensor` cannot be 2658 computed. 2659 """ 2660 if key in self._feature_tensors: 2661 # FeatureColumn is already transformed or converted. 2662 return self._feature_tensors[key] 2663 2664 if key in self._features: 2665 feature_tensor = self._get_raw_feature_as_tensor(key) 2666 self._feature_tensors[key] = feature_tensor 2667 return feature_tensor 2668 2669 if isinstance(key, six.string_types): 2670 raise ValueError('Feature {} is not in features dictionary.'.format(key)) 2671 2672 if not isinstance(key, FeatureColumn): 2673 raise TypeError('"key" must be either a "str" or "FeatureColumn". ' 2674 'Provided: {}'.format(key)) 2675 2676 column = key 2677 logging.debug('Transforming feature_column %s.', column) 2678 transformed = column.transform_feature(self, state_manager) 2679 if transformed is None: 2680 raise ValueError('Column {} is not supported.'.format(column.name)) 2681 self._feature_tensors[column] = transformed 2682 return transformed 2683 2684 def _get_raw_feature_as_tensor(self, key): 2685 """Gets the raw_feature (keyed by `key`) as `tensor`. 2686 2687 The raw feature is converted to (sparse) tensor and maybe expand dim. 2688 2689 For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if 2690 the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will 2691 error out as it is not supported. 2692 2693 Args: 2694 key: A `str` key to access the raw feature. 2695 2696 Returns: 2697 A `Tensor` or `SparseTensor`. 2698 2699 Raises: 2700 ValueError: if the raw feature has rank 0. 2701 """ 2702 raw_feature = self._features[key] 2703 feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 2704 raw_feature) 2705 2706 def expand_dims(input_tensor): 2707 # Input_tensor must have rank 1. 2708 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2709 return sparse_ops.sparse_reshape( 2710 input_tensor, [array_ops.shape(input_tensor)[0], 1]) 2711 else: 2712 return array_ops.expand_dims(input_tensor, -1) 2713 2714 rank = feature_tensor.get_shape().ndims 2715 if rank is not None: 2716 if rank == 0: 2717 raise ValueError( 2718 'Feature (key: {}) cannot have rank 0. Given: {}'.format( 2719 key, feature_tensor)) 2720 return feature_tensor if rank != 1 else expand_dims(feature_tensor) 2721 2722 # Handle dynamic rank. 2723 with ops.control_dependencies([ 2724 check_ops.assert_positive( 2725 array_ops.rank(feature_tensor), 2726 message='Feature (key: {}) cannot have rank 0. Given: {}'.format( 2727 key, feature_tensor))]): 2728 return control_flow_ops.cond( 2729 math_ops.equal(1, array_ops.rank(feature_tensor)), 2730 lambda: expand_dims(feature_tensor), 2731 lambda: feature_tensor) 2732 2733 2734# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py 2735def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None): 2736 """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. 2737 2738 If `input_tensor` is already a `SparseTensor`, just return it. 2739 2740 Args: 2741 input_tensor: A string or integer `Tensor`. 2742 ignore_value: Entries in `dense_tensor` equal to this value will be 2743 absent from the resulting `SparseTensor`. If `None`, default value of 2744 `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`). 2745 2746 Returns: 2747 A `SparseTensor` with the same shape as `input_tensor`. 2748 2749 Raises: 2750 ValueError: when `input_tensor`'s rank is `None`. 2751 """ 2752 input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 2753 input_tensor) 2754 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2755 return input_tensor 2756 with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)): 2757 if ignore_value is None: 2758 if input_tensor.dtype == dtypes.string: 2759 # Exception due to TF strings are converted to numpy objects by default. 2760 ignore_value = '' 2761 elif input_tensor.dtype.is_integer: 2762 ignore_value = -1 # -1 has a special meaning of missing feature 2763 else: 2764 # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is 2765 # constructing a new numpy object of the given type, which yields the 2766 # default value for that type. 2767 ignore_value = input_tensor.dtype.as_numpy_dtype() 2768 ignore_value = math_ops.cast( 2769 ignore_value, input_tensor.dtype, name='ignore_value') 2770 indices = array_ops.where_v2( 2771 math_ops.not_equal(input_tensor, ignore_value), name='indices') 2772 return sparse_tensor_lib.SparseTensor( 2773 indices=indices, 2774 values=array_ops.gather_nd(input_tensor, indices, name='values'), 2775 dense_shape=array_ops.shape( 2776 input_tensor, out_type=dtypes.int64, name='dense_shape')) 2777 2778 2779def _normalize_feature_columns(feature_columns): 2780 """Normalizes the `feature_columns` input. 2781 2782 This method converts the `feature_columns` to list type as best as it can. In 2783 addition, verifies the type and other parts of feature_columns, required by 2784 downstream library. 2785 2786 Args: 2787 feature_columns: The raw feature columns, usually passed by users. 2788 2789 Returns: 2790 The normalized feature column list. 2791 2792 Raises: 2793 ValueError: for any invalid inputs, such as empty, duplicated names, etc. 2794 """ 2795 if isinstance(feature_columns, FeatureColumn): 2796 feature_columns = [feature_columns] 2797 2798 if isinstance(feature_columns, collections_abc.Iterator): 2799 feature_columns = list(feature_columns) 2800 2801 if isinstance(feature_columns, dict): 2802 raise ValueError('Expected feature_columns to be iterable, found dict.') 2803 2804 for column in feature_columns: 2805 if not isinstance(column, FeatureColumn): 2806 raise ValueError('Items of feature_columns must be a FeatureColumn. ' 2807 'Given (type {}): {}.'.format(type(column), column)) 2808 if not feature_columns: 2809 raise ValueError('feature_columns must not be empty.') 2810 name_to_column = {} 2811 for column in feature_columns: 2812 if column.name in name_to_column: 2813 raise ValueError('Duplicate feature column name found for columns: {} ' 2814 'and {}. This usually means that these columns refer to ' 2815 'same base feature. Either one must be discarded or a ' 2816 'duplicated but renamed item must be inserted in ' 2817 'features dict.'.format(column, 2818 name_to_column[column.name])) 2819 name_to_column[column.name] = column 2820 2821 return sorted(feature_columns, key=lambda x: x.name) 2822 2823 2824class NumericColumn( 2825 DenseColumn, 2826 fc_old._DenseColumn, # pylint: disable=protected-access 2827 collections.namedtuple( 2828 'NumericColumn', 2829 ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))): 2830 """see `numeric_column`.""" 2831 2832 @property 2833 def _is_v2_column(self): 2834 return True 2835 2836 @property 2837 def name(self): 2838 """See `FeatureColumn` base class.""" 2839 return self.key 2840 2841 @property 2842 def parse_example_spec(self): 2843 """See `FeatureColumn` base class.""" 2844 return { 2845 self.key: 2846 parsing_ops.FixedLenFeature(self.shape, self.dtype, 2847 self.default_value) 2848 } 2849 2850 @property 2851 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2852 _FEATURE_COLUMN_DEPRECATION) 2853 def _parse_example_spec(self): 2854 return self.parse_example_spec 2855 2856 def _transform_input_tensor(self, input_tensor): 2857 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2858 raise ValueError( 2859 'The corresponding Tensor of numerical column must be a Tensor. ' 2860 'SparseTensor is not supported. key: {}'.format(self.key)) 2861 if self.normalizer_fn is not None: 2862 input_tensor = self.normalizer_fn(input_tensor) 2863 return math_ops.cast(input_tensor, dtypes.float32) 2864 2865 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2866 _FEATURE_COLUMN_DEPRECATION) 2867 def _transform_feature(self, inputs): 2868 input_tensor = inputs.get(self.key) 2869 return self._transform_input_tensor(input_tensor) 2870 2871 def transform_feature(self, transformation_cache, state_manager): 2872 """See `FeatureColumn` base class. 2873 2874 In this case, we apply the `normalizer_fn` to the input tensor. 2875 2876 Args: 2877 transformation_cache: A `FeatureTransformationCache` object to access 2878 features. 2879 state_manager: A `StateManager` to create / access resources such as 2880 lookup tables. 2881 2882 Returns: 2883 Normalized input tensor. 2884 Raises: 2885 ValueError: If a SparseTensor is passed in. 2886 """ 2887 input_tensor = transformation_cache.get(self.key, state_manager) 2888 return self._transform_input_tensor(input_tensor) 2889 2890 @property 2891 def variable_shape(self): 2892 """See `DenseColumn` base class.""" 2893 return tensor_shape.TensorShape(self.shape) 2894 2895 @property 2896 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2897 _FEATURE_COLUMN_DEPRECATION) 2898 def _variable_shape(self): 2899 return self.variable_shape 2900 2901 def get_dense_tensor(self, transformation_cache, state_manager): 2902 """Returns dense `Tensor` representing numeric feature. 2903 2904 Args: 2905 transformation_cache: A `FeatureTransformationCache` object to access 2906 features. 2907 state_manager: A `StateManager` to create / access resources such as 2908 lookup tables. 2909 2910 Returns: 2911 Dense `Tensor` created within `transform_feature`. 2912 """ 2913 # Feature has been already transformed. Return the intermediate 2914 # representation created by _transform_feature. 2915 return transformation_cache.get(self, state_manager) 2916 2917 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2918 _FEATURE_COLUMN_DEPRECATION) 2919 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 2920 del weight_collections 2921 del trainable 2922 return inputs.get(self) 2923 2924 @property 2925 def parents(self): 2926 """See 'FeatureColumn` base class.""" 2927 return [self.key] 2928 2929 def get_config(self): 2930 """See 'FeatureColumn` base class.""" 2931 config = dict(zip(self._fields, self)) 2932 config['normalizer_fn'] = generic_utils.serialize_keras_object( 2933 self.normalizer_fn) 2934 config['dtype'] = self.dtype.name 2935 return config 2936 2937 @classmethod 2938 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2939 """See 'FeatureColumn` base class.""" 2940 _check_config_keys(config, cls._fields) 2941 kwargs = _standardize_and_copy_config(config) 2942 kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object( 2943 config['normalizer_fn'], custom_objects=custom_objects) 2944 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 2945 2946 return cls(**kwargs) 2947 2948 2949class BucketizedColumn( 2950 DenseColumn, 2951 CategoricalColumn, 2952 fc_old._DenseColumn, # pylint: disable=protected-access 2953 fc_old._CategoricalColumn, # pylint: disable=protected-access 2954 collections.namedtuple('BucketizedColumn', 2955 ('source_column', 'boundaries'))): 2956 """See `bucketized_column`.""" 2957 2958 @property 2959 def _is_v2_column(self): 2960 return (isinstance(self.source_column, FeatureColumn) and 2961 self.source_column._is_v2_column) # pylint: disable=protected-access 2962 2963 @property 2964 def name(self): 2965 """See `FeatureColumn` base class.""" 2966 return '{}_bucketized'.format(self.source_column.name) 2967 2968 @property 2969 def parse_example_spec(self): 2970 """See `FeatureColumn` base class.""" 2971 return self.source_column.parse_example_spec 2972 2973 @property 2974 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2975 _FEATURE_COLUMN_DEPRECATION) 2976 def _parse_example_spec(self): 2977 return self.source_column._parse_example_spec # pylint: disable=protected-access 2978 2979 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2980 _FEATURE_COLUMN_DEPRECATION) 2981 def _transform_feature(self, inputs): 2982 """Returns bucketized categorical `source_column` tensor.""" 2983 source_tensor = inputs.get(self.source_column) 2984 return math_ops._bucketize( # pylint: disable=protected-access 2985 source_tensor, 2986 boundaries=self.boundaries) 2987 2988 def transform_feature(self, transformation_cache, state_manager): 2989 """Returns bucketized categorical `source_column` tensor.""" 2990 source_tensor = transformation_cache.get(self.source_column, state_manager) 2991 return math_ops._bucketize( # pylint: disable=protected-access 2992 source_tensor, 2993 boundaries=self.boundaries) 2994 2995 @property 2996 def variable_shape(self): 2997 """See `DenseColumn` base class.""" 2998 return tensor_shape.TensorShape( 2999 tuple(self.source_column.shape) + (len(self.boundaries) + 1,)) 3000 3001 @property 3002 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3003 _FEATURE_COLUMN_DEPRECATION) 3004 def _variable_shape(self): 3005 return self.variable_shape 3006 3007 def _get_dense_tensor_for_input_tensor(self, input_tensor): 3008 return array_ops.one_hot( 3009 indices=math_ops.cast(input_tensor, dtypes.int64), 3010 depth=len(self.boundaries) + 1, 3011 on_value=1., 3012 off_value=0.) 3013 3014 def get_dense_tensor(self, transformation_cache, state_manager): 3015 """Returns one hot encoded dense `Tensor`.""" 3016 input_tensor = transformation_cache.get(self, state_manager) 3017 return self._get_dense_tensor_for_input_tensor(input_tensor) 3018 3019 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3020 _FEATURE_COLUMN_DEPRECATION) 3021 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 3022 del weight_collections 3023 del trainable 3024 input_tensor = inputs.get(self) 3025 return self._get_dense_tensor_for_input_tensor(input_tensor) 3026 3027 @property 3028 def num_buckets(self): 3029 """See `CategoricalColumn` base class.""" 3030 # By construction, source_column is always one-dimensional. 3031 return (len(self.boundaries) + 1) * self.source_column.shape[0] 3032 3033 @property 3034 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3035 _FEATURE_COLUMN_DEPRECATION) 3036 def _num_buckets(self): 3037 return self.num_buckets 3038 3039 def _get_sparse_tensors_for_input_tensor(self, input_tensor): 3040 batch_size = array_ops.shape(input_tensor)[0] 3041 # By construction, source_column is always one-dimensional. 3042 source_dimension = self.source_column.shape[0] 3043 3044 i1 = array_ops.reshape( 3045 array_ops.tile( 3046 array_ops.expand_dims(math_ops.range(0, batch_size), 1), 3047 [1, source_dimension]), 3048 (-1,)) 3049 i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size]) 3050 # Flatten the bucket indices and unique them across dimensions 3051 # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets 3052 bucket_indices = ( 3053 array_ops.reshape(input_tensor, (-1,)) + 3054 (len(self.boundaries) + 1) * i2) 3055 3056 indices = math_ops.cast( 3057 array_ops.transpose(array_ops.stack((i1, i2))), dtypes.int64) 3058 dense_shape = math_ops.cast( 3059 array_ops.stack([batch_size, source_dimension]), dtypes.int64) 3060 sparse_tensor = sparse_tensor_lib.SparseTensor( 3061 indices=indices, 3062 values=bucket_indices, 3063 dense_shape=dense_shape) 3064 return CategoricalColumn.IdWeightPair(sparse_tensor, None) 3065 3066 def get_sparse_tensors(self, transformation_cache, state_manager): 3067 """Converts dense inputs to SparseTensor so downstream code can use it.""" 3068 input_tensor = transformation_cache.get(self, state_manager) 3069 return self._get_sparse_tensors_for_input_tensor(input_tensor) 3070 3071 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3072 _FEATURE_COLUMN_DEPRECATION) 3073 def _get_sparse_tensors(self, inputs, weight_collections=None, 3074 trainable=None): 3075 """Converts dense inputs to SparseTensor so downstream code can use it.""" 3076 del weight_collections 3077 del trainable 3078 input_tensor = inputs.get(self) 3079 return self._get_sparse_tensors_for_input_tensor(input_tensor) 3080 3081 @property 3082 def parents(self): 3083 """See 'FeatureColumn` base class.""" 3084 return [self.source_column] 3085 3086 def get_config(self): 3087 """See 'FeatureColumn` base class.""" 3088 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 3089 config = dict(zip(self._fields, self)) 3090 config['source_column'] = serialize_feature_column(self.source_column) 3091 return config 3092 3093 @classmethod 3094 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3095 """See 'FeatureColumn` base class.""" 3096 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 3097 _check_config_keys(config, cls._fields) 3098 kwargs = _standardize_and_copy_config(config) 3099 kwargs['source_column'] = deserialize_feature_column( 3100 config['source_column'], custom_objects, columns_by_name) 3101 return cls(**kwargs) 3102 3103 3104class EmbeddingColumn( 3105 DenseColumn, 3106 SequenceDenseColumn, 3107 fc_old._DenseColumn, # pylint: disable=protected-access 3108 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 3109 collections.namedtuple( 3110 'EmbeddingColumn', 3111 ('categorical_column', 'dimension', 'combiner', 'initializer', 3112 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable', 3113 'use_safe_embedding_lookup'))): 3114 """See `embedding_column`.""" 3115 3116 def __new__(cls, 3117 categorical_column, 3118 dimension, 3119 combiner, 3120 initializer, 3121 ckpt_to_load_from, 3122 tensor_name_in_ckpt, 3123 max_norm, 3124 trainable, 3125 use_safe_embedding_lookup=True): 3126 return super(EmbeddingColumn, cls).__new__( 3127 cls, 3128 categorical_column=categorical_column, 3129 dimension=dimension, 3130 combiner=combiner, 3131 initializer=initializer, 3132 ckpt_to_load_from=ckpt_to_load_from, 3133 tensor_name_in_ckpt=tensor_name_in_ckpt, 3134 max_norm=max_norm, 3135 trainable=trainable, 3136 use_safe_embedding_lookup=use_safe_embedding_lookup) 3137 3138 @property 3139 def _is_v2_column(self): 3140 return (isinstance(self.categorical_column, FeatureColumn) and 3141 self.categorical_column._is_v2_column) # pylint: disable=protected-access 3142 3143 @property 3144 def name(self): 3145 """See `FeatureColumn` base class.""" 3146 return '{}_embedding'.format(self.categorical_column.name) 3147 3148 @property 3149 def parse_example_spec(self): 3150 """See `FeatureColumn` base class.""" 3151 return self.categorical_column.parse_example_spec 3152 3153 @property 3154 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3155 _FEATURE_COLUMN_DEPRECATION) 3156 def _parse_example_spec(self): 3157 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 3158 3159 def transform_feature(self, transformation_cache, state_manager): 3160 """Transforms underlying `categorical_column`.""" 3161 return transformation_cache.get(self.categorical_column, state_manager) 3162 3163 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3164 _FEATURE_COLUMN_DEPRECATION) 3165 def _transform_feature(self, inputs): 3166 return inputs.get(self.categorical_column) 3167 3168 @property 3169 def variable_shape(self): 3170 """See `DenseColumn` base class.""" 3171 return tensor_shape.TensorShape([self.dimension]) 3172 3173 @property 3174 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3175 _FEATURE_COLUMN_DEPRECATION) 3176 def _variable_shape(self): 3177 return self.variable_shape 3178 3179 def create_state(self, state_manager): 3180 """Creates the embedding lookup variable.""" 3181 default_num_buckets = (self.categorical_column.num_buckets 3182 if self._is_v2_column 3183 else self.categorical_column._num_buckets) # pylint: disable=protected-access 3184 num_buckets = getattr(self.categorical_column, 'num_buckets', 3185 default_num_buckets) 3186 embedding_shape = (num_buckets, self.dimension) 3187 state_manager.create_variable( 3188 self, 3189 name='embedding_weights', 3190 shape=embedding_shape, 3191 dtype=dtypes.float32, 3192 trainable=self.trainable, 3193 use_resource=True, 3194 initializer=self.initializer) 3195 3196 def _get_dense_tensor_internal_helper(self, sparse_tensors, 3197 embedding_weights): 3198 sparse_ids = sparse_tensors.id_tensor 3199 sparse_weights = sparse_tensors.weight_tensor 3200 3201 if self.ckpt_to_load_from is not None: 3202 to_restore = embedding_weights 3203 if isinstance(to_restore, variables.PartitionedVariable): 3204 to_restore = to_restore._get_variable_list() # pylint: disable=protected-access 3205 checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, { 3206 self.tensor_name_in_ckpt: to_restore 3207 }) 3208 3209 sparse_id_rank = tensor_shape.dimension_value( 3210 sparse_ids.dense_shape.get_shape()[0]) 3211 embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse 3212 if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and 3213 sparse_id_rank <= 2): 3214 embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse 3215 # Return embedding lookup result. 3216 return embedding_lookup_sparse( 3217 embedding_weights, 3218 sparse_ids, 3219 sparse_weights, 3220 combiner=self.combiner, 3221 name='%s_weights' % self.name, 3222 max_norm=self.max_norm) 3223 3224 def _get_dense_tensor_internal(self, sparse_tensors, state_manager): 3225 """Private method that follows the signature of get_dense_tensor.""" 3226 embedding_weights = state_manager.get_variable( 3227 self, name='embedding_weights') 3228 return self._get_dense_tensor_internal_helper(sparse_tensors, 3229 embedding_weights) 3230 3231 def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections, 3232 trainable): 3233 """Private method that follows the signature of _get_dense_tensor.""" 3234 embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access 3235 if (weight_collections and 3236 ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections): 3237 weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) 3238 embedding_weights = variable_scope.get_variable( 3239 name='embedding_weights', 3240 shape=embedding_shape, 3241 dtype=dtypes.float32, 3242 initializer=self.initializer, 3243 trainable=self.trainable and trainable, 3244 collections=weight_collections) 3245 return self._get_dense_tensor_internal_helper(sparse_tensors, 3246 embedding_weights) 3247 3248 def get_dense_tensor(self, transformation_cache, state_manager): 3249 """Returns tensor after doing the embedding lookup. 3250 3251 Args: 3252 transformation_cache: A `FeatureTransformationCache` object to access 3253 features. 3254 state_manager: A `StateManager` to create / access resources such as 3255 lookup tables. 3256 3257 Returns: 3258 Embedding lookup tensor. 3259 3260 Raises: 3261 ValueError: `categorical_column` is SequenceCategoricalColumn. 3262 """ 3263 if isinstance(self.categorical_column, SequenceCategoricalColumn): 3264 raise ValueError( 3265 'In embedding_column: {}. ' 3266 'categorical_column must not be of type SequenceCategoricalColumn. ' 3267 'Suggested fix A: If you wish to use DenseFeatures, use a ' 3268 'non-sequence categorical_column_with_*. ' 3269 'Suggested fix B: If you wish to create sequence input, use ' 3270 'SequenceFeatures instead of DenseFeatures. ' 3271 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3272 self.categorical_column)) 3273 # Get sparse IDs and weights. 3274 sparse_tensors = self.categorical_column.get_sparse_tensors( 3275 transformation_cache, state_manager) 3276 return self._get_dense_tensor_internal(sparse_tensors, state_manager) 3277 3278 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3279 _FEATURE_COLUMN_DEPRECATION) 3280 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 3281 if isinstance( 3282 self.categorical_column, 3283 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 3284 raise ValueError( 3285 'In embedding_column: {}. ' 3286 'categorical_column must not be of type _SequenceCategoricalColumn. ' 3287 'Suggested fix A: If you wish to use DenseFeatures, use a ' 3288 'non-sequence categorical_column_with_*. ' 3289 'Suggested fix B: If you wish to create sequence input, use ' 3290 'SequenceFeatures instead of DenseFeatures. ' 3291 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3292 self.categorical_column)) 3293 sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access 3294 inputs, weight_collections, trainable) 3295 return self._old_get_dense_tensor_internal(sparse_tensors, 3296 weight_collections, trainable) 3297 3298 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 3299 """See `SequenceDenseColumn` base class.""" 3300 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 3301 raise ValueError( 3302 'In embedding_column: {}. ' 3303 'categorical_column must be of type SequenceCategoricalColumn ' 3304 'to use SequenceFeatures. ' 3305 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3306 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3307 self.categorical_column)) 3308 sparse_tensors = self.categorical_column.get_sparse_tensors( 3309 transformation_cache, state_manager) 3310 dense_tensor = self._get_dense_tensor_internal(sparse_tensors, 3311 state_manager) 3312 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3313 sparse_tensors.id_tensor) 3314 return SequenceDenseColumn.TensorSequenceLengthPair( 3315 dense_tensor=dense_tensor, sequence_length=sequence_length) 3316 3317 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3318 _FEATURE_COLUMN_DEPRECATION) 3319 def _get_sequence_dense_tensor(self, 3320 inputs, 3321 weight_collections=None, 3322 trainable=None): 3323 if not isinstance( 3324 self.categorical_column, 3325 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 3326 raise ValueError( 3327 'In embedding_column: {}. ' 3328 'categorical_column must be of type SequenceCategoricalColumn ' 3329 'to use SequenceFeatures. ' 3330 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3331 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3332 self.categorical_column)) 3333 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 3334 dense_tensor = self._old_get_dense_tensor_internal( 3335 sparse_tensors, 3336 weight_collections=weight_collections, 3337 trainable=trainable) 3338 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3339 sparse_tensors.id_tensor) 3340 return SequenceDenseColumn.TensorSequenceLengthPair( 3341 dense_tensor=dense_tensor, sequence_length=sequence_length) 3342 3343 @property 3344 def parents(self): 3345 """See 'FeatureColumn` base class.""" 3346 return [self.categorical_column] 3347 3348 def get_config(self): 3349 """See 'FeatureColumn` base class.""" 3350 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 3351 config = dict(zip(self._fields, self)) 3352 config['categorical_column'] = serialize_feature_column( 3353 self.categorical_column) 3354 config['initializer'] = initializers.serialize(self.initializer) 3355 return config 3356 3357 @classmethod 3358 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3359 """See 'FeatureColumn` base class.""" 3360 if 'use_safe_embedding_lookup' not in config: 3361 config['use_safe_embedding_lookup'] = True 3362 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 3363 _check_config_keys(config, cls._fields) 3364 kwargs = _standardize_and_copy_config(config) 3365 kwargs['categorical_column'] = deserialize_feature_column( 3366 config['categorical_column'], custom_objects, columns_by_name) 3367 kwargs['initializer'] = initializers.deserialize( 3368 config['initializer'], custom_objects=custom_objects) 3369 return cls(**kwargs) 3370 3371 3372def _raise_shared_embedding_column_error(): 3373 raise ValueError('SharedEmbeddingColumns are not supported in ' 3374 '`linear_model` or `input_layer`. Please use ' 3375 '`DenseFeatures` or `LinearModel` instead.') 3376 3377 3378class SharedEmbeddingColumnCreator(tracking.AutoTrackable): 3379 3380 def __init__(self, 3381 dimension, 3382 initializer, 3383 ckpt_to_load_from, 3384 tensor_name_in_ckpt, 3385 num_buckets, 3386 trainable, 3387 name='shared_embedding_column_creator', 3388 use_safe_embedding_lookup=True): 3389 self._dimension = dimension 3390 self._initializer = initializer 3391 self._ckpt_to_load_from = ckpt_to_load_from 3392 self._tensor_name_in_ckpt = tensor_name_in_ckpt 3393 self._num_buckets = num_buckets 3394 self._trainable = trainable 3395 self._name = name 3396 self._use_safe_embedding_lookup = use_safe_embedding_lookup 3397 # Map from graph keys to embedding_weight variables. 3398 self._embedding_weights = {} 3399 3400 def __call__(self, categorical_column, combiner, max_norm): 3401 return SharedEmbeddingColumn(categorical_column, self, combiner, max_norm, 3402 self._use_safe_embedding_lookup) 3403 3404 @property 3405 def embedding_weights(self): 3406 key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 3407 if key not in self._embedding_weights: 3408 embedding_shape = (self._num_buckets, self._dimension) 3409 var = variable_scope.get_variable( 3410 name=self._name, 3411 shape=embedding_shape, 3412 dtype=dtypes.float32, 3413 initializer=self._initializer, 3414 trainable=self._trainable) 3415 3416 if self._ckpt_to_load_from is not None: 3417 to_restore = var 3418 if isinstance(to_restore, variables.PartitionedVariable): 3419 to_restore = to_restore._get_variable_list() # pylint: disable=protected-access 3420 checkpoint_utils.init_from_checkpoint( 3421 self._ckpt_to_load_from, {self._tensor_name_in_ckpt: to_restore}) 3422 self._embedding_weights[key] = var 3423 return self._embedding_weights[key] 3424 3425 @property 3426 def dimension(self): 3427 return self._dimension 3428 3429 3430class SharedEmbeddingColumn( 3431 DenseColumn, 3432 SequenceDenseColumn, 3433 fc_old._DenseColumn, # pylint: disable=protected-access 3434 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 3435 collections.namedtuple( 3436 'SharedEmbeddingColumn', 3437 ('categorical_column', 'shared_embedding_column_creator', 'combiner', 3438 'max_norm', 'use_safe_embedding_lookup'))): 3439 """See `embedding_column`.""" 3440 3441 def __new__(cls, 3442 categorical_column, 3443 shared_embedding_column_creator, 3444 combiner, 3445 max_norm, 3446 use_safe_embedding_lookup=True): 3447 return super(SharedEmbeddingColumn, cls).__new__( 3448 cls, 3449 categorical_column=categorical_column, 3450 shared_embedding_column_creator=shared_embedding_column_creator, 3451 combiner=combiner, 3452 max_norm=max_norm, 3453 use_safe_embedding_lookup=use_safe_embedding_lookup) 3454 3455 @property 3456 def _is_v2_column(self): 3457 return True 3458 3459 @property 3460 def name(self): 3461 """See `FeatureColumn` base class.""" 3462 return '{}_shared_embedding'.format(self.categorical_column.name) 3463 3464 @property 3465 def parse_example_spec(self): 3466 """See `FeatureColumn` base class.""" 3467 return self.categorical_column.parse_example_spec 3468 3469 @property 3470 def _parse_example_spec(self): 3471 return _raise_shared_embedding_column_error() 3472 3473 def transform_feature(self, transformation_cache, state_manager): 3474 """See `FeatureColumn` base class.""" 3475 return transformation_cache.get(self.categorical_column, state_manager) 3476 3477 def _transform_feature(self, inputs): 3478 return _raise_shared_embedding_column_error() 3479 3480 @property 3481 def variable_shape(self): 3482 """See `DenseColumn` base class.""" 3483 return tensor_shape.TensorShape( 3484 [self.shared_embedding_column_creator.dimension]) 3485 3486 @property 3487 def _variable_shape(self): 3488 return _raise_shared_embedding_column_error() 3489 3490 def _get_dense_tensor_internal(self, transformation_cache, state_manager): 3491 """Private method that follows the signature of _get_dense_tensor.""" 3492 # This method is called from a variable_scope with name _var_scope_name, 3493 # which is shared among all shared embeddings. Open a name_scope here, so 3494 # that the ops for different columns have distinct names. 3495 with ops.name_scope(None, default_name=self.name): 3496 # Get sparse IDs and weights. 3497 sparse_tensors = self.categorical_column.get_sparse_tensors( 3498 transformation_cache, state_manager) 3499 sparse_ids = sparse_tensors.id_tensor 3500 sparse_weights = sparse_tensors.weight_tensor 3501 3502 embedding_weights = self.shared_embedding_column_creator.embedding_weights 3503 3504 sparse_id_rank = tensor_shape.dimension_value( 3505 sparse_ids.dense_shape.get_shape()[0]) 3506 embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse 3507 if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and 3508 sparse_id_rank <= 2): 3509 embedding_lookup_sparse = (embedding_ops.embedding_lookup_sparse) 3510 # Return embedding lookup result. 3511 return embedding_lookup_sparse( 3512 embedding_weights, 3513 sparse_ids, 3514 sparse_weights, 3515 combiner=self.combiner, 3516 name='%s_weights' % self.name, 3517 max_norm=self.max_norm) 3518 3519 def get_dense_tensor(self, transformation_cache, state_manager): 3520 """Returns the embedding lookup result.""" 3521 if isinstance(self.categorical_column, SequenceCategoricalColumn): 3522 raise ValueError( 3523 'In embedding_column: {}. ' 3524 'categorical_column must not be of type SequenceCategoricalColumn. ' 3525 'Suggested fix A: If you wish to use DenseFeatures, use a ' 3526 'non-sequence categorical_column_with_*. ' 3527 'Suggested fix B: If you wish to create sequence input, use ' 3528 'SequenceFeatures instead of DenseFeatures. ' 3529 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3530 self.categorical_column)) 3531 return self._get_dense_tensor_internal(transformation_cache, state_manager) 3532 3533 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 3534 return _raise_shared_embedding_column_error() 3535 3536 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 3537 """See `SequenceDenseColumn` base class.""" 3538 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 3539 raise ValueError( 3540 'In embedding_column: {}. ' 3541 'categorical_column must be of type SequenceCategoricalColumn ' 3542 'to use SequenceFeatures. ' 3543 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3544 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3545 self.categorical_column)) 3546 dense_tensor = self._get_dense_tensor_internal(transformation_cache, 3547 state_manager) 3548 sparse_tensors = self.categorical_column.get_sparse_tensors( 3549 transformation_cache, state_manager) 3550 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3551 sparse_tensors.id_tensor) 3552 return SequenceDenseColumn.TensorSequenceLengthPair( 3553 dense_tensor=dense_tensor, sequence_length=sequence_length) 3554 3555 def _get_sequence_dense_tensor(self, 3556 inputs, 3557 weight_collections=None, 3558 trainable=None): 3559 return _raise_shared_embedding_column_error() 3560 3561 @property 3562 def parents(self): 3563 """See 'FeatureColumn` base class.""" 3564 return [self.categorical_column] 3565 3566 3567def _check_shape(shape, key): 3568 """Returns shape if it's valid, raises error otherwise.""" 3569 assert shape is not None 3570 if not nest.is_sequence(shape): 3571 shape = [shape] 3572 shape = tuple(shape) 3573 for dimension in shape: 3574 if not isinstance(dimension, int): 3575 raise TypeError('shape dimensions must be integer. ' 3576 'shape: {}, key: {}'.format(shape, key)) 3577 if dimension < 1: 3578 raise ValueError('shape dimensions must be greater than 0. ' 3579 'shape: {}, key: {}'.format(shape, key)) 3580 return shape 3581 3582 3583class HashedCategoricalColumn( 3584 CategoricalColumn, 3585 fc_old._CategoricalColumn, # pylint: disable=protected-access 3586 collections.namedtuple('HashedCategoricalColumn', 3587 ('key', 'hash_bucket_size', 'dtype'))): 3588 """see `categorical_column_with_hash_bucket`.""" 3589 3590 @property 3591 def _is_v2_column(self): 3592 return True 3593 3594 @property 3595 def name(self): 3596 """See `FeatureColumn` base class.""" 3597 return self.key 3598 3599 @property 3600 def parse_example_spec(self): 3601 """See `FeatureColumn` base class.""" 3602 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3603 3604 @property 3605 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3606 _FEATURE_COLUMN_DEPRECATION) 3607 def _parse_example_spec(self): 3608 return self.parse_example_spec 3609 3610 def _transform_input_tensor(self, input_tensor): 3611 """Hashes the values in the feature_column.""" 3612 if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 3613 raise ValueError('SparseColumn input must be a SparseTensor.') 3614 3615 fc_utils.assert_string_or_int( 3616 input_tensor.dtype, 3617 prefix='column_name: {} input_tensor'.format(self.key)) 3618 3619 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3620 raise ValueError( 3621 'Column dtype and SparseTensors dtype must be compatible. ' 3622 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3623 self.key, self.dtype, input_tensor.dtype)) 3624 3625 if self.dtype == dtypes.string: 3626 sparse_values = input_tensor.values 3627 else: 3628 sparse_values = string_ops.as_string(input_tensor.values) 3629 3630 sparse_id_values = string_ops.string_to_hash_bucket_fast( 3631 sparse_values, self.hash_bucket_size, name='lookup') 3632 return sparse_tensor_lib.SparseTensor( 3633 input_tensor.indices, sparse_id_values, input_tensor.dense_shape) 3634 3635 def transform_feature(self, transformation_cache, state_manager): 3636 """Hashes the values in the feature_column.""" 3637 input_tensor = _to_sparse_input_and_drop_ignore_values( 3638 transformation_cache.get(self.key, state_manager)) 3639 return self._transform_input_tensor(input_tensor) 3640 3641 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3642 _FEATURE_COLUMN_DEPRECATION) 3643 def _transform_feature(self, inputs): 3644 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3645 return self._transform_input_tensor(input_tensor) 3646 3647 @property 3648 def num_buckets(self): 3649 """Returns number of buckets in this sparse feature.""" 3650 return self.hash_bucket_size 3651 3652 @property 3653 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3654 _FEATURE_COLUMN_DEPRECATION) 3655 def _num_buckets(self): 3656 return self.num_buckets 3657 3658 def get_sparse_tensors(self, transformation_cache, state_manager): 3659 """See `CategoricalColumn` base class.""" 3660 return CategoricalColumn.IdWeightPair( 3661 transformation_cache.get(self, state_manager), None) 3662 3663 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3664 _FEATURE_COLUMN_DEPRECATION) 3665 def _get_sparse_tensors(self, inputs, weight_collections=None, 3666 trainable=None): 3667 del weight_collections 3668 del trainable 3669 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3670 3671 @property 3672 def parents(self): 3673 """See 'FeatureColumn` base class.""" 3674 return [self.key] 3675 3676 def get_config(self): 3677 """See 'FeatureColumn` base class.""" 3678 config = dict(zip(self._fields, self)) 3679 config['dtype'] = self.dtype.name 3680 return config 3681 3682 @classmethod 3683 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3684 """See 'FeatureColumn` base class.""" 3685 _check_config_keys(config, cls._fields) 3686 kwargs = _standardize_and_copy_config(config) 3687 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3688 return cls(**kwargs) 3689 3690 3691class VocabularyFileCategoricalColumn( 3692 CategoricalColumn, 3693 fc_old._CategoricalColumn, # pylint: disable=protected-access 3694 collections.namedtuple('VocabularyFileCategoricalColumn', 3695 ('key', 'vocabulary_file', 'vocabulary_size', 3696 'num_oov_buckets', 'dtype', 'default_value'))): 3697 """See `categorical_column_with_vocabulary_file`.""" 3698 3699 @property 3700 def _is_v2_column(self): 3701 return True 3702 3703 @property 3704 def name(self): 3705 """See `FeatureColumn` base class.""" 3706 return self.key 3707 3708 @property 3709 def parse_example_spec(self): 3710 """See `FeatureColumn` base class.""" 3711 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3712 3713 @property 3714 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3715 _FEATURE_COLUMN_DEPRECATION) 3716 def _parse_example_spec(self): 3717 return self.parse_example_spec 3718 3719 def _transform_input_tensor(self, input_tensor, state_manager=None): 3720 """Creates a lookup table for the vocabulary.""" 3721 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3722 raise ValueError( 3723 'Column dtype and SparseTensors dtype must be compatible. ' 3724 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3725 self.key, self.dtype, input_tensor.dtype)) 3726 3727 fc_utils.assert_string_or_int( 3728 input_tensor.dtype, 3729 prefix='column_name: {} input_tensor'.format(self.key)) 3730 3731 key_dtype = self.dtype 3732 if input_tensor.dtype.is_integer: 3733 # `index_table_from_file` requires 64-bit integer keys. 3734 key_dtype = dtypes.int64 3735 input_tensor = math_ops.cast(input_tensor, dtypes.int64) 3736 3737 name = '{}_lookup'.format(self.key) 3738 table = lookup_ops.index_table_from_file( 3739 vocabulary_file=self.vocabulary_file, 3740 num_oov_buckets=self.num_oov_buckets, 3741 vocab_size=self.vocabulary_size, 3742 default_value=self.default_value, 3743 key_dtype=key_dtype, 3744 name=name) 3745 if state_manager is not None: 3746 state_manager.add_resource(self, name, table) 3747 return table.lookup(input_tensor) 3748 3749 def transform_feature(self, transformation_cache, state_manager): 3750 """Creates a lookup table for the vocabulary.""" 3751 input_tensor = _to_sparse_input_and_drop_ignore_values( 3752 transformation_cache.get(self.key, state_manager)) 3753 return self._transform_input_tensor(input_tensor, state_manager) 3754 3755 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3756 _FEATURE_COLUMN_DEPRECATION) 3757 def _transform_feature(self, inputs): 3758 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3759 return self._transform_input_tensor(input_tensor) 3760 3761 @property 3762 def num_buckets(self): 3763 """Returns number of buckets in this sparse feature.""" 3764 return self.vocabulary_size + self.num_oov_buckets 3765 3766 @property 3767 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3768 _FEATURE_COLUMN_DEPRECATION) 3769 def _num_buckets(self): 3770 return self.num_buckets 3771 3772 def get_sparse_tensors(self, transformation_cache, state_manager): 3773 """See `CategoricalColumn` base class.""" 3774 return CategoricalColumn.IdWeightPair( 3775 transformation_cache.get(self, state_manager), None) 3776 3777 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3778 _FEATURE_COLUMN_DEPRECATION) 3779 def _get_sparse_tensors(self, inputs, weight_collections=None, 3780 trainable=None): 3781 del weight_collections 3782 del trainable 3783 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3784 3785 @property 3786 def parents(self): 3787 """See 'FeatureColumn` base class.""" 3788 return [self.key] 3789 3790 def get_config(self): 3791 """See 'FeatureColumn` base class.""" 3792 config = dict(zip(self._fields, self)) 3793 config['dtype'] = self.dtype.name 3794 return config 3795 3796 @classmethod 3797 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3798 """See 'FeatureColumn` base class.""" 3799 _check_config_keys(config, cls._fields) 3800 kwargs = _standardize_and_copy_config(config) 3801 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3802 return cls(**kwargs) 3803 3804 3805class VocabularyListCategoricalColumn( 3806 CategoricalColumn, 3807 fc_old._CategoricalColumn, # pylint: disable=protected-access 3808 collections.namedtuple( 3809 'VocabularyListCategoricalColumn', 3810 ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets')) 3811): 3812 """See `categorical_column_with_vocabulary_list`.""" 3813 3814 @property 3815 def _is_v2_column(self): 3816 return True 3817 3818 @property 3819 def name(self): 3820 """See `FeatureColumn` base class.""" 3821 return self.key 3822 3823 @property 3824 def parse_example_spec(self): 3825 """See `FeatureColumn` base class.""" 3826 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3827 3828 @property 3829 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3830 _FEATURE_COLUMN_DEPRECATION) 3831 def _parse_example_spec(self): 3832 return self.parse_example_spec 3833 3834 def _transform_input_tensor(self, input_tensor, state_manager=None): 3835 """Creates a lookup table for the vocabulary list.""" 3836 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3837 raise ValueError( 3838 'Column dtype and SparseTensors dtype must be compatible. ' 3839 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3840 self.key, self.dtype, input_tensor.dtype)) 3841 3842 fc_utils.assert_string_or_int( 3843 input_tensor.dtype, 3844 prefix='column_name: {} input_tensor'.format(self.key)) 3845 3846 key_dtype = self.dtype 3847 if input_tensor.dtype.is_integer: 3848 # `index_table_from_tensor` requires 64-bit integer keys. 3849 key_dtype = dtypes.int64 3850 input_tensor = math_ops.cast(input_tensor, dtypes.int64) 3851 3852 name = '{}_lookup'.format(self.key) 3853 table = lookup_ops.index_table_from_tensor( 3854 vocabulary_list=tuple(self.vocabulary_list), 3855 default_value=self.default_value, 3856 num_oov_buckets=self.num_oov_buckets, 3857 dtype=key_dtype, 3858 name=name) 3859 if state_manager is not None: 3860 state_manager.add_resource(self, name, table) 3861 return table.lookup(input_tensor) 3862 3863 def transform_feature(self, transformation_cache, state_manager): 3864 """Creates a lookup table for the vocabulary list.""" 3865 input_tensor = _to_sparse_input_and_drop_ignore_values( 3866 transformation_cache.get(self.key, state_manager)) 3867 return self._transform_input_tensor(input_tensor, state_manager) 3868 3869 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3870 _FEATURE_COLUMN_DEPRECATION) 3871 def _transform_feature(self, inputs): 3872 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3873 return self._transform_input_tensor(input_tensor) 3874 3875 @property 3876 def num_buckets(self): 3877 """Returns number of buckets in this sparse feature.""" 3878 return len(self.vocabulary_list) + self.num_oov_buckets 3879 3880 @property 3881 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3882 _FEATURE_COLUMN_DEPRECATION) 3883 def _num_buckets(self): 3884 return self.num_buckets 3885 3886 def get_sparse_tensors(self, transformation_cache, state_manager): 3887 """See `CategoricalColumn` base class.""" 3888 return CategoricalColumn.IdWeightPair( 3889 transformation_cache.get(self, state_manager), None) 3890 3891 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3892 _FEATURE_COLUMN_DEPRECATION) 3893 def _get_sparse_tensors(self, inputs, weight_collections=None, 3894 trainable=None): 3895 del weight_collections 3896 del trainable 3897 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3898 3899 @property 3900 def parents(self): 3901 """See 'FeatureColumn` base class.""" 3902 return [self.key] 3903 3904 def get_config(self): 3905 """See 'FeatureColumn` base class.""" 3906 config = dict(zip(self._fields, self)) 3907 config['dtype'] = self.dtype.name 3908 return config 3909 3910 @classmethod 3911 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3912 """See 'FeatureColumn` base class.""" 3913 _check_config_keys(config, cls._fields) 3914 kwargs = _standardize_and_copy_config(config) 3915 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3916 return cls(**kwargs) 3917 3918 3919class IdentityCategoricalColumn( 3920 CategoricalColumn, 3921 fc_old._CategoricalColumn, # pylint: disable=protected-access 3922 collections.namedtuple('IdentityCategoricalColumn', 3923 ('key', 'number_buckets', 'default_value'))): 3924 3925 """See `categorical_column_with_identity`.""" 3926 3927 @property 3928 def _is_v2_column(self): 3929 return True 3930 3931 @property 3932 def name(self): 3933 """See `FeatureColumn` base class.""" 3934 return self.key 3935 3936 @property 3937 def parse_example_spec(self): 3938 """See `FeatureColumn` base class.""" 3939 return {self.key: parsing_ops.VarLenFeature(dtypes.int64)} 3940 3941 @property 3942 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3943 _FEATURE_COLUMN_DEPRECATION) 3944 def _parse_example_spec(self): 3945 return self.parse_example_spec 3946 3947 def _transform_input_tensor(self, input_tensor): 3948 """Returns a SparseTensor with identity values.""" 3949 if not input_tensor.dtype.is_integer: 3950 raise ValueError( 3951 'Invalid input, not integer. key: {} dtype: {}'.format( 3952 self.key, input_tensor.dtype)) 3953 values = input_tensor.values 3954 if input_tensor.values.dtype != dtypes.int64: 3955 values = math_ops.cast(values, dtypes.int64, name='values') 3956 if self.default_value is not None: 3957 values = math_ops.cast(input_tensor.values, dtypes.int64, name='values') 3958 num_buckets = math_ops.cast( 3959 self.num_buckets, dtypes.int64, name='num_buckets') 3960 zero = math_ops.cast(0, dtypes.int64, name='zero') 3961 # Assign default for out-of-range values. 3962 values = array_ops.where_v2( 3963 math_ops.logical_or( 3964 values < zero, values >= num_buckets, name='out_of_range'), 3965 array_ops.fill( 3966 dims=array_ops.shape(values), 3967 value=math_ops.cast(self.default_value, dtypes.int64), 3968 name='default_values'), values) 3969 3970 return sparse_tensor_lib.SparseTensor( 3971 indices=input_tensor.indices, 3972 values=values, 3973 dense_shape=input_tensor.dense_shape) 3974 3975 def transform_feature(self, transformation_cache, state_manager): 3976 """Returns a SparseTensor with identity values.""" 3977 input_tensor = _to_sparse_input_and_drop_ignore_values( 3978 transformation_cache.get(self.key, state_manager)) 3979 return self._transform_input_tensor(input_tensor) 3980 3981 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3982 _FEATURE_COLUMN_DEPRECATION) 3983 def _transform_feature(self, inputs): 3984 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3985 return self._transform_input_tensor(input_tensor) 3986 3987 @property 3988 def num_buckets(self): 3989 """Returns number of buckets in this sparse feature.""" 3990 return self.number_buckets 3991 3992 @property 3993 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3994 _FEATURE_COLUMN_DEPRECATION) 3995 def _num_buckets(self): 3996 return self.num_buckets 3997 3998 def get_sparse_tensors(self, transformation_cache, state_manager): 3999 """See `CategoricalColumn` base class.""" 4000 return CategoricalColumn.IdWeightPair( 4001 transformation_cache.get(self, state_manager), None) 4002 4003 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4004 _FEATURE_COLUMN_DEPRECATION) 4005 def _get_sparse_tensors(self, inputs, weight_collections=None, 4006 trainable=None): 4007 del weight_collections 4008 del trainable 4009 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 4010 4011 @property 4012 def parents(self): 4013 """See 'FeatureColumn` base class.""" 4014 return [self.key] 4015 4016 def get_config(self): 4017 """See 'FeatureColumn` base class.""" 4018 return dict(zip(self._fields, self)) 4019 4020 @classmethod 4021 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4022 """See 'FeatureColumn` base class.""" 4023 _check_config_keys(config, cls._fields) 4024 kwargs = _standardize_and_copy_config(config) 4025 return cls(**kwargs) 4026 4027 4028class WeightedCategoricalColumn( 4029 CategoricalColumn, 4030 fc_old._CategoricalColumn, # pylint: disable=protected-access 4031 collections.namedtuple( 4032 'WeightedCategoricalColumn', 4033 ('categorical_column', 'weight_feature_key', 'dtype'))): 4034 """See `weighted_categorical_column`.""" 4035 4036 @property 4037 def _is_v2_column(self): 4038 return (isinstance(self.categorical_column, FeatureColumn) and 4039 self.categorical_column._is_v2_column) # pylint: disable=protected-access 4040 4041 @property 4042 def name(self): 4043 """See `FeatureColumn` base class.""" 4044 return '{}_weighted_by_{}'.format( 4045 self.categorical_column.name, self.weight_feature_key) 4046 4047 @property 4048 def parse_example_spec(self): 4049 """See `FeatureColumn` base class.""" 4050 config = self.categorical_column.parse_example_spec 4051 if self.weight_feature_key in config: 4052 raise ValueError('Parse config {} already exists for {}.'.format( 4053 config[self.weight_feature_key], self.weight_feature_key)) 4054 config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype) 4055 return config 4056 4057 @property 4058 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4059 _FEATURE_COLUMN_DEPRECATION) 4060 def _parse_example_spec(self): 4061 config = self.categorical_column._parse_example_spec # pylint: disable=protected-access 4062 if self.weight_feature_key in config: 4063 raise ValueError('Parse config {} already exists for {}.'.format( 4064 config[self.weight_feature_key], self.weight_feature_key)) 4065 config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype) 4066 return config 4067 4068 @property 4069 def num_buckets(self): 4070 """See `DenseColumn` base class.""" 4071 return self.categorical_column.num_buckets 4072 4073 @property 4074 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4075 _FEATURE_COLUMN_DEPRECATION) 4076 def _num_buckets(self): 4077 return self.categorical_column._num_buckets # pylint: disable=protected-access 4078 4079 def _transform_weight_tensor(self, weight_tensor): 4080 if weight_tensor is None: 4081 raise ValueError('Missing weights {}.'.format(self.weight_feature_key)) 4082 weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 4083 weight_tensor) 4084 if self.dtype != weight_tensor.dtype.base_dtype: 4085 raise ValueError('Bad dtype, expected {}, but got {}.'.format( 4086 self.dtype, weight_tensor.dtype)) 4087 if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor): 4088 # The weight tensor can be a regular Tensor. In this case, sparsify it. 4089 weight_tensor = _to_sparse_input_and_drop_ignore_values( 4090 weight_tensor, ignore_value=0.0) 4091 if not weight_tensor.dtype.is_floating: 4092 weight_tensor = math_ops.cast(weight_tensor, dtypes.float32) 4093 return weight_tensor 4094 4095 def transform_feature(self, transformation_cache, state_manager): 4096 """Applies weights to tensor generated from `categorical_column`'.""" 4097 weight_tensor = transformation_cache.get(self.weight_feature_key, 4098 state_manager) 4099 sparse_weight_tensor = self._transform_weight_tensor(weight_tensor) 4100 sparse_categorical_tensor = _to_sparse_input_and_drop_ignore_values( 4101 transformation_cache.get(self.categorical_column, state_manager)) 4102 return (sparse_categorical_tensor, sparse_weight_tensor) 4103 4104 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4105 _FEATURE_COLUMN_DEPRECATION) 4106 def _transform_feature(self, inputs): 4107 """Applies weights to tensor generated from `categorical_column`'.""" 4108 weight_tensor = inputs.get(self.weight_feature_key) 4109 weight_tensor = self._transform_weight_tensor(weight_tensor) 4110 return (inputs.get(self.categorical_column), weight_tensor) 4111 4112 def get_sparse_tensors(self, transformation_cache, state_manager): 4113 """See `CategoricalColumn` base class.""" 4114 tensors = transformation_cache.get(self, state_manager) 4115 return CategoricalColumn.IdWeightPair(tensors[0], tensors[1]) 4116 4117 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4118 _FEATURE_COLUMN_DEPRECATION) 4119 def _get_sparse_tensors(self, inputs, weight_collections=None, 4120 trainable=None): 4121 del weight_collections 4122 del trainable 4123 tensors = inputs.get(self) 4124 return CategoricalColumn.IdWeightPair(tensors[0], tensors[1]) 4125 4126 @property 4127 def parents(self): 4128 """See 'FeatureColumn` base class.""" 4129 return [self.categorical_column, self.weight_feature_key] 4130 4131 def get_config(self): 4132 """See 'FeatureColumn` base class.""" 4133 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4134 config = dict(zip(self._fields, self)) 4135 config['categorical_column'] = serialize_feature_column( 4136 self.categorical_column) 4137 config['dtype'] = self.dtype.name 4138 return config 4139 4140 @classmethod 4141 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4142 """See 'FeatureColumn` base class.""" 4143 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4144 _check_config_keys(config, cls._fields) 4145 kwargs = _standardize_and_copy_config(config) 4146 kwargs['categorical_column'] = deserialize_feature_column( 4147 config['categorical_column'], custom_objects, columns_by_name) 4148 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 4149 return cls(**kwargs) 4150 4151 4152class CrossedColumn( 4153 CategoricalColumn, 4154 fc_old._CategoricalColumn, # pylint: disable=protected-access 4155 collections.namedtuple('CrossedColumn', 4156 ('keys', 'hash_bucket_size', 'hash_key'))): 4157 """See `crossed_column`.""" 4158 4159 @property 4160 def _is_v2_column(self): 4161 for key in _collect_leaf_level_keys(self): 4162 if isinstance(key, six.string_types): 4163 continue 4164 if not isinstance(key, FeatureColumn): 4165 return False 4166 if not key._is_v2_column: # pylint: disable=protected-access 4167 return False 4168 return True 4169 4170 @property 4171 def name(self): 4172 """See `FeatureColumn` base class.""" 4173 feature_names = [] 4174 for key in _collect_leaf_level_keys(self): 4175 if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)): # pylint: disable=protected-access 4176 feature_names.append(key.name) 4177 else: # key must be a string 4178 feature_names.append(key) 4179 return '_X_'.join(sorted(feature_names)) 4180 4181 @property 4182 def parse_example_spec(self): 4183 """See `FeatureColumn` base class.""" 4184 config = {} 4185 for key in self.keys: 4186 if isinstance(key, FeatureColumn): 4187 config.update(key.parse_example_spec) 4188 elif isinstance(key, fc_old._FeatureColumn): # pylint: disable=protected-access 4189 config.update(key._parse_example_spec) # pylint: disable=protected-access 4190 else: # key must be a string 4191 config.update({key: parsing_ops.VarLenFeature(dtypes.string)}) 4192 return config 4193 4194 @property 4195 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4196 _FEATURE_COLUMN_DEPRECATION) 4197 def _parse_example_spec(self): 4198 return self.parse_example_spec 4199 4200 def transform_feature(self, transformation_cache, state_manager): 4201 """Generates a hashed sparse cross from the input tensors.""" 4202 feature_tensors = [] 4203 for key in _collect_leaf_level_keys(self): 4204 if isinstance(key, six.string_types): 4205 feature_tensors.append(transformation_cache.get(key, state_manager)) 4206 elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)): # pylint: disable=protected-access 4207 ids_and_weights = key.get_sparse_tensors(transformation_cache, 4208 state_manager) 4209 if ids_and_weights.weight_tensor is not None: 4210 raise ValueError( 4211 'crossed_column does not support weight_tensor, but the given ' 4212 'column populates weight_tensor. ' 4213 'Given column: {}'.format(key.name)) 4214 feature_tensors.append(ids_and_weights.id_tensor) 4215 else: 4216 raise ValueError('Unsupported column type. Given: {}'.format(key)) 4217 return sparse_ops.sparse_cross_hashed( 4218 inputs=feature_tensors, 4219 num_buckets=self.hash_bucket_size, 4220 hash_key=self.hash_key) 4221 4222 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4223 _FEATURE_COLUMN_DEPRECATION) 4224 def _transform_feature(self, inputs): 4225 """Generates a hashed sparse cross from the input tensors.""" 4226 feature_tensors = [] 4227 for key in _collect_leaf_level_keys(self): 4228 if isinstance(key, six.string_types): 4229 feature_tensors.append(inputs.get(key)) 4230 elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access 4231 ids_and_weights = key._get_sparse_tensors(inputs) # pylint: disable=protected-access 4232 if ids_and_weights.weight_tensor is not None: 4233 raise ValueError( 4234 'crossed_column does not support weight_tensor, but the given ' 4235 'column populates weight_tensor. ' 4236 'Given column: {}'.format(key.name)) 4237 feature_tensors.append(ids_and_weights.id_tensor) 4238 else: 4239 raise ValueError('Unsupported column type. Given: {}'.format(key)) 4240 return sparse_ops.sparse_cross_hashed( 4241 inputs=feature_tensors, 4242 num_buckets=self.hash_bucket_size, 4243 hash_key=self.hash_key) 4244 4245 @property 4246 def num_buckets(self): 4247 """Returns number of buckets in this sparse feature.""" 4248 return self.hash_bucket_size 4249 4250 @property 4251 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4252 _FEATURE_COLUMN_DEPRECATION) 4253 def _num_buckets(self): 4254 return self.num_buckets 4255 4256 def get_sparse_tensors(self, transformation_cache, state_manager): 4257 """See `CategoricalColumn` base class.""" 4258 return CategoricalColumn.IdWeightPair( 4259 transformation_cache.get(self, state_manager), None) 4260 4261 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4262 _FEATURE_COLUMN_DEPRECATION) 4263 def _get_sparse_tensors(self, inputs, weight_collections=None, 4264 trainable=None): 4265 """See `CategoricalColumn` base class.""" 4266 del weight_collections 4267 del trainable 4268 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 4269 4270 @property 4271 def parents(self): 4272 """See 'FeatureColumn` base class.""" 4273 return list(self.keys) 4274 4275 def get_config(self): 4276 """See 'FeatureColumn` base class.""" 4277 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4278 config = dict(zip(self._fields, self)) 4279 config['keys'] = tuple([serialize_feature_column(fc) for fc in self.keys]) 4280 return config 4281 4282 @classmethod 4283 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4284 """See 'FeatureColumn` base class.""" 4285 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4286 _check_config_keys(config, cls._fields) 4287 kwargs = _standardize_and_copy_config(config) 4288 kwargs['keys'] = tuple([ 4289 deserialize_feature_column(c, custom_objects, columns_by_name) 4290 for c in config['keys'] 4291 ]) 4292 return cls(**kwargs) 4293 4294 4295def _collect_leaf_level_keys(cross): 4296 """Collects base keys by expanding all nested crosses. 4297 4298 Args: 4299 cross: A `CrossedColumn`. 4300 4301 Returns: 4302 A list of strings or `CategoricalColumn` instances. 4303 """ 4304 leaf_level_keys = [] 4305 for k in cross.keys: 4306 if isinstance(k, CrossedColumn): 4307 leaf_level_keys.extend(_collect_leaf_level_keys(k)) 4308 else: 4309 leaf_level_keys.append(k) 4310 return leaf_level_keys 4311 4312 4313def _prune_invalid_ids(sparse_ids, sparse_weights): 4314 """Prune invalid IDs (< 0) from the input ids and weights.""" 4315 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) 4316 if sparse_weights is not None: 4317 is_id_valid = math_ops.logical_and( 4318 is_id_valid, 4319 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) 4320 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) 4321 if sparse_weights is not None: 4322 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) 4323 return sparse_ids, sparse_weights 4324 4325 4326def _prune_invalid_weights(sparse_ids, sparse_weights): 4327 """Prune invalid weights (< 0) from the input ids and weights.""" 4328 if sparse_weights is not None: 4329 is_weights_valid = math_ops.greater(sparse_weights.values, 0) 4330 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) 4331 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) 4332 return sparse_ids, sparse_weights 4333 4334 4335class IndicatorColumn( 4336 DenseColumn, 4337 SequenceDenseColumn, 4338 fc_old._DenseColumn, # pylint: disable=protected-access 4339 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 4340 collections.namedtuple('IndicatorColumn', ('categorical_column'))): 4341 """Represents a one-hot column for use in deep networks. 4342 4343 Args: 4344 categorical_column: A `CategoricalColumn` which is created by 4345 `categorical_column_with_*` function. 4346 """ 4347 4348 @property 4349 def _is_v2_column(self): 4350 return (isinstance(self.categorical_column, FeatureColumn) and 4351 self.categorical_column._is_v2_column) # pylint: disable=protected-access 4352 4353 @property 4354 def name(self): 4355 """See `FeatureColumn` base class.""" 4356 return '{}_indicator'.format(self.categorical_column.name) 4357 4358 def _transform_id_weight_pair(self, id_weight_pair, size): 4359 id_tensor = id_weight_pair.id_tensor 4360 weight_tensor = id_weight_pair.weight_tensor 4361 4362 # If the underlying column is weighted, return the input as a dense tensor. 4363 if weight_tensor is not None: 4364 weighted_column = sparse_ops.sparse_merge( 4365 sp_ids=id_tensor, sp_values=weight_tensor, vocab_size=int(size)) 4366 # Remove (?, -1) index. 4367 weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0], 4368 weighted_column.dense_shape) 4369 # Use scatter_nd to merge duplicated indices if existed, 4370 # instead of sparse_tensor_to_dense. 4371 return array_ops.scatter_nd(weighted_column.indices, 4372 weighted_column.values, 4373 weighted_column.dense_shape) 4374 4375 dense_id_tensor = sparse_ops.sparse_tensor_to_dense( 4376 id_tensor, default_value=-1) 4377 4378 # One hot must be float for tf.concat reasons since all other inputs to 4379 # input_layer are float32. 4380 one_hot_id_tensor = array_ops.one_hot( 4381 dense_id_tensor, depth=size, on_value=1.0, off_value=0.0) 4382 4383 # Reduce to get a multi-hot per example. 4384 return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2]) 4385 4386 def transform_feature(self, transformation_cache, state_manager): 4387 """Returns dense `Tensor` representing feature. 4388 4389 Args: 4390 transformation_cache: A `FeatureTransformationCache` object to access 4391 features. 4392 state_manager: A `StateManager` to create / access resources such as 4393 lookup tables. 4394 4395 Returns: 4396 Transformed feature `Tensor`. 4397 4398 Raises: 4399 ValueError: if input rank is not known at graph building time. 4400 """ 4401 id_weight_pair = self.categorical_column.get_sparse_tensors( 4402 transformation_cache, state_manager) 4403 return self._transform_id_weight_pair(id_weight_pair, 4404 self.variable_shape[-1]) 4405 4406 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4407 _FEATURE_COLUMN_DEPRECATION) 4408 def _transform_feature(self, inputs): 4409 id_weight_pair = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4410 return self._transform_id_weight_pair(id_weight_pair, 4411 self._variable_shape[-1]) 4412 4413 @property 4414 def parse_example_spec(self): 4415 """See `FeatureColumn` base class.""" 4416 return self.categorical_column.parse_example_spec 4417 4418 @property 4419 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4420 _FEATURE_COLUMN_DEPRECATION) 4421 def _parse_example_spec(self): 4422 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 4423 4424 @property 4425 def variable_shape(self): 4426 """Returns a `TensorShape` representing the shape of the dense `Tensor`.""" 4427 if isinstance(self.categorical_column, FeatureColumn): 4428 return tensor_shape.TensorShape([1, self.categorical_column.num_buckets]) 4429 else: 4430 return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access 4431 4432 @property 4433 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4434 _FEATURE_COLUMN_DEPRECATION) 4435 def _variable_shape(self): 4436 return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access 4437 4438 def get_dense_tensor(self, transformation_cache, state_manager): 4439 """Returns dense `Tensor` representing feature. 4440 4441 Args: 4442 transformation_cache: A `FeatureTransformationCache` object to access 4443 features. 4444 state_manager: A `StateManager` to create / access resources such as 4445 lookup tables. 4446 4447 Returns: 4448 Dense `Tensor` created within `transform_feature`. 4449 4450 Raises: 4451 ValueError: If `categorical_column` is a `SequenceCategoricalColumn`. 4452 """ 4453 if isinstance(self.categorical_column, SequenceCategoricalColumn): 4454 raise ValueError( 4455 'In indicator_column: {}. ' 4456 'categorical_column must not be of type SequenceCategoricalColumn. ' 4457 'Suggested fix A: If you wish to use DenseFeatures, use a ' 4458 'non-sequence categorical_column_with_*. ' 4459 'Suggested fix B: If you wish to create sequence input, use ' 4460 'SequenceFeatures instead of DenseFeatures. ' 4461 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4462 self.categorical_column)) 4463 # Feature has been already transformed. Return the intermediate 4464 # representation created by transform_feature. 4465 return transformation_cache.get(self, state_manager) 4466 4467 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4468 _FEATURE_COLUMN_DEPRECATION) 4469 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 4470 del weight_collections 4471 del trainable 4472 if isinstance( 4473 self.categorical_column, 4474 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 4475 raise ValueError( 4476 'In indicator_column: {}. ' 4477 'categorical_column must not be of type _SequenceCategoricalColumn. ' 4478 'Suggested fix A: If you wish to use DenseFeatures, use a ' 4479 'non-sequence categorical_column_with_*. ' 4480 'Suggested fix B: If you wish to create sequence input, use ' 4481 'SequenceFeatures instead of DenseFeatures. ' 4482 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4483 self.categorical_column)) 4484 # Feature has been already transformed. Return the intermediate 4485 # representation created by transform_feature. 4486 return inputs.get(self) 4487 4488 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 4489 """See `SequenceDenseColumn` base class.""" 4490 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 4491 raise ValueError( 4492 'In indicator_column: {}. ' 4493 'categorical_column must be of type SequenceCategoricalColumn ' 4494 'to use SequenceFeatures. ' 4495 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 4496 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4497 self.categorical_column)) 4498 # Feature has been already transformed. Return the intermediate 4499 # representation created by transform_feature. 4500 dense_tensor = transformation_cache.get(self, state_manager) 4501 sparse_tensors = self.categorical_column.get_sparse_tensors( 4502 transformation_cache, state_manager) 4503 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 4504 sparse_tensors.id_tensor) 4505 return SequenceDenseColumn.TensorSequenceLengthPair( 4506 dense_tensor=dense_tensor, sequence_length=sequence_length) 4507 4508 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4509 _FEATURE_COLUMN_DEPRECATION) 4510 def _get_sequence_dense_tensor(self, 4511 inputs, 4512 weight_collections=None, 4513 trainable=None): 4514 # Do nothing with weight_collections and trainable since no variables are 4515 # created in this function. 4516 del weight_collections 4517 del trainable 4518 if not isinstance( 4519 self.categorical_column, 4520 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 4521 raise ValueError( 4522 'In indicator_column: {}. ' 4523 'categorical_column must be of type _SequenceCategoricalColumn ' 4524 'to use SequenceFeatures. ' 4525 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 4526 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4527 self.categorical_column)) 4528 # Feature has been already transformed. Return the intermediate 4529 # representation created by _transform_feature. 4530 dense_tensor = inputs.get(self) 4531 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4532 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 4533 sparse_tensors.id_tensor) 4534 return SequenceDenseColumn.TensorSequenceLengthPair( 4535 dense_tensor=dense_tensor, sequence_length=sequence_length) 4536 4537 @property 4538 def parents(self): 4539 """See 'FeatureColumn` base class.""" 4540 return [self.categorical_column] 4541 4542 def get_config(self): 4543 """See 'FeatureColumn` base class.""" 4544 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4545 config = dict(zip(self._fields, self)) 4546 config['categorical_column'] = serialize_feature_column( 4547 self.categorical_column) 4548 return config 4549 4550 @classmethod 4551 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4552 """See 'FeatureColumn` base class.""" 4553 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4554 _check_config_keys(config, cls._fields) 4555 kwargs = _standardize_and_copy_config(config) 4556 kwargs['categorical_column'] = deserialize_feature_column( 4557 config['categorical_column'], custom_objects, columns_by_name) 4558 return cls(**kwargs) 4559 4560 4561def _verify_static_batch_size_equality(tensors, columns): 4562 """Verify equality between static batch sizes. 4563 4564 Args: 4565 tensors: iterable of input tensors. 4566 columns: Corresponding feature columns. 4567 4568 Raises: 4569 ValueError: in case of mismatched batch sizes. 4570 """ 4571 # bath_size is a Dimension object. 4572 expected_batch_size = None 4573 for i in range(0, len(tensors)): 4574 batch_size = tensor_shape.Dimension(tensor_shape.dimension_value( 4575 tensors[i].shape[0])) 4576 if batch_size.value is not None: 4577 if expected_batch_size is None: 4578 bath_size_column_index = i 4579 expected_batch_size = batch_size 4580 elif not expected_batch_size.is_compatible_with(batch_size): 4581 raise ValueError( 4582 'Batch size (first dimension) of each feature must be same. ' 4583 'Batch size of columns ({}, {}): ({}, {})'.format( 4584 columns[bath_size_column_index].name, columns[i].name, 4585 expected_batch_size, batch_size)) 4586 4587 4588class SequenceCategoricalColumn( 4589 CategoricalColumn, 4590 fc_old._SequenceCategoricalColumn, # pylint: disable=protected-access 4591 collections.namedtuple('SequenceCategoricalColumn', 4592 ('categorical_column'))): 4593 """Represents sequences of categorical data.""" 4594 4595 @property 4596 def _is_v2_column(self): 4597 return (isinstance(self.categorical_column, FeatureColumn) and 4598 self.categorical_column._is_v2_column) # pylint: disable=protected-access 4599 4600 @property 4601 def name(self): 4602 """See `FeatureColumn` base class.""" 4603 return self.categorical_column.name 4604 4605 @property 4606 def parse_example_spec(self): 4607 """See `FeatureColumn` base class.""" 4608 return self.categorical_column.parse_example_spec 4609 4610 @property 4611 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4612 _FEATURE_COLUMN_DEPRECATION) 4613 def _parse_example_spec(self): 4614 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 4615 4616 def transform_feature(self, transformation_cache, state_manager): 4617 """See `FeatureColumn` base class.""" 4618 return self.categorical_column.transform_feature(transformation_cache, 4619 state_manager) 4620 4621 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4622 _FEATURE_COLUMN_DEPRECATION) 4623 def _transform_feature(self, inputs): 4624 return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access 4625 4626 @property 4627 def num_buckets(self): 4628 """Returns number of buckets in this sparse feature.""" 4629 return self.categorical_column.num_buckets 4630 4631 @property 4632 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4633 _FEATURE_COLUMN_DEPRECATION) 4634 def _num_buckets(self): 4635 return self.categorical_column._num_buckets # pylint: disable=protected-access 4636 4637 def _get_sparse_tensors_helper(self, sparse_tensors): 4638 id_tensor = sparse_tensors.id_tensor 4639 weight_tensor = sparse_tensors.weight_tensor 4640 # Expands third dimension, if necessary so that embeddings are not 4641 # combined during embedding lookup. If the tensor is already 3D, leave 4642 # as-is. 4643 shape = array_ops.shape(id_tensor) 4644 # Compute the third dimension explicitly instead of setting it to -1, as 4645 # that doesn't work for dynamically shaped tensors with 0-length at runtime. 4646 # This happens for empty sequences. 4647 target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])] 4648 id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape) 4649 if weight_tensor is not None: 4650 weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape) 4651 return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) 4652 4653 def get_sparse_tensors(self, transformation_cache, state_manager): 4654 """Returns an IdWeightPair. 4655 4656 `IdWeightPair` is a pair of `SparseTensor`s which represents ids and 4657 weights. 4658 4659 `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets` 4660 `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a 4661 `SparseTensor` of `float` or `None` to indicate all weights should be 4662 taken to be 1. If specified, `weight_tensor` must have exactly the same 4663 shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing 4664 output of a `VarLenFeature` which is a ragged matrix. 4665 4666 Args: 4667 transformation_cache: A `FeatureTransformationCache` object to access 4668 features. 4669 state_manager: A `StateManager` to create / access resources such as 4670 lookup tables. 4671 """ 4672 sparse_tensors = self.categorical_column.get_sparse_tensors( 4673 transformation_cache, state_manager) 4674 return self._get_sparse_tensors_helper(sparse_tensors) 4675 4676 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4677 _FEATURE_COLUMN_DEPRECATION) 4678 def _get_sparse_tensors(self, inputs, weight_collections=None, 4679 trainable=None): 4680 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4681 return self._get_sparse_tensors_helper(sparse_tensors) 4682 4683 @property 4684 def parents(self): 4685 """See 'FeatureColumn` base class.""" 4686 return [self.categorical_column] 4687 4688 def get_config(self): 4689 """See 'FeatureColumn` base class.""" 4690 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4691 config = dict(zip(self._fields, self)) 4692 config['categorical_column'] = serialize_feature_column( 4693 self.categorical_column) 4694 return config 4695 4696 @classmethod 4697 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4698 """See 'FeatureColumn` base class.""" 4699 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4700 _check_config_keys(config, cls._fields) 4701 kwargs = _standardize_and_copy_config(config) 4702 kwargs['categorical_column'] = deserialize_feature_column( 4703 config['categorical_column'], custom_objects, columns_by_name) 4704 return cls(**kwargs) 4705 4706 4707def _check_config_keys(config, expected_keys): 4708 """Checks that a config has all expected_keys.""" 4709 if set(config.keys()) != set(expected_keys): 4710 raise ValueError('Invalid config: {}, expected keys: {}'.format( 4711 config, expected_keys)) 4712 4713 4714def _standardize_and_copy_config(config): 4715 """Returns a shallow copy of config with lists turned to tuples. 4716 4717 Keras serialization uses nest to listify everything. 4718 This causes problems with the NumericColumn shape, which becomes 4719 unhashable. We could try to solve this on the Keras side, but that 4720 would require lots of tracking to avoid changing existing behavior. 4721 Instead, we ensure here that we revive correctly. 4722 4723 Args: 4724 config: dict that will be used to revive a Feature Column 4725 4726 Returns: 4727 Shallow copy of config with lists turned to tuples. 4728 """ 4729 kwargs = config.copy() 4730 for k, v in kwargs.items(): 4731 if isinstance(v, list): 4732 kwargs[k] = tuple(v) 4733 4734 return kwargs 4735 4736 4737def _sanitize_column_name_for_variable_scope(name): 4738 """Sanitizes user-provided feature names for use as variable scopes.""" 4739 invalid_char = re.compile('[^A-Za-z0-9_.\\-]') 4740 return invalid_char.sub('_', name) 4741