• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Built-in WideNDeep model classes."""
16
17from tensorflow.python.eager import backprop
18from tensorflow.python.keras import activations
19from tensorflow.python.keras import backend
20from tensorflow.python.keras import layers as layer_module
21from tensorflow.python.keras.engine import base_layer
22from tensorflow.python.keras.engine import data_adapter
23from tensorflow.python.keras.engine import training as keras_training
24from tensorflow.python.keras.utils import generic_utils
25from tensorflow.python.util import nest
26from tensorflow.python.util.tf_export import keras_export
27
28
29@keras_export('keras.experimental.WideDeepModel')
30class WideDeepModel(keras_training.Model):
31  r"""Wide & Deep Model for regression and classification problems.
32
33  This model jointly train a linear and a dnn model.
34
35  Example:
36
37  ```python
38  linear_model = LinearModel()
39  dnn_model = keras.Sequential([keras.layers.Dense(units=64),
40                               keras.layers.Dense(units=1)])
41  combined_model = WideDeepModel(linear_model, dnn_model)
42  combined_model.compile(optimizer=['sgd', 'adam'], 'mse', ['mse'])
43  # define dnn_inputs and linear_inputs as separate numpy arrays or
44  # a single numpy array if dnn_inputs is same as linear_inputs.
45  combined_model.fit([linear_inputs, dnn_inputs], y, epochs)
46  # or define a single `tf.data.Dataset` that contains a single tensor or
47  # separate tensors for dnn_inputs and linear_inputs.
48  dataset = tf.data.Dataset.from_tensors(([linear_inputs, dnn_inputs], y))
49  combined_model.fit(dataset, epochs)
50  ```
51
52  Both linear and dnn model can be pre-compiled and trained separately
53  before jointly training:
54
55  Example:
56  ```python
57  linear_model = LinearModel()
58  linear_model.compile('adagrad', 'mse')
59  linear_model.fit(linear_inputs, y, epochs)
60  dnn_model = keras.Sequential([keras.layers.Dense(units=1)])
61  dnn_model.compile('rmsprop', 'mse')
62  dnn_model.fit(dnn_inputs, y, epochs)
63  combined_model = WideDeepModel(linear_model, dnn_model)
64  combined_model.compile(optimizer=['sgd', 'adam'], 'mse', ['mse'])
65  combined_model.fit([linear_inputs, dnn_inputs], y, epochs)
66  ```
67
68  """
69
70  def __init__(self, linear_model, dnn_model, activation=None, **kwargs):
71    """Create a Wide & Deep Model.
72
73    Args:
74      linear_model: a premade LinearModel, its output must match the output of
75        the dnn model.
76      dnn_model: a `tf.keras.Model`, its output must match the output of the
77        linear model.
78      activation: Activation function. Set it to None to maintain a linear
79        activation.
80      **kwargs: The keyword arguments that are passed on to BaseLayer.__init__.
81        Allowed keyword arguments include `name`.
82    """
83    super(WideDeepModel, self).__init__(**kwargs)
84    self.linear_model = linear_model
85    self.dnn_model = dnn_model
86    self.activation = activations.get(activation)
87
88  def call(self, inputs, training=None):
89    if not isinstance(inputs, (tuple, list)) or len(inputs) != 2:
90      linear_inputs = dnn_inputs = inputs
91    else:
92      linear_inputs, dnn_inputs = inputs
93    linear_output = self.linear_model(linear_inputs)
94    # pylint: disable=protected-access
95    if self.dnn_model._expects_training_arg:
96      if training is None:
97        training = backend.learning_phase()
98      dnn_output = self.dnn_model(dnn_inputs, training=training)
99    else:
100      dnn_output = self.dnn_model(dnn_inputs)
101    output = nest.map_structure(lambda x, y: (x + y), linear_output, dnn_output)
102    if self.activation:
103      return nest.map_structure(self.activation, output)
104    return output
105
106  # This does not support gradient scaling and LossScaleOptimizer.
107  def train_step(self, data):
108    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
109    x, y, sample_weight = data_adapter.expand_1d((x, y, sample_weight))
110
111    with backprop.GradientTape() as tape:
112      y_pred = self(x, training=True)
113      loss = self.compiled_loss(
114          y, y_pred, sample_weight, regularization_losses=self.losses)
115    self.compiled_metrics.update_state(y, y_pred, sample_weight)
116
117    if isinstance(self.optimizer, (list, tuple)):
118      linear_vars = self.linear_model.trainable_variables
119      dnn_vars = self.dnn_model.trainable_variables
120      linear_grads, dnn_grads = tape.gradient(loss, (linear_vars, dnn_vars))
121
122      linear_optimizer = self.optimizer[0]
123      dnn_optimizer = self.optimizer[1]
124      linear_optimizer.apply_gradients(zip(linear_grads, linear_vars))
125      dnn_optimizer.apply_gradients(zip(dnn_grads, dnn_vars))
126    else:
127      trainable_variables = self.trainable_variables
128      grads = tape.gradient(loss, trainable_variables)
129      self.optimizer.apply_gradients(zip(grads, trainable_variables))
130
131    return {m.name: m.result() for m in self.metrics}
132
133  def _make_train_function(self):
134    # Only needed for graph mode and model_to_estimator.
135    has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
136    self._check_trainable_weights_consistency()
137    # If we have re-compiled the loss/weighted metric sub-graphs then create
138    # train function even if one exists already. This is because
139    # `_feed_sample_weights` list has been updated on re-compile.
140    if getattr(self, 'train_function', None) is None or has_recompiled:
141      # Restore the compiled trainable state.
142      current_trainable_state = self._get_trainable_state()
143      self._set_trainable_state(self._compiled_trainable_state)
144
145      inputs = (
146          self._feed_inputs + self._feed_targets + self._feed_sample_weights)
147      if not isinstance(backend.symbolic_learning_phase(), int):
148        inputs += [backend.symbolic_learning_phase()]
149
150      if isinstance(self.optimizer, (list, tuple)):
151        linear_optimizer = self.optimizer[0]
152        dnn_optimizer = self.optimizer[1]
153      else:
154        linear_optimizer = self.optimizer
155        dnn_optimizer = self.optimizer
156
157      with backend.get_graph().as_default():
158        with backend.name_scope('training'):
159          # Training updates
160          updates = []
161          linear_updates = linear_optimizer.get_updates(
162              params=self.linear_model.trainable_weights,  # pylint: disable=protected-access
163              loss=self.total_loss)
164          updates += linear_updates
165          dnn_updates = dnn_optimizer.get_updates(
166              params=self.dnn_model.trainable_weights,  # pylint: disable=protected-access
167              loss=self.total_loss)
168          updates += dnn_updates
169          # Unconditional updates
170          updates += self.get_updates_for(None)
171          # Conditional updates relevant to this model
172          updates += self.get_updates_for(self.inputs)
173
174        metrics = self._get_training_eval_metrics()
175        metrics_tensors = [
176            m._call_result for m in metrics if hasattr(m, '_call_result')  # pylint: disable=protected-access
177        ]
178
179      with backend.name_scope('training'):
180        # Gets loss and metrics. Updates weights at each call.
181        fn = backend.function(
182            inputs, [self.total_loss] + metrics_tensors,
183            updates=updates,
184            name='train_function',
185            **self._function_kwargs)
186        setattr(self, 'train_function', fn)
187
188      # Restore the current trainable state
189      self._set_trainable_state(current_trainable_state)
190
191  def get_config(self):
192    linear_config = generic_utils.serialize_keras_object(self.linear_model)
193    dnn_config = generic_utils.serialize_keras_object(self.dnn_model)
194    config = {
195        'linear_model': linear_config,
196        'dnn_model': dnn_config,
197        'activation': activations.serialize(self.activation),
198    }
199    base_config = base_layer.Layer.get_config(self)
200    return dict(list(base_config.items()) + list(config.items()))
201
202  @classmethod
203  def from_config(cls, config, custom_objects=None):
204    linear_config = config.pop('linear_model')
205    linear_model = layer_module.deserialize(linear_config, custom_objects)
206    dnn_config = config.pop('dnn_model')
207    dnn_model = layer_module.deserialize(dnn_config, custom_objects)
208    activation = activations.deserialize(
209        config.pop('activation', None), custom_objects=custom_objects)
210    return cls(
211        linear_model=linear_model,
212        dnn_model=dnn_model,
213        activation=activation,
214        **config)
215