1 package org.mockitoutil; 2 3 import java.io.ByteArrayInputStream; 4 import java.io.File; 5 import java.io.IOException; 6 import java.io.InputStream; 7 import java.lang.reflect.Field; 8 import java.lang.reflect.Modifier; 9 import java.net.MalformedURLException; 10 import java.net.URI; 11 import java.net.URISyntaxException; 12 import java.net.URL; 13 import java.net.URLClassLoader; 14 import java.net.URLConnection; 15 import java.net.URLStreamHandler; 16 import java.util.ArrayList; 17 import java.util.Arrays; 18 import java.util.Collections; 19 import java.util.Enumeration; 20 import java.util.HashMap; 21 import java.util.HashSet; 22 import java.util.Iterator; 23 import java.util.List; 24 import java.util.Map; 25 import java.util.Set; 26 import java.util.concurrent.ExecutionException; 27 import java.util.concurrent.ExecutorService; 28 import java.util.concurrent.Executors; 29 import java.util.concurrent.Future; 30 import java.util.concurrent.ThreadFactory; 31 import org.objenesis.Objenesis; 32 import org.objenesis.ObjenesisStd; 33 import org.objenesis.instantiator.ObjectInstantiator; 34 35 import static java.lang.String.format; 36 import static java.util.Arrays.asList; 37 38 public abstract class ClassLoaders { 39 protected ClassLoader parent = currentClassLoader(); 40 ClassLoaders()41 protected ClassLoaders() { 42 } 43 isolatedClassLoader()44 public static IsolatedURLClassLoaderBuilder isolatedClassLoader() { 45 return new IsolatedURLClassLoaderBuilder(); 46 } 47 excludingClassLoader()48 public static ExcludingURLClassLoaderBuilder excludingClassLoader() { 49 return new ExcludingURLClassLoaderBuilder(); 50 } 51 inMemoryClassLoader()52 public static InMemoryClassLoaderBuilder inMemoryClassLoader() { 53 return new InMemoryClassLoaderBuilder(); 54 } 55 in(ClassLoader classLoader)56 public static ReachableClassesFinder in(ClassLoader classLoader) { 57 return new ReachableClassesFinder(classLoader); 58 } 59 jdkClassLoader()60 public static ClassLoader jdkClassLoader() { 61 return String.class.getClassLoader(); 62 } 63 systemClassLoader()64 public static ClassLoader systemClassLoader() { 65 return ClassLoader.getSystemClassLoader(); 66 } 67 currentClassLoader()68 public static ClassLoader currentClassLoader() { 69 return ClassLoaders.class.getClassLoader(); 70 } 71 build()72 public abstract ClassLoader build(); 73 coverageTool()74 public static Class<?>[] coverageTool() { 75 HashSet<Class<?>> classes = new HashSet<Class<?>>(); 76 classes.add(safeGetClass("net.sourceforge.cobertura.coveragedata.TouchCollector")); 77 classes.add(safeGetClass("org.slf4j.LoggerFactory")); 78 79 classes.remove(null); 80 return classes.toArray(new Class<?>[classes.size()]); 81 } 82 safeGetClass(String className)83 private static Class<?> safeGetClass(String className) { 84 try { 85 return Class.forName(className); 86 } catch (ClassNotFoundException e) { 87 return null; 88 } 89 } 90 using(final ClassLoader classLoader)91 public static ClassLoaderExecutor using(final ClassLoader classLoader) { 92 return new ClassLoaderExecutor(classLoader); 93 } 94 95 public static class ClassLoaderExecutor { 96 private ClassLoader classLoader; 97 ClassLoaderExecutor(ClassLoader classLoader)98 public ClassLoaderExecutor(ClassLoader classLoader) { 99 this.classLoader = classLoader; 100 } 101 execute(final Runnable task)102 public void execute(final Runnable task) throws Exception { 103 ExecutorService executorService = Executors.newSingleThreadExecutor(new ThreadFactory() { 104 @Override 105 public Thread newThread(Runnable r) { 106 Thread thread = Executors.defaultThreadFactory().newThread(r); 107 thread.setContextClassLoader(classLoader); 108 return thread; 109 } 110 }); 111 try { 112 Future<?> taskFuture = executorService.submit(new Runnable() { 113 @Override 114 public void run() { 115 try { 116 reloadTaskInClassLoader(task).run(); 117 } catch (Throwable throwable) { 118 throw new IllegalStateException(format("Given task could not be loaded properly in the given classloader '%s', error '%s", 119 task, 120 throwable.getMessage()), 121 throwable); 122 } 123 } 124 }); 125 taskFuture.get(); 126 executorService.shutdownNow(); 127 } catch (InterruptedException e) { 128 Thread.currentThread().interrupt(); 129 } catch (ExecutionException e) { 130 throw this.<Exception>unwrapAndThrows(e); 131 } 132 } 133 134 @SuppressWarnings("unchecked") unwrapAndThrows(ExecutionException ex)135 private <T extends Throwable> T unwrapAndThrows(ExecutionException ex) throws T { 136 throw (T) ex.getCause(); 137 } 138 reloadTaskInClassLoader(Runnable task)139 Runnable reloadTaskInClassLoader(Runnable task) { 140 try { 141 @SuppressWarnings("unchecked") 142 Class<Runnable> taskClassReloaded = (Class<Runnable>) classLoader.loadClass(task.getClass().getName()); 143 144 Objenesis objenesis = new ObjenesisStd(); 145 ObjectInstantiator<Runnable> thingyInstantiator = objenesis.getInstantiatorOf(taskClassReloaded); 146 Runnable reloaded = thingyInstantiator.newInstance(); 147 148 // lenient shallow copy of class compatible fields 149 for (Field field : task.getClass().getDeclaredFields()) { 150 Field declaredField = taskClassReloaded.getDeclaredField(field.getName()); 151 int modifiers = declaredField.getModifiers(); 152 if(Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) { 153 // Skip static final fields (e.g. jacoco fields) 154 // otherwise IllegalAccessException (can be bypassed with Unsafe though) 155 // We may also miss coverage data. 156 continue; 157 } 158 if (declaredField.getType() == field.getType()) { // don't copy this 159 field.setAccessible(true); 160 declaredField.setAccessible(true); 161 declaredField.set(reloaded, field.get(task)); 162 } 163 } 164 165 return reloaded; 166 } catch (ClassNotFoundException e) { 167 throw new IllegalStateException(e); 168 } catch (IllegalAccessException e) { 169 throw new IllegalStateException(e); 170 } catch (NoSuchFieldException e) { 171 throw new IllegalStateException(e); 172 } 173 } 174 } 175 176 public static class IsolatedURLClassLoaderBuilder extends ClassLoaders { 177 private final ArrayList<String> excludedPrefixes = new ArrayList<String>(); 178 private final ArrayList<String> privateCopyPrefixes = new ArrayList<String>(); 179 private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>(); 180 withPrivateCopyOf(String... privatePrefixes)181 public IsolatedURLClassLoaderBuilder withPrivateCopyOf(String... privatePrefixes) { 182 privateCopyPrefixes.addAll(asList(privatePrefixes)); 183 return this; 184 } 185 withCodeSourceUrls(String... urls)186 public IsolatedURLClassLoaderBuilder withCodeSourceUrls(String... urls) { 187 codeSourceUrls.addAll(pathsToURLs(urls)); 188 return this; 189 } 190 withCodeSourceUrlOf(Class<?>.... classes)191 public IsolatedURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) { 192 for (Class<?> clazz : classes) { 193 codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName())); 194 } 195 return this; 196 } 197 withCurrentCodeSourceUrls()198 public IsolatedURLClassLoaderBuilder withCurrentCodeSourceUrls() { 199 codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName())); 200 return this; 201 } 202 without(String... privatePrefixes)203 public IsolatedURLClassLoaderBuilder without(String... privatePrefixes) { 204 excludedPrefixes.addAll(asList(privatePrefixes)); 205 return this; 206 } 207 build()208 public ClassLoader build() { 209 return new LocalIsolatedURLClassLoader( 210 jdkClassLoader(), 211 codeSourceUrls.toArray(new URL[codeSourceUrls.size()]), 212 privateCopyPrefixes, 213 excludedPrefixes 214 ); 215 } 216 } 217 218 static class LocalIsolatedURLClassLoader extends URLClassLoader { 219 private final ArrayList<String> privateCopyPrefixes; 220 private final ArrayList<String> excludedPrefixes; 221 LocalIsolatedURLClassLoader(ClassLoader classLoader, URL[] urls, ArrayList<String> privateCopyPrefixes, ArrayList<String> excludedPrefixes)222 LocalIsolatedURLClassLoader(ClassLoader classLoader, 223 URL[] urls, 224 ArrayList<String> privateCopyPrefixes, 225 ArrayList<String> excludedPrefixes) { 226 super(urls, classLoader); 227 this.privateCopyPrefixes = privateCopyPrefixes; 228 this.excludedPrefixes = excludedPrefixes; 229 } 230 231 @Override findClass(String name)232 public Class<?> findClass(String name) throws ClassNotFoundException { 233 if (!classShouldBePrivate(name) || classShouldBeExcluded(name)) { 234 throw new ClassNotFoundException(format("Can only load classes with prefixes : %s, but not : %s", 235 privateCopyPrefixes, 236 excludedPrefixes)); 237 } 238 try { 239 return super.findClass(name); 240 } catch (ClassNotFoundException cnfe) { 241 throw new ClassNotFoundException(format("%s%n%s%n", 242 cnfe.getMessage(), 243 " Did you forgot to add the code source url 'withCodeSourceUrlOf' / 'withCurrentCodeSourceUrls' ?"), 244 cnfe); 245 } 246 } 247 classShouldBePrivate(String name)248 private boolean classShouldBePrivate(String name) { 249 for (String prefix : privateCopyPrefixes) { 250 if (name.startsWith(prefix)) return true; 251 } 252 return false; 253 } 254 classShouldBeExcluded(String name)255 private boolean classShouldBeExcluded(String name) { 256 for (String prefix : excludedPrefixes) { 257 if (name.startsWith(prefix)) return true; 258 } 259 return false; 260 } 261 } 262 263 public static class ExcludingURLClassLoaderBuilder extends ClassLoaders { 264 private final ArrayList<String> excludedPrefixes = new ArrayList<String>(); 265 private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>(); 266 without(String... privatePrefixes)267 public ExcludingURLClassLoaderBuilder without(String... privatePrefixes) { 268 excludedPrefixes.addAll(asList(privatePrefixes)); 269 return this; 270 } 271 withCodeSourceUrls(String... urls)272 public ExcludingURLClassLoaderBuilder withCodeSourceUrls(String... urls) { 273 codeSourceUrls.addAll(pathsToURLs(urls)); 274 return this; 275 } 276 withCodeSourceUrlOf(Class<?>.... classes)277 public ExcludingURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) { 278 for (Class<?> clazz : classes) { 279 codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName())); 280 } 281 return this; 282 } 283 withCurrentCodeSourceUrls()284 public ExcludingURLClassLoaderBuilder withCurrentCodeSourceUrls() { 285 codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName())); 286 return this; 287 } 288 build()289 public ClassLoader build() { 290 return new LocalExcludingURLClassLoader( 291 jdkClassLoader(), 292 codeSourceUrls.toArray(new URL[codeSourceUrls.size()]), 293 excludedPrefixes 294 ); 295 } 296 } 297 298 static class LocalExcludingURLClassLoader extends URLClassLoader { 299 private final ArrayList<String> excludedPrefixes; 300 LocalExcludingURLClassLoader(ClassLoader classLoader, URL[] urls, ArrayList<String> excludedPrefixes)301 LocalExcludingURLClassLoader(ClassLoader classLoader, 302 URL[] urls, 303 ArrayList<String> excludedPrefixes) { 304 super(urls, classLoader); 305 this.excludedPrefixes = excludedPrefixes; 306 } 307 308 @Override findClass(String name)309 public Class<?> findClass(String name) throws ClassNotFoundException { 310 if (classShouldBeExcluded(name)) 311 throw new ClassNotFoundException("classes with prefix : " + excludedPrefixes + " are excluded"); 312 return super.findClass(name); 313 } 314 classShouldBeExcluded(String name)315 private boolean classShouldBeExcluded(String name) { 316 for (String prefix : excludedPrefixes) { 317 if (name.startsWith(prefix)) return true; 318 } 319 return false; 320 } 321 } 322 323 public static class InMemoryClassLoaderBuilder extends ClassLoaders { 324 private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>(); 325 withParent(ClassLoader parent)326 public InMemoryClassLoaderBuilder withParent(ClassLoader parent) { 327 this.parent = parent; 328 return this; 329 } 330 withClassDefinition(String name, byte[] classDefinition)331 public InMemoryClassLoaderBuilder withClassDefinition(String name, byte[] classDefinition) { 332 inMemoryClassObjects.put(name, classDefinition); 333 return this; 334 } 335 build()336 public ClassLoader build() { 337 return new InMemoryClassLoader(parent, inMemoryClassObjects); 338 } 339 } 340 341 static class InMemoryClassLoader extends ClassLoader { 342 public static final String SCHEME = "mem"; 343 private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>(); 344 InMemoryClassLoader(ClassLoader parent, Map<String, byte[]> inMemoryClassObjects)345 public InMemoryClassLoader(ClassLoader parent, Map<String, byte[]> inMemoryClassObjects) { 346 super(parent); 347 this.inMemoryClassObjects = inMemoryClassObjects; 348 } 349 findClass(String name)350 protected Class<?> findClass(String name) throws ClassNotFoundException { 351 byte[] classDefinition = inMemoryClassObjects.get(name); 352 if (classDefinition != null) { 353 return defineClass(name, classDefinition, 0, classDefinition.length); 354 } 355 throw new ClassNotFoundException(name); 356 } 357 358 @Override getResources(String ignored)359 public Enumeration<URL> getResources(String ignored) throws IOException { 360 return inMemoryOnly(); 361 } 362 inMemoryOnly()363 private Enumeration<URL> inMemoryOnly() { 364 final Set<String> names = inMemoryClassObjects.keySet(); 365 return new Enumeration<URL>() { 366 private final MemHandler memHandler = new MemHandler(InMemoryClassLoader.this); 367 private final Iterator<String> it = names.iterator(); 368 369 public boolean hasMoreElements() { 370 return it.hasNext(); 371 } 372 373 public URL nextElement() { 374 try { 375 return new URL(null, SCHEME + ":" + it.next(), memHandler); 376 } catch (MalformedURLException rethrown) { 377 throw new IllegalStateException(rethrown); 378 } 379 } 380 }; 381 } 382 } 383 384 public static class MemHandler extends URLStreamHandler { 385 private InMemoryClassLoader inMemoryClassLoader; 386 MemHandler(InMemoryClassLoader inMemoryClassLoader)387 public MemHandler(InMemoryClassLoader inMemoryClassLoader) { 388 this.inMemoryClassLoader = inMemoryClassLoader; 389 } 390 391 @Override openConnection(URL url)392 protected URLConnection openConnection(URL url) throws IOException { 393 return new MemURLConnection(url, inMemoryClassLoader); 394 } 395 396 private static class MemURLConnection extends URLConnection { 397 private final InMemoryClassLoader inMemoryClassLoader; 398 private String qualifiedName; 399 MemURLConnection(URL url, InMemoryClassLoader inMemoryClassLoader)400 public MemURLConnection(URL url, InMemoryClassLoader inMemoryClassLoader) { 401 super(url); 402 this.inMemoryClassLoader = inMemoryClassLoader; 403 qualifiedName = url.getPath(); 404 } 405 406 @Override connect()407 public void connect() throws IOException { 408 } 409 410 @Override getInputStream()411 public InputStream getInputStream() throws IOException { 412 return new ByteArrayInputStream(inMemoryClassLoader.inMemoryClassObjects.get(qualifiedName)); 413 } 414 } 415 } 416 obtainCurrentClassPathOf(String className)417 URL obtainCurrentClassPathOf(String className) { 418 String path = className.replace('.', '/') + ".class"; 419 String url = ClassLoaders.class.getClassLoader().getResource(path).toExternalForm(); 420 421 try { 422 return new URL(url.substring(0, url.length() - path.length())); 423 } catch (MalformedURLException e) { 424 throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e); 425 } 426 } 427 pathsToURLs(String... codeSourceUrls)428 List<URL> pathsToURLs(String... codeSourceUrls) { 429 return pathsToURLs(Arrays.asList(codeSourceUrls)); 430 } 431 pathsToURLs(List<String> codeSourceUrls)432 private List<URL> pathsToURLs(List<String> codeSourceUrls) { 433 ArrayList<URL> urls = new ArrayList<URL>(codeSourceUrls.size()); 434 for (String codeSourceUrl : codeSourceUrls) { 435 URL url = pathToUrl(codeSourceUrl); 436 urls.add(url); 437 } 438 return urls; 439 } 440 pathToUrl(String path)441 private URL pathToUrl(String path) { 442 try { 443 return new File(path).getAbsoluteFile().toURI().toURL(); 444 } catch (MalformedURLException e) { 445 throw new IllegalArgumentException("Path is malformed", e); 446 } 447 } 448 449 public static class ReachableClassesFinder { 450 private ClassLoader classLoader; 451 private Set<String> qualifiedNameSubstring = new HashSet<String>(); 452 ReachableClassesFinder(ClassLoader classLoader)453 ReachableClassesFinder(ClassLoader classLoader) { 454 this.classLoader = classLoader; 455 } 456 omit(String... qualifiedNameSubstring)457 public ReachableClassesFinder omit(String... qualifiedNameSubstring) { 458 this.qualifiedNameSubstring.addAll(Arrays.asList(qualifiedNameSubstring)); 459 return this; 460 } 461 listOwnedClasses()462 public Set<String> listOwnedClasses() throws IOException, URISyntaxException { 463 Enumeration<URL> roots = classLoader.getResources(""); 464 465 Set<String> classes = new HashSet<String>(); 466 while (roots.hasMoreElements()) { 467 URI uri = roots.nextElement().toURI(); 468 469 if (uri.getScheme().equalsIgnoreCase("file")) { 470 addFromFileBasedClassLoader(classes, uri); 471 } else if (uri.getScheme().equalsIgnoreCase(InMemoryClassLoader.SCHEME)) { 472 addFromInMemoryBasedClassLoader(classes, uri); 473 } else { 474 throw new IllegalArgumentException(format("Given ClassLoader '%s' don't have reachable by File or vi ClassLoaders.inMemory", classLoader)); 475 } 476 } 477 return classes; 478 } 479 addFromFileBasedClassLoader(Set<String> classes, URI uri)480 private void addFromFileBasedClassLoader(Set<String> classes, URI uri) { 481 File root = new File(uri); 482 classes.addAll(findClassQualifiedNames(root, root, qualifiedNameSubstring)); 483 } 484 addFromInMemoryBasedClassLoader(Set<String> classes, URI uri)485 private void addFromInMemoryBasedClassLoader(Set<String> classes, URI uri) { 486 String qualifiedName = uri.getSchemeSpecificPart(); 487 if (excludes(qualifiedName, qualifiedNameSubstring)) { 488 classes.add(qualifiedName); 489 } 490 } 491 492 findClassQualifiedNames(File root, File file, Set<String> packageFilters)493 private Set<String> findClassQualifiedNames(File root, File file, Set<String> packageFilters) { 494 if (file.isDirectory()) { 495 File[] files = file.listFiles(); 496 Set<String> classes = new HashSet<String>(); 497 for (File children : files) { 498 classes.addAll(findClassQualifiedNames(root, children, packageFilters)); 499 } 500 return classes; 501 } else { 502 if (file.getName().endsWith(".class")) { 503 String qualifiedName = classNameFor(root, file); 504 if (excludes(qualifiedName, packageFilters)) { 505 return Collections.singleton(qualifiedName); 506 } 507 } 508 } 509 return Collections.emptySet(); 510 } 511 excludes(String qualifiedName, Set<String> packageFilters)512 private boolean excludes(String qualifiedName, Set<String> packageFilters) { 513 for (String filter : packageFilters) { 514 if (qualifiedName.contains(filter)) return false; 515 } 516 return true; 517 } 518 classNameFor(File root, File file)519 private String classNameFor(File root, File file) { 520 String temp = file.getAbsolutePath().substring(root.getAbsolutePath().length() + 1). 521 replace(File.separatorChar, '.'); 522 return temp.subSequence(0, temp.indexOf(".class")).toString(); 523 } 524 525 } 526 } 527