1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This API defines FeatureColumn abstraction. 16 17FeatureColumns provide a high level abstraction for ingesting and representing 18features. FeatureColumns are also the primary way of encoding features for 19canned `tf.estimator.Estimator`s. 20 21When using FeatureColumns with `Estimators`, the type of feature column you 22should choose depends on (1) the feature type and (2) the model type. 23 241. Feature type: 25 26 * Continuous features can be represented by `numeric_column`. 27 * Categorical features can be represented by any `categorical_column_with_*` 28 column: 29 - `categorical_column_with_vocabulary_list` 30 - `categorical_column_with_vocabulary_file` 31 - `categorical_column_with_hash_bucket` 32 - `categorical_column_with_identity` 33 - `weighted_categorical_column` 34 352. Model type: 36 37 * Deep neural network models (`DNNClassifier`, `DNNRegressor`). 38 39 Continuous features can be directly fed into deep neural network models. 40 41 age_column = numeric_column("age") 42 43 To feed sparse features into DNN models, wrap the column with 44 `embedding_column` or `indicator_column`. `indicator_column` is recommended 45 for features with only a few possible values. For features with many 46 possible values, to reduce the size of your model, `embedding_column` is 47 recommended. 48 49 embedded_dept_column = embedding_column( 50 categorical_column_with_vocabulary_list( 51 "department", ["math", "philosophy", ...]), dimension=10) 52 53 * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`). 54 55 Sparse features can be fed directly into linear models. They behave like an 56 indicator column but with an efficient implementation. 57 58 dept_column = categorical_column_with_vocabulary_list("department", 59 ["math", "philosophy", "english"]) 60 61 It is recommended that continuous features be bucketized before being 62 fed into linear models. 63 64 bucketized_age_column = bucketized_column( 65 source_column=age_column, 66 boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) 67 68 Sparse features can be crossed (also known as conjuncted or combined) in 69 order to form non-linearities, and then fed into linear models. 70 71 cross_dept_age_column = crossed_column( 72 columns=["department", bucketized_age_column], 73 hash_bucket_size=1000) 74 75Example of building canned `Estimator`s using FeatureColumns: 76 77 ```python 78 # Define features and transformations 79 deep_feature_columns = [age_column, embedded_dept_column] 80 wide_feature_columns = [dept_column, bucketized_age_column, 81 cross_dept_age_column] 82 83 # Build deep model 84 estimator = DNNClassifier( 85 feature_columns=deep_feature_columns, 86 hidden_units=[500, 250, 50]) 87 estimator.train(...) 88 89 # Or build a wide model 90 estimator = LinearClassifier( 91 feature_columns=wide_feature_columns) 92 estimator.train(...) 93 94 # Or build a wide and deep model! 95 estimator = DNNLinearCombinedClassifier( 96 linear_feature_columns=wide_feature_columns, 97 dnn_feature_columns=deep_feature_columns, 98 dnn_hidden_units=[500, 250, 50]) 99 estimator.train(...) 100 ``` 101 102 103FeatureColumns can also be transformed into a generic input layer for 104custom models using `input_layer`. 105 106Example of building model using FeatureColumns, this can be used in a 107`model_fn` which is given to the {tf.estimator.Estimator}: 108 109 ```python 110 # Building model via layers 111 112 deep_feature_columns = [age_column, embedded_dept_column] 113 columns_to_tensor = parse_feature_columns_from_examples( 114 serialized=my_data, 115 feature_columns=deep_feature_columns) 116 first_layer = input_layer( 117 features=columns_to_tensor, 118 feature_columns=deep_feature_columns) 119 second_layer = fully_connected(first_layer, ...) 120 ``` 121 122NOTE: Functions prefixed with "_" indicate experimental or private parts of 123the API subject to change, and should not be relied upon! 124""" 125 126from __future__ import absolute_import 127from __future__ import division 128from __future__ import print_function 129 130import abc 131import collections 132import math 133import re 134 135import numpy as np 136import six 137 138from tensorflow.python.eager import context 139from tensorflow.python.feature_column import feature_column as fc_old 140from tensorflow.python.feature_column import utils as fc_utils 141from tensorflow.python.framework import dtypes 142from tensorflow.python.framework import ops 143from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 144from tensorflow.python.framework import tensor_shape 145from tensorflow.python.ops import array_ops 146from tensorflow.python.ops import check_ops 147from tensorflow.python.ops import control_flow_ops 148from tensorflow.python.ops import embedding_ops 149from tensorflow.python.ops import init_ops 150from tensorflow.python.ops import lookup_ops 151from tensorflow.python.ops import math_ops 152from tensorflow.python.ops import parsing_ops 153from tensorflow.python.ops import sparse_ops 154from tensorflow.python.ops import string_ops 155from tensorflow.python.ops import variable_scope 156from tensorflow.python.ops import variables 157from tensorflow.python.platform import gfile 158from tensorflow.python.platform import tf_logging as logging 159from tensorflow.python.training import checkpoint_utils 160from tensorflow.python.training.tracking import base as trackable 161from tensorflow.python.training.tracking import data_structures 162from tensorflow.python.training.tracking import tracking 163from tensorflow.python.util import deprecation 164from tensorflow.python.util import nest 165from tensorflow.python.util import tf_inspect 166from tensorflow.python.util.compat import collections_abc 167from tensorflow.python.util.tf_export import tf_export 168 169 170_FEATURE_COLUMN_DEPRECATION_DATE = None 171_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being ' 172 'deprecated. Please use the new FeatureColumn ' 173 'APIs instead.') 174 175 176class StateManager(object): 177 """Manages the state associated with FeatureColumns. 178 179 Some `FeatureColumn`s create variables or resources to assist their 180 computation. The `StateManager` is responsible for creating and storing these 181 objects since `FeatureColumn`s are supposed to be stateless configuration 182 only. 183 """ 184 185 def create_variable(self, 186 feature_column, 187 name, 188 shape, 189 dtype=None, 190 trainable=True, 191 use_resource=True, 192 initializer=None): 193 """Creates a new variable. 194 195 Args: 196 feature_column: A `FeatureColumn` object this variable corresponds to. 197 name: variable name. 198 shape: variable shape. 199 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 200 trainable: Whether this variable is trainable or not. 201 use_resource: If true, we use resource variables. Otherwise we use 202 RefVariable. 203 initializer: initializer instance (callable). 204 205 Returns: 206 The created variable. 207 """ 208 del feature_column, name, shape, dtype, trainable, use_resource, initializer 209 raise NotImplementedError('StateManager.create_variable') 210 211 def add_variable(self, feature_column, var): 212 """Adds an existing variable to the state. 213 214 Args: 215 feature_column: A `FeatureColumn` object to associate this variable with. 216 var: The variable. 217 """ 218 del feature_column, var 219 raise NotImplementedError('StateManager.add_variable') 220 221 def get_variable(self, feature_column, name): 222 """Returns an existing variable. 223 224 Args: 225 feature_column: A `FeatureColumn` object this variable corresponds to. 226 name: variable name. 227 """ 228 del feature_column, name 229 raise NotImplementedError('StateManager.get_var') 230 231 def add_resource(self, feature_column, name, resource): 232 """Creates a new resource. 233 234 Resources can be things such as tables, variables, trackables, etc. 235 236 Args: 237 feature_column: A `FeatureColumn` object this resource corresponds to. 238 name: Name of the resource. 239 resource: The resource. 240 241 Returns: 242 The created resource. 243 """ 244 del feature_column, name, resource 245 raise NotImplementedError('StateManager.add_resource') 246 247 def has_resource(self, feature_column, name): 248 """Returns true iff a resource with same name exists. 249 250 Resources can be things such as tables, variables, trackables, etc. 251 252 Args: 253 feature_column: A `FeatureColumn` object this variable corresponds to. 254 name: Name of the resource. 255 """ 256 del feature_column, name 257 raise NotImplementedError('StateManager.has_resource') 258 259 def get_resource(self, feature_column, name): 260 """Returns an already created resource. 261 262 Resources can be things such as tables, variables, trackables, etc. 263 264 Args: 265 feature_column: A `FeatureColumn` object this variable corresponds to. 266 name: Name of the resource. 267 """ 268 del feature_column, name 269 raise NotImplementedError('StateManager.get_resource') 270 271 272class _StateManagerImpl(StateManager): 273 """Manages the state of DenseFeatures and LinearLayer.""" 274 275 def __init__(self, layer, trainable): 276 """Creates an _StateManagerImpl object. 277 278 Args: 279 layer: The input layer this state manager is associated with. 280 trainable: Whether by default, variables created are trainable or not. 281 """ 282 self._trainable = trainable 283 self._layer = layer 284 if self._layer is not None and not hasattr(self._layer, '_resources'): 285 self._layer._resources = data_structures.Mapping() # pylint: disable=protected-access 286 self._cols_to_vars_map = collections.defaultdict(lambda: {}) 287 self._cols_to_resources_map = collections.defaultdict(lambda: {}) 288 289 def create_variable(self, 290 feature_column, 291 name, 292 shape, 293 dtype=None, 294 trainable=True, 295 use_resource=True, 296 initializer=None): 297 if name in self._cols_to_vars_map[feature_column]: 298 raise ValueError('Variable already exists.') 299 300 # We explicitly track these variables since `name` is not guaranteed to be 301 # unique and disable manual tracking that the add_weight call does. 302 with trackable.no_manual_dependency_tracking_scope(self._layer): 303 var = self._layer.add_weight( 304 name=name, 305 shape=shape, 306 dtype=dtype, 307 initializer=initializer, 308 trainable=self._trainable and trainable, 309 use_resource=use_resource, 310 # TODO(rohanj): Get rid of this hack once we have a mechanism for 311 # specifying a default partitioner for an entire layer. In that case, 312 # the default getter for Layers should work. 313 getter=variable_scope.get_variable) 314 if isinstance(var, variables.PartitionedVariable): 315 for v in var: 316 part_name = name + '/' + str(v._get_save_slice_info().var_offset[0]) # pylint: disable=protected-access 317 self._layer._track_trackable(v, feature_column.name + '/' + part_name) # pylint: disable=protected-access 318 else: 319 if isinstance(var, trackable.Trackable): 320 self._layer._track_trackable(var, feature_column.name + '/' + name) # pylint: disable=protected-access 321 322 self._cols_to_vars_map[feature_column][name] = var 323 return var 324 325 def get_variable(self, feature_column, name): 326 if name in self._cols_to_vars_map[feature_column]: 327 return self._cols_to_vars_map[feature_column][name] 328 raise ValueError('Variable does not exist.') 329 330 def add_resource(self, feature_column, resource_name, resource): 331 self._cols_to_resources_map[feature_column][resource_name] = resource 332 # pylint: disable=protected-access 333 if self._layer is not None and isinstance(resource, trackable.Trackable): 334 # Add trackable resources to the layer for serialization. 335 if feature_column.name not in self._layer._resources: 336 self._layer._resources[feature_column.name] = data_structures.Mapping() 337 if resource_name not in self._layer._resources[feature_column.name]: 338 self._layer._resources[feature_column.name][resource_name] = resource 339 # pylint: enable=protected-access 340 341 def has_resource(self, feature_column, resource_name): 342 return resource_name in self._cols_to_resources_map[feature_column] 343 344 def get_resource(self, feature_column, resource_name): 345 if (feature_column not in self._cols_to_resources_map or 346 resource_name not in self._cols_to_resources_map[feature_column]): 347 raise ValueError('Resource does not exist.') 348 return self._cols_to_resources_map[feature_column][resource_name] 349 350 351class _StateManagerImplV2(_StateManagerImpl): 352 """Manages the state of DenseFeatures.""" 353 354 def create_variable(self, 355 feature_column, 356 name, 357 shape, 358 dtype=None, 359 trainable=True, 360 use_resource=True, 361 initializer=None): 362 if name in self._cols_to_vars_map[feature_column]: 363 raise ValueError('Variable already exists.') 364 365 # We explicitly track these variables since `name` is not guaranteed to be 366 # unique and disable manual tracking that the add_weight call does. 367 with trackable.no_manual_dependency_tracking_scope(self._layer): 368 var = self._layer.add_weight( 369 name=name, 370 shape=shape, 371 dtype=dtype, 372 initializer=initializer, 373 trainable=self._trainable and trainable, 374 use_resource=use_resource) 375 if isinstance(var, trackable.Trackable): 376 self._layer._track_trackable(var, feature_column.name + '/' + name) # pylint: disable=protected-access 377 self._cols_to_vars_map[feature_column][name] = var 378 return var 379 380 381def _transform_features_v2(features, feature_columns, state_manager): 382 """Returns transformed features based on features columns passed in. 383 384 Please note that most probably you would not need to use this function. Please 385 check `input_layer` and `linear_model` to see whether they will 386 satisfy your use case or not. 387 388 Example: 389 390 ```python 391 # Define features and transformations 392 crosses_a_x_b = crossed_column( 393 columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000) 394 price_buckets = bucketized_column( 395 source_column=numeric_column("price"), boundaries=[...]) 396 397 columns = [crosses_a_x_b, price_buckets] 398 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 399 transformed = transform_features(features=features, feature_columns=columns) 400 401 assertCountEqual(columns, transformed.keys()) 402 ``` 403 404 Args: 405 features: A mapping from key to tensors. `FeatureColumn`s look up via these 406 keys. For example `numeric_column('price')` will look at 'price' key in 407 this dict. Values can be a `SparseTensor` or a `Tensor` depends on 408 corresponding `FeatureColumn`. 409 feature_columns: An iterable containing all the `FeatureColumn`s. 410 state_manager: A StateManager object that holds the FeatureColumn state. 411 412 Returns: 413 A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values. 414 """ 415 feature_columns = _normalize_feature_columns(feature_columns) 416 outputs = {} 417 with ops.name_scope( 418 None, default_name='transform_features', values=features.values()): 419 transformation_cache = FeatureTransformationCache(features) 420 for column in feature_columns: 421 with ops.name_scope( 422 None, 423 default_name=_sanitize_column_name_for_variable_scope(column.name)): 424 outputs[column] = transformation_cache.get(column, state_manager) 425 return outputs 426 427 428@tf_export('feature_column.make_parse_example_spec', v1=[]) 429def make_parse_example_spec_v2(feature_columns): 430 """Creates parsing spec dictionary from input feature_columns. 431 432 The returned dictionary can be used as arg 'features' in 433 `tf.io.parse_example`. 434 435 Typical usage example: 436 437 ```python 438 # Define features and transformations 439 feature_a = tf.feature_column.categorical_column_with_vocabulary_file(...) 440 feature_b = tf.feature_column.numeric_column(...) 441 feature_c_bucketized = tf.feature_column.bucketized_column( 442 tf.feature_column.numeric_column("feature_c"), ...) 443 feature_a_x_feature_c = tf.feature_column.crossed_column( 444 columns=["feature_a", feature_c_bucketized], ...) 445 446 feature_columns = set( 447 [feature_b, feature_c_bucketized, feature_a_x_feature_c]) 448 features = tf.io.parse_example( 449 serialized=serialized_examples, 450 features=tf.feature_column.make_parse_example_spec(feature_columns)) 451 ``` 452 453 For the above example, make_parse_example_spec would return the dict: 454 455 ```python 456 { 457 "feature_a": parsing_ops.VarLenFeature(tf.string), 458 "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32), 459 "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32) 460 } 461 ``` 462 463 Args: 464 feature_columns: An iterable containing all feature columns. All items 465 should be instances of classes derived from `FeatureColumn`. 466 467 Returns: 468 A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature` 469 value. 470 471 Raises: 472 ValueError: If any of the given `feature_columns` is not a `FeatureColumn` 473 instance. 474 """ 475 result = {} 476 for column in feature_columns: 477 if not isinstance(column, FeatureColumn): 478 raise ValueError('All feature_columns must be FeatureColumn instances. ' 479 'Given: {}'.format(column)) 480 config = column.parse_example_spec 481 for key, value in six.iteritems(config): 482 if key in result and value != result[key]: 483 raise ValueError( 484 'feature_columns contain different parse_spec for key ' 485 '{}. Given {} and {}'.format(key, value, result[key])) 486 result.update(config) 487 return result 488 489 490@tf_export('feature_column.embedding_column') 491def embedding_column(categorical_column, 492 dimension, 493 combiner='mean', 494 initializer=None, 495 ckpt_to_load_from=None, 496 tensor_name_in_ckpt=None, 497 max_norm=None, 498 trainable=True, 499 use_safe_embedding_lookup=True): 500 """`DenseColumn` that converts from sparse, categorical input. 501 502 Use this when your inputs are sparse, but you want to convert them to a dense 503 representation (e.g., to feed to a DNN). 504 505 Inputs must be a `CategoricalColumn` created by any of the 506 `categorical_column_*` function. Here is an example of using 507 `embedding_column` with `DNNClassifier`: 508 509 ```python 510 video_id = categorical_column_with_identity( 511 key='video_id', num_buckets=1000000, default_value=0) 512 columns = [embedding_column(video_id, 9),...] 513 514 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 515 516 label_column = ... 517 def input_fn(): 518 features = tf.io.parse_example( 519 ..., features=make_parse_example_spec(columns + [label_column])) 520 labels = features.pop(label_column.name) 521 return features, labels 522 523 estimator.train(input_fn=input_fn, steps=100) 524 ``` 525 526 Here is an example using `embedding_column` with model_fn: 527 528 ```python 529 def model_fn(features, ...): 530 video_id = categorical_column_with_identity( 531 key='video_id', num_buckets=1000000, default_value=0) 532 columns = [embedding_column(video_id, 9),...] 533 dense_tensor = input_layer(features, columns) 534 # Form DNN layers, calculate loss, and return EstimatorSpec. 535 ... 536 ``` 537 538 Args: 539 categorical_column: A `CategoricalColumn` created by a 540 `categorical_column_with_*` function. This column produces the sparse IDs 541 that are inputs to the embedding lookup. 542 dimension: An integer specifying dimension of the embedding, must be > 0. 543 combiner: A string specifying how to reduce if there are multiple entries in 544 a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 545 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 546 with bag-of-words columns. Each of this can be thought as example level 547 normalizations on the column. For more information, see 548 `tf.embedding_lookup_sparse`. 549 initializer: A variable initializer function to be used in embedding 550 variable initialization. If not specified, defaults to 551 `truncated_normal_initializer` with mean `0.0` and 552 standard deviation `1/sqrt(dimension)`. 553 ckpt_to_load_from: String representing checkpoint name/pattern from which to 554 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 555 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which 556 to restore the column weights. Required if `ckpt_to_load_from` is not 557 `None`. 558 max_norm: If not `None`, embedding values are l2-normalized to this value. 559 trainable: Whether or not the embedding is trainable. Default is True. 560 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 561 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 562 there are no empty rows and all weights and ids are positive at the 563 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 564 input tensors. Defaults to true, consider turning off if the above checks 565 are not needed. Note that having empty rows will not trigger any error 566 though the output result might be 0 or omitted. 567 568 Returns: 569 `DenseColumn` that converts from sparse input. 570 571 Raises: 572 ValueError: if `dimension` not > 0. 573 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 574 is specified. 575 ValueError: if `initializer` is specified and is not callable. 576 RuntimeError: If eager execution is enabled. 577 """ 578 if (dimension is None) or (dimension < 1): 579 raise ValueError('Invalid dimension {}.'.format(dimension)) 580 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 581 raise ValueError('Must specify both `ckpt_to_load_from` and ' 582 '`tensor_name_in_ckpt` or none of them.') 583 584 if (initializer is not None) and (not callable(initializer)): 585 raise ValueError('initializer must be callable if specified. ' 586 'Embedding of column_name: {}'.format( 587 categorical_column.name)) 588 if initializer is None: 589 initializer = init_ops.truncated_normal_initializer( 590 mean=0.0, stddev=1 / math.sqrt(dimension)) 591 592 return EmbeddingColumn( 593 categorical_column=categorical_column, 594 dimension=dimension, 595 combiner=combiner, 596 initializer=initializer, 597 ckpt_to_load_from=ckpt_to_load_from, 598 tensor_name_in_ckpt=tensor_name_in_ckpt, 599 max_norm=max_norm, 600 trainable=trainable, 601 use_safe_embedding_lookup=use_safe_embedding_lookup) 602 603 604@tf_export(v1=['feature_column.shared_embedding_columns']) 605def shared_embedding_columns(categorical_columns, 606 dimension, 607 combiner='mean', 608 initializer=None, 609 shared_embedding_collection_name=None, 610 ckpt_to_load_from=None, 611 tensor_name_in_ckpt=None, 612 max_norm=None, 613 trainable=True, 614 use_safe_embedding_lookup=True): 615 """List of dense columns that convert from sparse, categorical input. 616 617 This is similar to `embedding_column`, except that it produces a list of 618 embedding columns that share the same embedding weights. 619 620 Use this when your inputs are sparse and of the same type (e.g. watched and 621 impression video IDs that share the same vocabulary), and you want to convert 622 them to a dense representation (e.g., to feed to a DNN). 623 624 Inputs must be a list of categorical columns created by any of the 625 `categorical_column_*` function. They must all be of the same type and have 626 the same arguments except `key`. E.g. they can be 627 categorical_column_with_vocabulary_file with the same vocabulary_file. Some or 628 all columns could also be weighted_categorical_column. 629 630 Here is an example embedding of two features for a DNNClassifier model: 631 632 ```python 633 watched_video_id = categorical_column_with_vocabulary_file( 634 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 635 impression_video_id = categorical_column_with_vocabulary_file( 636 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 637 columns = shared_embedding_columns( 638 [watched_video_id, impression_video_id], dimension=10) 639 640 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 641 642 label_column = ... 643 def input_fn(): 644 features = tf.io.parse_example( 645 ..., features=make_parse_example_spec(columns + [label_column])) 646 labels = features.pop(label_column.name) 647 return features, labels 648 649 estimator.train(input_fn=input_fn, steps=100) 650 ``` 651 652 Here is an example using `shared_embedding_columns` with model_fn: 653 654 ```python 655 def model_fn(features, ...): 656 watched_video_id = categorical_column_with_vocabulary_file( 657 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 658 impression_video_id = categorical_column_with_vocabulary_file( 659 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 660 columns = shared_embedding_columns( 661 [watched_video_id, impression_video_id], dimension=10) 662 dense_tensor = input_layer(features, columns) 663 # Form DNN layers, calculate loss, and return EstimatorSpec. 664 ... 665 ``` 666 667 Args: 668 categorical_columns: List of categorical columns created by a 669 `categorical_column_with_*` function. These columns produce the sparse IDs 670 that are inputs to the embedding lookup. All columns must be of the same 671 type and have the same arguments except `key`. E.g. they can be 672 categorical_column_with_vocabulary_file with the same vocabulary_file. 673 Some or all columns could also be weighted_categorical_column. 674 dimension: An integer specifying dimension of the embedding, must be > 0. 675 combiner: A string specifying how to reduce if there are multiple entries in 676 a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 677 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 678 with bag-of-words columns. Each of this can be thought as example level 679 normalizations on the column. For more information, see 680 `tf.embedding_lookup_sparse`. 681 initializer: A variable initializer function to be used in embedding 682 variable initialization. If not specified, defaults to 683 `truncated_normal_initializer` with mean `0.0` and 684 standard deviation `1/sqrt(dimension)`. 685 shared_embedding_collection_name: Optional name of the collection where 686 shared embedding weights are added. If not given, a reasonable name will 687 be chosen based on the names of `categorical_columns`. This is also used 688 in `variable_scope` when creating shared embedding weights. 689 ckpt_to_load_from: String representing checkpoint name/pattern from which to 690 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 691 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which 692 to restore the column weights. Required if `ckpt_to_load_from` is not 693 `None`. 694 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 695 than this value, before combining. 696 trainable: Whether or not the embedding is trainable. Default is True. 697 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 698 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 699 there are no empty rows and all weights and ids are positive at the 700 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 701 input tensors. Defaults to true, consider turning off if the above checks 702 are not needed. Note that having empty rows will not trigger any error 703 though the output result might be 0 or omitted. 704 705 Returns: 706 A list of dense columns that converts from sparse input. The order of 707 results follows the ordering of `categorical_columns`. 708 709 Raises: 710 ValueError: if `dimension` not > 0. 711 ValueError: if any of the given `categorical_columns` is of different type 712 or has different arguments than the others. 713 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 714 is specified. 715 ValueError: if `initializer` is specified and is not callable. 716 RuntimeError: if eager execution is enabled. 717 """ 718 if context.executing_eagerly(): 719 raise RuntimeError('shared_embedding_columns are not supported when eager ' 720 'execution is enabled.') 721 722 if (dimension is None) or (dimension < 1): 723 raise ValueError('Invalid dimension {}.'.format(dimension)) 724 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 725 raise ValueError('Must specify both `ckpt_to_load_from` and ' 726 '`tensor_name_in_ckpt` or none of them.') 727 728 if (initializer is not None) and (not callable(initializer)): 729 raise ValueError('initializer must be callable if specified.') 730 if initializer is None: 731 initializer = init_ops.truncated_normal_initializer( 732 mean=0.0, stddev=1. / math.sqrt(dimension)) 733 734 # Sort the columns so the default collection name is deterministic even if the 735 # user passes columns from an unsorted collection, such as dict.values(). 736 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 737 738 c0 = sorted_columns[0] 739 num_buckets = c0._num_buckets # pylint: disable=protected-access 740 if not isinstance(c0, fc_old._CategoricalColumn): # pylint: disable=protected-access 741 raise ValueError( 742 'All categorical_columns must be subclasses of _CategoricalColumn. ' 743 'Given: {}, of type: {}'.format(c0, type(c0))) 744 while isinstance( 745 c0, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn, # pylint: disable=protected-access 746 fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)): # pylint: disable=protected-access 747 c0 = c0.categorical_column 748 for c in sorted_columns[1:]: 749 while isinstance( 750 c, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn, # pylint: disable=protected-access 751 fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)): # pylint: disable=protected-access 752 c = c.categorical_column 753 if not isinstance(c, type(c0)): 754 raise ValueError( 755 'To use shared_embedding_column, all categorical_columns must have ' 756 'the same type, or be weighted_categorical_column or sequence column ' 757 'of the same type. Given column: {} of type: {} does not match given ' 758 'column: {} of type: {}'.format(c0, type(c0), c, type(c))) 759 if num_buckets != c._num_buckets: # pylint: disable=protected-access 760 raise ValueError( 761 'To use shared_embedding_column, all categorical_columns must have ' 762 'the same number of buckets. Given column: {} with buckets: {} does ' 763 'not match column: {} with buckets: {}'.format( 764 c0, num_buckets, c, c._num_buckets)) # pylint: disable=protected-access 765 766 if not shared_embedding_collection_name: 767 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 768 shared_embedding_collection_name += '_shared_embedding' 769 770 result = [] 771 for column in categorical_columns: 772 result.append( 773 fc_old._SharedEmbeddingColumn( # pylint: disable=protected-access 774 categorical_column=column, 775 initializer=initializer, 776 dimension=dimension, 777 combiner=combiner, 778 shared_embedding_collection_name=shared_embedding_collection_name, 779 ckpt_to_load_from=ckpt_to_load_from, 780 tensor_name_in_ckpt=tensor_name_in_ckpt, 781 max_norm=max_norm, 782 trainable=trainable, 783 use_safe_embedding_lookup=use_safe_embedding_lookup)) 784 785 return result 786 787 788@tf_export('feature_column.shared_embeddings', v1=[]) 789def shared_embedding_columns_v2(categorical_columns, 790 dimension, 791 combiner='mean', 792 initializer=None, 793 shared_embedding_collection_name=None, 794 ckpt_to_load_from=None, 795 tensor_name_in_ckpt=None, 796 max_norm=None, 797 trainable=True, 798 use_safe_embedding_lookup=True): 799 """List of dense columns that convert from sparse, categorical input. 800 801 This is similar to `embedding_column`, except that it produces a list of 802 embedding columns that share the same embedding weights. 803 804 Use this when your inputs are sparse and of the same type (e.g. watched and 805 impression video IDs that share the same vocabulary), and you want to convert 806 them to a dense representation (e.g., to feed to a DNN). 807 808 Inputs must be a list of categorical columns created by any of the 809 `categorical_column_*` function. They must all be of the same type and have 810 the same arguments except `key`. E.g. they can be 811 categorical_column_with_vocabulary_file with the same vocabulary_file. Some or 812 all columns could also be weighted_categorical_column. 813 814 Here is an example embedding of two features for a DNNClassifier model: 815 816 ```python 817 watched_video_id = categorical_column_with_vocabulary_file( 818 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 819 impression_video_id = categorical_column_with_vocabulary_file( 820 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 821 columns = shared_embedding_columns( 822 [watched_video_id, impression_video_id], dimension=10) 823 824 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 825 826 label_column = ... 827 def input_fn(): 828 features = tf.io.parse_example( 829 ..., features=make_parse_example_spec(columns + [label_column])) 830 labels = features.pop(label_column.name) 831 return features, labels 832 833 estimator.train(input_fn=input_fn, steps=100) 834 ``` 835 836 Here is an example using `shared_embedding_columns` with model_fn: 837 838 ```python 839 def model_fn(features, ...): 840 watched_video_id = categorical_column_with_vocabulary_file( 841 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 842 impression_video_id = categorical_column_with_vocabulary_file( 843 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 844 columns = shared_embedding_columns( 845 [watched_video_id, impression_video_id], dimension=10) 846 dense_tensor = input_layer(features, columns) 847 # Form DNN layers, calculate loss, and return EstimatorSpec. 848 ... 849 ``` 850 851 Args: 852 categorical_columns: List of categorical columns created by a 853 `categorical_column_with_*` function. These columns produce the sparse IDs 854 that are inputs to the embedding lookup. All columns must be of the same 855 type and have the same arguments except `key`. E.g. they can be 856 categorical_column_with_vocabulary_file with the same vocabulary_file. 857 Some or all columns could also be weighted_categorical_column. 858 dimension: An integer specifying dimension of the embedding, must be > 0. 859 combiner: A string specifying how to reduce if there are multiple entries 860 in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 861 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 862 with bag-of-words columns. Each of this can be thought as example level 863 normalizations on the column. For more information, see 864 `tf.embedding_lookup_sparse`. 865 initializer: A variable initializer function to be used in embedding 866 variable initialization. If not specified, defaults to 867 `truncated_normal_initializer` with mean `0.0` and standard 868 deviation `1/sqrt(dimension)`. 869 shared_embedding_collection_name: Optional collective name of these columns. 870 If not given, a reasonable name will be chosen based on the names of 871 `categorical_columns`. 872 ckpt_to_load_from: String representing checkpoint name/pattern from which to 873 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 874 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from 875 which to restore the column weights. Required if `ckpt_to_load_from` is 876 not `None`. 877 max_norm: If not `None`, each embedding is clipped if its l2-norm is 878 larger than this value, before combining. 879 trainable: Whether or not the embedding is trainable. Default is True. 880 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 881 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 882 there are no empty rows and all weights and ids are positive at the 883 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 884 input tensors. Defaults to true, consider turning off if the above checks 885 are not needed. Note that having empty rows will not trigger any error 886 though the output result might be 0 or omitted. 887 888 Returns: 889 A list of dense columns that converts from sparse input. The order of 890 results follows the ordering of `categorical_columns`. 891 892 Raises: 893 ValueError: if `dimension` not > 0. 894 ValueError: if any of the given `categorical_columns` is of different type 895 or has different arguments than the others. 896 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 897 is specified. 898 ValueError: if `initializer` is specified and is not callable. 899 RuntimeError: if eager execution is enabled. 900 """ 901 if context.executing_eagerly(): 902 raise RuntimeError('shared_embedding_columns are not supported when eager ' 903 'execution is enabled.') 904 905 if (dimension is None) or (dimension < 1): 906 raise ValueError('Invalid dimension {}.'.format(dimension)) 907 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 908 raise ValueError('Must specify both `ckpt_to_load_from` and ' 909 '`tensor_name_in_ckpt` or none of them.') 910 911 if (initializer is not None) and (not callable(initializer)): 912 raise ValueError('initializer must be callable if specified.') 913 if initializer is None: 914 initializer = init_ops.truncated_normal_initializer( 915 mean=0.0, stddev=1. / math.sqrt(dimension)) 916 917 # Sort the columns so the default collection name is deterministic even if the 918 # user passes columns from an unsorted collection, such as dict.values(). 919 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 920 921 c0 = sorted_columns[0] 922 num_buckets = c0.num_buckets 923 if not isinstance(c0, CategoricalColumn): 924 raise ValueError( 925 'All categorical_columns must be subclasses of CategoricalColumn. ' 926 'Given: {}, of type: {}'.format(c0, type(c0))) 927 while isinstance(c0, (WeightedCategoricalColumn, SequenceCategoricalColumn)): 928 c0 = c0.categorical_column 929 for c in sorted_columns[1:]: 930 while isinstance(c, (WeightedCategoricalColumn, SequenceCategoricalColumn)): 931 c = c.categorical_column 932 if not isinstance(c, type(c0)): 933 raise ValueError( 934 'To use shared_embedding_column, all categorical_columns must have ' 935 'the same type, or be weighted_categorical_column or sequence column ' 936 'of the same type. Given column: {} of type: {} does not match given ' 937 'column: {} of type: {}'.format(c0, type(c0), c, type(c))) 938 if num_buckets != c.num_buckets: 939 raise ValueError( 940 'To use shared_embedding_column, all categorical_columns must have ' 941 'the same number of buckets. Given column: {} with buckets: {} does ' 942 'not match column: {} with buckets: {}'.format( 943 c0, num_buckets, c, c.num_buckets)) 944 945 if not shared_embedding_collection_name: 946 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 947 shared_embedding_collection_name += '_shared_embedding' 948 949 column_creator = SharedEmbeddingColumnCreator( 950 dimension, initializer, ckpt_to_load_from, tensor_name_in_ckpt, 951 num_buckets, trainable, shared_embedding_collection_name, 952 use_safe_embedding_lookup) 953 954 result = [] 955 for column in categorical_columns: 956 result.append( 957 column_creator( 958 categorical_column=column, combiner=combiner, max_norm=max_norm)) 959 960 return result 961 962 963@tf_export('feature_column.numeric_column') 964def numeric_column(key, 965 shape=(1,), 966 default_value=None, 967 dtype=dtypes.float32, 968 normalizer_fn=None): 969 """Represents real valued or numerical features. 970 971 Example: 972 973 ```python 974 price = numeric_column('price') 975 columns = [price, ...] 976 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 977 dense_tensor = input_layer(features, columns) 978 979 # or 980 bucketized_price = bucketized_column(price, boundaries=[...]) 981 columns = [bucketized_price, ...] 982 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 983 linear_prediction = linear_model(features, columns) 984 ``` 985 986 Args: 987 key: A unique string identifying the input feature. It is used as the 988 column name and the dictionary key for feature parsing configs, feature 989 `Tensor` objects, and feature columns. 990 shape: An iterable of integers specifies the shape of the `Tensor`. An 991 integer can be given which means a single dimension `Tensor` with given 992 width. The `Tensor` representing the column will have the shape of 993 [batch_size] + `shape`. 994 default_value: A single value compatible with `dtype` or an iterable of 995 values compatible with `dtype` which the column takes on during 996 `tf.Example` parsing if data is missing. A default value of `None` will 997 cause `tf.io.parse_example` to fail if an example does not contain this 998 column. If a single value is provided, the same value will be applied as 999 the default value for every item. If an iterable of values is provided, 1000 the shape of the `default_value` should be equal to the given `shape`. 1001 dtype: defines the type of values. Default value is `tf.float32`. Must be a 1002 non-quantized, real integer or floating point type. 1003 normalizer_fn: If not `None`, a function that can be used to normalize the 1004 value of the tensor after `default_value` is applied for parsing. 1005 Normalizer function takes the input `Tensor` as its argument, and returns 1006 the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that 1007 even though the most common use case of this function is normalization, it 1008 can be used for any kind of Tensorflow transformations. 1009 1010 Returns: 1011 A `NumericColumn`. 1012 1013 Raises: 1014 TypeError: if any dimension in shape is not an int 1015 ValueError: if any dimension in shape is not a positive integer 1016 TypeError: if `default_value` is an iterable but not compatible with `shape` 1017 TypeError: if `default_value` is not compatible with `dtype`. 1018 ValueError: if `dtype` is not convertible to `tf.float32`. 1019 """ 1020 shape = _check_shape(shape, key) 1021 if not (dtype.is_integer or dtype.is_floating): 1022 raise ValueError('dtype must be convertible to float. ' 1023 'dtype: {}, key: {}'.format(dtype, key)) 1024 default_value = fc_utils.check_default_value( 1025 shape, default_value, dtype, key) 1026 1027 if normalizer_fn is not None and not callable(normalizer_fn): 1028 raise TypeError( 1029 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) 1030 1031 fc_utils.assert_key_is_string(key) 1032 return NumericColumn( 1033 key, 1034 shape=shape, 1035 default_value=default_value, 1036 dtype=dtype, 1037 normalizer_fn=normalizer_fn) 1038 1039 1040@tf_export('feature_column.bucketized_column') 1041def bucketized_column(source_column, boundaries): 1042 """Represents discretized dense input bucketed by `boundaries`. 1043 1044 Buckets include the left boundary, and exclude the right boundary. Namely, 1045 `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`, 1046 `[1., 2.)`, and `[2., +inf)`. 1047 1048 For example, if the inputs are 1049 1050 ```python 1051 boundaries = [0, 10, 100] 1052 input tensor = [[-5, 10000] 1053 [150, 10] 1054 [5, 100]] 1055 ``` 1056 1057 then the output will be 1058 1059 ```python 1060 output = [[0, 3] 1061 [3, 2] 1062 [1, 3]] 1063 ``` 1064 1065 Example: 1066 1067 ```python 1068 price = tf.feature_column.numeric_column('price') 1069 bucketized_price = tf.feature_column.bucketized_column( 1070 price, boundaries=[...]) 1071 columns = [bucketized_price, ...] 1072 features = tf.io.parse_example( 1073 ..., features=tf.feature_column.make_parse_example_spec(columns)) 1074 dense_tensor = tf.keras.layers.DenseFeatures(columns)(features) 1075 ``` 1076 1077 A `bucketized_column` can also be crossed with another categorical column 1078 using `crossed_column`: 1079 1080 ```python 1081 price = tf.feature_column.numeric_column('price') 1082 # bucketized_column converts numerical feature to a categorical one. 1083 bucketized_price = tf.feature_column.bucketized_column( 1084 price, boundaries=[...]) 1085 # 'keywords' is a string feature. 1086 price_x_keywords = tf.feature_column.crossed_column( 1087 [bucketized_price, 'keywords'], 50K) 1088 columns = [price_x_keywords, ...] 1089 features = tf.io.parse_example( 1090 ..., features=tf.feature_column.make_parse_example_spec(columns)) 1091 dense_tensor = tf.keras.layers.DenseFeatures(columns)(features) 1092 linear_model = tf.keras.experimental.LinearModel(units=...)(dense_tensor) 1093 ``` 1094 1095 Args: 1096 source_column: A one-dimensional dense column which is generated with 1097 `numeric_column`. 1098 boundaries: A sorted list or tuple of floats specifying the boundaries. 1099 1100 Returns: 1101 A `BucketizedColumn`. 1102 1103 Raises: 1104 ValueError: If `source_column` is not a numeric column, or if it is not 1105 one-dimensional. 1106 ValueError: If `boundaries` is not a sorted list or tuple. 1107 """ 1108 if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)): # pylint: disable=protected-access 1109 raise ValueError( 1110 'source_column must be a column generated with numeric_column(). ' 1111 'Given: {}'.format(source_column)) 1112 if len(source_column.shape) > 1: 1113 raise ValueError( 1114 'source_column must be one-dimensional column. ' 1115 'Given: {}'.format(source_column)) 1116 if not boundaries: 1117 raise ValueError('boundaries must not be empty.') 1118 if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)): 1119 raise ValueError('boundaries must be a sorted list.') 1120 for i in range(len(boundaries) - 1): 1121 if boundaries[i] >= boundaries[i + 1]: 1122 raise ValueError('boundaries must be a sorted list.') 1123 return BucketizedColumn(source_column, tuple(boundaries)) 1124 1125 1126@tf_export('feature_column.categorical_column_with_hash_bucket') 1127def categorical_column_with_hash_bucket(key, 1128 hash_bucket_size, 1129 dtype=dtypes.string): 1130 """Represents sparse feature where ids are set by hashing. 1131 1132 Use this when your sparse features are in string or integer format, and you 1133 want to distribute your inputs into a finite number of buckets by hashing. 1134 output_id = Hash(input_feature_string) % bucket_size for string type input. 1135 For int type input, the value is converted to its string representation first 1136 and then hashed by the same formula. 1137 1138 For input dictionary `features`, `features[key]` is either `Tensor` or 1139 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1140 and `''` for string, which will be dropped by this feature column. 1141 1142 Example: 1143 1144 ```python 1145 import tensorflow as tf 1146 keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords", 1147 10000) 1148 columns = [keywords] 1149 features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM', 1150 'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow', 1151 'LSTM', 'Keras', 'RNN']])} 1152 linear_prediction, _, _ = tf.compat.v1.feature_column.linear_model(features, 1153 columns) 1154 1155 # or 1156 import tensorflow as tf 1157 keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords", 1158 10000) 1159 keywords_embedded = tf.feature_column.embedding_column(keywords, 16) 1160 columns = [keywords_embedded] 1161 features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM', 1162 'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow', 1163 'LSTM', 'Keras', 'RNN']])} 1164 input_layer = tf.keras.layers.DenseFeatures(columns) 1165 dense_tensor = input_layer(features) 1166 ``` 1167 1168 Args: 1169 key: A unique string identifying the input feature. It is used as the 1170 column name and the dictionary key for feature parsing configs, feature 1171 `Tensor` objects, and feature columns. 1172 hash_bucket_size: An int > 1. The number of buckets. 1173 dtype: The type of features. Only string and integer types are supported. 1174 1175 Returns: 1176 A `HashedCategoricalColumn`. 1177 1178 Raises: 1179 ValueError: `hash_bucket_size` is not greater than 1. 1180 ValueError: `dtype` is neither string nor integer. 1181 """ 1182 if hash_bucket_size is None: 1183 raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key)) 1184 1185 if hash_bucket_size < 1: 1186 raise ValueError('hash_bucket_size must be at least 1. ' 1187 'hash_bucket_size: {}, key: {}'.format( 1188 hash_bucket_size, key)) 1189 1190 fc_utils.assert_key_is_string(key) 1191 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1192 1193 return HashedCategoricalColumn(key, hash_bucket_size, dtype) 1194 1195 1196@tf_export(v1=['feature_column.categorical_column_with_vocabulary_file']) 1197def categorical_column_with_vocabulary_file(key, 1198 vocabulary_file, 1199 vocabulary_size=None, 1200 num_oov_buckets=0, 1201 default_value=None, 1202 dtype=dtypes.string): 1203 """A `CategoricalColumn` with a vocabulary file. 1204 1205 Use this when your inputs are in string or integer format, and you have a 1206 vocabulary file that maps each value to an integer ID. By default, 1207 out-of-vocabulary values are ignored. Use either (but not both) of 1208 `num_oov_buckets` and `default_value` to specify how to include 1209 out-of-vocabulary values. 1210 1211 For input dictionary `features`, `features[key]` is either `Tensor` or 1212 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1213 and `''` for string, which will be dropped by this feature column. 1214 1215 Example with `num_oov_buckets`: 1216 File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state 1217 abbreviation. All inputs with values in that file are assigned an ID 0-49, 1218 corresponding to its line number. All other values are hashed and assigned an 1219 ID 50-54. 1220 1221 ```python 1222 import tensorflow as tf 1223 states = tf.feature_column.categorical_column_with_vocabulary_file( 1224 key='states', vocabulary_file='states.txt', vocabulary_size=5, 1225 num_oov_buckets=1) 1226 columns = [states] 1227 features = {'states':tf.constant([['california', 'georgia', 'michigan', 1228 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan', 1229 'texas']])} 1230 linear_prediction = tf.compat.v1.feature_column.linear_model(features, 1231 columns) 1232 ``` 1233 1234 Example with `default_value`: 1235 File '/us/states.txt' contains 51 lines - the first line is 'XX', and the 1236 other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX' 1237 in input, and other values missing from the file, will be assigned ID 0. All 1238 others are assigned the corresponding line number 1-50. 1239 1240 ```python 1241 import tensorflow as tf 1242 states = tf.feature_column.categorical_column_with_vocabulary_file( 1243 key='states', vocabulary_file='states.txt', vocabulary_size=6, 1244 default_value=0) 1245 columns = [states] 1246 features = {'states':tf.constant([['california', 'georgia', 'michigan', 1247 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan', 1248 'texas']])} 1249 linear_prediction = tf.compat.v1.feature_column.linear_model(features, 1250 columns) 1251 ``` 1252 1253 And to make an embedding with either: 1254 1255 ```python 1256 import tensorflow as tf 1257 states = tf.feature_column.categorical_column_with_vocabulary_file( 1258 key='states', vocabulary_file='states.txt', vocabulary_size=5, 1259 num_oov_buckets=1) 1260 columns = [tf.feature_column.embedding_column(states, 3)] 1261 features = {'states':tf.constant([['california', 'georgia', 'michigan', 1262 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan', 1263 'texas']])} 1264 input_layer = tf.keras.layers.DenseFeatures(columns) 1265 dense_tensor = input_layer(features) 1266 ``` 1267 1268 Args: 1269 key: A unique string identifying the input feature. It is used as the 1270 column name and the dictionary key for feature parsing configs, feature 1271 `Tensor` objects, and feature columns. 1272 vocabulary_file: The vocabulary file name. 1273 vocabulary_size: Number of the elements in the vocabulary. This must be no 1274 greater than length of `vocabulary_file`, if less than length, later 1275 values are ignored. If None, it is set to the length of `vocabulary_file`. 1276 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1277 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1278 `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of 1279 the input value. A positive `num_oov_buckets` can not be specified with 1280 `default_value`. 1281 default_value: The integer ID value to return for out-of-vocabulary feature 1282 values, defaults to `-1`. This can not be specified with a positive 1283 `num_oov_buckets`. 1284 dtype: The type of features. Only string and integer types are supported. 1285 1286 Returns: 1287 A `CategoricalColumn` with a vocabulary file. 1288 1289 Raises: 1290 ValueError: `vocabulary_file` is missing or cannot be opened. 1291 ValueError: `vocabulary_size` is missing or < 1. 1292 ValueError: `num_oov_buckets` is a negative integer. 1293 ValueError: `num_oov_buckets` and `default_value` are both specified. 1294 ValueError: `dtype` is neither string nor integer. 1295 """ 1296 return categorical_column_with_vocabulary_file_v2( 1297 key, vocabulary_file, vocabulary_size, 1298 dtype, default_value, 1299 num_oov_buckets) 1300 1301 1302@tf_export('feature_column.categorical_column_with_vocabulary_file', v1=[]) 1303def categorical_column_with_vocabulary_file_v2(key, 1304 vocabulary_file, 1305 vocabulary_size=None, 1306 dtype=dtypes.string, 1307 default_value=None, 1308 num_oov_buckets=0): 1309 """A `CategoricalColumn` with a vocabulary file. 1310 1311 Use this when your inputs are in string or integer format, and you have a 1312 vocabulary file that maps each value to an integer ID. By default, 1313 out-of-vocabulary values are ignored. Use either (but not both) of 1314 `num_oov_buckets` and `default_value` to specify how to include 1315 out-of-vocabulary values. 1316 1317 For input dictionary `features`, `features[key]` is either `Tensor` or 1318 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1319 and `''` for string, which will be dropped by this feature column. 1320 1321 Example with `num_oov_buckets`: 1322 File `'/us/states.txt'` contains 50 lines, each with a 2-character U.S. state 1323 abbreviation. All inputs with values in that file are assigned an ID 0-49, 1324 corresponding to its line number. All other values are hashed and assigned an 1325 ID 50-54. 1326 1327 ```python 1328 states = categorical_column_with_vocabulary_file( 1329 key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, 1330 num_oov_buckets=5) 1331 columns = [states, ...] 1332 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1333 linear_prediction = linear_model(features, columns) 1334 ``` 1335 1336 Example with `default_value`: 1337 File `'/us/states.txt'` contains 51 lines - the first line is `'XX'`, and the 1338 other 50 each have a 2-character U.S. state abbreviation. Both a literal 1339 `'XX'` in input, and other values missing from the file, will be assigned 1340 ID 0. All others are assigned the corresponding line number 1-50. 1341 1342 ```python 1343 states = categorical_column_with_vocabulary_file( 1344 key='states', vocabulary_file='/us/states.txt', vocabulary_size=51, 1345 default_value=0) 1346 columns = [states, ...] 1347 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1348 linear_prediction, _, _ = linear_model(features, columns) 1349 ``` 1350 1351 And to make an embedding with either: 1352 1353 ```python 1354 columns = [embedding_column(states, 3),...] 1355 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1356 dense_tensor = input_layer(features, columns) 1357 ``` 1358 1359 Args: 1360 key: A unique string identifying the input feature. It is used as the 1361 column name and the dictionary key for feature parsing configs, feature 1362 `Tensor` objects, and feature columns. 1363 vocabulary_file: The vocabulary file name. 1364 vocabulary_size: Number of the elements in the vocabulary. This must be no 1365 greater than length of `vocabulary_file`, if less than length, later 1366 values are ignored. If None, it is set to the length of `vocabulary_file`. 1367 dtype: The type of features. Only string and integer types are supported. 1368 default_value: The integer ID value to return for out-of-vocabulary feature 1369 values, defaults to `-1`. This can not be specified with a positive 1370 `num_oov_buckets`. 1371 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1372 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1373 `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of 1374 the input value. A positive `num_oov_buckets` can not be specified with 1375 `default_value`. 1376 1377 Returns: 1378 A `CategoricalColumn` with a vocabulary file. 1379 1380 Raises: 1381 ValueError: `vocabulary_file` is missing or cannot be opened. 1382 ValueError: `vocabulary_size` is missing or < 1. 1383 ValueError: `num_oov_buckets` is a negative integer. 1384 ValueError: `num_oov_buckets` and `default_value` are both specified. 1385 ValueError: `dtype` is neither string nor integer. 1386 """ 1387 if not vocabulary_file: 1388 raise ValueError('Missing vocabulary_file in {}.'.format(key)) 1389 1390 if vocabulary_size is None: 1391 if not gfile.Exists(vocabulary_file): 1392 raise ValueError('vocabulary_file in {} does not exist.'.format(key)) 1393 1394 with gfile.GFile(vocabulary_file, mode='rb') as f: 1395 vocabulary_size = sum(1 for _ in f) 1396 logging.info( 1397 'vocabulary_size = %d in %s is inferred from the number of elements ' 1398 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file) 1399 1400 # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`. 1401 if vocabulary_size < 1: 1402 raise ValueError('Invalid vocabulary_size in {}.'.format(key)) 1403 if num_oov_buckets: 1404 if default_value is not None: 1405 raise ValueError( 1406 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( 1407 key)) 1408 if num_oov_buckets < 0: 1409 raise ValueError('Invalid num_oov_buckets {} in {}.'.format( 1410 num_oov_buckets, key)) 1411 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1412 fc_utils.assert_key_is_string(key) 1413 return VocabularyFileCategoricalColumn( 1414 key=key, 1415 vocabulary_file=vocabulary_file, 1416 vocabulary_size=vocabulary_size, 1417 num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets, 1418 default_value=-1 if default_value is None else default_value, 1419 dtype=dtype) 1420 1421 1422@tf_export('feature_column.categorical_column_with_vocabulary_list') 1423def categorical_column_with_vocabulary_list(key, 1424 vocabulary_list, 1425 dtype=None, 1426 default_value=-1, 1427 num_oov_buckets=0): 1428 """A `CategoricalColumn` with in-memory vocabulary. 1429 1430 Use this when your inputs are in string or integer format, and you have an 1431 in-memory vocabulary mapping each value to an integer ID. By default, 1432 out-of-vocabulary values are ignored. Use either (but not both) of 1433 `num_oov_buckets` and `default_value` to specify how to include 1434 out-of-vocabulary values. 1435 1436 For input dictionary `features`, `features[key]` is either `Tensor` or 1437 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1438 and `''` for string, which will be dropped by this feature column. 1439 1440 Example with `num_oov_buckets`: 1441 In the following example, each input in `vocabulary_list` is assigned an ID 1442 0-3 corresponding to its index (e.g., input 'B' produces output 2). All other 1443 inputs are hashed and assigned an ID 4-5. 1444 1445 ```python 1446 colors = categorical_column_with_vocabulary_list( 1447 key='colors', vocabulary_list=('R', 'G', 'B', 'Y'), 1448 num_oov_buckets=2) 1449 columns = [colors, ...] 1450 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1451 linear_prediction, _, _ = linear_model(features, columns) 1452 ``` 1453 1454 Example with `default_value`: 1455 In the following example, each input in `vocabulary_list` is assigned an ID 1456 0-4 corresponding to its index (e.g., input 'B' produces output 3). All other 1457 inputs are assigned `default_value` 0. 1458 1459 1460 ```python 1461 colors = categorical_column_with_vocabulary_list( 1462 key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0) 1463 columns = [colors, ...] 1464 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1465 linear_prediction, _, _ = linear_model(features, columns) 1466 ``` 1467 1468 And to make an embedding with either: 1469 1470 ```python 1471 columns = [embedding_column(colors, 3),...] 1472 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1473 dense_tensor = input_layer(features, columns) 1474 ``` 1475 1476 Args: 1477 key: A unique string identifying the input feature. It is used as the column 1478 name and the dictionary key for feature parsing configs, feature `Tensor` 1479 objects, and feature columns. 1480 vocabulary_list: An ordered iterable defining the vocabulary. Each feature 1481 is mapped to the index of its value (if present) in `vocabulary_list`. 1482 Must be castable to `dtype`. 1483 dtype: The type of features. Only string and integer types are supported. If 1484 `None`, it will be inferred from `vocabulary_list`. 1485 default_value: The integer ID value to return for out-of-vocabulary feature 1486 values, defaults to `-1`. This can not be specified with a positive 1487 `num_oov_buckets`. 1488 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1489 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1490 `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a 1491 hash of the input value. A positive `num_oov_buckets` can not be specified 1492 with `default_value`. 1493 1494 Returns: 1495 A `CategoricalColumn` with in-memory vocabulary. 1496 1497 Raises: 1498 ValueError: if `vocabulary_list` is empty, or contains duplicate keys. 1499 ValueError: `num_oov_buckets` is a negative integer. 1500 ValueError: `num_oov_buckets` and `default_value` are both specified. 1501 ValueError: if `dtype` is not integer or string. 1502 """ 1503 if (vocabulary_list is None) or (len(vocabulary_list) < 1): 1504 raise ValueError( 1505 'vocabulary_list {} must be non-empty, column_name: {}'.format( 1506 vocabulary_list, key)) 1507 if len(set(vocabulary_list)) != len(vocabulary_list): 1508 raise ValueError( 1509 'Duplicate keys in vocabulary_list {}, column_name: {}'.format( 1510 vocabulary_list, key)) 1511 vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype) 1512 if num_oov_buckets: 1513 if default_value != -1: 1514 raise ValueError( 1515 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( 1516 key)) 1517 if num_oov_buckets < 0: 1518 raise ValueError('Invalid num_oov_buckets {} in {}.'.format( 1519 num_oov_buckets, key)) 1520 fc_utils.assert_string_or_int( 1521 vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key)) 1522 if dtype is None: 1523 dtype = vocabulary_dtype 1524 elif dtype.is_integer != vocabulary_dtype.is_integer: 1525 raise ValueError( 1526 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format( 1527 dtype, vocabulary_dtype, key)) 1528 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1529 fc_utils.assert_key_is_string(key) 1530 1531 return VocabularyListCategoricalColumn( 1532 key=key, 1533 vocabulary_list=tuple(vocabulary_list), 1534 dtype=dtype, 1535 default_value=default_value, 1536 num_oov_buckets=num_oov_buckets) 1537 1538 1539@tf_export('feature_column.categorical_column_with_identity') 1540def categorical_column_with_identity(key, num_buckets, default_value=None): 1541 """A `CategoricalColumn` that returns identity values. 1542 1543 Use this when your inputs are integers in the range `[0, num_buckets)`, and 1544 you want to use the input value itself as the categorical ID. Values outside 1545 this range will result in `default_value` if specified, otherwise it will 1546 fail. 1547 1548 Typically, this is used for contiguous ranges of integer indexes, but 1549 it doesn't have to be. This might be inefficient, however, if many of IDs 1550 are unused. Consider `categorical_column_with_hash_bucket` in that case. 1551 1552 For input dictionary `features`, `features[key]` is either `Tensor` or 1553 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1554 and `''` for string, which will be dropped by this feature column. 1555 1556 In the following examples, each input in the range `[0, 1000000)` is assigned 1557 the same value. All other inputs are assigned `default_value` 0. Note that a 1558 literal 0 in inputs will result in the same default ID. 1559 1560 Linear model: 1561 1562 ```python 1563 import tensorflow as tf 1564 video_id = tf.feature_column.categorical_column_with_identity( 1565 key='video_id', num_buckets=1000000, default_value=0) 1566 columns = [video_id] 1567 features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0], 1568 [33,78, 2, 73, 1]])} 1569 linear_prediction = tf.compat.v1.feature_column.linear_model(features, 1570 columns) 1571 ``` 1572 1573 Embedding for a DNN model: 1574 1575 ```python 1576 import tensorflow as tf 1577 video_id = tf.feature_column.categorical_column_with_identity( 1578 key='video_id', num_buckets=1000000, default_value=0) 1579 columns = [tf.feature_column.embedding_column(video_id, 9)] 1580 features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0], 1581 [33,78, 2, 73, 1]])} 1582 input_layer = tf.keras.layers.DenseFeatures(columns) 1583 dense_tensor = input_layer(features) 1584 ``` 1585 1586 Args: 1587 key: A unique string identifying the input feature. It is used as the 1588 column name and the dictionary key for feature parsing configs, feature 1589 `Tensor` objects, and feature columns. 1590 num_buckets: Range of inputs and outputs is `[0, num_buckets)`. 1591 default_value: If set, values outside of range `[0, num_buckets)` will 1592 be replaced with this value. If not set, values >= num_buckets will 1593 cause a failure while values < 0 will be dropped. 1594 1595 Returns: 1596 A `CategoricalColumn` that returns identity values. 1597 1598 Raises: 1599 ValueError: if `num_buckets` is less than one. 1600 ValueError: if `default_value` is not in range `[0, num_buckets)`. 1601 """ 1602 if num_buckets < 1: 1603 raise ValueError( 1604 'num_buckets {} < 1, column_name {}'.format(num_buckets, key)) 1605 if (default_value is not None) and ( 1606 (default_value < 0) or (default_value >= num_buckets)): 1607 raise ValueError( 1608 'default_value {} not in range [0, {}), column_name {}'.format( 1609 default_value, num_buckets, key)) 1610 fc_utils.assert_key_is_string(key) 1611 return IdentityCategoricalColumn( 1612 key=key, number_buckets=num_buckets, default_value=default_value) 1613 1614 1615@tf_export('feature_column.indicator_column') 1616def indicator_column(categorical_column): 1617 """Represents multi-hot representation of given categorical column. 1618 1619 - For DNN model, `indicator_column` can be used to wrap any 1620 `categorical_column_*` (e.g., to feed to DNN). Consider to Use 1621 `embedding_column` if the number of buckets/unique(values) are large. 1622 1623 - For Wide (aka linear) model, `indicator_column` is the internal 1624 representation for categorical column when passing categorical column 1625 directly (as any element in feature_columns) to `linear_model`. See 1626 `linear_model` for details. 1627 1628 ```python 1629 name = indicator_column(categorical_column_with_vocabulary_list( 1630 'name', ['bob', 'george', 'wanda'])) 1631 columns = [name, ...] 1632 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1633 dense_tensor = input_layer(features, columns) 1634 1635 dense_tensor == [[1, 0, 0]] # If "name" bytes_list is ["bob"] 1636 dense_tensor == [[1, 0, 1]] # If "name" bytes_list is ["bob", "wanda"] 1637 dense_tensor == [[2, 0, 0]] # If "name" bytes_list is ["bob", "bob"] 1638 ``` 1639 1640 Args: 1641 categorical_column: A `CategoricalColumn` which is created by 1642 `categorical_column_with_*` or `crossed_column` functions. 1643 1644 Returns: 1645 An `IndicatorColumn`. 1646 1647 Raises: 1648 ValueError: If `categorical_column` is not CategoricalColumn type. 1649 """ 1650 if not isinstance(categorical_column, 1651 (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access 1652 raise ValueError( 1653 'Unsupported input type. Input must be a CategoricalColumn. ' 1654 'Given: {}'.format(categorical_column)) 1655 return IndicatorColumn(categorical_column) 1656 1657 1658@tf_export('feature_column.weighted_categorical_column') 1659def weighted_categorical_column(categorical_column, 1660 weight_feature_key, 1661 dtype=dtypes.float32): 1662 """Applies weight values to a `CategoricalColumn`. 1663 1664 Use this when each of your sparse inputs has both an ID and a value. For 1665 example, if you're representing text documents as a collection of word 1666 frequencies, you can provide 2 parallel sparse input features ('terms' and 1667 'frequencies' below). 1668 1669 Example: 1670 1671 Input `tf.Example` objects: 1672 1673 ```proto 1674 [ 1675 features { 1676 feature { 1677 key: "terms" 1678 value {bytes_list {value: "very" value: "model"}} 1679 } 1680 feature { 1681 key: "frequencies" 1682 value {float_list {value: 0.3 value: 0.1}} 1683 } 1684 }, 1685 features { 1686 feature { 1687 key: "terms" 1688 value {bytes_list {value: "when" value: "course" value: "human"}} 1689 } 1690 feature { 1691 key: "frequencies" 1692 value {float_list {value: 0.4 value: 0.1 value: 0.2}} 1693 } 1694 } 1695 ] 1696 ``` 1697 1698 ```python 1699 categorical_column = categorical_column_with_hash_bucket( 1700 column_name='terms', hash_bucket_size=1000) 1701 weighted_column = weighted_categorical_column( 1702 categorical_column=categorical_column, weight_feature_key='frequencies') 1703 columns = [weighted_column, ...] 1704 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1705 linear_prediction, _, _ = linear_model(features, columns) 1706 ``` 1707 1708 This assumes the input dictionary contains a `SparseTensor` for key 1709 'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have 1710 the same indices and dense shape. 1711 1712 Args: 1713 categorical_column: A `CategoricalColumn` created by 1714 `categorical_column_with_*` functions. 1715 weight_feature_key: String key for weight values. 1716 dtype: Type of weights, such as `tf.float32`. Only float and integer weights 1717 are supported. 1718 1719 Returns: 1720 A `CategoricalColumn` composed of two sparse features: one represents id, 1721 the other represents weight (value) of the id feature in that example. 1722 1723 Raises: 1724 ValueError: if `dtype` is not convertible to float. 1725 """ 1726 if (dtype is None) or not (dtype.is_integer or dtype.is_floating): 1727 raise ValueError('dtype {} is not convertible to float.'.format(dtype)) 1728 return WeightedCategoricalColumn( 1729 categorical_column=categorical_column, 1730 weight_feature_key=weight_feature_key, 1731 dtype=dtype) 1732 1733 1734@tf_export('feature_column.crossed_column') 1735def crossed_column(keys, hash_bucket_size, hash_key=None): 1736 """Returns a column for performing crosses of categorical features. 1737 1738 Crossed features will be hashed according to `hash_bucket_size`. Conceptually, 1739 the transformation can be thought of as: 1740 Hash(cartesian product of features) % `hash_bucket_size` 1741 1742 For example, if the input features are: 1743 1744 * SparseTensor referred by first key: 1745 1746 ```python 1747 shape = [2, 2] 1748 { 1749 [0, 0]: "a" 1750 [1, 0]: "b" 1751 [1, 1]: "c" 1752 } 1753 ``` 1754 1755 * SparseTensor referred by second key: 1756 1757 ```python 1758 shape = [2, 1] 1759 { 1760 [0, 0]: "d" 1761 [1, 0]: "e" 1762 } 1763 ``` 1764 1765 then crossed feature will look like: 1766 1767 ```python 1768 shape = [2, 2] 1769 { 1770 [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size 1771 [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size 1772 [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size 1773 } 1774 ``` 1775 1776 Here is an example to create a linear model with crosses of string features: 1777 1778 ```python 1779 keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K) 1780 columns = [keywords_x_doc_terms, ...] 1781 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1782 linear_prediction = linear_model(features, columns) 1783 ``` 1784 1785 You could also use vocabulary lookup before crossing: 1786 1787 ```python 1788 keywords = categorical_column_with_vocabulary_file( 1789 'keywords', '/path/to/vocabulary/file', vocabulary_size=1K) 1790 keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K) 1791 columns = [keywords_x_doc_terms, ...] 1792 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1793 linear_prediction = linear_model(features, columns) 1794 ``` 1795 1796 If an input feature is of numeric type, you can use 1797 `categorical_column_with_identity`, or `bucketized_column`, as in the example: 1798 1799 ```python 1800 # vertical_id is an integer categorical feature. 1801 vertical_id = categorical_column_with_identity('vertical_id', 10K) 1802 price = numeric_column('price') 1803 # bucketized_column converts numerical feature to a categorical one. 1804 bucketized_price = bucketized_column(price, boundaries=[...]) 1805 vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K) 1806 columns = [vertical_id_x_price, ...] 1807 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1808 linear_prediction = linear_model(features, columns) 1809 ``` 1810 1811 To use crossed column in DNN model, you need to add it in an embedding column 1812 as in this example: 1813 1814 ```python 1815 vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K) 1816 vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10) 1817 dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...]) 1818 ``` 1819 1820 Args: 1821 keys: An iterable identifying the features to be crossed. Each element can 1822 be either: 1823 * string: Will use the corresponding feature which must be of string type. 1824 * `CategoricalColumn`: Will use the transformed tensor produced by this 1825 column. Does not support hashed categorical column. 1826 hash_bucket_size: An int > 1. The number of buckets. 1827 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 1828 function to combine the crosses fingerprints on SparseCrossOp (optional). 1829 1830 Returns: 1831 A `CrossedColumn`. 1832 1833 Raises: 1834 ValueError: If `len(keys) < 2`. 1835 ValueError: If any of the keys is neither a string nor `CategoricalColumn`. 1836 ValueError: If any of the keys is `HashedCategoricalColumn`. 1837 ValueError: If `hash_bucket_size < 1`. 1838 """ 1839 if not hash_bucket_size or hash_bucket_size < 1: 1840 raise ValueError('hash_bucket_size must be > 1. ' 1841 'hash_bucket_size: {}'.format(hash_bucket_size)) 1842 if not keys or len(keys) < 2: 1843 raise ValueError( 1844 'keys must be a list with length > 1. Given: {}'.format(keys)) 1845 for key in keys: 1846 if (not isinstance(key, six.string_types) and 1847 not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))): # pylint: disable=protected-access 1848 raise ValueError( 1849 'Unsupported key type. All keys must be either string, or ' 1850 'categorical column except HashedCategoricalColumn. ' 1851 'Given: {}'.format(key)) 1852 if isinstance(key, 1853 (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)): # pylint: disable=protected-access 1854 raise ValueError( 1855 'categorical_column_with_hash_bucket is not supported for crossing. ' 1856 'Hashing before crossing will increase probability of collision. ' 1857 'Instead, use the feature name as a string. Given: {}'.format(key)) 1858 return CrossedColumn( 1859 keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key) 1860 1861 1862@six.add_metaclass(abc.ABCMeta) 1863class FeatureColumn(object): 1864 """Represents a feature column abstraction. 1865 1866 WARNING: Do not subclass this layer unless you know what you are doing: 1867 the API is subject to future changes. 1868 1869 To distinguish between the concept of a feature family and a specific binary 1870 feature within a family, we refer to a feature family like "country" as a 1871 feature column. For example, we can have a feature in a `tf.Example` format: 1872 {key: "country", value: [ "US" ]} 1873 In this example the value of feature is "US" and "country" refers to the 1874 column of the feature. 1875 1876 This class is an abstract class. Users should not create instances of this. 1877 """ 1878 1879 @abc.abstractproperty 1880 def name(self): 1881 """Returns string. Used for naming.""" 1882 pass 1883 1884 def __lt__(self, other): 1885 """Allows feature columns to be sorted in Python 3 as they are in Python 2. 1886 1887 Feature columns need to occasionally be sortable, for example when used as 1888 keys in a features dictionary passed to a layer. 1889 1890 In CPython, `__lt__` must be defined for all objects in the 1891 sequence being sorted. 1892 1893 If any objects in the sequence being sorted do not have an `__lt__` method 1894 compatible with feature column objects (such as strings), then CPython will 1895 fall back to using the `__gt__` method below. 1896 https://docs.python.org/3/library/stdtypes.html#list.sort 1897 1898 Args: 1899 other: The other object to compare to. 1900 1901 Returns: 1902 True if the string representation of this object is lexicographically less 1903 than the string representation of `other`. For FeatureColumn objects, 1904 this looks like "<__main__.FeatureColumn object at 0xa>". 1905 """ 1906 return str(self) < str(other) 1907 1908 def __gt__(self, other): 1909 """Allows feature columns to be sorted in Python 3 as they are in Python 2. 1910 1911 Feature columns need to occasionally be sortable, for example when used as 1912 keys in a features dictionary passed to a layer. 1913 1914 `__gt__` is called when the "other" object being compared during the sort 1915 does not have `__lt__` defined. 1916 Example: 1917 ``` 1918 # __lt__ only class 1919 class A(): 1920 def __lt__(self, other): return str(self) < str(other) 1921 1922 a = A() 1923 a < "b" # True 1924 "0" < a # Error 1925 1926 # __lt__ and __gt__ class 1927 class B(): 1928 def __lt__(self, other): return str(self) < str(other) 1929 def __gt__(self, other): return str(self) > str(other) 1930 1931 b = B() 1932 b < "c" # True 1933 "0" < b # True 1934 ``` 1935 1936 Args: 1937 other: The other object to compare to. 1938 1939 Returns: 1940 True if the string representation of this object is lexicographically 1941 greater than the string representation of `other`. For FeatureColumn 1942 objects, this looks like "<__main__.FeatureColumn object at 0xa>". 1943 """ 1944 return str(self) > str(other) 1945 1946 @abc.abstractmethod 1947 def transform_feature(self, transformation_cache, state_manager): 1948 """Returns intermediate representation (usually a `Tensor`). 1949 1950 Uses `transformation_cache` to create an intermediate representation 1951 (usually a `Tensor`) that other feature columns can use. 1952 1953 Example usage of `transformation_cache`: 1954 Let's say a Feature column depends on raw feature ('raw') and another 1955 `FeatureColumn` (input_fc). To access corresponding `Tensor`s, 1956 transformation_cache will be used as follows: 1957 1958 ```python 1959 raw_tensor = transformation_cache.get('raw', state_manager) 1960 fc_tensor = transformation_cache.get(input_fc, state_manager) 1961 ``` 1962 1963 Args: 1964 transformation_cache: A `FeatureTransformationCache` object to access 1965 features. 1966 state_manager: A `StateManager` to create / access resources such as 1967 lookup tables. 1968 1969 Returns: 1970 Transformed feature `Tensor`. 1971 """ 1972 pass 1973 1974 @abc.abstractproperty 1975 def parse_example_spec(self): 1976 """Returns a `tf.Example` parsing spec as dict. 1977 1978 It is used for get_parsing_spec for `tf.io.parse_example`. Returned spec is 1979 a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other 1980 supported objects. Please check documentation of `tf.io.parse_example` for 1981 all supported spec objects. 1982 1983 Let's say a Feature column depends on raw feature ('raw') and another 1984 `FeatureColumn` (input_fc). One possible implementation of 1985 parse_example_spec is as follows: 1986 1987 ```python 1988 spec = {'raw': tf.io.FixedLenFeature(...)} 1989 spec.update(input_fc.parse_example_spec) 1990 return spec 1991 ``` 1992 """ 1993 pass 1994 1995 def create_state(self, state_manager): 1996 """Uses the `state_manager` to create state for the FeatureColumn. 1997 1998 Args: 1999 state_manager: A `StateManager` to create / access resources such as 2000 lookup tables and variables. 2001 """ 2002 pass 2003 2004 @abc.abstractproperty 2005 def _is_v2_column(self): 2006 """Returns whether this FeatureColumn is fully conformant to the new API. 2007 2008 This is needed for composition type cases where an EmbeddingColumn etc. 2009 might take in old categorical columns as input and then we want to use the 2010 old API. 2011 """ 2012 pass 2013 2014 @abc.abstractproperty 2015 def parents(self): 2016 """Returns a list of immediate raw feature and FeatureColumn dependencies. 2017 2018 For example: 2019 # For the following feature columns 2020 a = numeric_column('f1') 2021 c = crossed_column(a, 'f2') 2022 # The expected parents are: 2023 a.parents = ['f1'] 2024 c.parents = [a, 'f2'] 2025 """ 2026 pass 2027 2028 def get_config(self): 2029 """Returns the config of the feature column. 2030 2031 A FeatureColumn config is a Python dictionary (serializable) containing the 2032 configuration of a FeatureColumn. The same FeatureColumn can be 2033 reinstantiated later from this configuration. 2034 2035 The config of a feature column does not include information about feature 2036 columns depending on it nor the FeatureColumn class name. 2037 2038 Example with (de)serialization practices followed in this file: 2039 ```python 2040 class SerializationExampleFeatureColumn( 2041 FeatureColumn, collections.namedtuple( 2042 'SerializationExampleFeatureColumn', 2043 ('dimension', 'parent', 'dtype', 'normalizer_fn'))): 2044 2045 def get_config(self): 2046 # Create a dict from the namedtuple. 2047 # Python attribute literals can be directly copied from / to the config. 2048 # For example 'dimension', assuming it is an integer literal. 2049 config = dict(zip(self._fields, self)) 2050 2051 # (De)serialization of parent FeatureColumns should use the provided 2052 # (de)serialize_feature_column() methods that take care of de-duping. 2053 config['parent'] = serialize_feature_column(self.parent) 2054 2055 # Many objects provide custom (de)serialization e.g: for tf.DType 2056 # tf.DType.name, tf.as_dtype() can be used. 2057 config['dtype'] = self.dtype.name 2058 2059 # Non-trivial dependencies should be Keras-(de)serializable. 2060 config['normalizer_fn'] = generic_utils.serialize_keras_object( 2061 self.normalizer_fn) 2062 2063 return config 2064 2065 @classmethod 2066 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2067 # This should do the inverse transform from `get_config` and construct 2068 # the namedtuple. 2069 kwargs = config.copy() 2070 kwargs['parent'] = deserialize_feature_column( 2071 config['parent'], custom_objects, columns_by_name) 2072 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 2073 kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object( 2074 config['normalizer_fn'], custom_objects=custom_objects) 2075 return cls(**kwargs) 2076 2077 ``` 2078 Returns: 2079 A serializable Dict that can be used to deserialize the object with 2080 from_config. 2081 """ 2082 return self._get_config() 2083 2084 def _get_config(self): 2085 raise NotImplementedError('Must be implemented in subclasses.') 2086 2087 @classmethod 2088 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2089 """Creates a FeatureColumn from its config. 2090 2091 This method should be the reverse of `get_config`, capable of instantiating 2092 the same FeatureColumn from the config dictionary. See `get_config` for an 2093 example of common (de)serialization practices followed in this file. 2094 2095 TODO(b/118939620): This is a private method until consensus is reached on 2096 supporting object deserialization deduping within Keras. 2097 2098 Args: 2099 config: A Dict config acquired with `get_config`. 2100 custom_objects: Optional dictionary mapping names (strings) to custom 2101 classes or functions to be considered during deserialization. 2102 columns_by_name: A Dict[String, FeatureColumn] of existing columns in 2103 order to avoid duplication. Should be passed to any calls to 2104 deserialize_feature_column(). 2105 2106 Returns: 2107 A FeatureColumn for the input config. 2108 """ 2109 return cls._from_config(config, custom_objects, columns_by_name) 2110 2111 @classmethod 2112 def _from_config(cls, config, custom_objects=None, columns_by_name=None): 2113 raise NotImplementedError('Must be implemented in subclasses.') 2114 2115 2116class DenseColumn(FeatureColumn): 2117 """Represents a column which can be represented as `Tensor`. 2118 2119 Some examples of this type are: numeric_column, embedding_column, 2120 indicator_column. 2121 """ 2122 2123 @abc.abstractproperty 2124 def variable_shape(self): 2125 """`TensorShape` of `get_dense_tensor`, without batch dimension.""" 2126 pass 2127 2128 @abc.abstractmethod 2129 def get_dense_tensor(self, transformation_cache, state_manager): 2130 """Returns a `Tensor`. 2131 2132 The output of this function will be used by model-builder-functions. For 2133 example the pseudo code of `input_layer` will be like: 2134 2135 ```python 2136 def input_layer(features, feature_columns, ...): 2137 outputs = [fc.get_dense_tensor(...) for fc in feature_columns] 2138 return tf.concat(outputs) 2139 ``` 2140 2141 Args: 2142 transformation_cache: A `FeatureTransformationCache` object to access 2143 features. 2144 state_manager: A `StateManager` to create / access resources such as 2145 lookup tables. 2146 2147 Returns: 2148 `Tensor` of shape [batch_size] + `variable_shape`. 2149 """ 2150 pass 2151 2152 2153def is_feature_column_v2(feature_columns): 2154 """Returns True if all feature columns are V2.""" 2155 for feature_column in feature_columns: 2156 if not isinstance(feature_column, FeatureColumn): 2157 return False 2158 if not feature_column._is_v2_column: # pylint: disable=protected-access 2159 return False 2160 return True 2161 2162 2163def _create_weighted_sum(column, transformation_cache, state_manager, 2164 sparse_combiner, weight_var): 2165 """Creates a weighted sum for a dense/categorical column for linear_model.""" 2166 if isinstance(column, CategoricalColumn): 2167 return _create_categorical_column_weighted_sum( 2168 column=column, 2169 transformation_cache=transformation_cache, 2170 state_manager=state_manager, 2171 sparse_combiner=sparse_combiner, 2172 weight_var=weight_var) 2173 else: 2174 return _create_dense_column_weighted_sum( 2175 column=column, 2176 transformation_cache=transformation_cache, 2177 state_manager=state_manager, 2178 weight_var=weight_var) 2179 2180 2181def _create_dense_column_weighted_sum(column, transformation_cache, 2182 state_manager, weight_var): 2183 """Create a weighted sum of a dense column for linear_model.""" 2184 tensor = column.get_dense_tensor(transformation_cache, state_manager) 2185 num_elements = column.variable_shape.num_elements() 2186 batch_size = array_ops.shape(tensor)[0] 2187 tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements)) 2188 return math_ops.matmul(tensor, weight_var, name='weighted_sum') 2189 2190 2191class CategoricalColumn(FeatureColumn): 2192 """Represents a categorical feature. 2193 2194 A categorical feature typically handled with a `tf.sparse.SparseTensor` of 2195 IDs. 2196 """ 2197 2198 IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name 2199 'IdWeightPair', ('id_tensor', 'weight_tensor')) 2200 2201 @abc.abstractproperty 2202 def num_buckets(self): 2203 """Returns number of buckets in this sparse feature.""" 2204 pass 2205 2206 @abc.abstractmethod 2207 def get_sparse_tensors(self, transformation_cache, state_manager): 2208 """Returns an IdWeightPair. 2209 2210 `IdWeightPair` is a pair of `SparseTensor`s which represents ids and 2211 weights. 2212 2213 `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets` 2214 `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a 2215 `SparseTensor` of `float` or `None` to indicate all weights should be 2216 taken to be 1. If specified, `weight_tensor` must have exactly the same 2217 shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing 2218 output of a `VarLenFeature` which is a ragged matrix. 2219 2220 Args: 2221 transformation_cache: A `FeatureTransformationCache` object to access 2222 features. 2223 state_manager: A `StateManager` to create / access resources such as 2224 lookup tables. 2225 """ 2226 pass 2227 2228 2229def _create_categorical_column_weighted_sum( 2230 column, transformation_cache, state_manager, sparse_combiner, weight_var): 2231 # pylint: disable=g-doc-return-or-yield,g-doc-args 2232 """Create a weighted sum of a categorical column for linear_model. 2233 2234 Note to maintainer: As implementation details, the weighted sum is 2235 implemented via embedding_lookup_sparse toward efficiency. Mathematically, 2236 they are the same. 2237 2238 To be specific, conceptually, categorical column can be treated as multi-hot 2239 vector. Say: 2240 2241 ```python 2242 x = [0 0 1] # categorical column input 2243 w = [a b c] # weights 2244 ``` 2245 The weighted sum is `c` in this case, which is same as `w[2]`. 2246 2247 Another example is 2248 2249 ```python 2250 x = [0 1 1] # categorical column input 2251 w = [a b c] # weights 2252 ``` 2253 The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`. 2254 2255 For both cases, we can implement weighted sum via embedding_lookup with 2256 sparse_combiner = "sum". 2257 """ 2258 2259 sparse_tensors = column.get_sparse_tensors(transformation_cache, 2260 state_manager) 2261 id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [ 2262 array_ops.shape(sparse_tensors.id_tensor)[0], -1 2263 ]) 2264 weight_tensor = sparse_tensors.weight_tensor 2265 if weight_tensor is not None: 2266 weight_tensor = sparse_ops.sparse_reshape( 2267 weight_tensor, [array_ops.shape(weight_tensor)[0], -1]) 2268 2269 return embedding_ops.safe_embedding_lookup_sparse( 2270 weight_var, 2271 id_tensor, 2272 sparse_weights=weight_tensor, 2273 combiner=sparse_combiner, 2274 name='weighted_sum') 2275 2276 2277class SequenceDenseColumn(FeatureColumn): 2278 """Represents dense sequence data.""" 2279 2280 TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name 2281 'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length')) 2282 2283 @abc.abstractmethod 2284 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 2285 """Returns a `TensorSequenceLengthPair`. 2286 2287 Args: 2288 transformation_cache: A `FeatureTransformationCache` object to access 2289 features. 2290 state_manager: A `StateManager` to create / access resources such as 2291 lookup tables. 2292 """ 2293 pass 2294 2295 2296class FeatureTransformationCache(object): 2297 """Handles caching of transformations while building the model. 2298 2299 `FeatureColumn` specifies how to digest an input column to the network. Some 2300 feature columns require data transformations. This class caches those 2301 transformations. 2302 2303 Some features may be used in more than one place. For example, one can use a 2304 bucketized feature by itself and a cross with it. In that case we 2305 should create only one bucketization op instead of creating ops for each 2306 feature column separately. To handle re-use of transformed columns, 2307 `FeatureTransformationCache` caches all previously transformed columns. 2308 2309 Example: 2310 We're trying to use the following `FeatureColumn`s: 2311 2312 ```python 2313 bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...) 2314 keywords = fc.categorical_column_with_hash_buckets("keywords", ...) 2315 age_X_keywords = fc.crossed_column([bucketized_age, "keywords"]) 2316 ... = linear_model(features, 2317 [bucketized_age, keywords, age_X_keywords] 2318 ``` 2319 2320 If we transform each column independently, then we'll get duplication of 2321 bucketization (one for cross, one for bucketization itself). 2322 The `FeatureTransformationCache` eliminates this duplication. 2323 """ 2324 2325 def __init__(self, features): 2326 """Creates a `FeatureTransformationCache`. 2327 2328 Args: 2329 features: A mapping from feature column to objects that are `Tensor` or 2330 `SparseTensor`, or can be converted to same via 2331 `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key 2332 signifies a base feature (not-transformed). A `FeatureColumn` key 2333 means that this `Tensor` is the output of an existing `FeatureColumn` 2334 which can be reused. 2335 """ 2336 self._features = features.copy() 2337 self._feature_tensors = {} 2338 2339 def get(self, key, state_manager, training=None): 2340 """Returns a `Tensor` for the given key. 2341 2342 A `str` key is used to access a base feature (not-transformed). When a 2343 `FeatureColumn` is passed, the transformed feature is returned if it 2344 already exists, otherwise the given `FeatureColumn` is asked to provide its 2345 transformed output, which is then cached. 2346 2347 Args: 2348 key: a `str` or a `FeatureColumn`. 2349 state_manager: A StateManager object that holds the FeatureColumn state. 2350 training: Boolean indicating whether to the column is being used in 2351 training mode. This argument is passed to the transform_feature method 2352 of any `FeatureColumn` that takes a `training` argument. For example, if 2353 a `FeatureColumn` performed dropout, it could expose a `training` 2354 argument to control whether the dropout should be applied. 2355 2356 Returns: 2357 The transformed `Tensor` corresponding to the `key`. 2358 2359 Raises: 2360 ValueError: if key is not found or a transformed `Tensor` cannot be 2361 computed. 2362 """ 2363 if key in self._feature_tensors: 2364 # FeatureColumn is already transformed or converted. 2365 return self._feature_tensors[key] 2366 2367 if key in self._features: 2368 feature_tensor = self._get_raw_feature_as_tensor(key) 2369 self._feature_tensors[key] = feature_tensor 2370 return feature_tensor 2371 2372 if isinstance(key, six.string_types): 2373 raise ValueError('Feature {} is not in features dictionary.'.format(key)) 2374 2375 if not isinstance(key, FeatureColumn): 2376 raise TypeError('"key" must be either a "str" or "FeatureColumn". ' 2377 'Provided: {}'.format(key)) 2378 2379 column = key 2380 logging.debug('Transforming feature_column %s.', column) 2381 2382 # Some columns may need information about whether the transformation is 2383 # happening in training or prediction mode, but not all columns expose this 2384 # argument. 2385 try: 2386 transformed = column.transform_feature( 2387 self, state_manager, training=training) 2388 except TypeError: 2389 transformed = column.transform_feature(self, state_manager) 2390 if transformed is None: 2391 raise ValueError('Column {} is not supported.'.format(column.name)) 2392 self._feature_tensors[column] = transformed 2393 return transformed 2394 2395 def _get_raw_feature_as_tensor(self, key): 2396 """Gets the raw_feature (keyed by `key`) as `tensor`. 2397 2398 The raw feature is converted to (sparse) tensor and maybe expand dim. 2399 2400 For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if 2401 the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will 2402 error out as it is not supported. 2403 2404 Args: 2405 key: A `str` key to access the raw feature. 2406 2407 Returns: 2408 A `Tensor` or `SparseTensor`. 2409 2410 Raises: 2411 ValueError: if the raw feature has rank 0. 2412 """ 2413 raw_feature = self._features[key] 2414 feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 2415 raw_feature) 2416 2417 def expand_dims(input_tensor): 2418 # Input_tensor must have rank 1. 2419 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2420 return sparse_ops.sparse_reshape( 2421 input_tensor, [array_ops.shape(input_tensor)[0], 1]) 2422 else: 2423 return array_ops.expand_dims(input_tensor, -1) 2424 2425 rank = feature_tensor.get_shape().ndims 2426 if rank is not None: 2427 if rank == 0: 2428 raise ValueError( 2429 'Feature (key: {}) cannot have rank 0. Given: {}'.format( 2430 key, feature_tensor)) 2431 return feature_tensor if rank != 1 else expand_dims(feature_tensor) 2432 2433 # Handle dynamic rank. 2434 with ops.control_dependencies([ 2435 check_ops.assert_positive( 2436 array_ops.rank(feature_tensor), 2437 message='Feature (key: {}) cannot have rank 0. Given: {}'.format( 2438 key, feature_tensor))]): 2439 return control_flow_ops.cond( 2440 math_ops.equal(1, array_ops.rank(feature_tensor)), 2441 lambda: expand_dims(feature_tensor), 2442 lambda: feature_tensor) 2443 2444 2445# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py 2446def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None): 2447 """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. 2448 2449 If `input_tensor` is already a `SparseTensor`, just return it. 2450 2451 Args: 2452 input_tensor: A string or integer `Tensor`. 2453 ignore_value: Entries in `dense_tensor` equal to this value will be 2454 absent from the resulting `SparseTensor`. If `None`, default value of 2455 `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`). 2456 2457 Returns: 2458 A `SparseTensor` with the same shape as `input_tensor`. 2459 2460 Raises: 2461 ValueError: when `input_tensor`'s rank is `None`. 2462 """ 2463 input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 2464 input_tensor) 2465 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2466 return input_tensor 2467 with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)): 2468 if ignore_value is None: 2469 if input_tensor.dtype == dtypes.string: 2470 # Exception due to TF strings are converted to numpy objects by default. 2471 ignore_value = '' 2472 elif input_tensor.dtype.is_integer: 2473 ignore_value = -1 # -1 has a special meaning of missing feature 2474 else: 2475 # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is 2476 # constructing a new numpy object of the given type, which yields the 2477 # default value for that type. 2478 ignore_value = input_tensor.dtype.as_numpy_dtype() 2479 ignore_value = math_ops.cast( 2480 ignore_value, input_tensor.dtype, name='ignore_value') 2481 indices = array_ops.where_v2( 2482 math_ops.not_equal(input_tensor, ignore_value), name='indices') 2483 return sparse_tensor_lib.SparseTensor( 2484 indices=indices, 2485 values=array_ops.gather_nd(input_tensor, indices, name='values'), 2486 dense_shape=array_ops.shape( 2487 input_tensor, out_type=dtypes.int64, name='dense_shape')) 2488 2489 2490def _normalize_feature_columns(feature_columns): 2491 """Normalizes the `feature_columns` input. 2492 2493 This method converts the `feature_columns` to list type as best as it can. In 2494 addition, verifies the type and other parts of feature_columns, required by 2495 downstream library. 2496 2497 Args: 2498 feature_columns: The raw feature columns, usually passed by users. 2499 2500 Returns: 2501 The normalized feature column list. 2502 2503 Raises: 2504 ValueError: for any invalid inputs, such as empty, duplicated names, etc. 2505 """ 2506 if isinstance(feature_columns, FeatureColumn): 2507 feature_columns = [feature_columns] 2508 2509 if isinstance(feature_columns, collections_abc.Iterator): 2510 feature_columns = list(feature_columns) 2511 2512 if isinstance(feature_columns, dict): 2513 raise ValueError('Expected feature_columns to be iterable, found dict.') 2514 2515 for column in feature_columns: 2516 if not isinstance(column, FeatureColumn): 2517 raise ValueError('Items of feature_columns must be a FeatureColumn. ' 2518 'Given (type {}): {}.'.format(type(column), column)) 2519 if not feature_columns: 2520 raise ValueError('feature_columns must not be empty.') 2521 name_to_column = {} 2522 for column in feature_columns: 2523 if column.name in name_to_column: 2524 raise ValueError('Duplicate feature column name found for columns: {} ' 2525 'and {}. This usually means that these columns refer to ' 2526 'same base feature. Either one must be discarded or a ' 2527 'duplicated but renamed item must be inserted in ' 2528 'features dict.'.format(column, 2529 name_to_column[column.name])) 2530 name_to_column[column.name] = column 2531 2532 return sorted(feature_columns, key=lambda x: x.name) 2533 2534 2535class NumericColumn( 2536 DenseColumn, 2537 fc_old._DenseColumn, # pylint: disable=protected-access 2538 collections.namedtuple( 2539 'NumericColumn', 2540 ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))): 2541 """see `numeric_column`.""" 2542 2543 @property 2544 def _is_v2_column(self): 2545 return True 2546 2547 @property 2548 def name(self): 2549 """See `FeatureColumn` base class.""" 2550 return self.key 2551 2552 @property 2553 def parse_example_spec(self): 2554 """See `FeatureColumn` base class.""" 2555 return { 2556 self.key: 2557 parsing_ops.FixedLenFeature(self.shape, self.dtype, 2558 self.default_value) 2559 } 2560 2561 @property 2562 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2563 _FEATURE_COLUMN_DEPRECATION) 2564 def _parse_example_spec(self): 2565 return self.parse_example_spec 2566 2567 def _transform_input_tensor(self, input_tensor): 2568 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2569 raise ValueError( 2570 'The corresponding Tensor of numerical column must be a Tensor. ' 2571 'SparseTensor is not supported. key: {}'.format(self.key)) 2572 if self.normalizer_fn is not None: 2573 input_tensor = self.normalizer_fn(input_tensor) 2574 return math_ops.cast(input_tensor, dtypes.float32) 2575 2576 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2577 _FEATURE_COLUMN_DEPRECATION) 2578 def _transform_feature(self, inputs): 2579 input_tensor = inputs.get(self.key) 2580 return self._transform_input_tensor(input_tensor) 2581 2582 def transform_feature(self, transformation_cache, state_manager): 2583 """See `FeatureColumn` base class. 2584 2585 In this case, we apply the `normalizer_fn` to the input tensor. 2586 2587 Args: 2588 transformation_cache: A `FeatureTransformationCache` object to access 2589 features. 2590 state_manager: A `StateManager` to create / access resources such as 2591 lookup tables. 2592 2593 Returns: 2594 Normalized input tensor. 2595 Raises: 2596 ValueError: If a SparseTensor is passed in. 2597 """ 2598 input_tensor = transformation_cache.get(self.key, state_manager) 2599 return self._transform_input_tensor(input_tensor) 2600 2601 @property 2602 def variable_shape(self): 2603 """See `DenseColumn` base class.""" 2604 return tensor_shape.TensorShape(self.shape) 2605 2606 @property 2607 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2608 _FEATURE_COLUMN_DEPRECATION) 2609 def _variable_shape(self): 2610 return self.variable_shape 2611 2612 def get_dense_tensor(self, transformation_cache, state_manager): 2613 """Returns dense `Tensor` representing numeric feature. 2614 2615 Args: 2616 transformation_cache: A `FeatureTransformationCache` object to access 2617 features. 2618 state_manager: A `StateManager` to create / access resources such as 2619 lookup tables. 2620 2621 Returns: 2622 Dense `Tensor` created within `transform_feature`. 2623 """ 2624 # Feature has been already transformed. Return the intermediate 2625 # representation created by _transform_feature. 2626 return transformation_cache.get(self, state_manager) 2627 2628 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2629 _FEATURE_COLUMN_DEPRECATION) 2630 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 2631 del weight_collections 2632 del trainable 2633 return inputs.get(self) 2634 2635 @property 2636 def parents(self): 2637 """See 'FeatureColumn` base class.""" 2638 return [self.key] 2639 2640 def get_config(self): 2641 """See 'FeatureColumn` base class.""" 2642 config = dict(zip(self._fields, self)) 2643 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 2644 config['normalizer_fn'] = serialization._serialize_keras_object( # pylint: disable=protected-access 2645 self.normalizer_fn) 2646 config['dtype'] = self.dtype.name 2647 return config 2648 2649 @classmethod 2650 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2651 """See 'FeatureColumn` base class.""" 2652 _check_config_keys(config, cls._fields) 2653 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 2654 kwargs = _standardize_and_copy_config(config) 2655 kwargs['normalizer_fn'] = serialization._deserialize_keras_object( # pylint: disable=protected-access 2656 config['normalizer_fn'], custom_objects=custom_objects) 2657 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 2658 2659 return cls(**kwargs) 2660 2661 2662class BucketizedColumn( 2663 DenseColumn, 2664 CategoricalColumn, 2665 fc_old._DenseColumn, # pylint: disable=protected-access 2666 fc_old._CategoricalColumn, # pylint: disable=protected-access 2667 collections.namedtuple('BucketizedColumn', 2668 ('source_column', 'boundaries'))): 2669 """See `bucketized_column`.""" 2670 2671 @property 2672 def _is_v2_column(self): 2673 return (isinstance(self.source_column, FeatureColumn) and 2674 self.source_column._is_v2_column) # pylint: disable=protected-access 2675 2676 @property 2677 def name(self): 2678 """See `FeatureColumn` base class.""" 2679 return '{}_bucketized'.format(self.source_column.name) 2680 2681 @property 2682 def parse_example_spec(self): 2683 """See `FeatureColumn` base class.""" 2684 return self.source_column.parse_example_spec 2685 2686 @property 2687 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2688 _FEATURE_COLUMN_DEPRECATION) 2689 def _parse_example_spec(self): 2690 return self.source_column._parse_example_spec # pylint: disable=protected-access 2691 2692 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2693 _FEATURE_COLUMN_DEPRECATION) 2694 def _transform_feature(self, inputs): 2695 """Returns bucketized categorical `source_column` tensor.""" 2696 source_tensor = inputs.get(self.source_column) 2697 return math_ops._bucketize( # pylint: disable=protected-access 2698 source_tensor, 2699 boundaries=self.boundaries) 2700 2701 def transform_feature(self, transformation_cache, state_manager): 2702 """Returns bucketized categorical `source_column` tensor.""" 2703 source_tensor = transformation_cache.get(self.source_column, state_manager) 2704 return math_ops._bucketize( # pylint: disable=protected-access 2705 source_tensor, 2706 boundaries=self.boundaries) 2707 2708 @property 2709 def variable_shape(self): 2710 """See `DenseColumn` base class.""" 2711 return tensor_shape.TensorShape( 2712 tuple(self.source_column.shape) + (len(self.boundaries) + 1,)) 2713 2714 @property 2715 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2716 _FEATURE_COLUMN_DEPRECATION) 2717 def _variable_shape(self): 2718 return self.variable_shape 2719 2720 def _get_dense_tensor_for_input_tensor(self, input_tensor): 2721 return array_ops.one_hot( 2722 indices=math_ops.cast(input_tensor, dtypes.int64), 2723 depth=len(self.boundaries) + 1, 2724 on_value=1., 2725 off_value=0.) 2726 2727 def get_dense_tensor(self, transformation_cache, state_manager): 2728 """Returns one hot encoded dense `Tensor`.""" 2729 input_tensor = transformation_cache.get(self, state_manager) 2730 return self._get_dense_tensor_for_input_tensor(input_tensor) 2731 2732 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2733 _FEATURE_COLUMN_DEPRECATION) 2734 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 2735 del weight_collections 2736 del trainable 2737 input_tensor = inputs.get(self) 2738 return self._get_dense_tensor_for_input_tensor(input_tensor) 2739 2740 @property 2741 def num_buckets(self): 2742 """See `CategoricalColumn` base class.""" 2743 # By construction, source_column is always one-dimensional. 2744 return (len(self.boundaries) + 1) * self.source_column.shape[0] 2745 2746 @property 2747 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2748 _FEATURE_COLUMN_DEPRECATION) 2749 def _num_buckets(self): 2750 return self.num_buckets 2751 2752 def _get_sparse_tensors_for_input_tensor(self, input_tensor): 2753 batch_size = array_ops.shape(input_tensor)[0] 2754 # By construction, source_column is always one-dimensional. 2755 source_dimension = self.source_column.shape[0] 2756 2757 i1 = array_ops.reshape( 2758 array_ops.tile( 2759 array_ops.expand_dims(math_ops.range(0, batch_size), 1), 2760 [1, source_dimension]), 2761 (-1,)) 2762 i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size]) 2763 # Flatten the bucket indices and unique them across dimensions 2764 # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets 2765 bucket_indices = ( 2766 array_ops.reshape(input_tensor, (-1,)) + 2767 (len(self.boundaries) + 1) * i2) 2768 2769 indices = math_ops.cast( 2770 array_ops.transpose(array_ops.stack((i1, i2))), dtypes.int64) 2771 dense_shape = math_ops.cast( 2772 array_ops.stack([batch_size, source_dimension]), dtypes.int64) 2773 sparse_tensor = sparse_tensor_lib.SparseTensor( 2774 indices=indices, 2775 values=bucket_indices, 2776 dense_shape=dense_shape) 2777 return CategoricalColumn.IdWeightPair(sparse_tensor, None) 2778 2779 def get_sparse_tensors(self, transformation_cache, state_manager): 2780 """Converts dense inputs to SparseTensor so downstream code can use it.""" 2781 input_tensor = transformation_cache.get(self, state_manager) 2782 return self._get_sparse_tensors_for_input_tensor(input_tensor) 2783 2784 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2785 _FEATURE_COLUMN_DEPRECATION) 2786 def _get_sparse_tensors(self, inputs, weight_collections=None, 2787 trainable=None): 2788 """Converts dense inputs to SparseTensor so downstream code can use it.""" 2789 del weight_collections 2790 del trainable 2791 input_tensor = inputs.get(self) 2792 return self._get_sparse_tensors_for_input_tensor(input_tensor) 2793 2794 @property 2795 def parents(self): 2796 """See 'FeatureColumn` base class.""" 2797 return [self.source_column] 2798 2799 def get_config(self): 2800 """See 'FeatureColumn` base class.""" 2801 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 2802 config = dict(zip(self._fields, self)) 2803 config['source_column'] = serialize_feature_column(self.source_column) 2804 return config 2805 2806 @classmethod 2807 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2808 """See 'FeatureColumn` base class.""" 2809 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 2810 _check_config_keys(config, cls._fields) 2811 kwargs = _standardize_and_copy_config(config) 2812 kwargs['source_column'] = deserialize_feature_column( 2813 config['source_column'], custom_objects, columns_by_name) 2814 return cls(**kwargs) 2815 2816 2817class EmbeddingColumn( 2818 DenseColumn, 2819 SequenceDenseColumn, 2820 fc_old._DenseColumn, # pylint: disable=protected-access 2821 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 2822 collections.namedtuple( 2823 'EmbeddingColumn', 2824 ('categorical_column', 'dimension', 'combiner', 'initializer', 2825 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable', 2826 'use_safe_embedding_lookup'))): 2827 """See `embedding_column`.""" 2828 2829 def __new__(cls, 2830 categorical_column, 2831 dimension, 2832 combiner, 2833 initializer, 2834 ckpt_to_load_from, 2835 tensor_name_in_ckpt, 2836 max_norm, 2837 trainable, 2838 use_safe_embedding_lookup=True): 2839 return super(EmbeddingColumn, cls).__new__( 2840 cls, 2841 categorical_column=categorical_column, 2842 dimension=dimension, 2843 combiner=combiner, 2844 initializer=initializer, 2845 ckpt_to_load_from=ckpt_to_load_from, 2846 tensor_name_in_ckpt=tensor_name_in_ckpt, 2847 max_norm=max_norm, 2848 trainable=trainable, 2849 use_safe_embedding_lookup=use_safe_embedding_lookup) 2850 2851 @property 2852 def _is_v2_column(self): 2853 return (isinstance(self.categorical_column, FeatureColumn) and 2854 self.categorical_column._is_v2_column) # pylint: disable=protected-access 2855 2856 @property 2857 def name(self): 2858 """See `FeatureColumn` base class.""" 2859 return '{}_embedding'.format(self.categorical_column.name) 2860 2861 @property 2862 def parse_example_spec(self): 2863 """See `FeatureColumn` base class.""" 2864 return self.categorical_column.parse_example_spec 2865 2866 @property 2867 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2868 _FEATURE_COLUMN_DEPRECATION) 2869 def _parse_example_spec(self): 2870 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 2871 2872 def transform_feature(self, transformation_cache, state_manager): 2873 """Transforms underlying `categorical_column`.""" 2874 return transformation_cache.get(self.categorical_column, state_manager) 2875 2876 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2877 _FEATURE_COLUMN_DEPRECATION) 2878 def _transform_feature(self, inputs): 2879 return inputs.get(self.categorical_column) 2880 2881 @property 2882 def variable_shape(self): 2883 """See `DenseColumn` base class.""" 2884 return tensor_shape.TensorShape([self.dimension]) 2885 2886 @property 2887 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2888 _FEATURE_COLUMN_DEPRECATION) 2889 def _variable_shape(self): 2890 return self.variable_shape 2891 2892 def create_state(self, state_manager): 2893 """Creates the embedding lookup variable.""" 2894 default_num_buckets = (self.categorical_column.num_buckets 2895 if self._is_v2_column 2896 else self.categorical_column._num_buckets) # pylint: disable=protected-access 2897 num_buckets = getattr(self.categorical_column, 'num_buckets', 2898 default_num_buckets) 2899 embedding_shape = (num_buckets, self.dimension) 2900 state_manager.create_variable( 2901 self, 2902 name='embedding_weights', 2903 shape=embedding_shape, 2904 dtype=dtypes.float32, 2905 trainable=self.trainable, 2906 use_resource=True, 2907 initializer=self.initializer) 2908 2909 def _get_dense_tensor_internal_helper(self, sparse_tensors, 2910 embedding_weights): 2911 sparse_ids = sparse_tensors.id_tensor 2912 sparse_weights = sparse_tensors.weight_tensor 2913 2914 if self.ckpt_to_load_from is not None: 2915 to_restore = embedding_weights 2916 if isinstance(to_restore, variables.PartitionedVariable): 2917 to_restore = to_restore._get_variable_list() # pylint: disable=protected-access 2918 checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, { 2919 self.tensor_name_in_ckpt: to_restore 2920 }) 2921 2922 sparse_id_rank = tensor_shape.dimension_value( 2923 sparse_ids.dense_shape.get_shape()[0]) 2924 embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse 2925 if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and 2926 sparse_id_rank <= 2): 2927 embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2 2928 # Return embedding lookup result. 2929 return embedding_lookup_sparse( 2930 embedding_weights, 2931 sparse_ids, 2932 sparse_weights, 2933 combiner=self.combiner, 2934 name='%s_weights' % self.name, 2935 max_norm=self.max_norm) 2936 2937 def _get_dense_tensor_internal(self, sparse_tensors, state_manager): 2938 """Private method that follows the signature of get_dense_tensor.""" 2939 embedding_weights = state_manager.get_variable( 2940 self, name='embedding_weights') 2941 return self._get_dense_tensor_internal_helper(sparse_tensors, 2942 embedding_weights) 2943 2944 def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections, 2945 trainable): 2946 """Private method that follows the signature of _get_dense_tensor.""" 2947 embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access 2948 if (weight_collections and 2949 ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections): 2950 weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) 2951 embedding_weights = variable_scope.get_variable( 2952 name='embedding_weights', 2953 shape=embedding_shape, 2954 dtype=dtypes.float32, 2955 initializer=self.initializer, 2956 trainable=self.trainable and trainable, 2957 collections=weight_collections) 2958 return self._get_dense_tensor_internal_helper(sparse_tensors, 2959 embedding_weights) 2960 2961 def get_dense_tensor(self, transformation_cache, state_manager): 2962 """Returns tensor after doing the embedding lookup. 2963 2964 Args: 2965 transformation_cache: A `FeatureTransformationCache` object to access 2966 features. 2967 state_manager: A `StateManager` to create / access resources such as 2968 lookup tables. 2969 2970 Returns: 2971 Embedding lookup tensor. 2972 2973 Raises: 2974 ValueError: `categorical_column` is SequenceCategoricalColumn. 2975 """ 2976 if isinstance(self.categorical_column, SequenceCategoricalColumn): 2977 raise ValueError( 2978 'In embedding_column: {}. ' 2979 'categorical_column must not be of type SequenceCategoricalColumn. ' 2980 'Suggested fix A: If you wish to use DenseFeatures, use a ' 2981 'non-sequence categorical_column_with_*. ' 2982 'Suggested fix B: If you wish to create sequence input, use ' 2983 'SequenceFeatures instead of DenseFeatures. ' 2984 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 2985 self.categorical_column)) 2986 # Get sparse IDs and weights. 2987 sparse_tensors = self.categorical_column.get_sparse_tensors( 2988 transformation_cache, state_manager) 2989 return self._get_dense_tensor_internal(sparse_tensors, state_manager) 2990 2991 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2992 _FEATURE_COLUMN_DEPRECATION) 2993 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 2994 if isinstance( 2995 self.categorical_column, 2996 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 2997 raise ValueError( 2998 'In embedding_column: {}. ' 2999 'categorical_column must not be of type _SequenceCategoricalColumn. ' 3000 'Suggested fix A: If you wish to use DenseFeatures, use a ' 3001 'non-sequence categorical_column_with_*. ' 3002 'Suggested fix B: If you wish to create sequence input, use ' 3003 'SequenceFeatures instead of DenseFeatures. ' 3004 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3005 self.categorical_column)) 3006 sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access 3007 inputs, weight_collections, trainable) 3008 return self._old_get_dense_tensor_internal(sparse_tensors, 3009 weight_collections, trainable) 3010 3011 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 3012 """See `SequenceDenseColumn` base class.""" 3013 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 3014 raise ValueError( 3015 'In embedding_column: {}. ' 3016 'categorical_column must be of type SequenceCategoricalColumn ' 3017 'to use SequenceFeatures. ' 3018 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3019 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3020 self.categorical_column)) 3021 sparse_tensors = self.categorical_column.get_sparse_tensors( 3022 transformation_cache, state_manager) 3023 dense_tensor = self._get_dense_tensor_internal(sparse_tensors, 3024 state_manager) 3025 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3026 sparse_tensors.id_tensor) 3027 return SequenceDenseColumn.TensorSequenceLengthPair( 3028 dense_tensor=dense_tensor, sequence_length=sequence_length) 3029 3030 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3031 _FEATURE_COLUMN_DEPRECATION) 3032 def _get_sequence_dense_tensor(self, 3033 inputs, 3034 weight_collections=None, 3035 trainable=None): 3036 if not isinstance( 3037 self.categorical_column, 3038 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 3039 raise ValueError( 3040 'In embedding_column: {}. ' 3041 'categorical_column must be of type SequenceCategoricalColumn ' 3042 'to use SequenceFeatures. ' 3043 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3044 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3045 self.categorical_column)) 3046 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 3047 dense_tensor = self._old_get_dense_tensor_internal( 3048 sparse_tensors, 3049 weight_collections=weight_collections, 3050 trainable=trainable) 3051 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3052 sparse_tensors.id_tensor) 3053 return SequenceDenseColumn.TensorSequenceLengthPair( 3054 dense_tensor=dense_tensor, sequence_length=sequence_length) 3055 3056 @property 3057 def parents(self): 3058 """See 'FeatureColumn` base class.""" 3059 return [self.categorical_column] 3060 3061 def get_config(self): 3062 """See 'FeatureColumn` base class.""" 3063 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 3064 config = dict(zip(self._fields, self)) 3065 config['categorical_column'] = serialization.serialize_feature_column( 3066 self.categorical_column) 3067 config['initializer'] = serialization._serialize_keras_object( # pylint: disable=protected-access 3068 self.initializer) 3069 return config 3070 3071 @classmethod 3072 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3073 """See 'FeatureColumn` base class.""" 3074 if 'use_safe_embedding_lookup' not in config: 3075 config['use_safe_embedding_lookup'] = True 3076 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 3077 _check_config_keys(config, cls._fields) 3078 kwargs = _standardize_and_copy_config(config) 3079 kwargs['categorical_column'] = serialization.deserialize_feature_column( 3080 config['categorical_column'], custom_objects, columns_by_name) 3081 all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass)) 3082 kwargs['initializer'] = serialization._deserialize_keras_object( # pylint: disable=protected-access 3083 config['initializer'], 3084 module_objects=all_initializers, 3085 custom_objects=custom_objects) 3086 return cls(**kwargs) 3087 3088 3089def _raise_shared_embedding_column_error(): 3090 raise ValueError('SharedEmbeddingColumns are not supported in ' 3091 '`linear_model` or `input_layer`. Please use ' 3092 '`DenseFeatures` or `LinearModel` instead.') 3093 3094 3095class SharedEmbeddingColumnCreator(tracking.AutoTrackable): 3096 3097 def __init__(self, 3098 dimension, 3099 initializer, 3100 ckpt_to_load_from, 3101 tensor_name_in_ckpt, 3102 num_buckets, 3103 trainable, 3104 name='shared_embedding_column_creator', 3105 use_safe_embedding_lookup=True): 3106 self._dimension = dimension 3107 self._initializer = initializer 3108 self._ckpt_to_load_from = ckpt_to_load_from 3109 self._tensor_name_in_ckpt = tensor_name_in_ckpt 3110 self._num_buckets = num_buckets 3111 self._trainable = trainable 3112 self._name = name 3113 self._use_safe_embedding_lookup = use_safe_embedding_lookup 3114 # Map from graph keys to embedding_weight variables. 3115 self._embedding_weights = {} 3116 3117 def __call__(self, categorical_column, combiner, max_norm): 3118 return SharedEmbeddingColumn(categorical_column, self, combiner, max_norm, 3119 self._use_safe_embedding_lookup) 3120 3121 @property 3122 def embedding_weights(self): 3123 key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 3124 if key not in self._embedding_weights: 3125 embedding_shape = (self._num_buckets, self._dimension) 3126 var = variable_scope.get_variable( 3127 name=self._name, 3128 shape=embedding_shape, 3129 dtype=dtypes.float32, 3130 initializer=self._initializer, 3131 trainable=self._trainable) 3132 3133 if self._ckpt_to_load_from is not None: 3134 to_restore = var 3135 if isinstance(to_restore, variables.PartitionedVariable): 3136 to_restore = to_restore._get_variable_list() # pylint: disable=protected-access 3137 checkpoint_utils.init_from_checkpoint( 3138 self._ckpt_to_load_from, {self._tensor_name_in_ckpt: to_restore}) 3139 self._embedding_weights[key] = var 3140 return self._embedding_weights[key] 3141 3142 @property 3143 def dimension(self): 3144 return self._dimension 3145 3146 3147class SharedEmbeddingColumn( 3148 DenseColumn, 3149 SequenceDenseColumn, 3150 fc_old._DenseColumn, # pylint: disable=protected-access 3151 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 3152 collections.namedtuple( 3153 'SharedEmbeddingColumn', 3154 ('categorical_column', 'shared_embedding_column_creator', 'combiner', 3155 'max_norm', 'use_safe_embedding_lookup'))): 3156 """See `embedding_column`.""" 3157 3158 def __new__(cls, 3159 categorical_column, 3160 shared_embedding_column_creator, 3161 combiner, 3162 max_norm, 3163 use_safe_embedding_lookup=True): 3164 return super(SharedEmbeddingColumn, cls).__new__( 3165 cls, 3166 categorical_column=categorical_column, 3167 shared_embedding_column_creator=shared_embedding_column_creator, 3168 combiner=combiner, 3169 max_norm=max_norm, 3170 use_safe_embedding_lookup=use_safe_embedding_lookup) 3171 3172 @property 3173 def _is_v2_column(self): 3174 return True 3175 3176 @property 3177 def name(self): 3178 """See `FeatureColumn` base class.""" 3179 return '{}_shared_embedding'.format(self.categorical_column.name) 3180 3181 @property 3182 def parse_example_spec(self): 3183 """See `FeatureColumn` base class.""" 3184 return self.categorical_column.parse_example_spec 3185 3186 @property 3187 def _parse_example_spec(self): 3188 return _raise_shared_embedding_column_error() 3189 3190 def transform_feature(self, transformation_cache, state_manager): 3191 """See `FeatureColumn` base class.""" 3192 return transformation_cache.get(self.categorical_column, state_manager) 3193 3194 def _transform_feature(self, inputs): 3195 return _raise_shared_embedding_column_error() 3196 3197 @property 3198 def variable_shape(self): 3199 """See `DenseColumn` base class.""" 3200 return tensor_shape.TensorShape( 3201 [self.shared_embedding_column_creator.dimension]) 3202 3203 @property 3204 def _variable_shape(self): 3205 return _raise_shared_embedding_column_error() 3206 3207 def _get_dense_tensor_internal(self, transformation_cache, state_manager): 3208 """Private method that follows the signature of _get_dense_tensor.""" 3209 # This method is called from a variable_scope with name _var_scope_name, 3210 # which is shared among all shared embeddings. Open a name_scope here, so 3211 # that the ops for different columns have distinct names. 3212 with ops.name_scope(None, default_name=self.name): 3213 # Get sparse IDs and weights. 3214 sparse_tensors = self.categorical_column.get_sparse_tensors( 3215 transformation_cache, state_manager) 3216 sparse_ids = sparse_tensors.id_tensor 3217 sparse_weights = sparse_tensors.weight_tensor 3218 3219 embedding_weights = self.shared_embedding_column_creator.embedding_weights 3220 3221 sparse_id_rank = tensor_shape.dimension_value( 3222 sparse_ids.dense_shape.get_shape()[0]) 3223 embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse 3224 if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and 3225 sparse_id_rank <= 2): 3226 embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2 3227 # Return embedding lookup result. 3228 return embedding_lookup_sparse( 3229 embedding_weights, 3230 sparse_ids, 3231 sparse_weights, 3232 combiner=self.combiner, 3233 name='%s_weights' % self.name, 3234 max_norm=self.max_norm) 3235 3236 def get_dense_tensor(self, transformation_cache, state_manager): 3237 """Returns the embedding lookup result.""" 3238 if isinstance(self.categorical_column, SequenceCategoricalColumn): 3239 raise ValueError( 3240 'In embedding_column: {}. ' 3241 'categorical_column must not be of type SequenceCategoricalColumn. ' 3242 'Suggested fix A: If you wish to use DenseFeatures, use a ' 3243 'non-sequence categorical_column_with_*. ' 3244 'Suggested fix B: If you wish to create sequence input, use ' 3245 'SequenceFeatures instead of DenseFeatures. ' 3246 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3247 self.categorical_column)) 3248 return self._get_dense_tensor_internal(transformation_cache, state_manager) 3249 3250 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 3251 return _raise_shared_embedding_column_error() 3252 3253 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 3254 """See `SequenceDenseColumn` base class.""" 3255 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 3256 raise ValueError( 3257 'In embedding_column: {}. ' 3258 'categorical_column must be of type SequenceCategoricalColumn ' 3259 'to use SequenceFeatures. ' 3260 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3261 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3262 self.categorical_column)) 3263 dense_tensor = self._get_dense_tensor_internal(transformation_cache, 3264 state_manager) 3265 sparse_tensors = self.categorical_column.get_sparse_tensors( 3266 transformation_cache, state_manager) 3267 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3268 sparse_tensors.id_tensor) 3269 return SequenceDenseColumn.TensorSequenceLengthPair( 3270 dense_tensor=dense_tensor, sequence_length=sequence_length) 3271 3272 def _get_sequence_dense_tensor(self, 3273 inputs, 3274 weight_collections=None, 3275 trainable=None): 3276 return _raise_shared_embedding_column_error() 3277 3278 @property 3279 def parents(self): 3280 """See 'FeatureColumn` base class.""" 3281 return [self.categorical_column] 3282 3283 3284def _check_shape(shape, key): 3285 """Returns shape if it's valid, raises error otherwise.""" 3286 assert shape is not None 3287 if not nest.is_sequence(shape): 3288 shape = [shape] 3289 shape = tuple(shape) 3290 for dimension in shape: 3291 if not isinstance(dimension, int): 3292 raise TypeError('shape dimensions must be integer. ' 3293 'shape: {}, key: {}'.format(shape, key)) 3294 if dimension < 1: 3295 raise ValueError('shape dimensions must be greater than 0. ' 3296 'shape: {}, key: {}'.format(shape, key)) 3297 return shape 3298 3299 3300class HashedCategoricalColumn( 3301 CategoricalColumn, 3302 fc_old._CategoricalColumn, # pylint: disable=protected-access 3303 collections.namedtuple('HashedCategoricalColumn', 3304 ('key', 'hash_bucket_size', 'dtype'))): 3305 """see `categorical_column_with_hash_bucket`.""" 3306 3307 @property 3308 def _is_v2_column(self): 3309 return True 3310 3311 @property 3312 def name(self): 3313 """See `FeatureColumn` base class.""" 3314 return self.key 3315 3316 @property 3317 def parse_example_spec(self): 3318 """See `FeatureColumn` base class.""" 3319 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3320 3321 @property 3322 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3323 _FEATURE_COLUMN_DEPRECATION) 3324 def _parse_example_spec(self): 3325 return self.parse_example_spec 3326 3327 def _transform_input_tensor(self, input_tensor): 3328 """Hashes the values in the feature_column.""" 3329 if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 3330 raise ValueError('SparseColumn input must be a SparseTensor.') 3331 3332 fc_utils.assert_string_or_int( 3333 input_tensor.dtype, 3334 prefix='column_name: {} input_tensor'.format(self.key)) 3335 3336 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3337 raise ValueError( 3338 'Column dtype and SparseTensors dtype must be compatible. ' 3339 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3340 self.key, self.dtype, input_tensor.dtype)) 3341 3342 if self.dtype == dtypes.string: 3343 sparse_values = input_tensor.values 3344 else: 3345 sparse_values = string_ops.as_string(input_tensor.values) 3346 3347 sparse_id_values = string_ops.string_to_hash_bucket_fast( 3348 sparse_values, self.hash_bucket_size, name='lookup') 3349 return sparse_tensor_lib.SparseTensor( 3350 input_tensor.indices, sparse_id_values, input_tensor.dense_shape) 3351 3352 def transform_feature(self, transformation_cache, state_manager): 3353 """Hashes the values in the feature_column.""" 3354 input_tensor = _to_sparse_input_and_drop_ignore_values( 3355 transformation_cache.get(self.key, state_manager)) 3356 return self._transform_input_tensor(input_tensor) 3357 3358 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3359 _FEATURE_COLUMN_DEPRECATION) 3360 def _transform_feature(self, inputs): 3361 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3362 return self._transform_input_tensor(input_tensor) 3363 3364 @property 3365 def num_buckets(self): 3366 """Returns number of buckets in this sparse feature.""" 3367 return self.hash_bucket_size 3368 3369 @property 3370 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3371 _FEATURE_COLUMN_DEPRECATION) 3372 def _num_buckets(self): 3373 return self.num_buckets 3374 3375 def get_sparse_tensors(self, transformation_cache, state_manager): 3376 """See `CategoricalColumn` base class.""" 3377 return CategoricalColumn.IdWeightPair( 3378 transformation_cache.get(self, state_manager), None) 3379 3380 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3381 _FEATURE_COLUMN_DEPRECATION) 3382 def _get_sparse_tensors(self, inputs, weight_collections=None, 3383 trainable=None): 3384 del weight_collections 3385 del trainable 3386 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3387 3388 @property 3389 def parents(self): 3390 """See 'FeatureColumn` base class.""" 3391 return [self.key] 3392 3393 def get_config(self): 3394 """See 'FeatureColumn` base class.""" 3395 config = dict(zip(self._fields, self)) 3396 config['dtype'] = self.dtype.name 3397 return config 3398 3399 @classmethod 3400 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3401 """See 'FeatureColumn` base class.""" 3402 _check_config_keys(config, cls._fields) 3403 kwargs = _standardize_and_copy_config(config) 3404 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3405 return cls(**kwargs) 3406 3407 3408class VocabularyFileCategoricalColumn( 3409 CategoricalColumn, 3410 fc_old._CategoricalColumn, # pylint: disable=protected-access 3411 collections.namedtuple('VocabularyFileCategoricalColumn', 3412 ('key', 'vocabulary_file', 'vocabulary_size', 3413 'num_oov_buckets', 'dtype', 'default_value'))): 3414 """See `categorical_column_with_vocabulary_file`.""" 3415 3416 @property 3417 def _is_v2_column(self): 3418 return True 3419 3420 @property 3421 def name(self): 3422 """See `FeatureColumn` base class.""" 3423 return self.key 3424 3425 @property 3426 def parse_example_spec(self): 3427 """See `FeatureColumn` base class.""" 3428 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3429 3430 @property 3431 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3432 _FEATURE_COLUMN_DEPRECATION) 3433 def _parse_example_spec(self): 3434 return self.parse_example_spec 3435 3436 def _transform_input_tensor(self, input_tensor, state_manager=None): 3437 """Creates a lookup table for the vocabulary.""" 3438 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3439 raise ValueError( 3440 'Column dtype and SparseTensors dtype must be compatible. ' 3441 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3442 self.key, self.dtype, input_tensor.dtype)) 3443 3444 fc_utils.assert_string_or_int( 3445 input_tensor.dtype, 3446 prefix='column_name: {} input_tensor'.format(self.key)) 3447 3448 key_dtype = self.dtype 3449 if input_tensor.dtype.is_integer: 3450 # `index_table_from_file` requires 64-bit integer keys. 3451 key_dtype = dtypes.int64 3452 input_tensor = math_ops.cast(input_tensor, dtypes.int64) 3453 3454 name = '{}_lookup'.format(self.key) 3455 if state_manager is None or not state_manager.has_resource(self, name): 3456 with ops.init_scope(): 3457 table = lookup_ops.index_table_from_file( 3458 vocabulary_file=self.vocabulary_file, 3459 num_oov_buckets=self.num_oov_buckets, 3460 vocab_size=self.vocabulary_size, 3461 default_value=self.default_value, 3462 key_dtype=key_dtype, 3463 name=name) 3464 if state_manager is not None: 3465 state_manager.add_resource(self, name, table) 3466 else: 3467 # Reuse the table from the previous run. 3468 table = state_manager.get_resource(self, name) 3469 return table.lookup(input_tensor) 3470 3471 def transform_feature(self, transformation_cache, state_manager): 3472 """Creates a lookup table for the vocabulary.""" 3473 input_tensor = _to_sparse_input_and_drop_ignore_values( 3474 transformation_cache.get(self.key, state_manager)) 3475 return self._transform_input_tensor(input_tensor, state_manager) 3476 3477 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3478 _FEATURE_COLUMN_DEPRECATION) 3479 def _transform_feature(self, inputs): 3480 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3481 return self._transform_input_tensor(input_tensor) 3482 3483 @property 3484 def num_buckets(self): 3485 """Returns number of buckets in this sparse feature.""" 3486 return self.vocabulary_size + self.num_oov_buckets 3487 3488 @property 3489 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3490 _FEATURE_COLUMN_DEPRECATION) 3491 def _num_buckets(self): 3492 return self.num_buckets 3493 3494 def get_sparse_tensors(self, transformation_cache, state_manager): 3495 """See `CategoricalColumn` base class.""" 3496 return CategoricalColumn.IdWeightPair( 3497 transformation_cache.get(self, state_manager), None) 3498 3499 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3500 _FEATURE_COLUMN_DEPRECATION) 3501 def _get_sparse_tensors(self, inputs, weight_collections=None, 3502 trainable=None): 3503 del weight_collections 3504 del trainable 3505 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3506 3507 @property 3508 def parents(self): 3509 """See 'FeatureColumn` base class.""" 3510 return [self.key] 3511 3512 def get_config(self): 3513 """See 'FeatureColumn` base class.""" 3514 config = dict(zip(self._fields, self)) 3515 config['dtype'] = self.dtype.name 3516 return config 3517 3518 @classmethod 3519 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3520 """See 'FeatureColumn` base class.""" 3521 _check_config_keys(config, cls._fields) 3522 kwargs = _standardize_and_copy_config(config) 3523 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3524 return cls(**kwargs) 3525 3526 3527class VocabularyListCategoricalColumn( 3528 CategoricalColumn, 3529 fc_old._CategoricalColumn, # pylint: disable=protected-access 3530 collections.namedtuple( 3531 'VocabularyListCategoricalColumn', 3532 ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets')) 3533): 3534 """See `categorical_column_with_vocabulary_list`.""" 3535 3536 @property 3537 def _is_v2_column(self): 3538 return True 3539 3540 @property 3541 def name(self): 3542 """See `FeatureColumn` base class.""" 3543 return self.key 3544 3545 @property 3546 def parse_example_spec(self): 3547 """See `FeatureColumn` base class.""" 3548 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3549 3550 @property 3551 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3552 _FEATURE_COLUMN_DEPRECATION) 3553 def _parse_example_spec(self): 3554 return self.parse_example_spec 3555 3556 def _transform_input_tensor(self, input_tensor, state_manager=None): 3557 """Creates a lookup table for the vocabulary list.""" 3558 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3559 raise ValueError( 3560 'Column dtype and SparseTensors dtype must be compatible. ' 3561 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3562 self.key, self.dtype, input_tensor.dtype)) 3563 3564 fc_utils.assert_string_or_int( 3565 input_tensor.dtype, 3566 prefix='column_name: {} input_tensor'.format(self.key)) 3567 3568 key_dtype = self.dtype 3569 if input_tensor.dtype.is_integer: 3570 # `index_table_from_tensor` requires 64-bit integer keys. 3571 key_dtype = dtypes.int64 3572 input_tensor = math_ops.cast(input_tensor, dtypes.int64) 3573 3574 name = '{}_lookup'.format(self.key) 3575 if state_manager is None or not state_manager.has_resource(self, name): 3576 with ops.init_scope(): 3577 table = lookup_ops.index_table_from_tensor( 3578 vocabulary_list=tuple(self.vocabulary_list), 3579 default_value=self.default_value, 3580 num_oov_buckets=self.num_oov_buckets, 3581 dtype=key_dtype, 3582 name=name) 3583 if state_manager is not None: 3584 state_manager.add_resource(self, name, table) 3585 else: 3586 # Reuse the table from the previous run. 3587 table = state_manager.get_resource(self, name) 3588 return table.lookup(input_tensor) 3589 3590 def transform_feature(self, transformation_cache, state_manager): 3591 """Creates a lookup table for the vocabulary list.""" 3592 input_tensor = _to_sparse_input_and_drop_ignore_values( 3593 transformation_cache.get(self.key, state_manager)) 3594 return self._transform_input_tensor(input_tensor, state_manager) 3595 3596 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3597 _FEATURE_COLUMN_DEPRECATION) 3598 def _transform_feature(self, inputs): 3599 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3600 return self._transform_input_tensor(input_tensor) 3601 3602 @property 3603 def num_buckets(self): 3604 """Returns number of buckets in this sparse feature.""" 3605 return len(self.vocabulary_list) + self.num_oov_buckets 3606 3607 @property 3608 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3609 _FEATURE_COLUMN_DEPRECATION) 3610 def _num_buckets(self): 3611 return self.num_buckets 3612 3613 def get_sparse_tensors(self, transformation_cache, state_manager): 3614 """See `CategoricalColumn` base class.""" 3615 return CategoricalColumn.IdWeightPair( 3616 transformation_cache.get(self, state_manager), None) 3617 3618 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3619 _FEATURE_COLUMN_DEPRECATION) 3620 def _get_sparse_tensors(self, inputs, weight_collections=None, 3621 trainable=None): 3622 del weight_collections 3623 del trainable 3624 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3625 3626 @property 3627 def parents(self): 3628 """See 'FeatureColumn` base class.""" 3629 return [self.key] 3630 3631 def get_config(self): 3632 """See 'FeatureColumn` base class.""" 3633 config = dict(zip(self._fields, self)) 3634 config['dtype'] = self.dtype.name 3635 return config 3636 3637 @classmethod 3638 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3639 """See 'FeatureColumn` base class.""" 3640 _check_config_keys(config, cls._fields) 3641 kwargs = _standardize_and_copy_config(config) 3642 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3643 return cls(**kwargs) 3644 3645 3646class IdentityCategoricalColumn( 3647 CategoricalColumn, 3648 fc_old._CategoricalColumn, # pylint: disable=protected-access 3649 collections.namedtuple('IdentityCategoricalColumn', 3650 ('key', 'number_buckets', 'default_value'))): 3651 3652 """See `categorical_column_with_identity`.""" 3653 3654 @property 3655 def _is_v2_column(self): 3656 return True 3657 3658 @property 3659 def name(self): 3660 """See `FeatureColumn` base class.""" 3661 return self.key 3662 3663 @property 3664 def parse_example_spec(self): 3665 """See `FeatureColumn` base class.""" 3666 return {self.key: parsing_ops.VarLenFeature(dtypes.int64)} 3667 3668 @property 3669 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3670 _FEATURE_COLUMN_DEPRECATION) 3671 def _parse_example_spec(self): 3672 return self.parse_example_spec 3673 3674 def _transform_input_tensor(self, input_tensor): 3675 """Returns a SparseTensor with identity values.""" 3676 if not input_tensor.dtype.is_integer: 3677 raise ValueError( 3678 'Invalid input, not integer. key: {} dtype: {}'.format( 3679 self.key, input_tensor.dtype)) 3680 values = input_tensor.values 3681 if input_tensor.values.dtype != dtypes.int64: 3682 values = math_ops.cast(values, dtypes.int64, name='values') 3683 if self.default_value is not None: 3684 values = math_ops.cast(input_tensor.values, dtypes.int64, name='values') 3685 num_buckets = math_ops.cast( 3686 self.num_buckets, dtypes.int64, name='num_buckets') 3687 zero = math_ops.cast(0, dtypes.int64, name='zero') 3688 # Assign default for out-of-range values. 3689 values = array_ops.where_v2( 3690 math_ops.logical_or( 3691 values < zero, values >= num_buckets, name='out_of_range'), 3692 array_ops.fill( 3693 dims=array_ops.shape(values), 3694 value=math_ops.cast(self.default_value, dtypes.int64), 3695 name='default_values'), values) 3696 3697 return sparse_tensor_lib.SparseTensor( 3698 indices=input_tensor.indices, 3699 values=values, 3700 dense_shape=input_tensor.dense_shape) 3701 3702 def transform_feature(self, transformation_cache, state_manager): 3703 """Returns a SparseTensor with identity values.""" 3704 input_tensor = _to_sparse_input_and_drop_ignore_values( 3705 transformation_cache.get(self.key, state_manager)) 3706 return self._transform_input_tensor(input_tensor) 3707 3708 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3709 _FEATURE_COLUMN_DEPRECATION) 3710 def _transform_feature(self, inputs): 3711 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3712 return self._transform_input_tensor(input_tensor) 3713 3714 @property 3715 def num_buckets(self): 3716 """Returns number of buckets in this sparse feature.""" 3717 return self.number_buckets 3718 3719 @property 3720 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3721 _FEATURE_COLUMN_DEPRECATION) 3722 def _num_buckets(self): 3723 return self.num_buckets 3724 3725 def get_sparse_tensors(self, transformation_cache, state_manager): 3726 """See `CategoricalColumn` base class.""" 3727 return CategoricalColumn.IdWeightPair( 3728 transformation_cache.get(self, state_manager), None) 3729 3730 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3731 _FEATURE_COLUMN_DEPRECATION) 3732 def _get_sparse_tensors(self, inputs, weight_collections=None, 3733 trainable=None): 3734 del weight_collections 3735 del trainable 3736 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3737 3738 @property 3739 def parents(self): 3740 """See 'FeatureColumn` base class.""" 3741 return [self.key] 3742 3743 def get_config(self): 3744 """See 'FeatureColumn` base class.""" 3745 return dict(zip(self._fields, self)) 3746 3747 @classmethod 3748 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3749 """See 'FeatureColumn` base class.""" 3750 _check_config_keys(config, cls._fields) 3751 kwargs = _standardize_and_copy_config(config) 3752 return cls(**kwargs) 3753 3754 3755class WeightedCategoricalColumn( 3756 CategoricalColumn, 3757 fc_old._CategoricalColumn, # pylint: disable=protected-access 3758 collections.namedtuple( 3759 'WeightedCategoricalColumn', 3760 ('categorical_column', 'weight_feature_key', 'dtype'))): 3761 """See `weighted_categorical_column`.""" 3762 3763 @property 3764 def _is_v2_column(self): 3765 return (isinstance(self.categorical_column, FeatureColumn) and 3766 self.categorical_column._is_v2_column) # pylint: disable=protected-access 3767 3768 @property 3769 def name(self): 3770 """See `FeatureColumn` base class.""" 3771 return '{}_weighted_by_{}'.format( 3772 self.categorical_column.name, self.weight_feature_key) 3773 3774 @property 3775 def parse_example_spec(self): 3776 """See `FeatureColumn` base class.""" 3777 config = self.categorical_column.parse_example_spec 3778 if self.weight_feature_key in config: 3779 raise ValueError('Parse config {} already exists for {}.'.format( 3780 config[self.weight_feature_key], self.weight_feature_key)) 3781 config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype) 3782 return config 3783 3784 @property 3785 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3786 _FEATURE_COLUMN_DEPRECATION) 3787 def _parse_example_spec(self): 3788 config = self.categorical_column._parse_example_spec # pylint: disable=protected-access 3789 if self.weight_feature_key in config: 3790 raise ValueError('Parse config {} already exists for {}.'.format( 3791 config[self.weight_feature_key], self.weight_feature_key)) 3792 config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype) 3793 return config 3794 3795 @property 3796 def num_buckets(self): 3797 """See `DenseColumn` base class.""" 3798 return self.categorical_column.num_buckets 3799 3800 @property 3801 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3802 _FEATURE_COLUMN_DEPRECATION) 3803 def _num_buckets(self): 3804 return self.categorical_column._num_buckets # pylint: disable=protected-access 3805 3806 def _transform_weight_tensor(self, weight_tensor): 3807 if weight_tensor is None: 3808 raise ValueError('Missing weights {}.'.format(self.weight_feature_key)) 3809 weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 3810 weight_tensor) 3811 if self.dtype != weight_tensor.dtype.base_dtype: 3812 raise ValueError('Bad dtype, expected {}, but got {}.'.format( 3813 self.dtype, weight_tensor.dtype)) 3814 if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor): 3815 # The weight tensor can be a regular Tensor. In this case, sparsify it. 3816 weight_tensor = _to_sparse_input_and_drop_ignore_values( 3817 weight_tensor, ignore_value=0.0) 3818 if not weight_tensor.dtype.is_floating: 3819 weight_tensor = math_ops.cast(weight_tensor, dtypes.float32) 3820 return weight_tensor 3821 3822 def transform_feature(self, transformation_cache, state_manager): 3823 """Applies weights to tensor generated from `categorical_column`'.""" 3824 weight_tensor = transformation_cache.get(self.weight_feature_key, 3825 state_manager) 3826 sparse_weight_tensor = self._transform_weight_tensor(weight_tensor) 3827 sparse_categorical_tensor = _to_sparse_input_and_drop_ignore_values( 3828 transformation_cache.get(self.categorical_column, state_manager)) 3829 return (sparse_categorical_tensor, sparse_weight_tensor) 3830 3831 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3832 _FEATURE_COLUMN_DEPRECATION) 3833 def _transform_feature(self, inputs): 3834 """Applies weights to tensor generated from `categorical_column`'.""" 3835 weight_tensor = inputs.get(self.weight_feature_key) 3836 weight_tensor = self._transform_weight_tensor(weight_tensor) 3837 return (inputs.get(self.categorical_column), weight_tensor) 3838 3839 def get_sparse_tensors(self, transformation_cache, state_manager): 3840 """See `CategoricalColumn` base class.""" 3841 tensors = transformation_cache.get(self, state_manager) 3842 return CategoricalColumn.IdWeightPair(tensors[0], tensors[1]) 3843 3844 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3845 _FEATURE_COLUMN_DEPRECATION) 3846 def _get_sparse_tensors(self, inputs, weight_collections=None, 3847 trainable=None): 3848 del weight_collections 3849 del trainable 3850 tensors = inputs.get(self) 3851 return CategoricalColumn.IdWeightPair(tensors[0], tensors[1]) 3852 3853 @property 3854 def parents(self): 3855 """See 'FeatureColumn` base class.""" 3856 return [self.categorical_column, self.weight_feature_key] 3857 3858 def get_config(self): 3859 """See 'FeatureColumn` base class.""" 3860 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 3861 config = dict(zip(self._fields, self)) 3862 config['categorical_column'] = serialize_feature_column( 3863 self.categorical_column) 3864 config['dtype'] = self.dtype.name 3865 return config 3866 3867 @classmethod 3868 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3869 """See 'FeatureColumn` base class.""" 3870 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 3871 _check_config_keys(config, cls._fields) 3872 kwargs = _standardize_and_copy_config(config) 3873 kwargs['categorical_column'] = deserialize_feature_column( 3874 config['categorical_column'], custom_objects, columns_by_name) 3875 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3876 return cls(**kwargs) 3877 3878 3879class CrossedColumn( 3880 CategoricalColumn, 3881 fc_old._CategoricalColumn, # pylint: disable=protected-access 3882 collections.namedtuple('CrossedColumn', 3883 ('keys', 'hash_bucket_size', 'hash_key'))): 3884 """See `crossed_column`.""" 3885 3886 @property 3887 def _is_v2_column(self): 3888 for key in _collect_leaf_level_keys(self): 3889 if isinstance(key, six.string_types): 3890 continue 3891 if not isinstance(key, FeatureColumn): 3892 return False 3893 if not key._is_v2_column: # pylint: disable=protected-access 3894 return False 3895 return True 3896 3897 @property 3898 def name(self): 3899 """See `FeatureColumn` base class.""" 3900 feature_names = [] 3901 for key in _collect_leaf_level_keys(self): 3902 if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)): # pylint: disable=protected-access 3903 feature_names.append(key.name) 3904 else: # key must be a string 3905 feature_names.append(key) 3906 return '_X_'.join(sorted(feature_names)) 3907 3908 @property 3909 def parse_example_spec(self): 3910 """See `FeatureColumn` base class.""" 3911 config = {} 3912 for key in self.keys: 3913 if isinstance(key, FeatureColumn): 3914 config.update(key.parse_example_spec) 3915 elif isinstance(key, fc_old._FeatureColumn): # pylint: disable=protected-access 3916 config.update(key._parse_example_spec) # pylint: disable=protected-access 3917 else: # key must be a string 3918 config.update({key: parsing_ops.VarLenFeature(dtypes.string)}) 3919 return config 3920 3921 @property 3922 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3923 _FEATURE_COLUMN_DEPRECATION) 3924 def _parse_example_spec(self): 3925 return self.parse_example_spec 3926 3927 def transform_feature(self, transformation_cache, state_manager): 3928 """Generates a hashed sparse cross from the input tensors.""" 3929 feature_tensors = [] 3930 for key in _collect_leaf_level_keys(self): 3931 if isinstance(key, six.string_types): 3932 feature_tensors.append(transformation_cache.get(key, state_manager)) 3933 elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)): # pylint: disable=protected-access 3934 ids_and_weights = key.get_sparse_tensors(transformation_cache, 3935 state_manager) 3936 if ids_and_weights.weight_tensor is not None: 3937 raise ValueError( 3938 'crossed_column does not support weight_tensor, but the given ' 3939 'column populates weight_tensor. ' 3940 'Given column: {}'.format(key.name)) 3941 feature_tensors.append(ids_and_weights.id_tensor) 3942 else: 3943 raise ValueError('Unsupported column type. Given: {}'.format(key)) 3944 return sparse_ops.sparse_cross_hashed( 3945 inputs=feature_tensors, 3946 num_buckets=self.hash_bucket_size, 3947 hash_key=self.hash_key) 3948 3949 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3950 _FEATURE_COLUMN_DEPRECATION) 3951 def _transform_feature(self, inputs): 3952 """Generates a hashed sparse cross from the input tensors.""" 3953 feature_tensors = [] 3954 for key in _collect_leaf_level_keys(self): 3955 if isinstance(key, six.string_types): 3956 feature_tensors.append(inputs.get(key)) 3957 elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access 3958 ids_and_weights = key._get_sparse_tensors(inputs) # pylint: disable=protected-access 3959 if ids_and_weights.weight_tensor is not None: 3960 raise ValueError( 3961 'crossed_column does not support weight_tensor, but the given ' 3962 'column populates weight_tensor. ' 3963 'Given column: {}'.format(key.name)) 3964 feature_tensors.append(ids_and_weights.id_tensor) 3965 else: 3966 raise ValueError('Unsupported column type. Given: {}'.format(key)) 3967 return sparse_ops.sparse_cross_hashed( 3968 inputs=feature_tensors, 3969 num_buckets=self.hash_bucket_size, 3970 hash_key=self.hash_key) 3971 3972 @property 3973 def num_buckets(self): 3974 """Returns number of buckets in this sparse feature.""" 3975 return self.hash_bucket_size 3976 3977 @property 3978 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3979 _FEATURE_COLUMN_DEPRECATION) 3980 def _num_buckets(self): 3981 return self.num_buckets 3982 3983 def get_sparse_tensors(self, transformation_cache, state_manager): 3984 """See `CategoricalColumn` base class.""" 3985 return CategoricalColumn.IdWeightPair( 3986 transformation_cache.get(self, state_manager), None) 3987 3988 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3989 _FEATURE_COLUMN_DEPRECATION) 3990 def _get_sparse_tensors(self, inputs, weight_collections=None, 3991 trainable=None): 3992 """See `CategoricalColumn` base class.""" 3993 del weight_collections 3994 del trainable 3995 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3996 3997 @property 3998 def parents(self): 3999 """See 'FeatureColumn` base class.""" 4000 return list(self.keys) 4001 4002 def get_config(self): 4003 """See 'FeatureColumn` base class.""" 4004 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4005 config = dict(zip(self._fields, self)) 4006 config['keys'] = tuple([serialize_feature_column(fc) for fc in self.keys]) 4007 return config 4008 4009 @classmethod 4010 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4011 """See 'FeatureColumn` base class.""" 4012 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4013 _check_config_keys(config, cls._fields) 4014 kwargs = _standardize_and_copy_config(config) 4015 kwargs['keys'] = tuple([ 4016 deserialize_feature_column(c, custom_objects, columns_by_name) 4017 for c in config['keys'] 4018 ]) 4019 return cls(**kwargs) 4020 4021 4022def _collect_leaf_level_keys(cross): 4023 """Collects base keys by expanding all nested crosses. 4024 4025 Args: 4026 cross: A `CrossedColumn`. 4027 4028 Returns: 4029 A list of strings or `CategoricalColumn` instances. 4030 """ 4031 leaf_level_keys = [] 4032 for k in cross.keys: 4033 if isinstance(k, CrossedColumn): 4034 leaf_level_keys.extend(_collect_leaf_level_keys(k)) 4035 else: 4036 leaf_level_keys.append(k) 4037 return leaf_level_keys 4038 4039 4040def _prune_invalid_ids(sparse_ids, sparse_weights): 4041 """Prune invalid IDs (< 0) from the input ids and weights.""" 4042 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) 4043 if sparse_weights is not None: 4044 is_id_valid = math_ops.logical_and( 4045 is_id_valid, 4046 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) 4047 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) 4048 if sparse_weights is not None: 4049 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) 4050 return sparse_ids, sparse_weights 4051 4052 4053def _prune_invalid_weights(sparse_ids, sparse_weights): 4054 """Prune invalid weights (< 0) from the input ids and weights.""" 4055 if sparse_weights is not None: 4056 is_weights_valid = math_ops.greater(sparse_weights.values, 0) 4057 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) 4058 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) 4059 return sparse_ids, sparse_weights 4060 4061 4062class IndicatorColumn( 4063 DenseColumn, 4064 SequenceDenseColumn, 4065 fc_old._DenseColumn, # pylint: disable=protected-access 4066 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 4067 collections.namedtuple('IndicatorColumn', ('categorical_column'))): 4068 """Represents a one-hot column for use in deep networks. 4069 4070 Args: 4071 categorical_column: A `CategoricalColumn` which is created by 4072 `categorical_column_with_*` function. 4073 """ 4074 4075 @property 4076 def _is_v2_column(self): 4077 return (isinstance(self.categorical_column, FeatureColumn) and 4078 self.categorical_column._is_v2_column) # pylint: disable=protected-access 4079 4080 @property 4081 def name(self): 4082 """See `FeatureColumn` base class.""" 4083 return '{}_indicator'.format(self.categorical_column.name) 4084 4085 def _transform_id_weight_pair(self, id_weight_pair, size): 4086 id_tensor = id_weight_pair.id_tensor 4087 weight_tensor = id_weight_pair.weight_tensor 4088 4089 # If the underlying column is weighted, return the input as a dense tensor. 4090 if weight_tensor is not None: 4091 weighted_column = sparse_ops.sparse_merge( 4092 sp_ids=id_tensor, sp_values=weight_tensor, vocab_size=int(size)) 4093 # Remove (?, -1) index. 4094 weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0], 4095 weighted_column.dense_shape) 4096 # Use scatter_nd to merge duplicated indices if existed, 4097 # instead of sparse_tensor_to_dense. 4098 return array_ops.scatter_nd(weighted_column.indices, 4099 weighted_column.values, 4100 weighted_column.dense_shape) 4101 4102 dense_id_tensor = sparse_ops.sparse_tensor_to_dense( 4103 id_tensor, default_value=-1) 4104 4105 # One hot must be float for tf.concat reasons since all other inputs to 4106 # input_layer are float32. 4107 one_hot_id_tensor = array_ops.one_hot( 4108 dense_id_tensor, depth=size, on_value=1.0, off_value=0.0) 4109 4110 # Reduce to get a multi-hot per example. 4111 return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2]) 4112 4113 def transform_feature(self, transformation_cache, state_manager): 4114 """Returns dense `Tensor` representing feature. 4115 4116 Args: 4117 transformation_cache: A `FeatureTransformationCache` object to access 4118 features. 4119 state_manager: A `StateManager` to create / access resources such as 4120 lookup tables. 4121 4122 Returns: 4123 Transformed feature `Tensor`. 4124 4125 Raises: 4126 ValueError: if input rank is not known at graph building time. 4127 """ 4128 id_weight_pair = self.categorical_column.get_sparse_tensors( 4129 transformation_cache, state_manager) 4130 return self._transform_id_weight_pair(id_weight_pair, 4131 self.variable_shape[-1]) 4132 4133 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4134 _FEATURE_COLUMN_DEPRECATION) 4135 def _transform_feature(self, inputs): 4136 id_weight_pair = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4137 return self._transform_id_weight_pair(id_weight_pair, 4138 self._variable_shape[-1]) 4139 4140 @property 4141 def parse_example_spec(self): 4142 """See `FeatureColumn` base class.""" 4143 return self.categorical_column.parse_example_spec 4144 4145 @property 4146 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4147 _FEATURE_COLUMN_DEPRECATION) 4148 def _parse_example_spec(self): 4149 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 4150 4151 @property 4152 def variable_shape(self): 4153 """Returns a `TensorShape` representing the shape of the dense `Tensor`.""" 4154 if isinstance(self.categorical_column, FeatureColumn): 4155 return tensor_shape.TensorShape([1, self.categorical_column.num_buckets]) 4156 else: 4157 return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access 4158 4159 @property 4160 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4161 _FEATURE_COLUMN_DEPRECATION) 4162 def _variable_shape(self): 4163 return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access 4164 4165 def get_dense_tensor(self, transformation_cache, state_manager): 4166 """Returns dense `Tensor` representing feature. 4167 4168 Args: 4169 transformation_cache: A `FeatureTransformationCache` object to access 4170 features. 4171 state_manager: A `StateManager` to create / access resources such as 4172 lookup tables. 4173 4174 Returns: 4175 Dense `Tensor` created within `transform_feature`. 4176 4177 Raises: 4178 ValueError: If `categorical_column` is a `SequenceCategoricalColumn`. 4179 """ 4180 if isinstance(self.categorical_column, SequenceCategoricalColumn): 4181 raise ValueError( 4182 'In indicator_column: {}. ' 4183 'categorical_column must not be of type SequenceCategoricalColumn. ' 4184 'Suggested fix A: If you wish to use DenseFeatures, use a ' 4185 'non-sequence categorical_column_with_*. ' 4186 'Suggested fix B: If you wish to create sequence input, use ' 4187 'SequenceFeatures instead of DenseFeatures. ' 4188 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4189 self.categorical_column)) 4190 # Feature has been already transformed. Return the intermediate 4191 # representation created by transform_feature. 4192 return transformation_cache.get(self, state_manager) 4193 4194 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4195 _FEATURE_COLUMN_DEPRECATION) 4196 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 4197 del weight_collections 4198 del trainable 4199 if isinstance( 4200 self.categorical_column, 4201 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 4202 raise ValueError( 4203 'In indicator_column: {}. ' 4204 'categorical_column must not be of type _SequenceCategoricalColumn. ' 4205 'Suggested fix A: If you wish to use DenseFeatures, use a ' 4206 'non-sequence categorical_column_with_*. ' 4207 'Suggested fix B: If you wish to create sequence input, use ' 4208 'SequenceFeatures instead of DenseFeatures. ' 4209 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4210 self.categorical_column)) 4211 # Feature has been already transformed. Return the intermediate 4212 # representation created by transform_feature. 4213 return inputs.get(self) 4214 4215 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 4216 """See `SequenceDenseColumn` base class.""" 4217 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 4218 raise ValueError( 4219 'In indicator_column: {}. ' 4220 'categorical_column must be of type SequenceCategoricalColumn ' 4221 'to use SequenceFeatures. ' 4222 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 4223 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4224 self.categorical_column)) 4225 # Feature has been already transformed. Return the intermediate 4226 # representation created by transform_feature. 4227 dense_tensor = transformation_cache.get(self, state_manager) 4228 sparse_tensors = self.categorical_column.get_sparse_tensors( 4229 transformation_cache, state_manager) 4230 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 4231 sparse_tensors.id_tensor) 4232 return SequenceDenseColumn.TensorSequenceLengthPair( 4233 dense_tensor=dense_tensor, sequence_length=sequence_length) 4234 4235 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4236 _FEATURE_COLUMN_DEPRECATION) 4237 def _get_sequence_dense_tensor(self, 4238 inputs, 4239 weight_collections=None, 4240 trainable=None): 4241 # Do nothing with weight_collections and trainable since no variables are 4242 # created in this function. 4243 del weight_collections 4244 del trainable 4245 if not isinstance( 4246 self.categorical_column, 4247 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 4248 raise ValueError( 4249 'In indicator_column: {}. ' 4250 'categorical_column must be of type _SequenceCategoricalColumn ' 4251 'to use SequenceFeatures. ' 4252 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 4253 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4254 self.categorical_column)) 4255 # Feature has been already transformed. Return the intermediate 4256 # representation created by _transform_feature. 4257 dense_tensor = inputs.get(self) 4258 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4259 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 4260 sparse_tensors.id_tensor) 4261 return SequenceDenseColumn.TensorSequenceLengthPair( 4262 dense_tensor=dense_tensor, sequence_length=sequence_length) 4263 4264 @property 4265 def parents(self): 4266 """See 'FeatureColumn` base class.""" 4267 return [self.categorical_column] 4268 4269 def get_config(self): 4270 """See 'FeatureColumn` base class.""" 4271 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4272 config = dict(zip(self._fields, self)) 4273 config['categorical_column'] = serialize_feature_column( 4274 self.categorical_column) 4275 return config 4276 4277 @classmethod 4278 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4279 """See 'FeatureColumn` base class.""" 4280 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4281 _check_config_keys(config, cls._fields) 4282 kwargs = _standardize_and_copy_config(config) 4283 kwargs['categorical_column'] = deserialize_feature_column( 4284 config['categorical_column'], custom_objects, columns_by_name) 4285 return cls(**kwargs) 4286 4287 4288def _verify_static_batch_size_equality(tensors, columns): 4289 """Verify equality between static batch sizes. 4290 4291 Args: 4292 tensors: iterable of input tensors. 4293 columns: Corresponding feature columns. 4294 4295 Raises: 4296 ValueError: in case of mismatched batch sizes. 4297 """ 4298 # bath_size is a Dimension object. 4299 expected_batch_size = None 4300 for i in range(0, len(tensors)): 4301 batch_size = tensor_shape.Dimension(tensor_shape.dimension_value( 4302 tensors[i].shape[0])) 4303 if batch_size.value is not None: 4304 if expected_batch_size is None: 4305 bath_size_column_index = i 4306 expected_batch_size = batch_size 4307 elif not expected_batch_size.is_compatible_with(batch_size): 4308 raise ValueError( 4309 'Batch size (first dimension) of each feature must be same. ' 4310 'Batch size of columns ({}, {}): ({}, {})'.format( 4311 columns[bath_size_column_index].name, columns[i].name, 4312 expected_batch_size, batch_size)) 4313 4314 4315class SequenceCategoricalColumn( 4316 CategoricalColumn, 4317 fc_old._SequenceCategoricalColumn, # pylint: disable=protected-access 4318 collections.namedtuple('SequenceCategoricalColumn', 4319 ('categorical_column'))): 4320 """Represents sequences of categorical data.""" 4321 4322 @property 4323 def _is_v2_column(self): 4324 return (isinstance(self.categorical_column, FeatureColumn) and 4325 self.categorical_column._is_v2_column) # pylint: disable=protected-access 4326 4327 @property 4328 def name(self): 4329 """See `FeatureColumn` base class.""" 4330 return self.categorical_column.name 4331 4332 @property 4333 def parse_example_spec(self): 4334 """See `FeatureColumn` base class.""" 4335 return self.categorical_column.parse_example_spec 4336 4337 @property 4338 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4339 _FEATURE_COLUMN_DEPRECATION) 4340 def _parse_example_spec(self): 4341 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 4342 4343 def transform_feature(self, transformation_cache, state_manager): 4344 """See `FeatureColumn` base class.""" 4345 return self.categorical_column.transform_feature(transformation_cache, 4346 state_manager) 4347 4348 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4349 _FEATURE_COLUMN_DEPRECATION) 4350 def _transform_feature(self, inputs): 4351 return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access 4352 4353 @property 4354 def num_buckets(self): 4355 """Returns number of buckets in this sparse feature.""" 4356 return self.categorical_column.num_buckets 4357 4358 @property 4359 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4360 _FEATURE_COLUMN_DEPRECATION) 4361 def _num_buckets(self): 4362 return self.categorical_column._num_buckets # pylint: disable=protected-access 4363 4364 def _get_sparse_tensors_helper(self, sparse_tensors): 4365 id_tensor = sparse_tensors.id_tensor 4366 weight_tensor = sparse_tensors.weight_tensor 4367 # Expands third dimension, if necessary so that embeddings are not 4368 # combined during embedding lookup. If the tensor is already 3D, leave 4369 # as-is. 4370 shape = array_ops.shape(id_tensor) 4371 # Compute the third dimension explicitly instead of setting it to -1, as 4372 # that doesn't work for dynamically shaped tensors with 0-length at runtime. 4373 # This happens for empty sequences. 4374 target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])] 4375 id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape) 4376 if weight_tensor is not None: 4377 weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape) 4378 return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) 4379 4380 def get_sparse_tensors(self, transformation_cache, state_manager): 4381 """Returns an IdWeightPair. 4382 4383 `IdWeightPair` is a pair of `SparseTensor`s which represents ids and 4384 weights. 4385 4386 `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets` 4387 `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a 4388 `SparseTensor` of `float` or `None` to indicate all weights should be 4389 taken to be 1. If specified, `weight_tensor` must have exactly the same 4390 shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing 4391 output of a `VarLenFeature` which is a ragged matrix. 4392 4393 Args: 4394 transformation_cache: A `FeatureTransformationCache` object to access 4395 features. 4396 state_manager: A `StateManager` to create / access resources such as 4397 lookup tables. 4398 """ 4399 sparse_tensors = self.categorical_column.get_sparse_tensors( 4400 transformation_cache, state_manager) 4401 return self._get_sparse_tensors_helper(sparse_tensors) 4402 4403 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4404 _FEATURE_COLUMN_DEPRECATION) 4405 def _get_sparse_tensors(self, inputs, weight_collections=None, 4406 trainable=None): 4407 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4408 return self._get_sparse_tensors_helper(sparse_tensors) 4409 4410 @property 4411 def parents(self): 4412 """See 'FeatureColumn` base class.""" 4413 return [self.categorical_column] 4414 4415 def get_config(self): 4416 """See 'FeatureColumn` base class.""" 4417 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4418 config = dict(zip(self._fields, self)) 4419 config['categorical_column'] = serialize_feature_column( 4420 self.categorical_column) 4421 return config 4422 4423 @classmethod 4424 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4425 """See 'FeatureColumn` base class.""" 4426 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4427 _check_config_keys(config, cls._fields) 4428 kwargs = _standardize_and_copy_config(config) 4429 kwargs['categorical_column'] = deserialize_feature_column( 4430 config['categorical_column'], custom_objects, columns_by_name) 4431 return cls(**kwargs) 4432 4433 4434def _check_config_keys(config, expected_keys): 4435 """Checks that a config has all expected_keys.""" 4436 if set(config.keys()) != set(expected_keys): 4437 raise ValueError('Invalid config: {}, expected keys: {}'.format( 4438 config, expected_keys)) 4439 4440 4441def _standardize_and_copy_config(config): 4442 """Returns a shallow copy of config with lists turned to tuples. 4443 4444 Keras serialization uses nest to listify everything. 4445 This causes problems with the NumericColumn shape, which becomes 4446 unhashable. We could try to solve this on the Keras side, but that 4447 would require lots of tracking to avoid changing existing behavior. 4448 Instead, we ensure here that we revive correctly. 4449 4450 Args: 4451 config: dict that will be used to revive a Feature Column 4452 4453 Returns: 4454 Shallow copy of config with lists turned to tuples. 4455 """ 4456 kwargs = config.copy() 4457 for k, v in kwargs.items(): 4458 if isinstance(v, list): 4459 kwargs[k] = tuple(v) 4460 4461 return kwargs 4462 4463 4464def _sanitize_column_name_for_variable_scope(name): 4465 """Sanitizes user-provided feature names for use as variable scopes.""" 4466 invalid_char = re.compile('[^A-Za-z0-9_.\\-]') 4467 return invalid_char.sub('_', name) 4468