1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Abstractions for the head(s) of a model. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23import six 24 25from tensorflow.contrib import framework as framework_lib 26from tensorflow.contrib import layers as layers_lib 27from tensorflow.contrib.learn.python.learn.estimators import constants 28from tensorflow.contrib.learn.python.learn.estimators import model_fn 29from tensorflow.contrib.learn.python.learn.estimators import prediction_key 30from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey as mkey 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import sparse_tensor 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import lookup_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import metrics as metrics_lib 39from tensorflow.python.ops import nn 40from tensorflow.python.ops import sparse_ops 41from tensorflow.python.ops import string_ops 42from tensorflow.python.ops import variable_scope 43from tensorflow.python.ops import weights_broadcast_ops 44from tensorflow.python.ops.losses import losses as losses_lib 45from tensorflow.python.platform import tf_logging as logging 46from tensorflow.python.summary import summary 47from tensorflow.python.training import training 48from tensorflow.python.util import tf_decorator 49from tensorflow.python.util import tf_inspect 50 51 52class Head(object): 53 """Interface for the head/top of a model. 54 55 Given logits (or output of a hidden layer), a Head knows how to compute 56 predictions, loss, default metric and export signature. It is meant to, 57 58 1) Simplify writing model_fn and to make model_fn more configurable 59 2) Support wide range of machine learning models. Since most heads can work 60 with logits, they can support DNN, RNN, Wide, Wide&Deep, 61 Global objectives, Gradient boosted trees and many other types 62 of machine learning models. 63 2) To allow users to seamlessly switch between 1 to n heads for multi 64 objective learning (See _MultiHead implementation for more details) 65 66 Common usage: 67 Here is simplified model_fn to build a multiclass DNN model. 68 ```python 69 def _my_dnn_model_fn(features, labels, mode, params, config=None): 70 # Optionally your callers can pass head to model_fn as a param. 71 head = tf.contrib.learn.multi_class_head(...) 72 input = tf.contrib.layers.input_from_feature_columns(features, ...) 73 last_hidden_layer_out = tf.contrib.layers.stack( 74 input, tf.contrib.layers.fully_connected, [1000, 500]) 75 logits = tf.contrib.layers.fully_connected( 76 last_hidden_layer_out, head.logits_dimension, activation_fn=None) 77 78 def _train_op_fn(loss): 79 return optimizer.minimize(loss) 80 81 return head.create_model_fn_ops( 82 features=features, 83 labels=labels, 84 mode=mode, 85 train_op_fn=_train_op_fn, 86 logits=logits, 87 scope=...) 88 ``` 89 90 Most heads also support logits_input which is typically the output of the last 91 hidden layer. Some heads (like heads responsible for candidate sampling or 92 hierarchical softmax) intrinsically will not support logits and you have 93 to pass logits_input. Here is a common usage, 94 ```python 95 return head.create_model_fn_ops( 96 features=features, 97 labels=labels, 98 mode=mode, 99 train_op_fn=_train_op_fn, 100 logits_input=last_hidden_layer_out, 101 scope=...) 102 ```python 103 104 There are cases where computing and applying gradients can not be meaningfully 105 captured with train_op_fn we support (for example, with sync optimizer). In 106 such case, you can take the responsibility on your own. Here is a common 107 use case, 108 ```python 109 model_fn_ops = head.create_model_fn_ops( 110 features=features, 111 labels=labels, 112 mode=mode, 113 train_op_fn=tf.contrib.learn.no_op_train_fn, 114 logits=logits, 115 scope=...) 116 if mode == tf.contrib.learn.ModeKeys.TRAIN: 117 optimizer = ... 118 sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...) 119 update_op = tf.contrib.layers.optimize_loss(optimizer=sync, 120 loss=model_fn_ops.loss, ...) 121 hooks = [sync.make_session_run_hook(is_chief)] 122 ... update train_op and hooks in ModelFnOps and return 123 ``` 124 """ 125 __metaclass__ = abc.ABCMeta 126 127 @abc.abstractproperty 128 def logits_dimension(self): 129 """Size of the last dimension of the logits `Tensor`. 130 131 Typically, logits is of shape `[batch_size, logits_dimension]`. 132 133 Returns: 134 The expected size of the `logits` tensor. 135 """ 136 raise NotImplementedError("Calling an abstract method.") 137 138 @abc.abstractmethod 139 def create_model_fn_ops(self, 140 features, 141 mode, 142 labels=None, 143 train_op_fn=None, 144 logits=None, 145 logits_input=None, 146 scope=None): 147 """Returns `ModelFnOps` that a model_fn can return. 148 149 Please note that, 150 + Exactly one of `logits` and `logits_input` must be provided. 151 + All args must be passed via name. 152 153 Args: 154 features: Input `dict` of `Tensor` objects. 155 mode: Estimator's `ModeKeys`. 156 labels: Labels `Tensor`, or `dict` of same. 157 train_op_fn: Function that takes a scalar loss `Tensor` and returns an op 158 to optimize the model with the loss. This is used in TRAIN mode and 159 must not be None. None is allowed in other modes. If you want to 160 optimize loss yourself you can pass `no_op_train_fn` and then use 161 ModeFnOps.loss to compute and apply gradients. 162 logits: logits `Tensor` to be used by the head. 163 logits_input: `Tensor` from which to build logits, often needed when you 164 don't want to compute the logits. Typically this is the activation of 165 the last hidden layer in a DNN. Some heads (like the ones responsible 166 for candidate sampling) intrinsically avoid computing full logits and 167 only accepts logits_input. 168 scope: Optional scope for `variable_scope`. 169 170 Returns: 171 An instance of `ModelFnOps`. 172 173 Raises: 174 ValueError: If `mode` is not recognized. 175 ValueError: If neither or both of `logits` and `logits_input` is provided. 176 """ 177 raise NotImplementedError("Calling an abstract method.") 178 179 180def regression_head(label_name=None, 181 weight_column_name=None, 182 label_dimension=1, 183 enable_centered_bias=False, 184 head_name=None, 185 link_fn=None): 186 """Creates a `Head` for linear regression. 187 188 Args: 189 label_name: String, name of the key in label dict. Can be null if label 190 is a tensor (single headed models). 191 weight_column_name: A string defining feature column name representing 192 weights. It is used to down weight or boost examples during training. It 193 will be multiplied by the loss of the example. 194 label_dimension: Number of regression labels per example. This is the size 195 of the last dimension of the labels `Tensor` (typically, this has shape 196 `[batch_size, label_dimension]`). 197 enable_centered_bias: A bool. If True, estimator will learn a centered 198 bias variable for each class. Rest of the model structure learns the 199 residual after centered bias. 200 head_name: name of the head. If provided, predictions, summary and metrics 201 keys will be suffixed by `"/" + head_name` and the default variable scope 202 will be `head_name`. 203 link_fn: link function to convert logits to predictions. If provided, 204 this link function will be used instead of identity. 205 206 Returns: 207 An instance of `Head` for linear regression. 208 """ 209 return _RegressionHead( 210 label_name=label_name, 211 weight_column_name=weight_column_name, 212 label_dimension=label_dimension, 213 enable_centered_bias=enable_centered_bias, 214 head_name=head_name, 215 loss_fn=_mean_squared_loss, 216 link_fn=(link_fn if link_fn is not None else array_ops.identity)) 217 218 219def poisson_regression_head(label_name=None, 220 weight_column_name=None, 221 label_dimension=1, 222 enable_centered_bias=False, 223 head_name=None): 224 """Creates a `Head` for poisson regression. 225 226 Args: 227 label_name: String, name of the key in label dict. Can be null if label 228 is a tensor (single headed models). 229 weight_column_name: A string defining feature column name representing 230 weights. It is used to down weight or boost examples during training. It 231 will be multiplied by the loss of the example. 232 label_dimension: Number of regression labels per example. This is the size 233 of the last dimension of the labels `Tensor` (typically, this has shape 234 `[batch_size, label_dimension]`). 235 enable_centered_bias: A bool. If True, estimator will learn a centered 236 bias variable for each class. Rest of the model structure learns the 237 residual after centered bias. 238 head_name: name of the head. If provided, predictions, summary and metrics 239 keys will be suffixed by `"/" + head_name` and the default variable scope 240 will be `head_name`. 241 242 Returns: 243 An instance of `Head` for poisson regression. 244 """ 245 return _RegressionHead( 246 label_name=label_name, 247 weight_column_name=weight_column_name, 248 label_dimension=label_dimension, 249 enable_centered_bias=enable_centered_bias, 250 head_name=head_name, 251 loss_fn=_poisson_loss, 252 link_fn=math_ops.exp) 253 254# TODO(zakaria): Consider adding a _RegressionHead for logistic_regression 255 256 257def multi_class_head(n_classes, 258 label_name=None, 259 weight_column_name=None, 260 enable_centered_bias=False, 261 head_name=None, 262 thresholds=None, 263 metric_class_ids=None, 264 loss_fn=None, 265 label_keys=None): 266 """Creates a `Head` for multi class single label classification. 267 268 The Head uses softmax cross entropy loss. 269 270 This head expects to be fed integer labels specifying the class index. But 271 if `label_keys` is specified, then labels must be strings from this 272 vocabulary, and the predicted classes will be strings from the same 273 vocabulary. 274 275 Args: 276 n_classes: Integer, number of classes, must be >= 2 277 label_name: String, name of the key in label dict. Can be null if label 278 is a tensor (single headed models). 279 weight_column_name: A string defining feature column name representing 280 weights. It is used to down weight or boost examples during training. It 281 will be multiplied by the loss of the example. 282 enable_centered_bias: A bool. If True, estimator will learn a centered 283 bias variable for each class. Rest of the model structure learns the 284 residual after centered bias. 285 head_name: name of the head. If provided, predictions, summary and metrics 286 keys will be suffixed by `"/" + head_name` and the default variable scope 287 will be `head_name`. 288 thresholds: thresholds for eval metrics, defaults to [.5] 289 metric_class_ids: List of class IDs for which we should report per-class 290 metrics. Must all be in the range `[0, n_classes)`. Invalid if 291 `n_classes` is 2. 292 loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as 293 parameter and returns a weighted scalar loss. `weights` should be 294 optional. See `tf.losses` 295 label_keys: Optional list of strings with size `[n_classes]` defining the 296 label vocabulary. Only supported for `n_classes` > 2. 297 298 Returns: 299 An instance of `Head` for multi class classification. 300 301 Raises: 302 ValueError: if `n_classes` is < 2. 303 ValueError: If `metric_class_ids` is provided when `n_classes` is 2. 304 ValueError: If `len(label_keys) != n_classes`. 305 """ 306 if (n_classes is None) or (n_classes < 2): 307 raise ValueError("n_classes must be > 1 for classification: %s." % 308 n_classes) 309 if loss_fn: 310 _verify_loss_fn_args(loss_fn) 311 312 loss_fn = _wrap_custom_loss_fn(loss_fn) if loss_fn else None 313 if n_classes == 2: 314 if metric_class_ids: 315 raise ValueError("metric_class_ids invalid for n_classes==2.") 316 if label_keys: 317 raise ValueError("label_keys is not supported for n_classes=2.") 318 return _BinaryLogisticHead( 319 label_name=label_name, 320 weight_column_name=weight_column_name, 321 enable_centered_bias=enable_centered_bias, 322 head_name=head_name, 323 thresholds=thresholds, 324 loss_fn=loss_fn) 325 326 return _MultiClassHead( 327 n_classes=n_classes, 328 label_name=label_name, 329 weight_column_name=weight_column_name, 330 enable_centered_bias=enable_centered_bias, 331 head_name=head_name, 332 thresholds=thresholds, 333 metric_class_ids=metric_class_ids, 334 loss_fn=loss_fn, 335 label_keys=label_keys) 336 337 338def binary_svm_head( 339 label_name=None, 340 weight_column_name=None, 341 enable_centered_bias=False, 342 head_name=None, 343 thresholds=None,): 344 """Creates a `Head` for binary classification with SVMs. 345 346 The head uses binary hinge loss. 347 348 Args: 349 label_name: String, name of the key in label dict. Can be null if label 350 is a tensor (single headed models). 351 weight_column_name: A string defining feature column name representing 352 weights. It is used to down weight or boost examples during training. It 353 will be multiplied by the loss of the example. 354 enable_centered_bias: A bool. If True, estimator will learn a centered 355 bias variable for each class. Rest of the model structure learns the 356 residual after centered bias. 357 head_name: name of the head. If provided, predictions, summary and metrics 358 keys will be suffixed by `"/" + head_name` and the default variable scope 359 will be `head_name`. 360 thresholds: thresholds for eval metrics, defaults to [.5] 361 362 Returns: 363 An instance of `Head` for binary classification with SVM. 364 """ 365 return _BinarySvmHead( 366 label_name=label_name, 367 weight_column_name=weight_column_name, 368 enable_centered_bias=enable_centered_bias, 369 head_name=head_name, 370 thresholds=thresholds) 371 372 373def multi_label_head(n_classes, 374 label_name=None, 375 weight_column_name=None, 376 enable_centered_bias=False, 377 head_name=None, 378 thresholds=None, 379 metric_class_ids=None, 380 loss_fn=None): 381 """Creates a Head for multi label classification. 382 383 Multi-label classification handles the case where each example may have zero 384 or more associated labels, from a discrete set. This is distinct from 385 `multi_class_head` which has exactly one label from a discrete set. 386 387 This head by default uses sigmoid cross entropy loss, which expects as input 388 a multi-hot tensor of shape `(batch_size, num_classes)`. 389 390 Args: 391 n_classes: Integer, number of classes, must be >= 2 392 label_name: String, name of the key in label dict. Can be null if label 393 is a tensor (single headed models). 394 weight_column_name: A string defining feature column name representing 395 weights. It is used to down weight or boost examples during training. It 396 will be multiplied by the loss of the example. 397 enable_centered_bias: A bool. If True, estimator will learn a centered 398 bias variable for each class. Rest of the model structure learns the 399 residual after centered bias. 400 head_name: name of the head. If provided, predictions, summary and metrics 401 keys will be suffixed by `"/" + head_name` and the default variable scope 402 will be `head_name`. 403 thresholds: thresholds for eval metrics, defaults to [.5] 404 metric_class_ids: List of class IDs for which we should report per-class 405 metrics. Must all be in the range `[0, n_classes)`. 406 loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as 407 parameter and returns a weighted scalar loss. `weights` should be 408 optional. See `tf.losses` 409 410 Returns: 411 An instance of `Head` for multi label classification. 412 413 Raises: 414 ValueError: If n_classes is < 2 415 ValueError: If loss_fn does not have expected signature. 416 """ 417 if n_classes < 2: 418 raise ValueError("n_classes must be > 1 for classification.") 419 if loss_fn: 420 _verify_loss_fn_args(loss_fn) 421 422 return _MultiLabelHead( 423 n_classes=n_classes, 424 label_name=label_name, 425 weight_column_name=weight_column_name, 426 enable_centered_bias=enable_centered_bias, 427 head_name=head_name, 428 thresholds=thresholds, 429 metric_class_ids=metric_class_ids, 430 loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) 431 432 433def loss_only_head(loss_fn, head_name=None): 434 """Creates a Head that contains only loss terms. 435 436 Loss only head holds additional loss terms to be added to other heads and 437 usually represents additional regularization terms in the objective function. 438 439 Args: 440 loss_fn: a function that takes no argument and returns a list of 441 scalar tensors. 442 head_name: a name for the head. 443 444 Returns: 445 An instance of `Head` to hold the additional losses. 446 """ 447 return _LossOnlyHead(loss_fn, head_name=head_name) 448 449 450def multi_head(heads, loss_weights=None): 451 """Creates a MultiHead stemming from same logits/hidden layer. 452 453 Args: 454 heads: list of Head objects. 455 loss_weights: optional list of weights to be used to merge losses from 456 each head. All losses are weighted equally if not provided. 457 458 Returns: 459 A instance of `Head` that merges multiple heads. 460 461 Raises: 462 ValueError: if heads and loss_weights have different size. 463 """ 464 if loss_weights: 465 if len(loss_weights) != len(heads): 466 raise ValueError("heads and loss_weights must have same size") 467 468 def _weighted_loss_merger(losses): 469 if loss_weights: 470 if len(losses) != len(loss_weights): 471 raise ValueError("losses and loss_weights must have same size") 472 weighted_losses = [] 473 for loss, weight in zip(losses, loss_weights): 474 weighted_losses.append(math_ops.multiply(loss, weight)) 475 return math_ops.add_n(weighted_losses) 476 else: 477 return math_ops.add_n(losses) 478 479 return _MultiHead(heads, loss_merger=_weighted_loss_merger) 480 481 482def no_op_train_fn(loss): 483 del loss 484 return control_flow_ops.no_op() 485 486 487class _SingleHead(Head): 488 """Interface for a single head/top of a model.""" 489 __metaclass__ = abc.ABCMeta 490 491 def __init__( 492 self, problem_type, logits_dimension, label_name=None, 493 weight_column_name=None, head_name=None): 494 if problem_type is None: 495 raise ValueError("Invalid problem_type %s." % problem_type) 496 if logits_dimension is None or logits_dimension < 1: 497 raise ValueError("Invalid logits_dimension %s." % logits_dimension) 498 self._problem_type = problem_type 499 self._logits_dimension = logits_dimension 500 self._label_name = label_name 501 self._weight_column_name = weight_column_name 502 self._head_name = head_name 503 504 @property 505 def logits_dimension(self): 506 return self._logits_dimension 507 508 @property 509 def label_name(self): 510 return self._label_name 511 512 @property 513 def weight_column_name(self): 514 return self._weight_column_name 515 516 @property 517 def head_name(self): 518 return self._head_name 519 520 def _create_output_alternatives(self, predictions): 521 """Creates output alternative for the Head. 522 523 Args: 524 predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a 525 symbolic name for an output Tensor possibly but not necessarily taken 526 from `PredictionKey`, and 'Tensor' is the corresponding output Tensor 527 itself. 528 529 Returns: 530 `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where 531 'submodel_name' is a submodel identifier that should be consistent across 532 the pipeline (here likely taken from the head_name), 533 'problem_type' is a `ProblemType`, 534 'tensor_name' is a symbolic name for an output Tensor possibly but not 535 necessarily taken from `PredictionKey`, and 536 'Tensor' is the corresponding output Tensor itself. 537 """ 538 return {self._head_name: (self._problem_type, predictions)} 539 540 541# TODO(zakaria): use contrib losses. 542def _mean_squared_loss(labels, logits, weights=None): 543 with ops.name_scope(None, "mean_squared_loss", (logits, labels)) as name: 544 logits = ops.convert_to_tensor(logits) 545 labels = ops.convert_to_tensor(labels) 546 # To prevent broadcasting inside "-". 547 if len(labels.get_shape()) == 1: 548 labels = array_ops.expand_dims(labels, dim=(1,)) 549 # TODO(zakaria): make sure it does not recreate the broadcast bug. 550 if len(logits.get_shape()) == 1: 551 logits = array_ops.expand_dims(logits, dim=(1,)) 552 logits.get_shape().assert_is_compatible_with(labels.get_shape()) 553 loss = math_ops.square(logits - math_ops.to_float(labels), name=name) 554 return _compute_weighted_loss(loss, weights) 555 556 557def _poisson_loss(labels, logits, weights=None): 558 """Computes poisson loss from logits.""" 559 with ops.name_scope(None, "_poisson_loss", (logits, labels)) as name: 560 logits = ops.convert_to_tensor(logits) 561 labels = ops.convert_to_tensor(labels) 562 # To prevent broadcasting inside "-". 563 if len(labels.get_shape()) == 1: 564 labels = array_ops.expand_dims(labels, dim=(1,)) 565 # TODO(zakaria): make sure it does not recreate the broadcast bug. 566 if len(logits.get_shape()) == 1: 567 logits = array_ops.expand_dims(logits, dim=(1,)) 568 logits.get_shape().assert_is_compatible_with(labels.get_shape()) 569 loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True, 570 name=name) 571 return _compute_weighted_loss(loss, weights) 572 573 574def _logits(logits_input, logits, logits_dimension): 575 """Validate logits args, and create `logits` if necessary. 576 577 Exactly one of `logits_input` and `logits` must be provided. 578 579 Args: 580 logits_input: `Tensor` input to `logits`. 581 logits: `Tensor` output. 582 logits_dimension: Integer, last dimension of `logits`. This is used to 583 create `logits` from `logits_input` if `logits` is `None`; otherwise, it's 584 used to validate `logits`. 585 586 Returns: 587 `logits` `Tensor`. 588 589 Raises: 590 ValueError: if neither or both of `logits` and `logits_input` are supplied. 591 """ 592 if (logits_dimension is None) or (logits_dimension < 1): 593 raise ValueError("Invalid logits_dimension %s." % logits_dimension) 594 595 # If not provided, create logits. 596 if logits is None: 597 if logits_input is None: 598 raise ValueError("Neither logits nor logits_input supplied.") 599 return layers_lib.linear(logits_input, logits_dimension, scope="logits") 600 601 if logits_input is not None: 602 raise ValueError("Both logits and logits_input supplied.") 603 604 logits = ops.convert_to_tensor(logits, name="logits") 605 logits_dims = logits.get_shape().dims 606 if logits_dims is not None: 607 logits_dims[-1].assert_is_compatible_with(logits_dimension) 608 609 return logits 610 611 612def _create_model_fn_ops(features, 613 mode, 614 loss_fn, 615 logits_to_predictions_fn, 616 metrics_fn, 617 create_output_alternatives_fn, 618 labels=None, 619 train_op_fn=None, 620 logits=None, 621 logits_dimension=None, 622 head_name=None, 623 weight_column_name=None, 624 enable_centered_bias=False): 625 """Returns a `ModelFnOps` object.""" 626 _check_mode_valid(mode) 627 628 centered_bias = None 629 if enable_centered_bias: 630 centered_bias = _centered_bias(logits_dimension, head_name) 631 logits = nn.bias_add(logits, centered_bias) 632 633 predictions = logits_to_predictions_fn(logits) 634 loss = None 635 train_op = None 636 eval_metric_ops = None 637 if (mode != model_fn.ModeKeys.INFER) and (labels is not None): 638 weight_tensor = _weight_tensor(features, weight_column_name) 639 loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor) 640 # The name_scope escapism is needed to maintain the same summary tag 641 # after switching away from the now unsupported API. 642 with ops.name_scope(""): 643 summary_loss = array_ops.identity(weighted_average_loss) 644 summary.scalar(_summary_key(head_name, mkey.LOSS), summary_loss) 645 646 if mode == model_fn.ModeKeys.TRAIN: 647 if train_op_fn is None: 648 raise ValueError("train_op_fn can not be None in TRAIN mode") 649 batch_size = array_ops.shape(logits)[0] 650 train_op = _train_op(loss, labels, train_op_fn, centered_bias, 651 batch_size, loss_fn, weight_tensor) 652 eval_metric_ops = metrics_fn( 653 weighted_average_loss, predictions, labels, weight_tensor) 654 return model_fn.ModelFnOps( 655 mode=mode, 656 predictions=predictions, 657 loss=loss, 658 train_op=train_op, 659 eval_metric_ops=eval_metric_ops, 660 output_alternatives=create_output_alternatives_fn(predictions)) 661 662 663class _RegressionHead(_SingleHead): 664 """`Head` for regression with a generalized linear model.""" 665 666 def __init__(self, 667 label_dimension, 668 loss_fn, 669 link_fn, 670 logits_dimension=None, 671 label_name=None, 672 weight_column_name=None, 673 enable_centered_bias=False, 674 head_name=None): 675 """`Head` for regression. 676 677 Args: 678 label_dimension: Number of regression labels per example. This is the 679 size of the last dimension of the labels `Tensor` (typically, this has 680 shape `[batch_size, label_dimension]`). 681 loss_fn: Loss function, takes logits and labels and returns loss. 682 link_fn: Link function, takes a logits tensor and returns the output. 683 logits_dimension: Number of logits per example. This is the 684 size of the last dimension of the logits `Tensor` (typically, this has 685 shape `[batch_size, label_dimension]`). 686 Default value: `label_dimension`. 687 label_name: String, name of the key in label dict. Can be null if label 688 is a tensor (single headed models). 689 weight_column_name: A string defining feature column name representing 690 weights. It is used to down weight or boost examples during training. It 691 will be multiplied by the loss of the example. 692 enable_centered_bias: A bool. If True, estimator will learn a centered 693 bias variable for each class. Rest of the model structure learns the 694 residual after centered bias. 695 head_name: name of the head. Predictions, summary and metrics keys are 696 suffixed by `"/" + head_name` and the default variable scope is 697 `head_name`. 698 """ 699 super(_RegressionHead, self).__init__( 700 problem_type=constants.ProblemType.LINEAR_REGRESSION, 701 logits_dimension=(logits_dimension if logits_dimension is not None 702 else label_dimension), 703 label_name=label_name, 704 weight_column_name=weight_column_name, 705 head_name=head_name) 706 707 self._loss_fn = loss_fn 708 self._link_fn = link_fn 709 self._enable_centered_bias = enable_centered_bias 710 711 def create_model_fn_ops(self, 712 features, 713 mode, 714 labels=None, 715 train_op_fn=None, 716 logits=None, 717 logits_input=None, 718 scope=None): 719 """See `Head`.""" 720 with variable_scope.variable_scope( 721 scope, 722 default_name=self.head_name or "regression_head", 723 values=(tuple(six.itervalues(features)) + 724 (labels, logits, logits_input))): 725 labels = self._transform_labels(mode=mode, labels=labels) 726 logits = _logits(logits_input, logits, self.logits_dimension) 727 return _create_model_fn_ops( 728 features=features, 729 mode=mode, 730 loss_fn=self._loss_fn, 731 logits_to_predictions_fn=self._logits_to_predictions, 732 metrics_fn=self._metrics, 733 create_output_alternatives_fn=self._create_output_alternatives, 734 labels=labels, 735 train_op_fn=train_op_fn, 736 logits=logits, 737 logits_dimension=self.logits_dimension, 738 head_name=self.head_name, 739 weight_column_name=self.weight_column_name, 740 enable_centered_bias=self._enable_centered_bias) 741 742 def _transform_labels(self, mode, labels): 743 """Applies transformations to labels tensor.""" 744 if (mode == model_fn.ModeKeys.INFER) or (labels is None): 745 return None 746 labels_tensor = _to_labels_tensor(labels, self._label_name) 747 _check_no_sparse_tensor(labels_tensor) 748 return labels_tensor 749 750 def _logits_to_predictions(self, logits): 751 """Returns a dict of predictions. 752 753 Args: 754 logits: logits `Tensor` after applying possible centered bias. 755 756 Returns: 757 Dict of prediction `Tensor` keyed by `PredictionKey`. 758 """ 759 key = prediction_key.PredictionKey.SCORES 760 with ops.name_scope(None, "predictions", (logits,)): 761 if self.logits_dimension == 1: 762 logits = array_ops.squeeze(logits, squeeze_dims=(1,), name=key) 763 return {key: self._link_fn(logits)} 764 765 def _metrics(self, eval_loss, predictions, labels, weights): 766 """Returns a dict of metrics keyed by name.""" 767 del predictions, labels, weights # Unused by this head. 768 with ops.name_scope("metrics", values=[eval_loss]): 769 return { 770 _summary_key(self.head_name, mkey.LOSS): 771 metrics_lib.mean(eval_loss)} 772 773 774def _log_loss_with_two_classes(labels, logits, weights=None): 775 with ops.name_scope(None, "log_loss_with_two_classes", 776 (logits, labels)) as name: 777 logits = ops.convert_to_tensor(logits) 778 labels = math_ops.to_float(labels) 779 # TODO(ptucker): This will break for dynamic shapes. 780 # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. 781 if len(labels.get_shape()) == 1: 782 labels = array_ops.expand_dims(labels, dim=(1,)) 783 loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits, 784 name=name) 785 return _compute_weighted_loss(loss, weights) 786 787 788def _one_class_to_two_class_logits(logits): 789 return array_ops.concat((array_ops.zeros_like(logits), logits), 1) 790 791 792class _BinaryLogisticHead(_SingleHead): 793 """`Head` for binary classification with logistic regression.""" 794 795 def __init__(self, 796 label_name=None, 797 weight_column_name=None, 798 enable_centered_bias=False, 799 head_name=None, 800 loss_fn=None, 801 thresholds=None): 802 """`Head` for binary classification with logistic regression. 803 804 Args: 805 label_name: String, name of the key in label dict. Can be `None` if label 806 is a tensor (single headed models). 807 weight_column_name: A string defining feature column name representing 808 weights. It is used to down weight or boost examples during training. It 809 will be multiplied by the loss of the example. 810 enable_centered_bias: A bool. If True, estimator will learn a centered 811 bias variable for each class. Rest of the model structure learns the 812 residual after centered bias. 813 head_name: name of the head. Predictions, summary, metrics keys are 814 suffixed by `"/" + head_name` and the default variable scope is 815 `head_name`. 816 loss_fn: Loss function. 817 thresholds: thresholds for eval. 818 819 Raises: 820 ValueError: if n_classes is invalid. 821 """ 822 super(_BinaryLogisticHead, self).__init__( 823 problem_type=constants.ProblemType.LOGISTIC_REGRESSION, 824 logits_dimension=1, 825 label_name=label_name, 826 weight_column_name=weight_column_name, 827 head_name=head_name) 828 self._thresholds = thresholds if thresholds else (.5,) 829 self._loss_fn = loss_fn if loss_fn else _log_loss_with_two_classes 830 self._enable_centered_bias = enable_centered_bias 831 832 def create_model_fn_ops(self, 833 features, 834 mode, 835 labels=None, 836 train_op_fn=None, 837 logits=None, 838 logits_input=None, 839 scope=None): 840 """See `Head`.""" 841 with variable_scope.variable_scope( 842 scope, 843 default_name=self.head_name or "binary_logistic_head", 844 values=(tuple(six.itervalues(features)) + 845 (labels, logits, logits_input))): 846 labels = self._transform_labels(mode=mode, labels=labels) 847 logits = _logits(logits_input, logits, self.logits_dimension) 848 return _create_model_fn_ops( 849 features=features, 850 mode=mode, 851 loss_fn=self._loss_fn, 852 logits_to_predictions_fn=self._logits_to_predictions, 853 metrics_fn=self._metrics, 854 create_output_alternatives_fn=_classification_output_alternatives( 855 self.head_name, self._problem_type), 856 labels=labels, 857 train_op_fn=train_op_fn, 858 logits=logits, 859 logits_dimension=self.logits_dimension, 860 head_name=self.head_name, 861 weight_column_name=self.weight_column_name, 862 enable_centered_bias=self._enable_centered_bias) 863 864 def _transform_labels(self, mode, labels): 865 """Applies transformations to labels tensor.""" 866 if (mode == model_fn.ModeKeys.INFER) or (labels is None): 867 return None 868 labels_tensor = _to_labels_tensor(labels, self._label_name) 869 _check_no_sparse_tensor(labels_tensor) 870 return labels_tensor 871 872 def _logits_to_predictions(self, logits): 873 """Returns a dict of predictions. 874 875 Args: 876 logits: logits `Output` after applying possible centered bias. 877 878 Returns: 879 Dict of prediction `Output` keyed by `PredictionKey`. 880 """ 881 with ops.name_scope(None, "predictions", (logits,)): 882 two_class_logits = _one_class_to_two_class_logits(logits) 883 return { 884 prediction_key.PredictionKey.LOGITS: 885 logits, 886 prediction_key.PredictionKey.LOGISTIC: 887 math_ops.sigmoid( 888 logits, name=prediction_key.PredictionKey.LOGISTIC), 889 prediction_key.PredictionKey.PROBABILITIES: 890 nn.softmax( 891 two_class_logits, 892 name=prediction_key.PredictionKey.PROBABILITIES), 893 prediction_key.PredictionKey.CLASSES: 894 math_ops.argmax( 895 two_class_logits, 896 1, 897 name=prediction_key.PredictionKey.CLASSES) 898 } 899 900 def _metrics(self, eval_loss, predictions, labels, weights): 901 """Returns a dict of metrics keyed by name.""" 902 with ops.name_scope("metrics", values=( 903 [eval_loss, labels, weights] + list(six.itervalues(predictions)))): 904 classes = predictions[prediction_key.PredictionKey.CLASSES] 905 logistic = predictions[prediction_key.PredictionKey.LOGISTIC] 906 907 metrics = {_summary_key(self.head_name, mkey.LOSS): 908 metrics_lib.mean(eval_loss)} 909 # TODO(b/29366811): This currently results in both an "accuracy" and an 910 # "accuracy/threshold_0.500000_mean" metric for binary classification. 911 metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( 912 metrics_lib.accuracy(labels, classes, weights)) 913 metrics[_summary_key(self.head_name, mkey.PREDICTION_MEAN)] = ( 914 _predictions_streaming_mean(logistic, weights)) 915 metrics[_summary_key(self.head_name, mkey.LABEL_MEAN)] = ( 916 _indicator_labels_streaming_mean(labels, weights)) 917 918 # Also include the streaming mean of the label as an accuracy baseline, as 919 # a reminder to users. 920 metrics[_summary_key(self.head_name, mkey.ACCURACY_BASELINE)] = ( 921 _indicator_labels_streaming_mean(labels, weights)) 922 metrics[_summary_key(self.head_name, mkey.AUC)] = ( 923 _streaming_auc(logistic, labels, weights)) 924 metrics[_summary_key(self.head_name, mkey.AUC_PR)] = ( 925 _streaming_auc(logistic, labels, weights, curve="PR")) 926 927 for threshold in self._thresholds: 928 metrics[_summary_key( 929 self.head_name, mkey.ACCURACY_MEAN % threshold)] = ( 930 _streaming_accuracy_at_threshold(logistic, labels, weights, 931 threshold)) 932 # Precision for positive examples. 933 metrics[_summary_key( 934 self.head_name, mkey.PRECISION_MEAN % threshold)] = ( 935 _streaming_precision_at_threshold(logistic, labels, weights, 936 threshold)) 937 # Recall for positive examples. 938 metrics[_summary_key( 939 self.head_name, mkey.RECALL_MEAN % threshold)] = ( 940 _streaming_recall_at_threshold(logistic, labels, weights, 941 threshold)) 942 943 return metrics 944 945 946def _softmax_cross_entropy_loss(labels, logits, weights=None): 947 with ops.name_scope( 948 None, "softmax_cross_entropy_loss", (logits, labels,)) as name: 949 labels = ops.convert_to_tensor(labels) 950 # Check that we got integer for classification. 951 if not labels.dtype.is_integer: 952 raise ValueError("Labels dtype should be integer " 953 "Instead got %s." % labels.dtype) 954 955 # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels. 956 is_squeezed_labels = False 957 # TODO(ptucker): This will break for dynamic shapes. 958 if len(labels.get_shape()) == 2: 959 labels = array_ops.squeeze(labels, squeeze_dims=(1,)) 960 is_squeezed_labels = True 961 962 loss = nn.sparse_softmax_cross_entropy_with_logits( 963 labels=labels, logits=logits, name=name) 964 965 # Restore squeezed dimension, if necessary, so loss matches weights shape. 966 if is_squeezed_labels: 967 loss = array_ops.expand_dims(loss, axis=(1,)) 968 969 return _compute_weighted_loss(loss, weights) 970 971 972class _MultiClassHead(_SingleHead): 973 """'Head' for multi class classification.""" 974 975 def __init__(self, 976 n_classes, 977 label_name=None, 978 weight_column_name=None, 979 enable_centered_bias=False, 980 head_name=None, 981 loss_fn=None, 982 thresholds=None, 983 metric_class_ids=None, 984 label_keys=None): 985 """'Head' for multi class classification. 986 987 This head expects to be fed integer labels specifying the class index. But 988 if `label_keys` is specified, then labels must be strings from this 989 vocabulary, and the predicted classes will be strings from the same 990 vocabulary. 991 992 Args: 993 n_classes: Number of classes, must be greater than 2 (for 2 classes, use 994 `_BinaryLogisticHead`). 995 label_name: String, name of the key in label dict. Can be null if label 996 is a tensor (single headed models). 997 weight_column_name: A string defining feature column name representing 998 weights. It is used to down weight or boost examples during training. It 999 will be multiplied by the loss of the example. 1000 enable_centered_bias: A bool. If True, estimator will learn a centered 1001 bias variable for each class. Rest of the model structure learns the 1002 residual after centered bias. 1003 head_name: name of the head. If provided, predictions, summary, metrics 1004 keys will be suffixed by `"/" + head_name` and the default variable 1005 scope will be `head_name`. 1006 loss_fn: Loss function. Defaults to softmax cross entropy loss. 1007 thresholds: thresholds for eval. 1008 metric_class_ids: List of class IDs for which we should report per-class 1009 metrics. Must all be in the range `[0, n_classes)`. 1010 label_keys: Optional list of strings with size `[n_classes]` defining the 1011 label vocabulary. 1012 1013 Raises: 1014 ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid. 1015 """ 1016 super(_MultiClassHead, self).__init__( 1017 problem_type=constants.ProblemType.CLASSIFICATION, 1018 logits_dimension=n_classes, 1019 label_name=label_name, 1020 weight_column_name=weight_column_name, 1021 head_name=head_name) 1022 1023 if (n_classes is None) or (n_classes <= 2): 1024 raise ValueError("n_classes must be > 2: %s." % n_classes) 1025 self._thresholds = thresholds if thresholds else (.5,) 1026 self._loss_fn = loss_fn if loss_fn else _softmax_cross_entropy_loss 1027 self._enable_centered_bias = enable_centered_bias 1028 self._metric_class_ids = tuple([] if metric_class_ids is None else 1029 metric_class_ids) 1030 for class_id in self._metric_class_ids: 1031 if (class_id < 0) or (class_id >= n_classes): 1032 raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes)) 1033 if label_keys and len(label_keys) != n_classes: 1034 raise ValueError("Length of label_keys must equal n_classes.") 1035 self._label_keys = label_keys 1036 1037 def create_model_fn_ops(self, 1038 features, 1039 mode, 1040 labels=None, 1041 train_op_fn=None, 1042 logits=None, 1043 logits_input=None, 1044 scope=None): 1045 """See `Head`.""" 1046 with variable_scope.variable_scope( 1047 scope, 1048 default_name=self.head_name or "multi_class_head", 1049 values=(tuple(six.itervalues(features)) + 1050 (labels, logits, logits_input))): 1051 labels = self._transform_labels(mode=mode, labels=labels) 1052 logits = _logits(logits_input, logits, self.logits_dimension) 1053 return _create_model_fn_ops( 1054 features=features, 1055 mode=mode, 1056 loss_fn=self._wrapped_loss_fn, 1057 logits_to_predictions_fn=self._logits_to_predictions, 1058 metrics_fn=self._metrics, 1059 create_output_alternatives_fn=_classification_output_alternatives( 1060 self.head_name, self._problem_type, self._label_keys), 1061 labels=labels, 1062 train_op_fn=train_op_fn, 1063 logits=logits, 1064 logits_dimension=self.logits_dimension, 1065 head_name=self.head_name, 1066 weight_column_name=self.weight_column_name, 1067 enable_centered_bias=self._enable_centered_bias) 1068 1069 def _transform_labels(self, mode, labels): 1070 """Returns a dict that contains both the original labels and label IDs.""" 1071 if (mode == model_fn.ModeKeys.INFER) or (labels is None): 1072 return None 1073 labels_tensor = _to_labels_tensor(labels, self._label_name) 1074 _check_no_sparse_tensor(labels_tensor) 1075 if self._label_keys: 1076 table = lookup_ops.index_table_from_tensor( 1077 self._label_keys, name="label_id_lookup") 1078 return { 1079 "labels": labels_tensor, 1080 "label_ids": table.lookup(labels_tensor), 1081 } 1082 return { 1083 "labels": labels_tensor, 1084 "label_ids": labels_tensor, 1085 } 1086 1087 def _labels(self, labels_dict): 1088 """Returns labels `Tensor` of the same type as classes.""" 1089 return labels_dict["labels"] 1090 1091 def _label_ids(self, labels_dict): 1092 """Returns integer label ID `Tensor`.""" 1093 return labels_dict["label_ids"] 1094 1095 def _wrapped_loss_fn(self, labels, logits, weights=None): 1096 return self._loss_fn(self._label_ids(labels), logits, weights=weights) 1097 1098 def _logits_to_predictions(self, logits): 1099 """Returns a dict of predictions. 1100 1101 Args: 1102 logits: logits `Tensor` after applying possible centered bias. 1103 1104 Returns: 1105 Dict of prediction `Tensor` keyed by `PredictionKey`. 1106 """ 1107 with ops.name_scope(None, "predictions", (logits,)): 1108 class_ids = math_ops.argmax( 1109 logits, 1, name=prediction_key.PredictionKey.CLASSES) 1110 if self._label_keys: 1111 table = lookup_ops.index_to_string_table_from_tensor( 1112 self._label_keys, name="class_string_lookup") 1113 classes = table.lookup(class_ids) 1114 else: 1115 classes = class_ids 1116 return { 1117 prediction_key.PredictionKey.LOGITS: logits, 1118 prediction_key.PredictionKey.PROBABILITIES: 1119 nn.softmax( 1120 logits, name=prediction_key.PredictionKey.PROBABILITIES), 1121 prediction_key.PredictionKey.CLASSES: classes 1122 } 1123 1124 def _metrics(self, eval_loss, predictions, labels, weights): 1125 """Returns a dict of metrics keyed by name.""" 1126 with ops.name_scope( 1127 "metrics", 1128 values=((eval_loss, self._labels(labels), self._label_ids(labels), 1129 weights) + tuple(six.itervalues(predictions)))): 1130 logits = predictions[prediction_key.PredictionKey.LOGITS] 1131 probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES] 1132 classes = predictions[prediction_key.PredictionKey.CLASSES] 1133 1134 metrics = {_summary_key(self.head_name, mkey.LOSS): 1135 metrics_lib.mean(eval_loss)} 1136 # TODO(b/29366811): This currently results in both an "accuracy" and an 1137 # "accuracy/threshold_0.500000_mean" metric for binary classification. 1138 metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( 1139 metrics_lib.accuracy(self._labels(labels), classes, weights)) 1140 1141 if not self._label_keys: 1142 # Classes are IDs. Add some metrics. 1143 for class_id in self._metric_class_ids: 1144 metrics[_summary_key( 1145 self.head_name, mkey.CLASS_PREDICTION_MEAN % class_id)] = ( 1146 _class_predictions_streaming_mean(classes, weights, class_id)) 1147 # TODO(ptucker): Add per-class accuracy, precision, recall. 1148 metrics[_summary_key( 1149 self.head_name, mkey.CLASS_LABEL_MEAN % class_id)] = ( 1150 _class_labels_streaming_mean( 1151 self._label_ids(labels), weights, class_id)) 1152 metrics[_summary_key( 1153 self.head_name, mkey.CLASS_PROBABILITY_MEAN % class_id)] = ( 1154 _predictions_streaming_mean(probabilities, weights, class_id)) 1155 metrics[_summary_key( 1156 self.head_name, mkey.CLASS_LOGITS_MEAN % class_id)] = ( 1157 _predictions_streaming_mean(logits, weights, class_id)) 1158 1159 return metrics 1160 1161 1162def _to_labels_tensor(labels, label_name): 1163 """Returns label as a tensor. 1164 1165 Args: 1166 labels: Label `Tensor` or `SparseTensor` or a dict containing labels. 1167 label_name: Label name if labels is a dict. 1168 1169 Returns: 1170 Label `Tensor` or `SparseTensor`. 1171 """ 1172 labels = labels[label_name] if isinstance(labels, dict) else labels 1173 return framework_lib.convert_to_tensor_or_sparse_tensor(labels) 1174 1175 1176def _check_no_sparse_tensor(x): 1177 """Raises ValueError if the given tensor is `SparseTensor`.""" 1178 if isinstance(x, sparse_tensor.SparseTensor): 1179 raise ValueError("SparseTensor is not supported.") 1180 1181 1182def _sparse_labels_to_indicator(labels, num_classes): 1183 """If labels is `SparseTensor`, converts it to indicator `Tensor`. 1184 1185 Args: 1186 labels: Label `Tensor` or `SparseTensor`. 1187 num_classes: Number of classes. 1188 1189 Returns: 1190 Dense label `Tensor`. 1191 1192 Raises: 1193 ValueError: If labels is `SparseTensor` and `num_classes` < 2. 1194 """ 1195 if isinstance(labels, sparse_tensor.SparseTensor): 1196 if num_classes < 2: 1197 raise ValueError("Must set num_classes >= 2 when passing labels as a " 1198 "SparseTensor.") 1199 return math_ops.to_int64( 1200 sparse_ops.sparse_to_indicator(labels, num_classes)) 1201 return labels 1202 1203 1204def _assert_labels_rank(labels): 1205 return control_flow_ops.Assert( 1206 math_ops.less_equal(array_ops.rank(labels), 2), 1207 ("labels shape should be either [batch_size, 1] or [batch_size]",)) 1208 1209 1210class _BinarySvmHead(_SingleHead): 1211 """`Head` for binary classification using SVM.""" 1212 1213 def __init__(self, label_name, weight_column_name, enable_centered_bias, 1214 head_name, thresholds): 1215 1216 def _loss_fn(labels, logits, weights=None): 1217 with ops.name_scope(None, "hinge_loss", (logits, labels)) as name: 1218 with ops.control_dependencies((_assert_labels_rank(labels),)): 1219 labels = array_ops.reshape(labels, shape=(-1, 1)) 1220 loss = losses_lib.hinge_loss(labels=labels, logits=logits, scope=name, 1221 reduction=losses_lib.Reduction.NONE) 1222 return _compute_weighted_loss(loss, weights) 1223 1224 super(_BinarySvmHead, self).__init__( 1225 problem_type=constants.ProblemType.LOGISTIC_REGRESSION, 1226 logits_dimension=1, 1227 label_name=label_name, 1228 weight_column_name=weight_column_name, 1229 head_name=head_name) 1230 self._thresholds = thresholds if thresholds else (.5,) 1231 self._loss_fn = _loss_fn 1232 self._enable_centered_bias = enable_centered_bias 1233 1234 def create_model_fn_ops(self, 1235 features, 1236 mode, 1237 labels=None, 1238 train_op_fn=None, 1239 logits=None, 1240 logits_input=None, 1241 scope=None): 1242 """See `Head`.""" 1243 with variable_scope.variable_scope( 1244 scope, 1245 default_name=self.head_name or "binary_svm_head", 1246 values=(tuple(six.itervalues(features)) + 1247 (labels, logits, logits_input))): 1248 labels = self._transform_labels(mode=mode, labels=labels) 1249 logits = _logits(logits_input, logits, self.logits_dimension) 1250 return _create_model_fn_ops( 1251 features=features, 1252 mode=mode, 1253 loss_fn=self._loss_fn, 1254 logits_to_predictions_fn=self._logits_to_predictions, 1255 metrics_fn=self._metrics, 1256 # TODO(zakaria): Handle labels for export. 1257 create_output_alternatives_fn=self._create_output_alternatives, 1258 labels=labels, 1259 train_op_fn=train_op_fn, 1260 logits=logits, 1261 logits_dimension=self.logits_dimension, 1262 head_name=self.head_name, 1263 weight_column_name=self.weight_column_name, 1264 enable_centered_bias=self._enable_centered_bias) 1265 1266 def _transform_labels(self, mode, labels): 1267 """Applies transformations to labels tensor.""" 1268 if (mode == model_fn.ModeKeys.INFER) or (labels is None): 1269 return None 1270 labels_tensor = _to_labels_tensor(labels, self._label_name) 1271 _check_no_sparse_tensor(labels_tensor) 1272 return labels_tensor 1273 1274 def _logits_to_predictions(self, logits): 1275 """See `_MultiClassHead`.""" 1276 with ops.name_scope(None, "predictions", (logits,)): 1277 return { 1278 prediction_key.PredictionKey.LOGITS: 1279 logits, 1280 prediction_key.PredictionKey.CLASSES: 1281 math_ops.argmax( 1282 _one_class_to_two_class_logits(logits), 1283 1, 1284 name=prediction_key.PredictionKey.CLASSES) 1285 } 1286 1287 def _metrics(self, eval_loss, predictions, labels, weights): 1288 """See `_MultiClassHead`.""" 1289 with ops.name_scope("metrics", values=( 1290 [eval_loss, labels, weights] + list(six.itervalues(predictions)))): 1291 metrics = {_summary_key(self.head_name, mkey.LOSS): 1292 metrics_lib.mean(eval_loss)} 1293 1294 # TODO(b/29366811): This currently results in both an "accuracy" and an 1295 # "accuracy/threshold_0.500000_mean" metric for binary classification. 1296 classes = predictions[prediction_key.PredictionKey.CLASSES] 1297 metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( 1298 metrics_lib.accuracy(labels, classes, weights)) 1299 # TODO(sibyl-vie3Poto): add more metrics relevant for svms. 1300 1301 return metrics 1302 1303 1304class _MultiLabelHead(_SingleHead): 1305 """`Head` for multi-label classification.""" 1306 1307 # TODO(zakaria): add signature and metric for multilabel. 1308 def __init__(self, 1309 n_classes, 1310 label_name, 1311 weight_column_name, 1312 enable_centered_bias, 1313 head_name, 1314 thresholds, 1315 metric_class_ids=None, 1316 loss_fn=None): 1317 1318 super(_MultiLabelHead, self).__init__( 1319 problem_type=constants.ProblemType.CLASSIFICATION, 1320 logits_dimension=n_classes, 1321 label_name=label_name, 1322 weight_column_name=weight_column_name, 1323 head_name=head_name) 1324 1325 self._thresholds = thresholds if thresholds else (.5,) 1326 self._loss_fn = loss_fn if loss_fn else _sigmoid_cross_entropy_loss 1327 self._enable_centered_bias = enable_centered_bias 1328 self._metric_class_ids = tuple([] if metric_class_ids is None else 1329 metric_class_ids) 1330 for class_id in self._metric_class_ids: 1331 if (class_id < 0) or (class_id >= n_classes): 1332 raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes)) 1333 1334 def create_model_fn_ops(self, 1335 features, 1336 mode, 1337 labels=None, 1338 train_op_fn=None, 1339 logits=None, 1340 logits_input=None, 1341 scope=None): 1342 """See `Head`.""" 1343 with variable_scope.variable_scope( 1344 scope, 1345 default_name=self.head_name or "multi_label_head", 1346 values=(tuple(six.itervalues(features)) + 1347 (labels, logits, logits_input))): 1348 labels = self._transform_labels(mode=mode, labels=labels) 1349 logits = _logits(logits_input, logits, self.logits_dimension) 1350 return _create_model_fn_ops( 1351 features=features, 1352 mode=mode, 1353 loss_fn=self._loss_fn, 1354 logits_to_predictions_fn=self._logits_to_predictions, 1355 metrics_fn=self._metrics, 1356 create_output_alternatives_fn=_classification_output_alternatives( 1357 self.head_name, self._problem_type), 1358 labels=labels, 1359 train_op_fn=train_op_fn, 1360 logits=logits, 1361 logits_dimension=self.logits_dimension, 1362 head_name=self.head_name, 1363 weight_column_name=self.weight_column_name, 1364 enable_centered_bias=self._enable_centered_bias) 1365 1366 def _transform_labels(self, mode, labels): 1367 """Applies transformations to labels tensor.""" 1368 if (mode == model_fn.ModeKeys.INFER) or (labels is None): 1369 return None 1370 labels_tensor = _to_labels_tensor(labels, self._label_name) 1371 labels_tensor = _sparse_labels_to_indicator(labels_tensor, 1372 self._logits_dimension) 1373 return labels_tensor 1374 1375 def _logits_to_predictions(self, logits): 1376 """See `_MultiClassHead`.""" 1377 with ops.name_scope(None, "predictions", (logits,)): 1378 return { 1379 prediction_key.PredictionKey.LOGITS: 1380 logits, 1381 prediction_key.PredictionKey.PROBABILITIES: 1382 math_ops.sigmoid( 1383 logits, name=prediction_key.PredictionKey.PROBABILITIES), 1384 prediction_key.PredictionKey.CLASSES: 1385 math_ops.to_int64( 1386 math_ops.greater(logits, 0), 1387 name=prediction_key.PredictionKey.CLASSES) 1388 } 1389 1390 def _metrics(self, eval_loss, predictions, labels, weights): 1391 """Returns a dict of metrics keyed by name.""" 1392 with ops.name_scope("metrics", values=( 1393 [eval_loss, labels, weights] + list(six.itervalues(predictions)))): 1394 classes = predictions[prediction_key.PredictionKey.CLASSES] 1395 probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES] 1396 logits = predictions[prediction_key.PredictionKey.LOGITS] 1397 1398 metrics = {_summary_key(self.head_name, mkey.LOSS): 1399 metrics_lib.mean(eval_loss)} 1400 # TODO(b/29366811): This currently results in both an "accuracy" and an 1401 # "accuracy/threshold_0.500000_mean" metric for binary classification. 1402 metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( 1403 metrics_lib.accuracy(labels, classes, weights)) 1404 metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc( 1405 probabilities, labels, weights) 1406 metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc( 1407 probabilities, labels, weights, curve="PR") 1408 1409 for class_id in self._metric_class_ids: 1410 # TODO(ptucker): Add per-class accuracy, precision, recall. 1411 metrics[_summary_key( 1412 self.head_name, mkey.CLASS_PREDICTION_MEAN % class_id)] = ( 1413 _predictions_streaming_mean(classes, weights, class_id)) 1414 metrics[_summary_key( 1415 self.head_name, mkey.CLASS_LABEL_MEAN % class_id)] = ( 1416 _indicator_labels_streaming_mean(labels, weights, class_id)) 1417 metrics[_summary_key( 1418 self.head_name, mkey.CLASS_PROBABILITY_MEAN % class_id)] = ( 1419 _predictions_streaming_mean(probabilities, weights, class_id)) 1420 metrics[_summary_key( 1421 self.head_name, mkey.CLASS_LOGITS_MEAN % class_id)] = ( 1422 _predictions_streaming_mean(logits, weights, class_id)) 1423 metrics[_summary_key(self.head_name, mkey.CLASS_AUC % class_id)] = ( 1424 _streaming_auc(probabilities, labels, weights, class_id)) 1425 metrics[_summary_key(self.head_name, mkey.CLASS_AUC_PR % class_id)] = ( 1426 _streaming_auc(probabilities, labels, weights, class_id, 1427 curve="PR")) 1428 1429 return metrics 1430 1431 1432class _LossOnlyHead(Head): 1433 """`Head` implementation for additional loss terms. 1434 1435 This class only holds loss terms unrelated to any other heads (labels), 1436 e.g. regularization. 1437 1438 Common usage: 1439 This is oftem combine with other heads in a multi head setup. 1440 ```python 1441 head = multi_head([ 1442 head1, head2, loss_only_head('regularizer', regularizer)]) 1443 ``` 1444 """ 1445 1446 def __init__(self, loss_fn, head_name=None): 1447 self._loss_fn = loss_fn 1448 self.head_name = head_name or "loss_only_head" 1449 1450 @property 1451 def logits_dimension(self): 1452 return 0 1453 1454 def create_model_fn_ops(self, 1455 features, 1456 mode, 1457 labels=None, 1458 train_op_fn=None, 1459 logits=None, 1460 logits_input=None, 1461 scope=None): 1462 """See `_Head.create_model_fn_ops`. 1463 1464 Args: 1465 features: Not been used. 1466 mode: Estimator's `ModeKeys`. 1467 labels: Labels `Tensor`, or `dict` of same. 1468 train_op_fn: Function that takes a scalar loss and returns an op to 1469 optimize with the loss. 1470 logits: Not been used. 1471 logits_input: Not been used. 1472 scope: Optional scope for variable_scope. If provided, will be passed to 1473 all heads. Most users will want to set this to `None`, so each head 1474 constructs a separate variable_scope according to its `head_name`. 1475 1476 Returns: 1477 A `ModelFnOps` object. 1478 1479 Raises: 1480 ValueError: if `mode` is not recognition. 1481 """ 1482 _check_mode_valid(mode) 1483 loss = None 1484 train_op = None 1485 if mode != model_fn.ModeKeys.INFER: 1486 with variable_scope.variable_scope(scope, default_name=self.head_name): 1487 loss = self._loss_fn() 1488 if isinstance(loss, list): 1489 loss = math_ops.add_n(loss) 1490 # The name_scope escapism is needed to maintain the same summary tag 1491 # after switching away from the now unsupported API. 1492 with ops.name_scope(""): 1493 summary_loss = array_ops.identity(loss) 1494 summary.scalar(_summary_key(self.head_name, mkey.LOSS), 1495 summary_loss) 1496 if mode == model_fn.ModeKeys.TRAIN: 1497 if train_op_fn is None: 1498 raise ValueError("train_op_fn can not be None in TRAIN mode") 1499 with ops.name_scope(None, "train_op", (loss,)): 1500 train_op = train_op_fn(loss) 1501 1502 return model_fn.ModelFnOps( 1503 mode=mode, 1504 loss=loss, 1505 train_op=train_op, 1506 predictions={}, 1507 eval_metric_ops={}) 1508 1509 1510class _MultiHead(Head): 1511 """`Head` implementation for multi objective learning. 1512 1513 This class is responsible for using and merging the output of multiple 1514 `Head` objects. 1515 1516 All heads stem from the same logits/logit_input tensor. 1517 1518 Common usage: 1519 For simple use cases you can pass the activation of hidden layer like 1520 this from your model_fn, 1521 ```python 1522 last_hidden_layer_activation = ... Build your model. 1523 multi_head = ... 1524 return multi_head.create_model_fn_ops( 1525 ..., logits_input=last_hidden_layer_activation, ...) 1526 ``` 1527 1528 Or you can create a logits tensor of 1529 [batch_size, multi_head.logits_dimension] shape. _MultiHead will split the 1530 logits for you. 1531 return multi_head.create_model_fn_ops(..., logits=logits, ...) 1532 1533 For more complex use cases like a multi-task/multi-tower model or when logits 1534 for each head has to be created separately, you can pass a dict of logits 1535 where the keys match the name of the single heads. 1536 ```python 1537 logits = {"head1": logits1, "head2": logits2} 1538 return multi_head.create_model_fn_ops(..., logits=logits, ...) 1539 ``` 1540 1541 Here is what this class does, 1542 + For training, merges losses of each heads according a function provided by 1543 user, calls user provided train_op_fn with this final loss. 1544 + For eval, merges metrics by adding head_name suffix to the keys in eval 1545 metrics. 1546 + For inference, updates keys in prediction dict to a 2-tuple, 1547 (head_name, prediction_key) 1548 """ 1549 1550 def __init__(self, heads, loss_merger): 1551 """_Head to merges multiple _Head objects. 1552 1553 Args: 1554 heads: list of _Head objects. 1555 loss_merger: function that takes a list of loss tensors for the heads 1556 and returns the final loss tensor for the multi head. 1557 1558 Raises: 1559 ValueError: if any head does not have a name. 1560 """ 1561 self._logits_dimension = 0 1562 for head in heads: 1563 if not head.head_name: 1564 raise ValueError("Members of MultiHead must have names.") 1565 self._logits_dimension += head.logits_dimension 1566 1567 self._heads = heads 1568 self._loss_merger = loss_merger 1569 1570 @property 1571 def logits_dimension(self): 1572 return self._logits_dimension 1573 1574 def create_model_fn_ops(self, 1575 features, 1576 mode, 1577 labels=None, 1578 train_op_fn=None, 1579 logits=None, 1580 logits_input=None, 1581 scope=None): 1582 """See `_Head.create_model_fn_ops`. 1583 1584 Args: 1585 features: Input `dict` of `Tensor` objects. 1586 mode: Estimator's `ModeKeys`. 1587 labels: Labels `Tensor`, or `dict` of same. 1588 train_op_fn: Function that takes a scalar loss and returns an op to 1589 optimize with the loss. 1590 logits: Concatenated logits for all heads or a dict of head name to logits 1591 tensor. If concatenated logits, it should have (batchsize, x) shape 1592 where x is the sum of `logits_dimension` of all the heads, 1593 i.e., same as `logits_dimension` of this class. create_model_fn_ops 1594 will split the logits tensor and pass logits of proper size to each 1595 head. This is useful if we want to be agnostic about whether you 1596 creating a single versus multihead. logits can also be a dict for 1597 convenience where you are creating the head specific logits explicitly 1598 and don't want to concatenate them yourself. 1599 logits_input: tensor to build logits from. 1600 scope: Optional scope for variable_scope. If provided, will be passed to 1601 all heads. Most users will want to set this to `None`, so each head 1602 constructs a separate variable_scope according to its `head_name`. 1603 1604 Returns: 1605 `ModelFnOps`. 1606 1607 Raises: 1608 ValueError: if `mode` is not recognized, or neither or both of `logits` 1609 and `logits_input` is provided. 1610 """ 1611 _check_mode_valid(mode) 1612 all_model_fn_ops = [] 1613 if logits is None: 1614 # Use logits_input. 1615 for head in self._heads: 1616 all_model_fn_ops.append( 1617 head.create_model_fn_ops( 1618 features=features, 1619 mode=mode, 1620 labels=labels, 1621 train_op_fn=no_op_train_fn, 1622 logits_input=logits_input, 1623 scope=scope)) 1624 else: 1625 head_logits_pairs = [] 1626 if isinstance(logits, dict): 1627 head_logits_pairs = [] 1628 for head in self._heads: 1629 if isinstance(head, _LossOnlyHead): 1630 head_logits_pairs.append((head, None)) 1631 else: 1632 head_logits_pairs.append((head, logits[head.head_name])) 1633 else: 1634 # Split logits for each head. 1635 head_logits_pairs = zip(self._heads, self._split_logits(logits)) 1636 1637 for head, head_logits in head_logits_pairs: 1638 all_model_fn_ops.append( 1639 head.create_model_fn_ops( 1640 features=features, 1641 mode=mode, 1642 labels=labels, 1643 train_op_fn=no_op_train_fn, 1644 logits=head_logits, 1645 scope=scope)) 1646 1647 if mode == model_fn.ModeKeys.TRAIN: 1648 if train_op_fn is None: 1649 raise ValueError("train_op_fn can not be None in TRAIN mode.") 1650 return self._merge_train(all_model_fn_ops, train_op_fn) 1651 if mode == model_fn.ModeKeys.INFER: 1652 return self._merge_infer(all_model_fn_ops) 1653 if mode == model_fn.ModeKeys.EVAL: 1654 return self._merge_eval(all_model_fn_ops) 1655 raise ValueError("mode=%s unrecognized" % str(mode)) 1656 1657 def _split_logits(self, logits): 1658 """Splits logits for heads. 1659 1660 Args: 1661 logits: the logits tensor. 1662 1663 Returns: 1664 A list of logits for the individual heads. 1665 """ 1666 all_logits = [] 1667 begin = 0 1668 for head in self._heads: 1669 current_logits_size = head.logits_dimension 1670 current_logits = array_ops.slice(logits, [0, begin], 1671 [-1, current_logits_size]) 1672 all_logits.append(current_logits) 1673 begin += current_logits_size 1674 return all_logits 1675 1676 def _merge_train(self, all_model_fn_ops, train_op_fn): 1677 """Merges list of ModelFnOps for training. 1678 1679 Args: 1680 all_model_fn_ops: list of ModelFnOps for the individual heads. 1681 train_op_fn: Function to create train op. See `create_model_fn_ops` 1682 documentation for more details. 1683 1684 Returns: 1685 ModelFnOps that merges all heads for TRAIN. 1686 """ 1687 losses = [] 1688 metrics = {} 1689 additional_train_ops = [] 1690 for m in all_model_fn_ops: 1691 losses.append(m.loss) 1692 if m.eval_metric_ops is not None: 1693 for k, v in six.iteritems(m.eval_metric_ops): 1694 # metrics["%s/%s" % (k, head_name)] = v 1695 metrics[k] = v 1696 additional_train_ops.append(m.train_op) 1697 loss = self._loss_merger(losses) 1698 1699 train_op = train_op_fn(loss) 1700 train_op = control_flow_ops.group(train_op, *additional_train_ops) 1701 return model_fn.ModelFnOps( 1702 mode=model_fn.ModeKeys.TRAIN, 1703 loss=loss, 1704 train_op=train_op, 1705 eval_metric_ops=metrics) 1706 1707 def _merge_infer(self, all_model_fn_ops): 1708 """Merges list of ModelFnOps for inference. 1709 1710 Args: 1711 all_model_fn_ops: list of ModelFnOps for the individual heads. 1712 1713 Returns: 1714 ModelFnOps that Merges all the heads for INFER. 1715 """ 1716 predictions = {} 1717 output_alternatives = {} 1718 for head, m in zip(self._heads, all_model_fn_ops): 1719 if isinstance(head, _LossOnlyHead): 1720 continue 1721 head_name = head.head_name 1722 output_alternatives[head_name] = m.output_alternatives[head_name] 1723 for k, v in m.predictions.items(): 1724 predictions[(head_name, k)] = v 1725 1726 return model_fn.ModelFnOps( 1727 mode=model_fn.ModeKeys.INFER, 1728 predictions=predictions, 1729 output_alternatives=output_alternatives) 1730 1731 def _merge_eval(self, all_model_fn_ops): 1732 """Merges list of ModelFnOps for eval. 1733 1734 Args: 1735 all_model_fn_ops: list of ModelFnOps for the individual heads. 1736 1737 Returns: 1738 ModelFnOps that merges all the heads for EVAL. 1739 """ 1740 predictions = {} 1741 metrics = {} 1742 losses = [] 1743 for head, m in zip(self._heads, all_model_fn_ops): 1744 losses.append(m.loss) 1745 head_name = head.head_name 1746 for k, v in m.predictions.items(): 1747 predictions[(head_name, k)] = v 1748 for k, v in m.eval_metric_ops.items(): 1749 # metrics["%s/%s" % (k, head_name)] = v 1750 metrics[k] = v 1751 loss = self._loss_merger(losses) 1752 1753 return model_fn.ModelFnOps( 1754 mode=model_fn.ModeKeys.EVAL, 1755 predictions=predictions, 1756 loss=loss, 1757 eval_metric_ops=metrics) 1758 1759 1760def _weight_tensor(features, weight_column_name): 1761 """Returns weights as `Tensor` of rank 0, or at least 2.""" 1762 if not weight_column_name: 1763 return None 1764 if weight_column_name not in features: 1765 raise ValueError("Weights {} missing from features.".format( 1766 weight_column_name)) 1767 with ops.name_scope(None, "weight_tensor", tuple(six.itervalues(features))): 1768 weight_tensor = math_ops.to_float(features[weight_column_name]) 1769 shape = weight_tensor.get_shape() 1770 rank = shape.ndims 1771 # We don't bother with expanding dims of non-staticly shaped tensors or 1772 # scalars, and >1d is already in a good format. 1773 if rank == 1: 1774 logging.warning("Weights {} has shape {}, expanding to make it 2d.". 1775 format(weight_column_name, shape)) 1776 return ( 1777 sparse_ops.sparse_reshape(weight_tensor, (-1, 1)) 1778 if isinstance(weight_tensor, sparse_tensor.SparseTensor) else 1779 array_ops.reshape(weight_tensor, (-1, 1))) 1780 return weight_tensor 1781 1782 1783# TODO(zakaria): This function is needed for backward compatibility and should 1784# be removed when we migrate to core. 1785def _compute_weighted_loss(loss_unweighted, weight, name="loss"): 1786 """Returns a tuple of (loss_train, loss_report). 1787 1788 loss is used for gradient descent while weighted_average_loss is used for 1789 summaries to be backward compatible. 1790 1791 loss is different from the loss reported on the tensorboard as we 1792 should respect the example weights when computing the gradient. 1793 1794 L = sum_{i} w_{i} * l_{i} / B 1795 1796 where B is the number of examples in the batch, l_{i}, w_{i} are individual 1797 losses, and example weight. 1798 1799 Args: 1800 loss_unweighted: Unweighted loss 1801 weight: Weight tensor 1802 name: Optional name 1803 1804 Returns: 1805 A tuple of losses. First one for training and the second one for reporting. 1806 """ 1807 with ops.name_scope(name, values=(loss_unweighted, weight)) as name_scope: 1808 if weight is None: 1809 loss = math_ops.reduce_mean(loss_unweighted, name=name_scope) 1810 return loss, loss 1811 weight = weights_broadcast_ops.broadcast_weights(weight, loss_unweighted) 1812 with ops.name_scope(None, "weighted_loss", 1813 (loss_unweighted, weight)) as name: 1814 weighted_loss = math_ops.multiply(loss_unweighted, weight, name=name) 1815 weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope) 1816 weighted_loss_normalized = math_ops.div( 1817 math_ops.reduce_sum(weighted_loss), 1818 math_ops.to_float(math_ops.reduce_sum(weight)), 1819 name="weighted_average_loss") 1820 1821 return weighted_loss_mean, weighted_loss_normalized 1822 1823 1824def _wrap_custom_loss_fn(loss_fn): 1825 def _wrapper(labels, logits, weights=None): 1826 if weights is None: 1827 loss = loss_fn(labels, logits) 1828 else: 1829 loss = loss_fn(labels, logits, weights) 1830 return loss, loss 1831 return _wrapper 1832 1833 1834def _check_mode_valid(mode): 1835 """Raises ValueError if the given mode is invalid.""" 1836 if (mode != model_fn.ModeKeys.TRAIN and mode != model_fn.ModeKeys.INFER and 1837 mode != model_fn.ModeKeys.EVAL): 1838 raise ValueError("mode=%s unrecognized." % str(mode)) 1839 1840 1841def _get_arguments(func): 1842 """Returns a spec of given func.""" 1843 _, func = tf_decorator.unwrap(func) 1844 if hasattr(func, "__code__"): 1845 # Regular function. 1846 return tf_inspect.getargspec(func) 1847 elif hasattr(func, "__call__"): 1848 # Callable object. 1849 return _get_arguments(func.__call__) 1850 elif hasattr(func, "func"): 1851 # Partial function. 1852 return _get_arguments(func.func) 1853 1854 1855def _verify_loss_fn_args(loss_fn): 1856 args = _get_arguments(loss_fn).args 1857 for arg_name in ["labels", "logits", "weights"]: 1858 if arg_name not in args: 1859 raise ValueError("Argument %s not found in loss_fn." % arg_name) 1860 1861 1862def _centered_bias(logits_dimension, head_name=None): 1863 """Returns centered_bias `Variable`. 1864 1865 Args: 1866 logits_dimension: Last dimension of `logits`. Must be >= 1. 1867 head_name: Optional name of the head. 1868 1869 Returns: 1870 `Variable` with shape `[logits_dimension]`. 1871 1872 Raises: 1873 ValueError: if `logits_dimension` is invalid. 1874 """ 1875 if (logits_dimension is None) or (logits_dimension < 1): 1876 raise ValueError("Invalid logits_dimension %s." % logits_dimension) 1877 # Do not create a variable with variable_scope.get_variable, because that may 1878 # create a PartitionedVariable, which does not support indexing, so 1879 # summary.scalar will not work. 1880 centered_bias = variable_scope.variable( 1881 name="centered_bias_weight", 1882 initial_value=array_ops.zeros(shape=(logits_dimension,)), 1883 trainable=True) 1884 for dim in range(logits_dimension): 1885 if head_name: 1886 summary.scalar("centered_bias/bias_%d/%s" % (dim, head_name), 1887 centered_bias[dim]) 1888 else: 1889 summary.scalar("centered_bias/bias_%d" % dim, centered_bias[dim]) 1890 return centered_bias 1891 1892 1893def _centered_bias_step(centered_bias, batch_size, labels, loss_fn, weights): 1894 """Creates and returns training op for centered bias.""" 1895 with ops.name_scope(None, "centered_bias_step", (labels,)) as name: 1896 logits_dimension = array_ops.shape(centered_bias)[0] 1897 logits = array_ops.reshape( 1898 array_ops.tile(centered_bias, (batch_size,)), 1899 (batch_size, logits_dimension)) 1900 with ops.name_scope(None, "centered_bias", (labels, logits)): 1901 centered_bias_loss = math_ops.reduce_mean( 1902 loss_fn(labels, logits, weights), name="training_loss") 1903 # Learn central bias by an optimizer. 0.1 is a convervative lr for a 1904 # single variable. 1905 return training.AdagradOptimizer(0.1).minimize( 1906 centered_bias_loss, var_list=(centered_bias,), name=name) 1907 1908 1909def _summary_key(head_name, val): 1910 return "%s/%s" % (val, head_name) if head_name else val 1911 1912 1913def _train_op(loss, labels, train_op_fn, centered_bias, batch_size, loss_fn, 1914 weights): 1915 """Returns op for the training step.""" 1916 if centered_bias is not None: 1917 centered_bias_step = _centered_bias_step( 1918 centered_bias=centered_bias, 1919 batch_size=batch_size, 1920 labels=labels, 1921 loss_fn=loss_fn, 1922 weights=weights) 1923 else: 1924 centered_bias_step = None 1925 with ops.name_scope(None, "train_op", (loss, labels)): 1926 train_op = train_op_fn(loss) 1927 if centered_bias_step is not None: 1928 train_op = control_flow_ops.group(train_op, centered_bias_step) 1929 return train_op 1930 1931 1932def _sigmoid_cross_entropy_loss(labels, logits, weights=None): 1933 with ops.name_scope(None, "sigmoid_cross_entropy_loss", 1934 (logits, labels)) as name: 1935 # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels. 1936 loss = nn.sigmoid_cross_entropy_with_logits( 1937 labels=math_ops.to_float(labels), logits=logits, name=name) 1938 return _compute_weighted_loss(loss, weights) 1939 1940 1941def _float_weights_or_none(weights): 1942 if weights is None: 1943 return None 1944 with ops.name_scope(None, "float_weights", (weights,)) as name: 1945 return math_ops.to_float(weights, name=name) 1946 1947 1948def _indicator_labels_streaming_mean(labels, weights=None, class_id=None): 1949 labels = math_ops.to_float(labels) 1950 weights = _float_weights_or_none(weights) 1951 if weights is not None: 1952 weights = weights_broadcast_ops.broadcast_weights(weights, labels) 1953 if class_id is not None: 1954 if weights is not None: 1955 weights = weights[:, class_id] 1956 labels = labels[:, class_id] 1957 return metrics_lib.mean(labels, weights) 1958 1959 1960def _predictions_streaming_mean(predictions, 1961 weights=None, 1962 class_id=None): 1963 predictions = math_ops.to_float(predictions) 1964 weights = _float_weights_or_none(weights) 1965 if weights is not None: 1966 weights = weights_broadcast_ops.broadcast_weights(weights, predictions) 1967 if class_id is not None: 1968 if weights is not None: 1969 weights = weights[:, class_id] 1970 predictions = predictions[:, class_id] 1971 return metrics_lib.mean(predictions, weights) 1972 1973 1974# TODO(ptucker): Add support for SparseTensor labels. 1975def _class_id_labels_to_indicator(labels, num_classes): 1976 if (num_classes is None) or (num_classes < 2): 1977 raise ValueError("Invalid num_classes %s." % num_classes) 1978 with ops.control_dependencies((_assert_labels_rank(labels),)): 1979 labels = array_ops.reshape(labels, (-1,)) 1980 return array_ops.one_hot(labels, depth=num_classes, axis=-1) 1981 1982 1983def _class_predictions_streaming_mean(predictions, weights, class_id): 1984 return metrics_lib.mean( 1985 array_ops.where( 1986 math_ops.equal( 1987 math_ops.to_int32(class_id), math_ops.to_int32(predictions)), 1988 array_ops.ones_like(predictions), 1989 array_ops.zeros_like(predictions)), 1990 weights=weights) 1991 1992 1993def _class_labels_streaming_mean(labels, weights, class_id): 1994 return metrics_lib.mean( 1995 array_ops.where( 1996 math_ops.equal( 1997 math_ops.to_int32(class_id), math_ops.to_int32(labels)), 1998 array_ops.ones_like(labels), array_ops.zeros_like(labels)), 1999 weights=weights) 2000 2001 2002def _streaming_auc(predictions, labels, weights=None, class_id=None, 2003 curve="ROC"): 2004 # pylint: disable=missing-docstring 2005 predictions = math_ops.to_float(predictions) 2006 if labels.dtype.base_dtype != dtypes.bool: 2007 logging.warning("Casting %s labels to bool.", labels.dtype) 2008 labels = math_ops.cast(labels, dtypes.bool) 2009 weights = _float_weights_or_none(weights) 2010 if weights is not None: 2011 weights = weights_broadcast_ops.broadcast_weights(weights, predictions) 2012 if class_id is not None: 2013 if weights is not None: 2014 weights = weights[:, class_id] 2015 predictions = predictions[:, class_id] 2016 labels = labels[:, class_id] 2017 return metrics_lib.auc(labels, predictions, weights, curve=curve) 2018 2019 2020def _assert_class_id(class_id, num_classes=None): 2021 """Average label value for class `class_id`.""" 2022 if (class_id is None) or (class_id < 0): 2023 raise ValueError("Invalid class_id %s." % class_id) 2024 if num_classes is not None: 2025 if num_classes < 2: 2026 raise ValueError("Invalid num_classes %s." % num_classes) 2027 if class_id >= num_classes: 2028 raise ValueError("Invalid class_id %s." % class_id) 2029 2030 2031def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold): 2032 threshold_predictions = math_ops.to_float( 2033 math_ops.greater_equal(predictions, threshold)) 2034 return metrics_lib.accuracy(labels, threshold_predictions, weights) 2035 2036 2037def _streaming_precision_at_threshold(predictions, labels, weights, threshold): 2038 precision_tensor, update_op = metrics_lib.precision_at_thresholds( 2039 labels, predictions, (threshold,), _float_weights_or_none(weights)) 2040 return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) 2041 2042 2043def _streaming_recall_at_threshold(predictions, labels, weights, threshold): 2044 precision_tensor, update_op = metrics_lib.recall_at_thresholds( 2045 labels, predictions, (threshold,), _float_weights_or_none(weights)) 2046 return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) 2047 2048 2049def _classification_output_alternatives(head_name, problem_type, 2050 label_keys=None): 2051 """Creates a func to generate output alternatives for classification. 2052 2053 Servo expects classes to be a string tensor, and have the same dimensions 2054 as the probabilities tensor. It should contain the labels of the corresponding 2055 entries in probabilities. This function creates a new classes tensor that 2056 satisfies these conditions and can be exported. 2057 2058 Args: 2059 head_name: Name of the head. 2060 problem_type: `ProblemType` 2061 label_keys: Optional label keys 2062 2063 Returns: 2064 A function to generate output alternatives. 2065 """ 2066 def _create_output_alternatives(predictions): 2067 """Creates output alternative for the Head. 2068 2069 Args: 2070 predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a 2071 symbolic name for an output Tensor possibly but not necessarily taken 2072 from `PredictionKey`, and 'Tensor' is the corresponding output Tensor 2073 itself. 2074 2075 Returns: 2076 `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where 2077 'submodel_name' is a submodel identifier that should be consistent across 2078 the pipeline (here likely taken from the head_name), 2079 'problem_type' is a `ProblemType`, 2080 'tensor_name' is a symbolic name for an output Tensor possibly but not 2081 necessarily taken from `PredictionKey`, and 2082 'Tensor' is the corresponding output Tensor itself. 2083 2084 Raises: 2085 ValueError: if predictions does not have PredictionKey.PROBABILITIES key. 2086 """ 2087 probabilities = predictions.get(prediction_key.PredictionKey.PROBABILITIES) 2088 if probabilities is None: 2089 raise ValueError("%s missing in predictions" % 2090 prediction_key.PredictionKey.PROBABILITIES) 2091 2092 with ops.name_scope(None, "_classification_output_alternatives", 2093 (probabilities,)): 2094 batch_size = array_ops.shape(probabilities)[0] 2095 if label_keys: 2096 classes = array_ops.tile( 2097 input=array_ops.expand_dims(input=label_keys, axis=0), 2098 multiples=[batch_size, 1], 2099 name="classes_tensor") 2100 else: 2101 n = array_ops.shape(probabilities)[1] 2102 classes = array_ops.tile( 2103 input=array_ops.expand_dims(input=math_ops.range(n), axis=0), 2104 multiples=[batch_size, 1]) 2105 classes = string_ops.as_string(classes, name="classes_tensor") 2106 2107 exported_predictions = { 2108 prediction_key.PredictionKey.PROBABILITIES: probabilities, 2109 prediction_key.PredictionKey.CLASSES: classes} 2110 return {head_name: (problem_type, exported_predictions)} 2111 2112 return _create_output_alternatives 2113 2114# Aliases 2115# TODO(zakaria): Remove these aliases, See b/34751732 2116_regression_head = regression_head 2117_poisson_regression_head = poisson_regression_head 2118_multi_class_head = multi_class_head 2119_binary_svm_head = binary_svm_head 2120_multi_label_head = multi_label_head 2121_multi_head = multi_head 2122_Head = Head 2123