1 /* 2 * Copyright (c) 2017 Mockito contributors 3 * This program is made available under the terms of the MIT License. 4 */ 5 package org.mockitoutil; 6 7 import java.net.MalformedURLException; 8 import java.net.URL; 9 import java.net.URLClassLoader; 10 import java.util.HashMap; 11 import java.util.Map; 12 import java.util.concurrent.Callable; 13 14 /** 15 * Custom classloader to load classes in hierarchic realm. 16 * 17 * Each class can be reloaded in the realm if the LoadClassPredicate says so. 18 */ 19 public class SimplePerRealmReloadingClassLoader extends URLClassLoader { 20 21 private final Map<String,Class<?>> classHashMap = new HashMap<String, Class<?>>(); 22 private ReloadClassPredicate reloadClassPredicate; 23 SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate)24 public SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate) { 25 super(getPossibleClassPathsUrls()); 26 this.reloadClassPredicate = reloadClassPredicate; 27 } 28 SimplePerRealmReloadingClassLoader(ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate)29 public SimplePerRealmReloadingClassLoader(ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate) { 30 super(getPossibleClassPathsUrls(), parentClassLoader); 31 this.reloadClassPredicate = reloadClassPredicate; 32 } 33 getPossibleClassPathsUrls()34 private static URL[] getPossibleClassPathsUrls() { 35 return new URL[]{ 36 obtainClassPath(), 37 obtainClassPath("org.mockito.Mockito"), 38 obtainClassPath("net.bytebuddy.ByteBuddy") 39 }; 40 } 41 obtainClassPath()42 private static URL obtainClassPath() { 43 String className = SimplePerRealmReloadingClassLoader.class.getName(); 44 return obtainClassPath(className); 45 } 46 obtainClassPath(String className)47 private static URL obtainClassPath(String className) { 48 String path = className.replace('.', '/') + ".class"; 49 String url = SimplePerRealmReloadingClassLoader.class.getClassLoader().getResource(path).toExternalForm(); 50 51 try { 52 return new URL(url.substring(0, url.length() - path.length())); 53 } catch (MalformedURLException e) { 54 throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e); 55 } 56 } 57 58 59 60 @Override loadClass(String qualifiedClassName)61 public Class<?> loadClass(String qualifiedClassName) throws ClassNotFoundException { 62 if(reloadClassPredicate.acceptReloadOf(qualifiedClassName)) { 63 // return customLoadClass(qualifiedClassName); 64 // Class<?> loadedClass = findLoadedClass(qualifiedClassName); 65 if(!classHashMap.containsKey(qualifiedClassName)) { 66 Class<?> foundClass = findClass(qualifiedClassName); 67 saveFoundClass(qualifiedClassName, foundClass); 68 return foundClass; 69 } 70 71 return classHashMap.get(qualifiedClassName); 72 } 73 return useParentClassLoaderFor(qualifiedClassName); 74 } 75 saveFoundClass(String qualifiedClassName, Class<?> foundClass)76 private void saveFoundClass(String qualifiedClassName, Class<?> foundClass) { 77 classHashMap.put(qualifiedClassName, foundClass); 78 } 79 80 useParentClassLoaderFor(String qualifiedName)81 private Class<?> useParentClassLoaderFor(String qualifiedName) throws ClassNotFoundException { 82 return super.loadClass(qualifiedName); 83 } 84 85 doInRealm(String callableCalledInClassLoaderRealm)86 public Object doInRealm(String callableCalledInClassLoaderRealm) throws Exception { 87 ClassLoader current = Thread.currentThread().getContextClassLoader(); 88 try { 89 Thread.currentThread().setContextClassLoader(this); 90 Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor().newInstance(); 91 if (instance instanceof Callable) { 92 Callable<?> callableInRealm = (Callable<?>) instance; 93 return callableInRealm.call(); 94 } 95 } finally { 96 Thread.currentThread().setContextClassLoader(current); 97 } 98 throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable"); 99 } 100 101 doInRealm(String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args)102 public Object doInRealm(String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args) throws Exception { 103 ClassLoader current = Thread.currentThread().getContextClassLoader(); 104 try { 105 Thread.currentThread().setContextClassLoader(this); 106 Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor(argTypes).newInstance(args); 107 if (instance instanceof Callable) { 108 Callable<?> callableInRealm = (Callable<?>) instance; 109 return callableInRealm.call(); 110 } 111 } finally { 112 Thread.currentThread().setContextClassLoader(current); 113 } 114 115 throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable"); 116 } 117 118 119 public interface ReloadClassPredicate { acceptReloadOf(String qualifiedName)120 boolean acceptReloadOf(String qualifiedName); 121 } 122 } 123