1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-import-not-at-top 16"""Callbacks: utilities called at certain points during model training. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os 23 24import numpy as np 25 26from tensorflow.python.eager import context 27from tensorflow.python.framework import dtypes 28from tensorflow.python.keras import backend as K 29from tensorflow.python.keras import callbacks 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import state_ops 32from tensorflow.python.ops import summary_ops_v2 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.profiler import profiler_v2 as profiler 36from tensorflow.python.summary import summary as tf_summary 37from tensorflow.python.training import saver 38from tensorflow.python.util.tf_export import keras_export 39 40 41@keras_export(v1=['keras.callbacks.TensorBoard']) 42class TensorBoard(callbacks.TensorBoard): 43 # pylint: disable=line-too-long 44 """Enable visualizations for TensorBoard. 45 46 TensorBoard is a visualization tool provided with TensorFlow. 47 48 This callback logs events for TensorBoard, including: 49 * Metrics summary plots 50 * Training graph visualization 51 * Activation histograms 52 * Sampled profiling 53 54 If you have installed TensorFlow with pip, you should be able 55 to launch TensorBoard from the command line: 56 57 ```sh 58 tensorboard --logdir=path_to_your_logs 59 ``` 60 61 You can find more information about TensorBoard 62 [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). 63 64 Args: 65 log_dir: the path of the directory where to save the log files to be 66 parsed by TensorBoard. 67 histogram_freq: frequency (in epochs) at which to compute activation and 68 weight histograms for the layers of the model. If set to 0, histograms 69 won't be computed. Validation data (or split) must be specified for 70 histogram visualizations. 71 write_graph: whether to visualize the graph in TensorBoard. The log file 72 can become quite large when write_graph is set to True. 73 write_grads: whether to visualize gradient histograms in TensorBoard. 74 `histogram_freq` must be greater than 0. 75 batch_size: size of batch of inputs to feed to the network for histograms 76 computation. 77 write_images: whether to write model weights to visualize as image in 78 TensorBoard. 79 embeddings_freq: frequency (in epochs) at which selected embedding layers 80 will be saved. If set to 0, embeddings won't be computed. Data to be 81 visualized in TensorBoard's Embedding tab must be passed as 82 `embeddings_data`. 83 embeddings_layer_names: a list of names of layers to keep eye on. If None 84 or empty list all the embedding layer will be watched. 85 embeddings_metadata: a dictionary which maps layer name to a file name in 86 which metadata for this embedding layer is saved. 87 [Here are details]( 88 https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional) 89 about metadata files format. In case if the same metadata file is 90 used for all embedding layers, string can be passed. 91 embeddings_data: data to be embedded at layers specified in 92 `embeddings_layer_names`. Numpy array (if the model has a single input) 93 or list of Numpy arrays (if the model has multiple inputs). Learn more 94 about embeddings [in this guide]( 95 https://www.tensorflow.org/programmers_guide/embedding). 96 update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, 97 writes the losses and metrics to TensorBoard after each batch. The same 98 applies for `'epoch'`. If using an integer, let's say `1000`, the 99 callback will write the metrics and losses to TensorBoard every 1000 100 samples. Note that writing too frequently to TensorBoard can slow down 101 your training. 102 profile_batch: Profile the batch to sample compute characteristics. By 103 default, it will profile the second batch. Set profile_batch=0 to 104 disable profiling. 105 106 Raises: 107 ValueError: If histogram_freq is set and no validation data is provided. 108 109 @compatibility(eager) 110 Using the `TensorBoard` callback will work when eager execution is enabled, 111 with the restriction that outputting histogram summaries of weights and 112 gradients is not supported. Consequently, `histogram_freq` will be ignored. 113 @end_compatibility 114 """ 115 116 # pylint: enable=line-too-long 117 118 def __init__(self, 119 log_dir='./logs', 120 histogram_freq=0, 121 batch_size=32, 122 write_graph=True, 123 write_grads=False, 124 write_images=False, 125 embeddings_freq=0, 126 embeddings_layer_names=None, 127 embeddings_metadata=None, 128 embeddings_data=None, 129 update_freq='epoch', 130 profile_batch=2): 131 # Don't call super's init since it is an eager-only version. 132 callbacks.Callback.__init__(self) 133 self.log_dir = log_dir 134 self.histogram_freq = histogram_freq 135 if self.histogram_freq and context.executing_eagerly(): 136 logging.warning( 137 UserWarning('Weight and gradient histograms not supported for eager' 138 'execution, setting `histogram_freq` to `0`.')) 139 self.histogram_freq = 0 140 self.merged = None 141 self.write_graph = write_graph 142 self.write_grads = write_grads 143 self.write_images = write_images 144 self.batch_size = batch_size 145 self._current_batch = 0 146 self._total_batches_seen = 0 147 self._total_val_batches_seen = 0 148 self.embeddings_freq = embeddings_freq 149 self.embeddings_layer_names = embeddings_layer_names 150 self.embeddings_metadata = embeddings_metadata 151 self.embeddings_data = embeddings_data 152 if update_freq == 'batch': 153 self.update_freq = 1 154 else: 155 self.update_freq = update_freq 156 self._samples_seen = 0 157 self._samples_seen_at_last_write = 0 158 # TODO(fishx): Add a link to the full profiler tutorial. 159 self._profile_batch = profile_batch 160 # One profiler session is running if it is True. 161 self._is_profiling = False 162 163 # TensorBoard should only write summaries on the chief when in a 164 # Multi-Worker setting. 165 self._chief_worker_only = True 166 167 def _init_writer(self, model): 168 """Sets file writer.""" 169 if context.executing_eagerly(): 170 self.writer = summary_ops_v2.create_file_writer_v2(self.log_dir) 171 if not model.run_eagerly and self.write_graph: 172 with self.writer.as_default(): 173 summary_ops_v2.graph(K.get_graph()) 174 elif self.write_graph: 175 self.writer = tf_summary.FileWriter(self.log_dir, K.get_graph()) 176 else: 177 self.writer = tf_summary.FileWriter(self.log_dir) 178 179 def _make_histogram_ops(self, model): 180 """Defines histogram ops when histogram_freq > 0.""" 181 # only make histogram summary op if it hasn't already been made 182 if self.histogram_freq and self.merged is None: 183 for layer in self.model.layers: 184 for weight in layer.weights: 185 mapped_weight_name = weight.name.replace(':', '_') 186 tf_summary.histogram(mapped_weight_name, weight) 187 if self.write_images: 188 w_img = array_ops.squeeze(weight) 189 shape = K.int_shape(w_img) 190 if len(shape) == 2: # dense layer kernel case 191 if shape[0] > shape[1]: 192 w_img = array_ops.transpose(w_img) 193 shape = K.int_shape(w_img) 194 w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1]) 195 elif len(shape) == 3: # convnet case 196 if K.image_data_format() == 'channels_last': 197 # switch to channels_first to display 198 # every kernel as a separate image 199 w_img = array_ops.transpose(w_img, perm=[2, 0, 1]) 200 shape = K.int_shape(w_img) 201 w_img = array_ops.reshape(w_img, 202 [shape[0], shape[1], shape[2], 1]) 203 elif len(shape) == 1: # bias case 204 w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1]) 205 else: 206 # not possible to handle 3D convnets etc. 207 continue 208 209 shape = K.int_shape(w_img) 210 assert len(shape) == 4 and shape[-1] in [1, 3, 4] 211 tf_summary.image(mapped_weight_name, w_img) 212 213 if self.write_grads: 214 for weight in layer.trainable_weights: 215 mapped_weight_name = weight.name.replace(':', '_') 216 grads = model.optimizer.get_gradients(model.total_loss, weight) 217 218 def is_indexed_slices(grad): 219 return type(grad).__name__ == 'IndexedSlices' 220 221 grads = [ 222 grad.values if is_indexed_slices(grad) else grad 223 for grad in grads 224 ] 225 tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) 226 227 if hasattr(layer, 'output'): 228 if isinstance(layer.output, list): 229 for i, output in enumerate(layer.output): 230 tf_summary.histogram('{}_out_{}'.format(layer.name, i), output) 231 else: 232 tf_summary.histogram('{}_out'.format(layer.name), layer.output) 233 234 def set_model(self, model): 235 """Sets Keras model and creates summary ops.""" 236 237 self.model = model 238 self._init_writer(model) 239 # histogram summaries only enabled in graph mode 240 if not context.executing_eagerly(): 241 self._make_histogram_ops(model) 242 self.merged = tf_summary.merge_all() 243 244 # If both embedding_freq and embeddings_data are available, we will 245 # visualize embeddings. 246 if self.embeddings_freq and self.embeddings_data is not None: 247 # Avoid circular dependency. 248 from tensorflow.python.keras.engine import training_utils_v1 # pylint: disable=g-import-not-at-top 249 self.embeddings_data = training_utils_v1.standardize_input_data( 250 self.embeddings_data, model.input_names) 251 252 # If embedding_layer_names are not provided, get all of the embedding 253 # layers from the model. 254 embeddings_layer_names = self.embeddings_layer_names 255 if not embeddings_layer_names: 256 embeddings_layer_names = [ 257 layer.name 258 for layer in self.model.layers 259 if type(layer).__name__ == 'Embedding' 260 ] 261 262 self.assign_embeddings = [] 263 embeddings_vars = {} 264 265 self.batch_id = batch_id = array_ops.placeholder(dtypes.int32) 266 self.step = step = array_ops.placeholder(dtypes.int32) 267 268 for layer in self.model.layers: 269 if layer.name in embeddings_layer_names: 270 embedding_input = self.model.get_layer(layer.name).output 271 embedding_size = np.prod(embedding_input.shape[1:]) 272 embedding_input = array_ops.reshape(embedding_input, 273 (step, int(embedding_size))) 274 shape = (self.embeddings_data[0].shape[0], int(embedding_size)) 275 embedding = variables.Variable( 276 array_ops.zeros(shape), name=layer.name + '_embedding') 277 embeddings_vars[layer.name] = embedding 278 batch = state_ops.assign(embedding[batch_id:batch_id + step], 279 embedding_input) 280 self.assign_embeddings.append(batch) 281 282 self.saver = saver.Saver(list(embeddings_vars.values())) 283 284 # Create embeddings_metadata dictionary 285 if isinstance(self.embeddings_metadata, str): 286 embeddings_metadata = { 287 layer_name: self.embeddings_metadata 288 for layer_name in embeddings_vars.keys() 289 } 290 else: 291 # If embedding_metadata is already a dictionary 292 embeddings_metadata = self.embeddings_metadata 293 294 try: 295 from tensorboard.plugins import projector 296 except ImportError: 297 raise ImportError('Failed to import TensorBoard. Please make sure that ' 298 'TensorBoard integration is complete."') 299 300 # TODO(psv): Add integration tests to test embedding visualization 301 # with TensorBoard callback. We are unable to write a unit test for this 302 # because TensorBoard dependency assumes TensorFlow package is installed. 303 config = projector.ProjectorConfig() 304 for layer_name, tensor in embeddings_vars.items(): 305 embedding = config.embeddings.add() 306 embedding.tensor_name = tensor.name 307 308 if (embeddings_metadata is not None and 309 layer_name in embeddings_metadata): 310 embedding.metadata_path = embeddings_metadata[layer_name] 311 312 projector.visualize_embeddings(self.writer, config) 313 314 def _fetch_callback(self, summary): 315 self.writer.add_summary(summary, self._total_val_batches_seen) 316 self._total_val_batches_seen += 1 317 318 def _write_custom_summaries(self, step, logs=None): 319 """Writes metrics out as custom scalar summaries. 320 321 Args: 322 step: the global step to use for TensorBoard. 323 logs: dict. Keys are scalar summary names, values are 324 NumPy scalars. 325 326 """ 327 logs = logs or {} 328 if context.executing_eagerly(): 329 # use v2 summary ops 330 with self.writer.as_default(), summary_ops_v2.record_if(True): 331 for name, value in logs.items(): 332 if isinstance(value, np.ndarray): 333 value = value.item() 334 summary_ops_v2.scalar(name, value, step=step) 335 else: 336 # use FileWriter from v1 summary 337 for name, value in logs.items(): 338 if isinstance(value, np.ndarray): 339 value = value.item() 340 summary = tf_summary.Summary() 341 summary_value = summary.value.add() 342 summary_value.simple_value = value 343 summary_value.tag = name 344 self.writer.add_summary(summary, step) 345 self.writer.flush() 346 347 def on_train_batch_begin(self, batch, logs=None): 348 if (not self._is_profiling and 349 self._total_batches_seen == self._profile_batch - 1): 350 profiler.start(self.log_dir) 351 self._is_profiling = True 352 353 def on_train_batch_end(self, batch, logs=None): 354 return self.on_batch_end(batch, logs) 355 356 def on_test_begin(self, logs=None): 357 pass 358 359 def on_test_end(self, logs=None): 360 pass 361 362 def on_batch_end(self, batch, logs=None): 363 """Writes scalar summaries for metrics on every training batch. 364 365 Performs profiling if current batch is in profiler_batches. 366 """ 367 # Don't output batch_size and batch number as TensorBoard summaries 368 logs = logs or {} 369 self._samples_seen += logs.get('size', 1) 370 samples_seen_since = self._samples_seen - self._samples_seen_at_last_write 371 if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq: 372 batch_logs = {('batch_' + k): v 373 for k, v in logs.items() 374 if k not in ['batch', 'size', 'num_steps']} 375 self._write_custom_summaries(self._total_batches_seen, batch_logs) 376 self._samples_seen_at_last_write = self._samples_seen 377 self._total_batches_seen += 1 378 379 if self._is_profiling: 380 profiler.stop() 381 self._is_profiling = False 382 383 def on_train_begin(self, logs=None): 384 pass 385 386 def on_epoch_begin(self, epoch, logs=None): 387 """Add histogram op to Model eval_function callbacks, reset batch count.""" 388 389 # check if histogram summary should be run for this epoch 390 if self.histogram_freq and epoch % self.histogram_freq == 0: 391 self._epoch = epoch 392 # pylint: disable=protected-access 393 # add the histogram summary op if it should run this epoch 394 self.model._make_test_function() 395 if self.merged not in self.model.test_function.fetches: 396 self.model.test_function.fetches.append(self.merged) 397 self.model.test_function.fetch_callbacks[ 398 self.merged] = self._fetch_callback 399 # pylint: enable=protected-access 400 401 def on_epoch_end(self, epoch, logs=None): 402 """Checks if summary ops should run next epoch, logs scalar summaries.""" 403 404 # don't output batch_size and 405 # batch number as TensorBoard summaries 406 logs = {('epoch_' + k): v 407 for k, v in logs.items() 408 if k not in ['batch', 'size', 'num_steps']} 409 if self.update_freq == 'epoch': 410 step = epoch 411 else: 412 step = self._samples_seen 413 self._write_custom_summaries(step, logs) 414 415 # pop the histogram summary op after each epoch 416 if self.histogram_freq: 417 # pylint: disable=protected-access 418 if self.merged in self.model.test_function.fetches: 419 self.model.test_function.fetches.remove(self.merged) 420 if self.merged in self.model.test_function.fetch_callbacks: 421 self.model.test_function.fetch_callbacks.pop(self.merged) 422 # pylint: enable=protected-access 423 424 if self.embeddings_data is None and self.embeddings_freq: 425 raise ValueError('To visualize embeddings, embeddings_data must ' 426 'be provided.') 427 428 if self.embeddings_freq and self.embeddings_data is not None: 429 if epoch % self.embeddings_freq == 0: 430 # We need a second forward-pass here because we're passing 431 # the `embeddings_data` explicitly. This design allows to pass 432 # arbitrary data as `embeddings_data` and results from the fact 433 # that we need to know the size of the `tf.Variable`s which 434 # hold the embeddings in `set_model`. At this point, however, 435 # the `validation_data` is not yet set. 436 437 embeddings_data = self.embeddings_data 438 n_samples = embeddings_data[0].shape[0] 439 i = 0 440 sess = K.get_session() 441 while i < n_samples: 442 step = min(self.batch_size, n_samples - i) 443 batch = slice(i, i + step) 444 445 if isinstance(self.model.input, list): 446 feed_dict = { 447 model_input: embeddings_data[idx][batch] 448 for idx, model_input in enumerate(self.model.input) 449 } 450 else: 451 feed_dict = {self.model.input: embeddings_data[0][batch]} 452 453 feed_dict.update({self.batch_id: i, self.step: step}) 454 455 if not isinstance(K.learning_phase(), int): 456 feed_dict[K.learning_phase()] = False 457 458 sess.run(self.assign_embeddings, feed_dict=feed_dict) 459 self.saver.save(sess, 460 os.path.join(self.log_dir, 'keras_embedding.ckpt'), 461 epoch) 462 463 i += self.batch_size 464 465 def on_train_end(self, logs=None): 466 if self._is_profiling: 467 profiler.stop() 468 self._is_profiling = False 469 self.writer.close() 470