1 package org.mockitoutil; 2 3 import java.net.MalformedURLException; 4 import java.net.URL; 5 import java.net.URLClassLoader; 6 import java.util.HashMap; 7 import java.util.Map; 8 import java.util.concurrent.Callable; 9 10 /** 11 * Custom classloader to load classes in hierarchic realm. 12 * 13 * Each class can be reloaded in the realm if the LoadClassPredicate says so. 14 */ 15 public class SimplePerRealmReloadingClassLoader extends URLClassLoader { 16 17 private final Map<String,Class<?>> classHashMap = new HashMap<String, Class<?>>(); 18 private ReloadClassPredicate reloadClassPredicate; 19 SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate)20 public SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate) { 21 super(getPossibleClassPathsUrls()); 22 this.reloadClassPredicate = reloadClassPredicate; 23 } 24 SimplePerRealmReloadingClassLoader(ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate)25 public SimplePerRealmReloadingClassLoader(ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate) { 26 super(getPossibleClassPathsUrls(), parentClassLoader); 27 this.reloadClassPredicate = reloadClassPredicate; 28 } 29 getPossibleClassPathsUrls()30 private static URL[] getPossibleClassPathsUrls() { 31 return new URL[]{ 32 obtainClassPath(), 33 obtainClassPath("org.mockito.Mockito"), 34 obtainClassPath("net.bytebuddy.ByteBuddy") 35 }; 36 } 37 obtainClassPath()38 private static URL obtainClassPath() { 39 String className = SimplePerRealmReloadingClassLoader.class.getName(); 40 return obtainClassPath(className); 41 } 42 obtainClassPath(String className)43 private static URL obtainClassPath(String className) { 44 String path = className.replace('.', '/') + ".class"; 45 String url = SimplePerRealmReloadingClassLoader.class.getClassLoader().getResource(path).toExternalForm(); 46 47 try { 48 return new URL(url.substring(0, url.length() - path.length())); 49 } catch (MalformedURLException e) { 50 throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e); 51 } 52 } 53 54 55 56 @Override loadClass(String qualifiedClassName)57 public Class<?> loadClass(String qualifiedClassName) throws ClassNotFoundException { 58 if(reloadClassPredicate.acceptReloadOf(qualifiedClassName)) { 59 // return customLoadClass(qualifiedClassName); 60 // Class<?> loadedClass = findLoadedClass(qualifiedClassName); 61 if(!classHashMap.containsKey(qualifiedClassName)) { 62 Class<?> foundClass = findClass(qualifiedClassName); 63 saveFoundClass(qualifiedClassName, foundClass); 64 return foundClass; 65 } 66 67 return classHashMap.get(qualifiedClassName); 68 } 69 return useParentClassLoaderFor(qualifiedClassName); 70 } 71 saveFoundClass(String qualifiedClassName, Class<?> foundClass)72 private void saveFoundClass(String qualifiedClassName, Class<?> foundClass) { 73 classHashMap.put(qualifiedClassName, foundClass); 74 } 75 76 useParentClassLoaderFor(String qualifiedName)77 private Class<?> useParentClassLoaderFor(String qualifiedName) throws ClassNotFoundException { 78 return super.loadClass(qualifiedName); 79 } 80 81 doInRealm(String callableCalledInClassLoaderRealm)82 public Object doInRealm(String callableCalledInClassLoaderRealm) throws Exception { 83 ClassLoader current = Thread.currentThread().getContextClassLoader(); 84 try { 85 Thread.currentThread().setContextClassLoader(this); 86 Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor().newInstance(); 87 if (instance instanceof Callable) { 88 Callable<?> callableInRealm = (Callable<?>) instance; 89 return callableInRealm.call(); 90 } 91 } finally { 92 Thread.currentThread().setContextClassLoader(current); 93 } 94 throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable"); 95 } 96 97 doInRealm(String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args)98 public Object doInRealm(String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args) throws Exception { 99 ClassLoader current = Thread.currentThread().getContextClassLoader(); 100 try { 101 Thread.currentThread().setContextClassLoader(this); 102 Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor(argTypes).newInstance(args); 103 if (instance instanceof Callable) { 104 Callable<?> callableInRealm = (Callable<?>) instance; 105 return callableInRealm.call(); 106 } 107 } finally { 108 Thread.currentThread().setContextClassLoader(current); 109 } 110 111 throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable"); 112 } 113 114 115 public interface ReloadClassPredicate { acceptReloadOf(String qualifiedName)116 boolean acceptReloadOf(String qualifiedName); 117 } 118 } 119