• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017 Mockito contributors
3  * This program is made available under the terms of the MIT License.
4  */
5 package org.mockitoutil;
6 
7 import static java.lang.String.format;
8 import static java.util.Arrays.asList;
9 
10 import java.io.ByteArrayInputStream;
11 import java.io.File;
12 import java.io.IOException;
13 import java.io.InputStream;
14 import java.lang.reflect.Field;
15 import java.lang.reflect.Modifier;
16 import java.net.MalformedURLException;
17 import java.net.URI;
18 import java.net.URISyntaxException;
19 import java.net.URL;
20 import java.net.URLClassLoader;
21 import java.net.URLConnection;
22 import java.net.URLStreamHandler;
23 import java.util.ArrayList;
24 import java.util.Arrays;
25 import java.util.Collections;
26 import java.util.Enumeration;
27 import java.util.HashMap;
28 import java.util.HashSet;
29 import java.util.Iterator;
30 import java.util.List;
31 import java.util.Map;
32 import java.util.Set;
33 import java.util.concurrent.ExecutionException;
34 import java.util.concurrent.ExecutorService;
35 import java.util.concurrent.Executors;
36 import java.util.concurrent.Future;
37 import java.util.concurrent.ThreadFactory;
38 
39 import org.mockito.internal.configuration.plugins.Plugins;
40 import org.mockito.plugins.MemberAccessor;
41 import org.objenesis.Objenesis;
42 import org.objenesis.ObjenesisStd;
43 import org.objenesis.instantiator.ObjectInstantiator;
44 
45 public abstract class ClassLoaders {
46     protected ClassLoader parent = currentClassLoader();
47 
ClassLoaders()48     protected ClassLoaders() {}
49 
isolatedClassLoader()50     public static IsolatedURLClassLoaderBuilder isolatedClassLoader() {
51         return new IsolatedURLClassLoaderBuilder();
52     }
53 
excludingClassLoader()54     public static ExcludingURLClassLoaderBuilder excludingClassLoader() {
55         return new ExcludingURLClassLoaderBuilder();
56     }
57 
inMemoryClassLoader()58     public static InMemoryClassLoaderBuilder inMemoryClassLoader() {
59         return new InMemoryClassLoaderBuilder();
60     }
61 
in(ClassLoader classLoader)62     public static ReachableClassesFinder in(ClassLoader classLoader) {
63         return new ReachableClassesFinder(classLoader);
64     }
65 
jdkClassLoader()66     public static ClassLoader jdkClassLoader() {
67         return String.class.getClassLoader();
68     }
69 
systemClassLoader()70     public static ClassLoader systemClassLoader() {
71         return ClassLoader.getSystemClassLoader();
72     }
73 
currentClassLoader()74     public static ClassLoader currentClassLoader() {
75         return ClassLoaders.class.getClassLoader();
76     }
77 
build()78     public abstract ClassLoader build();
79 
coverageTool()80     public static Class<?>[] coverageTool() {
81         HashSet<Class<?>> classes = new HashSet<Class<?>>();
82         classes.add(safeGetClass("net.sourceforge.cobertura.coveragedata.TouchCollector"));
83         classes.add(safeGetClass("org.slf4j.LoggerFactory"));
84 
85         classes.remove(null);
86         return classes.toArray(new Class<?>[classes.size()]);
87     }
88 
safeGetClass(String className)89     private static Class<?> safeGetClass(String className) {
90         try {
91             return Class.forName(className);
92         } catch (ClassNotFoundException e) {
93             return null;
94         }
95     }
96 
using(final ClassLoader classLoader)97     public static ClassLoaderExecutor using(final ClassLoader classLoader) {
98         return new ClassLoaderExecutor(classLoader);
99     }
100 
101     public static class ClassLoaderExecutor {
102         private ClassLoader classLoader;
103 
ClassLoaderExecutor(ClassLoader classLoader)104         public ClassLoaderExecutor(ClassLoader classLoader) {
105             this.classLoader = classLoader;
106         }
107 
execute(final Runnable task)108         public void execute(final Runnable task) throws Exception {
109             ExecutorService executorService =
110                     Executors.newSingleThreadExecutor(
111                             new ThreadFactory() {
112                                 @Override
113                                 public Thread newThread(Runnable r) {
114                                     Thread thread = Executors.defaultThreadFactory().newThread(r);
115                                     thread.setContextClassLoader(classLoader);
116                                     return thread;
117                                 }
118                             });
119             try {
120                 Future<?> taskFuture =
121                         executorService.submit(
122                                 new Runnable() {
123                                     @Override
124                                     public void run() {
125                                         try {
126                                             reloadTaskInClassLoader(task).run();
127                                         } catch (Throwable throwable) {
128                                             throw new IllegalStateException(
129                                                     format(
130                                                             "Given task could not be loaded properly in the given classloader '%s', error '%s",
131                                                             task, throwable.getMessage()),
132                                                     throwable);
133                                         }
134                                     }
135                                 });
136                 taskFuture.get();
137                 executorService.shutdownNow();
138             } catch (InterruptedException e) {
139                 Thread.currentThread().interrupt();
140             } catch (ExecutionException e) {
141                 throw this.<Exception>unwrapAndThrows(e);
142             }
143         }
144 
145         @SuppressWarnings("unchecked")
unwrapAndThrows(ExecutionException ex)146         private <T extends Throwable> T unwrapAndThrows(ExecutionException ex) throws T {
147             throw (T) ex.getCause();
148         }
149 
reloadTaskInClassLoader(Runnable task)150         Runnable reloadTaskInClassLoader(Runnable task) {
151             try {
152                 @SuppressWarnings("unchecked")
153                 Class<Runnable> taskClassReloaded =
154                         (Class<Runnable>) classLoader.loadClass(task.getClass().getName());
155 
156                 Objenesis objenesis = new ObjenesisStd();
157                 ObjectInstantiator<Runnable> thingyInstantiator =
158                         objenesis.getInstantiatorOf(taskClassReloaded);
159                 Runnable reloaded = thingyInstantiator.newInstance();
160 
161                 // lenient shallow copy of class compatible fields
162                 for (Field field : task.getClass().getDeclaredFields()) {
163                     Field declaredField = taskClassReloaded.getDeclaredField(field.getName());
164                     int modifiers = declaredField.getModifiers();
165                     if (Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) {
166                         // Skip static final fields (e.g. jacoco fields)
167                         // otherwise IllegalAccessException (can be bypassed with Unsafe though)
168                         // We may also miss coverage data.
169                         continue;
170                     }
171                     if (declaredField.getType() == field.getType()) { // don't copy this
172                         MemberAccessor accessor = Plugins.getMemberAccessor();
173                         accessor.set(declaredField, reloaded, accessor.get(field, task));
174                     }
175                 }
176 
177                 return reloaded;
178             } catch (ClassNotFoundException e) {
179                 throw new IllegalStateException(e);
180             } catch (IllegalAccessException e) {
181                 throw new IllegalStateException(e);
182             } catch (NoSuchFieldException e) {
183                 throw new IllegalStateException(e);
184             }
185         }
186     }
187 
188     public static class IsolatedURLClassLoaderBuilder extends ClassLoaders {
189         private final ArrayList<String> excludedPrefixes = new ArrayList<String>();
190         private final ArrayList<String> privateCopyPrefixes = new ArrayList<String>();
191         private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>();
192 
withPrivateCopyOf(String... privatePrefixes)193         public IsolatedURLClassLoaderBuilder withPrivateCopyOf(String... privatePrefixes) {
194             privateCopyPrefixes.addAll(asList(privatePrefixes));
195             return this;
196         }
197 
withCodeSourceUrls(String... urls)198         public IsolatedURLClassLoaderBuilder withCodeSourceUrls(String... urls) {
199             codeSourceUrls.addAll(pathsToURLs(urls));
200             return this;
201         }
202 
withCodeSourceUrlOf(Class<?>.... classes)203         public IsolatedURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) {
204             for (Class<?> clazz : classes) {
205                 codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName()));
206             }
207             return this;
208         }
209 
withCurrentCodeSourceUrls()210         public IsolatedURLClassLoaderBuilder withCurrentCodeSourceUrls() {
211             codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName()));
212             return this;
213         }
214 
without(String... privatePrefixes)215         public IsolatedURLClassLoaderBuilder without(String... privatePrefixes) {
216             excludedPrefixes.addAll(asList(privatePrefixes));
217             return this;
218         }
219 
build()220         public ClassLoader build() {
221             return new LocalIsolatedURLClassLoader(
222                     jdkClassLoader(),
223                     codeSourceUrls.toArray(new URL[codeSourceUrls.size()]),
224                     privateCopyPrefixes,
225                     excludedPrefixes);
226         }
227     }
228 
229     static class LocalIsolatedURLClassLoader extends URLClassLoader {
230         private final ArrayList<String> privateCopyPrefixes;
231         private final ArrayList<String> excludedPrefixes;
232 
LocalIsolatedURLClassLoader( ClassLoader classLoader, URL[] urls, ArrayList<String> privateCopyPrefixes, ArrayList<String> excludedPrefixes)233         LocalIsolatedURLClassLoader(
234                 ClassLoader classLoader,
235                 URL[] urls,
236                 ArrayList<String> privateCopyPrefixes,
237                 ArrayList<String> excludedPrefixes) {
238             super(urls, classLoader);
239             this.privateCopyPrefixes = privateCopyPrefixes;
240             this.excludedPrefixes = excludedPrefixes;
241         }
242 
243         @Override
findClass(String name)244         public Class<?> findClass(String name) throws ClassNotFoundException {
245             if (!classShouldBePrivate(name) || classShouldBeExcluded(name)) {
246                 throw new ClassNotFoundException(
247                         format(
248                                 "Can only load classes with prefixes : %s, but not : %s",
249                                 privateCopyPrefixes, excludedPrefixes));
250             }
251             try {
252                 return super.findClass(name);
253             } catch (ClassNotFoundException cnfe) {
254                 throw new ClassNotFoundException(
255                         format(
256                                 "%s%n%s%n",
257                                 cnfe.getMessage(),
258                                 "    Did you forgot to add the code source url 'withCodeSourceUrlOf' / 'withCurrentCodeSourceUrls' ?"),
259                         cnfe);
260             }
261         }
262 
classShouldBePrivate(String name)263         private boolean classShouldBePrivate(String name) {
264             for (String prefix : privateCopyPrefixes) {
265                 if (name.startsWith(prefix)) return true;
266             }
267             return false;
268         }
269 
classShouldBeExcluded(String name)270         private boolean classShouldBeExcluded(String name) {
271             for (String prefix : excludedPrefixes) {
272                 if (name.startsWith(prefix)) return true;
273             }
274             return false;
275         }
276     }
277 
278     public static class ExcludingURLClassLoaderBuilder extends ClassLoaders {
279         private final ArrayList<String> excludedPrefixes = new ArrayList<String>();
280         private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>();
281 
without(String... privatePrefixes)282         public ExcludingURLClassLoaderBuilder without(String... privatePrefixes) {
283             excludedPrefixes.addAll(asList(privatePrefixes));
284             return this;
285         }
286 
withCodeSourceUrls(String... urls)287         public ExcludingURLClassLoaderBuilder withCodeSourceUrls(String... urls) {
288             codeSourceUrls.addAll(pathsToURLs(urls));
289             return this;
290         }
291 
withCodeSourceUrlOf(Class<?>.... classes)292         public ExcludingURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) {
293             for (Class<?> clazz : classes) {
294                 codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName()));
295             }
296             return this;
297         }
298 
withCurrentCodeSourceUrls()299         public ExcludingURLClassLoaderBuilder withCurrentCodeSourceUrls() {
300             codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName()));
301             return this;
302         }
303 
build()304         public ClassLoader build() {
305             return new LocalExcludingURLClassLoader(
306                     jdkClassLoader(),
307                     codeSourceUrls.toArray(new URL[codeSourceUrls.size()]),
308                     excludedPrefixes);
309         }
310     }
311 
312     static class LocalExcludingURLClassLoader extends URLClassLoader {
313         private final ArrayList<String> excludedPrefixes;
314 
LocalExcludingURLClassLoader( ClassLoader classLoader, URL[] urls, ArrayList<String> excludedPrefixes)315         LocalExcludingURLClassLoader(
316                 ClassLoader classLoader, URL[] urls, ArrayList<String> excludedPrefixes) {
317             super(urls, classLoader);
318             this.excludedPrefixes = excludedPrefixes;
319         }
320 
321         @Override
findClass(String name)322         public Class<?> findClass(String name) throws ClassNotFoundException {
323             if (classShouldBeExcluded(name))
324                 throw new ClassNotFoundException(
325                         "classes with prefix : " + excludedPrefixes + " are excluded");
326             return super.findClass(name);
327         }
328 
classShouldBeExcluded(String name)329         private boolean classShouldBeExcluded(String name) {
330             for (String prefix : excludedPrefixes) {
331                 if (name.startsWith(prefix)) return true;
332             }
333             return false;
334         }
335     }
336 
337     public static class InMemoryClassLoaderBuilder extends ClassLoaders {
338         private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>();
339 
withParent(ClassLoader parent)340         public InMemoryClassLoaderBuilder withParent(ClassLoader parent) {
341             this.parent = parent;
342             return this;
343         }
344 
withClassDefinition(String name, byte[] classDefinition)345         public InMemoryClassLoaderBuilder withClassDefinition(String name, byte[] classDefinition) {
346             inMemoryClassObjects.put(name, classDefinition);
347             return this;
348         }
349 
build()350         public ClassLoader build() {
351             return new InMemoryClassLoader(parent, inMemoryClassObjects);
352         }
353     }
354 
355     static class InMemoryClassLoader extends ClassLoader {
356         public static final String SCHEME = "mem";
357         private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>();
358 
InMemoryClassLoader(ClassLoader parent, Map<String, byte[]> inMemoryClassObjects)359         public InMemoryClassLoader(ClassLoader parent, Map<String, byte[]> inMemoryClassObjects) {
360             super(parent);
361             this.inMemoryClassObjects = inMemoryClassObjects;
362         }
363 
findClass(String name)364         protected Class<?> findClass(String name) throws ClassNotFoundException {
365             byte[] classDefinition = inMemoryClassObjects.get(name);
366             if (classDefinition != null) {
367                 return defineClass(name, classDefinition, 0, classDefinition.length);
368             }
369             throw new ClassNotFoundException(name);
370         }
371 
372         @Override
getResources(String ignored)373         public Enumeration<URL> getResources(String ignored) throws IOException {
374             return inMemoryOnly();
375         }
376 
inMemoryOnly()377         private Enumeration<URL> inMemoryOnly() {
378             final Set<String> names = inMemoryClassObjects.keySet();
379             return new Enumeration<URL>() {
380                 private final MemHandler memHandler = new MemHandler(InMemoryClassLoader.this);
381                 private final Iterator<String> it = names.iterator();
382 
383                 public boolean hasMoreElements() {
384                     return it.hasNext();
385                 }
386 
387                 public URL nextElement() {
388                     try {
389                         return new URL(null, SCHEME + ":" + it.next(), memHandler);
390                     } catch (MalformedURLException rethrown) {
391                         throw new IllegalStateException(rethrown);
392                     }
393                 }
394             };
395         }
396     }
397 
398     public static class MemHandler extends URLStreamHandler {
399         private InMemoryClassLoader inMemoryClassLoader;
400 
MemHandler(InMemoryClassLoader inMemoryClassLoader)401         public MemHandler(InMemoryClassLoader inMemoryClassLoader) {
402             this.inMemoryClassLoader = inMemoryClassLoader;
403         }
404 
405         @Override
openConnection(URL url)406         protected URLConnection openConnection(URL url) throws IOException {
407             return new MemURLConnection(url, inMemoryClassLoader);
408         }
409 
410         private static class MemURLConnection extends URLConnection {
411             private final InMemoryClassLoader inMemoryClassLoader;
412             private String qualifiedName;
413 
MemURLConnection(URL url, InMemoryClassLoader inMemoryClassLoader)414             public MemURLConnection(URL url, InMemoryClassLoader inMemoryClassLoader) {
415                 super(url);
416                 this.inMemoryClassLoader = inMemoryClassLoader;
417                 qualifiedName = url.getPath();
418             }
419 
420             @Override
connect()421             public void connect() throws IOException {}
422 
423             @Override
getInputStream()424             public InputStream getInputStream() throws IOException {
425                 return new ByteArrayInputStream(
426                         inMemoryClassLoader.inMemoryClassObjects.get(qualifiedName));
427             }
428         }
429     }
430 
obtainCurrentClassPathOf(String className)431     URL obtainCurrentClassPathOf(String className) {
432         String path = className.replace('.', '/') + ".class";
433         String url = ClassLoaders.class.getClassLoader().getResource(path).toExternalForm();
434 
435         try {
436             return new URL(url.substring(0, url.length() - path.length()));
437         } catch (MalformedURLException e) {
438             throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e);
439         }
440     }
441 
pathsToURLs(String... codeSourceUrls)442     List<URL> pathsToURLs(String... codeSourceUrls) {
443         return pathsToURLs(Arrays.asList(codeSourceUrls));
444     }
445 
pathsToURLs(List<String> codeSourceUrls)446     private List<URL> pathsToURLs(List<String> codeSourceUrls) {
447         ArrayList<URL> urls = new ArrayList<URL>(codeSourceUrls.size());
448         for (String codeSourceUrl : codeSourceUrls) {
449             URL url = pathToUrl(codeSourceUrl);
450             urls.add(url);
451         }
452         return urls;
453     }
454 
pathToUrl(String path)455     private URL pathToUrl(String path) {
456         try {
457             return new File(path).getAbsoluteFile().toURI().toURL();
458         } catch (MalformedURLException e) {
459             throw new IllegalArgumentException("Path is malformed", e);
460         }
461     }
462 
463     public static class ReachableClassesFinder {
464         private ClassLoader classLoader;
465         private Set<String> qualifiedNameSubstring = new HashSet<String>();
466 
ReachableClassesFinder(ClassLoader classLoader)467         ReachableClassesFinder(ClassLoader classLoader) {
468             this.classLoader = classLoader;
469         }
470 
omit(String... qualifiedNameSubstring)471         public ReachableClassesFinder omit(String... qualifiedNameSubstring) {
472             this.qualifiedNameSubstring.addAll(Arrays.asList(qualifiedNameSubstring));
473             return this;
474         }
475 
listOwnedClasses()476         public Set<String> listOwnedClasses() throws IOException, URISyntaxException {
477             Enumeration<URL> roots = classLoader.getResources("");
478 
479             Set<String> classes = new HashSet<String>();
480             while (roots.hasMoreElements()) {
481                 URI uri = roots.nextElement().toURI();
482 
483                 if (uri.getScheme().equalsIgnoreCase("file")) {
484                     addFromFileBasedClassLoader(classes, uri);
485                 } else if (uri.getScheme().equalsIgnoreCase(InMemoryClassLoader.SCHEME)) {
486                     addFromInMemoryBasedClassLoader(classes, uri);
487                 } else if (uri.getScheme().equalsIgnoreCase("jar")) {
488                     // Java 9+ returns "jar:file:" style urls for modules.
489                     // It's not a classes owned by mockito itself.
490                     // Just ignore it.
491                     continue;
492                 } else {
493                     throw new IllegalArgumentException(
494                             format(
495                                     "Given ClassLoader '%s' don't have reachable by File or vi ClassLoaders.inMemory",
496                                     classLoader));
497                 }
498             }
499             return classes;
500         }
501 
addFromFileBasedClassLoader(Set<String> classes, URI uri)502         private void addFromFileBasedClassLoader(Set<String> classes, URI uri) {
503             File root = new File(uri);
504             classes.addAll(findClassQualifiedNames(root, root, qualifiedNameSubstring));
505         }
506 
addFromInMemoryBasedClassLoader(Set<String> classes, URI uri)507         private void addFromInMemoryBasedClassLoader(Set<String> classes, URI uri) {
508             String qualifiedName = uri.getSchemeSpecificPart();
509             if (excludes(qualifiedName, qualifiedNameSubstring)) {
510                 classes.add(qualifiedName);
511             }
512         }
513 
findClassQualifiedNames( File root, File file, Set<String> packageFilters)514         private Set<String> findClassQualifiedNames(
515                 File root, File file, Set<String> packageFilters) {
516             if (file.isDirectory()) {
517                 File[] files = file.listFiles();
518                 Set<String> classes = new HashSet<String>();
519                 for (File children : files) {
520                     classes.addAll(findClassQualifiedNames(root, children, packageFilters));
521                 }
522                 return classes;
523             } else {
524                 if (file.getName().endsWith(".class")) {
525                     String qualifiedName = classNameFor(root, file);
526                     if (excludes(qualifiedName, packageFilters)) {
527                         return Collections.singleton(qualifiedName);
528                     }
529                 }
530             }
531             return Collections.emptySet();
532         }
533 
excludes(String qualifiedName, Set<String> packageFilters)534         private boolean excludes(String qualifiedName, Set<String> packageFilters) {
535             for (String filter : packageFilters) {
536                 if (qualifiedName.contains(filter)) return false;
537             }
538             return true;
539         }
540 
classNameFor(File root, File file)541         private String classNameFor(File root, File file) {
542             String temp =
543                     file.getAbsolutePath()
544                             .substring(root.getAbsolutePath().length() + 1)
545                             .replace(File.separatorChar, '.');
546             return temp.subSequence(0, temp.indexOf(".class")).toString();
547         }
548     }
549 }
550