1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TargetColumn abstract a single head in the model. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from tensorflow.contrib.framework import deprecated 24from tensorflow.contrib.losses.python.losses import loss_ops 25from tensorflow.contrib.metrics.python.ops import metric_ops 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import nn 32 33 34@deprecated( 35 "2016-11-12", "This file will be removed after the deprecation date." 36 "Please switch to " 37 "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py") 38def regression_target(label_name=None, 39 weight_column_name=None, 40 label_dimension=1): 41 """Creates a _TargetColumn for linear regression. 42 43 Args: 44 label_name: String, name of the key in label dict. Can be null if label 45 is a tensor (single headed models). 46 weight_column_name: A string defining feature column name representing 47 weights. It is used to down weight or boost examples during training. It 48 will be multiplied by the loss of the example. 49 label_dimension: dimension of the target for multilabels. 50 51 Returns: 52 An instance of _TargetColumn 53 """ 54 return _RegressionTargetColumn( 55 loss_fn=_mean_squared_loss, 56 label_name=label_name, 57 weight_column_name=weight_column_name, 58 label_dimension=label_dimension) 59 60 61# TODO(zakaria): Add logistic_regression_target 62 63 64@deprecated( 65 "2016-11-12", "This file will be removed after the deprecation date." 66 "Please switch to " 67 "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py") 68def multi_class_target(n_classes, label_name=None, weight_column_name=None): 69 """Creates a _TargetColumn for multi class single label classification. 70 71 The target column uses softmax cross entropy loss. 72 73 Args: 74 n_classes: Integer, number of classes, must be >= 2 75 label_name: String, name of the key in label dict. Can be null if label 76 is a tensor (single headed models). 77 weight_column_name: A string defining feature column name representing 78 weights. It is used to down weight or boost examples during training. It 79 will be multiplied by the loss of the example. 80 81 Returns: 82 An instance of _MultiClassTargetColumn. 83 84 Raises: 85 ValueError: if n_classes is < 2 86 """ 87 if n_classes < 2: 88 raise ValueError("n_classes must be > 1 for classification.") 89 if n_classes == 2: 90 loss_fn = _log_loss_with_two_classes 91 else: 92 loss_fn = _softmax_cross_entropy_loss 93 return _MultiClassTargetColumn( 94 loss_fn=loss_fn, 95 n_classes=n_classes, 96 label_name=label_name, 97 weight_column_name=weight_column_name) 98 99 100@deprecated( 101 "2016-11-12", "This file will be removed after the deprecation date." 102 "Please switch to " 103 "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py") 104def binary_svm_target(label_name=None, weight_column_name=None): 105 """Creates a _TargetColumn for binary classification with SVMs. 106 107 The target column uses binary hinge loss. 108 109 Args: 110 label_name: String, name of the key in label dict. Can be null if label 111 is a tensor (single headed models). 112 weight_column_name: A string defining feature column name representing 113 weights. It is used to down weight or boost examples during training. It 114 will be multiplied by the loss of the example. 115 116 Returns: 117 An instance of _TargetColumn. 118 119 """ 120 return _BinarySvmTargetColumn( 121 label_name=label_name, weight_column_name=weight_column_name) 122 123 124@deprecated( 125 "2016-11-12", "This file will be removed after the deprecation date." 126 "Please switch to " 127 "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py") 128class ProblemType(object): 129 UNSPECIFIED = 0 130 CLASSIFICATION = 1 131 LINEAR_REGRESSION = 2 132 LOGISTIC_REGRESSION = 3 133 134 135class _TargetColumn(object): 136 """_TargetColumn is the abstraction for a single head in a model. 137 138 Args: 139 loss_fn: a function that returns the loss tensor. 140 num_label_columns: Integer, number of label columns. 141 label_name: String, name of the key in label dict. Can be null if label 142 is a tensor (single headed models). 143 weight_column_name: A string defining feature column name representing 144 weights. It is used to down weight or boost examples during training. It 145 will be multiplied by the loss of the example. 146 147 Raises: 148 ValueError: if loss_fn or n_classes are missing. 149 """ 150 151 def __init__(self, loss_fn, num_label_columns, label_name, weight_column_name, 152 problem_type): 153 if not loss_fn: 154 raise ValueError("loss_fn must be provided") 155 if num_label_columns is None: # n_classes can be 0 156 raise ValueError("num_label_columns must be provided") 157 158 self._loss_fn = loss_fn 159 self._num_label_columns = num_label_columns 160 self._label_name = label_name 161 self._weight_column_name = weight_column_name 162 self._problem_type = problem_type 163 164 def logits_to_predictions(self, logits, proba=False): 165 # Abstrat, Subclasses must implement. 166 raise NotImplementedError() 167 168 def get_eval_ops(self, features, logits, labels, metrics=None): 169 """Returns eval op.""" 170 raise NotImplementedError 171 172 @property 173 def label_name(self): 174 return self._label_name 175 176 @property 177 def weight_column_name(self): 178 return self._weight_column_name 179 180 @property 181 def num_label_columns(self): 182 return self._num_label_columns 183 184 def get_weight_tensor(self, features): 185 if not self._weight_column_name: 186 return None 187 else: 188 return array_ops.reshape( 189 math_ops.cast(features[self._weight_column_name], dtypes.float32), 190 shape=(-1,)) 191 192 @property 193 def problem_type(self): 194 return self._problem_type 195 196 def _weighted_loss(self, loss, weight_tensor): 197 """Returns cumulative weighted loss.""" 198 unweighted_loss = array_ops.reshape(loss, shape=(-1,)) 199 weighted_loss = math_ops.multiply(unweighted_loss, 200 array_ops.reshape( 201 weight_tensor, shape=(-1,))) 202 return weighted_loss 203 204 def training_loss(self, logits, target, features, name="training_loss"): 205 """Returns training loss tensor for this head. 206 207 Training loss is different from the loss reported on the tensorboard as we 208 should respect the example weights when computing the gradient. 209 210 L = sum_{i} w_{i} * l_{i} / B 211 212 where B is the number of examples in the batch, l_{i}, w_{i} are individual 213 losses, and example weight. 214 215 Args: 216 logits: logits, a float tensor. 217 target: either a tensor for labels or in multihead case, a dict of string 218 to target tensor. 219 features: features dict. 220 name: Op name. 221 222 Returns: 223 Loss tensor. 224 """ 225 target = target[self.name] if isinstance(target, dict) else target 226 loss_unweighted = self._loss_fn(logits, target) 227 228 weight_tensor = self.get_weight_tensor(features) 229 if weight_tensor is None: 230 return math_ops.reduce_mean(loss_unweighted, name=name) 231 loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor) 232 return math_ops.reduce_mean(loss_weighted, name=name) 233 234 def loss(self, logits, target, features): 235 """Returns loss tensor for this head. 236 237 The loss returned is the weighted average. 238 239 L = sum_{i} w_{i} * l_{i} / sum_{i} w_{i} 240 241 Args: 242 logits: logits, a float tensor. 243 target: either a tensor for labels or in multihead case, a dict of string 244 to target tensor. 245 features: features dict. 246 247 Returns: 248 Loss tensor. 249 """ 250 target = target[self.name] if isinstance(target, dict) else target 251 loss_unweighted = self._loss_fn(logits, target) 252 253 weight_tensor = self.get_weight_tensor(features) 254 if weight_tensor is None: 255 return math_ops.reduce_mean(loss_unweighted, name="loss") 256 loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor) 257 return math_ops.div( 258 math_ops.reduce_sum(loss_weighted), 259 math_ops.cast(math_ops.reduce_sum(weight_tensor), dtypes.float32), 260 name="loss") 261 262 263class _RegressionTargetColumn(_TargetColumn): 264 """_TargetColumn for regression.""" 265 266 def __init__(self, loss_fn, label_name, weight_column_name, label_dimension): 267 super(_RegressionTargetColumn, self).__init__( 268 loss_fn=loss_fn, 269 num_label_columns=label_dimension, 270 label_name=label_name, 271 weight_column_name=weight_column_name, 272 problem_type=ProblemType.LINEAR_REGRESSION) 273 274 def logits_to_predictions(self, logits, proba=False): 275 if self.num_label_columns == 1: 276 return array_ops.squeeze(logits, axis=[1]) 277 return logits 278 279 def get_eval_ops(self, features, logits, labels, metrics=None): 280 loss = self.loss(logits, labels, features) 281 result = {"loss": metric_ops.streaming_mean(loss)} 282 if metrics: 283 predictions = self.logits_to_predictions(logits, proba=False) 284 result.update( 285 _run_metrics(predictions, labels, metrics, 286 self.get_weight_tensor(features))) 287 return result 288 289 290class _MultiClassTargetColumn(_TargetColumn): 291 """_TargetColumn for classification.""" 292 293 # TODO(zakaria): support multilabel. 294 def __init__(self, loss_fn, n_classes, label_name, weight_column_name): 295 if n_classes < 2: 296 raise ValueError("n_classes must be >= 2") 297 super(_MultiClassTargetColumn, self).__init__( 298 loss_fn=loss_fn, 299 num_label_columns=1 if n_classes == 2 else n_classes, 300 label_name=label_name, 301 weight_column_name=weight_column_name, 302 problem_type=ProblemType.CLASSIFICATION) 303 304 def logits_to_predictions(self, logits, proba=False): 305 if self.num_label_columns == 1: 306 logits = array_ops.concat([array_ops.zeros_like(logits), logits], 1) 307 308 if proba: 309 return nn.softmax(logits) 310 else: 311 return math_ops.argmax(logits, 1) 312 313 def _default_eval_metrics(self): 314 if self._num_label_columns == 1: 315 return get_default_binary_metrics_for_eval(thresholds=[.5]) 316 return {} 317 318 def get_eval_ops(self, features, logits, labels, metrics=None): 319 loss = self.loss(logits, labels, features) 320 result = {"loss": metric_ops.streaming_mean(loss)} 321 322 # Adds default metrics. 323 if metrics is None: 324 # TODO(b/29366811): This currently results in both an "accuracy" and an 325 # "accuracy/threshold_0.500000_mean" metric for binary classification. 326 metrics = {("accuracy", "classes"): metric_ops.streaming_accuracy} 327 328 predictions = math_ops.sigmoid(logits) 329 labels_float = math_ops.cast(labels, dtypes.float32) 330 331 default_metrics = self._default_eval_metrics() 332 for metric_name, metric_op in default_metrics.items(): 333 result[metric_name] = metric_op(predictions, labels_float) 334 335 class_metrics = {} 336 proba_metrics = {} 337 for name, metric_op in six.iteritems(metrics): 338 if isinstance(name, tuple): 339 if len(name) != 2: 340 raise ValueError("Ignoring metric {}. It returned a tuple with " 341 "len {}, expected 2.".format(name, len(name))) 342 else: 343 if name[1] not in ["classes", "probabilities"]: 344 raise ValueError("Ignoring metric {}. The 2nd element of its " 345 "name should be either 'classes' or " 346 "'probabilities'.".format(name)) 347 elif name[1] == "classes": 348 class_metrics[name[0]] = metric_op 349 else: 350 proba_metrics[name[0]] = metric_op 351 elif isinstance(name, str): 352 class_metrics[name] = metric_op 353 else: 354 raise ValueError("Ignoring metric {}. Its name is not in the correct " 355 "form.".format(name)) 356 if class_metrics: 357 class_predictions = self.logits_to_predictions(logits, proba=False) 358 result.update( 359 _run_metrics(class_predictions, labels, class_metrics, 360 self.get_weight_tensor(features))) 361 if proba_metrics: 362 predictions = self.logits_to_predictions(logits, proba=True) 363 result.update( 364 _run_metrics(predictions, labels, proba_metrics, 365 self.get_weight_tensor(features))) 366 return result 367 368 369class _BinarySvmTargetColumn(_MultiClassTargetColumn): 370 """_TargetColumn for binary classification using SVMs.""" 371 372 def __init__(self, label_name, weight_column_name): 373 374 def loss_fn(logits, target): 375 check_shape_op = control_flow_ops.Assert( 376 math_ops.less_equal(array_ops.rank(target), 2), 377 ["target's shape should be either [batch_size, 1] or [batch_size]"]) 378 with ops.control_dependencies([check_shape_op]): 379 target = array_ops.reshape( 380 target, shape=[array_ops.shape(target)[0], 1]) 381 return loss_ops.hinge_loss(logits, target) 382 383 super(_BinarySvmTargetColumn, self).__init__( 384 loss_fn=loss_fn, 385 n_classes=2, 386 label_name=label_name, 387 weight_column_name=weight_column_name) 388 389 def logits_to_predictions(self, logits, proba=False): 390 if proba: 391 raise ValueError( 392 "logits to probabilities is not supported for _BinarySvmTargetColumn") 393 394 logits = array_ops.concat([array_ops.zeros_like(logits), logits], 1) 395 return math_ops.argmax(logits, 1) 396 397 398# TODO(zakaria): use contrib losses. 399def _mean_squared_loss(logits, target): 400 # To prevent broadcasting inside "-". 401 if len(target.get_shape()) == 1: 402 target = array_ops.expand_dims(target, axis=1) 403 404 logits.get_shape().assert_is_compatible_with(target.get_shape()) 405 return math_ops.squared_difference(logits, 406 math_ops.cast(target, dtypes.float32)) 407 408 409def _log_loss_with_two_classes(logits, target): 410 # sigmoid_cross_entropy_with_logits requires [batch_size, 1] target. 411 if len(target.get_shape()) == 1: 412 target = array_ops.expand_dims(target, axis=1) 413 loss_vec = nn.sigmoid_cross_entropy_with_logits( 414 labels=math_ops.cast(target, dtypes.float32), logits=logits) 415 return loss_vec 416 417 418def _softmax_cross_entropy_loss(logits, target): 419 # Check that we got integer for classification. 420 if not target.dtype.is_integer: 421 raise ValueError("Target's dtype should be integer " 422 "Instead got %s." % target.dtype) 423 # sparse_softmax_cross_entropy_with_logits requires [batch_size] target. 424 if len(target.get_shape()) == 2: 425 target = array_ops.squeeze(target, axis=[1]) 426 loss_vec = nn.sparse_softmax_cross_entropy_with_logits( 427 labels=target, logits=logits) 428 return loss_vec 429 430 431def _run_metrics(predictions, labels, metrics, weights): 432 result = {} 433 labels = math_ops.cast(labels, predictions.dtype) 434 for name, metric in six.iteritems(metrics or {}): 435 if weights is not None: 436 result[name] = metric(predictions, labels, weights=weights) 437 else: 438 result[name] = metric(predictions, labels) 439 440 return result 441 442 443@deprecated( 444 "2016-11-12", "This file will be removed after the deprecation date." 445 "Please switch to " 446 "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py") 447def get_default_binary_metrics_for_eval(thresholds): 448 """Returns a dictionary of basic metrics for logistic regression. 449 450 Args: 451 thresholds: List of floating point thresholds to use for accuracy, 452 precision, and recall metrics. If None, defaults to [0.5]. 453 454 Returns: 455 Dictionary mapping metrics string names to metrics functions. 456 """ 457 metrics = {} 458 metrics[_MetricKeys.PREDICTION_MEAN] = _predictions_streaming_mean 459 metrics[_MetricKeys.TARGET_MEAN] = _labels_streaming_mean 460 # Also include the streaming mean of the label as an accuracy baseline, as 461 # a reminder to users. 462 metrics[_MetricKeys.ACCURACY_BASELINE] = _labels_streaming_mean 463 464 metrics[_MetricKeys.AUC] = _streaming_auc 465 466 for threshold in thresholds: 467 metrics[_MetricKeys.ACCURACY_MEAN % 468 threshold] = _accuracy_at_threshold(threshold) 469 # Precision for positive examples. 470 metrics[_MetricKeys.PRECISION_MEAN % threshold] = _streaming_at_threshold( 471 metric_ops.streaming_precision_at_thresholds, threshold) 472 # Recall for positive examples. 473 metrics[_MetricKeys.RECALL_MEAN % threshold] = _streaming_at_threshold( 474 metric_ops.streaming_recall_at_thresholds, threshold) 475 476 return metrics 477 478 479def _float_weights_or_none(weights): 480 if weights is None: 481 return None 482 return math_ops.cast(weights, dtypes.float32) 483 484 485def _labels_streaming_mean(unused_predictions, labels, weights=None): 486 return metric_ops.streaming_mean(labels, weights=weights) 487 488 489def _predictions_streaming_mean(predictions, unused_labels, weights=None): 490 return metric_ops.streaming_mean(predictions, weights=weights) 491 492 493def _streaming_auc(predictions, labels, weights=None): 494 return metric_ops.streaming_auc( 495 predictions, labels, weights=_float_weights_or_none(weights)) 496 497 498def _accuracy_at_threshold(threshold): 499 500 def _accuracy_metric(predictions, labels, weights=None): 501 threshold_predictions = math_ops.cast( 502 math_ops.greater_equal(predictions, threshold), dtypes.float32) 503 return metric_ops.streaming_accuracy( 504 predictions=threshold_predictions, labels=labels, weights=weights) 505 506 return _accuracy_metric 507 508 509def _streaming_at_threshold(streaming_metrics_fn, threshold): 510 511 def _streaming_metrics(predictions, labels, weights=None): 512 precision_tensor, update_op = streaming_metrics_fn( 513 predictions, 514 labels=labels, 515 thresholds=[threshold], 516 weights=_float_weights_or_none(weights)) 517 return array_ops.squeeze(precision_tensor), update_op 518 519 return _streaming_metrics 520 521 522class _MetricKeys(object): 523 AUC = "auc" 524 PREDICTION_MEAN = "labels/prediction_mean" 525 TARGET_MEAN = "labels/actual_target_mean" 526 ACCURACY_BASELINE = "accuracy/baseline_target_mean" 527 ACCURACY_MEAN = "accuracy/threshold_%f_mean" 528 PRECISION_MEAN = "precision/positive_threshold_%f_mean" 529 RECALL_MEAN = "recall/positive_threshold_%f_mean" 530