1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A SessionRunHook extends `session.run()` calls for the `MonitoredSession`. 16 17SessionRunHooks are useful to track training, report progress, request early 18stopping and more. SessionRunHooks use the observer pattern and notify at the 19following points: 20 - when a session starts being used 21 - before a call to the `session.run()` 22 - after a call to the `session.run()` 23 - when the session closed 24 25A SessionRunHook encapsulates a piece of reusable/composable computation that 26can piggyback a call to `MonitoredSession.run()`. A hook can add any 27ops-or-tensor/feeds to the run call, and when the run call finishes with success 28gets the outputs it requested. Hooks are allowed to add ops to the graph in 29`hook.begin()`. The graph is finalized after the `begin()` method is called. 30 31There are a few pre-defined hooks: 32 - StopAtStepHook: Request stop based on global_step 33 - CheckpointSaverHook: saves checkpoint 34 - LoggingTensorHook: outputs one or more tensor values to log 35 - NanTensorHook: Request stop if given `Tensor` contains Nans. 36 - SummarySaverHook: saves summaries to a summary writer 37 38For more specific needs, you can create custom hooks: 39 class ExampleHook(SessionRunHook): 40 def begin(self): 41 # You can add ops to the graph here. 42 print('Starting the session.') 43 self.your_tensor = ... 44 45 def after_create_session(self, session, coord): 46 # When this is called, the graph is finalized and 47 # ops can no longer be added to the graph. 48 print('Session created.') 49 50 def before_run(self, run_context): 51 print('Before calling session.run().') 52 return SessionRunArgs(self.your_tensor) 53 54 def after_run(self, run_context, run_values): 55 print('Done running one step. The value of my tensor: %s', 56 run_values.results) 57 if you-need-to-stop-loop: 58 run_context.request_stop() 59 60 def end(self, session): 61 print('Done with the session.') 62 63To understand how hooks interact with calls to `MonitoredSession.run()`, 64look at following code: 65 with MonitoredTrainingSession(hooks=your_hooks, ...) as sess: 66 while not sess.should_stop(): 67 sess.run(your_fetches) 68 69Above user code leads to following execution: 70 call hooks.begin() 71 sess = tf.compat.v1.Session() 72 call hooks.after_create_session() 73 while not stop is requested: 74 call hooks.before_run() 75 try: 76 results = sess.run(merged_fetches, feed_dict=merged_feeds) 77 except (errors.OutOfRangeError, StopIteration): 78 break 79 call hooks.after_run() 80 call hooks.end() 81 sess.close() 82 83Note that if sess.run() raises OutOfRangeError or StopIteration then 84hooks.after_run() will not be called but hooks.end() will still be called. 85If sess.run() raises any other exception then neither hooks.after_run() nor 86hooks.end() will be called. 87""" 88 89from __future__ import absolute_import 90from __future__ import division 91from __future__ import print_function 92 93import collections 94from tensorflow.python.util.tf_export import tf_export 95 96 97@tf_export(v1=["train.SessionRunHook"]) 98class SessionRunHook(object): 99 """Hook to extend calls to MonitoredSession.run().""" 100 101 def begin(self): 102 """Called once before using the session. 103 104 When called, the default graph is the one that will be launched in the 105 session. The hook can modify the graph by adding new operations to it. 106 After the `begin()` call the graph will be finalized and the other callbacks 107 can not modify the graph anymore. Second call of `begin()` on the same 108 graph, should not change the graph. 109 """ 110 pass 111 112 def after_create_session(self, session, coord): # pylint: disable=unused-argument 113 """Called when new TensorFlow session is created. 114 115 This is called to signal the hooks that a new session has been created. This 116 has two essential differences with the situation in which `begin` is called: 117 118 * When this is called, the graph is finalized and ops can no longer be added 119 to the graph. 120 * This method will also be called as a result of recovering a wrapped 121 session, not only at the beginning of the overall session. 122 123 Args: 124 session: A TensorFlow Session that has been created. 125 coord: A Coordinator object which keeps track of all threads. 126 """ 127 pass 128 129 def before_run(self, run_context): # pylint: disable=unused-argument 130 """Called before each call to run(). 131 132 You can return from this call a `SessionRunArgs` object indicating ops or 133 tensors to add to the upcoming `run()` call. These ops/tensors will be run 134 together with the ops/tensors originally passed to the original run() call. 135 The run args you return can also contain feeds to be added to the run() 136 call. 137 138 The `run_context` argument is a `SessionRunContext` that provides 139 information about the upcoming `run()` call: the originally requested 140 op/tensors, the TensorFlow Session. 141 142 At this point graph is finalized and you can not add ops. 143 144 Args: 145 run_context: A `SessionRunContext` object. 146 147 Returns: 148 None or a `SessionRunArgs` object. 149 """ 150 return None 151 152 def after_run(self, 153 run_context, # pylint: disable=unused-argument 154 run_values): # pylint: disable=unused-argument 155 """Called after each call to run(). 156 157 The `run_values` argument contains results of requested ops/tensors by 158 `before_run()`. 159 160 The `run_context` argument is the same one send to `before_run` call. 161 `run_context.request_stop()` can be called to stop the iteration. 162 163 If `session.run()` raises any exceptions then `after_run()` is not called. 164 165 Args: 166 run_context: A `SessionRunContext` object. 167 run_values: A SessionRunValues object. 168 """ 169 pass 170 171 def end(self, session): # pylint: disable=unused-argument 172 """Called at the end of session. 173 174 The `session` argument can be used in case the hook wants to run final ops, 175 such as saving a last checkpoint. 176 177 If `session.run()` raises exception other than OutOfRangeError or 178 StopIteration then `end()` is not called. 179 Note the difference between `end()` and `after_run()` behavior when 180 `session.run()` raises OutOfRangeError or StopIteration. In that case 181 `end()` is called but `after_run()` is not called. 182 183 Args: 184 session: A TensorFlow Session that will be soon closed. 185 """ 186 pass 187 188 189@tf_export(v1=["train.SessionRunArgs"]) 190class SessionRunArgs( 191 collections.namedtuple("SessionRunArgs", 192 ["fetches", "feed_dict", "options"])): 193 """Represents arguments to be added to a `Session.run()` call. 194 195 Args: 196 fetches: Exactly like the 'fetches' argument to Session.Run(). 197 Can be a single tensor or op, a list of 'fetches' or a dictionary 198 of fetches. For example: 199 fetches = global_step_tensor 200 fetches = [train_op, summary_op, global_step_tensor] 201 fetches = {'step': global_step_tensor, 'summ': summary_op} 202 Note that this can recurse as expected: 203 fetches = {'step': global_step_tensor, 204 'ops': [train_op, check_nan_op]} 205 feed_dict: Exactly like the `feed_dict` argument to `Session.Run()` 206 options: Exactly like the `options` argument to `Session.run()`, i.e., a 207 config_pb2.RunOptions proto. 208 """ 209 210 def __new__(cls, fetches, feed_dict=None, options=None): 211 return super(SessionRunArgs, cls).__new__(cls, fetches, feed_dict, options) 212 213 214@tf_export(v1=["train.SessionRunContext"]) 215class SessionRunContext(object): 216 """Provides information about the `session.run()` call being made. 217 218 Provides information about original request to `Session.Run()` function. 219 SessionRunHook objects can stop the loop by calling `request_stop()` of 220 `run_context`. In the future we may use this object to add more information 221 about run without changing the Hook API. 222 """ 223 224 def __init__(self, original_args, session): 225 """Initializes SessionRunContext.""" 226 self._original_args = original_args 227 self._session = session 228 self._stop_requested = False 229 230 @property 231 def original_args(self): 232 """A `SessionRunArgs` object holding the original arguments of `run()`. 233 234 If user called `MonitoredSession.run(fetches=a, feed_dict=b)`, then this 235 field is equal to SessionRunArgs(a, b). 236 237 Returns: 238 A `SessionRunArgs` object 239 """ 240 return self._original_args 241 242 @property 243 def session(self): 244 """A TensorFlow session object which will execute the `run`.""" 245 return self._session 246 247 @property 248 def stop_requested(self): 249 """Returns whether a stop is requested or not. 250 251 If true, `MonitoredSession` stops iterations. 252 Returns: 253 A `bool` 254 """ 255 return self._stop_requested 256 257 def request_stop(self): 258 """Sets stop requested field. 259 260 Hooks can use this function to request stop of iterations. 261 `MonitoredSession` checks whether this is called or not. 262 """ 263 self._stop_requested = True 264 265 266@tf_export(v1=["train.SessionRunValues"]) 267class SessionRunValues( 268 collections.namedtuple("SessionRunValues", 269 ["results", "options", "run_metadata"])): 270 """Contains the results of `Session.run()`. 271 272 In the future we may use this object to add more information about result of 273 run without changing the Hook API. 274 275 Args: 276 results: The return values from `Session.run()` corresponding to the fetches 277 attribute returned in the RunArgs. Note that this has the same shape as 278 the RunArgs fetches. For example: 279 fetches = global_step_tensor 280 => results = nparray(int) 281 fetches = [train_op, summary_op, global_step_tensor] 282 => results = [None, nparray(string), nparray(int)] 283 fetches = {'step': global_step_tensor, 'summ': summary_op} 284 => results = {'step': nparray(int), 'summ': nparray(string)} 285 options: `RunOptions` from the `Session.run()` call. 286 run_metadata: `RunMetadata` from the `Session.run()` call. 287 """ 288