• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This API defines FeatureColumn abstraction.
16
17FeatureColumns provide a high level abstraction for ingesting and representing
18features. FeatureColumns are also the primary way of encoding features for
19canned `tf.estimator.Estimator`s.
20
21When using FeatureColumns with `Estimators`, the type of feature column you
22should choose depends on (1) the feature type and (2) the model type.
23
241. Feature type:
25
26  * Continuous features can be represented by `numeric_column`.
27  * Categorical features can be represented by any `categorical_column_with_*`
28  column:
29    - `categorical_column_with_vocabulary_list`
30    - `categorical_column_with_vocabulary_file`
31    - `categorical_column_with_hash_bucket`
32    - `categorical_column_with_identity`
33    - `weighted_categorical_column`
34
352. Model type:
36
37  * Deep neural network models (`DNNClassifier`, `DNNRegressor`).
38
39    Continuous features can be directly fed into deep neural network models.
40
41      age_column = numeric_column("age")
42
43    To feed sparse features into DNN models, wrap the column with
44    `embedding_column` or `indicator_column`. `indicator_column` is recommended
45    for features with only a few possible values. For features with many
46    possible values, to reduce the size of your model, `embedding_column` is
47    recommended.
48
49      embedded_dept_column = embedding_column(
50          categorical_column_with_vocabulary_list(
51              "department", ["math", "philosophy", ...]), dimension=10)
52
53  * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
54
55    Sparse features can be fed directly into linear models. They behave like an
56    indicator column but with an efficient implementation.
57
58      dept_column = categorical_column_with_vocabulary_list("department",
59          ["math", "philosophy", "english"])
60
61    It is recommended that continuous features be bucketized before being
62    fed into linear models.
63
64      bucketized_age_column = bucketized_column(
65          source_column=age_column,
66          boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
67
68    Sparse features can be crossed (also known as conjuncted or combined) in
69    order to form non-linearities, and then fed into linear models.
70
71      cross_dept_age_column = crossed_column(
72          columns=["department", bucketized_age_column],
73          hash_bucket_size=1000)
74
75Example of building canned `Estimator`s using FeatureColumns:
76
77  ```python
78  # Define features and transformations
79  deep_feature_columns = [age_column, embedded_dept_column]
80  wide_feature_columns = [dept_column, bucketized_age_column,
81      cross_dept_age_column]
82
83  # Build deep model
84  estimator = DNNClassifier(
85      feature_columns=deep_feature_columns,
86      hidden_units=[500, 250, 50])
87  estimator.train(...)
88
89  # Or build a wide model
90  estimator = LinearClassifier(
91      feature_columns=wide_feature_columns)
92  estimator.train(...)
93
94  # Or build a wide and deep model!
95  estimator = DNNLinearCombinedClassifier(
96      linear_feature_columns=wide_feature_columns,
97      dnn_feature_columns=deep_feature_columns,
98      dnn_hidden_units=[500, 250, 50])
99  estimator.train(...)
100  ```
101
102
103FeatureColumns can also be transformed into a generic input layer for
104custom models using `input_layer`.
105
106Example of building model using FeatureColumns, this can be used in a
107`model_fn` which is given to the {tf.estimator.Estimator}:
108
109  ```python
110  # Building model via layers
111
112  deep_feature_columns = [age_column, embedded_dept_column]
113  columns_to_tensor = parse_feature_columns_from_examples(
114      serialized=my_data,
115      feature_columns=deep_feature_columns)
116  first_layer = input_layer(
117      features=columns_to_tensor,
118      feature_columns=deep_feature_columns)
119  second_layer = fully_connected(first_layer, ...)
120  ```
121
122NOTE: Functions prefixed with "_" indicate experimental or private parts of
123the API subject to change, and should not be relied upon!
124"""
125
126from __future__ import absolute_import
127from __future__ import division
128from __future__ import print_function
129
130import abc
131import collections
132import math
133
134import re
135import numpy as np
136import six
137
138from tensorflow.python.eager import context
139from tensorflow.python.feature_column import feature_column as fc_old
140from tensorflow.python.feature_column import utils as fc_utils
141from tensorflow.python.framework import dtypes
142from tensorflow.python.framework import ops
143from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
144from tensorflow.python.framework import tensor_shape
145# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
146# of the main repo.
147from tensorflow.python.keras import initializers
148from tensorflow.python.keras.utils import generic_utils
149from tensorflow.python.keras.engine import training
150from tensorflow.python.keras.engine.base_layer import Layer
151from tensorflow.python.ops import array_ops
152from tensorflow.python.ops import check_ops
153from tensorflow.python.ops import control_flow_ops
154from tensorflow.python.ops import embedding_ops
155from tensorflow.python.ops import init_ops
156from tensorflow.python.ops import lookup_ops
157from tensorflow.python.ops import math_ops
158from tensorflow.python.ops import nn_ops
159from tensorflow.python.ops import parsing_ops
160from tensorflow.python.ops import sparse_ops
161from tensorflow.python.ops import string_ops
162from tensorflow.python.ops import variable_scope
163from tensorflow.python.ops import variables
164from tensorflow.python.platform import gfile
165from tensorflow.python.platform import tf_logging as logging
166from tensorflow.python.training import checkpoint_utils
167from tensorflow.python.training.tracking import tracking
168from tensorflow.python.training.tracking import base as trackable
169from tensorflow.python.util import deprecation
170from tensorflow.python.util import nest
171from tensorflow.python.util.tf_export import tf_export
172from tensorflow.python.util.compat import collections_abc
173
174
175_FEATURE_COLUMN_DEPRECATION_DATE = None
176_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being '
177                               'deprecated. Please use the new FeatureColumn '
178                               'APIs instead.')
179
180
181class StateManager(object):
182  """Manages the state associated with FeatureColumns.
183
184  Some `FeatureColumn`s create variables or resources to assist their
185  computation. The `StateManager` is responsible for creating and storing these
186  objects since `FeatureColumn`s are supposed to be stateless configuration
187  only.
188  """
189
190  def create_variable(self,
191                      feature_column,
192                      name,
193                      shape,
194                      dtype=None,
195                      trainable=True,
196                      use_resource=True,
197                      initializer=None):
198    """Creates a new variable.
199
200    Args:
201      feature_column: A `FeatureColumn` object this variable corresponds to.
202      name: variable name.
203      shape: variable shape.
204      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
205      trainable: Whether this variable is trainable or not.
206      use_resource: If true, we use resource variables. Otherwise we use
207        RefVariable.
208      initializer: initializer instance (callable).
209
210    Returns:
211      The created variable.
212    """
213    del feature_column, name, shape, dtype, trainable, use_resource, initializer
214    raise NotImplementedError('StateManager.create_variable')
215
216  def add_variable(self, feature_column, var):
217    """Adds an existing variable to the state.
218
219    Args:
220      feature_column: A `FeatureColumn` object to associate this variable with.
221      var: The variable.
222    """
223    del feature_column, var
224    raise NotImplementedError('StateManager.add_variable')
225
226  def get_variable(self, feature_column, name):
227    """Returns an existing variable.
228
229    Args:
230      feature_column: A `FeatureColumn` object this variable corresponds to.
231      name: variable name.
232    """
233    del feature_column, name
234    raise NotImplementedError('StateManager.get_var')
235
236  def add_resource(self, feature_column, name, resource):
237    """Creates a new resource.
238
239    Resources can be things such as tables etc.
240
241    Args:
242      feature_column: A `FeatureColumn` object this resource corresponds to.
243      name: Name of the resource.
244      resource: The resource.
245
246    Returns:
247      The created resource.
248    """
249    del feature_column, name, resource
250    raise NotImplementedError('StateManager.add_resource')
251
252  def get_resource(self, feature_column, name):
253    """Returns an already created resource.
254
255    Resources can be things such as tables etc.
256
257    Args:
258      feature_column: A `FeatureColumn` object this variable corresponds to.
259      name: Name of the resource.
260    """
261    del feature_column, name
262    raise NotImplementedError('StateManager.get_resource')
263
264
265class _StateManagerImpl(StateManager):
266  """Manages the state of DenseFeatures and LinearLayer."""
267
268  def __init__(self, layer, trainable):
269    """Creates an _StateManagerImpl object.
270
271    Args:
272      layer: The input layer this state manager is associated with.
273      trainable: Whether by default, variables created are trainable or not.
274    """
275    self._trainable = trainable
276    self._layer = layer
277    if self._layer is not None and not hasattr(self._layer, '_resources'):
278      self._layer._resources = []  # pylint: disable=protected-access
279    self._cols_to_vars_map = collections.defaultdict(lambda: {})
280    # TODO(vbardiovsky): Make sure the resources are tracked by moving them to
281    # the layer (inheriting from AutoTrackable), e.g.:
282    # self._layer._resources_map = data_structures.Mapping()
283    self._cols_to_resources_map = collections.defaultdict(lambda: {})
284
285  def create_variable(self,
286                      feature_column,
287                      name,
288                      shape,
289                      dtype=None,
290                      trainable=True,
291                      use_resource=True,
292                      initializer=None):
293    if name in self._cols_to_vars_map[feature_column]:
294      raise ValueError('Variable already exists.')
295
296    # We explicitly track these variables since `name` is not guaranteed to be
297    # unique and disable manual tracking that the add_weight call does.
298    with trackable.no_manual_dependency_tracking_scope(self._layer):
299      var = self._layer.add_weight(
300          name=name,
301          shape=shape,
302          dtype=dtype,
303          initializer=initializer,
304          trainable=self._trainable and trainable,
305          use_resource=use_resource,
306          # TODO(rohanj): Get rid of this hack once we have a mechanism for
307          # specifying a default partitioner for an entire layer. In that case,
308          # the default getter for Layers should work.
309          getter=variable_scope.get_variable)
310    if isinstance(var, variables.PartitionedVariable):
311      for v in var:
312        part_name = name + '/' + str(v._get_save_slice_info().var_offset[0])  # pylint: disable=protected-access
313        self._layer._track_trackable(v, feature_column.name + '/' + part_name)  # pylint: disable=protected-access
314    else:
315      if isinstance(var, trackable.Trackable):
316        self._layer._track_trackable(var, feature_column.name + '/' + name)  # pylint: disable=protected-access
317
318    self._cols_to_vars_map[feature_column][name] = var
319    return var
320
321  def get_variable(self, feature_column, name):
322    if name in self._cols_to_vars_map[feature_column]:
323      return self._cols_to_vars_map[feature_column][name]
324    raise ValueError('Variable does not exist.')
325
326  def add_resource(self, feature_column, name, resource):
327    self._cols_to_resources_map[feature_column][name] = resource
328    if self._layer is not None:
329      self._layer._resources.append(resource)  # pylint: disable=protected-access
330
331  def get_resource(self, feature_column, name):
332    if name in self._cols_to_resources_map[feature_column]:
333      return self._cols_to_resources_map[feature_column][name]
334    raise ValueError('Resource does not exist.')
335
336
337class _StateManagerImplV2(_StateManagerImpl):
338  """Manages the state of DenseFeatures."""
339
340  def create_variable(self,
341                      feature_column,
342                      name,
343                      shape,
344                      dtype=None,
345                      trainable=True,
346                      use_resource=True,
347                      initializer=None):
348    if name in self._cols_to_vars_map[feature_column]:
349      raise ValueError('Variable already exists.')
350
351    # We explicitly track these variables since `name` is not guaranteed to be
352    # unique and disable manual tracking that the add_weight call does.
353    with trackable.no_manual_dependency_tracking_scope(self._layer):
354      var = self._layer.add_weight(
355          name=name,
356          shape=shape,
357          dtype=dtype,
358          initializer=initializer,
359          trainable=self._trainable and trainable,
360          use_resource=use_resource)
361    if isinstance(var, trackable.Trackable):
362      self._layer._track_trackable(var, feature_column.name + '/' + name)  # pylint: disable=protected-access
363    self._cols_to_vars_map[feature_column][name] = var
364    return var
365
366
367class _BaseFeaturesLayer(Layer):
368  """Base class for DenseFeatures and SequenceFeatures.
369
370  Defines common methods and helpers.
371
372  Args:
373    feature_columns: An iterable containing the FeatureColumns to use as
374      inputs to your model.
375    expected_column_type: Expected class for provided feature columns.
376    trainable:  Boolean, whether the layer's variables will be updated via
377      gradient descent during training.
378    name: Name to give to the DenseFeatures.
379    **kwargs: Keyword arguments to construct a layer.
380
381  Raises:
382    ValueError: if an item in `feature_columns` doesn't match
383      `expected_column_type`.
384  """
385
386  def __init__(self,
387               feature_columns,
388               expected_column_type,
389               trainable,
390               name,
391               partitioner=None,
392               **kwargs):
393    super(_BaseFeaturesLayer, self).__init__(
394        name=name, trainable=trainable, **kwargs)
395    self._feature_columns = _normalize_feature_columns(feature_columns)
396    self._state_manager = _StateManagerImpl(self, self.trainable)
397    self._partitioner = partitioner
398    for column in self._feature_columns:
399      if not isinstance(column, expected_column_type):
400        raise ValueError(
401            'Items of feature_columns must be a {}. '
402            'You can wrap a categorical column with an '
403            'embedding_column or indicator_column. Given: {}'.format(
404                expected_column_type, column))
405
406  def build(self, _):
407    for column in self._feature_columns:
408      with variable_scope._pure_variable_scope(  # pylint: disable=protected-access
409          self.name,
410          partitioner=self._partitioner):
411        with variable_scope._pure_variable_scope(  # pylint: disable=protected-access
412            _sanitize_column_name_for_variable_scope(column.name)):
413          column.create_state(self._state_manager)
414    super(_BaseFeaturesLayer, self).build(None)
415
416  def _output_shape(self, input_shape, num_elements):
417    """Computes expected output shape of the layer or a column's dense tensor.
418
419    Args:
420      input_shape: Tensor or array with batch shape.
421      num_elements: Size of the last dimension of the output.
422
423    Returns:
424      Tuple with output shape.
425    """
426    raise NotImplementedError('Calling an abstract method.')
427
428  def compute_output_shape(self, input_shape):
429    total_elements = 0
430    for column in self._feature_columns:
431      total_elements += column.variable_shape.num_elements()
432    return self._target_shape(input_shape, total_elements)
433
434  def _process_dense_tensor(self, column, tensor):
435    """Reshapes the dense tensor output of a column based on expected shape.
436
437    Args:
438      column: A DenseColumn or SequenceDenseColumn object.
439      tensor: A dense tensor obtained from the same column.
440
441    Returns:
442      Reshaped dense tensor."""
443    num_elements = column.variable_shape.num_elements()
444    target_shape = self._target_shape(array_ops.shape(tensor), num_elements)
445    return array_ops.reshape(tensor, shape=target_shape)
446
447  def _verify_and_concat_tensors(self, output_tensors):
448    """Verifies and concatenates the dense output of several columns."""
449    _verify_static_batch_size_equality(output_tensors, self._feature_columns)
450    return array_ops.concat(output_tensors, -1)
451
452  def get_config(self):
453    # Import here to avoid circular imports.
454    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
455    column_configs = serialization.serialize_feature_columns(
456        self._feature_columns)
457    config = {'feature_columns': column_configs}
458    config['partitioner'] = generic_utils.serialize_keras_object(
459        self._partitioner)
460
461    base_config = super(  # pylint: disable=bad-super-call
462        _BaseFeaturesLayer, self).get_config()
463    return dict(list(base_config.items()) + list(config.items()))
464
465  @classmethod
466  def from_config(cls, config, custom_objects=None):
467    # Import here to avoid circular imports.
468    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
469    config_cp = config.copy()
470    config_cp['feature_columns'] = serialization.deserialize_feature_columns(
471        config['feature_columns'], custom_objects=custom_objects)
472    config_cp['partitioner'] = generic_utils.deserialize_keras_object(
473        config['partitioner'], custom_objects)
474
475    return cls(**config_cp)
476
477
478class _LinearModelLayer(Layer):
479  """Layer that contains logic for `LinearModel`."""
480
481  def __init__(self,
482               feature_columns,
483               units=1,
484               sparse_combiner='sum',
485               trainable=True,
486               name=None,
487               **kwargs):
488    super(_LinearModelLayer, self).__init__(
489        name=name, trainable=trainable, **kwargs)
490
491    self._feature_columns = _normalize_feature_columns(feature_columns)
492    for column in self._feature_columns:
493      if not isinstance(column, (DenseColumn, CategoricalColumn)):
494        raise ValueError(
495            'Items of feature_columns must be either a '
496            'DenseColumn or CategoricalColumn. Given: {}'.format(column))
497
498    self._units = units
499    self._sparse_combiner = sparse_combiner
500
501    self._state_manager = _StateManagerImpl(self, self.trainable)
502    self.bias = None
503
504  def build(self, _):
505    # We need variable scopes for now because we want the variable partitioning
506    # information to percolate down. We also use _pure_variable_scope's here
507    # since we want to open up a name_scope in the `call` method while creating
508    # the ops.
509    with variable_scope._pure_variable_scope(self.name):  # pylint: disable=protected-access
510      for column in self._feature_columns:
511        with variable_scope._pure_variable_scope(  # pylint: disable=protected-access
512            _sanitize_column_name_for_variable_scope(column.name)):
513          # Create the state for each feature column
514          column.create_state(self._state_manager)
515
516          # Create a weight variable for each column.
517          if isinstance(column, CategoricalColumn):
518            first_dim = column.num_buckets
519          else:
520            first_dim = column.variable_shape.num_elements()
521          self._state_manager.create_variable(
522              column,
523              name='weights',
524              dtype=dtypes.float32,
525              shape=(first_dim, self._units),
526              initializer=init_ops.zeros_initializer(),
527              trainable=self.trainable)
528
529      # Create a bias variable.
530      self.bias = self.add_variable(
531          name='bias_weights',
532          dtype=dtypes.float32,
533          shape=[self._units],
534          initializer=init_ops.zeros_initializer(),
535          trainable=self.trainable,
536          use_resource=True,
537          # TODO(rohanj): Get rid of this hack once we have a mechanism for
538          # specifying a default partitioner for an entire layer. In that case,
539          # the default getter for Layers should work.
540          getter=variable_scope.get_variable)
541
542    super(_LinearModelLayer, self).build(None)
543
544  def call(self, features):
545    if not isinstance(features, dict):
546      raise ValueError('We expected a dictionary here. Instead we got: {}'
547                       .format(features))
548    with ops.name_scope(self.name):
549      transformation_cache = FeatureTransformationCache(features)
550      weighted_sums = []
551      for column in self._feature_columns:
552        with ops.name_scope(
553            _sanitize_column_name_for_variable_scope(column.name)):
554          # All the weights used in the linear model are owned by the state
555          # manager associated with this Linear Model.
556          weight_var = self._state_manager.get_variable(column, 'weights')
557
558          weighted_sum = _create_weighted_sum(
559              column=column,
560              transformation_cache=transformation_cache,
561              state_manager=self._state_manager,
562              sparse_combiner=self._sparse_combiner,
563              weight_var=weight_var)
564          weighted_sums.append(weighted_sum)
565
566      _verify_static_batch_size_equality(weighted_sums, self._feature_columns)
567      predictions_no_bias = math_ops.add_n(
568          weighted_sums, name='weighted_sum_no_bias')
569      predictions = nn_ops.bias_add(
570          predictions_no_bias, self.bias, name='weighted_sum')
571      return predictions
572
573  def get_config(self):
574    # Import here to avoid circular imports.
575    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
576    column_configs = serialization.serialize_feature_columns(
577        self._feature_columns)
578    config = {
579        'feature_columns': column_configs,
580        'units': self._units,
581        'sparse_combiner': self._sparse_combiner
582    }
583
584    base_config = super(  # pylint: disable=bad-super-call
585        _LinearModelLayer, self).get_config()
586    return dict(list(base_config.items()) + list(config.items()))
587
588  @classmethod
589  def from_config(cls, config, custom_objects=None):
590    # Import here to avoid circular imports.
591    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
592    config_cp = config.copy()
593    columns = serialization.deserialize_feature_columns(
594        config_cp['feature_columns'], custom_objects=custom_objects)
595
596    del config_cp['feature_columns']
597    return cls(feature_columns=columns, **config_cp)
598
599
600# TODO(tanzheny): Cleanup it with respect to Premade model b/132690565.
601class LinearModel(training.Model):
602  """Produces a linear prediction `Tensor` based on given `feature_columns`.
603
604  This layer generates a weighted sum based on output dimension `units`.
605  Weighted sum refers to logits in classification problems. It refers to the
606  prediction itself for linear regression problems.
607
608  Note on supported columns: `LinearLayer` treats categorical columns as
609  `indicator_column`s. To be specific, assume the input as `SparseTensor` looks
610  like:
611
612  ```python
613    shape = [2, 2]
614    {
615        [0, 0]: "a"
616        [1, 0]: "b"
617        [1, 1]: "c"
618    }
619  ```
620  `linear_model` assigns weights for the presence of "a", "b", "c' implicitly,
621  just like `indicator_column`, while `input_layer` explicitly requires wrapping
622  each of categorical columns with an `embedding_column` or an
623  `indicator_column`.
624
625  Example of usage:
626
627  ```python
628  price = numeric_column('price')
629  price_buckets = bucketized_column(price, boundaries=[0., 10., 100., 1000.])
630  keywords = categorical_column_with_hash_bucket("keywords", 10K)
631  keywords_price = crossed_column('keywords', price_buckets, ...)
632  columns = [price_buckets, keywords, keywords_price ...]
633  linear_model = LinearLayer(columns)
634
635  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
636  prediction = linear_model(features)
637  ```
638  """
639
640  def __init__(self,
641               feature_columns,
642               units=1,
643               sparse_combiner='sum',
644               trainable=True,
645               name=None,
646               **kwargs):
647    """Constructs a LinearLayer.
648
649    Args:
650      feature_columns: An iterable containing the FeatureColumns to use as
651        inputs to your model. All items should be instances of classes derived
652        from `_FeatureColumn`s.
653      units: An integer, dimensionality of the output space. Default value is 1.
654      sparse_combiner: A string specifying how to reduce if a categorical column
655        is multivalent. Except `numeric_column`, almost all columns passed to
656        `linear_model` are considered as categorical columns.  It combines each
657        categorical column independently. Currently "mean", "sqrtn" and "sum"
658        are supported, with "sum" the default for linear model. "sqrtn" often
659        achieves good accuracy, in particular with bag-of-words columns.
660          * "sum": do not normalize features in the column
661          * "mean": do l1 normalization on features in the column
662          * "sqrtn": do l2 normalization on features in the column
663        For example, for two features represented as the categorical columns:
664
665          ```python
666          # Feature 1
667
668          shape = [2, 2]
669          {
670              [0, 0]: "a"
671              [0, 1]: "b"
672              [1, 0]: "c"
673          }
674
675          # Feature 2
676
677          shape = [2, 3]
678          {
679              [0, 0]: "d"
680              [1, 0]: "e"
681              [1, 1]: "f"
682              [1, 2]: "g"
683          }
684          ```
685
686        with `sparse_combiner` as "mean", the linear model outputs conceptly are
687        ```
688        y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
689        y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
690        ```
691        where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
692        assigned to the presence of `x` in the input features.
693      trainable: If `True` also add the variable to the graph collection
694        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
695      name: Name to give to the Linear Model. All variables and ops created will
696        be scoped by this name.
697      **kwargs: Keyword arguments to construct a layer.
698
699    Raises:
700      ValueError: if an item in `feature_columns` is neither a `DenseColumn`
701        nor `CategoricalColumn`.
702    """
703
704    super(LinearModel, self).__init__(name=name, **kwargs)
705    self.layer = _LinearModelLayer(
706        feature_columns,
707        units,
708        sparse_combiner,
709        trainable,
710        name=self.name,
711        **kwargs)
712
713  def call(self, features):
714    """Returns a `Tensor` the represents the predictions of a linear model.
715
716    Args:
717      features: A mapping from key to tensors. `_FeatureColumn`s look up via
718        these keys. For example `numeric_column('price')` will look at 'price'
719        key in this dict. Values are `Tensor` or `SparseTensor` depending on
720        corresponding `_FeatureColumn`.
721
722    Returns:
723      A `Tensor` which represents predictions/logits of a linear model. Its
724      shape is (batch_size, units) and its dtype is `float32`.
725
726    Raises:
727      ValueError: If features are not a dictionary.
728    """
729    return self.layer(features)
730
731  @property
732  def bias(self):
733    return self.layer.bias
734
735
736def _transform_features_v2(features, feature_columns, state_manager):
737  """Returns transformed features based on features columns passed in.
738
739  Please note that most probably you would not need to use this function. Please
740  check `input_layer` and `linear_model` to see whether they will
741  satisfy your use case or not.
742
743  Example:
744
745  ```python
746  # Define features and transformations
747  crosses_a_x_b = crossed_column(
748      columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000)
749  price_buckets = bucketized_column(
750      source_column=numeric_column("price"), boundaries=[...])
751
752  columns = [crosses_a_x_b, price_buckets]
753  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
754  transformed = transform_features(features=features, feature_columns=columns)
755
756  assertCountEqual(columns, transformed.keys())
757  ```
758
759  Args:
760    features: A mapping from key to tensors. `FeatureColumn`s look up via these
761      keys. For example `numeric_column('price')` will look at 'price' key in
762      this dict. Values can be a `SparseTensor` or a `Tensor` depends on
763      corresponding `FeatureColumn`.
764    feature_columns: An iterable containing all the `FeatureColumn`s.
765    state_manager: A StateManager object that holds the FeatureColumn state.
766
767  Returns:
768    A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values.
769  """
770  feature_columns = _normalize_feature_columns(feature_columns)
771  outputs = {}
772  with ops.name_scope(
773      None, default_name='transform_features', values=features.values()):
774    transformation_cache = FeatureTransformationCache(features)
775    for column in feature_columns:
776      with ops.name_scope(
777          None,
778          default_name=_sanitize_column_name_for_variable_scope(column.name)):
779        outputs[column] = transformation_cache.get(column, state_manager)
780  return outputs
781
782
783@tf_export('feature_column.make_parse_example_spec', v1=[])
784def make_parse_example_spec_v2(feature_columns):
785  """Creates parsing spec dictionary from input feature_columns.
786
787  The returned dictionary can be used as arg 'features' in
788  `tf.io.parse_example`.
789
790  Typical usage example:
791
792  ```python
793  # Define features and transformations
794  feature_a = tf.feature_column.categorical_column_with_vocabulary_file(...)
795  feature_b = tf.feature_column.numeric_column(...)
796  feature_c_bucketized = tf.feature_column.bucketized_column(
797      tf.feature_column.numeric_column("feature_c"), ...)
798  feature_a_x_feature_c = tf.feature_column.crossed_column(
799      columns=["feature_a", feature_c_bucketized], ...)
800
801  feature_columns = set(
802      [feature_b, feature_c_bucketized, feature_a_x_feature_c])
803  features = tf.io.parse_example(
804      serialized=serialized_examples,
805      features=tf.feature_column.make_parse_example_spec(feature_columns))
806  ```
807
808  For the above example, make_parse_example_spec would return the dict:
809
810  ```python
811  {
812      "feature_a": parsing_ops.VarLenFeature(tf.string),
813      "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32),
814      "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32)
815  }
816  ```
817
818  Args:
819    feature_columns: An iterable containing all feature columns. All items
820      should be instances of classes derived from `FeatureColumn`.
821
822  Returns:
823    A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`
824    value.
825
826  Raises:
827    ValueError: If any of the given `feature_columns` is not a `FeatureColumn`
828      instance.
829  """
830  result = {}
831  for column in feature_columns:
832    if not isinstance(column, FeatureColumn):
833      raise ValueError('All feature_columns must be FeatureColumn instances. '
834                       'Given: {}'.format(column))
835    config = column.parse_example_spec
836    for key, value in six.iteritems(config):
837      if key in result and value != result[key]:
838        raise ValueError(
839            'feature_columns contain different parse_spec for key '
840            '{}. Given {} and {}'.format(key, value, result[key]))
841    result.update(config)
842  return result
843
844
845@tf_export('feature_column.embedding_column')
846def embedding_column(categorical_column,
847                     dimension,
848                     combiner='mean',
849                     initializer=None,
850                     ckpt_to_load_from=None,
851                     tensor_name_in_ckpt=None,
852                     max_norm=None,
853                     trainable=True,
854                     use_safe_embedding_lookup=True):
855  """`DenseColumn` that converts from sparse, categorical input.
856
857  Use this when your inputs are sparse, but you want to convert them to a dense
858  representation (e.g., to feed to a DNN).
859
860  Inputs must be a `CategoricalColumn` created by any of the
861  `categorical_column_*` function. Here is an example of using
862  `embedding_column` with `DNNClassifier`:
863
864  ```python
865  video_id = categorical_column_with_identity(
866      key='video_id', num_buckets=1000000, default_value=0)
867  columns = [embedding_column(video_id, 9),...]
868
869  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
870
871  label_column = ...
872  def input_fn():
873    features = tf.io.parse_example(
874        ..., features=make_parse_example_spec(columns + [label_column]))
875    labels = features.pop(label_column.name)
876    return features, labels
877
878  estimator.train(input_fn=input_fn, steps=100)
879  ```
880
881  Here is an example using `embedding_column` with model_fn:
882
883  ```python
884  def model_fn(features, ...):
885    video_id = categorical_column_with_identity(
886        key='video_id', num_buckets=1000000, default_value=0)
887    columns = [embedding_column(video_id, 9),...]
888    dense_tensor = input_layer(features, columns)
889    # Form DNN layers, calculate loss, and return EstimatorSpec.
890    ...
891  ```
892
893  Args:
894    categorical_column: A `CategoricalColumn` created by a
895      `categorical_column_with_*` function. This column produces the sparse IDs
896      that are inputs to the embedding lookup.
897    dimension: An integer specifying dimension of the embedding, must be > 0.
898    combiner: A string specifying how to reduce if there are multiple entries in
899      a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
900      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
901      with bag-of-words columns. Each of this can be thought as example level
902      normalizations on the column. For more information, see
903      `tf.embedding_lookup_sparse`.
904    initializer: A variable initializer function to be used in embedding
905      variable initialization. If not specified, defaults to
906      `truncated_normal_initializer` with mean `0.0` and
907      standard deviation `1/sqrt(dimension)`.
908    ckpt_to_load_from: String representing checkpoint name/pattern from which to
909      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
910    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
911      to restore the column weights. Required if `ckpt_to_load_from` is not
912      `None`.
913    max_norm: If not `None`, embedding values are l2-normalized to this value.
914    trainable: Whether or not the embedding is trainable. Default is True.
915    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
916      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
917      there are no empty rows and all weights and ids are positive at the
918      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
919      input tensors. Defaults to true, consider turning off if the above checks
920      are not needed. Note that having empty rows will not trigger any error
921      though the output result might be 0 or omitted.
922
923  Returns:
924    `DenseColumn` that converts from sparse input.
925
926  Raises:
927    ValueError: if `dimension` not > 0.
928    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
929      is specified.
930    ValueError: if `initializer` is specified and is not callable.
931    RuntimeError: If eager execution is enabled.
932  """
933  if (dimension is None) or (dimension < 1):
934    raise ValueError('Invalid dimension {}.'.format(dimension))
935  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
936    raise ValueError('Must specify both `ckpt_to_load_from` and '
937                     '`tensor_name_in_ckpt` or none of them.')
938
939  if (initializer is not None) and (not callable(initializer)):
940    raise ValueError('initializer must be callable if specified. '
941                     'Embedding of column_name: {}'.format(
942                         categorical_column.name))
943  if initializer is None:
944    initializer = init_ops.truncated_normal_initializer(
945        mean=0.0, stddev=1 / math.sqrt(dimension))
946
947  return EmbeddingColumn(
948      categorical_column=categorical_column,
949      dimension=dimension,
950      combiner=combiner,
951      initializer=initializer,
952      ckpt_to_load_from=ckpt_to_load_from,
953      tensor_name_in_ckpt=tensor_name_in_ckpt,
954      max_norm=max_norm,
955      trainable=trainable,
956      use_safe_embedding_lookup=use_safe_embedding_lookup)
957
958
959@tf_export(v1=['feature_column.shared_embedding_columns'])
960def shared_embedding_columns(categorical_columns,
961                             dimension,
962                             combiner='mean',
963                             initializer=None,
964                             shared_embedding_collection_name=None,
965                             ckpt_to_load_from=None,
966                             tensor_name_in_ckpt=None,
967                             max_norm=None,
968                             trainable=True,
969                             use_safe_embedding_lookup=True):
970  """List of dense columns that convert from sparse, categorical input.
971
972  This is similar to `embedding_column`, except that it produces a list of
973  embedding columns that share the same embedding weights.
974
975  Use this when your inputs are sparse and of the same type (e.g. watched and
976  impression video IDs that share the same vocabulary), and you want to convert
977  them to a dense representation (e.g., to feed to a DNN).
978
979  Inputs must be a list of categorical columns created by any of the
980  `categorical_column_*` function. They must all be of the same type and have
981  the same arguments except `key`. E.g. they can be
982  categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
983  all columns could also be weighted_categorical_column.
984
985  Here is an example embedding of two features for a DNNClassifier model:
986
987  ```python
988  watched_video_id = categorical_column_with_vocabulary_file(
989      'watched_video_id', video_vocabulary_file, video_vocabulary_size)
990  impression_video_id = categorical_column_with_vocabulary_file(
991      'impression_video_id', video_vocabulary_file, video_vocabulary_size)
992  columns = shared_embedding_columns(
993      [watched_video_id, impression_video_id], dimension=10)
994
995  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
996
997  label_column = ...
998  def input_fn():
999    features = tf.io.parse_example(
1000        ..., features=make_parse_example_spec(columns + [label_column]))
1001    labels = features.pop(label_column.name)
1002    return features, labels
1003
1004  estimator.train(input_fn=input_fn, steps=100)
1005  ```
1006
1007  Here is an example using `shared_embedding_columns` with model_fn:
1008
1009  ```python
1010  def model_fn(features, ...):
1011    watched_video_id = categorical_column_with_vocabulary_file(
1012        'watched_video_id', video_vocabulary_file, video_vocabulary_size)
1013    impression_video_id = categorical_column_with_vocabulary_file(
1014        'impression_video_id', video_vocabulary_file, video_vocabulary_size)
1015    columns = shared_embedding_columns(
1016        [watched_video_id, impression_video_id], dimension=10)
1017    dense_tensor = input_layer(features, columns)
1018    # Form DNN layers, calculate loss, and return EstimatorSpec.
1019    ...
1020  ```
1021
1022  Args:
1023    categorical_columns: List of categorical columns created by a
1024      `categorical_column_with_*` function. These columns produce the sparse IDs
1025      that are inputs to the embedding lookup. All columns must be of the same
1026      type and have the same arguments except `key`. E.g. they can be
1027      categorical_column_with_vocabulary_file with the same vocabulary_file.
1028      Some or all columns could also be weighted_categorical_column.
1029    dimension: An integer specifying dimension of the embedding, must be > 0.
1030    combiner: A string specifying how to reduce if there are multiple entries in
1031      a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
1032      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
1033      with bag-of-words columns. Each of this can be thought as example level
1034      normalizations on the column. For more information, see
1035      `tf.embedding_lookup_sparse`.
1036    initializer: A variable initializer function to be used in embedding
1037      variable initialization. If not specified, defaults to
1038      `truncated_normal_initializer` with mean `0.0` and
1039      standard deviation `1/sqrt(dimension)`.
1040    shared_embedding_collection_name: Optional name of the collection where
1041      shared embedding weights are added. If not given, a reasonable name will
1042      be chosen based on the names of `categorical_columns`. This is also used
1043      in `variable_scope` when creating shared embedding weights.
1044    ckpt_to_load_from: String representing checkpoint name/pattern from which to
1045      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
1046    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
1047      to restore the column weights. Required if `ckpt_to_load_from` is not
1048      `None`.
1049    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
1050      than this value, before combining.
1051    trainable: Whether or not the embedding is trainable. Default is True.
1052    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
1053      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
1054      there are no empty rows and all weights and ids are positive at the
1055      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
1056      input tensors. Defaults to true, consider turning off if the above checks
1057      are not needed. Note that having empty rows will not trigger any error
1058      though the output result might be 0 or omitted.
1059
1060  Returns:
1061    A list of dense columns that converts from sparse input. The order of
1062    results follows the ordering of `categorical_columns`.
1063
1064  Raises:
1065    ValueError: if `dimension` not > 0.
1066    ValueError: if any of the given `categorical_columns` is of different type
1067      or has different arguments than the others.
1068    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
1069      is specified.
1070    ValueError: if `initializer` is specified and is not callable.
1071    RuntimeError: if eager execution is enabled.
1072  """
1073  if context.executing_eagerly():
1074    raise RuntimeError('shared_embedding_columns are not supported when eager '
1075                       'execution is enabled.')
1076
1077  if (dimension is None) or (dimension < 1):
1078    raise ValueError('Invalid dimension {}.'.format(dimension))
1079  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
1080    raise ValueError('Must specify both `ckpt_to_load_from` and '
1081                     '`tensor_name_in_ckpt` or none of them.')
1082
1083  if (initializer is not None) and (not callable(initializer)):
1084    raise ValueError('initializer must be callable if specified.')
1085  if initializer is None:
1086    initializer = init_ops.truncated_normal_initializer(
1087        mean=0.0, stddev=1. / math.sqrt(dimension))
1088
1089  # Sort the columns so the default collection name is deterministic even if the
1090  # user passes columns from an unsorted collection, such as dict.values().
1091  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
1092
1093  c0 = sorted_columns[0]
1094  num_buckets = c0._num_buckets  # pylint: disable=protected-access
1095  if not isinstance(c0, fc_old._CategoricalColumn):  # pylint: disable=protected-access
1096    raise ValueError(
1097        'All categorical_columns must be subclasses of _CategoricalColumn. '
1098        'Given: {}, of type: {}'.format(c0, type(c0)))
1099  while isinstance(
1100      c0, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn,  # pylint: disable=protected-access
1101           fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)):  # pylint: disable=protected-access
1102    c0 = c0.categorical_column
1103  for c in sorted_columns[1:]:
1104    while isinstance(
1105        c, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn,  # pylint: disable=protected-access
1106            fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)):  # pylint: disable=protected-access
1107      c = c.categorical_column
1108    if not isinstance(c, type(c0)):
1109      raise ValueError(
1110          'To use shared_embedding_column, all categorical_columns must have '
1111          'the same type, or be weighted_categorical_column or sequence column '
1112          'of the same type. Given column: {} of type: {} does not match given '
1113          'column: {} of type: {}'.format(c0, type(c0), c, type(c)))
1114    if num_buckets != c._num_buckets:  # pylint: disable=protected-access
1115      raise ValueError(
1116          'To use shared_embedding_column, all categorical_columns must have '
1117          'the same number of buckets. Given column: {} with buckets: {} does  '
1118          'not match column: {} with buckets: {}'.format(
1119              c0, num_buckets, c, c._num_buckets))  # pylint: disable=protected-access
1120
1121  if not shared_embedding_collection_name:
1122    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
1123    shared_embedding_collection_name += '_shared_embedding'
1124
1125  result = []
1126  for column in categorical_columns:
1127    result.append(
1128        fc_old._SharedEmbeddingColumn(  # pylint: disable=protected-access
1129            categorical_column=column,
1130            initializer=initializer,
1131            dimension=dimension,
1132            combiner=combiner,
1133            shared_embedding_collection_name=shared_embedding_collection_name,
1134            ckpt_to_load_from=ckpt_to_load_from,
1135            tensor_name_in_ckpt=tensor_name_in_ckpt,
1136            max_norm=max_norm,
1137            trainable=trainable,
1138            use_safe_embedding_lookup=use_safe_embedding_lookup))
1139
1140  return result
1141
1142
1143@tf_export('feature_column.shared_embeddings', v1=[])
1144def shared_embedding_columns_v2(categorical_columns,
1145                                dimension,
1146                                combiner='mean',
1147                                initializer=None,
1148                                shared_embedding_collection_name=None,
1149                                ckpt_to_load_from=None,
1150                                tensor_name_in_ckpt=None,
1151                                max_norm=None,
1152                                trainable=True,
1153                                use_safe_embedding_lookup=True):
1154  """List of dense columns that convert from sparse, categorical input.
1155
1156  This is similar to `embedding_column`, except that it produces a list of
1157  embedding columns that share the same embedding weights.
1158
1159  Use this when your inputs are sparse and of the same type (e.g. watched and
1160  impression video IDs that share the same vocabulary), and you want to convert
1161  them to a dense representation (e.g., to feed to a DNN).
1162
1163  Inputs must be a list of categorical columns created by any of the
1164  `categorical_column_*` function. They must all be of the same type and have
1165  the same arguments except `key`. E.g. they can be
1166  categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
1167  all columns could also be weighted_categorical_column.
1168
1169  Here is an example embedding of two features for a DNNClassifier model:
1170
1171  ```python
1172  watched_video_id = categorical_column_with_vocabulary_file(
1173      'watched_video_id', video_vocabulary_file, video_vocabulary_size)
1174  impression_video_id = categorical_column_with_vocabulary_file(
1175      'impression_video_id', video_vocabulary_file, video_vocabulary_size)
1176  columns = shared_embedding_columns(
1177      [watched_video_id, impression_video_id], dimension=10)
1178
1179  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
1180
1181  label_column = ...
1182  def input_fn():
1183    features = tf.io.parse_example(
1184        ..., features=make_parse_example_spec(columns + [label_column]))
1185    labels = features.pop(label_column.name)
1186    return features, labels
1187
1188  estimator.train(input_fn=input_fn, steps=100)
1189  ```
1190
1191  Here is an example using `shared_embedding_columns` with model_fn:
1192
1193  ```python
1194  def model_fn(features, ...):
1195    watched_video_id = categorical_column_with_vocabulary_file(
1196        'watched_video_id', video_vocabulary_file, video_vocabulary_size)
1197    impression_video_id = categorical_column_with_vocabulary_file(
1198        'impression_video_id', video_vocabulary_file, video_vocabulary_size)
1199    columns = shared_embedding_columns(
1200        [watched_video_id, impression_video_id], dimension=10)
1201    dense_tensor = input_layer(features, columns)
1202    # Form DNN layers, calculate loss, and return EstimatorSpec.
1203    ...
1204  ```
1205
1206  Args:
1207    categorical_columns: List of categorical columns created by a
1208      `categorical_column_with_*` function. These columns produce the sparse IDs
1209      that are inputs to the embedding lookup. All columns must be of the same
1210      type and have the same arguments except `key`. E.g. they can be
1211      categorical_column_with_vocabulary_file with the same vocabulary_file.
1212      Some or all columns could also be weighted_categorical_column.
1213    dimension: An integer specifying dimension of the embedding, must be > 0.
1214    combiner: A string specifying how to reduce if there are multiple entries
1215      in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
1216      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
1217      with bag-of-words columns. Each of this can be thought as example level
1218      normalizations on the column. For more information, see
1219      `tf.embedding_lookup_sparse`.
1220    initializer: A variable initializer function to be used in embedding
1221      variable initialization. If not specified, defaults to
1222      `truncated_normal_initializer` with mean `0.0` and standard
1223      deviation `1/sqrt(dimension)`.
1224    shared_embedding_collection_name: Optional collective name of these columns.
1225      If not given, a reasonable name will be chosen based on the names of
1226      `categorical_columns`.
1227    ckpt_to_load_from: String representing checkpoint name/pattern from which to
1228      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
1229    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
1230      which to restore the column weights. Required if `ckpt_to_load_from` is
1231      not `None`.
1232    max_norm: If not `None`, each embedding is clipped if its l2-norm is
1233      larger than this value, before combining.
1234    trainable: Whether or not the embedding is trainable. Default is True.
1235    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
1236      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
1237      there are no empty rows and all weights and ids are positive at the
1238      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
1239      input tensors. Defaults to true, consider turning off if the above checks
1240      are not needed. Note that having empty rows will not trigger any error
1241      though the output result might be 0 or omitted.
1242
1243  Returns:
1244    A list of dense columns that converts from sparse input. The order of
1245    results follows the ordering of `categorical_columns`.
1246
1247  Raises:
1248    ValueError: if `dimension` not > 0.
1249    ValueError: if any of the given `categorical_columns` is of different type
1250      or has different arguments than the others.
1251    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
1252      is specified.
1253    ValueError: if `initializer` is specified and is not callable.
1254    RuntimeError: if eager execution is enabled.
1255  """
1256  if context.executing_eagerly():
1257    raise RuntimeError('shared_embedding_columns are not supported when eager '
1258                       'execution is enabled.')
1259
1260  if (dimension is None) or (dimension < 1):
1261    raise ValueError('Invalid dimension {}.'.format(dimension))
1262  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
1263    raise ValueError('Must specify both `ckpt_to_load_from` and '
1264                     '`tensor_name_in_ckpt` or none of them.')
1265
1266  if (initializer is not None) and (not callable(initializer)):
1267    raise ValueError('initializer must be callable if specified.')
1268  if initializer is None:
1269    initializer = init_ops.truncated_normal_initializer(
1270        mean=0.0, stddev=1. / math.sqrt(dimension))
1271
1272  # Sort the columns so the default collection name is deterministic even if the
1273  # user passes columns from an unsorted collection, such as dict.values().
1274  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
1275
1276  c0 = sorted_columns[0]
1277  num_buckets = c0.num_buckets
1278  if not isinstance(c0, CategoricalColumn):
1279    raise ValueError(
1280        'All categorical_columns must be subclasses of CategoricalColumn. '
1281        'Given: {}, of type: {}'.format(c0, type(c0)))
1282  while isinstance(c0, (WeightedCategoricalColumn, SequenceCategoricalColumn)):
1283    c0 = c0.categorical_column
1284  for c in sorted_columns[1:]:
1285    while isinstance(c, (WeightedCategoricalColumn, SequenceCategoricalColumn)):
1286      c = c.categorical_column
1287    if not isinstance(c, type(c0)):
1288      raise ValueError(
1289          'To use shared_embedding_column, all categorical_columns must have '
1290          'the same type, or be weighted_categorical_column or sequence column '
1291          'of the same type. Given column: {} of type: {} does not match given '
1292          'column: {} of type: {}'.format(c0, type(c0), c, type(c)))
1293    if num_buckets != c.num_buckets:
1294      raise ValueError(
1295          'To use shared_embedding_column, all categorical_columns must have '
1296          'the same number of buckets. Given column: {} with buckets: {} does  '
1297          'not match column: {} with buckets: {}'.format(
1298              c0, num_buckets, c, c.num_buckets))
1299
1300  if not shared_embedding_collection_name:
1301    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
1302    shared_embedding_collection_name += '_shared_embedding'
1303
1304  column_creator = SharedEmbeddingColumnCreator(
1305      dimension, initializer, ckpt_to_load_from, tensor_name_in_ckpt,
1306      num_buckets, trainable, shared_embedding_collection_name,
1307      use_safe_embedding_lookup)
1308
1309  result = []
1310  for column in categorical_columns:
1311    result.append(
1312        column_creator(
1313            categorical_column=column, combiner=combiner, max_norm=max_norm))
1314
1315  return result
1316
1317
1318@tf_export('feature_column.numeric_column')
1319def numeric_column(key,
1320                   shape=(1,),
1321                   default_value=None,
1322                   dtype=dtypes.float32,
1323                   normalizer_fn=None):
1324  """Represents real valued or numerical features.
1325
1326  Example:
1327
1328  ```python
1329  price = numeric_column('price')
1330  columns = [price, ...]
1331  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1332  dense_tensor = input_layer(features, columns)
1333
1334  # or
1335  bucketized_price = bucketized_column(price, boundaries=[...])
1336  columns = [bucketized_price, ...]
1337  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1338  linear_prediction = linear_model(features, columns)
1339  ```
1340
1341  Args:
1342    key: A unique string identifying the input feature. It is used as the
1343      column name and the dictionary key for feature parsing configs, feature
1344      `Tensor` objects, and feature columns.
1345    shape: An iterable of integers specifies the shape of the `Tensor`. An
1346      integer can be given which means a single dimension `Tensor` with given
1347      width. The `Tensor` representing the column will have the shape of
1348      [batch_size] + `shape`.
1349    default_value: A single value compatible with `dtype` or an iterable of
1350      values compatible with `dtype` which the column takes on during
1351      `tf.Example` parsing if data is missing. A default value of `None` will
1352      cause `tf.io.parse_example` to fail if an example does not contain this
1353      column. If a single value is provided, the same value will be applied as
1354      the default value for every item. If an iterable of values is provided,
1355      the shape of the `default_value` should be equal to the given `shape`.
1356    dtype: defines the type of values. Default value is `tf.float32`. Must be a
1357      non-quantized, real integer or floating point type.
1358    normalizer_fn: If not `None`, a function that can be used to normalize the
1359      value of the tensor after `default_value` is applied for parsing.
1360      Normalizer function takes the input `Tensor` as its argument, and returns
1361      the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
1362      even though the most common use case of this function is normalization, it
1363      can be used for any kind of Tensorflow transformations.
1364
1365  Returns:
1366    A `NumericColumn`.
1367
1368  Raises:
1369    TypeError: if any dimension in shape is not an int
1370    ValueError: if any dimension in shape is not a positive integer
1371    TypeError: if `default_value` is an iterable but not compatible with `shape`
1372    TypeError: if `default_value` is not compatible with `dtype`.
1373    ValueError: if `dtype` is not convertible to `tf.float32`.
1374  """
1375  shape = _check_shape(shape, key)
1376  if not (dtype.is_integer or dtype.is_floating):
1377    raise ValueError('dtype must be convertible to float. '
1378                     'dtype: {}, key: {}'.format(dtype, key))
1379  default_value = fc_utils.check_default_value(
1380      shape, default_value, dtype, key)
1381
1382  if normalizer_fn is not None and not callable(normalizer_fn):
1383    raise TypeError(
1384        'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
1385
1386  fc_utils.assert_key_is_string(key)
1387  return NumericColumn(
1388      key,
1389      shape=shape,
1390      default_value=default_value,
1391      dtype=dtype,
1392      normalizer_fn=normalizer_fn)
1393
1394
1395@tf_export('feature_column.bucketized_column')
1396def bucketized_column(source_column, boundaries):
1397  """Represents discretized dense input bucketed by `boundaries`.
1398
1399  Buckets include the left boundary, and exclude the right boundary. Namely,
1400  `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
1401  `[1., 2.)`, and `[2., +inf)`.
1402
1403  For example, if the inputs are
1404
1405  ```python
1406  boundaries = [0, 10, 100]
1407  input tensor = [[-5, 10000]
1408                  [150,   10]
1409                  [5,    100]]
1410  ```
1411
1412  then the output will be
1413
1414  ```python
1415  output = [[0, 3]
1416            [3, 2]
1417            [1, 3]]
1418  ```
1419
1420  Example:
1421
1422  ```python
1423  price = tf.feature_column.numeric_column('price')
1424  bucketized_price = tf.feature_column.bucketized_column(
1425      price, boundaries=[...])
1426  columns = [bucketized_price, ...]
1427  features = tf.io.parse_example(
1428      ..., features=tf.feature_column.make_parse_example_spec(columns))
1429  dense_tensor = tf.keras.layers.DenseFeatures(columns)(features)
1430  ```
1431
1432  `bucketized_column` can also be crossed with another categorical column using
1433  `crossed_column`:
1434
1435  ```python
1436  price = tf.feature_column.numeric_column('price')
1437  # bucketized_column converts numerical feature to a categorical one.
1438  bucketized_price = tf.feature_column.bucketized_column(
1439      price, boundaries=[...])
1440  # 'keywords' is a string feature.
1441  price_x_keywords = tf.feature_column.crossed_column(
1442      [bucketized_price, 'keywords'], 50K)
1443  columns = [price_x_keywords, ...]
1444  features = tf.io.parse_example(
1445      ..., features=tf.feature_column.make_parse_example_spec(columns))
1446  dense_tensor = tf.keras.layers.DenseFeatures(columns)(features)
1447  linear_model = tf.keras.experimental.LinearModel(units=...)(dense_tensor)
1448  ```
1449
1450  Args:
1451    source_column: A one-dimensional dense column which is generated with
1452      `numeric_column`.
1453    boundaries: A sorted list or tuple of floats specifying the boundaries.
1454
1455  Returns:
1456    A `BucketizedColumn`.
1457
1458  Raises:
1459    ValueError: If `source_column` is not a numeric column, or if it is not
1460      one-dimensional.
1461    ValueError: If `boundaries` is not a sorted list or tuple.
1462  """
1463  if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)):  # pylint: disable=protected-access
1464    raise ValueError(
1465        'source_column must be a column generated with numeric_column(). '
1466        'Given: {}'.format(source_column))
1467  if len(source_column.shape) > 1:
1468    raise ValueError(
1469        'source_column must be one-dimensional column. '
1470        'Given: {}'.format(source_column))
1471  if not boundaries:
1472    raise ValueError('boundaries must not be empty.')
1473  if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)):
1474    raise ValueError('boundaries must be a sorted list.')
1475  for i in range(len(boundaries) - 1):
1476    if boundaries[i] >= boundaries[i + 1]:
1477      raise ValueError('boundaries must be a sorted list.')
1478  return BucketizedColumn(source_column, tuple(boundaries))
1479
1480
1481@tf_export('feature_column.categorical_column_with_hash_bucket')
1482def categorical_column_with_hash_bucket(key,
1483                                        hash_bucket_size,
1484                                        dtype=dtypes.string):
1485  """Represents sparse feature where ids are set by hashing.
1486
1487  Use this when your sparse features are in string or integer format, and you
1488  want to distribute your inputs into a finite number of buckets by hashing.
1489  output_id = Hash(input_feature_string) % bucket_size for string type input.
1490  For int type input, the value is converted to its string representation first
1491  and then hashed by the same formula.
1492
1493  For input dictionary `features`, `features[key]` is either `Tensor` or
1494  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1495  and `''` for string, which will be dropped by this feature column.
1496
1497  Example:
1498
1499  ```python
1500  keywords = categorical_column_with_hash_bucket("keywords", 10K)
1501  columns = [keywords, ...]
1502  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1503  linear_prediction = linear_model(features, columns)
1504
1505  # or
1506  keywords_embedded = embedding_column(keywords, 16)
1507  columns = [keywords_embedded, ...]
1508  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1509  dense_tensor = input_layer(features, columns)
1510  ```
1511
1512  Args:
1513    key: A unique string identifying the input feature. It is used as the
1514      column name and the dictionary key for feature parsing configs, feature
1515      `Tensor` objects, and feature columns.
1516    hash_bucket_size: An int > 1. The number of buckets.
1517    dtype: The type of features. Only string and integer types are supported.
1518
1519  Returns:
1520    A `HashedCategoricalColumn`.
1521
1522  Raises:
1523    ValueError: `hash_bucket_size` is not greater than 1.
1524    ValueError: `dtype` is neither string nor integer.
1525  """
1526  if hash_bucket_size is None:
1527    raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key))
1528
1529  if hash_bucket_size < 1:
1530    raise ValueError('hash_bucket_size must be at least 1. '
1531                     'hash_bucket_size: {}, key: {}'.format(
1532                         hash_bucket_size, key))
1533
1534  fc_utils.assert_key_is_string(key)
1535  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1536
1537  return HashedCategoricalColumn(key, hash_bucket_size, dtype)
1538
1539
1540@tf_export(v1=['feature_column.categorical_column_with_vocabulary_file'])
1541def categorical_column_with_vocabulary_file(key,
1542                                            vocabulary_file,
1543                                            vocabulary_size=None,
1544                                            num_oov_buckets=0,
1545                                            default_value=None,
1546                                            dtype=dtypes.string):
1547  """A `CategoricalColumn` with a vocabulary file.
1548
1549  Use this when your inputs are in string or integer format, and you have a
1550  vocabulary file that maps each value to an integer ID. By default,
1551  out-of-vocabulary values are ignored. Use either (but not both) of
1552  `num_oov_buckets` and `default_value` to specify how to include
1553  out-of-vocabulary values.
1554
1555  For input dictionary `features`, `features[key]` is either `Tensor` or
1556  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1557  and `''` for string, which will be dropped by this feature column.
1558
1559  Example with `num_oov_buckets`:
1560  File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
1561  abbreviation. All inputs with values in that file are assigned an ID 0-49,
1562  corresponding to its line number. All other values are hashed and assigned an
1563  ID 50-54.
1564
1565  ```python
1566  states = categorical_column_with_vocabulary_file(
1567      key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
1568      num_oov_buckets=5)
1569  columns = [states, ...]
1570  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1571  linear_prediction = linear_model(features, columns)
1572  ```
1573
1574  Example with `default_value`:
1575  File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
1576  other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
1577  in input, and other values missing from the file, will be assigned ID 0. All
1578  others are assigned the corresponding line number 1-50.
1579
1580  ```python
1581  states = categorical_column_with_vocabulary_file(
1582      key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
1583      default_value=0)
1584  columns = [states, ...]
1585  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1586  linear_prediction, _, _ = linear_model(features, columns)
1587  ```
1588
1589  And to make an embedding with either:
1590
1591  ```python
1592  columns = [embedding_column(states, 3),...]
1593  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1594  dense_tensor = input_layer(features, columns)
1595  ```
1596
1597  Args:
1598    key: A unique string identifying the input feature. It is used as the
1599      column name and the dictionary key for feature parsing configs, feature
1600      `Tensor` objects, and feature columns.
1601    vocabulary_file: The vocabulary file name.
1602    vocabulary_size: Number of the elements in the vocabulary. This must be no
1603      greater than length of `vocabulary_file`, if less than length, later
1604      values are ignored. If None, it is set to the length of `vocabulary_file`.
1605    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1606      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1607      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1608      the input value. A positive `num_oov_buckets` can not be specified with
1609      `default_value`.
1610    default_value: The integer ID value to return for out-of-vocabulary feature
1611      values, defaults to `-1`. This can not be specified with a positive
1612      `num_oov_buckets`.
1613    dtype: The type of features. Only string and integer types are supported.
1614
1615  Returns:
1616    A `CategoricalColumn` with a vocabulary file.
1617
1618  Raises:
1619    ValueError: `vocabulary_file` is missing or cannot be opened.
1620    ValueError: `vocabulary_size` is missing or < 1.
1621    ValueError: `num_oov_buckets` is a negative integer.
1622    ValueError: `num_oov_buckets` and `default_value` are both specified.
1623    ValueError: `dtype` is neither string nor integer.
1624  """
1625  return categorical_column_with_vocabulary_file_v2(
1626      key, vocabulary_file, vocabulary_size,
1627      dtype, default_value,
1628      num_oov_buckets)
1629
1630
1631@tf_export('feature_column.categorical_column_with_vocabulary_file', v1=[])
1632def categorical_column_with_vocabulary_file_v2(key,
1633                                               vocabulary_file,
1634                                               vocabulary_size=None,
1635                                               dtype=dtypes.string,
1636                                               default_value=None,
1637                                               num_oov_buckets=0):
1638  """A `CategoricalColumn` with a vocabulary file.
1639
1640  Use this when your inputs are in string or integer format, and you have a
1641  vocabulary file that maps each value to an integer ID. By default,
1642  out-of-vocabulary values are ignored. Use either (but not both) of
1643  `num_oov_buckets` and `default_value` to specify how to include
1644  out-of-vocabulary values.
1645
1646  For input dictionary `features`, `features[key]` is either `Tensor` or
1647  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1648  and `''` for string, which will be dropped by this feature column.
1649
1650  Example with `num_oov_buckets`:
1651  File `'/us/states.txt'` contains 50 lines, each with a 2-character U.S. state
1652  abbreviation. All inputs with values in that file are assigned an ID 0-49,
1653  corresponding to its line number. All other values are hashed and assigned an
1654  ID 50-54.
1655
1656  ```python
1657  states = categorical_column_with_vocabulary_file(
1658      key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
1659      num_oov_buckets=5)
1660  columns = [states, ...]
1661  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1662  linear_prediction = linear_model(features, columns)
1663  ```
1664
1665  Example with `default_value`:
1666  File `'/us/states.txt'` contains 51 lines - the first line is `'XX'`, and the
1667  other 50 each have a 2-character U.S. state abbreviation. Both a literal
1668  `'XX'` in input, and other values missing from the file, will be assigned
1669  ID 0. All others are assigned the corresponding line number 1-50.
1670
1671  ```python
1672  states = categorical_column_with_vocabulary_file(
1673      key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
1674      default_value=0)
1675  columns = [states, ...]
1676  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1677  linear_prediction, _, _ = linear_model(features, columns)
1678  ```
1679
1680  And to make an embedding with either:
1681
1682  ```python
1683  columns = [embedding_column(states, 3),...]
1684  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1685  dense_tensor = input_layer(features, columns)
1686  ```
1687
1688  Args:
1689    key: A unique string identifying the input feature. It is used as the
1690      column name and the dictionary key for feature parsing configs, feature
1691      `Tensor` objects, and feature columns.
1692    vocabulary_file: The vocabulary file name.
1693    vocabulary_size: Number of the elements in the vocabulary. This must be no
1694      greater than length of `vocabulary_file`, if less than length, later
1695      values are ignored. If None, it is set to the length of `vocabulary_file`.
1696    dtype: The type of features. Only string and integer types are supported.
1697    default_value: The integer ID value to return for out-of-vocabulary feature
1698      values, defaults to `-1`. This can not be specified with a positive
1699      `num_oov_buckets`.
1700    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1701      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1702      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1703      the input value. A positive `num_oov_buckets` can not be specified with
1704      `default_value`.
1705
1706  Returns:
1707    A `CategoricalColumn` with a vocabulary file.
1708
1709  Raises:
1710    ValueError: `vocabulary_file` is missing or cannot be opened.
1711    ValueError: `vocabulary_size` is missing or < 1.
1712    ValueError: `num_oov_buckets` is a negative integer.
1713    ValueError: `num_oov_buckets` and `default_value` are both specified.
1714    ValueError: `dtype` is neither string nor integer.
1715  """
1716  if not vocabulary_file:
1717    raise ValueError('Missing vocabulary_file in {}.'.format(key))
1718
1719  if vocabulary_size is None:
1720    if not gfile.Exists(vocabulary_file):
1721      raise ValueError('vocabulary_file in {} does not exist.'.format(key))
1722
1723    with gfile.GFile(vocabulary_file, mode='rb') as f:
1724      vocabulary_size = sum(1 for _ in f)
1725    logging.info(
1726        'vocabulary_size = %d in %s is inferred from the number of elements '
1727        'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file)
1728
1729  # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
1730  if vocabulary_size < 1:
1731    raise ValueError('Invalid vocabulary_size in {}.'.format(key))
1732  if num_oov_buckets:
1733    if default_value is not None:
1734      raise ValueError(
1735          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1736              key))
1737    if num_oov_buckets < 0:
1738      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1739          num_oov_buckets, key))
1740  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1741  fc_utils.assert_key_is_string(key)
1742  return VocabularyFileCategoricalColumn(
1743      key=key,
1744      vocabulary_file=vocabulary_file,
1745      vocabulary_size=vocabulary_size,
1746      num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets,
1747      default_value=-1 if default_value is None else default_value,
1748      dtype=dtype)
1749
1750
1751@tf_export('feature_column.categorical_column_with_vocabulary_list')
1752def categorical_column_with_vocabulary_list(key,
1753                                            vocabulary_list,
1754                                            dtype=None,
1755                                            default_value=-1,
1756                                            num_oov_buckets=0):
1757  """A `CategoricalColumn` with in-memory vocabulary.
1758
1759  Use this when your inputs are in string or integer format, and you have an
1760  in-memory vocabulary mapping each value to an integer ID. By default,
1761  out-of-vocabulary values are ignored. Use either (but not both) of
1762  `num_oov_buckets` and `default_value` to specify how to include
1763  out-of-vocabulary values.
1764
1765  For input dictionary `features`, `features[key]` is either `Tensor` or
1766  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1767  and `''` for string, which will be dropped by this feature column.
1768
1769  Example with `num_oov_buckets`:
1770  In the following example, each input in `vocabulary_list` is assigned an ID
1771  0-3 corresponding to its index (e.g., input 'B' produces output 2). All other
1772  inputs are hashed and assigned an ID 4-5.
1773
1774  ```python
1775  colors = categorical_column_with_vocabulary_list(
1776      key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
1777      num_oov_buckets=2)
1778  columns = [colors, ...]
1779  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1780  linear_prediction, _, _ = linear_model(features, columns)
1781  ```
1782
1783  Example with `default_value`:
1784  In the following example, each input in `vocabulary_list` is assigned an ID
1785  0-4 corresponding to its index (e.g., input 'B' produces output 3). All other
1786  inputs are assigned `default_value` 0.
1787
1788
1789  ```python
1790  colors = categorical_column_with_vocabulary_list(
1791      key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0)
1792  columns = [colors, ...]
1793  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1794  linear_prediction, _, _ = linear_model(features, columns)
1795  ```
1796
1797  And to make an embedding with either:
1798
1799  ```python
1800  columns = [embedding_column(colors, 3),...]
1801  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1802  dense_tensor = input_layer(features, columns)
1803  ```
1804
1805  Args:
1806    key: A unique string identifying the input feature. It is used as the column
1807      name and the dictionary key for feature parsing configs, feature `Tensor`
1808      objects, and feature columns.
1809    vocabulary_list: An ordered iterable defining the vocabulary. Each feature
1810      is mapped to the index of its value (if present) in `vocabulary_list`.
1811      Must be castable to `dtype`.
1812    dtype: The type of features. Only string and integer types are supported. If
1813      `None`, it will be inferred from `vocabulary_list`.
1814    default_value: The integer ID value to return for out-of-vocabulary feature
1815      values, defaults to `-1`. This can not be specified with a positive
1816      `num_oov_buckets`.
1817    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1818      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1819      `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
1820      hash of the input value. A positive `num_oov_buckets` can not be specified
1821      with `default_value`.
1822
1823  Returns:
1824    A `CategoricalColumn` with in-memory vocabulary.
1825
1826  Raises:
1827    ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
1828    ValueError: `num_oov_buckets` is a negative integer.
1829    ValueError: `num_oov_buckets` and `default_value` are both specified.
1830    ValueError: if `dtype` is not integer or string.
1831  """
1832  if (vocabulary_list is None) or (len(vocabulary_list) < 1):
1833    raise ValueError(
1834        'vocabulary_list {} must be non-empty, column_name: {}'.format(
1835            vocabulary_list, key))
1836  if len(set(vocabulary_list)) != len(vocabulary_list):
1837    raise ValueError(
1838        'Duplicate keys in vocabulary_list {}, column_name: {}'.format(
1839            vocabulary_list, key))
1840  vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype)
1841  if num_oov_buckets:
1842    if default_value != -1:
1843      raise ValueError(
1844          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1845              key))
1846    if num_oov_buckets < 0:
1847      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1848          num_oov_buckets, key))
1849  fc_utils.assert_string_or_int(
1850      vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
1851  if dtype is None:
1852    dtype = vocabulary_dtype
1853  elif dtype.is_integer != vocabulary_dtype.is_integer:
1854    raise ValueError(
1855        'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
1856            dtype, vocabulary_dtype, key))
1857  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1858  fc_utils.assert_key_is_string(key)
1859
1860  return VocabularyListCategoricalColumn(
1861      key=key,
1862      vocabulary_list=tuple(vocabulary_list),
1863      dtype=dtype,
1864      default_value=default_value,
1865      num_oov_buckets=num_oov_buckets)
1866
1867
1868@tf_export('feature_column.categorical_column_with_identity')
1869def categorical_column_with_identity(key, num_buckets, default_value=None):
1870  """A `CategoricalColumn` that returns identity values.
1871
1872  Use this when your inputs are integers in the range `[0, num_buckets)`, and
1873  you want to use the input value itself as the categorical ID. Values outside
1874  this range will result in `default_value` if specified, otherwise it will
1875  fail.
1876
1877  Typically, this is used for contiguous ranges of integer indexes, but
1878  it doesn't have to be. This might be inefficient, however, if many of IDs
1879  are unused. Consider `categorical_column_with_hash_bucket` in that case.
1880
1881  For input dictionary `features`, `features[key]` is either `Tensor` or
1882  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1883  and `''` for string, which will be dropped by this feature column.
1884
1885  In the following examples, each input in the range `[0, 1000000)` is assigned
1886  the same value. All other inputs are assigned `default_value` 0. Note that a
1887  literal 0 in inputs will result in the same default ID.
1888
1889  Linear model:
1890
1891  ```python
1892  video_id = categorical_column_with_identity(
1893      key='video_id', num_buckets=1000000, default_value=0)
1894  columns = [video_id, ...]
1895  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1896  linear_prediction, _, _ = linear_model(features, columns)
1897  ```
1898
1899  Embedding for a DNN model:
1900
1901  ```python
1902  columns = [embedding_column(video_id, 9),...]
1903  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1904  dense_tensor = input_layer(features, columns)
1905  ```
1906
1907  Args:
1908    key: A unique string identifying the input feature. It is used as the
1909      column name and the dictionary key for feature parsing configs, feature
1910      `Tensor` objects, and feature columns.
1911    num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
1912    default_value: If set, values outside of range `[0, num_buckets)` will
1913      be replaced with this value. If not set, values >= num_buckets will
1914      cause a failure while values < 0 will be dropped.
1915
1916  Returns:
1917    A `CategoricalColumn` that returns identity values.
1918
1919  Raises:
1920    ValueError: if `num_buckets` is less than one.
1921    ValueError: if `default_value` is not in range `[0, num_buckets)`.
1922  """
1923  if num_buckets < 1:
1924    raise ValueError(
1925        'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
1926  if (default_value is not None) and (
1927      (default_value < 0) or (default_value >= num_buckets)):
1928    raise ValueError(
1929        'default_value {} not in range [0, {}), column_name {}'.format(
1930            default_value, num_buckets, key))
1931  fc_utils.assert_key_is_string(key)
1932  return IdentityCategoricalColumn(
1933      key=key, number_buckets=num_buckets, default_value=default_value)
1934
1935
1936@tf_export('feature_column.indicator_column')
1937def indicator_column(categorical_column):
1938  """Represents multi-hot representation of given categorical column.
1939
1940  - For DNN model, `indicator_column` can be used to wrap any
1941    `categorical_column_*` (e.g., to feed to DNN). Consider to Use
1942    `embedding_column` if the number of buckets/unique(values) are large.
1943
1944  - For Wide (aka linear) model, `indicator_column` is the internal
1945    representation for categorical column when passing categorical column
1946    directly (as any element in feature_columns) to `linear_model`. See
1947    `linear_model` for details.
1948
1949  ```python
1950  name = indicator_column(categorical_column_with_vocabulary_list(
1951      'name', ['bob', 'george', 'wanda'])
1952  columns = [name, ...]
1953  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1954  dense_tensor = input_layer(features, columns)
1955
1956  dense_tensor == [[1, 0, 0]]  # If "name" bytes_list is ["bob"]
1957  dense_tensor == [[1, 0, 1]]  # If "name" bytes_list is ["bob", "wanda"]
1958  dense_tensor == [[2, 0, 0]]  # If "name" bytes_list is ["bob", "bob"]
1959  ```
1960
1961  Args:
1962    categorical_column: A `CategoricalColumn` which is created by
1963      `categorical_column_with_*` or `crossed_column` functions.
1964
1965  Returns:
1966    An `IndicatorColumn`.
1967
1968  Raises:
1969    ValueError: If `categorical_column` is not CategoricalColumn type.
1970  """
1971  if not isinstance(categorical_column,
1972                    (CategoricalColumn, fc_old._CategoricalColumn)):  # pylint: disable=protected-access
1973    raise ValueError(
1974        'Unsupported input type. Input must be a CategoricalColumn. '
1975        'Given: {}'.format(categorical_column))
1976  return IndicatorColumn(categorical_column)
1977
1978
1979@tf_export('feature_column.weighted_categorical_column')
1980def weighted_categorical_column(categorical_column,
1981                                weight_feature_key,
1982                                dtype=dtypes.float32):
1983  """Applies weight values to a `CategoricalColumn`.
1984
1985  Use this when each of your sparse inputs has both an ID and a value. For
1986  example, if you're representing text documents as a collection of word
1987  frequencies, you can provide 2 parallel sparse input features ('terms' and
1988  'frequencies' below).
1989
1990  Example:
1991
1992  Input `tf.Example` objects:
1993
1994  ```proto
1995  [
1996    features {
1997      feature {
1998        key: "terms"
1999        value {bytes_list {value: "very" value: "model"}}
2000      }
2001      feature {
2002        key: "frequencies"
2003        value {float_list {value: 0.3 value: 0.1}}
2004      }
2005    },
2006    features {
2007      feature {
2008        key: "terms"
2009        value {bytes_list {value: "when" value: "course" value: "human"}}
2010      }
2011      feature {
2012        key: "frequencies"
2013        value {float_list {value: 0.4 value: 0.1 value: 0.2}}
2014      }
2015    }
2016  ]
2017  ```
2018
2019  ```python
2020  categorical_column = categorical_column_with_hash_bucket(
2021      column_name='terms', hash_bucket_size=1000)
2022  weighted_column = weighted_categorical_column(
2023      categorical_column=categorical_column, weight_feature_key='frequencies')
2024  columns = [weighted_column, ...]
2025  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
2026  linear_prediction, _, _ = linear_model(features, columns)
2027  ```
2028
2029  This assumes the input dictionary contains a `SparseTensor` for key
2030  'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have
2031  the same indices and dense shape.
2032
2033  Args:
2034    categorical_column: A `CategoricalColumn` created by
2035      `categorical_column_with_*` functions.
2036    weight_feature_key: String key for weight values.
2037    dtype: Type of weights, such as `tf.float32`. Only float and integer weights
2038      are supported.
2039
2040  Returns:
2041    A `CategoricalColumn` composed of two sparse features: one represents id,
2042    the other represents weight (value) of the id feature in that example.
2043
2044  Raises:
2045    ValueError: if `dtype` is not convertible to float.
2046  """
2047  if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
2048    raise ValueError('dtype {} is not convertible to float.'.format(dtype))
2049  return WeightedCategoricalColumn(
2050      categorical_column=categorical_column,
2051      weight_feature_key=weight_feature_key,
2052      dtype=dtype)
2053
2054
2055@tf_export('feature_column.crossed_column')
2056def crossed_column(keys, hash_bucket_size, hash_key=None):
2057  """Returns a column for performing crosses of categorical features.
2058
2059  Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
2060  the transformation can be thought of as:
2061    Hash(cartesian product of features) % `hash_bucket_size`
2062
2063  For example, if the input features are:
2064
2065  * SparseTensor referred by first key:
2066
2067    ```python
2068    shape = [2, 2]
2069    {
2070        [0, 0]: "a"
2071        [1, 0]: "b"
2072        [1, 1]: "c"
2073    }
2074    ```
2075
2076  * SparseTensor referred by second key:
2077
2078    ```python
2079    shape = [2, 1]
2080    {
2081        [0, 0]: "d"
2082        [1, 0]: "e"
2083    }
2084    ```
2085
2086  then crossed feature will look like:
2087
2088  ```python
2089   shape = [2, 2]
2090  {
2091      [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size
2092      [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size
2093      [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size
2094  }
2095  ```
2096
2097  Here is an example to create a linear model with crosses of string features:
2098
2099  ```python
2100  keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K)
2101  columns = [keywords_x_doc_terms, ...]
2102  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
2103  linear_prediction = linear_model(features, columns)
2104  ```
2105
2106  You could also use vocabulary lookup before crossing:
2107
2108  ```python
2109  keywords = categorical_column_with_vocabulary_file(
2110      'keywords', '/path/to/vocabulary/file', vocabulary_size=1K)
2111  keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K)
2112  columns = [keywords_x_doc_terms, ...]
2113  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
2114  linear_prediction = linear_model(features, columns)
2115  ```
2116
2117  If an input feature is of numeric type, you can use
2118  `categorical_column_with_identity`, or `bucketized_column`, as in the example:
2119
2120  ```python
2121  # vertical_id is an integer categorical feature.
2122  vertical_id = categorical_column_with_identity('vertical_id', 10K)
2123  price = numeric_column('price')
2124  # bucketized_column converts numerical feature to a categorical one.
2125  bucketized_price = bucketized_column(price, boundaries=[...])
2126  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
2127  columns = [vertical_id_x_price, ...]
2128  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
2129  linear_prediction = linear_model(features, columns)
2130  ```
2131
2132  To use crossed column in DNN model, you need to add it in an embedding column
2133  as in this example:
2134
2135  ```python
2136  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
2137  vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10)
2138  dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...])
2139  ```
2140
2141  Args:
2142    keys: An iterable identifying the features to be crossed. Each element can
2143      be either:
2144      * string: Will use the corresponding feature which must be of string type.
2145      * `CategoricalColumn`: Will use the transformed tensor produced by this
2146        column. Does not support hashed categorical column.
2147    hash_bucket_size: An int > 1. The number of buckets.
2148    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
2149      function to combine the crosses fingerprints on SparseCrossOp (optional).
2150
2151  Returns:
2152    A `CrossedColumn`.
2153
2154  Raises:
2155    ValueError: If `len(keys) < 2`.
2156    ValueError: If any of the keys is neither a string nor `CategoricalColumn`.
2157    ValueError: If any of the keys is `HashedCategoricalColumn`.
2158    ValueError: If `hash_bucket_size < 1`.
2159  """
2160  if not hash_bucket_size or hash_bucket_size < 1:
2161    raise ValueError('hash_bucket_size must be > 1. '
2162                     'hash_bucket_size: {}'.format(hash_bucket_size))
2163  if not keys or len(keys) < 2:
2164    raise ValueError(
2165        'keys must be a list with length > 1. Given: {}'.format(keys))
2166  for key in keys:
2167    if (not isinstance(key, six.string_types) and
2168        not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))):  # pylint: disable=protected-access
2169      raise ValueError(
2170          'Unsupported key type. All keys must be either string, or '
2171          'categorical column except HashedCategoricalColumn. '
2172          'Given: {}'.format(key))
2173    if isinstance(key,
2174                  (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)):  # pylint: disable=protected-access
2175      raise ValueError(
2176          'categorical_column_with_hash_bucket is not supported for crossing. '
2177          'Hashing before crossing will increase probability of collision. '
2178          'Instead, use the feature name as a string. Given: {}'.format(key))
2179  return CrossedColumn(
2180      keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
2181
2182
2183@six.add_metaclass(abc.ABCMeta)
2184class FeatureColumn(object):
2185  """Represents a feature column abstraction.
2186
2187  WARNING: Do not subclass this layer unless you know what you are doing:
2188  the API is subject to future changes.
2189
2190  To distinguish between the concept of a feature family and a specific binary
2191  feature within a family, we refer to a feature family like "country" as a
2192  feature column. For example, we can have a feature in a `tf.Example` format:
2193    {key: "country",  value: [ "US" ]}
2194  In this example the value of feature is "US" and "country" refers to the
2195  column of the feature.
2196
2197  This class is an abstract class. Users should not create instances of this.
2198  """
2199
2200  @abc.abstractproperty
2201  def name(self):
2202    """Returns string. Used for naming."""
2203    pass
2204
2205  def __lt__(self, other):
2206    """Allows feature columns to be sorted in Python 3 as they are in Python 2.
2207
2208    Feature columns need to occasionally be sortable, for example when used as
2209    keys in a features dictionary passed to a layer.
2210
2211    In CPython, `__lt__` must be defined for all objects in the
2212    sequence being sorted.
2213
2214    If any objects in teh sequence being sorted do not have an `__lt__` method
2215    compatible with feature column objects (such as strings), then CPython will
2216    fall back to using the `__gt__` method below.
2217    https://docs.python.org/3/library/stdtypes.html#list.sort
2218
2219    Args:
2220      other: The other object to compare to.
2221
2222    Returns:
2223      True if the string representation of this object is lexicographically less
2224      than the string representation of `other`. For FeatureColumn objects,
2225      this looks like "<__main__.FeatureColumn object at 0xa>".
2226    """
2227    return str(self) < str(other)
2228
2229  def __gt__(self, other):
2230    """Allows feature columns to be sorted in Python 3 as they are in Python 2.
2231
2232    Feature columns need to occasionally be sortable, for example when used as
2233    keys in a features dictionary passed to a layer.
2234
2235    `__gt__` is called when the "other" object being compared during the sort
2236    does not have `__lt__` defined.
2237    Example: http://gpaste/4803354716798976
2238
2239    Args:
2240      other: The other object to compare to.
2241
2242    Returns:
2243      True if the string representation of this object is lexicographically
2244      greater than the string representation of `other`. For FeatureColumn
2245      objects, this looks like "<__main__.FeatureColumn object at 0xa>".
2246    """
2247    return str(self) > str(other)
2248
2249  @abc.abstractmethod
2250  def transform_feature(self, transformation_cache, state_manager):
2251    """Returns intermediate representation (usually a `Tensor`).
2252
2253    Uses `transformation_cache` to create an intermediate representation
2254    (usually a `Tensor`) that other feature columns can use.
2255
2256    Example usage of `transformation_cache`:
2257    Let's say a Feature column depends on raw feature ('raw') and another
2258    `FeatureColumn` (input_fc). To access corresponding `Tensor`s,
2259    transformation_cache will be used as follows:
2260
2261    ```python
2262    raw_tensor = transformation_cache.get('raw', state_manager)
2263    fc_tensor = transformation_cache.get(input_fc, state_manager)
2264    ```
2265
2266    Args:
2267      transformation_cache: A `FeatureTransformationCache` object to access
2268        features.
2269      state_manager: A `StateManager` to create / access resources such as
2270        lookup tables.
2271
2272    Returns:
2273      Transformed feature `Tensor`.
2274    """
2275    pass
2276
2277  @abc.abstractproperty
2278  def parse_example_spec(self):
2279    """Returns a `tf.Example` parsing spec as dict.
2280
2281    It is used for get_parsing_spec for `tf.io.parse_example`. Returned spec is
2282    a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other
2283    supported objects. Please check documentation of `tf.io.parse_example` for
2284    all supported spec objects.
2285
2286    Let's say a Feature column depends on raw feature ('raw') and another
2287    `FeatureColumn` (input_fc). One possible implementation of
2288    parse_example_spec is as follows:
2289
2290    ```python
2291    spec = {'raw': tf.io.FixedLenFeature(...)}
2292    spec.update(input_fc.parse_example_spec)
2293    return spec
2294    ```
2295    """
2296    pass
2297
2298  def create_state(self, state_manager):
2299    """Uses the `state_manager` to create state for the FeatureColumn.
2300
2301    Args:
2302      state_manager: A `StateManager` to create / access resources such as
2303        lookup tables and variables.
2304    """
2305    pass
2306
2307  @abc.abstractproperty
2308  def _is_v2_column(self):
2309    """Returns whether this FeatureColumn is fully conformant to the new API.
2310
2311    This is needed for composition type cases where an EmbeddingColumn etc.
2312    might take in old categorical columns as input and then we want to use the
2313    old API.
2314    """
2315    pass
2316
2317  @abc.abstractproperty
2318  def parents(self):
2319    """Returns a list of immediate raw feature and FeatureColumn dependencies.
2320
2321    For example:
2322    # For the following feature columns
2323    a = numeric_column('f1')
2324    c = crossed_column(a, 'f2')
2325    # The expected parents are:
2326    a.parents = ['f1']
2327    c.parents = [a, 'f2']
2328    """
2329    pass
2330
2331  def get_config(self):
2332    """Returns the config of the feature column.
2333
2334    A FeatureColumn config is a Python dictionary (serializable) containing the
2335    configuration of a FeatureColumn. The same FeatureColumn can be
2336    reinstantiated later from this configuration.
2337
2338    The config of a feature column does not include information about feature
2339    columns depending on it nor the FeatureColumn class name.
2340
2341    Example with (de)serialization practices followed in this file:
2342    ```python
2343    class SerializationExampleFeatureColumn(
2344        FeatureColumn, collections.namedtuple(
2345            'SerializationExampleFeatureColumn',
2346            ('dimension', 'parent', 'dtype', 'normalizer_fn'))):
2347
2348      def get_config(self):
2349        # Create a dict from the namedtuple.
2350        # Python attribute literals can be directly copied from / to the config.
2351        # For example 'dimension', assuming it is an integer literal.
2352        config = dict(zip(self._fields, self))
2353
2354        # (De)serialization of parent FeatureColumns should use the provided
2355        # (de)serialize_feature_column() methods that take care of de-duping.
2356        config['parent'] = serialize_feature_column(self.parent)
2357
2358        # Many objects provide custom (de)serialization e.g: for tf.DType
2359        # tf.DType.name, tf.as_dtype() can be used.
2360        config['dtype'] = self.dtype.name
2361
2362        # Non-trivial dependencies should be Keras-(de)serializable.
2363        config['normalizer_fn'] = generic_utils.serialize_keras_object(
2364            self.normalizer_fn)
2365
2366        return config
2367
2368      @classmethod
2369      def from_config(cls, config, custom_objects=None, columns_by_name=None):
2370        # This should do the inverse transform from `get_config` and construct
2371        # the namedtuple.
2372        kwargs = config.copy()
2373        kwargs['parent'] = deserialize_feature_column(
2374            config['parent'], custom_objects, columns_by_name)
2375        kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2376        kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
2377          config['normalizer_fn'], custom_objects=custom_objects)
2378        return cls(**kwargs)
2379
2380    ```
2381    Returns:
2382      A serializable Dict that can be used to deserialize the object with
2383      from_config.
2384    """
2385    return self._get_config()
2386
2387  def _get_config(self):
2388    raise NotImplementedError('Must be implemented in subclasses.')
2389
2390  @classmethod
2391  def from_config(cls, config, custom_objects=None, columns_by_name=None):
2392    """Creates a FeatureColumn from its config.
2393
2394    This method should be the reverse of `get_config`, capable of instantiating
2395    the same FeatureColumn from the config dictionary. See `get_config` for an
2396    example of common (de)serialization practices followed in this file.
2397
2398    TODO(b/118939620): This is a private method until consensus is reached on
2399    supporting object deserialization deduping within Keras.
2400
2401    Args:
2402      config: A Dict config acquired with `get_config`.
2403      custom_objects: Optional dictionary mapping names (strings) to custom
2404        classes or functions to be considered during deserialization.
2405      columns_by_name: A Dict[String, FeatureColumn] of existing columns in
2406        order to avoid duplication. Should be passed to any calls to
2407        deserialize_feature_column().
2408
2409    Returns:
2410      A FeatureColumn for the input config.
2411    """
2412    return cls._from_config(config, custom_objects, columns_by_name)
2413
2414  @classmethod
2415  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
2416    raise NotImplementedError('Must be implemented in subclasses.')
2417
2418
2419class DenseColumn(FeatureColumn):
2420  """Represents a column which can be represented as `Tensor`.
2421
2422  Some examples of this type are: numeric_column, embedding_column,
2423  indicator_column.
2424  """
2425
2426  @abc.abstractproperty
2427  def variable_shape(self):
2428    """`TensorShape` of `get_dense_tensor`, without batch dimension."""
2429    pass
2430
2431  @abc.abstractmethod
2432  def get_dense_tensor(self, transformation_cache, state_manager):
2433    """Returns a `Tensor`.
2434
2435    The output of this function will be used by model-builder-functions. For
2436    example the pseudo code of `input_layer` will be like:
2437
2438    ```python
2439    def input_layer(features, feature_columns, ...):
2440      outputs = [fc.get_dense_tensor(...) for fc in feature_columns]
2441      return tf.concat(outputs)
2442    ```
2443
2444    Args:
2445      transformation_cache: A `FeatureTransformationCache` object to access
2446        features.
2447      state_manager: A `StateManager` to create / access resources such as
2448        lookup tables.
2449
2450    Returns:
2451      `Tensor` of shape [batch_size] + `variable_shape`.
2452    """
2453    pass
2454
2455
2456def is_feature_column_v2(feature_columns):
2457  """Returns True if all feature columns are V2."""
2458  for feature_column in feature_columns:
2459    if not isinstance(feature_column, FeatureColumn):
2460      return False
2461    if not feature_column._is_v2_column:  # pylint: disable=protected-access
2462      return False
2463  return True
2464
2465
2466def _create_weighted_sum(column, transformation_cache, state_manager,
2467                         sparse_combiner, weight_var):
2468  """Creates a weighted sum for a dense/categorical column for linear_model."""
2469  if isinstance(column, CategoricalColumn):
2470    return _create_categorical_column_weighted_sum(
2471        column=column,
2472        transformation_cache=transformation_cache,
2473        state_manager=state_manager,
2474        sparse_combiner=sparse_combiner,
2475        weight_var=weight_var)
2476  else:
2477    return _create_dense_column_weighted_sum(
2478        column=column,
2479        transformation_cache=transformation_cache,
2480        state_manager=state_manager,
2481        weight_var=weight_var)
2482
2483
2484def _create_dense_column_weighted_sum(column, transformation_cache,
2485                                      state_manager, weight_var):
2486  """Create a weighted sum of a dense column for linear_model."""
2487  tensor = column.get_dense_tensor(transformation_cache, state_manager)
2488  num_elements = column.variable_shape.num_elements()
2489  batch_size = array_ops.shape(tensor)[0]
2490  tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
2491  return math_ops.matmul(tensor, weight_var, name='weighted_sum')
2492
2493
2494class CategoricalColumn(FeatureColumn):
2495  """Represents a categorical feature.
2496
2497  A categorical feature typically handled with a `tf.SparseTensor` of IDs.
2498  """
2499
2500  IdWeightPair = collections.namedtuple(  # pylint: disable=invalid-name
2501      'IdWeightPair', ('id_tensor', 'weight_tensor'))
2502
2503  @abc.abstractproperty
2504  def num_buckets(self):
2505    """Returns number of buckets in this sparse feature."""
2506    pass
2507
2508  @abc.abstractmethod
2509  def get_sparse_tensors(self, transformation_cache, state_manager):
2510    """Returns an IdWeightPair.
2511
2512    `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
2513    weights.
2514
2515    `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
2516    `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
2517    `SparseTensor` of `float` or `None` to indicate all weights should be
2518    taken to be 1. If specified, `weight_tensor` must have exactly the same
2519    shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
2520    output of a `VarLenFeature` which is a ragged matrix.
2521
2522    Args:
2523      transformation_cache: A `FeatureTransformationCache` object to access
2524        features.
2525      state_manager: A `StateManager` to create / access resources such as
2526        lookup tables.
2527    """
2528    pass
2529
2530
2531def _create_categorical_column_weighted_sum(
2532    column, transformation_cache, state_manager, sparse_combiner, weight_var):
2533  # pylint: disable=g-doc-return-or-yield,g-doc-args
2534  """Create a weighted sum of a categorical column for linear_model.
2535
2536  Note to maintainer: As implementation details, the weighted sum is
2537  implemented via embedding_lookup_sparse toward efficiency. Mathematically,
2538  they are the same.
2539
2540  To be specific, conceptually, categorical column can be treated as multi-hot
2541  vector. Say:
2542
2543  ```python
2544    x = [0 0 1]  # categorical column input
2545    w = [a b c]  # weights
2546  ```
2547  The weighted sum is `c` in this case, which is same as `w[2]`.
2548
2549  Another example is
2550
2551  ```python
2552    x = [0 1 1]  # categorical column input
2553    w = [a b c]  # weights
2554  ```
2555  The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
2556
2557  For both cases, we can implement weighted sum via embedding_lookup with
2558  sparse_combiner = "sum".
2559  """
2560
2561  sparse_tensors = column.get_sparse_tensors(transformation_cache,
2562                                             state_manager)
2563  id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [
2564      array_ops.shape(sparse_tensors.id_tensor)[0], -1
2565  ])
2566  weight_tensor = sparse_tensors.weight_tensor
2567  if weight_tensor is not None:
2568    weight_tensor = sparse_ops.sparse_reshape(
2569        weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
2570
2571  return embedding_ops.safe_embedding_lookup_sparse(
2572      weight_var,
2573      id_tensor,
2574      sparse_weights=weight_tensor,
2575      combiner=sparse_combiner,
2576      name='weighted_sum')
2577
2578
2579class SequenceDenseColumn(FeatureColumn):
2580  """Represents dense sequence data."""
2581
2582  TensorSequenceLengthPair = collections.namedtuple(  # pylint: disable=invalid-name
2583      'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length'))
2584
2585  @abc.abstractmethod
2586  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
2587    """Returns a `TensorSequenceLengthPair`.
2588
2589    Args:
2590      transformation_cache: A `FeatureTransformationCache` object to access
2591        features.
2592      state_manager: A `StateManager` to create / access resources such as
2593        lookup tables.
2594    """
2595    pass
2596
2597
2598class FeatureTransformationCache(object):
2599  """Handles caching of transformations while building the model.
2600
2601  `FeatureColumn` specifies how to digest an input column to the network. Some
2602  feature columns require data transformations. This class caches those
2603  transformations.
2604
2605  Some features may be used in more than one place. For example, one can use a
2606  bucketized feature by itself and a cross with it. In that case we
2607  should create only one bucketization op instead of creating ops for each
2608  feature column separately. To handle re-use of transformed columns,
2609  `FeatureTransformationCache` caches all previously transformed columns.
2610
2611  Example:
2612  We're trying to use the following `FeatureColumn`s:
2613
2614  ```python
2615  bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...)
2616  keywords = fc.categorical_column_with_hash_buckets("keywords", ...)
2617  age_X_keywords = fc.crossed_column([bucketized_age, "keywords"])
2618  ... = linear_model(features,
2619                          [bucketized_age, keywords, age_X_keywords]
2620  ```
2621
2622  If we transform each column independently, then we'll get duplication of
2623  bucketization (one for cross, one for bucketization itself).
2624  The `FeatureTransformationCache` eliminates this duplication.
2625  """
2626
2627  def __init__(self, features):
2628    """Creates a `FeatureTransformationCache`.
2629
2630    Args:
2631      features: A mapping from feature column to objects that are `Tensor` or
2632        `SparseTensor`, or can be converted to same via
2633        `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key
2634        signifies a base feature (not-transformed). A `FeatureColumn` key
2635        means that this `Tensor` is the output of an existing `FeatureColumn`
2636        which can be reused.
2637    """
2638    self._features = features.copy()
2639    self._feature_tensors = {}
2640
2641  def get(self, key, state_manager):
2642    """Returns a `Tensor` for the given key.
2643
2644    A `str` key is used to access a base feature (not-transformed). When a
2645    `FeatureColumn` is passed, the transformed feature is returned if it
2646    already exists, otherwise the given `FeatureColumn` is asked to provide its
2647    transformed output, which is then cached.
2648
2649    Args:
2650      key: a `str` or a `FeatureColumn`.
2651      state_manager: A StateManager object that holds the FeatureColumn state.
2652
2653    Returns:
2654      The transformed `Tensor` corresponding to the `key`.
2655
2656    Raises:
2657      ValueError: if key is not found or a transformed `Tensor` cannot be
2658        computed.
2659    """
2660    if key in self._feature_tensors:
2661      # FeatureColumn is already transformed or converted.
2662      return self._feature_tensors[key]
2663
2664    if key in self._features:
2665      feature_tensor = self._get_raw_feature_as_tensor(key)
2666      self._feature_tensors[key] = feature_tensor
2667      return feature_tensor
2668
2669    if isinstance(key, six.string_types):
2670      raise ValueError('Feature {} is not in features dictionary.'.format(key))
2671
2672    if not isinstance(key, FeatureColumn):
2673      raise TypeError('"key" must be either a "str" or "FeatureColumn". '
2674                      'Provided: {}'.format(key))
2675
2676    column = key
2677    logging.debug('Transforming feature_column %s.', column)
2678    transformed = column.transform_feature(self, state_manager)
2679    if transformed is None:
2680      raise ValueError('Column {} is not supported.'.format(column.name))
2681    self._feature_tensors[column] = transformed
2682    return transformed
2683
2684  def _get_raw_feature_as_tensor(self, key):
2685    """Gets the raw_feature (keyed by `key`) as `tensor`.
2686
2687    The raw feature is converted to (sparse) tensor and maybe expand dim.
2688
2689    For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if
2690    the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will
2691    error out as it is not supported.
2692
2693    Args:
2694      key: A `str` key to access the raw feature.
2695
2696    Returns:
2697      A `Tensor` or `SparseTensor`.
2698
2699    Raises:
2700      ValueError: if the raw feature has rank 0.
2701    """
2702    raw_feature = self._features[key]
2703    feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2704        raw_feature)
2705
2706    def expand_dims(input_tensor):
2707      # Input_tensor must have rank 1.
2708      if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2709        return sparse_ops.sparse_reshape(
2710            input_tensor, [array_ops.shape(input_tensor)[0], 1])
2711      else:
2712        return array_ops.expand_dims(input_tensor, -1)
2713
2714    rank = feature_tensor.get_shape().ndims
2715    if rank is not None:
2716      if rank == 0:
2717        raise ValueError(
2718            'Feature (key: {}) cannot have rank 0. Given: {}'.format(
2719                key, feature_tensor))
2720      return feature_tensor if rank != 1 else expand_dims(feature_tensor)
2721
2722    # Handle dynamic rank.
2723    with ops.control_dependencies([
2724        check_ops.assert_positive(
2725            array_ops.rank(feature_tensor),
2726            message='Feature (key: {}) cannot have rank 0. Given: {}'.format(
2727                key, feature_tensor))]):
2728      return control_flow_ops.cond(
2729          math_ops.equal(1, array_ops.rank(feature_tensor)),
2730          lambda: expand_dims(feature_tensor),
2731          lambda: feature_tensor)
2732
2733
2734# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
2735def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
2736  """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
2737
2738  If `input_tensor` is already a `SparseTensor`, just return it.
2739
2740  Args:
2741    input_tensor: A string or integer `Tensor`.
2742    ignore_value: Entries in `dense_tensor` equal to this value will be
2743      absent from the resulting `SparseTensor`. If `None`, default value of
2744      `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`).
2745
2746  Returns:
2747    A `SparseTensor` with the same shape as `input_tensor`.
2748
2749  Raises:
2750    ValueError: when `input_tensor`'s rank is `None`.
2751  """
2752  input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2753      input_tensor)
2754  if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2755    return input_tensor
2756  with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)):
2757    if ignore_value is None:
2758      if input_tensor.dtype == dtypes.string:
2759        # Exception due to TF strings are converted to numpy objects by default.
2760        ignore_value = ''
2761      elif input_tensor.dtype.is_integer:
2762        ignore_value = -1  # -1 has a special meaning of missing feature
2763      else:
2764        # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
2765        # constructing a new numpy object of the given type, which yields the
2766        # default value for that type.
2767        ignore_value = input_tensor.dtype.as_numpy_dtype()
2768    ignore_value = math_ops.cast(
2769        ignore_value, input_tensor.dtype, name='ignore_value')
2770    indices = array_ops.where_v2(
2771        math_ops.not_equal(input_tensor, ignore_value), name='indices')
2772    return sparse_tensor_lib.SparseTensor(
2773        indices=indices,
2774        values=array_ops.gather_nd(input_tensor, indices, name='values'),
2775        dense_shape=array_ops.shape(
2776            input_tensor, out_type=dtypes.int64, name='dense_shape'))
2777
2778
2779def _normalize_feature_columns(feature_columns):
2780  """Normalizes the `feature_columns` input.
2781
2782  This method converts the `feature_columns` to list type as best as it can. In
2783  addition, verifies the type and other parts of feature_columns, required by
2784  downstream library.
2785
2786  Args:
2787    feature_columns: The raw feature columns, usually passed by users.
2788
2789  Returns:
2790    The normalized feature column list.
2791
2792  Raises:
2793    ValueError: for any invalid inputs, such as empty, duplicated names, etc.
2794  """
2795  if isinstance(feature_columns, FeatureColumn):
2796    feature_columns = [feature_columns]
2797
2798  if isinstance(feature_columns, collections_abc.Iterator):
2799    feature_columns = list(feature_columns)
2800
2801  if isinstance(feature_columns, dict):
2802    raise ValueError('Expected feature_columns to be iterable, found dict.')
2803
2804  for column in feature_columns:
2805    if not isinstance(column, FeatureColumn):
2806      raise ValueError('Items of feature_columns must be a FeatureColumn. '
2807                       'Given (type {}): {}.'.format(type(column), column))
2808  if not feature_columns:
2809    raise ValueError('feature_columns must not be empty.')
2810  name_to_column = {}
2811  for column in feature_columns:
2812    if column.name in name_to_column:
2813      raise ValueError('Duplicate feature column name found for columns: {} '
2814                       'and {}. This usually means that these columns refer to '
2815                       'same base feature. Either one must be discarded or a '
2816                       'duplicated but renamed item must be inserted in '
2817                       'features dict.'.format(column,
2818                                               name_to_column[column.name]))
2819    name_to_column[column.name] = column
2820
2821  return sorted(feature_columns, key=lambda x: x.name)
2822
2823
2824class NumericColumn(
2825    DenseColumn,
2826    fc_old._DenseColumn,  # pylint: disable=protected-access
2827    collections.namedtuple(
2828        'NumericColumn',
2829        ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
2830  """see `numeric_column`."""
2831
2832  @property
2833  def _is_v2_column(self):
2834    return True
2835
2836  @property
2837  def name(self):
2838    """See `FeatureColumn` base class."""
2839    return self.key
2840
2841  @property
2842  def parse_example_spec(self):
2843    """See `FeatureColumn` base class."""
2844    return {
2845        self.key:
2846            parsing_ops.FixedLenFeature(self.shape, self.dtype,
2847                                        self.default_value)
2848    }
2849
2850  @property
2851  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2852                          _FEATURE_COLUMN_DEPRECATION)
2853  def _parse_example_spec(self):
2854    return self.parse_example_spec
2855
2856  def _transform_input_tensor(self, input_tensor):
2857    if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2858      raise ValueError(
2859          'The corresponding Tensor of numerical column must be a Tensor. '
2860          'SparseTensor is not supported. key: {}'.format(self.key))
2861    if self.normalizer_fn is not None:
2862      input_tensor = self.normalizer_fn(input_tensor)
2863    return math_ops.cast(input_tensor, dtypes.float32)
2864
2865  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2866                          _FEATURE_COLUMN_DEPRECATION)
2867  def _transform_feature(self, inputs):
2868    input_tensor = inputs.get(self.key)
2869    return self._transform_input_tensor(input_tensor)
2870
2871  def transform_feature(self, transformation_cache, state_manager):
2872    """See `FeatureColumn` base class.
2873
2874    In this case, we apply the `normalizer_fn` to the input tensor.
2875
2876    Args:
2877      transformation_cache: A `FeatureTransformationCache` object to access
2878        features.
2879      state_manager: A `StateManager` to create / access resources such as
2880        lookup tables.
2881
2882    Returns:
2883      Normalized input tensor.
2884    Raises:
2885      ValueError: If a SparseTensor is passed in.
2886    """
2887    input_tensor = transformation_cache.get(self.key, state_manager)
2888    return self._transform_input_tensor(input_tensor)
2889
2890  @property
2891  def variable_shape(self):
2892    """See `DenseColumn` base class."""
2893    return tensor_shape.TensorShape(self.shape)
2894
2895  @property
2896  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2897                          _FEATURE_COLUMN_DEPRECATION)
2898  def _variable_shape(self):
2899    return self.variable_shape
2900
2901  def get_dense_tensor(self, transformation_cache, state_manager):
2902    """Returns dense `Tensor` representing numeric feature.
2903
2904    Args:
2905      transformation_cache: A `FeatureTransformationCache` object to access
2906        features.
2907      state_manager: A `StateManager` to create / access resources such as
2908        lookup tables.
2909
2910    Returns:
2911      Dense `Tensor` created within `transform_feature`.
2912    """
2913    # Feature has been already transformed. Return the intermediate
2914    # representation created by _transform_feature.
2915    return transformation_cache.get(self, state_manager)
2916
2917  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2918                          _FEATURE_COLUMN_DEPRECATION)
2919  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2920    del weight_collections
2921    del trainable
2922    return inputs.get(self)
2923
2924  @property
2925  def parents(self):
2926    """See 'FeatureColumn` base class."""
2927    return [self.key]
2928
2929  def get_config(self):
2930    """See 'FeatureColumn` base class."""
2931    config = dict(zip(self._fields, self))
2932    config['normalizer_fn'] = generic_utils.serialize_keras_object(
2933        self.normalizer_fn)
2934    config['dtype'] = self.dtype.name
2935    return config
2936
2937  @classmethod
2938  def from_config(cls, config, custom_objects=None, columns_by_name=None):
2939    """See 'FeatureColumn` base class."""
2940    _check_config_keys(config, cls._fields)
2941    kwargs = _standardize_and_copy_config(config)
2942    kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
2943        config['normalizer_fn'], custom_objects=custom_objects)
2944    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2945
2946    return cls(**kwargs)
2947
2948
2949class BucketizedColumn(
2950    DenseColumn,
2951    CategoricalColumn,
2952    fc_old._DenseColumn,  # pylint: disable=protected-access
2953    fc_old._CategoricalColumn,  # pylint: disable=protected-access
2954    collections.namedtuple('BucketizedColumn',
2955                           ('source_column', 'boundaries'))):
2956  """See `bucketized_column`."""
2957
2958  @property
2959  def _is_v2_column(self):
2960    return (isinstance(self.source_column, FeatureColumn) and
2961            self.source_column._is_v2_column)  # pylint: disable=protected-access
2962
2963  @property
2964  def name(self):
2965    """See `FeatureColumn` base class."""
2966    return '{}_bucketized'.format(self.source_column.name)
2967
2968  @property
2969  def parse_example_spec(self):
2970    """See `FeatureColumn` base class."""
2971    return self.source_column.parse_example_spec
2972
2973  @property
2974  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2975                          _FEATURE_COLUMN_DEPRECATION)
2976  def _parse_example_spec(self):
2977    return self.source_column._parse_example_spec  # pylint: disable=protected-access
2978
2979  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2980                          _FEATURE_COLUMN_DEPRECATION)
2981  def _transform_feature(self, inputs):
2982    """Returns bucketized categorical `source_column` tensor."""
2983    source_tensor = inputs.get(self.source_column)
2984    return math_ops._bucketize(  # pylint: disable=protected-access
2985        source_tensor,
2986        boundaries=self.boundaries)
2987
2988  def transform_feature(self, transformation_cache, state_manager):
2989    """Returns bucketized categorical `source_column` tensor."""
2990    source_tensor = transformation_cache.get(self.source_column, state_manager)
2991    return math_ops._bucketize(  # pylint: disable=protected-access
2992        source_tensor,
2993        boundaries=self.boundaries)
2994
2995  @property
2996  def variable_shape(self):
2997    """See `DenseColumn` base class."""
2998    return tensor_shape.TensorShape(
2999        tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
3000
3001  @property
3002  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3003                          _FEATURE_COLUMN_DEPRECATION)
3004  def _variable_shape(self):
3005    return self.variable_shape
3006
3007  def _get_dense_tensor_for_input_tensor(self, input_tensor):
3008    return array_ops.one_hot(
3009        indices=math_ops.cast(input_tensor, dtypes.int64),
3010        depth=len(self.boundaries) + 1,
3011        on_value=1.,
3012        off_value=0.)
3013
3014  def get_dense_tensor(self, transformation_cache, state_manager):
3015    """Returns one hot encoded dense `Tensor`."""
3016    input_tensor = transformation_cache.get(self, state_manager)
3017    return self._get_dense_tensor_for_input_tensor(input_tensor)
3018
3019  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3020                          _FEATURE_COLUMN_DEPRECATION)
3021  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3022    del weight_collections
3023    del trainable
3024    input_tensor = inputs.get(self)
3025    return self._get_dense_tensor_for_input_tensor(input_tensor)
3026
3027  @property
3028  def num_buckets(self):
3029    """See `CategoricalColumn` base class."""
3030    # By construction, source_column is always one-dimensional.
3031    return (len(self.boundaries) + 1) * self.source_column.shape[0]
3032
3033  @property
3034  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3035                          _FEATURE_COLUMN_DEPRECATION)
3036  def _num_buckets(self):
3037    return self.num_buckets
3038
3039  def _get_sparse_tensors_for_input_tensor(self, input_tensor):
3040    batch_size = array_ops.shape(input_tensor)[0]
3041    # By construction, source_column is always one-dimensional.
3042    source_dimension = self.source_column.shape[0]
3043
3044    i1 = array_ops.reshape(
3045        array_ops.tile(
3046            array_ops.expand_dims(math_ops.range(0, batch_size), 1),
3047            [1, source_dimension]),
3048        (-1,))
3049    i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
3050    # Flatten the bucket indices and unique them across dimensions
3051    # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
3052    bucket_indices = (
3053        array_ops.reshape(input_tensor, (-1,)) +
3054        (len(self.boundaries) + 1) * i2)
3055
3056    indices = math_ops.cast(
3057        array_ops.transpose(array_ops.stack((i1, i2))), dtypes.int64)
3058    dense_shape = math_ops.cast(
3059        array_ops.stack([batch_size, source_dimension]), dtypes.int64)
3060    sparse_tensor = sparse_tensor_lib.SparseTensor(
3061        indices=indices,
3062        values=bucket_indices,
3063        dense_shape=dense_shape)
3064    return CategoricalColumn.IdWeightPair(sparse_tensor, None)
3065
3066  def get_sparse_tensors(self, transformation_cache, state_manager):
3067    """Converts dense inputs to SparseTensor so downstream code can use it."""
3068    input_tensor = transformation_cache.get(self, state_manager)
3069    return self._get_sparse_tensors_for_input_tensor(input_tensor)
3070
3071  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3072                          _FEATURE_COLUMN_DEPRECATION)
3073  def _get_sparse_tensors(self, inputs, weight_collections=None,
3074                          trainable=None):
3075    """Converts dense inputs to SparseTensor so downstream code can use it."""
3076    del weight_collections
3077    del trainable
3078    input_tensor = inputs.get(self)
3079    return self._get_sparse_tensors_for_input_tensor(input_tensor)
3080
3081  @property
3082  def parents(self):
3083    """See 'FeatureColumn` base class."""
3084    return [self.source_column]
3085
3086  def get_config(self):
3087    """See 'FeatureColumn` base class."""
3088    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
3089    config = dict(zip(self._fields, self))
3090    config['source_column'] = serialize_feature_column(self.source_column)
3091    return config
3092
3093  @classmethod
3094  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3095    """See 'FeatureColumn` base class."""
3096    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
3097    _check_config_keys(config, cls._fields)
3098    kwargs = _standardize_and_copy_config(config)
3099    kwargs['source_column'] = deserialize_feature_column(
3100        config['source_column'], custom_objects, columns_by_name)
3101    return cls(**kwargs)
3102
3103
3104class EmbeddingColumn(
3105    DenseColumn,
3106    SequenceDenseColumn,
3107    fc_old._DenseColumn,  # pylint: disable=protected-access
3108    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
3109    collections.namedtuple(
3110        'EmbeddingColumn',
3111        ('categorical_column', 'dimension', 'combiner', 'initializer',
3112         'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable',
3113         'use_safe_embedding_lookup'))):
3114  """See `embedding_column`."""
3115
3116  def __new__(cls,
3117              categorical_column,
3118              dimension,
3119              combiner,
3120              initializer,
3121              ckpt_to_load_from,
3122              tensor_name_in_ckpt,
3123              max_norm,
3124              trainable,
3125              use_safe_embedding_lookup=True):
3126    return super(EmbeddingColumn, cls).__new__(
3127        cls,
3128        categorical_column=categorical_column,
3129        dimension=dimension,
3130        combiner=combiner,
3131        initializer=initializer,
3132        ckpt_to_load_from=ckpt_to_load_from,
3133        tensor_name_in_ckpt=tensor_name_in_ckpt,
3134        max_norm=max_norm,
3135        trainable=trainable,
3136        use_safe_embedding_lookup=use_safe_embedding_lookup)
3137
3138  @property
3139  def _is_v2_column(self):
3140    return (isinstance(self.categorical_column, FeatureColumn) and
3141            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
3142
3143  @property
3144  def name(self):
3145    """See `FeatureColumn` base class."""
3146    return '{}_embedding'.format(self.categorical_column.name)
3147
3148  @property
3149  def parse_example_spec(self):
3150    """See `FeatureColumn` base class."""
3151    return self.categorical_column.parse_example_spec
3152
3153  @property
3154  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3155                          _FEATURE_COLUMN_DEPRECATION)
3156  def _parse_example_spec(self):
3157    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
3158
3159  def transform_feature(self, transformation_cache, state_manager):
3160    """Transforms underlying `categorical_column`."""
3161    return transformation_cache.get(self.categorical_column, state_manager)
3162
3163  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3164                          _FEATURE_COLUMN_DEPRECATION)
3165  def _transform_feature(self, inputs):
3166    return inputs.get(self.categorical_column)
3167
3168  @property
3169  def variable_shape(self):
3170    """See `DenseColumn` base class."""
3171    return tensor_shape.TensorShape([self.dimension])
3172
3173  @property
3174  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3175                          _FEATURE_COLUMN_DEPRECATION)
3176  def _variable_shape(self):
3177    return self.variable_shape
3178
3179  def create_state(self, state_manager):
3180    """Creates the embedding lookup variable."""
3181    default_num_buckets = (self.categorical_column.num_buckets
3182                           if self._is_v2_column
3183                           else self.categorical_column._num_buckets)   # pylint: disable=protected-access
3184    num_buckets = getattr(self.categorical_column, 'num_buckets',
3185                          default_num_buckets)
3186    embedding_shape = (num_buckets, self.dimension)
3187    state_manager.create_variable(
3188        self,
3189        name='embedding_weights',
3190        shape=embedding_shape,
3191        dtype=dtypes.float32,
3192        trainable=self.trainable,
3193        use_resource=True,
3194        initializer=self.initializer)
3195
3196  def _get_dense_tensor_internal_helper(self, sparse_tensors,
3197                                        embedding_weights):
3198    sparse_ids = sparse_tensors.id_tensor
3199    sparse_weights = sparse_tensors.weight_tensor
3200
3201    if self.ckpt_to_load_from is not None:
3202      to_restore = embedding_weights
3203      if isinstance(to_restore, variables.PartitionedVariable):
3204        to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
3205      checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
3206          self.tensor_name_in_ckpt: to_restore
3207      })
3208
3209    sparse_id_rank = tensor_shape.dimension_value(
3210        sparse_ids.dense_shape.get_shape()[0])
3211    embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
3212    if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
3213        sparse_id_rank <= 2):
3214      embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse
3215    # Return embedding lookup result.
3216    return embedding_lookup_sparse(
3217        embedding_weights,
3218        sparse_ids,
3219        sparse_weights,
3220        combiner=self.combiner,
3221        name='%s_weights' % self.name,
3222        max_norm=self.max_norm)
3223
3224  def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
3225    """Private method that follows the signature of get_dense_tensor."""
3226    embedding_weights = state_manager.get_variable(
3227        self, name='embedding_weights')
3228    return self._get_dense_tensor_internal_helper(sparse_tensors,
3229                                                  embedding_weights)
3230
3231  def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
3232                                     trainable):
3233    """Private method that follows the signature of _get_dense_tensor."""
3234    embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
3235    if (weight_collections and
3236        ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
3237      weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
3238    embedding_weights = variable_scope.get_variable(
3239        name='embedding_weights',
3240        shape=embedding_shape,
3241        dtype=dtypes.float32,
3242        initializer=self.initializer,
3243        trainable=self.trainable and trainable,
3244        collections=weight_collections)
3245    return self._get_dense_tensor_internal_helper(sparse_tensors,
3246                                                  embedding_weights)
3247
3248  def get_dense_tensor(self, transformation_cache, state_manager):
3249    """Returns tensor after doing the embedding lookup.
3250
3251    Args:
3252      transformation_cache: A `FeatureTransformationCache` object to access
3253        features.
3254      state_manager: A `StateManager` to create / access resources such as
3255        lookup tables.
3256
3257    Returns:
3258      Embedding lookup tensor.
3259
3260    Raises:
3261      ValueError: `categorical_column` is SequenceCategoricalColumn.
3262    """
3263    if isinstance(self.categorical_column, SequenceCategoricalColumn):
3264      raise ValueError(
3265          'In embedding_column: {}. '
3266          'categorical_column must not be of type SequenceCategoricalColumn. '
3267          'Suggested fix A: If you wish to use DenseFeatures, use a '
3268          'non-sequence categorical_column_with_*. '
3269          'Suggested fix B: If you wish to create sequence input, use '
3270          'SequenceFeatures instead of DenseFeatures. '
3271          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3272                                       self.categorical_column))
3273    # Get sparse IDs and weights.
3274    sparse_tensors = self.categorical_column.get_sparse_tensors(
3275        transformation_cache, state_manager)
3276    return self._get_dense_tensor_internal(sparse_tensors, state_manager)
3277
3278  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3279                          _FEATURE_COLUMN_DEPRECATION)
3280  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3281    if isinstance(
3282        self.categorical_column,
3283        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
3284      raise ValueError(
3285          'In embedding_column: {}. '
3286          'categorical_column must not be of type _SequenceCategoricalColumn. '
3287          'Suggested fix A: If you wish to use DenseFeatures, use a '
3288          'non-sequence categorical_column_with_*. '
3289          'Suggested fix B: If you wish to create sequence input, use '
3290          'SequenceFeatures instead of DenseFeatures. '
3291          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3292                                       self.categorical_column))
3293    sparse_tensors = self.categorical_column._get_sparse_tensors(  # pylint: disable=protected-access
3294        inputs, weight_collections, trainable)
3295    return self._old_get_dense_tensor_internal(sparse_tensors,
3296                                               weight_collections, trainable)
3297
3298  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3299    """See `SequenceDenseColumn` base class."""
3300    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3301      raise ValueError(
3302          'In embedding_column: {}. '
3303          'categorical_column must be of type SequenceCategoricalColumn '
3304          'to use SequenceFeatures. '
3305          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3306          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3307                                       self.categorical_column))
3308    sparse_tensors = self.categorical_column.get_sparse_tensors(
3309        transformation_cache, state_manager)
3310    dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
3311                                                   state_manager)
3312    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3313        sparse_tensors.id_tensor)
3314    return SequenceDenseColumn.TensorSequenceLengthPair(
3315        dense_tensor=dense_tensor, sequence_length=sequence_length)
3316
3317  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3318                          _FEATURE_COLUMN_DEPRECATION)
3319  def _get_sequence_dense_tensor(self,
3320                                 inputs,
3321                                 weight_collections=None,
3322                                 trainable=None):
3323    if not isinstance(
3324        self.categorical_column,
3325        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
3326      raise ValueError(
3327          'In embedding_column: {}. '
3328          'categorical_column must be of type SequenceCategoricalColumn '
3329          'to use SequenceFeatures. '
3330          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3331          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3332                                       self.categorical_column))
3333    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
3334    dense_tensor = self._old_get_dense_tensor_internal(
3335        sparse_tensors,
3336        weight_collections=weight_collections,
3337        trainable=trainable)
3338    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3339        sparse_tensors.id_tensor)
3340    return SequenceDenseColumn.TensorSequenceLengthPair(
3341        dense_tensor=dense_tensor, sequence_length=sequence_length)
3342
3343  @property
3344  def parents(self):
3345    """See 'FeatureColumn` base class."""
3346    return [self.categorical_column]
3347
3348  def get_config(self):
3349    """See 'FeatureColumn` base class."""
3350    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
3351    config = dict(zip(self._fields, self))
3352    config['categorical_column'] = serialize_feature_column(
3353        self.categorical_column)
3354    config['initializer'] = initializers.serialize(self.initializer)
3355    return config
3356
3357  @classmethod
3358  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3359    """See 'FeatureColumn` base class."""
3360    if 'use_safe_embedding_lookup' not in config:
3361      config['use_safe_embedding_lookup'] = True
3362    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
3363    _check_config_keys(config, cls._fields)
3364    kwargs = _standardize_and_copy_config(config)
3365    kwargs['categorical_column'] = deserialize_feature_column(
3366        config['categorical_column'], custom_objects, columns_by_name)
3367    kwargs['initializer'] = initializers.deserialize(
3368        config['initializer'], custom_objects=custom_objects)
3369    return cls(**kwargs)
3370
3371
3372def _raise_shared_embedding_column_error():
3373  raise ValueError('SharedEmbeddingColumns are not supported in '
3374                   '`linear_model` or `input_layer`. Please use '
3375                   '`DenseFeatures` or `LinearModel` instead.')
3376
3377
3378class SharedEmbeddingColumnCreator(tracking.AutoTrackable):
3379
3380  def __init__(self,
3381               dimension,
3382               initializer,
3383               ckpt_to_load_from,
3384               tensor_name_in_ckpt,
3385               num_buckets,
3386               trainable,
3387               name='shared_embedding_column_creator',
3388               use_safe_embedding_lookup=True):
3389    self._dimension = dimension
3390    self._initializer = initializer
3391    self._ckpt_to_load_from = ckpt_to_load_from
3392    self._tensor_name_in_ckpt = tensor_name_in_ckpt
3393    self._num_buckets = num_buckets
3394    self._trainable = trainable
3395    self._name = name
3396    self._use_safe_embedding_lookup = use_safe_embedding_lookup
3397    # Map from graph keys to embedding_weight variables.
3398    self._embedding_weights = {}
3399
3400  def __call__(self, categorical_column, combiner, max_norm):
3401    return SharedEmbeddingColumn(categorical_column, self, combiner, max_norm,
3402                                 self._use_safe_embedding_lookup)
3403
3404  @property
3405  def embedding_weights(self):
3406    key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
3407    if key not in self._embedding_weights:
3408      embedding_shape = (self._num_buckets, self._dimension)
3409      var = variable_scope.get_variable(
3410          name=self._name,
3411          shape=embedding_shape,
3412          dtype=dtypes.float32,
3413          initializer=self._initializer,
3414          trainable=self._trainable)
3415
3416      if self._ckpt_to_load_from is not None:
3417        to_restore = var
3418        if isinstance(to_restore, variables.PartitionedVariable):
3419          to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
3420        checkpoint_utils.init_from_checkpoint(
3421            self._ckpt_to_load_from, {self._tensor_name_in_ckpt: to_restore})
3422      self._embedding_weights[key] = var
3423    return self._embedding_weights[key]
3424
3425  @property
3426  def dimension(self):
3427    return self._dimension
3428
3429
3430class SharedEmbeddingColumn(
3431    DenseColumn,
3432    SequenceDenseColumn,
3433    fc_old._DenseColumn,  # pylint: disable=protected-access
3434    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
3435    collections.namedtuple(
3436        'SharedEmbeddingColumn',
3437        ('categorical_column', 'shared_embedding_column_creator', 'combiner',
3438         'max_norm', 'use_safe_embedding_lookup'))):
3439  """See `embedding_column`."""
3440
3441  def __new__(cls,
3442              categorical_column,
3443              shared_embedding_column_creator,
3444              combiner,
3445              max_norm,
3446              use_safe_embedding_lookup=True):
3447    return super(SharedEmbeddingColumn, cls).__new__(
3448        cls,
3449        categorical_column=categorical_column,
3450        shared_embedding_column_creator=shared_embedding_column_creator,
3451        combiner=combiner,
3452        max_norm=max_norm,
3453        use_safe_embedding_lookup=use_safe_embedding_lookup)
3454
3455  @property
3456  def _is_v2_column(self):
3457    return True
3458
3459  @property
3460  def name(self):
3461    """See `FeatureColumn` base class."""
3462    return '{}_shared_embedding'.format(self.categorical_column.name)
3463
3464  @property
3465  def parse_example_spec(self):
3466    """See `FeatureColumn` base class."""
3467    return self.categorical_column.parse_example_spec
3468
3469  @property
3470  def _parse_example_spec(self):
3471    return _raise_shared_embedding_column_error()
3472
3473  def transform_feature(self, transformation_cache, state_manager):
3474    """See `FeatureColumn` base class."""
3475    return transformation_cache.get(self.categorical_column, state_manager)
3476
3477  def _transform_feature(self, inputs):
3478    return _raise_shared_embedding_column_error()
3479
3480  @property
3481  def variable_shape(self):
3482    """See `DenseColumn` base class."""
3483    return tensor_shape.TensorShape(
3484        [self.shared_embedding_column_creator.dimension])
3485
3486  @property
3487  def _variable_shape(self):
3488    return _raise_shared_embedding_column_error()
3489
3490  def _get_dense_tensor_internal(self, transformation_cache, state_manager):
3491    """Private method that follows the signature of _get_dense_tensor."""
3492    # This method is called from a variable_scope with name _var_scope_name,
3493    # which is shared among all shared embeddings. Open a name_scope here, so
3494    # that the ops for different columns have distinct names.
3495    with ops.name_scope(None, default_name=self.name):
3496      # Get sparse IDs and weights.
3497      sparse_tensors = self.categorical_column.get_sparse_tensors(
3498          transformation_cache, state_manager)
3499      sparse_ids = sparse_tensors.id_tensor
3500      sparse_weights = sparse_tensors.weight_tensor
3501
3502      embedding_weights = self.shared_embedding_column_creator.embedding_weights
3503
3504      sparse_id_rank = tensor_shape.dimension_value(
3505          sparse_ids.dense_shape.get_shape()[0])
3506      embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
3507      if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
3508          sparse_id_rank <= 2):
3509        embedding_lookup_sparse = (embedding_ops.embedding_lookup_sparse)
3510      # Return embedding lookup result.
3511      return embedding_lookup_sparse(
3512          embedding_weights,
3513          sparse_ids,
3514          sparse_weights,
3515          combiner=self.combiner,
3516          name='%s_weights' % self.name,
3517          max_norm=self.max_norm)
3518
3519  def get_dense_tensor(self, transformation_cache, state_manager):
3520    """Returns the embedding lookup result."""
3521    if isinstance(self.categorical_column, SequenceCategoricalColumn):
3522      raise ValueError(
3523          'In embedding_column: {}. '
3524          'categorical_column must not be of type SequenceCategoricalColumn. '
3525          'Suggested fix A: If you wish to use DenseFeatures, use a '
3526          'non-sequence categorical_column_with_*. '
3527          'Suggested fix B: If you wish to create sequence input, use '
3528          'SequenceFeatures instead of DenseFeatures. '
3529          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3530                                       self.categorical_column))
3531    return self._get_dense_tensor_internal(transformation_cache, state_manager)
3532
3533  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3534    return _raise_shared_embedding_column_error()
3535
3536  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3537    """See `SequenceDenseColumn` base class."""
3538    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3539      raise ValueError(
3540          'In embedding_column: {}. '
3541          'categorical_column must be of type SequenceCategoricalColumn '
3542          'to use SequenceFeatures. '
3543          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3544          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3545                                       self.categorical_column))
3546    dense_tensor = self._get_dense_tensor_internal(transformation_cache,
3547                                                   state_manager)
3548    sparse_tensors = self.categorical_column.get_sparse_tensors(
3549        transformation_cache, state_manager)
3550    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3551        sparse_tensors.id_tensor)
3552    return SequenceDenseColumn.TensorSequenceLengthPair(
3553        dense_tensor=dense_tensor, sequence_length=sequence_length)
3554
3555  def _get_sequence_dense_tensor(self,
3556                                 inputs,
3557                                 weight_collections=None,
3558                                 trainable=None):
3559    return _raise_shared_embedding_column_error()
3560
3561  @property
3562  def parents(self):
3563    """See 'FeatureColumn` base class."""
3564    return [self.categorical_column]
3565
3566
3567def _check_shape(shape, key):
3568  """Returns shape if it's valid, raises error otherwise."""
3569  assert shape is not None
3570  if not nest.is_sequence(shape):
3571    shape = [shape]
3572  shape = tuple(shape)
3573  for dimension in shape:
3574    if not isinstance(dimension, int):
3575      raise TypeError('shape dimensions must be integer. '
3576                      'shape: {}, key: {}'.format(shape, key))
3577    if dimension < 1:
3578      raise ValueError('shape dimensions must be greater than 0. '
3579                       'shape: {}, key: {}'.format(shape, key))
3580  return shape
3581
3582
3583class HashedCategoricalColumn(
3584    CategoricalColumn,
3585    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3586    collections.namedtuple('HashedCategoricalColumn',
3587                           ('key', 'hash_bucket_size', 'dtype'))):
3588  """see `categorical_column_with_hash_bucket`."""
3589
3590  @property
3591  def _is_v2_column(self):
3592    return True
3593
3594  @property
3595  def name(self):
3596    """See `FeatureColumn` base class."""
3597    return self.key
3598
3599  @property
3600  def parse_example_spec(self):
3601    """See `FeatureColumn` base class."""
3602    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3603
3604  @property
3605  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3606                          _FEATURE_COLUMN_DEPRECATION)
3607  def _parse_example_spec(self):
3608    return self.parse_example_spec
3609
3610  def _transform_input_tensor(self, input_tensor):
3611    """Hashes the values in the feature_column."""
3612    if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
3613      raise ValueError('SparseColumn input must be a SparseTensor.')
3614
3615    fc_utils.assert_string_or_int(
3616        input_tensor.dtype,
3617        prefix='column_name: {} input_tensor'.format(self.key))
3618
3619    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3620      raise ValueError(
3621          'Column dtype and SparseTensors dtype must be compatible. '
3622          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3623              self.key, self.dtype, input_tensor.dtype))
3624
3625    if self.dtype == dtypes.string:
3626      sparse_values = input_tensor.values
3627    else:
3628      sparse_values = string_ops.as_string(input_tensor.values)
3629
3630    sparse_id_values = string_ops.string_to_hash_bucket_fast(
3631        sparse_values, self.hash_bucket_size, name='lookup')
3632    return sparse_tensor_lib.SparseTensor(
3633        input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
3634
3635  def transform_feature(self, transformation_cache, state_manager):
3636    """Hashes the values in the feature_column."""
3637    input_tensor = _to_sparse_input_and_drop_ignore_values(
3638        transformation_cache.get(self.key, state_manager))
3639    return self._transform_input_tensor(input_tensor)
3640
3641  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3642                          _FEATURE_COLUMN_DEPRECATION)
3643  def _transform_feature(self, inputs):
3644    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3645    return self._transform_input_tensor(input_tensor)
3646
3647  @property
3648  def num_buckets(self):
3649    """Returns number of buckets in this sparse feature."""
3650    return self.hash_bucket_size
3651
3652  @property
3653  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3654                          _FEATURE_COLUMN_DEPRECATION)
3655  def _num_buckets(self):
3656    return self.num_buckets
3657
3658  def get_sparse_tensors(self, transformation_cache, state_manager):
3659    """See `CategoricalColumn` base class."""
3660    return CategoricalColumn.IdWeightPair(
3661        transformation_cache.get(self, state_manager), None)
3662
3663  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3664                          _FEATURE_COLUMN_DEPRECATION)
3665  def _get_sparse_tensors(self, inputs, weight_collections=None,
3666                          trainable=None):
3667    del weight_collections
3668    del trainable
3669    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3670
3671  @property
3672  def parents(self):
3673    """See 'FeatureColumn` base class."""
3674    return [self.key]
3675
3676  def get_config(self):
3677    """See 'FeatureColumn` base class."""
3678    config = dict(zip(self._fields, self))
3679    config['dtype'] = self.dtype.name
3680    return config
3681
3682  @classmethod
3683  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3684    """See 'FeatureColumn` base class."""
3685    _check_config_keys(config, cls._fields)
3686    kwargs = _standardize_and_copy_config(config)
3687    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3688    return cls(**kwargs)
3689
3690
3691class VocabularyFileCategoricalColumn(
3692    CategoricalColumn,
3693    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3694    collections.namedtuple('VocabularyFileCategoricalColumn',
3695                           ('key', 'vocabulary_file', 'vocabulary_size',
3696                            'num_oov_buckets', 'dtype', 'default_value'))):
3697  """See `categorical_column_with_vocabulary_file`."""
3698
3699  @property
3700  def _is_v2_column(self):
3701    return True
3702
3703  @property
3704  def name(self):
3705    """See `FeatureColumn` base class."""
3706    return self.key
3707
3708  @property
3709  def parse_example_spec(self):
3710    """See `FeatureColumn` base class."""
3711    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3712
3713  @property
3714  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3715                          _FEATURE_COLUMN_DEPRECATION)
3716  def _parse_example_spec(self):
3717    return self.parse_example_spec
3718
3719  def _transform_input_tensor(self, input_tensor, state_manager=None):
3720    """Creates a lookup table for the vocabulary."""
3721    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3722      raise ValueError(
3723          'Column dtype and SparseTensors dtype must be compatible. '
3724          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3725              self.key, self.dtype, input_tensor.dtype))
3726
3727    fc_utils.assert_string_or_int(
3728        input_tensor.dtype,
3729        prefix='column_name: {} input_tensor'.format(self.key))
3730
3731    key_dtype = self.dtype
3732    if input_tensor.dtype.is_integer:
3733      # `index_table_from_file` requires 64-bit integer keys.
3734      key_dtype = dtypes.int64
3735      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3736
3737    name = '{}_lookup'.format(self.key)
3738    table = lookup_ops.index_table_from_file(
3739        vocabulary_file=self.vocabulary_file,
3740        num_oov_buckets=self.num_oov_buckets,
3741        vocab_size=self.vocabulary_size,
3742        default_value=self.default_value,
3743        key_dtype=key_dtype,
3744        name=name)
3745    if state_manager is not None:
3746      state_manager.add_resource(self, name, table)
3747    return table.lookup(input_tensor)
3748
3749  def transform_feature(self, transformation_cache, state_manager):
3750    """Creates a lookup table for the vocabulary."""
3751    input_tensor = _to_sparse_input_and_drop_ignore_values(
3752        transformation_cache.get(self.key, state_manager))
3753    return self._transform_input_tensor(input_tensor, state_manager)
3754
3755  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3756                          _FEATURE_COLUMN_DEPRECATION)
3757  def _transform_feature(self, inputs):
3758    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3759    return self._transform_input_tensor(input_tensor)
3760
3761  @property
3762  def num_buckets(self):
3763    """Returns number of buckets in this sparse feature."""
3764    return self.vocabulary_size + self.num_oov_buckets
3765
3766  @property
3767  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3768                          _FEATURE_COLUMN_DEPRECATION)
3769  def _num_buckets(self):
3770    return self.num_buckets
3771
3772  def get_sparse_tensors(self, transformation_cache, state_manager):
3773    """See `CategoricalColumn` base class."""
3774    return CategoricalColumn.IdWeightPair(
3775        transformation_cache.get(self, state_manager), None)
3776
3777  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3778                          _FEATURE_COLUMN_DEPRECATION)
3779  def _get_sparse_tensors(self, inputs, weight_collections=None,
3780                          trainable=None):
3781    del weight_collections
3782    del trainable
3783    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3784
3785  @property
3786  def parents(self):
3787    """See 'FeatureColumn` base class."""
3788    return [self.key]
3789
3790  def get_config(self):
3791    """See 'FeatureColumn` base class."""
3792    config = dict(zip(self._fields, self))
3793    config['dtype'] = self.dtype.name
3794    return config
3795
3796  @classmethod
3797  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3798    """See 'FeatureColumn` base class."""
3799    _check_config_keys(config, cls._fields)
3800    kwargs = _standardize_and_copy_config(config)
3801    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3802    return cls(**kwargs)
3803
3804
3805class VocabularyListCategoricalColumn(
3806    CategoricalColumn,
3807    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3808    collections.namedtuple(
3809        'VocabularyListCategoricalColumn',
3810        ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
3811):
3812  """See `categorical_column_with_vocabulary_list`."""
3813
3814  @property
3815  def _is_v2_column(self):
3816    return True
3817
3818  @property
3819  def name(self):
3820    """See `FeatureColumn` base class."""
3821    return self.key
3822
3823  @property
3824  def parse_example_spec(self):
3825    """See `FeatureColumn` base class."""
3826    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3827
3828  @property
3829  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3830                          _FEATURE_COLUMN_DEPRECATION)
3831  def _parse_example_spec(self):
3832    return self.parse_example_spec
3833
3834  def _transform_input_tensor(self, input_tensor, state_manager=None):
3835    """Creates a lookup table for the vocabulary list."""
3836    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3837      raise ValueError(
3838          'Column dtype and SparseTensors dtype must be compatible. '
3839          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3840              self.key, self.dtype, input_tensor.dtype))
3841
3842    fc_utils.assert_string_or_int(
3843        input_tensor.dtype,
3844        prefix='column_name: {} input_tensor'.format(self.key))
3845
3846    key_dtype = self.dtype
3847    if input_tensor.dtype.is_integer:
3848      # `index_table_from_tensor` requires 64-bit integer keys.
3849      key_dtype = dtypes.int64
3850      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3851
3852    name = '{}_lookup'.format(self.key)
3853    table = lookup_ops.index_table_from_tensor(
3854        vocabulary_list=tuple(self.vocabulary_list),
3855        default_value=self.default_value,
3856        num_oov_buckets=self.num_oov_buckets,
3857        dtype=key_dtype,
3858        name=name)
3859    if state_manager is not None:
3860      state_manager.add_resource(self, name, table)
3861    return table.lookup(input_tensor)
3862
3863  def transform_feature(self, transformation_cache, state_manager):
3864    """Creates a lookup table for the vocabulary list."""
3865    input_tensor = _to_sparse_input_and_drop_ignore_values(
3866        transformation_cache.get(self.key, state_manager))
3867    return self._transform_input_tensor(input_tensor, state_manager)
3868
3869  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3870                          _FEATURE_COLUMN_DEPRECATION)
3871  def _transform_feature(self, inputs):
3872    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3873    return self._transform_input_tensor(input_tensor)
3874
3875  @property
3876  def num_buckets(self):
3877    """Returns number of buckets in this sparse feature."""
3878    return len(self.vocabulary_list) + self.num_oov_buckets
3879
3880  @property
3881  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3882                          _FEATURE_COLUMN_DEPRECATION)
3883  def _num_buckets(self):
3884    return self.num_buckets
3885
3886  def get_sparse_tensors(self, transformation_cache, state_manager):
3887    """See `CategoricalColumn` base class."""
3888    return CategoricalColumn.IdWeightPair(
3889        transformation_cache.get(self, state_manager), None)
3890
3891  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3892                          _FEATURE_COLUMN_DEPRECATION)
3893  def _get_sparse_tensors(self, inputs, weight_collections=None,
3894                          trainable=None):
3895    del weight_collections
3896    del trainable
3897    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3898
3899  @property
3900  def parents(self):
3901    """See 'FeatureColumn` base class."""
3902    return [self.key]
3903
3904  def get_config(self):
3905    """See 'FeatureColumn` base class."""
3906    config = dict(zip(self._fields, self))
3907    config['dtype'] = self.dtype.name
3908    return config
3909
3910  @classmethod
3911  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3912    """See 'FeatureColumn` base class."""
3913    _check_config_keys(config, cls._fields)
3914    kwargs = _standardize_and_copy_config(config)
3915    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3916    return cls(**kwargs)
3917
3918
3919class IdentityCategoricalColumn(
3920    CategoricalColumn,
3921    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3922    collections.namedtuple('IdentityCategoricalColumn',
3923                           ('key', 'number_buckets', 'default_value'))):
3924
3925  """See `categorical_column_with_identity`."""
3926
3927  @property
3928  def _is_v2_column(self):
3929    return True
3930
3931  @property
3932  def name(self):
3933    """See `FeatureColumn` base class."""
3934    return self.key
3935
3936  @property
3937  def parse_example_spec(self):
3938    """See `FeatureColumn` base class."""
3939    return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
3940
3941  @property
3942  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3943                          _FEATURE_COLUMN_DEPRECATION)
3944  def _parse_example_spec(self):
3945    return self.parse_example_spec
3946
3947  def _transform_input_tensor(self, input_tensor):
3948    """Returns a SparseTensor with identity values."""
3949    if not input_tensor.dtype.is_integer:
3950      raise ValueError(
3951          'Invalid input, not integer. key: {} dtype: {}'.format(
3952              self.key, input_tensor.dtype))
3953    values = input_tensor.values
3954    if input_tensor.values.dtype != dtypes.int64:
3955      values = math_ops.cast(values, dtypes.int64, name='values')
3956    if self.default_value is not None:
3957      values = math_ops.cast(input_tensor.values, dtypes.int64, name='values')
3958      num_buckets = math_ops.cast(
3959          self.num_buckets, dtypes.int64, name='num_buckets')
3960      zero = math_ops.cast(0, dtypes.int64, name='zero')
3961      # Assign default for out-of-range values.
3962      values = array_ops.where_v2(
3963          math_ops.logical_or(
3964              values < zero, values >= num_buckets, name='out_of_range'),
3965          array_ops.fill(
3966              dims=array_ops.shape(values),
3967              value=math_ops.cast(self.default_value, dtypes.int64),
3968              name='default_values'), values)
3969
3970    return sparse_tensor_lib.SparseTensor(
3971        indices=input_tensor.indices,
3972        values=values,
3973        dense_shape=input_tensor.dense_shape)
3974
3975  def transform_feature(self, transformation_cache, state_manager):
3976    """Returns a SparseTensor with identity values."""
3977    input_tensor = _to_sparse_input_and_drop_ignore_values(
3978        transformation_cache.get(self.key, state_manager))
3979    return self._transform_input_tensor(input_tensor)
3980
3981  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3982                          _FEATURE_COLUMN_DEPRECATION)
3983  def _transform_feature(self, inputs):
3984    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3985    return self._transform_input_tensor(input_tensor)
3986
3987  @property
3988  def num_buckets(self):
3989    """Returns number of buckets in this sparse feature."""
3990    return self.number_buckets
3991
3992  @property
3993  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3994                          _FEATURE_COLUMN_DEPRECATION)
3995  def _num_buckets(self):
3996    return self.num_buckets
3997
3998  def get_sparse_tensors(self, transformation_cache, state_manager):
3999    """See `CategoricalColumn` base class."""
4000    return CategoricalColumn.IdWeightPair(
4001        transformation_cache.get(self, state_manager), None)
4002
4003  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4004                          _FEATURE_COLUMN_DEPRECATION)
4005  def _get_sparse_tensors(self, inputs, weight_collections=None,
4006                          trainable=None):
4007    del weight_collections
4008    del trainable
4009    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
4010
4011  @property
4012  def parents(self):
4013    """See 'FeatureColumn` base class."""
4014    return [self.key]
4015
4016  def get_config(self):
4017    """See 'FeatureColumn` base class."""
4018    return dict(zip(self._fields, self))
4019
4020  @classmethod
4021  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4022    """See 'FeatureColumn` base class."""
4023    _check_config_keys(config, cls._fields)
4024    kwargs = _standardize_and_copy_config(config)
4025    return cls(**kwargs)
4026
4027
4028class WeightedCategoricalColumn(
4029    CategoricalColumn,
4030    fc_old._CategoricalColumn,  # pylint: disable=protected-access
4031    collections.namedtuple(
4032        'WeightedCategoricalColumn',
4033        ('categorical_column', 'weight_feature_key', 'dtype'))):
4034  """See `weighted_categorical_column`."""
4035
4036  @property
4037  def _is_v2_column(self):
4038    return (isinstance(self.categorical_column, FeatureColumn) and
4039            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
4040
4041  @property
4042  def name(self):
4043    """See `FeatureColumn` base class."""
4044    return '{}_weighted_by_{}'.format(
4045        self.categorical_column.name, self.weight_feature_key)
4046
4047  @property
4048  def parse_example_spec(self):
4049    """See `FeatureColumn` base class."""
4050    config = self.categorical_column.parse_example_spec
4051    if self.weight_feature_key in config:
4052      raise ValueError('Parse config {} already exists for {}.'.format(
4053          config[self.weight_feature_key], self.weight_feature_key))
4054    config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
4055    return config
4056
4057  @property
4058  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4059                          _FEATURE_COLUMN_DEPRECATION)
4060  def _parse_example_spec(self):
4061    config = self.categorical_column._parse_example_spec  # pylint: disable=protected-access
4062    if self.weight_feature_key in config:
4063      raise ValueError('Parse config {} already exists for {}.'.format(
4064          config[self.weight_feature_key], self.weight_feature_key))
4065    config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
4066    return config
4067
4068  @property
4069  def num_buckets(self):
4070    """See `DenseColumn` base class."""
4071    return self.categorical_column.num_buckets
4072
4073  @property
4074  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4075                          _FEATURE_COLUMN_DEPRECATION)
4076  def _num_buckets(self):
4077    return self.categorical_column._num_buckets  # pylint: disable=protected-access
4078
4079  def _transform_weight_tensor(self, weight_tensor):
4080    if weight_tensor is None:
4081      raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
4082    weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
4083        weight_tensor)
4084    if self.dtype != weight_tensor.dtype.base_dtype:
4085      raise ValueError('Bad dtype, expected {}, but got {}.'.format(
4086          self.dtype, weight_tensor.dtype))
4087    if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
4088      # The weight tensor can be a regular Tensor. In this case, sparsify it.
4089      weight_tensor = _to_sparse_input_and_drop_ignore_values(
4090          weight_tensor, ignore_value=0.0)
4091    if not weight_tensor.dtype.is_floating:
4092      weight_tensor = math_ops.cast(weight_tensor, dtypes.float32)
4093    return weight_tensor
4094
4095  def transform_feature(self, transformation_cache, state_manager):
4096    """Applies weights to tensor generated from `categorical_column`'."""
4097    weight_tensor = transformation_cache.get(self.weight_feature_key,
4098                                             state_manager)
4099    sparse_weight_tensor = self._transform_weight_tensor(weight_tensor)
4100    sparse_categorical_tensor = _to_sparse_input_and_drop_ignore_values(
4101        transformation_cache.get(self.categorical_column, state_manager))
4102    return (sparse_categorical_tensor, sparse_weight_tensor)
4103
4104  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4105                          _FEATURE_COLUMN_DEPRECATION)
4106  def _transform_feature(self, inputs):
4107    """Applies weights to tensor generated from `categorical_column`'."""
4108    weight_tensor = inputs.get(self.weight_feature_key)
4109    weight_tensor = self._transform_weight_tensor(weight_tensor)
4110    return (inputs.get(self.categorical_column), weight_tensor)
4111
4112  def get_sparse_tensors(self, transformation_cache, state_manager):
4113    """See `CategoricalColumn` base class."""
4114    tensors = transformation_cache.get(self, state_manager)
4115    return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
4116
4117  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4118                          _FEATURE_COLUMN_DEPRECATION)
4119  def _get_sparse_tensors(self, inputs, weight_collections=None,
4120                          trainable=None):
4121    del weight_collections
4122    del trainable
4123    tensors = inputs.get(self)
4124    return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
4125
4126  @property
4127  def parents(self):
4128    """See 'FeatureColumn` base class."""
4129    return [self.categorical_column, self.weight_feature_key]
4130
4131  def get_config(self):
4132    """See 'FeatureColumn` base class."""
4133    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4134    config = dict(zip(self._fields, self))
4135    config['categorical_column'] = serialize_feature_column(
4136        self.categorical_column)
4137    config['dtype'] = self.dtype.name
4138    return config
4139
4140  @classmethod
4141  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4142    """See 'FeatureColumn` base class."""
4143    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4144    _check_config_keys(config, cls._fields)
4145    kwargs = _standardize_and_copy_config(config)
4146    kwargs['categorical_column'] = deserialize_feature_column(
4147        config['categorical_column'], custom_objects, columns_by_name)
4148    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
4149    return cls(**kwargs)
4150
4151
4152class CrossedColumn(
4153    CategoricalColumn,
4154    fc_old._CategoricalColumn,  # pylint: disable=protected-access
4155    collections.namedtuple('CrossedColumn',
4156                           ('keys', 'hash_bucket_size', 'hash_key'))):
4157  """See `crossed_column`."""
4158
4159  @property
4160  def _is_v2_column(self):
4161    for key in _collect_leaf_level_keys(self):
4162      if isinstance(key, six.string_types):
4163        continue
4164      if not isinstance(key, FeatureColumn):
4165        return False
4166      if not key._is_v2_column:  # pylint: disable=protected-access
4167        return False
4168    return True
4169
4170  @property
4171  def name(self):
4172    """See `FeatureColumn` base class."""
4173    feature_names = []
4174    for key in _collect_leaf_level_keys(self):
4175      if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)):  # pylint: disable=protected-access
4176        feature_names.append(key.name)
4177      else:  # key must be a string
4178        feature_names.append(key)
4179    return '_X_'.join(sorted(feature_names))
4180
4181  @property
4182  def parse_example_spec(self):
4183    """See `FeatureColumn` base class."""
4184    config = {}
4185    for key in self.keys:
4186      if isinstance(key, FeatureColumn):
4187        config.update(key.parse_example_spec)
4188      elif isinstance(key, fc_old._FeatureColumn):  # pylint: disable=protected-access
4189        config.update(key._parse_example_spec)  # pylint: disable=protected-access
4190      else:  # key must be a string
4191        config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
4192    return config
4193
4194  @property
4195  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4196                          _FEATURE_COLUMN_DEPRECATION)
4197  def _parse_example_spec(self):
4198    return self.parse_example_spec
4199
4200  def transform_feature(self, transformation_cache, state_manager):
4201    """Generates a hashed sparse cross from the input tensors."""
4202    feature_tensors = []
4203    for key in _collect_leaf_level_keys(self):
4204      if isinstance(key, six.string_types):
4205        feature_tensors.append(transformation_cache.get(key, state_manager))
4206      elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)):  # pylint: disable=protected-access
4207        ids_and_weights = key.get_sparse_tensors(transformation_cache,
4208                                                 state_manager)
4209        if ids_and_weights.weight_tensor is not None:
4210          raise ValueError(
4211              'crossed_column does not support weight_tensor, but the given '
4212              'column populates weight_tensor. '
4213              'Given column: {}'.format(key.name))
4214        feature_tensors.append(ids_and_weights.id_tensor)
4215      else:
4216        raise ValueError('Unsupported column type. Given: {}'.format(key))
4217    return sparse_ops.sparse_cross_hashed(
4218        inputs=feature_tensors,
4219        num_buckets=self.hash_bucket_size,
4220        hash_key=self.hash_key)
4221
4222  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4223                          _FEATURE_COLUMN_DEPRECATION)
4224  def _transform_feature(self, inputs):
4225    """Generates a hashed sparse cross from the input tensors."""
4226    feature_tensors = []
4227    for key in _collect_leaf_level_keys(self):
4228      if isinstance(key, six.string_types):
4229        feature_tensors.append(inputs.get(key))
4230      elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)):  # pylint: disable=protected-access
4231        ids_and_weights = key._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4232        if ids_and_weights.weight_tensor is not None:
4233          raise ValueError(
4234              'crossed_column does not support weight_tensor, but the given '
4235              'column populates weight_tensor. '
4236              'Given column: {}'.format(key.name))
4237        feature_tensors.append(ids_and_weights.id_tensor)
4238      else:
4239        raise ValueError('Unsupported column type. Given: {}'.format(key))
4240    return sparse_ops.sparse_cross_hashed(
4241        inputs=feature_tensors,
4242        num_buckets=self.hash_bucket_size,
4243        hash_key=self.hash_key)
4244
4245  @property
4246  def num_buckets(self):
4247    """Returns number of buckets in this sparse feature."""
4248    return self.hash_bucket_size
4249
4250  @property
4251  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4252                          _FEATURE_COLUMN_DEPRECATION)
4253  def _num_buckets(self):
4254    return self.num_buckets
4255
4256  def get_sparse_tensors(self, transformation_cache, state_manager):
4257    """See `CategoricalColumn` base class."""
4258    return CategoricalColumn.IdWeightPair(
4259        transformation_cache.get(self, state_manager), None)
4260
4261  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4262                          _FEATURE_COLUMN_DEPRECATION)
4263  def _get_sparse_tensors(self, inputs, weight_collections=None,
4264                          trainable=None):
4265    """See `CategoricalColumn` base class."""
4266    del weight_collections
4267    del trainable
4268    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
4269
4270  @property
4271  def parents(self):
4272    """See 'FeatureColumn` base class."""
4273    return list(self.keys)
4274
4275  def get_config(self):
4276    """See 'FeatureColumn` base class."""
4277    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4278    config = dict(zip(self._fields, self))
4279    config['keys'] = tuple([serialize_feature_column(fc) for fc in self.keys])
4280    return config
4281
4282  @classmethod
4283  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4284    """See 'FeatureColumn` base class."""
4285    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4286    _check_config_keys(config, cls._fields)
4287    kwargs = _standardize_and_copy_config(config)
4288    kwargs['keys'] = tuple([
4289        deserialize_feature_column(c, custom_objects, columns_by_name)
4290        for c in config['keys']
4291    ])
4292    return cls(**kwargs)
4293
4294
4295def _collect_leaf_level_keys(cross):
4296  """Collects base keys by expanding all nested crosses.
4297
4298  Args:
4299    cross: A `CrossedColumn`.
4300
4301  Returns:
4302    A list of strings or `CategoricalColumn` instances.
4303  """
4304  leaf_level_keys = []
4305  for k in cross.keys:
4306    if isinstance(k, CrossedColumn):
4307      leaf_level_keys.extend(_collect_leaf_level_keys(k))
4308    else:
4309      leaf_level_keys.append(k)
4310  return leaf_level_keys
4311
4312
4313def _prune_invalid_ids(sparse_ids, sparse_weights):
4314  """Prune invalid IDs (< 0) from the input ids and weights."""
4315  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
4316  if sparse_weights is not None:
4317    is_id_valid = math_ops.logical_and(
4318        is_id_valid,
4319        array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
4320  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
4321  if sparse_weights is not None:
4322    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
4323  return sparse_ids, sparse_weights
4324
4325
4326def _prune_invalid_weights(sparse_ids, sparse_weights):
4327  """Prune invalid weights (< 0) from the input ids and weights."""
4328  if sparse_weights is not None:
4329    is_weights_valid = math_ops.greater(sparse_weights.values, 0)
4330    sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
4331    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
4332  return sparse_ids, sparse_weights
4333
4334
4335class IndicatorColumn(
4336    DenseColumn,
4337    SequenceDenseColumn,
4338    fc_old._DenseColumn,  # pylint: disable=protected-access
4339    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
4340    collections.namedtuple('IndicatorColumn', ('categorical_column'))):
4341  """Represents a one-hot column for use in deep networks.
4342
4343  Args:
4344    categorical_column: A `CategoricalColumn` which is created by
4345      `categorical_column_with_*` function.
4346  """
4347
4348  @property
4349  def _is_v2_column(self):
4350    return (isinstance(self.categorical_column, FeatureColumn) and
4351            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
4352
4353  @property
4354  def name(self):
4355    """See `FeatureColumn` base class."""
4356    return '{}_indicator'.format(self.categorical_column.name)
4357
4358  def _transform_id_weight_pair(self, id_weight_pair, size):
4359    id_tensor = id_weight_pair.id_tensor
4360    weight_tensor = id_weight_pair.weight_tensor
4361
4362    # If the underlying column is weighted, return the input as a dense tensor.
4363    if weight_tensor is not None:
4364      weighted_column = sparse_ops.sparse_merge(
4365          sp_ids=id_tensor, sp_values=weight_tensor, vocab_size=int(size))
4366      # Remove (?, -1) index.
4367      weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
4368                                                weighted_column.dense_shape)
4369      # Use scatter_nd to merge duplicated indices if existed,
4370      # instead of sparse_tensor_to_dense.
4371      return array_ops.scatter_nd(weighted_column.indices,
4372                                  weighted_column.values,
4373                                  weighted_column.dense_shape)
4374
4375    dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
4376        id_tensor, default_value=-1)
4377
4378    # One hot must be float for tf.concat reasons since all other inputs to
4379    # input_layer are float32.
4380    one_hot_id_tensor = array_ops.one_hot(
4381        dense_id_tensor, depth=size, on_value=1.0, off_value=0.0)
4382
4383    # Reduce to get a multi-hot per example.
4384    return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
4385
4386  def transform_feature(self, transformation_cache, state_manager):
4387    """Returns dense `Tensor` representing feature.
4388
4389    Args:
4390      transformation_cache: A `FeatureTransformationCache` object to access
4391        features.
4392      state_manager: A `StateManager` to create / access resources such as
4393        lookup tables.
4394
4395    Returns:
4396      Transformed feature `Tensor`.
4397
4398    Raises:
4399      ValueError: if input rank is not known at graph building time.
4400    """
4401    id_weight_pair = self.categorical_column.get_sparse_tensors(
4402        transformation_cache, state_manager)
4403    return self._transform_id_weight_pair(id_weight_pair,
4404                                          self.variable_shape[-1])
4405
4406  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4407                          _FEATURE_COLUMN_DEPRECATION)
4408  def _transform_feature(self, inputs):
4409    id_weight_pair = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4410    return self._transform_id_weight_pair(id_weight_pair,
4411                                          self._variable_shape[-1])
4412
4413  @property
4414  def parse_example_spec(self):
4415    """See `FeatureColumn` base class."""
4416    return self.categorical_column.parse_example_spec
4417
4418  @property
4419  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4420                          _FEATURE_COLUMN_DEPRECATION)
4421  def _parse_example_spec(self):
4422    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
4423
4424  @property
4425  def variable_shape(self):
4426    """Returns a `TensorShape` representing the shape of the dense `Tensor`."""
4427    if isinstance(self.categorical_column, FeatureColumn):
4428      return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
4429    else:
4430      return tensor_shape.TensorShape([1, self.categorical_column._num_buckets])  # pylint: disable=protected-access
4431
4432  @property
4433  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4434                          _FEATURE_COLUMN_DEPRECATION)
4435  def _variable_shape(self):
4436    return tensor_shape.TensorShape([1, self.categorical_column._num_buckets])  # pylint: disable=protected-access
4437
4438  def get_dense_tensor(self, transformation_cache, state_manager):
4439    """Returns dense `Tensor` representing feature.
4440
4441    Args:
4442      transformation_cache: A `FeatureTransformationCache` object to access
4443        features.
4444      state_manager: A `StateManager` to create / access resources such as
4445        lookup tables.
4446
4447    Returns:
4448      Dense `Tensor` created within `transform_feature`.
4449
4450    Raises:
4451      ValueError: If `categorical_column` is a `SequenceCategoricalColumn`.
4452    """
4453    if isinstance(self.categorical_column, SequenceCategoricalColumn):
4454      raise ValueError(
4455          'In indicator_column: {}. '
4456          'categorical_column must not be of type SequenceCategoricalColumn. '
4457          'Suggested fix A: If you wish to use DenseFeatures, use a '
4458          'non-sequence categorical_column_with_*. '
4459          'Suggested fix B: If you wish to create sequence input, use '
4460          'SequenceFeatures instead of DenseFeatures. '
4461          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4462                                       self.categorical_column))
4463    # Feature has been already transformed. Return the intermediate
4464    # representation created by transform_feature.
4465    return transformation_cache.get(self, state_manager)
4466
4467  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4468                          _FEATURE_COLUMN_DEPRECATION)
4469  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
4470    del weight_collections
4471    del trainable
4472    if isinstance(
4473        self.categorical_column,
4474        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
4475      raise ValueError(
4476          'In indicator_column: {}. '
4477          'categorical_column must not be of type _SequenceCategoricalColumn. '
4478          'Suggested fix A: If you wish to use DenseFeatures, use a '
4479          'non-sequence categorical_column_with_*. '
4480          'Suggested fix B: If you wish to create sequence input, use '
4481          'SequenceFeatures instead of DenseFeatures. '
4482          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4483                                       self.categorical_column))
4484    # Feature has been already transformed. Return the intermediate
4485    # representation created by transform_feature.
4486    return inputs.get(self)
4487
4488  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
4489    """See `SequenceDenseColumn` base class."""
4490    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
4491      raise ValueError(
4492          'In indicator_column: {}. '
4493          'categorical_column must be of type SequenceCategoricalColumn '
4494          'to use SequenceFeatures. '
4495          'Suggested fix: Use one of sequence_categorical_column_with_*. '
4496          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4497                                       self.categorical_column))
4498    # Feature has been already transformed. Return the intermediate
4499    # representation created by transform_feature.
4500    dense_tensor = transformation_cache.get(self, state_manager)
4501    sparse_tensors = self.categorical_column.get_sparse_tensors(
4502        transformation_cache, state_manager)
4503    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4504        sparse_tensors.id_tensor)
4505    return SequenceDenseColumn.TensorSequenceLengthPair(
4506        dense_tensor=dense_tensor, sequence_length=sequence_length)
4507
4508  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4509                          _FEATURE_COLUMN_DEPRECATION)
4510  def _get_sequence_dense_tensor(self,
4511                                 inputs,
4512                                 weight_collections=None,
4513                                 trainable=None):
4514    # Do nothing with weight_collections and trainable since no variables are
4515    # created in this function.
4516    del weight_collections
4517    del trainable
4518    if not isinstance(
4519        self.categorical_column,
4520        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
4521      raise ValueError(
4522          'In indicator_column: {}. '
4523          'categorical_column must be of type _SequenceCategoricalColumn '
4524          'to use SequenceFeatures. '
4525          'Suggested fix: Use one of sequence_categorical_column_with_*. '
4526          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4527                                       self.categorical_column))
4528    # Feature has been already transformed. Return the intermediate
4529    # representation created by _transform_feature.
4530    dense_tensor = inputs.get(self)
4531    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4532    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4533        sparse_tensors.id_tensor)
4534    return SequenceDenseColumn.TensorSequenceLengthPair(
4535        dense_tensor=dense_tensor, sequence_length=sequence_length)
4536
4537  @property
4538  def parents(self):
4539    """See 'FeatureColumn` base class."""
4540    return [self.categorical_column]
4541
4542  def get_config(self):
4543    """See 'FeatureColumn` base class."""
4544    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4545    config = dict(zip(self._fields, self))
4546    config['categorical_column'] = serialize_feature_column(
4547        self.categorical_column)
4548    return config
4549
4550  @classmethod
4551  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4552    """See 'FeatureColumn` base class."""
4553    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4554    _check_config_keys(config, cls._fields)
4555    kwargs = _standardize_and_copy_config(config)
4556    kwargs['categorical_column'] = deserialize_feature_column(
4557        config['categorical_column'], custom_objects, columns_by_name)
4558    return cls(**kwargs)
4559
4560
4561def _verify_static_batch_size_equality(tensors, columns):
4562  """Verify equality between static batch sizes.
4563
4564  Args:
4565    tensors: iterable of input tensors.
4566    columns: Corresponding feature columns.
4567
4568  Raises:
4569    ValueError: in case of mismatched batch sizes.
4570  """
4571  # bath_size is a Dimension object.
4572  expected_batch_size = None
4573  for i in range(0, len(tensors)):
4574    batch_size = tensor_shape.Dimension(tensor_shape.dimension_value(
4575        tensors[i].shape[0]))
4576    if batch_size.value is not None:
4577      if expected_batch_size is None:
4578        bath_size_column_index = i
4579        expected_batch_size = batch_size
4580      elif not expected_batch_size.is_compatible_with(batch_size):
4581        raise ValueError(
4582            'Batch size (first dimension) of each feature must be same. '
4583            'Batch size of columns ({}, {}): ({}, {})'.format(
4584                columns[bath_size_column_index].name, columns[i].name,
4585                expected_batch_size, batch_size))
4586
4587
4588class SequenceCategoricalColumn(
4589    CategoricalColumn,
4590    fc_old._SequenceCategoricalColumn,  # pylint: disable=protected-access
4591    collections.namedtuple('SequenceCategoricalColumn',
4592                           ('categorical_column'))):
4593  """Represents sequences of categorical data."""
4594
4595  @property
4596  def _is_v2_column(self):
4597    return (isinstance(self.categorical_column, FeatureColumn) and
4598            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
4599
4600  @property
4601  def name(self):
4602    """See `FeatureColumn` base class."""
4603    return self.categorical_column.name
4604
4605  @property
4606  def parse_example_spec(self):
4607    """See `FeatureColumn` base class."""
4608    return self.categorical_column.parse_example_spec
4609
4610  @property
4611  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4612                          _FEATURE_COLUMN_DEPRECATION)
4613  def _parse_example_spec(self):
4614    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
4615
4616  def transform_feature(self, transformation_cache, state_manager):
4617    """See `FeatureColumn` base class."""
4618    return self.categorical_column.transform_feature(transformation_cache,
4619                                                     state_manager)
4620
4621  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4622                          _FEATURE_COLUMN_DEPRECATION)
4623  def _transform_feature(self, inputs):
4624    return self.categorical_column._transform_feature(inputs)  # pylint: disable=protected-access
4625
4626  @property
4627  def num_buckets(self):
4628    """Returns number of buckets in this sparse feature."""
4629    return self.categorical_column.num_buckets
4630
4631  @property
4632  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4633                          _FEATURE_COLUMN_DEPRECATION)
4634  def _num_buckets(self):
4635    return self.categorical_column._num_buckets  # pylint: disable=protected-access
4636
4637  def _get_sparse_tensors_helper(self, sparse_tensors):
4638    id_tensor = sparse_tensors.id_tensor
4639    weight_tensor = sparse_tensors.weight_tensor
4640    # Expands third dimension, if necessary so that embeddings are not
4641    # combined during embedding lookup. If the tensor is already 3D, leave
4642    # as-is.
4643    shape = array_ops.shape(id_tensor)
4644    # Compute the third dimension explicitly instead of setting it to -1, as
4645    # that doesn't work for dynamically shaped tensors with 0-length at runtime.
4646    # This happens for empty sequences.
4647    target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
4648    id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
4649    if weight_tensor is not None:
4650      weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
4651    return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
4652
4653  def get_sparse_tensors(self, transformation_cache, state_manager):
4654    """Returns an IdWeightPair.
4655
4656    `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
4657    weights.
4658
4659    `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
4660    `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
4661    `SparseTensor` of `float` or `None` to indicate all weights should be
4662    taken to be 1. If specified, `weight_tensor` must have exactly the same
4663    shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
4664    output of a `VarLenFeature` which is a ragged matrix.
4665
4666    Args:
4667      transformation_cache: A `FeatureTransformationCache` object to access
4668        features.
4669      state_manager: A `StateManager` to create / access resources such as
4670        lookup tables.
4671    """
4672    sparse_tensors = self.categorical_column.get_sparse_tensors(
4673        transformation_cache, state_manager)
4674    return self._get_sparse_tensors_helper(sparse_tensors)
4675
4676  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4677                          _FEATURE_COLUMN_DEPRECATION)
4678  def _get_sparse_tensors(self, inputs, weight_collections=None,
4679                          trainable=None):
4680    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4681    return self._get_sparse_tensors_helper(sparse_tensors)
4682
4683  @property
4684  def parents(self):
4685    """See 'FeatureColumn` base class."""
4686    return [self.categorical_column]
4687
4688  def get_config(self):
4689    """See 'FeatureColumn` base class."""
4690    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4691    config = dict(zip(self._fields, self))
4692    config['categorical_column'] = serialize_feature_column(
4693        self.categorical_column)
4694    return config
4695
4696  @classmethod
4697  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4698    """See 'FeatureColumn` base class."""
4699    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4700    _check_config_keys(config, cls._fields)
4701    kwargs = _standardize_and_copy_config(config)
4702    kwargs['categorical_column'] = deserialize_feature_column(
4703        config['categorical_column'], custom_objects, columns_by_name)
4704    return cls(**kwargs)
4705
4706
4707def _check_config_keys(config, expected_keys):
4708  """Checks that a config has all expected_keys."""
4709  if set(config.keys()) != set(expected_keys):
4710    raise ValueError('Invalid config: {}, expected keys: {}'.format(
4711        config, expected_keys))
4712
4713
4714def _standardize_and_copy_config(config):
4715  """Returns a shallow copy of config with lists turned to tuples.
4716
4717  Keras serialization uses nest to listify everything.
4718  This causes problems with the NumericColumn shape, which becomes
4719  unhashable. We could try to solve this on the Keras side, but that
4720  would require lots of tracking to avoid changing existing behavior.
4721  Instead, we ensure here that we revive correctly.
4722
4723  Args:
4724    config: dict that will be used to revive a Feature Column
4725
4726  Returns:
4727    Shallow copy of config with lists turned to tuples.
4728  """
4729  kwargs = config.copy()
4730  for k, v in kwargs.items():
4731    if isinstance(v, list):
4732      kwargs[k] = tuple(v)
4733
4734  return kwargs
4735
4736
4737def _sanitize_column_name_for_variable_scope(name):
4738  """Sanitizes user-provided feature names for use as variable scopes."""
4739  invalid_char = re.compile('[^A-Za-z0-9_.\\-]')
4740  return invalid_char.sub('_', name)
4741