1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras Vis utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python import keras 22from tensorflow.python.keras.utils import vis_utils 23from tensorflow.python.lib.io import file_io 24from tensorflow.python.ops import math_ops 25from tensorflow.python.platform import test 26 27 28class ModelToDotFormatTest(test.TestCase): 29 30 def test_plot_model_cnn(self): 31 model = keras.Sequential() 32 model.add( 33 keras.layers.Conv2D( 34 filters=2, kernel_size=(2, 3), input_shape=(3, 5, 5), name='conv')) 35 model.add(keras.layers.Flatten(name='flat')) 36 model.add(keras.layers.Dense(5, name='dense')) 37 dot_img_file = 'model_1.png' 38 try: 39 vis_utils.plot_model( 40 model, to_file=dot_img_file, show_shapes=True, show_dtype=True) 41 self.assertTrue(file_io.file_exists_v2(dot_img_file)) 42 file_io.delete_file_v2(dot_img_file) 43 except ImportError: 44 pass 45 46 def test_plot_model_with_wrapped_layers_and_models(self): 47 inputs = keras.Input(shape=(None, 3)) 48 lstm = keras.layers.LSTM(6, return_sequences=True, name='lstm') 49 x = lstm(inputs) 50 # Add layer inside a Wrapper 51 bilstm = keras.layers.Bidirectional( 52 keras.layers.LSTM(16, return_sequences=True, name='bilstm')) 53 x = bilstm(x) 54 # Add model inside a Wrapper 55 submodel = keras.Sequential( 56 [keras.layers.Dense(32, name='dense', input_shape=(None, 32))] 57 ) 58 wrapped_dense = keras.layers.TimeDistributed(submodel) 59 x = wrapped_dense(x) 60 # Add shared submodel 61 outputs = submodel(x) 62 model = keras.Model(inputs, outputs) 63 dot_img_file = 'model_2.png' 64 try: 65 vis_utils.plot_model( 66 model, 67 to_file=dot_img_file, 68 show_shapes=True, 69 show_dtype=True, 70 expand_nested=True) 71 self.assertTrue(file_io.file_exists_v2(dot_img_file)) 72 file_io.delete_file_v2(dot_img_file) 73 except ImportError: 74 pass 75 76 def test_plot_model_with_add_loss(self): 77 inputs = keras.Input(shape=(None, 3)) 78 outputs = keras.layers.Dense(1)(inputs) 79 model = keras.Model(inputs, outputs) 80 model.add_loss(math_ops.reduce_mean(outputs)) 81 dot_img_file = 'model_3.png' 82 try: 83 vis_utils.plot_model( 84 model, 85 to_file=dot_img_file, 86 show_shapes=True, 87 show_dtype=True, 88 expand_nested=True) 89 self.assertTrue(file_io.file_exists_v2(dot_img_file)) 90 file_io.delete_file_v2(dot_img_file) 91 except ImportError: 92 pass 93 94 model = keras.Sequential([ 95 keras.Input(shape=(None, 3)), keras.layers.Dense(1)]) 96 model.add_loss(math_ops.reduce_mean(model.output)) 97 dot_img_file = 'model_4.png' 98 try: 99 vis_utils.plot_model( 100 model, 101 to_file=dot_img_file, 102 show_shapes=True, 103 show_dtype=True, 104 expand_nested=True) 105 self.assertTrue(file_io.file_exists_v2(dot_img_file)) 106 file_io.delete_file_v2(dot_img_file) 107 except ImportError: 108 pass 109 110 111if __name__ == '__main__': 112 test.main() 113