# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # pylint: disable=protected-access # pylint: disable=g-import-not-at-top """Utilities related to model visualization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys from tensorflow.python.util.tf_export import keras_export try: # pydot-ng is a fork of pydot that is better maintained. import pydot_ng as pydot except ImportError: # pydotplus is an improved version of pydot try: import pydotplus as pydot except ImportError: # Fall back on pydot if necessary. try: import pydot except ImportError: pydot = None def _check_pydot(): try: # Attempt to create an image of a blank graph # to check the pydot/graphviz installation. pydot.Dot.create(pydot.Dot()) return True except Exception: # pylint: disable=broad-except # pydot raises a generic Exception here, # so no specific class can be caught. return False def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'): """Convert a Keras model to dot format. Arguments: model: A Keras model instance. show_shapes: whether to display shape information. show_layer_names: whether to display layer names. rankdir: `rankdir` argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot. Returns: A `pydot.Dot` instance representing the Keras model (or None if the Dot file could not be generated). Raises: ImportError: if graphviz or pydot are not available. """ from tensorflow.python.keras.layers.wrappers import Wrapper from tensorflow.python.keras.models import Sequential from tensorflow.python.util import nest check = _check_pydot() if not check: if 'IPython.core.magics.namespace' in sys.modules: # We don't raise an exception here in order to avoid crashing notebook # tests where graphviz is not available. print('Failed to import pydot. You must install pydot' ' and graphviz for `pydotprint` to work.') return else: raise ImportError('Failed to import pydot. You must install pydot' ' and graphviz for `pydotprint` to work.') dot = pydot.Dot() dot.set('rankdir', rankdir) dot.set('concentrate', True) dot.set_node_defaults(shape='record') if isinstance(model, Sequential): if not model.built: model.build() layers = model._layers # Create graph nodes. for layer in layers: layer_id = str(id(layer)) # Append a wrapped layer's label to node's label, if it exists. layer_name = layer.name class_name = layer.__class__.__name__ if isinstance(layer, Wrapper): layer_name = '{}({})'.format(layer_name, layer.layer.name) child_class_name = layer.layer.__class__.__name__ class_name = '{}({})'.format(class_name, child_class_name) # Create node's label. if show_layer_names: label = '{}: {}'.format(layer_name, class_name) else: label = class_name # Rebuild the label as a table including input/output shapes. if show_shapes: try: outputlabels = str(layer.output_shape) except AttributeError: outputlabels = 'multiple' if hasattr(layer, 'input_shape'): inputlabels = str(layer.input_shape) elif hasattr(layer, 'input_shapes'): inputlabels = ', '.join([str(ishape) for ishape in layer.input_shapes]) else: inputlabels = 'multiple' label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels, outputlabels) node = pydot.Node(layer_id, label=label) dot.add_node(node) # Connect nodes with edges. for layer in layers: layer_id = str(id(layer)) for i, node in enumerate(layer._inbound_nodes): node_key = layer.name + '_ib-' + str(i) if node_key in model._network_nodes: # pylint: disable=protected-access for inbound_layer in nest.flatten(node.inbound_layers): inbound_layer_id = str(id(inbound_layer)) layer_id = str(id(layer)) dot.add_edge(pydot.Edge(inbound_layer_id, layer_id)) return dot @keras_export('keras.utils.plot_model') def plot_model(model, to_file='model.png', show_shapes=False, show_layer_names=True, rankdir='TB'): """Converts a Keras model to dot format and save to a file. Arguments: model: A Keras model instance to_file: File name of the plot image. show_shapes: whether to display shape information. show_layer_names: whether to display layer names. rankdir: `rankdir` argument passed to PyDot, a string specifying the format of the plot: 'TB' creates a vertical plot; 'LR' creates a horizontal plot. Returns: A Jupyter notebook Image object if Jupyter is installed. This enables in-line display of the model plots in notebooks. """ dot = model_to_dot(model, show_shapes, show_layer_names, rankdir) if dot is None: return _, extension = os.path.splitext(to_file) if not extension: extension = 'png' else: extension = extension[1:] # Save image to disk. dot.write(to_file, format=extension) # Return the image as a Jupyter Image object, to be displayed in-line. # Note that we cannot easily detect whether the code is running in a # notebook, and thus we always return the Image if Jupyter is available. try: from IPython import display return display.Image(filename=to_file) except ImportError: pass