• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for automatic outside compilation for TF 2.0/Keras."""
16
17import os
18
19from absl import flags
20import numpy as np
21
22from tensorboard.plugins.histogram import summary_v2 as histogram_summary_v2
23from tensorboard.plugins.image import summary_v2 as image_summary_v2
24from tensorboard.plugins.scalar import summary_v2 as scalar_summary_v2
25from tensorflow.python.compat import v2_compat
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.distribute import tpu_strategy as tpu_strategy_lib
28from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
29from tensorflow.python.eager import def_function
30from tensorflow.python.eager import remote
31from tensorflow.python.eager.context import set_soft_device_placement
32from tensorflow.python.framework import ops
33from tensorflow.python.keras import callbacks
34from tensorflow.python.keras import initializers
35from tensorflow.python.keras.distribute import distribute_strategy_test
36from tensorflow.python.keras.engine import base_layer
37from tensorflow.python.keras.engine import sequential as sequential_model_lib
38from tensorflow.python.keras.engine import training
39from tensorflow.python.keras.layers import convolutional as conv_layer_lib
40from tensorflow.python.keras.layers import core as layer_lib
41from tensorflow.python.keras.layers import pooling as pool_layer_lib
42from tensorflow.python.lib.io import file_io
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import summary_ops_v2
45# from tensorflow.python.platform import flags
46from tensorflow.python.platform import test
47from tensorflow.python.summary import summary_iterator
48from tensorflow.python.tpu import tpu_strategy_util
49
50NUM_CLASSES = 4
51
52FLAGS = flags.FLAGS
53flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
54flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
55flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
56
57
58def get_tpu_cluster_resolver():
59  resolver = tpu_cluster_resolver.TPUClusterResolver(
60      tpu=FLAGS.tpu,
61      zone=FLAGS.zone,
62      project=FLAGS.project,
63  )
64  return resolver
65
66
67def get_tpu_strategy():
68  resolver = get_tpu_cluster_resolver()
69  remote.connect_to_cluster(resolver)
70  tpu_strategy_util.initialize_tpu_system(resolver)
71  return tpu_strategy_lib.TPUStrategy(resolver)
72
73
74class LayerForScalarSummary(base_layer.Layer):
75  """A pass-through layer that only records scalar values to summary."""
76
77  def call(self, x):
78    # Add summary scalar using compat v2 implementation.
79    scalar_summary_v2.scalar('custom_scalar_summary_v2', math_ops.reduce_sum(x))
80    return x
81
82
83class LayerForImageSummary(base_layer.Layer):
84  """A pass-through layer that only records image values to summary."""
85
86  def call(self, x):
87    # Add summary image using compat v2 implementation.
88    image_summary_v2.image('custom_image_summary_v2', x)
89
90    return x
91
92
93class LayerForHistogramSummary(base_layer.Layer):
94  """A pass-through layer that records histogram values to summary."""
95
96  def call(self, x):
97    # Add summary histogram using compat v2 implementation.
98    histogram_summary_v2.histogram('custom_histogram_summary_v2', x)
99
100    return x
101
102
103class CustomModel(training.Model):
104  """Custom model with summary ops in model call definition."""
105
106  def __init__(self, name=None):
107    super(CustomModel, self).__init__()
108    self._my_layers = [
109        layer_lib.Dense(
110            4096,
111            name='dense1',
112            kernel_initializer=initializers.glorot_normal(seed=0),
113            use_bias=False),
114        layer_lib.Dense(
115            4,
116            name='dense2',
117            kernel_initializer=initializers.glorot_normal(seed=0),
118            use_bias=False),
119    ]
120    self.histogram_summary_layer = LayerForHistogramSummary()
121    self.scalar_summary_layer = LayerForScalarSummary()
122
123  def call(self, x):
124    for layer in self._my_layers:
125      x = layer(x)
126    x = self.scalar_summary_layer(x)
127    return self.histogram_summary_layer(x)
128
129
130def get_image_dataset():
131  inputs = np.zeros((10, 28, 28, 3), dtype=np.float32)
132  targets = np.zeros((10, NUM_CLASSES), dtype=np.float32)
133  dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
134  dataset = dataset.repeat(100)
135  dataset = dataset.batch(10, drop_remainder=True)
136  return dataset
137
138
139def mnist_model(input_shape):
140  """Creates a MNIST model."""
141  model = sequential_model_lib.Sequential()
142
143  # Adding custom pass-through layer to visualize input images.
144  model.add(LayerForImageSummary())
145
146  model.add(
147      conv_layer_lib.Conv2D(
148          32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
149  model.add(conv_layer_lib.Conv2D(64, (3, 3), activation='relu'))
150  model.add(pool_layer_lib.MaxPooling2D(pool_size=(2, 2)))
151  model.add(layer_lib.Dropout(0.25))
152  model.add(layer_lib.Flatten())
153  model.add(layer_lib.Dense(128, activation='relu'))
154  model.add(layer_lib.Dropout(0.5))
155  model.add(layer_lib.Dense(NUM_CLASSES, activation='softmax'))
156
157  # Adding custom pass-through layer for summary recording.
158  model.add(LayerForHistogramSummary())
159  return model
160
161
162class AutoOutsideCompilationWithKerasTest(test.TestCase):
163
164  def setUp(self):
165    super(AutoOutsideCompilationWithKerasTest, self).setUp()
166    v2_compat.enable_v2_behavior()
167    set_soft_device_placement(True)
168    self.summary_dir = self.get_temp_dir()
169
170  def validate_recorded_sumary_file(self, event_files, summary_dict,
171                                    expected_count):
172    for event_file in event_files:
173      for e in summary_iterator.summary_iterator(event_file):
174        for v in e.summary.value:
175          if v.tag in summary_dict:
176            summary_dict[v.tag] += 1
177
178    for key in summary_dict:
179      self.assertEqual(summary_dict[key], expected_count)
180
181  def testV2SummaryWithKerasSequentialModel(self):
182    strategy = get_tpu_strategy()
183
184    with strategy.scope():
185      model = mnist_model((28, 28, 3))
186      model.compile('sgd', 'mse')
187
188      dataset = get_image_dataset()
189      tensorboard_callback = callbacks.TensorBoard(
190          self.summary_dir, update_freq=2)
191      model.fit(
192          dataset,
193          steps_per_epoch=10,
194          epochs=1,
195          callbacks=[tensorboard_callback])
196
197      events_count_dictionary = {
198          'sequential/layer_for_histogram_summary/custom_histogram_summary_v2':
199              0,
200          'sequential/layer_for_image_summary/custom_image_summary_v2':
201              0,
202      }
203
204      event_files = file_io.get_matching_files_v2(
205          os.path.join(self.summary_dir, 'train', 'event*'))
206      # Since total of 10 steps are ran and summary ops should be invoked
207      # every 2 batches, we should see total of 5 event logs.
208      self.validate_recorded_sumary_file(event_files, events_count_dictionary,
209                                         5)
210
211  def testV2SummaryWithKerasSubclassedModel(self):
212    strategy = get_tpu_strategy()
213
214    with strategy.scope():
215      model = CustomModel()
216      model.compile('sgd', 'mse')
217
218      dataset = distribute_strategy_test.get_dataset(strategy)
219      tensorboard_callback = callbacks.TensorBoard(
220          self.summary_dir, update_freq=2)
221      model.fit(
222          dataset,
223          steps_per_epoch=10,
224          epochs=1,
225          callbacks=[tensorboard_callback])
226
227      event_files = file_io.get_matching_files_v2(
228          os.path.join(self.summary_dir, 'train', 'event*'))
229      events_count_dictionary = {
230          ('custom_model/layer_for_scalar_summary/'
231           'custom_scalar_summary_v2'):
232              0,
233          ('custom_model/layer_for_histogram_summary/'
234           'custom_histogram_summary_v2'):
235              0
236      }
237
238      # Since total of 10 steps are ran and summary ops should be invoked
239      # every 2 batches, we should see total of 5 event logs.
240      self.validate_recorded_sumary_file(event_files, events_count_dictionary,
241                                         5)
242
243  def testSummaryWithCustomTrainingLoop(self):
244    strategy = get_tpu_strategy()
245
246    writer = summary_ops_v2.create_file_writer_v2(self.summary_dir)
247    with strategy.scope():
248      model = distribute_strategy_test.get_model()
249      model.compile('sgd', 'mse')
250
251    @def_function.function
252    def custom_function(dataset):
253
254      def _custom_step(features, labels):
255        del labels
256        logits = model(features)
257        with summary_ops_v2.record_if(True), writer.as_default():
258          scalar_summary_v2.scalar(
259              'logits',
260              math_ops.reduce_sum(logits),
261              step=model.optimizer.iterations)
262        return logits
263
264      iterator = iter(dataset)
265      output = strategy.unwrap(
266          strategy.run(_custom_step, args=(next(iterator))))
267      return output
268
269    dataset = strategy.experimental_distribute_dataset(
270        distribute_strategy_test.get_dataset(strategy))
271
272    custom_function(dataset)
273    writer.close()
274
275    event_files = file_io.get_matching_files_v2(
276        os.path.join(self.summary_dir, 'event*'))
277    events_count_dictionary = {
278        ('logits'): 0,
279    }
280    self.validate_recorded_sumary_file(event_files, events_count_dictionary,
281                                       1)
282
283
284if __name__ == '__main__':
285  ops.enable_eager_execution()
286  test.main()
287