• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Embedding layer.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import sharded_variable
22from tensorflow.python.eager import context
23from tensorflow.python.framework import config as tf_config
24from tensorflow.python.framework import ops
25from tensorflow.python.keras import backend as K
26from tensorflow.python.keras import constraints
27from tensorflow.python.keras import initializers
28from tensorflow.python.keras import regularizers
29from tensorflow.python.keras.engine import base_layer_utils
30from tensorflow.python.keras.engine.base_layer import Layer
31from tensorflow.python.keras.utils import tf_utils
32from tensorflow.python.ops import embedding_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.util.tf_export import keras_export
35
36
37@keras_export('keras.layers.Embedding')
38class Embedding(Layer):
39  """Turns positive integers (indexes) into dense vectors of fixed size.
40
41  e.g. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]`
42
43  This layer can only be used as the first layer in a model.
44
45  Example:
46
47  >>> model = tf.keras.Sequential()
48  >>> model.add(tf.keras.layers.Embedding(1000, 64, input_length=10))
49  >>> # The model will take as input an integer matrix of size (batch,
50  >>> # input_length), and the largest integer (i.e. word index) in the input
51  >>> # should be no larger than 999 (vocabulary size).
52  >>> # Now model.output_shape is (None, 10, 64), where `None` is the batch
53  >>> # dimension.
54  >>> input_array = np.random.randint(1000, size=(32, 10))
55  >>> model.compile('rmsprop', 'mse')
56  >>> output_array = model.predict(input_array)
57  >>> print(output_array.shape)
58  (32, 10, 64)
59
60  Args:
61    input_dim: Integer. Size of the vocabulary,
62      i.e. maximum integer index + 1.
63    output_dim: Integer. Dimension of the dense embedding.
64    embeddings_initializer: Initializer for the `embeddings`
65      matrix (see `keras.initializers`).
66    embeddings_regularizer: Regularizer function applied to
67      the `embeddings` matrix (see `keras.regularizers`).
68    embeddings_constraint: Constraint function applied to
69      the `embeddings` matrix (see `keras.constraints`).
70    mask_zero: Boolean, whether or not the input value 0 is a special "padding"
71      value that should be masked out.
72      This is useful when using recurrent layers
73      which may take variable length input.
74      If this is `True`, then all subsequent layers
75      in the model need to support masking or an exception will be raised.
76      If mask_zero is set to True, as a consequence, index 0 cannot be
77      used in the vocabulary (input_dim should equal size of
78      vocabulary + 1).
79    input_length: Length of input sequences, when it is constant.
80      This argument is required if you are going to connect
81      `Flatten` then `Dense` layers upstream
82      (without it, the shape of the dense outputs cannot be computed).
83
84  Input shape:
85    2D tensor with shape: `(batch_size, input_length)`.
86
87  Output shape:
88    3D tensor with shape: `(batch_size, input_length, output_dim)`.
89  """
90
91  def __init__(self,
92               input_dim,
93               output_dim,
94               embeddings_initializer='uniform',
95               embeddings_regularizer=None,
96               activity_regularizer=None,
97               embeddings_constraint=None,
98               mask_zero=False,
99               input_length=None,
100               **kwargs):
101    if 'input_shape' not in kwargs:
102      if input_length:
103        kwargs['input_shape'] = (input_length,)
104      else:
105        kwargs['input_shape'] = (None,)
106    if input_dim <= 0 or output_dim <= 0:
107      raise ValueError('Both `input_dim` and `output_dim` should be positive, '
108                       'found input_dim {} and output_dim {}'.format(
109                           input_dim, output_dim))
110    if (not base_layer_utils.v2_dtype_behavior_enabled() and
111        'dtype' not in kwargs):
112      # In TF1, the dtype defaults to the input dtype which is typically int32,
113      # so explicitly set it to floatx
114      kwargs['dtype'] = K.floatx()
115    # We set autocast to False, as we do not want to cast floating- point inputs
116    # to self.dtype. In call(), we cast to int32, and casting to self.dtype
117    # before casting to int32 might cause the int32 values to be different due
118    # to a loss of precision.
119    kwargs['autocast'] = False
120    super(Embedding, self).__init__(**kwargs)
121
122    self.input_dim = input_dim
123    self.output_dim = output_dim
124    self.embeddings_initializer = initializers.get(embeddings_initializer)
125    self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
126    self.activity_regularizer = regularizers.get(activity_regularizer)
127    self.embeddings_constraint = constraints.get(embeddings_constraint)
128    self.mask_zero = mask_zero
129    self.supports_masking = mask_zero
130    self.input_length = input_length
131
132  @tf_utils.shape_type_conversion
133  def build(self, input_shape):
134    # Note: most sparse optimizers do not have GPU kernels defined. When
135    # building graphs, the placement algorithm is able to place variables on CPU
136    # since it knows all kernels using the variable only exist on CPU.
137    # When eager execution is enabled, the placement decision has to be made
138    # right now. Checking for the presence of GPUs to avoid complicating the
139    # TPU codepaths which can handle sparse optimizers.
140    if context.executing_eagerly() and tf_config.list_logical_devices('GPU'):
141      with ops.device('cpu:0'):
142        self.embeddings = self.add_weight(
143            shape=(self.input_dim, self.output_dim),
144            initializer=self.embeddings_initializer,
145            name='embeddings',
146            regularizer=self.embeddings_regularizer,
147            constraint=self.embeddings_constraint,
148            experimental_autocast=False)
149    else:
150      self.embeddings = self.add_weight(
151          shape=(self.input_dim, self.output_dim),
152          initializer=self.embeddings_initializer,
153          name='embeddings',
154          regularizer=self.embeddings_regularizer,
155          constraint=self.embeddings_constraint,
156          experimental_autocast=False)
157    self.built = True
158
159  def compute_mask(self, inputs, mask=None):
160    if not self.mask_zero:
161      return None
162
163    return math_ops.not_equal(inputs, 0)
164
165  @tf_utils.shape_type_conversion
166  def compute_output_shape(self, input_shape):
167    if self.input_length is None:
168      return input_shape + (self.output_dim,)
169    else:
170      # input_length can be tuple if input is 3D or higher
171      if isinstance(self.input_length, (list, tuple)):
172        in_lens = list(self.input_length)
173      else:
174        in_lens = [self.input_length]
175      if len(in_lens) != len(input_shape) - 1:
176        raise ValueError('"input_length" is %s, '
177                         'but received input has shape %s' % (str(
178                             self.input_length), str(input_shape)))
179      else:
180        for i, (s1, s2) in enumerate(zip(in_lens, input_shape[1:])):
181          if s1 is not None and s2 is not None and s1 != s2:
182            raise ValueError('"input_length" is %s, '
183                             'but received input has shape %s' % (str(
184                                 self.input_length), str(input_shape)))
185          elif s1 is None:
186            in_lens[i] = s2
187      return (input_shape[0],) + tuple(in_lens) + (self.output_dim,)
188
189  def call(self, inputs):
190    dtype = K.dtype(inputs)
191    if dtype != 'int32' and dtype != 'int64':
192      inputs = math_ops.cast(inputs, 'int32')
193    if isinstance(self.embeddings, sharded_variable.ShardedVariable):
194      out = embedding_ops.embedding_lookup_v2(self.embeddings.variables, inputs)
195    else:
196      out = embedding_ops.embedding_lookup_v2(self.embeddings, inputs)
197    if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype:
198      # Instead of casting the variable as in most layers, cast the output, as
199      # this is mathematically equivalent but is faster.
200      out = math_ops.cast(out, self._dtype_policy.compute_dtype)
201    return out
202
203  def get_config(self):
204    config = {
205        'input_dim': self.input_dim,
206        'output_dim': self.output_dim,
207        'embeddings_initializer':
208            initializers.serialize(self.embeddings_initializer),
209        'embeddings_regularizer':
210            regularizers.serialize(self.embeddings_regularizer),
211        'activity_regularizer':
212            regularizers.serialize(self.activity_regularizer),
213        'embeddings_constraint':
214            constraints.serialize(self.embeddings_constraint),
215        'mask_zero': self.mask_zero,
216        'input_length': self.input_length
217    }
218    base_config = super(Embedding, self).get_config()
219    return dict(list(base_config.items()) + list(config.items()))
220