1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests which set DEBUG_SAVEALL and assert no garbage was created. 16 17This flag seems to be sticky, so these tests have been isolated for now. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24from tensorflow.python.eager import context 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import resource_variable_ops 28from tensorflow.python.ops import tensor_array_ops 29from tensorflow.python.platform import test 30 31 32class NoReferenceCycleTests(test_util.TensorFlowTestCase): 33 34 @test_util.assert_no_garbage_created 35 def testEagerResourceVariables(self): 36 with context.eager_mode(): 37 resource_variable_ops.ResourceVariable(1.0, name="a") 38 39 @test_util.assert_no_garbage_created 40 def testTensorArrays(self): 41 with context.eager_mode(): 42 ta = tensor_array_ops.TensorArray( 43 dtype=dtypes.float32, 44 tensor_array_name="foo", 45 size=3, 46 infer_shape=False) 47 48 w0 = ta.write(0, [[4.0, 5.0]]) 49 w1 = w0.write(1, [[1.0]]) 50 w2 = w1.write(2, -3.0) 51 52 r0 = w2.read(0) 53 r1 = w2.read(1) 54 r2 = w2.read(2) 55 56 d0, d1, d2 = self.evaluate([r0, r1, r2]) 57 self.assertAllEqual([[4.0, 5.0]], d0) 58 self.assertAllEqual([[1.0]], d1) 59 self.assertAllEqual(-3.0, d2) 60 61 62if __name__ == "__main__": 63 test.main() 64