1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This file contains integration test for TPUStrategy in regards to memory.""" 16 17import gc 18 19import tensorflow as tf 20 21from tensorflow.python.eager import context 22from tensorflow.python.platform import flags 23 24FLAGS = flags.FLAGS 25NUM_CLASS = 10 26 27 28def get_dataset(): 29 30 def generate_data(_): 31 image = tf.ones([500, 500, 3], dtype=tf.float32) 32 label = tf.zeros([1], dtype=tf.int32) 33 return image, label 34 35 def preprocess(image, label): 36 label = tf.cast(label, tf.int32) 37 label = tf.one_hot(label, NUM_CLASS) 38 label = tf.reshape(label, [NUM_CLASS]) 39 return image, label 40 41 dataset = tf.data.Dataset.range(1) 42 dataset = dataset.repeat() 43 dataset = dataset.map( 44 generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) 45 dataset = dataset.map( 46 preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) 47 dataset = dataset.repeat() 48 dataset = dataset.batch(128, drop_remainder=True) 49 return dataset 50 51 52class TpuMemoryTest(tf.test.TestCase): 53 54 def setUp(self): 55 super().setUp() 56 # Clear all cached tensors 57 context._reset_context() 58 # Run garbage collection to free any tensors from previous 59 # runs. 60 gc.collect() 61 62 # Run a small program and copy the result to CPU. 63 # This causes deferred deallocations to be flushed and new memory to be 64 # allocated in a less fragmented way. 65 # Turning deferred deallocations off no longer seems to work. 66 assert tf.reduce_sum(tf.random.uniform( 67 (1024, 128), dtype=tf.float32)).numpy() > 1.0 68 69 self.resolver = tf.distribute.cluster_resolver.TPUClusterResolver( 70 tpu="", project=None, zone=None) 71 72 tf.config.experimental_connect_to_cluster(self.resolver) 73 tf.tpu.experimental.initialize_tpu_system(self.resolver) 74 75 def testAutoDefragInProgramLoading(self): 76 # This test covers the case when training a large model on TPU. TPU HBM 77 # is not big enough to hold all TPU buffers and preserve stack for the 78 # TPU program. Runtime will automatically unload unused TPU program to 79 # free up space for TPU buffers. Having lots of TPU buffer may also 80 # introduce fragmentation in HBM to prevent us loading a TPU program 81 # properly. Runtime will automatically defrag in order to load a large 82 # TPU program. 83 84 strategy = tf.distribute.TPUStrategy(self.resolver) 85 dataset = get_dataset() 86 iterator = iter( 87 strategy.experimental_distribute_dataset(dataset, 88 tf.distribute.InputOptions())) 89 90 # Create a dummy big model that is close to HBM limit (15G): 91 # Temp HBM: 11G 92 # Sharded variable size: 2G 93 # Unsharded variables size: 4G 94 with strategy.scope(): 95 x = tf.keras.layers.Input(shape=(500, 500, 3), name="input") 96 y = tf.keras.layers.Conv2D( 97 384, (15, 15), 98 strides=(2, 2), 99 padding="valid", 100 use_bias=False, 101 kernel_initializer="he_normal", 102 name="conv1")( 103 x) 104 y = tf.keras.layers.BatchNormalization( 105 momentum=0.997, center=True, scale=True)( 106 y) 107 y = tf.keras.layers.Dense( 108 10, 109 activation="softmax", 110 kernel_initializer=tf.random_normal_initializer(stddev=0.01))( 111 y) 112 y = tf.keras.layers.Conv2D( 113 64, (9, 9), 114 strides=(2, 2), 115 padding="valid", 116 use_bias=False, 117 kernel_initializer="he_normal", 118 name="conv2")( 119 y) 120 y = tf.keras.layers.Flatten()(y) 121 y = tf.keras.layers.Dense( 122 1024, 123 activation="softmax", 124 kernel_initializer=tf.random_normal_initializer(stddev=0.01))( 125 y) 126 y = tf.keras.layers.Dense( 127 1024, 128 activation="softmax", 129 kernel_initializer=tf.random_normal_initializer(stddev=0.01))( 130 y) 131 y = tf.keras.layers.Dense( 132 NUM_CLASS, 133 activation="softmax", 134 kernel_initializer=tf.random_normal_initializer(stddev=0.01))( 135 y) 136 model = tf.keras.Model(x, y) 137 optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) 138 loss_obj = tf.keras.losses.CategoricalCrossentropy( 139 label_smoothing=0.0, reduction=tf.keras.losses.Reduction.NONE) 140 model.compile(optimizer=optimizer, loss=loss_obj) 141 142 @tf.function 143 def train_step(iterator): 144 145 def step_fn(inputs): 146 images, targets = inputs 147 with tf.GradientTape() as tape: 148 outputs = model(images, training=True) 149 loss = model.loss(targets, outputs) 150 151 grads = tape.gradient(loss, model.trainable_variables) 152 model.optimizer.apply_gradients(zip(grads, model.trainable_variables)) 153 return loss 154 155 # Using host training loop here to trigger weight-update-sharding. It will 156 # introduce shard variable and unshard variable ops into the graph. 157 # When running unshard variable op, HBM won't have enough space for 158 # unsharded variables: 11G + 2G + 4G > 15G. So Runtime will have to 159 # automatically unload step function to free up space for unshard 160 # variable op. 161 for _ in tf.range(tf.constant(20)): 162 strategy.run(step_fn, args=(next(iterator),)) 163 164 # We want to load the step function again after unshard variable op. 165 # However, we won't have enough space due to fragamentation: 166 # 15G - 2G - 4G < 11G. So Runtime will have to automatically defrag 167 # in order to load the program successfully. 168 strategy.run(step_fn, args=(next(iterator),)) 169 170 # A dummy result to indicate this @tf.function has finished. 171 return 1.0 172 173 if FLAGS.tpu_use_tfrt: 174 result = train_step(iterator) 175 176 self.assertAllClose(1.0, result, atol=1e-07) 177 else: 178 # TPU StreamExecutor does not support auto-defrag in program loading. So 179 # it will return a ResourceExhaustedError. 180 with self.assertRaises(tf.errors.ResourceExhaustedError): 181 _ = train_step(iterator) 182 183 def testAutoDefragInBufferAllocation(self): 184 if not FLAGS.tpu_use_tfrt: 185 self.skipTest( 186 "TPU StreamExecutor does not support auto-defrag in allocation.") 187 with tf.device("TPU:0"): 188 # DF has ~15G HBM. Following 7 buffers will consume most HBM. 189 # pylint: disable=unused-variable 190 buffer_2g_1 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32) 191 buffer_2g_2 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32) 192 buffer_2g_3 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32) 193 buffer_2g_4 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32) 194 buffer_2g_5 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32) 195 buffer_2g_6 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32) 196 buffer_2g_7 = tf.random.uniform((2, 256, 1024, 1024), dtype=tf.float32) 197 # pylint: enable=unused-variable 198 199 # Deallocate two buffers. 200 del buffer_2g_1, buffer_2g_3 201 gc.collect() 202 203 # The buffer we just deallocated doesn't provide enough contiguous region 204 # for allocating 4G. This allocation will trigger auto-defrag. 205 buffer_4g = tf.random.uniform((4, 256, 1024, 1024), dtype=tf.float32) 206 207 self.assertEndsWith(buffer_4g.device, "device:TPU:0") 208 209 210if __name__ == "__main__": 211 tf.test.main() 212