• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Baseline estimators.
16
17Baseline estimators are bias-only estimators that can be used for debugging
18and as simple baselines.
19
20Example:
21
22```
23# Build BaselineClassifier
24classifier = BaselineClassifier(n_classes=3)
25
26# Input builders
27def input_fn_train: # returns x, y (where y represents label's class index).
28  pass
29
30def input_fn_eval: # returns x, y (where y represents label's class index).
31  pass
32
33# Fit model.
34classifier.train(input_fn=input_fn_train)
35
36# Evaluate cross entropy between the test and train labels.
37loss = classifier.evaluate(input_fn=input_fn_eval)["loss"]
38
39# predict outputs the probability distribution of the classes as seen in
40# training.
41predictions = classifier.predict(new_samples)
42```
43"""
44from __future__ import absolute_import
45from __future__ import division
46from __future__ import print_function
47
48import six
49
50from tensorflow.python.estimator import estimator
51from tensorflow.python.estimator.canned import head as head_lib
52from tensorflow.python.estimator.canned import optimizers
53from tensorflow.python.feature_column import feature_column as feature_column_lib
54from tensorflow.python.framework import ops
55from tensorflow.python.ops import array_ops
56from tensorflow.python.ops import check_ops
57from tensorflow.python.ops import init_ops
58from tensorflow.python.ops import math_ops
59from tensorflow.python.ops import variable_scope
60from tensorflow.python.ops.losses import losses
61from tensorflow.python.training import training_util
62from tensorflow.python.util.tf_export import tf_export
63
64# The default learning rate of 0.3 is a historical artifact of the initial
65# implementation, but seems a reasonable choice.
66_LEARNING_RATE = 0.3
67
68
69def _get_weight_column_key(weight_column):
70  if weight_column is None:
71    return None
72  if isinstance(weight_column, six.string_types):
73    return weight_column
74  if not isinstance(weight_column, feature_column_lib._NumericColumn):  # pylint: disable=protected-access
75    raise TypeError('Weight column must be either a string or _NumericColumn.'
76                    ' Given type: {}.'.format(type(weight_column)))
77  return weight_column.key()
78
79
80def _baseline_logit_fn_builder(num_outputs, weight_column=None):
81  """Function builder for a baseline logit_fn.
82
83  Args:
84    num_outputs: Number of outputs for the model.
85    weight_column: A string or a `_NumericColumn` created by
86      `tf.feature_column.numeric_column` defining feature column representing
87       weights. It will be multiplied by the loss of the example.
88  Returns:
89    A logit_fn (see below).
90  """
91
92  def baseline_logit_fn(features):
93    """Baseline model logit_fn.
94
95    The baseline model simply learns a bias, so the output logits are a
96    `Variable` with one weight for each output that learns the bias for the
97    corresponding output.
98
99    Args:
100      features: The first item returned from the `input_fn` passed to `train`,
101        `evaluate`, and `predict`. This should be a single `Tensor` or dict with
102        `Tensor` values.
103    Returns:
104      A `Tensor` representing the logits.
105    """
106    size_checks = []
107    batch_size = None
108
109    weight_column_key = _get_weight_column_key(weight_column)
110
111    # The first dimension is assumed to be a batch size and must be consistent
112    # among all of the features.
113    for key, feature in features.items():
114      # Skip weight_column to ensure we don't add size checks to it.
115      # These would introduce a dependency on the weight at serving time.
116      if key == weight_column_key:
117        continue
118      first_dim = array_ops.shape(feature)[0]
119      if batch_size is None:
120        batch_size = first_dim
121      else:
122        size_checks.append(check_ops.assert_equal(batch_size, first_dim))
123
124    with ops.control_dependencies(size_checks):
125      with variable_scope.variable_scope('baseline'):
126        bias = variable_scope.get_variable('bias', shape=[num_outputs],
127                                           initializer=init_ops.Zeros)
128        return math_ops.multiply(bias, array_ops.ones([batch_size,
129                                                       num_outputs]))
130
131  return baseline_logit_fn
132
133
134def _baseline_model_fn(features, labels, mode, head, optimizer,
135                       weight_column=None, config=None):
136  """Model_fn for baseline models.
137
138  Args:
139    features: `Tensor` or dict of `Tensor` (depends on data passed to `train`).
140    labels: `Tensor` of labels that are compatible with the `Head` instance.
141    mode: Defines whether this is training, evaluation or prediction.
142      See `ModeKeys`.
143    head: A `Head` instance.
144    optimizer: String, `tf.Optimizer` object, or callable that creates the
145      optimizer to use for training. If not specified, will use `FtrlOptimizer`
146      with a default learning rate of 0.3.
147    weight_column: A string or a `_NumericColumn` created by
148      `tf.feature_column.numeric_column` defining feature column representing
149       weights. It will be multiplied by the loss of the example.
150    config: `RunConfig` object to configure the runtime settings.
151
152  Raises:
153    KeyError: If weight column is specified but not present.
154    ValueError: If features is an empty dictionary.
155
156  Returns:
157    An `EstimatorSpec` instance.
158  """
159  del config  # Unused.
160
161  logit_fn = _baseline_logit_fn_builder(head.logits_dimension, weight_column)
162  logits = logit_fn(features)
163
164  def train_op_fn(loss):
165    opt = optimizers.get_optimizer_instance(
166        optimizer, learning_rate=_LEARNING_RATE)
167    return opt.minimize(loss, global_step=training_util.get_global_step())
168
169  return head.create_estimator_spec(
170      features=features,
171      mode=mode,
172      logits=logits,
173      labels=labels,
174      train_op_fn=train_op_fn)
175
176
177@tf_export('estimator.BaselineClassifier')
178class BaselineClassifier(estimator.Estimator):
179  """A classifier that can establish a simple baseline.
180
181  This classifier ignores feature values and will learn to predict the average
182  value of each label. For single-label problems, this will predict the
183  probability distribution of the classes as seen in the labels. For multi-label
184  problems, this will predict the fraction of examples that are positive for
185  each class.
186
187  Example:
188
189  ```python
190
191  # Build BaselineClassifier
192  classifier = BaselineClassifier(n_classes=3)
193
194  # Input builders
195  def input_fn_train: # returns x, y (where y represents label's class index).
196    pass
197
198  def input_fn_eval: # returns x, y (where y represents label's class index).
199    pass
200
201  # Fit model.
202  classifier.train(input_fn=input_fn_train)
203
204  # Evaluate cross entropy between the test and train labels.
205  loss = classifier.evaluate(input_fn=input_fn_eval)["loss"]
206
207  # predict outputs the probability distribution of the classes as seen in
208  # training.
209  predictions = classifier.predict(new_samples)
210
211  ```
212
213  Input of `train` and `evaluate` should have following features,
214    otherwise there will be a `KeyError`:
215
216  * if `weight_column` is not `None`, a feature with
217     `key=weight_column` whose value is a `Tensor`.
218  """
219
220  def __init__(self,
221               model_dir=None,
222               n_classes=2,
223               weight_column=None,
224               label_vocabulary=None,
225               optimizer='Ftrl',
226               config=None,
227               loss_reduction=losses.Reduction.SUM):
228    """Initializes a BaselineClassifier instance.
229
230    Args:
231      model_dir: Directory to save model parameters, graph and etc. This can
232        also be used to load checkpoints from the directory into a estimator to
233        continue training a previously saved model.
234      n_classes: number of label classes. Default is binary classification.
235        It must be greater than 1. Note: Class labels are integers representing
236        the class index (i.e. values from 0 to n_classes-1). For arbitrary
237        label values (e.g. string labels), convert to class indices first.
238      weight_column: A string or a `_NumericColumn` created by
239        `tf.feature_column.numeric_column` defining feature column representing
240         weights. It will be multiplied by the loss of the example.
241      label_vocabulary: Optional list of strings with size `[n_classes]`
242        defining the label vocabulary. Only supported for `n_classes` > 2.
243      optimizer: String, `tf.Optimizer` object, or callable that creates the
244        optimizer to use for training. If not specified, will use
245        `FtrlOptimizer` with a default learning rate of 0.3.
246      config: `RunConfig` object to configure the runtime settings.
247      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
248        to reduce training loss over batch. Defaults to `SUM`.
249    Returns:
250      A `BaselineClassifier` estimator.
251
252    Raises:
253      ValueError: If `n_classes` < 2.
254    """
255    if n_classes == 2:
256      head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(  # pylint: disable=protected-access
257          weight_column=weight_column,
258          label_vocabulary=label_vocabulary,
259          loss_reduction=loss_reduction)
260    else:
261      head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(  # pylint: disable=protected-access
262          n_classes, weight_column=weight_column,
263          label_vocabulary=label_vocabulary,
264          loss_reduction=loss_reduction)
265    def _model_fn(features, labels, mode, config):
266      return _baseline_model_fn(
267          features=features,
268          labels=labels,
269          mode=mode,
270          head=head,
271          optimizer=optimizer,
272          weight_column=weight_column,
273          config=config)
274    super(BaselineClassifier, self).__init__(
275        model_fn=_model_fn,
276        model_dir=model_dir,
277        config=config)
278
279
280@tf_export('estimator.BaselineRegressor')
281class BaselineRegressor(estimator.Estimator):
282  """A regressor that can establish a simple baseline.
283
284  This regressor ignores feature values and will learn to predict the average
285  value of each label.
286
287  Example:
288
289  ```python
290
291  # Build BaselineRegressor
292  regressor = BaselineRegressor()
293
294  # Input builders
295  def input_fn_train: # returns x, y (where y is the label).
296    pass
297
298  def input_fn_eval: # returns x, y (where y is the label).
299    pass
300
301  # Fit model.
302  regressor.train(input_fn=input_fn_train)
303
304  # Evaluate squared-loss between the test and train targets.
305  loss = regressor.evaluate(input_fn=input_fn_eval)["loss"]
306
307  # predict outputs the mean value seen during training.
308  predictions = regressor.predict(new_samples)
309  ```
310
311  Input of `train` and `evaluate` should have following features,
312    otherwise there will be a `KeyError`:
313
314  * if `weight_column` is not `None`, a feature with
315     `key=weight_column` whose value is a `Tensor`.
316  """
317
318  def __init__(self,
319               model_dir=None,
320               label_dimension=1,
321               weight_column=None,
322               optimizer='Ftrl',
323               config=None,
324               loss_reduction=losses.Reduction.SUM):
325    """Initializes a BaselineRegressor instance.
326
327    Args:
328      model_dir: Directory to save model parameters, graph and etc. This can
329        also be used to load checkpoints from the directory into a estimator to
330        continue training a previously saved model.
331      label_dimension: Number of regression targets per example. This is the
332        size of the last dimension of the labels and logits `Tensor` objects
333        (typically, these have shape `[batch_size, label_dimension]`).
334      weight_column: A string or a `_NumericColumn` created by
335        `tf.feature_column.numeric_column` defining feature column representing
336         weights. It will be multiplied by the loss of the example.
337      optimizer: String, `tf.Optimizer` object, or callable that creates the
338        optimizer to use for training. If not specified, will use
339        `FtrlOptimizer` with a default learning rate of 0.3.
340      config: `RunConfig` object to configure the runtime settings.
341      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
342        to reduce training loss over batch. Defaults to `SUM`.
343    Returns:
344      A `BaselineRegressor` estimator.
345    """
346
347    head = head_lib._regression_head_with_mean_squared_error_loss(  # pylint: disable=protected-access
348        label_dimension=label_dimension,
349        weight_column=weight_column,
350        loss_reduction=loss_reduction)
351    def _model_fn(features, labels, mode, config):
352      return _baseline_model_fn(
353          features=features,
354          labels=labels,
355          mode=mode,
356          head=head,
357          optimizer=optimizer,
358          config=config)
359    super(BaselineRegressor, self).__init__(
360        model_fn=_model_fn,
361        model_dir=model_dir,
362        config=config)
363