• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras Vis utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python import keras
22from tensorflow.python.keras.utils import vis_utils
23from tensorflow.python.lib.io import file_io
24from tensorflow.python.ops import math_ops
25from tensorflow.python.platform import test
26
27
28class ModelToDotFormatTest(test.TestCase):
29
30  def test_plot_model_cnn(self):
31    model = keras.Sequential()
32    model.add(
33        keras.layers.Conv2D(
34            filters=2, kernel_size=(2, 3), input_shape=(3, 5, 5), name='conv'))
35    model.add(keras.layers.Flatten(name='flat'))
36    model.add(keras.layers.Dense(5, name='dense'))
37    dot_img_file = 'model_1.png'
38    try:
39      vis_utils.plot_model(
40          model, to_file=dot_img_file, show_shapes=True, show_dtype=True)
41      self.assertTrue(file_io.file_exists_v2(dot_img_file))
42      file_io.delete_file_v2(dot_img_file)
43    except ImportError:
44      pass
45
46  def test_plot_model_with_wrapped_layers_and_models(self):
47    inputs = keras.Input(shape=(None, 3))
48    lstm = keras.layers.LSTM(6, return_sequences=True, name='lstm')
49    x = lstm(inputs)
50    # Add layer inside a Wrapper
51    bilstm = keras.layers.Bidirectional(
52        keras.layers.LSTM(16, return_sequences=True, name='bilstm'))
53    x = bilstm(x)
54    # Add model inside a Wrapper
55    submodel = keras.Sequential(
56        [keras.layers.Dense(32, name='dense', input_shape=(None, 32))]
57    )
58    wrapped_dense = keras.layers.TimeDistributed(submodel)
59    x = wrapped_dense(x)
60    # Add shared submodel
61    outputs = submodel(x)
62    model = keras.Model(inputs, outputs)
63    dot_img_file = 'model_2.png'
64    try:
65      vis_utils.plot_model(
66          model,
67          to_file=dot_img_file,
68          show_shapes=True,
69          show_dtype=True,
70          expand_nested=True)
71      self.assertTrue(file_io.file_exists_v2(dot_img_file))
72      file_io.delete_file_v2(dot_img_file)
73    except ImportError:
74      pass
75
76  def test_plot_model_with_add_loss(self):
77    inputs = keras.Input(shape=(None, 3))
78    outputs = keras.layers.Dense(1)(inputs)
79    model = keras.Model(inputs, outputs)
80    model.add_loss(math_ops.reduce_mean(outputs))
81    dot_img_file = 'model_3.png'
82    try:
83      vis_utils.plot_model(
84          model,
85          to_file=dot_img_file,
86          show_shapes=True,
87          show_dtype=True,
88          expand_nested=True)
89      self.assertTrue(file_io.file_exists_v2(dot_img_file))
90      file_io.delete_file_v2(dot_img_file)
91    except ImportError:
92      pass
93
94    model = keras.Sequential([
95        keras.Input(shape=(None, 3)), keras.layers.Dense(1)])
96    model.add_loss(math_ops.reduce_mean(model.output))
97    dot_img_file = 'model_4.png'
98    try:
99      vis_utils.plot_model(
100          model,
101          to_file=dot_img_file,
102          show_shapes=True,
103          show_dtype=True,
104          expand_nested=True)
105      self.assertTrue(file_io.file_exists_v2(dot_img_file))
106      file_io.delete_file_v2(dot_img_file)
107    except ImportError:
108      pass
109
110
111if __name__ == '__main__':
112  test.main()
113