• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.assertFalse;
20 import static org.junit.Assert.assertNotNull;
21 import static org.junit.Assert.assertTrue;
22 import static org.junit.Assert.fail;
23 
24 import java.lang.ref.Reference;
25 import java.lang.ref.ReferenceQueue;
26 import java.util.concurrent.BlockingQueue;
27 import java.util.concurrent.LinkedBlockingQueue;
28 import java.util.concurrent.TimeUnit;
29 import java.util.concurrent.atomic.AtomicBoolean;
30 
31 import org.junit.Test;
32 import org.junit.runner.RunWith;
33 import org.junit.runners.JUnit4;
34 import org.tensorflow.EagerSession.ResourceCleanupStrategy;
35 
36 @RunWith(JUnit4.class)
37 public class EagerSessionTest {
38 
39   @Test
closeSessionTwiceDoesNotFail()40   public void closeSessionTwiceDoesNotFail() {
41     try (EagerSession s = EagerSession.create()) {
42       s.close();
43     }
44   }
45 
46   @Test
cleanupResourceOnSessionClose()47   public void cleanupResourceOnSessionClose() {
48     TestReference ref;
49     try (EagerSession s =
50         EagerSession.options()
51             .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE)
52             .build()) {
53       ref = new TestReference(s, new Object());
54       assertFalse(ref.isDeleted());
55 
56       // check that reaching safe point did not release resources
57       buildOp(s);
58       assertFalse(ref.isDeleted());
59     }
60     assertTrue(ref.isDeleted());
61   }
62 
63   @Test
cleanupResourceOnSafePoints()64   public void cleanupResourceOnSafePoints() {
65     TestGarbageCollectorQueue gcQueue = new TestGarbageCollectorQueue();
66     try (EagerSession s =
67         EagerSession.options()
68             .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SAFE_POINTS)
69             .buildForGcTest(gcQueue)) {
70 
71       TestReference ref = new TestReference(s, new Object());
72       assertFalse(ref.isDeleted());
73 
74       // garbage collecting the reference won't release until we reached safe point
75       gcQueue.collect(ref);
76       assertFalse(ref.isDeleted());
77       buildOp(s); // safe point
78       assertTrue(ref.isDeleted());
79       assertTrue(gcQueue.isEmpty());
80     }
81   }
82 
83   @Test
cleanupResourceInBackground()84   public void cleanupResourceInBackground() {
85     TestGarbageCollectorQueue gcQueue = new TestGarbageCollectorQueue();
86     try (EagerSession s =
87         EagerSession.options()
88             .resourceCleanupStrategy(ResourceCleanupStrategy.IN_BACKGROUND)
89             .buildForGcTest(gcQueue)) {
90 
91       TestReference ref = new TestReference(s, new Object());
92       assertFalse(ref.isDeleted());
93 
94       gcQueue.collect(ref);
95       sleep(50); // allow some time to the background thread for cleaning up resources
96       assertTrue(ref.isDeleted());
97       assertTrue(gcQueue.isEmpty());
98     }
99   }
100 
101   @Test
clearedResourcesAreNotCleanedUp()102   public void clearedResourcesAreNotCleanedUp() {
103     TestReference ref;
104     try (EagerSession s = EagerSession.create()) {
105       ref = new TestReference(s, new Object());
106       ref.clear();
107     }
108     assertFalse(ref.isDeleted());
109   }
110 
111   @Test
buildingOpWithClosedSessionFails()112   public void buildingOpWithClosedSessionFails() {
113     EagerSession s = EagerSession.create();
114     s.close();
115     try {
116       buildOp(s);
117       fail();
118     } catch (IllegalStateException e) {
119       // ok
120     }
121   }
122 
123   @Test
addingReferenceToClosedSessionFails()124   public void addingReferenceToClosedSessionFails() {
125     EagerSession s = EagerSession.create();
126     s.close();
127     try {
128       new TestReference(s, new Object());
129       fail();
130     } catch (IllegalStateException e) {
131       // ok
132     }
133   }
134 
135   @Test
defaultSession()136   public void defaultSession() throws Exception {
137     EagerSession.Options options =
138         EagerSession.options().resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE);
139     EagerSession.initDefault(options);
140     EagerSession session = EagerSession.getDefault();
141     assertNotNull(session);
142     assertEquals(ResourceCleanupStrategy.ON_SESSION_CLOSE, session.resourceCleanupStrategy());
143     try {
144       EagerSession.initDefault(options);
145       fail();
146     } catch (IllegalStateException e) {
147       // expected
148     }
149     try {
150       session.close();
151       fail();
152     } catch (IllegalStateException e) {
153       // expected
154     }
155   }
156 
157   private static class TestReference extends EagerSession.NativeReference {
158 
TestReference(EagerSession session, Object referent)159     TestReference(EagerSession session, Object referent) {
160       super(session, referent);
161     }
162 
163     @Override
delete()164     void delete() {
165       if (!deleted.compareAndSet(false, true)) {
166         fail("Reference was deleted more than once");
167       }
168     }
169 
isDeleted()170     boolean isDeleted() {
171       return deleted.get();
172     }
173 
174     private final AtomicBoolean deleted = new AtomicBoolean();
175   }
176 
177   private static class TestGarbageCollectorQueue extends ReferenceQueue<Object> {
178 
179     @Override
poll()180     public Reference<? extends Object> poll() {
181       return garbage.poll();
182     }
183 
184     @Override
remove()185     public Reference<? extends Object> remove() throws InterruptedException {
186       return garbage.take();
187     }
188 
189     @Override
remove(long timeout)190     public Reference<? extends Object> remove(long timeout)
191         throws IllegalArgumentException, InterruptedException {
192       return garbage.poll(timeout, TimeUnit.MILLISECONDS);
193     }
194 
collect(TestReference ref)195     void collect(TestReference ref) {
196       garbage.add(ref);
197     }
198 
isEmpty()199     boolean isEmpty() {
200       return garbage.isEmpty();
201     }
202 
203     private final BlockingQueue<TestReference> garbage = new LinkedBlockingQueue<>();
204   }
205 
buildOp(EagerSession s)206   private static void buildOp(EagerSession s) {
207     // Creating an operation is a safe point for resource cleanup
208     try {
209       s.opBuilder("Const", "Const");
210     } catch (UnsupportedOperationException e) {
211       // TODO (karlllessard) remove this exception catch when EagerOperationBuilder is implemented
212     }
213   }
214 
sleep(int millis)215   private static void sleep(int millis) {
216     try {
217       Thread.sleep(millis);
218     } catch (InterruptedException e) {
219     }
220   }
221 }
222