1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class Evaluator holds Metrics for the duration of an evaluation run.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from tensorflow.contrib.eager.python import datasets 24from tensorflow.contrib.eager.python import metrics 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.eager import context 27from tensorflow.python.eager import function 28from tensorflow.python.framework import errors_impl 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import summary_ops_v2 as summary_ops 32 33 34class Evaluator(object): 35 """This holds and updates Metrics for the duration of a single eval run. 36 37 Usage: 38 evaluator = my_model.evaluator() # or MyEvaluator(my_model) 39 for example_batch in ...: 40 evaluator(example_batch) 41 results = evaluator.all_metric_results(optional_summary_logdir) 42 43 Or, if you are getting your examples from a tf.data.Dataset, you can use 44 the evaluate_on_dataset() method. 45 46 Implementers of Evaluators should 47 (a) Call `track_metric()` and/or `track_evaluator()` in __init__(). 48 (b) Override the `call()` method. It will be passed the output of the 49 model's `eval_data()` method, and should call its contained metrics 50 (treating them as callables) and any child Evaluators (using their 51 call() method to avoid calling eval_data() again). 52 53 Args: 54 model: A `Model` object with an `eval_data()` method. 55 """ 56 57 def __init__(self, model): 58 self._model = model 59 self._metrics = {} 60 self._evaluators = {} 61 if not context.executing_eagerly(): 62 self.call = function.defun(self.call) 63 64 # ---- API for users ---- 65 def __call__(self, *args, **kwargs): 66 """Update metrics with a minibatch of input examples. 67 68 Args: 69 *args: 70 **kwargs: Arguments representing an input mini-batch of examples to 71 pass to self.model.eval_data(). 72 73 Returns: 74 The op to execute or None if executing eagerly. 75 """ 76 return self.call(self._model.eval_data(*args, **kwargs)) 77 78 def init_variables(self): 79 """Return an op for initializing all contained uninitialized variables. 80 81 Only for graph execution. Should be called after variables are created 82 in the first execution of __call__(). 83 84 Returns: 85 An op. 86 87 Raises: 88 RuntimeError: if eager execution is enabled. 89 90 @compatibility(eager) 91 Only for graph execution. 92 @end_compatibility 93 """ 94 if context.executing_eagerly(): 95 raise RuntimeError("Evaluator.init_variables() not needed when " 96 "eager execution is enabled.") 97 return control_flow_ops.group([m.init_variables() for _, m in self.metrics]) 98 99 def all_metric_results(self, summary_logdir=None): 100 """Computes results for all contained metrics. 101 102 Args: 103 summary_logdir: An optional string. If specified, metric results 104 will be written as summaries to this directory. 105 106 Returns: 107 A `dict` mapping string names to tensors. 108 """ 109 if summary_logdir is None: 110 with summary_ops.never_record_summaries(): 111 return self._all_metric_results() 112 else: 113 def f(): 114 with summary_ops.create_file_writer( 115 summary_logdir).as_default(), summary_ops.always_record_summaries(): 116 return self._all_metric_results() 117 118 if context.executing_eagerly(): 119 return f() 120 else: 121 return function.defun(f)() 122 123 def _all_metric_results(self): 124 """Implementation of `all_metric_results` in the summary context.""" 125 results = {} 126 for name, metric in six.iteritems(self._metrics): 127 results[name] = metric.result() 128 for prefix, evaluator in six.iteritems(self._evaluators): 129 for name, metric in six.iteritems(evaluator._metrics): # pylint: disable=protected-access 130 results[prefix + "/" + name] = metric.result() 131 return results 132 133 def evaluate_on_dataset(self, dataset, *args, **kwargs): 134 """Convenience method for performing an eval on a Dataset. 135 136 Args: 137 dataset: Dataset object with the input data to evaluate on. 138 *args: 139 **kwargs: Optional additional arguments to __call__(), except 140 `summary_logdir`: if specified, metrics will be written as summaries 141 to this directory. 142 143 Returns: 144 @compatibility(eager) 145 When eager execution is enabled, this returns the result of performing 146 an evaluation as a dictionary. With graph execution, this returns a tuple 147 (init_op, call_op, results_op) which may be executed using this code: 148 ```python 149 sess.run(init_op) 150 try: 151 while True: 152 sess.run(call_op) 153 except tf.errors.OutOfRangeError: 154 pass 155 return sess.run(results_op) # A dictionary 156 157 # equivalently: 158 return evaluator.run_evaluation(init_op, call_op, results_op, sess=sess) 159 ``` 160 @end_compatibility 161 """ 162 summary_logdir = kwargs.pop("summary_logdir", None) 163 if context.executing_eagerly(): 164 for example in datasets.Iterator(dataset): 165 self.__call__(example, *args, **kwargs) 166 return self.all_metric_results(summary_logdir) 167 # Graph construction 168 call_op = self.__call__( 169 dataset_ops.make_one_shot_iterator(dataset).get_next(), *args, **kwargs) 170 init_op = self.init_variables() 171 results_op = self.all_metric_results(summary_logdir) 172 return (init_op, call_op, results_op) 173 174 @staticmethod 175 def run_evaluation(init_op, call_op, results_op, sess=None): 176 """Convenience method for running the ops returned by evaluate_on_dataset. 177 178 Args: 179 init_op: An op that initializes/resets evaluation state. 180 call_op: An op that updates evaluation state on a mini-batch of examples. 181 Must generate an tf.errors.OutOfRangeError when done. 182 results_op: A dictionary of tensors that compute the final evaluation 183 results from the evaluation state. 184 sess: The Session to run the evaluation in. Defaults to the default 185 Session. 186 187 Returns: 188 A dictionary of values, parallel to results_op. 189 190 Raises: 191 RuntimeError: if eager execution is enabled. 192 193 @compatibility(eager) 194 Only for graph execution. 195 @end_compatibility 196 """ 197 if context.executing_eagerly(): 198 raise RuntimeError("Evaluator.run_evaluation() not supported when " 199 "eager execution is enabled.") 200 sess = sess or ops.get_default_session() 201 sess.run(init_op) 202 try: 203 while True: 204 sess.run(call_op) 205 except errors_impl.OutOfRangeError: 206 pass 207 return sess.run(results_op) 208 209 # ---- To be implemented by descendants --- 210 def call(self, eval_data): 211 """Update metrics using the output of self.model. 212 213 Note: This function is executed as a graph function in graph mode. 214 This means: 215 a) Operations on the same resource are executed in textual order. 216 This should make it easier to do things like add the updated 217 value of a variable to another, for example. 218 b) You don't need to worry about collecting the update ops to execute. 219 All update ops added to the graph by this function will be executed. 220 As a result, code should generally work the same way with graph or 221 eager execution. 222 223 Args: 224 eval_data: The output of self.model.eval_data() on a mini-batch of 225 examples. 226 """ 227 raise NotImplementedError("Evaluators must define a call member function.") 228 229 # ---- For use by descendants --- 230 @property 231 def model(self): 232 return self._model 233 234 def track_metric(self, metric): 235 """Add a Metric to be tracked. 236 237 Metrics can only be tracked by one `Evaluator`. Metrics must be 238 tracked or they will not appear in `all_metric_results()`. 239 240 Args: 241 metric: A `Metric` object. 242 243 Returns: 244 The `metric` passed into this function. 245 246 Raises: 247 RuntimeError: If called before __init__. 248 TypeError: If `metric` is not of the correct type. 249 ValueError: If there is a name collision between Metrics or `metric` 250 has already been added to another `Evaluator`. 251 """ 252 if not hasattr(self, "_metrics"): 253 raise RuntimeError( 254 "Need to call Evaluator.__init__ before adding metrics") 255 if not isinstance(metric, metrics.Metric): 256 raise TypeError( 257 "Evaluator.track_metric() passed type %s, not a tfe.metrics.Metric" % 258 (type(metric),)) 259 if metric.name in self._metrics: 260 if metric is self._metrics[metric.name]: 261 return metric 262 raise ValueError( 263 "Attempt to add two Metrics with the name '%s' to the same Evaluator " 264 "'%s'" % (metric.name, self.name)) 265 # pylint: disable=protected-access 266 if hasattr(metric, "_added_to_an_evaluator"): 267 raise ValueError("Metric %s already added to Evaluator %s" % 268 (metric.name, metric._added_to_an_evaluator)) 269 metric._added_to_an_evaluator = self.__class__.__name__ 270 # pylint: enable=protected-access 271 self._metrics[metric.name] = metric 272 return metric 273 274 def track_evaluator(self, prefix, evaluator): 275 """Add a contained `Evaluator`. 276 277 This is for delegating to another `Evaluator`, e.g. for when you have a 278 model with multiple heads. Users should manually invoke the child 279 `Evaluator`'s `call` method from their `call` method. 280 281 Args: 282 prefix: A string. Metrics from `evaluator` are exported with this 283 prefix and a '/'. 284 evaluator: An `Evaluator` object. 285 286 Returns: 287 The value of `evaluator` passed into this function. 288 289 Raises: 290 RuntimeError: If called before __init__. 291 TypeError: If `evaluator` is not of the correct type. 292 ValueError: If an `Evaluator` has already been added with that `prefix`. 293 """ 294 if not hasattr(self, "_evaluators"): 295 raise RuntimeError( 296 "Need to call Evaluator.__init__ before adding evaluators") 297 if not isinstance(evaluator, Evaluator): 298 raise TypeError( 299 "Evaluator.track_evaluator() passed type %s, not a tfe.Evaluator." % 300 (type(evaluator),)) 301 if prefix in self._evaluators: 302 if evaluator is self._evaluators[prefix]: 303 return evaluator 304 raise RuntimeError( 305 "Attempt to add two Evaluators with the same prefix '%s'." % prefix) 306 self._evaluators[prefix] = evaluator 307 return evaluator 308 309 @property 310 def metric_variables(self): 311 v = [] 312 for metric in six.itervalues(self._metrics): 313 v += metric.variables 314 for evaluator in six.itervalues(self._evaluators): 315 v += evaluator.metric_variables 316 return v 317 318 @property 319 def metrics(self): 320 """Returns a list of (prefix, metric) pairs.""" 321 m = [] 322 for metric in six.itervalues(self._metrics): 323 m.append(("", metric)) 324 for prefix, evaluator in six.iteritems(self._evaluators): 325 m += [(prefix + "/" + p, m) for p, m in evaluator.metrics] 326 return m 327 328 329class SparseSoftmaxEvaluator(Evaluator): 330 """Evaluator for a sparse softmax model. 331 332 Computes a standard set of metrics for single-label, multi-class 333 models. 334 335 Args: 336 model: A `SparseSoftmaxModel` object or a `Model` whose `eval_data()` 337 method produces a `dict` containing values for the loss, true 338 label, predicted class, and optional weights. 339 loss_key: Optional key for looking up the value of the loss in the 340 `eval_data()` dict. Defaults to "loss". 341 label_key: Optional key for looking up the value of the label in the 342 `eval_data()` dict. Defaults to "label". 343 predicted_class_key: Optional key for looking up the value of the 344 predicted class in the `eval_data()` dict. Defaults to "predicted_class". 345 weights_key: Optional key for looking up the value of the weights 346 in the `eval_data()` dict. Defaults to "weights". Note that weights 347 are optional, and default to 1 if not present in `eval_data`. 348 """ 349 350 def __init__(self, model, loss_key="loss", label_key="label", 351 predicted_class_key="predicted_class", weights_key="weights"): 352 super(SparseSoftmaxEvaluator, self).__init__(model) 353 # TODO(josh11b): Expand this to include everything from the standard 354 # SparseSoftmax Head. 355 self.avg_loss = self.track_metric(metrics.Mean("Avg Loss")) 356 self.accuracy = self.track_metric(metrics.Accuracy()) 357 self.loss_key = loss_key 358 self.label_key = label_key 359 self.predicted_class_key = predicted_class_key 360 self.weights_key = weights_key 361 362 def call(self, eval_data): 363 """Update metrics for `eval_data` dict (described above).""" 364 weights = eval_data.get(self.weights_key, None) 365 if weights is None: 366 self.avg_loss(eval_data[self.loss_key]) 367 self.accuracy(eval_data[self.label_key], 368 eval_data[self.predicted_class_key]) 369 else: 370 self.avg_loss(eval_data[self.loss_key], weights=weights) 371 self.accuracy(eval_data[self.label_key], 372 eval_data[self.predicted_class_key], 373 weights=weights) 374