1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertFalse; 20 import static org.junit.Assert.assertNotNull; 21 import static org.junit.Assert.assertTrue; 22 import static org.junit.Assert.fail; 23 24 import java.lang.ref.Reference; 25 import java.lang.ref.ReferenceQueue; 26 import java.util.concurrent.BlockingQueue; 27 import java.util.concurrent.LinkedBlockingQueue; 28 import java.util.concurrent.TimeUnit; 29 import java.util.concurrent.atomic.AtomicBoolean; 30 31 import org.junit.Test; 32 import org.junit.runner.RunWith; 33 import org.junit.runners.JUnit4; 34 import org.tensorflow.EagerSession.ResourceCleanupStrategy; 35 36 @RunWith(JUnit4.class) 37 public class EagerSessionTest { 38 39 @Test closeSessionTwiceDoesNotFail()40 public void closeSessionTwiceDoesNotFail() { 41 try (EagerSession s = EagerSession.create()) { 42 s.close(); 43 } 44 } 45 46 @Test cleanupResourceOnSessionClose()47 public void cleanupResourceOnSessionClose() { 48 TestReference ref; 49 try (EagerSession s = 50 EagerSession.options() 51 .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE) 52 .build()) { 53 ref = new TestReference(s, new Object()); 54 assertFalse(ref.isDeleted()); 55 56 // check that reaching safe point did not release resources 57 buildOp(s); 58 assertFalse(ref.isDeleted()); 59 } 60 assertTrue(ref.isDeleted()); 61 } 62 63 @Test cleanupResourceOnSafePoints()64 public void cleanupResourceOnSafePoints() { 65 TestGarbageCollectorQueue gcQueue = new TestGarbageCollectorQueue(); 66 try (EagerSession s = 67 EagerSession.options() 68 .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SAFE_POINTS) 69 .buildForGcTest(gcQueue)) { 70 71 TestReference ref = new TestReference(s, new Object()); 72 assertFalse(ref.isDeleted()); 73 74 // garbage collecting the reference won't release until we reached safe point 75 gcQueue.collect(ref); 76 assertFalse(ref.isDeleted()); 77 buildOp(s); // safe point 78 assertTrue(ref.isDeleted()); 79 assertTrue(gcQueue.isEmpty()); 80 } 81 } 82 83 @Test cleanupResourceInBackground()84 public void cleanupResourceInBackground() { 85 TestGarbageCollectorQueue gcQueue = new TestGarbageCollectorQueue(); 86 try (EagerSession s = 87 EagerSession.options() 88 .resourceCleanupStrategy(ResourceCleanupStrategy.IN_BACKGROUND) 89 .buildForGcTest(gcQueue)) { 90 91 TestReference ref = new TestReference(s, new Object()); 92 assertFalse(ref.isDeleted()); 93 94 gcQueue.collect(ref); 95 sleep(50); // allow some time to the background thread for cleaning up resources 96 assertTrue(ref.isDeleted()); 97 assertTrue(gcQueue.isEmpty()); 98 } 99 } 100 101 @Test clearedResourcesAreNotCleanedUp()102 public void clearedResourcesAreNotCleanedUp() { 103 TestReference ref; 104 try (EagerSession s = EagerSession.create()) { 105 ref = new TestReference(s, new Object()); 106 ref.clear(); 107 } 108 assertFalse(ref.isDeleted()); 109 } 110 111 @Test buildingOpWithClosedSessionFails()112 public void buildingOpWithClosedSessionFails() { 113 EagerSession s = EagerSession.create(); 114 s.close(); 115 try { 116 buildOp(s); 117 fail(); 118 } catch (IllegalStateException e) { 119 // ok 120 } 121 } 122 123 @Test addingReferenceToClosedSessionFails()124 public void addingReferenceToClosedSessionFails() { 125 EagerSession s = EagerSession.create(); 126 s.close(); 127 try { 128 new TestReference(s, new Object()); 129 fail(); 130 } catch (IllegalStateException e) { 131 // ok 132 } 133 } 134 135 @Test defaultSession()136 public void defaultSession() throws Exception { 137 EagerSession.Options options = 138 EagerSession.options().resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE); 139 EagerSession.initDefault(options); 140 EagerSession session = EagerSession.getDefault(); 141 assertNotNull(session); 142 assertEquals(ResourceCleanupStrategy.ON_SESSION_CLOSE, session.resourceCleanupStrategy()); 143 try { 144 EagerSession.initDefault(options); 145 fail(); 146 } catch (IllegalStateException e) { 147 // expected 148 } 149 try { 150 session.close(); 151 fail(); 152 } catch (IllegalStateException e) { 153 // expected 154 } 155 } 156 157 private static class TestReference extends EagerSession.NativeReference { 158 TestReference(EagerSession session, Object referent)159 TestReference(EagerSession session, Object referent) { 160 super(session, referent); 161 } 162 163 @Override delete()164 void delete() { 165 if (!deleted.compareAndSet(false, true)) { 166 fail("Reference was deleted more than once"); 167 } 168 } 169 isDeleted()170 boolean isDeleted() { 171 return deleted.get(); 172 } 173 174 private final AtomicBoolean deleted = new AtomicBoolean(); 175 } 176 177 private static class TestGarbageCollectorQueue extends ReferenceQueue<Object> { 178 179 @Override poll()180 public Reference<? extends Object> poll() { 181 return garbage.poll(); 182 } 183 184 @Override remove()185 public Reference<? extends Object> remove() throws InterruptedException { 186 return garbage.take(); 187 } 188 189 @Override remove(long timeout)190 public Reference<? extends Object> remove(long timeout) 191 throws IllegalArgumentException, InterruptedException { 192 return garbage.poll(timeout, TimeUnit.MILLISECONDS); 193 } 194 collect(TestReference ref)195 void collect(TestReference ref) { 196 garbage.add(ref); 197 } 198 isEmpty()199 boolean isEmpty() { 200 return garbage.isEmpty(); 201 } 202 203 private final BlockingQueue<TestReference> garbage = new LinkedBlockingQueue<>(); 204 } 205 buildOp(EagerSession s)206 private static void buildOp(EagerSession s) { 207 // Creating an operation is a safe point for resource cleanup 208 try { 209 s.opBuilder("Const", "Const"); 210 } catch (UnsupportedOperationException e) { 211 // TODO (karlllessard) remove this exception catch when EagerOperationBuilder is implemented 212 } 213 } 214 sleep(int millis)215 private static void sleep(int millis) { 216 try { 217 Thread.sleep(millis); 218 } catch (InterruptedException e) { 219 } 220 } 221 } 222