• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2020 The Dagger Authors.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package dagger.hilt.android.plugin
17 
18 import dagger.hilt.android.plugin.util.isClassFile
19 import dagger.hilt.android.plugin.util.isJarFile
20 import java.io.File
21 import java.io.FileInputStream
22 import java.util.zip.ZipInputStream
23 import javassist.ClassPool
24 import javassist.CtClass
25 import javassist.Modifier
26 import javassist.bytecode.Bytecode
27 import javassist.bytecode.CodeIterator
28 import javassist.bytecode.Opcode
29 import org.slf4j.LoggerFactory
30 
31 typealias CodeArray = javassist.bytecode.ByteArray // Avoids conflict with Kotlin's stdlib ByteArray
32 
33 /**
34  * A helper class for performing the transform.
35  *
36  * Create it with the list of all available source directories along with the root output directory
37  * and use [AndroidEntryPointClassTransformer.transformFile] or
38  * [AndroidEntryPointClassTransformer.transformJarContents] to perform the actual transformation.
39  */
40 internal class AndroidEntryPointClassTransformer(
41   val taskName: String,
42   allInputs: List<File>,
43   private val sourceRootOutputDir: File,
44   private val copyNonTransformed: Boolean
45 ) {
46   private val logger = LoggerFactory.getLogger(AndroidEntryPointClassTransformer::class.java)
47 
48   // A ClassPool created from the given input files, this allows us to use the higher
49   // level Javaassit APIs, but requires class parsing/loading.
50   private val classPool: ClassPool = ClassPool(true).also { pool ->
51     allInputs.forEach {
52       pool.appendClassPath(it.path)
53     }
54   }
55 
56   init {
57     sourceRootOutputDir.mkdirs()
58   }
59 
60   /**
61    * Transforms the classes inside the jar and copies re-written class files if and only if they are
62    * transformed.
63    *
64    * @param inputFile The jar file to transform, must be a jar.
65    * @return true if at least one class within the jar was transformed.
66    */
67   fun transformJarContents(inputFile: File): Boolean {
68     require(inputFile.isJarFile()) {
69       "Invalid file, '$inputFile' is not a jar."
70     }
71     // Validate transform is not applied to a jar when copying is enabled, meaning the transformer
72     // is being used in the Android transform API pipeline which does not need to transform jars
73     // and handles copying them.
74     check(!copyNonTransformed) {
75       "Transforming a jar is not supported with 'copyNonTransformed'."
76     }
77     var transformed = false
78     ZipInputStream(FileInputStream(inputFile)).use { input ->
79       var entry = input.nextEntry
80       while (entry != null) {
81         if (entry.isClassFile()) {
82           val clazz = classPool.makeClass(input, false)
83           transformed = transformClassToOutput(clazz) || transformed
84         }
85         entry = input.nextEntry
86       }
87     }
88     return transformed
89   }
90 
91   /**
92    * Transform a single class file.
93    *
94    * @param inputFile The file to transform, must be a class file.
95    * @return true if the class file was transformed.
96    */
97   fun transformFile(inputFile: File): Boolean {
98     check(inputFile.isClassFile()) {
99       "Invalid file, '$inputFile' is not a class."
100     }
101     val clazz = inputFile.inputStream().use { classPool.makeClass(it, false) }
102     return transformClassToOutput(clazz)
103   }
104 
105   private fun transformClassToOutput(clazz: CtClass): Boolean {
106     val transformed = transformClass(clazz)
107     if (transformed || copyNonTransformed) {
108       clazz.writeFile(sourceRootOutputDir.path)
109     }
110     return transformed
111   }
112 
113   private fun transformClass(clazz: CtClass): Boolean {
114     if (ANDROID_ENTRY_POINT_ANNOTATIONS.none { clazz.hasAnnotation(it) }) {
115       // Not a Android entry point annotated class, don't do anything.
116       return false
117     }
118 
119     // TODO(danysantiago): Handle classes with '$' in their name if they do become an issue.
120     val superclassName = clazz.classFile.superclass
121     val entryPointSuperclassName =
122       clazz.packageName + ".Hilt_" + clazz.simpleName.replace("$", "_")
123     logger.info(
124       "[$taskName] Transforming ${clazz.name} to extend $entryPointSuperclassName instead of " +
125         "$superclassName."
126     )
127     val entryPointSuperclass = classPool.get(entryPointSuperclassName)
128     clazz.superclass = entryPointSuperclass
129     transformSuperMethodCalls(clazz, superclassName, entryPointSuperclassName)
130 
131     // Check if Hilt generated class is a BroadcastReceiver with the marker field which means
132     // a super.onReceive invocation has to be inserted in the implementation.
133     if (entryPointSuperclass.declaredFields.any { it.name == "onReceiveBytecodeInjectionMarker" }) {
134       transformOnReceive(clazz, entryPointSuperclassName)
135     }
136 
137     return true
138   }
139 
140   /**
141    * Iterates over each declared method, finding in its bodies super calls. (e.g. super.onCreate())
142    * and rewrites the method reference of the invokespecial instruction to one that uses the new
143    * superclass.
144    *
145    * The invokespecial instruction is emitted for code that between other things also invokes a
146    * method of a superclass of the current class. The opcode invokespecial takes two operands, each
147    * of 8 bit, that together represent an address in the constant pool to a method reference. The
148    * method reference is computed at compile-time by looking the direct superclass declaration, but
149    * at runtime the code behaves like invokevirtual, where as the actual method invoked is looked up
150    * based on the class hierarchy.
151    *
152    * However, it has been observed that on APIs 19 to 22 the Android Runtime (ART) jumps over the
153    * direct superclass and into the method reference class, causing unexpected behaviours.
154    * Therefore, this method performs the additional transformation to rewrite direct super call
155    * invocations to use a method reference whose class in the pool is the new superclass. Note that
156    * this is not necessary for constructor calls since the Javassist library takes care of those.
157    *
158    * @see: https://docs.oracle.com/javase/specs/jvms/se11/html/jvms-6.html#jvms-6.5.invokespecial
159    * @see: https://source.android.com/devices/tech/dalvik/dalvik-bytecode
160    */
161   private fun transformSuperMethodCalls(
162     clazz: CtClass,
163     oldSuperclassName: String,
164     newSuperclassName: String
165   ) {
166     val constantPool = clazz.classFile.constPool
167     clazz.declaredMethods
168       .filter {
169         it.methodInfo.isMethod &&
170           !Modifier.isStatic(it.modifiers) &&
171           !Modifier.isAbstract(it.modifiers) &&
172           !Modifier.isNative(it.modifiers)
173       }
174       .forEach { method ->
175         val codeAttr = method.methodInfo.codeAttribute
176         val code = codeAttr.code
177         codeAttr.iterator().forEachInstruction { index, opcode ->
178           // We are only interested in 'invokespecial' instructions.
179           if (opcode != Opcode.INVOKESPECIAL) {
180             return@forEachInstruction
181           }
182           // If the method reference of the instruction is not using the old superclass then we
183           // should not rewrite it.
184           val methodRef = CodeArray.readU16bit(code, index + 1)
185           val currentClassRef = constantPool.getMethodrefClassName(methodRef)
186           if (currentClassRef != oldSuperclassName) {
187             return@forEachInstruction
188           }
189           val nameAndTypeRef = constantPool.getMethodrefNameAndType(methodRef)
190           val newSuperclassRef = constantPool.addClassInfo(newSuperclassName)
191           val newMethodRef = constantPool.addMethodrefInfo(newSuperclassRef, nameAndTypeRef)
192           logger.info(
193             "[$taskName] Redirecting an invokespecial in " +
194               "${clazz.name}.${method.name}:${method.signature} at code index $index from " +
195               "method ref #$methodRef to #$newMethodRef."
196           )
197           CodeArray.write16bit(newMethodRef, code, index + 1)
198         }
199       }
200   }
201 
202   // Iterate over each instruction in a CodeIterator.
203   private fun CodeIterator.forEachInstruction(body: CodeIterator.(Int, Int) -> Unit) {
204     while (hasNext()) {
205       val index = next()
206       this.body(index, byteAt(index))
207     }
208   }
209 
210   /**
211    * For a BroadcastReceiver insert a super call in the onReceive method implementation since
212    * after the class is transformed onReceive will no longer be abstract (it is implemented by
213    * Hilt generated receiver).
214    */
215   private fun transformOnReceive(clazz: CtClass, entryPointSuperclassName: String) {
216     val method = clazz.declaredMethods.first {
217       it.name + it.signature == ON_RECEIVE_METHOD_NAME + ON_RECEIVE_METHOD_SIGNATURE
218     }
219     val constantPool = clazz.classFile.constPool
220     val newCode = Bytecode(constantPool).apply {
221       addAload(0) // Loads 'this'
222       addAload(1) // Loads method param 1 (Context)
223       addAload(2) // Loads method param 2 (Intent)
224       addInvokespecial(
225         entryPointSuperclassName, ON_RECEIVE_METHOD_NAME, ON_RECEIVE_METHOD_SIGNATURE
226       )
227     }
228     val newCodeAttribute = newCode.toCodeAttribute()
229     val currentCodeAttribute = method.methodInfo.codeAttribute
230     currentCodeAttribute.maxStack =
231       maxOf(newCodeAttribute.maxStack, currentCodeAttribute.maxStack)
232     currentCodeAttribute.maxLocals =
233       maxOf(newCodeAttribute.maxLocals, currentCodeAttribute.maxLocals)
234     val codeIterator = currentCodeAttribute.iterator()
235     val pos = codeIterator.insertEx(newCode.get()) // insert new code
236     codeIterator.insert(newCodeAttribute.exceptionTable, pos) // offset exception table
237     method.methodInfo.rebuildStackMap(clazz.classPool) // update stack table
238   }
239 
240   companion object {
241     val ANDROID_ENTRY_POINT_ANNOTATIONS = setOf(
242       "dagger.hilt.android.AndroidEntryPoint",
243       "dagger.hilt.android.HiltAndroidApp"
244     )
245     val ON_RECEIVE_METHOD_NAME = "onReceive"
246     val ON_RECEIVE_METHOD_SIGNATURE =
247       "(Landroid/content/Context;Landroid/content/Intent;)V"
248   }
249 }
250