• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This file contains integration test for TPUStrategy in regards to memory."""
16
17import gc
18
19import tensorflow as tf
20
21from tensorflow.python.eager import context
22from tensorflow.python.platform import flags
23
24FLAGS = flags.FLAGS
25NUM_CLASS = 10
26
27
28def get_dataset():
29
30  def generate_data(_):
31    image = tf.ones([500, 500, 3], dtype=tf.float32)
32    label = tf.zeros([1], dtype=tf.int32)
33    return image, label
34
35  def preprocess(image, label):
36    label = tf.cast(label, tf.int32)
37    label = tf.one_hot(label, NUM_CLASS)
38    label = tf.reshape(label, [NUM_CLASS])
39    return image, label
40
41  dataset = tf.data.Dataset.range(1)
42  dataset = dataset.repeat()
43  dataset = dataset.map(
44      generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
45  dataset = dataset.map(
46      preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
47  dataset = dataset.repeat()
48  dataset = dataset.batch(128, drop_remainder=True)
49  return dataset
50
51
52class TpuMemoryTest(tf.test.TestCase):
53
54  def setUp(self):
55    super().setUp()
56    # Clear all cached tensors
57    context._reset_context()
58    # Run garbage collection to free any tensors from previous
59    # runs.
60    gc.collect()
61
62    # Run a small program and copy the result to CPU.
63    # This causes deferred deallocations to be flushed and new memory to be
64    # allocated in a less fragmented way.
65    # Turning deferred deallocations off no longer seems to work.
66    assert tf.reduce_sum(tf.random.uniform(
67        (1024, 128), dtype=tf.float32)).numpy() > 1.0
68
69    self.resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
70        tpu="", project=None, zone=None)
71
72    tf.config.experimental_connect_to_cluster(self.resolver)
73    tf.tpu.experimental.initialize_tpu_system(self.resolver)
74
75  def testAutoDefragInProgramLoading(self):
76    # This test covers the case when training a large model on TPU. TPU HBM
77    # is not big enough to hold all TPU buffers and preserve stack for the
78    # TPU program. Runtime will automatically unload unused TPU program to
79    # free up space for TPU buffers. Having lots of TPU buffer may also
80    # introduce fragmentation in HBM to prevent us loading a TPU program
81    # properly. Runtime will automatically defrag in order to load a large
82    # TPU program.
83
84    strategy = tf.distribute.TPUStrategy(self.resolver)
85    dataset = get_dataset()
86    iterator = iter(
87        strategy.experimental_distribute_dataset(dataset,
88                                                 tf.distribute.InputOptions()))
89
90    # Create a dummy big model that is close to HBM limit (15G):
91    # Temp HBM: 11G
92    # Sharded variable size: 2G
93    # Unsharded variables size: 4G
94    with strategy.scope():
95      x = tf.keras.layers.Input(shape=(500, 500, 3), name="input")
96      y = tf.keras.layers.Conv2D(
97          384, (15, 15),
98          strides=(2, 2),
99          padding="valid",
100          use_bias=False,
101          kernel_initializer="he_normal",
102          name="conv1")(
103              x)
104      y = tf.keras.layers.BatchNormalization(
105          momentum=0.997, center=True, scale=True)(
106              y)
107      y = tf.keras.layers.Dense(
108          10,
109          activation="softmax",
110          kernel_initializer=tf.random_normal_initializer(stddev=0.01))(
111              y)
112      y = tf.keras.layers.Conv2D(
113          64, (9, 9),
114          strides=(2, 2),
115          padding="valid",
116          use_bias=False,
117          kernel_initializer="he_normal",
118          name="conv2")(
119              y)
120      y = tf.keras.layers.Flatten()(y)
121      y = tf.keras.layers.Dense(
122          1024,
123          activation="softmax",
124          kernel_initializer=tf.random_normal_initializer(stddev=0.01))(
125              y)
126      y = tf.keras.layers.Dense(
127          1024,
128          activation="softmax",
129          kernel_initializer=tf.random_normal_initializer(stddev=0.01))(
130              y)
131      y = tf.keras.layers.Dense(
132          NUM_CLASS,
133          activation="softmax",
134          kernel_initializer=tf.random_normal_initializer(stddev=0.01))(
135              y)
136      model = tf.keras.Model(x, y)
137      optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
138      loss_obj = tf.keras.losses.CategoricalCrossentropy(
139          label_smoothing=0.0, reduction=tf.keras.losses.Reduction.NONE)
140      model.compile(optimizer=optimizer, loss=loss_obj)
141
142    @tf.function
143    def train_step(iterator):
144
145      def step_fn(inputs):
146        images, targets = inputs
147        with tf.GradientTape() as tape:
148          outputs = model(images, training=True)
149          loss = model.loss(targets, outputs)
150
151        grads = tape.gradient(loss, model.trainable_variables)
152        model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
153        return loss
154
155      # Using host training loop here to trigger weight-update-sharding. It will
156      # introduce shard variable and unshard variable ops into the graph.
157      # When running unshard variable op, HBM won't have enough space for
158      # unsharded variables: 11G + 2G + 4G > 15G. So Runtime will have to
159      # automatically unload step function to free up space for unshard
160      # variable op.
161      for _ in tf.range(tf.constant(20)):
162        strategy.run(step_fn, args=(next(iterator),))
163
164      # We want to load the step function again after unshard variable op.
165      # However, we won't have enough space due to fragamentation:
166      # 15G - 2G - 4G < 11G. So Runtime will have to automatically defrag
167      # in order to load the program successfully.
168      strategy.run(step_fn, args=(next(iterator),))
169
170      # A dummy result to indicate this @tf.function has finished.
171      return 1.0
172
173    if FLAGS.tpu_use_tfrt:
174      result = train_step(iterator)
175
176      self.assertAllClose(1.0, result, atol=1e-07)
177    else:
178      # TPU StreamExecutor does not support auto-defrag in program loading. So
179      # it will return a ResourceExhaustedError.
180      with self.assertRaises(tf.errors.ResourceExhaustedError):
181        _ = train_step(iterator)
182
183  def testAutoDefragInBufferAllocation(self):
184    if not FLAGS.tpu_use_tfrt:
185      self.skipTest(
186          "TPU StreamExecutor does not support auto-defrag in allocation.")
187    with tf.device("TPU:0"):
188      # DF has ~15G HBM. Following 7 buffers will consume most HBM.
189      # pylint: disable=unused-variable
190      buffer_2g_1 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
191      buffer_2g_2 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
192      buffer_2g_3 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
193      buffer_2g_4 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
194      buffer_2g_5 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
195      buffer_2g_6 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
196      buffer_2g_7 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32)
197      #  pylint: enable=unused-variable
198
199      # Deallocate two buffers.
200      del buffer_2g_1, buffer_2g_3
201      gc.collect()
202
203      # The buffer we just deallocated doesn't provide enough contiguous region
204      # for allocating 4G. This allocation will trigger auto-defrag.
205      buffer_4g = tf.random.uniform((4, 256, 1024, 1024), dtype=tf.float32)
206
207    self.assertEndsWith(buffer_4g.device, "device:TPU:0")
208
209
210if __name__ == "__main__":
211  tf.test.main()
212