1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-import-not-at-top 16# pylint: disable=g-classes-have-attributes 17"""Callbacks: utilities called at certain points during model training.""" 18 19import os 20import numpy as np 21 22from tensorflow.python.eager import context 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import errors 25from tensorflow.python.keras import backend as K 26from tensorflow.python.keras import callbacks 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import state_ops 29from tensorflow.python.ops import summary_ops_v2 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.profiler import profiler_v2 as profiler 33from tensorflow.python.summary import summary as tf_summary 34from tensorflow.python.training import saver 35from tensorflow.python.util.tf_export import keras_export 36 37 38@keras_export(v1=['keras.callbacks.TensorBoard']) 39class TensorBoard(callbacks.TensorBoard): 40 # pylint: disable=line-too-long 41 """Enable visualizations for TensorBoard. 42 43 TensorBoard is a visualization tool provided with TensorFlow. 44 45 This callback logs events for TensorBoard, including: 46 * Metrics summary plots 47 * Training graph visualization 48 * Activation histograms 49 * Sampled profiling 50 51 If you have installed TensorFlow with pip, you should be able 52 to launch TensorBoard from the command line: 53 54 ```sh 55 tensorboard --logdir=path_to_your_logs 56 ``` 57 58 You can find more information about TensorBoard 59 [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). 60 61 Args: 62 log_dir: the path of the directory where to save the log files to be 63 parsed by TensorBoard. 64 histogram_freq: frequency (in epochs) at which to compute activation and 65 weight histograms for the layers of the model. If set to 0, histograms 66 won't be computed. Validation data (or split) must be specified for 67 histogram visualizations. 68 write_graph: whether to visualize the graph in TensorBoard. The log file 69 can become quite large when write_graph is set to True. 70 write_grads: whether to visualize gradient histograms in TensorBoard. 71 `histogram_freq` must be greater than 0. 72 batch_size: size of batch of inputs to feed to the network for histograms 73 computation. 74 write_images: whether to write model weights to visualize as image in 75 TensorBoard. 76 embeddings_freq: frequency (in epochs) at which selected embedding layers 77 will be saved. If set to 0, embeddings won't be computed. Data to be 78 visualized in TensorBoard's Embedding tab must be passed as 79 `embeddings_data`. 80 embeddings_layer_names: a list of names of layers to keep eye on. If None 81 or empty list all the embedding layer will be watched. 82 embeddings_metadata: a dictionary which maps layer name to a file name in 83 which metadata for this embedding layer is saved. 84 [Here are details]( 85 https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional) 86 about metadata files format. In case if the same metadata file is 87 used for all embedding layers, string can be passed. 88 embeddings_data: data to be embedded at layers specified in 89 `embeddings_layer_names`. Numpy array (if the model has a single input) 90 or list of Numpy arrays (if the model has multiple inputs). Learn more 91 about embeddings [in this guide]( 92 https://www.tensorflow.org/programmers_guide/embedding). 93 update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, 94 writes the losses and metrics to TensorBoard after each batch. The same 95 applies for `'epoch'`. If using an integer, let's say `1000`, the 96 callback will write the metrics and losses to TensorBoard every 1000 97 samples. Note that writing too frequently to TensorBoard can slow down 98 your training. 99 profile_batch: Profile the batch to sample compute characteristics. By 100 default, it will profile the second batch. Set profile_batch=0 to 101 disable profiling. 102 103 Raises: 104 ValueError: If histogram_freq is set and no validation data is provided. 105 106 @compatibility(eager) 107 Using the `TensorBoard` callback will work when eager execution is enabled, 108 with the restriction that outputting histogram summaries of weights and 109 gradients is not supported. Consequently, `histogram_freq` will be ignored. 110 @end_compatibility 111 """ 112 113 # pylint: enable=line-too-long 114 115 def __init__(self, 116 log_dir='./logs', 117 histogram_freq=0, 118 batch_size=32, 119 write_graph=True, 120 write_grads=False, 121 write_images=False, 122 embeddings_freq=0, 123 embeddings_layer_names=None, 124 embeddings_metadata=None, 125 embeddings_data=None, 126 update_freq='epoch', 127 profile_batch=2): 128 # Don't call super's init since it is an eager-only version. 129 callbacks.Callback.__init__(self) 130 self.log_dir = log_dir 131 self.histogram_freq = histogram_freq 132 if self.histogram_freq and context.executing_eagerly(): 133 logging.warning( 134 UserWarning('Weight and gradient histograms not supported for eager' 135 'execution, setting `histogram_freq` to `0`.')) 136 self.histogram_freq = 0 137 self.merged = None 138 self.write_graph = write_graph 139 self.write_grads = write_grads 140 self.write_images = write_images 141 self.batch_size = batch_size 142 self._current_batch = 0 143 self._total_batches_seen = 0 144 self._total_val_batches_seen = 0 145 self.embeddings_freq = embeddings_freq 146 self.embeddings_layer_names = embeddings_layer_names 147 self.embeddings_metadata = embeddings_metadata 148 self.embeddings_data = embeddings_data 149 if update_freq == 'batch': 150 self.update_freq = 1 151 else: 152 self.update_freq = update_freq 153 self._samples_seen = 0 154 self._samples_seen_at_last_write = 0 155 # TODO(fishx): Add a link to the full profiler tutorial. 156 self._profile_batch = profile_batch 157 # True when the profiler was successfully started by this callback. 158 # We track the status here to make sure callbacks do not interfere with 159 # each other. The callback will only stop the profiler it started. 160 self._profiler_started = False 161 162 # TensorBoard should only write summaries on the chief when in a 163 # Multi-Worker setting. 164 self._chief_worker_only = True 165 166 def _init_writer(self, model): 167 """Sets file writer.""" 168 if context.executing_eagerly(): 169 self.writer = summary_ops_v2.create_file_writer_v2(self.log_dir) 170 if not model.run_eagerly and self.write_graph: 171 with self.writer.as_default(): 172 summary_ops_v2.graph(K.get_graph()) 173 elif self.write_graph: 174 self.writer = tf_summary.FileWriter(self.log_dir, K.get_graph()) 175 else: 176 self.writer = tf_summary.FileWriter(self.log_dir) 177 178 def _make_histogram_ops(self, model): 179 """Defines histogram ops when histogram_freq > 0.""" 180 # only make histogram summary op if it hasn't already been made 181 if self.histogram_freq and self.merged is None: 182 for layer in self.model.layers: 183 for weight in layer.weights: 184 mapped_weight_name = weight.name.replace(':', '_') 185 tf_summary.histogram(mapped_weight_name, weight) 186 if self.write_images: 187 w_img = array_ops.squeeze(weight) 188 shape = K.int_shape(w_img) 189 if len(shape) == 2: # dense layer kernel case 190 if shape[0] > shape[1]: 191 w_img = array_ops.transpose(w_img) 192 shape = K.int_shape(w_img) 193 w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1]) 194 elif len(shape) == 3: # convnet case 195 if K.image_data_format() == 'channels_last': 196 # switch to channels_first to display 197 # every kernel as a separate image 198 w_img = array_ops.transpose(w_img, perm=[2, 0, 1]) 199 shape = K.int_shape(w_img) 200 w_img = array_ops.reshape(w_img, 201 [shape[0], shape[1], shape[2], 1]) 202 elif len(shape) == 1: # bias case 203 w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1]) 204 else: 205 # not possible to handle 3D convnets etc. 206 continue 207 208 shape = K.int_shape(w_img) 209 assert len(shape) == 4 and shape[-1] in [1, 3, 4] 210 tf_summary.image(mapped_weight_name, w_img) 211 212 if self.write_grads: 213 for weight in layer.trainable_weights: 214 mapped_weight_name = weight.name.replace(':', '_') 215 grads = model.optimizer.get_gradients(model.total_loss, weight) 216 217 def is_indexed_slices(grad): 218 return type(grad).__name__ == 'IndexedSlices' 219 220 grads = [ 221 grad.values if is_indexed_slices(grad) else grad 222 for grad in grads 223 ] 224 tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) 225 226 if hasattr(layer, 'output'): 227 if isinstance(layer.output, list): 228 for i, output in enumerate(layer.output): 229 tf_summary.histogram('{}_out_{}'.format(layer.name, i), output) 230 else: 231 tf_summary.histogram('{}_out'.format(layer.name), layer.output) 232 233 def set_model(self, model): 234 """Sets Keras model and creates summary ops.""" 235 236 self.model = model 237 self._init_writer(model) 238 # histogram summaries only enabled in graph mode 239 if not context.executing_eagerly(): 240 self._make_histogram_ops(model) 241 self.merged = tf_summary.merge_all() 242 243 # If both embedding_freq and embeddings_data are available, we will 244 # visualize embeddings. 245 if self.embeddings_freq and self.embeddings_data is not None: 246 # Avoid circular dependency. 247 from tensorflow.python.keras.engine import training_utils_v1 # pylint: disable=g-import-not-at-top 248 self.embeddings_data = training_utils_v1.standardize_input_data( 249 self.embeddings_data, model.input_names) 250 251 # If embedding_layer_names are not provided, get all of the embedding 252 # layers from the model. 253 embeddings_layer_names = self.embeddings_layer_names 254 if not embeddings_layer_names: 255 embeddings_layer_names = [ 256 layer.name 257 for layer in self.model.layers 258 if type(layer).__name__ == 'Embedding' 259 ] 260 261 self.assign_embeddings = [] 262 embeddings_vars = {} 263 264 self.batch_id = batch_id = array_ops.placeholder(dtypes.int32) 265 self.step = step = array_ops.placeholder(dtypes.int32) 266 267 for layer in self.model.layers: 268 if layer.name in embeddings_layer_names: 269 embedding_input = self.model.get_layer(layer.name).output 270 embedding_size = np.prod(embedding_input.shape[1:]) 271 embedding_input = array_ops.reshape(embedding_input, 272 (step, int(embedding_size))) 273 shape = (self.embeddings_data[0].shape[0], int(embedding_size)) 274 embedding = variables.Variable( 275 array_ops.zeros(shape), name=layer.name + '_embedding') 276 embeddings_vars[layer.name] = embedding 277 batch = state_ops.assign(embedding[batch_id:batch_id + step], 278 embedding_input) 279 self.assign_embeddings.append(batch) 280 281 self.saver = saver.Saver(list(embeddings_vars.values())) 282 283 # Create embeddings_metadata dictionary 284 if isinstance(self.embeddings_metadata, str): 285 embeddings_metadata = { 286 layer_name: self.embeddings_metadata 287 for layer_name in embeddings_vars.keys() 288 } 289 else: 290 # If embedding_metadata is already a dictionary 291 embeddings_metadata = self.embeddings_metadata 292 293 try: 294 from tensorboard.plugins import projector 295 except ImportError: 296 raise ImportError('Failed to import TensorBoard. Please make sure that ' 297 'TensorBoard integration is complete."') 298 299 # TODO(psv): Add integration tests to test embedding visualization 300 # with TensorBoard callback. We are unable to write a unit test for this 301 # because TensorBoard dependency assumes TensorFlow package is installed. 302 config = projector.ProjectorConfig() 303 for layer_name, tensor in embeddings_vars.items(): 304 embedding = config.embeddings.add() 305 embedding.tensor_name = tensor.name 306 307 if (embeddings_metadata is not None and 308 layer_name in embeddings_metadata): 309 embedding.metadata_path = embeddings_metadata[layer_name] 310 311 projector.visualize_embeddings(self.writer, config) 312 313 def _fetch_callback(self, summary): 314 self.writer.add_summary(summary, self._total_val_batches_seen) 315 self._total_val_batches_seen += 1 316 317 def _write_custom_summaries(self, step, logs=None): 318 """Writes metrics out as custom scalar summaries. 319 320 Args: 321 step: the global step to use for TensorBoard. 322 logs: dict. Keys are scalar summary names, values are 323 NumPy scalars. 324 325 """ 326 logs = logs or {} 327 if context.executing_eagerly(): 328 # use v2 summary ops 329 with self.writer.as_default(), summary_ops_v2.record_if(True): 330 for name, value in logs.items(): 331 if isinstance(value, np.ndarray): 332 value = value.item() 333 summary_ops_v2.scalar(name, value, step=step) 334 else: 335 # use FileWriter from v1 summary 336 for name, value in logs.items(): 337 if isinstance(value, np.ndarray): 338 value = value.item() 339 summary = tf_summary.Summary() 340 summary_value = summary.value.add() 341 summary_value.simple_value = value 342 summary_value.tag = name 343 self.writer.add_summary(summary, step) 344 self.writer.flush() 345 346 def on_train_batch_begin(self, batch, logs=None): 347 if self._total_batches_seen == self._profile_batch - 1: 348 self._start_profiler() 349 350 def on_train_batch_end(self, batch, logs=None): 351 return self.on_batch_end(batch, logs) 352 353 def on_test_begin(self, logs=None): 354 pass 355 356 def on_test_end(self, logs=None): 357 pass 358 359 def on_batch_end(self, batch, logs=None): 360 """Writes scalar summaries for metrics on every training batch. 361 362 Performs profiling if current batch is in profiler_batches. 363 """ 364 # Don't output batch_size and batch number as TensorBoard summaries 365 logs = logs or {} 366 self._samples_seen += logs.get('size', 1) 367 samples_seen_since = self._samples_seen - self._samples_seen_at_last_write 368 if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq: 369 batch_logs = {('batch_' + k): v 370 for k, v in logs.items() 371 if k not in ['batch', 'size', 'num_steps']} 372 self._write_custom_summaries(self._total_batches_seen, batch_logs) 373 self._samples_seen_at_last_write = self._samples_seen 374 self._total_batches_seen += 1 375 self._stop_profiler() 376 377 def on_train_begin(self, logs=None): 378 pass 379 380 def on_epoch_begin(self, epoch, logs=None): 381 """Add histogram op to Model eval_function callbacks, reset batch count.""" 382 383 # check if histogram summary should be run for this epoch 384 if self.histogram_freq and epoch % self.histogram_freq == 0: 385 # pylint: disable=protected-access 386 # add the histogram summary op if it should run this epoch 387 self.model._make_test_function() 388 if self.merged not in self.model.test_function.fetches: 389 self.model.test_function.fetches.append(self.merged) 390 self.model.test_function.fetch_callbacks[ 391 self.merged] = self._fetch_callback 392 # pylint: enable=protected-access 393 394 def on_epoch_end(self, epoch, logs=None): 395 """Checks if summary ops should run next epoch, logs scalar summaries.""" 396 397 # don't output batch_size and 398 # batch number as TensorBoard summaries 399 logs = {('epoch_' + k): v 400 for k, v in logs.items() 401 if k not in ['batch', 'size', 'num_steps']} 402 if self.update_freq == 'epoch': 403 step = epoch 404 else: 405 step = self._samples_seen 406 self._write_custom_summaries(step, logs) 407 408 # pop the histogram summary op after each epoch 409 if self.histogram_freq: 410 # pylint: disable=protected-access 411 if self.merged in self.model.test_function.fetches: 412 self.model.test_function.fetches.remove(self.merged) 413 if self.merged in self.model.test_function.fetch_callbacks: 414 self.model.test_function.fetch_callbacks.pop(self.merged) 415 # pylint: enable=protected-access 416 417 if self.embeddings_data is None and self.embeddings_freq: 418 raise ValueError('To visualize embeddings, embeddings_data must ' 419 'be provided.') 420 421 if self.embeddings_freq and self.embeddings_data is not None: 422 if epoch % self.embeddings_freq == 0: 423 # We need a second forward-pass here because we're passing 424 # the `embeddings_data` explicitly. This design allows to pass 425 # arbitrary data as `embeddings_data` and results from the fact 426 # that we need to know the size of the `tf.Variable`s which 427 # hold the embeddings in `set_model`. At this point, however, 428 # the `validation_data` is not yet set. 429 430 embeddings_data = self.embeddings_data 431 n_samples = embeddings_data[0].shape[0] 432 i = 0 433 sess = K.get_session() 434 while i < n_samples: 435 step = min(self.batch_size, n_samples - i) 436 batch = slice(i, i + step) 437 438 if isinstance(self.model.input, list): 439 feed_dict = { 440 model_input: embeddings_data[idx][batch] 441 for idx, model_input in enumerate(self.model.input) 442 } 443 else: 444 feed_dict = {self.model.input: embeddings_data[0][batch]} 445 446 feed_dict.update({self.batch_id: i, self.step: step}) 447 448 if not isinstance(K.learning_phase(), int): 449 feed_dict[K.learning_phase()] = False 450 451 sess.run(self.assign_embeddings, feed_dict=feed_dict) 452 self.saver.save(sess, 453 os.path.join(self.log_dir, 'keras_embedding.ckpt'), 454 epoch) 455 456 i += self.batch_size 457 458 def on_train_end(self, logs=None): 459 self._stop_profiler() 460 self.writer.close() 461 462 def _start_profiler(self): 463 """Starts the profiler if currently inactive.""" 464 if self._profiler_started: 465 return 466 try: 467 profiler.start(logdir=self.log_dir) 468 self._profiler_started = True 469 except errors.AlreadyExistsError as e: 470 # Profiler errors should not be fatal. 471 logging.error('Failed to start profiler: %s', e.message) 472 473 def _stop_profiler(self): 474 """Stops the profiler if currently active.""" 475 if not self._profiler_started: 476 return 477 try: 478 profiler.stop() 479 except errors.UnavailableError as e: 480 # Profiler errors should not be fatal. 481 logging.error('Failed to stop profiler: %s', e.message) 482 finally: 483 self._profiler_started = False 484