1 package org.robolectric.internal.bytecode; 2 3 import static com.google.common.base.StandardSystemProperty.JAVA_CLASS_PATH; 4 import static com.google.common.base.StandardSystemProperty.PATH_SEPARATOR; 5 import static org.robolectric.util.ReflectionHelpers.ClassParameter.from; 6 7 import com.google.common.base.Splitter; 8 import com.google.common.collect.ImmutableList; 9 import java.io.File; 10 import java.io.IOException; 11 import java.io.InputStream; 12 import java.net.MalformedURLException; 13 import java.net.URL; 14 import java.net.URLClassLoader; 15 import org.robolectric.util.Logger; 16 import org.robolectric.util.PerfStatsCollector; 17 import org.robolectric.util.ReflectionHelpers; 18 import org.robolectric.util.Util; 19 20 /** 21 * Class loader that modifies the bytecode of Android classes to insert calls to Robolectric's 22 * shadow classes. 23 */ 24 public class SandboxClassLoader extends URLClassLoader { 25 private final ClassLoader systemClassLoader; 26 private final ClassLoader urls; 27 private final InstrumentationConfiguration config; 28 private final ClassInstrumentor classInstrumentor; 29 private final ClassNodeProvider classNodeProvider; 30 SandboxClassLoader(InstrumentationConfiguration config)31 public SandboxClassLoader(InstrumentationConfiguration config) { 32 this(ClassLoader.getSystemClassLoader(), config); 33 } 34 SandboxClassLoader( ClassLoader systemClassLoader, InstrumentationConfiguration config, URL... urls)35 public SandboxClassLoader( 36 ClassLoader systemClassLoader, InstrumentationConfiguration config, URL... urls) { 37 super(getClassPathUrls(systemClassLoader), systemClassLoader.getParent()); 38 this.systemClassLoader = systemClassLoader; 39 40 this.config = config; 41 this.urls = new URLClassLoader(urls, null); 42 for (URL url : urls) { 43 Logger.debug("Loading classes from: %s", url); 44 } 45 46 ClassInstrumentor.Decorator decorator = new ShadowDecorator(); 47 classInstrumentor = createClassInstrumentor(decorator); 48 49 classNodeProvider = new ClassNodeProvider() { 50 @Override 51 protected byte[] getClassBytes(String internalClassName) throws ClassNotFoundException { 52 return getByteCode(internalClassName); 53 } 54 }; 55 } 56 getClassPathUrls(ClassLoader classloader)57 private static URL[] getClassPathUrls(ClassLoader classloader) { 58 if (classloader instanceof URLClassLoader) { 59 return ((URLClassLoader) classloader).getURLs(); 60 } 61 return parseJavaClassPath(); 62 } 63 64 // TODO(b/65488446): Use a public API once one is available. parseJavaClassPath()65 private static URL[] parseJavaClassPath() { 66 ImmutableList.Builder<URL> urls = ImmutableList.builder(); 67 for (String entry : Splitter.on(PATH_SEPARATOR.value()).split(JAVA_CLASS_PATH.value())) { 68 try { 69 try { 70 urls.add(new File(entry).toURI().toURL()); 71 } catch (SecurityException e) { // File.toURI checks to see if the file is a directory 72 urls.add(new URL("file", null, new File(entry).getAbsolutePath())); 73 } 74 } catch (MalformedURLException e) { 75 Logger.strict("malformed classpath entry: " + entry, e); 76 } 77 } 78 return urls.build().toArray(new URL[0]); 79 } 80 createClassInstrumentor(ClassInstrumentor.Decorator decorator)81 protected ClassInstrumentor createClassInstrumentor(ClassInstrumentor.Decorator decorator) { 82 return InvokeDynamic.ENABLED 83 ? new InvokeDynamicClassInstrumentor(decorator) 84 : new OldClassInstrumentor(decorator); 85 } 86 87 @Override getResource(String name)88 public URL getResource(String name) { 89 if (config.shouldAcquireResource(name)) { 90 return urls.getResource(name); 91 } 92 URL fromParent = super.getResource(name); 93 if (fromParent != null) { 94 return fromParent; 95 } 96 return urls.getResource(name); 97 } 98 getClassBytesAsStreamPreferringLocalUrls(String resName)99 private InputStream getClassBytesAsStreamPreferringLocalUrls(String resName) { 100 InputStream fromUrlsClassLoader = urls.getResourceAsStream(resName); 101 if (fromUrlsClassLoader != null) { 102 return fromUrlsClassLoader; 103 } 104 return super.getResourceAsStream(resName); 105 } 106 107 @Override findClass(String name)108 protected Class<?> findClass(String name) throws ClassNotFoundException { 109 if (config.shouldAcquire(name)) { 110 return PerfStatsCollector.getInstance().measure("load sandboxed class", 111 () -> maybeInstrumentClass(name)); 112 } else { 113 return systemClassLoader.loadClass(name); 114 } 115 } 116 maybeInstrumentClass(String className)117 protected Class<?> maybeInstrumentClass(String className) throws ClassNotFoundException { 118 final byte[] origClassBytes = getByteCode(className); 119 120 MutableClass mutableClass = PerfStatsCollector.getInstance().measure("analyze class", 121 () -> classInstrumentor.analyzeClass(origClassBytes, config, classNodeProvider) 122 ); 123 124 try { 125 final byte[] bytes; 126 if (config.shouldInstrument(mutableClass)) { 127 bytes = PerfStatsCollector.getInstance().measure("instrument class", 128 () -> classInstrumentor.instrumentToBytes(mutableClass) 129 ); 130 } else { 131 bytes = postProcessUninstrumentedClass(mutableClass, origClassBytes); 132 } 133 ensurePackage(className); 134 return defineClass(className, bytes, 0, bytes.length); 135 } catch (Exception e) { 136 throw new ClassNotFoundException("couldn't load " + className, e); 137 } catch (OutOfMemoryError e) { 138 System.err.println("[ERROR] couldn't load " + className + " in " + this); 139 throw e; 140 } 141 } 142 postProcessUninstrumentedClass( MutableClass mutableClass, byte[] origClassBytes)143 protected byte[] postProcessUninstrumentedClass( 144 MutableClass mutableClass, byte[] origClassBytes) { 145 return origClassBytes; 146 } 147 148 @Override getPackage(String name)149 protected Package getPackage(String name) { 150 Package aPackage = super.getPackage(name); 151 if (aPackage != null) { 152 return aPackage; 153 } 154 155 return ReflectionHelpers.callInstanceMethod(systemClassLoader, "getPackage", 156 from(String.class, name)); 157 } 158 getByteCode(String className)159 protected byte[] getByteCode(String className) throws ClassNotFoundException { 160 String classFilename = className.replace('.', '/') + ".class"; 161 try (InputStream classBytesStream = getClassBytesAsStreamPreferringLocalUrls(classFilename)) { 162 if (classBytesStream == null) { 163 throw new ClassNotFoundException(className); 164 } 165 166 return Util.readBytes(classBytesStream); 167 } catch (IOException e) { 168 throw new ClassNotFoundException("couldn't load " + className, e); 169 } 170 } 171 ensurePackage(final String className)172 private void ensurePackage(final String className) { 173 int lastDotIndex = className.lastIndexOf('.'); 174 if (lastDotIndex != -1) { 175 String pckgName = className.substring(0, lastDotIndex); 176 Package pckg = getPackage(pckgName); 177 if (pckg == null) { 178 definePackage(pckgName, null, null, null, null, null, null, null); 179 } 180 } 181 } 182 183 } 184