1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for automatic outside compilation for TF 2.0/Keras.""" 16 17import os 18 19from absl import flags 20import numpy as np 21 22from tensorboard.plugins.histogram import summary_v2 as histogram_summary_v2 23from tensorboard.plugins.image import summary_v2 as image_summary_v2 24from tensorboard.plugins.scalar import summary_v2 as scalar_summary_v2 25from tensorflow.python.compat import v2_compat 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.distribute import tpu_strategy as tpu_strategy_lib 28from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 29from tensorflow.python.eager import def_function 30from tensorflow.python.eager import remote 31from tensorflow.python.eager.context import set_soft_device_placement 32from tensorflow.python.framework import ops 33from tensorflow.python.keras import callbacks 34from tensorflow.python.keras import initializers 35from tensorflow.python.keras.distribute import distribute_strategy_test 36from tensorflow.python.keras.engine import base_layer 37from tensorflow.python.keras.engine import sequential as sequential_model_lib 38from tensorflow.python.keras.engine import training 39from tensorflow.python.keras.layers import convolutional as conv_layer_lib 40from tensorflow.python.keras.layers import core as layer_lib 41from tensorflow.python.keras.layers import pooling as pool_layer_lib 42from tensorflow.python.lib.io import file_io 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import summary_ops_v2 45# from tensorflow.python.platform import flags 46from tensorflow.python.platform import test 47from tensorflow.python.summary import summary_iterator 48from tensorflow.python.tpu import tpu_strategy_util 49 50NUM_CLASSES = 4 51 52FLAGS = flags.FLAGS 53flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.') 54flags.DEFINE_string('project', None, 'Name of GCP project with TPU.') 55flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.') 56 57 58def get_tpu_cluster_resolver(): 59 resolver = tpu_cluster_resolver.TPUClusterResolver( 60 tpu=FLAGS.tpu, 61 zone=FLAGS.zone, 62 project=FLAGS.project, 63 ) 64 return resolver 65 66 67def get_tpu_strategy(): 68 resolver = get_tpu_cluster_resolver() 69 remote.connect_to_cluster(resolver) 70 tpu_strategy_util.initialize_tpu_system(resolver) 71 return tpu_strategy_lib.TPUStrategy(resolver) 72 73 74class LayerForScalarSummary(base_layer.Layer): 75 """A pass-through layer that only records scalar values to summary.""" 76 77 def call(self, x): 78 # Add summary scalar using compat v2 implementation. 79 scalar_summary_v2.scalar('custom_scalar_summary_v2', math_ops.reduce_sum(x)) 80 return x 81 82 83class LayerForImageSummary(base_layer.Layer): 84 """A pass-through layer that only records image values to summary.""" 85 86 def call(self, x): 87 # Add summary image using compat v2 implementation. 88 image_summary_v2.image('custom_image_summary_v2', x) 89 90 return x 91 92 93class LayerForHistogramSummary(base_layer.Layer): 94 """A pass-through layer that records histogram values to summary.""" 95 96 def call(self, x): 97 # Add summary histogram using compat v2 implementation. 98 histogram_summary_v2.histogram('custom_histogram_summary_v2', x) 99 100 return x 101 102 103class CustomModel(training.Model): 104 """Custom model with summary ops in model call definition.""" 105 106 def __init__(self, name=None): 107 super(CustomModel, self).__init__() 108 self._my_layers = [ 109 layer_lib.Dense( 110 4096, 111 name='dense1', 112 kernel_initializer=initializers.glorot_normal(seed=0), 113 use_bias=False), 114 layer_lib.Dense( 115 4, 116 name='dense2', 117 kernel_initializer=initializers.glorot_normal(seed=0), 118 use_bias=False), 119 ] 120 self.histogram_summary_layer = LayerForHistogramSummary() 121 self.scalar_summary_layer = LayerForScalarSummary() 122 123 def call(self, x): 124 for layer in self._my_layers: 125 x = layer(x) 126 x = self.scalar_summary_layer(x) 127 return self.histogram_summary_layer(x) 128 129 130def get_image_dataset(): 131 inputs = np.zeros((10, 28, 28, 3), dtype=np.float32) 132 targets = np.zeros((10, NUM_CLASSES), dtype=np.float32) 133 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 134 dataset = dataset.repeat(100) 135 dataset = dataset.batch(10, drop_remainder=True) 136 return dataset 137 138 139def mnist_model(input_shape): 140 """Creates a MNIST model.""" 141 model = sequential_model_lib.Sequential() 142 143 # Adding custom pass-through layer to visualize input images. 144 model.add(LayerForImageSummary()) 145 146 model.add( 147 conv_layer_lib.Conv2D( 148 32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) 149 model.add(conv_layer_lib.Conv2D(64, (3, 3), activation='relu')) 150 model.add(pool_layer_lib.MaxPooling2D(pool_size=(2, 2))) 151 model.add(layer_lib.Dropout(0.25)) 152 model.add(layer_lib.Flatten()) 153 model.add(layer_lib.Dense(128, activation='relu')) 154 model.add(layer_lib.Dropout(0.5)) 155 model.add(layer_lib.Dense(NUM_CLASSES, activation='softmax')) 156 157 # Adding custom pass-through layer for summary recording. 158 model.add(LayerForHistogramSummary()) 159 return model 160 161 162class AutoOutsideCompilationWithKerasTest(test.TestCase): 163 164 def setUp(self): 165 super(AutoOutsideCompilationWithKerasTest, self).setUp() 166 v2_compat.enable_v2_behavior() 167 set_soft_device_placement(True) 168 self.summary_dir = self.get_temp_dir() 169 170 def validate_recorded_sumary_file(self, event_files, summary_dict, 171 expected_count): 172 for event_file in event_files: 173 for e in summary_iterator.summary_iterator(event_file): 174 for v in e.summary.value: 175 if v.tag in summary_dict: 176 summary_dict[v.tag] += 1 177 178 for key in summary_dict: 179 self.assertEqual(summary_dict[key], expected_count) 180 181 def testV2SummaryWithKerasSequentialModel(self): 182 strategy = get_tpu_strategy() 183 184 with strategy.scope(): 185 model = mnist_model((28, 28, 3)) 186 model.compile('sgd', 'mse') 187 188 dataset = get_image_dataset() 189 tensorboard_callback = callbacks.TensorBoard( 190 self.summary_dir, update_freq=2) 191 model.fit( 192 dataset, 193 steps_per_epoch=10, 194 epochs=1, 195 callbacks=[tensorboard_callback]) 196 197 events_count_dictionary = { 198 'sequential/layer_for_histogram_summary/custom_histogram_summary_v2': 199 0, 200 'sequential/layer_for_image_summary/custom_image_summary_v2': 201 0, 202 } 203 204 event_files = file_io.get_matching_files_v2( 205 os.path.join(self.summary_dir, 'train', 'event*')) 206 # Since total of 10 steps are ran and summary ops should be invoked 207 # every 2 batches, we should see total of 5 event logs. 208 self.validate_recorded_sumary_file(event_files, events_count_dictionary, 209 5) 210 211 def testV2SummaryWithKerasSubclassedModel(self): 212 strategy = get_tpu_strategy() 213 214 with strategy.scope(): 215 model = CustomModel() 216 model.compile('sgd', 'mse') 217 218 dataset = distribute_strategy_test.get_dataset(strategy) 219 tensorboard_callback = callbacks.TensorBoard( 220 self.summary_dir, update_freq=2) 221 model.fit( 222 dataset, 223 steps_per_epoch=10, 224 epochs=1, 225 callbacks=[tensorboard_callback]) 226 227 event_files = file_io.get_matching_files_v2( 228 os.path.join(self.summary_dir, 'train', 'event*')) 229 events_count_dictionary = { 230 ('custom_model/layer_for_scalar_summary/' 231 'custom_scalar_summary_v2'): 232 0, 233 ('custom_model/layer_for_histogram_summary/' 234 'custom_histogram_summary_v2'): 235 0 236 } 237 238 # Since total of 10 steps are ran and summary ops should be invoked 239 # every 2 batches, we should see total of 5 event logs. 240 self.validate_recorded_sumary_file(event_files, events_count_dictionary, 241 5) 242 243 def testSummaryWithCustomTrainingLoop(self): 244 strategy = get_tpu_strategy() 245 246 writer = summary_ops_v2.create_file_writer_v2(self.summary_dir) 247 with strategy.scope(): 248 model = distribute_strategy_test.get_model() 249 model.compile('sgd', 'mse') 250 251 @def_function.function 252 def custom_function(dataset): 253 254 def _custom_step(features, labels): 255 del labels 256 logits = model(features) 257 with summary_ops_v2.record_if(True), writer.as_default(): 258 scalar_summary_v2.scalar( 259 'logits', 260 math_ops.reduce_sum(logits), 261 step=model.optimizer.iterations) 262 return logits 263 264 iterator = iter(dataset) 265 output = strategy.unwrap( 266 strategy.run(_custom_step, args=(next(iterator)))) 267 return output 268 269 dataset = strategy.experimental_distribute_dataset( 270 distribute_strategy_test.get_dataset(strategy)) 271 272 custom_function(dataset) 273 writer.close() 274 275 event_files = file_io.get_matching_files_v2( 276 os.path.join(self.summary_dir, 'event*')) 277 events_count_dictionary = { 278 ('logits'): 0, 279 } 280 self.validate_recorded_sumary_file(event_files, events_count_dictionary, 281 1) 282 283 284if __name__ == '__main__': 285 ops.enable_eager_execution() 286 test.main() 287