• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-import-not-at-top
16# pylint: disable=g-classes-have-attributes
17"""Callbacks: utilities called at certain points during model training."""
18
19import os
20import numpy as np
21
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import errors
25from tensorflow.python.keras import backend as K
26from tensorflow.python.keras import callbacks
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import state_ops
29from tensorflow.python.ops import summary_ops_v2
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.profiler import profiler_v2 as profiler
33from tensorflow.python.summary import summary as tf_summary
34from tensorflow.python.training import saver
35from tensorflow.python.util.tf_export import keras_export
36
37
38@keras_export(v1=['keras.callbacks.TensorBoard'])
39class TensorBoard(callbacks.TensorBoard):
40  # pylint: disable=line-too-long
41  """Enable visualizations for TensorBoard.
42
43  TensorBoard is a visualization tool provided with TensorFlow.
44
45  This callback logs events for TensorBoard, including:
46  * Metrics summary plots
47  * Training graph visualization
48  * Activation histograms
49  * Sampled profiling
50
51  If you have installed TensorFlow with pip, you should be able
52  to launch TensorBoard from the command line:
53
54  ```sh
55  tensorboard --logdir=path_to_your_logs
56  ```
57
58  You can find more information about TensorBoard
59  [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
60
61  Args:
62      log_dir: the path of the directory where to save the log files to be
63        parsed by TensorBoard.
64      histogram_freq: frequency (in epochs) at which to compute activation and
65        weight histograms for the layers of the model. If set to 0, histograms
66        won't be computed. Validation data (or split) must be specified for
67        histogram visualizations.
68      write_graph: whether to visualize the graph in TensorBoard. The log file
69        can become quite large when write_graph is set to True.
70      write_grads: whether to visualize gradient histograms in TensorBoard.
71        `histogram_freq` must be greater than 0.
72      batch_size: size of batch of inputs to feed to the network for histograms
73        computation.
74      write_images: whether to write model weights to visualize as image in
75        TensorBoard.
76      embeddings_freq: frequency (in epochs) at which selected embedding layers
77        will be saved. If set to 0, embeddings won't be computed. Data to be
78        visualized in TensorBoard's Embedding tab must be passed as
79        `embeddings_data`.
80      embeddings_layer_names: a list of names of layers to keep eye on. If None
81        or empty list all the embedding layer will be watched.
82      embeddings_metadata: a dictionary which maps layer name to a file name in
83        which metadata for this embedding layer is saved.
84          [Here are details](
85            https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
86            about metadata files format. In case if the same metadata file is
87            used for all embedding layers, string can be passed.
88      embeddings_data: data to be embedded at layers specified in
89        `embeddings_layer_names`. Numpy array (if the model has a single input)
90        or list of Numpy arrays (if the model has multiple inputs). Learn more
91        about embeddings [in this guide](
92          https://www.tensorflow.org/programmers_guide/embedding).
93      update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
94        writes the losses and metrics to TensorBoard after each batch. The same
95        applies for `'epoch'`. If using an integer, let's say `1000`, the
96        callback will write the metrics and losses to TensorBoard every 1000
97        samples. Note that writing too frequently to TensorBoard can slow down
98        your training.
99      profile_batch: Profile the batch to sample compute characteristics. By
100        default, it will profile the second batch. Set profile_batch=0 to
101        disable profiling.
102
103  Raises:
104      ValueError: If histogram_freq is set and no validation data is provided.
105
106  @compatibility(eager)
107  Using the `TensorBoard` callback will work when eager execution is enabled,
108  with the restriction that outputting histogram summaries of weights and
109  gradients is not supported. Consequently, `histogram_freq` will be ignored.
110  @end_compatibility
111  """
112
113  # pylint: enable=line-too-long
114
115  def __init__(self,
116               log_dir='./logs',
117               histogram_freq=0,
118               batch_size=32,
119               write_graph=True,
120               write_grads=False,
121               write_images=False,
122               embeddings_freq=0,
123               embeddings_layer_names=None,
124               embeddings_metadata=None,
125               embeddings_data=None,
126               update_freq='epoch',
127               profile_batch=2):
128    # Don't call super's init since it is an eager-only version.
129    callbacks.Callback.__init__(self)
130    self.log_dir = log_dir
131    self.histogram_freq = histogram_freq
132    if self.histogram_freq and context.executing_eagerly():
133      logging.warning(
134          UserWarning('Weight and gradient histograms not supported for eager'
135                      'execution, setting `histogram_freq` to `0`.'))
136      self.histogram_freq = 0
137    self.merged = None
138    self.write_graph = write_graph
139    self.write_grads = write_grads
140    self.write_images = write_images
141    self.batch_size = batch_size
142    self._current_batch = 0
143    self._total_batches_seen = 0
144    self._total_val_batches_seen = 0
145    self.embeddings_freq = embeddings_freq
146    self.embeddings_layer_names = embeddings_layer_names
147    self.embeddings_metadata = embeddings_metadata
148    self.embeddings_data = embeddings_data
149    if update_freq == 'batch':
150      self.update_freq = 1
151    else:
152      self.update_freq = update_freq
153    self._samples_seen = 0
154    self._samples_seen_at_last_write = 0
155    # TODO(fishx): Add a link to the full profiler tutorial.
156    self._profile_batch = profile_batch
157    # True when the profiler was successfully started by this callback.
158    # We track the status here to make sure callbacks do not interfere with
159    # each other. The callback will only stop the profiler it started.
160    self._profiler_started = False
161
162    # TensorBoard should only write summaries on the chief when in a
163    # Multi-Worker setting.
164    self._chief_worker_only = True
165
166  def _init_writer(self, model):
167    """Sets file writer."""
168    if context.executing_eagerly():
169      self.writer = summary_ops_v2.create_file_writer_v2(self.log_dir)
170      if not model.run_eagerly and self.write_graph:
171        with self.writer.as_default():
172          summary_ops_v2.graph(K.get_graph())
173    elif self.write_graph:
174      self.writer = tf_summary.FileWriter(self.log_dir, K.get_graph())
175    else:
176      self.writer = tf_summary.FileWriter(self.log_dir)
177
178  def _make_histogram_ops(self, model):
179    """Defines histogram ops when histogram_freq > 0."""
180    # only make histogram summary op if it hasn't already been made
181    if self.histogram_freq and self.merged is None:
182      for layer in self.model.layers:
183        for weight in layer.weights:
184          mapped_weight_name = weight.name.replace(':', '_')
185          tf_summary.histogram(mapped_weight_name, weight)
186          if self.write_images:
187            w_img = array_ops.squeeze(weight)
188            shape = K.int_shape(w_img)
189            if len(shape) == 2:  # dense layer kernel case
190              if shape[0] > shape[1]:
191                w_img = array_ops.transpose(w_img)
192                shape = K.int_shape(w_img)
193              w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
194            elif len(shape) == 3:  # convnet case
195              if K.image_data_format() == 'channels_last':
196                # switch to channels_first to display
197                # every kernel as a separate image
198                w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
199                shape = K.int_shape(w_img)
200              w_img = array_ops.reshape(w_img,
201                                        [shape[0], shape[1], shape[2], 1])
202            elif len(shape) == 1:  # bias case
203              w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
204            else:
205              # not possible to handle 3D convnets etc.
206              continue
207
208            shape = K.int_shape(w_img)
209            assert len(shape) == 4 and shape[-1] in [1, 3, 4]
210            tf_summary.image(mapped_weight_name, w_img)
211
212        if self.write_grads:
213          for weight in layer.trainable_weights:
214            mapped_weight_name = weight.name.replace(':', '_')
215            grads = model.optimizer.get_gradients(model.total_loss, weight)
216
217            def is_indexed_slices(grad):
218              return type(grad).__name__ == 'IndexedSlices'
219
220            grads = [
221                grad.values if is_indexed_slices(grad) else grad
222                for grad in grads
223            ]
224            tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
225
226        if hasattr(layer, 'output'):
227          if isinstance(layer.output, list):
228            for i, output in enumerate(layer.output):
229              tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
230          else:
231            tf_summary.histogram('{}_out'.format(layer.name), layer.output)
232
233  def set_model(self, model):
234    """Sets Keras model and creates summary ops."""
235
236    self.model = model
237    self._init_writer(model)
238    # histogram summaries only enabled in graph mode
239    if not context.executing_eagerly():
240      self._make_histogram_ops(model)
241      self.merged = tf_summary.merge_all()
242
243    # If both embedding_freq and embeddings_data are available, we will
244    # visualize embeddings.
245    if self.embeddings_freq and self.embeddings_data is not None:
246      # Avoid circular dependency.
247      from tensorflow.python.keras.engine import training_utils_v1  # pylint: disable=g-import-not-at-top
248      self.embeddings_data = training_utils_v1.standardize_input_data(
249          self.embeddings_data, model.input_names)
250
251      # If embedding_layer_names are not provided, get all of the embedding
252      # layers from the model.
253      embeddings_layer_names = self.embeddings_layer_names
254      if not embeddings_layer_names:
255        embeddings_layer_names = [
256            layer.name
257            for layer in self.model.layers
258            if type(layer).__name__ == 'Embedding'
259        ]
260
261      self.assign_embeddings = []
262      embeddings_vars = {}
263
264      self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
265      self.step = step = array_ops.placeholder(dtypes.int32)
266
267      for layer in self.model.layers:
268        if layer.name in embeddings_layer_names:
269          embedding_input = self.model.get_layer(layer.name).output
270          embedding_size = np.prod(embedding_input.shape[1:])
271          embedding_input = array_ops.reshape(embedding_input,
272                                              (step, int(embedding_size)))
273          shape = (self.embeddings_data[0].shape[0], int(embedding_size))
274          embedding = variables.Variable(
275              array_ops.zeros(shape), name=layer.name + '_embedding')
276          embeddings_vars[layer.name] = embedding
277          batch = state_ops.assign(embedding[batch_id:batch_id + step],
278                                   embedding_input)
279          self.assign_embeddings.append(batch)
280
281      self.saver = saver.Saver(list(embeddings_vars.values()))
282
283      # Create embeddings_metadata dictionary
284      if isinstance(self.embeddings_metadata, str):
285        embeddings_metadata = {
286            layer_name: self.embeddings_metadata
287            for layer_name in embeddings_vars.keys()
288        }
289      else:
290        # If embedding_metadata is already a dictionary
291        embeddings_metadata = self.embeddings_metadata
292
293      try:
294        from tensorboard.plugins import projector
295      except ImportError:
296        raise ImportError('Failed to import TensorBoard. Please make sure that '
297                          'TensorBoard integration is complete."')
298
299      # TODO(psv): Add integration tests to test embedding visualization
300      # with TensorBoard callback. We are unable to write a unit test for this
301      # because TensorBoard dependency assumes TensorFlow package is installed.
302      config = projector.ProjectorConfig()
303      for layer_name, tensor in embeddings_vars.items():
304        embedding = config.embeddings.add()
305        embedding.tensor_name = tensor.name
306
307        if (embeddings_metadata is not None and
308            layer_name in embeddings_metadata):
309          embedding.metadata_path = embeddings_metadata[layer_name]
310
311      projector.visualize_embeddings(self.writer, config)
312
313  def _fetch_callback(self, summary):
314    self.writer.add_summary(summary, self._total_val_batches_seen)
315    self._total_val_batches_seen += 1
316
317  def _write_custom_summaries(self, step, logs=None):
318    """Writes metrics out as custom scalar summaries.
319
320    Args:
321        step: the global step to use for TensorBoard.
322        logs: dict. Keys are scalar summary names, values are
323            NumPy scalars.
324
325    """
326    logs = logs or {}
327    if context.executing_eagerly():
328      # use v2 summary ops
329      with self.writer.as_default(), summary_ops_v2.record_if(True):
330        for name, value in logs.items():
331          if isinstance(value, np.ndarray):
332            value = value.item()
333          summary_ops_v2.scalar(name, value, step=step)
334    else:
335      # use FileWriter from v1 summary
336      for name, value in logs.items():
337        if isinstance(value, np.ndarray):
338          value = value.item()
339        summary = tf_summary.Summary()
340        summary_value = summary.value.add()
341        summary_value.simple_value = value
342        summary_value.tag = name
343        self.writer.add_summary(summary, step)
344    self.writer.flush()
345
346  def on_train_batch_begin(self, batch, logs=None):
347    if self._total_batches_seen == self._profile_batch - 1:
348      self._start_profiler()
349
350  def on_train_batch_end(self, batch, logs=None):
351    return self.on_batch_end(batch, logs)
352
353  def on_test_begin(self, logs=None):
354    pass
355
356  def on_test_end(self, logs=None):
357    pass
358
359  def on_batch_end(self, batch, logs=None):
360    """Writes scalar summaries for metrics on every training batch.
361
362    Performs profiling if current batch is in profiler_batches.
363    """
364    # Don't output batch_size and batch number as TensorBoard summaries
365    logs = logs or {}
366    self._samples_seen += logs.get('size', 1)
367    samples_seen_since = self._samples_seen - self._samples_seen_at_last_write
368    if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq:
369      batch_logs = {('batch_' + k): v
370                    for k, v in logs.items()
371                    if k not in ['batch', 'size', 'num_steps']}
372      self._write_custom_summaries(self._total_batches_seen, batch_logs)
373      self._samples_seen_at_last_write = self._samples_seen
374    self._total_batches_seen += 1
375    self._stop_profiler()
376
377  def on_train_begin(self, logs=None):
378    pass
379
380  def on_epoch_begin(self, epoch, logs=None):
381    """Add histogram op to Model eval_function callbacks, reset batch count."""
382
383    # check if histogram summary should be run for this epoch
384    if self.histogram_freq and epoch % self.histogram_freq == 0:
385      # pylint: disable=protected-access
386      # add the histogram summary op if it should run this epoch
387      self.model._make_test_function()
388      if self.merged not in self.model.test_function.fetches:
389        self.model.test_function.fetches.append(self.merged)
390        self.model.test_function.fetch_callbacks[
391            self.merged] = self._fetch_callback
392      # pylint: enable=protected-access
393
394  def on_epoch_end(self, epoch, logs=None):
395    """Checks if summary ops should run next epoch, logs scalar summaries."""
396
397    # don't output batch_size and
398    # batch number as TensorBoard summaries
399    logs = {('epoch_' + k): v
400            for k, v in logs.items()
401            if k not in ['batch', 'size', 'num_steps']}
402    if self.update_freq == 'epoch':
403      step = epoch
404    else:
405      step = self._samples_seen
406    self._write_custom_summaries(step, logs)
407
408    # pop the histogram summary op after each epoch
409    if self.histogram_freq:
410      # pylint: disable=protected-access
411      if self.merged in self.model.test_function.fetches:
412        self.model.test_function.fetches.remove(self.merged)
413      if self.merged in self.model.test_function.fetch_callbacks:
414        self.model.test_function.fetch_callbacks.pop(self.merged)
415      # pylint: enable=protected-access
416
417    if self.embeddings_data is None and self.embeddings_freq:
418      raise ValueError('To visualize embeddings, embeddings_data must '
419                       'be provided.')
420
421    if self.embeddings_freq and self.embeddings_data is not None:
422      if epoch % self.embeddings_freq == 0:
423        # We need a second forward-pass here because we're passing
424        # the `embeddings_data` explicitly. This design allows to pass
425        # arbitrary data as `embeddings_data` and results from the fact
426        # that we need to know the size of the `tf.Variable`s which
427        # hold the embeddings in `set_model`. At this point, however,
428        # the `validation_data` is not yet set.
429
430        embeddings_data = self.embeddings_data
431        n_samples = embeddings_data[0].shape[0]
432        i = 0
433        sess = K.get_session()
434        while i < n_samples:
435          step = min(self.batch_size, n_samples - i)
436          batch = slice(i, i + step)
437
438          if isinstance(self.model.input, list):
439            feed_dict = {
440                model_input: embeddings_data[idx][batch]
441                for idx, model_input in enumerate(self.model.input)
442            }
443          else:
444            feed_dict = {self.model.input: embeddings_data[0][batch]}
445
446          feed_dict.update({self.batch_id: i, self.step: step})
447
448          if not isinstance(K.learning_phase(), int):
449            feed_dict[K.learning_phase()] = False
450
451          sess.run(self.assign_embeddings, feed_dict=feed_dict)
452          self.saver.save(sess,
453                          os.path.join(self.log_dir, 'keras_embedding.ckpt'),
454                          epoch)
455
456          i += self.batch_size
457
458  def on_train_end(self, logs=None):
459    self._stop_profiler()
460    self.writer.close()
461
462  def _start_profiler(self):
463    """Starts the profiler if currently inactive."""
464    if self._profiler_started:
465      return
466    try:
467      profiler.start(logdir=self.log_dir)
468      self._profiler_started = True
469    except errors.AlreadyExistsError as e:
470      # Profiler errors should not be fatal.
471      logging.error('Failed to start profiler: %s', e.message)
472
473  def _stop_profiler(self):
474    """Stops the profiler if currently active."""
475    if not self._profiler_started:
476      return
477    try:
478      profiler.stop()
479    except errors.UnavailableError as e:
480      # Profiler errors should not be fatal.
481      logging.error('Failed to stop profiler: %s', e.message)
482    finally:
483      self._profiler_started = False
484