• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Abstractions for the head(s) of a model.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import six
24
25from tensorflow.contrib import framework as framework_lib
26from tensorflow.contrib import layers as layers_lib
27from tensorflow.contrib.learn.python.learn.estimators import constants
28from tensorflow.contrib.learn.python.learn.estimators import model_fn
29from tensorflow.contrib.learn.python.learn.estimators import prediction_key
30from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey as mkey
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import sparse_tensor
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import lookup_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import metrics as metrics_lib
39from tensorflow.python.ops import nn
40from tensorflow.python.ops import sparse_ops
41from tensorflow.python.ops import string_ops
42from tensorflow.python.ops import variable_scope
43from tensorflow.python.ops import weights_broadcast_ops
44from tensorflow.python.ops.losses import losses as losses_lib
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.python.summary import summary
47from tensorflow.python.training import training
48from tensorflow.python.util import tf_decorator
49from tensorflow.python.util import tf_inspect
50
51
52class Head(object):
53  """Interface for the head/top of a model.
54
55  Given logits (or output of a hidden layer), a Head knows how to compute
56  predictions, loss, default metric and export signature. It is meant to,
57
58  1) Simplify writing model_fn and to make model_fn more configurable
59  2) Support wide range of machine learning models. Since most heads can work
60      with logits, they can support DNN, RNN, Wide, Wide&Deep,
61      Global objectives, Gradient boosted trees and many other types
62      of machine learning models.
63  2) To allow users to seamlessly switch between 1 to n heads for multi
64  objective learning (See _MultiHead implementation for more details)
65
66  Common usage:
67  Here is simplified model_fn to build a multiclass DNN model.
68    ```python
69    def _my_dnn_model_fn(features, labels, mode, params, config=None):
70      # Optionally your callers can pass head to model_fn as a param.
71      head = tf.contrib.learn.multi_class_head(...)
72      input = tf.contrib.layers.input_from_feature_columns(features, ...)
73      last_hidden_layer_out = tf.contrib.layers.stack(
74          input, tf.contrib.layers.fully_connected, [1000, 500])
75      logits = tf.contrib.layers.fully_connected(
76          last_hidden_layer_out, head.logits_dimension, activation_fn=None)
77
78      def _train_op_fn(loss):
79        return optimizer.minimize(loss)
80
81      return head.create_model_fn_ops(
82          features=features,
83          labels=labels,
84          mode=mode,
85          train_op_fn=_train_op_fn,
86          logits=logits,
87          scope=...)
88    ```
89
90  Most heads also support logits_input which is typically the output of the last
91  hidden layer. Some heads (like heads responsible for candidate sampling or
92  hierarchical softmax) intrinsically will not support logits and you have
93  to pass logits_input. Here is a common usage,
94    ```python
95    return head.create_model_fn_ops(
96        features=features,
97        labels=labels,
98        mode=mode,
99        train_op_fn=_train_op_fn,
100        logits_input=last_hidden_layer_out,
101        scope=...)
102    ```python
103
104  There are cases where computing and applying gradients can not be meaningfully
105  captured with train_op_fn we support (for example, with sync optimizer). In
106  such case, you can take the responsibility on your own. Here is a common
107  use case,
108    ```python
109    model_fn_ops = head.create_model_fn_ops(
110        features=features,
111        labels=labels,
112        mode=mode,
113        train_op_fn=tf.contrib.learn.no_op_train_fn,
114        logits=logits,
115        scope=...)
116    if mode == tf.contrib.learn.ModeKeys.TRAIN:
117      optimizer = ...
118      sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...)
119      update_op = tf.contrib.layers.optimize_loss(optimizer=sync,
120                                                  loss=model_fn_ops.loss, ...)
121      hooks = [sync.make_session_run_hook(is_chief)]
122      ... update train_op and hooks in ModelFnOps and return
123    ```
124  """
125  __metaclass__ = abc.ABCMeta
126
127  @abc.abstractproperty
128  def logits_dimension(self):
129    """Size of the last dimension of the logits `Tensor`.
130
131    Typically, logits is of shape `[batch_size, logits_dimension]`.
132
133    Returns:
134      The expected size of the `logits` tensor.
135    """
136    raise NotImplementedError("Calling an abstract method.")
137
138  @abc.abstractmethod
139  def create_model_fn_ops(self,
140                          features,
141                          mode,
142                          labels=None,
143                          train_op_fn=None,
144                          logits=None,
145                          logits_input=None,
146                          scope=None):
147    """Returns `ModelFnOps` that a model_fn can return.
148
149    Please note that,
150    + Exactly one of `logits` and `logits_input` must be provided.
151    + All args must be passed via name.
152
153    Args:
154      features: Input `dict` of `Tensor` objects.
155      mode: Estimator's `ModeKeys`.
156      labels: Labels `Tensor`, or `dict` of same.
157      train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
158          to optimize the model with the loss. This is used in TRAIN mode and
159          must not be None. None is allowed in other modes. If you want to
160          optimize loss yourself you can pass `no_op_train_fn` and then use
161          ModeFnOps.loss to compute and apply gradients.
162      logits: logits `Tensor` to be used by the head.
163      logits_input: `Tensor` from which to build logits, often needed when you
164        don't want to compute the logits. Typically this is the activation of
165        the last hidden layer in a DNN. Some heads (like the ones responsible
166        for candidate sampling) intrinsically avoid computing full logits and
167        only accepts logits_input.
168      scope: Optional scope for `variable_scope`.
169
170    Returns:
171      An instance of `ModelFnOps`.
172
173    Raises:
174      ValueError: If `mode` is not recognized.
175      ValueError: If neither or both of `logits` and `logits_input` is provided.
176    """
177    raise NotImplementedError("Calling an abstract method.")
178
179
180def regression_head(label_name=None,
181                    weight_column_name=None,
182                    label_dimension=1,
183                    enable_centered_bias=False,
184                    head_name=None,
185                    link_fn=None):
186  """Creates a `Head` for linear regression.
187
188  Args:
189    label_name: String, name of the key in label dict. Can be null if label
190        is a tensor (single headed models).
191    weight_column_name: A string defining feature column name representing
192      weights. It is used to down weight or boost examples during training. It
193      will be multiplied by the loss of the example.
194    label_dimension: Number of regression labels per example. This is the size
195      of the last dimension of the labels `Tensor` (typically, this has shape
196      `[batch_size, label_dimension]`).
197    enable_centered_bias: A bool. If True, estimator will learn a centered
198      bias variable for each class. Rest of the model structure learns the
199      residual after centered bias.
200    head_name: name of the head. If provided, predictions, summary and metrics
201      keys will be suffixed by `"/" + head_name` and the default variable scope
202      will be `head_name`.
203    link_fn: link function to convert logits to predictions. If provided,
204      this link function will be used instead of identity.
205
206  Returns:
207    An instance of `Head` for linear regression.
208  """
209  return _RegressionHead(
210      label_name=label_name,
211      weight_column_name=weight_column_name,
212      label_dimension=label_dimension,
213      enable_centered_bias=enable_centered_bias,
214      head_name=head_name,
215      loss_fn=_mean_squared_loss,
216      link_fn=(link_fn if link_fn is not None else array_ops.identity))
217
218
219def poisson_regression_head(label_name=None,
220                            weight_column_name=None,
221                            label_dimension=1,
222                            enable_centered_bias=False,
223                            head_name=None):
224  """Creates a `Head` for poisson regression.
225
226  Args:
227    label_name: String, name of the key in label dict. Can be null if label
228        is a tensor (single headed models).
229    weight_column_name: A string defining feature column name representing
230      weights. It is used to down weight or boost examples during training. It
231      will be multiplied by the loss of the example.
232    label_dimension: Number of regression labels per example. This is the size
233      of the last dimension of the labels `Tensor` (typically, this has shape
234      `[batch_size, label_dimension]`).
235    enable_centered_bias: A bool. If True, estimator will learn a centered
236      bias variable for each class. Rest of the model structure learns the
237      residual after centered bias.
238    head_name: name of the head. If provided, predictions, summary and metrics
239      keys will be suffixed by `"/" + head_name` and the default variable scope
240      will be `head_name`.
241
242  Returns:
243    An instance of `Head` for poisson regression.
244  """
245  return _RegressionHead(
246      label_name=label_name,
247      weight_column_name=weight_column_name,
248      label_dimension=label_dimension,
249      enable_centered_bias=enable_centered_bias,
250      head_name=head_name,
251      loss_fn=_poisson_loss,
252      link_fn=math_ops.exp)
253
254# TODO(zakaria): Consider adding a _RegressionHead for logistic_regression
255
256
257def multi_class_head(n_classes,
258                     label_name=None,
259                     weight_column_name=None,
260                     enable_centered_bias=False,
261                     head_name=None,
262                     thresholds=None,
263                     metric_class_ids=None,
264                     loss_fn=None,
265                     label_keys=None):
266  """Creates a `Head` for multi class single label classification.
267
268  The Head uses softmax cross entropy loss.
269
270  This head expects to be fed integer labels specifying the class index. But
271  if `label_keys` is specified, then labels must be strings from this
272  vocabulary, and the predicted classes will be strings from the same
273  vocabulary.
274
275  Args:
276    n_classes: Integer, number of classes, must be >= 2
277    label_name: String, name of the key in label dict. Can be null if label
278        is a tensor (single headed models).
279    weight_column_name: A string defining feature column name representing
280      weights. It is used to down weight or boost examples during training. It
281      will be multiplied by the loss of the example.
282    enable_centered_bias: A bool. If True, estimator will learn a centered
283      bias variable for each class. Rest of the model structure learns the
284      residual after centered bias.
285    head_name: name of the head. If provided, predictions, summary and metrics
286      keys will be suffixed by `"/" + head_name` and the default variable scope
287      will be `head_name`.
288    thresholds: thresholds for eval metrics, defaults to [.5]
289    metric_class_ids: List of class IDs for which we should report per-class
290      metrics. Must all be in the range `[0, n_classes)`. Invalid if
291      `n_classes` is 2.
292    loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as
293      parameter and returns a weighted scalar loss. `weights` should be
294      optional. See `tf.losses`
295    label_keys: Optional list of strings with size `[n_classes]` defining the
296      label vocabulary. Only supported for `n_classes` > 2.
297
298  Returns:
299    An instance of `Head` for multi class classification.
300
301  Raises:
302    ValueError: if `n_classes` is < 2.
303    ValueError: If `metric_class_ids` is provided when `n_classes` is 2.
304    ValueError: If `len(label_keys) != n_classes`.
305  """
306  if (n_classes is None) or (n_classes < 2):
307    raise ValueError("n_classes must be > 1 for classification: %s." %
308                     n_classes)
309  if loss_fn:
310    _verify_loss_fn_args(loss_fn)
311
312  loss_fn = _wrap_custom_loss_fn(loss_fn) if loss_fn else None
313  if n_classes == 2:
314    if metric_class_ids:
315      raise ValueError("metric_class_ids invalid for n_classes==2.")
316    if label_keys:
317      raise ValueError("label_keys is not supported for n_classes=2.")
318    return _BinaryLogisticHead(
319        label_name=label_name,
320        weight_column_name=weight_column_name,
321        enable_centered_bias=enable_centered_bias,
322        head_name=head_name,
323        thresholds=thresholds,
324        loss_fn=loss_fn)
325
326  return _MultiClassHead(
327      n_classes=n_classes,
328      label_name=label_name,
329      weight_column_name=weight_column_name,
330      enable_centered_bias=enable_centered_bias,
331      head_name=head_name,
332      thresholds=thresholds,
333      metric_class_ids=metric_class_ids,
334      loss_fn=loss_fn,
335      label_keys=label_keys)
336
337
338def binary_svm_head(
339    label_name=None,
340    weight_column_name=None,
341    enable_centered_bias=False,
342    head_name=None,
343    thresholds=None,):
344  """Creates a `Head` for binary classification with SVMs.
345
346  The head uses binary hinge loss.
347
348  Args:
349    label_name: String, name of the key in label dict. Can be null if label
350      is a tensor (single headed models).
351    weight_column_name: A string defining feature column name representing
352      weights. It is used to down weight or boost examples during training. It
353      will be multiplied by the loss of the example.
354    enable_centered_bias: A bool. If True, estimator will learn a centered
355      bias variable for each class. Rest of the model structure learns the
356      residual after centered bias.
357    head_name: name of the head. If provided, predictions, summary and metrics
358      keys will be suffixed by `"/" + head_name` and the default variable scope
359      will be `head_name`.
360    thresholds: thresholds for eval metrics, defaults to [.5]
361
362  Returns:
363    An instance of `Head` for binary classification with SVM.
364  """
365  return _BinarySvmHead(
366      label_name=label_name,
367      weight_column_name=weight_column_name,
368      enable_centered_bias=enable_centered_bias,
369      head_name=head_name,
370      thresholds=thresholds)
371
372
373def multi_label_head(n_classes,
374                     label_name=None,
375                     weight_column_name=None,
376                     enable_centered_bias=False,
377                     head_name=None,
378                     thresholds=None,
379                     metric_class_ids=None,
380                     loss_fn=None):
381  """Creates a Head for multi label classification.
382
383  Multi-label classification handles the case where each example may have zero
384  or more associated labels, from a discrete set.  This is distinct from
385  `multi_class_head` which has exactly one label from a discrete set.
386
387  This head by default uses sigmoid cross entropy loss, which expects as input
388  a multi-hot tensor of shape `(batch_size, num_classes)`.
389
390  Args:
391    n_classes: Integer, number of classes, must be >= 2
392    label_name: String, name of the key in label dict. Can be null if label
393        is a tensor (single headed models).
394    weight_column_name: A string defining feature column name representing
395      weights. It is used to down weight or boost examples during training. It
396      will be multiplied by the loss of the example.
397    enable_centered_bias: A bool. If True, estimator will learn a centered
398      bias variable for each class. Rest of the model structure learns the
399      residual after centered bias.
400    head_name: name of the head. If provided, predictions, summary and metrics
401      keys will be suffixed by `"/" + head_name` and the default variable scope
402      will be `head_name`.
403    thresholds: thresholds for eval metrics, defaults to [.5]
404    metric_class_ids: List of class IDs for which we should report per-class
405      metrics. Must all be in the range `[0, n_classes)`.
406    loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as
407      parameter and returns a weighted scalar loss. `weights` should be
408      optional. See `tf.losses`
409
410  Returns:
411    An instance of `Head` for multi label classification.
412
413  Raises:
414    ValueError: If n_classes is < 2
415    ValueError: If loss_fn does not have expected signature.
416  """
417  if n_classes < 2:
418    raise ValueError("n_classes must be > 1 for classification.")
419  if loss_fn:
420    _verify_loss_fn_args(loss_fn)
421
422  return _MultiLabelHead(
423      n_classes=n_classes,
424      label_name=label_name,
425      weight_column_name=weight_column_name,
426      enable_centered_bias=enable_centered_bias,
427      head_name=head_name,
428      thresholds=thresholds,
429      metric_class_ids=metric_class_ids,
430      loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
431
432
433def loss_only_head(loss_fn, head_name=None):
434  """Creates a Head that contains only loss terms.
435
436  Loss only head holds additional loss terms to be added to other heads and
437  usually represents additional regularization terms in the objective function.
438
439  Args:
440    loss_fn: a function that takes no argument and returns a list of
441        scalar tensors.
442    head_name: a name for the head.
443
444  Returns:
445    An instance of `Head` to hold the additional losses.
446  """
447  return _LossOnlyHead(loss_fn, head_name=head_name)
448
449
450def multi_head(heads, loss_weights=None):
451  """Creates a MultiHead stemming from same logits/hidden layer.
452
453  Args:
454    heads: list of Head objects.
455    loss_weights: optional list of weights to be used to merge losses from
456        each head. All losses are weighted equally if not provided.
457
458  Returns:
459    A instance of `Head` that merges multiple heads.
460
461  Raises:
462    ValueError: if heads and loss_weights have different size.
463  """
464  if loss_weights:
465    if len(loss_weights) != len(heads):
466      raise ValueError("heads and loss_weights must have same size")
467
468  def _weighted_loss_merger(losses):
469    if loss_weights:
470      if len(losses) != len(loss_weights):
471        raise ValueError("losses and loss_weights must have same size")
472      weighted_losses = []
473      for loss, weight in zip(losses, loss_weights):
474        weighted_losses.append(math_ops.multiply(loss, weight))
475      return math_ops.add_n(weighted_losses)
476    else:
477      return math_ops.add_n(losses)
478
479  return _MultiHead(heads, loss_merger=_weighted_loss_merger)
480
481
482def no_op_train_fn(loss):
483  del loss
484  return control_flow_ops.no_op()
485
486
487class _SingleHead(Head):
488  """Interface for a single head/top of a model."""
489  __metaclass__ = abc.ABCMeta
490
491  def __init__(
492      self, problem_type, logits_dimension, label_name=None,
493      weight_column_name=None, head_name=None):
494    if problem_type is None:
495      raise ValueError("Invalid problem_type %s." % problem_type)
496    if logits_dimension is None or logits_dimension < 1:
497      raise ValueError("Invalid logits_dimension %s." % logits_dimension)
498    self._problem_type = problem_type
499    self._logits_dimension = logits_dimension
500    self._label_name = label_name
501    self._weight_column_name = weight_column_name
502    self._head_name = head_name
503
504  @property
505  def logits_dimension(self):
506    return self._logits_dimension
507
508  @property
509  def label_name(self):
510    return self._label_name
511
512  @property
513  def weight_column_name(self):
514    return self._weight_column_name
515
516  @property
517  def head_name(self):
518    return self._head_name
519
520  def _create_output_alternatives(self, predictions):
521    """Creates output alternative for the Head.
522
523    Args:
524      predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a
525        symbolic name for an output Tensor possibly but not necessarily taken
526        from `PredictionKey`, and 'Tensor' is the corresponding output Tensor
527        itself.
528
529    Returns:
530      `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where
531      'submodel_name' is a submodel identifier that should be consistent across
532      the pipeline (here likely taken from the head_name),
533      'problem_type' is a `ProblemType`,
534      'tensor_name' is a symbolic name for an output Tensor possibly but not
535       necessarily taken from `PredictionKey`, and
536      'Tensor' is the corresponding output Tensor itself.
537    """
538    return {self._head_name: (self._problem_type, predictions)}
539
540
541# TODO(zakaria): use contrib losses.
542def _mean_squared_loss(labels, logits, weights=None):
543  with ops.name_scope(None, "mean_squared_loss", (logits, labels)) as name:
544    logits = ops.convert_to_tensor(logits)
545    labels = ops.convert_to_tensor(labels)
546    # To prevent broadcasting inside "-".
547    if len(labels.get_shape()) == 1:
548      labels = array_ops.expand_dims(labels, dim=(1,))
549    # TODO(zakaria): make sure it does not recreate the broadcast bug.
550    if len(logits.get_shape()) == 1:
551      logits = array_ops.expand_dims(logits, dim=(1,))
552    logits.get_shape().assert_is_compatible_with(labels.get_shape())
553    loss = math_ops.square(logits - math_ops.to_float(labels), name=name)
554    return _compute_weighted_loss(loss, weights)
555
556
557def _poisson_loss(labels, logits, weights=None):
558  """Computes poisson loss from logits."""
559  with ops.name_scope(None, "_poisson_loss", (logits, labels)) as name:
560    logits = ops.convert_to_tensor(logits)
561    labels = ops.convert_to_tensor(labels)
562    # To prevent broadcasting inside "-".
563    if len(labels.get_shape()) == 1:
564      labels = array_ops.expand_dims(labels, dim=(1,))
565    # TODO(zakaria): make sure it does not recreate the broadcast bug.
566    if len(logits.get_shape()) == 1:
567      logits = array_ops.expand_dims(logits, dim=(1,))
568    logits.get_shape().assert_is_compatible_with(labels.get_shape())
569    loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True,
570                               name=name)
571    return _compute_weighted_loss(loss, weights)
572
573
574def _logits(logits_input, logits, logits_dimension):
575  """Validate logits args, and create `logits` if necessary.
576
577  Exactly one of `logits_input` and `logits` must be provided.
578
579  Args:
580    logits_input: `Tensor` input to `logits`.
581    logits: `Tensor` output.
582    logits_dimension: Integer, last dimension of `logits`. This is used to
583      create `logits` from `logits_input` if `logits` is `None`; otherwise, it's
584      used to validate `logits`.
585
586  Returns:
587    `logits` `Tensor`.
588
589  Raises:
590    ValueError: if neither or both of `logits` and `logits_input` are supplied.
591  """
592  if (logits_dimension is None) or (logits_dimension < 1):
593    raise ValueError("Invalid logits_dimension %s." % logits_dimension)
594
595  # If not provided, create logits.
596  if logits is None:
597    if logits_input is None:
598      raise ValueError("Neither logits nor logits_input supplied.")
599    return layers_lib.linear(logits_input, logits_dimension, scope="logits")
600
601  if logits_input is not None:
602    raise ValueError("Both logits and logits_input supplied.")
603
604  logits = ops.convert_to_tensor(logits, name="logits")
605  logits_dims = logits.get_shape().dims
606  if logits_dims is not None:
607    logits_dims[-1].assert_is_compatible_with(logits_dimension)
608
609  return logits
610
611
612def _create_model_fn_ops(features,
613                         mode,
614                         loss_fn,
615                         logits_to_predictions_fn,
616                         metrics_fn,
617                         create_output_alternatives_fn,
618                         labels=None,
619                         train_op_fn=None,
620                         logits=None,
621                         logits_dimension=None,
622                         head_name=None,
623                         weight_column_name=None,
624                         enable_centered_bias=False):
625  """Returns a `ModelFnOps` object."""
626  _check_mode_valid(mode)
627
628  centered_bias = None
629  if enable_centered_bias:
630    centered_bias = _centered_bias(logits_dimension, head_name)
631    logits = nn.bias_add(logits, centered_bias)
632
633  predictions = logits_to_predictions_fn(logits)
634  loss = None
635  train_op = None
636  eval_metric_ops = None
637  if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
638    weight_tensor = _weight_tensor(features, weight_column_name)
639    loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor)
640    # The name_scope escapism is needed to maintain the same summary tag
641    # after switching away from the now unsupported API.
642    with ops.name_scope(""):
643      summary_loss = array_ops.identity(weighted_average_loss)
644      summary.scalar(_summary_key(head_name, mkey.LOSS), summary_loss)
645
646    if mode == model_fn.ModeKeys.TRAIN:
647      if train_op_fn is None:
648        raise ValueError("train_op_fn can not be None in TRAIN mode")
649      batch_size = array_ops.shape(logits)[0]
650      train_op = _train_op(loss, labels, train_op_fn, centered_bias,
651                           batch_size, loss_fn, weight_tensor)
652    eval_metric_ops = metrics_fn(
653        weighted_average_loss, predictions, labels, weight_tensor)
654  return model_fn.ModelFnOps(
655      mode=mode,
656      predictions=predictions,
657      loss=loss,
658      train_op=train_op,
659      eval_metric_ops=eval_metric_ops,
660      output_alternatives=create_output_alternatives_fn(predictions))
661
662
663class _RegressionHead(_SingleHead):
664  """`Head` for regression with a generalized linear model."""
665
666  def __init__(self,
667               label_dimension,
668               loss_fn,
669               link_fn,
670               logits_dimension=None,
671               label_name=None,
672               weight_column_name=None,
673               enable_centered_bias=False,
674               head_name=None):
675    """`Head` for regression.
676
677    Args:
678      label_dimension: Number of regression labels per example. This is the
679        size of the last dimension of the labels `Tensor` (typically, this has
680        shape `[batch_size, label_dimension]`).
681      loss_fn: Loss function, takes logits and labels and returns loss.
682      link_fn: Link function, takes a logits tensor and returns the output.
683      logits_dimension: Number of logits per example. This is the
684        size of the last dimension of the logits `Tensor` (typically, this has
685        shape `[batch_size, label_dimension]`).
686        Default value: `label_dimension`.
687      label_name: String, name of the key in label dict. Can be null if label
688          is a tensor (single headed models).
689      weight_column_name: A string defining feature column name representing
690        weights. It is used to down weight or boost examples during training. It
691        will be multiplied by the loss of the example.
692      enable_centered_bias: A bool. If True, estimator will learn a centered
693        bias variable for each class. Rest of the model structure learns the
694        residual after centered bias.
695      head_name: name of the head. Predictions, summary and metrics keys are
696        suffixed by `"/" + head_name` and the default variable scope is
697        `head_name`.
698    """
699    super(_RegressionHead, self).__init__(
700        problem_type=constants.ProblemType.LINEAR_REGRESSION,
701        logits_dimension=(logits_dimension if logits_dimension is not None
702                          else label_dimension),
703        label_name=label_name,
704        weight_column_name=weight_column_name,
705        head_name=head_name)
706
707    self._loss_fn = loss_fn
708    self._link_fn = link_fn
709    self._enable_centered_bias = enable_centered_bias
710
711  def create_model_fn_ops(self,
712                          features,
713                          mode,
714                          labels=None,
715                          train_op_fn=None,
716                          logits=None,
717                          logits_input=None,
718                          scope=None):
719    """See `Head`."""
720    with variable_scope.variable_scope(
721        scope,
722        default_name=self.head_name or "regression_head",
723        values=(tuple(six.itervalues(features)) +
724                (labels, logits, logits_input))):
725      labels = self._transform_labels(mode=mode, labels=labels)
726      logits = _logits(logits_input, logits, self.logits_dimension)
727      return _create_model_fn_ops(
728          features=features,
729          mode=mode,
730          loss_fn=self._loss_fn,
731          logits_to_predictions_fn=self._logits_to_predictions,
732          metrics_fn=self._metrics,
733          create_output_alternatives_fn=self._create_output_alternatives,
734          labels=labels,
735          train_op_fn=train_op_fn,
736          logits=logits,
737          logits_dimension=self.logits_dimension,
738          head_name=self.head_name,
739          weight_column_name=self.weight_column_name,
740          enable_centered_bias=self._enable_centered_bias)
741
742  def _transform_labels(self, mode, labels):
743    """Applies transformations to labels tensor."""
744    if (mode == model_fn.ModeKeys.INFER) or (labels is None):
745      return None
746    labels_tensor = _to_labels_tensor(labels, self._label_name)
747    _check_no_sparse_tensor(labels_tensor)
748    return labels_tensor
749
750  def _logits_to_predictions(self, logits):
751    """Returns a dict of predictions.
752
753    Args:
754      logits: logits `Tensor` after applying possible centered bias.
755
756    Returns:
757      Dict of prediction `Tensor` keyed by `PredictionKey`.
758    """
759    key = prediction_key.PredictionKey.SCORES
760    with ops.name_scope(None, "predictions", (logits,)):
761      if self.logits_dimension == 1:
762        logits = array_ops.squeeze(logits, squeeze_dims=(1,), name=key)
763      return {key: self._link_fn(logits)}
764
765  def _metrics(self, eval_loss, predictions, labels, weights):
766    """Returns a dict of metrics keyed by name."""
767    del predictions, labels, weights  # Unused by this head.
768    with ops.name_scope("metrics", values=[eval_loss]):
769      return {
770          _summary_key(self.head_name, mkey.LOSS):
771              metrics_lib.mean(eval_loss)}
772
773
774def _log_loss_with_two_classes(labels, logits, weights=None):
775  with ops.name_scope(None, "log_loss_with_two_classes",
776                      (logits, labels)) as name:
777    logits = ops.convert_to_tensor(logits)
778    labels = math_ops.to_float(labels)
779    # TODO(ptucker): This will break for dynamic shapes.
780    # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
781    if len(labels.get_shape()) == 1:
782      labels = array_ops.expand_dims(labels, dim=(1,))
783    loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits,
784                                                name=name)
785    return _compute_weighted_loss(loss, weights)
786
787
788def _one_class_to_two_class_logits(logits):
789  return array_ops.concat((array_ops.zeros_like(logits), logits), 1)
790
791
792class _BinaryLogisticHead(_SingleHead):
793  """`Head` for binary classification with logistic regression."""
794
795  def __init__(self,
796               label_name=None,
797               weight_column_name=None,
798               enable_centered_bias=False,
799               head_name=None,
800               loss_fn=None,
801               thresholds=None):
802    """`Head` for binary classification with logistic regression.
803
804    Args:
805      label_name: String, name of the key in label dict. Can be `None` if label
806          is a tensor (single headed models).
807      weight_column_name: A string defining feature column name representing
808        weights. It is used to down weight or boost examples during training. It
809        will be multiplied by the loss of the example.
810      enable_centered_bias: A bool. If True, estimator will learn a centered
811        bias variable for each class. Rest of the model structure learns the
812        residual after centered bias.
813      head_name: name of the head. Predictions, summary, metrics keys are
814        suffixed by `"/" + head_name` and the default variable scope is
815        `head_name`.
816      loss_fn: Loss function.
817      thresholds: thresholds for eval.
818
819    Raises:
820      ValueError: if n_classes is invalid.
821    """
822    super(_BinaryLogisticHead, self).__init__(
823        problem_type=constants.ProblemType.LOGISTIC_REGRESSION,
824        logits_dimension=1,
825        label_name=label_name,
826        weight_column_name=weight_column_name,
827        head_name=head_name)
828    self._thresholds = thresholds if thresholds else (.5,)
829    self._loss_fn = loss_fn if loss_fn else _log_loss_with_two_classes
830    self._enable_centered_bias = enable_centered_bias
831
832  def create_model_fn_ops(self,
833                          features,
834                          mode,
835                          labels=None,
836                          train_op_fn=None,
837                          logits=None,
838                          logits_input=None,
839                          scope=None):
840    """See `Head`."""
841    with variable_scope.variable_scope(
842        scope,
843        default_name=self.head_name or "binary_logistic_head",
844        values=(tuple(six.itervalues(features)) +
845                (labels, logits, logits_input))):
846      labels = self._transform_labels(mode=mode, labels=labels)
847      logits = _logits(logits_input, logits, self.logits_dimension)
848      return _create_model_fn_ops(
849          features=features,
850          mode=mode,
851          loss_fn=self._loss_fn,
852          logits_to_predictions_fn=self._logits_to_predictions,
853          metrics_fn=self._metrics,
854          create_output_alternatives_fn=_classification_output_alternatives(
855              self.head_name, self._problem_type),
856          labels=labels,
857          train_op_fn=train_op_fn,
858          logits=logits,
859          logits_dimension=self.logits_dimension,
860          head_name=self.head_name,
861          weight_column_name=self.weight_column_name,
862          enable_centered_bias=self._enable_centered_bias)
863
864  def _transform_labels(self, mode, labels):
865    """Applies transformations to labels tensor."""
866    if (mode == model_fn.ModeKeys.INFER) or (labels is None):
867      return None
868    labels_tensor = _to_labels_tensor(labels, self._label_name)
869    _check_no_sparse_tensor(labels_tensor)
870    return labels_tensor
871
872  def _logits_to_predictions(self, logits):
873    """Returns a dict of predictions.
874
875    Args:
876      logits: logits `Output` after applying possible centered bias.
877
878    Returns:
879      Dict of prediction `Output` keyed by `PredictionKey`.
880    """
881    with ops.name_scope(None, "predictions", (logits,)):
882      two_class_logits = _one_class_to_two_class_logits(logits)
883      return {
884          prediction_key.PredictionKey.LOGITS:
885              logits,
886          prediction_key.PredictionKey.LOGISTIC:
887              math_ops.sigmoid(
888                  logits, name=prediction_key.PredictionKey.LOGISTIC),
889          prediction_key.PredictionKey.PROBABILITIES:
890              nn.softmax(
891                  two_class_logits,
892                  name=prediction_key.PredictionKey.PROBABILITIES),
893          prediction_key.PredictionKey.CLASSES:
894              math_ops.argmax(
895                  two_class_logits,
896                  1,
897                  name=prediction_key.PredictionKey.CLASSES)
898      }
899
900  def _metrics(self, eval_loss, predictions, labels, weights):
901    """Returns a dict of metrics keyed by name."""
902    with ops.name_scope("metrics", values=(
903        [eval_loss, labels, weights] + list(six.itervalues(predictions)))):
904      classes = predictions[prediction_key.PredictionKey.CLASSES]
905      logistic = predictions[prediction_key.PredictionKey.LOGISTIC]
906
907      metrics = {_summary_key(self.head_name, mkey.LOSS):
908                 metrics_lib.mean(eval_loss)}
909      # TODO(b/29366811): This currently results in both an "accuracy" and an
910      # "accuracy/threshold_0.500000_mean" metric for binary classification.
911      metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
912          metrics_lib.accuracy(labels, classes, weights))
913      metrics[_summary_key(self.head_name, mkey.PREDICTION_MEAN)] = (
914          _predictions_streaming_mean(logistic, weights))
915      metrics[_summary_key(self.head_name, mkey.LABEL_MEAN)] = (
916          _indicator_labels_streaming_mean(labels, weights))
917
918      # Also include the streaming mean of the label as an accuracy baseline, as
919      # a reminder to users.
920      metrics[_summary_key(self.head_name, mkey.ACCURACY_BASELINE)] = (
921          _indicator_labels_streaming_mean(labels, weights))
922      metrics[_summary_key(self.head_name, mkey.AUC)] = (
923          _streaming_auc(logistic, labels, weights))
924      metrics[_summary_key(self.head_name, mkey.AUC_PR)] = (
925          _streaming_auc(logistic, labels, weights, curve="PR"))
926
927      for threshold in self._thresholds:
928        metrics[_summary_key(
929            self.head_name, mkey.ACCURACY_MEAN % threshold)] = (
930                _streaming_accuracy_at_threshold(logistic, labels, weights,
931                                                 threshold))
932        # Precision for positive examples.
933        metrics[_summary_key(
934            self.head_name, mkey.PRECISION_MEAN % threshold)] = (
935                _streaming_precision_at_threshold(logistic, labels, weights,
936                                                  threshold))
937        # Recall for positive examples.
938        metrics[_summary_key(
939            self.head_name, mkey.RECALL_MEAN % threshold)] = (
940                _streaming_recall_at_threshold(logistic, labels, weights,
941                                               threshold))
942
943    return metrics
944
945
946def _softmax_cross_entropy_loss(labels, logits, weights=None):
947  with ops.name_scope(
948      None, "softmax_cross_entropy_loss", (logits, labels,)) as name:
949    labels = ops.convert_to_tensor(labels)
950    # Check that we got integer for classification.
951    if not labels.dtype.is_integer:
952      raise ValueError("Labels dtype should be integer "
953                       "Instead got %s." % labels.dtype)
954
955    # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
956    is_squeezed_labels = False
957    # TODO(ptucker): This will break for dynamic shapes.
958    if len(labels.get_shape()) == 2:
959      labels = array_ops.squeeze(labels, squeeze_dims=(1,))
960      is_squeezed_labels = True
961
962    loss = nn.sparse_softmax_cross_entropy_with_logits(
963        labels=labels, logits=logits, name=name)
964
965    # Restore squeezed dimension, if necessary, so loss matches weights shape.
966    if is_squeezed_labels:
967      loss = array_ops.expand_dims(loss, axis=(1,))
968
969    return _compute_weighted_loss(loss, weights)
970
971
972class _MultiClassHead(_SingleHead):
973  """'Head' for multi class classification."""
974
975  def __init__(self,
976               n_classes,
977               label_name=None,
978               weight_column_name=None,
979               enable_centered_bias=False,
980               head_name=None,
981               loss_fn=None,
982               thresholds=None,
983               metric_class_ids=None,
984               label_keys=None):
985    """'Head' for multi class classification.
986
987    This head expects to be fed integer labels specifying the class index. But
988    if `label_keys` is specified, then labels must be strings from this
989    vocabulary, and the predicted classes will be strings from the same
990    vocabulary.
991
992    Args:
993      n_classes: Number of classes, must be greater than 2 (for 2 classes, use
994        `_BinaryLogisticHead`).
995      label_name: String, name of the key in label dict. Can be null if label
996        is a tensor (single headed models).
997      weight_column_name: A string defining feature column name representing
998        weights. It is used to down weight or boost examples during training. It
999        will be multiplied by the loss of the example.
1000      enable_centered_bias: A bool. If True, estimator will learn a centered
1001        bias variable for each class. Rest of the model structure learns the
1002        residual after centered bias.
1003      head_name: name of the head. If provided, predictions, summary, metrics
1004        keys will be suffixed by `"/" + head_name` and the default variable
1005        scope will be `head_name`.
1006      loss_fn: Loss function. Defaults to softmax cross entropy loss.
1007      thresholds: thresholds for eval.
1008      metric_class_ids: List of class IDs for which we should report per-class
1009        metrics. Must all be in the range `[0, n_classes)`.
1010      label_keys: Optional list of strings with size `[n_classes]` defining the
1011        label vocabulary.
1012
1013    Raises:
1014      ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid.
1015    """
1016    super(_MultiClassHead, self).__init__(
1017        problem_type=constants.ProblemType.CLASSIFICATION,
1018        logits_dimension=n_classes,
1019        label_name=label_name,
1020        weight_column_name=weight_column_name,
1021        head_name=head_name)
1022
1023    if (n_classes is None) or (n_classes <= 2):
1024      raise ValueError("n_classes must be > 2: %s." % n_classes)
1025    self._thresholds = thresholds if thresholds else (.5,)
1026    self._loss_fn = loss_fn if loss_fn else _softmax_cross_entropy_loss
1027    self._enable_centered_bias = enable_centered_bias
1028    self._metric_class_ids = tuple([] if metric_class_ids is None else
1029                                   metric_class_ids)
1030    for class_id in self._metric_class_ids:
1031      if (class_id < 0) or (class_id >= n_classes):
1032        raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes))
1033    if label_keys and len(label_keys) != n_classes:
1034      raise ValueError("Length of label_keys must equal n_classes.")
1035    self._label_keys = label_keys
1036
1037  def create_model_fn_ops(self,
1038                          features,
1039                          mode,
1040                          labels=None,
1041                          train_op_fn=None,
1042                          logits=None,
1043                          logits_input=None,
1044                          scope=None):
1045    """See `Head`."""
1046    with variable_scope.variable_scope(
1047        scope,
1048        default_name=self.head_name or "multi_class_head",
1049        values=(tuple(six.itervalues(features)) +
1050                (labels, logits, logits_input))):
1051      labels = self._transform_labels(mode=mode, labels=labels)
1052      logits = _logits(logits_input, logits, self.logits_dimension)
1053      return _create_model_fn_ops(
1054          features=features,
1055          mode=mode,
1056          loss_fn=self._wrapped_loss_fn,
1057          logits_to_predictions_fn=self._logits_to_predictions,
1058          metrics_fn=self._metrics,
1059          create_output_alternatives_fn=_classification_output_alternatives(
1060              self.head_name, self._problem_type, self._label_keys),
1061          labels=labels,
1062          train_op_fn=train_op_fn,
1063          logits=logits,
1064          logits_dimension=self.logits_dimension,
1065          head_name=self.head_name,
1066          weight_column_name=self.weight_column_name,
1067          enable_centered_bias=self._enable_centered_bias)
1068
1069  def _transform_labels(self, mode, labels):
1070    """Returns a dict that contains both the original labels and label IDs."""
1071    if (mode == model_fn.ModeKeys.INFER) or (labels is None):
1072      return None
1073    labels_tensor = _to_labels_tensor(labels, self._label_name)
1074    _check_no_sparse_tensor(labels_tensor)
1075    if self._label_keys:
1076      table = lookup_ops.index_table_from_tensor(
1077          self._label_keys, name="label_id_lookup")
1078      return {
1079          "labels": labels_tensor,
1080          "label_ids": table.lookup(labels_tensor),
1081      }
1082    return {
1083        "labels": labels_tensor,
1084        "label_ids": labels_tensor,
1085    }
1086
1087  def _labels(self, labels_dict):
1088    """Returns labels `Tensor` of the same type as classes."""
1089    return labels_dict["labels"]
1090
1091  def _label_ids(self, labels_dict):
1092    """Returns integer label ID `Tensor`."""
1093    return labels_dict["label_ids"]
1094
1095  def _wrapped_loss_fn(self, labels, logits, weights=None):
1096    return self._loss_fn(self._label_ids(labels), logits, weights=weights)
1097
1098  def _logits_to_predictions(self, logits):
1099    """Returns a dict of predictions.
1100
1101    Args:
1102      logits: logits `Tensor` after applying possible centered bias.
1103
1104    Returns:
1105      Dict of prediction `Tensor` keyed by `PredictionKey`.
1106    """
1107    with ops.name_scope(None, "predictions", (logits,)):
1108      class_ids = math_ops.argmax(
1109          logits, 1, name=prediction_key.PredictionKey.CLASSES)
1110      if self._label_keys:
1111        table = lookup_ops.index_to_string_table_from_tensor(
1112            self._label_keys, name="class_string_lookup")
1113        classes = table.lookup(class_ids)
1114      else:
1115        classes = class_ids
1116      return {
1117          prediction_key.PredictionKey.LOGITS: logits,
1118          prediction_key.PredictionKey.PROBABILITIES:
1119              nn.softmax(
1120                  logits, name=prediction_key.PredictionKey.PROBABILITIES),
1121          prediction_key.PredictionKey.CLASSES: classes
1122      }
1123
1124  def _metrics(self, eval_loss, predictions, labels, weights):
1125    """Returns a dict of metrics keyed by name."""
1126    with ops.name_scope(
1127        "metrics",
1128        values=((eval_loss, self._labels(labels), self._label_ids(labels),
1129                 weights) + tuple(six.itervalues(predictions)))):
1130      logits = predictions[prediction_key.PredictionKey.LOGITS]
1131      probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES]
1132      classes = predictions[prediction_key.PredictionKey.CLASSES]
1133
1134      metrics = {_summary_key(self.head_name, mkey.LOSS):
1135                 metrics_lib.mean(eval_loss)}
1136      # TODO(b/29366811): This currently results in both an "accuracy" and an
1137      # "accuracy/threshold_0.500000_mean" metric for binary classification.
1138      metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
1139          metrics_lib.accuracy(self._labels(labels), classes, weights))
1140
1141      if not self._label_keys:
1142        # Classes are IDs. Add some metrics.
1143        for class_id in self._metric_class_ids:
1144          metrics[_summary_key(
1145              self.head_name, mkey.CLASS_PREDICTION_MEAN % class_id)] = (
1146                  _class_predictions_streaming_mean(classes, weights, class_id))
1147          # TODO(ptucker): Add per-class accuracy, precision, recall.
1148          metrics[_summary_key(
1149              self.head_name, mkey.CLASS_LABEL_MEAN % class_id)] = (
1150                  _class_labels_streaming_mean(
1151                      self._label_ids(labels), weights, class_id))
1152          metrics[_summary_key(
1153              self.head_name, mkey.CLASS_PROBABILITY_MEAN % class_id)] = (
1154                  _predictions_streaming_mean(probabilities, weights, class_id))
1155          metrics[_summary_key(
1156              self.head_name, mkey.CLASS_LOGITS_MEAN % class_id)] = (
1157                  _predictions_streaming_mean(logits, weights, class_id))
1158
1159    return metrics
1160
1161
1162def _to_labels_tensor(labels, label_name):
1163  """Returns label as a tensor.
1164
1165  Args:
1166    labels: Label `Tensor` or `SparseTensor` or a dict containing labels.
1167    label_name: Label name if labels is a dict.
1168
1169  Returns:
1170    Label `Tensor` or `SparseTensor`.
1171  """
1172  labels = labels[label_name] if isinstance(labels, dict) else labels
1173  return framework_lib.convert_to_tensor_or_sparse_tensor(labels)
1174
1175
1176def _check_no_sparse_tensor(x):
1177  """Raises ValueError if the given tensor is `SparseTensor`."""
1178  if isinstance(x, sparse_tensor.SparseTensor):
1179    raise ValueError("SparseTensor is not supported.")
1180
1181
1182def _sparse_labels_to_indicator(labels, num_classes):
1183  """If labels is `SparseTensor`, converts it to indicator `Tensor`.
1184
1185  Args:
1186    labels: Label `Tensor` or `SparseTensor`.
1187    num_classes: Number of classes.
1188
1189  Returns:
1190    Dense label `Tensor`.
1191
1192  Raises:
1193    ValueError: If labels is `SparseTensor` and `num_classes` < 2.
1194  """
1195  if isinstance(labels, sparse_tensor.SparseTensor):
1196    if num_classes < 2:
1197      raise ValueError("Must set num_classes >= 2 when passing labels as a "
1198                       "SparseTensor.")
1199    return math_ops.to_int64(
1200        sparse_ops.sparse_to_indicator(labels, num_classes))
1201  return labels
1202
1203
1204def _assert_labels_rank(labels):
1205  return control_flow_ops.Assert(
1206      math_ops.less_equal(array_ops.rank(labels), 2),
1207      ("labels shape should be either [batch_size, 1] or [batch_size]",))
1208
1209
1210class _BinarySvmHead(_SingleHead):
1211  """`Head` for binary classification using SVM."""
1212
1213  def __init__(self, label_name, weight_column_name, enable_centered_bias,
1214               head_name, thresholds):
1215
1216    def _loss_fn(labels, logits, weights=None):
1217      with ops.name_scope(None, "hinge_loss", (logits, labels)) as name:
1218        with ops.control_dependencies((_assert_labels_rank(labels),)):
1219          labels = array_ops.reshape(labels, shape=(-1, 1))
1220        loss = losses_lib.hinge_loss(labels=labels, logits=logits, scope=name,
1221                                     reduction=losses_lib.Reduction.NONE)
1222        return _compute_weighted_loss(loss, weights)
1223
1224    super(_BinarySvmHead, self).__init__(
1225        problem_type=constants.ProblemType.LOGISTIC_REGRESSION,
1226        logits_dimension=1,
1227        label_name=label_name,
1228        weight_column_name=weight_column_name,
1229        head_name=head_name)
1230    self._thresholds = thresholds if thresholds else (.5,)
1231    self._loss_fn = _loss_fn
1232    self._enable_centered_bias = enable_centered_bias
1233
1234  def create_model_fn_ops(self,
1235                          features,
1236                          mode,
1237                          labels=None,
1238                          train_op_fn=None,
1239                          logits=None,
1240                          logits_input=None,
1241                          scope=None):
1242    """See `Head`."""
1243    with variable_scope.variable_scope(
1244        scope,
1245        default_name=self.head_name or "binary_svm_head",
1246        values=(tuple(six.itervalues(features)) +
1247                (labels, logits, logits_input))):
1248      labels = self._transform_labels(mode=mode, labels=labels)
1249      logits = _logits(logits_input, logits, self.logits_dimension)
1250      return _create_model_fn_ops(
1251          features=features,
1252          mode=mode,
1253          loss_fn=self._loss_fn,
1254          logits_to_predictions_fn=self._logits_to_predictions,
1255          metrics_fn=self._metrics,
1256          # TODO(zakaria): Handle labels for export.
1257          create_output_alternatives_fn=self._create_output_alternatives,
1258          labels=labels,
1259          train_op_fn=train_op_fn,
1260          logits=logits,
1261          logits_dimension=self.logits_dimension,
1262          head_name=self.head_name,
1263          weight_column_name=self.weight_column_name,
1264          enable_centered_bias=self._enable_centered_bias)
1265
1266  def _transform_labels(self, mode, labels):
1267    """Applies transformations to labels tensor."""
1268    if (mode == model_fn.ModeKeys.INFER) or (labels is None):
1269      return None
1270    labels_tensor = _to_labels_tensor(labels, self._label_name)
1271    _check_no_sparse_tensor(labels_tensor)
1272    return labels_tensor
1273
1274  def _logits_to_predictions(self, logits):
1275    """See `_MultiClassHead`."""
1276    with ops.name_scope(None, "predictions", (logits,)):
1277      return {
1278          prediction_key.PredictionKey.LOGITS:
1279              logits,
1280          prediction_key.PredictionKey.CLASSES:
1281              math_ops.argmax(
1282                  _one_class_to_two_class_logits(logits),
1283                  1,
1284                  name=prediction_key.PredictionKey.CLASSES)
1285      }
1286
1287  def _metrics(self, eval_loss, predictions, labels, weights):
1288    """See `_MultiClassHead`."""
1289    with ops.name_scope("metrics", values=(
1290        [eval_loss, labels, weights] + list(six.itervalues(predictions)))):
1291      metrics = {_summary_key(self.head_name, mkey.LOSS):
1292                 metrics_lib.mean(eval_loss)}
1293
1294      # TODO(b/29366811): This currently results in both an "accuracy" and an
1295      # "accuracy/threshold_0.500000_mean" metric for binary classification.
1296      classes = predictions[prediction_key.PredictionKey.CLASSES]
1297      metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
1298          metrics_lib.accuracy(labels, classes, weights))
1299      # TODO(sibyl-vie3Poto): add more metrics relevant for svms.
1300
1301    return metrics
1302
1303
1304class _MultiLabelHead(_SingleHead):
1305  """`Head` for multi-label classification."""
1306
1307  # TODO(zakaria): add signature and metric for multilabel.
1308  def __init__(self,
1309               n_classes,
1310               label_name,
1311               weight_column_name,
1312               enable_centered_bias,
1313               head_name,
1314               thresholds,
1315               metric_class_ids=None,
1316               loss_fn=None):
1317
1318    super(_MultiLabelHead, self).__init__(
1319        problem_type=constants.ProblemType.CLASSIFICATION,
1320        logits_dimension=n_classes,
1321        label_name=label_name,
1322        weight_column_name=weight_column_name,
1323        head_name=head_name)
1324
1325    self._thresholds = thresholds if thresholds else (.5,)
1326    self._loss_fn = loss_fn if loss_fn else _sigmoid_cross_entropy_loss
1327    self._enable_centered_bias = enable_centered_bias
1328    self._metric_class_ids = tuple([] if metric_class_ids is None else
1329                                   metric_class_ids)
1330    for class_id in self._metric_class_ids:
1331      if (class_id < 0) or (class_id >= n_classes):
1332        raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes))
1333
1334  def create_model_fn_ops(self,
1335                          features,
1336                          mode,
1337                          labels=None,
1338                          train_op_fn=None,
1339                          logits=None,
1340                          logits_input=None,
1341                          scope=None):
1342    """See `Head`."""
1343    with variable_scope.variable_scope(
1344        scope,
1345        default_name=self.head_name or "multi_label_head",
1346        values=(tuple(six.itervalues(features)) +
1347                (labels, logits, logits_input))):
1348      labels = self._transform_labels(mode=mode, labels=labels)
1349      logits = _logits(logits_input, logits, self.logits_dimension)
1350      return _create_model_fn_ops(
1351          features=features,
1352          mode=mode,
1353          loss_fn=self._loss_fn,
1354          logits_to_predictions_fn=self._logits_to_predictions,
1355          metrics_fn=self._metrics,
1356          create_output_alternatives_fn=_classification_output_alternatives(
1357              self.head_name, self._problem_type),
1358          labels=labels,
1359          train_op_fn=train_op_fn,
1360          logits=logits,
1361          logits_dimension=self.logits_dimension,
1362          head_name=self.head_name,
1363          weight_column_name=self.weight_column_name,
1364          enable_centered_bias=self._enable_centered_bias)
1365
1366  def _transform_labels(self, mode, labels):
1367    """Applies transformations to labels tensor."""
1368    if (mode == model_fn.ModeKeys.INFER) or (labels is None):
1369      return None
1370    labels_tensor = _to_labels_tensor(labels, self._label_name)
1371    labels_tensor = _sparse_labels_to_indicator(labels_tensor,
1372                                                self._logits_dimension)
1373    return labels_tensor
1374
1375  def _logits_to_predictions(self, logits):
1376    """See `_MultiClassHead`."""
1377    with ops.name_scope(None, "predictions", (logits,)):
1378      return {
1379          prediction_key.PredictionKey.LOGITS:
1380              logits,
1381          prediction_key.PredictionKey.PROBABILITIES:
1382              math_ops.sigmoid(
1383                  logits, name=prediction_key.PredictionKey.PROBABILITIES),
1384          prediction_key.PredictionKey.CLASSES:
1385              math_ops.to_int64(
1386                  math_ops.greater(logits, 0),
1387                  name=prediction_key.PredictionKey.CLASSES)
1388      }
1389
1390  def _metrics(self, eval_loss, predictions, labels, weights):
1391    """Returns a dict of metrics keyed by name."""
1392    with ops.name_scope("metrics", values=(
1393        [eval_loss, labels, weights] + list(six.itervalues(predictions)))):
1394      classes = predictions[prediction_key.PredictionKey.CLASSES]
1395      probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES]
1396      logits = predictions[prediction_key.PredictionKey.LOGITS]
1397
1398      metrics = {_summary_key(self.head_name, mkey.LOSS):
1399                 metrics_lib.mean(eval_loss)}
1400      # TODO(b/29366811): This currently results in both an "accuracy" and an
1401      # "accuracy/threshold_0.500000_mean" metric for binary classification.
1402      metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
1403          metrics_lib.accuracy(labels, classes, weights))
1404      metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc(
1405          probabilities, labels, weights)
1406      metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc(
1407          probabilities, labels, weights, curve="PR")
1408
1409      for class_id in self._metric_class_ids:
1410        # TODO(ptucker): Add per-class accuracy, precision, recall.
1411        metrics[_summary_key(
1412            self.head_name, mkey.CLASS_PREDICTION_MEAN % class_id)] = (
1413                _predictions_streaming_mean(classes, weights, class_id))
1414        metrics[_summary_key(
1415            self.head_name, mkey.CLASS_LABEL_MEAN % class_id)] = (
1416                _indicator_labels_streaming_mean(labels, weights, class_id))
1417        metrics[_summary_key(
1418            self.head_name, mkey.CLASS_PROBABILITY_MEAN % class_id)] = (
1419                _predictions_streaming_mean(probabilities, weights, class_id))
1420        metrics[_summary_key(
1421            self.head_name, mkey.CLASS_LOGITS_MEAN % class_id)] = (
1422                _predictions_streaming_mean(logits, weights, class_id))
1423        metrics[_summary_key(self.head_name, mkey.CLASS_AUC % class_id)] = (
1424            _streaming_auc(probabilities, labels, weights, class_id))
1425        metrics[_summary_key(self.head_name, mkey.CLASS_AUC_PR % class_id)] = (
1426            _streaming_auc(probabilities, labels, weights, class_id,
1427                           curve="PR"))
1428
1429    return metrics
1430
1431
1432class _LossOnlyHead(Head):
1433  """`Head` implementation for additional loss terms.
1434
1435  This class only holds loss terms unrelated to any other heads (labels),
1436  e.g. regularization.
1437
1438  Common usage:
1439  This is oftem combine with other heads in a multi head setup.
1440    ```python
1441    head = multi_head([
1442        head1, head2, loss_only_head('regularizer', regularizer)])
1443    ```
1444  """
1445
1446  def __init__(self, loss_fn, head_name=None):
1447    self._loss_fn = loss_fn
1448    self.head_name = head_name or "loss_only_head"
1449
1450  @property
1451  def logits_dimension(self):
1452    return 0
1453
1454  def create_model_fn_ops(self,
1455                          features,
1456                          mode,
1457                          labels=None,
1458                          train_op_fn=None,
1459                          logits=None,
1460                          logits_input=None,
1461                          scope=None):
1462    """See `_Head.create_model_fn_ops`.
1463
1464    Args:
1465      features: Not been used.
1466      mode: Estimator's `ModeKeys`.
1467      labels: Labels `Tensor`, or `dict` of same.
1468      train_op_fn: Function that takes a scalar loss and returns an op to
1469          optimize with the loss.
1470      logits: Not been used.
1471      logits_input: Not been used.
1472      scope: Optional scope for variable_scope. If provided, will be passed to
1473          all heads. Most users will want to set this to `None`, so each head
1474          constructs a separate variable_scope according to its `head_name`.
1475
1476    Returns:
1477      A `ModelFnOps` object.
1478
1479    Raises:
1480      ValueError: if `mode` is not recognition.
1481    """
1482    _check_mode_valid(mode)
1483    loss = None
1484    train_op = None
1485    if mode != model_fn.ModeKeys.INFER:
1486      with variable_scope.variable_scope(scope, default_name=self.head_name):
1487        loss = self._loss_fn()
1488        if isinstance(loss, list):
1489          loss = math_ops.add_n(loss)
1490        # The name_scope escapism is needed to maintain the same summary tag
1491        # after switching away from the now unsupported API.
1492        with ops.name_scope(""):
1493          summary_loss = array_ops.identity(loss)
1494          summary.scalar(_summary_key(self.head_name, mkey.LOSS),
1495                         summary_loss)
1496        if mode == model_fn.ModeKeys.TRAIN:
1497          if train_op_fn is None:
1498            raise ValueError("train_op_fn can not be None in TRAIN mode")
1499          with ops.name_scope(None, "train_op", (loss,)):
1500            train_op = train_op_fn(loss)
1501
1502    return model_fn.ModelFnOps(
1503        mode=mode,
1504        loss=loss,
1505        train_op=train_op,
1506        predictions={},
1507        eval_metric_ops={})
1508
1509
1510class _MultiHead(Head):
1511  """`Head` implementation for multi objective learning.
1512
1513  This class is responsible for using and merging the output of multiple
1514  `Head` objects.
1515
1516  All heads stem from the same logits/logit_input tensor.
1517
1518  Common usage:
1519  For simple use cases you can pass the activation of hidden layer like
1520  this from your model_fn,
1521    ```python
1522    last_hidden_layer_activation = ... Build your model.
1523    multi_head = ...
1524    return multi_head.create_model_fn_ops(
1525        ..., logits_input=last_hidden_layer_activation, ...)
1526    ```
1527
1528  Or you can create a logits tensor of
1529  [batch_size, multi_head.logits_dimension] shape. _MultiHead will split the
1530  logits for you.
1531    return multi_head.create_model_fn_ops(..., logits=logits, ...)
1532
1533  For more complex use cases like a multi-task/multi-tower model or when logits
1534  for each head has to be created separately, you can pass a dict of logits
1535  where the keys match the name of the single heads.
1536    ```python
1537    logits = {"head1": logits1, "head2": logits2}
1538    return multi_head.create_model_fn_ops(..., logits=logits, ...)
1539    ```
1540
1541  Here is what this class does,
1542  + For training, merges losses of each heads according a function provided by
1543      user, calls user provided train_op_fn with this final loss.
1544  + For eval, merges metrics by adding head_name suffix to the keys in eval
1545      metrics.
1546  + For inference, updates keys in prediction dict to a 2-tuple,
1547      (head_name, prediction_key)
1548  """
1549
1550  def __init__(self, heads, loss_merger):
1551    """_Head to merges multiple _Head objects.
1552
1553    Args:
1554      heads: list of _Head objects.
1555      loss_merger: function that takes a list of loss tensors for the heads
1556        and returns the final loss tensor for the multi head.
1557
1558    Raises:
1559      ValueError: if any head does not have a name.
1560    """
1561    self._logits_dimension = 0
1562    for head in heads:
1563      if not head.head_name:
1564        raise ValueError("Members of MultiHead must have names.")
1565      self._logits_dimension += head.logits_dimension
1566
1567    self._heads = heads
1568    self._loss_merger = loss_merger
1569
1570  @property
1571  def logits_dimension(self):
1572    return self._logits_dimension
1573
1574  def create_model_fn_ops(self,
1575                          features,
1576                          mode,
1577                          labels=None,
1578                          train_op_fn=None,
1579                          logits=None,
1580                          logits_input=None,
1581                          scope=None):
1582    """See `_Head.create_model_fn_ops`.
1583
1584    Args:
1585      features: Input `dict` of `Tensor` objects.
1586      mode: Estimator's `ModeKeys`.
1587      labels: Labels `Tensor`, or `dict` of same.
1588      train_op_fn: Function that takes a scalar loss and returns an op to
1589          optimize with the loss.
1590      logits: Concatenated logits for all heads or a dict of head name to logits
1591          tensor. If concatenated logits, it should have (batchsize, x) shape
1592          where x is the sum of `logits_dimension` of all the heads,
1593          i.e., same as `logits_dimension` of this class. create_model_fn_ops
1594          will split the logits tensor and pass logits of proper size to each
1595          head. This is useful if we want to be agnostic about whether you
1596          creating a single versus multihead. logits can also be a dict for
1597          convenience where you are creating the head specific logits explicitly
1598          and don't want to concatenate them yourself.
1599      logits_input: tensor to build logits from.
1600      scope: Optional scope for variable_scope. If provided, will be passed to
1601        all heads. Most users will want to set this to `None`, so each head
1602        constructs a separate variable_scope according to its `head_name`.
1603
1604    Returns:
1605      `ModelFnOps`.
1606
1607    Raises:
1608      ValueError: if `mode` is not recognized, or neither or both of `logits`
1609          and `logits_input` is provided.
1610    """
1611    _check_mode_valid(mode)
1612    all_model_fn_ops = []
1613    if logits is None:
1614      # Use logits_input.
1615      for head in self._heads:
1616        all_model_fn_ops.append(
1617            head.create_model_fn_ops(
1618                features=features,
1619                mode=mode,
1620                labels=labels,
1621                train_op_fn=no_op_train_fn,
1622                logits_input=logits_input,
1623                scope=scope))
1624    else:
1625      head_logits_pairs = []
1626      if isinstance(logits, dict):
1627        head_logits_pairs = []
1628        for head in self._heads:
1629          if isinstance(head, _LossOnlyHead):
1630            head_logits_pairs.append((head, None))
1631          else:
1632            head_logits_pairs.append((head, logits[head.head_name]))
1633      else:
1634        # Split logits for each head.
1635        head_logits_pairs = zip(self._heads, self._split_logits(logits))
1636
1637      for head, head_logits in head_logits_pairs:
1638        all_model_fn_ops.append(
1639            head.create_model_fn_ops(
1640                features=features,
1641                mode=mode,
1642                labels=labels,
1643                train_op_fn=no_op_train_fn,
1644                logits=head_logits,
1645                scope=scope))
1646
1647    if mode == model_fn.ModeKeys.TRAIN:
1648      if train_op_fn is None:
1649        raise ValueError("train_op_fn can not be None in TRAIN mode.")
1650      return self._merge_train(all_model_fn_ops, train_op_fn)
1651    if mode == model_fn.ModeKeys.INFER:
1652      return self._merge_infer(all_model_fn_ops)
1653    if mode == model_fn.ModeKeys.EVAL:
1654      return self._merge_eval(all_model_fn_ops)
1655    raise ValueError("mode=%s unrecognized" % str(mode))
1656
1657  def _split_logits(self, logits):
1658    """Splits logits for heads.
1659
1660    Args:
1661      logits: the logits tensor.
1662
1663    Returns:
1664      A list of logits for the individual heads.
1665    """
1666    all_logits = []
1667    begin = 0
1668    for head in self._heads:
1669      current_logits_size = head.logits_dimension
1670      current_logits = array_ops.slice(logits, [0, begin],
1671                                       [-1, current_logits_size])
1672      all_logits.append(current_logits)
1673      begin += current_logits_size
1674    return all_logits
1675
1676  def _merge_train(self, all_model_fn_ops, train_op_fn):
1677    """Merges list of ModelFnOps for training.
1678
1679    Args:
1680      all_model_fn_ops: list of ModelFnOps for the individual heads.
1681      train_op_fn: Function to create train op. See `create_model_fn_ops`
1682          documentation for more details.
1683
1684    Returns:
1685      ModelFnOps that merges all heads for TRAIN.
1686    """
1687    losses = []
1688    metrics = {}
1689    additional_train_ops = []
1690    for m in all_model_fn_ops:
1691      losses.append(m.loss)
1692      if m.eval_metric_ops is not None:
1693        for k, v in six.iteritems(m.eval_metric_ops):
1694          # metrics["%s/%s" % (k, head_name)] = v
1695          metrics[k] = v
1696      additional_train_ops.append(m.train_op)
1697    loss = self._loss_merger(losses)
1698
1699    train_op = train_op_fn(loss)
1700    train_op = control_flow_ops.group(train_op, *additional_train_ops)
1701    return model_fn.ModelFnOps(
1702        mode=model_fn.ModeKeys.TRAIN,
1703        loss=loss,
1704        train_op=train_op,
1705        eval_metric_ops=metrics)
1706
1707  def _merge_infer(self, all_model_fn_ops):
1708    """Merges list of ModelFnOps for inference.
1709
1710    Args:
1711      all_model_fn_ops: list of ModelFnOps for the individual heads.
1712
1713    Returns:
1714      ModelFnOps that Merges all the heads for INFER.
1715    """
1716    predictions = {}
1717    output_alternatives = {}
1718    for head, m in zip(self._heads, all_model_fn_ops):
1719      if isinstance(head, _LossOnlyHead):
1720        continue
1721      head_name = head.head_name
1722      output_alternatives[head_name] = m.output_alternatives[head_name]
1723      for k, v in m.predictions.items():
1724        predictions[(head_name, k)] = v
1725
1726    return model_fn.ModelFnOps(
1727        mode=model_fn.ModeKeys.INFER,
1728        predictions=predictions,
1729        output_alternatives=output_alternatives)
1730
1731  def _merge_eval(self, all_model_fn_ops):
1732    """Merges list of ModelFnOps for eval.
1733
1734    Args:
1735      all_model_fn_ops: list of ModelFnOps for the individual heads.
1736
1737    Returns:
1738      ModelFnOps that merges all the heads for EVAL.
1739    """
1740    predictions = {}
1741    metrics = {}
1742    losses = []
1743    for head, m in zip(self._heads, all_model_fn_ops):
1744      losses.append(m.loss)
1745      head_name = head.head_name
1746      for k, v in m.predictions.items():
1747        predictions[(head_name, k)] = v
1748      for k, v in m.eval_metric_ops.items():
1749        # metrics["%s/%s" % (k, head_name)] = v
1750        metrics[k] = v
1751    loss = self._loss_merger(losses)
1752
1753    return model_fn.ModelFnOps(
1754        mode=model_fn.ModeKeys.EVAL,
1755        predictions=predictions,
1756        loss=loss,
1757        eval_metric_ops=metrics)
1758
1759
1760def _weight_tensor(features, weight_column_name):
1761  """Returns weights as `Tensor` of rank 0, or at least 2."""
1762  if not weight_column_name:
1763    return None
1764  if weight_column_name not in features:
1765    raise ValueError("Weights {} missing from features.".format(
1766        weight_column_name))
1767  with ops.name_scope(None, "weight_tensor", tuple(six.itervalues(features))):
1768    weight_tensor = math_ops.to_float(features[weight_column_name])
1769    shape = weight_tensor.get_shape()
1770    rank = shape.ndims
1771    # We don't bother with expanding dims of non-staticly shaped tensors or
1772    # scalars, and >1d is already in a good format.
1773    if rank == 1:
1774      logging.warning("Weights {} has shape {}, expanding to make it 2d.".
1775                      format(weight_column_name, shape))
1776      return (
1777          sparse_ops.sparse_reshape(weight_tensor, (-1, 1))
1778          if isinstance(weight_tensor, sparse_tensor.SparseTensor) else
1779          array_ops.reshape(weight_tensor, (-1, 1)))
1780    return weight_tensor
1781
1782
1783# TODO(zakaria): This function is needed for backward compatibility and should
1784#   be removed when we migrate to core.
1785def _compute_weighted_loss(loss_unweighted, weight, name="loss"):
1786  """Returns a tuple of (loss_train, loss_report).
1787
1788  loss is used for gradient descent while weighted_average_loss is used for
1789  summaries to be backward compatible.
1790
1791  loss is different from the loss reported on the tensorboard as we
1792  should respect the example weights when computing the gradient.
1793
1794    L = sum_{i} w_{i} * l_{i} / B
1795
1796  where B is the number of examples in the batch, l_{i}, w_{i} are individual
1797  losses, and example weight.
1798
1799  Args:
1800    loss_unweighted: Unweighted loss
1801    weight: Weight tensor
1802    name: Optional name
1803
1804  Returns:
1805    A tuple of losses. First one for training and the second one for reporting.
1806  """
1807  with ops.name_scope(name, values=(loss_unweighted, weight)) as name_scope:
1808    if weight is None:
1809      loss = math_ops.reduce_mean(loss_unweighted, name=name_scope)
1810      return loss, loss
1811    weight = weights_broadcast_ops.broadcast_weights(weight, loss_unweighted)
1812    with ops.name_scope(None, "weighted_loss",
1813                        (loss_unweighted, weight)) as name:
1814      weighted_loss = math_ops.multiply(loss_unweighted, weight, name=name)
1815    weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope)
1816    weighted_loss_normalized = math_ops.div(
1817        math_ops.reduce_sum(weighted_loss),
1818        math_ops.to_float(math_ops.reduce_sum(weight)),
1819        name="weighted_average_loss")
1820
1821    return weighted_loss_mean, weighted_loss_normalized
1822
1823
1824def _wrap_custom_loss_fn(loss_fn):
1825  def _wrapper(labels, logits, weights=None):
1826    if weights is None:
1827      loss = loss_fn(labels, logits)
1828    else:
1829      loss = loss_fn(labels, logits, weights)
1830    return loss, loss
1831  return _wrapper
1832
1833
1834def _check_mode_valid(mode):
1835  """Raises ValueError if the given mode is invalid."""
1836  if (mode != model_fn.ModeKeys.TRAIN and mode != model_fn.ModeKeys.INFER and
1837      mode != model_fn.ModeKeys.EVAL):
1838    raise ValueError("mode=%s unrecognized." % str(mode))
1839
1840
1841def _get_arguments(func):
1842  """Returns a spec of given func."""
1843  _, func = tf_decorator.unwrap(func)
1844  if hasattr(func, "__code__"):
1845    # Regular function.
1846    return tf_inspect.getargspec(func)
1847  elif hasattr(func, "__call__"):
1848    # Callable object.
1849    return _get_arguments(func.__call__)
1850  elif hasattr(func, "func"):
1851    # Partial function.
1852    return _get_arguments(func.func)
1853
1854
1855def _verify_loss_fn_args(loss_fn):
1856  args = _get_arguments(loss_fn).args
1857  for arg_name in ["labels", "logits", "weights"]:
1858    if arg_name not in args:
1859      raise ValueError("Argument %s not found in loss_fn." % arg_name)
1860
1861
1862def _centered_bias(logits_dimension, head_name=None):
1863  """Returns centered_bias `Variable`.
1864
1865  Args:
1866    logits_dimension: Last dimension of `logits`. Must be >= 1.
1867    head_name: Optional name of the head.
1868
1869  Returns:
1870    `Variable` with shape `[logits_dimension]`.
1871
1872  Raises:
1873    ValueError: if `logits_dimension` is invalid.
1874  """
1875  if (logits_dimension is None) or (logits_dimension < 1):
1876    raise ValueError("Invalid logits_dimension %s." % logits_dimension)
1877  # Do not create a variable with variable_scope.get_variable, because that may
1878  # create a PartitionedVariable, which does not support indexing, so
1879  # summary.scalar will not work.
1880  centered_bias = variable_scope.variable(
1881      name="centered_bias_weight",
1882      initial_value=array_ops.zeros(shape=(logits_dimension,)),
1883      trainable=True)
1884  for dim in range(logits_dimension):
1885    if head_name:
1886      summary.scalar("centered_bias/bias_%d/%s" % (dim, head_name),
1887                     centered_bias[dim])
1888    else:
1889      summary.scalar("centered_bias/bias_%d" % dim, centered_bias[dim])
1890  return centered_bias
1891
1892
1893def _centered_bias_step(centered_bias, batch_size, labels, loss_fn, weights):
1894  """Creates and returns training op for centered bias."""
1895  with ops.name_scope(None, "centered_bias_step", (labels,)) as name:
1896    logits_dimension = array_ops.shape(centered_bias)[0]
1897    logits = array_ops.reshape(
1898        array_ops.tile(centered_bias, (batch_size,)),
1899        (batch_size, logits_dimension))
1900    with ops.name_scope(None, "centered_bias", (labels, logits)):
1901      centered_bias_loss = math_ops.reduce_mean(
1902          loss_fn(labels, logits, weights), name="training_loss")
1903  # Learn central bias by an optimizer. 0.1 is a convervative lr for a
1904  # single variable.
1905  return training.AdagradOptimizer(0.1).minimize(
1906      centered_bias_loss, var_list=(centered_bias,), name=name)
1907
1908
1909def _summary_key(head_name, val):
1910  return "%s/%s" % (val, head_name) if head_name else val
1911
1912
1913def _train_op(loss, labels, train_op_fn, centered_bias, batch_size, loss_fn,
1914              weights):
1915  """Returns op for the training step."""
1916  if centered_bias is not None:
1917    centered_bias_step = _centered_bias_step(
1918        centered_bias=centered_bias,
1919        batch_size=batch_size,
1920        labels=labels,
1921        loss_fn=loss_fn,
1922        weights=weights)
1923  else:
1924    centered_bias_step = None
1925  with ops.name_scope(None, "train_op", (loss, labels)):
1926    train_op = train_op_fn(loss)
1927    if centered_bias_step is not None:
1928      train_op = control_flow_ops.group(train_op, centered_bias_step)
1929    return train_op
1930
1931
1932def _sigmoid_cross_entropy_loss(labels, logits, weights=None):
1933  with ops.name_scope(None, "sigmoid_cross_entropy_loss",
1934                      (logits, labels)) as name:
1935    # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels.
1936    loss = nn.sigmoid_cross_entropy_with_logits(
1937        labels=math_ops.to_float(labels), logits=logits, name=name)
1938    return _compute_weighted_loss(loss, weights)
1939
1940
1941def _float_weights_or_none(weights):
1942  if weights is None:
1943    return None
1944  with ops.name_scope(None, "float_weights", (weights,)) as name:
1945    return math_ops.to_float(weights, name=name)
1946
1947
1948def _indicator_labels_streaming_mean(labels, weights=None, class_id=None):
1949  labels = math_ops.to_float(labels)
1950  weights = _float_weights_or_none(weights)
1951  if weights is not None:
1952    weights = weights_broadcast_ops.broadcast_weights(weights, labels)
1953  if class_id is not None:
1954    if weights is not None:
1955      weights = weights[:, class_id]
1956    labels = labels[:, class_id]
1957  return metrics_lib.mean(labels, weights)
1958
1959
1960def _predictions_streaming_mean(predictions,
1961                                weights=None,
1962                                class_id=None):
1963  predictions = math_ops.to_float(predictions)
1964  weights = _float_weights_or_none(weights)
1965  if weights is not None:
1966    weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
1967  if class_id is not None:
1968    if weights is not None:
1969      weights = weights[:, class_id]
1970    predictions = predictions[:, class_id]
1971  return metrics_lib.mean(predictions, weights)
1972
1973
1974# TODO(ptucker): Add support for SparseTensor labels.
1975def _class_id_labels_to_indicator(labels, num_classes):
1976  if (num_classes is None) or (num_classes < 2):
1977    raise ValueError("Invalid num_classes %s." % num_classes)
1978  with ops.control_dependencies((_assert_labels_rank(labels),)):
1979    labels = array_ops.reshape(labels, (-1,))
1980  return array_ops.one_hot(labels, depth=num_classes, axis=-1)
1981
1982
1983def _class_predictions_streaming_mean(predictions, weights, class_id):
1984  return metrics_lib.mean(
1985      array_ops.where(
1986          math_ops.equal(
1987              math_ops.to_int32(class_id), math_ops.to_int32(predictions)),
1988          array_ops.ones_like(predictions),
1989          array_ops.zeros_like(predictions)),
1990      weights=weights)
1991
1992
1993def _class_labels_streaming_mean(labels, weights, class_id):
1994  return metrics_lib.mean(
1995      array_ops.where(
1996          math_ops.equal(
1997              math_ops.to_int32(class_id), math_ops.to_int32(labels)),
1998          array_ops.ones_like(labels), array_ops.zeros_like(labels)),
1999      weights=weights)
2000
2001
2002def _streaming_auc(predictions, labels, weights=None, class_id=None,
2003                   curve="ROC"):
2004  # pylint: disable=missing-docstring
2005  predictions = math_ops.to_float(predictions)
2006  if labels.dtype.base_dtype != dtypes.bool:
2007    logging.warning("Casting %s labels to bool.", labels.dtype)
2008    labels = math_ops.cast(labels, dtypes.bool)
2009  weights = _float_weights_or_none(weights)
2010  if weights is not None:
2011    weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
2012  if class_id is not None:
2013    if weights is not None:
2014      weights = weights[:, class_id]
2015    predictions = predictions[:, class_id]
2016    labels = labels[:, class_id]
2017  return metrics_lib.auc(labels, predictions, weights, curve=curve)
2018
2019
2020def _assert_class_id(class_id, num_classes=None):
2021  """Average label value for class `class_id`."""
2022  if (class_id is None) or (class_id < 0):
2023    raise ValueError("Invalid class_id %s." % class_id)
2024  if num_classes is not None:
2025    if num_classes < 2:
2026      raise ValueError("Invalid num_classes %s." % num_classes)
2027    if class_id >= num_classes:
2028      raise ValueError("Invalid class_id %s." % class_id)
2029
2030
2031def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold):
2032  threshold_predictions = math_ops.to_float(
2033      math_ops.greater_equal(predictions, threshold))
2034  return metrics_lib.accuracy(labels, threshold_predictions, weights)
2035
2036
2037def _streaming_precision_at_threshold(predictions, labels, weights, threshold):
2038  precision_tensor, update_op = metrics_lib.precision_at_thresholds(
2039      labels, predictions, (threshold,), _float_weights_or_none(weights))
2040  return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
2041
2042
2043def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
2044  precision_tensor, update_op = metrics_lib.recall_at_thresholds(
2045      labels, predictions, (threshold,), _float_weights_or_none(weights))
2046  return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
2047
2048
2049def _classification_output_alternatives(head_name, problem_type,
2050                                        label_keys=None):
2051  """Creates a func to generate output alternatives for classification.
2052
2053  Servo expects classes to be a string tensor, and have the same dimensions
2054  as the probabilities tensor. It should contain the labels of the corresponding
2055  entries in probabilities. This function creates a new classes tensor that
2056  satisfies these conditions and can be exported.
2057
2058  Args:
2059    head_name: Name of the head.
2060    problem_type: `ProblemType`
2061    label_keys: Optional label keys
2062
2063  Returns:
2064    A function to generate output alternatives.
2065  """
2066  def _create_output_alternatives(predictions):
2067    """Creates output alternative for the Head.
2068
2069    Args:
2070      predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a
2071        symbolic name for an output Tensor possibly but not necessarily taken
2072        from `PredictionKey`, and 'Tensor' is the corresponding output Tensor
2073        itself.
2074
2075    Returns:
2076      `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where
2077      'submodel_name' is a submodel identifier that should be consistent across
2078      the pipeline (here likely taken from the head_name),
2079      'problem_type' is a `ProblemType`,
2080      'tensor_name' is a symbolic name for an output Tensor possibly but not
2081       necessarily taken from `PredictionKey`, and
2082      'Tensor' is the corresponding output Tensor itself.
2083
2084    Raises:
2085      ValueError: if predictions does not have PredictionKey.PROBABILITIES key.
2086    """
2087    probabilities = predictions.get(prediction_key.PredictionKey.PROBABILITIES)
2088    if probabilities is None:
2089      raise ValueError("%s missing in predictions" %
2090                       prediction_key.PredictionKey.PROBABILITIES)
2091
2092    with ops.name_scope(None, "_classification_output_alternatives",
2093                        (probabilities,)):
2094      batch_size = array_ops.shape(probabilities)[0]
2095      if label_keys:
2096        classes = array_ops.tile(
2097            input=array_ops.expand_dims(input=label_keys, axis=0),
2098            multiples=[batch_size, 1],
2099            name="classes_tensor")
2100      else:
2101        n = array_ops.shape(probabilities)[1]
2102        classes = array_ops.tile(
2103            input=array_ops.expand_dims(input=math_ops.range(n), axis=0),
2104            multiples=[batch_size, 1])
2105        classes = string_ops.as_string(classes, name="classes_tensor")
2106
2107    exported_predictions = {
2108        prediction_key.PredictionKey.PROBABILITIES: probabilities,
2109        prediction_key.PredictionKey.CLASSES: classes}
2110    return {head_name: (problem_type, exported_predictions)}
2111
2112  return _create_output_alternatives
2113
2114# Aliases
2115# TODO(zakaria): Remove these aliases, See b/34751732
2116_regression_head = regression_head
2117_poisson_regression_head = poisson_regression_head
2118_multi_class_head = multi_class_head
2119_binary_svm_head = binary_svm_head
2120_multi_label_head = multi_label_head
2121_multi_head = multi_head
2122_Head = Head
2123