1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Abstractions for the head(s) of a model (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import abc 27 28import six 29 30from tensorflow.contrib import framework as framework_lib 31from tensorflow.contrib import layers as layers_lib 32from tensorflow.contrib.learn.python.learn.estimators import constants 33from tensorflow.contrib.learn.python.learn.estimators import model_fn 34from tensorflow.contrib.learn.python.learn.estimators import prediction_key 35from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey as mkey 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import sparse_tensor 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import control_flow_ops 41from tensorflow.python.ops import lookup_ops 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops import metrics as metrics_lib 44from tensorflow.python.ops import nn 45from tensorflow.python.ops import sparse_ops 46from tensorflow.python.ops import string_ops 47from tensorflow.python.ops import variable_scope 48from tensorflow.python.ops import weights_broadcast_ops 49from tensorflow.python.ops.losses import losses as losses_lib 50from tensorflow.python.platform import tf_logging as logging 51from tensorflow.python.summary import summary 52from tensorflow.python.training import training 53from tensorflow.python.util import tf_decorator 54from tensorflow.python.util import tf_inspect 55from tensorflow.python.util.deprecation import deprecated 56 57 58@six.add_metaclass(abc.ABCMeta) 59class Head(object): 60 """Interface for the head/top of a model. 61 62 THIS CLASS IS DEPRECATED. See 63 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 64 for general migration instructions. 65 66 Given logits (or output of a hidden layer), a Head knows how to compute 67 predictions, loss, default metric and export signature. It is meant to, 68 69 1) Simplify writing model_fn and to make model_fn more configurable 70 2) Support wide range of machine learning models. Since most heads can work 71 with logits, they can support DNN, RNN, Wide, Wide&Deep, 72 Global objectives, Gradient boosted trees and many other types 73 of machine learning models. 74 2) To allow users to seamlessly switch between 1 to n heads for multi 75 objective learning (See _MultiHead implementation for more details) 76 77 Common usage: 78 Here is simplified model_fn to build a multiclass DNN model. 79 ```python 80 def _my_dnn_model_fn(features, labels, mode, params, config=None): 81 # Optionally your callers can pass head to model_fn as a param. 82 head = tf.contrib.learn.multi_class_head(...) 83 input = tf.contrib.layers.input_from_feature_columns(features, ...) 84 last_hidden_layer_out = tf.contrib.layers.stack( 85 input, tf.contrib.layers.fully_connected, [1000, 500]) 86 logits = tf.contrib.layers.fully_connected( 87 last_hidden_layer_out, head.logits_dimension, activation_fn=None) 88 89 def _train_op_fn(loss): 90 return optimizer.minimize(loss) 91 92 return head.create_model_fn_ops( 93 features=features, 94 labels=labels, 95 mode=mode, 96 train_op_fn=_train_op_fn, 97 logits=logits, 98 scope=...) 99 ``` 100 101 Most heads also support logits_input which is typically the output of the last 102 hidden layer. Some heads (like heads responsible for candidate sampling or 103 hierarchical softmax) intrinsically will not support logits and you have 104 to pass logits_input. Here is a common usage, 105 ```python 106 return head.create_model_fn_ops( 107 features=features, 108 labels=labels, 109 mode=mode, 110 train_op_fn=_train_op_fn, 111 logits_input=last_hidden_layer_out, 112 scope=...) 113 ```python 114 115 There are cases where computing and applying gradients can not be meaningfully 116 captured with train_op_fn we support (for example, with sync optimizer). In 117 such case, you can take the responsibility on your own. Here is a common 118 use case, 119 ```python 120 model_fn_ops = head.create_model_fn_ops( 121 features=features, 122 labels=labels, 123 mode=mode, 124 train_op_fn=tf.contrib.learn.no_op_train_fn, 125 logits=logits, 126 scope=...) 127 if mode == tf.contrib.learn.ModeKeys.TRAIN: 128 optimizer = ... 129 sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...) 130 update_op = tf.contrib.layers.optimize_loss(optimizer=sync, 131 loss=model_fn_ops.loss, ...) 132 hooks = [sync.make_session_run_hook(is_chief)] 133 ... update train_op and hooks in ModelFnOps and return 134 ``` 135 """ 136 137 @abc.abstractproperty 138 def logits_dimension(self): 139 """Size of the last dimension of the logits `Tensor`. 140 141 Typically, logits is of shape `[batch_size, logits_dimension]`. 142 143 Returns: 144 The expected size of the `logits` tensor. 145 """ 146 raise NotImplementedError("Calling an abstract method.") 147 148 @abc.abstractmethod 149 def create_model_fn_ops(self, 150 features, 151 mode, 152 labels=None, 153 train_op_fn=None, 154 logits=None, 155 logits_input=None, 156 scope=None): 157 """Returns `ModelFnOps` that a model_fn can return. 158 159 Please note that, 160 + Exactly one of `logits` and `logits_input` must be provided. 161 + All args must be passed via name. 162 163 Args: 164 features: Input `dict` of `Tensor` objects. 165 mode: Estimator's `ModeKeys`. 166 labels: Labels `Tensor`, or `dict` of same. 167 train_op_fn: Function that takes a scalar loss `Tensor` and returns an op 168 to optimize the model with the loss. This is used in TRAIN mode and 169 must not be None. None is allowed in other modes. If you want to 170 optimize loss yourself you can pass `no_op_train_fn` and then use 171 ModeFnOps.loss to compute and apply gradients. 172 logits: logits `Tensor` to be used by the head. 173 logits_input: `Tensor` from which to build logits, often needed when you 174 don't want to compute the logits. Typically this is the activation of 175 the last hidden layer in a DNN. Some heads (like the ones responsible 176 for candidate sampling) intrinsically avoid computing full logits and 177 only accepts logits_input. 178 scope: Optional scope for `variable_scope`. 179 180 Returns: 181 An instance of `ModelFnOps`. 182 183 Raises: 184 ValueError: If `mode` is not recognized. 185 ValueError: If neither or both of `logits` and `logits_input` is provided. 186 """ 187 raise NotImplementedError("Calling an abstract method.") 188 189 190@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") 191def regression_head(label_name=None, 192 weight_column_name=None, 193 label_dimension=1, 194 enable_centered_bias=False, 195 head_name=None, 196 link_fn=None): 197 """Creates a `Head` for linear regression. 198 199 Args: 200 label_name: String, name of the key in label dict. Can be null if label 201 is a tensor (single headed models). 202 weight_column_name: A string defining feature column name representing 203 weights. It is used to down weight or boost examples during training. It 204 will be multiplied by the loss of the example. 205 label_dimension: Number of regression labels per example. This is the size 206 of the last dimension of the labels `Tensor` (typically, this has shape 207 `[batch_size, label_dimension]`). 208 enable_centered_bias: A bool. If True, estimator will learn a centered 209 bias variable for each class. Rest of the model structure learns the 210 residual after centered bias. 211 head_name: name of the head. If provided, predictions, summary and metrics 212 keys will be suffixed by `"/" + head_name` and the default variable scope 213 will be `head_name`. 214 link_fn: link function to convert logits to predictions. If provided, 215 this link function will be used instead of identity. 216 217 Returns: 218 An instance of `Head` for linear regression. 219 """ 220 return _RegressionHead( 221 label_name=label_name, 222 weight_column_name=weight_column_name, 223 label_dimension=label_dimension, 224 enable_centered_bias=enable_centered_bias, 225 head_name=head_name, 226 loss_fn=_mean_squared_loss, 227 link_fn=(link_fn if link_fn is not None else array_ops.identity)) 228 229 230@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") 231def poisson_regression_head(label_name=None, 232 weight_column_name=None, 233 label_dimension=1, 234 enable_centered_bias=False, 235 head_name=None): 236 """Creates a `Head` for poisson regression. 237 238 Args: 239 label_name: String, name of the key in label dict. Can be null if label 240 is a tensor (single headed models). 241 weight_column_name: A string defining feature column name representing 242 weights. It is used to down weight or boost examples during training. It 243 will be multiplied by the loss of the example. 244 label_dimension: Number of regression labels per example. This is the size 245 of the last dimension of the labels `Tensor` (typically, this has shape 246 `[batch_size, label_dimension]`). 247 enable_centered_bias: A bool. If True, estimator will learn a centered 248 bias variable for each class. Rest of the model structure learns the 249 residual after centered bias. 250 head_name: name of the head. If provided, predictions, summary and metrics 251 keys will be suffixed by `"/" + head_name` and the default variable scope 252 will be `head_name`. 253 254 Returns: 255 An instance of `Head` for poisson regression. 256 """ 257 return _RegressionHead( 258 label_name=label_name, 259 weight_column_name=weight_column_name, 260 label_dimension=label_dimension, 261 enable_centered_bias=enable_centered_bias, 262 head_name=head_name, 263 loss_fn=_poisson_loss, 264 link_fn=math_ops.exp) 265 266# TODO(zakaria): Consider adding a _RegressionHead for logistic_regression 267 268 269@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") 270def multi_class_head(n_classes, 271 label_name=None, 272 weight_column_name=None, 273 enable_centered_bias=False, 274 head_name=None, 275 thresholds=None, 276 metric_class_ids=None, 277 loss_fn=None, 278 label_keys=None): 279 """Creates a `Head` for multi class single label classification. 280 281 The Head uses softmax cross entropy loss. 282 283 This head expects to be fed integer labels specifying the class index. But 284 if `label_keys` is specified, then labels must be strings from this 285 vocabulary, and the predicted classes will be strings from the same 286 vocabulary. 287 288 Args: 289 n_classes: Integer, number of classes, must be >= 2 290 label_name: String, name of the key in label dict. Can be null if label 291 is a tensor (single headed models). 292 weight_column_name: A string defining feature column name representing 293 weights. It is used to down weight or boost examples during training. It 294 will be multiplied by the loss of the example. 295 enable_centered_bias: A bool. If True, estimator will learn a centered 296 bias variable for each class. Rest of the model structure learns the 297 residual after centered bias. 298 head_name: name of the head. If provided, predictions, summary and metrics 299 keys will be suffixed by `"/" + head_name` and the default variable scope 300 will be `head_name`. 301 thresholds: thresholds for eval metrics, defaults to [.5] 302 metric_class_ids: List of class IDs for which we should report per-class 303 metrics. Must all be in the range `[0, n_classes)`. Invalid if 304 `n_classes` is 2. 305 loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as 306 parameter and returns a weighted scalar loss. `weights` should be 307 optional. See `tf.losses` 308 label_keys: Optional list of strings with size `[n_classes]` defining the 309 label vocabulary. Only supported for `n_classes` > 2. 310 311 Returns: 312 An instance of `Head` for multi class classification. 313 314 Raises: 315 ValueError: if `n_classes` is < 2. 316 ValueError: If `metric_class_ids` is provided when `n_classes` is 2. 317 ValueError: If `len(label_keys) != n_classes`. 318 """ 319 if (n_classes is None) or (n_classes < 2): 320 raise ValueError("n_classes must be > 1 for classification: %s." % 321 n_classes) 322 if loss_fn: 323 _verify_loss_fn_args(loss_fn) 324 325 loss_fn = _wrap_custom_loss_fn(loss_fn) if loss_fn else None 326 if n_classes == 2: 327 if metric_class_ids: 328 raise ValueError("metric_class_ids invalid for n_classes==2.") 329 if label_keys: 330 raise ValueError("label_keys is not supported for n_classes=2.") 331 return _BinaryLogisticHead( 332 label_name=label_name, 333 weight_column_name=weight_column_name, 334 enable_centered_bias=enable_centered_bias, 335 head_name=head_name, 336 thresholds=thresholds, 337 loss_fn=loss_fn) 338 339 return _MultiClassHead( 340 n_classes=n_classes, 341 label_name=label_name, 342 weight_column_name=weight_column_name, 343 enable_centered_bias=enable_centered_bias, 344 head_name=head_name, 345 thresholds=thresholds, 346 metric_class_ids=metric_class_ids, 347 loss_fn=loss_fn, 348 label_keys=label_keys) 349 350 351@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") 352def binary_svm_head( 353 label_name=None, 354 weight_column_name=None, 355 enable_centered_bias=False, 356 head_name=None, 357 thresholds=None,): 358 """Creates a `Head` for binary classification with SVMs. 359 360 The head uses binary hinge loss. 361 362 Args: 363 label_name: String, name of the key in label dict. Can be null if label 364 is a tensor (single headed models). 365 weight_column_name: A string defining feature column name representing 366 weights. It is used to down weight or boost examples during training. It 367 will be multiplied by the loss of the example. 368 enable_centered_bias: A bool. If True, estimator will learn a centered 369 bias variable for each class. Rest of the model structure learns the 370 residual after centered bias. 371 head_name: name of the head. If provided, predictions, summary and metrics 372 keys will be suffixed by `"/" + head_name` and the default variable scope 373 will be `head_name`. 374 thresholds: thresholds for eval metrics, defaults to [.5] 375 376 Returns: 377 An instance of `Head` for binary classification with SVM. 378 """ 379 return _BinarySvmHead( 380 label_name=label_name, 381 weight_column_name=weight_column_name, 382 enable_centered_bias=enable_centered_bias, 383 head_name=head_name, 384 thresholds=thresholds) 385 386 387@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") 388def multi_label_head(n_classes, 389 label_name=None, 390 weight_column_name=None, 391 enable_centered_bias=False, 392 head_name=None, 393 thresholds=None, 394 metric_class_ids=None, 395 loss_fn=None): 396 """Creates a Head for multi label classification. 397 398 Multi-label classification handles the case where each example may have zero 399 or more associated labels, from a discrete set. This is distinct from 400 `multi_class_head` which has exactly one label from a discrete set. 401 402 This head by default uses sigmoid cross entropy loss, which expects as input 403 a multi-hot tensor of shape `(batch_size, num_classes)`. 404 405 Args: 406 n_classes: Integer, number of classes, must be >= 2 407 label_name: String, name of the key in label dict. Can be null if label 408 is a tensor (single headed models). 409 weight_column_name: A string defining feature column name representing 410 weights. It is used to down weight or boost examples during training. It 411 will be multiplied by the loss of the example. 412 enable_centered_bias: A bool. If True, estimator will learn a centered 413 bias variable for each class. Rest of the model structure learns the 414 residual after centered bias. 415 head_name: name of the head. If provided, predictions, summary and metrics 416 keys will be suffixed by `"/" + head_name` and the default variable scope 417 will be `head_name`. 418 thresholds: thresholds for eval metrics, defaults to [.5] 419 metric_class_ids: List of class IDs for which we should report per-class 420 metrics. Must all be in the range `[0, n_classes)`. 421 loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as 422 parameter and returns a weighted scalar loss. `weights` should be 423 optional. See `tf.losses` 424 425 Returns: 426 An instance of `Head` for multi label classification. 427 428 Raises: 429 ValueError: If n_classes is < 2 430 ValueError: If loss_fn does not have expected signature. 431 """ 432 if n_classes < 2: 433 raise ValueError("n_classes must be > 1 for classification.") 434 if loss_fn: 435 _verify_loss_fn_args(loss_fn) 436 437 return _MultiLabelHead( 438 n_classes=n_classes, 439 label_name=label_name, 440 weight_column_name=weight_column_name, 441 enable_centered_bias=enable_centered_bias, 442 head_name=head_name, 443 thresholds=thresholds, 444 metric_class_ids=metric_class_ids, 445 loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) 446 447 448@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") 449def loss_only_head(loss_fn, head_name=None): 450 """Creates a Head that contains only loss terms. 451 452 Loss only head holds additional loss terms to be added to other heads and 453 usually represents additional regularization terms in the objective function. 454 455 Args: 456 loss_fn: a function that takes no argument and returns a list of 457 scalar tensors. 458 head_name: a name for the head. 459 460 Returns: 461 An instance of `Head` to hold the additional losses. 462 """ 463 return _LossOnlyHead(loss_fn, head_name=head_name) 464 465 466@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") 467def multi_head(heads, loss_weights=None): 468 """Creates a MultiHead stemming from same logits/hidden layer. 469 470 Args: 471 heads: list of Head objects. 472 loss_weights: optional list of weights to be used to merge losses from 473 each head. All losses are weighted equally if not provided. 474 475 Returns: 476 A instance of `Head` that merges multiple heads. 477 478 Raises: 479 ValueError: if heads and loss_weights have different size. 480 """ 481 if loss_weights: 482 if len(loss_weights) != len(heads): 483 raise ValueError("heads and loss_weights must have same size") 484 485 def _weighted_loss_merger(losses): 486 if loss_weights: 487 if len(losses) != len(loss_weights): 488 raise ValueError("losses and loss_weights must have same size") 489 weighted_losses = [] 490 for loss, weight in zip(losses, loss_weights): 491 weighted_losses.append(math_ops.multiply(loss, weight)) 492 return math_ops.add_n(weighted_losses) 493 else: 494 return math_ops.add_n(losses) 495 496 return _MultiHead(heads, loss_merger=_weighted_loss_merger) 497 498 499@deprecated(None, "Use 'lambda _: tf.no_op()'.") 500def no_op_train_fn(loss): 501 del loss 502 return control_flow_ops.no_op() 503 504 505class _SingleHead(Head): 506 """Interface for a single head/top of a model.""" 507 508 def __init__( 509 self, problem_type, logits_dimension, label_name=None, 510 weight_column_name=None, head_name=None): 511 if problem_type is None: 512 raise ValueError("Invalid problem_type %s." % problem_type) 513 if logits_dimension is None or logits_dimension < 1: 514 raise ValueError("Invalid logits_dimension %s." % logits_dimension) 515 self._problem_type = problem_type 516 self._logits_dimension = logits_dimension 517 self._label_name = label_name 518 self._weight_column_name = weight_column_name 519 self._head_name = head_name 520 521 @property 522 def logits_dimension(self): 523 return self._logits_dimension 524 525 @property 526 def label_name(self): 527 return self._label_name 528 529 @property 530 def weight_column_name(self): 531 return self._weight_column_name 532 533 @property 534 def head_name(self): 535 return self._head_name 536 537 def _create_output_alternatives(self, predictions): 538 """Creates output alternative for the Head. 539 540 Args: 541 predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a 542 symbolic name for an output Tensor possibly but not necessarily taken 543 from `PredictionKey`, and 'Tensor' is the corresponding output Tensor 544 itself. 545 546 Returns: 547 `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where 548 'submodel_name' is a submodel identifier that should be consistent across 549 the pipeline (here likely taken from the head_name), 550 'problem_type' is a `ProblemType`, 551 'tensor_name' is a symbolic name for an output Tensor possibly but not 552 necessarily taken from `PredictionKey`, and 553 'Tensor' is the corresponding output Tensor itself. 554 """ 555 return {self._head_name: (self._problem_type, predictions)} 556 557 558# TODO(zakaria): use contrib losses. 559def _mean_squared_loss(labels, logits, weights=None): 560 with ops.name_scope(None, "mean_squared_loss", (logits, labels)) as name: 561 logits = ops.convert_to_tensor(logits) 562 labels = ops.convert_to_tensor(labels) 563 # To prevent broadcasting inside "-". 564 if len(labels.get_shape()) == 1: 565 labels = array_ops.expand_dims(labels, axis=1) 566 # TODO(zakaria): make sure it does not recreate the broadcast bug. 567 if len(logits.get_shape()) == 1: 568 logits = array_ops.expand_dims(logits, axis=1) 569 logits.get_shape().assert_is_compatible_with(labels.get_shape()) 570 loss = math_ops.squared_difference( 571 logits, math_ops.cast(labels, dtypes.float32), name=name) 572 return _compute_weighted_loss(loss, weights) 573 574 575def _poisson_loss(labels, logits, weights=None): 576 """Computes poisson loss from logits.""" 577 with ops.name_scope(None, "_poisson_loss", (logits, labels)) as name: 578 logits = ops.convert_to_tensor(logits) 579 labels = ops.convert_to_tensor(labels) 580 # To prevent broadcasting inside "-". 581 if len(labels.get_shape()) == 1: 582 labels = array_ops.expand_dims(labels, axis=1) 583 # TODO(zakaria): make sure it does not recreate the broadcast bug. 584 if len(logits.get_shape()) == 1: 585 logits = array_ops.expand_dims(logits, axis=1) 586 logits.get_shape().assert_is_compatible_with(labels.get_shape()) 587 loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True, 588 name=name) 589 return _compute_weighted_loss(loss, weights) 590 591 592def _logits(logits_input, logits, logits_dimension): 593 """Validate logits args, and create `logits` if necessary. 594 595 Exactly one of `logits_input` and `logits` must be provided. 596 597 Args: 598 logits_input: `Tensor` input to `logits`. 599 logits: `Tensor` output. 600 logits_dimension: Integer, last dimension of `logits`. This is used to 601 create `logits` from `logits_input` if `logits` is `None`; otherwise, it's 602 used to validate `logits`. 603 604 Returns: 605 `logits` `Tensor`. 606 607 Raises: 608 ValueError: if neither or both of `logits` and `logits_input` are supplied. 609 """ 610 if (logits_dimension is None) or (logits_dimension < 1): 611 raise ValueError("Invalid logits_dimension %s." % logits_dimension) 612 613 # If not provided, create logits. 614 if logits is None: 615 if logits_input is None: 616 raise ValueError("Neither logits nor logits_input supplied.") 617 return layers_lib.linear(logits_input, logits_dimension, scope="logits") 618 619 if logits_input is not None: 620 raise ValueError("Both logits and logits_input supplied.") 621 622 logits = ops.convert_to_tensor(logits, name="logits") 623 logits_dims = logits.get_shape().dims 624 if logits_dims is not None: 625 logits_dims[-1].assert_is_compatible_with(logits_dimension) 626 627 return logits 628 629 630def _create_model_fn_ops(features, 631 mode, 632 loss_fn, 633 logits_to_predictions_fn, 634 metrics_fn, 635 create_output_alternatives_fn, 636 labels=None, 637 train_op_fn=None, 638 logits=None, 639 logits_dimension=None, 640 head_name=None, 641 weight_column_name=None, 642 enable_centered_bias=False): 643 """Returns a `ModelFnOps` object.""" 644 _check_mode_valid(mode) 645 646 centered_bias = None 647 if enable_centered_bias: 648 centered_bias = _centered_bias(logits_dimension, head_name) 649 logits = nn.bias_add(logits, centered_bias) 650 651 predictions = logits_to_predictions_fn(logits) 652 loss = None 653 train_op = None 654 eval_metric_ops = None 655 if (mode != model_fn.ModeKeys.INFER) and (labels is not None): 656 weight_tensor = _weight_tensor(features, weight_column_name) 657 loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor) 658 # The name_scope escapism is needed to maintain the same summary tag 659 # after switching away from the now unsupported API. 660 with ops.name_scope(""): 661 summary_loss = array_ops.identity(weighted_average_loss) 662 summary.scalar(_summary_key(head_name, mkey.LOSS), summary_loss) 663 664 if mode == model_fn.ModeKeys.TRAIN: 665 if train_op_fn is None: 666 raise ValueError("train_op_fn can not be None in TRAIN mode") 667 batch_size = array_ops.shape(logits)[0] 668 train_op = _train_op(loss, labels, train_op_fn, centered_bias, 669 batch_size, loss_fn, weight_tensor) 670 eval_metric_ops = metrics_fn( 671 weighted_average_loss, predictions, labels, weight_tensor) 672 return model_fn.ModelFnOps( 673 mode=mode, 674 predictions=predictions, 675 loss=loss, 676 train_op=train_op, 677 eval_metric_ops=eval_metric_ops, 678 output_alternatives=create_output_alternatives_fn(predictions)) 679 680 681class _RegressionHead(_SingleHead): 682 """`Head` for regression with a generalized linear model.""" 683 684 def __init__(self, 685 label_dimension, 686 loss_fn, 687 link_fn, 688 logits_dimension=None, 689 label_name=None, 690 weight_column_name=None, 691 enable_centered_bias=False, 692 head_name=None): 693 """`Head` for regression. 694 695 Args: 696 label_dimension: Number of regression labels per example. This is the 697 size of the last dimension of the labels `Tensor` (typically, this has 698 shape `[batch_size, label_dimension]`). 699 loss_fn: Loss function, takes logits and labels and returns loss. 700 link_fn: Link function, takes a logits tensor and returns the output. 701 logits_dimension: Number of logits per example. This is the 702 size of the last dimension of the logits `Tensor` (typically, this has 703 shape `[batch_size, label_dimension]`). 704 Default value: `label_dimension`. 705 label_name: String, name of the key in label dict. Can be null if label 706 is a tensor (single headed models). 707 weight_column_name: A string defining feature column name representing 708 weights. It is used to down weight or boost examples during training. It 709 will be multiplied by the loss of the example. 710 enable_centered_bias: A bool. If True, estimator will learn a centered 711 bias variable for each class. Rest of the model structure learns the 712 residual after centered bias. 713 head_name: name of the head. Predictions, summary and metrics keys are 714 suffixed by `"/" + head_name` and the default variable scope is 715 `head_name`. 716 """ 717 super(_RegressionHead, self).__init__( 718 problem_type=constants.ProblemType.LINEAR_REGRESSION, 719 logits_dimension=(logits_dimension if logits_dimension is not None 720 else label_dimension), 721 label_name=label_name, 722 weight_column_name=weight_column_name, 723 head_name=head_name) 724 725 self._loss_fn = loss_fn 726 self._link_fn = link_fn 727 self._enable_centered_bias = enable_centered_bias 728 729 def create_model_fn_ops(self, 730 features, 731 mode, 732 labels=None, 733 train_op_fn=None, 734 logits=None, 735 logits_input=None, 736 scope=None): 737 """See `Head`.""" 738 with variable_scope.variable_scope( 739 scope, 740 default_name=self.head_name or "regression_head", 741 values=(tuple(six.itervalues(features)) + 742 (labels, logits, logits_input))): 743 labels = self._transform_labels(mode=mode, labels=labels) 744 logits = _logits(logits_input, logits, self.logits_dimension) 745 return _create_model_fn_ops( 746 features=features, 747 mode=mode, 748 loss_fn=self._loss_fn, 749 logits_to_predictions_fn=self._logits_to_predictions, 750 metrics_fn=self._metrics, 751 create_output_alternatives_fn=self._create_output_alternatives, 752 labels=labels, 753 train_op_fn=train_op_fn, 754 logits=logits, 755 logits_dimension=self.logits_dimension, 756 head_name=self.head_name, 757 weight_column_name=self.weight_column_name, 758 enable_centered_bias=self._enable_centered_bias) 759 760 def _transform_labels(self, mode, labels): 761 """Applies transformations to labels tensor.""" 762 if (mode == model_fn.ModeKeys.INFER) or (labels is None): 763 return None 764 labels_tensor = _to_labels_tensor(labels, self._label_name) 765 _check_no_sparse_tensor(labels_tensor) 766 return labels_tensor 767 768 def _logits_to_predictions(self, logits): 769 """Returns a dict of predictions. 770 771 Args: 772 logits: logits `Tensor` after applying possible centered bias. 773 774 Returns: 775 Dict of prediction `Tensor` keyed by `PredictionKey`. 776 """ 777 key = prediction_key.PredictionKey.SCORES 778 with ops.name_scope(None, "predictions", (logits,)): 779 if self.logits_dimension == 1: 780 logits = array_ops.squeeze(logits, axis=(1,), name=key) 781 return {key: self._link_fn(logits)} 782 783 def _metrics(self, eval_loss, predictions, labels, weights): 784 """Returns a dict of metrics keyed by name.""" 785 del predictions, labels, weights # Unused by this head. 786 with ops.name_scope("metrics", values=[eval_loss]): 787 return { 788 _summary_key(self.head_name, mkey.LOSS): 789 metrics_lib.mean(eval_loss)} 790 791 792def _log_loss_with_two_classes(labels, logits, weights=None): 793 with ops.name_scope(None, "log_loss_with_two_classes", 794 (logits, labels)) as name: 795 logits = ops.convert_to_tensor(logits) 796 labels = math_ops.cast(labels, dtypes.float32) 797 # TODO(ptucker): This will break for dynamic shapes. 798 # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. 799 if len(labels.get_shape()) == 1: 800 labels = array_ops.expand_dims(labels, axis=1) 801 loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits, 802 name=name) 803 return _compute_weighted_loss(loss, weights) 804 805 806def _one_class_to_two_class_logits(logits): 807 return array_ops.concat((array_ops.zeros_like(logits), logits), 1) 808 809 810class _BinaryLogisticHead(_SingleHead): 811 """`Head` for binary classification with logistic regression.""" 812 813 def __init__(self, 814 label_name=None, 815 weight_column_name=None, 816 enable_centered_bias=False, 817 head_name=None, 818 loss_fn=None, 819 thresholds=None): 820 """`Head` for binary classification with logistic regression. 821 822 Args: 823 label_name: String, name of the key in label dict. Can be `None` if label 824 is a tensor (single headed models). 825 weight_column_name: A string defining feature column name representing 826 weights. It is used to down weight or boost examples during training. It 827 will be multiplied by the loss of the example. 828 enable_centered_bias: A bool. If True, estimator will learn a centered 829 bias variable for each class. Rest of the model structure learns the 830 residual after centered bias. 831 head_name: name of the head. Predictions, summary, metrics keys are 832 suffixed by `"/" + head_name` and the default variable scope is 833 `head_name`. 834 loss_fn: Loss function. 835 thresholds: thresholds for eval. 836 837 Raises: 838 ValueError: if n_classes is invalid. 839 """ 840 super(_BinaryLogisticHead, self).__init__( 841 problem_type=constants.ProblemType.LOGISTIC_REGRESSION, 842 logits_dimension=1, 843 label_name=label_name, 844 weight_column_name=weight_column_name, 845 head_name=head_name) 846 self._thresholds = thresholds if thresholds else (.5,) 847 self._loss_fn = loss_fn if loss_fn else _log_loss_with_two_classes 848 self._enable_centered_bias = enable_centered_bias 849 850 def create_model_fn_ops(self, 851 features, 852 mode, 853 labels=None, 854 train_op_fn=None, 855 logits=None, 856 logits_input=None, 857 scope=None): 858 """See `Head`.""" 859 with variable_scope.variable_scope( 860 scope, 861 default_name=self.head_name or "binary_logistic_head", 862 values=(tuple(six.itervalues(features)) + 863 (labels, logits, logits_input))): 864 labels = self._transform_labels(mode=mode, labels=labels) 865 logits = _logits(logits_input, logits, self.logits_dimension) 866 return _create_model_fn_ops( 867 features=features, 868 mode=mode, 869 loss_fn=self._loss_fn, 870 logits_to_predictions_fn=self._logits_to_predictions, 871 metrics_fn=self._metrics, 872 create_output_alternatives_fn=_classification_output_alternatives( 873 self.head_name, self._problem_type), 874 labels=labels, 875 train_op_fn=train_op_fn, 876 logits=logits, 877 logits_dimension=self.logits_dimension, 878 head_name=self.head_name, 879 weight_column_name=self.weight_column_name, 880 enable_centered_bias=self._enable_centered_bias) 881 882 def _transform_labels(self, mode, labels): 883 """Applies transformations to labels tensor.""" 884 if (mode == model_fn.ModeKeys.INFER) or (labels is None): 885 return None 886 labels_tensor = _to_labels_tensor(labels, self._label_name) 887 _check_no_sparse_tensor(labels_tensor) 888 return labels_tensor 889 890 def _logits_to_predictions(self, logits): 891 """Returns a dict of predictions. 892 893 Args: 894 logits: logits `Output` after applying possible centered bias. 895 896 Returns: 897 Dict of prediction `Output` keyed by `PredictionKey`. 898 """ 899 with ops.name_scope(None, "predictions", (logits,)): 900 two_class_logits = _one_class_to_two_class_logits(logits) 901 return { 902 prediction_key.PredictionKey.LOGITS: 903 logits, 904 prediction_key.PredictionKey.LOGISTIC: 905 math_ops.sigmoid( 906 logits, name=prediction_key.PredictionKey.LOGISTIC), 907 prediction_key.PredictionKey.PROBABILITIES: 908 nn.softmax( 909 two_class_logits, 910 name=prediction_key.PredictionKey.PROBABILITIES), 911 prediction_key.PredictionKey.CLASSES: 912 math_ops.argmax( 913 two_class_logits, 914 1, 915 name=prediction_key.PredictionKey.CLASSES) 916 } 917 918 def _metrics(self, eval_loss, predictions, labels, weights): 919 """Returns a dict of metrics keyed by name.""" 920 with ops.name_scope("metrics", values=( 921 [eval_loss, labels, weights] + list(six.itervalues(predictions)))): 922 classes = predictions[prediction_key.PredictionKey.CLASSES] 923 logistic = predictions[prediction_key.PredictionKey.LOGISTIC] 924 925 metrics = {_summary_key(self.head_name, mkey.LOSS): 926 metrics_lib.mean(eval_loss)} 927 # TODO(b/29366811): This currently results in both an "accuracy" and an 928 # "accuracy/threshold_0.500000_mean" metric for binary classification. 929 metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( 930 metrics_lib.accuracy(labels, classes, weights)) 931 metrics[_summary_key(self.head_name, mkey.PREDICTION_MEAN)] = ( 932 _predictions_streaming_mean(logistic, weights)) 933 metrics[_summary_key(self.head_name, mkey.LABEL_MEAN)] = ( 934 _indicator_labels_streaming_mean(labels, weights)) 935 936 # Also include the streaming mean of the label as an accuracy baseline, as 937 # a reminder to users. 938 metrics[_summary_key(self.head_name, mkey.ACCURACY_BASELINE)] = ( 939 _indicator_labels_streaming_mean(labels, weights)) 940 metrics[_summary_key(self.head_name, mkey.AUC)] = ( 941 _streaming_auc(logistic, labels, weights)) 942 metrics[_summary_key(self.head_name, mkey.AUC_PR)] = ( 943 _streaming_auc(logistic, labels, weights, curve="PR")) 944 945 for threshold in self._thresholds: 946 metrics[_summary_key( 947 self.head_name, mkey.ACCURACY_MEAN % threshold)] = ( 948 _streaming_accuracy_at_threshold(logistic, labels, weights, 949 threshold)) 950 # Precision for positive examples. 951 metrics[_summary_key( 952 self.head_name, mkey.PRECISION_MEAN % threshold)] = ( 953 _streaming_precision_at_threshold(logistic, labels, weights, 954 threshold)) 955 # Recall for positive examples. 956 metrics[_summary_key( 957 self.head_name, mkey.RECALL_MEAN % threshold)] = ( 958 _streaming_recall_at_threshold(logistic, labels, weights, 959 threshold)) 960 961 return metrics 962 963 964def _softmax_cross_entropy_loss(labels, logits, weights=None): 965 with ops.name_scope( 966 None, "softmax_cross_entropy_loss", (logits, labels,)) as name: 967 labels = ops.convert_to_tensor(labels) 968 # Check that we got integer for classification. 969 if not labels.dtype.is_integer: 970 raise ValueError("Labels dtype should be integer " 971 "Instead got %s." % labels.dtype) 972 973 # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels. 974 is_squeezed_labels = False 975 # TODO(ptucker): This will break for dynamic shapes. 976 if len(labels.get_shape()) == 2: 977 labels = array_ops.squeeze(labels, axis=(1,)) 978 is_squeezed_labels = True 979 980 loss = nn.sparse_softmax_cross_entropy_with_logits( 981 labels=labels, logits=logits, name=name) 982 983 # Restore squeezed dimension, if necessary, so loss matches weights shape. 984 if is_squeezed_labels: 985 loss = array_ops.expand_dims(loss, axis=(1,)) 986 987 return _compute_weighted_loss(loss, weights) 988 989 990class _MultiClassHead(_SingleHead): 991 """'Head' for multi class classification.""" 992 993 def __init__(self, 994 n_classes, 995 label_name=None, 996 weight_column_name=None, 997 enable_centered_bias=False, 998 head_name=None, 999 loss_fn=None, 1000 thresholds=None, 1001 metric_class_ids=None, 1002 label_keys=None): 1003 """'Head' for multi class classification. 1004 1005 This head expects to be fed integer labels specifying the class index. But 1006 if `label_keys` is specified, then labels must be strings from this 1007 vocabulary, and the predicted classes will be strings from the same 1008 vocabulary. 1009 1010 Args: 1011 n_classes: Number of classes, must be greater than 2 (for 2 classes, use 1012 `_BinaryLogisticHead`). 1013 label_name: String, name of the key in label dict. Can be null if label 1014 is a tensor (single headed models). 1015 weight_column_name: A string defining feature column name representing 1016 weights. It is used to down weight or boost examples during training. It 1017 will be multiplied by the loss of the example. 1018 enable_centered_bias: A bool. If True, estimator will learn a centered 1019 bias variable for each class. Rest of the model structure learns the 1020 residual after centered bias. 1021 head_name: name of the head. If provided, predictions, summary, metrics 1022 keys will be suffixed by `"/" + head_name` and the default variable 1023 scope will be `head_name`. 1024 loss_fn: Loss function. Defaults to softmax cross entropy loss. 1025 thresholds: thresholds for eval. 1026 metric_class_ids: List of class IDs for which we should report per-class 1027 metrics. Must all be in the range `[0, n_classes)`. 1028 label_keys: Optional list of strings with size `[n_classes]` defining the 1029 label vocabulary. 1030 1031 Raises: 1032 ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid. 1033 """ 1034 super(_MultiClassHead, self).__init__( 1035 problem_type=constants.ProblemType.CLASSIFICATION, 1036 logits_dimension=n_classes, 1037 label_name=label_name, 1038 weight_column_name=weight_column_name, 1039 head_name=head_name) 1040 1041 if (n_classes is None) or (n_classes <= 2): 1042 raise ValueError("n_classes must be > 2: %s." % n_classes) 1043 self._thresholds = thresholds if thresholds else (.5,) 1044 self._loss_fn = loss_fn if loss_fn else _softmax_cross_entropy_loss 1045 self._enable_centered_bias = enable_centered_bias 1046 self._metric_class_ids = tuple([] if metric_class_ids is None else 1047 metric_class_ids) 1048 for class_id in self._metric_class_ids: 1049 if (class_id < 0) or (class_id >= n_classes): 1050 raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes)) 1051 if label_keys and len(label_keys) != n_classes: 1052 raise ValueError("Length of label_keys must equal n_classes.") 1053 self._label_keys = label_keys 1054 1055 def create_model_fn_ops(self, 1056 features, 1057 mode, 1058 labels=None, 1059 train_op_fn=None, 1060 logits=None, 1061 logits_input=None, 1062 scope=None): 1063 """See `Head`.""" 1064 with variable_scope.variable_scope( 1065 scope, 1066 default_name=self.head_name or "multi_class_head", 1067 values=(tuple(six.itervalues(features)) + 1068 (labels, logits, logits_input))): 1069 labels = self._transform_labels(mode=mode, labels=labels) 1070 logits = _logits(logits_input, logits, self.logits_dimension) 1071 return _create_model_fn_ops( 1072 features=features, 1073 mode=mode, 1074 loss_fn=self._wrapped_loss_fn, 1075 logits_to_predictions_fn=self._logits_to_predictions, 1076 metrics_fn=self._metrics, 1077 create_output_alternatives_fn=_classification_output_alternatives( 1078 self.head_name, self._problem_type, self._label_keys), 1079 labels=labels, 1080 train_op_fn=train_op_fn, 1081 logits=logits, 1082 logits_dimension=self.logits_dimension, 1083 head_name=self.head_name, 1084 weight_column_name=self.weight_column_name, 1085 enable_centered_bias=self._enable_centered_bias) 1086 1087 def _transform_labels(self, mode, labels): 1088 """Returns a dict that contains both the original labels and label IDs.""" 1089 if (mode == model_fn.ModeKeys.INFER) or (labels is None): 1090 return None 1091 labels_tensor = _to_labels_tensor(labels, self._label_name) 1092 _check_no_sparse_tensor(labels_tensor) 1093 if self._label_keys: 1094 table = lookup_ops.index_table_from_tensor( 1095 self._label_keys, name="label_id_lookup") 1096 return { 1097 "labels": labels_tensor, 1098 "label_ids": table.lookup(labels_tensor), 1099 } 1100 return { 1101 "labels": labels_tensor, 1102 "label_ids": labels_tensor, 1103 } 1104 1105 def _labels(self, labels_dict): 1106 """Returns labels `Tensor` of the same type as classes.""" 1107 return labels_dict["labels"] 1108 1109 def _label_ids(self, labels_dict): 1110 """Returns integer label ID `Tensor`.""" 1111 return labels_dict["label_ids"] 1112 1113 def _wrapped_loss_fn(self, labels, logits, weights=None): 1114 return self._loss_fn(self._label_ids(labels), logits, weights=weights) 1115 1116 def _logits_to_predictions(self, logits): 1117 """Returns a dict of predictions. 1118 1119 Args: 1120 logits: logits `Tensor` after applying possible centered bias. 1121 1122 Returns: 1123 Dict of prediction `Tensor` keyed by `PredictionKey`. 1124 """ 1125 with ops.name_scope(None, "predictions", (logits,)): 1126 class_ids = math_ops.argmax( 1127 logits, 1, name=prediction_key.PredictionKey.CLASSES) 1128 if self._label_keys: 1129 table = lookup_ops.index_to_string_table_from_tensor( 1130 self._label_keys, name="class_string_lookup") 1131 classes = table.lookup(class_ids) 1132 else: 1133 classes = class_ids 1134 return { 1135 prediction_key.PredictionKey.LOGITS: logits, 1136 prediction_key.PredictionKey.PROBABILITIES: 1137 nn.softmax( 1138 logits, name=prediction_key.PredictionKey.PROBABILITIES), 1139 prediction_key.PredictionKey.CLASSES: classes 1140 } 1141 1142 def _metrics(self, eval_loss, predictions, labels, weights): 1143 """Returns a dict of metrics keyed by name.""" 1144 with ops.name_scope( 1145 "metrics", 1146 values=((eval_loss, self._labels(labels), self._label_ids(labels), 1147 weights) + tuple(six.itervalues(predictions)))): 1148 logits = predictions[prediction_key.PredictionKey.LOGITS] 1149 probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES] 1150 classes = predictions[prediction_key.PredictionKey.CLASSES] 1151 1152 metrics = {_summary_key(self.head_name, mkey.LOSS): 1153 metrics_lib.mean(eval_loss)} 1154 # TODO(b/29366811): This currently results in both an "accuracy" and an 1155 # "accuracy/threshold_0.500000_mean" metric for binary classification. 1156 metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( 1157 metrics_lib.accuracy(self._labels(labels), classes, weights)) 1158 1159 if not self._label_keys: 1160 # Classes are IDs. Add some metrics. 1161 for class_id in self._metric_class_ids: 1162 metrics[_summary_key( 1163 self.head_name, mkey.CLASS_PREDICTION_MEAN % class_id)] = ( 1164 _class_predictions_streaming_mean(classes, weights, class_id)) 1165 # TODO(ptucker): Add per-class accuracy, precision, recall. 1166 metrics[_summary_key( 1167 self.head_name, mkey.CLASS_LABEL_MEAN % class_id)] = ( 1168 _class_labels_streaming_mean( 1169 self._label_ids(labels), weights, class_id)) 1170 metrics[_summary_key( 1171 self.head_name, mkey.CLASS_PROBABILITY_MEAN % class_id)] = ( 1172 _predictions_streaming_mean(probabilities, weights, class_id)) 1173 metrics[_summary_key( 1174 self.head_name, mkey.CLASS_LOGITS_MEAN % class_id)] = ( 1175 _predictions_streaming_mean(logits, weights, class_id)) 1176 1177 return metrics 1178 1179 1180def _to_labels_tensor(labels, label_name): 1181 """Returns label as a tensor. 1182 1183 Args: 1184 labels: Label `Tensor` or `SparseTensor` or a dict containing labels. 1185 label_name: Label name if labels is a dict. 1186 1187 Returns: 1188 Label `Tensor` or `SparseTensor`. 1189 """ 1190 labels = labels[label_name] if isinstance(labels, dict) else labels 1191 return framework_lib.convert_to_tensor_or_sparse_tensor(labels) 1192 1193 1194def _check_no_sparse_tensor(x): 1195 """Raises ValueError if the given tensor is `SparseTensor`.""" 1196 if isinstance(x, sparse_tensor.SparseTensor): 1197 raise ValueError("SparseTensor is not supported.") 1198 1199 1200def _sparse_labels_to_indicator(labels, num_classes): 1201 """If labels is `SparseTensor`, converts it to indicator `Tensor`. 1202 1203 Args: 1204 labels: Label `Tensor` or `SparseTensor`. 1205 num_classes: Number of classes. 1206 1207 Returns: 1208 Dense label `Tensor`. 1209 1210 Raises: 1211 ValueError: If labels is `SparseTensor` and `num_classes` < 2. 1212 """ 1213 if isinstance(labels, sparse_tensor.SparseTensor): 1214 if num_classes < 2: 1215 raise ValueError("Must set num_classes >= 2 when passing labels as a " 1216 "SparseTensor.") 1217 return math_ops.cast( 1218 sparse_ops.sparse_to_indicator(labels, num_classes), dtypes.int64) 1219 return labels 1220 1221 1222def _assert_labels_rank(labels): 1223 return control_flow_ops.Assert( 1224 math_ops.less_equal(array_ops.rank(labels), 2), 1225 ("labels shape should be either [batch_size, 1] or [batch_size]",)) 1226 1227 1228class _BinarySvmHead(_SingleHead): 1229 """`Head` for binary classification using SVM.""" 1230 1231 def __init__(self, label_name, weight_column_name, enable_centered_bias, 1232 head_name, thresholds): 1233 1234 def _loss_fn(labels, logits, weights=None): 1235 with ops.name_scope(None, "hinge_loss", (logits, labels)) as name: 1236 with ops.control_dependencies((_assert_labels_rank(labels),)): 1237 labels = array_ops.reshape(labels, shape=(-1, 1)) 1238 loss = losses_lib.hinge_loss(labels=labels, logits=logits, scope=name, 1239 reduction=losses_lib.Reduction.NONE) 1240 return _compute_weighted_loss(loss, weights) 1241 1242 super(_BinarySvmHead, self).__init__( 1243 problem_type=constants.ProblemType.LOGISTIC_REGRESSION, 1244 logits_dimension=1, 1245 label_name=label_name, 1246 weight_column_name=weight_column_name, 1247 head_name=head_name) 1248 self._thresholds = thresholds if thresholds else (.5,) 1249 self._loss_fn = _loss_fn 1250 self._enable_centered_bias = enable_centered_bias 1251 1252 def create_model_fn_ops(self, 1253 features, 1254 mode, 1255 labels=None, 1256 train_op_fn=None, 1257 logits=None, 1258 logits_input=None, 1259 scope=None): 1260 """See `Head`.""" 1261 with variable_scope.variable_scope( 1262 scope, 1263 default_name=self.head_name or "binary_svm_head", 1264 values=(tuple(six.itervalues(features)) + 1265 (labels, logits, logits_input))): 1266 labels = self._transform_labels(mode=mode, labels=labels) 1267 logits = _logits(logits_input, logits, self.logits_dimension) 1268 return _create_model_fn_ops( 1269 features=features, 1270 mode=mode, 1271 loss_fn=self._loss_fn, 1272 logits_to_predictions_fn=self._logits_to_predictions, 1273 metrics_fn=self._metrics, 1274 # TODO(zakaria): Handle labels for export. 1275 create_output_alternatives_fn=self._create_output_alternatives, 1276 labels=labels, 1277 train_op_fn=train_op_fn, 1278 logits=logits, 1279 logits_dimension=self.logits_dimension, 1280 head_name=self.head_name, 1281 weight_column_name=self.weight_column_name, 1282 enable_centered_bias=self._enable_centered_bias) 1283 1284 def _transform_labels(self, mode, labels): 1285 """Applies transformations to labels tensor.""" 1286 if (mode == model_fn.ModeKeys.INFER) or (labels is None): 1287 return None 1288 labels_tensor = _to_labels_tensor(labels, self._label_name) 1289 _check_no_sparse_tensor(labels_tensor) 1290 return labels_tensor 1291 1292 def _logits_to_predictions(self, logits): 1293 """See `_MultiClassHead`.""" 1294 with ops.name_scope(None, "predictions", (logits,)): 1295 return { 1296 prediction_key.PredictionKey.LOGITS: 1297 logits, 1298 prediction_key.PredictionKey.CLASSES: 1299 math_ops.argmax( 1300 _one_class_to_two_class_logits(logits), 1301 1, 1302 name=prediction_key.PredictionKey.CLASSES) 1303 } 1304 1305 def _metrics(self, eval_loss, predictions, labels, weights): 1306 """See `_MultiClassHead`.""" 1307 with ops.name_scope("metrics", values=( 1308 [eval_loss, labels, weights] + list(six.itervalues(predictions)))): 1309 metrics = {_summary_key(self.head_name, mkey.LOSS): 1310 metrics_lib.mean(eval_loss)} 1311 1312 # TODO(b/29366811): This currently results in both an "accuracy" and an 1313 # "accuracy/threshold_0.500000_mean" metric for binary classification. 1314 classes = predictions[prediction_key.PredictionKey.CLASSES] 1315 metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( 1316 metrics_lib.accuracy(labels, classes, weights)) 1317 # TODO(sibyl-vie3Poto): add more metrics relevant for svms. 1318 1319 return metrics 1320 1321 1322class _MultiLabelHead(_SingleHead): 1323 """`Head` for multi-label classification.""" 1324 1325 # TODO(zakaria): add signature and metric for multilabel. 1326 def __init__(self, 1327 n_classes, 1328 label_name, 1329 weight_column_name, 1330 enable_centered_bias, 1331 head_name, 1332 thresholds, 1333 metric_class_ids=None, 1334 loss_fn=None): 1335 1336 super(_MultiLabelHead, self).__init__( 1337 problem_type=constants.ProblemType.CLASSIFICATION, 1338 logits_dimension=n_classes, 1339 label_name=label_name, 1340 weight_column_name=weight_column_name, 1341 head_name=head_name) 1342 1343 self._thresholds = thresholds if thresholds else (.5,) 1344 self._loss_fn = loss_fn if loss_fn else _sigmoid_cross_entropy_loss 1345 self._enable_centered_bias = enable_centered_bias 1346 self._metric_class_ids = tuple([] if metric_class_ids is None else 1347 metric_class_ids) 1348 for class_id in self._metric_class_ids: 1349 if (class_id < 0) or (class_id >= n_classes): 1350 raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes)) 1351 1352 def create_model_fn_ops(self, 1353 features, 1354 mode, 1355 labels=None, 1356 train_op_fn=None, 1357 logits=None, 1358 logits_input=None, 1359 scope=None): 1360 """See `Head`.""" 1361 with variable_scope.variable_scope( 1362 scope, 1363 default_name=self.head_name or "multi_label_head", 1364 values=(tuple(six.itervalues(features)) + 1365 (labels, logits, logits_input))): 1366 labels = self._transform_labels(mode=mode, labels=labels) 1367 logits = _logits(logits_input, logits, self.logits_dimension) 1368 return _create_model_fn_ops( 1369 features=features, 1370 mode=mode, 1371 loss_fn=self._loss_fn, 1372 logits_to_predictions_fn=self._logits_to_predictions, 1373 metrics_fn=self._metrics, 1374 create_output_alternatives_fn=_classification_output_alternatives( 1375 self.head_name, self._problem_type), 1376 labels=labels, 1377 train_op_fn=train_op_fn, 1378 logits=logits, 1379 logits_dimension=self.logits_dimension, 1380 head_name=self.head_name, 1381 weight_column_name=self.weight_column_name, 1382 enable_centered_bias=self._enable_centered_bias) 1383 1384 def _transform_labels(self, mode, labels): 1385 """Applies transformations to labels tensor.""" 1386 if (mode == model_fn.ModeKeys.INFER) or (labels is None): 1387 return None 1388 labels_tensor = _to_labels_tensor(labels, self._label_name) 1389 labels_tensor = _sparse_labels_to_indicator(labels_tensor, 1390 self._logits_dimension) 1391 return labels_tensor 1392 1393 def _logits_to_predictions(self, logits): 1394 """See `_MultiClassHead`.""" 1395 with ops.name_scope(None, "predictions", (logits,)): 1396 return { 1397 prediction_key.PredictionKey.LOGITS: 1398 logits, 1399 prediction_key.PredictionKey.PROBABILITIES: 1400 math_ops.sigmoid( 1401 logits, name=prediction_key.PredictionKey.PROBABILITIES), 1402 prediction_key.PredictionKey.CLASSES: 1403 math_ops.cast( 1404 math_ops.greater(logits, 0), 1405 dtypes.int64, 1406 name=prediction_key.PredictionKey.CLASSES) 1407 } 1408 1409 def _metrics(self, eval_loss, predictions, labels, weights): 1410 """Returns a dict of metrics keyed by name.""" 1411 with ops.name_scope("metrics", values=( 1412 [eval_loss, labels, weights] + list(six.itervalues(predictions)))): 1413 classes = predictions[prediction_key.PredictionKey.CLASSES] 1414 probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES] 1415 logits = predictions[prediction_key.PredictionKey.LOGITS] 1416 1417 metrics = {_summary_key(self.head_name, mkey.LOSS): 1418 metrics_lib.mean(eval_loss)} 1419 # TODO(b/29366811): This currently results in both an "accuracy" and an 1420 # "accuracy/threshold_0.500000_mean" metric for binary classification. 1421 metrics[_summary_key(self.head_name, mkey.ACCURACY)] = ( 1422 metrics_lib.accuracy(labels, classes, weights)) 1423 metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc( 1424 probabilities, labels, weights) 1425 metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc( 1426 probabilities, labels, weights, curve="PR") 1427 1428 for class_id in self._metric_class_ids: 1429 # TODO(ptucker): Add per-class accuracy, precision, recall. 1430 metrics[_summary_key( 1431 self.head_name, mkey.CLASS_PREDICTION_MEAN % class_id)] = ( 1432 _predictions_streaming_mean(classes, weights, class_id)) 1433 metrics[_summary_key( 1434 self.head_name, mkey.CLASS_LABEL_MEAN % class_id)] = ( 1435 _indicator_labels_streaming_mean(labels, weights, class_id)) 1436 metrics[_summary_key( 1437 self.head_name, mkey.CLASS_PROBABILITY_MEAN % class_id)] = ( 1438 _predictions_streaming_mean(probabilities, weights, class_id)) 1439 metrics[_summary_key( 1440 self.head_name, mkey.CLASS_LOGITS_MEAN % class_id)] = ( 1441 _predictions_streaming_mean(logits, weights, class_id)) 1442 metrics[_summary_key(self.head_name, mkey.CLASS_AUC % class_id)] = ( 1443 _streaming_auc(probabilities, labels, weights, class_id)) 1444 metrics[_summary_key(self.head_name, mkey.CLASS_AUC_PR % class_id)] = ( 1445 _streaming_auc(probabilities, labels, weights, class_id, 1446 curve="PR")) 1447 1448 return metrics 1449 1450 1451class _LossOnlyHead(Head): 1452 """`Head` implementation for additional loss terms. 1453 1454 This class only holds loss terms unrelated to any other heads (labels), 1455 e.g. regularization. 1456 1457 Common usage: 1458 This is oftem combine with other heads in a multi head setup. 1459 ```python 1460 head = multi_head([ 1461 head1, head2, loss_only_head('regularizer', regularizer)]) 1462 ``` 1463 """ 1464 1465 def __init__(self, loss_fn, head_name=None): 1466 self._loss_fn = loss_fn 1467 self.head_name = head_name or "loss_only_head" 1468 1469 @property 1470 def logits_dimension(self): 1471 return 0 1472 1473 def create_model_fn_ops(self, 1474 features, 1475 mode, 1476 labels=None, 1477 train_op_fn=None, 1478 logits=None, 1479 logits_input=None, 1480 scope=None): 1481 """See `_Head.create_model_fn_ops`. 1482 1483 Args: 1484 features: Not been used. 1485 mode: Estimator's `ModeKeys`. 1486 labels: Labels `Tensor`, or `dict` of same. 1487 train_op_fn: Function that takes a scalar loss and returns an op to 1488 optimize with the loss. 1489 logits: Not been used. 1490 logits_input: Not been used. 1491 scope: Optional scope for variable_scope. If provided, will be passed to 1492 all heads. Most users will want to set this to `None`, so each head 1493 constructs a separate variable_scope according to its `head_name`. 1494 1495 Returns: 1496 A `ModelFnOps` object. 1497 1498 Raises: 1499 ValueError: if `mode` is not recognition. 1500 """ 1501 _check_mode_valid(mode) 1502 loss = None 1503 train_op = None 1504 if mode != model_fn.ModeKeys.INFER: 1505 with variable_scope.variable_scope(scope, default_name=self.head_name): 1506 loss = self._loss_fn() 1507 if isinstance(loss, list): 1508 loss = math_ops.add_n(loss) 1509 # The name_scope escapism is needed to maintain the same summary tag 1510 # after switching away from the now unsupported API. 1511 with ops.name_scope(""): 1512 summary_loss = array_ops.identity(loss) 1513 summary.scalar(_summary_key(self.head_name, mkey.LOSS), 1514 summary_loss) 1515 if mode == model_fn.ModeKeys.TRAIN: 1516 if train_op_fn is None: 1517 raise ValueError("train_op_fn can not be None in TRAIN mode") 1518 with ops.name_scope(None, "train_op", (loss,)): 1519 train_op = train_op_fn(loss) 1520 1521 return model_fn.ModelFnOps( 1522 mode=mode, 1523 loss=loss, 1524 train_op=train_op, 1525 predictions={}, 1526 eval_metric_ops={}) 1527 1528 1529class _MultiHead(Head): 1530 """`Head` implementation for multi objective learning. 1531 1532 This class is responsible for using and merging the output of multiple 1533 `Head` objects. 1534 1535 All heads stem from the same logits/logit_input tensor. 1536 1537 Common usage: 1538 For simple use cases you can pass the activation of hidden layer like 1539 this from your model_fn, 1540 ```python 1541 last_hidden_layer_activation = ... Build your model. 1542 multi_head = ... 1543 return multi_head.create_model_fn_ops( 1544 ..., logits_input=last_hidden_layer_activation, ...) 1545 ``` 1546 1547 Or you can create a logits tensor of 1548 [batch_size, multi_head.logits_dimension] shape. _MultiHead will split the 1549 logits for you. 1550 return multi_head.create_model_fn_ops(..., logits=logits, ...) 1551 1552 For more complex use cases like a multi-task/multi-tower model or when logits 1553 for each head has to be created separately, you can pass a dict of logits 1554 where the keys match the name of the single heads. 1555 ```python 1556 logits = {"head1": logits1, "head2": logits2} 1557 return multi_head.create_model_fn_ops(..., logits=logits, ...) 1558 ``` 1559 1560 Here is what this class does, 1561 + For training, merges losses of each heads according a function provided by 1562 user, calls user provided train_op_fn with this final loss. 1563 + For eval, merges metrics by adding head_name suffix to the keys in eval 1564 metrics. 1565 + For inference, updates keys in prediction dict to a 2-tuple, 1566 (head_name, prediction_key) 1567 """ 1568 1569 def __init__(self, heads, loss_merger): 1570 """_Head to merges multiple _Head objects. 1571 1572 Args: 1573 heads: list of _Head objects. 1574 loss_merger: function that takes a list of loss tensors for the heads 1575 and returns the final loss tensor for the multi head. 1576 1577 Raises: 1578 ValueError: if any head does not have a name. 1579 """ 1580 self._logits_dimension = 0 1581 for head in heads: 1582 if not head.head_name: 1583 raise ValueError("Members of MultiHead must have names.") 1584 self._logits_dimension += head.logits_dimension 1585 1586 self._heads = heads 1587 self._loss_merger = loss_merger 1588 1589 @property 1590 def logits_dimension(self): 1591 return self._logits_dimension 1592 1593 def create_model_fn_ops(self, 1594 features, 1595 mode, 1596 labels=None, 1597 train_op_fn=None, 1598 logits=None, 1599 logits_input=None, 1600 scope=None): 1601 """See `_Head.create_model_fn_ops`. 1602 1603 Args: 1604 features: Input `dict` of `Tensor` objects. 1605 mode: Estimator's `ModeKeys`. 1606 labels: Labels `Tensor`, or `dict` of same. 1607 train_op_fn: Function that takes a scalar loss and returns an op to 1608 optimize with the loss. 1609 logits: Concatenated logits for all heads or a dict of head name to logits 1610 tensor. If concatenated logits, it should have (batchsize, x) shape 1611 where x is the sum of `logits_dimension` of all the heads, 1612 i.e., same as `logits_dimension` of this class. create_model_fn_ops 1613 will split the logits tensor and pass logits of proper size to each 1614 head. This is useful if we want to be agnostic about whether you 1615 creating a single versus multihead. logits can also be a dict for 1616 convenience where you are creating the head specific logits explicitly 1617 and don't want to concatenate them yourself. 1618 logits_input: tensor to build logits from. 1619 scope: Optional scope for variable_scope. If provided, will be passed to 1620 all heads. Most users will want to set this to `None`, so each head 1621 constructs a separate variable_scope according to its `head_name`. 1622 1623 Returns: 1624 `ModelFnOps`. 1625 1626 Raises: 1627 ValueError: if `mode` is not recognized, or neither or both of `logits` 1628 and `logits_input` is provided. 1629 """ 1630 _check_mode_valid(mode) 1631 all_model_fn_ops = [] 1632 if logits is None: 1633 # Use logits_input. 1634 for head in self._heads: 1635 all_model_fn_ops.append( 1636 head.create_model_fn_ops( 1637 features=features, 1638 mode=mode, 1639 labels=labels, 1640 train_op_fn=no_op_train_fn, 1641 logits_input=logits_input, 1642 scope=scope)) 1643 else: 1644 head_logits_pairs = [] 1645 if isinstance(logits, dict): 1646 head_logits_pairs = [] 1647 for head in self._heads: 1648 if isinstance(head, _LossOnlyHead): 1649 head_logits_pairs.append((head, None)) 1650 else: 1651 head_logits_pairs.append((head, logits[head.head_name])) 1652 else: 1653 # Split logits for each head. 1654 head_logits_pairs = zip(self._heads, self._split_logits(logits)) 1655 1656 for head, head_logits in head_logits_pairs: 1657 all_model_fn_ops.append( 1658 head.create_model_fn_ops( 1659 features=features, 1660 mode=mode, 1661 labels=labels, 1662 train_op_fn=no_op_train_fn, 1663 logits=head_logits, 1664 scope=scope)) 1665 1666 if mode == model_fn.ModeKeys.TRAIN: 1667 if train_op_fn is None: 1668 raise ValueError("train_op_fn can not be None in TRAIN mode.") 1669 return self._merge_train(all_model_fn_ops, train_op_fn) 1670 if mode == model_fn.ModeKeys.INFER: 1671 return self._merge_infer(all_model_fn_ops) 1672 if mode == model_fn.ModeKeys.EVAL: 1673 return self._merge_eval(all_model_fn_ops) 1674 raise ValueError("mode=%s unrecognized" % str(mode)) 1675 1676 def _split_logits(self, logits): 1677 """Splits logits for heads. 1678 1679 Args: 1680 logits: the logits tensor. 1681 1682 Returns: 1683 A list of logits for the individual heads. 1684 """ 1685 all_logits = [] 1686 begin = 0 1687 for head in self._heads: 1688 current_logits_size = head.logits_dimension 1689 current_logits = array_ops.slice(logits, [0, begin], 1690 [-1, current_logits_size]) 1691 all_logits.append(current_logits) 1692 begin += current_logits_size 1693 return all_logits 1694 1695 def _merge_train(self, all_model_fn_ops, train_op_fn): 1696 """Merges list of ModelFnOps for training. 1697 1698 Args: 1699 all_model_fn_ops: list of ModelFnOps for the individual heads. 1700 train_op_fn: Function to create train op. See `create_model_fn_ops` 1701 documentation for more details. 1702 1703 Returns: 1704 ModelFnOps that merges all heads for TRAIN. 1705 """ 1706 losses = [] 1707 metrics = {} 1708 additional_train_ops = [] 1709 for m in all_model_fn_ops: 1710 losses.append(m.loss) 1711 if m.eval_metric_ops is not None: 1712 for k, v in six.iteritems(m.eval_metric_ops): 1713 # metrics["%s/%s" % (k, head_name)] = v 1714 metrics[k] = v 1715 additional_train_ops.append(m.train_op) 1716 loss = self._loss_merger(losses) 1717 1718 train_op = train_op_fn(loss) 1719 train_op = control_flow_ops.group(train_op, *additional_train_ops) 1720 return model_fn.ModelFnOps( 1721 mode=model_fn.ModeKeys.TRAIN, 1722 loss=loss, 1723 train_op=train_op, 1724 eval_metric_ops=metrics) 1725 1726 def _merge_infer(self, all_model_fn_ops): 1727 """Merges list of ModelFnOps for inference. 1728 1729 Args: 1730 all_model_fn_ops: list of ModelFnOps for the individual heads. 1731 1732 Returns: 1733 ModelFnOps that Merges all the heads for INFER. 1734 """ 1735 predictions = {} 1736 output_alternatives = {} 1737 for head, m in zip(self._heads, all_model_fn_ops): 1738 if isinstance(head, _LossOnlyHead): 1739 continue 1740 head_name = head.head_name 1741 output_alternatives[head_name] = m.output_alternatives[head_name] 1742 for k, v in m.predictions.items(): 1743 predictions[(head_name, k)] = v 1744 1745 return model_fn.ModelFnOps( 1746 mode=model_fn.ModeKeys.INFER, 1747 predictions=predictions, 1748 output_alternatives=output_alternatives) 1749 1750 def _merge_eval(self, all_model_fn_ops): 1751 """Merges list of ModelFnOps for eval. 1752 1753 Args: 1754 all_model_fn_ops: list of ModelFnOps for the individual heads. 1755 1756 Returns: 1757 ModelFnOps that merges all the heads for EVAL. 1758 """ 1759 predictions = {} 1760 metrics = {} 1761 losses = [] 1762 for head, m in zip(self._heads, all_model_fn_ops): 1763 losses.append(m.loss) 1764 head_name = head.head_name 1765 for k, v in m.predictions.items(): 1766 predictions[(head_name, k)] = v 1767 for k, v in m.eval_metric_ops.items(): 1768 # metrics["%s/%s" % (k, head_name)] = v 1769 metrics[k] = v 1770 loss = self._loss_merger(losses) 1771 1772 return model_fn.ModelFnOps( 1773 mode=model_fn.ModeKeys.EVAL, 1774 predictions=predictions, 1775 loss=loss, 1776 eval_metric_ops=metrics) 1777 1778 1779def _weight_tensor(features, weight_column_name): 1780 """Returns weights as `Tensor` of rank 0, or at least 2.""" 1781 if not weight_column_name: 1782 return None 1783 if weight_column_name not in features: 1784 raise ValueError("Weights {} missing from features.".format( 1785 weight_column_name)) 1786 with ops.name_scope(None, "weight_tensor", tuple(six.itervalues(features))): 1787 weight_tensor = math_ops.cast(features[weight_column_name], dtypes.float32) 1788 shape = weight_tensor.get_shape() 1789 rank = shape.ndims 1790 # We don't bother with expanding dims of non-staticly shaped tensors or 1791 # scalars, and >1d is already in a good format. 1792 if rank == 1: 1793 logging.warning("Weights {} has shape {}, expanding to make it 2d.". 1794 format(weight_column_name, shape)) 1795 return ( 1796 sparse_ops.sparse_reshape(weight_tensor, (-1, 1)) 1797 if isinstance(weight_tensor, sparse_tensor.SparseTensor) else 1798 array_ops.reshape(weight_tensor, (-1, 1))) 1799 return weight_tensor 1800 1801 1802# TODO(zakaria): This function is needed for backward compatibility and should 1803# be removed when we migrate to core. 1804def _compute_weighted_loss(loss_unweighted, weight, name="loss"): 1805 """Returns a tuple of (loss_train, loss_report). 1806 1807 loss is used for gradient descent while weighted_average_loss is used for 1808 summaries to be backward compatible. 1809 1810 loss is different from the loss reported on the tensorboard as we 1811 should respect the example weights when computing the gradient. 1812 1813 L = sum_{i} w_{i} * l_{i} / B 1814 1815 where B is the number of examples in the batch, l_{i}, w_{i} are individual 1816 losses, and example weight. 1817 1818 Args: 1819 loss_unweighted: Unweighted loss 1820 weight: Weight tensor 1821 name: Optional name 1822 1823 Returns: 1824 A tuple of losses. First one for training and the second one for reporting. 1825 """ 1826 with ops.name_scope(name, values=(loss_unweighted, weight)) as name_scope: 1827 if weight is None: 1828 loss = math_ops.reduce_mean(loss_unweighted, name=name_scope) 1829 return loss, loss 1830 weight = weights_broadcast_ops.broadcast_weights(weight, loss_unweighted) 1831 with ops.name_scope(None, "weighted_loss", 1832 (loss_unweighted, weight)) as name: 1833 weighted_loss = math_ops.multiply(loss_unweighted, weight, name=name) 1834 weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope) 1835 weighted_loss_normalized = math_ops.div( 1836 math_ops.reduce_sum(weighted_loss), 1837 math_ops.cast(math_ops.reduce_sum(weight), dtypes.float32), 1838 name="weighted_average_loss") 1839 1840 return weighted_loss_mean, weighted_loss_normalized 1841 1842 1843def _wrap_custom_loss_fn(loss_fn): 1844 def _wrapper(labels, logits, weights=None): 1845 if weights is None: 1846 loss = loss_fn(labels, logits) 1847 else: 1848 loss = loss_fn(labels, logits, weights) 1849 return loss, loss 1850 return _wrapper 1851 1852 1853def _check_mode_valid(mode): 1854 """Raises ValueError if the given mode is invalid.""" 1855 if (mode != model_fn.ModeKeys.TRAIN and mode != model_fn.ModeKeys.INFER and 1856 mode != model_fn.ModeKeys.EVAL): 1857 raise ValueError("mode=%s unrecognized." % str(mode)) 1858 1859 1860def _get_arguments(func): 1861 """Returns a spec of given func.""" 1862 _, func = tf_decorator.unwrap(func) 1863 if hasattr(func, "__code__"): 1864 # Regular function. 1865 return tf_inspect.getargspec(func) 1866 elif hasattr(func, "func"): 1867 # Partial function. 1868 return _get_arguments(func.func) 1869 elif hasattr(func, "__call__"): 1870 # Callable object. 1871 return _get_arguments(func.__call__) 1872 1873 1874def _verify_loss_fn_args(loss_fn): 1875 args = _get_arguments(loss_fn).args 1876 for arg_name in ["labels", "logits", "weights"]: 1877 if arg_name not in args: 1878 raise ValueError("Argument %s not found in loss_fn." % arg_name) 1879 1880 1881def _centered_bias(logits_dimension, head_name=None): 1882 """Returns centered_bias `Variable`. 1883 1884 Args: 1885 logits_dimension: Last dimension of `logits`. Must be >= 1. 1886 head_name: Optional name of the head. 1887 1888 Returns: 1889 `Variable` with shape `[logits_dimension]`. 1890 1891 Raises: 1892 ValueError: if `logits_dimension` is invalid. 1893 """ 1894 if (logits_dimension is None) or (logits_dimension < 1): 1895 raise ValueError("Invalid logits_dimension %s." % logits_dimension) 1896 # Do not create a variable with variable_scope.get_variable, because that may 1897 # create a PartitionedVariable, which does not support indexing, so 1898 # summary.scalar will not work. 1899 centered_bias = variable_scope.variable( 1900 name="centered_bias_weight", 1901 initial_value=array_ops.zeros(shape=(logits_dimension,)), 1902 trainable=True) 1903 for dim in range(logits_dimension): 1904 if head_name: 1905 summary.scalar("centered_bias/bias_%d/%s" % (dim, head_name), 1906 centered_bias[dim]) 1907 else: 1908 summary.scalar("centered_bias/bias_%d" % dim, centered_bias[dim]) 1909 return centered_bias 1910 1911 1912def _centered_bias_step(centered_bias, batch_size, labels, loss_fn, weights): 1913 """Creates and returns training op for centered bias.""" 1914 with ops.name_scope(None, "centered_bias_step", (labels,)) as name: 1915 logits_dimension = array_ops.shape(centered_bias)[0] 1916 logits = array_ops.reshape( 1917 array_ops.tile(centered_bias, (batch_size,)), 1918 (batch_size, logits_dimension)) 1919 with ops.name_scope(None, "centered_bias", (labels, logits)): 1920 centered_bias_loss = math_ops.reduce_mean( 1921 loss_fn(labels, logits, weights), name="training_loss") 1922 # Learn central bias by an optimizer. 0.1 is a convervative lr for a 1923 # single variable. 1924 return training.AdagradOptimizer(0.1).minimize( 1925 centered_bias_loss, var_list=(centered_bias,), name=name) 1926 1927 1928def _summary_key(head_name, val): 1929 return "%s/%s" % (val, head_name) if head_name else val 1930 1931 1932def _train_op(loss, labels, train_op_fn, centered_bias, batch_size, loss_fn, 1933 weights): 1934 """Returns op for the training step.""" 1935 if centered_bias is not None: 1936 centered_bias_step = _centered_bias_step( 1937 centered_bias=centered_bias, 1938 batch_size=batch_size, 1939 labels=labels, 1940 loss_fn=loss_fn, 1941 weights=weights) 1942 else: 1943 centered_bias_step = None 1944 with ops.name_scope(None, "train_op", (loss, labels)): 1945 train_op = train_op_fn(loss) 1946 if centered_bias_step is not None: 1947 train_op = control_flow_ops.group(train_op, centered_bias_step) 1948 return train_op 1949 1950 1951def _sigmoid_cross_entropy_loss(labels, logits, weights=None): 1952 with ops.name_scope(None, "sigmoid_cross_entropy_loss", 1953 (logits, labels)) as name: 1954 # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels. 1955 loss = nn.sigmoid_cross_entropy_with_logits( 1956 labels=math_ops.cast(labels, dtypes.float32), logits=logits, name=name) 1957 return _compute_weighted_loss(loss, weights) 1958 1959 1960def _float_weights_or_none(weights): 1961 if weights is None: 1962 return None 1963 with ops.name_scope(None, "float_weights", (weights,)) as name: 1964 return math_ops.cast(weights, dtypes.float32, name=name) 1965 1966 1967def _indicator_labels_streaming_mean(labels, weights=None, class_id=None): 1968 labels = math_ops.cast(labels, dtypes.float32) 1969 weights = _float_weights_or_none(weights) 1970 if weights is not None: 1971 weights = weights_broadcast_ops.broadcast_weights(weights, labels) 1972 if class_id is not None: 1973 if weights is not None: 1974 weights = weights[:, class_id] 1975 labels = labels[:, class_id] 1976 return metrics_lib.mean(labels, weights) 1977 1978 1979def _predictions_streaming_mean(predictions, 1980 weights=None, 1981 class_id=None): 1982 predictions = math_ops.cast(predictions, dtypes.float32) 1983 weights = _float_weights_or_none(weights) 1984 if weights is not None: 1985 weights = weights_broadcast_ops.broadcast_weights(weights, predictions) 1986 if class_id is not None: 1987 if weights is not None: 1988 weights = weights[:, class_id] 1989 predictions = predictions[:, class_id] 1990 return metrics_lib.mean(predictions, weights) 1991 1992 1993# TODO(ptucker): Add support for SparseTensor labels. 1994def _class_id_labels_to_indicator(labels, num_classes): 1995 if (num_classes is None) or (num_classes < 2): 1996 raise ValueError("Invalid num_classes %s." % num_classes) 1997 with ops.control_dependencies((_assert_labels_rank(labels),)): 1998 labels = array_ops.reshape(labels, (-1,)) 1999 return array_ops.one_hot(labels, depth=num_classes, axis=-1) 2000 2001 2002def _class_predictions_streaming_mean(predictions, weights, class_id): 2003 return metrics_lib.mean( 2004 array_ops.where( 2005 math_ops.equal( 2006 math_ops.cast(class_id, dtypes.int32), 2007 math_ops.cast(predictions, dtypes.int32)), 2008 array_ops.ones_like(predictions), array_ops.zeros_like(predictions)), 2009 weights=weights) 2010 2011 2012def _class_labels_streaming_mean(labels, weights, class_id): 2013 return metrics_lib.mean( 2014 array_ops.where( 2015 math_ops.equal( 2016 math_ops.cast(class_id, dtypes.int32), 2017 math_ops.cast(labels, dtypes.int32)), array_ops.ones_like(labels), 2018 array_ops.zeros_like(labels)), 2019 weights=weights) 2020 2021 2022def _streaming_auc(predictions, labels, weights=None, class_id=None, 2023 curve="ROC"): 2024 # pylint: disable=missing-docstring 2025 predictions = math_ops.cast(predictions, dtypes.float32) 2026 if labels.dtype.base_dtype != dtypes.bool: 2027 logging.warning("Casting %s labels to bool.", labels.dtype) 2028 labels = math_ops.cast(labels, dtypes.bool) 2029 weights = _float_weights_or_none(weights) 2030 if weights is not None: 2031 weights = weights_broadcast_ops.broadcast_weights(weights, predictions) 2032 if class_id is not None: 2033 if weights is not None: 2034 weights = weights[:, class_id] 2035 predictions = predictions[:, class_id] 2036 labels = labels[:, class_id] 2037 return metrics_lib.auc(labels, predictions, weights, curve=curve) 2038 2039 2040def _assert_class_id(class_id, num_classes=None): 2041 """Average label value for class `class_id`.""" 2042 if (class_id is None) or (class_id < 0): 2043 raise ValueError("Invalid class_id %s." % class_id) 2044 if num_classes is not None: 2045 if num_classes < 2: 2046 raise ValueError("Invalid num_classes %s." % num_classes) 2047 if class_id >= num_classes: 2048 raise ValueError("Invalid class_id %s." % class_id) 2049 2050 2051def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold): 2052 threshold_predictions = math_ops.cast( 2053 math_ops.greater_equal(predictions, threshold), dtypes.float32) 2054 return metrics_lib.accuracy(labels, threshold_predictions, weights) 2055 2056 2057def _streaming_precision_at_threshold(predictions, labels, weights, threshold): 2058 precision_tensor, update_op = metrics_lib.precision_at_thresholds( 2059 labels, predictions, (threshold,), _float_weights_or_none(weights)) 2060 return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) 2061 2062 2063def _streaming_recall_at_threshold(predictions, labels, weights, threshold): 2064 precision_tensor, update_op = metrics_lib.recall_at_thresholds( 2065 labels, predictions, (threshold,), _float_weights_or_none(weights)) 2066 return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) 2067 2068 2069def _classification_output_alternatives(head_name, problem_type, 2070 label_keys=None): 2071 """Creates a func to generate output alternatives for classification. 2072 2073 Servo expects classes to be a string tensor, and have the same dimensions 2074 as the probabilities tensor. It should contain the labels of the corresponding 2075 entries in probabilities. This function creates a new classes tensor that 2076 satisfies these conditions and can be exported. 2077 2078 Args: 2079 head_name: Name of the head. 2080 problem_type: `ProblemType` 2081 label_keys: Optional label keys 2082 2083 Returns: 2084 A function to generate output alternatives. 2085 """ 2086 def _create_output_alternatives(predictions): 2087 """Creates output alternative for the Head. 2088 2089 Args: 2090 predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a 2091 symbolic name for an output Tensor possibly but not necessarily taken 2092 from `PredictionKey`, and 'Tensor' is the corresponding output Tensor 2093 itself. 2094 2095 Returns: 2096 `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where 2097 'submodel_name' is a submodel identifier that should be consistent across 2098 the pipeline (here likely taken from the head_name), 2099 'problem_type' is a `ProblemType`, 2100 'tensor_name' is a symbolic name for an output Tensor possibly but not 2101 necessarily taken from `PredictionKey`, and 2102 'Tensor' is the corresponding output Tensor itself. 2103 2104 Raises: 2105 ValueError: if predictions does not have PredictionKey.PROBABILITIES key. 2106 """ 2107 probabilities = predictions.get(prediction_key.PredictionKey.PROBABILITIES) 2108 if probabilities is None: 2109 raise ValueError("%s missing in predictions" % 2110 prediction_key.PredictionKey.PROBABILITIES) 2111 2112 with ops.name_scope(None, "_classification_output_alternatives", 2113 (probabilities,)): 2114 batch_size = array_ops.shape(probabilities)[0] 2115 if label_keys: 2116 classes = array_ops.tile( 2117 input=array_ops.expand_dims(input=label_keys, axis=0), 2118 multiples=[batch_size, 1], 2119 name="classes_tensor") 2120 else: 2121 n = array_ops.shape(probabilities)[1] 2122 classes = array_ops.tile( 2123 input=array_ops.expand_dims(input=math_ops.range(n), axis=0), 2124 multiples=[batch_size, 1]) 2125 classes = string_ops.as_string(classes, name="classes_tensor") 2126 2127 exported_predictions = { 2128 prediction_key.PredictionKey.PROBABILITIES: probabilities, 2129 prediction_key.PredictionKey.CLASSES: classes} 2130 return {head_name: (problem_type, exported_predictions)} 2131 2132 return _create_output_alternatives 2133 2134# Aliases 2135# TODO(zakaria): Remove these aliases, See b/34751732 2136_regression_head = regression_head 2137_poisson_regression_head = poisson_regression_head 2138_multi_class_head = multi_class_head 2139_binary_svm_head = binary_svm_head 2140_multi_label_head = multi_label_head 2141_multi_head = multi_head 2142_Head = Head 2143