• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains functions for evaluation and summarization of metrics."""
16
17import math
18import time
19
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import init_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import state_ops
26from tensorflow.python.ops import variable_scope
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.training import basic_session_run_hooks
29from tensorflow.python.training import monitored_session
30from tensorflow.python.training import session_run_hook
31
32
33def _get_or_create_eval_step():
34  """Gets or creates the eval step `Tensor`.
35
36  Returns:
37    A `Tensor` representing a counter for the evaluation step.
38
39  Raises:
40    ValueError: If multiple `Tensors` have been added to the
41      `tf.GraphKeys.EVAL_STEP` collection.
42  """
43  graph = ops.get_default_graph()
44  eval_steps = graph.get_collection(ops.GraphKeys.EVAL_STEP)
45  if len(eval_steps) == 1:
46    return eval_steps[0]
47  elif len(eval_steps) > 1:
48    raise ValueError('Multiple tensors added to tf.GraphKeys.EVAL_STEP')
49  else:
50    counter = variable_scope.get_variable(
51        'eval_step',
52        shape=[],
53        dtype=dtypes.int64,
54        initializer=init_ops.zeros_initializer(),
55        trainable=False,
56        collections=[ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.EVAL_STEP])
57    return counter
58
59
60def _get_latest_eval_step_value(update_ops):
61  """Gets the eval step `Tensor` value after running `update_ops`.
62
63  Args:
64    update_ops: A list of `Tensors` or a dictionary of names to `Tensors`, which
65      are run before reading the eval step value.
66
67  Returns:
68    A `Tensor` representing the value for the evaluation step.
69  """
70  if isinstance(update_ops, dict):
71    update_ops = list(update_ops.values())
72
73  with ops.control_dependencies(update_ops):
74    return array_ops.identity(_get_or_create_eval_step().read_value())
75
76
77class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook):
78  """Run hook used by the evaluation routines to run the `eval_ops` N times."""
79
80  def __init__(self, num_evals, steps_per_run=1):
81    """Constructs the run hook.
82
83    Args:
84      num_evals: The number of evaluations to run for. if set to None, will
85        iterate the dataset until all inputs are exhausted.
86      steps_per_run: Number of steps executed per run call.
87    """
88    self._num_evals = num_evals
89    self._evals_completed = None
90    self._steps_per_run_initial_value = steps_per_run
91
92  def _set_evals_completed_tensor(self, updated_eval_step):
93    self._evals_completed = updated_eval_step
94
95  def begin(self):
96    self._steps_per_run_variable = \
97        basic_session_run_hooks.get_or_create_steps_per_run_variable()
98
99  def after_create_session(self, session, coord):
100    # Update number of steps to run in the first run call
101    if self._num_evals is None:
102      steps = self._steps_per_run_initial_value
103    else:
104      steps = min(self._steps_per_run_initial_value, self._num_evals)
105    self._steps_per_run_variable.load(steps, session=session)
106
107  def before_run(self, run_context):
108    return session_run_hook.SessionRunArgs(
109        {'evals_completed': self._evals_completed})
110
111  def after_run(self, run_context, run_values):
112    evals_completed = run_values.results['evals_completed']
113    # Update number of steps to run in the next iteration
114    if self._num_evals is None:
115      steps = self._steps_per_run_initial_value
116    else:
117      steps = min(self._num_evals - evals_completed,
118                  self._steps_per_run_initial_value)
119    self._steps_per_run_variable.load(steps, session=run_context.session)
120
121    if self._num_evals is None:
122      logging.info('Evaluation [%d]', evals_completed)
123    else:
124      logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
125    if self._num_evals is not None and evals_completed >= self._num_evals:
126      run_context.request_stop()
127
128
129class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
130  """Run hook used by the evaluation routines to run the `eval_ops` N times."""
131
132  def __init__(self, num_evals, log_progress=True):
133    """Constructs the run hook.
134
135    Args:
136      num_evals: The number of evaluations to run for. if set to None, will
137        iterate the dataset until all inputs are exhausted.
138      log_progress: Whether to log evaluation progress, defaults to True.
139    """
140    # The number of evals to run for.
141    self._num_evals = num_evals
142    self._evals_completed = None
143    self._log_progress = log_progress
144    # Reduce logging frequency if there are 20 or more evaluations.
145    self._log_frequency = (1 if (num_evals is None or num_evals < 20) else
146                           math.floor(num_evals / 10.))
147
148  def _set_evals_completed_tensor(self, updated_eval_step):
149    self._evals_completed = updated_eval_step
150
151  def before_run(self, run_context):
152    return session_run_hook.SessionRunArgs(
153        {'evals_completed': self._evals_completed})
154
155  def after_run(self, run_context, run_values):
156    evals_completed = run_values.results['evals_completed']
157    if self._log_progress:
158      if self._num_evals is None:
159        logging.info('Evaluation [%d]', evals_completed)
160      else:
161        if ((evals_completed % self._log_frequency) == 0 or
162            (self._num_evals == evals_completed)):
163          logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
164    if self._num_evals is not None and evals_completed >= self._num_evals:
165      run_context.request_stop()
166
167
168def _evaluate_once(checkpoint_path,
169                   master='',
170                   scaffold=None,
171                   eval_ops=None,
172                   feed_dict=None,
173                   final_ops=None,
174                   final_ops_feed_dict=None,
175                   hooks=None,
176                   config=None):
177  """Evaluates the model at the given checkpoint path.
178
179  During a single evaluation, the `eval_ops` is run until the session is
180  interrupted or requested to finish. This is typically requested via a
181  `tf.contrib.training.StopAfterNEvalsHook` which results in `eval_ops` running
182  the requested number of times.
183
184  Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of
185  `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is
186  evaluated a single time after `eval_ops` has finished running and the fetched
187  values of `final_ops` are returned. If `final_ops` is left as `None`, then
188  `None` is returned.
189
190  One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record
191  summaries after the `eval_ops` have run. If `eval_ops` is `None`, the
192  summaries run immediately after the model checkpoint has been restored.
193
194  Note that `evaluate_once` creates a local variable used to track the number of
195  evaluations run via `tf.contrib.training.get_or_create_eval_step`.
196  Consequently, if a custom local init op is provided via a `scaffold`, the
197  caller should ensure that the local init op also initializes the eval step.
198
199  Args:
200    checkpoint_path: The path to a checkpoint to use for evaluation.
201    master: The BNS address of the TensorFlow master.
202    scaffold: An tf.compat.v1.train.Scaffold instance for initializing variables
203      and restoring variables. Note that `scaffold.init_fn` is used by the
204      function to restore the checkpoint. If you supply a custom init_fn, then
205      it must also take care of restoring the model from its checkpoint.
206    eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to
207      `Tensors`, which is run until the session is requested to stop, commonly
208      done by a `tf.contrib.training.StopAfterNEvalsHook`.
209    feed_dict: The feed dictionary to use when executing the `eval_ops`.
210    final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
211      to `Tensors`.
212    final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
213    hooks: List of `tf.estimator.SessionRunHook` callbacks which are run inside
214      the evaluation loop.
215    config: An instance of `tf.compat.v1.ConfigProto` that will be used to
216      configure the `Session`. If left as `None`, the default will be used.
217
218  Returns:
219    The fetched values of `final_ops` or `None` if `final_ops` is `None`.
220  """
221  eval_step = _get_or_create_eval_step()
222
223  # Prepare the run hooks.
224  hooks = list(hooks or [])
225
226  if eval_ops is not None:
227    if any(isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks):
228      steps_per_run_variable = \
229          basic_session_run_hooks.get_or_create_steps_per_run_variable()
230      update_eval_step = state_ops.assign_add(
231          eval_step,
232          math_ops.cast(steps_per_run_variable, dtype=eval_step.dtype),
233          use_locking=True)
234    else:
235      update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
236
237    if isinstance(eval_ops, dict):
238      eval_ops['update_eval_step'] = update_eval_step
239    elif isinstance(eval_ops, (tuple, list)):
240      eval_ops = list(eval_ops) + [update_eval_step]
241    else:
242      eval_ops = [eval_ops, update_eval_step]
243
244    eval_step_value = _get_latest_eval_step_value(eval_ops)
245
246    for h in hooks:
247      if isinstance(h, (_StopAfterNEvalsHook, _MultiStepStopAfterNEvalsHook)):
248        h._set_evals_completed_tensor(eval_step_value)  # pylint: disable=protected-access
249
250  logging.info('Starting evaluation at ' +
251               time.strftime('%Y-%m-%dT%H:%M:%S', time.localtime()))
252  start = time.time()
253  # Prepare the session creator.
254  session_creator = monitored_session.ChiefSessionCreator(
255      scaffold=scaffold,
256      checkpoint_filename_with_path=checkpoint_path,
257      master=master,
258      config=config)
259
260  final_ops_hook = basic_session_run_hooks.FinalOpsHook(final_ops,
261                                                        final_ops_feed_dict)
262  hooks.append(final_ops_hook)
263
264  with monitored_session.MonitoredSession(
265      session_creator=session_creator, hooks=hooks) as session:
266    if eval_ops is not None:
267      while not session.should_stop():
268        session.run(eval_ops, feed_dict)
269  logging.info('Inference Time : {:0.5f}s'.format(time.time() - start))
270
271  logging.info('Finished evaluation at ' +
272               time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime()))
273  return final_ops_hook.final_ops_values
274