1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""The metric spec class to flexibly connect models and metrics (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import six 27 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.util import tf_inspect 30from tensorflow.python.util.deprecation import deprecated 31 32 33def _assert_named_args(sentinel): 34 if sentinel is not None: 35 raise ValueError( 36 '`metric_fn` requires named args: ' 37 '`labels`, `predictions`, and optionally `weights`.') 38 39 40def _args(fn): 41 """Get argument names for function-like object. 42 43 Args: 44 fn: Function, or function-like object (e.g., result of `functools.partial`). 45 46 Returns: 47 `tuple` of string argument names. 48 """ 49 if hasattr(fn, 'func') and hasattr(fn, 'keywords'): 50 # Handle functools.partial and similar objects. 51 return tuple( 52 [arg for arg in _args(fn.func) if arg not in set(fn.keywords.keys())]) 53 # Handle function. 54 return tuple(tf_inspect.getargspec(fn).args) 55 56 57_CANONICAL_LABELS_ARG = 'labels' 58_LABELS_ARGS = set((_CANONICAL_LABELS_ARG, 'label', 'targets', 'target')) 59_CANONICAL_PREDICTIONS_ARG = 'predictions' 60_PREDICTIONS_ARGS = set((_CANONICAL_PREDICTIONS_ARG, 'prediction', 61 'logits', 'logit')) 62_CANONICAL_WEIGHTS_ARG = 'weights' 63_WEIGHTS_ARGS = set((_CANONICAL_WEIGHTS_ARG, 'weight')) 64 65 66def _matching_arg( 67 fn_name, fn_args, candidate_args, canonical_arg, is_required=False): 68 """Find single argument in `args` from `candidate_args`. 69 70 Args: 71 fn_name: Function name, only used for error string. 72 fn_args: String argument names to `fn_name` function. 73 candidate_args: Candidate argument names to find in `args`. 74 canonical_arg: Canonical argument name in `candidate_args`. This is only 75 used to log a warning if a non-canonical match is found. 76 is_required: Whether function is required to have an arg in 77 `candidate_args`. 78 79 Returns: 80 String argument name if found, or `None` if not found. 81 82 Raises: 83 ValueError: if 2 candidates are found, or 0 are found and `is_required` is 84 set. 85 """ 86 assert canonical_arg in candidate_args # Sanity check. 87 matching_args = candidate_args.intersection(fn_args) 88 if len(matching_args) > 1: 89 raise ValueError( 90 'Ambiguous arguments %s, must provide only one of %s.' % ( 91 matching_args, candidate_args)) 92 matching_arg = matching_args.pop() if matching_args else None 93 if matching_arg: 94 if matching_arg != canonical_arg: 95 logging.warning( 96 'Canonical arg %s missing from %s(%s), using %s.', 97 canonical_arg, fn_name, fn_args, matching_arg) 98 elif is_required: 99 raise ValueError( 100 '%s missing from %s(%s).' % (candidate_args, fn_name, fn_args)) 101 return matching_arg 102 103 104def _fn_name(fn): 105 if hasattr(fn, '__name__'): 106 return fn.__name__ 107 if hasattr(fn, 'func') and hasattr(fn.func, '__name__'): 108 return fn.func.__name__ # If it's a functools.partial. 109 return str(fn) 110 111 112def _adapt_metric_fn( 113 metric_fn, metric_fn_name, is_labels_required, is_weights_required): 114 """Adapt `metric_fn` to take only named args. 115 116 This returns a function that takes only named args `labels`, `predictions`, 117 and `weights`, and invokes `metric_fn` according to the following rules: 118 - If `metric_fn` args include exactly one of `_LABELS_ARGS`, that arg is 119 passed (usually by name, but positionally if both it and `predictions` need 120 to be passed positionally). Otherwise, `labels` are omitted. 121 - If `metric_fn` args include exactly one of `_PREDICTIONS_ARGS`, that arg is 122 passed by name. Otherwise, `predictions` are passed positionally as the 123 first non-label argument. 124 - If exactly one of `_WEIGHTS_ARGS` is provided, that arg is passed by 125 name. 126 127 Args: 128 metric_fn: Metric function to be wrapped. 129 metric_fn_name: `metric_fn` name, only used for logging. 130 is_labels_required: Whether `labels` is a required arg. 131 is_weights_required: Whether `weights` is a required arg. 132 133 Returns: 134 Function accepting only named args `labels, `predictions`, and `weights`, 135 and passing those to `metric_fn`. 136 137 Raises: 138 ValueError: if one of the following is true: 139 - `metric_fn` has more than one arg of `_LABELS_ARGS`, `_PREDICTIONS_ARGS`, 140 or `_WEIGHTS_ARGS` 141 - `is_labels_required` is true, and `metric_fn` has no arg from 142 `_LABELS_ARGS`. 143 - `is_weights_required` is true, and `metric_fn` has no arg from 144 `_WEIGHTS_ARGS`. 145 """ 146 args = _args(metric_fn) 147 148 labels_arg = _matching_arg( 149 metric_fn_name, args, _LABELS_ARGS, _CANONICAL_LABELS_ARG, 150 is_labels_required) 151 predictions_arg = _matching_arg( 152 metric_fn_name, args, _PREDICTIONS_ARGS, _CANONICAL_PREDICTIONS_ARG) 153 weights_arg = _matching_arg( 154 metric_fn_name, args, _WEIGHTS_ARGS, _CANONICAL_WEIGHTS_ARG, 155 is_weights_required) 156 157 # pylint: disable=invalid-name 158 if labels_arg: 159 if predictions_arg: 160 # Both labels and predictions are named args. 161 def _named_metric_fn( 162 _sentinel=None, labels=None, predictions=None, weights=None): 163 _assert_named_args(_sentinel) 164 kwargs = { 165 labels_arg: labels, 166 predictions_arg: predictions, 167 } 168 if weights is not None: 169 kwargs[weights_arg] = weights 170 return metric_fn(**kwargs) 171 return _named_metric_fn 172 173 if labels_arg == args[0]: 174 # labels is a named arg, and first. predictions is not a named arg, so we 175 # want to pass it as the 2nd positional arg (i.e., the first non-labels 176 # position), which means passing both positionally. 177 def _positional_metric_fn( 178 _sentinel=None, labels=None, predictions=None, weights=None): 179 _assert_named_args(_sentinel) 180 # TODO(ptucker): Should we support metrics that take only labels? 181 # Currently, if you want streaming mean of a label, you have to wrap it 182 # in a fn that takes discards predictions. 183 if weights is None: 184 return metric_fn(labels, predictions) 185 return metric_fn(labels, predictions, **{weights_arg: weights}) 186 return _positional_metric_fn 187 188 # labels is a named arg, and not first, so we pass predictions positionally 189 # and labels by name. 190 def _positional_predictions_metric_fn( 191 _sentinel=None, labels=None, predictions=None, weights=None): 192 _assert_named_args(_sentinel) 193 kwargs = { 194 labels_arg: labels, 195 } 196 if weights is not None: 197 kwargs[weights_arg] = weights 198 return metric_fn(predictions, **kwargs) 199 return _positional_predictions_metric_fn 200 201 if predictions_arg: 202 # No labels, and predictions is named, so we pass the latter as a named arg. 203 def _named_no_labels_metric_fn( 204 _sentinel=None, labels=None, predictions=None, weights=None): 205 del labels 206 _assert_named_args(_sentinel) 207 kwargs = { 208 predictions_arg: predictions, 209 } 210 # TODO(ptucker): Should we allow weights with no labels? 211 if weights is not None: 212 kwargs[weights_arg] = weights 213 return metric_fn(**kwargs) 214 return _named_no_labels_metric_fn 215 216 # Neither labels nor predictions are named, so we just pass predictions as the 217 # first arg. 218 def _positional_no_labels_metric_fn( 219 _sentinel=None, labels=None, predictions=None, weights=None): 220 del labels 221 _assert_named_args(_sentinel) 222 if weights is None: 223 return metric_fn(predictions) 224 # TODO(ptucker): Should we allow weights with no labels? 225 return metric_fn(predictions, **{weights_arg: weights}) 226 return _positional_no_labels_metric_fn 227 228 229class MetricSpec(object): 230 """MetricSpec connects a model to metric functions. 231 232 THIS CLASS IS DEPRECATED. See 233 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 234 for general migration instructions. 235 236 The MetricSpec class contains all information necessary to connect the 237 output of a `model_fn` to the metrics (usually, streaming metrics) that are 238 used in evaluation. 239 240 It is passed in the `metrics` argument of `Estimator.evaluate`. The 241 `Estimator` then knows which predictions, labels, and weight to use to call a 242 given metric function. 243 244 When building the ops to run in evaluation, an `Estimator` will call 245 `create_metric_ops`, which will connect the given `metric_fn` to the model 246 as detailed in the docstring for `create_metric_ops`, and return the metric. 247 248 Example: 249 250 Assuming a model has an input function which returns inputs containing 251 (among other things) a tensor with key "input_key", and a labels dictionary 252 containing "label_key". Let's assume that the `model_fn` for this model 253 returns a prediction with key "prediction_key". 254 255 In order to compute the accuracy of the "prediction_key" prediction, we 256 would add 257 258 ``` 259 "prediction accuracy": MetricSpec(metric_fn=prediction_accuracy_fn, 260 prediction_key="prediction_key", 261 label_key="label_key") 262 ``` 263 264 to the metrics argument to `evaluate`. `prediction_accuracy_fn` can be either 265 a predefined function in metric_ops (e.g., `streaming_accuracy`) or a custom 266 function you define. 267 268 If we would like the accuracy to be weighted by "input_key", we can add that 269 as the `weight_key` argument. 270 271 ``` 272 "prediction accuracy": MetricSpec(metric_fn=prediction_accuracy_fn, 273 prediction_key="prediction_key", 274 label_key="label_key", 275 weight_key="input_key") 276 ``` 277 278 An end-to-end example is as follows: 279 280 ``` 281 estimator = tf.contrib.learn.Estimator(...) 282 estimator.fit(...) 283 _ = estimator.evaluate( 284 input_fn=input_fn, 285 steps=1, 286 metrics={ 287 'prediction accuracy': 288 metric_spec.MetricSpec( 289 metric_fn=prediction_accuracy_fn, 290 prediction_key="prediction_key", 291 label_key="label_key") 292 }) 293 ``` 294 295 """ 296 297 @deprecated(None, 'Use tf.estimator.EstimatorSpec.eval_metric_ops.') 298 def __init__(self, 299 metric_fn, 300 prediction_key=None, 301 label_key=None, 302 weight_key=None): 303 """Constructor. 304 305 Creates a MetricSpec. 306 307 Args: 308 metric_fn: A function to use as a metric. See `_adapt_metric_fn` for 309 rules on how `predictions`, `labels`, and `weights` are passed to this 310 function. This must return either a single `Tensor`, which is 311 interpreted as a value of this metric, or a pair 312 `(value_op, update_op)`, where `value_op` is the op to call to 313 obtain the value of the metric, and `update_op` should be run for 314 each batch to update internal state. 315 prediction_key: The key for a tensor in the `predictions` dict (output 316 from the `model_fn`) to use as the `predictions` input to the 317 `metric_fn`. Optional. If `None`, the `model_fn` must return a single 318 tensor or a dict with only a single entry as `predictions`. 319 label_key: The key for a tensor in the `labels` dict (output from the 320 `input_fn`) to use as the `labels` input to the `metric_fn`. 321 Optional. If `None`, the `input_fn` must return a single tensor or a 322 dict with only a single entry as `labels`. 323 weight_key: The key for a tensor in the `inputs` dict (output from the 324 `input_fn`) to use as the `weights` input to the `metric_fn`. 325 Optional. If `None`, no weights will be passed to the `metric_fn`. 326 """ 327 self._metric_fn_name = _fn_name(metric_fn) 328 self._metric_fn = _adapt_metric_fn( 329 metric_fn=metric_fn, 330 metric_fn_name=self._metric_fn_name, 331 is_labels_required=label_key is not None, 332 is_weights_required=weight_key is not None) 333 self._prediction_key = prediction_key 334 self._label_key = label_key 335 self._weight_key = weight_key 336 337 @property 338 def prediction_key(self): 339 return self._prediction_key 340 341 @property 342 def label_key(self): 343 return self._label_key 344 345 @property 346 def weight_key(self): 347 return self._weight_key 348 349 @property 350 def metric_fn(self): 351 """Metric function. 352 353 This function accepts named args: `predictions`, `labels`, `weights`. It 354 returns a single `Tensor` or `(value_op, update_op)` pair. See `metric_fn` 355 constructor argument for more details. 356 357 Returns: 358 Function, see `metric_fn` constructor argument for more details. 359 """ 360 return self._metric_fn 361 362 def __str__(self): 363 return ('MetricSpec(metric_fn=%s, ' % self._metric_fn_name + 364 'prediction_key=%s, ' % self.prediction_key + 365 'label_key=%s, ' % self.label_key + 366 'weight_key=%s)' % self.weight_key 367 ) 368 369 def create_metric_ops(self, inputs, labels, predictions): 370 """Connect our `metric_fn` to the specified members of the given dicts. 371 372 This function will call the `metric_fn` given in our constructor as follows: 373 374 ``` 375 metric_fn(predictions[self.prediction_key], 376 labels[self.label_key], 377 weights=weights[self.weight_key]) 378 ``` 379 380 And returns the result. The `weights` argument is only passed if 381 `self.weight_key` is not `None`. 382 383 `predictions` and `labels` may be single tensors as well as dicts. If 384 `predictions` is a single tensor, `self.prediction_key` must be `None`. If 385 `predictions` is a single element dict, `self.prediction_key` is allowed to 386 be `None`. Conversely, if `labels` is a single tensor, `self.label_key` must 387 be `None`. If `labels` is a single element dict, `self.label_key` is allowed 388 to be `None`. 389 390 Args: 391 inputs: A dict of inputs produced by the `input_fn` 392 labels: A dict of labels or a single label tensor produced by the 393 `input_fn`. 394 predictions: A dict of predictions or a single tensor produced by the 395 `model_fn`. 396 397 Returns: 398 The result of calling `metric_fn`. 399 400 Raises: 401 ValueError: If `predictions` or `labels` is a single `Tensor` and 402 `self.prediction_key` or `self.label_key` is not `None`; or if 403 `self.label_key` is `None` but `labels` is a dict with more than one 404 element, or if `self.prediction_key` is `None` but `predictions` is a 405 dict with more than one element. 406 """ 407 def _get_dict(name, dict_or_tensor, key): 408 """Get a single tensor or an element of a dict or raise ValueError.""" 409 if key: 410 if not isinstance(dict_or_tensor, dict): 411 raise ValueError('MetricSpec with ' + name + '_key specified' 412 ' requires ' + 413 name + 's dict, got %s.\n' % dict_or_tensor + 414 'You must not provide a %s_key if you ' % name + 415 'only have a single Tensor as %ss.' % name) 416 if key not in dict_or_tensor: 417 raise KeyError( 418 'Key \'%s\' missing from %s.' % (key, dict_or_tensor.keys())) 419 return dict_or_tensor[key] 420 else: 421 if isinstance(dict_or_tensor, dict): 422 if len(dict_or_tensor) != 1: 423 raise ValueError('MetricSpec without specified ' + name + '_key' 424 ' requires ' + name + 's tensor or single element' 425 ' dict, got %s' % dict_or_tensor) 426 return six.next(six.itervalues(dict_or_tensor)) 427 return dict_or_tensor 428 429 # Get the predictions. 430 prediction = _get_dict('prediction', predictions, self.prediction_key) 431 432 # Get the labels. 433 label = _get_dict('label', labels, self.label_key) 434 435 try: 436 return self.metric_fn( 437 labels=label, 438 predictions=prediction, 439 weights=inputs[self.weight_key] if self.weight_key else None) 440 except Exception as ex: 441 logging.error('Could not create metric ops for %s, %s.' % (self, ex)) 442 raise 443