• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This API defines FeatureColumn abstraction.
16
17FeatureColumns provide a high level abstraction for ingesting and representing
18features. FeatureColumns are also the primary way of encoding features for
19canned `tf.estimator.Estimator`s.
20
21When using FeatureColumns with `Estimators`, the type of feature column you
22should choose depends on (1) the feature type and (2) the model type.
23
241. Feature type:
25
26  * Continuous features can be represented by `numeric_column`.
27  * Categorical features can be represented by any `categorical_column_with_*`
28  column:
29    - `categorical_column_with_vocabulary_list`
30    - `categorical_column_with_vocabulary_file`
31    - `categorical_column_with_hash_bucket`
32    - `categorical_column_with_identity`
33    - `weighted_categorical_column`
34
352. Model type:
36
37  * Deep neural network models (`DNNClassifier`, `DNNRegressor`).
38
39    Continuous features can be directly fed into deep neural network models.
40
41      age_column = numeric_column("age")
42
43    To feed sparse features into DNN models, wrap the column with
44    `embedding_column` or `indicator_column`. `indicator_column` is recommended
45    for features with only a few possible values. For features with many
46    possible values, to reduce the size of your model, `embedding_column` is
47    recommended.
48
49      embedded_dept_column = embedding_column(
50          categorical_column_with_vocabulary_list(
51              "department", ["math", "philosophy", ...]), dimension=10)
52
53  * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
54
55    Sparse features can be fed directly into linear models. They behave like an
56    indicator column but with an efficient implementation.
57
58      dept_column = categorical_column_with_vocabulary_list("department",
59          ["math", "philosophy", "english"])
60
61    It is recommended that continuous features be bucketized before being
62    fed into linear models.
63
64      bucketized_age_column = bucketized_column(
65          source_column=age_column,
66          boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
67
68    Sparse features can be crossed (also known as conjuncted or combined) in
69    order to form non-linearities, and then fed into linear models.
70
71      cross_dept_age_column = crossed_column(
72          columns=["department", bucketized_age_column],
73          hash_bucket_size=1000)
74
75Example of building canned `Estimator`s using FeatureColumns:
76
77  ```python
78  # Define features and transformations
79  deep_feature_columns = [age_column, embedded_dept_column]
80  wide_feature_columns = [dept_column, bucketized_age_column,
81      cross_dept_age_column]
82
83  # Build deep model
84  estimator = DNNClassifier(
85      feature_columns=deep_feature_columns,
86      hidden_units=[500, 250, 50])
87  estimator.train(...)
88
89  # Or build a wide model
90  estimator = LinearClassifier(
91      feature_columns=wide_feature_columns)
92  estimator.train(...)
93
94  # Or build a wide and deep model!
95  estimator = DNNLinearCombinedClassifier(
96      linear_feature_columns=wide_feature_columns,
97      dnn_feature_columns=deep_feature_columns,
98      dnn_hidden_units=[500, 250, 50])
99  estimator.train(...)
100  ```
101
102
103FeatureColumns can also be transformed into a generic input layer for
104custom models using `input_layer`.
105
106Example of building model using FeatureColumns, this can be used in a
107`model_fn` which is given to the {tf.estimator.Estimator}:
108
109  ```python
110  # Building model via layers
111
112  deep_feature_columns = [age_column, embedded_dept_column]
113  columns_to_tensor = parse_feature_columns_from_examples(
114      serialized=my_data,
115      feature_columns=deep_feature_columns)
116  first_layer = input_layer(
117      features=columns_to_tensor,
118      feature_columns=deep_feature_columns)
119  second_layer = fully_connected(first_layer, ...)
120  ```
121
122NOTE: Functions prefixed with "_" indicate experimental or private parts of
123the API subject to change, and should not be relied upon!
124"""
125
126from __future__ import absolute_import
127from __future__ import division
128from __future__ import print_function
129
130import abc
131import collections
132import math
133import re
134
135import numpy as np
136import six
137
138from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops
139from tensorflow.python.data.ops import readers
140from tensorflow.python.eager import context
141from tensorflow.python.feature_column import feature_column as fc_old
142from tensorflow.python.feature_column import utils as fc_utils
143from tensorflow.python.framework import dtypes
144from tensorflow.python.framework import ops
145from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
146from tensorflow.python.framework import tensor_shape
147from tensorflow.python.ops import array_ops
148from tensorflow.python.ops import check_ops
149from tensorflow.python.ops import control_flow_ops
150from tensorflow.python.ops import embedding_ops
151from tensorflow.python.ops import init_ops
152from tensorflow.python.ops import lookup_ops
153from tensorflow.python.ops import math_ops
154from tensorflow.python.ops import parsing_ops
155from tensorflow.python.ops import sparse_ops
156from tensorflow.python.ops import string_ops
157from tensorflow.python.ops import variable_scope
158from tensorflow.python.ops import variables
159from tensorflow.python.platform import gfile
160from tensorflow.python.platform import tf_logging as logging
161from tensorflow.python.training import checkpoint_utils
162from tensorflow.python.training.tracking import base as trackable
163from tensorflow.python.training.tracking import data_structures
164from tensorflow.python.training.tracking import tracking
165from tensorflow.python.util import deprecation
166from tensorflow.python.util import nest
167from tensorflow.python.util import tf_inspect
168from tensorflow.python.util.compat import collections_abc
169from tensorflow.python.util.tf_export import tf_export
170
171
172_FEATURE_COLUMN_DEPRECATION_DATE = None
173_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being '
174                               'deprecated. Please use the new FeatureColumn '
175                               'APIs instead.')
176
177
178class StateManager(object):
179  """Manages the state associated with FeatureColumns.
180
181  Some `FeatureColumn`s create variables or resources to assist their
182  computation. The `StateManager` is responsible for creating and storing these
183  objects since `FeatureColumn`s are supposed to be stateless configuration
184  only.
185  """
186
187  def create_variable(self,
188                      feature_column,
189                      name,
190                      shape,
191                      dtype=None,
192                      trainable=True,
193                      use_resource=True,
194                      initializer=None):
195    """Creates a new variable.
196
197    Args:
198      feature_column: A `FeatureColumn` object this variable corresponds to.
199      name: variable name.
200      shape: variable shape.
201      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
202      trainable: Whether this variable is trainable or not.
203      use_resource: If true, we use resource variables. Otherwise we use
204        RefVariable.
205      initializer: initializer instance (callable).
206
207    Returns:
208      The created variable.
209    """
210    del feature_column, name, shape, dtype, trainable, use_resource, initializer
211    raise NotImplementedError('StateManager.create_variable')
212
213  def add_variable(self, feature_column, var):
214    """Adds an existing variable to the state.
215
216    Args:
217      feature_column: A `FeatureColumn` object to associate this variable with.
218      var: The variable.
219    """
220    del feature_column, var
221    raise NotImplementedError('StateManager.add_variable')
222
223  def get_variable(self, feature_column, name):
224    """Returns an existing variable.
225
226    Args:
227      feature_column: A `FeatureColumn` object this variable corresponds to.
228      name: variable name.
229    """
230    del feature_column, name
231    raise NotImplementedError('StateManager.get_var')
232
233  def add_resource(self, feature_column, name, resource):
234    """Creates a new resource.
235
236    Resources can be things such as tables, variables, trackables, etc.
237
238    Args:
239      feature_column: A `FeatureColumn` object this resource corresponds to.
240      name: Name of the resource.
241      resource: The resource.
242
243    Returns:
244      The created resource.
245    """
246    del feature_column, name, resource
247    raise NotImplementedError('StateManager.add_resource')
248
249  def has_resource(self, feature_column, name):
250    """Returns true iff a resource with same name exists.
251
252    Resources can be things such as tables, variables, trackables, etc.
253
254    Args:
255      feature_column: A `FeatureColumn` object this variable corresponds to.
256      name: Name of the resource.
257    """
258    del feature_column, name
259    raise NotImplementedError('StateManager.has_resource')
260
261  def get_resource(self, feature_column, name):
262    """Returns an already created resource.
263
264    Resources can be things such as tables, variables, trackables, etc.
265
266    Args:
267      feature_column: A `FeatureColumn` object this variable corresponds to.
268      name: Name of the resource.
269    """
270    del feature_column, name
271    raise NotImplementedError('StateManager.get_resource')
272
273
274@tf_export('__internal__.feature_column.StateManager', v1=[])
275class _StateManagerImpl(StateManager):
276  """Manages the state of DenseFeatures and LinearLayer.
277
278  Some `FeatureColumn`s create variables or resources to assist their
279  computation. The `StateManager` is responsible for creating and storing these
280  objects since `FeatureColumn`s are supposed to be stateless configuration
281  only.
282  """
283
284  def __init__(self, layer, trainable):
285    """Creates an _StateManagerImpl object.
286
287    Args:
288      layer: The input layer this state manager is associated with.
289      trainable: Whether by default, variables created are trainable or not.
290    """
291    self._trainable = trainable
292    self._layer = layer
293    if self._layer is not None and not hasattr(self._layer, '_resources'):
294      self._layer._resources = data_structures.Mapping()  # pylint: disable=protected-access
295    self._cols_to_vars_map = collections.defaultdict(lambda: {})
296    self._cols_to_resources_map = collections.defaultdict(lambda: {})
297
298  def create_variable(self,
299                      feature_column,
300                      name,
301                      shape,
302                      dtype=None,
303                      trainable=True,
304                      use_resource=True,
305                      initializer=None):
306    """Creates a new variable.
307
308    Args:
309      feature_column: A `FeatureColumn` object this variable corresponds to.
310      name: variable name.
311      shape: variable shape.
312      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
313      trainable: Whether this variable is trainable or not.
314      use_resource: If true, we use resource variables. Otherwise we use
315        RefVariable.
316      initializer: initializer instance (callable).
317
318    Returns:
319      The created variable.
320    """
321    if name in self._cols_to_vars_map[feature_column]:
322      raise ValueError('Variable already exists.')
323
324    # We explicitly track these variables since `name` is not guaranteed to be
325    # unique and disable manual tracking that the add_weight call does.
326    with trackable.no_manual_dependency_tracking_scope(self._layer):
327      var = self._layer.add_weight(
328          name=name,
329          shape=shape,
330          dtype=dtype,
331          initializer=initializer,
332          trainable=self._trainable and trainable,
333          use_resource=use_resource,
334          # TODO(rohanj): Get rid of this hack once we have a mechanism for
335          # specifying a default partitioner for an entire layer. In that case,
336          # the default getter for Layers should work.
337          getter=variable_scope.get_variable)
338    if isinstance(var, variables.PartitionedVariable):
339      for v in var:
340        part_name = name + '/' + str(v._get_save_slice_info().var_offset[0])  # pylint: disable=protected-access
341        self._layer._track_trackable(v, feature_column.name + '/' + part_name)  # pylint: disable=protected-access
342    else:
343      if isinstance(var, trackable.Trackable):
344        self._layer._track_trackable(var, feature_column.name + '/' + name)  # pylint: disable=protected-access
345
346    self._cols_to_vars_map[feature_column][name] = var
347    return var
348
349  def get_variable(self, feature_column, name):
350    """Returns an existing variable.
351
352    Args:
353      feature_column: A `FeatureColumn` object this variable corresponds to.
354      name: variable name.
355    """
356    if name in self._cols_to_vars_map[feature_column]:
357      return self._cols_to_vars_map[feature_column][name]
358    raise ValueError('Variable does not exist.')
359
360  def add_resource(self, feature_column, resource_name, resource):
361    """Creates a new resource.
362
363    Resources can be things such as tables, variables, trackables, etc.
364
365    Args:
366      feature_column: A `FeatureColumn` object this resource corresponds to.
367      resource_name: Name of the resource.
368      resource: The resource.
369
370    Returns:
371      The created resource.
372    """
373    self._cols_to_resources_map[feature_column][resource_name] = resource
374    # pylint: disable=protected-access
375    if self._layer is not None and isinstance(resource, trackable.Trackable):
376      # Add trackable resources to the layer for serialization.
377      if feature_column.name not in self._layer._resources:
378        self._layer._resources[feature_column.name] = data_structures.Mapping()
379      if resource_name not in self._layer._resources[feature_column.name]:
380        self._layer._resources[feature_column.name][resource_name] = resource
381    # pylint: enable=protected-access
382
383  def has_resource(self, feature_column, resource_name):
384    """Returns true iff a resource with same name exists.
385
386    Resources can be things such as tables, variables, trackables, etc.
387
388    Args:
389      feature_column: A `FeatureColumn` object this variable corresponds to.
390      resource_name: Name of the resource.
391    """
392    return resource_name in self._cols_to_resources_map[feature_column]
393
394  def get_resource(self, feature_column, resource_name):
395    """Returns an already created resource.
396
397    Resources can be things such as tables, variables, trackables, etc.
398
399    Args:
400      feature_column: A `FeatureColumn` object this variable corresponds to.
401      resource_name: Name of the resource.
402    """
403    if (feature_column not in self._cols_to_resources_map or
404        resource_name not in self._cols_to_resources_map[feature_column]):
405      raise ValueError('Resource does not exist.')
406    return self._cols_to_resources_map[feature_column][resource_name]
407
408
409def _transform_features_v2(features, feature_columns, state_manager):
410  """Returns transformed features based on features columns passed in.
411
412  Please note that most probably you would not need to use this function. Please
413  check `input_layer` and `linear_model` to see whether they will
414  satisfy your use case or not.
415
416  Example:
417
418  ```python
419  # Define features and transformations
420  crosses_a_x_b = crossed_column(
421      columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000)
422  price_buckets = bucketized_column(
423      source_column=numeric_column("price"), boundaries=[...])
424
425  columns = [crosses_a_x_b, price_buckets]
426  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
427  transformed = transform_features(features=features, feature_columns=columns)
428
429  assertCountEqual(columns, transformed.keys())
430  ```
431
432  Args:
433    features: A mapping from key to tensors. `FeatureColumn`s look up via these
434      keys. For example `numeric_column('price')` will look at 'price' key in
435      this dict. Values can be a `SparseTensor` or a `Tensor` depends on
436      corresponding `FeatureColumn`.
437    feature_columns: An iterable containing all the `FeatureColumn`s.
438    state_manager: A StateManager object that holds the FeatureColumn state.
439
440  Returns:
441    A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values.
442  """
443  feature_columns = _normalize_feature_columns(feature_columns)
444  outputs = {}
445  with ops.name_scope(
446      None, default_name='transform_features', values=features.values()):
447    transformation_cache = FeatureTransformationCache(features)
448    for column in feature_columns:
449      with ops.name_scope(
450          None,
451          default_name=_sanitize_column_name_for_variable_scope(column.name)):
452        outputs[column] = transformation_cache.get(column, state_manager)
453  return outputs
454
455
456@tf_export('feature_column.make_parse_example_spec', v1=[])
457def make_parse_example_spec_v2(feature_columns):
458  """Creates parsing spec dictionary from input feature_columns.
459
460  The returned dictionary can be used as arg 'features' in
461  `tf.io.parse_example`.
462
463  Typical usage example:
464
465  ```python
466  # Define features and transformations
467  feature_a = tf.feature_column.categorical_column_with_vocabulary_file(...)
468  feature_b = tf.feature_column.numeric_column(...)
469  feature_c_bucketized = tf.feature_column.bucketized_column(
470      tf.feature_column.numeric_column("feature_c"), ...)
471  feature_a_x_feature_c = tf.feature_column.crossed_column(
472      columns=["feature_a", feature_c_bucketized], ...)
473
474  feature_columns = set(
475      [feature_b, feature_c_bucketized, feature_a_x_feature_c])
476  features = tf.io.parse_example(
477      serialized=serialized_examples,
478      features=tf.feature_column.make_parse_example_spec(feature_columns))
479  ```
480
481  For the above example, make_parse_example_spec would return the dict:
482
483  ```python
484  {
485      "feature_a": parsing_ops.VarLenFeature(tf.string),
486      "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32),
487      "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32)
488  }
489  ```
490
491  Args:
492    feature_columns: An iterable containing all feature columns. All items
493      should be instances of classes derived from `FeatureColumn`.
494
495  Returns:
496    A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`
497    value.
498
499  Raises:
500    ValueError: If any of the given `feature_columns` is not a `FeatureColumn`
501      instance.
502  """
503  result = {}
504  for column in feature_columns:
505    if not isinstance(column, FeatureColumn):
506      raise ValueError('All feature_columns must be FeatureColumn instances. '
507                       'Given: {}'.format(column))
508    config = column.parse_example_spec
509    for key, value in six.iteritems(config):
510      if key in result and value != result[key]:
511        raise ValueError(
512            'feature_columns contain different parse_spec for key '
513            '{}. Given {} and {}'.format(key, value, result[key]))
514    result.update(config)
515  return result
516
517
518@tf_export('feature_column.embedding_column')
519def embedding_column(categorical_column,
520                     dimension,
521                     combiner='mean',
522                     initializer=None,
523                     ckpt_to_load_from=None,
524                     tensor_name_in_ckpt=None,
525                     max_norm=None,
526                     trainable=True,
527                     use_safe_embedding_lookup=True):
528  """`DenseColumn` that converts from sparse, categorical input.
529
530  Use this when your inputs are sparse, but you want to convert them to a dense
531  representation (e.g., to feed to a DNN).
532
533  Inputs must be a `CategoricalColumn` created by any of the
534  `categorical_column_*` function. Here is an example of using
535  `embedding_column` with `DNNClassifier`:
536
537  ```python
538  video_id = categorical_column_with_identity(
539      key='video_id', num_buckets=1000000, default_value=0)
540  columns = [embedding_column(video_id, 9),...]
541
542  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
543
544  label_column = ...
545  def input_fn():
546    features = tf.io.parse_example(
547        ..., features=make_parse_example_spec(columns + [label_column]))
548    labels = features.pop(label_column.name)
549    return features, labels
550
551  estimator.train(input_fn=input_fn, steps=100)
552  ```
553
554  Here is an example using `embedding_column` with model_fn:
555
556  ```python
557  def model_fn(features, ...):
558    video_id = categorical_column_with_identity(
559        key='video_id', num_buckets=1000000, default_value=0)
560    columns = [embedding_column(video_id, 9),...]
561    dense_tensor = input_layer(features, columns)
562    # Form DNN layers, calculate loss, and return EstimatorSpec.
563    ...
564  ```
565
566  Args:
567    categorical_column: A `CategoricalColumn` created by a
568      `categorical_column_with_*` function. This column produces the sparse IDs
569      that are inputs to the embedding lookup.
570    dimension: An integer specifying dimension of the embedding, must be > 0.
571    combiner: A string specifying how to reduce if there are multiple entries in
572      a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
573      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
574      with bag-of-words columns. Each of this can be thought as example level
575      normalizations on the column. For more information, see
576      `tf.embedding_lookup_sparse`.
577    initializer: A variable initializer function to be used in embedding
578      variable initialization. If not specified, defaults to
579      `truncated_normal_initializer` with mean `0.0` and
580      standard deviation `1/sqrt(dimension)`.
581    ckpt_to_load_from: String representing checkpoint name/pattern from which to
582      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
583    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
584      to restore the column weights. Required if `ckpt_to_load_from` is not
585      `None`.
586    max_norm: If not `None`, embedding values are l2-normalized to this value.
587    trainable: Whether or not the embedding is trainable. Default is True.
588    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
589      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
590      there are no empty rows and all weights and ids are positive at the
591      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
592      input tensors. Defaults to true, consider turning off if the above checks
593      are not needed. Note that having empty rows will not trigger any error
594      though the output result might be 0 or omitted.
595
596  Returns:
597    `DenseColumn` that converts from sparse input.
598
599  Raises:
600    ValueError: if `dimension` not > 0.
601    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
602      is specified.
603    ValueError: if `initializer` is specified and is not callable.
604    RuntimeError: If eager execution is enabled.
605  """
606  if (dimension is None) or (dimension < 1):
607    raise ValueError('Invalid dimension {}.'.format(dimension))
608  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
609    raise ValueError('Must specify both `ckpt_to_load_from` and '
610                     '`tensor_name_in_ckpt` or none of them.')
611
612  if (initializer is not None) and (not callable(initializer)):
613    raise ValueError('initializer must be callable if specified. '
614                     'Embedding of column_name: {}'.format(
615                         categorical_column.name))
616  if initializer is None:
617    initializer = init_ops.truncated_normal_initializer(
618        mean=0.0, stddev=1 / math.sqrt(dimension))
619
620  return EmbeddingColumn(
621      categorical_column=categorical_column,
622      dimension=dimension,
623      combiner=combiner,
624      initializer=initializer,
625      ckpt_to_load_from=ckpt_to_load_from,
626      tensor_name_in_ckpt=tensor_name_in_ckpt,
627      max_norm=max_norm,
628      trainable=trainable,
629      use_safe_embedding_lookup=use_safe_embedding_lookup)
630
631
632@tf_export(v1=['feature_column.shared_embedding_columns'])
633def shared_embedding_columns(categorical_columns,
634                             dimension,
635                             combiner='mean',
636                             initializer=None,
637                             shared_embedding_collection_name=None,
638                             ckpt_to_load_from=None,
639                             tensor_name_in_ckpt=None,
640                             max_norm=None,
641                             trainable=True,
642                             use_safe_embedding_lookup=True):
643  """List of dense columns that convert from sparse, categorical input.
644
645  This is similar to `embedding_column`, except that it produces a list of
646  embedding columns that share the same embedding weights.
647
648  Use this when your inputs are sparse and of the same type (e.g. watched and
649  impression video IDs that share the same vocabulary), and you want to convert
650  them to a dense representation (e.g., to feed to a DNN).
651
652  Inputs must be a list of categorical columns created by any of the
653  `categorical_column_*` function. They must all be of the same type and have
654  the same arguments except `key`. E.g. they can be
655  categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
656  all columns could also be weighted_categorical_column.
657
658  Here is an example embedding of two features for a DNNClassifier model:
659
660  ```python
661  watched_video_id = categorical_column_with_vocabulary_file(
662      'watched_video_id', video_vocabulary_file, video_vocabulary_size)
663  impression_video_id = categorical_column_with_vocabulary_file(
664      'impression_video_id', video_vocabulary_file, video_vocabulary_size)
665  columns = shared_embedding_columns(
666      [watched_video_id, impression_video_id], dimension=10)
667
668  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
669
670  label_column = ...
671  def input_fn():
672    features = tf.io.parse_example(
673        ..., features=make_parse_example_spec(columns + [label_column]))
674    labels = features.pop(label_column.name)
675    return features, labels
676
677  estimator.train(input_fn=input_fn, steps=100)
678  ```
679
680  Here is an example using `shared_embedding_columns` with model_fn:
681
682  ```python
683  def model_fn(features, ...):
684    watched_video_id = categorical_column_with_vocabulary_file(
685        'watched_video_id', video_vocabulary_file, video_vocabulary_size)
686    impression_video_id = categorical_column_with_vocabulary_file(
687        'impression_video_id', video_vocabulary_file, video_vocabulary_size)
688    columns = shared_embedding_columns(
689        [watched_video_id, impression_video_id], dimension=10)
690    dense_tensor = input_layer(features, columns)
691    # Form DNN layers, calculate loss, and return EstimatorSpec.
692    ...
693  ```
694
695  Args:
696    categorical_columns: List of categorical columns created by a
697      `categorical_column_with_*` function. These columns produce the sparse IDs
698      that are inputs to the embedding lookup. All columns must be of the same
699      type and have the same arguments except `key`. E.g. they can be
700      categorical_column_with_vocabulary_file with the same vocabulary_file.
701      Some or all columns could also be weighted_categorical_column.
702    dimension: An integer specifying dimension of the embedding, must be > 0.
703    combiner: A string specifying how to reduce if there are multiple entries in
704      a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
705      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
706      with bag-of-words columns. Each of this can be thought as example level
707      normalizations on the column. For more information, see
708      `tf.embedding_lookup_sparse`.
709    initializer: A variable initializer function to be used in embedding
710      variable initialization. If not specified, defaults to
711      `truncated_normal_initializer` with mean `0.0` and
712      standard deviation `1/sqrt(dimension)`.
713    shared_embedding_collection_name: Optional name of the collection where
714      shared embedding weights are added. If not given, a reasonable name will
715      be chosen based on the names of `categorical_columns`. This is also used
716      in `variable_scope` when creating shared embedding weights.
717    ckpt_to_load_from: String representing checkpoint name/pattern from which to
718      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
719    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
720      to restore the column weights. Required if `ckpt_to_load_from` is not
721      `None`.
722    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
723      than this value, before combining.
724    trainable: Whether or not the embedding is trainable. Default is True.
725    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
726      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
727      there are no empty rows and all weights and ids are positive at the
728      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
729      input tensors. Defaults to true, consider turning off if the above checks
730      are not needed. Note that having empty rows will not trigger any error
731      though the output result might be 0 or omitted.
732
733  Returns:
734    A list of dense columns that converts from sparse input. The order of
735    results follows the ordering of `categorical_columns`.
736
737  Raises:
738    ValueError: if `dimension` not > 0.
739    ValueError: if any of the given `categorical_columns` is of different type
740      or has different arguments than the others.
741    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
742      is specified.
743    ValueError: if `initializer` is specified and is not callable.
744    RuntimeError: if eager execution is enabled.
745  """
746  if context.executing_eagerly():
747    raise RuntimeError('shared_embedding_columns are not supported when eager '
748                       'execution is enabled.')
749
750  if (dimension is None) or (dimension < 1):
751    raise ValueError('Invalid dimension {}.'.format(dimension))
752  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
753    raise ValueError('Must specify both `ckpt_to_load_from` and '
754                     '`tensor_name_in_ckpt` or none of them.')
755
756  if (initializer is not None) and (not callable(initializer)):
757    raise ValueError('initializer must be callable if specified.')
758  if initializer is None:
759    initializer = init_ops.truncated_normal_initializer(
760        mean=0.0, stddev=1. / math.sqrt(dimension))
761
762  # Sort the columns so the default collection name is deterministic even if the
763  # user passes columns from an unsorted collection, such as dict.values().
764  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
765
766  c0 = sorted_columns[0]
767  num_buckets = c0._num_buckets  # pylint: disable=protected-access
768  if not isinstance(c0, fc_old._CategoricalColumn):  # pylint: disable=protected-access
769    raise ValueError(
770        'All categorical_columns must be subclasses of _CategoricalColumn. '
771        'Given: {}, of type: {}'.format(c0, type(c0)))
772  while isinstance(
773      c0, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn,  # pylint: disable=protected-access
774           fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)):  # pylint: disable=protected-access
775    c0 = c0.categorical_column
776  for c in sorted_columns[1:]:
777    while isinstance(
778        c, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn,  # pylint: disable=protected-access
779            fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)):  # pylint: disable=protected-access
780      c = c.categorical_column
781    if not isinstance(c, type(c0)):
782      raise ValueError(
783          'To use shared_embedding_column, all categorical_columns must have '
784          'the same type, or be weighted_categorical_column or sequence column '
785          'of the same type. Given column: {} of type: {} does not match given '
786          'column: {} of type: {}'.format(c0, type(c0), c, type(c)))
787    if num_buckets != c._num_buckets:  # pylint: disable=protected-access
788      raise ValueError(
789          'To use shared_embedding_column, all categorical_columns must have '
790          'the same number of buckets. ven column: {} with buckets: {} does  '
791          'not match column: {} with buckets: {}'.format(
792              c0, num_buckets, c, c._num_buckets))  # pylint: disable=protected-access
793
794  if not shared_embedding_collection_name:
795    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
796    shared_embedding_collection_name += '_shared_embedding'
797
798  result = []
799  for column in categorical_columns:
800    result.append(
801        fc_old._SharedEmbeddingColumn(  # pylint: disable=protected-access
802            categorical_column=column,
803            initializer=initializer,
804            dimension=dimension,
805            combiner=combiner,
806            shared_embedding_collection_name=shared_embedding_collection_name,
807            ckpt_to_load_from=ckpt_to_load_from,
808            tensor_name_in_ckpt=tensor_name_in_ckpt,
809            max_norm=max_norm,
810            trainable=trainable,
811            use_safe_embedding_lookup=use_safe_embedding_lookup))
812
813  return result
814
815
816@tf_export('feature_column.shared_embeddings', v1=[])
817def shared_embedding_columns_v2(categorical_columns,
818                                dimension,
819                                combiner='mean',
820                                initializer=None,
821                                shared_embedding_collection_name=None,
822                                ckpt_to_load_from=None,
823                                tensor_name_in_ckpt=None,
824                                max_norm=None,
825                                trainable=True,
826                                use_safe_embedding_lookup=True):
827  """List of dense columns that convert from sparse, categorical input.
828
829  This is similar to `embedding_column`, except that it produces a list of
830  embedding columns that share the same embedding weights.
831
832  Use this when your inputs are sparse and of the same type (e.g. watched and
833  impression video IDs that share the same vocabulary), and you want to convert
834  them to a dense representation (e.g., to feed to a DNN).
835
836  Inputs must be a list of categorical columns created by any of the
837  `categorical_column_*` function. They must all be of the same type and have
838  the same arguments except `key`. E.g. they can be
839  categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
840  all columns could also be weighted_categorical_column.
841
842  Here is an example embedding of two features for a DNNClassifier model:
843
844  ```python
845  watched_video_id = categorical_column_with_vocabulary_file(
846      'watched_video_id', video_vocabulary_file, video_vocabulary_size)
847  impression_video_id = categorical_column_with_vocabulary_file(
848      'impression_video_id', video_vocabulary_file, video_vocabulary_size)
849  columns = shared_embedding_columns(
850      [watched_video_id, impression_video_id], dimension=10)
851
852  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
853
854  label_column = ...
855  def input_fn():
856    features = tf.io.parse_example(
857        ..., features=make_parse_example_spec(columns + [label_column]))
858    labels = features.pop(label_column.name)
859    return features, labels
860
861  estimator.train(input_fn=input_fn, steps=100)
862  ```
863
864  Here is an example using `shared_embedding_columns` with model_fn:
865
866  ```python
867  def model_fn(features, ...):
868    watched_video_id = categorical_column_with_vocabulary_file(
869        'watched_video_id', video_vocabulary_file, video_vocabulary_size)
870    impression_video_id = categorical_column_with_vocabulary_file(
871        'impression_video_id', video_vocabulary_file, video_vocabulary_size)
872    columns = shared_embedding_columns(
873        [watched_video_id, impression_video_id], dimension=10)
874    dense_tensor = input_layer(features, columns)
875    # Form DNN layers, calculate loss, and return EstimatorSpec.
876    ...
877  ```
878
879  Args:
880    categorical_columns: List of categorical columns created by a
881      `categorical_column_with_*` function. These columns produce the sparse IDs
882      that are inputs to the embedding lookup. All columns must be of the same
883      type and have the same arguments except `key`. E.g. they can be
884      categorical_column_with_vocabulary_file with the same vocabulary_file.
885      Some or all columns could also be weighted_categorical_column.
886    dimension: An integer specifying dimension of the embedding, must be > 0.
887    combiner: A string specifying how to reduce if there are multiple entries
888      in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
889      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
890      with bag-of-words columns. Each of this can be thought as example level
891      normalizations on the column. For more information, see
892      `tf.embedding_lookup_sparse`.
893    initializer: A variable initializer function to be used in embedding
894      variable initialization. If not specified, defaults to
895      `truncated_normal_initializer` with mean `0.0` and standard
896      deviation `1/sqrt(dimension)`.
897    shared_embedding_collection_name: Optional collective name of these columns.
898      If not given, a reasonable name will be chosen based on the names of
899      `categorical_columns`.
900    ckpt_to_load_from: String representing checkpoint name/pattern from which to
901      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
902    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
903      which to restore the column weights. Required if `ckpt_to_load_from` is
904      not `None`.
905    max_norm: If not `None`, each embedding is clipped if its l2-norm is
906      larger than this value, before combining.
907    trainable: Whether or not the embedding is trainable. Default is True.
908    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
909      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
910      there are no empty rows and all weights and ids are positive at the
911      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
912      input tensors. Defaults to true, consider turning off if the above checks
913      are not needed. Note that having empty rows will not trigger any error
914      though the output result might be 0 or omitted.
915
916  Returns:
917    A list of dense columns that converts from sparse input. The order of
918    results follows the ordering of `categorical_columns`.
919
920  Raises:
921    ValueError: if `dimension` not > 0.
922    ValueError: if any of the given `categorical_columns` is of different type
923      or has different arguments than the others.
924    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
925      is specified.
926    ValueError: if `initializer` is specified and is not callable.
927    RuntimeError: if eager execution is enabled.
928  """
929  if context.executing_eagerly():
930    raise RuntimeError('shared_embedding_columns are not supported when eager '
931                       'execution is enabled.')
932
933  if (dimension is None) or (dimension < 1):
934    raise ValueError('Invalid dimension {}.'.format(dimension))
935  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
936    raise ValueError('Must specify both `ckpt_to_load_from` and '
937                     '`tensor_name_in_ckpt` or none of them.')
938
939  if (initializer is not None) and (not callable(initializer)):
940    raise ValueError('initializer must be callable if specified.')
941  if initializer is None:
942    initializer = init_ops.truncated_normal_initializer(
943        mean=0.0, stddev=1. / math.sqrt(dimension))
944
945  # Sort the columns so the default collection name is deterministic even if the
946  # user passes columns from an unsorted collection, such as dict.values().
947  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
948
949  c0 = sorted_columns[0]
950  num_buckets = c0.num_buckets
951  if not isinstance(c0, CategoricalColumn):
952    raise ValueError(
953        'All categorical_columns must be subclasses of CategoricalColumn. '
954        'Given: {}, of type: {}'.format(c0, type(c0)))
955  while isinstance(c0, (WeightedCategoricalColumn, SequenceCategoricalColumn)):
956    c0 = c0.categorical_column
957  for c in sorted_columns[1:]:
958    while isinstance(c, (WeightedCategoricalColumn, SequenceCategoricalColumn)):
959      c = c.categorical_column
960    if not isinstance(c, type(c0)):
961      raise ValueError(
962          'To use shared_embedding_column, all categorical_columns must have '
963          'the same type, or be weighted_categorical_column or sequence column '
964          'of the same type. Given column: {} of type: {} does not match given '
965          'column: {} of type: {}'.format(c0, type(c0), c, type(c)))
966    if num_buckets != c.num_buckets:
967      raise ValueError(
968          'To use shared_embedding_column, all categorical_columns must have '
969          'the same number of buckets. Given column: {} with buckets: {} does  '
970          'not match column: {} with buckets: {}'.format(
971              c0, num_buckets, c, c.num_buckets))
972
973  if not shared_embedding_collection_name:
974    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
975    shared_embedding_collection_name += '_shared_embedding'
976
977  column_creator = SharedEmbeddingColumnCreator(
978      dimension, initializer, ckpt_to_load_from, tensor_name_in_ckpt,
979      num_buckets, trainable, shared_embedding_collection_name,
980      use_safe_embedding_lookup)
981
982  result = []
983  for column in categorical_columns:
984    result.append(
985        column_creator(
986            categorical_column=column, combiner=combiner, max_norm=max_norm))
987
988  return result
989
990
991@tf_export('feature_column.numeric_column')
992def numeric_column(key,
993                   shape=(1,),
994                   default_value=None,
995                   dtype=dtypes.float32,
996                   normalizer_fn=None):
997  """Represents real valued or numerical features.
998
999  Example:
1000
1001  Assume we have data with two features `a` and `b`.
1002
1003  >>> data = {'a': [15, 9, 17, 19, 21, 18, 25, 30],
1004  ...    'b': [5.0, 6.4, 10.5, 13.6, 15.7, 19.9, 20.3 , 0.0]}
1005
1006  Let us represent the features `a` and `b` as numerical features.
1007
1008  >>> a = tf.feature_column.numeric_column('a')
1009  >>> b = tf.feature_column.numeric_column('b')
1010
1011  Feature column describe a set of transformations to the inputs.
1012
1013  For example, to "bucketize" feature `a`, wrap the `a` column in a
1014  `feature_column.bucketized_column`.
1015  Providing `5` bucket boundaries, the bucketized_column api
1016  will bucket this feature in total of `6` buckets.
1017
1018  >>> a_buckets = tf.feature_column.bucketized_column(a,
1019  ...    boundaries=[10, 15, 20, 25, 30])
1020
1021  Create a `DenseFeatures` layer which will apply the transformations
1022  described by the set of `tf.feature_column` objects:
1023
1024  >>> feature_layer = tf.keras.layers.DenseFeatures([a_buckets, b])
1025  >>> print(feature_layer(data))
1026  tf.Tensor(
1027  [[ 0.   0.   1.   0.   0.   0.   5. ]
1028   [ 1.   0.   0.   0.   0.   0.   6.4]
1029   [ 0.   0.   1.   0.   0.   0.  10.5]
1030   [ 0.   0.   1.   0.   0.   0.  13.6]
1031   [ 0.   0.   0.   1.   0.   0.  15.7]
1032   [ 0.   0.   1.   0.   0.   0.  19.9]
1033   [ 0.   0.   0.   0.   1.   0.  20.3]
1034   [ 0.   0.   0.   0.   0.   1.   0. ]], shape=(8, 7), dtype=float32)
1035
1036  Args:
1037    key: A unique string identifying the input feature. It is used as the
1038      column name and the dictionary key for feature parsing configs, feature
1039      `Tensor` objects, and feature columns.
1040    shape: An iterable of integers specifies the shape of the `Tensor`. An
1041      integer can be given which means a single dimension `Tensor` with given
1042      width. The `Tensor` representing the column will have the shape of
1043      [batch_size] + `shape`.
1044    default_value: A single value compatible with `dtype` or an iterable of
1045      values compatible with `dtype` which the column takes on during
1046      `tf.Example` parsing if data is missing. A default value of `None` will
1047      cause `tf.io.parse_example` to fail if an example does not contain this
1048      column. If a single value is provided, the same value will be applied as
1049      the default value for every item. If an iterable of values is provided,
1050      the shape of the `default_value` should be equal to the given `shape`.
1051    dtype: defines the type of values. Default value is `tf.float32`. Must be a
1052      non-quantized, real integer or floating point type.
1053    normalizer_fn: If not `None`, a function that can be used to normalize the
1054      value of the tensor after `default_value` is applied for parsing.
1055      Normalizer function takes the input `Tensor` as its argument, and returns
1056      the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
1057      even though the most common use case of this function is normalization, it
1058      can be used for any kind of Tensorflow transformations.
1059
1060  Returns:
1061    A `NumericColumn`.
1062
1063  Raises:
1064    TypeError: if any dimension in shape is not an int
1065    ValueError: if any dimension in shape is not a positive integer
1066    TypeError: if `default_value` is an iterable but not compatible with `shape`
1067    TypeError: if `default_value` is not compatible with `dtype`.
1068    ValueError: if `dtype` is not convertible to `tf.float32`.
1069  """
1070  shape = _check_shape(shape, key)
1071  if not (dtype.is_integer or dtype.is_floating):
1072    raise ValueError('dtype must be convertible to float. '
1073                     'dtype: {}, key: {}'.format(dtype, key))
1074  default_value = fc_utils.check_default_value(
1075      shape, default_value, dtype, key)
1076
1077  if normalizer_fn is not None and not callable(normalizer_fn):
1078    raise TypeError(
1079        'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
1080
1081  fc_utils.assert_key_is_string(key)
1082  return NumericColumn(
1083      key,
1084      shape=shape,
1085      default_value=default_value,
1086      dtype=dtype,
1087      normalizer_fn=normalizer_fn)
1088
1089
1090@tf_export('feature_column.bucketized_column')
1091def bucketized_column(source_column, boundaries):
1092  """Represents discretized dense input bucketed by `boundaries`.
1093
1094  Buckets include the left boundary, and exclude the right boundary. Namely,
1095  `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
1096  `[1., 2.)`, and `[2., +inf)`.
1097
1098  For example, if the inputs are
1099
1100  ```python
1101  boundaries = [0, 10, 100]
1102  input tensor = [[-5, 10000]
1103                  [150,   10]
1104                  [5,    100]]
1105  ```
1106
1107  then the output will be
1108
1109  ```python
1110  output = [[0, 3]
1111            [3, 2]
1112            [1, 3]]
1113  ```
1114
1115  Example:
1116
1117  ```python
1118  price = tf.feature_column.numeric_column('price')
1119  bucketized_price = tf.feature_column.bucketized_column(
1120      price, boundaries=[...])
1121  columns = [bucketized_price, ...]
1122  features = tf.io.parse_example(
1123      ..., features=tf.feature_column.make_parse_example_spec(columns))
1124  dense_tensor = tf.keras.layers.DenseFeatures(columns)(features)
1125  ```
1126
1127  A `bucketized_column` can also be crossed with another categorical column
1128  using `crossed_column`:
1129
1130  ```python
1131  price = tf.feature_column.numeric_column('price')
1132  # bucketized_column converts numerical feature to a categorical one.
1133  bucketized_price = tf.feature_column.bucketized_column(
1134      price, boundaries=[...])
1135  # 'keywords' is a string feature.
1136  price_x_keywords = tf.feature_column.crossed_column(
1137      [bucketized_price, 'keywords'], 50K)
1138  columns = [price_x_keywords, ...]
1139  features = tf.io.parse_example(
1140      ..., features=tf.feature_column.make_parse_example_spec(columns))
1141  dense_tensor = tf.keras.layers.DenseFeatures(columns)(features)
1142  linear_model = tf.keras.experimental.LinearModel(units=...)(dense_tensor)
1143  ```
1144
1145  Args:
1146    source_column: A one-dimensional dense column which is generated with
1147      `numeric_column`.
1148    boundaries: A sorted list or tuple of floats specifying the boundaries.
1149
1150  Returns:
1151    A `BucketizedColumn`.
1152
1153  Raises:
1154    ValueError: If `source_column` is not a numeric column, or if it is not
1155      one-dimensional.
1156    ValueError: If `boundaries` is not a sorted list or tuple.
1157  """
1158  if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)):  # pylint: disable=protected-access
1159    raise ValueError(
1160        'source_column must be a column generated with numeric_column(). '
1161        'Given: {}'.format(source_column))
1162  if len(source_column.shape) > 1:
1163    raise ValueError(
1164        'source_column must be one-dimensional column. '
1165        'Given: {}'.format(source_column))
1166  if not boundaries:
1167    raise ValueError('boundaries must not be empty.')
1168  if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)):
1169    raise ValueError('boundaries must be a sorted list.')
1170  for i in range(len(boundaries) - 1):
1171    if boundaries[i] >= boundaries[i + 1]:
1172      raise ValueError('boundaries must be a sorted list.')
1173  return BucketizedColumn(source_column, tuple(boundaries))
1174
1175
1176@tf_export('feature_column.categorical_column_with_hash_bucket')
1177def categorical_column_with_hash_bucket(key,
1178                                        hash_bucket_size,
1179                                        dtype=dtypes.string):
1180  """Represents sparse feature where ids are set by hashing.
1181
1182  Use this when your sparse features are in string or integer format, and you
1183  want to distribute your inputs into a finite number of buckets by hashing.
1184  output_id = Hash(input_feature_string) % bucket_size for string type input.
1185  For int type input, the value is converted to its string representation first
1186  and then hashed by the same formula.
1187
1188  For input dictionary `features`, `features[key]` is either `Tensor` or
1189  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1190  and `''` for string, which will be dropped by this feature column.
1191
1192  Example:
1193
1194  ```python
1195  import tensorflow as tf
1196  keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords",
1197  10000)
1198  columns = [keywords]
1199  features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM',
1200  'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow',
1201  'LSTM', 'Keras', 'RNN']])}
1202  linear_prediction, _, _ = tf.compat.v1.feature_column.linear_model(features,
1203  columns)
1204
1205  # or
1206  import tensorflow as tf
1207  keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords",
1208  10000)
1209  keywords_embedded = tf.feature_column.embedding_column(keywords, 16)
1210  columns = [keywords_embedded]
1211  features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM',
1212  'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow',
1213  'LSTM', 'Keras', 'RNN']])}
1214  input_layer = tf.keras.layers.DenseFeatures(columns)
1215  dense_tensor = input_layer(features)
1216  ```
1217
1218  Args:
1219    key: A unique string identifying the input feature. It is used as the
1220      column name and the dictionary key for feature parsing configs, feature
1221      `Tensor` objects, and feature columns.
1222    hash_bucket_size: An int > 1. The number of buckets.
1223    dtype: The type of features. Only string and integer types are supported.
1224
1225  Returns:
1226    A `HashedCategoricalColumn`.
1227
1228  Raises:
1229    ValueError: `hash_bucket_size` is not greater than 1.
1230    ValueError: `dtype` is neither string nor integer.
1231  """
1232  if hash_bucket_size is None:
1233    raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key))
1234
1235  if hash_bucket_size < 1:
1236    raise ValueError('hash_bucket_size must be at least 1. '
1237                     'hash_bucket_size: {}, key: {}'.format(
1238                         hash_bucket_size, key))
1239
1240  fc_utils.assert_key_is_string(key)
1241  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1242
1243  return HashedCategoricalColumn(key, hash_bucket_size, dtype)
1244
1245
1246@tf_export(v1=['feature_column.categorical_column_with_vocabulary_file'])
1247def categorical_column_with_vocabulary_file(key,
1248                                            vocabulary_file,
1249                                            vocabulary_size=None,
1250                                            num_oov_buckets=0,
1251                                            default_value=None,
1252                                            dtype=dtypes.string):
1253  """A `CategoricalColumn` with a vocabulary file.
1254
1255  Use this when your inputs are in string or integer format, and you have a
1256  vocabulary file that maps each value to an integer ID. By default,
1257  out-of-vocabulary values are ignored. Use either (but not both) of
1258  `num_oov_buckets` and `default_value` to specify how to include
1259  out-of-vocabulary values.
1260
1261  For input dictionary `features`, `features[key]` is either `Tensor` or
1262  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1263  and `''` for string, which will be dropped by this feature column.
1264
1265  Example with `num_oov_buckets`:
1266  File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
1267  abbreviation. All inputs with values in that file are assigned an ID 0-49,
1268  corresponding to its line number. All other values are hashed and assigned an
1269  ID 50-54.
1270
1271  ```python
1272  import tensorflow as tf
1273  states = tf.feature_column.categorical_column_with_vocabulary_file(
1274    key='states', vocabulary_file='states.txt', vocabulary_size=5,
1275    num_oov_buckets=1)
1276  columns = [states]
1277  features = {'states':tf.constant([['california', 'georgia', 'michigan',
1278  'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1279  'texas']])}
1280  linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1281  columns)
1282  ```
1283
1284  Example with `default_value`:
1285  File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
1286  other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
1287  in input, and other values missing from the file, will be assigned ID 0. All
1288  others are assigned the corresponding line number 1-50.
1289
1290  ```python
1291  import tensorflow as tf
1292  states = tf.feature_column.categorical_column_with_vocabulary_file(
1293    key='states', vocabulary_file='states.txt', vocabulary_size=6,
1294    default_value=0)
1295  columns = [states]
1296  features = {'states':tf.constant([['california', 'georgia', 'michigan',
1297  'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1298  'texas']])}
1299  linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1300  columns)
1301  ```
1302
1303  And to make an embedding with either:
1304
1305  ```python
1306  import tensorflow as tf
1307  states = tf.feature_column.categorical_column_with_vocabulary_file(
1308    key='states', vocabulary_file='states.txt', vocabulary_size=5,
1309    num_oov_buckets=1)
1310  columns = [tf.feature_column.embedding_column(states, 3)]
1311  features = {'states':tf.constant([['california', 'georgia', 'michigan',
1312  'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1313  'texas']])}
1314  input_layer = tf.keras.layers.DenseFeatures(columns)
1315  dense_tensor = input_layer(features)
1316  ```
1317
1318  Args:
1319    key: A unique string identifying the input feature. It is used as the
1320      column name and the dictionary key for feature parsing configs, feature
1321      `Tensor` objects, and feature columns.
1322    vocabulary_file: The vocabulary file name.
1323    vocabulary_size: Number of the elements in the vocabulary. This must be no
1324      greater than length of `vocabulary_file`, if less than length, later
1325      values are ignored. If None, it is set to the length of `vocabulary_file`.
1326    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1327      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1328      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1329      the input value. A positive `num_oov_buckets` can not be specified with
1330      `default_value`.
1331    default_value: The integer ID value to return for out-of-vocabulary feature
1332      values, defaults to `-1`. This can not be specified with a positive
1333      `num_oov_buckets`.
1334    dtype: The type of features. Only string and integer types are supported.
1335
1336  Returns:
1337    A `CategoricalColumn` with a vocabulary file.
1338
1339  Raises:
1340    ValueError: `vocabulary_file` is missing or cannot be opened.
1341    ValueError: `vocabulary_size` is missing or < 1.
1342    ValueError: `num_oov_buckets` is a negative integer.
1343    ValueError: `num_oov_buckets` and `default_value` are both specified.
1344    ValueError: `dtype` is neither string nor integer.
1345  """
1346  return categorical_column_with_vocabulary_file_v2(
1347      key, vocabulary_file, vocabulary_size,
1348      dtype, default_value,
1349      num_oov_buckets)
1350
1351
1352@tf_export('feature_column.categorical_column_with_vocabulary_file', v1=[])
1353def categorical_column_with_vocabulary_file_v2(key,
1354                                               vocabulary_file,
1355                                               vocabulary_size=None,
1356                                               dtype=dtypes.string,
1357                                               default_value=None,
1358                                               num_oov_buckets=0,
1359                                               file_format=None):
1360  """A `CategoricalColumn` with a vocabulary file.
1361
1362  Use this when your inputs are in string or integer format, and you have a
1363  vocabulary file that maps each value to an integer ID. By default,
1364  out-of-vocabulary values are ignored. Use either (but not both) of
1365  `num_oov_buckets` and `default_value` to specify how to include
1366  out-of-vocabulary values.
1367
1368  For input dictionary `features`, `features[key]` is either `Tensor` or
1369  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1370  and `''` for string, which will be dropped by this feature column.
1371
1372  Example with `num_oov_buckets`:
1373  File `'/us/states.txt'` contains 50 lines, each with a 2-character U.S. state
1374  abbreviation. All inputs with values in that file are assigned an ID 0-49,
1375  corresponding to its line number. All other values are hashed and assigned an
1376  ID 50-54.
1377
1378  ```python
1379  states = categorical_column_with_vocabulary_file(
1380      key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
1381      num_oov_buckets=5)
1382  columns = [states, ...]
1383  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1384  linear_prediction = linear_model(features, columns)
1385  ```
1386
1387  Example with `default_value`:
1388  File `'/us/states.txt'` contains 51 lines - the first line is `'XX'`, and the
1389  other 50 each have a 2-character U.S. state abbreviation. Both a literal
1390  `'XX'` in input, and other values missing from the file, will be assigned
1391  ID 0. All others are assigned the corresponding line number 1-50.
1392
1393  ```python
1394  states = categorical_column_with_vocabulary_file(
1395      key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
1396      default_value=0)
1397  columns = [states, ...]
1398  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1399  linear_prediction, _, _ = linear_model(features, columns)
1400  ```
1401
1402  And to make an embedding with either:
1403
1404  ```python
1405  columns = [embedding_column(states, 3),...]
1406  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1407  dense_tensor = input_layer(features, columns)
1408  ```
1409
1410  Args:
1411    key: A unique string identifying the input feature. It is used as the
1412      column name and the dictionary key for feature parsing configs, feature
1413      `Tensor` objects, and feature columns.
1414    vocabulary_file: The vocabulary file name.
1415    vocabulary_size: Number of the elements in the vocabulary. This must be no
1416      greater than length of `vocabulary_file`, if less than length, later
1417      values are ignored. If None, it is set to the length of `vocabulary_file`.
1418    dtype: The type of features. Only string and integer types are supported.
1419    default_value: The integer ID value to return for out-of-vocabulary feature
1420      values, defaults to `-1`. This can not be specified with a positive
1421      `num_oov_buckets`.
1422    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1423      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1424      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1425      the input value. A positive `num_oov_buckets` can not be specified with
1426      `default_value`.
1427    file_format: The format of the vocabulary file. The format is 'text' by
1428      default unless `vocabulary_file` is a string which ends in 'tfrecord.gz'.
1429      Accepted alternative value for `file_format` is 'tfrecord_gzip'.
1430
1431  Returns:
1432    A `CategoricalColumn` with a vocabulary file.
1433
1434  Raises:
1435    ValueError: `vocabulary_file` is missing or cannot be opened.
1436    ValueError: `vocabulary_size` is missing or < 1.
1437    ValueError: `num_oov_buckets` is a negative integer.
1438    ValueError: `num_oov_buckets` and `default_value` are both specified.
1439    ValueError: `dtype` is neither string nor integer.
1440  """
1441  if not vocabulary_file:
1442    raise ValueError('Missing vocabulary_file in {}.'.format(key))
1443
1444  if file_format is None and vocabulary_file.endswith('tfrecord.gz'):
1445    file_format = 'tfrecord_gzip'
1446
1447  if vocabulary_size is None:
1448    if not gfile.Exists(vocabulary_file):
1449      raise ValueError('vocabulary_file in {} does not exist.'.format(key))
1450
1451    if file_format == 'tfrecord_gzip':
1452      ds = readers.TFRecordDataset(vocabulary_file, 'GZIP')
1453      vocabulary_size = ds.reduce(0, lambda x, _: x + 1)
1454      if context.executing_eagerly():
1455        vocabulary_size = vocabulary_size.numpy()
1456    else:
1457      with gfile.GFile(vocabulary_file, mode='rb') as f:
1458        vocabulary_size = sum(1 for _ in f)
1459    logging.info(
1460        'vocabulary_size = %d in %s is inferred from the number of elements '
1461        'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file)
1462
1463  # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
1464  if not isinstance(vocabulary_size, ops.Tensor) and vocabulary_size < 1:
1465    raise ValueError('Invalid vocabulary_size in {}.'.format(key))
1466  if num_oov_buckets:
1467    if default_value is not None:
1468      raise ValueError(
1469          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1470              key))
1471    if num_oov_buckets < 0:
1472      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1473          num_oov_buckets, key))
1474  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1475  fc_utils.assert_key_is_string(key)
1476  return VocabularyFileCategoricalColumn(
1477      key=key,
1478      vocabulary_file=vocabulary_file,
1479      vocabulary_size=vocabulary_size,
1480      num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets,
1481      default_value=-1 if default_value is None else default_value,
1482      dtype=dtype,
1483      file_format=file_format)
1484
1485
1486@tf_export('feature_column.categorical_column_with_vocabulary_list')
1487def categorical_column_with_vocabulary_list(key,
1488                                            vocabulary_list,
1489                                            dtype=None,
1490                                            default_value=-1,
1491                                            num_oov_buckets=0):
1492  """A `CategoricalColumn` with in-memory vocabulary.
1493
1494  Use this when your inputs are in string or integer format, and you have an
1495  in-memory vocabulary mapping each value to an integer ID. By default,
1496  out-of-vocabulary values are ignored. Use either (but not both) of
1497  `num_oov_buckets` and `default_value` to specify how to include
1498  out-of-vocabulary values.
1499
1500  For input dictionary `features`, `features[key]` is either `Tensor` or
1501  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1502  and `''` for string, which will be dropped by this feature column.
1503
1504  Example with `num_oov_buckets`:
1505  In the following example, each input in `vocabulary_list` is assigned an ID
1506  0-3 corresponding to its index (e.g., input 'B' produces output 2). All other
1507  inputs are hashed and assigned an ID 4-5.
1508
1509  ```python
1510  colors = categorical_column_with_vocabulary_list(
1511      key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
1512      num_oov_buckets=2)
1513  columns = [colors, ...]
1514  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1515  linear_prediction, _, _ = linear_model(features, columns)
1516  ```
1517
1518  Example with `default_value`:
1519  In the following example, each input in `vocabulary_list` is assigned an ID
1520  0-4 corresponding to its index (e.g., input 'B' produces output 3). All other
1521  inputs are assigned `default_value` 0.
1522
1523
1524  ```python
1525  colors = categorical_column_with_vocabulary_list(
1526      key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0)
1527  columns = [colors, ...]
1528  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1529  linear_prediction, _, _ = linear_model(features, columns)
1530  ```
1531
1532  And to make an embedding with either:
1533
1534  ```python
1535  columns = [embedding_column(colors, 3),...]
1536  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1537  dense_tensor = input_layer(features, columns)
1538  ```
1539
1540  Args:
1541    key: A unique string identifying the input feature. It is used as the column
1542      name and the dictionary key for feature parsing configs, feature `Tensor`
1543      objects, and feature columns.
1544    vocabulary_list: An ordered iterable defining the vocabulary. Each feature
1545      is mapped to the index of its value (if present) in `vocabulary_list`.
1546      Must be castable to `dtype`.
1547    dtype: The type of features. Only string and integer types are supported. If
1548      `None`, it will be inferred from `vocabulary_list`.
1549    default_value: The integer ID value to return for out-of-vocabulary feature
1550      values, defaults to `-1`. This can not be specified with a positive
1551      `num_oov_buckets`.
1552    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1553      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1554      `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
1555      hash of the input value. A positive `num_oov_buckets` can not be specified
1556      with `default_value`.
1557
1558  Returns:
1559    A `CategoricalColumn` with in-memory vocabulary.
1560
1561  Raises:
1562    ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
1563    ValueError: `num_oov_buckets` is a negative integer.
1564    ValueError: `num_oov_buckets` and `default_value` are both specified.
1565    ValueError: if `dtype` is not integer or string.
1566  """
1567  if (vocabulary_list is None) or (len(vocabulary_list) < 1):
1568    raise ValueError(
1569        'vocabulary_list {} must be non-empty, column_name: {}'.format(
1570            vocabulary_list, key))
1571  if len(set(vocabulary_list)) != len(vocabulary_list):
1572    raise ValueError(
1573        'Duplicate keys in vocabulary_list {}, column_name: {}'.format(
1574            vocabulary_list, key))
1575  vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype)
1576  if num_oov_buckets:
1577    if default_value != -1:
1578      raise ValueError(
1579          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1580              key))
1581    if num_oov_buckets < 0:
1582      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1583          num_oov_buckets, key))
1584  fc_utils.assert_string_or_int(
1585      vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
1586  if dtype is None:
1587    dtype = vocabulary_dtype
1588  elif dtype.is_integer != vocabulary_dtype.is_integer:
1589    raise ValueError(
1590        'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
1591            dtype, vocabulary_dtype, key))
1592  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1593  fc_utils.assert_key_is_string(key)
1594
1595  return VocabularyListCategoricalColumn(
1596      key=key,
1597      vocabulary_list=tuple(vocabulary_list),
1598      dtype=dtype,
1599      default_value=default_value,
1600      num_oov_buckets=num_oov_buckets)
1601
1602
1603@tf_export('feature_column.categorical_column_with_identity')
1604def categorical_column_with_identity(key, num_buckets, default_value=None):
1605  """A `CategoricalColumn` that returns identity values.
1606
1607  Use this when your inputs are integers in the range `[0, num_buckets)`, and
1608  you want to use the input value itself as the categorical ID. Values outside
1609  this range will result in `default_value` if specified, otherwise it will
1610  fail.
1611
1612  Typically, this is used for contiguous ranges of integer indexes, but
1613  it doesn't have to be. This might be inefficient, however, if many of IDs
1614  are unused. Consider `categorical_column_with_hash_bucket` in that case.
1615
1616  For input dictionary `features`, `features[key]` is either `Tensor` or
1617  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1618  and `''` for string, which will be dropped by this feature column.
1619
1620  In the following examples, each input in the range `[0, 1000000)` is assigned
1621  the same value. All other inputs are assigned `default_value` 0. Note that a
1622  literal 0 in inputs will result in the same default ID.
1623
1624  Linear model:
1625
1626  ```python
1627  import tensorflow as tf
1628  video_id = tf.feature_column.categorical_column_with_identity(
1629      key='video_id', num_buckets=1000000, default_value=0)
1630  columns = [video_id]
1631  features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0],
1632  [33,78, 2, 73, 1]])}
1633  linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1634  columns)
1635  ```
1636
1637  Embedding for a DNN model:
1638
1639  ```python
1640  import tensorflow as tf
1641  video_id = tf.feature_column.categorical_column_with_identity(
1642      key='video_id', num_buckets=1000000, default_value=0)
1643  columns = [tf.feature_column.embedding_column(video_id, 9)]
1644  features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0],
1645  [33,78, 2, 73, 1]])}
1646  input_layer = tf.keras.layers.DenseFeatures(columns)
1647  dense_tensor = input_layer(features)
1648  ```
1649
1650  Args:
1651    key: A unique string identifying the input feature. It is used as the
1652      column name and the dictionary key for feature parsing configs, feature
1653      `Tensor` objects, and feature columns.
1654    num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
1655    default_value: If set, values outside of range `[0, num_buckets)` will
1656      be replaced with this value. If not set, values >= num_buckets will
1657      cause a failure while values < 0 will be dropped.
1658
1659  Returns:
1660    A `CategoricalColumn` that returns identity values.
1661
1662  Raises:
1663    ValueError: if `num_buckets` is less than one.
1664    ValueError: if `default_value` is not in range `[0, num_buckets)`.
1665  """
1666  if num_buckets < 1:
1667    raise ValueError(
1668        'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
1669  if (default_value is not None) and (
1670      (default_value < 0) or (default_value >= num_buckets)):
1671    raise ValueError(
1672        'default_value {} not in range [0, {}), column_name {}'.format(
1673            default_value, num_buckets, key))
1674  fc_utils.assert_key_is_string(key)
1675  return IdentityCategoricalColumn(
1676      key=key, number_buckets=num_buckets, default_value=default_value)
1677
1678
1679@tf_export('feature_column.indicator_column')
1680def indicator_column(categorical_column):
1681  """Represents multi-hot representation of given categorical column.
1682
1683  - For DNN model, `indicator_column` can be used to wrap any
1684    `categorical_column_*` (e.g., to feed to DNN). Consider to Use
1685    `embedding_column` if the number of buckets/unique(values) are large.
1686
1687  - For Wide (aka linear) model, `indicator_column` is the internal
1688    representation for categorical column when passing categorical column
1689    directly (as any element in feature_columns) to `linear_model`. See
1690    `linear_model` for details.
1691
1692  ```python
1693  name = indicator_column(categorical_column_with_vocabulary_list(
1694      'name', ['bob', 'george', 'wanda']))
1695  columns = [name, ...]
1696  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1697  dense_tensor = input_layer(features, columns)
1698
1699  dense_tensor == [[1, 0, 0]]  # If "name" bytes_list is ["bob"]
1700  dense_tensor == [[1, 0, 1]]  # If "name" bytes_list is ["bob", "wanda"]
1701  dense_tensor == [[2, 0, 0]]  # If "name" bytes_list is ["bob", "bob"]
1702  ```
1703
1704  Args:
1705    categorical_column: A `CategoricalColumn` which is created by
1706      `categorical_column_with_*` or `crossed_column` functions.
1707
1708  Returns:
1709    An `IndicatorColumn`.
1710
1711  Raises:
1712    ValueError: If `categorical_column` is not CategoricalColumn type.
1713  """
1714  if not isinstance(categorical_column,
1715                    (CategoricalColumn, fc_old._CategoricalColumn)):  # pylint: disable=protected-access
1716    raise ValueError(
1717        'Unsupported input type. Input must be a CategoricalColumn. '
1718        'Given: {}'.format(categorical_column))
1719  return IndicatorColumn(categorical_column)
1720
1721
1722@tf_export('feature_column.weighted_categorical_column')
1723def weighted_categorical_column(categorical_column,
1724                                weight_feature_key,
1725                                dtype=dtypes.float32):
1726  """Applies weight values to a `CategoricalColumn`.
1727
1728  Use this when each of your sparse inputs has both an ID and a value. For
1729  example, if you're representing text documents as a collection of word
1730  frequencies, you can provide 2 parallel sparse input features ('terms' and
1731  'frequencies' below).
1732
1733  Example:
1734
1735  Input `tf.Example` objects:
1736
1737  ```proto
1738  [
1739    features {
1740      feature {
1741        key: "terms"
1742        value {bytes_list {value: "very" value: "model"}}
1743      }
1744      feature {
1745        key: "frequencies"
1746        value {float_list {value: 0.3 value: 0.1}}
1747      }
1748    },
1749    features {
1750      feature {
1751        key: "terms"
1752        value {bytes_list {value: "when" value: "course" value: "human"}}
1753      }
1754      feature {
1755        key: "frequencies"
1756        value {float_list {value: 0.4 value: 0.1 value: 0.2}}
1757      }
1758    }
1759  ]
1760  ```
1761
1762  ```python
1763  categorical_column = categorical_column_with_hash_bucket(
1764      column_name='terms', hash_bucket_size=1000)
1765  weighted_column = weighted_categorical_column(
1766      categorical_column=categorical_column, weight_feature_key='frequencies')
1767  columns = [weighted_column, ...]
1768  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1769  linear_prediction, _, _ = linear_model(features, columns)
1770  ```
1771
1772  This assumes the input dictionary contains a `SparseTensor` for key
1773  'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have
1774  the same indices and dense shape.
1775
1776  Args:
1777    categorical_column: A `CategoricalColumn` created by
1778      `categorical_column_with_*` functions.
1779    weight_feature_key: String key for weight values.
1780    dtype: Type of weights, such as `tf.float32`. Only float and integer weights
1781      are supported.
1782
1783  Returns:
1784    A `CategoricalColumn` composed of two sparse features: one represents id,
1785    the other represents weight (value) of the id feature in that example.
1786
1787  Raises:
1788    ValueError: if `dtype` is not convertible to float.
1789  """
1790  if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
1791    raise ValueError('dtype {} is not convertible to float.'.format(dtype))
1792  return WeightedCategoricalColumn(
1793      categorical_column=categorical_column,
1794      weight_feature_key=weight_feature_key,
1795      dtype=dtype)
1796
1797
1798@tf_export('feature_column.crossed_column')
1799def crossed_column(keys, hash_bucket_size, hash_key=None):
1800  """Returns a column for performing crosses of categorical features.
1801
1802  Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
1803  the transformation can be thought of as:
1804    Hash(cartesian product of features) % `hash_bucket_size`
1805
1806  For example, if the input features are:
1807
1808  * SparseTensor referred by first key:
1809
1810    ```python
1811    shape = [2, 2]
1812    {
1813        [0, 0]: "a"
1814        [1, 0]: "b"
1815        [1, 1]: "c"
1816    }
1817    ```
1818
1819  * SparseTensor referred by second key:
1820
1821    ```python
1822    shape = [2, 1]
1823    {
1824        [0, 0]: "d"
1825        [1, 0]: "e"
1826    }
1827    ```
1828
1829  then crossed feature will look like:
1830
1831  ```python
1832   shape = [2, 2]
1833  {
1834      [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size
1835      [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size
1836      [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size
1837  }
1838  ```
1839
1840  Here is an example to create a linear model with crosses of string features:
1841
1842  ```python
1843  keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K)
1844  columns = [keywords_x_doc_terms, ...]
1845  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1846  linear_prediction = linear_model(features, columns)
1847  ```
1848
1849  You could also use vocabulary lookup before crossing:
1850
1851  ```python
1852  keywords = categorical_column_with_vocabulary_file(
1853      'keywords', '/path/to/vocabulary/file', vocabulary_size=1K)
1854  keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K)
1855  columns = [keywords_x_doc_terms, ...]
1856  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1857  linear_prediction = linear_model(features, columns)
1858  ```
1859
1860  If an input feature is of numeric type, you can use
1861  `categorical_column_with_identity`, or `bucketized_column`, as in the example:
1862
1863  ```python
1864  # vertical_id is an integer categorical feature.
1865  vertical_id = categorical_column_with_identity('vertical_id', 10K)
1866  price = numeric_column('price')
1867  # bucketized_column converts numerical feature to a categorical one.
1868  bucketized_price = bucketized_column(price, boundaries=[...])
1869  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
1870  columns = [vertical_id_x_price, ...]
1871  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1872  linear_prediction = linear_model(features, columns)
1873  ```
1874
1875  To use crossed column in DNN model, you need to add it in an embedding column
1876  as in this example:
1877
1878  ```python
1879  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
1880  vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10)
1881  dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...])
1882  ```
1883
1884  Args:
1885    keys: An iterable identifying the features to be crossed. Each element can
1886      be either:
1887      * string: Will use the corresponding feature which must be of string type.
1888      * `CategoricalColumn`: Will use the transformed tensor produced by this
1889        column. Does not support hashed categorical column.
1890    hash_bucket_size: An int > 1. The number of buckets.
1891    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
1892      function to combine the crosses fingerprints on SparseCrossOp (optional).
1893
1894  Returns:
1895    A `CrossedColumn`.
1896
1897  Raises:
1898    ValueError: If `len(keys) < 2`.
1899    ValueError: If any of the keys is neither a string nor `CategoricalColumn`.
1900    ValueError: If any of the keys is `HashedCategoricalColumn`.
1901    ValueError: If `hash_bucket_size < 1`.
1902  """
1903  if not hash_bucket_size or hash_bucket_size < 1:
1904    raise ValueError('hash_bucket_size must be > 1. '
1905                     'hash_bucket_size: {}'.format(hash_bucket_size))
1906  if not keys or len(keys) < 2:
1907    raise ValueError(
1908        'keys must be a list with length > 1. Given: {}'.format(keys))
1909  for key in keys:
1910    if (not isinstance(key, six.string_types) and
1911        not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))):  # pylint: disable=protected-access
1912      raise ValueError(
1913          'Unsupported key type. All keys must be either string, or '
1914          'categorical column except HashedCategoricalColumn. '
1915          'Given: {}'.format(key))
1916    if isinstance(key,
1917                  (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)):  # pylint: disable=protected-access
1918      raise ValueError(
1919          'categorical_column_with_hash_bucket is not supported for crossing. '
1920          'Hashing before crossing will increase probability of collision. '
1921          'Instead, use the feature name as a string. Given: {}'.format(key))
1922  return CrossedColumn(
1923      keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
1924
1925
1926# TODO(b/181853833): Add a tf.type for instance type checking.
1927@tf_export('__internal__.feature_column.FeatureColumn', v1=[])
1928@six.add_metaclass(abc.ABCMeta)
1929class FeatureColumn(object):
1930  """Represents a feature column abstraction.
1931
1932  WARNING: Do not subclass this layer unless you know what you are doing:
1933  the API is subject to future changes.
1934
1935  To distinguish between the concept of a feature family and a specific binary
1936  feature within a family, we refer to a feature family like "country" as a
1937  feature column. For example, we can have a feature in a `tf.Example` format:
1938    {key: "country",  value: [ "US" ]}
1939  In this example the value of feature is "US" and "country" refers to the
1940  column of the feature.
1941
1942  This class is an abstract class. Users should not create instances of this.
1943  """
1944
1945  @abc.abstractproperty
1946  def name(self):
1947    """Returns string. Used for naming."""
1948    pass
1949
1950  def __lt__(self, other):
1951    """Allows feature columns to be sorted in Python 3 as they are in Python 2.
1952
1953    Feature columns need to occasionally be sortable, for example when used as
1954    keys in a features dictionary passed to a layer.
1955
1956    In CPython, `__lt__` must be defined for all objects in the
1957    sequence being sorted.
1958
1959    If any objects in the sequence being sorted do not have an `__lt__` method
1960    compatible with feature column objects (such as strings), then CPython will
1961    fall back to using the `__gt__` method below.
1962    https://docs.python.org/3/library/stdtypes.html#list.sort
1963
1964    Args:
1965      other: The other object to compare to.
1966
1967    Returns:
1968      True if the string representation of this object is lexicographically less
1969      than the string representation of `other`. For FeatureColumn objects,
1970      this looks like "<__main__.FeatureColumn object at 0xa>".
1971    """
1972    return str(self) < str(other)
1973
1974  def __gt__(self, other):
1975    """Allows feature columns to be sorted in Python 3 as they are in Python 2.
1976
1977    Feature columns need to occasionally be sortable, for example when used as
1978    keys in a features dictionary passed to a layer.
1979
1980    `__gt__` is called when the "other" object being compared during the sort
1981    does not have `__lt__` defined.
1982    Example:
1983    ```
1984    # __lt__ only class
1985    class A():
1986      def __lt__(self, other): return str(self) < str(other)
1987
1988    a = A()
1989    a < "b" # True
1990    "0" < a # Error
1991
1992    # __lt__ and __gt__ class
1993    class B():
1994      def __lt__(self, other): return str(self) < str(other)
1995      def __gt__(self, other): return str(self) > str(other)
1996
1997    b = B()
1998    b < "c" # True
1999    "0" < b # True
2000    ```
2001
2002    Args:
2003      other: The other object to compare to.
2004
2005    Returns:
2006      True if the string representation of this object is lexicographically
2007      greater than the string representation of `other`. For FeatureColumn
2008      objects, this looks like "<__main__.FeatureColumn object at 0xa>".
2009    """
2010    return str(self) > str(other)
2011
2012  @abc.abstractmethod
2013  def transform_feature(self, transformation_cache, state_manager):
2014    """Returns intermediate representation (usually a `Tensor`).
2015
2016    Uses `transformation_cache` to create an intermediate representation
2017    (usually a `Tensor`) that other feature columns can use.
2018
2019    Example usage of `transformation_cache`:
2020    Let's say a Feature column depends on raw feature ('raw') and another
2021    `FeatureColumn` (input_fc). To access corresponding `Tensor`s,
2022    transformation_cache will be used as follows:
2023
2024    ```python
2025    raw_tensor = transformation_cache.get('raw', state_manager)
2026    fc_tensor = transformation_cache.get(input_fc, state_manager)
2027    ```
2028
2029    Args:
2030      transformation_cache: A `FeatureTransformationCache` object to access
2031        features.
2032      state_manager: A `StateManager` to create / access resources such as
2033        lookup tables.
2034
2035    Returns:
2036      Transformed feature `Tensor`.
2037    """
2038    pass
2039
2040  @abc.abstractproperty
2041  def parse_example_spec(self):
2042    """Returns a `tf.Example` parsing spec as dict.
2043
2044    It is used for get_parsing_spec for `tf.io.parse_example`. Returned spec is
2045    a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other
2046    supported objects. Please check documentation of `tf.io.parse_example` for
2047    all supported spec objects.
2048
2049    Let's say a Feature column depends on raw feature ('raw') and another
2050    `FeatureColumn` (input_fc). One possible implementation of
2051    parse_example_spec is as follows:
2052
2053    ```python
2054    spec = {'raw': tf.io.FixedLenFeature(...)}
2055    spec.update(input_fc.parse_example_spec)
2056    return spec
2057    ```
2058    """
2059    pass
2060
2061  def create_state(self, state_manager):
2062    """Uses the `state_manager` to create state for the FeatureColumn.
2063
2064    Args:
2065      state_manager: A `StateManager` to create / access resources such as
2066        lookup tables and variables.
2067    """
2068    pass
2069
2070  @abc.abstractproperty
2071  def _is_v2_column(self):
2072    """Returns whether this FeatureColumn is fully conformant to the new API.
2073
2074    This is needed for composition type cases where an EmbeddingColumn etc.
2075    might take in old categorical columns as input and then we want to use the
2076    old API.
2077    """
2078    pass
2079
2080  @abc.abstractproperty
2081  def parents(self):
2082    """Returns a list of immediate raw feature and FeatureColumn dependencies.
2083
2084    For example:
2085    # For the following feature columns
2086    a = numeric_column('f1')
2087    c = crossed_column(a, 'f2')
2088    # The expected parents are:
2089    a.parents = ['f1']
2090    c.parents = [a, 'f2']
2091    """
2092    pass
2093
2094  def get_config(self):
2095    """Returns the config of the feature column.
2096
2097    A FeatureColumn config is a Python dictionary (serializable) containing the
2098    configuration of a FeatureColumn. The same FeatureColumn can be
2099    reinstantiated later from this configuration.
2100
2101    The config of a feature column does not include information about feature
2102    columns depending on it nor the FeatureColumn class name.
2103
2104    Example with (de)serialization practices followed in this file:
2105    ```python
2106    class SerializationExampleFeatureColumn(
2107        FeatureColumn, collections.namedtuple(
2108            'SerializationExampleFeatureColumn',
2109            ('dimension', 'parent', 'dtype', 'normalizer_fn'))):
2110
2111      def get_config(self):
2112        # Create a dict from the namedtuple.
2113        # Python attribute literals can be directly copied from / to the config.
2114        # For example 'dimension', assuming it is an integer literal.
2115        config = dict(zip(self._fields, self))
2116
2117        # (De)serialization of parent FeatureColumns should use the provided
2118        # (de)serialize_feature_column() methods that take care of de-duping.
2119        config['parent'] = serialize_feature_column(self.parent)
2120
2121        # Many objects provide custom (de)serialization e.g: for tf.DType
2122        # tf.DType.name, tf.as_dtype() can be used.
2123        config['dtype'] = self.dtype.name
2124
2125        # Non-trivial dependencies should be Keras-(de)serializable.
2126        config['normalizer_fn'] = generic_utils.serialize_keras_object(
2127            self.normalizer_fn)
2128
2129        return config
2130
2131      @classmethod
2132      def from_config(cls, config, custom_objects=None, columns_by_name=None):
2133        # This should do the inverse transform from `get_config` and construct
2134        # the namedtuple.
2135        kwargs = config.copy()
2136        kwargs['parent'] = deserialize_feature_column(
2137            config['parent'], custom_objects, columns_by_name)
2138        kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2139        kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
2140          config['normalizer_fn'], custom_objects=custom_objects)
2141        return cls(**kwargs)
2142
2143    ```
2144    Returns:
2145      A serializable Dict that can be used to deserialize the object with
2146      from_config.
2147    """
2148    return self._get_config()
2149
2150  def _get_config(self):
2151    raise NotImplementedError('Must be implemented in subclasses.')
2152
2153  @classmethod
2154  def from_config(cls, config, custom_objects=None, columns_by_name=None):
2155    """Creates a FeatureColumn from its config.
2156
2157    This method should be the reverse of `get_config`, capable of instantiating
2158    the same FeatureColumn from the config dictionary. See `get_config` for an
2159    example of common (de)serialization practices followed in this file.
2160
2161    TODO(b/118939620): This is a private method until consensus is reached on
2162    supporting object deserialization deduping within Keras.
2163
2164    Args:
2165      config: A Dict config acquired with `get_config`.
2166      custom_objects: Optional dictionary mapping names (strings) to custom
2167        classes or functions to be considered during deserialization.
2168      columns_by_name: A Dict[String, FeatureColumn] of existing columns in
2169        order to avoid duplication. Should be passed to any calls to
2170        deserialize_feature_column().
2171
2172    Returns:
2173      A FeatureColumn for the input config.
2174    """
2175    return cls._from_config(config, custom_objects, columns_by_name)
2176
2177  @classmethod
2178  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
2179    raise NotImplementedError('Must be implemented in subclasses.')
2180
2181
2182# TODO(b/181853833): Add a tf.type for instance type checking.
2183@tf_export('__internal__.feature_column.DenseColumn', v1=[])
2184class DenseColumn(FeatureColumn):
2185  """Represents a column which can be represented as `Tensor`.
2186
2187  Some examples of this type are: numeric_column, embedding_column,
2188  indicator_column.
2189  """
2190
2191  @abc.abstractproperty
2192  def variable_shape(self):
2193    """`TensorShape` of `get_dense_tensor`, without batch dimension."""
2194    pass
2195
2196  @abc.abstractmethod
2197  def get_dense_tensor(self, transformation_cache, state_manager):
2198    """Returns a `Tensor`.
2199
2200    The output of this function will be used by model-builder-functions. For
2201    example the pseudo code of `input_layer` will be like:
2202
2203    ```python
2204    def input_layer(features, feature_columns, ...):
2205      outputs = [fc.get_dense_tensor(...) for fc in feature_columns]
2206      return tf.concat(outputs)
2207    ```
2208
2209    Args:
2210      transformation_cache: A `FeatureTransformationCache` object to access
2211        features.
2212      state_manager: A `StateManager` to create / access resources such as
2213        lookup tables.
2214
2215    Returns:
2216      `Tensor` of shape [batch_size] + `variable_shape`.
2217    """
2218    pass
2219
2220
2221def is_feature_column_v2(feature_columns):
2222  """Returns True if all feature columns are V2."""
2223  for feature_column in feature_columns:
2224    if not isinstance(feature_column, FeatureColumn):
2225      return False
2226    if not feature_column._is_v2_column:  # pylint: disable=protected-access
2227      return False
2228  return True
2229
2230
2231def _create_weighted_sum(column, transformation_cache, state_manager,
2232                         sparse_combiner, weight_var):
2233  """Creates a weighted sum for a dense/categorical column for linear_model."""
2234  if isinstance(column, CategoricalColumn):
2235    return _create_categorical_column_weighted_sum(
2236        column=column,
2237        transformation_cache=transformation_cache,
2238        state_manager=state_manager,
2239        sparse_combiner=sparse_combiner,
2240        weight_var=weight_var)
2241  else:
2242    return _create_dense_column_weighted_sum(
2243        column=column,
2244        transformation_cache=transformation_cache,
2245        state_manager=state_manager,
2246        weight_var=weight_var)
2247
2248
2249def _create_dense_column_weighted_sum(column, transformation_cache,
2250                                      state_manager, weight_var):
2251  """Create a weighted sum of a dense column for linear_model."""
2252  tensor = column.get_dense_tensor(transformation_cache, state_manager)
2253  num_elements = column.variable_shape.num_elements()
2254  batch_size = array_ops.shape(tensor)[0]
2255  tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
2256  return math_ops.matmul(tensor, weight_var, name='weighted_sum')
2257
2258
2259class CategoricalColumn(FeatureColumn):
2260  """Represents a categorical feature.
2261
2262  A categorical feature typically handled with a `tf.sparse.SparseTensor` of
2263  IDs.
2264  """
2265
2266  IdWeightPair = collections.namedtuple(  # pylint: disable=invalid-name
2267      'IdWeightPair', ('id_tensor', 'weight_tensor'))
2268
2269  @abc.abstractproperty
2270  def num_buckets(self):
2271    """Returns number of buckets in this sparse feature."""
2272    pass
2273
2274  @abc.abstractmethod
2275  def get_sparse_tensors(self, transformation_cache, state_manager):
2276    """Returns an IdWeightPair.
2277
2278    `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
2279    weights.
2280
2281    `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
2282    `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
2283    `SparseTensor` of `float` or `None` to indicate all weights should be
2284    taken to be 1. If specified, `weight_tensor` must have exactly the same
2285    shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
2286    output of a `VarLenFeature` which is a ragged matrix.
2287
2288    Args:
2289      transformation_cache: A `FeatureTransformationCache` object to access
2290        features.
2291      state_manager: A `StateManager` to create / access resources such as
2292        lookup tables.
2293    """
2294    pass
2295
2296
2297def _create_categorical_column_weighted_sum(
2298    column, transformation_cache, state_manager, sparse_combiner, weight_var):
2299  # pylint: disable=g-doc-return-or-yield,g-doc-args
2300  """Create a weighted sum of a categorical column for linear_model.
2301
2302  Note to maintainer: As implementation details, the weighted sum is
2303  implemented via embedding_lookup_sparse toward efficiency. Mathematically,
2304  they are the same.
2305
2306  To be specific, conceptually, categorical column can be treated as multi-hot
2307  vector. Say:
2308
2309  ```python
2310    x = [0 0 1]  # categorical column input
2311    w = [a b c]  # weights
2312  ```
2313  The weighted sum is `c` in this case, which is same as `w[2]`.
2314
2315  Another example is
2316
2317  ```python
2318    x = [0 1 1]  # categorical column input
2319    w = [a b c]  # weights
2320  ```
2321  The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
2322
2323  For both cases, we can implement weighted sum via embedding_lookup with
2324  sparse_combiner = "sum".
2325  """
2326
2327  sparse_tensors = column.get_sparse_tensors(transformation_cache,
2328                                             state_manager)
2329  id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [
2330      array_ops.shape(sparse_tensors.id_tensor)[0], -1
2331  ])
2332  weight_tensor = sparse_tensors.weight_tensor
2333  if weight_tensor is not None:
2334    weight_tensor = sparse_ops.sparse_reshape(
2335        weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
2336
2337  return embedding_ops.safe_embedding_lookup_sparse(
2338      weight_var,
2339      id_tensor,
2340      sparse_weights=weight_tensor,
2341      combiner=sparse_combiner,
2342      name='weighted_sum')
2343
2344
2345# TODO(b/181853833): Add a tf.type for instance type checking.
2346@tf_export('__internal__.feature_column.SequenceDenseColumn', v1=[])
2347class SequenceDenseColumn(FeatureColumn):
2348  """Represents dense sequence data."""
2349
2350  TensorSequenceLengthPair = collections.namedtuple(  # pylint: disable=invalid-name
2351      'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length'))
2352
2353  @abc.abstractmethod
2354  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
2355    """Returns a `TensorSequenceLengthPair`.
2356
2357    Args:
2358      transformation_cache: A `FeatureTransformationCache` object to access
2359        features.
2360      state_manager: A `StateManager` to create / access resources such as
2361        lookup tables.
2362    """
2363    pass
2364
2365
2366@tf_export('__internal__.feature_column.FeatureTransformationCache', v1=[])
2367class FeatureTransformationCache(object):
2368  """Handles caching of transformations while building the model.
2369
2370  `FeatureColumn` specifies how to digest an input column to the network. Some
2371  feature columns require data transformations. This class caches those
2372  transformations.
2373
2374  Some features may be used in more than one place. For example, one can use a
2375  bucketized feature by itself and a cross with it. In that case we
2376  should create only one bucketization op instead of creating ops for each
2377  feature column separately. To handle re-use of transformed columns,
2378  `FeatureTransformationCache` caches all previously transformed columns.
2379
2380  Example:
2381  We're trying to use the following `FeatureColumn`s:
2382
2383  ```python
2384  bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...)
2385  keywords = fc.categorical_column_with_hash_buckets("keywords", ...)
2386  age_X_keywords = fc.crossed_column([bucketized_age, "keywords"])
2387  ... = linear_model(features,
2388                          [bucketized_age, keywords, age_X_keywords]
2389  ```
2390
2391  If we transform each column independently, then we'll get duplication of
2392  bucketization (one for cross, one for bucketization itself).
2393  The `FeatureTransformationCache` eliminates this duplication.
2394  """
2395
2396  def __init__(self, features):
2397    """Creates a `FeatureTransformationCache`.
2398
2399    Args:
2400      features: A mapping from feature column to objects that are `Tensor` or
2401        `SparseTensor`, or can be converted to same via
2402        `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key
2403        signifies a base feature (not-transformed). A `FeatureColumn` key
2404        means that this `Tensor` is the output of an existing `FeatureColumn`
2405        which can be reused.
2406    """
2407    self._features = features.copy()
2408    self._feature_tensors = {}
2409
2410  def get(self, key, state_manager, training=None):
2411    """Returns a `Tensor` for the given key.
2412
2413    A `str` key is used to access a base feature (not-transformed). When a
2414    `FeatureColumn` is passed, the transformed feature is returned if it
2415    already exists, otherwise the given `FeatureColumn` is asked to provide its
2416    transformed output, which is then cached.
2417
2418    Args:
2419      key: a `str` or a `FeatureColumn`.
2420      state_manager: A StateManager object that holds the FeatureColumn state.
2421      training: Boolean indicating whether to the column is being used in
2422        training mode. This argument is passed to the transform_feature method
2423        of any `FeatureColumn` that takes a `training` argument. For example, if
2424        a `FeatureColumn` performed dropout, it could expose a `training`
2425        argument to control whether the dropout should be applied.
2426
2427    Returns:
2428      The transformed `Tensor` corresponding to the `key`.
2429
2430    Raises:
2431      ValueError: if key is not found or a transformed `Tensor` cannot be
2432        computed.
2433    """
2434    if key in self._feature_tensors:
2435      # FeatureColumn is already transformed or converted.
2436      return self._feature_tensors[key]
2437
2438    if key in self._features:
2439      feature_tensor = self._get_raw_feature_as_tensor(key)
2440      self._feature_tensors[key] = feature_tensor
2441      return feature_tensor
2442
2443    if isinstance(key, six.string_types):
2444      raise ValueError('Feature {} is not in features dictionary.'.format(key))
2445
2446    if not isinstance(key, FeatureColumn):
2447      raise TypeError('"key" must be either a "str" or "FeatureColumn". '
2448                      'Provided: {}'.format(key))
2449
2450    column = key
2451    logging.debug('Transforming feature_column %s.', column)
2452
2453    # Some columns may need information about whether the transformation is
2454    # happening in training or prediction mode, but not all columns expose this
2455    # argument.
2456    try:
2457      transformed = column.transform_feature(
2458          self, state_manager, training=training)
2459    except TypeError:
2460      transformed = column.transform_feature(self, state_manager)
2461    if transformed is None:
2462      raise ValueError('Column {} is not supported.'.format(column.name))
2463    self._feature_tensors[column] = transformed
2464    return transformed
2465
2466  def _get_raw_feature_as_tensor(self, key):
2467    """Gets the raw_feature (keyed by `key`) as `tensor`.
2468
2469    The raw feature is converted to (sparse) tensor and maybe expand dim.
2470
2471    For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if
2472    the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will
2473    error out as it is not supported.
2474
2475    Args:
2476      key: A `str` key to access the raw feature.
2477
2478    Returns:
2479      A `Tensor` or `SparseTensor`.
2480
2481    Raises:
2482      ValueError: if the raw feature has rank 0.
2483    """
2484    raw_feature = self._features[key]
2485    feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2486        raw_feature)
2487
2488    def expand_dims(input_tensor):
2489      # Input_tensor must have rank 1.
2490      if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2491        return sparse_ops.sparse_reshape(
2492            input_tensor, [array_ops.shape(input_tensor)[0], 1])
2493      else:
2494        return array_ops.expand_dims(input_tensor, -1)
2495
2496    rank = feature_tensor.get_shape().ndims
2497    if rank is not None:
2498      if rank == 0:
2499        raise ValueError(
2500            'Feature (key: {}) cannot have rank 0. Given: {}'.format(
2501                key, feature_tensor))
2502      return feature_tensor if rank != 1 else expand_dims(feature_tensor)
2503
2504    # Handle dynamic rank.
2505    with ops.control_dependencies([
2506        check_ops.assert_positive(
2507            array_ops.rank(feature_tensor),
2508            message='Feature (key: {}) cannot have rank 0. Given: {}'.format(
2509                key, feature_tensor))]):
2510      return control_flow_ops.cond(
2511          math_ops.equal(1, array_ops.rank(feature_tensor)),
2512          lambda: expand_dims(feature_tensor),
2513          lambda: feature_tensor)
2514
2515
2516# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
2517def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
2518  """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
2519
2520  If `input_tensor` is already a `SparseTensor`, just return it.
2521
2522  Args:
2523    input_tensor: A string or integer `Tensor`.
2524    ignore_value: Entries in `dense_tensor` equal to this value will be
2525      absent from the resulting `SparseTensor`. If `None`, default value of
2526      `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`).
2527
2528  Returns:
2529    A `SparseTensor` with the same shape as `input_tensor`.
2530
2531  Raises:
2532    ValueError: when `input_tensor`'s rank is `None`.
2533  """
2534  input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2535      input_tensor)
2536  if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2537    return input_tensor
2538  with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)):
2539    if ignore_value is None:
2540      if input_tensor.dtype == dtypes.string:
2541        # Exception due to TF strings are converted to numpy objects by default.
2542        ignore_value = ''
2543      elif input_tensor.dtype.is_integer:
2544        ignore_value = -1  # -1 has a special meaning of missing feature
2545      else:
2546        # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
2547        # constructing a new numpy object of the given type, which yields the
2548        # default value for that type.
2549        ignore_value = input_tensor.dtype.as_numpy_dtype()
2550    ignore_value = math_ops.cast(
2551        ignore_value, input_tensor.dtype, name='ignore_value')
2552    indices = array_ops.where_v2(
2553        math_ops.not_equal(input_tensor, ignore_value), name='indices')
2554    return sparse_tensor_lib.SparseTensor(
2555        indices=indices,
2556        values=array_ops.gather_nd(input_tensor, indices, name='values'),
2557        dense_shape=array_ops.shape(
2558            input_tensor, out_type=dtypes.int64, name='dense_shape'))
2559
2560
2561def _normalize_feature_columns(feature_columns):
2562  """Normalizes the `feature_columns` input.
2563
2564  This method converts the `feature_columns` to list type as best as it can. In
2565  addition, verifies the type and other parts of feature_columns, required by
2566  downstream library.
2567
2568  Args:
2569    feature_columns: The raw feature columns, usually passed by users.
2570
2571  Returns:
2572    The normalized feature column list.
2573
2574  Raises:
2575    ValueError: for any invalid inputs, such as empty, duplicated names, etc.
2576  """
2577  if isinstance(feature_columns, FeatureColumn):
2578    feature_columns = [feature_columns]
2579
2580  if isinstance(feature_columns, collections_abc.Iterator):
2581    feature_columns = list(feature_columns)
2582
2583  if isinstance(feature_columns, dict):
2584    raise ValueError('Expected feature_columns to be iterable, found dict.')
2585
2586  for column in feature_columns:
2587    if not isinstance(column, FeatureColumn):
2588      raise ValueError('Items of feature_columns must be a FeatureColumn. '
2589                       'Given (type {}): {}.'.format(type(column), column))
2590  if not feature_columns:
2591    raise ValueError('feature_columns must not be empty.')
2592  name_to_column = {}
2593  for column in feature_columns:
2594    if column.name in name_to_column:
2595      raise ValueError('Duplicate feature column name found for columns: {} '
2596                       'and {}. This usually means that these columns refer to '
2597                       'same base feature. Either one must be discarded or a '
2598                       'duplicated but renamed item must be inserted in '
2599                       'features dict.'.format(column,
2600                                               name_to_column[column.name]))
2601    name_to_column[column.name] = column
2602
2603  return sorted(feature_columns, key=lambda x: x.name)
2604
2605
2606class NumericColumn(
2607    DenseColumn,
2608    fc_old._DenseColumn,  # pylint: disable=protected-access
2609    collections.namedtuple(
2610        'NumericColumn',
2611        ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
2612  """see `numeric_column`."""
2613
2614  @property
2615  def _is_v2_column(self):
2616    return True
2617
2618  @property
2619  def name(self):
2620    """See `FeatureColumn` base class."""
2621    return self.key
2622
2623  @property
2624  def parse_example_spec(self):
2625    """See `FeatureColumn` base class."""
2626    return {
2627        self.key:
2628            parsing_ops.FixedLenFeature(self.shape, self.dtype,
2629                                        self.default_value)
2630    }
2631
2632  @property
2633  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2634                          _FEATURE_COLUMN_DEPRECATION)
2635  def _parse_example_spec(self):
2636    return self.parse_example_spec
2637
2638  def _transform_input_tensor(self, input_tensor):
2639    if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2640      raise ValueError(
2641          'The corresponding Tensor of numerical column must be a Tensor. '
2642          'SparseTensor is not supported. key: {}'.format(self.key))
2643    if self.normalizer_fn is not None:
2644      input_tensor = self.normalizer_fn(input_tensor)
2645    return math_ops.cast(input_tensor, dtypes.float32)
2646
2647  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2648                          _FEATURE_COLUMN_DEPRECATION)
2649  def _transform_feature(self, inputs):
2650    input_tensor = inputs.get(self.key)
2651    return self._transform_input_tensor(input_tensor)
2652
2653  def transform_feature(self, transformation_cache, state_manager):
2654    """See `FeatureColumn` base class.
2655
2656    In this case, we apply the `normalizer_fn` to the input tensor.
2657
2658    Args:
2659      transformation_cache: A `FeatureTransformationCache` object to access
2660        features.
2661      state_manager: A `StateManager` to create / access resources such as
2662        lookup tables.
2663
2664    Returns:
2665      Normalized input tensor.
2666    Raises:
2667      ValueError: If a SparseTensor is passed in.
2668    """
2669    input_tensor = transformation_cache.get(self.key, state_manager)
2670    return self._transform_input_tensor(input_tensor)
2671
2672  @property
2673  def variable_shape(self):
2674    """See `DenseColumn` base class."""
2675    return tensor_shape.TensorShape(self.shape)
2676
2677  @property
2678  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2679                          _FEATURE_COLUMN_DEPRECATION)
2680  def _variable_shape(self):
2681    return self.variable_shape
2682
2683  def get_dense_tensor(self, transformation_cache, state_manager):
2684    """Returns dense `Tensor` representing numeric feature.
2685
2686    Args:
2687      transformation_cache: A `FeatureTransformationCache` object to access
2688        features.
2689      state_manager: A `StateManager` to create / access resources such as
2690        lookup tables.
2691
2692    Returns:
2693      Dense `Tensor` created within `transform_feature`.
2694    """
2695    # Feature has been already transformed. Return the intermediate
2696    # representation created by _transform_feature.
2697    return transformation_cache.get(self, state_manager)
2698
2699  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2700                          _FEATURE_COLUMN_DEPRECATION)
2701  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2702    del weight_collections
2703    del trainable
2704    return inputs.get(self)
2705
2706  @property
2707  def parents(self):
2708    """See 'FeatureColumn` base class."""
2709    return [self.key]
2710
2711  def get_config(self):
2712    """See 'FeatureColumn` base class."""
2713    config = dict(zip(self._fields, self))
2714    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
2715    config['normalizer_fn'] = serialization._serialize_keras_object(  # pylint: disable=protected-access
2716        self.normalizer_fn)
2717    config['dtype'] = self.dtype.name
2718    return config
2719
2720  @classmethod
2721  def from_config(cls, config, custom_objects=None, columns_by_name=None):
2722    """See 'FeatureColumn` base class."""
2723    _check_config_keys(config, cls._fields)
2724    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
2725    kwargs = _standardize_and_copy_config(config)
2726    kwargs['normalizer_fn'] = serialization._deserialize_keras_object(  # pylint: disable=protected-access
2727        config['normalizer_fn'], custom_objects=custom_objects)
2728    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2729
2730    return cls(**kwargs)
2731
2732
2733class BucketizedColumn(
2734    DenseColumn,
2735    CategoricalColumn,
2736    fc_old._DenseColumn,  # pylint: disable=protected-access
2737    fc_old._CategoricalColumn,  # pylint: disable=protected-access
2738    collections.namedtuple('BucketizedColumn',
2739                           ('source_column', 'boundaries'))):
2740  """See `bucketized_column`."""
2741
2742  @property
2743  def _is_v2_column(self):
2744    return (isinstance(self.source_column, FeatureColumn) and
2745            self.source_column._is_v2_column)  # pylint: disable=protected-access
2746
2747  @property
2748  def name(self):
2749    """See `FeatureColumn` base class."""
2750    return '{}_bucketized'.format(self.source_column.name)
2751
2752  @property
2753  def parse_example_spec(self):
2754    """See `FeatureColumn` base class."""
2755    return self.source_column.parse_example_spec
2756
2757  @property
2758  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2759                          _FEATURE_COLUMN_DEPRECATION)
2760  def _parse_example_spec(self):
2761    return self.source_column._parse_example_spec  # pylint: disable=protected-access
2762
2763  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2764                          _FEATURE_COLUMN_DEPRECATION)
2765  def _transform_feature(self, inputs):
2766    """Returns bucketized categorical `source_column` tensor."""
2767    source_tensor = inputs.get(self.source_column)
2768    return math_ops._bucketize(  # pylint: disable=protected-access
2769        source_tensor,
2770        boundaries=self.boundaries)
2771
2772  def transform_feature(self, transformation_cache, state_manager):
2773    """Returns bucketized categorical `source_column` tensor."""
2774    source_tensor = transformation_cache.get(self.source_column, state_manager)
2775    return math_ops._bucketize(  # pylint: disable=protected-access
2776        source_tensor,
2777        boundaries=self.boundaries)
2778
2779  @property
2780  def variable_shape(self):
2781    """See `DenseColumn` base class."""
2782    return tensor_shape.TensorShape(
2783        tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
2784
2785  @property
2786  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2787                          _FEATURE_COLUMN_DEPRECATION)
2788  def _variable_shape(self):
2789    return self.variable_shape
2790
2791  def _get_dense_tensor_for_input_tensor(self, input_tensor):
2792    return array_ops.one_hot(
2793        indices=math_ops.cast(input_tensor, dtypes.int64),
2794        depth=len(self.boundaries) + 1,
2795        on_value=1.,
2796        off_value=0.)
2797
2798  def get_dense_tensor(self, transformation_cache, state_manager):
2799    """Returns one hot encoded dense `Tensor`."""
2800    input_tensor = transformation_cache.get(self, state_manager)
2801    return self._get_dense_tensor_for_input_tensor(input_tensor)
2802
2803  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2804                          _FEATURE_COLUMN_DEPRECATION)
2805  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2806    del weight_collections
2807    del trainable
2808    input_tensor = inputs.get(self)
2809    return self._get_dense_tensor_for_input_tensor(input_tensor)
2810
2811  @property
2812  def num_buckets(self):
2813    """See `CategoricalColumn` base class."""
2814    # By construction, source_column is always one-dimensional.
2815    return (len(self.boundaries) + 1) * self.source_column.shape[0]
2816
2817  @property
2818  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2819                          _FEATURE_COLUMN_DEPRECATION)
2820  def _num_buckets(self):
2821    return self.num_buckets
2822
2823  def _get_sparse_tensors_for_input_tensor(self, input_tensor):
2824    batch_size = array_ops.shape(input_tensor)[0]
2825    # By construction, source_column is always one-dimensional.
2826    source_dimension = self.source_column.shape[0]
2827
2828    i1 = array_ops.reshape(
2829        array_ops.tile(
2830            array_ops.expand_dims(math_ops.range(0, batch_size), 1),
2831            [1, source_dimension]),
2832        (-1,))
2833    i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
2834    # Flatten the bucket indices and unique them across dimensions
2835    # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
2836    bucket_indices = (
2837        array_ops.reshape(input_tensor, (-1,)) +
2838        (len(self.boundaries) + 1) * i2)
2839
2840    indices = math_ops.cast(
2841        array_ops.transpose(array_ops.stack((i1, i2))), dtypes.int64)
2842    dense_shape = math_ops.cast(
2843        array_ops.stack([batch_size, source_dimension]), dtypes.int64)
2844    sparse_tensor = sparse_tensor_lib.SparseTensor(
2845        indices=indices,
2846        values=bucket_indices,
2847        dense_shape=dense_shape)
2848    return CategoricalColumn.IdWeightPair(sparse_tensor, None)
2849
2850  def get_sparse_tensors(self, transformation_cache, state_manager):
2851    """Converts dense inputs to SparseTensor so downstream code can use it."""
2852    input_tensor = transformation_cache.get(self, state_manager)
2853    return self._get_sparse_tensors_for_input_tensor(input_tensor)
2854
2855  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2856                          _FEATURE_COLUMN_DEPRECATION)
2857  def _get_sparse_tensors(self, inputs, weight_collections=None,
2858                          trainable=None):
2859    """Converts dense inputs to SparseTensor so downstream code can use it."""
2860    del weight_collections
2861    del trainable
2862    input_tensor = inputs.get(self)
2863    return self._get_sparse_tensors_for_input_tensor(input_tensor)
2864
2865  @property
2866  def parents(self):
2867    """See 'FeatureColumn` base class."""
2868    return [self.source_column]
2869
2870  def get_config(self):
2871    """See 'FeatureColumn` base class."""
2872    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
2873    config = dict(zip(self._fields, self))
2874    config['source_column'] = serialize_feature_column(self.source_column)
2875    return config
2876
2877  @classmethod
2878  def from_config(cls, config, custom_objects=None, columns_by_name=None):
2879    """See 'FeatureColumn` base class."""
2880    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
2881    _check_config_keys(config, cls._fields)
2882    kwargs = _standardize_and_copy_config(config)
2883    kwargs['source_column'] = deserialize_feature_column(
2884        config['source_column'], custom_objects, columns_by_name)
2885    return cls(**kwargs)
2886
2887
2888class EmbeddingColumn(
2889    DenseColumn,
2890    SequenceDenseColumn,
2891    fc_old._DenseColumn,  # pylint: disable=protected-access
2892    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
2893    collections.namedtuple(
2894        'EmbeddingColumn',
2895        ('categorical_column', 'dimension', 'combiner', 'initializer',
2896         'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable',
2897         'use_safe_embedding_lookup'))):
2898  """See `embedding_column`."""
2899
2900  def __new__(cls,
2901              categorical_column,
2902              dimension,
2903              combiner,
2904              initializer,
2905              ckpt_to_load_from,
2906              tensor_name_in_ckpt,
2907              max_norm,
2908              trainable,
2909              use_safe_embedding_lookup=True):
2910    return super(EmbeddingColumn, cls).__new__(
2911        cls,
2912        categorical_column=categorical_column,
2913        dimension=dimension,
2914        combiner=combiner,
2915        initializer=initializer,
2916        ckpt_to_load_from=ckpt_to_load_from,
2917        tensor_name_in_ckpt=tensor_name_in_ckpt,
2918        max_norm=max_norm,
2919        trainable=trainable,
2920        use_safe_embedding_lookup=use_safe_embedding_lookup)
2921
2922  @property
2923  def _is_v2_column(self):
2924    return (isinstance(self.categorical_column, FeatureColumn) and
2925            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
2926
2927  @property
2928  def name(self):
2929    """See `FeatureColumn` base class."""
2930    return '{}_embedding'.format(self.categorical_column.name)
2931
2932  @property
2933  def parse_example_spec(self):
2934    """See `FeatureColumn` base class."""
2935    return self.categorical_column.parse_example_spec
2936
2937  @property
2938  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2939                          _FEATURE_COLUMN_DEPRECATION)
2940  def _parse_example_spec(self):
2941    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
2942
2943  def transform_feature(self, transformation_cache, state_manager):
2944    """Transforms underlying `categorical_column`."""
2945    return transformation_cache.get(self.categorical_column, state_manager)
2946
2947  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2948                          _FEATURE_COLUMN_DEPRECATION)
2949  def _transform_feature(self, inputs):
2950    return inputs.get(self.categorical_column)
2951
2952  @property
2953  def variable_shape(self):
2954    """See `DenseColumn` base class."""
2955    return tensor_shape.TensorShape([self.dimension])
2956
2957  @property
2958  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2959                          _FEATURE_COLUMN_DEPRECATION)
2960  def _variable_shape(self):
2961    return self.variable_shape
2962
2963  def create_state(self, state_manager):
2964    """Creates the embedding lookup variable."""
2965    default_num_buckets = (self.categorical_column.num_buckets
2966                           if self._is_v2_column
2967                           else self.categorical_column._num_buckets)   # pylint: disable=protected-access
2968    num_buckets = getattr(self.categorical_column, 'num_buckets',
2969                          default_num_buckets)
2970    embedding_shape = (num_buckets, self.dimension)
2971    state_manager.create_variable(
2972        self,
2973        name='embedding_weights',
2974        shape=embedding_shape,
2975        dtype=dtypes.float32,
2976        trainable=self.trainable,
2977        use_resource=True,
2978        initializer=self.initializer)
2979
2980  def _get_dense_tensor_internal_helper(self, sparse_tensors,
2981                                        embedding_weights):
2982    sparse_ids = sparse_tensors.id_tensor
2983    sparse_weights = sparse_tensors.weight_tensor
2984
2985    if self.ckpt_to_load_from is not None:
2986      to_restore = embedding_weights
2987      if isinstance(to_restore, variables.PartitionedVariable):
2988        to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
2989      checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
2990          self.tensor_name_in_ckpt: to_restore
2991      })
2992
2993    sparse_id_rank = tensor_shape.dimension_value(
2994        sparse_ids.dense_shape.get_shape()[0])
2995    embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
2996    if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
2997        sparse_id_rank <= 2):
2998      embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
2999    # Return embedding lookup result.
3000    return embedding_lookup_sparse(
3001        embedding_weights,
3002        sparse_ids,
3003        sparse_weights,
3004        combiner=self.combiner,
3005        name='%s_weights' % self.name,
3006        max_norm=self.max_norm)
3007
3008  def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
3009    """Private method that follows the signature of get_dense_tensor."""
3010    embedding_weights = state_manager.get_variable(
3011        self, name='embedding_weights')
3012    return self._get_dense_tensor_internal_helper(sparse_tensors,
3013                                                  embedding_weights)
3014
3015  def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
3016                                     trainable):
3017    """Private method that follows the signature of _get_dense_tensor."""
3018    embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
3019    if (weight_collections and
3020        ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
3021      weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
3022    embedding_weights = variable_scope.get_variable(
3023        name='embedding_weights',
3024        shape=embedding_shape,
3025        dtype=dtypes.float32,
3026        initializer=self.initializer,
3027        trainable=self.trainable and trainable,
3028        collections=weight_collections)
3029    return self._get_dense_tensor_internal_helper(sparse_tensors,
3030                                                  embedding_weights)
3031
3032  def get_dense_tensor(self, transformation_cache, state_manager):
3033    """Returns tensor after doing the embedding lookup.
3034
3035    Args:
3036      transformation_cache: A `FeatureTransformationCache` object to access
3037        features.
3038      state_manager: A `StateManager` to create / access resources such as
3039        lookup tables.
3040
3041    Returns:
3042      Embedding lookup tensor.
3043
3044    Raises:
3045      ValueError: `categorical_column` is SequenceCategoricalColumn.
3046    """
3047    if isinstance(self.categorical_column, SequenceCategoricalColumn):
3048      raise ValueError(
3049          'In embedding_column: {}. '
3050          'categorical_column must not be of type SequenceCategoricalColumn. '
3051          'Suggested fix A: If you wish to use DenseFeatures, use a '
3052          'non-sequence categorical_column_with_*. '
3053          'Suggested fix B: If you wish to create sequence input, use '
3054          'SequenceFeatures instead of DenseFeatures. '
3055          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3056                                       self.categorical_column))
3057    # Get sparse IDs and weights.
3058    sparse_tensors = self.categorical_column.get_sparse_tensors(
3059        transformation_cache, state_manager)
3060    return self._get_dense_tensor_internal(sparse_tensors, state_manager)
3061
3062  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3063                          _FEATURE_COLUMN_DEPRECATION)
3064  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3065    if isinstance(
3066        self.categorical_column,
3067        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
3068      raise ValueError(
3069          'In embedding_column: {}. '
3070          'categorical_column must not be of type _SequenceCategoricalColumn. '
3071          'Suggested fix A: If you wish to use DenseFeatures, use a '
3072          'non-sequence categorical_column_with_*. '
3073          'Suggested fix B: If you wish to create sequence input, use '
3074          'SequenceFeatures instead of DenseFeatures. '
3075          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3076                                       self.categorical_column))
3077    sparse_tensors = self.categorical_column._get_sparse_tensors(  # pylint: disable=protected-access
3078        inputs, weight_collections, trainable)
3079    return self._old_get_dense_tensor_internal(sparse_tensors,
3080                                               weight_collections, trainable)
3081
3082  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3083    """See `SequenceDenseColumn` base class."""
3084    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3085      raise ValueError(
3086          'In embedding_column: {}. '
3087          'categorical_column must be of type SequenceCategoricalColumn '
3088          'to use SequenceFeatures. '
3089          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3090          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3091                                       self.categorical_column))
3092    sparse_tensors = self.categorical_column.get_sparse_tensors(
3093        transformation_cache, state_manager)
3094    dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
3095                                                   state_manager)
3096    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3097        sparse_tensors.id_tensor)
3098    return SequenceDenseColumn.TensorSequenceLengthPair(
3099        dense_tensor=dense_tensor, sequence_length=sequence_length)
3100
3101  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3102                          _FEATURE_COLUMN_DEPRECATION)
3103  def _get_sequence_dense_tensor(self,
3104                                 inputs,
3105                                 weight_collections=None,
3106                                 trainable=None):
3107    if not isinstance(
3108        self.categorical_column,
3109        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
3110      raise ValueError(
3111          'In embedding_column: {}. '
3112          'categorical_column must be of type SequenceCategoricalColumn '
3113          'to use SequenceFeatures. '
3114          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3115          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3116                                       self.categorical_column))
3117    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
3118    dense_tensor = self._old_get_dense_tensor_internal(
3119        sparse_tensors,
3120        weight_collections=weight_collections,
3121        trainable=trainable)
3122    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3123        sparse_tensors.id_tensor)
3124    return SequenceDenseColumn.TensorSequenceLengthPair(
3125        dense_tensor=dense_tensor, sequence_length=sequence_length)
3126
3127  @property
3128  def parents(self):
3129    """See 'FeatureColumn` base class."""
3130    return [self.categorical_column]
3131
3132  def get_config(self):
3133    """See 'FeatureColumn` base class."""
3134    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
3135    config = dict(zip(self._fields, self))
3136    config['categorical_column'] = serialization.serialize_feature_column(
3137        self.categorical_column)
3138    config['initializer'] = serialization._serialize_keras_object(  # pylint: disable=protected-access
3139        self.initializer)
3140    return config
3141
3142  @classmethod
3143  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3144    """See 'FeatureColumn` base class."""
3145    if 'use_safe_embedding_lookup' not in config:
3146      config['use_safe_embedding_lookup'] = True
3147    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
3148    _check_config_keys(config, cls._fields)
3149    kwargs = _standardize_and_copy_config(config)
3150    kwargs['categorical_column'] = serialization.deserialize_feature_column(
3151        config['categorical_column'], custom_objects, columns_by_name)
3152    all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
3153    kwargs['initializer'] = serialization._deserialize_keras_object(  # pylint: disable=protected-access
3154        config['initializer'],
3155        module_objects=all_initializers,
3156        custom_objects=custom_objects)
3157    return cls(**kwargs)
3158
3159
3160def _raise_shared_embedding_column_error():
3161  raise ValueError('SharedEmbeddingColumns are not supported in '
3162                   '`linear_model` or `input_layer`. Please use '
3163                   '`DenseFeatures` or `LinearModel` instead.')
3164
3165
3166class SharedEmbeddingColumnCreator(tracking.AutoTrackable):
3167
3168  def __init__(self,
3169               dimension,
3170               initializer,
3171               ckpt_to_load_from,
3172               tensor_name_in_ckpt,
3173               num_buckets,
3174               trainable,
3175               name='shared_embedding_column_creator',
3176               use_safe_embedding_lookup=True):
3177    self._dimension = dimension
3178    self._initializer = initializer
3179    self._ckpt_to_load_from = ckpt_to_load_from
3180    self._tensor_name_in_ckpt = tensor_name_in_ckpt
3181    self._num_buckets = num_buckets
3182    self._trainable = trainable
3183    self._name = name
3184    self._use_safe_embedding_lookup = use_safe_embedding_lookup
3185    # Map from graph keys to embedding_weight variables.
3186    self._embedding_weights = {}
3187
3188  def __call__(self, categorical_column, combiner, max_norm):
3189    return SharedEmbeddingColumn(categorical_column, self, combiner, max_norm,
3190                                 self._use_safe_embedding_lookup)
3191
3192  @property
3193  def embedding_weights(self):
3194    key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
3195    if key not in self._embedding_weights:
3196      embedding_shape = (self._num_buckets, self._dimension)
3197      var = variable_scope.get_variable(
3198          name=self._name,
3199          shape=embedding_shape,
3200          dtype=dtypes.float32,
3201          initializer=self._initializer,
3202          trainable=self._trainable)
3203
3204      if self._ckpt_to_load_from is not None:
3205        to_restore = var
3206        if isinstance(to_restore, variables.PartitionedVariable):
3207          to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
3208        checkpoint_utils.init_from_checkpoint(
3209            self._ckpt_to_load_from, {self._tensor_name_in_ckpt: to_restore})
3210      self._embedding_weights[key] = var
3211    return self._embedding_weights[key]
3212
3213  @property
3214  def dimension(self):
3215    return self._dimension
3216
3217
3218class SharedEmbeddingColumn(
3219    DenseColumn,
3220    SequenceDenseColumn,
3221    fc_old._DenseColumn,  # pylint: disable=protected-access
3222    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
3223    collections.namedtuple(
3224        'SharedEmbeddingColumn',
3225        ('categorical_column', 'shared_embedding_column_creator', 'combiner',
3226         'max_norm', 'use_safe_embedding_lookup'))):
3227  """See `embedding_column`."""
3228
3229  def __new__(cls,
3230              categorical_column,
3231              shared_embedding_column_creator,
3232              combiner,
3233              max_norm,
3234              use_safe_embedding_lookup=True):
3235    return super(SharedEmbeddingColumn, cls).__new__(
3236        cls,
3237        categorical_column=categorical_column,
3238        shared_embedding_column_creator=shared_embedding_column_creator,
3239        combiner=combiner,
3240        max_norm=max_norm,
3241        use_safe_embedding_lookup=use_safe_embedding_lookup)
3242
3243  @property
3244  def _is_v2_column(self):
3245    return True
3246
3247  @property
3248  def name(self):
3249    """See `FeatureColumn` base class."""
3250    return '{}_shared_embedding'.format(self.categorical_column.name)
3251
3252  @property
3253  def parse_example_spec(self):
3254    """See `FeatureColumn` base class."""
3255    return self.categorical_column.parse_example_spec
3256
3257  @property
3258  def _parse_example_spec(self):
3259    return _raise_shared_embedding_column_error()
3260
3261  def transform_feature(self, transformation_cache, state_manager):
3262    """See `FeatureColumn` base class."""
3263    return transformation_cache.get(self.categorical_column, state_manager)
3264
3265  def _transform_feature(self, inputs):
3266    return _raise_shared_embedding_column_error()
3267
3268  @property
3269  def variable_shape(self):
3270    """See `DenseColumn` base class."""
3271    return tensor_shape.TensorShape(
3272        [self.shared_embedding_column_creator.dimension])
3273
3274  @property
3275  def _variable_shape(self):
3276    return _raise_shared_embedding_column_error()
3277
3278  def _get_dense_tensor_internal(self, transformation_cache, state_manager):
3279    """Private method that follows the signature of _get_dense_tensor."""
3280    # This method is called from a variable_scope with name _var_scope_name,
3281    # which is shared among all shared embeddings. Open a name_scope here, so
3282    # that the ops for different columns have distinct names.
3283    with ops.name_scope(None, default_name=self.name):
3284      # Get sparse IDs and weights.
3285      sparse_tensors = self.categorical_column.get_sparse_tensors(
3286          transformation_cache, state_manager)
3287      sparse_ids = sparse_tensors.id_tensor
3288      sparse_weights = sparse_tensors.weight_tensor
3289
3290      embedding_weights = self.shared_embedding_column_creator.embedding_weights
3291
3292      sparse_id_rank = tensor_shape.dimension_value(
3293          sparse_ids.dense_shape.get_shape()[0])
3294      embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
3295      if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
3296          sparse_id_rank <= 2):
3297        embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
3298      # Return embedding lookup result.
3299      return embedding_lookup_sparse(
3300          embedding_weights,
3301          sparse_ids,
3302          sparse_weights,
3303          combiner=self.combiner,
3304          name='%s_weights' % self.name,
3305          max_norm=self.max_norm)
3306
3307  def get_dense_tensor(self, transformation_cache, state_manager):
3308    """Returns the embedding lookup result."""
3309    if isinstance(self.categorical_column, SequenceCategoricalColumn):
3310      raise ValueError(
3311          'In embedding_column: {}. '
3312          'categorical_column must not be of type SequenceCategoricalColumn. '
3313          'Suggested fix A: If you wish to use DenseFeatures, use a '
3314          'non-sequence categorical_column_with_*. '
3315          'Suggested fix B: If you wish to create sequence input, use '
3316          'SequenceFeatures instead of DenseFeatures. '
3317          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3318                                       self.categorical_column))
3319    return self._get_dense_tensor_internal(transformation_cache, state_manager)
3320
3321  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3322    return _raise_shared_embedding_column_error()
3323
3324  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3325    """See `SequenceDenseColumn` base class."""
3326    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3327      raise ValueError(
3328          'In embedding_column: {}. '
3329          'categorical_column must be of type SequenceCategoricalColumn '
3330          'to use SequenceFeatures. '
3331          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3332          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3333                                       self.categorical_column))
3334    dense_tensor = self._get_dense_tensor_internal(transformation_cache,
3335                                                   state_manager)
3336    sparse_tensors = self.categorical_column.get_sparse_tensors(
3337        transformation_cache, state_manager)
3338    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3339        sparse_tensors.id_tensor)
3340    return SequenceDenseColumn.TensorSequenceLengthPair(
3341        dense_tensor=dense_tensor, sequence_length=sequence_length)
3342
3343  def _get_sequence_dense_tensor(self,
3344                                 inputs,
3345                                 weight_collections=None,
3346                                 trainable=None):
3347    return _raise_shared_embedding_column_error()
3348
3349  @property
3350  def parents(self):
3351    """See 'FeatureColumn` base class."""
3352    return [self.categorical_column]
3353
3354
3355def _check_shape(shape, key):
3356  """Returns shape if it's valid, raises error otherwise."""
3357  assert shape is not None
3358  if not nest.is_sequence(shape):
3359    shape = [shape]
3360  shape = tuple(shape)
3361  for dimension in shape:
3362    if not isinstance(dimension, int):
3363      raise TypeError('shape dimensions must be integer. '
3364                      'shape: {}, key: {}'.format(shape, key))
3365    if dimension < 1:
3366      raise ValueError('shape dimensions must be greater than 0. '
3367                       'shape: {}, key: {}'.format(shape, key))
3368  return shape
3369
3370
3371class HashedCategoricalColumn(
3372    CategoricalColumn,
3373    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3374    collections.namedtuple('HashedCategoricalColumn',
3375                           ('key', 'hash_bucket_size', 'dtype'))):
3376  """see `categorical_column_with_hash_bucket`."""
3377
3378  @property
3379  def _is_v2_column(self):
3380    return True
3381
3382  @property
3383  def name(self):
3384    """See `FeatureColumn` base class."""
3385    return self.key
3386
3387  @property
3388  def parse_example_spec(self):
3389    """See `FeatureColumn` base class."""
3390    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3391
3392  @property
3393  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3394                          _FEATURE_COLUMN_DEPRECATION)
3395  def _parse_example_spec(self):
3396    return self.parse_example_spec
3397
3398  def _transform_input_tensor(self, input_tensor):
3399    """Hashes the values in the feature_column."""
3400    if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
3401      raise ValueError('SparseColumn input must be a SparseTensor.')
3402
3403    fc_utils.assert_string_or_int(
3404        input_tensor.dtype,
3405        prefix='column_name: {} input_tensor'.format(self.key))
3406
3407    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3408      raise ValueError(
3409          'Column dtype and SparseTensors dtype must be compatible. '
3410          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3411              self.key, self.dtype, input_tensor.dtype))
3412
3413    if self.dtype == dtypes.string:
3414      sparse_values = input_tensor.values
3415    else:
3416      sparse_values = string_ops.as_string(input_tensor.values)
3417
3418    sparse_id_values = string_ops.string_to_hash_bucket_fast(
3419        sparse_values, self.hash_bucket_size, name='lookup')
3420    return sparse_tensor_lib.SparseTensor(
3421        input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
3422
3423  def transform_feature(self, transformation_cache, state_manager):
3424    """Hashes the values in the feature_column."""
3425    input_tensor = _to_sparse_input_and_drop_ignore_values(
3426        transformation_cache.get(self.key, state_manager))
3427    return self._transform_input_tensor(input_tensor)
3428
3429  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3430                          _FEATURE_COLUMN_DEPRECATION)
3431  def _transform_feature(self, inputs):
3432    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3433    return self._transform_input_tensor(input_tensor)
3434
3435  @property
3436  def num_buckets(self):
3437    """Returns number of buckets in this sparse feature."""
3438    return self.hash_bucket_size
3439
3440  @property
3441  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3442                          _FEATURE_COLUMN_DEPRECATION)
3443  def _num_buckets(self):
3444    return self.num_buckets
3445
3446  def get_sparse_tensors(self, transformation_cache, state_manager):
3447    """See `CategoricalColumn` base class."""
3448    return CategoricalColumn.IdWeightPair(
3449        transformation_cache.get(self, state_manager), None)
3450
3451  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3452                          _FEATURE_COLUMN_DEPRECATION)
3453  def _get_sparse_tensors(self, inputs, weight_collections=None,
3454                          trainable=None):
3455    del weight_collections
3456    del trainable
3457    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3458
3459  @property
3460  def parents(self):
3461    """See 'FeatureColumn` base class."""
3462    return [self.key]
3463
3464  def get_config(self):
3465    """See 'FeatureColumn` base class."""
3466    config = dict(zip(self._fields, self))
3467    config['dtype'] = self.dtype.name
3468    return config
3469
3470  @classmethod
3471  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3472    """See 'FeatureColumn` base class."""
3473    _check_config_keys(config, cls._fields)
3474    kwargs = _standardize_and_copy_config(config)
3475    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3476    return cls(**kwargs)
3477
3478
3479class VocabularyFileCategoricalColumn(
3480    CategoricalColumn,
3481    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3482    collections.namedtuple(
3483        'VocabularyFileCategoricalColumn',
3484        ('key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets',
3485         'dtype', 'default_value', 'file_format'))):
3486  """See `categorical_column_with_vocabulary_file`."""
3487
3488  def __new__(cls,
3489              key,
3490              vocabulary_file,
3491              vocabulary_size,
3492              num_oov_buckets,
3493              dtype,
3494              default_value,
3495              file_format=None):
3496    return super(VocabularyFileCategoricalColumn, cls).__new__(
3497        cls,
3498        key=key,
3499        vocabulary_file=vocabulary_file,
3500        vocabulary_size=vocabulary_size,
3501        num_oov_buckets=num_oov_buckets,
3502        dtype=dtype,
3503        default_value=default_value,
3504        file_format=file_format)
3505
3506  @property
3507  def _is_v2_column(self):
3508    return True
3509
3510  @property
3511  def name(self):
3512    """See `FeatureColumn` base class."""
3513    return self.key
3514
3515  @property
3516  def parse_example_spec(self):
3517    """See `FeatureColumn` base class."""
3518    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3519
3520  @property
3521  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3522                          _FEATURE_COLUMN_DEPRECATION)
3523  def _parse_example_spec(self):
3524    return self.parse_example_spec
3525
3526  def _make_table_from_tfrecord_gzip_file(self, key_dtype, name):
3527    dataset = readers.TFRecordDataset(
3528        self.vocabulary_file, compression_type='GZIP')
3529
3530    def key_dtype_fn(key):
3531      return key if key_dtype is dtypes.string else string_ops.string_to_number(
3532          key, out_type=key_dtype)
3533
3534    return data_lookup_ops.index_table_from_dataset(
3535        dataset.map(key_dtype_fn),
3536        num_oov_buckets=self.num_oov_buckets,
3537        vocab_size=self.vocabulary_size,
3538        default_value=self.default_value,
3539        key_dtype=key_dtype,
3540        name=name)
3541
3542  def _make_table(self, key_dtype, state_manager):
3543    name = '{}_lookup'.format(self.key)
3544    if state_manager is None or not state_manager.has_resource(self, name):
3545      with ops.init_scope():
3546        if self.file_format == 'tfrecord_gzip':
3547          table = self._make_table_from_tfrecord_gzip_file(key_dtype, name)
3548        else:
3549          table = lookup_ops.index_table_from_file(
3550              vocabulary_file=self.vocabulary_file,
3551              num_oov_buckets=self.num_oov_buckets,
3552              vocab_size=self.vocabulary_size,
3553              default_value=self.default_value,
3554              key_dtype=key_dtype,
3555              name=name)
3556      if state_manager is not None:
3557        state_manager.add_resource(self, name, table)
3558    else:
3559      # Reuse the table from the previous run.
3560      table = state_manager.get_resource(self, name)
3561    return table
3562
3563  def _transform_input_tensor(self, input_tensor, state_manager=None):
3564    """Creates a lookup table for the vocabulary."""
3565    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3566      raise ValueError(
3567          'Column dtype and SparseTensors dtype must be compatible. '
3568          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3569              self.key, self.dtype, input_tensor.dtype))
3570
3571    fc_utils.assert_string_or_int(
3572        input_tensor.dtype,
3573        prefix='column_name: {} input_tensor'.format(self.key))
3574
3575    key_dtype = self.dtype
3576    if input_tensor.dtype.is_integer:
3577      # `index_table_from_file` requires 64-bit integer keys.
3578      key_dtype = dtypes.int64
3579      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3580    return self._make_table(key_dtype, state_manager).lookup(input_tensor)
3581
3582  def transform_feature(self, transformation_cache, state_manager):
3583    """Creates a lookup table for the vocabulary."""
3584    input_tensor = _to_sparse_input_and_drop_ignore_values(
3585        transformation_cache.get(self.key, state_manager))
3586    return self._transform_input_tensor(input_tensor, state_manager)
3587
3588  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3589                          _FEATURE_COLUMN_DEPRECATION)
3590  def _transform_feature(self, inputs):
3591    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3592    return self._transform_input_tensor(input_tensor)
3593
3594  @property
3595  def num_buckets(self):
3596    """Returns number of buckets in this sparse feature."""
3597    return self.vocabulary_size + self.num_oov_buckets
3598
3599  @property
3600  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3601                          _FEATURE_COLUMN_DEPRECATION)
3602  def _num_buckets(self):
3603    return self.num_buckets
3604
3605  def get_sparse_tensors(self, transformation_cache, state_manager):
3606    """See `CategoricalColumn` base class."""
3607    return CategoricalColumn.IdWeightPair(
3608        transformation_cache.get(self, state_manager), None)
3609
3610  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3611                          _FEATURE_COLUMN_DEPRECATION)
3612  def _get_sparse_tensors(self, inputs, weight_collections=None,
3613                          trainable=None):
3614    del weight_collections
3615    del trainable
3616    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3617
3618  @property
3619  def parents(self):
3620    """See 'FeatureColumn` base class."""
3621    return [self.key]
3622
3623  def get_config(self):
3624    """See 'FeatureColumn` base class."""
3625    config = dict(zip(self._fields, self))
3626    config['dtype'] = self.dtype.name
3627    return config
3628
3629  @classmethod
3630  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3631    """See 'FeatureColumn` base class."""
3632    _check_config_keys(config, cls._fields)
3633    kwargs = _standardize_and_copy_config(config)
3634    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3635    return cls(**kwargs)
3636
3637
3638class VocabularyListCategoricalColumn(
3639    CategoricalColumn,
3640    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3641    collections.namedtuple(
3642        'VocabularyListCategoricalColumn',
3643        ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
3644):
3645  """See `categorical_column_with_vocabulary_list`."""
3646
3647  @property
3648  def _is_v2_column(self):
3649    return True
3650
3651  @property
3652  def name(self):
3653    """See `FeatureColumn` base class."""
3654    return self.key
3655
3656  @property
3657  def parse_example_spec(self):
3658    """See `FeatureColumn` base class."""
3659    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3660
3661  @property
3662  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3663                          _FEATURE_COLUMN_DEPRECATION)
3664  def _parse_example_spec(self):
3665    return self.parse_example_spec
3666
3667  def _transform_input_tensor(self, input_tensor, state_manager=None):
3668    """Creates a lookup table for the vocabulary list."""
3669    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3670      raise ValueError(
3671          'Column dtype and SparseTensors dtype must be compatible. '
3672          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3673              self.key, self.dtype, input_tensor.dtype))
3674
3675    fc_utils.assert_string_or_int(
3676        input_tensor.dtype,
3677        prefix='column_name: {} input_tensor'.format(self.key))
3678
3679    key_dtype = self.dtype
3680    if input_tensor.dtype.is_integer:
3681      # `index_table_from_tensor` requires 64-bit integer keys.
3682      key_dtype = dtypes.int64
3683      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3684
3685    name = '{}_lookup'.format(self.key)
3686    if state_manager is None or not state_manager.has_resource(self, name):
3687      with ops.init_scope():
3688        table = lookup_ops.index_table_from_tensor(
3689            vocabulary_list=tuple(self.vocabulary_list),
3690            default_value=self.default_value,
3691            num_oov_buckets=self.num_oov_buckets,
3692            dtype=key_dtype,
3693            name=name)
3694      if state_manager is not None:
3695        state_manager.add_resource(self, name, table)
3696    else:
3697      # Reuse the table from the previous run.
3698      table = state_manager.get_resource(self, name)
3699    return table.lookup(input_tensor)
3700
3701  def transform_feature(self, transformation_cache, state_manager):
3702    """Creates a lookup table for the vocabulary list."""
3703    input_tensor = _to_sparse_input_and_drop_ignore_values(
3704        transformation_cache.get(self.key, state_manager))
3705    return self._transform_input_tensor(input_tensor, state_manager)
3706
3707  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3708                          _FEATURE_COLUMN_DEPRECATION)
3709  def _transform_feature(self, inputs):
3710    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3711    return self._transform_input_tensor(input_tensor)
3712
3713  @property
3714  def num_buckets(self):
3715    """Returns number of buckets in this sparse feature."""
3716    return len(self.vocabulary_list) + self.num_oov_buckets
3717
3718  @property
3719  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3720                          _FEATURE_COLUMN_DEPRECATION)
3721  def _num_buckets(self):
3722    return self.num_buckets
3723
3724  def get_sparse_tensors(self, transformation_cache, state_manager):
3725    """See `CategoricalColumn` base class."""
3726    return CategoricalColumn.IdWeightPair(
3727        transformation_cache.get(self, state_manager), None)
3728
3729  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3730                          _FEATURE_COLUMN_DEPRECATION)
3731  def _get_sparse_tensors(self, inputs, weight_collections=None,
3732                          trainable=None):
3733    del weight_collections
3734    del trainable
3735    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3736
3737  @property
3738  def parents(self):
3739    """See 'FeatureColumn` base class."""
3740    return [self.key]
3741
3742  def get_config(self):
3743    """See 'FeatureColumn` base class."""
3744    config = dict(zip(self._fields, self))
3745    config['dtype'] = self.dtype.name
3746    return config
3747
3748  @classmethod
3749  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3750    """See 'FeatureColumn` base class."""
3751    _check_config_keys(config, cls._fields)
3752    kwargs = _standardize_and_copy_config(config)
3753    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3754    return cls(**kwargs)
3755
3756
3757class IdentityCategoricalColumn(
3758    CategoricalColumn,
3759    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3760    collections.namedtuple('IdentityCategoricalColumn',
3761                           ('key', 'number_buckets', 'default_value'))):
3762
3763  """See `categorical_column_with_identity`."""
3764
3765  @property
3766  def _is_v2_column(self):
3767    return True
3768
3769  @property
3770  def name(self):
3771    """See `FeatureColumn` base class."""
3772    return self.key
3773
3774  @property
3775  def parse_example_spec(self):
3776    """See `FeatureColumn` base class."""
3777    return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
3778
3779  @property
3780  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3781                          _FEATURE_COLUMN_DEPRECATION)
3782  def _parse_example_spec(self):
3783    return self.parse_example_spec
3784
3785  def _transform_input_tensor(self, input_tensor):
3786    """Returns a SparseTensor with identity values."""
3787    if not input_tensor.dtype.is_integer:
3788      raise ValueError(
3789          'Invalid input, not integer. key: {} dtype: {}'.format(
3790              self.key, input_tensor.dtype))
3791    values = input_tensor.values
3792    if input_tensor.values.dtype != dtypes.int64:
3793      values = math_ops.cast(values, dtypes.int64, name='values')
3794    if self.default_value is not None:
3795      values = math_ops.cast(input_tensor.values, dtypes.int64, name='values')
3796      num_buckets = math_ops.cast(
3797          self.num_buckets, dtypes.int64, name='num_buckets')
3798      zero = math_ops.cast(0, dtypes.int64, name='zero')
3799      # Assign default for out-of-range values.
3800      values = array_ops.where_v2(
3801          math_ops.logical_or(
3802              values < zero, values >= num_buckets, name='out_of_range'),
3803          array_ops.fill(
3804              dims=array_ops.shape(values),
3805              value=math_ops.cast(self.default_value, dtypes.int64),
3806              name='default_values'), values)
3807
3808    return sparse_tensor_lib.SparseTensor(
3809        indices=input_tensor.indices,
3810        values=values,
3811        dense_shape=input_tensor.dense_shape)
3812
3813  def transform_feature(self, transformation_cache, state_manager):
3814    """Returns a SparseTensor with identity values."""
3815    input_tensor = _to_sparse_input_and_drop_ignore_values(
3816        transformation_cache.get(self.key, state_manager))
3817    return self._transform_input_tensor(input_tensor)
3818
3819  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3820                          _FEATURE_COLUMN_DEPRECATION)
3821  def _transform_feature(self, inputs):
3822    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3823    return self._transform_input_tensor(input_tensor)
3824
3825  @property
3826  def num_buckets(self):
3827    """Returns number of buckets in this sparse feature."""
3828    return self.number_buckets
3829
3830  @property
3831  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3832                          _FEATURE_COLUMN_DEPRECATION)
3833  def _num_buckets(self):
3834    return self.num_buckets
3835
3836  def get_sparse_tensors(self, transformation_cache, state_manager):
3837    """See `CategoricalColumn` base class."""
3838    return CategoricalColumn.IdWeightPair(
3839        transformation_cache.get(self, state_manager), None)
3840
3841  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3842                          _FEATURE_COLUMN_DEPRECATION)
3843  def _get_sparse_tensors(self, inputs, weight_collections=None,
3844                          trainable=None):
3845    del weight_collections
3846    del trainable
3847    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3848
3849  @property
3850  def parents(self):
3851    """See 'FeatureColumn` base class."""
3852    return [self.key]
3853
3854  def get_config(self):
3855    """See 'FeatureColumn` base class."""
3856    return dict(zip(self._fields, self))
3857
3858  @classmethod
3859  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3860    """See 'FeatureColumn` base class."""
3861    _check_config_keys(config, cls._fields)
3862    kwargs = _standardize_and_copy_config(config)
3863    return cls(**kwargs)
3864
3865
3866class WeightedCategoricalColumn(
3867    CategoricalColumn,
3868    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3869    collections.namedtuple(
3870        'WeightedCategoricalColumn',
3871        ('categorical_column', 'weight_feature_key', 'dtype'))):
3872  """See `weighted_categorical_column`."""
3873
3874  @property
3875  def _is_v2_column(self):
3876    return (isinstance(self.categorical_column, FeatureColumn) and
3877            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
3878
3879  @property
3880  def name(self):
3881    """See `FeatureColumn` base class."""
3882    return '{}_weighted_by_{}'.format(
3883        self.categorical_column.name, self.weight_feature_key)
3884
3885  @property
3886  def parse_example_spec(self):
3887    """See `FeatureColumn` base class."""
3888    config = self.categorical_column.parse_example_spec
3889    if self.weight_feature_key in config:
3890      raise ValueError('Parse config {} already exists for {}.'.format(
3891          config[self.weight_feature_key], self.weight_feature_key))
3892    config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
3893    return config
3894
3895  @property
3896  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3897                          _FEATURE_COLUMN_DEPRECATION)
3898  def _parse_example_spec(self):
3899    config = self.categorical_column._parse_example_spec  # pylint: disable=protected-access
3900    if self.weight_feature_key in config:
3901      raise ValueError('Parse config {} already exists for {}.'.format(
3902          config[self.weight_feature_key], self.weight_feature_key))
3903    config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
3904    return config
3905
3906  @property
3907  def num_buckets(self):
3908    """See `DenseColumn` base class."""
3909    return self.categorical_column.num_buckets
3910
3911  @property
3912  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3913                          _FEATURE_COLUMN_DEPRECATION)
3914  def _num_buckets(self):
3915    return self.categorical_column._num_buckets  # pylint: disable=protected-access
3916
3917  def _transform_weight_tensor(self, weight_tensor):
3918    if weight_tensor is None:
3919      raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
3920    weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
3921        weight_tensor)
3922    if self.dtype != weight_tensor.dtype.base_dtype:
3923      raise ValueError('Bad dtype, expected {}, but got {}.'.format(
3924          self.dtype, weight_tensor.dtype))
3925    if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
3926      # The weight tensor can be a regular Tensor. In this case, sparsify it.
3927      weight_tensor = _to_sparse_input_and_drop_ignore_values(
3928          weight_tensor, ignore_value=0.0)
3929    if not weight_tensor.dtype.is_floating:
3930      weight_tensor = math_ops.cast(weight_tensor, dtypes.float32)
3931    return weight_tensor
3932
3933  def transform_feature(self, transformation_cache, state_manager):
3934    """Applies weights to tensor generated from `categorical_column`'."""
3935    weight_tensor = transformation_cache.get(self.weight_feature_key,
3936                                             state_manager)
3937    sparse_weight_tensor = self._transform_weight_tensor(weight_tensor)
3938    sparse_categorical_tensor = _to_sparse_input_and_drop_ignore_values(
3939        transformation_cache.get(self.categorical_column, state_manager))
3940    return (sparse_categorical_tensor, sparse_weight_tensor)
3941
3942  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3943                          _FEATURE_COLUMN_DEPRECATION)
3944  def _transform_feature(self, inputs):
3945    """Applies weights to tensor generated from `categorical_column`'."""
3946    weight_tensor = inputs.get(self.weight_feature_key)
3947    weight_tensor = self._transform_weight_tensor(weight_tensor)
3948    return (inputs.get(self.categorical_column), weight_tensor)
3949
3950  def get_sparse_tensors(self, transformation_cache, state_manager):
3951    """See `CategoricalColumn` base class."""
3952    tensors = transformation_cache.get(self, state_manager)
3953    return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
3954
3955  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3956                          _FEATURE_COLUMN_DEPRECATION)
3957  def _get_sparse_tensors(self, inputs, weight_collections=None,
3958                          trainable=None):
3959    del weight_collections
3960    del trainable
3961    tensors = inputs.get(self)
3962    return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
3963
3964  @property
3965  def parents(self):
3966    """See 'FeatureColumn` base class."""
3967    return [self.categorical_column, self.weight_feature_key]
3968
3969  def get_config(self):
3970    """See 'FeatureColumn` base class."""
3971    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
3972    config = dict(zip(self._fields, self))
3973    config['categorical_column'] = serialize_feature_column(
3974        self.categorical_column)
3975    config['dtype'] = self.dtype.name
3976    return config
3977
3978  @classmethod
3979  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3980    """See 'FeatureColumn` base class."""
3981    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
3982    _check_config_keys(config, cls._fields)
3983    kwargs = _standardize_and_copy_config(config)
3984    kwargs['categorical_column'] = deserialize_feature_column(
3985        config['categorical_column'], custom_objects, columns_by_name)
3986    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3987    return cls(**kwargs)
3988
3989
3990class CrossedColumn(
3991    CategoricalColumn,
3992    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3993    collections.namedtuple('CrossedColumn',
3994                           ('keys', 'hash_bucket_size', 'hash_key'))):
3995  """See `crossed_column`."""
3996
3997  @property
3998  def _is_v2_column(self):
3999    for key in _collect_leaf_level_keys(self):
4000      if isinstance(key, six.string_types):
4001        continue
4002      if not isinstance(key, FeatureColumn):
4003        return False
4004      if not key._is_v2_column:  # pylint: disable=protected-access
4005        return False
4006    return True
4007
4008  @property
4009  def name(self):
4010    """See `FeatureColumn` base class."""
4011    feature_names = []
4012    for key in _collect_leaf_level_keys(self):
4013      if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)):  # pylint: disable=protected-access
4014        feature_names.append(key.name)
4015      else:  # key must be a string
4016        feature_names.append(key)
4017    return '_X_'.join(sorted(feature_names))
4018
4019  @property
4020  def parse_example_spec(self):
4021    """See `FeatureColumn` base class."""
4022    config = {}
4023    for key in self.keys:
4024      if isinstance(key, FeatureColumn):
4025        config.update(key.parse_example_spec)
4026      elif isinstance(key, fc_old._FeatureColumn):  # pylint: disable=protected-access
4027        config.update(key._parse_example_spec)  # pylint: disable=protected-access
4028      else:  # key must be a string
4029        config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
4030    return config
4031
4032  @property
4033  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4034                          _FEATURE_COLUMN_DEPRECATION)
4035  def _parse_example_spec(self):
4036    return self.parse_example_spec
4037
4038  def transform_feature(self, transformation_cache, state_manager):
4039    """Generates a hashed sparse cross from the input tensors."""
4040    feature_tensors = []
4041    for key in _collect_leaf_level_keys(self):
4042      if isinstance(key, six.string_types):
4043        feature_tensors.append(transformation_cache.get(key, state_manager))
4044      elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)):  # pylint: disable=protected-access
4045        ids_and_weights = key.get_sparse_tensors(transformation_cache,
4046                                                 state_manager)
4047        if ids_and_weights.weight_tensor is not None:
4048          raise ValueError(
4049              'crossed_column does not support weight_tensor, but the given '
4050              'column populates weight_tensor. '
4051              'Given column: {}'.format(key.name))
4052        feature_tensors.append(ids_and_weights.id_tensor)
4053      else:
4054        raise ValueError('Unsupported column type. Given: {}'.format(key))
4055    return sparse_ops.sparse_cross_hashed(
4056        inputs=feature_tensors,
4057        num_buckets=self.hash_bucket_size,
4058        hash_key=self.hash_key)
4059
4060  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4061                          _FEATURE_COLUMN_DEPRECATION)
4062  def _transform_feature(self, inputs):
4063    """Generates a hashed sparse cross from the input tensors."""
4064    feature_tensors = []
4065    for key in _collect_leaf_level_keys(self):
4066      if isinstance(key, six.string_types):
4067        feature_tensors.append(inputs.get(key))
4068      elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)):  # pylint: disable=protected-access
4069        ids_and_weights = key._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4070        if ids_and_weights.weight_tensor is not None:
4071          raise ValueError(
4072              'crossed_column does not support weight_tensor, but the given '
4073              'column populates weight_tensor. '
4074              'Given column: {}'.format(key.name))
4075        feature_tensors.append(ids_and_weights.id_tensor)
4076      else:
4077        raise ValueError('Unsupported column type. Given: {}'.format(key))
4078    return sparse_ops.sparse_cross_hashed(
4079        inputs=feature_tensors,
4080        num_buckets=self.hash_bucket_size,
4081        hash_key=self.hash_key)
4082
4083  @property
4084  def num_buckets(self):
4085    """Returns number of buckets in this sparse feature."""
4086    return self.hash_bucket_size
4087
4088  @property
4089  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4090                          _FEATURE_COLUMN_DEPRECATION)
4091  def _num_buckets(self):
4092    return self.num_buckets
4093
4094  def get_sparse_tensors(self, transformation_cache, state_manager):
4095    """See `CategoricalColumn` base class."""
4096    return CategoricalColumn.IdWeightPair(
4097        transformation_cache.get(self, state_manager), None)
4098
4099  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4100                          _FEATURE_COLUMN_DEPRECATION)
4101  def _get_sparse_tensors(self, inputs, weight_collections=None,
4102                          trainable=None):
4103    """See `CategoricalColumn` base class."""
4104    del weight_collections
4105    del trainable
4106    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
4107
4108  @property
4109  def parents(self):
4110    """See 'FeatureColumn` base class."""
4111    return list(self.keys)
4112
4113  def get_config(self):
4114    """See 'FeatureColumn` base class."""
4115    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4116    config = dict(zip(self._fields, self))
4117    config['keys'] = tuple([serialize_feature_column(fc) for fc in self.keys])
4118    return config
4119
4120  @classmethod
4121  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4122    """See 'FeatureColumn` base class."""
4123    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4124    _check_config_keys(config, cls._fields)
4125    kwargs = _standardize_and_copy_config(config)
4126    kwargs['keys'] = tuple([
4127        deserialize_feature_column(c, custom_objects, columns_by_name)
4128        for c in config['keys']
4129    ])
4130    return cls(**kwargs)
4131
4132
4133def _collect_leaf_level_keys(cross):
4134  """Collects base keys by expanding all nested crosses.
4135
4136  Args:
4137    cross: A `CrossedColumn`.
4138
4139  Returns:
4140    A list of strings or `CategoricalColumn` instances.
4141  """
4142  leaf_level_keys = []
4143  for k in cross.keys:
4144    if isinstance(k, CrossedColumn):
4145      leaf_level_keys.extend(_collect_leaf_level_keys(k))
4146    else:
4147      leaf_level_keys.append(k)
4148  return leaf_level_keys
4149
4150
4151def _prune_invalid_ids(sparse_ids, sparse_weights):
4152  """Prune invalid IDs (< 0) from the input ids and weights."""
4153  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
4154  if sparse_weights is not None:
4155    is_id_valid = math_ops.logical_and(
4156        is_id_valid,
4157        array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
4158  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
4159  if sparse_weights is not None:
4160    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
4161  return sparse_ids, sparse_weights
4162
4163
4164def _prune_invalid_weights(sparse_ids, sparse_weights):
4165  """Prune invalid weights (< 0) from the input ids and weights."""
4166  if sparse_weights is not None:
4167    is_weights_valid = math_ops.greater(sparse_weights.values, 0)
4168    sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
4169    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
4170  return sparse_ids, sparse_weights
4171
4172
4173class IndicatorColumn(
4174    DenseColumn,
4175    SequenceDenseColumn,
4176    fc_old._DenseColumn,  # pylint: disable=protected-access
4177    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
4178    collections.namedtuple('IndicatorColumn', ('categorical_column'))):
4179  """Represents a one-hot column for use in deep networks.
4180
4181  Args:
4182    categorical_column: A `CategoricalColumn` which is created by
4183      `categorical_column_with_*` function.
4184  """
4185
4186  @property
4187  def _is_v2_column(self):
4188    return (isinstance(self.categorical_column, FeatureColumn) and
4189            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
4190
4191  @property
4192  def name(self):
4193    """See `FeatureColumn` base class."""
4194    return '{}_indicator'.format(self.categorical_column.name)
4195
4196  def _transform_id_weight_pair(self, id_weight_pair, size):
4197    id_tensor = id_weight_pair.id_tensor
4198    weight_tensor = id_weight_pair.weight_tensor
4199
4200    # If the underlying column is weighted, return the input as a dense tensor.
4201    if weight_tensor is not None:
4202      weighted_column = sparse_ops.sparse_merge(
4203          sp_ids=id_tensor, sp_values=weight_tensor, vocab_size=int(size))
4204      # Remove (?, -1) index.
4205      weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
4206                                                weighted_column.dense_shape)
4207      # Use scatter_nd to merge duplicated indices if existed,
4208      # instead of sparse_tensor_to_dense.
4209      return array_ops.scatter_nd(weighted_column.indices,
4210                                  weighted_column.values,
4211                                  weighted_column.dense_shape)
4212
4213    dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
4214        id_tensor, default_value=-1)
4215
4216    # One hot must be float for tf.concat reasons since all other inputs to
4217    # input_layer are float32.
4218    one_hot_id_tensor = array_ops.one_hot(
4219        dense_id_tensor, depth=size, on_value=1.0, off_value=0.0)
4220
4221    # Reduce to get a multi-hot per example.
4222    return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
4223
4224  def transform_feature(self, transformation_cache, state_manager):
4225    """Returns dense `Tensor` representing feature.
4226
4227    Args:
4228      transformation_cache: A `FeatureTransformationCache` object to access
4229        features.
4230      state_manager: A `StateManager` to create / access resources such as
4231        lookup tables.
4232
4233    Returns:
4234      Transformed feature `Tensor`.
4235
4236    Raises:
4237      ValueError: if input rank is not known at graph building time.
4238    """
4239    id_weight_pair = self.categorical_column.get_sparse_tensors(
4240        transformation_cache, state_manager)
4241    return self._transform_id_weight_pair(id_weight_pair,
4242                                          self.variable_shape[-1])
4243
4244  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4245                          _FEATURE_COLUMN_DEPRECATION)
4246  def _transform_feature(self, inputs):
4247    id_weight_pair = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4248    return self._transform_id_weight_pair(id_weight_pair,
4249                                          self._variable_shape[-1])
4250
4251  @property
4252  def parse_example_spec(self):
4253    """See `FeatureColumn` base class."""
4254    return self.categorical_column.parse_example_spec
4255
4256  @property
4257  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4258                          _FEATURE_COLUMN_DEPRECATION)
4259  def _parse_example_spec(self):
4260    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
4261
4262  @property
4263  def variable_shape(self):
4264    """Returns a `TensorShape` representing the shape of the dense `Tensor`."""
4265    if isinstance(self.categorical_column, FeatureColumn):
4266      return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
4267    else:
4268      return tensor_shape.TensorShape([1, self.categorical_column._num_buckets])  # pylint: disable=protected-access
4269
4270  @property
4271  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4272                          _FEATURE_COLUMN_DEPRECATION)
4273  def _variable_shape(self):
4274    return tensor_shape.TensorShape([1, self.categorical_column._num_buckets])  # pylint: disable=protected-access
4275
4276  def get_dense_tensor(self, transformation_cache, state_manager):
4277    """Returns dense `Tensor` representing feature.
4278
4279    Args:
4280      transformation_cache: A `FeatureTransformationCache` object to access
4281        features.
4282      state_manager: A `StateManager` to create / access resources such as
4283        lookup tables.
4284
4285    Returns:
4286      Dense `Tensor` created within `transform_feature`.
4287
4288    Raises:
4289      ValueError: If `categorical_column` is a `SequenceCategoricalColumn`.
4290    """
4291    if isinstance(self.categorical_column, SequenceCategoricalColumn):
4292      raise ValueError(
4293          'In indicator_column: {}. '
4294          'categorical_column must not be of type SequenceCategoricalColumn. '
4295          'Suggested fix A: If you wish to use DenseFeatures, use a '
4296          'non-sequence categorical_column_with_*. '
4297          'Suggested fix B: If you wish to create sequence input, use '
4298          'SequenceFeatures instead of DenseFeatures. '
4299          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4300                                       self.categorical_column))
4301    # Feature has been already transformed. Return the intermediate
4302    # representation created by transform_feature.
4303    return transformation_cache.get(self, state_manager)
4304
4305  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4306                          _FEATURE_COLUMN_DEPRECATION)
4307  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
4308    del weight_collections
4309    del trainable
4310    if isinstance(
4311        self.categorical_column,
4312        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
4313      raise ValueError(
4314          'In indicator_column: {}. '
4315          'categorical_column must not be of type _SequenceCategoricalColumn. '
4316          'Suggested fix A: If you wish to use DenseFeatures, use a '
4317          'non-sequence categorical_column_with_*. '
4318          'Suggested fix B: If you wish to create sequence input, use '
4319          'SequenceFeatures instead of DenseFeatures. '
4320          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4321                                       self.categorical_column))
4322    # Feature has been already transformed. Return the intermediate
4323    # representation created by transform_feature.
4324    return inputs.get(self)
4325
4326  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
4327    """See `SequenceDenseColumn` base class."""
4328    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
4329      raise ValueError(
4330          'In indicator_column: {}. '
4331          'categorical_column must be of type SequenceCategoricalColumn '
4332          'to use SequenceFeatures. '
4333          'Suggested fix: Use one of sequence_categorical_column_with_*. '
4334          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4335                                       self.categorical_column))
4336    # Feature has been already transformed. Return the intermediate
4337    # representation created by transform_feature.
4338    dense_tensor = transformation_cache.get(self, state_manager)
4339    sparse_tensors = self.categorical_column.get_sparse_tensors(
4340        transformation_cache, state_manager)
4341    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4342        sparse_tensors.id_tensor)
4343    return SequenceDenseColumn.TensorSequenceLengthPair(
4344        dense_tensor=dense_tensor, sequence_length=sequence_length)
4345
4346  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4347                          _FEATURE_COLUMN_DEPRECATION)
4348  def _get_sequence_dense_tensor(self,
4349                                 inputs,
4350                                 weight_collections=None,
4351                                 trainable=None):
4352    # Do nothing with weight_collections and trainable since no variables are
4353    # created in this function.
4354    del weight_collections
4355    del trainable
4356    if not isinstance(
4357        self.categorical_column,
4358        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
4359      raise ValueError(
4360          'In indicator_column: {}. '
4361          'categorical_column must be of type _SequenceCategoricalColumn '
4362          'to use SequenceFeatures. '
4363          'Suggested fix: Use one of sequence_categorical_column_with_*. '
4364          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4365                                       self.categorical_column))
4366    # Feature has been already transformed. Return the intermediate
4367    # representation created by _transform_feature.
4368    dense_tensor = inputs.get(self)
4369    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4370    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4371        sparse_tensors.id_tensor)
4372    return SequenceDenseColumn.TensorSequenceLengthPair(
4373        dense_tensor=dense_tensor, sequence_length=sequence_length)
4374
4375  @property
4376  def parents(self):
4377    """See 'FeatureColumn` base class."""
4378    return [self.categorical_column]
4379
4380  def get_config(self):
4381    """See 'FeatureColumn` base class."""
4382    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4383    config = dict(zip(self._fields, self))
4384    config['categorical_column'] = serialize_feature_column(
4385        self.categorical_column)
4386    return config
4387
4388  @classmethod
4389  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4390    """See 'FeatureColumn` base class."""
4391    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4392    _check_config_keys(config, cls._fields)
4393    kwargs = _standardize_and_copy_config(config)
4394    kwargs['categorical_column'] = deserialize_feature_column(
4395        config['categorical_column'], custom_objects, columns_by_name)
4396    return cls(**kwargs)
4397
4398
4399def _verify_static_batch_size_equality(tensors, columns):
4400  """Verify equality between static batch sizes.
4401
4402  Args:
4403    tensors: iterable of input tensors.
4404    columns: Corresponding feature columns.
4405
4406  Raises:
4407    ValueError: in case of mismatched batch sizes.
4408  """
4409  # bath_size is a Dimension object.
4410  expected_batch_size = None
4411  for i in range(0, len(tensors)):
4412    batch_size = tensor_shape.Dimension(tensor_shape.dimension_value(
4413        tensors[i].shape[0]))
4414    if batch_size.value is not None:
4415      if expected_batch_size is None:
4416        bath_size_column_index = i
4417        expected_batch_size = batch_size
4418      elif not expected_batch_size.is_compatible_with(batch_size):
4419        raise ValueError(
4420            'Batch size (first dimension) of each feature must be same. '
4421            'Batch size of columns ({}, {}): ({}, {})'.format(
4422                columns[bath_size_column_index].name, columns[i].name,
4423                expected_batch_size, batch_size))
4424
4425
4426class SequenceCategoricalColumn(
4427    CategoricalColumn,
4428    fc_old._SequenceCategoricalColumn,  # pylint: disable=protected-access
4429    collections.namedtuple('SequenceCategoricalColumn',
4430                           ('categorical_column'))):
4431  """Represents sequences of categorical data."""
4432
4433  @property
4434  def _is_v2_column(self):
4435    return (isinstance(self.categorical_column, FeatureColumn) and
4436            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
4437
4438  @property
4439  def name(self):
4440    """See `FeatureColumn` base class."""
4441    return self.categorical_column.name
4442
4443  @property
4444  def parse_example_spec(self):
4445    """See `FeatureColumn` base class."""
4446    return self.categorical_column.parse_example_spec
4447
4448  @property
4449  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4450                          _FEATURE_COLUMN_DEPRECATION)
4451  def _parse_example_spec(self):
4452    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
4453
4454  def transform_feature(self, transformation_cache, state_manager):
4455    """See `FeatureColumn` base class."""
4456    return self.categorical_column.transform_feature(transformation_cache,
4457                                                     state_manager)
4458
4459  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4460                          _FEATURE_COLUMN_DEPRECATION)
4461  def _transform_feature(self, inputs):
4462    return self.categorical_column._transform_feature(inputs)  # pylint: disable=protected-access
4463
4464  @property
4465  def num_buckets(self):
4466    """Returns number of buckets in this sparse feature."""
4467    return self.categorical_column.num_buckets
4468
4469  @property
4470  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4471                          _FEATURE_COLUMN_DEPRECATION)
4472  def _num_buckets(self):
4473    return self.categorical_column._num_buckets  # pylint: disable=protected-access
4474
4475  def _get_sparse_tensors_helper(self, sparse_tensors):
4476    id_tensor = sparse_tensors.id_tensor
4477    weight_tensor = sparse_tensors.weight_tensor
4478    # Expands third dimension, if necessary so that embeddings are not
4479    # combined during embedding lookup. If the tensor is already 3D, leave
4480    # as-is.
4481    shape = array_ops.shape(id_tensor)
4482    # Compute the third dimension explicitly instead of setting it to -1, as
4483    # that doesn't work for dynamically shaped tensors with 0-length at runtime.
4484    # This happens for empty sequences.
4485    target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
4486    id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
4487    if weight_tensor is not None:
4488      weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
4489    return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
4490
4491  def get_sparse_tensors(self, transformation_cache, state_manager):
4492    """Returns an IdWeightPair.
4493
4494    `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
4495    weights.
4496
4497    `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
4498    `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
4499    `SparseTensor` of `float` or `None` to indicate all weights should be
4500    taken to be 1. If specified, `weight_tensor` must have exactly the same
4501    shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
4502    output of a `VarLenFeature` which is a ragged matrix.
4503
4504    Args:
4505      transformation_cache: A `FeatureTransformationCache` object to access
4506        features.
4507      state_manager: A `StateManager` to create / access resources such as
4508        lookup tables.
4509    """
4510    sparse_tensors = self.categorical_column.get_sparse_tensors(
4511        transformation_cache, state_manager)
4512    return self._get_sparse_tensors_helper(sparse_tensors)
4513
4514  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4515                          _FEATURE_COLUMN_DEPRECATION)
4516  def _get_sparse_tensors(self, inputs, weight_collections=None,
4517                          trainable=None):
4518    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4519    return self._get_sparse_tensors_helper(sparse_tensors)
4520
4521  @property
4522  def parents(self):
4523    """See 'FeatureColumn` base class."""
4524    return [self.categorical_column]
4525
4526  def get_config(self):
4527    """See 'FeatureColumn` base class."""
4528    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4529    config = dict(zip(self._fields, self))
4530    config['categorical_column'] = serialize_feature_column(
4531        self.categorical_column)
4532    return config
4533
4534  @classmethod
4535  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4536    """See 'FeatureColumn` base class."""
4537    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4538    _check_config_keys(config, cls._fields)
4539    kwargs = _standardize_and_copy_config(config)
4540    kwargs['categorical_column'] = deserialize_feature_column(
4541        config['categorical_column'], custom_objects, columns_by_name)
4542    return cls(**kwargs)
4543
4544
4545def _check_config_keys(config, expected_keys):
4546  """Checks that a config has all expected_keys."""
4547  if set(config.keys()) != set(expected_keys):
4548    raise ValueError('Invalid config: {}, expected keys: {}'.format(
4549        config, expected_keys))
4550
4551
4552def _standardize_and_copy_config(config):
4553  """Returns a shallow copy of config with lists turned to tuples.
4554
4555  Keras serialization uses nest to listify everything.
4556  This causes problems with the NumericColumn shape, which becomes
4557  unhashable. We could try to solve this on the Keras side, but that
4558  would require lots of tracking to avoid changing existing behavior.
4559  Instead, we ensure here that we revive correctly.
4560
4561  Args:
4562    config: dict that will be used to revive a Feature Column
4563
4564  Returns:
4565    Shallow copy of config with lists turned to tuples.
4566  """
4567  kwargs = config.copy()
4568  for k, v in kwargs.items():
4569    if isinstance(v, list):
4570      kwargs[k] = tuple(v)
4571
4572  return kwargs
4573
4574
4575def _sanitize_column_name_for_variable_scope(name):
4576  """Sanitizes user-provided feature names for use as variable scopes."""
4577  invalid_char = re.compile('[^A-Za-z0-9_.\\-]')
4578  return invalid_char.sub('_', name)
4579