• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class Evaluator holds Metrics for the duration of an evaluation run."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from tensorflow.contrib.eager.python import datasets
24from tensorflow.contrib.eager.python import metrics
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.eager import context
27from tensorflow.python.eager import function
28from tensorflow.python.framework import errors_impl
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import summary_ops_v2 as summary_ops
32
33
34class Evaluator(object):
35  """This holds and updates Metrics for the duration of a single eval run.
36
37  Usage:
38    evaluator = my_model.evaluator() # or MyEvaluator(my_model)
39    for example_batch in ...:
40      evaluator(example_batch)
41    results = evaluator.all_metric_results(optional_summary_logdir)
42
43  Or, if you are getting your examples from a tf.data.Dataset, you can use
44  the evaluate_on_dataset() method.
45
46  Implementers of Evaluators should
47  (a) Call `track_metric()` and/or `track_evaluator()` in __init__().
48  (b) Override the `call()` method. It will be passed the output of the
49      model's `eval_data()` method, and should call its contained metrics
50      (treating them as callables) and any child Evaluators (using their
51      call() method to avoid calling eval_data() again).
52
53  Args:
54    model: A `Model` object with an `eval_data()` method.
55  """
56
57  def __init__(self, model):
58    self._model = model
59    self._metrics = {}
60    self._evaluators = {}
61    if not context.executing_eagerly():
62      self.call = function.defun(self.call)
63
64  # ---- API for users ----
65  def __call__(self, *args, **kwargs):
66    """Update metrics with a minibatch of input examples.
67
68    Args:
69      *args:
70      **kwargs: Arguments representing an input mini-batch of examples to
71        pass to self.model.eval_data().
72
73    Returns:
74      The op to execute or None if executing eagerly.
75    """
76    return self.call(self._model.eval_data(*args, **kwargs))
77
78  def init_variables(self):
79    """Return an op for initializing all contained uninitialized variables.
80
81    Only for graph execution. Should be called after variables are created
82    in the first execution of __call__().
83
84    Returns:
85      An op.
86
87    Raises:
88      RuntimeError: if eager execution is enabled.
89
90    @compatibility(eager)
91    Only for graph execution.
92    @end_compatibility
93    """
94    if context.executing_eagerly():
95      raise RuntimeError("Evaluator.init_variables() not needed when "
96                         "eager execution is enabled.")
97    return control_flow_ops.group([m.init_variables() for _, m in self.metrics])
98
99  def all_metric_results(self, summary_logdir=None):
100    """Computes results for all contained metrics.
101
102    Args:
103      summary_logdir: An optional string. If specified, metric results
104        will be written as summaries to this directory.
105
106    Returns:
107      A `dict` mapping string names to tensors.
108    """
109    if summary_logdir is None:
110      with summary_ops.never_record_summaries():
111        return self._all_metric_results()
112    else:
113      def f():
114        with summary_ops.create_file_writer(
115            summary_logdir).as_default(), summary_ops.always_record_summaries():
116          return self._all_metric_results()
117
118      if context.executing_eagerly():
119        return f()
120      else:
121        return function.defun(f)()
122
123  def _all_metric_results(self):
124    """Implementation of `all_metric_results` in the summary context."""
125    results = {}
126    for name, metric in six.iteritems(self._metrics):
127      results[name] = metric.result()
128    for prefix, evaluator in six.iteritems(self._evaluators):
129      for name, metric in six.iteritems(evaluator._metrics):  # pylint: disable=protected-access
130        results[prefix + "/" + name] = metric.result()
131    return results
132
133  def evaluate_on_dataset(self, dataset, *args, **kwargs):
134    """Convenience method for performing an eval on a Dataset.
135
136    Args:
137      dataset: Dataset object with the input data to evaluate on.
138      *args:
139      **kwargs: Optional additional arguments to __call__(), except
140        `summary_logdir`: if specified, metrics will be written as summaries
141        to this directory.
142
143    Returns:
144      @compatibility(eager)
145      When eager execution is enabled, this returns the result of performing
146      an evaluation as a dictionary. With graph execution, this returns a tuple
147      (init_op, call_op, results_op) which may be executed using this code:
148      ```python
149        sess.run(init_op)
150        try:
151          while True:
152            sess.run(call_op)
153        except tf.errors.OutOfRangeError:
154          pass
155        return sess.run(results_op)  # A dictionary
156
157        # equivalently:
158        return evaluator.run_evaluation(init_op, call_op, results_op, sess=sess)
159      ```
160      @end_compatibility
161    """
162    summary_logdir = kwargs.pop("summary_logdir", None)
163    if context.executing_eagerly():
164      for example in datasets.Iterator(dataset):
165        self.__call__(example, *args, **kwargs)
166      return self.all_metric_results(summary_logdir)
167    # Graph construction
168    call_op = self.__call__(
169        dataset_ops.make_one_shot_iterator(dataset).get_next(), *args, **kwargs)
170    init_op = self.init_variables()
171    results_op = self.all_metric_results(summary_logdir)
172    return (init_op, call_op, results_op)
173
174  @staticmethod
175  def run_evaluation(init_op, call_op, results_op, sess=None):
176    """Convenience method for running the ops returned by evaluate_on_dataset.
177
178    Args:
179      init_op: An op that initializes/resets evaluation state.
180      call_op: An op that updates evaluation state on a mini-batch of examples.
181        Must generate an tf.errors.OutOfRangeError when done.
182      results_op: A dictionary of tensors that compute the final evaluation
183        results from the evaluation state.
184      sess: The Session to run the evaluation in. Defaults to the default
185        Session.
186
187    Returns:
188      A dictionary of values, parallel to results_op.
189
190    Raises:
191      RuntimeError: if eager execution is enabled.
192
193    @compatibility(eager)
194    Only for graph execution.
195    @end_compatibility
196    """
197    if context.executing_eagerly():
198      raise RuntimeError("Evaluator.run_evaluation() not supported when "
199                         "eager execution is enabled.")
200    sess = sess or ops.get_default_session()
201    sess.run(init_op)
202    try:
203      while True:
204        sess.run(call_op)
205    except errors_impl.OutOfRangeError:
206      pass
207    return sess.run(results_op)
208
209  # ---- To be implemented by descendants ---
210  def call(self, eval_data):
211    """Update metrics using the output of self.model.
212
213    Note: This function is executed as a graph function in graph mode.
214    This means:
215    a) Operations on the same resource are executed in textual order.
216       This should make it easier to do things like add the updated
217       value of a variable to another, for example.
218    b) You don't need to worry about collecting the update ops to execute.
219       All update ops added to the graph by this function will be executed.
220    As a result, code should generally work the same way with graph or
221    eager execution.
222
223    Args:
224      eval_data: The output of self.model.eval_data() on a mini-batch of
225        examples.
226    """
227    raise NotImplementedError("Evaluators must define a call member function.")
228
229  # ---- For use by descendants ---
230  @property
231  def model(self):
232    return self._model
233
234  def track_metric(self, metric):
235    """Add a Metric to be tracked.
236
237    Metrics can only be tracked by one `Evaluator`. Metrics must be
238    tracked or they will not appear in `all_metric_results()`.
239
240    Args:
241      metric: A `Metric` object.
242
243    Returns:
244      The `metric` passed into this function.
245
246    Raises:
247      RuntimeError: If called before __init__.
248      TypeError: If `metric` is not of the correct type.
249      ValueError: If there is a name collision between Metrics or `metric`
250        has already been added to another `Evaluator`.
251    """
252    if not hasattr(self, "_metrics"):
253      raise RuntimeError(
254          "Need to call Evaluator.__init__ before adding metrics")
255    if not isinstance(metric, metrics.Metric):
256      raise TypeError(
257          "Evaluator.track_metric() passed type %s, not a tfe.metrics.Metric" %
258          (type(metric),))
259    if metric.name in self._metrics:
260      if metric is self._metrics[metric.name]:
261        return metric
262      raise ValueError(
263          "Attempt to add two Metrics with the name '%s' to the same Evaluator "
264          "'%s'" % (metric.name, self.name))
265    # pylint: disable=protected-access
266    if hasattr(metric, "_added_to_an_evaluator"):
267      raise ValueError("Metric %s already added to Evaluator %s" %
268                       (metric.name, metric._added_to_an_evaluator))
269    metric._added_to_an_evaluator = self.__class__.__name__
270    # pylint: enable=protected-access
271    self._metrics[metric.name] = metric
272    return metric
273
274  def track_evaluator(self, prefix, evaluator):
275    """Add a contained `Evaluator`.
276
277    This is for delegating to another `Evaluator`, e.g. for when you have a
278    model with multiple heads. Users should manually invoke the child
279    `Evaluator`'s `call` method from their `call` method.
280
281    Args:
282      prefix: A string. Metrics from `evaluator` are exported with this
283        prefix and a '/'.
284      evaluator: An `Evaluator` object.
285
286    Returns:
287      The value of `evaluator` passed into this function.
288
289    Raises:
290      RuntimeError: If called before __init__.
291      TypeError: If `evaluator` is not of the correct type.
292      ValueError: If an `Evaluator` has already been added with that `prefix`.
293    """
294    if not hasattr(self, "_evaluators"):
295      raise RuntimeError(
296          "Need to call Evaluator.__init__ before adding evaluators")
297    if not isinstance(evaluator, Evaluator):
298      raise TypeError(
299          "Evaluator.track_evaluator() passed type %s, not a tfe.Evaluator." %
300          (type(evaluator),))
301    if prefix in self._evaluators:
302      if evaluator is self._evaluators[prefix]:
303        return evaluator
304      raise RuntimeError(
305          "Attempt to add two Evaluators with the same prefix '%s'." % prefix)
306    self._evaluators[prefix] = evaluator
307    return evaluator
308
309  @property
310  def metric_variables(self):
311    v = []
312    for metric in six.itervalues(self._metrics):
313      v += metric.variables
314    for evaluator in six.itervalues(self._evaluators):
315      v += evaluator.metric_variables
316    return v
317
318  @property
319  def metrics(self):
320    """Returns a list of (prefix, metric) pairs."""
321    m = []
322    for metric in six.itervalues(self._metrics):
323      m.append(("", metric))
324    for prefix, evaluator in six.iteritems(self._evaluators):
325      m += [(prefix + "/" + p, m) for p, m in evaluator.metrics]
326    return m
327
328
329class SparseSoftmaxEvaluator(Evaluator):
330  """Evaluator for a sparse softmax model.
331
332  Computes a standard set of metrics for single-label, multi-class
333  models.
334
335  Args:
336    model: A `SparseSoftmaxModel` object or a `Model` whose `eval_data()`
337      method produces a `dict` containing values for the loss, true
338      label, predicted class, and optional weights.
339    loss_key: Optional key for looking up the value of the loss in the
340      `eval_data()` dict. Defaults to "loss".
341    label_key: Optional key for looking up the value of the label in the
342      `eval_data()` dict. Defaults to "label".
343    predicted_class_key: Optional key for looking up the value of the
344      predicted class in the `eval_data()` dict. Defaults to "predicted_class".
345    weights_key: Optional key for looking up the value of the weights
346      in the `eval_data()` dict. Defaults to "weights". Note that weights
347      are optional, and default to 1 if not present in `eval_data`.
348  """
349
350  def __init__(self, model, loss_key="loss", label_key="label",
351               predicted_class_key="predicted_class", weights_key="weights"):
352    super(SparseSoftmaxEvaluator, self).__init__(model)
353    # TODO(josh11b): Expand this to include everything from the standard
354    # SparseSoftmax Head.
355    self.avg_loss = self.track_metric(metrics.Mean("Avg Loss"))
356    self.accuracy = self.track_metric(metrics.Accuracy())
357    self.loss_key = loss_key
358    self.label_key = label_key
359    self.predicted_class_key = predicted_class_key
360    self.weights_key = weights_key
361
362  def call(self, eval_data):
363    """Update metrics for `eval_data` dict (described above)."""
364    weights = eval_data.get(self.weights_key, None)
365    if weights is None:
366      self.avg_loss(eval_data[self.loss_key])
367      self.accuracy(eval_data[self.label_key],
368                    eval_data[self.predicted_class_key])
369    else:
370      self.avg_loss(eval_data[self.loss_key], weights=weights)
371      self.accuracy(eval_data[self.label_key],
372                    eval_data[self.predicted_class_key],
373                    weights=weights)
374