• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for memory leaks in eager execution.
16
17It is possible that this test suite will eventually become flaky due to taking
18too long to run (since the tests iterate many times), but for now they are
19helpful for finding memory leaks since not all PyObject leaks are found by
20introspection (test_util decorators). Please be careful adding new tests here.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import time
28import six
29
30from tensorflow.python import keras
31from tensorflow.python.eager import backprop
32from tensorflow.python.eager import context
33from tensorflow.python.eager import test
34from tensorflow.python.framework import dtypes
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops.variables import Variable
37
38# memory_profiler might not be available in the OSS version of TensorFlow.
39try:
40  import memory_profiler  # pylint:disable=g-import-not-at-top
41except ImportError:
42  memory_profiler = None
43
44
45class SingleLayerNet(keras.Model):
46  """Simple keras model used to ensure that there are no leaks."""
47
48  def __init__(self):
49    super(SingleLayerNet, self).__init__()
50    self.fc1 = keras.layers.Dense(5)
51
52  def call(self, x):
53    return self.fc1(x)
54
55
56class MemoryTest(test.TestCase):
57
58  def assertNotIncreasingMemory(self,
59                                f,
60                                num_iters=100000,
61                                increase_threshold_absolute_mb=10):
62    """Assert memory usage doesn't increase beyond given threshold for f."""
63
64    with context.eager_mode():
65      # Warm up.
66      f()
67
68      # Wait for background threads to start up and take over memory.
69      # FIXME: The nature of this test leaves few other options. Maybe there
70      # is a better way to do this.
71      time.sleep(4)
72
73      initial = memory_profiler.memory_usage(-1)[0]
74
75      for _ in six.moves.range(num_iters):
76        f()
77
78      increase = memory_profiler.memory_usage(-1)[0] - initial
79
80      assert increase < increase_threshold_absolute_mb, (
81          "Increase is too high. Initial memory usage: %f MB. Increase: %f MB. "
82          "Maximum allowed increase: %f") % (initial, increase,
83                                             increase_threshold_absolute_mb)
84
85  def testMemoryLeakAnonymousVariable(self):
86    if memory_profiler is None:
87      self.skipTest("memory_profiler required to run this test")
88
89    def f():
90      inputs = Variable(array_ops.zeros([32, 100], dtypes.float32))
91      del inputs
92
93    self.assertNotIncreasingMemory(f, num_iters=10000)
94
95  def testMemoryLeakInSimpleModelForwardOnly(self):
96    if memory_profiler is None:
97      self.skipTest("memory_profiler required to run this test")
98
99    inputs = array_ops.zeros([32, 100], dtypes.float32)
100    net = SingleLayerNet()
101
102    def f():
103      with backprop.GradientTape():
104        net(inputs)
105
106    self.assertNotIncreasingMemory(f)
107
108  def testMemoryLeakInSimpleModelForwardAndBackward(self):
109    if memory_profiler is None:
110      self.skipTest("memory_profiler required to run this test")
111
112    inputs = array_ops.zeros([32, 100], dtypes.float32)
113    net = SingleLayerNet()
114
115    def f():
116      with backprop.GradientTape() as tape:
117        result = net(inputs)
118
119      tape.gradient(result, net.variables)
120
121      del tape
122
123    self.assertNotIncreasingMemory(f)
124
125
126if __name__ == "__main__":
127  test.main()
128