• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 package org.mockitoutil;
2 
3 import java.net.MalformedURLException;
4 import java.net.URL;
5 import java.net.URLClassLoader;
6 import java.util.HashMap;
7 import java.util.Map;
8 import java.util.concurrent.Callable;
9 
10 /**
11  * Custom classloader to load classes in hierarchic realm.
12  *
13  * Each class can be reloaded in the realm if the LoadClassPredicate says so.
14  */
15 public class SimplePerRealmReloadingClassLoader extends URLClassLoader {
16 
17     private final Map<String,Class<?>> classHashMap = new HashMap<String, Class<?>>();
18     private ReloadClassPredicate reloadClassPredicate;
19 
SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate)20     public SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate) {
21         super(getPossibleClassPathsUrls());
22         this.reloadClassPredicate = reloadClassPredicate;
23     }
24 
SimplePerRealmReloadingClassLoader(ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate)25     public SimplePerRealmReloadingClassLoader(ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate) {
26         super(getPossibleClassPathsUrls(), parentClassLoader);
27         this.reloadClassPredicate = reloadClassPredicate;
28     }
29 
getPossibleClassPathsUrls()30     private static URL[] getPossibleClassPathsUrls() {
31         return new URL[]{
32                 obtainClassPath(),
33                 obtainClassPath("org.mockito.Mockito"),
34                 obtainClassPath("net.bytebuddy.ByteBuddy")
35         };
36     }
37 
obtainClassPath()38     private static URL obtainClassPath() {
39         String className = SimplePerRealmReloadingClassLoader.class.getName();
40         return obtainClassPath(className);
41     }
42 
obtainClassPath(String className)43     private static URL obtainClassPath(String className) {
44         String path = className.replace('.', '/') + ".class";
45         String url = SimplePerRealmReloadingClassLoader.class.getClassLoader().getResource(path).toExternalForm();
46 
47         try {
48             return new URL(url.substring(0, url.length() - path.length()));
49         } catch (MalformedURLException e) {
50             throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e);
51         }
52     }
53 
54 
55 
56     @Override
loadClass(String qualifiedClassName)57     public Class<?> loadClass(String qualifiedClassName) throws ClassNotFoundException {
58         if(reloadClassPredicate.acceptReloadOf(qualifiedClassName)) {
59             // return customLoadClass(qualifiedClassName);
60 //            Class<?> loadedClass = findLoadedClass(qualifiedClassName);
61             if(!classHashMap.containsKey(qualifiedClassName)) {
62                 Class<?> foundClass = findClass(qualifiedClassName);
63                 saveFoundClass(qualifiedClassName, foundClass);
64                 return foundClass;
65             }
66 
67             return classHashMap.get(qualifiedClassName);
68         }
69         return useParentClassLoaderFor(qualifiedClassName);
70     }
71 
saveFoundClass(String qualifiedClassName, Class<?> foundClass)72     private void saveFoundClass(String qualifiedClassName, Class<?> foundClass) {
73         classHashMap.put(qualifiedClassName, foundClass);
74     }
75 
76 
useParentClassLoaderFor(String qualifiedName)77     private Class<?> useParentClassLoaderFor(String qualifiedName) throws ClassNotFoundException {
78         return super.loadClass(qualifiedName);
79     }
80 
81 
doInRealm(String callableCalledInClassLoaderRealm)82     public Object doInRealm(String callableCalledInClassLoaderRealm) throws Exception {
83         ClassLoader current = Thread.currentThread().getContextClassLoader();
84         try {
85             Thread.currentThread().setContextClassLoader(this);
86             Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor().newInstance();
87             if (instance instanceof Callable) {
88                 Callable<?> callableInRealm = (Callable<?>) instance;
89                 return callableInRealm.call();
90             }
91         } finally {
92             Thread.currentThread().setContextClassLoader(current);
93         }
94         throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable");
95     }
96 
97 
doInRealm(String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args)98     public Object doInRealm(String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args) throws Exception {
99         ClassLoader current = Thread.currentThread().getContextClassLoader();
100         try {
101             Thread.currentThread().setContextClassLoader(this);
102             Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor(argTypes).newInstance(args);
103             if (instance instanceof Callable) {
104                 Callable<?> callableInRealm = (Callable<?>) instance;
105                 return callableInRealm.call();
106             }
107         } finally {
108             Thread.currentThread().setContextClassLoader(current);
109         }
110 
111         throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable");
112     }
113 
114 
115     public interface ReloadClassPredicate {
acceptReloadOf(String qualifiedName)116         boolean acceptReloadOf(String qualifiedName);
117     }
118 }
119