• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2020 The Dagger Authors.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package dagger.hilt.android.plugin
17 
18 import dagger.hilt.android.plugin.util.isClassFile
19 import dagger.hilt.android.plugin.util.isJarFile
20 import java.io.File
21 import java.io.FileInputStream
22 import java.util.zip.ZipInputStream
23 import javassist.ClassPool
24 import javassist.CtClass
25 import javassist.Modifier
26 import javassist.bytecode.Bytecode
27 import javassist.bytecode.CodeIterator
28 import javassist.bytecode.Opcode
29 import org.slf4j.LoggerFactory
30 
31 typealias CodeArray = javassist.bytecode.ByteArray // Avoids conflict with Kotlin's stdlib ByteArray
32 
33 /**
34  * A helper class for performing the transform.
35  *
36  * Create it with the list of all available source directories along with the root output directory
37  * and use [AndroidEntryPointClassTransformer.transformFile] or
38  * [AndroidEntryPointClassTransformer.transformJarContents] to perform the actual transformation.
39  */
40 internal class AndroidEntryPointClassTransformer(
41   val taskName: String,
42   allInputs: List<File>,
43   private val sourceRootOutputDir: File,
44   private val copyNonTransformed: Boolean
45 ) {
46   private val logger = LoggerFactory.getLogger(AndroidEntryPointClassTransformer::class.java)
47 
48   // A ClassPool created from the given input files, this allows us to use the higher
49   // level Javaassit APIs, but requires class parsing/loading.
50   private val classPool: ClassPool = ClassPool(true).also { pool ->
51     allInputs.forEach {
52       pool.appendClassPath(it.path)
53     }
54   }
55 
56   init {
57     sourceRootOutputDir.mkdirs()
58   }
59 
60   /**
61    * Transforms the classes inside the jar and copies re-written class files if and only if they are
62    * transformed.
63    *
64    * @param inputFile The jar file to transform, must be a jar.
65    * @return true if at least one class within the jar was transformed.
66    */
67   fun transformJarContents(inputFile: File): Boolean {
68     require(inputFile.isJarFile()) {
69       "Invalid file, '$inputFile' is not a jar."
70     }
71     // Validate transform is not applied to a jar when copying is enabled, meaning the transformer
72     // is being used in the Android transform API pipeline which does not need to transform jars
73     // and handles copying them.
74     check(!copyNonTransformed) {
75       "Transforming a jar is not supported with 'copyNonTransformed'."
76     }
77     var transformed = false
78     ZipInputStream(FileInputStream(inputFile)).use { input ->
79       var entry = input.nextEntry
80       while (entry != null) {
81         if (entry.isClassFile()) {
82           val clazz = classPool.makeClass(input, false)
83           transformed = transformClassToOutput(clazz) || transformed
84           clazz.detach()
85         }
86         entry = input.nextEntry
87       }
88     }
89     return transformed
90   }
91 
92   /**
93    * Transform a single class file.
94    *
95    * @param inputFile The file to transform, must be a class file.
96    * @return true if the class file was transformed.
97    */
98   fun transformFile(inputFile: File): Boolean {
99     check(inputFile.isClassFile()) {
100       "Invalid file, '$inputFile' is not a class."
101     }
102     val clazz = inputFile.inputStream().use { classPool.makeClass(it, false) }
103     val transformed = transformClassToOutput(clazz)
104     clazz.detach()
105     return transformed
106   }
107 
108   private fun transformClassToOutput(clazz: CtClass): Boolean {
109     val transformed = transformClass(clazz)
110     if (transformed || copyNonTransformed) {
111       clazz.writeFile(sourceRootOutputDir.path)
112     }
113     return transformed
114   }
115 
116   private fun transformClass(clazz: CtClass): Boolean {
117     if (ANDROID_ENTRY_POINT_ANNOTATIONS.none { clazz.hasAnnotation(it) }) {
118       // Not a Android entry point annotated class, don't do anything.
119       return false
120     }
121 
122     // TODO(danysantiago): Handle classes with '$' in their name if they do become an issue.
123     val superclassName = clazz.classFile.superclass
124     val entryPointSuperclassName =
125       clazz.packageName + ".Hilt_" + clazz.simpleName.replace("$", "_")
126     logger.info(
127       "[$taskName] Transforming ${clazz.name} to extend $entryPointSuperclassName instead of " +
128         "$superclassName."
129     )
130     val entryPointSuperclass = classPool.get(entryPointSuperclassName)
131     clazz.superclass = entryPointSuperclass
132     transformSuperMethodCalls(clazz, superclassName, entryPointSuperclassName)
133 
134     // Check if Hilt generated class is a BroadcastReceiver with the marker field which means
135     // a super.onReceive invocation has to be inserted in the implementation.
136     if (entryPointSuperclass.declaredFields.any { it.name == "onReceiveBytecodeInjectionMarker" }) {
137       transformOnReceive(clazz, entryPointSuperclassName)
138     }
139 
140     return true
141   }
142 
143   /**
144    * Iterates over each declared method, finding in its bodies super calls. (e.g. super.onCreate())
145    * and rewrites the method reference of the invokespecial instruction to one that uses the new
146    * superclass.
147    *
148    * The invokespecial instruction is emitted for code that between other things also invokes a
149    * method of a superclass of the current class. The opcode invokespecial takes two operands, each
150    * of 8 bit, that together represent an address in the constant pool to a method reference. The
151    * method reference is computed at compile-time by looking the direct superclass declaration, but
152    * at runtime the code behaves like invokevirtual, where as the actual method invoked is looked up
153    * based on the class hierarchy.
154    *
155    * However, it has been observed that on APIs 19 to 22 the Android Runtime (ART) jumps over the
156    * direct superclass and into the method reference class, causing unexpected behaviours.
157    * Therefore, this method performs the additional transformation to rewrite direct super call
158    * invocations to use a method reference whose class in the pool is the new superclass. Note that
159    * this is not necessary for constructor calls since the Javassist library takes care of those.
160    *
161    * @see: https://docs.oracle.com/javase/specs/jvms/se11/html/jvms-6.html#jvms-6.5.invokespecial
162    * @see: https://source.android.com/devices/tech/dalvik/dalvik-bytecode
163    */
164   private fun transformSuperMethodCalls(
165     clazz: CtClass,
166     oldSuperclassName: String,
167     newSuperclassName: String
168   ) {
169     val constantPool = clazz.classFile.constPool
170     clazz.declaredMethods
171       .filter {
172         it.methodInfo.isMethod &&
173           !Modifier.isStatic(it.modifiers) &&
174           !Modifier.isAbstract(it.modifiers) &&
175           !Modifier.isNative(it.modifiers)
176       }
177       .forEach { method ->
178         val codeAttr = method.methodInfo.codeAttribute
179         val code = codeAttr.code
180         codeAttr.iterator().forEachInstruction { index, opcode ->
181           // We are only interested in 'invokespecial' instructions.
182           if (opcode != Opcode.INVOKESPECIAL) {
183             return@forEachInstruction
184           }
185           // If the method reference of the instruction is not using the old superclass then we
186           // should not rewrite it.
187           val methodRef = CodeArray.readU16bit(code, index + 1)
188           val currentClassRef = constantPool.getMethodrefClassName(methodRef)
189           if (currentClassRef != oldSuperclassName) {
190             return@forEachInstruction
191           }
192           val nameAndTypeRef = constantPool.getMethodrefNameAndType(methodRef)
193           val newSuperclassRef = constantPool.addClassInfo(newSuperclassName)
194           val newMethodRef = constantPool.addMethodrefInfo(newSuperclassRef, nameAndTypeRef)
195           logger.info(
196             "[$taskName] Redirecting an invokespecial in " +
197               "${clazz.name}.${method.name}:${method.signature} at code index $index from " +
198               "method ref #$methodRef to #$newMethodRef."
199           )
200           CodeArray.write16bit(newMethodRef, code, index + 1)
201         }
202       }
203   }
204 
205   // Iterate over each instruction in a CodeIterator.
206   private fun CodeIterator.forEachInstruction(body: CodeIterator.(Int, Int) -> Unit) {
207     while (hasNext()) {
208       val index = next()
209       this.body(index, byteAt(index))
210     }
211   }
212 
213   /**
214    * For a BroadcastReceiver insert a super call in the onReceive method implementation since
215    * after the class is transformed onReceive will no longer be abstract (it is implemented by
216    * Hilt generated receiver).
217    */
218   private fun transformOnReceive(clazz: CtClass, entryPointSuperclassName: String) {
219     val method = clazz.declaredMethods.first {
220       it.name + it.signature == ON_RECEIVE_METHOD_NAME + ON_RECEIVE_METHOD_SIGNATURE
221     }
222     val constantPool = clazz.classFile.constPool
223     val newCode = Bytecode(constantPool).apply {
224       addAload(0) // Loads 'this'
225       addAload(1) // Loads method param 1 (Context)
226       addAload(2) // Loads method param 2 (Intent)
227       addInvokespecial(
228         entryPointSuperclassName, ON_RECEIVE_METHOD_NAME, ON_RECEIVE_METHOD_SIGNATURE
229       )
230     }
231     val newCodeAttribute = newCode.toCodeAttribute()
232     val currentCodeAttribute = method.methodInfo.codeAttribute
233     currentCodeAttribute.maxStack =
234       maxOf(newCodeAttribute.maxStack, currentCodeAttribute.maxStack)
235     currentCodeAttribute.maxLocals =
236       maxOf(newCodeAttribute.maxLocals, currentCodeAttribute.maxLocals)
237     val codeIterator = currentCodeAttribute.iterator()
238     val pos = codeIterator.insertEx(newCode.get()) // insert new code
239     codeIterator.insert(newCodeAttribute.exceptionTable, pos) // offset exception table
240     method.methodInfo.rebuildStackMap(clazz.classPool) // update stack table
241   }
242 
243   companion object {
244     val ANDROID_ENTRY_POINT_ANNOTATIONS = setOf(
245       "dagger.hilt.android.AndroidEntryPoint",
246       "dagger.hilt.android.HiltAndroidApp"
247     )
248     val ON_RECEIVE_METHOD_NAME = "onReceive"
249     val ON_RECEIVE_METHOD_SIGNATURE =
250       "(Landroid/content/Context;Landroid/content/Intent;)V"
251   }
252 }
253