1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains functions for evaluation and summarization of metrics.""" 16 17import math 18import time 19 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import init_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import state_ops 26from tensorflow.python.ops import variable_scope 27from tensorflow.python.platform import tf_logging as logging 28from tensorflow.python.training import basic_session_run_hooks 29from tensorflow.python.training import monitored_session 30from tensorflow.python.training import session_run_hook 31 32 33def _get_or_create_eval_step(): 34 """Gets or creates the eval step `Tensor`. 35 36 Returns: 37 A `Tensor` representing a counter for the evaluation step. 38 39 Raises: 40 ValueError: If multiple `Tensors` have been added to the 41 `tf.GraphKeys.EVAL_STEP` collection. 42 """ 43 graph = ops.get_default_graph() 44 eval_steps = graph.get_collection(ops.GraphKeys.EVAL_STEP) 45 if len(eval_steps) == 1: 46 return eval_steps[0] 47 elif len(eval_steps) > 1: 48 raise ValueError('Multiple tensors added to tf.GraphKeys.EVAL_STEP') 49 else: 50 counter = variable_scope.get_variable( 51 'eval_step', 52 shape=[], 53 dtype=dtypes.int64, 54 initializer=init_ops.zeros_initializer(), 55 trainable=False, 56 collections=[ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.EVAL_STEP]) 57 return counter 58 59 60def _get_latest_eval_step_value(update_ops): 61 """Gets the eval step `Tensor` value after running `update_ops`. 62 63 Args: 64 update_ops: A list of `Tensors` or a dictionary of names to `Tensors`, which 65 are run before reading the eval step value. 66 67 Returns: 68 A `Tensor` representing the value for the evaluation step. 69 """ 70 if isinstance(update_ops, dict): 71 update_ops = list(update_ops.values()) 72 73 with ops.control_dependencies(update_ops): 74 return array_ops.identity(_get_or_create_eval_step().read_value()) 75 76 77class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook): 78 """Run hook used by the evaluation routines to run the `eval_ops` N times.""" 79 80 def __init__(self, num_evals, steps_per_run=1): 81 """Constructs the run hook. 82 83 Args: 84 num_evals: The number of evaluations to run for. if set to None, will 85 iterate the dataset until all inputs are exhausted. 86 steps_per_run: Number of steps executed per run call. 87 """ 88 self._num_evals = num_evals 89 self._evals_completed = None 90 self._steps_per_run_initial_value = steps_per_run 91 92 def _set_evals_completed_tensor(self, updated_eval_step): 93 self._evals_completed = updated_eval_step 94 95 def begin(self): 96 self._steps_per_run_variable = \ 97 basic_session_run_hooks.get_or_create_steps_per_run_variable() 98 99 def after_create_session(self, session, coord): 100 # Update number of steps to run in the first run call 101 if self._num_evals is None: 102 steps = self._steps_per_run_initial_value 103 else: 104 steps = min(self._steps_per_run_initial_value, self._num_evals) 105 self._steps_per_run_variable.load(steps, session=session) 106 107 def before_run(self, run_context): 108 return session_run_hook.SessionRunArgs( 109 {'evals_completed': self._evals_completed}) 110 111 def after_run(self, run_context, run_values): 112 evals_completed = run_values.results['evals_completed'] 113 # Update number of steps to run in the next iteration 114 if self._num_evals is None: 115 steps = self._steps_per_run_initial_value 116 else: 117 steps = min(self._num_evals - evals_completed, 118 self._steps_per_run_initial_value) 119 self._steps_per_run_variable.load(steps, session=run_context.session) 120 121 if self._num_evals is None: 122 logging.info('Evaluation [%d]', evals_completed) 123 else: 124 logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals) 125 if self._num_evals is not None and evals_completed >= self._num_evals: 126 run_context.request_stop() 127 128 129class _StopAfterNEvalsHook(session_run_hook.SessionRunHook): 130 """Run hook used by the evaluation routines to run the `eval_ops` N times.""" 131 132 def __init__(self, num_evals, log_progress=True): 133 """Constructs the run hook. 134 135 Args: 136 num_evals: The number of evaluations to run for. if set to None, will 137 iterate the dataset until all inputs are exhausted. 138 log_progress: Whether to log evaluation progress, defaults to True. 139 """ 140 # The number of evals to run for. 141 self._num_evals = num_evals 142 self._evals_completed = None 143 self._log_progress = log_progress 144 # Reduce logging frequency if there are 20 or more evaluations. 145 self._log_frequency = (1 if (num_evals is None or num_evals < 20) else 146 math.floor(num_evals / 10.)) 147 148 def _set_evals_completed_tensor(self, updated_eval_step): 149 self._evals_completed = updated_eval_step 150 151 def before_run(self, run_context): 152 return session_run_hook.SessionRunArgs( 153 {'evals_completed': self._evals_completed}) 154 155 def after_run(self, run_context, run_values): 156 evals_completed = run_values.results['evals_completed'] 157 if self._log_progress: 158 if self._num_evals is None: 159 logging.info('Evaluation [%d]', evals_completed) 160 else: 161 if ((evals_completed % self._log_frequency) == 0 or 162 (self._num_evals == evals_completed)): 163 logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals) 164 if self._num_evals is not None and evals_completed >= self._num_evals: 165 run_context.request_stop() 166 167 168def _evaluate_once(checkpoint_path, 169 master='', 170 scaffold=None, 171 eval_ops=None, 172 feed_dict=None, 173 final_ops=None, 174 final_ops_feed_dict=None, 175 hooks=None, 176 config=None): 177 """Evaluates the model at the given checkpoint path. 178 179 During a single evaluation, the `eval_ops` is run until the session is 180 interrupted or requested to finish. This is typically requested via a 181 `tf.contrib.training.StopAfterNEvalsHook` which results in `eval_ops` running 182 the requested number of times. 183 184 Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of 185 `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is 186 evaluated a single time after `eval_ops` has finished running and the fetched 187 values of `final_ops` are returned. If `final_ops` is left as `None`, then 188 `None` is returned. 189 190 One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record 191 summaries after the `eval_ops` have run. If `eval_ops` is `None`, the 192 summaries run immediately after the model checkpoint has been restored. 193 194 Note that `evaluate_once` creates a local variable used to track the number of 195 evaluations run via `tf.contrib.training.get_or_create_eval_step`. 196 Consequently, if a custom local init op is provided via a `scaffold`, the 197 caller should ensure that the local init op also initializes the eval step. 198 199 Args: 200 checkpoint_path: The path to a checkpoint to use for evaluation. 201 master: The BNS address of the TensorFlow master. 202 scaffold: An tf.compat.v1.train.Scaffold instance for initializing variables 203 and restoring variables. Note that `scaffold.init_fn` is used by the 204 function to restore the checkpoint. If you supply a custom init_fn, then 205 it must also take care of restoring the model from its checkpoint. 206 eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to 207 `Tensors`, which is run until the session is requested to stop, commonly 208 done by a `tf.contrib.training.StopAfterNEvalsHook`. 209 feed_dict: The feed dictionary to use when executing the `eval_ops`. 210 final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names 211 to `Tensors`. 212 final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`. 213 hooks: List of `tf.estimator.SessionRunHook` callbacks which are run inside 214 the evaluation loop. 215 config: An instance of `tf.compat.v1.ConfigProto` that will be used to 216 configure the `Session`. If left as `None`, the default will be used. 217 218 Returns: 219 The fetched values of `final_ops` or `None` if `final_ops` is `None`. 220 """ 221 eval_step = _get_or_create_eval_step() 222 223 # Prepare the run hooks. 224 hooks = list(hooks or []) 225 226 if eval_ops is not None: 227 if any(isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks): 228 steps_per_run_variable = \ 229 basic_session_run_hooks.get_or_create_steps_per_run_variable() 230 update_eval_step = state_ops.assign_add( 231 eval_step, 232 math_ops.cast(steps_per_run_variable, dtype=eval_step.dtype), 233 use_locking=True) 234 else: 235 update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True) 236 237 if isinstance(eval_ops, dict): 238 eval_ops['update_eval_step'] = update_eval_step 239 elif isinstance(eval_ops, (tuple, list)): 240 eval_ops = list(eval_ops) + [update_eval_step] 241 else: 242 eval_ops = [eval_ops, update_eval_step] 243 244 eval_step_value = _get_latest_eval_step_value(eval_ops) 245 246 for h in hooks: 247 if isinstance(h, (_StopAfterNEvalsHook, _MultiStepStopAfterNEvalsHook)): 248 h._set_evals_completed_tensor(eval_step_value) # pylint: disable=protected-access 249 250 logging.info('Starting evaluation at ' + 251 time.strftime('%Y-%m-%dT%H:%M:%S', time.localtime())) 252 start = time.time() 253 # Prepare the session creator. 254 session_creator = monitored_session.ChiefSessionCreator( 255 scaffold=scaffold, 256 checkpoint_filename_with_path=checkpoint_path, 257 master=master, 258 config=config) 259 260 final_ops_hook = basic_session_run_hooks.FinalOpsHook(final_ops, 261 final_ops_feed_dict) 262 hooks.append(final_ops_hook) 263 264 with monitored_session.MonitoredSession( 265 session_creator=session_creator, hooks=hooks) as session: 266 if eval_ops is not None: 267 while not session.should_stop(): 268 session.run(eval_ops, feed_dict) 269 logging.info('Inference Time : {:0.5f}s'.format(time.time() - start)) 270 271 logging.info('Finished evaluation at ' + 272 time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime())) 273 return final_ops_hook.final_ops_values 274