• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017 Mockito contributors
3  * This program is made available under the terms of the MIT License.
4  */
5 package org.mockitoutil;
6 
7 import java.net.MalformedURLException;
8 import java.net.URL;
9 import java.net.URLClassLoader;
10 import java.util.HashMap;
11 import java.util.Map;
12 import java.util.concurrent.Callable;
13 
14 /**
15  * Custom classloader to load classes in hierarchic realm.
16  *
17  * Each class can be reloaded in the realm if the LoadClassPredicate says so.
18  */
19 public class SimplePerRealmReloadingClassLoader extends URLClassLoader {
20 
21     private final Map<String, Class<?>> classHashMap = new HashMap<String, Class<?>>();
22     private ReloadClassPredicate reloadClassPredicate;
23 
SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate)24     public SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate) {
25         super(getPossibleClassPathsUrls());
26         this.reloadClassPredicate = reloadClassPredicate;
27     }
28 
SimplePerRealmReloadingClassLoader( ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate)29     public SimplePerRealmReloadingClassLoader(
30             ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate) {
31         super(getPossibleClassPathsUrls(), parentClassLoader);
32         this.reloadClassPredicate = reloadClassPredicate;
33     }
34 
getPossibleClassPathsUrls()35     private static URL[] getPossibleClassPathsUrls() {
36         return new URL[] {
37             obtainClassPath(),
38             obtainClassPath("org.mockito.Mockito"),
39             obtainClassPath("net.bytebuddy.ByteBuddy")
40         };
41     }
42 
obtainClassPath()43     private static URL obtainClassPath() {
44         String className = SimplePerRealmReloadingClassLoader.class.getName();
45         return obtainClassPath(className);
46     }
47 
obtainClassPath(String className)48     private static URL obtainClassPath(String className) {
49         String path = className.replace('.', '/') + ".class";
50         String url =
51                 SimplePerRealmReloadingClassLoader.class
52                         .getClassLoader()
53                         .getResource(path)
54                         .toExternalForm();
55 
56         try {
57             return new URL(url.substring(0, url.length() - path.length()));
58         } catch (MalformedURLException e) {
59             throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e);
60         }
61     }
62 
63     @Override
loadClass(String qualifiedClassName)64     public Class<?> loadClass(String qualifiedClassName) throws ClassNotFoundException {
65         if (reloadClassPredicate.acceptReloadOf(qualifiedClassName)) {
66             // return customLoadClass(qualifiedClassName);
67             //            Class<?> loadedClass = findLoadedClass(qualifiedClassName);
68             if (!classHashMap.containsKey(qualifiedClassName)) {
69                 Class<?> foundClass = findClass(qualifiedClassName);
70                 saveFoundClass(qualifiedClassName, foundClass);
71                 return foundClass;
72             }
73 
74             return classHashMap.get(qualifiedClassName);
75         }
76         return useParentClassLoaderFor(qualifiedClassName);
77     }
78 
saveFoundClass(String qualifiedClassName, Class<?> foundClass)79     private void saveFoundClass(String qualifiedClassName, Class<?> foundClass) {
80         classHashMap.put(qualifiedClassName, foundClass);
81     }
82 
useParentClassLoaderFor(String qualifiedName)83     private Class<?> useParentClassLoaderFor(String qualifiedName) throws ClassNotFoundException {
84         return super.loadClass(qualifiedName);
85     }
86 
doInRealm(String callableCalledInClassLoaderRealm)87     public Object doInRealm(String callableCalledInClassLoaderRealm) throws Exception {
88         ClassLoader current = Thread.currentThread().getContextClassLoader();
89         try {
90             Thread.currentThread().setContextClassLoader(this);
91             Object instance =
92                     this.loadClass(callableCalledInClassLoaderRealm).getConstructor().newInstance();
93             if (instance instanceof Callable) {
94                 Callable<?> callableInRealm = (Callable<?>) instance;
95                 return callableInRealm.call();
96             }
97         } finally {
98             Thread.currentThread().setContextClassLoader(current);
99         }
100         throw new IllegalArgumentException(
101                 "qualified name '"
102                         + callableCalledInClassLoaderRealm
103                         + "' should represent a class implementing Callable");
104     }
105 
doInRealm( String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args)106     public Object doInRealm(
107             String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args)
108             throws Exception {
109         ClassLoader current = Thread.currentThread().getContextClassLoader();
110         try {
111             Thread.currentThread().setContextClassLoader(this);
112             Object instance =
113                     this.loadClass(callableCalledInClassLoaderRealm)
114                             .getConstructor(argTypes)
115                             .newInstance(args);
116             if (instance instanceof Callable) {
117                 Callable<?> callableInRealm = (Callable<?>) instance;
118                 return callableInRealm.call();
119             }
120         } finally {
121             Thread.currentThread().setContextClassLoader(current);
122         }
123 
124         throw new IllegalArgumentException(
125                 "qualified name '"
126                         + callableCalledInClassLoaderRealm
127                         + "' should represent a class implementing Callable");
128     }
129 
130     public interface ReloadClassPredicate {
acceptReloadOf(String qualifiedName)131         boolean acceptReloadOf(String qualifiedName);
132     }
133 }
134