• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TargetColumn abstract a single head in the model.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from tensorflow.contrib.framework import deprecated
24from tensorflow.contrib.losses.python.losses import loss_ops
25from tensorflow.contrib.metrics.python.ops import metric_ops
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn
32
33
34@deprecated(
35    "2016-11-12", "This file will be removed after the deprecation date."
36    "Please switch to "
37    "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py")
38def regression_target(label_name=None,
39                      weight_column_name=None,
40                      label_dimension=1):
41  """Creates a _TargetColumn for linear regression.
42
43  Args:
44    label_name: String, name of the key in label dict. Can be null if label
45        is a tensor (single headed models).
46    weight_column_name: A string defining feature column name representing
47      weights. It is used to down weight or boost examples during training. It
48      will be multiplied by the loss of the example.
49    label_dimension: dimension of the target for multilabels.
50
51  Returns:
52    An instance of _TargetColumn
53  """
54  return _RegressionTargetColumn(
55      loss_fn=_mean_squared_loss,
56      label_name=label_name,
57      weight_column_name=weight_column_name,
58      label_dimension=label_dimension)
59
60
61# TODO(zakaria): Add logistic_regression_target
62
63
64@deprecated(
65    "2016-11-12", "This file will be removed after the deprecation date."
66    "Please switch to "
67    "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py")
68def multi_class_target(n_classes, label_name=None, weight_column_name=None):
69  """Creates a _TargetColumn for multi class single label classification.
70
71  The target column uses softmax cross entropy loss.
72
73  Args:
74    n_classes: Integer, number of classes, must be >= 2
75    label_name: String, name of the key in label dict. Can be null if label
76        is a tensor (single headed models).
77    weight_column_name: A string defining feature column name representing
78      weights. It is used to down weight or boost examples during training. It
79      will be multiplied by the loss of the example.
80
81  Returns:
82    An instance of _MultiClassTargetColumn.
83
84  Raises:
85    ValueError: if n_classes is < 2
86  """
87  if n_classes < 2:
88    raise ValueError("n_classes must be > 1 for classification.")
89  if n_classes == 2:
90    loss_fn = _log_loss_with_two_classes
91  else:
92    loss_fn = _softmax_cross_entropy_loss
93  return _MultiClassTargetColumn(
94      loss_fn=loss_fn,
95      n_classes=n_classes,
96      label_name=label_name,
97      weight_column_name=weight_column_name)
98
99
100@deprecated(
101    "2016-11-12", "This file will be removed after the deprecation date."
102    "Please switch to "
103    "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py")
104def binary_svm_target(label_name=None, weight_column_name=None):
105  """Creates a _TargetColumn for binary classification with SVMs.
106
107  The target column uses binary hinge loss.
108
109  Args:
110    label_name: String, name of the key in label dict. Can be null if label
111      is a tensor (single headed models).
112    weight_column_name: A string defining feature column name representing
113      weights. It is used to down weight or boost examples during training. It
114      will be multiplied by the loss of the example.
115
116  Returns:
117    An instance of _TargetColumn.
118
119  """
120  return _BinarySvmTargetColumn(
121      label_name=label_name, weight_column_name=weight_column_name)
122
123
124@deprecated(
125    "2016-11-12", "This file will be removed after the deprecation date."
126    "Please switch to "
127    "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py")
128class ProblemType(object):
129  UNSPECIFIED = 0
130  CLASSIFICATION = 1
131  LINEAR_REGRESSION = 2
132  LOGISTIC_REGRESSION = 3
133
134
135class _TargetColumn(object):
136  """_TargetColumn is the abstraction for a single head in a model.
137
138    Args:
139      loss_fn: a function that returns the loss tensor.
140      num_label_columns: Integer, number of label columns.
141      label_name: String, name of the key in label dict. Can be null if label
142          is a tensor (single headed models).
143      weight_column_name: A string defining feature column name representing
144        weights. It is used to down weight or boost examples during training. It
145        will be multiplied by the loss of the example.
146
147    Raises:
148      ValueError: if loss_fn or n_classes are missing.
149  """
150
151  def __init__(self, loss_fn, num_label_columns, label_name, weight_column_name,
152               problem_type):
153    if not loss_fn:
154      raise ValueError("loss_fn must be provided")
155    if num_label_columns is None:  # n_classes can be 0
156      raise ValueError("num_label_columns must be provided")
157
158    self._loss_fn = loss_fn
159    self._num_label_columns = num_label_columns
160    self._label_name = label_name
161    self._weight_column_name = weight_column_name
162    self._problem_type = problem_type
163
164  def logits_to_predictions(self, logits, proba=False):
165    # Abstrat, Subclasses must implement.
166    raise NotImplementedError()
167
168  def get_eval_ops(self, features, logits, labels, metrics=None):
169    """Returns eval op."""
170    raise NotImplementedError
171
172  @property
173  def label_name(self):
174    return self._label_name
175
176  @property
177  def weight_column_name(self):
178    return self._weight_column_name
179
180  @property
181  def num_label_columns(self):
182    return self._num_label_columns
183
184  def get_weight_tensor(self, features):
185    if not self._weight_column_name:
186      return None
187    else:
188      return array_ops.reshape(
189          math_ops.cast(features[self._weight_column_name], dtypes.float32),
190          shape=(-1,))
191
192  @property
193  def problem_type(self):
194    return self._problem_type
195
196  def _weighted_loss(self, loss, weight_tensor):
197    """Returns cumulative weighted loss."""
198    unweighted_loss = array_ops.reshape(loss, shape=(-1,))
199    weighted_loss = math_ops.multiply(unweighted_loss,
200                                      array_ops.reshape(
201                                          weight_tensor, shape=(-1,)))
202    return weighted_loss
203
204  def training_loss(self, logits, target, features, name="training_loss"):
205    """Returns training loss tensor for this head.
206
207    Training loss is different from the loss reported on the tensorboard as we
208    should respect the example weights when computing the gradient.
209
210      L = sum_{i} w_{i} * l_{i} / B
211
212    where B is the number of examples in the batch, l_{i}, w_{i} are individual
213    losses, and example weight.
214
215    Args:
216      logits: logits, a float tensor.
217      target: either a tensor for labels or in multihead case, a dict of string
218        to target tensor.
219      features: features dict.
220      name: Op name.
221
222    Returns:
223      Loss tensor.
224    """
225    target = target[self.name] if isinstance(target, dict) else target
226    loss_unweighted = self._loss_fn(logits, target)
227
228    weight_tensor = self.get_weight_tensor(features)
229    if weight_tensor is None:
230      return math_ops.reduce_mean(loss_unweighted, name=name)
231    loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
232    return math_ops.reduce_mean(loss_weighted, name=name)
233
234  def loss(self, logits, target, features):
235    """Returns loss tensor for this head.
236
237    The loss returned is the weighted average.
238
239      L = sum_{i} w_{i} * l_{i} / sum_{i} w_{i}
240
241    Args:
242      logits: logits, a float tensor.
243      target: either a tensor for labels or in multihead case, a dict of string
244        to target tensor.
245      features: features dict.
246
247    Returns:
248      Loss tensor.
249    """
250    target = target[self.name] if isinstance(target, dict) else target
251    loss_unweighted = self._loss_fn(logits, target)
252
253    weight_tensor = self.get_weight_tensor(features)
254    if weight_tensor is None:
255      return math_ops.reduce_mean(loss_unweighted, name="loss")
256    loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
257    return math_ops.div(
258        math_ops.reduce_sum(loss_weighted),
259        math_ops.cast(math_ops.reduce_sum(weight_tensor), dtypes.float32),
260        name="loss")
261
262
263class _RegressionTargetColumn(_TargetColumn):
264  """_TargetColumn for regression."""
265
266  def __init__(self, loss_fn, label_name, weight_column_name, label_dimension):
267    super(_RegressionTargetColumn, self).__init__(
268        loss_fn=loss_fn,
269        num_label_columns=label_dimension,
270        label_name=label_name,
271        weight_column_name=weight_column_name,
272        problem_type=ProblemType.LINEAR_REGRESSION)
273
274  def logits_to_predictions(self, logits, proba=False):
275    if self.num_label_columns == 1:
276      return array_ops.squeeze(logits, axis=[1])
277    return logits
278
279  def get_eval_ops(self, features, logits, labels, metrics=None):
280    loss = self.loss(logits, labels, features)
281    result = {"loss": metric_ops.streaming_mean(loss)}
282    if metrics:
283      predictions = self.logits_to_predictions(logits, proba=False)
284      result.update(
285          _run_metrics(predictions, labels, metrics,
286                       self.get_weight_tensor(features)))
287    return result
288
289
290class _MultiClassTargetColumn(_TargetColumn):
291  """_TargetColumn for classification."""
292
293  # TODO(zakaria): support multilabel.
294  def __init__(self, loss_fn, n_classes, label_name, weight_column_name):
295    if n_classes < 2:
296      raise ValueError("n_classes must be >= 2")
297    super(_MultiClassTargetColumn, self).__init__(
298        loss_fn=loss_fn,
299        num_label_columns=1 if n_classes == 2 else n_classes,
300        label_name=label_name,
301        weight_column_name=weight_column_name,
302        problem_type=ProblemType.CLASSIFICATION)
303
304  def logits_to_predictions(self, logits, proba=False):
305    if self.num_label_columns == 1:
306      logits = array_ops.concat([array_ops.zeros_like(logits), logits], 1)
307
308    if proba:
309      return nn.softmax(logits)
310    else:
311      return math_ops.argmax(logits, 1)
312
313  def _default_eval_metrics(self):
314    if self._num_label_columns == 1:
315      return get_default_binary_metrics_for_eval(thresholds=[.5])
316    return {}
317
318  def get_eval_ops(self, features, logits, labels, metrics=None):
319    loss = self.loss(logits, labels, features)
320    result = {"loss": metric_ops.streaming_mean(loss)}
321
322    # Adds default metrics.
323    if metrics is None:
324      # TODO(b/29366811): This currently results in both an "accuracy" and an
325      # "accuracy/threshold_0.500000_mean" metric for binary classification.
326      metrics = {("accuracy", "classes"): metric_ops.streaming_accuracy}
327
328    predictions = math_ops.sigmoid(logits)
329    labels_float = math_ops.cast(labels, dtypes.float32)
330
331    default_metrics = self._default_eval_metrics()
332    for metric_name, metric_op in default_metrics.items():
333      result[metric_name] = metric_op(predictions, labels_float)
334
335    class_metrics = {}
336    proba_metrics = {}
337    for name, metric_op in six.iteritems(metrics):
338      if isinstance(name, tuple):
339        if len(name) != 2:
340          raise ValueError("Ignoring metric {}. It returned a tuple with "
341                           "len {}, expected 2.".format(name, len(name)))
342        else:
343          if name[1] not in ["classes", "probabilities"]:
344            raise ValueError("Ignoring metric {}. The 2nd element of its "
345                             "name should be either 'classes' or "
346                             "'probabilities'.".format(name))
347          elif name[1] == "classes":
348            class_metrics[name[0]] = metric_op
349          else:
350            proba_metrics[name[0]] = metric_op
351      elif isinstance(name, str):
352        class_metrics[name] = metric_op
353      else:
354        raise ValueError("Ignoring metric {}. Its name is not in the correct "
355                         "form.".format(name))
356    if class_metrics:
357      class_predictions = self.logits_to_predictions(logits, proba=False)
358      result.update(
359          _run_metrics(class_predictions, labels, class_metrics,
360                       self.get_weight_tensor(features)))
361    if proba_metrics:
362      predictions = self.logits_to_predictions(logits, proba=True)
363      result.update(
364          _run_metrics(predictions, labels, proba_metrics,
365                       self.get_weight_tensor(features)))
366    return result
367
368
369class _BinarySvmTargetColumn(_MultiClassTargetColumn):
370  """_TargetColumn for binary classification using SVMs."""
371
372  def __init__(self, label_name, weight_column_name):
373
374    def loss_fn(logits, target):
375      check_shape_op = control_flow_ops.Assert(
376          math_ops.less_equal(array_ops.rank(target), 2),
377          ["target's shape should be either [batch_size, 1] or [batch_size]"])
378      with ops.control_dependencies([check_shape_op]):
379        target = array_ops.reshape(
380            target, shape=[array_ops.shape(target)[0], 1])
381      return loss_ops.hinge_loss(logits, target)
382
383    super(_BinarySvmTargetColumn, self).__init__(
384        loss_fn=loss_fn,
385        n_classes=2,
386        label_name=label_name,
387        weight_column_name=weight_column_name)
388
389  def logits_to_predictions(self, logits, proba=False):
390    if proba:
391      raise ValueError(
392          "logits to probabilities is not supported for _BinarySvmTargetColumn")
393
394    logits = array_ops.concat([array_ops.zeros_like(logits), logits], 1)
395    return math_ops.argmax(logits, 1)
396
397
398# TODO(zakaria): use contrib losses.
399def _mean_squared_loss(logits, target):
400  # To prevent broadcasting inside "-".
401  if len(target.get_shape()) == 1:
402    target = array_ops.expand_dims(target, axis=1)
403
404  logits.get_shape().assert_is_compatible_with(target.get_shape())
405  return math_ops.squared_difference(logits,
406                                     math_ops.cast(target, dtypes.float32))
407
408
409def _log_loss_with_two_classes(logits, target):
410  # sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
411  if len(target.get_shape()) == 1:
412    target = array_ops.expand_dims(target, axis=1)
413  loss_vec = nn.sigmoid_cross_entropy_with_logits(
414      labels=math_ops.cast(target, dtypes.float32), logits=logits)
415  return loss_vec
416
417
418def _softmax_cross_entropy_loss(logits, target):
419  # Check that we got integer for classification.
420  if not target.dtype.is_integer:
421    raise ValueError("Target's dtype should be integer "
422                     "Instead got %s." % target.dtype)
423  # sparse_softmax_cross_entropy_with_logits requires [batch_size] target.
424  if len(target.get_shape()) == 2:
425    target = array_ops.squeeze(target, axis=[1])
426  loss_vec = nn.sparse_softmax_cross_entropy_with_logits(
427      labels=target, logits=logits)
428  return loss_vec
429
430
431def _run_metrics(predictions, labels, metrics, weights):
432  result = {}
433  labels = math_ops.cast(labels, predictions.dtype)
434  for name, metric in six.iteritems(metrics or {}):
435    if weights is not None:
436      result[name] = metric(predictions, labels, weights=weights)
437    else:
438      result[name] = metric(predictions, labels)
439
440  return result
441
442
443@deprecated(
444    "2016-11-12", "This file will be removed after the deprecation date."
445    "Please switch to "
446    "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py")
447def get_default_binary_metrics_for_eval(thresholds):
448  """Returns a dictionary of basic metrics for logistic regression.
449
450  Args:
451    thresholds: List of floating point thresholds to use for accuracy,
452      precision, and recall metrics. If None, defaults to [0.5].
453
454  Returns:
455    Dictionary mapping metrics string names to metrics functions.
456  """
457  metrics = {}
458  metrics[_MetricKeys.PREDICTION_MEAN] = _predictions_streaming_mean
459  metrics[_MetricKeys.TARGET_MEAN] = _labels_streaming_mean
460  # Also include the streaming mean of the label as an accuracy baseline, as
461  # a reminder to users.
462  metrics[_MetricKeys.ACCURACY_BASELINE] = _labels_streaming_mean
463
464  metrics[_MetricKeys.AUC] = _streaming_auc
465
466  for threshold in thresholds:
467    metrics[_MetricKeys.ACCURACY_MEAN %
468            threshold] = _accuracy_at_threshold(threshold)
469    # Precision for positive examples.
470    metrics[_MetricKeys.PRECISION_MEAN % threshold] = _streaming_at_threshold(
471        metric_ops.streaming_precision_at_thresholds, threshold)
472    # Recall for positive examples.
473    metrics[_MetricKeys.RECALL_MEAN % threshold] = _streaming_at_threshold(
474        metric_ops.streaming_recall_at_thresholds, threshold)
475
476  return metrics
477
478
479def _float_weights_or_none(weights):
480  if weights is None:
481    return None
482  return math_ops.cast(weights, dtypes.float32)
483
484
485def _labels_streaming_mean(unused_predictions, labels, weights=None):
486  return metric_ops.streaming_mean(labels, weights=weights)
487
488
489def _predictions_streaming_mean(predictions, unused_labels, weights=None):
490  return metric_ops.streaming_mean(predictions, weights=weights)
491
492
493def _streaming_auc(predictions, labels, weights=None):
494  return metric_ops.streaming_auc(
495      predictions, labels, weights=_float_weights_or_none(weights))
496
497
498def _accuracy_at_threshold(threshold):
499
500  def _accuracy_metric(predictions, labels, weights=None):
501    threshold_predictions = math_ops.cast(
502        math_ops.greater_equal(predictions, threshold), dtypes.float32)
503    return metric_ops.streaming_accuracy(
504        predictions=threshold_predictions, labels=labels, weights=weights)
505
506  return _accuracy_metric
507
508
509def _streaming_at_threshold(streaming_metrics_fn, threshold):
510
511  def _streaming_metrics(predictions, labels, weights=None):
512    precision_tensor, update_op = streaming_metrics_fn(
513        predictions,
514        labels=labels,
515        thresholds=[threshold],
516        weights=_float_weights_or_none(weights))
517    return array_ops.squeeze(precision_tensor), update_op
518
519  return _streaming_metrics
520
521
522class _MetricKeys(object):
523  AUC = "auc"
524  PREDICTION_MEAN = "labels/prediction_mean"
525  TARGET_MEAN = "labels/actual_target_mean"
526  ACCURACY_BASELINE = "accuracy/baseline_target_mean"
527  ACCURACY_MEAN = "accuracy/threshold_%f_mean"
528  PRECISION_MEAN = "precision/positive_threshold_%f_mean"
529  RECALL_MEAN = "recall/positive_threshold_%f_mean"
530