1 /* 2 * Copyright (c) 2017 Mockito contributors 3 * This program is made available under the terms of the MIT License. 4 */ 5 package org.mockitoutil; 6 7 import java.net.MalformedURLException; 8 import java.net.URL; 9 import java.net.URLClassLoader; 10 import java.util.HashMap; 11 import java.util.Map; 12 import java.util.concurrent.Callable; 13 14 /** 15 * Custom classloader to load classes in hierarchic realm. 16 * 17 * Each class can be reloaded in the realm if the LoadClassPredicate says so. 18 */ 19 public class SimplePerRealmReloadingClassLoader extends URLClassLoader { 20 21 private final Map<String, Class<?>> classHashMap = new HashMap<String, Class<?>>(); 22 private ReloadClassPredicate reloadClassPredicate; 23 SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate)24 public SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate) { 25 super(getPossibleClassPathsUrls()); 26 this.reloadClassPredicate = reloadClassPredicate; 27 } 28 SimplePerRealmReloadingClassLoader( ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate)29 public SimplePerRealmReloadingClassLoader( 30 ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate) { 31 super(getPossibleClassPathsUrls(), parentClassLoader); 32 this.reloadClassPredicate = reloadClassPredicate; 33 } 34 getPossibleClassPathsUrls()35 private static URL[] getPossibleClassPathsUrls() { 36 return new URL[] { 37 obtainClassPath(), 38 obtainClassPath("org.mockito.Mockito"), 39 obtainClassPath("net.bytebuddy.ByteBuddy") 40 }; 41 } 42 obtainClassPath()43 private static URL obtainClassPath() { 44 String className = SimplePerRealmReloadingClassLoader.class.getName(); 45 return obtainClassPath(className); 46 } 47 obtainClassPath(String className)48 private static URL obtainClassPath(String className) { 49 String path = className.replace('.', '/') + ".class"; 50 String url = 51 SimplePerRealmReloadingClassLoader.class 52 .getClassLoader() 53 .getResource(path) 54 .toExternalForm(); 55 56 try { 57 return new URL(url.substring(0, url.length() - path.length())); 58 } catch (MalformedURLException e) { 59 throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e); 60 } 61 } 62 63 @Override loadClass(String qualifiedClassName)64 public Class<?> loadClass(String qualifiedClassName) throws ClassNotFoundException { 65 if (reloadClassPredicate.acceptReloadOf(qualifiedClassName)) { 66 // return customLoadClass(qualifiedClassName); 67 // Class<?> loadedClass = findLoadedClass(qualifiedClassName); 68 if (!classHashMap.containsKey(qualifiedClassName)) { 69 Class<?> foundClass = findClass(qualifiedClassName); 70 saveFoundClass(qualifiedClassName, foundClass); 71 return foundClass; 72 } 73 74 return classHashMap.get(qualifiedClassName); 75 } 76 return useParentClassLoaderFor(qualifiedClassName); 77 } 78 saveFoundClass(String qualifiedClassName, Class<?> foundClass)79 private void saveFoundClass(String qualifiedClassName, Class<?> foundClass) { 80 classHashMap.put(qualifiedClassName, foundClass); 81 } 82 useParentClassLoaderFor(String qualifiedName)83 private Class<?> useParentClassLoaderFor(String qualifiedName) throws ClassNotFoundException { 84 return super.loadClass(qualifiedName); 85 } 86 doInRealm(String callableCalledInClassLoaderRealm)87 public Object doInRealm(String callableCalledInClassLoaderRealm) throws Exception { 88 ClassLoader current = Thread.currentThread().getContextClassLoader(); 89 try { 90 Thread.currentThread().setContextClassLoader(this); 91 Object instance = 92 this.loadClass(callableCalledInClassLoaderRealm).getConstructor().newInstance(); 93 if (instance instanceof Callable) { 94 Callable<?> callableInRealm = (Callable<?>) instance; 95 return callableInRealm.call(); 96 } 97 } finally { 98 Thread.currentThread().setContextClassLoader(current); 99 } 100 throw new IllegalArgumentException( 101 "qualified name '" 102 + callableCalledInClassLoaderRealm 103 + "' should represent a class implementing Callable"); 104 } 105 doInRealm( String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args)106 public Object doInRealm( 107 String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args) 108 throws Exception { 109 ClassLoader current = Thread.currentThread().getContextClassLoader(); 110 try { 111 Thread.currentThread().setContextClassLoader(this); 112 Object instance = 113 this.loadClass(callableCalledInClassLoaderRealm) 114 .getConstructor(argTypes) 115 .newInstance(args); 116 if (instance instanceof Callable) { 117 Callable<?> callableInRealm = (Callable<?>) instance; 118 return callableInRealm.call(); 119 } 120 } finally { 121 Thread.currentThread().setContextClassLoader(current); 122 } 123 124 throw new IllegalArgumentException( 125 "qualified name '" 126 + callableCalledInClassLoaderRealm 127 + "' should represent a class implementing Callable"); 128 } 129 130 public interface ReloadClassPredicate { acceptReloadOf(String qualifiedName)131 boolean acceptReloadOf(String qualifiedName); 132 } 133 } 134